diff --git a/data/basics/testing/nearzero.go b/data/basics/testing/nearzero.go index 4620e77c72..6eb26eb04f 100644 --- a/data/basics/testing/nearzero.go +++ b/data/basics/testing/nearzero.go @@ -18,6 +18,7 @@ package testing import ( "reflect" + "slices" "testing" "time" ) @@ -35,7 +36,7 @@ func NearZeros(t *testing.T, sample any) []any { if typ.Kind() != reflect.Struct { t.Fatalf("NearZeros: sample must be a struct, got %s", typ.Kind()) } - paths := CollectPaths(typ, []int{}) + paths := collectPaths(typ, nil, nil) var results []any for _, path := range paths { inst := makeInstanceWithNonZeroField(typ, path) @@ -45,14 +46,22 @@ func NearZeros(t *testing.T, sample any) []any { } // CollectPaths walks over the struct type (recursively) and returns a slice of -// index paths. Each path points to exactly one (exported) sub-field. +// index paths. Each path points to exactly one (exported) sub-field. If the +// type supplied is recursive, the path terminates at the recursion point. func CollectPaths(typ reflect.Type, prefix []int) [][]int { + return collectPaths(typ, prefix, []reflect.Type{}) +} + +// collectPaths walks over the struct type (recursively) and returns a slice of +// index paths. Each path points to exactly one (exported) sub-field. +// It tracks types in the current path to avoid infinite loops on recursive types. +func collectPaths(typ reflect.Type, prefix []int, pathStack []reflect.Type) [][]int { var paths [][]int switch typ.Kind() { case reflect.Ptr, reflect.Slice, reflect.Array: // Look through container to the element - return CollectPaths(typ.Elem(), prefix) + return collectPaths(typ.Elem(), prefix, pathStack) case reflect.Map: // Record as a leaf because we will just make a single entry in the map @@ -64,13 +73,23 @@ func CollectPaths(typ reflect.Type, prefix []int) [][]int { return [][]int{prefix} } + // Check if this type is already in the path stack (cycle detection) + if slices.Contains(pathStack, typ) { + // We've encountered a cycle, treat this as a leaf + return [][]int{prefix} + } + + // Add this type to the path stack + // Clone to avoid sharing the underlying array across branches + newStack := append(slices.Clone(pathStack), typ) + for i := 0; i < typ.NumField(); i++ { field := typ.Field(i) if !field.IsExported() { continue } newPath := append(append([]int(nil), prefix...), i) - subPaths := CollectPaths(field.Type, newPath) + subPaths := collectPaths(field.Type, newPath, newStack) // If recursion yielded deeper paths, use them if len(subPaths) > 0 { diff --git a/data/transactions/teal.go b/data/transactions/teal.go index 075466ba79..6f886d8680 100644 --- a/data/transactions/teal.go +++ b/data/transactions/teal.go @@ -58,17 +58,18 @@ func (ed EvalDelta) Equal(o EvalDelta) bool { return false } - // GlobalDeltas must be equal if !ed.GlobalDelta.Equal(o.GlobalDelta) { return false } - // Logs must be equal + if !slices.Equal(ed.SharedAccts, o.SharedAccts) { + return false + } + if !slices.Equal(ed.Logs, o.Logs) { return false } - // InnerTxns must be equal if len(ed.InnerTxns) != len(o.InnerTxns) { return false } diff --git a/data/transactions/transaction_test.go b/data/transactions/transaction_test.go index 289e5f018b..30e1bdc5f5 100644 --- a/data/transactions/transaction_test.go +++ b/data/transactions/transaction_test.go @@ -19,11 +19,13 @@ package transactions import ( "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/algorand/go-algorand/config" "github.com/algorand/go-algorand/crypto" "github.com/algorand/go-algorand/data/basics" + basics_testing "github.com/algorand/go-algorand/data/basics/testing" "github.com/algorand/go-algorand/protocol" "github.com/algorand/go-algorand/test/partitiontest" ) @@ -76,6 +78,7 @@ func TestTransactionHash(t *testing.T) { func TestTransactionIDChanges(t *testing.T) { partitiontest.PartitionTest(t) + t.Parallel() txn := Transaction{ Type: "pay", @@ -115,3 +118,39 @@ func TestTransactionIDChanges(t *testing.T) { t.Errorf("txid does not depend on lastvalid") } } + +func TestApplyDataEquality(t *testing.T) { + partitiontest.PartitionTest(t) + t.Parallel() + + var empty ApplyData + for _, nz := range basics_testing.NearZeros(t, ApplyData{}) { + ad := nz.(ApplyData) + assert.False(t, ad.Equal(empty), "Equal() seems to be disregarding something %+v", ad) + } + +} + +func TestEvalDataEquality(t *testing.T) { + partitiontest.PartitionTest(t) + t.Parallel() + + var empty EvalDelta + for _, nz := range basics_testing.NearZeros(t, EvalDelta{}) { + ed := nz.(EvalDelta) + assert.False(t, ed.Equal(empty), "Equal() seems to be disregarding something %+v", ed) + } + +} + +func TestLogicSigEquality(t *testing.T) { + partitiontest.PartitionTest(t) + t.Parallel() + + var empty LogicSig + for _, nz := range basics_testing.NearZeros(t, LogicSig{}) { + ls := nz.(LogicSig) + assert.False(t, ls.Equal(&empty), "Equal() seems to be disregarding something %+v", ls) + } + +}