chore: initial commit for v0.0.1
DChain single-node blockchain + React Native messenger client. Core: - PBFT consensus with multi-sig validator admission + equivocation slashing - BadgerDB + schema migration scaffold (CurrentSchemaVersion=0) - libp2p gossipsub (tx/v1, blocks/v1, relay/v1, version/v1) - Native Go contracts (username_registry) alongside WASM (wazero) - WebSocket gateway with topic-based fanout + Ed25519-nonce auth - Relay mailbox with NaCl envelope encryption (X25519 + Ed25519) - Prometheus /metrics, per-IP rate limit, body-size cap Deployment: - Single-node compose (deploy/single/) with Caddy TLS + optional Prometheus - 3-node dev compose (docker-compose.yml) with mocked internet topology - 3-validator prod compose (deploy/prod/) for federation - Auto-update from Gitea via /api/update-check + systemd timer - Build-time version injection (ldflags → node --version) - UI / Swagger toggle flags (DCHAIN_DISABLE_UI, DCHAIN_DISABLE_SWAGGER) Client (client-app/): - Expo / React Native / NativeWind - E2E NaCl encryption, typing indicator, contact requests - Auto-discovery of canonical contracts, chain_id aware, WS reconnect on node switch Documentation: - README.md, CHANGELOG.md, CONTEXT.md - deploy/single/README.md with 6 operator scenarios - deploy/UPDATE_STRATEGY.md with 4-layer forward-compat design - docs/contracts/*.md per contract
This commit is contained in:
71
vm/abi.go
Normal file
71
vm/abi.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package vm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ABI describes the callable interface of a deployed contract.
|
||||
type ABI struct {
|
||||
Methods []ABIMethod `json:"methods"`
|
||||
}
|
||||
|
||||
// ABIMethod describes a single callable method.
|
||||
type ABIMethod struct {
|
||||
Name string `json:"name"`
|
||||
Args []ABIArg `json:"args"` // may be nil / empty for zero-arg methods
|
||||
}
|
||||
|
||||
// ABIArg describes one parameter of a method.
|
||||
type ABIArg struct {
|
||||
Name string `json:"name"` // e.g. "amount"
|
||||
Type string `json:"type,omitempty"` // e.g. "uint64", "string", "bytes"
|
||||
}
|
||||
|
||||
// ParseABI deserializes an ABI from JSON.
|
||||
func ParseABI(jsonStr string) (*ABI, error) {
|
||||
var a ABI
|
||||
if err := json.Unmarshal([]byte(jsonStr), &a); err != nil {
|
||||
return nil, fmt.Errorf("invalid ABI JSON: %w", err)
|
||||
}
|
||||
return &a, nil
|
||||
}
|
||||
|
||||
// HasMethod returns true if the ABI declares the named method.
|
||||
func (a *ABI) HasMethod(name string) bool {
|
||||
for _, m := range a.Methods {
|
||||
if m.Name == name {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Validate checks that method exists in the ABI and args_json has the right
|
||||
// number of elements. argsJSON may be empty ("" or "[]") for zero-arg methods.
|
||||
func (a *ABI) Validate(method string, argsJSON []byte) error {
|
||||
var target *ABIMethod
|
||||
for i := range a.Methods {
|
||||
if a.Methods[i].Name == method {
|
||||
target = &a.Methods[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
if target == nil {
|
||||
return fmt.Errorf("method %q not found in ABI", method)
|
||||
}
|
||||
if len(target.Args) == 0 {
|
||||
return nil // no args expected — nothing to validate
|
||||
}
|
||||
if len(argsJSON) > 0 {
|
||||
var args []any
|
||||
if err := json.Unmarshal(argsJSON, &args); err != nil {
|
||||
return fmt.Errorf("args_json is not a JSON array: %w", err)
|
||||
}
|
||||
if len(args) != len(target.Args) {
|
||||
return fmt.Errorf("method %q expects %d args, got %d",
|
||||
method, len(target.Args), len(args))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
111
vm/gas.go
Normal file
111
vm/gas.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package vm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/tetratelabs/wazero/api"
|
||||
"github.com/tetratelabs/wazero/experimental"
|
||||
|
||||
"go-blockchain/blockchain"
|
||||
)
|
||||
|
||||
// ErrOutOfGas is returned when a contract call exhausts its gas limit.
|
||||
// It wraps blockchain.ErrTxFailed so the call is skipped rather than
|
||||
// aborting the entire block.
|
||||
var ErrOutOfGas = fmt.Errorf("%w: out of gas", blockchain.ErrTxFailed)
|
||||
|
||||
// gasKey is the context key used to pass the gas counter into host functions.
|
||||
type gasKey struct{}
|
||||
|
||||
// gasCounter holds a mutable gas counter accessible through a context value.
|
||||
type gasCounter struct {
|
||||
used atomic.Uint64
|
||||
limit uint64
|
||||
}
|
||||
|
||||
// withGasCounter attaches a new gas counter to ctx.
|
||||
func withGasCounter(ctx context.Context, limit uint64) (context.Context, *gasCounter) {
|
||||
gc := &gasCounter{limit: limit}
|
||||
return context.WithValue(ctx, gasKey{}, gc), gc
|
||||
}
|
||||
|
||||
// gasFromContext retrieves the gas counter from ctx; panics if not present.
|
||||
func gasFromContext(ctx context.Context) *gasCounter {
|
||||
gc, _ := ctx.Value(gasKey{}).(*gasCounter)
|
||||
return gc
|
||||
}
|
||||
|
||||
// charge attempts to consume n gas units. Returns ErrOutOfGas if the limit is exceeded.
|
||||
func (gc *gasCounter) charge(n uint64) error {
|
||||
if gc == nil {
|
||||
return nil
|
||||
}
|
||||
newUsed := gc.used.Add(n)
|
||||
if newUsed > gc.limit {
|
||||
return ErrOutOfGas
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Used returns total gas consumed so far.
|
||||
func (gc *gasCounter) Used() uint64 {
|
||||
if gc == nil {
|
||||
return 0
|
||||
}
|
||||
return gc.used.Load()
|
||||
}
|
||||
|
||||
// Remaining returns gas budget minus gas used. Returns 0 when exhausted.
|
||||
func (gc *gasCounter) Remaining() uint64 {
|
||||
if gc == nil {
|
||||
return 0
|
||||
}
|
||||
used := gc.used.Load()
|
||||
if used >= gc.limit {
|
||||
return 0
|
||||
}
|
||||
return gc.limit - used
|
||||
}
|
||||
|
||||
// gasListenerFactory is a wazero FunctionListenerFactory that charges gas on
|
||||
// every WASM function call. Each function call costs gasPerCall units.
|
||||
// True instruction-level metering would require bytecode instrumentation;
|
||||
// call-level metering is a pragmatic approximation that prevents runaway contracts.
|
||||
const gasPerCall uint64 = 100
|
||||
|
||||
type gasListenerFactory struct{}
|
||||
|
||||
func (gasListenerFactory) NewFunctionListener(def api.FunctionDefinition) experimental.FunctionListener {
|
||||
return gasListener{}
|
||||
}
|
||||
|
||||
type gasListener struct{}
|
||||
|
||||
func (gasListener) Before(ctx context.Context, _ api.Module, _ api.FunctionDefinition, _ []uint64, _ experimental.StackIterator) {
|
||||
gc := gasFromContext(ctx)
|
||||
if gc != nil {
|
||||
// Ignore error here — we can't abort from Before.
|
||||
// isOutOfGas() is checked after Call() returns.
|
||||
_ = gc.charge(gasPerCall)
|
||||
}
|
||||
}
|
||||
|
||||
func (gasListener) After(_ context.Context, _ api.Module, _ api.FunctionDefinition, _ []uint64) {}
|
||||
func (gasListener) Abort(_ context.Context, _ api.Module, _ api.FunctionDefinition, _ error) {}
|
||||
|
||||
// isOutOfGas reports whether err indicates gas exhaustion.
|
||||
func isOutOfGas(gc *gasCounter) bool {
|
||||
if gc == nil {
|
||||
return false
|
||||
}
|
||||
return gc.Used() > gc.limit
|
||||
}
|
||||
|
||||
// ensure gasListenerFactory satisfies the interface at compile time.
|
||||
var _ experimental.FunctionListenerFactory = gasListenerFactory{}
|
||||
|
||||
// ensure errors chain correctly.
|
||||
var _ = errors.Is(ErrOutOfGas, blockchain.ErrTxFailed)
|
||||
382
vm/host.go
Normal file
382
vm/host.go
Normal file
@@ -0,0 +1,382 @@
|
||||
package vm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"log"
|
||||
|
||||
"github.com/tetratelabs/wazero"
|
||||
"github.com/tetratelabs/wazero/api"
|
||||
|
||||
"go-blockchain/blockchain"
|
||||
)
|
||||
|
||||
// registerHostModule builds and instantiates the "env" host module.
|
||||
// argsJSON holds the raw JSON args_json bytes from the CALL_CONTRACT payload;
|
||||
// they are exposed to the WASM contract via get_args / get_arg_str / get_arg_u64.
|
||||
//
|
||||
// Host functions available to WASM contracts:
|
||||
//
|
||||
// env.get_state(keyPtr, keyLen, dstPtr, dstLen i32) → written i32
|
||||
// env.get_state_len(keyPtr, keyLen i32) → valLen i32
|
||||
// env.set_state(keyPtr, keyLen, valPtr, valLen i32)
|
||||
// env.get_balance(pubPtr, pubLen i32) → balance i64
|
||||
// env.transfer(fromPtr, fromLen, toPtr, toLen i32, amount i64) → errCode i32
|
||||
// env.get_caller(bufPtr, bufLen i32) → written i32
|
||||
// env.get_block_height() → height i64
|
||||
// env.get_contract_treasury(bufPtr, bufLen i32) → written i32
|
||||
// env.log(msgPtr, msgLen i32)
|
||||
// env.put_u64(keyPtr, keyLen i32, val i64)
|
||||
// env.get_u64(keyPtr, keyLen i32) → val i64
|
||||
// env.get_args(dstPtr, dstLen i32) → written i32
|
||||
// env.get_arg_str(idx, dstPtr, dstLen i32) → written i32
|
||||
// env.get_arg_u64(idx i32) → val i64
|
||||
// env.call_contract(cidPtr, cidLen, mthPtr, mthLen, argPtr, argLen i32) → errCode i32
|
||||
func registerHostModule(ctx context.Context, rt wazero.Runtime, env blockchain.VMHostEnv, argsJSON []byte) (api.Closer, error) {
|
||||
b := rt.NewHostModuleBuilder("env")
|
||||
|
||||
// --- get_state_len(keyPtr i32, keyLen i32) → valLen i32 ---
|
||||
b.NewFunctionBuilder().
|
||||
WithGoModuleFunction(api.GoModuleFunc(func(ctx context.Context, m api.Module, stack []uint64) {
|
||||
keyPtr := api.DecodeU32(stack[0])
|
||||
keyLen := api.DecodeU32(stack[1])
|
||||
key, ok := m.Memory().Read(keyPtr, keyLen)
|
||||
if !ok {
|
||||
stack[0] = 0
|
||||
return
|
||||
}
|
||||
val, _ := env.GetState(key)
|
||||
stack[0] = api.EncodeU32(uint32(len(val)))
|
||||
}), []api.ValueType{api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
|
||||
Export("get_state_len")
|
||||
|
||||
// --- get_state(keyPtr i32, keyLen i32, dstPtr i32, dstLen i32) → written i32 ---
|
||||
b.NewFunctionBuilder().
|
||||
WithGoModuleFunction(api.GoModuleFunc(func(ctx context.Context, m api.Module, stack []uint64) {
|
||||
keyPtr := api.DecodeU32(stack[0])
|
||||
keyLen := api.DecodeU32(stack[1])
|
||||
dstPtr := api.DecodeU32(stack[2])
|
||||
dstLen := api.DecodeU32(stack[3])
|
||||
key, ok := m.Memory().Read(keyPtr, keyLen)
|
||||
if !ok {
|
||||
stack[0] = 0
|
||||
return
|
||||
}
|
||||
val, _ := env.GetState(key)
|
||||
n := uint32(len(val))
|
||||
if n > dstLen {
|
||||
n = dstLen
|
||||
}
|
||||
if n > 0 {
|
||||
m.Memory().Write(dstPtr, val[:n])
|
||||
}
|
||||
stack[0] = api.EncodeU32(n)
|
||||
}), []api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
|
||||
Export("get_state")
|
||||
|
||||
// --- set_state(keyPtr i32, keyLen i32, valPtr i32, valLen i32) ---
|
||||
b.NewFunctionBuilder().
|
||||
WithGoModuleFunction(api.GoModuleFunc(func(ctx context.Context, m api.Module, stack []uint64) {
|
||||
keyPtr := api.DecodeU32(stack[0])
|
||||
keyLen := api.DecodeU32(stack[1])
|
||||
valPtr := api.DecodeU32(stack[2])
|
||||
valLen := api.DecodeU32(stack[3])
|
||||
key, okK := m.Memory().Read(keyPtr, keyLen)
|
||||
val, okV := m.Memory().Read(valPtr, valLen)
|
||||
if !okK || !okV {
|
||||
return
|
||||
}
|
||||
// Copy slices — WASM memory may be invalidated after the call returns.
|
||||
keyCopy := make([]byte, len(key))
|
||||
copy(keyCopy, key)
|
||||
valCopy := make([]byte, len(val))
|
||||
copy(valCopy, val)
|
||||
_ = env.SetState(keyCopy, valCopy)
|
||||
}), []api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{}).
|
||||
Export("set_state")
|
||||
|
||||
// --- get_balance(pubPtr i32, pubLen i32) → balance i64 ---
|
||||
b.NewFunctionBuilder().
|
||||
WithGoModuleFunction(api.GoModuleFunc(func(ctx context.Context, m api.Module, stack []uint64) {
|
||||
pubPtr := api.DecodeU32(stack[0])
|
||||
pubLen := api.DecodeU32(stack[1])
|
||||
pub, ok := m.Memory().Read(pubPtr, pubLen)
|
||||
if !ok {
|
||||
stack[0] = 0
|
||||
return
|
||||
}
|
||||
bal, _ := env.GetBalance(string(pub))
|
||||
stack[0] = api.EncodeI64(int64(bal))
|
||||
}), []api.ValueType{api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI64}).
|
||||
Export("get_balance")
|
||||
|
||||
// --- transfer(fromPtr, fromLen, toPtr, toLen i32, amount i64) → errCode i32 ---
|
||||
// Returns 0 on success, 1 on failure (insufficient balance or bad args).
|
||||
b.NewFunctionBuilder().
|
||||
WithGoModuleFunction(api.GoModuleFunc(func(ctx context.Context, m api.Module, stack []uint64) {
|
||||
fromPtr := api.DecodeU32(stack[0])
|
||||
fromLen := api.DecodeU32(stack[1])
|
||||
toPtr := api.DecodeU32(stack[2])
|
||||
toLen := api.DecodeU32(stack[3])
|
||||
amount := uint64(int64(stack[4]))
|
||||
from, okF := m.Memory().Read(fromPtr, fromLen)
|
||||
to, okT := m.Memory().Read(toPtr, toLen)
|
||||
if !okF || !okT {
|
||||
stack[0] = api.EncodeU32(1)
|
||||
return
|
||||
}
|
||||
if err := env.Transfer(string(from), string(to), amount); err != nil {
|
||||
stack[0] = api.EncodeU32(1)
|
||||
return
|
||||
}
|
||||
stack[0] = api.EncodeU32(0)
|
||||
}), []api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI64}, []api.ValueType{api.ValueTypeI32}).
|
||||
Export("transfer")
|
||||
|
||||
// --- get_caller(bufPtr i32, bufLen i32) → written i32 ---
|
||||
b.NewFunctionBuilder().
|
||||
WithGoModuleFunction(api.GoModuleFunc(func(ctx context.Context, m api.Module, stack []uint64) {
|
||||
bufPtr := api.DecodeU32(stack[0])
|
||||
bufLen := api.DecodeU32(stack[1])
|
||||
caller := env.GetCaller()
|
||||
n := uint32(len(caller))
|
||||
if n > bufLen {
|
||||
n = bufLen
|
||||
}
|
||||
if n > 0 {
|
||||
m.Memory().Write(bufPtr, []byte(caller[:n]))
|
||||
}
|
||||
stack[0] = api.EncodeU32(n)
|
||||
}), []api.ValueType{api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
|
||||
Export("get_caller")
|
||||
|
||||
// --- get_block_height() → height i64 ---
|
||||
b.NewFunctionBuilder().
|
||||
WithGoModuleFunction(api.GoModuleFunc(func(ctx context.Context, m api.Module, stack []uint64) {
|
||||
stack[0] = api.EncodeI64(int64(env.GetBlockHeight()))
|
||||
}), []api.ValueType{}, []api.ValueType{api.ValueTypeI64}).
|
||||
Export("get_block_height")
|
||||
|
||||
// --- log(msgPtr i32, msgLen i32) ---
|
||||
b.NewFunctionBuilder().
|
||||
WithGoModuleFunction(api.GoModuleFunc(func(ctx context.Context, m api.Module, stack []uint64) {
|
||||
msgPtr := api.DecodeU32(stack[0])
|
||||
msgLen := api.DecodeU32(stack[1])
|
||||
msg, ok := m.Memory().Read(msgPtr, msgLen)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
env.Log(string(msg))
|
||||
}), []api.ValueType{api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{}).
|
||||
Export("log")
|
||||
|
||||
// --- put_u64(keyPtr, keyLen i32, val i64) ---
|
||||
// Convenience: stores a uint64 as 8-byte big-endian.
|
||||
b.NewFunctionBuilder().
|
||||
WithGoModuleFunction(api.GoModuleFunc(func(ctx context.Context, m api.Module, stack []uint64) {
|
||||
keyPtr := api.DecodeU32(stack[0])
|
||||
keyLen := api.DecodeU32(stack[1])
|
||||
val := uint64(int64(stack[2]))
|
||||
key, ok := m.Memory().Read(keyPtr, keyLen)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
buf := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(buf, val)
|
||||
keyCopy := make([]byte, len(key))
|
||||
copy(keyCopy, key)
|
||||
_ = env.SetState(keyCopy, buf)
|
||||
}), []api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI64}, []api.ValueType{}).
|
||||
Export("put_u64")
|
||||
|
||||
// --- get_u64(keyPtr, keyLen i32) → val i64 ---
|
||||
// Reads an 8-byte big-endian uint64 stored by put_u64.
|
||||
b.NewFunctionBuilder().
|
||||
WithGoModuleFunction(api.GoModuleFunc(func(ctx context.Context, m api.Module, stack []uint64) {
|
||||
keyPtr := api.DecodeU32(stack[0])
|
||||
keyLen := api.DecodeU32(stack[1])
|
||||
key, ok := m.Memory().Read(keyPtr, keyLen)
|
||||
if !ok {
|
||||
stack[0] = 0
|
||||
return
|
||||
}
|
||||
val, _ := env.GetState(key)
|
||||
if len(val) < 8 {
|
||||
stack[0] = 0
|
||||
return
|
||||
}
|
||||
stack[0] = api.EncodeI64(int64(binary.BigEndian.Uint64(val[:8])))
|
||||
}), []api.ValueType{api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI64}).
|
||||
Export("get_u64")
|
||||
|
||||
// ── Argument accessors ────────────────────────────────────────────────────
|
||||
// argsJSON is the JSON array from CallContractPayload.ArgsJSON.
|
||||
// These functions let contracts read typed call arguments at runtime.
|
||||
|
||||
// Lazily parse argsJSON once and cache the result.
|
||||
var parsedArgs []json.RawMessage
|
||||
if len(argsJSON) > 0 {
|
||||
_ = json.Unmarshal(argsJSON, &parsedArgs)
|
||||
}
|
||||
|
||||
// --- get_args(dstPtr i32, dstLen i32) → written i32 ---
|
||||
// Copies the raw args_json bytes into WASM memory.
|
||||
b.NewFunctionBuilder().
|
||||
WithGoModuleFunction(api.GoModuleFunc(func(ctx context.Context, m api.Module, stack []uint64) {
|
||||
dstPtr := api.DecodeU32(stack[0])
|
||||
dstLen := api.DecodeU32(stack[1])
|
||||
n := uint32(len(argsJSON))
|
||||
if n > dstLen {
|
||||
n = dstLen
|
||||
}
|
||||
if n > 0 {
|
||||
m.Memory().Write(dstPtr, argsJSON[:n])
|
||||
}
|
||||
stack[0] = api.EncodeU32(n)
|
||||
}), []api.ValueType{api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
|
||||
Export("get_args")
|
||||
|
||||
// --- get_arg_str(idx i32, dstPtr i32, dstLen i32) → written i32 ---
|
||||
// Parses args_json as a JSON array, reads the idx-th element as a string,
|
||||
// copies the UTF-8 bytes (without quotes) into WASM memory.
|
||||
// Returns 0 if idx is out of range or the element is not a JSON string.
|
||||
b.NewFunctionBuilder().
|
||||
WithGoModuleFunction(api.GoModuleFunc(func(ctx context.Context, m api.Module, stack []uint64) {
|
||||
idx := api.DecodeU32(stack[0])
|
||||
dstPtr := api.DecodeU32(stack[1])
|
||||
dstLen := api.DecodeU32(stack[2])
|
||||
if int(idx) >= len(parsedArgs) {
|
||||
stack[0] = 0
|
||||
return
|
||||
}
|
||||
var s string
|
||||
if err := json.Unmarshal(parsedArgs[idx], &s); err != nil {
|
||||
stack[0] = 0
|
||||
return
|
||||
}
|
||||
n := uint32(len(s))
|
||||
if n > dstLen {
|
||||
n = dstLen
|
||||
}
|
||||
if n > 0 {
|
||||
m.Memory().Write(dstPtr, []byte(s[:n]))
|
||||
}
|
||||
stack[0] = api.EncodeU32(n)
|
||||
}), []api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
|
||||
Export("get_arg_str")
|
||||
|
||||
// --- get_arg_u64(idx i32) → val i64 ---
|
||||
// Parses args_json as a JSON array, reads the idx-th element as a uint64.
|
||||
// Returns 0 if idx is out of range or the element is not a JSON number.
|
||||
b.NewFunctionBuilder().
|
||||
WithGoModuleFunction(api.GoModuleFunc(func(ctx context.Context, m api.Module, stack []uint64) {
|
||||
idx := api.DecodeU32(stack[0])
|
||||
if int(idx) >= len(parsedArgs) {
|
||||
stack[0] = 0
|
||||
return
|
||||
}
|
||||
var n uint64
|
||||
if err := json.Unmarshal(parsedArgs[idx], &n); err != nil {
|
||||
stack[0] = 0
|
||||
return
|
||||
}
|
||||
stack[0] = api.EncodeI64(int64(n))
|
||||
}), []api.ValueType{api.ValueTypeI32}, []api.ValueType{api.ValueTypeI64}).
|
||||
Export("get_arg_u64")
|
||||
|
||||
// --- get_contract_treasury(bufPtr i32, bufLen i32) → written i32 ---
|
||||
// Returns the contract's ownerless treasury address as a 64-char hex string.
|
||||
// Derived as hex(sha256(contractID + ":treasury")); no private key exists.
|
||||
b.NewFunctionBuilder().
|
||||
WithGoModuleFunction(api.GoModuleFunc(func(ctx context.Context, m api.Module, stack []uint64) {
|
||||
bufPtr := api.DecodeU32(stack[0])
|
||||
bufLen := api.DecodeU32(stack[1])
|
||||
treasury := env.GetContractTreasury()
|
||||
n := uint32(len(treasury))
|
||||
if n > bufLen {
|
||||
n = bufLen
|
||||
}
|
||||
if n > 0 {
|
||||
m.Memory().Write(bufPtr, []byte(treasury[:n]))
|
||||
}
|
||||
stack[0] = api.EncodeU32(n)
|
||||
}), []api.ValueType{api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
|
||||
Export("get_contract_treasury")
|
||||
|
||||
// --- call_contract(cidPtr, cidLen, methodPtr, methodLen, argsPtr, argsLen i32) → i32 ---
|
||||
// Calls a method on another deployed contract. Returns 0 on success, 1 on error.
|
||||
// The sub-call's caller is set to the current contract ID.
|
||||
// Gas consumed by the sub-call is charged to the parent call's gas counter.
|
||||
b.NewFunctionBuilder().
|
||||
WithGoModuleFunction(api.GoModuleFunc(func(ctx context.Context, m api.Module, stack []uint64) {
|
||||
cidPtr := api.DecodeU32(stack[0])
|
||||
cidLen := api.DecodeU32(stack[1])
|
||||
mthPtr := api.DecodeU32(stack[2])
|
||||
mthLen := api.DecodeU32(stack[3])
|
||||
argPtr := api.DecodeU32(stack[4])
|
||||
argLen := api.DecodeU32(stack[5])
|
||||
|
||||
cid, ok1 := m.Memory().Read(cidPtr, cidLen)
|
||||
mth, ok2 := m.Memory().Read(mthPtr, mthLen)
|
||||
if !ok1 || !ok2 {
|
||||
stack[0] = api.EncodeU32(1)
|
||||
return
|
||||
}
|
||||
var args []byte
|
||||
if argLen > 0 {
|
||||
args, _ = m.Memory().Read(argPtr, argLen)
|
||||
}
|
||||
|
||||
// Give the sub-call whatever gas the parent has remaining.
|
||||
gc := gasFromContext(ctx)
|
||||
var gasLeft uint64 = 10_000 // safe fallback if no counter
|
||||
if gc != nil {
|
||||
gasLeft = gc.Remaining()
|
||||
}
|
||||
if gasLeft == 0 {
|
||||
stack[0] = api.EncodeU32(1)
|
||||
return
|
||||
}
|
||||
|
||||
gasUsed, err := env.CallContract(string(cid), string(mth), args, gasLeft)
|
||||
|
||||
// Charge parent counter for what the sub-call actually consumed.
|
||||
if gc != nil && gasUsed > 0 {
|
||||
_ = gc.charge(gasUsed)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
stack[0] = api.EncodeU32(1)
|
||||
return
|
||||
}
|
||||
stack[0] = api.EncodeU32(0)
|
||||
}), []api.ValueType{
|
||||
api.ValueTypeI32, api.ValueTypeI32, // contract_id
|
||||
api.ValueTypeI32, api.ValueTypeI32, // method
|
||||
api.ValueTypeI32, api.ValueTypeI32, // args_json
|
||||
}, []api.ValueType{api.ValueTypeI32}).
|
||||
Export("call_contract")
|
||||
|
||||
// --- gas_tick() ---
|
||||
// Called by instrumented loop headers; charges 1 gas unit per iteration.
|
||||
// Panics (traps the WASM module) when the gas budget is exhausted.
|
||||
b.NewFunctionBuilder().
|
||||
WithGoModuleFunction(api.GoModuleFunc(func(ctx context.Context, _ api.Module, _ []uint64) {
|
||||
gc := gasFromContext(ctx)
|
||||
if gc == nil {
|
||||
return
|
||||
}
|
||||
if err := gc.charge(1); err != nil {
|
||||
panic(err.Error()) // wazero catches this and surfaces it as a Call error
|
||||
}
|
||||
}), []api.ValueType{}, []api.ValueType{}).
|
||||
Export("gas_tick")
|
||||
|
||||
inst, err := b.Instantiate(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[VM] registerHostModule: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
return inst, nil
|
||||
}
|
||||
949
vm/instrument.go
Normal file
949
vm/instrument.go
Normal file
@@ -0,0 +1,949 @@
|
||||
// Package vm — WASM bytecode instrumentation for instruction-level gas metering.
|
||||
//
|
||||
// Instrument rewrites a WASM binary so that every loop header calls the host
|
||||
// function env.gas_tick(). This guarantees that any infinite loop—or any loop
|
||||
// whose iteration count is proportional to attacker-controlled input—is bounded
|
||||
// by the transaction gas limit.
|
||||
//
|
||||
// Algorithm
|
||||
//
|
||||
// 1. Parse sections.
|
||||
// 2. Find or add the function type () → () in the type section.
|
||||
// 3. Append env.gas_tick as the last function import (so existing import
|
||||
// indices are unchanged). Record numOldFuncImports = M.
|
||||
// 4. Rewrite the export section: any function export with index ≥ M → +1.
|
||||
// 5. Rewrite the code section:
|
||||
// • "call N" where N ≥ M → "call N+1" (defined-function shift)
|
||||
// • "loop bt …" injects "call $gas_tick" immediately after "loop bt"
|
||||
// 6. Reassemble.
|
||||
//
|
||||
// The call to gas_tick is also metered by the FunctionListener (100 gas/call),
|
||||
// so each loop iteration costs at least 101 gas units total.
|
||||
//
|
||||
// If the binary already contains an env.gas_tick import the function returns
|
||||
// the original bytes unchanged (idempotent).
|
||||
package vm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ── section and opcode constants ─────────────────────────────────────────────
|
||||
|
||||
const (
|
||||
wasmMagic = "\x00asm"
|
||||
wasmVersion = "\x01\x00\x00\x00"
|
||||
secType = 1
|
||||
secImport = 2
|
||||
secExport = 7
|
||||
secCode = 10
|
||||
importFunc = 0x00
|
||||
exportFunc = 0x00
|
||||
opBlock = 0x02
|
||||
opLoop = 0x03
|
||||
opIf = 0x04
|
||||
opElse = 0x05
|
||||
opEnd = 0x0B
|
||||
opBr = 0x0C
|
||||
opBrIf = 0x0D
|
||||
opBrTable = 0x0E
|
||||
opCall = 0x10
|
||||
opCallInd = 0x11
|
||||
opLocalGet = 0x20
|
||||
opLocalSet = 0x21
|
||||
opLocalTee = 0x22
|
||||
opGlobalGet = 0x23
|
||||
opGlobalSet = 0x24
|
||||
opTableGet = 0x25
|
||||
opTableSet = 0x26
|
||||
opI32Const = 0x41
|
||||
opI64Const = 0x42
|
||||
opF32Const = 0x43
|
||||
opF64Const = 0x44
|
||||
opMemSize = 0x3F
|
||||
opMemGrow = 0x40
|
||||
opRefNull = 0xD0
|
||||
opRefIsNull = 0xD1
|
||||
opRefFunc = 0xD2
|
||||
opSelectT = 0x1C
|
||||
opPrefixFC = 0xFC
|
||||
)
|
||||
|
||||
// ── LEB128 helpers ────────────────────────────────────────────────────────────
|
||||
|
||||
func readU32Leb(b []byte) (uint32, int) {
|
||||
var x uint32
|
||||
var s uint
|
||||
for i, by := range b {
|
||||
if i == 5 {
|
||||
return 0, -1
|
||||
}
|
||||
x |= uint32(by&0x7F) << s
|
||||
s += 7
|
||||
if by&0x80 == 0 {
|
||||
return x, i + 1
|
||||
}
|
||||
}
|
||||
return 0, -1
|
||||
}
|
||||
|
||||
func readI32Leb(b []byte) (int32, int) {
|
||||
var x int32
|
||||
var s uint
|
||||
for i, by := range b {
|
||||
if i == 5 {
|
||||
return 0, -1
|
||||
}
|
||||
x |= int32(by&0x7F) << s
|
||||
s += 7
|
||||
if by&0x80 == 0 {
|
||||
if s < 32 && (by&0x40) != 0 {
|
||||
x |= ^0 << s
|
||||
}
|
||||
return x, i + 1
|
||||
}
|
||||
}
|
||||
return 0, -1
|
||||
}
|
||||
|
||||
func readI64Leb(b []byte) (int64, int) {
|
||||
var x int64
|
||||
var s uint
|
||||
for i, by := range b {
|
||||
if i == 10 {
|
||||
return 0, -1
|
||||
}
|
||||
x |= int64(by&0x7F) << s
|
||||
s += 7
|
||||
if by&0x80 == 0 {
|
||||
if s < 64 && (by&0x40) != 0 {
|
||||
x |= ^int64(0) << s
|
||||
}
|
||||
return x, i + 1
|
||||
}
|
||||
}
|
||||
return 0, -1
|
||||
}
|
||||
|
||||
func appendU32Leb(out []byte, v uint32) []byte {
|
||||
for {
|
||||
b := byte(v & 0x7F)
|
||||
v >>= 7
|
||||
if v != 0 {
|
||||
b |= 0x80
|
||||
}
|
||||
out = append(out, b)
|
||||
if v == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// skipU32Leb returns the number of bytes consumed by one unsigned LEB128.
|
||||
func skipU32Leb(b []byte) (int, error) {
|
||||
_, n := readU32Leb(b)
|
||||
if n <= 0 {
|
||||
return 0, errors.New("bad unsigned LEB128")
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func skipI32Leb(b []byte) (int, error) {
|
||||
_, n := readI32Leb(b)
|
||||
if n <= 0 {
|
||||
return 0, errors.New("bad signed LEB128 i32")
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func skipI64Leb(b []byte) (int, error) {
|
||||
_, n := readI64Leb(b)
|
||||
if n <= 0 {
|
||||
return 0, errors.New("bad signed LEB128 i64")
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// ── WASM type for gas_tick: () → () ──────────────────────────────────────────
|
||||
// Encoded as: 0x60 (functype) 0x00 (0 params) 0x00 (0 results)
|
||||
var gasTickFuncType = []byte{0x60, 0x00, 0x00}
|
||||
|
||||
// ── Top-level Instrument ──────────────────────────────────────────────────────
|
||||
|
||||
// rawSection holds one parsed section.
|
||||
type rawSection struct {
|
||||
id byte
|
||||
data []byte
|
||||
}
|
||||
|
||||
// Instrument returns a copy of wasm with env.gas_tick calls injected at every
|
||||
// loop header. Returns wasm unchanged if already instrumented or if wasm has
|
||||
// no code section. Returns an error only on malformed binaries.
|
||||
func Instrument(wasm []byte) ([]byte, error) {
|
||||
if len(wasm) < 8 || string(wasm[:4]) != wasmMagic || string(wasm[4:8]) != wasmVersion {
|
||||
return nil, errors.New("not a valid WASM binary")
|
||||
}
|
||||
off := 8
|
||||
var sections []rawSection
|
||||
for off < len(wasm) {
|
||||
if off >= len(wasm) {
|
||||
break
|
||||
}
|
||||
id := wasm[off]
|
||||
off++
|
||||
size, n := readU32Leb(wasm[off:])
|
||||
if n <= 0 {
|
||||
return nil, errors.New("bad section size LEB128")
|
||||
}
|
||||
off += n
|
||||
end := off + int(size)
|
||||
if end > len(wasm) {
|
||||
return nil, fmt.Errorf("section %d size %d exceeds binary length", id, size)
|
||||
}
|
||||
sections = append(sections, rawSection{id: id, data: wasm[off:end]})
|
||||
off = end
|
||||
}
|
||||
|
||||
// Locate key sections.
|
||||
typeIdx, importIdx, exportIdx, codeIdx := -1, -1, -1, -1
|
||||
for i, s := range sections {
|
||||
switch s.id {
|
||||
case secType:
|
||||
typeIdx = i
|
||||
case secImport:
|
||||
importIdx = i
|
||||
case secExport:
|
||||
exportIdx = i
|
||||
case secCode:
|
||||
codeIdx = i
|
||||
}
|
||||
}
|
||||
if codeIdx < 0 {
|
||||
return wasm, nil // nothing to instrument
|
||||
}
|
||||
|
||||
// Idempotency: already instrumented?
|
||||
if importIdx >= 0 && containsGasTick(sections[importIdx].data) {
|
||||
return wasm, nil
|
||||
}
|
||||
|
||||
// ── Step 1: find or add gas_tick type ─────────────────────────────────────
|
||||
var gasTickTypeIdx uint32
|
||||
if typeIdx >= 0 {
|
||||
var err error
|
||||
sections[typeIdx].data, gasTickTypeIdx, err = ensureGasTickType(sections[typeIdx].data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("type section: %w", err)
|
||||
}
|
||||
} else {
|
||||
// Create a minimal type section containing only the gas_tick type.
|
||||
ts := appendU32Leb(nil, 1) // count = 1
|
||||
ts = append(ts, gasTickFuncType...)
|
||||
sections = insertBefore(sections, secImport, rawSection{id: secType, data: ts})
|
||||
// Recalculate section indices.
|
||||
typeIdx, importIdx, exportIdx, codeIdx = findSections(sections)
|
||||
gasTickTypeIdx = 0
|
||||
}
|
||||
|
||||
// ── Step 2: add gas_tick import, count old func imports ───────────────────
|
||||
var numOldFuncImports uint32
|
||||
var gasFnIdx uint32
|
||||
if importIdx >= 0 {
|
||||
var err error
|
||||
sections[importIdx].data, numOldFuncImports, err = appendGasTickImport(sections[importIdx].data, gasTickTypeIdx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("import section: %w", err)
|
||||
}
|
||||
gasFnIdx = numOldFuncImports // gas_tick is the new last import
|
||||
} else {
|
||||
// Build import section with just gas_tick.
|
||||
is, err := buildImportSection(gasTickTypeIdx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sections = insertBefore(sections, secExport, rawSection{id: secImport, data: is})
|
||||
typeIdx, importIdx, exportIdx, codeIdx = findSections(sections)
|
||||
numOldFuncImports = 0
|
||||
gasFnIdx = 0
|
||||
}
|
||||
|
||||
// ── Step 3: adjust export section ─────────────────────────────────────────
|
||||
if exportIdx >= 0 {
|
||||
var err error
|
||||
sections[exportIdx].data, err = adjustExportFuncIndices(sections[exportIdx].data, numOldFuncImports)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("export section: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Step 4: rewrite code section ──────────────────────────────────────────
|
||||
var err error
|
||||
sections[codeIdx].data, err = rewriteCodeSection(sections[codeIdx].data, numOldFuncImports, gasFnIdx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("code section: %w", err)
|
||||
}
|
||||
|
||||
// ── Reassemble ────────────────────────────────────────────────────────────
|
||||
out := make([]byte, 0, len(wasm)+128)
|
||||
out = append(out, []byte(wasmMagic+wasmVersion)...)
|
||||
for _, s := range sections {
|
||||
out = append(out, s.id)
|
||||
out = appendU32Leb(out, uint32(len(s.data)))
|
||||
out = append(out, s.data...)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// ── type section helpers ──────────────────────────────────────────────────────
|
||||
|
||||
// ensureGasTickType finds () → () in the type section or appends it.
|
||||
// Returns the (possibly rewritten) section bytes and the type index.
|
||||
func ensureGasTickType(data []byte) ([]byte, uint32, error) {
|
||||
off := 0
|
||||
count, n := readU32Leb(data[off:])
|
||||
if n <= 0 {
|
||||
return nil, 0, errors.New("bad type count")
|
||||
}
|
||||
off += n
|
||||
start := off
|
||||
for i := uint32(0); i < count; i++ {
|
||||
entryStart := off
|
||||
if off >= len(data) || data[off] != 0x60 {
|
||||
return nil, 0, fmt.Errorf("type %d: expected 0x60", i)
|
||||
}
|
||||
off++
|
||||
// params
|
||||
pc, n := readU32Leb(data[off:])
|
||||
if n <= 0 {
|
||||
return nil, 0, fmt.Errorf("type %d: bad param count", i)
|
||||
}
|
||||
off += n
|
||||
paramStart := off
|
||||
off += int(pc) // each valtype is 1 byte
|
||||
// results
|
||||
rc, n := readU32Leb(data[off:])
|
||||
if n <= 0 {
|
||||
return nil, 0, fmt.Errorf("type %d: bad result count", i)
|
||||
}
|
||||
off += n
|
||||
off += int(rc)
|
||||
// Is this () → () ?
|
||||
if pc == 0 && rc == 0 {
|
||||
_ = paramStart
|
||||
_ = entryStart
|
||||
_ = start
|
||||
return data, i, nil // already present
|
||||
}
|
||||
}
|
||||
// Not found — append.
|
||||
newData := appendU32Leb(nil, count+1)
|
||||
newData = append(newData, data[start:]...) // all existing type entries
|
||||
newData = append(newData, gasTickFuncType...)
|
||||
return newData, count, nil
|
||||
}
|
||||
|
||||
// ── import section helpers ────────────────────────────────────────────────────
|
||||
|
||||
// containsGasTick returns true if env.gas_tick is already imported.
|
||||
func containsGasTick(data []byte) bool {
|
||||
off := 0
|
||||
count, n := readU32Leb(data[off:])
|
||||
if n <= 0 {
|
||||
return false
|
||||
}
|
||||
off += n
|
||||
for i := uint32(0); i < count; i++ {
|
||||
// mod name
|
||||
ml, n := readU32Leb(data[off:])
|
||||
if n <= 0 {
|
||||
return false
|
||||
}
|
||||
off += n
|
||||
mod := string(data[off : off+int(ml)])
|
||||
off += int(ml)
|
||||
// field name
|
||||
fl, n := readU32Leb(data[off:])
|
||||
if n <= 0 {
|
||||
return false
|
||||
}
|
||||
off += n
|
||||
field := string(data[off : off+int(fl)])
|
||||
off += int(fl)
|
||||
// importdesc kind
|
||||
if off >= len(data) {
|
||||
return false
|
||||
}
|
||||
kind := data[off]
|
||||
off++
|
||||
if mod == "env" && field == "gas_tick" {
|
||||
return true
|
||||
}
|
||||
switch kind {
|
||||
case 0x00: // function: type idx
|
||||
_, n = readU32Leb(data[off:])
|
||||
if n <= 0 {
|
||||
return false
|
||||
}
|
||||
off += n
|
||||
case 0x01: // table: reftype + limits
|
||||
off++ // reftype
|
||||
lkind := data[off]
|
||||
off++
|
||||
_, n = readU32Leb(data[off:])
|
||||
if n <= 0 {
|
||||
return false
|
||||
}
|
||||
off += n
|
||||
if lkind == 1 {
|
||||
_, n = readU32Leb(data[off:])
|
||||
if n <= 0 {
|
||||
return false
|
||||
}
|
||||
off += n
|
||||
}
|
||||
case 0x02: // memory: limits
|
||||
lkind := data[off]
|
||||
off++
|
||||
_, n = readU32Leb(data[off:])
|
||||
if n <= 0 {
|
||||
return false
|
||||
}
|
||||
off += n
|
||||
if lkind == 1 {
|
||||
_, n = readU32Leb(data[off:])
|
||||
if n <= 0 {
|
||||
return false
|
||||
}
|
||||
off += n
|
||||
}
|
||||
case 0x03: // global: valtype + mutability
|
||||
off += 2
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// appendGasTickImport adds env.gas_tick as the last function import.
|
||||
// Returns (newData, numOldFuncImports, error).
|
||||
func appendGasTickImport(data []byte, typeIdx uint32) ([]byte, uint32, error) {
|
||||
off := 0
|
||||
count, n := readU32Leb(data[off:])
|
||||
if n <= 0 {
|
||||
return nil, 0, errors.New("bad import count")
|
||||
}
|
||||
off += n
|
||||
start := off
|
||||
var numFuncImports uint32
|
||||
for i := uint32(0); i < count; i++ {
|
||||
// mod name
|
||||
ml, n := readU32Leb(data[off:])
|
||||
if n <= 0 {
|
||||
return nil, 0, errors.New("bad import mod len")
|
||||
}
|
||||
off += n + int(ml)
|
||||
// field name
|
||||
fl, n := readU32Leb(data[off:])
|
||||
if n <= 0 {
|
||||
return nil, 0, errors.New("bad import field len")
|
||||
}
|
||||
off += n + int(fl)
|
||||
// importdesc
|
||||
if off >= len(data) {
|
||||
return nil, 0, errors.New("truncated importdesc")
|
||||
}
|
||||
kind := data[off]
|
||||
off++
|
||||
switch kind {
|
||||
case 0x00: // function
|
||||
numFuncImports++
|
||||
_, n = readU32Leb(data[off:])
|
||||
if n <= 0 {
|
||||
return nil, 0, errors.New("bad func import type idx")
|
||||
}
|
||||
off += n
|
||||
case 0x01: // table
|
||||
off++ // reftype
|
||||
lkind := data[off]
|
||||
off++
|
||||
_, n = readU32Leb(data[off:])
|
||||
if n <= 0 {
|
||||
return nil, 0, errors.New("bad table import limit")
|
||||
}
|
||||
off += n
|
||||
if lkind == 1 {
|
||||
_, n = readU32Leb(data[off:])
|
||||
if n <= 0 {
|
||||
return nil, 0, errors.New("bad table import max")
|
||||
}
|
||||
off += n
|
||||
}
|
||||
case 0x02: // memory
|
||||
lkind := data[off]
|
||||
off++
|
||||
_, n = readU32Leb(data[off:])
|
||||
if n <= 0 {
|
||||
return nil, 0, errors.New("bad mem import limit")
|
||||
}
|
||||
off += n
|
||||
if lkind == 1 {
|
||||
_, n = readU32Leb(data[off:])
|
||||
if n <= 0 {
|
||||
return nil, 0, errors.New("bad mem import max")
|
||||
}
|
||||
off += n
|
||||
}
|
||||
case 0x03: // global
|
||||
off += 2
|
||||
}
|
||||
}
|
||||
|
||||
// Build new import section: count+1, existing entries, then gas_tick entry.
|
||||
newData := appendU32Leb(nil, count+1)
|
||||
newData = append(newData, data[start:]...)
|
||||
// env.gas_tick entry
|
||||
newData = appendU32Leb(newData, 3) // "env" length
|
||||
newData = append(newData, "env"...)
|
||||
newData = appendU32Leb(newData, 8) // "gas_tick" length
|
||||
newData = append(newData, "gas_tick"...)
|
||||
newData = append(newData, importFunc) // kind = function
|
||||
newData = appendU32Leb(newData, typeIdx)
|
||||
return newData, numFuncImports, nil
|
||||
}
|
||||
|
||||
// buildImportSection builds a section containing only env.gas_tick.
|
||||
func buildImportSection(typeIdx uint32) ([]byte, error) {
|
||||
var data []byte
|
||||
data = appendU32Leb(data, 1) // count = 1
|
||||
data = appendU32Leb(data, 3)
|
||||
data = append(data, "env"...)
|
||||
data = appendU32Leb(data, 8)
|
||||
data = append(data, "gas_tick"...)
|
||||
data = append(data, importFunc)
|
||||
data = appendU32Leb(data, typeIdx)
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// ── export section adjustment ─────────────────────────────────────────────────
|
||||
|
||||
// adjustExportFuncIndices increments function export indices ≥ threshold by 1.
|
||||
func adjustExportFuncIndices(data []byte, threshold uint32) ([]byte, error) {
|
||||
off := 0
|
||||
count, n := readU32Leb(data[off:])
|
||||
if n <= 0 {
|
||||
return nil, errors.New("bad export count")
|
||||
}
|
||||
off += n
|
||||
|
||||
out := appendU32Leb(nil, count)
|
||||
for i := uint32(0); i < count; i++ {
|
||||
// name
|
||||
nl, n := readU32Leb(data[off:])
|
||||
if n <= 0 {
|
||||
return nil, fmt.Errorf("export %d: bad name len", i)
|
||||
}
|
||||
off += n
|
||||
out = appendU32Leb(out, nl)
|
||||
out = append(out, data[off:off+int(nl)]...)
|
||||
off += int(nl)
|
||||
// kind
|
||||
if off >= len(data) {
|
||||
return nil, fmt.Errorf("export %d: truncated kind", i)
|
||||
}
|
||||
kind := data[off]
|
||||
off++
|
||||
out = append(out, kind)
|
||||
// index
|
||||
idx, n := readU32Leb(data[off:])
|
||||
if n <= 0 {
|
||||
return nil, fmt.Errorf("export %d: bad index", i)
|
||||
}
|
||||
off += n
|
||||
if kind == exportFunc && idx >= threshold {
|
||||
idx++
|
||||
}
|
||||
out = appendU32Leb(out, idx)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// ── code section rewriting ────────────────────────────────────────────────────
|
||||
|
||||
func rewriteCodeSection(data []byte, numOldImports, gasFnIdx uint32) ([]byte, error) {
|
||||
off := 0
|
||||
count, n := readU32Leb(data[off:])
|
||||
if n <= 0 {
|
||||
return nil, errors.New("bad code section count")
|
||||
}
|
||||
off += n
|
||||
|
||||
out := appendU32Leb(nil, count)
|
||||
for i := uint32(0); i < count; i++ {
|
||||
bodySize, n := readU32Leb(data[off:])
|
||||
if n <= 0 {
|
||||
return nil, fmt.Errorf("code entry %d: bad body size", i)
|
||||
}
|
||||
off += n
|
||||
body := data[off : off+int(bodySize)]
|
||||
off += int(bodySize)
|
||||
|
||||
newBody, err := rewriteFuncBody(body, numOldImports, gasFnIdx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("code entry %d: %w", i, err)
|
||||
}
|
||||
out = appendU32Leb(out, uint32(len(newBody)))
|
||||
out = append(out, newBody...)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func rewriteFuncBody(body []byte, numOldImports, gasFnIdx uint32) ([]byte, error) {
|
||||
off := 0
|
||||
// local variable declarations
|
||||
localGroupCount, n := readU32Leb(body[off:])
|
||||
if n <= 0 {
|
||||
return nil, errors.New("bad local group count")
|
||||
}
|
||||
var out []byte
|
||||
out = appendU32Leb(out, localGroupCount)
|
||||
off += n
|
||||
for i := uint32(0); i < localGroupCount; i++ {
|
||||
cnt, n := readU32Leb(body[off:])
|
||||
if n <= 0 {
|
||||
return nil, fmt.Errorf("local group %d: bad count", i)
|
||||
}
|
||||
off += n
|
||||
if off >= len(body) {
|
||||
return nil, fmt.Errorf("local group %d: missing valtype", i)
|
||||
}
|
||||
out = appendU32Leb(out, cnt)
|
||||
out = append(out, body[off]) // valtype
|
||||
off++
|
||||
}
|
||||
// instruction stream
|
||||
instrs, err := rewriteExpr(body[off:], numOldImports, gasFnIdx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, instrs...)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// rewriteExpr processes one expression (terminated by the matching end).
|
||||
func rewriteExpr(code []byte, numOldImports, gasFnIdx uint32) ([]byte, error) {
|
||||
off := 0
|
||||
var out []byte
|
||||
depth := 1 // implicit outer scope; final end at depth==1 terminates expression
|
||||
|
||||
for off < len(code) {
|
||||
op := code[off]
|
||||
|
||||
switch op {
|
||||
case opBlock, opIf:
|
||||
off++
|
||||
btLen, err := blocktypeLen(code[off:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("block/if blocktype: %w", err)
|
||||
}
|
||||
out = append(out, op)
|
||||
out = append(out, code[off:off+btLen]...)
|
||||
off += btLen
|
||||
depth++
|
||||
|
||||
case opLoop:
|
||||
off++
|
||||
btLen, err := blocktypeLen(code[off:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loop blocktype: %w", err)
|
||||
}
|
||||
out = append(out, op)
|
||||
out = append(out, code[off:off+btLen]...)
|
||||
off += btLen
|
||||
depth++
|
||||
// Inject gas_tick call at loop header.
|
||||
out = append(out, opCall)
|
||||
out = appendU32Leb(out, gasFnIdx)
|
||||
|
||||
case opElse:
|
||||
out = append(out, op)
|
||||
off++
|
||||
|
||||
case opEnd:
|
||||
out = append(out, op)
|
||||
off++
|
||||
depth--
|
||||
if depth == 0 {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
case opCall:
|
||||
off++
|
||||
idx, n := readU32Leb(code[off:])
|
||||
if n <= 0 {
|
||||
return nil, errors.New("call: bad function index LEB128")
|
||||
}
|
||||
off += n
|
||||
if idx >= numOldImports {
|
||||
idx++ // shift defined-function index
|
||||
}
|
||||
out = append(out, opCall)
|
||||
out = appendU32Leb(out, idx)
|
||||
|
||||
case opCallInd:
|
||||
off++
|
||||
typeIdx, n := readU32Leb(code[off:])
|
||||
if n <= 0 {
|
||||
return nil, errors.New("call_indirect: bad type index")
|
||||
}
|
||||
off += n
|
||||
tableIdx, n := readU32Leb(code[off:])
|
||||
if n <= 0 {
|
||||
return nil, errors.New("call_indirect: bad table index")
|
||||
}
|
||||
off += n
|
||||
out = append(out, opCallInd)
|
||||
out = appendU32Leb(out, typeIdx)
|
||||
out = appendU32Leb(out, tableIdx)
|
||||
|
||||
case opRefFunc:
|
||||
off++
|
||||
idx, n := readU32Leb(code[off:])
|
||||
if n <= 0 {
|
||||
return nil, errors.New("ref.func: bad index")
|
||||
}
|
||||
off += n
|
||||
if idx >= numOldImports {
|
||||
idx++
|
||||
}
|
||||
out = append(out, opRefFunc)
|
||||
out = appendU32Leb(out, idx)
|
||||
|
||||
default:
|
||||
// Copy instruction verbatim; instrLen handles all remaining opcodes.
|
||||
ilen, err := instrLen(code[off:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("at offset %d: %w", off, err)
|
||||
}
|
||||
out = append(out, code[off:off+ilen]...)
|
||||
off += ilen
|
||||
}
|
||||
}
|
||||
return nil, errors.New("expression missing terminating end")
|
||||
}
|
||||
|
||||
// blocktypeLen returns the byte length of a blocktype immediate.
|
||||
// Blocktypes are SLEB128-encoded: value types and void are single negative bytes;
|
||||
// type-index blocktypes are non-negative LEB128 (multi-byte for indices ≥ 64).
|
||||
func blocktypeLen(b []byte) (int, error) {
|
||||
_, n := readI32Leb(b)
|
||||
if n <= 0 {
|
||||
return 0, errors.New("bad blocktype encoding")
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// instrLen returns the total byte length (opcode + immediates) of the instruction
|
||||
// at code[0]. It handles all opcodes EXCEPT block/loop/if/else/end/call/
|
||||
// call_indirect/ref.func — those are handled directly in rewriteExpr.
|
||||
func instrLen(code []byte) (int, error) {
|
||||
if len(code) == 0 {
|
||||
return 0, errors.New("empty instruction stream")
|
||||
}
|
||||
op := code[0]
|
||||
switch {
|
||||
// 1-byte, no immediates — covers all integer/float arithmetic, comparisons,
|
||||
// conversion, drop, select, return, unreachable, nop, else, end, ref.is_null.
|
||||
case op == 0x00 || op == 0x01 || op == 0x0F || op == 0x1A || op == 0x1B ||
|
||||
op == opRefIsNull ||
|
||||
(op >= 0x45 && op <= 0xC4):
|
||||
return 1, nil
|
||||
|
||||
// br, br_if, local.get/set/tee, global.get/set, table.get/set — one u32 LEB128
|
||||
case op == opBr || op == opBrIf ||
|
||||
op == opLocalGet || op == opLocalSet || op == opLocalTee ||
|
||||
op == opGlobalGet || op == opGlobalSet ||
|
||||
op == opTableGet || op == opTableSet:
|
||||
n, err := skipU32Leb(code[1:])
|
||||
return 1 + n, err
|
||||
|
||||
// br_table: u32 count N, then N+1 u32 values
|
||||
case op == opBrTable:
|
||||
off := 1
|
||||
cnt, n := readU32Leb(code[off:])
|
||||
if n <= 0 {
|
||||
return 0, errors.New("br_table: bad count")
|
||||
}
|
||||
off += n
|
||||
for i := uint32(0); i <= cnt; i++ {
|
||||
n, err := skipU32Leb(code[off:])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("br_table target %d: %w", i, err)
|
||||
}
|
||||
off += n
|
||||
}
|
||||
return off, nil
|
||||
|
||||
// memory load/store (0x28–0x3E): align u32 + offset u32
|
||||
case op >= 0x28 && op <= 0x3E:
|
||||
off := 1
|
||||
n, err := skipU32Leb(code[off:])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("mem instr 0x%02X align: %w", op, err)
|
||||
}
|
||||
off += n
|
||||
n, err = skipU32Leb(code[off:])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("mem instr 0x%02X offset: %w", op, err)
|
||||
}
|
||||
return 1 + (off - 1 + n), nil
|
||||
|
||||
// memory.size (0x3F), memory.grow (0x40): reserved byte 0x00
|
||||
case op == opMemSize || op == opMemGrow:
|
||||
return 2, nil
|
||||
|
||||
// i32.const: signed LEB128 i32
|
||||
case op == opI32Const:
|
||||
n, err := skipI32Leb(code[1:])
|
||||
return 1 + n, err
|
||||
|
||||
// i64.const: signed LEB128 i64
|
||||
case op == opI64Const:
|
||||
n, err := skipI64Leb(code[1:])
|
||||
return 1 + n, err
|
||||
|
||||
// f32.const: 4 raw bytes
|
||||
case op == opF32Const:
|
||||
if len(code) < 5 {
|
||||
return 0, errors.New("f32.const: truncated")
|
||||
}
|
||||
return 5, nil
|
||||
|
||||
// f64.const: 8 raw bytes
|
||||
case op == opF64Const:
|
||||
if len(code) < 9 {
|
||||
return 0, errors.New("f64.const: truncated")
|
||||
}
|
||||
return 9, nil
|
||||
|
||||
// ref.null: 1 byte reftype
|
||||
case op == opRefNull:
|
||||
return 2, nil
|
||||
|
||||
// select with types (0x1C): u32 count + that many valtypes (each 1 byte)
|
||||
case op == opSelectT:
|
||||
cnt, n := readU32Leb(code[1:])
|
||||
if n <= 0 {
|
||||
return 0, errors.New("select t: bad count")
|
||||
}
|
||||
return 1 + n + int(cnt), nil
|
||||
|
||||
// 0xFC prefix — saturating truncation, table/memory bulk ops
|
||||
case op == opPrefixFC:
|
||||
sub, n := readU32Leb(code[1:])
|
||||
if n <= 0 {
|
||||
return 0, errors.New("0xFC: bad sub-opcode")
|
||||
}
|
||||
off := 1 + n
|
||||
// sub-ops that have additional immediates
|
||||
switch sub {
|
||||
case 8: // memory.init: data_idx u32, mem idx u8 (0x00)
|
||||
n2, err := skipU32Leb(code[off:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
off += n2
|
||||
off++ // reserved memory index byte
|
||||
case 9: // data.drop: data_idx u32
|
||||
n2, err := skipU32Leb(code[off:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
off += n2
|
||||
case 10: // memory.copy: two reserved 0x00 bytes
|
||||
off += 2
|
||||
case 11: // memory.fill: reserved 0x00
|
||||
off++
|
||||
case 12: // table.init: elem_idx u32, table_idx u32
|
||||
n2, err := skipU32Leb(code[off:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
off += n2
|
||||
n2, err = skipU32Leb(code[off:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
off += n2
|
||||
case 13: // elem.drop: elem_idx u32
|
||||
n2, err := skipU32Leb(code[off:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
off += n2
|
||||
case 14: // table.copy: dst_idx u32, src_idx u32
|
||||
n2, err := skipU32Leb(code[off:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
off += n2
|
||||
n2, err = skipU32Leb(code[off:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
off += n2
|
||||
case 15, 16: // table.grow, table.size: table_idx u32
|
||||
n2, err := skipU32Leb(code[off:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
off += n2
|
||||
case 17: // table.fill: table_idx u32
|
||||
n2, err := skipU32Leb(code[off:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
off += n2
|
||||
// 0–7: i32/i64 saturating truncation — no additional immediates
|
||||
}
|
||||
return off, nil
|
||||
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown opcode 0x%02X", op)
|
||||
}
|
||||
}
|
||||
|
||||
// ── section list helpers ──────────────────────────────────────────────────────
|
||||
|
||||
func findSections(ss []rawSection) (typeIdx, importIdx, exportIdx, codeIdx int) {
|
||||
typeIdx, importIdx, exportIdx, codeIdx = -1, -1, -1, -1
|
||||
for i, s := range ss {
|
||||
switch s.id {
|
||||
case secType:
|
||||
typeIdx = i
|
||||
case secImport:
|
||||
importIdx = i
|
||||
case secExport:
|
||||
exportIdx = i
|
||||
case secCode:
|
||||
codeIdx = i
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// insertBefore inserts ns before the first section with the given id,
|
||||
// or at the end if no such section exists.
|
||||
func insertBefore(ss []rawSection, id byte, ns rawSection) []rawSection {
|
||||
for i, s := range ss {
|
||||
if s.id == id {
|
||||
result := make([]rawSection, 0, len(ss)+1)
|
||||
result = append(result, ss[:i]...)
|
||||
result = append(result, ns)
|
||||
result = append(result, ss[i:]...)
|
||||
return result
|
||||
}
|
||||
}
|
||||
return append(ss, ns)
|
||||
}
|
||||
|
||||
184
vm/vm.go
Normal file
184
vm/vm.go
Normal file
@@ -0,0 +1,184 @@
|
||||
// Package vm provides a WASM-based smart contract execution engine.
|
||||
//
|
||||
// The engine uses wazero (pure-Go, no CGO) in interpreter mode to guarantee
|
||||
// deterministic execution across all platforms — a requirement for consensus.
|
||||
//
|
||||
// Contract lifecycle:
|
||||
// 1. DEPLOY_CONTRACT tx → chain calls Validate() to compile-check the WASM,
|
||||
// then stores the bytecode in BadgerDB.
|
||||
// 2. CALL_CONTRACT tx → chain calls Call() with the stored WASM bytes,
|
||||
// the method name, JSON args, and a gas limit.
|
||||
//
|
||||
// Gas model: each WASM function call costs gasPerCall (100) units.
|
||||
// Gas cost in µT = gasUsed × blockchain.GasPrice.
|
||||
package vm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
|
||||
"github.com/tetratelabs/wazero"
|
||||
"github.com/tetratelabs/wazero/experimental"
|
||||
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
|
||||
|
||||
"go-blockchain/blockchain"
|
||||
)
|
||||
|
||||
// VM is a WASM execution engine. Create one per process; it is safe for
|
||||
// concurrent use. Compiled modules are cached by contract ID.
|
||||
type VM struct {
|
||||
rt wazero.Runtime
|
||||
mu sync.RWMutex
|
||||
cache map[string]wazero.CompiledModule // contractID → compiled module
|
||||
}
|
||||
|
||||
// NewVM creates a VM with an interpreter-mode wazero runtime.
|
||||
// Interpreter mode guarantees identical behaviour on all OS/arch combos,
|
||||
// which is required for all nodes to reach the same state.
|
||||
//
|
||||
// WASI preview 1 is pre-instantiated to support contracts compiled with
|
||||
// TinyGo's wasip1 target (tinygo build -target wasip1). Contracts that do
|
||||
// not import from wasi_snapshot_preview1 are unaffected.
|
||||
func NewVM(ctx context.Context) *VM {
|
||||
cfg := wazero.NewRuntimeConfigInterpreter()
|
||||
cfg = cfg.WithDebugInfoEnabled(false)
|
||||
// Enable cooperative cancellation: fn.Call(ctx) will return when ctx is
|
||||
// cancelled, even if the contract is in an infinite loop. Without this,
|
||||
// a buggy or malicious contract that dodges gas metering (e.g. a tight
|
||||
// loop of opcodes not hooked by the gas listener) would hang the
|
||||
// AddBlock goroutine forever, freezing the entire chain.
|
||||
cfg = cfg.WithCloseOnContextDone(true)
|
||||
// Attach call-level gas metering.
|
||||
ctx = experimental.WithFunctionListenerFactory(ctx, gasListenerFactory{})
|
||||
rt := wazero.NewRuntimeWithConfig(ctx, cfg)
|
||||
// Instantiate WASI so TinyGo contracts can call proc_exit and basic I/O.
|
||||
// The sandbox is fully isolated — no filesystem or network access.
|
||||
wasi_snapshot_preview1.MustInstantiate(ctx, rt)
|
||||
return &VM{
|
||||
rt: rt,
|
||||
cache: make(map[string]wazero.CompiledModule),
|
||||
}
|
||||
}
|
||||
|
||||
// Close releases the underlying wazero runtime.
|
||||
func (v *VM) Close(ctx context.Context) error {
|
||||
return v.rt.Close(ctx)
|
||||
}
|
||||
|
||||
// Validate compiles the WASM bytes without executing them.
|
||||
// Returns an error if the bytes are not valid WASM or import unknown symbols.
|
||||
// Implements blockchain.ContractVM.
|
||||
func (v *VM) Validate(ctx context.Context, wasmBytes []byte) error {
|
||||
_, err := v.rt.CompileModule(ctx, wasmBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid WASM module: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Call executes method on the contract identified by contractID.
|
||||
// wasmBytes is the compiled WASM (loaded from DB by the chain).
|
||||
// Returns gas consumed. Returns ErrOutOfGas (wrapping ErrTxFailed) on exhaustion.
|
||||
// Implements blockchain.ContractVM.
|
||||
func (v *VM) Call(
|
||||
ctx context.Context,
|
||||
contractID string,
|
||||
wasmBytes []byte,
|
||||
method string,
|
||||
argsJSON []byte,
|
||||
gasLimit uint64,
|
||||
env blockchain.VMHostEnv,
|
||||
) (gasUsed uint64, err error) {
|
||||
// Attach gas counter to context.
|
||||
ctx, gc := withGasCounter(ctx, gasLimit)
|
||||
ctx = experimental.WithFunctionListenerFactory(ctx, gasListenerFactory{})
|
||||
|
||||
// Compile (or retrieve from cache).
|
||||
compiled, err := v.compiled(ctx, contractID, wasmBytes)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("compile contract %s: %w", contractID, err)
|
||||
}
|
||||
|
||||
// Instantiate the "env" host module for this call, wiring in typed args.
|
||||
hostInst, err := registerHostModule(ctx, v.rt, env, argsJSON)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("register host module: %w", err)
|
||||
}
|
||||
defer hostInst.Close(ctx)
|
||||
|
||||
// Instantiate the contract module.
|
||||
modCfg := wazero.NewModuleConfig().
|
||||
WithName(""). // anonymous — allows multiple concurrent instances
|
||||
WithStartFunctions() // do not auto-call _start
|
||||
mod, err := v.rt.InstantiateModule(ctx, compiled, modCfg)
|
||||
if err != nil {
|
||||
return gc.Used(), fmt.Errorf("instantiate contract: %w", err)
|
||||
}
|
||||
defer mod.Close(ctx)
|
||||
|
||||
// Look up the exported method.
|
||||
fn := mod.ExportedFunction(method)
|
||||
if fn == nil {
|
||||
return gc.Used(), fmt.Errorf("%w: method %q not exported by contract %s",
|
||||
blockchain.ErrTxFailed, method, contractID)
|
||||
}
|
||||
|
||||
// Call. WASM functions called from Go pass args via the stack; our contracts
|
||||
// use no parameters — all input/output goes through host state functions.
|
||||
_, callErr := fn.Call(ctx)
|
||||
gasUsed = gc.Used()
|
||||
|
||||
if callErr != nil {
|
||||
log.Printf("[VM] contract %s.%s error: %v", contractID[:8], method, callErr)
|
||||
if isOutOfGas(gc) {
|
||||
return gasUsed, ErrOutOfGas
|
||||
}
|
||||
return gasUsed, fmt.Errorf("%w: %v", blockchain.ErrTxFailed, callErr)
|
||||
}
|
||||
if isOutOfGas(gc) {
|
||||
return gasUsed, ErrOutOfGas
|
||||
}
|
||||
return gasUsed, nil
|
||||
}
|
||||
|
||||
// compiled returns a cached compiled module, compiling it first if not cached.
|
||||
// Before compiling, the WASM bytes are instrumented with loop-header gas_tick
|
||||
// calls so that infinite loops are bounded by the gas limit.
|
||||
func (v *VM) compiled(ctx context.Context, contractID string, wasmBytes []byte) (wazero.CompiledModule, error) {
|
||||
v.mu.RLock()
|
||||
cm, ok := v.cache[contractID]
|
||||
v.mu.RUnlock()
|
||||
if ok {
|
||||
return cm, nil
|
||||
}
|
||||
|
||||
v.mu.Lock()
|
||||
defer v.mu.Unlock()
|
||||
// Double-check after acquiring write lock.
|
||||
if cm, ok = v.cache[contractID]; ok {
|
||||
return cm, nil
|
||||
}
|
||||
// Instrument: inject gas_tick at loop headers.
|
||||
// On instrumentation failure, fall back to the original bytes so that
|
||||
// unusual WASM features do not prevent execution entirely.
|
||||
instrumented, err := Instrument(wasmBytes)
|
||||
if err != nil {
|
||||
log.Printf("[VM] instrument contract %s: %v (using original bytes)", contractID[:min8(contractID)], err)
|
||||
instrumented = wasmBytes
|
||||
}
|
||||
compiled, err := v.rt.CompileModule(ctx, instrumented)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
v.cache[contractID] = compiled
|
||||
return compiled, nil
|
||||
}
|
||||
|
||||
func min8(s string) int {
|
||||
if len(s) < 8 {
|
||||
return len(s)
|
||||
}
|
||||
return 8
|
||||
}
|
||||
684
vm/vm_test.go
Normal file
684
vm/vm_test.go
Normal file
@@ -0,0 +1,684 @@
|
||||
package vm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"go-blockchain/blockchain"
|
||||
)
|
||||
|
||||
// ── mock host env ─────────────────────────────────────────────────────────────
|
||||
|
||||
type mockEnv struct {
|
||||
state map[string][]byte
|
||||
balances map[string]uint64
|
||||
caller string
|
||||
blockHeight uint64
|
||||
logs []string
|
||||
}
|
||||
|
||||
func newMockEnv(caller string) *mockEnv {
|
||||
return &mockEnv{
|
||||
state: make(map[string][]byte),
|
||||
balances: make(map[string]uint64),
|
||||
caller: caller,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockEnv) GetState(key []byte) ([]byte, error) {
|
||||
v := m.state[string(key)]
|
||||
return v, nil
|
||||
}
|
||||
func (m *mockEnv) SetState(key, value []byte) error {
|
||||
m.state[string(key)] = append([]byte(nil), value...)
|
||||
return nil
|
||||
}
|
||||
func (m *mockEnv) GetBalance(pub string) (uint64, error) { return m.balances[pub], nil }
|
||||
func (m *mockEnv) Transfer(from, to string, amount uint64) error {
|
||||
if m.balances[from] < amount {
|
||||
return errors.New("insufficient balance")
|
||||
}
|
||||
m.balances[from] -= amount
|
||||
m.balances[to] += amount
|
||||
return nil
|
||||
}
|
||||
func (m *mockEnv) GetCaller() string { return m.caller }
|
||||
func (m *mockEnv) GetBlockHeight() uint64 { return m.blockHeight }
|
||||
func (m *mockEnv) GetContractTreasury() string {
|
||||
return "0000000000000000000000000000000000000000000000000000000000000000"
|
||||
}
|
||||
func (m *mockEnv) Log(msg string) { m.logs = append(m.logs, msg) }
|
||||
func (m *mockEnv) CallContract(contractID, method string, argsJSON []byte, gasLimit uint64) (uint64, error) {
|
||||
return 0, fmt.Errorf("CallContract not supported in test mock")
|
||||
}
|
||||
|
||||
// counterWASM loads counter.wasm relative to the test file.
|
||||
func counterWASM(t *testing.T) []byte {
|
||||
t.Helper()
|
||||
data, err := os.ReadFile("../contracts/counter/counter.wasm")
|
||||
if err != nil {
|
||||
t.Fatalf("load counter.wasm: %v", err)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
// readU64State reads an 8-byte big-endian uint64 from env state.
|
||||
func readU64State(env *mockEnv, key string) uint64 {
|
||||
v := env.state[key]
|
||||
if len(v) < 8 {
|
||||
return 0
|
||||
}
|
||||
return binary.BigEndian.Uint64(v[:8])
|
||||
}
|
||||
|
||||
// ── unit tests ────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestValidate_ValidWASM(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
v := NewVM(ctx)
|
||||
defer v.Close(ctx)
|
||||
|
||||
wasmBytes := counterWASM(t)
|
||||
if err := v.Validate(ctx, wasmBytes); err != nil {
|
||||
t.Fatalf("Validate valid WASM: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate_InvalidBytes(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
v := NewVM(ctx)
|
||||
defer v.Close(ctx)
|
||||
|
||||
if err := v.Validate(ctx, []byte("not a wasm file")); err == nil {
|
||||
t.Fatal("expected error for invalid WASM bytes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate_EmptyBytes(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
v := NewVM(ctx)
|
||||
defer v.Close(ctx)
|
||||
|
||||
if err := v.Validate(ctx, []byte{}); err == nil {
|
||||
t.Fatal("expected error for empty bytes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCall_UnknownMethod(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
v := NewVM(ctx)
|
||||
defer v.Close(ctx)
|
||||
|
||||
wasmBytes := counterWASM(t)
|
||||
env := newMockEnv("caller1")
|
||||
_, err := v.Call(ctx, "test-id", wasmBytes, "nonexistent", nil, 100_000, env)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown method")
|
||||
}
|
||||
if !errors.Is(err, blockchain.ErrTxFailed) {
|
||||
t.Fatalf("expected ErrTxFailed, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCall_GasExhausted(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
v := NewVM(ctx)
|
||||
defer v.Close(ctx)
|
||||
|
||||
wasmBytes := counterWASM(t)
|
||||
env := newMockEnv("caller1")
|
||||
// Gas limit of 1 unit is far too low for any real work.
|
||||
_, err := v.Call(ctx, "test-id", wasmBytes, "increment", nil, 1, env)
|
||||
if err == nil {
|
||||
t.Fatal("expected ErrOutOfGas")
|
||||
}
|
||||
if !errors.Is(err, ErrOutOfGas) {
|
||||
t.Fatalf("expected ErrOutOfGas, got: %v", err)
|
||||
}
|
||||
if !errors.Is(err, blockchain.ErrTxFailed) {
|
||||
t.Fatalf("ErrOutOfGas must wrap ErrTxFailed, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ── integration tests (counter contract) ─────────────────────────────────────
|
||||
|
||||
func TestCounter_Increment(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
v := NewVM(ctx)
|
||||
defer v.Close(ctx)
|
||||
|
||||
wasmBytes := counterWASM(t)
|
||||
env := newMockEnv("alice")
|
||||
|
||||
for i := 1; i <= 3; i++ {
|
||||
gasUsed, err := v.Call(ctx, "ctr", wasmBytes, "increment", nil, 1_000_000, env)
|
||||
if err != nil {
|
||||
t.Fatalf("increment %d: %v", i, err)
|
||||
}
|
||||
if gasUsed == 0 {
|
||||
t.Errorf("increment %d: expected gas > 0", i)
|
||||
}
|
||||
got := readU64State(env, "counter")
|
||||
if got != uint64(i) {
|
||||
t.Errorf("after increment %d: counter=%d, want %d", i, got, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCounter_Get(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
v := NewVM(ctx)
|
||||
defer v.Close(ctx)
|
||||
|
||||
wasmBytes := counterWASM(t)
|
||||
env := newMockEnv("alice")
|
||||
|
||||
// Increment twice then call get.
|
||||
for i := 0; i < 2; i++ {
|
||||
if _, err := v.Call(ctx, "ctr", wasmBytes, "increment", nil, 1_000_000, env); err != nil {
|
||||
t.Fatalf("increment: %v", err)
|
||||
}
|
||||
}
|
||||
if _, err := v.Call(ctx, "ctr", wasmBytes, "get", nil, 1_000_000, env); err != nil {
|
||||
t.Fatalf("get: %v", err)
|
||||
}
|
||||
// get logs "get called"
|
||||
logged := false
|
||||
for _, l := range env.logs {
|
||||
if l == "get called" {
|
||||
logged = true
|
||||
}
|
||||
}
|
||||
if !logged {
|
||||
t.Error("expected 'get called' in logs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCounter_Reset_AuthorizedOwner(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
v := NewVM(ctx)
|
||||
defer v.Close(ctx)
|
||||
|
||||
wasmBytes := counterWASM(t)
|
||||
env := newMockEnv("alice")
|
||||
|
||||
// Increment 5 times.
|
||||
for i := 0; i < 5; i++ {
|
||||
if _, err := v.Call(ctx, "ctr", wasmBytes, "increment", nil, 1_000_000, env); err != nil {
|
||||
t.Fatalf("increment: %v", err)
|
||||
}
|
||||
}
|
||||
if got := readU64State(env, "counter"); got != 5 {
|
||||
t.Fatalf("before reset: counter=%d, want 5", got)
|
||||
}
|
||||
|
||||
// First reset — alice becomes owner.
|
||||
if _, err := v.Call(ctx, "ctr", wasmBytes, "reset", nil, 1_000_000, env); err != nil {
|
||||
t.Fatalf("reset (owner set): %v", err)
|
||||
}
|
||||
if got := readU64State(env, "counter"); got != 0 {
|
||||
t.Fatalf("after reset: counter=%d, want 0", got)
|
||||
}
|
||||
|
||||
// Increment again then reset again (alice is owner).
|
||||
for i := 0; i < 3; i++ {
|
||||
if _, err := v.Call(ctx, "ctr", wasmBytes, "increment", nil, 1_000_000, env); err != nil {
|
||||
t.Fatalf("increment: %v", err)
|
||||
}
|
||||
}
|
||||
if _, err := v.Call(ctx, "ctr", wasmBytes, "reset", nil, 1_000_000, env); err != nil {
|
||||
t.Fatalf("reset (owner confirmed): %v", err)
|
||||
}
|
||||
if got := readU64State(env, "counter"); got != 0 {
|
||||
t.Fatalf("second reset: counter=%d, want 0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCounter_Reset_UnauthorizedRejected(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
v := NewVM(ctx)
|
||||
defer v.Close(ctx)
|
||||
|
||||
wasmBytes := counterWASM(t)
|
||||
envAlice := newMockEnv("alice")
|
||||
|
||||
// Alice increments and performs first reset (sets herself as owner).
|
||||
if _, err := v.Call(ctx, "ctr", wasmBytes, "increment", nil, 1_000_000, envAlice); err != nil {
|
||||
t.Fatalf("increment: %v", err)
|
||||
}
|
||||
if _, err := v.Call(ctx, "ctr", wasmBytes, "reset", nil, 1_000_000, envAlice); err != nil {
|
||||
t.Fatalf("reset (set owner): %v", err)
|
||||
}
|
||||
|
||||
// Bob tries to reset using the same state but with his caller ID.
|
||||
// We simulate this by setting the owner in bob's env to alice's value.
|
||||
envBob := newMockEnv("bob")
|
||||
envBob.state = envAlice.state // shared state (bob sees alice as owner)
|
||||
|
||||
// Bob increments so counter > 0.
|
||||
if _, err := v.Call(ctx, "ctr", wasmBytes, "increment", nil, 1_000_000, envBob); err != nil {
|
||||
t.Fatalf("bob increment: %v", err)
|
||||
}
|
||||
counterBefore := readU64State(envBob, "counter")
|
||||
|
||||
// Bob tries reset — should be rejected (logs "unauthorized").
|
||||
if _, err := v.Call(ctx, "ctr", wasmBytes, "reset", nil, 1_000_000, envBob); err != nil {
|
||||
t.Fatalf("bob reset call error: %v", err)
|
||||
}
|
||||
// Counter should not be 0.
|
||||
if got := readU64State(envBob, "counter"); got == 0 {
|
||||
t.Errorf("bob reset succeeded (counter=%d), expected it to be rejected (was %d)", got, counterBefore)
|
||||
}
|
||||
// Should have logged "unauthorized".
|
||||
logged := false
|
||||
for _, l := range envBob.logs {
|
||||
if l == "unauthorized" {
|
||||
logged = true
|
||||
}
|
||||
}
|
||||
if !logged {
|
||||
t.Error("expected 'unauthorized' log when bob tries to reset")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCounter_GasReturned(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
v := NewVM(ctx)
|
||||
defer v.Close(ctx)
|
||||
|
||||
wasmBytes := counterWASM(t)
|
||||
env := newMockEnv("alice")
|
||||
|
||||
gasUsed, err := v.Call(ctx, "ctr", wasmBytes, "increment", nil, 1_000_000, env)
|
||||
if err != nil {
|
||||
t.Fatalf("increment: %v", err)
|
||||
}
|
||||
if gasUsed == 0 || gasUsed >= 1_000_000 {
|
||||
t.Errorf("unexpected gas: %d", gasUsed)
|
||||
}
|
||||
t.Logf("increment gasUsed=%d", gasUsed)
|
||||
}
|
||||
|
||||
func TestABI_Validate(t *testing.T) {
|
||||
a, err := ParseABI(`{"methods":[{"name":"increment","args":[]},{"name":"get","args":[]},{"name":"reset","args":[]}]}`)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseABI: %v", err)
|
||||
}
|
||||
if !a.HasMethod("increment") {
|
||||
t.Error("expected HasMethod(increment)")
|
||||
}
|
||||
if a.HasMethod("nonexistent") {
|
||||
t.Error("unexpected HasMethod(nonexistent)")
|
||||
}
|
||||
if err := a.Validate("increment", nil); err != nil {
|
||||
t.Errorf("Validate increment: %v", err)
|
||||
}
|
||||
if err := a.Validate("unknown", nil); err == nil {
|
||||
t.Error("expected error for unknown method")
|
||||
}
|
||||
}
|
||||
|
||||
func TestABI_NamedArgs(t *testing.T) {
|
||||
a, err := ParseABI(`{"methods":[
|
||||
{"name":"register","args":[{"name":"name","type":"string"}]},
|
||||
{"name":"transfer","args":[{"name":"name","type":"string"},{"name":"new_owner","type":"string"}]}
|
||||
]}`)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseABI: %v", err)
|
||||
}
|
||||
if !a.HasMethod("register") {
|
||||
t.Error("expected HasMethod(register)")
|
||||
}
|
||||
// register expects 1 arg
|
||||
if err := a.Validate("register", []byte(`["alice"]`)); err != nil {
|
||||
t.Errorf("Validate register 1 arg: %v", err)
|
||||
}
|
||||
if err := a.Validate("register", []byte(`["alice","extra"]`)); err == nil {
|
||||
t.Error("expected error for too many args")
|
||||
}
|
||||
// transfer expects 2 args
|
||||
if err := a.Validate("transfer", []byte(`["alice","bob_pubkey"]`)); err != nil {
|
||||
t.Errorf("Validate transfer 2 args: %v", err)
|
||||
}
|
||||
// Inspect arg metadata
|
||||
if a.Methods[0].Args[0].Name != "name" {
|
||||
t.Errorf("arg name: want 'name', got %q", a.Methods[0].Args[0].Name)
|
||||
}
|
||||
if a.Methods[0].Args[0].Type != "string" {
|
||||
t.Errorf("arg type: want 'string', got %q", a.Methods[0].Args[0].Type)
|
||||
}
|
||||
}
|
||||
|
||||
// ── name registry contract tests ──────────────────────────────────────────────
|
||||
|
||||
func nameRegistryWASM(t *testing.T) []byte {
|
||||
t.Helper()
|
||||
data, err := os.ReadFile("../contracts/name_registry/name_registry.wasm")
|
||||
if err != nil {
|
||||
t.Fatalf("load name_registry.wasm: %v", err)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func TestNameRegistry_Register(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
v := NewVM(ctx)
|
||||
defer v.Close(ctx)
|
||||
|
||||
wasm := nameRegistryWASM(t)
|
||||
env := newMockEnv("alice_pubkey_hex")
|
||||
|
||||
_, err := v.Call(ctx, "reg", wasm, "register", []byte(`["alice"]`), 1_000_000, env)
|
||||
if err != nil {
|
||||
t.Fatalf("register: %v", err)
|
||||
}
|
||||
// State key "alice" should now contain the caller pubkey bytes.
|
||||
val := env.state["alice"]
|
||||
if string(val) != "alice_pubkey_hex" {
|
||||
t.Errorf("state[alice] = %q, want %q", val, "alice_pubkey_hex")
|
||||
}
|
||||
// Should have logged something containing "registered"
|
||||
if len(env.logs) == 0 || !strings.Contains(env.logs[len(env.logs)-1], "registered") {
|
||||
t.Errorf("expected last log 'registered', got %v", env.logs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNameRegistry_NameTaken(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
v := NewVM(ctx)
|
||||
defer v.Close(ctx)
|
||||
|
||||
wasm := nameRegistryWASM(t)
|
||||
env := newMockEnv("alice_pubkey_hex")
|
||||
|
||||
// First registration succeeds.
|
||||
if _, err := v.Call(ctx, "reg", wasm, "register", []byte(`["alice"]`), 1_000_000, env); err != nil {
|
||||
t.Fatalf("first register: %v", err)
|
||||
}
|
||||
// Second registration by a different caller should log "name taken".
|
||||
env2 := newMockEnv("bob_pubkey_hex")
|
||||
env2.state = env.state // shared state
|
||||
logsBefore := len(env2.logs)
|
||||
if _, err := v.Call(ctx, "reg", wasm, "register", []byte(`["alice"]`), 1_000_000, env2); err != nil {
|
||||
t.Fatalf("second register: %v", err)
|
||||
}
|
||||
found := false
|
||||
for _, l := range env2.logs[logsBefore:] {
|
||||
if strings.Contains(l, "name taken") {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("expected 'name taken' log, got %v", env2.logs)
|
||||
}
|
||||
// Owner should still be alice.
|
||||
if string(env.state["alice"]) != "alice_pubkey_hex" {
|
||||
t.Error("owner changed unexpectedly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNameRegistry_Resolve(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
v := NewVM(ctx)
|
||||
defer v.Close(ctx)
|
||||
|
||||
wasm := nameRegistryWASM(t)
|
||||
env := newMockEnv("alice_pubkey_hex")
|
||||
|
||||
if _, err := v.Call(ctx, "reg", wasm, "register", []byte(`["alice"]`), 1_000_000, env); err != nil {
|
||||
t.Fatalf("register: %v", err)
|
||||
}
|
||||
env.logs = nil
|
||||
if _, err := v.Call(ctx, "reg", wasm, "resolve", []byte(`["alice"]`), 1_000_000, env); err != nil {
|
||||
t.Fatalf("resolve: %v", err)
|
||||
}
|
||||
// Should log the owner pubkey (verbose: "owner: alice_pubkey_hex").
|
||||
if len(env.logs) == 0 || !strings.Contains(env.logs[0], "alice_pubkey_hex") {
|
||||
t.Errorf("resolve logged %v, want something containing alice_pubkey_hex", env.logs)
|
||||
}
|
||||
|
||||
// Resolve unknown name logs "not found".
|
||||
env.logs = nil
|
||||
if _, err := v.Call(ctx, "reg", wasm, "resolve", []byte(`["unknown"]`), 1_000_000, env); err != nil {
|
||||
t.Fatalf("resolve unknown: %v", err)
|
||||
}
|
||||
if len(env.logs) == 0 || !strings.Contains(env.logs[0], "not found") {
|
||||
t.Errorf("expected 'not found', got %v", env.logs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNameRegistry_Transfer(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
v := NewVM(ctx)
|
||||
defer v.Close(ctx)
|
||||
|
||||
wasm := nameRegistryWASM(t)
|
||||
envAlice := newMockEnv("alice_pubkey_hex")
|
||||
|
||||
if _, err := v.Call(ctx, "reg", wasm, "register", []byte(`["alice"]`), 1_000_000, envAlice); err != nil {
|
||||
t.Fatalf("register: %v", err)
|
||||
}
|
||||
// Alice transfers "alice" to bob.
|
||||
if _, err := v.Call(ctx, "reg", wasm, "transfer",
|
||||
[]byte(`["alice","bob_pubkey_hex"]`), 1_000_000, envAlice); err != nil {
|
||||
t.Fatalf("transfer: %v", err)
|
||||
}
|
||||
if string(envAlice.state["alice"]) != "bob_pubkey_hex" {
|
||||
t.Errorf("state[alice] = %q, want bob_pubkey_hex", envAlice.state["alice"])
|
||||
}
|
||||
lastLog := envAlice.logs[len(envAlice.logs)-1]
|
||||
if !strings.Contains(lastLog, "transferred") {
|
||||
t.Errorf("expected 'transferred' log, got %q", lastLog)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNameRegistry_Transfer_Unauthorized(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
v := NewVM(ctx)
|
||||
defer v.Close(ctx)
|
||||
|
||||
wasm := nameRegistryWASM(t)
|
||||
envAlice := newMockEnv("alice_pubkey_hex")
|
||||
|
||||
if _, err := v.Call(ctx, "reg", wasm, "register", []byte(`["alice"]`), 1_000_000, envAlice); err != nil {
|
||||
t.Fatalf("register: %v", err)
|
||||
}
|
||||
|
||||
// Bob tries to transfer alice's name.
|
||||
envBob := newMockEnv("bob_pubkey_hex")
|
||||
envBob.state = envAlice.state
|
||||
if _, err := v.Call(ctx, "reg", wasm, "transfer",
|
||||
[]byte(`["alice","bob_pubkey_hex"]`), 1_000_000, envBob); err != nil {
|
||||
t.Fatalf("transfer call: %v", err)
|
||||
}
|
||||
// Owner should still be alice.
|
||||
if string(envBob.state["alice"]) != "alice_pubkey_hex" {
|
||||
t.Errorf("unauthorized transfer succeeded, state = %q", envBob.state["alice"])
|
||||
}
|
||||
found := false
|
||||
for _, l := range envBob.logs {
|
||||
if strings.Contains(l, "unauthorized") {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("expected 'unauthorized' log, got %v", envBob.logs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNameRegistry_Release(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
v := NewVM(ctx)
|
||||
defer v.Close(ctx)
|
||||
|
||||
wasm := nameRegistryWASM(t)
|
||||
env := newMockEnv("alice_pubkey_hex")
|
||||
|
||||
if _, err := v.Call(ctx, "reg", wasm, "register", []byte(`["alice"]`), 1_000_000, env); err != nil {
|
||||
t.Fatalf("register: %v", err)
|
||||
}
|
||||
if _, err := v.Call(ctx, "reg", wasm, "release", []byte(`["alice"]`), 1_000_000, env); err != nil {
|
||||
t.Fatalf("release: %v", err)
|
||||
}
|
||||
// State should now be empty (released).
|
||||
val := env.state["alice"]
|
||||
if len(val) > 0 {
|
||||
t.Errorf("after release state[alice] = %q, want empty", val)
|
||||
}
|
||||
lastLog := env.logs[len(env.logs)-1]
|
||||
if !strings.Contains(lastLog, "released") {
|
||||
t.Errorf("expected 'released' log, got %q", lastLog)
|
||||
}
|
||||
|
||||
// After release, anyone can re-register.
|
||||
env2 := newMockEnv("charlie_pubkey_hex")
|
||||
env2.state = env.state
|
||||
if _, err := v.Call(ctx, "reg", wasm, "register", []byte(`["alice"]`), 1_000_000, env2); err != nil {
|
||||
t.Fatalf("re-register after release: %v", err)
|
||||
}
|
||||
if string(env2.state["alice"]) != "charlie_pubkey_hex" {
|
||||
t.Errorf("re-register: state[alice] = %q, want charlie_pubkey_hex", env2.state["alice"])
|
||||
}
|
||||
}
|
||||
|
||||
// ── Phase 9: instruction-level gas metering ───────────────────────────────────
|
||||
|
||||
// infiniteLoopWASM is a hand-encoded WASM module that exports one function,
|
||||
// "loop_forever", which contains an unconditional infinite loop.
|
||||
// The Instrument function must inject gas_tick so the loop is terminated.
|
||||
//
|
||||
// WAT equivalent:
|
||||
//
|
||||
// (module
|
||||
// (func (export "loop_forever")
|
||||
// (loop $L (br $L))
|
||||
// )
|
||||
// )
|
||||
var infiniteLoopWASM = []byte{
|
||||
// magic + version
|
||||
0x00, 0x61, 0x73, 0x6D, 0x01, 0x00, 0x00, 0x00,
|
||||
// type section: 1 type — () → ()
|
||||
0x01, 0x04, 0x01, 0x60, 0x00, 0x00,
|
||||
// function section: 1 function using type 0
|
||||
0x03, 0x02, 0x01, 0x00,
|
||||
// export section: export "loop_forever" as function 0
|
||||
0x07, 0x10, 0x01, 0x0C,
|
||||
0x6C, 0x6F, 0x6F, 0x70, 0x5F, 0x66, 0x6F, 0x72, 0x65, 0x76, 0x65, 0x72, // "loop_forever"
|
||||
0x00, 0x00,
|
||||
// code section: 1 entry — body: loop void; br 0; end; end
|
||||
0x0A, 0x09, 0x01, 0x07, 0x00,
|
||||
0x03, 0x40, // loop void
|
||||
0x0C, 0x00, // br 0
|
||||
0x0B, // end (loop)
|
||||
0x0B, // end (function)
|
||||
}
|
||||
|
||||
// TestInstrument_InfiniteLoop verifies that Instrument succeeds and produces
|
||||
// valid WASM for a module containing an infinite loop.
|
||||
func TestInstrument_InfiniteLoop(t *testing.T) {
|
||||
instrumented, err := Instrument(infiniteLoopWASM)
|
||||
if err != nil {
|
||||
t.Fatalf("Instrument: %v", err)
|
||||
}
|
||||
if len(instrumented) <= len(infiniteLoopWASM) {
|
||||
t.Errorf("instrumented binary (%d B) not larger than original (%d B)",
|
||||
len(instrumented), len(infiniteLoopWASM))
|
||||
}
|
||||
// Must still be valid WASM.
|
||||
ctx := context.Background()
|
||||
v := NewVM(ctx)
|
||||
defer v.Close(ctx)
|
||||
if err := v.Validate(ctx, instrumented); err != nil {
|
||||
t.Fatalf("Validate instrumented: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestInstrument_Idempotent verifies that instrumenting an already-instrumented
|
||||
// binary is a no-op (returns identical bytes).
|
||||
func TestInstrument_Idempotent(t *testing.T) {
|
||||
once, err := Instrument(infiniteLoopWASM)
|
||||
if err != nil {
|
||||
t.Fatalf("first Instrument: %v", err)
|
||||
}
|
||||
twice, err := Instrument(once)
|
||||
if err != nil {
|
||||
t.Fatalf("second Instrument: %v", err)
|
||||
}
|
||||
if !bytes.Equal(once, twice) {
|
||||
t.Error("second instrumentation changed the binary — not idempotent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestInfiniteLoop_TrappedByGas verifies that an infinite loop is terminated
|
||||
// when the gas budget is exhausted.
|
||||
func TestInfiniteLoop_TrappedByGas(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
v := NewVM(ctx)
|
||||
defer v.Close(ctx)
|
||||
|
||||
env := newMockEnv("alice")
|
||||
_, err := v.Call(ctx, "inf-loop", infiniteLoopWASM, "loop_forever", nil, 10_000, env)
|
||||
if err == nil {
|
||||
t.Fatal("expected error from infinite loop — should have been trapped by gas")
|
||||
}
|
||||
if !errors.Is(err, ErrOutOfGas) {
|
||||
t.Fatalf("expected ErrOutOfGas, got: %v", err)
|
||||
}
|
||||
if !errors.Is(err, blockchain.ErrTxFailed) {
|
||||
t.Fatalf("ErrOutOfGas must wrap ErrTxFailed, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNameRegistry_GasIncludesLoopCost verifies that calling a contract method
|
||||
// that exercises a loop (bytes_equal in name_registry) produces a non-zero
|
||||
// gasUsed that is higher than the function-call minimum.
|
||||
func TestNameRegistry_GasIncludesLoopCost(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
v := NewVM(ctx)
|
||||
defer v.Close(ctx)
|
||||
|
||||
wasm := nameRegistryWASM(t)
|
||||
env := newMockEnv("alice_pubkey_hex")
|
||||
|
||||
gasUsed, err := v.Call(ctx, "reg-gas", wasm, "register", []byte(`["alice"]`), 1_000_000, env)
|
||||
if err != nil {
|
||||
t.Fatalf("register: %v", err)
|
||||
}
|
||||
// The name-registry register function calls bytes_equal in a loop.
|
||||
// With loop-header metering, gas > just function-call overhead.
|
||||
if gasUsed == 0 {
|
||||
t.Error("gasUsed == 0, expected > 0")
|
||||
}
|
||||
t.Logf("name_registry.register gasUsed = %d", gasUsed)
|
||||
}
|
||||
|
||||
// TestInstrument_CounterWASM verifies that the counter contract (no loops)
|
||||
// is instrumented without error and remains functionally identical.
|
||||
func TestInstrument_CounterWASM(t *testing.T) {
|
||||
original := counterWASM(t)
|
||||
instrumented, err := Instrument(original)
|
||||
if err != nil {
|
||||
t.Fatalf("Instrument counter: %v", err)
|
||||
}
|
||||
// Counter WASM has no loops — instrumented size may equal original.
|
||||
ctx := context.Background()
|
||||
v := NewVM(ctx)
|
||||
defer v.Close(ctx)
|
||||
if err := v.Validate(ctx, instrumented); err != nil {
|
||||
t.Fatalf("Validate instrumented counter: %v", err)
|
||||
}
|
||||
// Function still works after instrumentation.
|
||||
env := newMockEnv("alice")
|
||||
if _, err := v.Call(ctx, "ctr-instr", instrumented, "increment", nil, 1_000_000, env); err != nil {
|
||||
t.Fatalf("increment on instrumented counter: %v", err)
|
||||
}
|
||||
if readU64State(env, "counter") != 1 {
|
||||
t.Error("counter not incremented after instrumentation")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user