diff --git a/ethcoder/typed_data.go b/ethcoder/typed_data.go index 6da4f664..8efa9558 100644 --- a/ethcoder/typed_data.go +++ b/ethcoder/typed_data.go @@ -22,6 +22,38 @@ type TypedData struct { type TypedDataTypes map[string][]TypedDataArgument +// ValidateTypeGraph checks the type graph for cycles. A cycle would cause +// infinite recursion in EncodeType/encodeValue, leading to an unrecoverable +// stack overflow. This must be called before any recursive type traversal. +func (t TypedDataTypes) ValidateTypeGraph() error { + for typeName := range t { + if err := t.walkTypeGraph(typeName, make(map[string]bool)); err != nil { + return err + } + } + return nil +} + +func (t TypedDataTypes) walkTypeGraph(current string, visiting map[string]bool) error { + if visiting[current] { + return fmt.Errorf("cycle detected in type graph at %q", current) + } + visiting[current] = true + defer delete(visiting, current) + for _, field := range t[current] { + baseType := field.Type + if i := strings.Index(baseType, "["); i > 0 { + baseType = baseType[:i] + } + if _, ok := t[baseType]; ok { + if err := t.walkTypeGraph(baseType, visiting); err != nil { + return err + } + } + } + return nil +} + func (t TypedDataTypes) EncodeType(primaryType string) (string, error) { args, ok := t[primaryType] if !ok { @@ -231,6 +263,10 @@ func (t *TypedData) encodeValue(typ string, value interface{}) ([]byte, error) { // * the digest is the hash of the fully encoded EIP712 message // * the encoded message is the fully encoded EIP712 message (0x1901 + domain + hashStruct(message)) func (t *TypedData) Encode() ([]byte, []byte, error) { + if err := t.Types.ValidateTypeGraph(); err != nil { + return nil, nil, err + } + EIP191_HEADER := "0x1901" // EIP191 for typed data eip191Header, err := HexDecode(EIP191_HEADER) if err != nil { diff --git a/ethcoder/typed_data_json.go b/ethcoder/typed_data_json.go index 79db176d..c9ade308 100644 --- a/ethcoder/typed_data_json.go +++ b/ethcoder/typed_data_json.go @@ -202,6 +202,11 @@ func (t *TypedData) UnmarshalJSON(data []byte) error { domain.ChainID = chainID } + // Validate the type graph for cycles before any recursive traversal + if err := raw.Types.ValidateTypeGraph(); err != nil { + return err + } + // Decode the raw message into Go runtime types message, err := typedDataDecodeRawMessageMap(raw.Types.Map(), raw.PrimaryType, raw.Message) if err != nil { diff --git a/ethcoder/typed_data_test.go b/ethcoder/typed_data_test.go index 8ed67d94..8bc56d16 100644 --- a/ethcoder/typed_data_test.go +++ b/ethcoder/typed_data_test.go @@ -3,6 +3,7 @@ package ethcoder_test import ( "encoding/json" "math/big" + "strings" "testing" "github.com/0xsequence/ethkit/ethcoder" @@ -731,3 +732,109 @@ func TestTypedDataFromJSONPart6(t *testing.T) { require.NoError(t, err) require.Equal(t, digest, digest2) } + +func TestTypedDataCycleDetection(t *testing.T) { + t.Run("simple cycle A->B->A", func(t *testing.T) { + types := ethcoder.TypedDataTypes{ + "EIP712Domain": {}, + "A": {{Name: "b", Type: "B"}}, + "B": {{Name: "a", Type: "A"}}, + } + err := types.ValidateTypeGraph() + require.Error(t, err) + assert.True(t, strings.Contains(err.Error(), "cycle detected")) + }) + + t.Run("self-referencing type A->A", func(t *testing.T) { + types := ethcoder.TypedDataTypes{ + "EIP712Domain": {}, + "A": {{Name: "self", Type: "A"}}, + } + err := types.ValidateTypeGraph() + require.Error(t, err) + assert.True(t, strings.Contains(err.Error(), "cycle detected")) + }) + + t.Run("longer cycle A->B->C->A", func(t *testing.T) { + types := ethcoder.TypedDataTypes{ + "EIP712Domain": {}, + "A": {{Name: "b", Type: "B"}}, + "B": {{Name: "c", Type: "C"}}, + "C": {{Name: "a", Type: "A"}}, + } + err := types.ValidateTypeGraph() + require.Error(t, err) + assert.True(t, strings.Contains(err.Error(), "cycle detected")) + }) + + t.Run("cycle through array type A->B[]->A", func(t *testing.T) { + types := ethcoder.TypedDataTypes{ + "EIP712Domain": {}, + "A": {{Name: "bs", Type: "B[]"}}, + "B": {{Name: "a", Type: "A"}}, + } + err := types.ValidateTypeGraph() + require.Error(t, err) + assert.True(t, strings.Contains(err.Error(), "cycle detected")) + }) + + t.Run("valid DAG with diamond shape", func(t *testing.T) { + types := ethcoder.TypedDataTypes{ + "EIP712Domain": {}, + "A": {{Name: "b", Type: "B"}, {Name: "c", Type: "C"}}, + "B": {{Name: "d", Type: "D"}}, + "C": {{Name: "d", Type: "D"}}, + "D": {{Name: "value", Type: "uint256"}}, + } + err := types.ValidateTypeGraph() + require.NoError(t, err) + }) + + t.Run("valid simple types no cycle", func(t *testing.T) { + types := ethcoder.TypedDataTypes{ + "EIP712Domain": {}, + "Person": { + {Name: "name", Type: "string"}, + {Name: "wallet", Type: "address"}, + }, + } + err := types.ValidateTypeGraph() + require.NoError(t, err) + }) + + t.Run("cycle rejected during JSON unmarshal", func(t *testing.T) { + typedDataJson := `{ + "types": { + "EIP712Domain": [], + "A": [{"name": "b", "type": "B"}], + "B": [{"name": "a", "type": "A"}] + }, + "primaryType": "A", + "domain": {}, + "message": {"b": {"a": {}}} + }` + _, err := ethcoder.TypedDataFromJSON(typedDataJson) + require.Error(t, err) + assert.True(t, strings.Contains(err.Error(), "cycle detected")) + }) + + t.Run("cycle rejected during Encode", func(t *testing.T) { + typedData := ðcoder.TypedData{ + Types: ethcoder.TypedDataTypes{ + "EIP712Domain": {}, + "A": {{Name: "b", Type: "B"}}, + "B": {{Name: "a", Type: "A"}}, + }, + PrimaryType: "A", + Domain: ethcoder.TypedDataDomain{}, + Message: map[string]interface{}{"b": map[string]interface{}{"a": map[string]interface{}{}}}, + } + _, _, err := typedData.Encode() + require.Error(t, err) + assert.True(t, strings.Contains(err.Error(), "cycle detected")) + + _, err = typedData.EncodeDigest() + require.Error(t, err) + assert.True(t, strings.Contains(err.Error(), "cycle detected")) + }) +}