package vm import ( "context" "encoding/binary" "encoding/json" "log" "github.com/tetratelabs/wazero" "github.com/tetratelabs/wazero/api" "go-blockchain/blockchain" ) // registerHostModule builds and instantiates the "env" host module. // argsJSON holds the raw JSON args_json bytes from the CALL_CONTRACT payload; // they are exposed to the WASM contract via get_args / get_arg_str / get_arg_u64. // // Host functions available to WASM contracts: // // env.get_state(keyPtr, keyLen, dstPtr, dstLen i32) → written i32 // env.get_state_len(keyPtr, keyLen i32) → valLen i32 // env.set_state(keyPtr, keyLen, valPtr, valLen i32) // env.get_balance(pubPtr, pubLen i32) → balance i64 // env.transfer(fromPtr, fromLen, toPtr, toLen i32, amount i64) → errCode i32 // env.get_caller(bufPtr, bufLen i32) → written i32 // env.get_block_height() → height i64 // env.get_contract_treasury(bufPtr, bufLen i32) → written i32 // env.log(msgPtr, msgLen i32) // env.put_u64(keyPtr, keyLen i32, val i64) // env.get_u64(keyPtr, keyLen i32) → val i64 // env.get_args(dstPtr, dstLen i32) → written i32 // env.get_arg_str(idx, dstPtr, dstLen i32) → written i32 // env.get_arg_u64(idx i32) → val i64 // env.call_contract(cidPtr, cidLen, mthPtr, mthLen, argPtr, argLen i32) → errCode i32 func registerHostModule(ctx context.Context, rt wazero.Runtime, env blockchain.VMHostEnv, argsJSON []byte) (api.Closer, error) { b := rt.NewHostModuleBuilder("env") // --- get_state_len(keyPtr i32, keyLen i32) → valLen i32 --- b.NewFunctionBuilder(). WithGoModuleFunction(api.GoModuleFunc(func(ctx context.Context, m api.Module, stack []uint64) { keyPtr := api.DecodeU32(stack[0]) keyLen := api.DecodeU32(stack[1]) key, ok := m.Memory().Read(keyPtr, keyLen) if !ok { stack[0] = 0 return } val, _ := env.GetState(key) stack[0] = api.EncodeU32(uint32(len(val))) }), []api.ValueType{api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}). Export("get_state_len") // --- get_state(keyPtr i32, keyLen i32, dstPtr i32, dstLen i32) → written i32 --- b.NewFunctionBuilder(). WithGoModuleFunction(api.GoModuleFunc(func(ctx context.Context, m api.Module, stack []uint64) { keyPtr := api.DecodeU32(stack[0]) keyLen := api.DecodeU32(stack[1]) dstPtr := api.DecodeU32(stack[2]) dstLen := api.DecodeU32(stack[3]) key, ok := m.Memory().Read(keyPtr, keyLen) if !ok { stack[0] = 0 return } val, _ := env.GetState(key) n := uint32(len(val)) if n > dstLen { n = dstLen } if n > 0 { m.Memory().Write(dstPtr, val[:n]) } stack[0] = api.EncodeU32(n) }), []api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}). Export("get_state") // --- set_state(keyPtr i32, keyLen i32, valPtr i32, valLen i32) --- b.NewFunctionBuilder(). WithGoModuleFunction(api.GoModuleFunc(func(ctx context.Context, m api.Module, stack []uint64) { keyPtr := api.DecodeU32(stack[0]) keyLen := api.DecodeU32(stack[1]) valPtr := api.DecodeU32(stack[2]) valLen := api.DecodeU32(stack[3]) key, okK := m.Memory().Read(keyPtr, keyLen) val, okV := m.Memory().Read(valPtr, valLen) if !okK || !okV { return } // Copy slices — WASM memory may be invalidated after the call returns. keyCopy := make([]byte, len(key)) copy(keyCopy, key) valCopy := make([]byte, len(val)) copy(valCopy, val) _ = env.SetState(keyCopy, valCopy) }), []api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{}). Export("set_state") // --- get_balance(pubPtr i32, pubLen i32) → balance i64 --- b.NewFunctionBuilder(). WithGoModuleFunction(api.GoModuleFunc(func(ctx context.Context, m api.Module, stack []uint64) { pubPtr := api.DecodeU32(stack[0]) pubLen := api.DecodeU32(stack[1]) pub, ok := m.Memory().Read(pubPtr, pubLen) if !ok { stack[0] = 0 return } bal, _ := env.GetBalance(string(pub)) stack[0] = api.EncodeI64(int64(bal)) }), []api.ValueType{api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI64}). Export("get_balance") // --- transfer(fromPtr, fromLen, toPtr, toLen i32, amount i64) → errCode i32 --- // Returns 0 on success, 1 on failure (insufficient balance or bad args). b.NewFunctionBuilder(). WithGoModuleFunction(api.GoModuleFunc(func(ctx context.Context, m api.Module, stack []uint64) { fromPtr := api.DecodeU32(stack[0]) fromLen := api.DecodeU32(stack[1]) toPtr := api.DecodeU32(stack[2]) toLen := api.DecodeU32(stack[3]) amount := uint64(int64(stack[4])) from, okF := m.Memory().Read(fromPtr, fromLen) to, okT := m.Memory().Read(toPtr, toLen) if !okF || !okT { stack[0] = api.EncodeU32(1) return } if err := env.Transfer(string(from), string(to), amount); err != nil { stack[0] = api.EncodeU32(1) return } stack[0] = api.EncodeU32(0) }), []api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI64}, []api.ValueType{api.ValueTypeI32}). Export("transfer") // --- get_caller(bufPtr i32, bufLen i32) → written i32 --- b.NewFunctionBuilder(). WithGoModuleFunction(api.GoModuleFunc(func(ctx context.Context, m api.Module, stack []uint64) { bufPtr := api.DecodeU32(stack[0]) bufLen := api.DecodeU32(stack[1]) caller := env.GetCaller() n := uint32(len(caller)) if n > bufLen { n = bufLen } if n > 0 { m.Memory().Write(bufPtr, []byte(caller[:n])) } stack[0] = api.EncodeU32(n) }), []api.ValueType{api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}). Export("get_caller") // --- get_block_height() → height i64 --- b.NewFunctionBuilder(). WithGoModuleFunction(api.GoModuleFunc(func(ctx context.Context, m api.Module, stack []uint64) { stack[0] = api.EncodeI64(int64(env.GetBlockHeight())) }), []api.ValueType{}, []api.ValueType{api.ValueTypeI64}). Export("get_block_height") // --- log(msgPtr i32, msgLen i32) --- b.NewFunctionBuilder(). WithGoModuleFunction(api.GoModuleFunc(func(ctx context.Context, m api.Module, stack []uint64) { msgPtr := api.DecodeU32(stack[0]) msgLen := api.DecodeU32(stack[1]) msg, ok := m.Memory().Read(msgPtr, msgLen) if !ok { return } env.Log(string(msg)) }), []api.ValueType{api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{}). Export("log") // --- put_u64(keyPtr, keyLen i32, val i64) --- // Convenience: stores a uint64 as 8-byte big-endian. b.NewFunctionBuilder(). WithGoModuleFunction(api.GoModuleFunc(func(ctx context.Context, m api.Module, stack []uint64) { keyPtr := api.DecodeU32(stack[0]) keyLen := api.DecodeU32(stack[1]) val := uint64(int64(stack[2])) key, ok := m.Memory().Read(keyPtr, keyLen) if !ok { return } buf := make([]byte, 8) binary.BigEndian.PutUint64(buf, val) keyCopy := make([]byte, len(key)) copy(keyCopy, key) _ = env.SetState(keyCopy, buf) }), []api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI64}, []api.ValueType{}). Export("put_u64") // --- get_u64(keyPtr, keyLen i32) → val i64 --- // Reads an 8-byte big-endian uint64 stored by put_u64. b.NewFunctionBuilder(). WithGoModuleFunction(api.GoModuleFunc(func(ctx context.Context, m api.Module, stack []uint64) { keyPtr := api.DecodeU32(stack[0]) keyLen := api.DecodeU32(stack[1]) key, ok := m.Memory().Read(keyPtr, keyLen) if !ok { stack[0] = 0 return } val, _ := env.GetState(key) if len(val) < 8 { stack[0] = 0 return } stack[0] = api.EncodeI64(int64(binary.BigEndian.Uint64(val[:8]))) }), []api.ValueType{api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI64}). Export("get_u64") // ── Argument accessors ──────────────────────────────────────────────────── // argsJSON is the JSON array from CallContractPayload.ArgsJSON. // These functions let contracts read typed call arguments at runtime. // Lazily parse argsJSON once and cache the result. var parsedArgs []json.RawMessage if len(argsJSON) > 0 { _ = json.Unmarshal(argsJSON, &parsedArgs) } // --- get_args(dstPtr i32, dstLen i32) → written i32 --- // Copies the raw args_json bytes into WASM memory. b.NewFunctionBuilder(). WithGoModuleFunction(api.GoModuleFunc(func(ctx context.Context, m api.Module, stack []uint64) { dstPtr := api.DecodeU32(stack[0]) dstLen := api.DecodeU32(stack[1]) n := uint32(len(argsJSON)) if n > dstLen { n = dstLen } if n > 0 { m.Memory().Write(dstPtr, argsJSON[:n]) } stack[0] = api.EncodeU32(n) }), []api.ValueType{api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}). Export("get_args") // --- get_arg_str(idx i32, dstPtr i32, dstLen i32) → written i32 --- // Parses args_json as a JSON array, reads the idx-th element as a string, // copies the UTF-8 bytes (without quotes) into WASM memory. // Returns 0 if idx is out of range or the element is not a JSON string. b.NewFunctionBuilder(). WithGoModuleFunction(api.GoModuleFunc(func(ctx context.Context, m api.Module, stack []uint64) { idx := api.DecodeU32(stack[0]) dstPtr := api.DecodeU32(stack[1]) dstLen := api.DecodeU32(stack[2]) if int(idx) >= len(parsedArgs) { stack[0] = 0 return } var s string if err := json.Unmarshal(parsedArgs[idx], &s); err != nil { stack[0] = 0 return } n := uint32(len(s)) if n > dstLen { n = dstLen } if n > 0 { m.Memory().Write(dstPtr, []byte(s[:n])) } stack[0] = api.EncodeU32(n) }), []api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}). Export("get_arg_str") // --- get_arg_u64(idx i32) → val i64 --- // Parses args_json as a JSON array, reads the idx-th element as a uint64. // Returns 0 if idx is out of range or the element is not a JSON number. b.NewFunctionBuilder(). WithGoModuleFunction(api.GoModuleFunc(func(ctx context.Context, m api.Module, stack []uint64) { idx := api.DecodeU32(stack[0]) if int(idx) >= len(parsedArgs) { stack[0] = 0 return } var n uint64 if err := json.Unmarshal(parsedArgs[idx], &n); err != nil { stack[0] = 0 return } stack[0] = api.EncodeI64(int64(n)) }), []api.ValueType{api.ValueTypeI32}, []api.ValueType{api.ValueTypeI64}). Export("get_arg_u64") // --- get_contract_treasury(bufPtr i32, bufLen i32) → written i32 --- // Returns the contract's ownerless treasury address as a 64-char hex string. // Derived as hex(sha256(contractID + ":treasury")); no private key exists. b.NewFunctionBuilder(). WithGoModuleFunction(api.GoModuleFunc(func(ctx context.Context, m api.Module, stack []uint64) { bufPtr := api.DecodeU32(stack[0]) bufLen := api.DecodeU32(stack[1]) treasury := env.GetContractTreasury() n := uint32(len(treasury)) if n > bufLen { n = bufLen } if n > 0 { m.Memory().Write(bufPtr, []byte(treasury[:n])) } stack[0] = api.EncodeU32(n) }), []api.ValueType{api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}). Export("get_contract_treasury") // --- call_contract(cidPtr, cidLen, methodPtr, methodLen, argsPtr, argsLen i32) → i32 --- // Calls a method on another deployed contract. Returns 0 on success, 1 on error. // The sub-call's caller is set to the current contract ID. // Gas consumed by the sub-call is charged to the parent call's gas counter. b.NewFunctionBuilder(). WithGoModuleFunction(api.GoModuleFunc(func(ctx context.Context, m api.Module, stack []uint64) { cidPtr := api.DecodeU32(stack[0]) cidLen := api.DecodeU32(stack[1]) mthPtr := api.DecodeU32(stack[2]) mthLen := api.DecodeU32(stack[3]) argPtr := api.DecodeU32(stack[4]) argLen := api.DecodeU32(stack[5]) cid, ok1 := m.Memory().Read(cidPtr, cidLen) mth, ok2 := m.Memory().Read(mthPtr, mthLen) if !ok1 || !ok2 { stack[0] = api.EncodeU32(1) return } var args []byte if argLen > 0 { args, _ = m.Memory().Read(argPtr, argLen) } // Give the sub-call whatever gas the parent has remaining. gc := gasFromContext(ctx) var gasLeft uint64 = 10_000 // safe fallback if no counter if gc != nil { gasLeft = gc.Remaining() } if gasLeft == 0 { stack[0] = api.EncodeU32(1) return } gasUsed, err := env.CallContract(string(cid), string(mth), args, gasLeft) // Charge parent counter for what the sub-call actually consumed. if gc != nil && gasUsed > 0 { _ = gc.charge(gasUsed) } if err != nil { stack[0] = api.EncodeU32(1) return } stack[0] = api.EncodeU32(0) }), []api.ValueType{ api.ValueTypeI32, api.ValueTypeI32, // contract_id api.ValueTypeI32, api.ValueTypeI32, // method api.ValueTypeI32, api.ValueTypeI32, // args_json }, []api.ValueType{api.ValueTypeI32}). Export("call_contract") // --- gas_tick() --- // Called by instrumented loop headers; charges 1 gas unit per iteration. // Panics (traps the WASM module) when the gas budget is exhausted. b.NewFunctionBuilder(). WithGoModuleFunction(api.GoModuleFunc(func(ctx context.Context, _ api.Module, _ []uint64) { gc := gasFromContext(ctx) if gc == nil { return } if err := gc.charge(1); err != nil { panic(err.Error()) // wazero catches this and surfaces it as a Call error } }), []api.ValueType{}, []api.ValueType{}). Export("gas_tick") inst, err := b.Instantiate(ctx) if err != nil { log.Printf("[VM] registerHostModule: %v", err) return nil, err } return inst, nil }