add chain id into sign bytes to avoid replay attack (#18)

This commit is contained in:
zjubfd 2020-07-09 15:46:37 +08:00 committed by GitHub
parent 8d48b0deb8
commit f4816ee8b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 20 additions and 18 deletions

@ -156,7 +156,7 @@ func isToSystemContract(to common.Address) bool {
} }
// ecrecover extracts the Ethereum account address from a signed header. // ecrecover extracts the Ethereum account address from a signed header.
func ecrecover(header *types.Header, sigCache *lru.ARCCache) (common.Address, error) { func ecrecover(header *types.Header, sigCache *lru.ARCCache, chainId *big.Int) (common.Address, error) {
// If the signature's already cached, return that // If the signature's already cached, return that
hash := header.Hash() hash := header.Hash()
if address, known := sigCache.Get(hash); known { if address, known := sigCache.Get(hash); known {
@ -169,7 +169,7 @@ func ecrecover(header *types.Header, sigCache *lru.ARCCache) (common.Address, er
signature := header.Extra[len(header.Extra)-extraSeal:] signature := header.Extra[len(header.Extra)-extraSeal:]
// Recover the public key and the Ethereum address // Recover the public key and the Ethereum address
pubkey, err := crypto.Ecrecover(SealHash(header).Bytes(), signature) pubkey, err := crypto.Ecrecover(SealHash(header, chainId).Bytes(), signature)
if err != nil { if err != nil {
return common.Address{}, err return common.Address{}, err
} }
@ -187,9 +187,9 @@ func ecrecover(header *types.Header, sigCache *lru.ARCCache) (common.Address, er
// Note, the method requires the extra data to be at least 65 bytes, otherwise it // Note, the method requires the extra data to be at least 65 bytes, otherwise it
// panics. This is done to avoid accidentally using both forms (signature present // panics. This is done to avoid accidentally using both forms (signature present
// or not), which could be abused to produce different hashes for the same header. // or not), which could be abused to produce different hashes for the same header.
func ParliaRLP(header *types.Header) []byte { func ParliaRLP(header *types.Header, chainId *big.Int) []byte {
b := new(bytes.Buffer) b := new(bytes.Buffer)
encodeSigHeader(b, header) encodeSigHeader(b, header, chainId)
return b.Bytes() return b.Bytes()
} }
@ -498,7 +498,7 @@ func (p *Parlia) snapshot(chain consensus.ChainReader, number uint64, hash commo
headers[i], headers[len(headers)-1-i] = headers[len(headers)-1-i], headers[i] headers[i], headers[len(headers)-1-i] = headers[len(headers)-1-i], headers[i]
} }
snap, err := snap.apply(headers, chain, parents) snap, err := snap.apply(headers, chain, parents, p.chainConfig.ChainID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -546,7 +546,7 @@ func (p *Parlia) verifySeal(chain consensus.ChainReader, header *types.Header, p
} }
// Resolve the authorization key and check against validators // Resolve the authorization key and check against validators
signer, err := ecrecover(header, p.signatures) signer, err := ecrecover(header, p.signatures, p.chainConfig.ChainID)
if err != nil { if err != nil {
return err return err
} }
@ -821,7 +821,7 @@ func (p *Parlia) Seal(chain consensus.ChainReader, block *types.Block, results c
log.Info("Sealing block with", "number", number, "delay", delay, "headerDifficulty", header.Difficulty, "val", val.Hex()) log.Info("Sealing block with", "number", number, "delay", delay, "headerDifficulty", header.Difficulty, "val", val.Hex())
// Sign all the things! // Sign all the things!
sig, err := signFn(accounts.Account{Address: val}, accounts.MimetypeParlia, ParliaRLP(header)) sig, err := signFn(accounts.Account{Address: val}, accounts.MimetypeParlia, ParliaRLP(header, p.chainConfig.ChainID))
if err != nil { if err != nil {
return err return err
} }
@ -839,7 +839,7 @@ func (p *Parlia) Seal(chain consensus.ChainReader, block *types.Block, results c
select { select {
case results <- block.WithSeal(header): case results <- block.WithSeal(header):
default: default:
log.Warn("Sealing result is not read by miner", "sealhash", SealHash(header)) log.Warn("Sealing result is not read by miner", "sealhash", SealHash(header, p.chainConfig.ChainID))
} }
}() }()
@ -869,7 +869,7 @@ func CalcDifficulty(snap *Snapshot, signer common.Address) *big.Int {
// SealHash returns the hash of a block prior to it being sealed. // SealHash returns the hash of a block prior to it being sealed.
func (p *Parlia) SealHash(header *types.Header) common.Hash { func (p *Parlia) SealHash(header *types.Header) common.Hash {
return SealHash(header) return SealHash(header, p.chainConfig.ChainID)
} }
// APIs implements consensus.Engine, returning the user facing RPC API to query snapshot. // APIs implements consensus.Engine, returning the user facing RPC API to query snapshot.
@ -1109,15 +1109,16 @@ func (p *Parlia) applyTransaction(
// =========================== utility function ========================== // =========================== utility function ==========================
// SealHash returns the hash of a block prior to it being sealed. // SealHash returns the hash of a block prior to it being sealed.
func SealHash(header *types.Header) (hash common.Hash) { func SealHash(header *types.Header, chainId *big.Int) (hash common.Hash) {
hasher := sha3.NewLegacyKeccak256() hasher := sha3.NewLegacyKeccak256()
encodeSigHeader(hasher, header) encodeSigHeader(hasher, header, chainId)
hasher.Sum(hash[:0]) hasher.Sum(hash[:0])
return hash return hash
} }
func encodeSigHeader(w io.Writer, header *types.Header) { func encodeSigHeader(w io.Writer, header *types.Header, chainId *big.Int) {
err := rlp.Encode(w, []interface{}{ err := rlp.Encode(w, []interface{}{
chainId,
header.ParentHash, header.ParentHash,
header.UncleHash, header.UncleHash,
header.Coinbase, header.Coinbase,

@ -20,6 +20,7 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"errors" "errors"
"math/big"
"sort" "sort"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
@ -123,7 +124,7 @@ func (s *Snapshot) copy() *Snapshot {
return cpy return cpy
} }
func (s *Snapshot) apply(headers []*types.Header, chain consensus.ChainReader, parents []*types.Header) (*Snapshot, error) { func (s *Snapshot) apply(headers []*types.Header, chain consensus.ChainReader, parents []*types.Header, chainId *big.Int) (*Snapshot, error) {
// Allow passing in no headers for cleaner code // Allow passing in no headers for cleaner code
if len(headers) == 0 { if len(headers) == 0 {
return s, nil return s, nil
@ -153,7 +154,7 @@ func (s *Snapshot) apply(headers []*types.Header, chain consensus.ChainReader, p
delete(snap.Recents, number-limit) delete(snap.Recents, number-limit)
} }
// Resolve the authorization key and check against signers // Resolve the authorization key and check against signers
validator, err := ecrecover(header, s.sigCache) validator, err := ecrecover(header, s.sigCache, chainId)
if err != nil { if err != nil {
return nil, err return nil, err
} }

@ -285,7 +285,7 @@ func (api *SignerAPI) determineSignatureFormat(ctx context.Context, contentType
header.Extra = newExtra header.Extra = newExtra
} }
// Get back the rlp data, encoded by us // Get back the rlp data, encoded by us
sighash, parliaRlp, err := parliaHeaderHashAndRlp(header) sighash, parliaRlp, err := parliaHeaderHashAndRlp(header, api.chainID)
if err != nil { if err != nil {
return nil, useEthereumV, err return nil, useEthereumV, err
} }
@ -351,13 +351,13 @@ func cliqueHeaderHashAndRlp(header *types.Header) (hash, rlp []byte, err error)
return hash, rlp, err return hash, rlp, err
} }
func parliaHeaderHashAndRlp(header *types.Header) (hash, rlp []byte, err error) { func parliaHeaderHashAndRlp(header *types.Header, chainId *big.Int) (hash, rlp []byte, err error) {
if len(header.Extra) < 65 { if len(header.Extra) < 65 {
err = fmt.Errorf("clique header extradata too short, %d < 65", len(header.Extra)) err = fmt.Errorf("clique header extradata too short, %d < 65", len(header.Extra))
return return
} }
rlp = parlia.ParliaRLP(header) rlp = parlia.ParliaRLP(header, chainId)
hash = parlia.SealHash(header).Bytes() hash = parlia.SealHash(header, chainId).Bytes()
return hash, rlp, err return hash, rlp, err
} }