// Package vm — WASM bytecode instrumentation for instruction-level gas metering. // // Instrument rewrites a WASM binary so that every loop header calls the host // function env.gas_tick(). This guarantees that any infinite loop—or any loop // whose iteration count is proportional to attacker-controlled input—is bounded // by the transaction gas limit. // // Algorithm // // 1. Parse sections. // 2. Find or add the function type () → () in the type section. // 3. Append env.gas_tick as the last function import (so existing import // indices are unchanged). Record numOldFuncImports = M. // 4. Rewrite the export section: any function export with index ≥ M → +1. // 5. Rewrite the code section: // • "call N" where N ≥ M → "call N+1" (defined-function shift) // • "loop bt …" injects "call $gas_tick" immediately after "loop bt" // 6. Reassemble. // // The call to gas_tick is also metered by the FunctionListener (100 gas/call), // so each loop iteration costs at least 101 gas units total. // // If the binary already contains an env.gas_tick import the function returns // the original bytes unchanged (idempotent). package vm import ( "errors" "fmt" ) // ── section and opcode constants ───────────────────────────────────────────── const ( wasmMagic = "\x00asm" wasmVersion = "\x01\x00\x00\x00" secType = 1 secImport = 2 secExport = 7 secCode = 10 importFunc = 0x00 exportFunc = 0x00 opBlock = 0x02 opLoop = 0x03 opIf = 0x04 opElse = 0x05 opEnd = 0x0B opBr = 0x0C opBrIf = 0x0D opBrTable = 0x0E opCall = 0x10 opCallInd = 0x11 opLocalGet = 0x20 opLocalSet = 0x21 opLocalTee = 0x22 opGlobalGet = 0x23 opGlobalSet = 0x24 opTableGet = 0x25 opTableSet = 0x26 opI32Const = 0x41 opI64Const = 0x42 opF32Const = 0x43 opF64Const = 0x44 opMemSize = 0x3F opMemGrow = 0x40 opRefNull = 0xD0 opRefIsNull = 0xD1 opRefFunc = 0xD2 opSelectT = 0x1C opPrefixFC = 0xFC ) // ── LEB128 helpers ──────────────────────────────────────────────────────────── func readU32Leb(b []byte) (uint32, int) { var x uint32 var s uint for i, by := range b { if i == 5 { return 0, -1 } x |= uint32(by&0x7F) << s s += 7 if by&0x80 == 0 { return x, i + 1 } } return 0, -1 } func readI32Leb(b []byte) (int32, int) { var x int32 var s uint for i, by := range b { if i == 5 { return 0, -1 } x |= int32(by&0x7F) << s s += 7 if by&0x80 == 0 { if s < 32 && (by&0x40) != 0 { x |= ^0 << s } return x, i + 1 } } return 0, -1 } func readI64Leb(b []byte) (int64, int) { var x int64 var s uint for i, by := range b { if i == 10 { return 0, -1 } x |= int64(by&0x7F) << s s += 7 if by&0x80 == 0 { if s < 64 && (by&0x40) != 0 { x |= ^int64(0) << s } return x, i + 1 } } return 0, -1 } func appendU32Leb(out []byte, v uint32) []byte { for { b := byte(v & 0x7F) v >>= 7 if v != 0 { b |= 0x80 } out = append(out, b) if v == 0 { break } } return out } // skipU32Leb returns the number of bytes consumed by one unsigned LEB128. func skipU32Leb(b []byte) (int, error) { _, n := readU32Leb(b) if n <= 0 { return 0, errors.New("bad unsigned LEB128") } return n, nil } func skipI32Leb(b []byte) (int, error) { _, n := readI32Leb(b) if n <= 0 { return 0, errors.New("bad signed LEB128 i32") } return n, nil } func skipI64Leb(b []byte) (int, error) { _, n := readI64Leb(b) if n <= 0 { return 0, errors.New("bad signed LEB128 i64") } return n, nil } // ── WASM type for gas_tick: () → () ────────────────────────────────────────── // Encoded as: 0x60 (functype) 0x00 (0 params) 0x00 (0 results) var gasTickFuncType = []byte{0x60, 0x00, 0x00} // ── Top-level Instrument ────────────────────────────────────────────────────── // rawSection holds one parsed section. type rawSection struct { id byte data []byte } // Instrument returns a copy of wasm with env.gas_tick calls injected at every // loop header. Returns wasm unchanged if already instrumented or if wasm has // no code section. Returns an error only on malformed binaries. func Instrument(wasm []byte) ([]byte, error) { if len(wasm) < 8 || string(wasm[:4]) != wasmMagic || string(wasm[4:8]) != wasmVersion { return nil, errors.New("not a valid WASM binary") } off := 8 var sections []rawSection for off < len(wasm) { if off >= len(wasm) { break } id := wasm[off] off++ size, n := readU32Leb(wasm[off:]) if n <= 0 { return nil, errors.New("bad section size LEB128") } off += n end := off + int(size) if end > len(wasm) { return nil, fmt.Errorf("section %d size %d exceeds binary length", id, size) } sections = append(sections, rawSection{id: id, data: wasm[off:end]}) off = end } // Locate key sections. typeIdx, importIdx, exportIdx, codeIdx := -1, -1, -1, -1 for i, s := range sections { switch s.id { case secType: typeIdx = i case secImport: importIdx = i case secExport: exportIdx = i case secCode: codeIdx = i } } if codeIdx < 0 { return wasm, nil // nothing to instrument } // Idempotency: already instrumented? if importIdx >= 0 && containsGasTick(sections[importIdx].data) { return wasm, nil } // ── Step 1: find or add gas_tick type ───────────────────────────────────── var gasTickTypeIdx uint32 if typeIdx >= 0 { var err error sections[typeIdx].data, gasTickTypeIdx, err = ensureGasTickType(sections[typeIdx].data) if err != nil { return nil, fmt.Errorf("type section: %w", err) } } else { // Create a minimal type section containing only the gas_tick type. ts := appendU32Leb(nil, 1) // count = 1 ts = append(ts, gasTickFuncType...) sections = insertBefore(sections, secImport, rawSection{id: secType, data: ts}) // Recalculate section indices. typeIdx, importIdx, exportIdx, codeIdx = findSections(sections) gasTickTypeIdx = 0 } // ── Step 2: add gas_tick import, count old func imports ─────────────────── var numOldFuncImports uint32 var gasFnIdx uint32 if importIdx >= 0 { var err error sections[importIdx].data, numOldFuncImports, err = appendGasTickImport(sections[importIdx].data, gasTickTypeIdx) if err != nil { return nil, fmt.Errorf("import section: %w", err) } gasFnIdx = numOldFuncImports // gas_tick is the new last import } else { // Build import section with just gas_tick. is, err := buildImportSection(gasTickTypeIdx) if err != nil { return nil, err } sections = insertBefore(sections, secExport, rawSection{id: secImport, data: is}) typeIdx, importIdx, exportIdx, codeIdx = findSections(sections) numOldFuncImports = 0 gasFnIdx = 0 } // ── Step 3: adjust export section ───────────────────────────────────────── if exportIdx >= 0 { var err error sections[exportIdx].data, err = adjustExportFuncIndices(sections[exportIdx].data, numOldFuncImports) if err != nil { return nil, fmt.Errorf("export section: %w", err) } } // ── Step 4: rewrite code section ────────────────────────────────────────── var err error sections[codeIdx].data, err = rewriteCodeSection(sections[codeIdx].data, numOldFuncImports, gasFnIdx) if err != nil { return nil, fmt.Errorf("code section: %w", err) } // ── Reassemble ──────────────────────────────────────────────────────────── out := make([]byte, 0, len(wasm)+128) out = append(out, []byte(wasmMagic+wasmVersion)...) for _, s := range sections { out = append(out, s.id) out = appendU32Leb(out, uint32(len(s.data))) out = append(out, s.data...) } return out, nil } // ── type section helpers ────────────────────────────────────────────────────── // ensureGasTickType finds () → () in the type section or appends it. // Returns the (possibly rewritten) section bytes and the type index. func ensureGasTickType(data []byte) ([]byte, uint32, error) { off := 0 count, n := readU32Leb(data[off:]) if n <= 0 { return nil, 0, errors.New("bad type count") } off += n start := off for i := uint32(0); i < count; i++ { entryStart := off if off >= len(data) || data[off] != 0x60 { return nil, 0, fmt.Errorf("type %d: expected 0x60", i) } off++ // params pc, n := readU32Leb(data[off:]) if n <= 0 { return nil, 0, fmt.Errorf("type %d: bad param count", i) } off += n paramStart := off off += int(pc) // each valtype is 1 byte // results rc, n := readU32Leb(data[off:]) if n <= 0 { return nil, 0, fmt.Errorf("type %d: bad result count", i) } off += n off += int(rc) // Is this () → () ? if pc == 0 && rc == 0 { _ = paramStart _ = entryStart _ = start return data, i, nil // already present } } // Not found — append. newData := appendU32Leb(nil, count+1) newData = append(newData, data[start:]...) // all existing type entries newData = append(newData, gasTickFuncType...) return newData, count, nil } // ── import section helpers ──────────────────────────────────────────────────── // containsGasTick returns true if env.gas_tick is already imported. func containsGasTick(data []byte) bool { off := 0 count, n := readU32Leb(data[off:]) if n <= 0 { return false } off += n for i := uint32(0); i < count; i++ { // mod name ml, n := readU32Leb(data[off:]) if n <= 0 { return false } off += n mod := string(data[off : off+int(ml)]) off += int(ml) // field name fl, n := readU32Leb(data[off:]) if n <= 0 { return false } off += n field := string(data[off : off+int(fl)]) off += int(fl) // importdesc kind if off >= len(data) { return false } kind := data[off] off++ if mod == "env" && field == "gas_tick" { return true } switch kind { case 0x00: // function: type idx _, n = readU32Leb(data[off:]) if n <= 0 { return false } off += n case 0x01: // table: reftype + limits off++ // reftype lkind := data[off] off++ _, n = readU32Leb(data[off:]) if n <= 0 { return false } off += n if lkind == 1 { _, n = readU32Leb(data[off:]) if n <= 0 { return false } off += n } case 0x02: // memory: limits lkind := data[off] off++ _, n = readU32Leb(data[off:]) if n <= 0 { return false } off += n if lkind == 1 { _, n = readU32Leb(data[off:]) if n <= 0 { return false } off += n } case 0x03: // global: valtype + mutability off += 2 } } return false } // appendGasTickImport adds env.gas_tick as the last function import. // Returns (newData, numOldFuncImports, error). func appendGasTickImport(data []byte, typeIdx uint32) ([]byte, uint32, error) { off := 0 count, n := readU32Leb(data[off:]) if n <= 0 { return nil, 0, errors.New("bad import count") } off += n start := off var numFuncImports uint32 for i := uint32(0); i < count; i++ { // mod name ml, n := readU32Leb(data[off:]) if n <= 0 { return nil, 0, errors.New("bad import mod len") } off += n + int(ml) // field name fl, n := readU32Leb(data[off:]) if n <= 0 { return nil, 0, errors.New("bad import field len") } off += n + int(fl) // importdesc if off >= len(data) { return nil, 0, errors.New("truncated importdesc") } kind := data[off] off++ switch kind { case 0x00: // function numFuncImports++ _, n = readU32Leb(data[off:]) if n <= 0 { return nil, 0, errors.New("bad func import type idx") } off += n case 0x01: // table off++ // reftype lkind := data[off] off++ _, n = readU32Leb(data[off:]) if n <= 0 { return nil, 0, errors.New("bad table import limit") } off += n if lkind == 1 { _, n = readU32Leb(data[off:]) if n <= 0 { return nil, 0, errors.New("bad table import max") } off += n } case 0x02: // memory lkind := data[off] off++ _, n = readU32Leb(data[off:]) if n <= 0 { return nil, 0, errors.New("bad mem import limit") } off += n if lkind == 1 { _, n = readU32Leb(data[off:]) if n <= 0 { return nil, 0, errors.New("bad mem import max") } off += n } case 0x03: // global off += 2 } } // Build new import section: count+1, existing entries, then gas_tick entry. newData := appendU32Leb(nil, count+1) newData = append(newData, data[start:]...) // env.gas_tick entry newData = appendU32Leb(newData, 3) // "env" length newData = append(newData, "env"...) newData = appendU32Leb(newData, 8) // "gas_tick" length newData = append(newData, "gas_tick"...) newData = append(newData, importFunc) // kind = function newData = appendU32Leb(newData, typeIdx) return newData, numFuncImports, nil } // buildImportSection builds a section containing only env.gas_tick. func buildImportSection(typeIdx uint32) ([]byte, error) { var data []byte data = appendU32Leb(data, 1) // count = 1 data = appendU32Leb(data, 3) data = append(data, "env"...) data = appendU32Leb(data, 8) data = append(data, "gas_tick"...) data = append(data, importFunc) data = appendU32Leb(data, typeIdx) return data, nil } // ── export section adjustment ───────────────────────────────────────────────── // adjustExportFuncIndices increments function export indices ≥ threshold by 1. func adjustExportFuncIndices(data []byte, threshold uint32) ([]byte, error) { off := 0 count, n := readU32Leb(data[off:]) if n <= 0 { return nil, errors.New("bad export count") } off += n out := appendU32Leb(nil, count) for i := uint32(0); i < count; i++ { // name nl, n := readU32Leb(data[off:]) if n <= 0 { return nil, fmt.Errorf("export %d: bad name len", i) } off += n out = appendU32Leb(out, nl) out = append(out, data[off:off+int(nl)]...) off += int(nl) // kind if off >= len(data) { return nil, fmt.Errorf("export %d: truncated kind", i) } kind := data[off] off++ out = append(out, kind) // index idx, n := readU32Leb(data[off:]) if n <= 0 { return nil, fmt.Errorf("export %d: bad index", i) } off += n if kind == exportFunc && idx >= threshold { idx++ } out = appendU32Leb(out, idx) } return out, nil } // ── code section rewriting ──────────────────────────────────────────────────── func rewriteCodeSection(data []byte, numOldImports, gasFnIdx uint32) ([]byte, error) { off := 0 count, n := readU32Leb(data[off:]) if n <= 0 { return nil, errors.New("bad code section count") } off += n out := appendU32Leb(nil, count) for i := uint32(0); i < count; i++ { bodySize, n := readU32Leb(data[off:]) if n <= 0 { return nil, fmt.Errorf("code entry %d: bad body size", i) } off += n body := data[off : off+int(bodySize)] off += int(bodySize) newBody, err := rewriteFuncBody(body, numOldImports, gasFnIdx) if err != nil { return nil, fmt.Errorf("code entry %d: %w", i, err) } out = appendU32Leb(out, uint32(len(newBody))) out = append(out, newBody...) } return out, nil } func rewriteFuncBody(body []byte, numOldImports, gasFnIdx uint32) ([]byte, error) { off := 0 // local variable declarations localGroupCount, n := readU32Leb(body[off:]) if n <= 0 { return nil, errors.New("bad local group count") } var out []byte out = appendU32Leb(out, localGroupCount) off += n for i := uint32(0); i < localGroupCount; i++ { cnt, n := readU32Leb(body[off:]) if n <= 0 { return nil, fmt.Errorf("local group %d: bad count", i) } off += n if off >= len(body) { return nil, fmt.Errorf("local group %d: missing valtype", i) } out = appendU32Leb(out, cnt) out = append(out, body[off]) // valtype off++ } // instruction stream instrs, err := rewriteExpr(body[off:], numOldImports, gasFnIdx) if err != nil { return nil, err } out = append(out, instrs...) return out, nil } // rewriteExpr processes one expression (terminated by the matching end). func rewriteExpr(code []byte, numOldImports, gasFnIdx uint32) ([]byte, error) { off := 0 var out []byte depth := 1 // implicit outer scope; final end at depth==1 terminates expression for off < len(code) { op := code[off] switch op { case opBlock, opIf: off++ btLen, err := blocktypeLen(code[off:]) if err != nil { return nil, fmt.Errorf("block/if blocktype: %w", err) } out = append(out, op) out = append(out, code[off:off+btLen]...) off += btLen depth++ case opLoop: off++ btLen, err := blocktypeLen(code[off:]) if err != nil { return nil, fmt.Errorf("loop blocktype: %w", err) } out = append(out, op) out = append(out, code[off:off+btLen]...) off += btLen depth++ // Inject gas_tick call at loop header. out = append(out, opCall) out = appendU32Leb(out, gasFnIdx) case opElse: out = append(out, op) off++ case opEnd: out = append(out, op) off++ depth-- if depth == 0 { return out, nil } case opCall: off++ idx, n := readU32Leb(code[off:]) if n <= 0 { return nil, errors.New("call: bad function index LEB128") } off += n if idx >= numOldImports { idx++ // shift defined-function index } out = append(out, opCall) out = appendU32Leb(out, idx) case opCallInd: off++ typeIdx, n := readU32Leb(code[off:]) if n <= 0 { return nil, errors.New("call_indirect: bad type index") } off += n tableIdx, n := readU32Leb(code[off:]) if n <= 0 { return nil, errors.New("call_indirect: bad table index") } off += n out = append(out, opCallInd) out = appendU32Leb(out, typeIdx) out = appendU32Leb(out, tableIdx) case opRefFunc: off++ idx, n := readU32Leb(code[off:]) if n <= 0 { return nil, errors.New("ref.func: bad index") } off += n if idx >= numOldImports { idx++ } out = append(out, opRefFunc) out = appendU32Leb(out, idx) default: // Copy instruction verbatim; instrLen handles all remaining opcodes. ilen, err := instrLen(code[off:]) if err != nil { return nil, fmt.Errorf("at offset %d: %w", off, err) } out = append(out, code[off:off+ilen]...) off += ilen } } return nil, errors.New("expression missing terminating end") } // blocktypeLen returns the byte length of a blocktype immediate. // Blocktypes are SLEB128-encoded: value types and void are single negative bytes; // type-index blocktypes are non-negative LEB128 (multi-byte for indices ≥ 64). func blocktypeLen(b []byte) (int, error) { _, n := readI32Leb(b) if n <= 0 { return 0, errors.New("bad blocktype encoding") } return n, nil } // instrLen returns the total byte length (opcode + immediates) of the instruction // at code[0]. It handles all opcodes EXCEPT block/loop/if/else/end/call/ // call_indirect/ref.func — those are handled directly in rewriteExpr. func instrLen(code []byte) (int, error) { if len(code) == 0 { return 0, errors.New("empty instruction stream") } op := code[0] switch { // 1-byte, no immediates — covers all integer/float arithmetic, comparisons, // conversion, drop, select, return, unreachable, nop, else, end, ref.is_null. case op == 0x00 || op == 0x01 || op == 0x0F || op == 0x1A || op == 0x1B || op == opRefIsNull || (op >= 0x45 && op <= 0xC4): return 1, nil // br, br_if, local.get/set/tee, global.get/set, table.get/set — one u32 LEB128 case op == opBr || op == opBrIf || op == opLocalGet || op == opLocalSet || op == opLocalTee || op == opGlobalGet || op == opGlobalSet || op == opTableGet || op == opTableSet: n, err := skipU32Leb(code[1:]) return 1 + n, err // br_table: u32 count N, then N+1 u32 values case op == opBrTable: off := 1 cnt, n := readU32Leb(code[off:]) if n <= 0 { return 0, errors.New("br_table: bad count") } off += n for i := uint32(0); i <= cnt; i++ { n, err := skipU32Leb(code[off:]) if err != nil { return 0, fmt.Errorf("br_table target %d: %w", i, err) } off += n } return off, nil // memory load/store (0x28–0x3E): align u32 + offset u32 case op >= 0x28 && op <= 0x3E: off := 1 n, err := skipU32Leb(code[off:]) if err != nil { return 0, fmt.Errorf("mem instr 0x%02X align: %w", op, err) } off += n n, err = skipU32Leb(code[off:]) if err != nil { return 0, fmt.Errorf("mem instr 0x%02X offset: %w", op, err) } return 1 + (off - 1 + n), nil // memory.size (0x3F), memory.grow (0x40): reserved byte 0x00 case op == opMemSize || op == opMemGrow: return 2, nil // i32.const: signed LEB128 i32 case op == opI32Const: n, err := skipI32Leb(code[1:]) return 1 + n, err // i64.const: signed LEB128 i64 case op == opI64Const: n, err := skipI64Leb(code[1:]) return 1 + n, err // f32.const: 4 raw bytes case op == opF32Const: if len(code) < 5 { return 0, errors.New("f32.const: truncated") } return 5, nil // f64.const: 8 raw bytes case op == opF64Const: if len(code) < 9 { return 0, errors.New("f64.const: truncated") } return 9, nil // ref.null: 1 byte reftype case op == opRefNull: return 2, nil // select with types (0x1C): u32 count + that many valtypes (each 1 byte) case op == opSelectT: cnt, n := readU32Leb(code[1:]) if n <= 0 { return 0, errors.New("select t: bad count") } return 1 + n + int(cnt), nil // 0xFC prefix — saturating truncation, table/memory bulk ops case op == opPrefixFC: sub, n := readU32Leb(code[1:]) if n <= 0 { return 0, errors.New("0xFC: bad sub-opcode") } off := 1 + n // sub-ops that have additional immediates switch sub { case 8: // memory.init: data_idx u32, mem idx u8 (0x00) n2, err := skipU32Leb(code[off:]) if err != nil { return 0, err } off += n2 off++ // reserved memory index byte case 9: // data.drop: data_idx u32 n2, err := skipU32Leb(code[off:]) if err != nil { return 0, err } off += n2 case 10: // memory.copy: two reserved 0x00 bytes off += 2 case 11: // memory.fill: reserved 0x00 off++ case 12: // table.init: elem_idx u32, table_idx u32 n2, err := skipU32Leb(code[off:]) if err != nil { return 0, err } off += n2 n2, err = skipU32Leb(code[off:]) if err != nil { return 0, err } off += n2 case 13: // elem.drop: elem_idx u32 n2, err := skipU32Leb(code[off:]) if err != nil { return 0, err } off += n2 case 14: // table.copy: dst_idx u32, src_idx u32 n2, err := skipU32Leb(code[off:]) if err != nil { return 0, err } off += n2 n2, err = skipU32Leb(code[off:]) if err != nil { return 0, err } off += n2 case 15, 16: // table.grow, table.size: table_idx u32 n2, err := skipU32Leb(code[off:]) if err != nil { return 0, err } off += n2 case 17: // table.fill: table_idx u32 n2, err := skipU32Leb(code[off:]) if err != nil { return 0, err } off += n2 // 0–7: i32/i64 saturating truncation — no additional immediates } return off, nil default: return 0, fmt.Errorf("unknown opcode 0x%02X", op) } } // ── section list helpers ────────────────────────────────────────────────────── func findSections(ss []rawSection) (typeIdx, importIdx, exportIdx, codeIdx int) { typeIdx, importIdx, exportIdx, codeIdx = -1, -1, -1, -1 for i, s := range ss { switch s.id { case secType: typeIdx = i case secImport: importIdx = i case secExport: exportIdx = i case secCode: codeIdx = i } } return } // insertBefore inserts ns before the first section with the given id, // or at the end if no such section exists. func insertBefore(ss []rawSection, id byte, ns rawSection) []rawSection { for i, s := range ss { if s.id == id { result := make([]rawSection, 0, len(ss)+1) result = append(result, ss[:i]...) result = append(result, ns) result = append(result, ss[i:]...) return result } } return append(ss, ns) }