core/vm/runtime: simplified runtime calling mechanism

Implemented `runtime.Call` which uses - unlike Execute - the given state
for the execution and the address of the contract you wish to execute.
Unlike `Execute`, `Call` requires a config.
This commit is contained in:
Jeffrey Wilcke 2016-02-09 23:20:42 +01:00
parent ecc876cec0
commit 6cace73bea
3 changed files with 117 additions and 18 deletions

@ -41,6 +41,7 @@ type Config struct {
DisableJit bool // "disable" so it's enabled by default DisableJit bool // "disable" so it's enabled by default
Debug bool Debug bool
State *state.StateDB
GetHashFn func(n uint64) common.Hash GetHashFn func(n uint64) common.Hash
} }
@ -94,12 +95,14 @@ func Execute(code, input []byte, cfg *Config) ([]byte, *state.StateDB, error) {
vm.EnableJit = !cfg.DisableJit vm.EnableJit = !cfg.DisableJit
vm.Debug = cfg.Debug vm.Debug = cfg.Debug
if cfg.State == nil {
db, _ := ethdb.NewMemDatabase()
cfg.State, _ = state.New(common.Hash{}, db)
}
var ( var (
db, _ = ethdb.NewMemDatabase() vmenv = NewEnv(cfg, cfg.State)
statedb, _ = state.New(common.Hash{}, db) sender = cfg.State.CreateAccount(cfg.Origin)
vmenv = NewEnv(cfg, statedb) receiver = cfg.State.CreateAccount(common.StringToAddress("contract"))
sender = statedb.CreateAccount(cfg.Origin)
receiver = statedb.CreateAccount(common.StringToAddress("contract"))
) )
// set the receiver's (the executing contract) code for execution. // set the receiver's (the executing contract) code for execution.
receiver.SetCode(code) receiver.SetCode(code)
@ -117,5 +120,43 @@ func Execute(code, input []byte, cfg *Config) ([]byte, *state.StateDB, error) {
if cfg.Debug { if cfg.Debug {
vm.StdErrFormat(vmenv.StructLogs()) vm.StdErrFormat(vmenv.StructLogs())
} }
return ret, statedb, err return ret, cfg.State, err
}
// Call executes the code given by the contract's address. It will return the
// EVM's return value or an error if it failed.
//
// Call, unlike Execute, requires a config and also requires the State field to
// be set.
func Call(address common.Address, input []byte, cfg *Config) ([]byte, error) {
setDefaults(cfg)
// defer the call to setting back the original values
defer func(debug, forceJit, enableJit bool) {
vm.Debug = debug
vm.ForceJit = forceJit
vm.EnableJit = enableJit
}(vm.Debug, vm.ForceJit, vm.EnableJit)
vm.ForceJit = !cfg.DisableJit
vm.EnableJit = !cfg.DisableJit
vm.Debug = cfg.Debug
vmenv := NewEnv(cfg, cfg.State)
sender := cfg.State.GetOrNewStateObject(cfg.Origin)
// Call the code with the given configuration.
ret, err := vmenv.Call(
sender,
address,
input,
cfg.GasLimit,
cfg.GasPrice,
cfg.Value,
)
if cfg.Debug {
vm.StdErrFormat(vmenv.StructLogs())
}
return ret, err
} }

@ -17,12 +17,15 @@
package runtime package runtime
import ( import (
"math/big"
"strings" "strings"
"testing" "testing"
"github.com/ethereum/go-ethereum/accounts/abi" "github.com/ethereum/go-ethereum/accounts/abi"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/ethdb"
) )
func TestDefaults(t *testing.T) { func TestDefaults(t *testing.T) {
@ -71,6 +74,49 @@ func TestEnvironment(t *testing.T) {
}, nil, nil) }, nil, nil)
} }
func TestExecute(t *testing.T) {
ret, _, err := Execute([]byte{
byte(vm.PUSH1), 10,
byte(vm.PUSH1), 0,
byte(vm.MSTORE),
byte(vm.PUSH1), 32,
byte(vm.PUSH1), 0,
byte(vm.RETURN),
}, nil, nil)
if err != nil {
t.Fatal("didn't expect error", err)
}
num := common.BytesToBig(ret)
if num.Cmp(big.NewInt(10)) != 0 {
t.Error("Expected 10, got", num)
}
}
func TestCall(t *testing.T) {
db, _ := ethdb.NewMemDatabase()
state, _ := state.New(common.Hash{}, db)
address := common.HexToAddress("0x0a")
state.SetCode(address, []byte{
byte(vm.PUSH1), 10,
byte(vm.PUSH1), 0,
byte(vm.MSTORE),
byte(vm.PUSH1), 32,
byte(vm.PUSH1), 0,
byte(vm.RETURN),
})
ret, err := Call(address, nil, &Config{State: state})
if err != nil {
t.Fatal("didn't expect error", err)
}
num := common.BytesToBig(ret)
if num.Cmp(big.NewInt(10)) != 0 {
t.Error("Expected 10, got", num)
}
}
func TestRestoreDefaults(t *testing.T) { func TestRestoreDefaults(t *testing.T) {
Execute(nil, nil, &Config{Debug: true}) Execute(nil, nil, &Config{Debug: true})
if vm.ForceJit { if vm.ForceJit {

@ -25,6 +25,21 @@ import (
"github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/core/vm"
) )
// GetHashFn returns a function for which the VM env can query block hashes thru
// up to the limit defined by the Yellow Paper and uses the given block chain
// to query for information.
func GetHashFn(ref common.Hash, chain *BlockChain) func(n uint64) common.Hash {
return func(n uint64) common.Hash {
for block := chain.GetBlock(ref); block != nil; block = chain.GetBlock(block.ParentHash()) {
if block.NumberU64() == n {
return block.Hash()
}
}
return common.Hash{}
}
}
type VMEnv struct { type VMEnv struct {
state *state.StateDB state *state.StateDB
header *types.Header header *types.Header
@ -32,17 +47,20 @@ type VMEnv struct {
depth int depth int
chain *BlockChain chain *BlockChain
typ vm.Type typ vm.Type
getHashFn func(uint64) common.Hash
// structured logging // structured logging
logs []vm.StructLog logs []vm.StructLog
} }
func NewEnv(state *state.StateDB, chain *BlockChain, msg Message, header *types.Header) *VMEnv { func NewEnv(state *state.StateDB, chain *BlockChain, msg Message, header *types.Header) *VMEnv {
return &VMEnv{ return &VMEnv{
chain: chain, chain: chain,
state: state, state: state,
header: header, header: header,
msg: msg, msg: msg,
typ: vm.StdVmTy, typ: vm.StdVmTy,
getHashFn: GetHashFn(header.ParentHash, chain),
} }
} }
@ -59,13 +77,7 @@ func (self *VMEnv) SetDepth(i int) { self.depth = i }
func (self *VMEnv) VmType() vm.Type { return self.typ } func (self *VMEnv) VmType() vm.Type { return self.typ }
func (self *VMEnv) SetVmType(t vm.Type) { self.typ = t } func (self *VMEnv) SetVmType(t vm.Type) { self.typ = t }
func (self *VMEnv) GetHash(n uint64) common.Hash { func (self *VMEnv) GetHash(n uint64) common.Hash {
for block := self.chain.GetBlock(self.header.ParentHash); block != nil; block = self.chain.GetBlock(block.ParentHash()) { return self.getHashFn(n)
if block.NumberU64() == n {
return block.Hash()
}
}
return common.Hash{}
} }
func (self *VMEnv) AddLog(log *vm.Log) { func (self *VMEnv) AddLog(log *vm.Log) {