diff --git a/mempool/txgraph/README.md b/mempool/txgraph/README.md new file mode 100644 index 0000000000..9ce587dc98 --- /dev/null +++ b/mempool/txgraph/README.md @@ -0,0 +1,356 @@ +# txgraph: Transaction Graph for Bitcoin Mempool + +The `txgraph` package provides a high-performance transaction graph data +structure for tracking relationships between unconfirmed Bitcoin transactions. +It enables efficient ancestor/descendant queries, transaction package +identification, and cluster analysis—all critical for modern mempool policies +including TRUC (v3 transactions), CPFP (Child-Pays-For-Parent), and ephemeral +dust handling. + +## Why txgraph? + +Bitcoin's mempool needs to understand transaction dependencies to make +intelligent relay and mining decisions. When a child transaction spends outputs +from an unconfirmed parent, the mempool must: + +- **Enforce policy limits**: Ancestor/descendant count and size restrictions +(BIP 125) + +- **Enable package relay**: Validate and relay transaction groups atomically + +- **Calculate effective fee rates**: Consider CPFP when prioritizing transactions + +- **Detect conflicts**: Identify transactions that spend the same outputs + +- **Support RBF**: Compute incentive compatibility for replacements + +The `txgraph` package provides the foundational graph structure that makes +these operations efficient, replacing O(n) linear scans with O(1) lookups and +cached computations. + +## Core Concepts + +### Transaction Graph + +A **transaction graph** is a directed acyclic graph (DAG) where: + +- **Nodes** represent unconfirmed transactions in the mempool + +- **Edges** represent spend relationships (parent → child) + +- An edge from tx A to tx B means tx B spends an output from tx A + +The graph structure enables efficient traversal for ancestor/descendant queries +without repeatedly scanning the entire mempool. + +### Clusters + +A **cluster** is a connected component in the transaction graph—a set of +transactions that are related through spend relationships. Clusters are +important for: + +- **RBF validation**: Replacement transactions must improve the fee of the +entire cluster + +- **Mining optimization**: Miners can evaluate clusters as atomic units + +- **Eviction policy**: When the mempool is full, low-fee clusters are +candidates for removal + +### Transaction Packages + +A **package** is a specific type of transaction group identified by structure +and validation rules: + +- **1P1C (One Parent, One Child)**: Simple CPFP pattern with exactly one parent +and one child + +- **TRUC (v3)**: BIP 431 packages with topology restrictions to prevent pinning + +- **Ephemeral**: Packages containing transactions with dust outputs that must +be spent + +- **Standard**: General connected transaction groups + +Packages enable package-aware relay policies and optimized block template +construction. + +## Installation + +```bash +go get github.com/btcsuite/btcd/mempool/txgraph +``` + +## Quick Start + +### Example 1: Building a Transaction Graph + +```go +package main + +import ( + "fmt" + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/mempool/txgraph" + "github.com/btcsuite/btcd/wire" +) + +func main() { + // Create a new transaction graph with default configuration. + graph := txgraph.New(txgraph.DefaultConfig()) + + // Create a parent transaction (typically from P2P relay). + parentTx := createTransaction() + parentDesc := &txgraph.TxDesc{ + TxHash: *parentTx.Hash(), + VirtualSize: 200, + Fee: 1000, + FeePerKB: 5000, + } + + // Add parent to the graph. + if err := graph.AddTransaction(parentTx, parentDesc); err != nil { + panic(err) + } + + // Create a child that spends from the parent. + childTx := createChildTransaction(parentTx) // Spends parent's output + childDesc := &txgraph.TxDesc{ + TxHash: *childTx.Hash(), + VirtualSize: 150, + Fee: 2000, + FeePerKB: 13333, + } + + // Add child to the graph - edges are created automatically. + if err := graph.AddTransaction(childTx, childDesc); err != nil { + panic(err) + } + + // Query ancestors (returns parent). + ancestors := graph.GetAncestors(*childTx.Hash(), -1) // -1 = unlimited depth + fmt.Printf("Child has %d ancestor(s)\n", len(ancestors)) + + // Query descendants (returns child). + descendants := graph.GetDescendants(*parentTx.Hash(), -1) + fmt.Printf("Parent has %d descendant(s)\n", len(descendants)) +} +``` + +### Example 2: Iterating Over Clusters + +```go +package main + +import ( + "fmt" + "github.com/btcsuite/btcd/mempool/txgraph" +) + +func analyzeMempool(graph *txgraph.TxGraph) { + // Iterate over all connected components (clusters) in the mempool. + for cluster := range graph.IterateClusters() { + fmt.Printf("Cluster %d:\n", cluster.ID) + fmt.Printf(" Transactions: %d\n", cluster.Size) + fmt.Printf(" Total fees: %d sats\n", cluster.TotalFees) + fmt.Printf(" Total size: %d vbytes\n", cluster.TotalVSize) + + // Calculate cluster fee rate. + if cluster.TotalVSize > 0 { + feeRate := (cluster.TotalFees * 1000) / cluster.TotalVSize + fmt.Printf(" Fee rate: %d sat/vB\n", feeRate) + } + + // Identify entry points (root transactions with no parents). + fmt.Printf(" Roots: %d\n", len(cluster.Roots)) + + // Identify leaf transactions (no children, candidates for eviction). + fmt.Printf(" Leaves: %d\n", len(cluster.Leaves)) + } +} +``` + +### Example 3: Package Identification and Validation + +```go +package main + +import ( + "fmt" + "github.com/btcsuite/btcd/mempool/txgraph" +) + +func validatePackages(graph *txgraph.TxGraph, analyzer txgraph.PackageAnalyzer) { + // Identify all packages in the graph. + packages, err := graph.IdentifyPackages() + if err != nil { + panic(err) + } + + for _, pkg := range packages { + fmt.Printf("Package %s (type: %v):\n", + pkg.ID.Hash.String()[:8], pkg.Type) + + // Check package properties. + fmt.Printf(" Transactions: %d\n", len(pkg.Transactions)) + fmt.Printf(" Total fees: %d sats\n", pkg.TotalFees) + fmt.Printf(" Package fee rate: %d sat/vB\n", pkg.FeeRate) + + // Validate package according to its type-specific rules. + if err := graph.ValidatePackage(pkg); err != nil { + fmt.Printf(" ❌ Validation failed: %v\n", err) + continue + } + fmt.Printf(" ✅ Valid package\n") + + // Check topology properties. + topo := pkg.Topology + if topo.IsLinear { + fmt.Printf(" Structure: Linear chain (depth %d)\n", topo.MaxDepth) + } else if topo.IsTree { + fmt.Printf(" Structure: Tree (max width %d)\n", topo.MaxWidth) + } else { + fmt.Printf(" Structure: Complex DAG\n") + } + } +} +``` + +### Example 4: Advanced Iteration with Options + +```go +package main + +import ( + "fmt" + "github.com/btcsuite/btcd/mempool/txgraph" +) + +func findHighFeeTransactions(graph *txgraph.TxGraph, minFeeRate int64) { + // Use functional options to configure iteration. + highFeeFilter := func(node *txgraph.TxGraphNode) bool { + return node.TxDesc.FeePerKB >= minFeeRate + } + + // Iterate in fee rate order (high to low), filtered by minimum fee rate. + for node := range graph.Iterate( + txgraph.WithOrder(txgraph.TraversalFeeRate), + txgraph.WithFilter(highFeeFilter), + ) { + fmt.Printf("High-fee tx: %s (%d sat/kB)\n", + node.TxHash.String()[:8], + node.TxDesc.FeePerKB, + ) + } +} + +func getAncestorsWithLimit(graph *txgraph.TxGraph, txHash chainhash.Hash) { + // Iterate ancestors up to depth 3 in BFS order. + for ancestor := range graph.Iterate( + txgraph.WithOrder(txgraph.TraversalBFS), + txgraph.WithStartNode(&txHash), + txgraph.WithDirection(txgraph.DirectionBackward), + txgraph.WithMaxDepth(3), + txgraph.WithIncludeStart(false), // Exclude starting transaction + ) { + fmt.Printf("Ancestor: %s\n", ancestor.TxHash.String()[:8]) + } +} +``` + +## Common Patterns + +### Building Graphs Incrementally + +Add transactions as they arrive from P2P relay. The graph automatically creates +edges when it detects spend relationships: + +```go +// As each transaction arrives... +if err := graph.AddTransaction(tx, txDesc); err != nil { + // Handle error (duplicate, invalid, etc.) +} +// Edges to existing parents are created automatically. +``` + +### Package-Aware Relay + +Use package identification to enable package relay policies: + +```go +packages, _ := graph.IdentifyPackages() +for _, pkg := range packages { + if pkg.Type == txgraph.PackageTypeTRUC { + // Apply TRUC-specific relay rules + } +} +``` + +### Ancestor/Descendant Limits + +Enforce BIP 125 policy limits before accepting transactions: + +```go +const maxAncestorCount = 25 +const maxDescendantCount = 25 + +ancestors := graph.GetAncestors(txHash, -1) +if len(ancestors) > maxAncestorCount { + return errors.New("exceeds ancestor limit") +} + +descendants := graph.GetDescendants(txHash, -1) +if len(descendants) > maxDescendantCount { + return errors.New("exceeds descendant limit") +} +``` + +### Efficient Graph Cleanup + +Remove confirmed transactions efficiently when blocks arrive: + +```go +for _, tx := range block.Transactions() { + // Cascade removal: removes tx and all descendants + graph.RemoveTransaction(*tx.Hash()) +} +``` + +## Package Analyzer Interface + +The `PackageAnalyzer` interface abstracts protocol-specific validation logic, +enabling testing and future upgrades without modifying the core graph: + +```go +type PackageAnalyzer interface { + IsTRUCTransaction(tx *wire.MsgTx) bool + HasEphemeralDust(tx *wire.MsgTx) bool + IsZeroFee(desc *TxDesc) bool + ValidateTRUCPackage(nodes []*TxGraphNode) bool + ValidateEphemeralPackage(nodes []*TxGraphNode) bool + AnalyzePackageType(nodes []*TxGraphNode) PackageType +} +``` + +Implement this interface to customize package validation for your use case or +to add new package types. + +## Performance Characteristics + +- **Transaction lookup**: O(1) via hash map +- **Add transaction**: O(1) for graph insertion + O(k) for k parent edges +- **Remove transaction**: O(1) for node + O(d) for d descendants (cascade) +- **Ancestor/descendant query**: O(a) or O(d) where a/d is count +- **Package identification**: O(n) where n is number of root nodes + +## Thread Safety + +All graph operations are thread-safe and protected by a read-write mutex. Read +operations (queries, iteration) can proceed concurrently, while write +operations (add, remove) acquire exclusive access. + +## Further Reading + +- **API Documentation**: Run `go doc github.com/btcsuite/btcd/mempool/txgraph` +- **BIP 431 (TRUC)**: https://github.com/bitcoin/bips/blob/master/bip-0431.mediawiki +- **BIP 125 (RBF)**: https://github.com/bitcoin/bips/blob/master/bip-0125.mediawiki diff --git a/mempool/txgraph/bench_test.go b/mempool/txgraph/bench_test.go new file mode 100644 index 0000000000..f736fd919c --- /dev/null +++ b/mempool/txgraph/bench_test.go @@ -0,0 +1,478 @@ +package txgraph + +import ( + "fmt" + "testing" + + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" +) + +// BenchmarkAddTransaction benchmarks adding transactions to the graph. +func BenchmarkAddTransaction(b *testing.B) { + b.ReportAllocs() + + g := New(DefaultConfig()) + + // Pre-create transactions to isolate graph operations from test data + // generation overhead in timing measurements. + txs := make([]*wire.MsgTx, b.N) + descs := make([]*TxDesc, b.N) + + for i := 0; i < b.N; i++ { + tx, desc := createTestTx(nil, 1) + txs[i] = tx.MsgTx() + descs[i] = desc + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + btcTx := btcutil.NewTx(txs[i]) + g.AddTransaction(btcTx, descs[i]) + } + + b.ReportMetric(float64(g.GetNodeCount()), "nodes") +} + +// BenchmarkAddTransactionWithEdges benchmarks adding connected transactions. +// This measures the overhead of automatic edge detection when transactions +// form dependency chains, which is critical for understanding the cost of +// maintaining parent-child relationships in the mempool graph. +func BenchmarkAddTransactionWithEdges(b *testing.B) { + b.ReportAllocs() + + g := New(DefaultConfig()) + + // Create initial parent to enable edge creation for subsequent + // transactions. This setup allows us to measure the overhead of + // automatic edge detection as each new transaction references the + // previous one. + parent, parentDesc := createTestTx(nil, 1) + g.AddTransaction(parent, parentDesc) + + b.ResetTimer() + + prevHash := parent.Hash() + for i := 0; i < b.N; i++ { + tx, desc := createTestTx([]wire.OutPoint{ + {Hash: *prevHash, Index: 0}, + }, 1) + + g.AddTransaction(tx, desc) + prevHash = tx.Hash() + } + + b.ReportMetric(float64(g.GetNodeCount()), "nodes") + b.ReportMetric(float64(g.GetMetrics().EdgeCount), "edges") +} + +// BenchmarkRemoveTransaction benchmarks removing transactions. +// This measures the cost of graph cleanup operations, which occur when +// transactions confirm in blocks or are evicted from the mempool. +func BenchmarkRemoveTransaction(b *testing.B) { + b.ReportAllocs() + + g := New(DefaultConfig()) + + // Pre-populate the graph with transactions to isolate removal + // performance from insertion overhead. + hashes := make([]*chainhash.Hash, b.N) + for i := 0; i < b.N; i++ { + tx, desc := createTestTx(nil, 1) + g.AddTransaction(tx, desc) + hashes[i] = tx.Hash() + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + g.RemoveTransaction(*hashes[i]) + } +} + +// BenchmarkGetAncestors benchmarks ancestor queries at various chain depths. +// This is critical for mempool policy enforcement, as ancestor limits are +// checked before accepting transactions (BIP 125). +func BenchmarkGetAncestors(b *testing.B) { + benchmarkSizes := []int{10, 100, 1000} + + for _, size := range benchmarkSizes { + b.Run(fmt.Sprintf("depth_%d", size), func(b *testing.B) { + b.ReportAllocs() + + g := New(DefaultConfig()) + + // Create a linear chain of transactions to test ancestor + // traversal at different depths. This simulates worst-case + // scenarios where transactions have deep dependency chains. + var lastHash *chainhash.Hash + for i := 0; i < size; i++ { + var inputs []wire.OutPoint + if lastHash != nil { + inputs = append(inputs, wire.OutPoint{ + Hash: *lastHash, + Index: 0, + }) + } + + tx, desc := createTestTx(inputs, 1) + g.AddTransaction(tx, desc) + lastHash = tx.Hash() + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + ancestors := g.GetAncestors(*lastHash, -1) + _ = ancestors + } + }) + } +} + +// BenchmarkGetDescendants benchmarks descendant queries at various depths. +// This is essential for RBF validation, where replacement transactions must +// account for all descendants of conflicting transactions. +func BenchmarkGetDescendants(b *testing.B) { + benchmarkSizes := []int{10, 100, 1000} + + for _, size := range benchmarkSizes { + b.Run(fmt.Sprintf("depth_%d", size), func(b *testing.B) { + b.ReportAllocs() + + g := New(DefaultConfig()) + + // Create a linear chain to test descendant traversal from the + // root transaction. This measures the cost of walking forward + // through dependency chains. + var firstHash *chainhash.Hash + var lastHash *chainhash.Hash + + for i := 0; i < size; i++ { + var inputs []wire.OutPoint + if lastHash != nil { + inputs = append(inputs, wire.OutPoint{ + Hash: *lastHash, + Index: 0, + }) + } + + tx, desc := createTestTx(inputs, 1) + g.AddTransaction(tx, desc) + + if i == 0 { + firstHash = tx.Hash() + } + lastHash = tx.Hash() + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + descendants := g.GetDescendants(*firstHash, -1) + _ = descendants + } + }) + } +} + +// BenchmarkIterateDFS benchmarks depth-first iteration over the graph. +// DFS is used for dependency-aware traversal when processing transaction +// chains in topological order matters. +func BenchmarkIterateDFS(b *testing.B) { + graphSizes := []int{100, 1000, 5000} + + for _, size := range graphSizes { + b.Run(fmt.Sprintf("nodes_%d", size), func(b *testing.B) { + b.ReportAllocs() + + g := New(DefaultConfig()) + + // Create a graph with independent transactions to measure the + // cost of DFS traversal across disconnected components. + for i := 0; i < size; i++ { + tx, desc := createTestTx(nil, 1) + g.AddTransaction(tx, desc) + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + count := 0 + for range g.Iterate( + WithOrder(TraversalDFS), + ) { + count++ + } + } + }) + } +} + +// BenchmarkIterateBFS benchmarks breadth-first iteration over the graph. +// BFS is useful for level-order traversal when analyzing transaction packages +// layer by layer, such as identifying ancestor sets at specific depths. +func BenchmarkIterateBFS(b *testing.B) { + graphSizes := []int{100, 1000, 5000} + + for _, size := range graphSizes { + b.Run(fmt.Sprintf("nodes_%d", size), func(b *testing.B) { + b.ReportAllocs() + + g := New(DefaultConfig()) + + // Create a graph with independent transactions to measure BFS + // performance across multiple disconnected components. + for i := 0; i < size; i++ { + tx, desc := createTestTx(nil, 1) + g.AddTransaction(tx, desc) + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + count := 0 + for range g.Iterate( + WithOrder(TraversalBFS), + ) { + count++ + } + } + }) + } +} + +// BenchmarkIterateTopological benchmarks topological iteration. +// Topological ordering is critical for block template construction and +// ensuring transactions are processed before their descendants. +func BenchmarkIterateTopological(b *testing.B) { + graphSizes := []int{100, 1000, 5000} + + for _, size := range graphSizes { + b.Run(fmt.Sprintf("nodes_%d", size), func(b *testing.B) { + b.ReportAllocs() + + g := New(DefaultConfig()) + + // Create a mix of chains and roots to test topological sort + // performance on realistic graph structures. Breaking chains + // every 10 transactions creates multiple entry points. + var lastHash *chainhash.Hash + for i := 0; i < size; i++ { + var inputs []wire.OutPoint + if lastHash != nil && i%10 != 0 { + inputs = append(inputs, wire.OutPoint{ + Hash: *lastHash, + Index: 0, + }) + } + + tx, desc := createTestTx(inputs, 1) + g.AddTransaction(tx, desc) + lastHash = tx.Hash() + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + count := 0 + for range g.Iterate( + WithOrder(TraversalTopological), + ) { + count++ + } + } + }) + } +} + +// BenchmarkIdentifyPackages benchmarks package identification across the graph. +// This is used for package relay policies and determining which transaction +// groups should be evaluated together for mining and validation. +func BenchmarkIdentifyPackages(b *testing.B) { + packageCounts := []int{10, 100, 500} + + for _, count := range packageCounts { + b.Run(fmt.Sprintf("packages_%d", count), func(b *testing.B) { + b.ReportAllocs() + + g := New(DefaultConfig()) + + // Create 1P1C (one parent, one child) packages, which are the + // most common pattern for CPFP. This tests the cost of scanning + // for package structures across many independent clusters. + for i := 0; i < count; i++ { + parent, parentDesc := createTestTx(nil, 1) + g.AddTransaction(parent, parentDesc) + + child, childDesc := createTestTx([]wire.OutPoint{ + {Hash: *parent.Hash(), Index: 0}, + }, 1) + g.AddTransaction(child, childDesc) + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + packages, _ := g.IdentifyPackages() + _ = packages + } + }) + } +} + +// BenchmarkPackageCreation benchmarks creating packages from graph nodes. +// This measures the cost of constructing TxPackage objects with computed +// aggregate metrics (total fees, sizes, topology analysis). +func BenchmarkPackageCreation(b *testing.B) { + b.ReportAllocs() + + g := New(DefaultConfig()) + + // Create a simple parent-child pair to benchmark package construction + // overhead. The package creation involves metric aggregation and + // topology analysis. + parent, parentDesc := createTestTx(nil, 1) + child, childDesc := createTestTx([]wire.OutPoint{ + {Hash: *parent.Hash(), Index: 0}, + }, 1) + + g.AddTransaction(parent, parentDesc) + g.AddTransaction(child, childDesc) + + parentNode, _ := g.GetNode(*parent.Hash()) + childNode, _ := g.GetNode(*child.Hash()) + nodes := []*TxGraphNode{parentNode, childNode} + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + pkg, _ := g.CreatePackage(nodes) + _ = pkg + } +} + +// BenchmarkClusterOperations benchmarks cluster management operations. +// Clusters are connected components in the graph, critical for RBF validation +// and mempool eviction policies. +func BenchmarkClusterOperations(b *testing.B) { + b.Run("cluster_creation", func(b *testing.B) { + b.ReportAllocs() + + g := New(DefaultConfig()) + + b.ResetTimer() + + // Measure the cost of creating new independent clusters as + // unrelated transactions are added to the mempool. + for i := 0; i < b.N; i++ { + tx, desc := createTestTx(nil, 1) + g.AddTransaction(tx, desc) + } + + b.ReportMetric(float64(g.GetClusterCount()), "clusters") + }) + + b.Run("cluster_merging", func(b *testing.B) { + b.ReportAllocs() + + g := New(DefaultConfig()) + + // Pre-create separate clusters that will be merged by transactions + // spending from multiple parents. This simulates the common pattern + // of consolidation transactions in the mempool. + parents := make([]*chainhash.Hash, 100) + for i := 0; i < 100; i++ { + tx, desc := createTestTx(nil, 1) + g.AddTransaction(tx, desc) + parents[i] = tx.Hash() + } + + b.ResetTimer() + + // Merge clusters by creating transactions that spend from multiple + // parents, measuring the cost of cluster union operations. + for i := 0; i < b.N; i++ { + idx1 := i % len(parents) + idx2 := (i + 1) % len(parents) + + inputs := []wire.OutPoint{ + {Hash: *parents[idx1], Index: 0}, + {Hash: *parents[idx2], Index: 0}, + } + + tx, desc := createTestTx(inputs, 1) + g.AddTransaction(tx, desc) + } + + b.ReportMetric(float64(g.GetClusterCount()), "final_clusters") + }) +} + +// BenchmarkGetMetrics benchmarks metric collection from the graph. +// Metrics are used for monitoring mempool health and making eviction decisions +// when the mempool reaches capacity limits. +func BenchmarkGetMetrics(b *testing.B) { + b.ReportAllocs() + + g := New(DefaultConfig()) + + // Populate the graph with transactions to measure the cost of collecting + // aggregate statistics across a realistically sized mempool graph. + for i := 0; i < 1000; i++ { + tx, desc := createTestTx(nil, 1) + g.AddTransaction(tx, desc) + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + metrics := g.GetMetrics() + _ = metrics + } +} + +// BenchmarkConcurrentOperations benchmarks concurrent graph access patterns. +// The graph uses RWMutex for thread safety, allowing concurrent reads while +// serializing writes. This benchmark measures contention and throughput under +// realistic mixed workloads. +func BenchmarkConcurrentOperations(b *testing.B) { + b.ReportAllocs() + + g := New(DefaultConfig()) + + // Pre-populate the graph to provide transactions for read operations. + // This establishes a baseline graph state before measuring concurrent + // access patterns. + hashes := make([]*chainhash.Hash, 1000) + for i := 0; i < 1000; i++ { + tx, desc := createTestTx(nil, 1) + g.AddTransaction(tx, desc) + hashes[i] = tx.Hash() + } + + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + idx := i % len(hashes) + i++ + + // Mix of read and write operations to simulate realistic + // mempool access patterns: lookups, traversals, and insertions. + switch i % 4 { + case 0: + g.GetNode(*hashes[idx]) + case 1: + g.GetAncestors(*hashes[idx], 5) + case 2: + g.GetDescendants(*hashes[idx], 5) + case 3: + tx, desc := createTestTx(nil, 1) + g.AddTransaction(tx, desc) + } + } + }) +} + diff --git a/mempool/txgraph/collections.go b/mempool/txgraph/collections.go new file mode 100644 index 0000000000..f85a3cdac0 --- /dev/null +++ b/mempool/txgraph/collections.go @@ -0,0 +1,269 @@ +package txgraph + +import ( + "container/heap" + "iter" +) + +// Stack implements a generic LIFO stack data structure with O(1) Push and Pop. +// The zero value is ready to use. +type Stack[T any] struct { + items []T +} + +// NewStack creates a new empty stack with optional initial capacity. +func NewStack[T any](capacity ...int) *Stack[T] { + cap := 0 + if len(capacity) > 0 { + cap = capacity[0] + } + return &Stack[T]{ + items: make([]T, 0, cap), + } +} + +// Push adds an item to the top of the stack. +func (s *Stack[T]) Push(item T) { + s.items = append(s.items, item) +} + +// Pop removes and returns the item at the top of the stack. +// Returns false if the stack is empty. +func (s *Stack[T]) Pop() (T, bool) { + if len(s.items) == 0 { + var zero T + return zero, false + } + idx := len(s.items) - 1 + item := s.items[idx] + s.items = s.items[:idx] + return item, true +} + +// Peek returns the item at the top of the stack without removing it. +// Returns false if the stack is empty. +func (s *Stack[T]) Peek() (T, bool) { + if len(s.items) == 0 { + var zero T + return zero, false + } + return s.items[len(s.items)-1], true +} + +// Len returns the number of items in the stack. +func (s *Stack[T]) Len() int { + return len(s.items) +} + +// IsEmpty returns true if the stack contains no items. +func (s *Stack[T]) IsEmpty() bool { + return len(s.items) == 0 +} + +// Clear removes all items from the stack. +func (s *Stack[T]) Clear() { + s.items = s.items[:0] +} + +// Iterate returns an iterator that yields items from top to bottom without +// modifying the stack. +func (s *Stack[T]) Iterate() iter.Seq[T] { + return func(yield func(T) bool) { + for i := len(s.items) - 1; i >= 0; i-- { + if !yield(s.items[i]) { + return + } + } + } +} + +// Queue implements a generic FIFO queue with amortized O(1) Enqueue and +// Dequeue operations. The zero value is ready to use. +type Queue[T any] struct { + items []T +} + +// NewQueue creates a new empty queue with optional initial capacity. +func NewQueue[T any](capacity ...int) *Queue[T] { + cap := 0 + if len(capacity) > 0 { + cap = capacity[0] + } + return &Queue[T]{ + items: make([]T, 0, cap), + } +} + +// Enqueue adds an item to the back of the queue. +func (q *Queue[T]) Enqueue(item T) { + q.items = append(q.items, item) +} + +// Dequeue removes and returns the item at the front of the queue. +// Returns false if the queue is empty. +func (q *Queue[T]) Dequeue() (T, bool) { + if len(q.items) == 0 { + var zero T + return zero, false + } + item := q.items[0] + q.items = q.items[1:] + return item, true +} + +// Peek returns the item at the front of the queue without removing it. +// Returns false if the queue is empty. +func (q *Queue[T]) Peek() (T, bool) { + if len(q.items) == 0 { + var zero T + return zero, false + } + return q.items[0], true +} + +// Len returns the number of items in the queue. +func (q *Queue[T]) Len() int { + return len(q.items) +} + +// IsEmpty returns true if the queue contains no items. +func (q *Queue[T]) IsEmpty() bool { + return len(q.items) == 0 +} + +// Clear removes all items from the queue. +func (q *Queue[T]) Clear() { + q.items = q.items[:0] +} + +// Iterate returns an iterator that yields items from front to back without +// modifying the queue. +func (q *Queue[T]) Iterate() iter.Seq[T] { + return func(yield func(T) bool) { + for _, item := range q.items { + if !yield(item) { + return + } + } + } +} + +// PriorityQueue implements a generic priority queue using container/heap +// ordered by a comparison function. The zero value is NOT ready to use; +// use NewPriorityQueue to create an instance. +type PriorityQueue[T any] struct { + impl *heapImpl[T] +} + +// NewPriorityQueue creates a new priority queue with the given comparison +// function where less(a, b) returns true if a has higher priority than b. +// For a max-heap use: func(a, b) { return a > b }. +func NewPriorityQueue[T any]( + less func(a, b T) bool, + capacity ...int, +) *PriorityQueue[T] { + + cap := 0 + if len(capacity) > 0 { + cap = capacity[0] + } + return &PriorityQueue[T]{ + impl: &heapImpl[T]{ + items: make([]T, 0, cap), + less: less, + }, + } +} + +// Push adds an item to the priority queue. +func (pq *PriorityQueue[T]) Push(item T) { + heap.Push(pq.impl, item) +} + +// Pop removes and returns the highest priority item from the queue. +// Returns false if the queue is empty. +func (pq *PriorityQueue[T]) Pop() (T, bool) { + if pq.impl.Len() == 0 { + var zero T + return zero, false + } + return heap.Pop(pq.impl).(T), true +} + +// Peek returns the highest priority item without removing it. +// Returns false if the queue is empty. +func (pq *PriorityQueue[T]) Peek() (T, bool) { + if pq.impl.Len() == 0 { + var zero T + return zero, false + } + return pq.impl.items[0], true +} + +// Len returns the number of items in the priority queue. +func (pq *PriorityQueue[T]) Len() int { + return pq.impl.Len() +} + +// IsEmpty returns true if the priority queue contains no items. +func (pq *PriorityQueue[T]) IsEmpty() bool { + return pq.impl.Len() == 0 +} + +// Clear removes all items from the priority queue. +func (pq *PriorityQueue[T]) Clear() { + pq.impl.items = pq.impl.items[:0] +} + +// Iterate returns an iterator that yields items in priority order by +// creating a temporary copy to avoid modifying the original queue. +func (pq *PriorityQueue[T]) Iterate() iter.Seq[T] { + return func(yield func(T) bool) { + tmpItems := make([]T, len(pq.impl.items)) + copy(tmpItems, pq.impl.items) + tmp := &PriorityQueue[T]{ + impl: &heapImpl[T]{ + items: tmpItems, + less: pq.impl.less, + }, + } + + for !tmp.IsEmpty() { + item, _ := tmp.Pop() + if !yield(item) { + return + } + } + } +} + +// heapImpl implements heap.Interface to integrate with container/heap. +type heapImpl[T any] struct { + items []T + less func(a, b T) bool +} + +func (h *heapImpl[T]) Len() int { + return len(h.items) +} + +func (h *heapImpl[T]) Less(i, j int) bool { + return h.less(h.items[i], h.items[j]) +} + +func (h *heapImpl[T]) Swap(i, j int) { + h.items[i], h.items[j] = h.items[j], h.items[i] +} + +func (h *heapImpl[T]) Push(x any) { + h.items = append(h.items, x.(T)) +} + +func (h *heapImpl[T]) Pop() any { + n := len(h.items) - 1 + item := h.items[n] + h.items = h.items[:n] + return item +} + +var _ heap.Interface = (*heapImpl[int])(nil) \ No newline at end of file diff --git a/mempool/txgraph/collections_test.go b/mempool/txgraph/collections_test.go new file mode 100644 index 0000000000..84ba0e232c --- /dev/null +++ b/mempool/txgraph/collections_test.go @@ -0,0 +1,501 @@ +package txgraph + +import ( + "testing" +) + +// TestStackBasicOperations verifies LIFO behavior and empty state handling. +func TestStackBasicOperations(t *testing.T) { + t.Parallel() + + stack := NewStack[int]() + + if !stack.IsEmpty() { + t.Error("New stack should be empty") + } + if stack.Len() != 0 { + t.Errorf("Expected length 0, got %d", stack.Len()) + } + + _, ok := stack.Pop() + if ok { + t.Error("Pop on empty stack should return false") + } + + _, ok = stack.Peek() + if ok { + t.Error("Peek on empty stack should return false") + } + + stack.Push(1) + stack.Push(2) + stack.Push(3) + + if stack.Len() != 3 { + t.Errorf("Expected length 3, got %d", stack.Len()) + } + if stack.IsEmpty() { + t.Error("Stack should not be empty") + } + + val, ok := stack.Peek() + if !ok || val != 3 { + t.Errorf("Expected Peek to return 3, got %d", val) + } + if stack.Len() != 3 { + t.Error("Peek should not modify stack size") + } + + val, ok = stack.Pop() + if !ok || val != 3 { + t.Errorf("Expected Pop to return 3, got %d", val) + } + val, ok = stack.Pop() + if !ok || val != 2 { + t.Errorf("Expected Pop to return 2, got %d", val) + } + val, ok = stack.Pop() + if !ok || val != 1 { + t.Errorf("Expected Pop to return 1, got %d", val) + } + + if !stack.IsEmpty() { + t.Error("Stack should be empty after all pops") + } +} + +// TestStackIterate verifies iteration yields items top to bottom without +// modifying the stack state. +func TestStackIterate(t *testing.T) { + t.Parallel() + + stack := NewStack[int]() + stack.Push(1) + stack.Push(2) + stack.Push(3) + + expected := []int{3, 2, 1} + idx := 0 + for val := range stack.Iterate() { + if val != expected[idx] { + t.Errorf( + "Expected %d at index %d, got %d", + expected[idx], + idx, + val, + ) + } + idx++ + } + + if idx != 3 { + t.Errorf("Expected 3 iterations, got %d", idx) + } + + if stack.Len() != 3 { + t.Errorf( + "Iteration should not modify stack, length is %d", + stack.Len(), + ) + } +} + +// TestStackClear verifies Clear operation empties the stack. +func TestStackClear(t *testing.T) { + t.Parallel() + + stack := NewStack[int]() + stack.Push(1) + stack.Push(2) + stack.Push(3) + + stack.Clear() + + if !stack.IsEmpty() { + t.Error("Stack should be empty after Clear") + } + if stack.Len() != 0 { + t.Errorf("Expected length 0 after Clear, got %d", stack.Len()) + } +} + +// TestQueueBasicOperations verifies FIFO behavior and empty state handling. +func TestQueueBasicOperations(t *testing.T) { + t.Parallel() + + queue := NewQueue[int]() + + if !queue.IsEmpty() { + t.Error("New queue should be empty") + } + if queue.Len() != 0 { + t.Errorf("Expected length 0, got %d", queue.Len()) + } + + _, ok := queue.Dequeue() + if ok { + t.Error("Dequeue on empty queue should return false") + } + + _, ok = queue.Peek() + if ok { + t.Error("Peek on empty queue should return false") + } + + queue.Enqueue(1) + queue.Enqueue(2) + queue.Enqueue(3) + + if queue.Len() != 3 { + t.Errorf("Expected length 3, got %d", queue.Len()) + } + if queue.IsEmpty() { + t.Error("Queue should not be empty") + } + + val, ok := queue.Peek() + if !ok || val != 1 { + t.Errorf("Expected Peek to return 1, got %d", val) + } + if queue.Len() != 3 { + t.Error("Peek should not modify queue size") + } + + val, ok = queue.Dequeue() + if !ok || val != 1 { + t.Errorf("Expected Dequeue to return 1, got %d", val) + } + val, ok = queue.Dequeue() + if !ok || val != 2 { + t.Errorf("Expected Dequeue to return 2, got %d", val) + } + val, ok = queue.Dequeue() + if !ok || val != 3 { + t.Errorf("Expected Dequeue to return 3, got %d", val) + } + + if !queue.IsEmpty() { + t.Error("Queue should be empty after all dequeues") + } +} + +// TestQueueIterate verifies iteration yields items front to back without +// modifying the queue state. +func TestQueueIterate(t *testing.T) { + t.Parallel() + + queue := NewQueue[int]() + queue.Enqueue(1) + queue.Enqueue(2) + queue.Enqueue(3) + + expected := []int{1, 2, 3} + idx := 0 + for val := range queue.Iterate() { + if val != expected[idx] { + t.Errorf( + "Expected %d at index %d, got %d", + expected[idx], + idx, + val, + ) + } + idx++ + } + + if idx != 3 { + t.Errorf("Expected 3 iterations, got %d", idx) + } + + if queue.Len() != 3 { + t.Errorf( + "Iteration should not modify queue, length is %d", + queue.Len(), + ) + } +} + +// TestQueueClear verifies Clear operation empties the queue. +func TestQueueClear(t *testing.T) { + t.Parallel() + + queue := NewQueue[int]() + queue.Enqueue(1) + queue.Enqueue(2) + queue.Enqueue(3) + + queue.Clear() + + if !queue.IsEmpty() { + t.Error("Queue should be empty after Clear") + } + if queue.Len() != 0 { + t.Errorf("Expected length 0 after Clear, got %d", queue.Len()) + } +} + +// TestPriorityQueueBasicOperations verifies max-heap behavior maintains +// priority ordering across Push/Pop operations. +func TestPriorityQueueBasicOperations(t *testing.T) { + t.Parallel() + + pq := NewPriorityQueue(func(a, b int) bool { + return a > b + }) + + if !pq.IsEmpty() { + t.Error("New priority queue should be empty") + } + if pq.Len() != 0 { + t.Errorf("Expected length 0, got %d", pq.Len()) + } + + _, ok := pq.Pop() + if ok { + t.Error("Pop on empty priority queue should return false") + } + + _, ok = pq.Peek() + if ok { + t.Error("Peek on empty priority queue should return false") + } + + pq.Push(3) + pq.Push(1) + pq.Push(4) + pq.Push(2) + + if pq.Len() != 4 { + t.Errorf("Expected length 4, got %d", pq.Len()) + } + if pq.IsEmpty() { + t.Error("Priority queue should not be empty") + } + + val, ok := pq.Peek() + if !ok || val != 4 { + t.Errorf("Expected Peek to return 4, got %d", val) + } + if pq.Len() != 4 { + t.Error("Peek should not modify priority queue size") + } + + expected := []int{4, 3, 2, 1} + for i, exp := range expected { + val, ok := pq.Pop() + if !ok { + t.Errorf("Pop %d failed", i) + } + if val != exp { + t.Errorf( + "Expected Pop to return %d at position %d, got %d", + exp, + i, + val, + ) + } + } + + if !pq.IsEmpty() { + t.Error("Priority queue should be empty after all pops") + } +} + +// TestPriorityQueueMinHeap verifies min-heap behavior with reversed +// comparison function. +func TestPriorityQueueMinHeap(t *testing.T) { + t.Parallel() + + pq := NewPriorityQueue(func(a, b int) bool { + return a < b + }) + + pq.Push(3) + pq.Push(1) + pq.Push(4) + pq.Push(2) + + expected := []int{1, 2, 3, 4} + for i, exp := range expected { + val, ok := pq.Pop() + if !ok { + t.Errorf("Pop %d failed", i) + } + if val != exp { + t.Errorf( + "Expected Pop to return %d at position %d, got %d", + exp, + i, + val, + ) + } + } +} + +// TestPriorityQueueIterate verifies iteration yields items in priority order +// without modifying the original queue state. +func TestPriorityQueueIterate(t *testing.T) { + t.Parallel() + + pq := NewPriorityQueue(func(a, b int) bool { + return a > b + }) + + pq.Push(3) + pq.Push(1) + pq.Push(4) + pq.Push(2) + + expected := []int{4, 3, 2, 1} + idx := 0 + for val := range pq.Iterate() { + if val != expected[idx] { + t.Errorf( + "Expected %d at index %d, got %d", + expected[idx], + idx, + val, + ) + } + idx++ + } + + if idx != 4 { + t.Errorf("Expected 4 iterations, got %d", idx) + } + + if pq.Len() != 4 { + t.Errorf( + "Iteration should not modify priority queue, length is %d", + pq.Len(), + ) + } + + val, _ := pq.Pop() + if val != 4 { + t.Errorf( + "Original priority queue corrupted, expected 4, got %d", + val, + ) + } +} + +// TestPriorityQueueClear verifies Clear operation empties the priority queue. +func TestPriorityQueueClear(t *testing.T) { + t.Parallel() + + pq := NewPriorityQueue(func(a, b int) bool { + return a > b + }) + + pq.Push(1) + pq.Push(2) + pq.Push(3) + + pq.Clear() + + if !pq.IsEmpty() { + t.Error("Priority queue should be empty after Clear") + } + if pq.Len() != 0 { + t.Errorf("Expected length 0 after Clear, got %d", pq.Len()) + } +} + +// TestCollectionsWithCapacity verifies pre-allocation with initial capacity +// doesn't affect empty state behavior. +func TestCollectionsWithCapacity(t *testing.T) { + t.Parallel() + + stack := NewStack[int](10) + if stack.Len() != 0 { + t.Error("Stack with capacity should start empty") + } + + queue := NewQueue[int](10) + if queue.Len() != 0 { + t.Error("Queue with capacity should start empty") + } + + pq := NewPriorityQueue(func(a, b int) bool { return a > b }, 10) + if pq.Len() != 0 { + t.Error("PriorityQueue with capacity should start empty") + } +} + +// TestStackWithStrings verifies Stack works correctly with non-int types. +func TestStackWithStrings(t *testing.T) { + t.Parallel() + + stack := NewStack[string]() + stack.Push("first") + stack.Push("second") + stack.Push("third") + + val, _ := stack.Pop() + if val != "third" { + t.Errorf("Expected 'third', got '%s'", val) + } + val, _ = stack.Pop() + if val != "second" { + t.Errorf("Expected 'second', got '%s'", val) + } + val, _ = stack.Pop() + if val != "first" { + t.Errorf("Expected 'first', got '%s'", val) + } +} + +// TestEarlyExitFromIteration verifies iterators respect early break and don't +// continue yielding values after consumer stops. +func TestEarlyExitFromIteration(t *testing.T) { + t.Parallel() + + stack := NewStack[int]() + for i := 0; i < 10; i++ { + stack.Push(i) + } + + count := 0 + for range stack.Iterate() { + count++ + if count == 3 { + break + } + } + if count != 3 { + t.Errorf("Expected 3 iterations, got %d", count) + } + + queue := NewQueue[int]() + for i := 0; i < 10; i++ { + queue.Enqueue(i) + } + + count = 0 + for range queue.Iterate() { + count++ + if count == 3 { + break + } + } + if count != 3 { + t.Errorf("Expected 3 iterations, got %d", count) + } + + pq := NewPriorityQueue(func(a, b int) bool { return a > b }) + for i := 0; i < 10; i++ { + pq.Push(i) + } + + count = 0 + for range pq.Iterate() { + count++ + if count == 3 { + break + } + } + if count != 3 { + t.Errorf("Expected 3 iterations, got %d", count) + } +} \ No newline at end of file diff --git a/mempool/txgraph/doc.go b/mempool/txgraph/doc.go new file mode 100644 index 0000000000..1778955970 --- /dev/null +++ b/mempool/txgraph/doc.go @@ -0,0 +1,99 @@ +// Package txgraph provides a transaction graph data structure for efficiently +// tracking relationships between transactions in the mempool. It supports +// package identification, ancestor/descendant queries, and various traversal +// strategies using Go's iter.Seq iterators. +// +// # Core Features +// +// - Efficient O(1) lookups for transactions by hash +// - Automatic edge creation based on transaction inputs/outputs +// - Cluster management for connected components +// - Package identification (1P1C, TRUC, ephemeral dust) +// - Orphan transaction detection with configurable predicates +// - Multiple traversal strategies (DFS, BFS, topological) +// - Thread-safe operations with fine-grained locking +// +// # Graph Structure +// +// The graph maintains transactions as nodes with edges representing +// parent-child relationships. When a transaction spends outputs from +// another transaction, an edge is created from parent to child. +// +// # Packages +// +// The graph can identify various package types: +// - 1P1C: One parent, one child packages for CPFP +// - TRUC: Version 3 transactions with topology restrictions +// - Ephemeral: Packages with ephemeral dust outputs +// - Standard: General connected transaction packages +// +// # Iteration +// +// The graph supports multiple iteration strategies using iter.Seq: +// +// // Iterate over all nodes in DFS order +// for node := range graph.Iterate(IteratorOption{Order: TraversalDFS}) { +// // Process node +// } +// +// // Iterate over ancestors of a specific transaction +// for node := range graph.Iterate(IteratorOption{ +// Order: TraversalAncestors, +// StartNode: txHash, +// MaxDepth: 10, +// }) { +// // Process ancestor +// } +// +// # Orphan Detection +// +// The graph can identify orphan transactions - transactions with unconfirmed +// inputs that are not present in the mempool. A configurable predicate function +// determines whether an input is confirmed on-chain: +// +// // Define predicate to check if inputs are confirmed +// isConfirmed := func(outpoint wire.OutPoint) bool { +// return utxoSet.IsConfirmed(outpoint) +// } +// +// // Get all orphan transactions +// orphans := graph.GetOrphans(isConfirmed) +// +// // Or iterate over orphans +// for orphan := range graph.IterateOrphans(isConfirmed) { +// // Process orphan +// } +// +// If the predicate is nil, all transactions with no parents in the graph +// are considered orphans. +// +// # Thread Safety +// +// All graph operations are thread-safe. The implementation uses a +// hierarchical locking strategy to minimize contention while maintaining +// consistency. +// +// # Example Usage +// +// // Create a new graph +// graph := txgraph.New(txgraph.DefaultConfig()) +// +// // Add a transaction +// err := graph.AddTransaction(tx, txDesc) +// +// // Get ancestors +// ancestors := graph.GetAncestors(txHash, maxDepth) +// +// // Identify packages +// packages, err := graph.IdentifyPackages() +// +// // Iterate with custom filter +// for node := range graph.Iterate(IteratorOption{ +// Order: TraversalFeeRate, +// Filter: func(n *TxGraphNode) bool { +// return n.TxDesc.FeePerKB > 10000 +// }, +// }) { +// // Process high-fee transactions +// } +package txgraph \ No newline at end of file diff --git a/mempool/txgraph/graph.go b/mempool/txgraph/graph.go new file mode 100644 index 0000000000..ec1562ee02 --- /dev/null +++ b/mempool/txgraph/graph.go @@ -0,0 +1,914 @@ +package txgraph + +import ( + "errors" + "fmt" + "slices" + "sync" + "sync/atomic" + "time" + + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" +) + +var ( + // ErrTransactionExists is returned when attempting to add a duplicate + // transaction. + ErrTransactionExists = errors.New("transaction already exists in graph") + + // ErrNodeNotFound is returned when a node is not found in the graph. + ErrNodeNotFound = errors.New("node not found in graph") + + // ErrInvalidTopology is returned when a package has invalid topology. + ErrInvalidTopology = errors.New("invalid package topology") + + // ErrDisconnectedPackage is returned when package nodes are not + // connected. + ErrDisconnectedPackage = errors.New("package contains disconnected " + + "nodes") + + // ErrCycleDetected is returned when a cycle is detected in the graph. + ErrCycleDetected = errors.New("cycle detected in graph") + + // ErrMaxDepthExceeded is returned when max traversal depth is exceeded. + ErrMaxDepthExceeded = errors.New("maximum depth exceeded") + + // ErrInvalidEdge is returned when attempting to create an invalid edge. + ErrInvalidEdge = errors.New("invalid edge") +) + +// Config defines configuration for the transaction graph. +type Config struct { + // MaxNodes limits graph capacity to prevent unbounded memory growth. + // When reached, new transaction additions will be rejected, triggering + // mempool eviction policies in the caller. + MaxNodes int + + // MaxEdges limits the total number of parent-child relationships. This + // provides defense against attacks that try to create extremely + // connected transaction graphs to degrade performance. + MaxEdges int + + // EnableCaching enables memoization of expensive computations like + // ancestor/descendant counts. This trades memory for speed, which is + // beneficial in production but may complicate debugging. + EnableCaching bool + + // CacheTimeout defines how long cached computation results remain + // valid. Shorter timeouts trade freshness for computation cost. + CacheTimeout time.Duration + + // MaxPackageSize limits the number of transactions in a package. + // Bitcoin Core uses 101 (25 ancestors + 25 descendants + 1 root), + // enforced here to prevent package relay DoS attacks. + MaxPackageSize int + + // PackageAnalyzer provides protocol-specific validation logic. If nil, + // package identification will use heuristics only without protocol + // enforcement (useful for testing or non-standard configurations). + PackageAnalyzer PackageAnalyzer +} + +// DefaultConfig returns the default graph configuration. +func DefaultConfig() *Config { + return &Config{ + MaxNodes: 100000, + MaxEdges: 200000, + EnableCaching: true, + CacheTimeout: 5 * time.Second, + MaxPackageSize: 101, + } +} + +// TxGraph implements the Graph interface. +type TxGraph struct { + config *Config + + // analyzer provides protocol-specific validation logic for transaction + // packages. Abstracted as an interface to enable testing with mocks and + // to support future protocol upgrades without modifying the core graph. + analyzer PackageAnalyzer + + // nodes stores all transactions currently in the mempool. Hash map + // provides O(1) lookups by transaction ID, which is critical for + // performance as the mempool can contain thousands of transactions. + nodes map[chainhash.Hash]*TxGraphNode + + // indexes contains auxiliary data structures for O(1) lookups that + // would otherwise require O(n) graph traversal. + indexes struct { + // spentBy maps an outpoint to the transaction that spends it. + // This enables orphan transaction handling: when a transaction + // arrives before its parent, we can quickly connect them once + // the parent arrives without rescanning all transactions. + spentBy map[wire.OutPoint]*TxGraphNode + + // clusters maps cluster IDs to connected components. Clusters + // are used for RBF validation (replacement must improve entire + // cluster fee) and mempool eviction policies. + clusters map[ClusterID]*TxCluster + + // nodeToCluster provides O(1) cluster lookup for any + // transaction, avoiding repeated graph traversal to find + // connected components. + nodeToCluster map[chainhash.Hash]ClusterID + + // packages maps package IDs to identified transaction packages. + // Packages are used for package relay policies and block + // template construction. + packages map[PackageID]*TxPackage + + // nodeToPackage enables quick package membership checks without + // recomputing package structures. + nodeToPackage map[chainhash.Hash]PackageID + + // trucTxs indexes v3 (TRUC) transactions for efficient TRUC + // policy enforcement without scanning all nodes. + trucTxs map[chainhash.Hash]*TxGraphNode + + // ephemeralTxs indexes transactions with ephemeral dust outputs + // for package validation and relay policy checks. + ephemeralTxs map[chainhash.Hash]*TxGraphNode + } + + // metrics tracks aggregate statistics using atomic operations to enable + // lock-free reads. This is essential because metrics are frequently + // queried for monitoring and eviction decisions. + metrics struct { + nodeCount int32 + edgeCount int32 + packageCount int32 + clusterCount int32 + trucCount int32 + ephemeralCount int32 + } + + // nextClusterID generates monotonically increasing cluster identifiers. + // Uses atomic.Uint64 to avoid contention on the main graph mutex during + // high-frequency transaction additions. + nextClusterID atomic.Uint64 + + // mu protects the graph structure. RWMutex allows concurrent reads + // (queries, iteration) while serializing writes (add/remove operations). + mu sync.RWMutex +} + +// New creates a new transaction graph. +func New(config *Config) *TxGraph { + if config == nil { + config = DefaultConfig() + } + + g := &TxGraph{ + config: config, + analyzer: config.PackageAnalyzer, + nodes: make(map[chainhash.Hash]*TxGraphNode), + } + + g.indexes.spentBy = make(map[wire.OutPoint]*TxGraphNode) + g.indexes.clusters = make(map[ClusterID]*TxCluster) + g.indexes.nodeToCluster = make(map[chainhash.Hash]ClusterID) + g.indexes.packages = make(map[PackageID]*TxPackage) + g.indexes.nodeToPackage = make(map[chainhash.Hash]PackageID) + g.indexes.trucTxs = make(map[chainhash.Hash]*TxGraphNode) + g.indexes.ephemeralTxs = make(map[chainhash.Hash]*TxGraphNode) + + return g +} + +// AddTransaction adds a transaction to the graph. +func (g *TxGraph) AddTransaction(tx *btcutil.Tx, txDesc *TxDesc) error { + g.mu.Lock() + defer g.mu.Unlock() + + hash := tx.Hash() + + // Check if already exists. + if _, exists := g.nodes[*hash]; exists { + return ErrTransactionExists + } + + // Check capacity. + if int(atomic.LoadInt32(&g.metrics.nodeCount)) >= g.config.MaxNodes { + return fmt.Errorf("graph at capacity: %d nodes", + g.config.MaxNodes) + } + + // Create new node. + node := &TxGraphNode{ + TxHash: *hash, + Tx: tx, + TxDesc: txDesc, + Parents: make(map[chainhash.Hash]*TxGraphNode), + Children: make(map[chainhash.Hash]*TxGraphNode), + } + + // Set metadata. + node.Metadata.AddedTime = time.Now() + node.Metadata.IsTRUC = (tx.MsgTx().Version == 3) + node.Metadata.ClusterID = ClusterID(g.nextClusterID.Add(1)) + + // Add to graph. + g.nodes[*hash] = node + + // Connect edges based on inputs. + for _, txIn := range tx.MsgTx().TxIn { + parentHash := txIn.PreviousOutPoint.Hash + + // Always update spentBy index, even if parent doesn't exist yet. + // This handles the orphan case where a child arrives before its + // parent: when the parent eventually arrives, we use this index + // to quickly find and connect all waiting children without + // rescanning the entire mempool. This is a time-space tradeoff: + // we maintain extra index entries for orphans to avoid O(n) + // scans on every parent arrival. + g.indexes.spentBy[txIn.PreviousOutPoint] = node + + if parent, exists := g.nodes[parentHash]; exists { + // Create bidirectional edge between parent and child. + node.Parents[parentHash] = parent + parent.Children[*hash] = node + + atomic.AddInt32(&g.metrics.edgeCount, 1) + } + } + + // Find children that spend this transaction. + for i := range tx.MsgTx().TxOut { + outpoint := wire.OutPoint{ + Hash: *hash, + Index: uint32(i), + } + + if child, exists := g.indexes.spentBy[outpoint]; exists { + node.Children[child.TxHash] = child + child.Parents[*hash] = node + atomic.AddInt32(&g.metrics.edgeCount, 1) + } + } + + // Update cluster assignment after all edges are established. Cluster + // identification requires knowing the node's complete set of parents + // and children to determine which existing clusters need to be merged. + // Doing this before edge creation would result in incorrect cluster + // assignments when the node bridges multiple clusters. + g.updateClusterAssignment(node) + + // Update feature-specific indexes. + if node.Metadata.IsTRUC { + g.indexes.trucTxs[*hash] = node + atomic.AddInt32(&g.metrics.trucCount, 1) + } + + // Update metrics. + atomic.AddInt32(&g.metrics.nodeCount, 1) + + return nil +} + +// RemoveTransaction removes a transaction from the graph. +func (g *TxGraph) RemoveTransaction(hash chainhash.Hash) error { + g.mu.Lock() + defer g.mu.Unlock() + + _, exists := g.nodes[hash] + if !exists { + return ErrNodeNotFound + } + + // Collect all descendants to remove. + toRemove := g.collectDescendantsToRemove(hash) + + // Remove in reverse order (children before parents). This ordering is + // required because removeTransactionUnsafe updates parent nodes to + // remove child references. If we removed parents first, we'd be + // modifying the Children map of deleted nodes, which could cause + // panics or leave dangling pointers. Removing children first ensures + // all parent.Children map updates refer to live nodes. + for i := len(toRemove) - 1; i >= 0; i-- { + // Skip if node was already removed. This can occur when a node + // has multiple parents that were both in the removal set. + if _, exists := g.nodes[toRemove[i]]; !exists { + continue + } + if err := g.removeTransactionUnsafe(toRemove[i]); err != nil { + return err + } + } + + return nil +} + +// RemoveTransactionNoCascade removes a transaction without removing its +// descendants. This is used when a transaction is confirmed in a block - the +// transaction leaves the mempool but its children remain valid (they now +// reference a confirmed input). +// +// The children will have their parent reference updated to remove the +// confirmed transaction, but they remain in the graph. +// +// This is in contrast to RemoveTransaction which cascades and removes all +// descendants, which is appropriate when a transaction is evicted/invalidated +// (e.g., replaced by RBF). +func (g *TxGraph) RemoveTransactionNoCascade(hash chainhash.Hash) error { + g.mu.Lock() + defer g.mu.Unlock() + + _, exists := g.nodes[hash] + if !exists { + return ErrNodeNotFound + } + + // Remove just this transaction, without cascading to descendants. + // The removeTransactionUnsafe function will clean up edges from + // children. + return g.removeTransactionUnsafe(hash) +} + +// collectDescendantsToRemove collects all descendants including the node +// itself. Must be called with lock held. Returns a slice of hashes in +// topological order (parents before children). +func (g *TxGraph) collectDescendantsToRemove( + hash chainhash.Hash, +) []chainhash.Hash { + node, exists := g.nodes[hash] + if !exists { + // Node doesn't exist, return empty slice. + return nil + } + + result := []chainhash.Hash{hash} + + for childHash := range node.Children { + childDescendants := g.collectDescendantsToRemove(childHash) + result = append(result, childDescendants...) + } + + return result +} + +// removeTransactionUnsafe removes a single transaction without recursion. Must +// be called with lock held. +func (g *TxGraph) removeTransactionUnsafe(hash chainhash.Hash) error { + node, exists := g.nodes[hash] + if !exists { + return ErrNodeNotFound + } + + // Remove edges from parents. + for parentHash, parent := range node.Parents { + delete(parent.Children, hash) + + atomic.AddInt32(&g.metrics.edgeCount, -1) + + // Remove from spent index. + for _, txIn := range node.Tx.MsgTx().TxIn { + if txIn.PreviousOutPoint.Hash == parentHash { + delete(g.indexes.spentBy, txIn.PreviousOutPoint) + } + } + } + + // Remove edges from children. + for _, child := range node.Children { + delete(child.Parents, hash) + + atomic.AddInt32(&g.metrics.edgeCount, -1) + } + + // Remove from feature indexes. + if node.Metadata.IsTRUC { + delete(g.indexes.trucTxs, hash) + + atomic.AddInt32(&g.metrics.trucCount, -1) + } + if node.Metadata.IsEphemeral { + delete(g.indexes.ephemeralTxs, hash) + + atomic.AddInt32(&g.metrics.ephemeralCount, -1) + } + + // Remove from cluster. + if clusterID, exists := g.indexes.nodeToCluster[hash]; exists { + if cluster, exists := g.indexes.clusters[clusterID]; exists { + delete(cluster.Nodes, hash) + if len(cluster.Nodes) == 0 { + delete(g.indexes.clusters, clusterID) + atomic.AddInt32(&g.metrics.clusterCount, -1) + } + } + + delete(g.indexes.nodeToCluster, hash) + } + + // Remove from packages. + if pkgID := node.Metadata.PackageID; pkgID != nil { + if pkg, exists := g.indexes.packages[*pkgID]; exists { + delete(pkg.Transactions, hash) + + if len(pkg.Transactions) == 0 { + delete(g.indexes.packages, *pkgID) + + atomic.AddInt32(&g.metrics.packageCount, -1) + } + } + + delete(g.indexes.nodeToPackage, hash) + } + + // Remove node. + delete(g.nodes, hash) + + atomic.AddInt32(&g.metrics.nodeCount, -1) + + return nil +} + +// GetNode retrieves a node from the graph. +func (g *TxGraph) GetNode(hash chainhash.Hash) (*TxGraphNode, bool) { + g.mu.RLock() + defer g.mu.RUnlock() + + node, exists := g.nodes[hash] + return node, exists +} + +// HasTransaction checks if a transaction exists in the graph. +func (g *TxGraph) HasTransaction(hash chainhash.Hash) bool { + g.mu.RLock() + defer g.mu.RUnlock() + + _, exists := g.nodes[hash] + return exists +} + +// AddEdge adds an edge between two nodes. +func (g *TxGraph) AddEdge(parentHash, childHash chainhash.Hash) error { + g.mu.Lock() + defer g.mu.Unlock() + + parent, parentExists := g.nodes[parentHash] + child, childExists := g.nodes[childHash] + + if !parentExists || !childExists { + return ErrNodeNotFound + } + + // Check if edge already exists. + if _, exists := parent.Children[childHash]; exists { + return nil // Already connected + } + + // Check for cycles. + if g.wouldCreateCycle(parent, child) { + return ErrCycleDetected + } + + // Add edge. + parent.Children[childHash] = child + child.Parents[parentHash] = parent + atomic.AddInt32(&g.metrics.edgeCount, 1) + + // Update clusters. + g.mergeNodeClusters(parent, child) + + return nil +} + +// RemoveEdge removes an edge between two nodes. +func (g *TxGraph) RemoveEdge(parentHash, childHash chainhash.Hash) error { + g.mu.Lock() + defer g.mu.Unlock() + + parent, parentExists := g.nodes[parentHash] + child, childExists := g.nodes[childHash] + + if !parentExists || !childExists { + return ErrNodeNotFound + } + + // Remove edge if it exists. + if _, exists := parent.Children[childHash]; exists { + delete(parent.Children, childHash) + delete(child.Parents, parentHash) + atomic.AddInt32(&g.metrics.edgeCount, -1) + } + + return nil +} + +// GetAncestors returns all ancestors of a transaction up to maxDepth. +func (g *TxGraph) GetAncestors(hash chainhash.Hash, + maxDepth int) map[chainhash.Hash]*TxGraphNode { + + g.mu.RLock() + defer g.mu.RUnlock() + + node, exists := g.nodes[hash] + if !exists { + return nil + } + + // Check cache if enabled. + if g.config.EnableCaching && + time.Since(node.cachedMetrics.LastUpdated) < g.config.CacheTimeout { + // For now, skip cache and compute directly. + // TODO: Implement proper caching. + } + + ancestors := make(map[chainhash.Hash]*TxGraphNode) + visited := make(map[chainhash.Hash]bool) + g.collectAncestorsRecursive(node, ancestors, visited, 0, maxDepth) + + return ancestors +} + +// collectAncestorsRecursive recursively collects ancestors. +func (g *TxGraph) collectAncestorsRecursive( + node *TxGraphNode, ancestors map[chainhash.Hash]*TxGraphNode, + visited map[chainhash.Hash]bool, currentDepth, maxDepth int) { + if maxDepth >= 0 && currentDepth >= maxDepth { + return + } + + for hash, parent := range node.Parents { + if visited[hash] { + continue + } + + visited[hash] = true + ancestors[hash] = parent + + g.collectAncestorsRecursive( + parent, ancestors, visited, currentDepth+1, maxDepth, + ) + } +} + +// GetDescendants returns all descendants of a transaction up to maxDepth. +func (g *TxGraph) GetDescendants(hash chainhash.Hash, + maxDepth int) map[chainhash.Hash]*TxGraphNode { + + g.mu.RLock() + defer g.mu.RUnlock() + + node, exists := g.nodes[hash] + if !exists { + return nil + } + + descendants := make(map[chainhash.Hash]*TxGraphNode) + visited := make(map[chainhash.Hash]bool) + + g.collectDescendantsRecursive(node, descendants, visited, 0, maxDepth) + + return descendants +} + +// collectDescendantsRecursive recursively collects descendants. +func (g *TxGraph) collectDescendantsRecursive( + node *TxGraphNode, + descendants map[chainhash.Hash]*TxGraphNode, + visited map[chainhash.Hash]bool, + currentDepth, maxDepth int, +) { + if maxDepth >= 0 && currentDepth >= maxDepth { + return + } + + for hash, child := range node.Children { + if visited[hash] { + continue + } + visited[hash] = true + descendants[hash] = child + + g.collectDescendantsRecursive( + child, descendants, visited, currentDepth+1, maxDepth, + ) + } +} + +// GetCluster returns the cluster containing the specified transaction. +func (g *TxGraph) GetCluster(hash chainhash.Hash) (*TxCluster, error) { + g.mu.RLock() + defer g.mu.RUnlock() + + clusterID, exists := g.indexes.nodeToCluster[hash] + if !exists { + return nil, ErrNodeNotFound + } + + cluster, exists := g.indexes.clusters[clusterID] + if !exists { + return nil, fmt.Errorf("cluster %d not found", clusterID) + } + + return cluster, nil +} + +// GetPackage returns the package containing the specified transaction. +func (g *TxGraph) GetPackage(hash chainhash.Hash) (*TxPackage, error) { + g.mu.RLock() + defer g.mu.RUnlock() + + pkgID, exists := g.indexes.nodeToPackage[hash] + if !exists { + return nil, ErrNodeNotFound + } + + pkg, exists := g.indexes.packages[pkgID] + if !exists { + return nil, fmt.Errorf("package not found") + } + + return pkg, nil +} + +// GetOrphans returns all orphan transactions as a slice. +// See IterateOrphans for the definition of an orphan transaction. +func (g *TxGraph) GetOrphans(isConfirmed InputConfirmedPredicate, +) []*TxGraphNode { + + return slices.Collect(g.IterateOrphans(isConfirmed)) +} + +// ValidatePackage validates a transaction package. +func (g *TxGraph) ValidatePackage(pkg *TxPackage) error { + if pkg == nil { + return fmt.Errorf("nil package") + } + + // Check empty package. + if len(pkg.Transactions) == 0 { + return fmt.Errorf("empty package") + } + + // Check size limits. + if len(pkg.Transactions) > g.config.MaxPackageSize { + return fmt.Errorf("package too large: %d transactions", + len(pkg.Transactions)) + } + + // Validate topology consistency. + if pkg.Topology.MaxDepth > len(pkg.Transactions)-1 { + return fmt.Errorf("invalid topology: max depth %d "+ + "exceeds possible depth for %d transactions", + pkg.Topology.MaxDepth, len(pkg.Transactions)) + } + + // Validate topology based on type. + switch pkg.Type { + case PackageType1P1C: + if len(pkg.Transactions) != 2 { + return ErrInvalidTopology + } + // TODO: Validate 1 parent, 1 child relationship. + + case PackageTypeTRUC: + // All transactions must be version 3. + for _, node := range pkg.Transactions { + if !node.Metadata.IsTRUC { + return fmt.Errorf("non-TRUC " + + "transaction in TRUC package") + } + } + } + + return nil +} + +// GetMetrics returns current graph metrics. +func (g *TxGraph) GetMetrics() GraphMetrics { + return GraphMetrics{ + NodeCount: int(atomic.LoadInt32(&g.metrics.nodeCount)), + EdgeCount: int(atomic.LoadInt32(&g.metrics.edgeCount)), + PackageCount: int(atomic.LoadInt32(&g.metrics.packageCount)), + TRUCCount: int(atomic.LoadInt32(&g.metrics.trucCount)), + EphemeralCount: int(atomic.LoadInt32( + &g.metrics.ephemeralCount, + )), + ClusterCount: int(atomic.LoadInt32(&g.metrics.clusterCount)), + } +} + +// GetNodeCount returns the number of nodes in the graph. +func (g *TxGraph) GetNodeCount() int { + return int(atomic.LoadInt32(&g.metrics.nodeCount)) +} + +// GetClusterCount returns the number of clusters in the graph. +func (g *TxGraph) GetClusterCount() int { + return int(atomic.LoadInt32(&g.metrics.clusterCount)) +} + +// wouldCreateCycle checks if adding an edge would create a cycle in the DAG. +func (g *TxGraph) wouldCreateCycle(parent, child *TxGraphNode) bool { + // Cycle detection uses the fundamental DAG property: adding edge + // parent→child creates a cycle if and only if there's already a path + // child→parent. We check this by attempting to reach parent from child + // through existing edges. The graph is a directed acyclic graph (DAG) + // where edges point from parent to child (spending relationship), so + // any path child→parent would violate the acyclic property when + // combined with the new parent→child edge. + visited := make(map[chainhash.Hash]bool) + return g.isReachable(child, parent.TxHash, visited) +} + +// isReachable performs depth-first search to check if target is reachable from +// source via children edges. +func (g *TxGraph) isReachable(source *TxGraphNode, target chainhash.Hash, + visited map[chainhash.Hash]bool) bool { + + if source.TxHash == target { + return true + } + + // Visited map prevents infinite loops and improves performance by + // avoiding redundant traversal of already-explored subgraphs. + if visited[source.TxHash] { + return false + } + visited[source.TxHash] = true + + // Recursively traverse children edges. In transaction graph terms, + // this follows spending relationships forward (parent→child direction). + for _, child := range source.Children { + if g.isReachable(child, target, visited) { + return true + } + } + + return false +} + +// updateClusterAssignment updates the cluster assignment for a node. +func (g *TxGraph) updateClusterAssignment(node *TxGraphNode) { + // Find clusters of connected nodes. + parentClusters := make(map[ClusterID]bool) + childClusters := make(map[ClusterID]bool) + + for _, parent := range node.Parents { + if cid, exists := g.indexes.nodeToCluster[parent.TxHash]; exists { + parentClusters[cid] = true + } + } + + for _, child := range node.Children { + if cid, exists := g.indexes.nodeToCluster[child.TxHash]; exists { + childClusters[cid] = true + } + } + + // Merge all related clusters. + allClusters := make(map[ClusterID]bool) + for cid := range parentClusters { + allClusters[cid] = true + } + for cid := range childClusters { + allClusters[cid] = true + } + + if len(allClusters) == 0 { + // Create new cluster. + g.createNewCluster(node) + } else if len(allClusters) == 1 { + + // Add to existing cluster. + for cid := range allClusters { + g.addToCluster(node, cid) + } + + } else { + // Merge multiple clusters. + g.mergeClusters(node, allClusters) + } +} + +// createNewCluster creates a new cluster for a node. +func (g *TxGraph) createNewCluster(node *TxGraphNode) { + clusterID := ClusterID(g.nextClusterID.Add(1)) + + cluster := &TxCluster{ + ID: clusterID, + Nodes: make(map[chainhash.Hash]*TxGraphNode), + Size: 1, + } + + cluster.Nodes[node.TxHash] = node + cluster.Roots = []*TxGraphNode{node} + cluster.Leaves = []*TxGraphNode{node} + + g.indexes.clusters[clusterID] = cluster + g.indexes.nodeToCluster[node.TxHash] = clusterID + node.Metadata.ClusterID = clusterID + + atomic.AddInt32(&g.metrics.clusterCount, 1) +} + +// addToCluster adds a node to an existing cluster. +func (g *TxGraph) addToCluster(node *TxGraphNode, clusterID ClusterID) { + cluster, exists := g.indexes.clusters[clusterID] + if !exists { + g.createNewCluster(node) + return + } + + cluster.Nodes[node.TxHash] = node + cluster.Size++ + + g.indexes.nodeToCluster[node.TxHash] = clusterID + + node.Metadata.ClusterID = clusterID + + // Update roots and leaves. + g.updateClusterBoundaries(cluster) +} + +// mergeClusters merges multiple clusters into one. +func (g *TxGraph) mergeClusters( + node *TxGraphNode, clusterIDs map[ClusterID]bool, +) { + // Use the lowest cluster ID as the target to maintain deterministic + // behavior and stable cluster identification across runs. This choice + // is arbitrary but consistent: we could choose any cluster as the + // merge target, but always picking the lowest ID ensures that cluster + // IDs don't change unnecessarily during graph operations, which + // simplifies debugging and testing. + var targetID ClusterID + first := true + for cid := range clusterIDs { + if first || cid < targetID { + targetID = cid + first = false + } + } + + targetCluster := g.indexes.clusters[targetID] + if targetCluster == nil { + g.createNewCluster(node) + return + } + + // Add the new node. + targetCluster.Nodes[node.TxHash] = node + + g.indexes.nodeToCluster[node.TxHash] = targetID + node.Metadata.ClusterID = targetID + + // Merge other clusters into target. + for cid := range clusterIDs { + if cid == targetID { + continue + } + + if cluster, exists := g.indexes.clusters[cid]; exists { + for hash, n := range cluster.Nodes { + targetCluster.Nodes[hash] = n + g.indexes.nodeToCluster[hash] = targetID + n.Metadata.ClusterID = targetID + } + + delete(g.indexes.clusters, cid) + + atomic.AddInt32(&g.metrics.clusterCount, -1) + } + } + + targetCluster.Size = len(targetCluster.Nodes) + g.updateClusterBoundaries(targetCluster) +} + +// mergeNodeClusters merges clusters when adding an edge. +func (g *TxGraph) mergeNodeClusters(parent, child *TxGraphNode) { + parentCluster := g.indexes.nodeToCluster[parent.TxHash] + childCluster := g.indexes.nodeToCluster[child.TxHash] + + if parentCluster == childCluster { + return + } + + clusters := make(map[ClusterID]bool) + clusters[parentCluster] = true + clusters[childCluster] = true + + g.mergeClusters(parent, clusters) +} + +// updateClusterBoundaries updates roots and leaves of a cluster. +func (g *TxGraph) updateClusterBoundaries(cluster *TxCluster) { + cluster.Roots = nil + cluster.Leaves = nil + + for _, node := range cluster.Nodes { + if len(node.Parents) == 0 { + cluster.Roots = append(cluster.Roots, node) + } + if len(node.Children) == 0 { + cluster.Leaves = append(cluster.Leaves, node) + } + } +} + diff --git a/mempool/txgraph/graph_test.go b/mempool/txgraph/graph_test.go new file mode 100644 index 0000000000..f9eae67a5a --- /dev/null +++ b/mempool/txgraph/graph_test.go @@ -0,0 +1,1532 @@ +package txgraph + +import ( + "crypto/rand" + "encoding/binary" + "slices" + "sync/atomic" + "testing" + "time" + + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/txscript" + "github.com/btcsuite/btcd/wire" + "github.com/stretchr/testify/require" + "pgregory.net/rapid" +) + +// txCounter is a global atomic counter for generating unique transaction +// hashes. Using atomic operations allows concurrent test execution without +// hash collisions, which would cause spurious test failures. +var txCounter uint64 + +// h dereferences a hash pointer to a value. This helper reduces visual noise +// in tests that need to pass hash values rather than pointers to assertion +// functions. +func h(hash *chainhash.Hash) chainhash.Hash { + return *hash +} + +// txGenerator generates unique test transactions using an atomic counter. +// This provides deterministic transaction creation for property-based tests +// where we need reproducible test cases but still require unique transaction +// IDs to avoid graph collisions. +type txGenerator struct { + counter *uint64 +} + +// newTxGenerator creates a new transaction generator using the global +// txCounter to ensure uniqueness across all test cases. +func newTxGenerator() *txGenerator { + return &txGenerator{counter: &txCounter} +} + +// createTx creates a test transaction with a guaranteed unique hash by +// embedding an atomic counter in each output's pkScript. This ensures +// transaction uniqueness for tests that need deterministic behavior, unlike +// createTestTx which uses random data and may occasionally collide. +func (gen *txGenerator) createTx(inputs []wire.OutPoint, + numOutputs int) (*btcutil.Tx, *TxDesc) { + + tx := wire.NewMsgTx(wire.TxVersion) + + for _, input := range inputs { + tx.AddTxIn(wire.NewTxIn(&input, nil, nil)) + } + + // Embed atomic counter in pkScript to guarantee unique transaction + // hash. The counter is atomically incremented for each output, which + // means even transactions with the same inputs will have different + // hashes due to different output scripts. + for i := 0; i < numOutputs; i++ { + counter := atomic.AddUint64(gen.counter, 1) + pkScript := make([]byte, 8) + binary.BigEndian.PutUint64(pkScript, counter) + + tx.AddTxOut(wire.NewTxOut(100000, pkScript)) + } + + btcTx := btcutil.NewTx(tx) + + // Create a minimal TxDesc with fixed fee values. Tests that need + // specific fee rates should create their own descriptors. + txDesc := &TxDesc{ + TxHash: *btcTx.Hash(), + VirtualSize: int64(btcTx.MsgTx().SerializeSize()), + Fee: 1000, + FeePerKB: 10000, + Added: time.Now(), + } + + return btcTx, txDesc +} + +// createTestTx creates a test transaction with random output addresses. +// This is used by legacy tests that don't require deterministic transaction +// IDs. The randomness provides high probability of uniqueness but with a +// small chance of collision, which is acceptable for simple unit tests but +// not for property-based testing. +func createTestTx(inputs []wire.OutPoint, + numOutputs int) (*btcutil.Tx, *TxDesc) { + + tx := wire.NewMsgTx(wire.TxVersion) + + for _, input := range inputs { + tx.AddTxIn(wire.NewTxIn(&input, nil, nil)) + } + + // Generate random P2PKH addresses for each output. Using valid + // Bitcoin addresses (rather than raw random bytes) makes test + // transactions more realistic and easier to debug. + for i := 0; i < numOutputs; i++ { + randBytes := make([]byte, 20) + rand.Read(randBytes) + + addr, _ := btcutil.NewAddressPubKeyHash( + randBytes, &chaincfg.MainNetParams, + ) + pkScript, _ := txscript.PayToAddrScript(addr) + + tx.AddTxOut(wire.NewTxOut(100000, pkScript)) + } + + btcTx := btcutil.NewTx(tx) + + txDesc := &TxDesc{ + TxHash: *btcTx.Hash(), + VirtualSize: int64(btcTx.MsgTx().SerializeSize()), + Fee: 1000, + FeePerKB: 10000, + Added: time.Now(), + } + + return btcTx, txDesc +} + +// TestGraphAddRemove verifies that transactions can be added to and removed +// from the graph, including proper error handling for duplicate adds and +// removal of non-existent transactions. +func TestGraphAddRemove(t *testing.T) { + g := New(DefaultConfig()) + + tx1, desc1 := createTestTx(nil, 2) + + err := g.AddTransaction(tx1, desc1) + require.NoError(t, err) + + node, exists := g.GetNode(*tx1.Hash()) + require.True(t, exists) + require.NotNil(t, node) + require.Equal(t, *tx1.Hash(), node.TxHash) + + // Duplicate addition should return ErrTransactionExists rather than + // allowing graph corruption from multiple nodes with the same hash. + err = g.AddTransaction(tx1, desc1) + require.ErrorIs(t, err, ErrTransactionExists) + + err = g.RemoveTransaction(*tx1.Hash()) + require.NoError(t, err) + + _, exists = g.GetNode(*tx1.Hash()) + require.False(t, exists) + + // Removing a non-existent transaction should fail gracefully rather + // than panicking or corrupting graph state. + err = g.RemoveTransaction(*tx1.Hash()) + require.ErrorIs(t, err, ErrNodeNotFound) +} + +// TestGraphEdges verifies that parent-child edges are automatically created +// when a transaction spends outputs from another transaction in the graph. +// This is critical for maintaining the dependency graph that drives ancestor/ +// descendant queries and cluster formation. +func TestGraphEdges(t *testing.T) { + g := New(DefaultConfig()) + + parent, parentDesc := createTestTx(nil, 2) + err := g.AddTransaction(parent, parentDesc) + require.NoError(t, err) + + parentOut := wire.OutPoint{ + Hash: *parent.Hash(), + Index: 0, + } + child, childDesc := createTestTx([]wire.OutPoint{parentOut}, 1) + + // Edge creation should happen automatically during AddTransaction by + // detecting that the child spends from the parent's outputs. + err = g.AddTransaction(child, childDesc) + require.NoError(t, err) + + parentNode, _ := g.GetNode(*parent.Hash()) + childNode, _ := g.GetNode(*child.Hash()) + + require.Len(t, parentNode.Children, 1) + require.Len(t, childNode.Parents, 1) + require.NotNil(t, parentNode.Children[*child.Hash()]) + require.NotNil(t, childNode.Parents[*parent.Hash()]) + + metrics := g.GetMetrics() + require.Equal(t, 2, metrics.NodeCount) + require.Equal(t, 1, metrics.EdgeCount) + + err = g.RemoveEdge(*parent.Hash(), *child.Hash()) + require.NoError(t, err) + + // Edge removal should update both nodes' relationship maps and + // decrement the edge count metric. + parentNode, _ = g.GetNode(*parent.Hash()) + childNode, _ = g.GetNode(*child.Hash()) + require.Len(t, parentNode.Children, 0) + require.Len(t, childNode.Parents, 0) +} + +// TestGraphAncestorsDescendants verifies that ancestor and descendant +// queries correctly traverse the transaction dependency graph. These queries +// are essential for enforcing BIP 125 ancestor/descendant limits and +// calculating package fee rates for mining. +func TestGraphAncestorsDescendants(t *testing.T) { + g := New(DefaultConfig()) + + // Build a linear chain to test traversal in both directions. + var prevHash *chainhash.Hash + var txs []*btcutil.Tx + + for i := 0; i < 4; i++ { + var inputs []wire.OutPoint + if prevHash != nil { + inputs = append(inputs, wire.OutPoint{ + Hash: *prevHash, + Index: 0, + }) + } + + tx, desc := createTestTx(inputs, 1) + err := g.AddTransaction(tx, desc) + require.NoError(t, err) + + txs = append(txs, tx) + prevHash = tx.Hash() + } + + // Ancestors of tx3 should include all transactions it depends on. + ancestors := g.GetAncestors(*txs[2].Hash(), -1) + require.Len(t, ancestors, 2) + require.NotNil(t, ancestors[*txs[0].Hash()]) + require.NotNil(t, ancestors[*txs[1].Hash()]) + + // Depth limit should stop traversal at the specified level. + ancestors = g.GetAncestors(*txs[2].Hash(), 1) + require.Len(t, ancestors, 1) + require.NotNil(t, ancestors[*txs[1].Hash()]) + + // Descendants of tx1 should include all transactions that depend on + // it, directly or indirectly. + descendants := g.GetDescendants(*txs[0].Hash(), -1) + require.Len(t, descendants, 3) + require.NotNil(t, descendants[*txs[1].Hash()]) + require.NotNil(t, descendants[*txs[2].Hash()]) + require.NotNil(t, descendants[*txs[3].Hash()]) + + descendants = g.GetDescendants(*txs[0].Hash(), 2) + require.Len(t, descendants, 2) + require.NotNil(t, descendants[*txs[1].Hash()]) + require.NotNil(t, descendants[*txs[2].Hash()]) +} + +// TestCycleDetection verifies that the graph prevents cycles, which would +// violate the DAG property required for transaction dependencies. Cycles +// would make ancestor/descendant queries infinite loop and break topological +// ordering for block template construction. +func TestCycleDetection(t *testing.T) { + g := New(DefaultConfig()) + + tx1, desc1 := createTestTx(nil, 1) + tx2, desc2 := createTestTx( + []wire.OutPoint{{Hash: *tx1.Hash(), Index: 0}}, 1, + ) + + err := g.AddTransaction(tx1, desc1) + require.NoError(t, err) + err = g.AddTransaction(tx2, desc2) + require.NoError(t, err) + + // Attempting to add an edge that would create a cycle (tx2 -> tx1 + // when tx1 -> tx2 already exists) must be rejected to maintain the + // DAG invariant. + err = g.AddEdge(*tx2.Hash(), *tx1.Hash()) + require.ErrorIs(t, err, ErrCycleDetected) +} + +// TestClusterManagement verifies that transactions are correctly grouped +// into clusters (connected components) and that clusters merge when a +// transaction bridges two previously separate clusters. This is essential +// for RBF validation where replacement transactions must improve the fee +// rate of the entire cluster. +func TestClusterManagement(t *testing.T) { + g := New(DefaultConfig()) + + // Create two independent transaction chains. Each chain forms its own + // cluster since there are no spending relationships between them. + tx1, desc1 := createTestTx(nil, 2) + require.NoError(t, g.AddTransaction(tx1, desc1)) + + tx2, desc2 := createTestTx( + []wire.OutPoint{{Hash: *tx1.Hash(), Index: 0}}, 2, + ) + require.NoError(t, g.AddTransaction(tx2, desc2)) + + tx3, desc3 := createTestTx(nil, 2) + require.NoError(t, g.AddTransaction(tx3, desc3)) + + tx4, desc4 := createTestTx( + []wire.OutPoint{{Hash: *tx3.Hash(), Index: 0}}, 2, + ) + require.NoError(t, g.AddTransaction(tx4, desc4)) + + require.Equal(t, 2, g.GetClusterCount()) + + // Create a transaction that spends from both chains. This bridges + // the two clusters, forcing them to merge into a single connected + // component. + tx5Inputs := []wire.OutPoint{ + {Hash: *tx2.Hash(), Index: 0}, + {Hash: *tx4.Hash(), Index: 0}, + } + tx5, desc5 := createTestTx(tx5Inputs, 1) + require.NoError(t, g.AddTransaction(tx5, desc5)) + + require.Equal(t, 1, g.GetClusterCount()) + + // All transactions should now belong to the same cluster. + cluster1, err := g.GetCluster(*tx1.Hash()) + require.NoError(t, err) + cluster5, err := g.GetCluster(*tx5.Hash()) + require.NoError(t, err) + require.Equal(t, cluster1.ID, cluster5.ID) + require.Len(t, cluster1.Nodes, 5) +} + +// TestTRUCDetection tests version 3 transaction detection. +func TestTRUCDetection(t *testing.T) { + g := New(DefaultConfig()) + + // Create version 3 transaction. + tx := wire.NewMsgTx(3) + tx.AddTxOut(wire.NewTxOut(100000, nil)) + btcTx := btcutil.NewTx(tx) + + desc := &TxDesc{ + TxHash: *btcTx.Hash(), + VirtualSize: int64(tx.SerializeSize()), + Fee: 1000, + FeePerKB: 10000, + Added: time.Now(), + } + + err := g.AddTransaction(btcTx, desc) + require.NoError(t, err) + + node, exists := g.GetNode(*btcTx.Hash()) + require.True(t, exists) + require.True(t, node.Metadata.IsTRUC) + + metrics := g.GetMetrics() + require.Equal(t, 1, metrics.TRUCCount) +} + +// TestHasTransaction tests the HasTransaction method. +func TestHasTransaction(t *testing.T) { + g := New(DefaultConfig()) + + // Add a transaction. + tx, desc := createTestTx(nil, 1) + require.NoError(t, g.AddTransaction(tx, desc)) + + // Test HasTransaction. + require.True(t, g.HasTransaction(*tx.Hash())) + + // Test non-existent transaction. + nonExistent := &wire.MsgTx{Version: 1} + nonExistentHash := nonExistent.TxHash() + require.False(t, g.HasTransaction(nonExistentHash)) +} + +// TestGetNodeCount tests the GetNodeCount method. +func TestGetNodeCount(t *testing.T) { + g := New(DefaultConfig()) + + // Initially should be 0. + require.Equal(t, 0, g.GetNodeCount()) + + // Add transactions. + tx1, desc1 := createTestTx(nil, 1) + require.NoError(t, g.AddTransaction(tx1, desc1)) + require.Equal(t, 1, g.GetNodeCount()) + + tx2, desc2 := createTestTx( + []wire.OutPoint{{Hash: *tx1.Hash(), Index: 0}}, 1, + ) + require.NoError(t, g.AddTransaction(tx2, desc2)) + require.Equal(t, 2, g.GetNodeCount()) +} + +// TestAddEdgeErrors tests error cases in AddEdge. +func TestAddEdgeErrors(t *testing.T) { + g := New(DefaultConfig()) + + // Try to add edge between non-existent nodes. + tx1Msg := wire.NewMsgTx(1) + tx2Msg := wire.NewMsgTx(1) + hash1 := tx1Msg.TxHash() + hash2 := tx2Msg.TxHash() + + err := g.AddEdge(hash1, hash2) + require.Error(t, err) + require.Equal(t, ErrNodeNotFound, err) + + // Add one node and try to add edge. + tx1, desc1 := createTestTx(nil, 1) + require.NoError(t, g.AddTransaction(tx1, desc1)) + + err = g.AddEdge(*tx1.Hash(), hash2) + require.Error(t, err) + require.Equal(t, ErrNodeNotFound, err) + + // Add second node. + tx2, desc2 := createTestTx(nil, 1) + require.NoError(t, g.AddTransaction(tx2, desc2)) + + // Add valid edge. + err = g.AddEdge(*tx1.Hash(), *tx2.Hash()) + require.NoError(t, err) + + // Try to add duplicate edge. + err = g.AddEdge(*tx1.Hash(), *tx2.Hash()) + require.NoError(t, err) + + // Try to create cycle. + err = g.AddEdge(*tx2.Hash(), *tx1.Hash()) + require.Error(t, err) + require.Equal(t, ErrCycleDetected, err) +} + +// TestRemoveTransactionComplex tests complex removal scenarios. +func TestRemoveTransactionComplex(t *testing.T) { + g := New(DefaultConfig()) + + // Create a chain: tx1 -> tx2 -> tx3. + tx1, desc1 := createTestTx(nil, 1) + require.NoError(t, g.AddTransaction(tx1, desc1)) + + tx2, desc2 := createTestTx( + []wire.OutPoint{{Hash: *tx1.Hash(), Index: 0}}, 1, + ) + require.NoError(t, g.AddTransaction(tx2, desc2)) + + tx3, desc3 := createTestTx( + []wire.OutPoint{{Hash: *tx2.Hash(), Index: 0}}, 1, + ) + require.NoError(t, g.AddTransaction(tx3, desc3)) + + // Remove middle transaction (should recursively remove tx3 too). + err := g.RemoveTransaction(*tx2.Hash()) + require.NoError(t, err) + + // Verify tx3 was also removed. + require.False(t, g.HasTransaction(*tx3.Hash())) + require.False(t, g.HasTransaction(*tx2.Hash())) + + // tx1 should still exist. + require.True(t, g.HasTransaction(*tx1.Hash())) + + // Verify tx1 has no children. + node1, exists := g.GetNode(*tx1.Hash()) + require.True(t, exists) + require.Len(t, node1.Children, 0) +} + +// Property-based tests using rapid. + +// TestPropertyNoSelfLoops verifies graph never has self-loops. +func TestPropertyNoSelfLoops(t *testing.T) { + rapid.Check(t, func(t *rapid.T) { + g := New(DefaultConfig()) + + // Generate random transactions. + numTxs := rapid.IntRange(1, 20).Draw(t, "numTxs") + txs := make([]*btcutil.Tx, 0, numTxs) + + for i := 0; i < numTxs; i++ { + // Randomly connect to previous transactions. + var inputs []wire.OutPoint + if len(txs) > 0 && rapid.Bool().Draw(t, "hasParent") { + parentIdx := rapid.IntRange(0, len(txs)-1).Draw( + t, "parentIdx", + ) + inputs = append(inputs, wire.OutPoint{ + Hash: *txs[parentIdx].Hash(), + Index: 0, + }) + } + + tx, desc := createTestTx(inputs, 1) + + err := g.AddTransaction(tx, desc) + if err == nil { + txs = append(txs, tx) + } else { + // It's OK if we get duplicate transaction + // errors in random tests + require.ErrorIs(t, err, ErrTransactionExists) + } + } + + // Property: No node should have itself as parent or child. + for hash, node := range g.nodes { + require.Nil( + t, node.Parents[hash], + "Node has itself as parent", + ) + require.Nil( + t, node.Children[hash], + "Node has itself as child", + ) + } + }) +} + +// TestPropertyParentChildSymmetry verifies parent-child relationships are +// symmetric. +func TestPropertyParentChildSymmetry(t *testing.T) { + rapid.Check(t, func(t *rapid.T) { + g := New(DefaultConfig()) + + // Generate random DAG. + numTxs := rapid.IntRange(2, 15).Draw(t, "numTxs") + txs := make([]*btcutil.Tx, numTxs) + + for i := 0; i < numTxs; i++ { + var inputs []wire.OutPoint + if i > 0 && rapid.Bool().Draw(t, "hasParent") { + // Connect to random previous transaction. + parentIdx := rapid.IntRange(0, i-1).Draw( + t, "parentIdx", + ) + inputs = append(inputs, wire.OutPoint{ + Hash: *txs[parentIdx].Hash(), + Index: 0, + }) + } + + tx, desc := createTestTx(inputs, 1) + txs[i] = tx + g.AddTransaction(tx, desc) + } + + // Property: If A is parent of B, then B is child of A. + for _, node := range g.nodes { + for _, parent := range node.Parents { + require.NotNil( + t, parent.Children[node.TxHash], + "Parent-child relationship not "+ + "symmetric", + ) + } + for _, child := range node.Children { + require.NotNil( + t, child.Parents[node.TxHash], + "Child-parent relationship not "+ + "symmetric", + ) + } + } + }) +} + +// TestPropertyMetricsConsistency verifies metrics are consistent with actual +// state. +func TestPropertyMetricsConsistency(t *testing.T) { + rapid.Check(t, func(t *rapid.T) { + g := New(DefaultConfig()) + + // Generate random operations. + ops := rapid.IntRange(10, 50).Draw(t, "ops") + addedTxs := make(map[chainhash.Hash]*btcutil.Tx) + + for i := 0; i < ops; i++ { + op := rapid.IntRange(0, 2).Draw(t, "operation") + + switch op { + case 0, 1: + var inputs []wire.OutPoint + if len(addedTxs) > 0 && + rapid.Bool().Draw(t, "hasParent") { + + // Pick random parent from added txs. + for hash := range addedTxs { + inputs = append(inputs, wire.OutPoint{ + Hash: hash, + Index: 0, + }) + break + } + } + + tx, desc := createTestTx(inputs, 1) + if err := g.AddTransaction(tx, desc); err == nil { + addedTxs[*tx.Hash()] = tx + } + + case 2: + if len(addedTxs) > 0 { + // Remove random transaction. + for hash := range addedTxs { + if err := g.RemoveTransaction(hash); err == nil { + delete(addedTxs, hash) + } + break + } + } + } + } + + // Property: Metrics should match actual counts. + metrics := g.GetMetrics() + require.Equal(t, len(g.nodes), metrics.NodeCount, + "Node count mismatch") + + // Count actual edges. + actualEdges := 0 + for _, node := range g.nodes { + actualEdges += len(node.Children) + } + require.Equal(t, actualEdges, metrics.EdgeCount, + "Edge count mismatch") + }) +} + +// TestPropertyPackageTopology verifies package topology calculations. +func TestPropertyPackageTopology(t *testing.T) { + rapid.Check(t, func(t *rapid.T) { + g := New(DefaultConfig()) + + // Create a simple 1P1C package. + parent, parentDesc := createTestTx(nil, 1) + child, childDesc := createTestTx( + []wire.OutPoint{{Hash: *parent.Hash(), Index: 0}}, 1) + + g.AddTransaction(parent, parentDesc) + g.AddTransaction(child, childDesc) + + // Try to identify packages. + packages, err := g.IdentifyPackages() + require.NoError(t, err) + + // Property: 1P1C package should be identified correctly. + found1P1C := false + for _, pkg := range packages { + if pkg.Type == PackageType1P1C { + found1P1C = true + require.Len(t, pkg.Transactions, 2) + require.Equal(t, 1, pkg.Topology.MaxDepth) + require.Equal(t, 1, pkg.Topology.MaxWidth) + require.True(t, pkg.Topology.IsLinear) + require.True(t, pkg.Topology.IsTree) + } + } + require.True(t, found1P1C, "1P1C package not identified") + }) +} + +// TestPropertyIteratorCompleteness verifies iterators visit all expected nodes. +func TestPropertyIteratorCompleteness(t *testing.T) { + rapid.Check(t, func(t *rapid.T) { + g := New(DefaultConfig()) + + // Create random transactions. + numTxs := rapid.IntRange(5, 20).Draw(t, "numTxs") + expectedHashes := make(map[chainhash.Hash]bool) + + for i := 0; i < numTxs; i++ { + tx, desc := createTestTx(nil, 1) + g.AddTransaction(tx, desc) + expectedHashes[*tx.Hash()] = true + } + + // Property: BFS iterator should visit all nodes. + visitedBFS := make(map[chainhash.Hash]bool) + for node := range g.Iterate( + WithOrder(TraversalBFS), + ) { + visitedBFS[node.TxHash] = true + } + require.Equal(t, expectedHashes, visitedBFS, + "BFS didn't visit all nodes") + + // Property: DFS iterator should visit all nodes. + visitedDFS := make(map[chainhash.Hash]bool) + for node := range g.Iterate( + WithOrder(TraversalDFS), + ) { + visitedDFS[node.TxHash] = true + } + require.Equal(t, expectedHashes, visitedDFS, + "DFS didn't visit all nodes") + }) +} + +// TestAddTransactionReverseOrder tests the "Find children that spend this +// transaction" code path by adding transactions in reverse topological order +// (children before parents). This tests that the spentBy index is maintained +// correctly even when children are added before their parents, which is +// critical for Bitcoin mempool behavior. +func TestAddTransactionReverseOrder(t *testing.T) { + g := New(DefaultConfig()) + + // Create parent transaction but don't add it yet. + parent, parentDesc := createTestTx(nil, 2) + parentHash := parent.Hash() + + // Create child that spends output 0 of parent. + child1, child1Desc := createTestTx( + []wire.OutPoint{{Hash: *parentHash, Index: 0}}, 1, + ) + + // Create child that spends output 1 of parent. + child2, child2Desc := createTestTx( + []wire.OutPoint{{Hash: *parentHash, Index: 1}}, 1, + ) + + // Add children FIRST (children arrive before parent in mempool). This + // should populate spentBy index even though parent doesn't exist yet. + require.NoError(t, g.AddTransaction(child1, child1Desc)) + require.NoError(t, g.AddTransaction(child2, child2Desc)) + + // Verify children exist but have no parents yet (orphaned). + child1Node, _ := g.GetNode(*child1.Hash()) + child2Node, _ := g.GetNode(*child2.Hash()) + require.Len(t, child1Node.Parents, 0, "child1 should have no parents yet") + require.Len(t, child2Node.Parents, 0, "child2 should have no parents yet") + + // Now add parent - should trigger "Find children that spend this + // transaction". This reconnects the orphaned children to their parent. + require.NoError(t, g.AddTransaction(parent, parentDesc)) + + // Verify parent is connected to both children. + parentNode, exists := g.GetNode(h(parentHash)) + require.True(t, exists) + require.Len(t, parentNode.Children, 2, "parent should have 2 children") + require.NotNil(t, parentNode.Children[*child1.Hash()]) + require.NotNil(t, parentNode.Children[*child2.Hash()]) + + // Verify children are now connected to parent. + child1Node, _ = g.GetNode(*child1.Hash()) + child2Node, _ = g.GetNode(*child2.Hash()) + require.Len(t, child1Node.Parents, 1, "child1 should have 1 parent") + require.Len(t, child2Node.Parents, 1, "child2 should have 1 parent") + require.NotNil(t, child1Node.Parents[*parentHash]) + require.NotNil(t, child2Node.Parents[*parentHash]) + + // Verify edge count is correct (2 edges). + metrics := g.GetMetrics() + require.Equal(t, 2, metrics.EdgeCount, "should have 2 edges") +} + +// TestRemoveTransactionWithChildren tests edge removal from children +// in removeTransactionUnsafe. +func TestRemoveTransactionWithChildren(t *testing.T) { + g := New(DefaultConfig()) + + // Create a diamond pattern: + // tx1 + // / \ + // tx2 tx3 + // \ / + // tx4 + tx1, desc1 := createTestTx(nil, 2) + require.NoError(t, g.AddTransaction(tx1, desc1)) + + tx2, desc2 := createTestTx( + []wire.OutPoint{{Hash: *tx1.Hash(), Index: 0}}, 2, + ) + require.NoError(t, g.AddTransaction(tx2, desc2)) + + tx3, desc3 := createTestTx( + []wire.OutPoint{{Hash: *tx1.Hash(), Index: 1}}, 2, + ) + require.NoError(t, g.AddTransaction(tx3, desc3)) + + tx4Inputs := []wire.OutPoint{ + {Hash: *tx2.Hash(), Index: 0}, + {Hash: *tx3.Hash(), Index: 0}, + } + tx4, desc4 := createTestTx(tx4Inputs, 1) + require.NoError(t, g.AddTransaction(tx4, desc4)) + + // Verify initial state. + require.Equal(t, 4, g.GetNodeCount()) + require.Equal(t, 4, g.GetMetrics().EdgeCount) + + // Remove tx2 (should also remove tx4 as descendant). + require.NoError(t, g.RemoveTransaction(*tx2.Hash())) + + // Verify tx2 and tx4 are removed (tx4 loses both parents). + require.False(t, g.HasTransaction(*tx2.Hash())) + require.False(t, g.HasTransaction(*tx4.Hash())) + + // Verify tx1 and tx3 still exist. + require.True(t, g.HasTransaction(*tx1.Hash())) + require.True(t, g.HasTransaction(*tx3.Hash())) + + // Verify tx3 has its parent edge to tx1 still intact but child edge to + // tx4 is gone. + tx3Node, _ := g.GetNode(*tx3.Hash()) + require.Len(t, tx3Node.Parents, 1) + require.NotNil(t, tx3Node.Parents[*tx1.Hash()]) + require.Len( + t, tx3Node.Children, 0, "tx4 removed so tx3 should have "+ + "no children", + ) + + // Verify tx1 has only tx3 as child now. + tx1Node, _ := g.GetNode(*tx1.Hash()) + require.Len(t, tx1Node.Children, 1) + require.NotNil(t, tx1Node.Children[*tx3.Hash()]) +} + +// TestRemoveTransactionPackageCleanup tests package cleanup in +// removeTransactionUnsafe. +func TestRemoveTransactionPackageCleanup(t *testing.T) { + g := New(DefaultConfig()) + + // Create a 1P1C package. + parent, parentDesc := createTestTx(nil, 1) + child, childDesc := createTestTx( + []wire.OutPoint{{Hash: *parent.Hash(), Index: 0}}, 1, + ) + + require.NoError(t, g.AddTransaction(parent, parentDesc)) + require.NoError(t, g.AddTransaction(child, childDesc)) + + // Identify and assign packages. + packages, err := g.IdentifyPackages() + require.NoError(t, err) + require.Len(t, packages, 1) + + // Assign package IDs. + for i := range packages { + pkg := packages[i] + for hash := range pkg.Transactions { + node, _ := g.GetNode(hash) + node.Metadata.PackageID = &pkg.ID + } + g.indexes.packages[pkg.ID] = pkg + for hash := range pkg.Transactions { + g.indexes.nodeToPackage[hash] = pkg.ID + } + } + + // Verify package exists. + require.Equal(t, 1, g.GetMetrics().PackageCount) + + // Remove parent (should remove child and clean up package). + require.NoError(t, g.RemoveTransaction(*parent.Hash())) + + // Verify both transactions are gone. + require.False(t, g.HasTransaction(*parent.Hash())) + require.False(t, g.HasTransaction(*child.Hash())) + + // Verify package is cleaned up. + require.Equal(t, 0, g.GetMetrics().PackageCount) + require.Len(t, g.indexes.packages, 0) + require.Len(t, g.indexes.nodeToPackage, 0) +} + +// TestPropertyTransactionChains tests random transaction chains with +// add/remove operations. +func TestPropertyTransactionChains(t *testing.T) { + rapid.Check(t, func(t *rapid.T) { + g := New(DefaultConfig()) + + // Generate random chain length. + chainLength := rapid.IntRange(2, 10).Draw(t, "chainLength") + + // Build a chain of transactions in packages. + chain := make([]*btcutil.Tx, chainLength) + var prevHash *chainhash.Hash + + for i := 0; i < chainLength; i++ { + var inputs []wire.OutPoint + if prevHash != nil { + inputs = append(inputs, wire.OutPoint{ + Hash: *prevHash, + Index: 0, + }) + } + + tx, desc := createTestTx(inputs, 1) + err := g.AddTransaction(tx, desc) + require.NoError(t, err) + + chain[i] = tx + prevHash = tx.Hash() + } + + // Verify all transactions were added. + require.Equal(t, chainLength, g.GetNodeCount()) + + // Property: Edge count should be chainLength - 1. + expectedEdges := chainLength - 1 + require.Equal(t, expectedEdges, g.GetMetrics().EdgeCount, + "Edge count mismatch after adding chain") + + // Now remove transactions in random order and verify + // invariants. + indices := make([]int, chainLength) + for i := range indices { + indices[i] = i + } + removalOrder := rapid.Permutation(indices).Draw( + t, "removalOrder", + ) + + for _, idx := range removalOrder { + tx := chain[idx] + if !g.HasTransaction(*tx.Hash()) { + // Already removed as descendant. + continue + } + + // Remember state before removal. + nodeCountBefore := g.GetNodeCount() + edgeCountBefore := g.GetMetrics().EdgeCount + + // Remove transaction. + err := g.RemoveTransaction(*tx.Hash()) + require.NoError(t, err) + + // Verify transaction is gone. + require.False(t, g.HasTransaction(*tx.Hash())) + + // Verify node count decreased. + require.Less(t, g.GetNodeCount(), nodeCountBefore) + + // Property: No dangling edges. + for _, node := range g.nodes { + for parentHash, parent := range node.Parents { + require.NotNil( + t, parent.Children[node.TxHash], + "Dangling edge: parent %v "+ + "doesn't have child %v", + parentHash, node.TxHash, + ) + } + for childHash, child := range node.Children { + require.NotNil( + t, child.Parents[node.TxHash], + "Dangling edge: child %v "+ + "doesn't have parent %v", + childHash, node.TxHash, + ) + } + } + + // Property: Edge count is consistent. + actualEdges := 0 + for _, node := range g.nodes { + actualEdges += len(node.Children) + } + require.Equal( + t, actualEdges, g.GetMetrics().EdgeCount, + "Edge count inconsistent after removal", + ) + require.LessOrEqual( + t, g.GetMetrics().EdgeCount, edgeCountBefore, + "Edge count should not increase after removal", + ) + } + + // Property: After removing all, graph should be empty. + require.Equal(t, 0, g.GetNodeCount(), "Graph should be empty") + require.Equal( + t, 0, g.GetMetrics().EdgeCount, "No edges should remain", + ) + }) +} + +// TestIterateClusterWithFilter tests cluster iteration with filters. +func TestIterateClusterWithFilter(t *testing.T) { + g := New(DefaultConfig()) + + // Create a cluster with multiple transactions. + tx1, desc1 := createTestTx(nil, 2) + tx2, desc2 := createTestTx( + []wire.OutPoint{{Hash: *tx1.Hash(), Index: 0}}, 2, + ) + tx3, desc3 := createTestTx( + []wire.OutPoint{{Hash: *tx2.Hash(), Index: 0}}, 1, + ) + + require.NoError(t, g.AddTransaction(tx1, desc1)) + require.NoError(t, g.AddTransaction(tx2, desc2)) + require.NoError(t, g.AddTransaction(tx3, desc3)) + + // All should be in same cluster. + require.Equal(t, 1, g.GetClusterCount()) + + // Test iteration with filter (only nodes with 2 outputs). + var filtered []*TxGraphNode + filterFunc := func(node *TxGraphNode) bool { + return len(node.Tx.MsgTx().TxOut) == 2 + } + + for node := range g.Iterate( + WithOrder(TraversalCluster), WithStartNode(tx1.Hash()), + WithFilter(filterFunc), + ) { + filtered = append(filtered, node) + } + + // Should only get tx1 and tx2 (both have 2 outputs), not tx3 (1 + // output). + require.Len(t, filtered, 2) + + // Test iterating all clusters with nil start node. + var allInClusters []*TxGraphNode + for node := range g.Iterate(WithOrder(TraversalCluster)) { + allInClusters = append(allInClusters, node) + } + require.Len( + t, allInClusters, 3, "should iterate all nodes when start "+ + "is nil", + ) + + // Test early exit from yield function. + count := 0 + for range g.Iterate(WithOrder(TraversalCluster)) { + count++ + if count >= 2 { + break // Exit early + } + } + require.Equal( + t, 2, count, "should stop iteration when yield returns false", + ) +} + +// TestPropertyComplexPackageChains tests complex transaction graphs in +// packages. +func TestPropertyComplexPackageChains(t *testing.T) { + rapid.Check(t, func(t *rapid.T) { + g := New(DefaultConfig()) + gen := newTxGenerator() + + // Create a more complex graph structure. + numRoots := rapid.IntRange(1, 3).Draw(t, "numRoots") + roots := make([]*btcutil.Tx, numRoots) + + // Create root transactions. + for i := 0; i < numRoots; i++ { + tx, desc := gen.createTx(nil, 2) + require.NoError(t, g.AddTransaction(tx, desc)) + roots[i] = tx + } + + // Create children that may spend from multiple roots. + numChildren := rapid.IntRange(1, 5).Draw(t, "numChildren") + allTxs := make([]*btcutil.Tx, 0, numRoots+numChildren) + allTxHashes := make([]chainhash.Hash, 0, numRoots+numChildren) + allTxs = append(allTxs, roots...) + for _, tx := range roots { + allTxHashes = append(allTxHashes, *tx.Hash()) + } + + for i := 0; i < numChildren; i++ { + // Randomly select parents from existing transactions. + maxParents := 2 + if len(allTxs) < maxParents { + maxParents = len(allTxs) + } + numParents := rapid.IntRange(1, maxParents).Draw( + t, "numParents", + ) + inputs := make([]wire.OutPoint, 0, numParents) + + txIndices := make([]int, len(allTxs)) + for j := range txIndices { + txIndices[j] = j + } + selectedParents := rapid.Permutation(txIndices).Draw( + t, "parentPerm", + ) + for j := 0; j < numParents; j++ { + parentTx := allTxs[selectedParents[j]] + numOutputs := len(parentTx.MsgTx().TxOut) + if numOutputs > 0 { + // Use different output indexes to + // avoid conflicts. + outputIdx := uint32(j % numOutputs) + inputs = append(inputs, wire.OutPoint{ + Hash: *parentTx.Hash(), + Index: outputIdx, + }) + } + } + + if len(inputs) > 0 { + tx, desc := gen.createTx(inputs, 1) + if err := g.AddTransaction(tx, desc); err == nil { + allTxs = append(allTxs, tx) + allTxHashes = append( + allTxHashes, *tx.Hash(), + ) + } + } + } + + initialNodeCount := g.GetNodeCount() + initialEdgeCount := g.GetMetrics().EdgeCount + + // Property: Removing transactions maintains graph invariants. + if len(allTxHashes) > 0 && g.GetNodeCount() > 0 { + // Remove a random transaction using the stored hash + // value. + removeIdx := rapid.IntRange(0, len(allTxHashes)-1).Draw( + t, "removeIdx", + ) + hashToRemove := allTxHashes[removeIdx] + + // Debug: check if hash is in graph. + hasIt := g.HasTransaction(hashToRemove) + t.Logf("Hash to remove: %v, HasTransaction=%v", + hashToRemove, hasIt) + + // Only proceed if transaction is actually in the graph. + if !hasIt { + // Transaction may not have been added due to + // conflicts. + t.Logf("Skipping because transaction not in " + + "graph") + return + } + + t.Logf("About to call RemoveTransaction") + err := g.RemoveTransaction(hashToRemove) + t.Logf("RemoveTransaction returned: %v", err) + + require.NoError( + t, err, "failed to remove transaction %v", + hashToRemove, + ) + + // Property: Node count decreased. + require.LessOrEqual( + t, g.GetNodeCount(), initialNodeCount, + ) + + // Property: Edge count decreased or stayed same. + require.LessOrEqual( + t, g.GetMetrics().EdgeCount, initialEdgeCount, + ) + + // Property: All remaining edges are valid. + for _, node := range g.nodes { + for _, parent := range node.Parents { + require.NotNil( + t, parent.Children[node.TxHash], + "Invalid parent-child relationship", + ) + } + for _, child := range node.Children { + require.NotNil( + t, child.Parents[node.TxHash], + "Invalid child-parent relationship", + ) + } + } + + // Property: Metrics are consistent. + actualEdges := 0 + for _, node := range g.nodes { + actualEdges += len(node.Children) + } + require.Equal(t, actualEdges, g.GetMetrics().EdgeCount) + } + }) +} + +// TestWithIncludeStart tests the WithIncludeStart iterator option. +func TestWithIncludeStart(t *testing.T) { + t.Parallel() + + g := New(DefaultConfig()) + gen := newTxGenerator() + + // Create a chain of transactions: tx1 -> tx2 -> tx3 -> tx4. + tx1, desc1 := gen.createTx(nil, 1) + tx2, desc2 := gen.createTx( + []wire.OutPoint{{Hash: *tx1.Hash(), Index: 0}}, 1, + ) + tx3, desc3 := gen.createTx( + []wire.OutPoint{{Hash: *tx2.Hash(), Index: 0}}, 1, + ) + tx4, desc4 := gen.createTx( + []wire.OutPoint{{Hash: *tx3.Hash(), Index: 0}}, 1, + ) + + require.NoError(t, g.AddTransaction(tx1, desc1)) + require.NoError(t, g.AddTransaction(tx2, desc2)) + require.NoError(t, g.AddTransaction(tx3, desc3)) + require.NoError(t, g.AddTransaction(tx4, desc4)) + + // Test with IncludeStart = false (default behavior). Starting from + // tx2, we should get descendants without tx2 itself. + visitedExclude := slices.Collect(g.Iterate( + WithStartNode(tx2.Hash()), + WithOrder(TraversalDescendants), + WithIncludeStart(false), + )) + require.Len( + t, visitedExclude, 2, "Should visit 2 descendants (tx3, tx4)", + ) + excludeHashes := slicesMap( + visitedExclude, + func(n *TxGraphNode) chainhash.Hash { return n.TxHash }, + ) + + require.Contains(t, excludeHashes, *tx3.Hash()) + require.Contains(t, excludeHashes, *tx4.Hash()) + require.NotContains( + t, excludeHashes, *tx2.Hash(), "Should not include "+ + "starting node", + ) + + // Test with IncludeStart = true. Starting from tx2, we should get + // descendants including tx2 itself. + visitedInclude := slices.Collect(g.Iterate( + WithStartNode(tx2.Hash()), + WithOrder(TraversalDescendants), + WithIncludeStart(true), + )) + + require.Len( + t, visitedInclude, 3, "Should visit 3 nodes (tx2, tx3, tx4)", + ) + + includeHashes := slicesMap( + visitedInclude, + func(n *TxGraphNode) chainhash.Hash { return n.TxHash }, + ) + + require.Contains( + t, includeHashes, *tx2.Hash(), "Should include starting node", + ) + require.Contains(t, includeHashes, *tx3.Hash()) + require.Contains(t, includeHashes, *tx4.Hash()) + + // Test with ancestors direction and IncludeStart = true. + visitedAncestors := slices.Collect(g.Iterate( + WithStartNode(tx3.Hash()), + WithOrder(TraversalAncestors), + WithIncludeStart(true), + )) + require.Len( + t, visitedAncestors, 3, "Should visit 3 nodes (tx3, tx2, tx1)", + ) + ancestorHashes := slicesMap( + visitedAncestors, + func(n *TxGraphNode) chainhash.Hash { return n.TxHash }, + ) + + require.Contains( + t, ancestorHashes, *tx3.Hash(), "Should include starting node", + ) + require.Contains(t, ancestorHashes, *tx2.Hash()) + require.Contains(t, ancestorHashes, *tx1.Hash()) +} + +// slicesMap maps a slice using a transform function. +func slicesMap[T any, U any](slice []T, fn func(T) U) []U { + result := make([]U, len(slice)) + for i, v := range slice { + result[i] = fn(v) + } + return result +} + +// TestValidatePackageErrors tests error cases in ValidatePackage. +func TestValidatePackageErrors(t *testing.T) { + t.Parallel() + + g := New(DefaultConfig()) + + // Test nil package. + err := g.ValidatePackage(nil) + require.Error(t, err) + require.Contains(t, err.Error(), "nil package") + + // Test empty package. + emptyPkg := &TxPackage{ + ID: PackageID{Type: PackageType1P1C}, + Transactions: make(map[chainhash.Hash]*TxGraphNode), + } + err = g.ValidatePackage(emptyPkg) + require.Error(t, err) + require.Contains(t, err.Error(), "empty package") + + // Test package too large. + cfg := DefaultConfig() + largePkg := &TxPackage{ + ID: PackageID{Type: PackageType1P1C}, + Transactions: make(map[chainhash.Hash]*TxGraphNode), + } + // Create more transactions than allowed by config. + maxSize := cfg.MaxPackageSize + gen := newTxGenerator() + for i := 0; i < maxSize+1; i++ { + tx, _ := gen.createTx(nil, 1) + node := &TxGraphNode{ + TxHash: *tx.Hash(), + } + largePkg.Transactions[*tx.Hash()] = node + } + err = g.ValidatePackage(largePkg) + require.Error(t, err) + require.Contains(t, err.Error(), "package too large") +} + +// TestRemoveTransactionWithTRUC tests removal of TRUC transactions. +func TestRemoveTransactionWithTRUC(t *testing.T) { + t.Parallel() + + g := New(DefaultConfig()) + gen := newTxGenerator() + + // Create a TRUC transaction. + tx1, desc1 := gen.createTx(nil, 1) + require.NoError(t, g.AddTransaction(tx1, desc1)) + + // Mark it as TRUC by directly accessing the internal structure. + node, exists := g.GetNode(*tx1.Hash()) + require.True(t, exists) + node.Metadata.IsTRUC = true + g.indexes.trucTxs[*tx1.Hash()] = node + + // Update metrics. + oldTrucCount := g.GetMetrics().TRUCCount + atomic.AddInt32(&g.metrics.trucCount, 1) + + // Verify TRUC transaction is tracked. + require.Equal(t, oldTrucCount+1, g.GetMetrics().TRUCCount) + _, exists = g.indexes.trucTxs[*tx1.Hash()] + require.True(t, exists, "TRUC transaction should be in index") + + // Remove the TRUC transaction. + err := g.RemoveTransaction(*tx1.Hash()) + require.NoError(t, err) + + // Verify TRUC index is cleaned up. + _, exists = g.indexes.trucTxs[*tx1.Hash()] + require.False( + t, exists, "TRUC transaction should be removed from index", + ) + require.Equal( + t, oldTrucCount, g.GetMetrics().TRUCCount, + "TRUC count should be decremented", + ) +} + +// TestRemoveTransactionWithEphemeral tests removal of ephemeral transactions. +func TestRemoveTransactionWithEphemeral(t *testing.T) { + t.Parallel() + + g := New(DefaultConfig()) + gen := newTxGenerator() + + // Create an ephemeral transaction. + tx1, desc1 := gen.createTx(nil, 1) + require.NoError(t, g.AddTransaction(tx1, desc1)) + + // Mark it as ephemeral by directly accessing the internal structure. + node, exists := g.GetNode(*tx1.Hash()) + require.True(t, exists) + node.Metadata.IsEphemeral = true + g.indexes.ephemeralTxs[*tx1.Hash()] = node + + // Update metrics. + oldEphemeralCount := g.GetMetrics().EphemeralCount + atomic.AddInt32(&g.metrics.ephemeralCount, 1) + + // Verify ephemeral transaction is tracked. + require.Equal(t, oldEphemeralCount+1, g.GetMetrics().EphemeralCount) + _, exists = g.indexes.ephemeralTxs[*tx1.Hash()] + require.True(t, exists, "Ephemeral transaction should be in index") + + // Remove the ephemeral transaction. + err := g.RemoveTransaction(*tx1.Hash()) + require.NoError(t, err) + + // Verify ephemeral index is cleaned up. + _, exists = g.indexes.ephemeralTxs[*tx1.Hash()] + require.False( + t, exists, "Ephemeral transaction should be removed from index", + ) + require.Equal( + t, oldEphemeralCount, g.GetMetrics().EphemeralCount, + "Ephemeral count should be decremented", + ) +} + +// TestRemoveTransactionNoCascade tests that RemoveTransactionNoCascade +// properly removes a transaction without cascading to its children, and cleans +// up edges. +func TestRemoveTransactionNoCascade(t *testing.T) { + t.Parallel() + + g := New(DefaultConfig()) + gen := newTxGenerator() + + // Create a structure where a child has TWO parents: + // parent1 parent2 + // \ / + // child + // + // When we remove parent1 with NoCascade, the child should remain + // (still has parent2), but the edge from parent1 to child must be + // cleaned up. + parent1, parent1Desc := gen.createTx(nil, 1) + require.NoError(t, g.AddTransaction(parent1, parent1Desc)) + + parent2, parent2Desc := gen.createTx(nil, 1) + require.NoError(t, g.AddTransaction(parent2, parent2Desc)) + + // Create a child that spends from BOTH parents. + childInputs := []wire.OutPoint{ + {Hash: *parent1.Hash(), Index: 0}, + {Hash: *parent2.Hash(), Index: 0}, + } + child, childDesc := gen.createTx(childInputs, 1) + require.NoError(t, g.AddTransaction(child, childDesc)) + + // Verify initial state: child has 2 parents. + childNode, _ := g.GetNode(*child.Hash()) + require.Len(t, childNode.Parents, 2) + require.NotNil(t, childNode.Parents[*parent1.Hash()]) + require.NotNil(t, childNode.Parents[*parent2.Hash()]) + + parent1Node, _ := g.GetNode(*parent1.Hash()) + require.Len(t, parent1Node.Children, 1) + require.NotNil(t, parent1Node.Children[*child.Hash()]) + + initialEdgeCount := g.GetMetrics().EdgeCount + require.Equal(t, 2, initialEdgeCount) + + // Remove parent1 without cascade (simulating confirmation). Child + // should remain because it still has parent2. + err := g.RemoveTransactionNoCascade(h(parent1.Hash())) + require.NoError(t, err) + + // Verify parent1 is gone. + require.False(t, g.HasTransaction(*parent1.Hash())) + + // Verify child still exists (still has parent2). + require.True(t, g.HasTransaction(*child.Hash())) + + // Verify parent2 still exists. + require.True(t, g.HasTransaction(*parent2.Hash())) + + // CRITICAL: Verify that the child's parent1 reference is cleaned + // up. + childNode, _ = g.GetNode(*child.Hash()) + require.Len( + t, childNode.Parents, 1, "child should have 1 parent remaining", + ) + require.Nil( + t, childNode.Parents[*parent1.Hash()], + "parent1 should not exist in child's parent map", + ) + require.NotNil( + t, childNode.Parents[*parent2.Hash()], + "parent2 should still exist in child's parent map", + ) + + // Verify edge count is decremented correctly (1 edge removed). + finalEdgeCount := g.GetMetrics().EdgeCount + require.Equal( + t, initialEdgeCount-1, finalEdgeCount, + "should have removed 1 edge", + ) +} + +// TestAddToClusterWhenClusterDoesNotExist tests the path where addToCluster +// is called with a cluster ID that doesn't exist. +func TestAddToClusterWhenClusterDoesNotExist(t *testing.T) { + t.Parallel() + + g := New(DefaultConfig()) + gen := newTxGenerator() + + // Create two transactions that will form a cluster. + tx1, desc1 := gen.createTx(nil, 1) + tx2, desc2 := gen.createTx([]wire.OutPoint{{Hash: *tx1.Hash(), Index: 0}}, 1) + + require.NoError(t, g.AddTransaction(tx1, desc1)) + require.NoError(t, g.AddTransaction(tx2, desc2)) + + // Get the cluster ID for tx1. + clusterID := g.indexes.nodeToCluster[*tx1.Hash()] + require.NotEqual(t, ClusterID(0), clusterID, "tx1 should be in a cluster") + + // Now manually delete the cluster from the index to simulate the + // scenario where addToCluster is called with a non-existent cluster. + delete(g.indexes.clusters, clusterID) + + // Create a new transaction. + tx3, desc3 := gen.createTx(nil, 1) + require.NoError(t, g.AddTransaction(tx3, desc3)) + node3, exists := g.GetNode(*tx3.Hash()) + require.True(t, exists) + + // Manually call addToCluster with the deleted cluster ID. This + // should trigger the !exists path and create a new cluster. + oldClusterCount := g.GetMetrics().ClusterCount + g.addToCluster(node3, clusterID) + + // Verify a new cluster was created. + require.Equal( + t, oldClusterCount+1, g.GetMetrics().ClusterCount, + "New cluster should be created", + ) + newClusterID := g.indexes.nodeToCluster[*tx3.Hash()] + require.NotEqual(t, ClusterID(0), newClusterID, "tx3 should be in a cluster") + + // Verify the cluster exists and contains tx3. + cluster, exists := g.indexes.clusters[newClusterID] + require.True(t, exists, "New cluster should exist") + require.Equal(t, 1, cluster.Size, "Cluster should have 1 node") + require.Contains(t, cluster.Nodes, *tx3.Hash(), "Cluster should contain tx3") +} + diff --git a/mempool/txgraph/interfaces.go b/mempool/txgraph/interfaces.go new file mode 100644 index 0000000000..693074632b --- /dev/null +++ b/mempool/txgraph/interfaces.go @@ -0,0 +1,683 @@ +package txgraph + +import ( + "iter" + "time" + + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" +) + +// TxDesc contains transaction metadata for graph nodes. +// This is a simplified version to avoid circular mempool dependencies while +// still providing the information needed for ancestor/descendant calculations, +// fee rate analysis, and package validation. +type TxDesc struct { + // TxHash is the transaction identifier used for graph lookups and + // relationship tracking. + TxHash chainhash.Hash + + // VirtualSize is used to calculate ancestor/descendant size limits and + // to compute effective fee rates for package evaluation. + VirtualSize int64 + + // Fee is tracked to enable package fee calculations and to determine + // whether fee-based policies are satisfied. + Fee int64 + + // FeePerKB enables sorting and filtering transactions by fee rate, + // which is critical for block template construction and RBF logic. + FeePerKB int64 + + // Added tracks insertion time to enable time-based expiration and to + // provide ordering for transactions with identical fee rates. + Added time.Time +} + +// PackageType represents the type of transaction package. +type PackageType uint8 + +const ( + // PackageTypeUnknown represents an unknown package type. + PackageTypeUnknown PackageType = iota + + // PackageTypeStandard represents a standard package. + PackageTypeStandard + + // PackageType1P1C represents a one-parent-one-child package. + PackageType1P1C + + // PackageTypeTRUC represents a TRUC (v3) constrained package. + PackageTypeTRUC + + // PackageTypeEphemeral represents a package with ephemeral dust. + PackageTypeEphemeral +) + +// ClusterID uniquely identifies a connected component of transactions. +type ClusterID uint64 + +// PackageID uniquely identifies a transaction package. +type PackageID struct { + // Hash identifies the root transaction of the package, which serves + // as the canonical identifier since all packages are rooted at a + // specific transaction. + Hash chainhash.Hash + + // Type distinguishes between package types (1P1C, TRUC, ephemeral) + // since the same root transaction could theoretically belong to + // multiple package interpretations. + Type PackageType +} + +// FeeratePoint represents a point on a feerate diagram. +type FeeratePoint struct { + // CumulativeSize tracks the total size up to this point, enabling + // efficient comparison of feerate diagrams for RBF validation. + CumulativeSize int64 + + // CumulativeFee tracks the total fees up to this point, used to + // compute marginal fee rates and incentive compatibility. + CumulativeFee int64 + + // Feerate stores the marginal feerate at this point for quick + // comparisons without recomputing from cumulative values. + Feerate int64 +} + +// PackageTopology describes the shape of a package. +type PackageTopology struct { + // MaxDepth tracks the longest ancestor chain in the package, enabling + // enforcement of depth-based limits like TRUC's single-parent rule. + MaxDepth int + + // MaxWidth tracks the maximum number of siblings at any level, + // enabling detection of fan-out patterns. + MaxWidth int + + // TotalNodes counts transactions in the package for quick size checks + // without iterating the transaction map. + TotalNodes int + + // IsLinear indicates a simple chain structure (A->B->C) which enables + // optimizations for 1P1C packages and simpler validation logic. + IsLinear bool + + // IsTree indicates no transaction has multiple parents (no diamond + // patterns), which simplifies fee rate calculations and ensures + // unambiguous ancestor relationships. + IsTree bool +} + +// TxEdge represents metadata about an edge. +type TxEdge struct { + // OutPoints identifies which specific outputs are being spent in this + // relationship, enabling detection of conflicts and double-spends. + OutPoints []wire.OutPoint + + // Value tracks the total satoshi amount flowing through this edge, + // enabling economic analysis of transaction relationships. + Value int64 + + // Created records when this edge was established, useful for + // time-based analysis and debugging. + Created time.Time +} + +// GraphMetrics provides statistics about the transaction graph. +type GraphMetrics struct { + // NodeCount tracks the total number of transactions in the graph for + // capacity planning and monitoring. + NodeCount int + + // EdgeCount tracks the total number of parent-child relationships, + // indicating graph connectivity and complexity. + EdgeCount int + + // PackageCount tracks identified packages for relay and mining + // optimization monitoring. + PackageCount int + + // TRUCCount tracks v3 transactions for monitoring adoption of the TRUC + // policy and ensuring topology restrictions are enforced. + TRUCCount int + + // EphemeralCount tracks transactions with ephemeral dust, which require + // special handling to ensure dust is always spent. + EphemeralCount int + + // MaxAncestors tracks the largest ancestor set size, helping identify + // transactions approaching policy limits. + MaxAncestors int + + // MaxDescendants tracks the largest descendant set size for policy + // limit monitoring and potential eviction candidates. + MaxDescendants int + + // AveragePackageSize provides insight into typical package complexity + // for resource planning and optimization. + AveragePackageSize float64 + + // ClusterCount tracks the number of connected components, indicating + // mempool fragmentation and potential for cluster-based eviction. + ClusterCount int +} + +// TxGraphNode represents a single transaction in the graph. +type TxGraphNode struct { + // TxHash enables O(1) lookups in maps without dereferencing Tx. + TxHash chainhash.Hash + + // Tx provides access to inputs and outputs for validation and edge + // creation. + Tx *btcutil.Tx + + // TxDesc stores fee and size information needed for policy decisions. + TxDesc *TxDesc + + // Parents maps to transactions that this transaction spends outputs + // from. Using a map enables O(1) parent existence checks during graph + // traversal and cycle detection. + Parents map[chainhash.Hash]*TxGraphNode + + // Children maps to transactions that spend this transaction's outputs. + // Map structure allows efficient child removal during eviction without + // scanning slices. + Children map[chainhash.Hash]*TxGraphNode + + // cachedMetrics stores expensive-to-compute graph properties to avoid + // repeated traversals during policy checks. The cache is invalidated + // when ancestors or descendants change. + cachedMetrics struct { + // AncestorCount enables quick checks against BIP 125 limits. + AncestorCount int32 + + // DescendantCount enforces mempool policy limits efficiently. + DescendantCount int32 + + // AncestorSize tracks cumulative size for package limit checks. + AncestorSize int64 + + // DescendantSize enables fast descendant limit validation. + DescendantSize int64 + + // AncestorFees supports CPFP calculations. + AncestorFees int64 + + // DescendantFees enables descendant fee rate computations. + DescendantFees int64 + + // LastUpdated allows cache invalidation based on graph changes. + LastUpdated time.Time + } + + // Metadata holds feature-specific flags and relationships that don't + // affect core graph structure but enable specialized processing. + Metadata struct { + // IsTRUC marks v3 transactions for topology validation. + IsTRUC bool + + // IsEphemeral identifies transactions with dust outputs that + // must be spent in the same package. + IsEphemeral bool + + // PackageID associates this transaction with its package for + // group validation and eviction. + PackageID *PackageID + + // ClusterID groups connected transactions for RBF and CPFP + // conflict detection. + ClusterID ClusterID + + // AddedTime enables time-based eviction policies. + AddedTime time.Time + } +} + +// TxCluster represents a connected component in the graph. +type TxCluster struct { + // ID uniquely identifies this cluster for tracking relationships across + // graph mutations. + ID ClusterID + + // Nodes stores all transactions in this connected component, enabling + // O(1) membership tests during cluster merges and splits. + Nodes map[chainhash.Hash]*TxGraphNode + + // Roots identifies transactions with no unconfirmed parents in this + // cluster. These are entry points for package evaluation and block + // template building. + Roots []*TxGraphNode + + // Leaves identifies transactions with no children in this cluster. + // These are candidates for eviction when the mempool is full. + Leaves []*TxGraphNode + + // Size tracks the number of transactions for quick cluster size checks + // without iterating the Nodes map. + Size int + + // TotalFees aggregates fees across the cluster to compute effective + // fee rates for mining and eviction decisions. + TotalFees int64 + + // TotalVSize aggregates virtual sizes to enforce cluster size limits + // and to calculate cluster fee rates. + TotalVSize int64 + + // FeerateDiagram caches the feerate diagram used for RBF incentive + // compatibility checks. This is expensive to compute so we cache it. + FeerateDiagram []FeeratePoint + + // LastUpdated tracks when metrics were last computed, enabling + // invalidation when the cluster changes. + LastUpdated time.Time +} + +// TxPackage represents a set of related transactions. +type TxPackage struct { + // ID uniquely identifies this package for tracking and validation. + ID PackageID + + // Transactions stores all members of the package. Using a map enables + // efficient membership checks during package validation. + Transactions map[chainhash.Hash]*TxGraphNode + + // Root identifies the root transaction that anchors this package. + // All package types are rooted at a specific transaction. + Root *TxGraphNode + + // TotalFees aggregates fees across the package to compute effective + // package fee rates for relay and mining decisions. + TotalFees int64 + + // TotalSize aggregates sizes to enforce package size limits and to + // calculate package fee rates. + TotalSize int64 + + // FeeRate stores the computed package feerate in sats per vbyte for + // quick comparisons during relay and block template construction. + FeeRate int64 + + // Type identifies the package category (1P1C, TRUC, ephemeral) which + // determines what validation rules apply. + Type PackageType + + // Topology describes the shape of the package, enabling topology-based + // validation rules like TRUC's single-child restriction. + Topology PackageTopology + + // IsValid caches the validation result to avoid repeated expensive + // validation checks during processing. + IsValid bool + + // LastValidated tracks when validation occurred, enabling cache + // invalidation if the package changes. + LastValidated time.Time +} + +// EdgePair represents a parent-child relationship. +type EdgePair struct { + // Parent is the transaction being spent from, providing context for + // graph traversal and validation. + Parent *TxGraphNode + + // Child is the transaction doing the spending, enabling forward + // traversal during descendant queries. + Child *TxGraphNode + + // Edge contains metadata about the specific outputs being spent, + // enabling detailed analysis of fund flows. + Edge *TxEdge +} + +// Graph defines the primary interface for transaction graph operations. +type Graph interface { + // AddTransaction inserts a new transaction into the graph and + // automatically creates edges to any parent transactions already in + // the graph. This enables incremental graph construction as + // transactions arrive from the network. + AddTransaction(tx *btcutil.Tx, txDesc *TxDesc) error + + // RemoveTransaction removes a transaction and recursively removes all + // descendants, maintaining graph consistency. This is used during + // block confirmations and mempool evictions to prevent orphaned + // children from remaining in the graph. + RemoveTransaction(hash chainhash.Hash) error + + // RemoveTransactionNoCascade removes only the specified transaction + // without touching descendants. This is useful when descendants will + // be explicitly handled or when the caller needs fine-grained control + // over eviction ordering. + RemoveTransactionNoCascade(hash chainhash.Hash) error + + // GetNode retrieves a transaction from the graph by hash. The boolean + // return indicates existence, enabling distinction between missing + // transactions and nil nodes. + GetNode(hash chainhash.Hash) (*TxGraphNode, bool) + + // HasTransaction checks if a transaction exists in the graph without + // retrieving it, enabling efficient existence checks when the node + // data isn't needed. + HasTransaction(hash chainhash.Hash) bool + + // AddEdge creates a parent-child relationship between two transactions + // that are already in the graph. This enables explicit edge management + // when transaction dependencies need to be added after initial + // insertion. + AddEdge(parent, child chainhash.Hash) error + + // RemoveEdge severs a parent-child relationship without removing the + // transactions themselves. This is useful for handling reorganizations + // where relationships change but transactions remain valid. + RemoveEdge(parent, child chainhash.Hash) error + + // GetAncestors returns all ancestor transactions up to maxDepth. + // This is used to enforce ancestor count/size limits for policy + // validation and to compute ancestor fee rates for CPFP. + GetAncestors( + hash chainhash.Hash, maxDepth int, + ) map[chainhash.Hash]*TxGraphNode + + // GetDescendants returns all descendant transactions up to maxDepth. + // This is used to enforce descendant limits and to identify all + // transactions that must be removed when evicting a parent. + GetDescendants( + hash chainhash.Hash, maxDepth int, + ) map[chainhash.Hash]*TxGraphNode + + // GetCluster retrieves the connected component containing the given + // transaction. This enables cluster-based fee rate calculations for + // mining and RBF validation. + GetCluster(hash chainhash.Hash) (*TxCluster, error) + + // GetOrphans returns transactions with unconfirmed inputs not in the + // mempool. A transaction is an orphan if it has no parents in the + // graph AND its inputs are not confirmed (as determined by the + // predicate). If isConfirmed is nil, all transactions with no parents + // are considered orphans. + GetOrphans(isConfirmed InputConfirmedPredicate) []*TxGraphNode + + // IdentifyPackages scans the graph to detect transaction packages + // (1P1C, TRUC, ephemeral). This enables package-aware relay and mining + // optimizations by grouping related transactions. + IdentifyPackages() ([]*TxPackage, error) + + // GetPackage retrieves a previously identified package by its root + // transaction hash. This enables efficient package lookups during + // relay validation and block template construction. + GetPackage(hash chainhash.Hash) (*TxPackage, error) + + // ValidatePackage checks if a package satisfies all type-specific + // rules (topology, size, TRUC constraints). This ensures only valid + // packages are relayed and mined. + ValidatePackage(pkg *TxPackage) error + + // Iterate returns an iterator over graph nodes using the specified + // order and filters. This enables lazy evaluation of large result sets + // without allocating memory for all matches upfront. + Iterate(opts IteratorOption) iter.Seq[*TxGraphNode] + + // IteratePairs returns an iterator over parent-child edges in the + // graph. This enables efficient edge-based analysis like conflict + // detection and fund flow tracking. + IteratePairs(opts IteratorOption) iter.Seq[EdgePair] + + // IteratePackages returns an iterator over all identified packages. + // This enables package-by-package processing during block template + // construction and relay decisions. + IteratePackages() iter.Seq[*TxPackage] + + // IterateClusters returns an iterator over connected components in the + // graph. This enables cluster-based fee rate analysis for mining and + // cluster-aware eviction policies. + IterateClusters() iter.Seq[*TxCluster] + + // IterateOrphans iterates over transactions with unconfirmed inputs + // not in the mempool. See GetOrphans for the definition of an + // orphan transaction. + IterateOrphans(isConfirmed InputConfirmedPredicate) iter.Seq[*TxGraphNode] + + // GetMetrics returns comprehensive statistics about the graph. This + // enables monitoring of mempool health, capacity planning, and + // detection of unusual graph structures. + GetMetrics() GraphMetrics + + // GetNodeCount returns the number of transactions in the graph. This + // provides a quick way to check mempool size without computing full + // metrics. + GetNodeCount() int + + // GetClusterCount returns the number of connected components. This + // indicates mempool fragmentation and is useful for understanding the + // effectiveness of cluster-based optimizations. + GetClusterCount() int +} + +// TraversalOrder defines the traversal strategy for graph iteration. +type TraversalOrder uint8 + +const ( + // TraversalDefault iterates all nodes without specific order. + TraversalDefault TraversalOrder = iota + + // TraversalDFS performs depth-first search. + TraversalDFS + + // TraversalBFS performs breadth-first search. + TraversalBFS + + // TraversalTopological visits in topological order. + TraversalTopological + + // TraversalReverseTopo visits in reverse topological order. + TraversalReverseTopo + + // TraversalAncestors visits ancestors only. + TraversalAncestors + + // TraversalDescendants visits descendants only. + TraversalDescendants + + // TraversalCluster visits all transactions in the same cluster. + TraversalCluster + + // TraversalFeeRate visits in order by fee rate (high to low). + TraversalFeeRate +) + +// TraversalDirection specifies the direction of traversal. +type TraversalDirection uint8 + +const ( + // DirectionForward traverses from parents to children. + DirectionForward TraversalDirection = iota + + // DirectionBackward traverses from children to parents. + DirectionBackward + + // DirectionBoth traverses in both directions. + DirectionBoth +) + +// IteratorOption configures graph iteration behavior. +type IteratorOption struct { + Order TraversalOrder + MaxDepth int + Filter func(*TxGraphNode) bool + StartNode *chainhash.Hash + Direction TraversalDirection + IncludeStart bool +} + +// DefaultIteratorOption returns an IteratorOption with sensible defaults. +func DefaultIteratorOption() IteratorOption { + return IteratorOption{ + Order: TraversalDefault, + MaxDepth: -1, + Direction: DirectionForward, + IncludeStart: false, + } +} + +// IterOption is a functional option for configuring iteration. +type IterOption func(*IteratorOption) + +// WithOrder sets the traversal order. +func WithOrder(order TraversalOrder) IterOption { + return func(o *IteratorOption) { + o.Order = order + } +} + +// WithMaxDepth sets the maximum traversal depth (-1 for unlimited). +func WithMaxDepth(depth int) IterOption { + return func(o *IteratorOption) { + o.MaxDepth = depth + } +} + +// WithFilter sets a filter predicate. +func WithFilter(filter func(*TxGraphNode) bool) IterOption { + return func(o *IteratorOption) { + o.Filter = filter + } +} + +// WithStartNode sets the starting node for traversal. +func WithStartNode(hash *chainhash.Hash) IterOption { + return func(o *IteratorOption) { + o.StartNode = hash + } +} + +// WithDirection sets the traversal direction. +func WithDirection(direction TraversalDirection) IterOption { + return func(o *IteratorOption) { + o.Direction = direction + } +} + +// WithIncludeStart sets whether to include the starting node. +func WithIncludeStart(include bool) IterOption { + return func(o *IteratorOption) { + o.IncludeStart = include + } +} + +// GraphQuery provides advanced query operations. +type GraphQuery interface { + // FindTransactions searches for transactions matching the specified + // criteria. This enables complex filtering operations like finding + // all TRUC transactions above a certain fee rate. + FindTransactions(criteria TxCriteria) []*TxGraphNode + + // FindPackages searches for packages matching the specified criteria. + // This enables targeted package queries like finding all valid 1P1C + // packages above a minimum fee rate. + FindPackages(criteria PackageCriteria) []*TxPackage + + // FindPath searches for a dependency path between two transactions. + // This is useful for understanding transaction relationships and + // debugging unexpected dependencies. + FindPath(from, to *chainhash.Hash) []*TxGraphNode + + // HasPath checks if a dependency path exists without computing it. + // This enables efficient reachability checks for cycle detection and + // conflict analysis. + HasPath(from, to *chainhash.Hash) bool + + // GetTopologicalOrder returns all transactions in topological order, + // ensuring parents appear before children. This is essential for block + // template construction where dependencies must be satisfied. + GetTopologicalOrder() []*TxGraphNode + + // DetectCycles finds circular dependencies in the graph, which should + // never exist but can occur due to bugs. Each inner slice represents + // one cycle detected in the graph. + DetectCycles() [][]*TxGraphNode + + // GetFeerateDistribution computes the cumulative feerate diagram for + // all transactions. This enables analysis of mempool composition and + // fee rate distributions. + GetFeerateDistribution() []FeeratePoint + + // GetPackageFeerates computes the effective fee rate for each package. + // This enables package-based comparisons for relay and mining + // decisions. + GetPackageFeerates() map[PackageID]int64 +} + +// TxCriteria defines criteria for finding transactions. +type TxCriteria struct { + // MinFeeRate filters for transactions at or above this fee rate, + // enabling queries for high-priority transactions. + MinFeeRate int64 + + // MaxFeeRate filters for transactions at or above this fee rate, + // enabling queries for low-fee transactions that may need eviction. + MaxFeeRate int64 + + // MinSize filters for transactions at or above this size, useful for + // identifying large transactions that consume significant mempool + // space. + MinSize int64 + + // MaxSize filters for transactions at or below this size, useful for + // finding small transactions or enforcing size limits. + MaxSize int64 + + // IsTRUC filters by v3 transaction status. Nil means don't filter, + // true means only v3, false means only non-v3. + IsTRUC *bool + + // IsEphemeral filters by ephemeral dust status. Nil means don't + // filter, enabling queries specific to ephemeral transactions. + IsEphemeral *bool + + // HasAncestors filters by ancestor presence. Nil means don't filter, + // true finds transactions with parents, false finds root transactions. + HasAncestors *bool + + // HasChildren filters by child presence. Nil means don't filter, true + // finds transactions with children, false finds leaf transactions. + HasChildren *bool +} + +// PackageCriteria defines criteria for finding packages. +type PackageCriteria struct { + // Type filters by package type (1P1C, TRUC, ephemeral), enabling + // type-specific package queries. + Type PackageType + + // MinSize filters for packages at or above this transaction count, + // useful for finding complex multi-transaction packages. + MinSize int + + // MaxSize filters for packages at or below this transaction count, + // useful for finding simple packages or enforcing limits. + MaxSize int + + // MinFeeRate filters for packages at or above this effective fee rate, + // enabling high-fee package identification. + MinFeeRate int64 + + // MaxFeeRate filters for packages at or below this effective fee rate, + // useful for low-fee package queries. + MaxFeeRate int64 + + // IsValid filters by validation status. Nil means don't filter, + // enabling queries for valid or invalid packages separately. + IsValid *bool +} + +// InputConfirmedPredicate is a function that checks if a transaction input +// references a confirmed UTXO. This is used to distinguish between: +// - Orphans: transactions with unconfirmed inputs not in the mempool +// - Root transactions: transactions with confirmed inputs (not orphans) +// +// The predicate takes an outpoint and returns true if that output is confirmed +// on-chain, false if it's unconfirmed or doesn't exist. +type InputConfirmedPredicate func(outpoint wire.OutPoint) bool + diff --git a/mempool/txgraph/iterator.go b/mempool/txgraph/iterator.go new file mode 100644 index 0000000000..7ac3410b5c --- /dev/null +++ b/mempool/txgraph/iterator.go @@ -0,0 +1,653 @@ +package txgraph + +import ( + "iter" + + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" +) + +// Iterate returns an iterator over graph nodes. +func (g *TxGraph) Iterate(options ...IterOption) iter.Seq[*TxGraphNode] { + // Build options with defaults. + opts := DefaultIteratorOption() + for _, option := range options { + option(&opts) + } + + return func(yield func(*TxGraphNode) bool) { + g.mu.RLock() + defer g.mu.RUnlock() + + // Get starting node. + var startNode *TxGraphNode + if opts.StartNode != nil { + startNode = g.nodes[*opts.StartNode] + if startNode == nil { + return + } + } + + // Select traversal implementation. + switch opts.Order { + case TraversalDefault: + // Default to iterating all nodes. + for _, node := range g.nodes { + // Apply filter. + if opts.Filter != nil && !opts.Filter(node) { + continue + } + if !yield(node) { + return + } + } + case TraversalDFS: + g.iterateDFS(startNode, opts, yield) + case TraversalBFS: + g.iterateBFS(startNode, opts, yield) + case TraversalTopological: + g.iterateTopological(opts, yield) + case TraversalReverseTopo: + g.iterateReverseTopological(opts, yield) + case TraversalAncestors: + g.iterateAncestors(startNode, opts, yield) + case TraversalDescendants: + g.iterateDescendants(startNode, opts, yield) + case TraversalCluster: + g.iterateCluster(startNode, opts, yield) + case TraversalFeeRate: + g.iterateFeeRate(opts, yield) + } + } +} + +// iterateDFS performs depth-first traversal. +func (g *TxGraph) iterateDFS(start *TxGraphNode, opts IteratorOption, yield func(*TxGraphNode) bool) { + visited := make(map[chainhash.Hash]bool) + stack := NewStack[*TxGraphNode]() + depth := make(map[chainhash.Hash]int) + + // Initialize with start node or all roots. + if start != nil { + if opts.IncludeStart { + stack.Push(start) + depth[start.TxHash] = 0 + } else { + // Mark start node as visited to prevent revisiting. + visited[start.TxHash] = true + depth[start.TxHash] = 0 + // Start with children/parents based on direction. + g.addNeighborsToStack(start, stack, depth, opts.Direction, 1) + } + } else { + // Find all root nodes (no parents). + for _, node := range g.nodes { + if len(node.Parents) == 0 { + stack.Push(node) + depth[node.TxHash] = 0 + } + } + } + + for !stack.IsEmpty() { + // Pop from stack. + node, _ := stack.Pop() + + // Skip if already visited. + if visited[node.TxHash] { + continue + } + visited[node.TxHash] = true + + // Check depth limit. + if opts.MaxDepth >= 0 && depth[node.TxHash] > opts.MaxDepth { + continue + } + + // Yield node to consumer if filter passes. + if opts.Filter == nil || opts.Filter(node) { + if !yield(node) { + return // Consumer wants to stop + } + } + + // Always add neighbors to continue traversal. + nextDepth := depth[node.TxHash] + 1 + g.addNeighborsToStack(node, stack, depth, opts.Direction, nextDepth) + } +} + +// iterateBFS performs breadth-first traversal. +func (g *TxGraph) iterateBFS(start *TxGraphNode, opts IteratorOption, yield func(*TxGraphNode) bool) { + visited := make(map[chainhash.Hash]bool) + queue := NewQueue[*TxGraphNode]() + depth := make(map[chainhash.Hash]int) + + // Initialize with start node or all roots. + if start != nil { + if opts.IncludeStart { + queue.Enqueue(start) + depth[start.TxHash] = 0 + } else { + // Mark start node as visited to prevent revisiting. + visited[start.TxHash] = true + depth[start.TxHash] = 0 + // Start with children/parents based on direction. + g.addNeighborsToQueue(start, queue, depth, opts.Direction, 1) + } + } else { + for _, node := range g.nodes { + if len(node.Parents) == 0 { + queue.Enqueue(node) + depth[node.TxHash] = 0 + } + } + } + + for !queue.IsEmpty() { + // Dequeue from front. + node, _ := queue.Dequeue() + + // Skip if already visited. + if visited[node.TxHash] { + continue + } + visited[node.TxHash] = true + + // Check depth limit. + if opts.MaxDepth >= 0 && depth[node.TxHash] > opts.MaxDepth { + continue + } + + // Yield to consumer if filter passes. + if opts.Filter == nil || opts.Filter(node) { + if !yield(node) { + return + } + } + + // Always add neighbors to continue traversal. + nextDepth := depth[node.TxHash] + 1 + g.addNeighborsToQueue(node, queue, depth, opts.Direction, nextDepth) + } +} + +// iterateTopological performs topological traversal. +func (g *TxGraph) iterateTopological(opts IteratorOption, yield func(*TxGraphNode) bool) { + // Calculate in-degrees. + inDegree := make(map[chainhash.Hash]int) + for hash, node := range g.nodes { + inDegree[hash] = len(node.Parents) + } + + // Find all nodes with no parents. + queue := NewQueue[*TxGraphNode]() + for hash, degree := range inDegree { + if degree == 0 { + queue.Enqueue(g.nodes[hash]) + } + } + + // Process in topological order. + for !queue.IsEmpty() { + node, _ := queue.Dequeue() + + // Apply filter. + if opts.Filter != nil && !opts.Filter(node) { + continue + } + + // Yield to consumer. + if !yield(node) { + return + } + + // Update in-degrees and queue children. + for _, child := range node.Children { + inDegree[child.TxHash]-- + if inDegree[child.TxHash] == 0 { + queue.Enqueue(child) + } + } + } +} + +// iterateReverseTopological performs reverse topological traversal. +func (g *TxGraph) iterateReverseTopological(opts IteratorOption, yield func(*TxGraphNode) bool) { + // Calculate out-degrees. + outDegree := make(map[chainhash.Hash]int) + for hash, node := range g.nodes { + outDegree[hash] = len(node.Children) + } + + // Find all nodes with no children. + queue := NewQueue[*TxGraphNode]() + for hash, degree := range outDegree { + if degree == 0 { + queue.Enqueue(g.nodes[hash]) + } + } + + // Process in reverse topological order. + for !queue.IsEmpty() { + node, _ := queue.Dequeue() + + // Apply filter. + if opts.Filter != nil && !opts.Filter(node) { + continue + } + + // Yield to consumer. + if !yield(node) { + return + } + + // Update out-degrees and queue parents. + for _, parent := range node.Parents { + outDegree[parent.TxHash]-- + if outDegree[parent.TxHash] == 0 { + queue.Enqueue(parent) + } + } + } +} + +// iterateAncestors iterates over all ancestors of a node. +func (g *TxGraph) iterateAncestors(start *TxGraphNode, opts IteratorOption, yield func(*TxGraphNode) bool) { + if start == nil { + return + } + + visited := make(map[chainhash.Hash]bool) + queue := NewQueue[*TxGraphNode]() + depth := make(map[chainhash.Hash]int) + + // Include start node if requested. + if opts.IncludeStart { + queue.Enqueue(start) + depth[start.TxHash] = 0 + } else { + // Start with parents only. + for _, parent := range start.Parents { + if parent != nil { + queue.Enqueue(parent) + depth[parent.TxHash] = 1 + } + } + } + + for !queue.IsEmpty() { + node, _ := queue.Dequeue() + + // Use the hash from the node itself. + nodeHash := node.TxHash + if visited[nodeHash] { + continue + } + visited[nodeHash] = true + + // Check depth limit. + nodeDepth := depth[node.TxHash] + if opts.MaxDepth >= 0 && nodeDepth > opts.MaxDepth { + continue + } + + // Yield to consumer if filter passes. + if opts.Filter == nil || opts.Filter(node) { + if !yield(node) { + return + } + } + + // Add parents to continue traversal (depth starts at 1 for parents). + nextDepth := depth[node.TxHash] + 1 + for _, parent := range node.Parents { + if !visited[parent.TxHash] { + queue.Enqueue(parent) + depth[parent.TxHash] = nextDepth + } + } + } +} + +// iterateDescendants iterates over all descendants of a node. +func (g *TxGraph) iterateDescendants(start *TxGraphNode, opts IteratorOption, yield func(*TxGraphNode) bool) { + if start == nil { + return + } + + visited := make(map[chainhash.Hash]bool) + queue := NewQueue[*TxGraphNode]() + depth := make(map[chainhash.Hash]int) + + // Include start node if requested. + if opts.IncludeStart { + queue.Enqueue(start) + depth[start.TxHash] = 0 + } else { + // Start with children only. + for _, child := range start.Children { + queue.Enqueue(child) + depth[child.TxHash] = 1 + } + } + + for !queue.IsEmpty() { + node, _ := queue.Dequeue() + + if visited[node.TxHash] { + continue + } + visited[node.TxHash] = true + + // Check depth limit. + if opts.MaxDepth >= 0 && depth[node.TxHash] > opts.MaxDepth { + continue + } + + // Yield to consumer if filter passes. + if opts.Filter == nil || opts.Filter(node) { + if !yield(node) { + return + } + } + + // Add children to continue traversal (depth starts at 1 for children). + nextDepth := depth[node.TxHash] + 1 + for _, child := range node.Children { + if !visited[child.TxHash] { + queue.Enqueue(child) + depth[child.TxHash] = nextDepth + } + } + } +} + +// iterateCluster iterates over all nodes in the same cluster. +func (g *TxGraph) iterateCluster(start *TxGraphNode, opts IteratorOption, yield func(*TxGraphNode) bool) { + if start == nil { + // Iterate all clusters. + for _, cluster := range g.indexes.clusters { + for _, node := range cluster.Nodes { + if opts.Filter != nil && !opts.Filter(node) { + continue + } + if !yield(node) { + return + } + } + } + return + } + + // Find cluster for start node. + clusterID, exists := g.indexes.nodeToCluster[start.TxHash] + if !exists { + return + } + + cluster, exists := g.indexes.clusters[clusterID] + if !exists { + return + } + + // Iterate nodes in cluster. + for _, node := range cluster.Nodes { + if opts.Filter != nil && !opts.Filter(node) { + continue + } + if !yield(node) { + return + } + } +} + +// iterateFeeRate iterates nodes ordered by fee rate using a max-heap to +// efficiently yield transactions in descending order of fee rate. +func (g *TxGraph) iterateFeeRate( + opts IteratorOption, + yield func(*TxGraphNode) bool, +) { + + pq := NewPriorityQueue(func(a, b *TxGraphNode) bool { + return a.TxDesc.FeePerKB > b.TxDesc.FeePerKB + }, len(g.nodes)) + + for _, node := range g.nodes { + if opts.Filter != nil && !opts.Filter(node) { + continue + } + pq.Push(node) + } + + for !pq.IsEmpty() { + node, _ := pq.Pop() + if !yield(node) { + return + } + } +} + +// IteratePairs returns an iterator over parent-child pairs. +func (g *TxGraph) IteratePairs(options ...IterOption) iter.Seq[EdgePair] { + // Build options with defaults. + opts := DefaultIteratorOption() + for _, option := range options { + option(&opts) + } + + return func(yield func(EdgePair) bool) { + g.mu.RLock() + defer g.mu.RUnlock() + + visited := make(map[string]bool) // Track visited edges + + // Iterate directly over all nodes in the graph. + for _, node := range g.nodes { + // Apply filter if specified. + if opts.Filter != nil && !opts.Filter(node) { + continue + } + + for _, child := range node.Children { + // Create unique edge key. + edgeKey := node.TxHash.String() + "->" + child.TxHash.String() + if visited[edgeKey] { + continue + } + visited[edgeKey] = true + + // Create edge metadata. + edge := &TxEdge{ + OutPoints: g.findOutpoints(node, child), + Created: node.Metadata.AddedTime, + } + + pair := EdgePair{ + Parent: node, + Child: child, + Edge: edge, + } + + if !yield(pair) { + return + } + } + } + } +} + +// IteratePackages returns an iterator over packages. +func (g *TxGraph) IteratePackages() iter.Seq[*TxPackage] { + return func(yield func(*TxPackage) bool) { + g.mu.RLock() + defer g.mu.RUnlock() + + for _, pkg := range g.indexes.packages { + if !yield(pkg) { + return + } + } + } +} + +// IterateClusters returns an iterator over clusters. +func (g *TxGraph) IterateClusters() iter.Seq[*TxCluster] { + return func(yield func(*TxCluster) bool) { + g.mu.RLock() + defer g.mu.RUnlock() + + for _, cluster := range g.indexes.clusters { + if !yield(cluster) { + return + } + } + } +} + +// IterateOrphans returns an iterator over orphan transactions. +// A transaction is considered an orphan if: +// 1. It has no parents in the graph (len(Parents) == 0) +// 2. AND at least one of its inputs is unconfirmed (as determined by the +// isConfirmed predicate) +// +// If isConfirmed is nil, all transactions with no parents are yielded. This +// is useful when the caller cannot determine chain state and wants to +// identify all potentially orphaned transactions. +func (g *TxGraph) IterateOrphans( + isConfirmed InputConfirmedPredicate, +) iter.Seq[*TxGraphNode] { + return func(yield func(*TxGraphNode) bool) { + g.mu.RLock() + defer g.mu.RUnlock() + + for _, node := range g.nodes { + // Orphans by definition have no parents in the mempool, since + // they're waiting for unconfirmed parent transactions that + // haven't arrived yet. + if len(node.Parents) > 0 { + continue + } + + // Without chain state access, we conservatively treat all + // parentless transactions as potential orphans. + if isConfirmed == nil { + if !yield(node) { + return + } + continue + } + + // Distinguish between true orphans (waiting for unconfirmed + // parents) and root transactions (spending confirmed UTXOs). + // A transaction is only an orphan if at least one input + // references an unconfirmed output not in the mempool. + hasUnconfirmedInput := false + for _, txIn := range node.Tx.MsgTx().TxIn { + if !isConfirmed(txIn.PreviousOutPoint) { + hasUnconfirmedInput = true + break + } + } + + // Root transactions with all confirmed inputs are not orphans, + // as they're not waiting for any parent transactions. + if hasUnconfirmedInput { + if !yield(node) { + return + } + } + } + } +} + +// addNeighborsToStack adds neighbors to DFS stack based on direction. +func (g *TxGraph) addNeighborsToStack( + node *TxGraphNode, + stack *Stack[*TxGraphNode], + depth map[chainhash.Hash]int, + direction TraversalDirection, + nextDepth int, +) { + switch direction { + case DirectionForward: + for _, child := range node.Children { + if _, exists := depth[child.TxHash]; !exists { + stack.Push(child) + depth[child.TxHash] = nextDepth + } + } + case DirectionBackward: + for _, parent := range node.Parents { + if _, exists := depth[parent.TxHash]; !exists { + stack.Push(parent) + depth[parent.TxHash] = nextDepth + } + } + case DirectionBoth: + for _, child := range node.Children { + if _, exists := depth[child.TxHash]; !exists { + stack.Push(child) + depth[child.TxHash] = nextDepth + } + } + for _, parent := range node.Parents { + if _, exists := depth[parent.TxHash]; !exists { + stack.Push(parent) + depth[parent.TxHash] = nextDepth + } + } + } +} + +// addNeighborsToQueue adds neighbors to BFS queue based on direction. +func (g *TxGraph) addNeighborsToQueue( + node *TxGraphNode, + queue *Queue[*TxGraphNode], + depth map[chainhash.Hash]int, + direction TraversalDirection, + nextDepth int, +) { + switch direction { + case DirectionForward: + for _, child := range node.Children { + if _, exists := depth[child.TxHash]; !exists { + queue.Enqueue(child) + depth[child.TxHash] = nextDepth + } + } + case DirectionBackward: + for _, parent := range node.Parents { + if _, exists := depth[parent.TxHash]; !exists { + queue.Enqueue(parent) + depth[parent.TxHash] = nextDepth + } + } + case DirectionBoth: + for _, child := range node.Children { + if _, exists := depth[child.TxHash]; !exists { + queue.Enqueue(child) + depth[child.TxHash] = nextDepth + } + } + for _, parent := range node.Parents { + if _, exists := depth[parent.TxHash]; !exists { + queue.Enqueue(parent) + depth[parent.TxHash] = nextDepth + } + } + } +} + +// findOutpoints finds the outpoints connecting parent to child. +func (g *TxGraph) findOutpoints(parent, child *TxGraphNode) []wire.OutPoint { + var outpoints []wire.OutPoint + + for _, txIn := range child.Tx.MsgTx().TxIn { + if txIn.PreviousOutPoint.Hash == parent.TxHash { + outpoints = append(outpoints, txIn.PreviousOutPoint) + } + } + + return outpoints +} \ No newline at end of file diff --git a/mempool/txgraph/iterator_test.go b/mempool/txgraph/iterator_test.go new file mode 100644 index 0000000000..95bd6b74ed --- /dev/null +++ b/mempool/txgraph/iterator_test.go @@ -0,0 +1,522 @@ +package txgraph + +import ( + "slices" + "testing" + + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/stretchr/testify/require" +) + +// TestIteratorTraversalMethods verifies that different traversal orders +// (ancestors, descendants, fee rate, topological) produce correct results. +// The iterator patterns are critical for mempool operations: ancestor/ +// descendant traversal for policy limits, fee rate sorting for mining, and +// topological ordering for block template construction. +func TestIteratorTraversalMethods(t *testing.T) { + g := New(DefaultConfig()) + + // Create a diamond-shaped DAG to test multiple traversal paths. This + // structure ensures that traversal algorithms handle nodes with + // multiple parents and children correctly. + // tx1 + // / \ + // tx2 tx3 + // \ / + // tx4 + tx1, desc1 := createTestTx(nil, 2) + require.NoError(t, g.AddTransaction(tx1, desc1)) + + tx2, desc2 := createTestTx( + []wire.OutPoint{{Hash: *tx1.Hash(), Index: 0}}, 1, + ) + require.NoError(t, g.AddTransaction(tx2, desc2)) + + tx3, desc3 := createTestTx( + []wire.OutPoint{{Hash: *tx1.Hash(), Index: 1}}, 1, + ) + require.NoError(t, g.AddTransaction(tx3, desc3)) + + tx4, desc4 := createTestTx([]wire.OutPoint{ + {Hash: *tx2.Hash(), Index: 0}, + {Hash: *tx3.Hash(), Index: 0}, + }, 1) + require.NoError(t, g.AddTransaction(tx4, desc4)) + + t.Run("Ancestors", func(t *testing.T) { + tx4Node, exists := g.GetNode(*tx4.Hash()) + require.True(t, exists, "tx4 node should exist") + require.Len(t, tx4Node.Parents, 2, "tx4 should have 2 parents") + + // Ancestor traversal should find all transactions that tx4 + // depends on, including both direct parents (tx2, tx3) and the + // grandparent (tx1). This traversal is used for enforcing BIP + // 125 ancestor count and size limits. + count := 0 + ancestors := make(map[chainhash.Hash]bool) + for node := range g.Iterate( + WithOrder(TraversalAncestors), + WithStartNode(tx4.Hash()), + ) { + t.Logf("Visited ancestor: %v", node.TxHash) + ancestors[node.TxHash] = true + count++ + } + require.Equal(t, 3, count, "Should have 3 ancestors") + require.True(t, ancestors[*tx1.Hash()]) + require.True(t, ancestors[*tx2.Hash()]) + require.True(t, ancestors[*tx3.Hash()]) + }) + + t.Run("Descendants", func(t *testing.T) { + // Descendant traversal should find all transactions that depend + // on tx1, either directly or transitively. This is used for RBF + // validation where we need to compute the total fees of all + // transactions that would be evicted by a replacement. + count := 0 + descendants := make(map[chainhash.Hash]bool) + for node := range g.Iterate( + WithOrder(TraversalDescendants), + WithStartNode(tx1.Hash()), + ) { + descendants[node.TxHash] = true + count++ + } + require.Equal(t, 3, count, "Should have 3 descendants") + require.True(t, descendants[*tx2.Hash()]) + require.True(t, descendants[*tx3.Hash()]) + require.True(t, descendants[*tx4.Hash()]) + }) + + t.Run("FeeRate", func(t *testing.T) { + // Fee rate traversal should iterate transactions in descending + // fee rate order. This is used for block template construction + // where miners want to include the highest fee transactions + // first to maximize revenue. + var lastFeeRate int64 = -1 + for node := range g.Iterate( + WithOrder(TraversalFeeRate), + ) { + if lastFeeRate == -1 { + lastFeeRate = node.TxDesc.FeePerKB + } else { + require.LessOrEqual( + t, node.TxDesc.FeePerKB, lastFeeRate, + ) + lastFeeRate = node.TxDesc.FeePerKB + } + } + }) + + t.Run("ReverseTopological", func(t *testing.T) { + // Reverse topological order visits children before parents. This + // is useful for transaction removal where we must remove + // descendants before ancestors to avoid dangling references. + var order []chainhash.Hash + for node := range g.Iterate( + WithOrder(TraversalReverseTopo), + ) { + order = append(order, node.TxHash) + } + require.Equal(t, 4, len(order)) + + // Verify that children appear before their parents in the + // traversal order, which is required for safe removal. + tx4Idx := -1 + tx2Idx := -1 + tx3Idx := -1 + tx1Idx := -1 + for i, hash := range order { + switch hash { + case *tx4.Hash(): + tx4Idx = i + case *tx2.Hash(): + tx2Idx = i + case *tx3.Hash(): + tx3Idx = i + case *tx1.Hash(): + tx1Idx = i + } + } + require.Less(t, tx4Idx, tx2Idx) + require.Less(t, tx4Idx, tx3Idx) + require.Less(t, tx2Idx, tx1Idx) + require.Less(t, tx3Idx, tx1Idx) + }) + + t.Run("Cluster", func(t *testing.T) { + // Cluster traversal should visit all transactions in the same + // connected component. Since all transactions in this test are + // connected via spending relationships, they form one cluster. + count := 0 + for range g.Iterate( + WithOrder(TraversalCluster), + WithStartNode(tx1.Hash()), + ) { + count++ + } + require.Equal(t, 4, count, "All txs should be in same cluster") + }) + + t.Run("MaxDepth", func(t *testing.T) { + // Depth limiting should stop traversal after a specified number + // of levels. This prevents unbounded traversal in deep + // dependency chains. + count := 0 + for range g.Iterate( + WithOrder(TraversalDescendants), + WithStartNode(tx1.Hash()), + WithMaxDepth(1), + ) { + count++ + } + require.Equal( + t, 2, count, "Should only get direct children with "+ + "depth 1", + ) + }) + + t.Run("Filter", func(t *testing.T) { + // Filter predicates allow selective iteration based on node + // properties. This is useful for finding specific transaction + // patterns without traversing the entire graph. + count := 0 + nodes := make(map[chainhash.Hash]bool) + for node := range g.Iterate( + WithOrder(TraversalBFS), + WithFilter(func(n *TxGraphNode) bool { + // Select only transactions with exactly one + // parent. This pattern identifies simple chains + // without merge points. + return len(n.Parents) == 1 + }), + ) { + nodes[node.TxHash] = true + count++ + } + require.Equal(t, 2, count, "Should only get tx2 and tx3") + require.True(t, nodes[*tx2.Hash()], "Should include tx2") + require.True(t, nodes[*tx3.Hash()], "Should include tx3") + }) + + t.Run("DirectionBackward", func(t *testing.T) { + // Backward direction traverses toward ancestors (parents). This + // is useful for finding all transactions that must confirm + // before a given transaction can be included in a block. + count := 0 + for range g.Iterate( + WithOrder(TraversalDFS), + WithStartNode(tx4.Hash()), + WithDirection(DirectionBackward), + ) { + count++ + } + require.Equal(t, 3, count, "Should traverse backward to parents") + }) + + t.Run("DirectionBoth", func(t *testing.T) { + // Bidirectional traversal visits both parents and children. + // This is useful for analyzing the local neighborhood of a + // transaction without full graph traversal. + count := 0 + visited := make(map[chainhash.Hash]bool) + for node := range g.Iterate( + WithOrder(TraversalBFS), + WithStartNode(tx2.Hash()), + WithDirection(DirectionBoth), + WithMaxDepth(1), + ) { + visited[node.TxHash] = true + count++ + } + require.Equal( + t, 2, count, "Should traverse to both parent and child", + ) + require.True( + t, visited[*tx1.Hash()], "Should visit tx1 (parent)", + ) + require.True( + t, visited[*tx4.Hash()], "Should visit tx4 (child)", + ) + }) +} + +// TestIteratePairs verifies that edge pair iteration produces all parent- +// child relationships in the graph. This is useful for analyzing spending +// patterns and computing aggregate statistics about transaction dependencies. +func TestIteratePairs(t *testing.T) { + g := New(DefaultConfig()) + + tx1, desc1 := createTestTx(nil, 1) + require.NoError(t, g.AddTransaction(tx1, desc1)) + + tx2, desc2 := createTestTx( + []wire.OutPoint{{Hash: *tx1.Hash(), Index: 0}}, 1, + ) + require.NoError(t, g.AddTransaction(tx2, desc2)) + + tx3, desc3 := createTestTx( + []wire.OutPoint{{Hash: *tx2.Hash(), Index: 0}}, 1, + ) + require.NoError(t, g.AddTransaction(tx3, desc3)) + + edges := make(map[string]bool) + for pair := range g.IteratePairs() { + edgeKey := pair.Parent.TxHash.String() + "->" + + pair.Child.TxHash.String() + edges[edgeKey] = true + + // Each edge pair should include metadata about which outputs + // are being spent, enabling detailed dependency analysis. + require.NotNil(t, pair.Edge) + require.NotEmpty(t, pair.Edge.OutPoints) + } + + require.Len(t, edges, 2) + require.True(t, edges[tx1.Hash().String()+"->"+tx2.Hash().String()]) + require.True(t, edges[tx2.Hash().String()+"->"+tx3.Hash().String()]) +} + +// TestIteratePackages verifies that package iteration produces all +// identified transaction packages in the graph. Package iteration enables +// efficient processing of transaction groups for package relay policies and +// mining optimization. +func TestIteratePackages(t *testing.T) { + // Basic 1P1C (one parent, one child) detection works without an + // analyzer, making this test independent of protocol-specific logic. + g := New(DefaultConfig()) + + // Create multiple independent 1P1C packages, which is the most common + // pattern for CPFP (child pays for parent) transactions. + for i := 0; i < 3; i++ { + parent, parentDesc := createTestTx(nil, 1) + require.NoError(t, g.AddTransaction(parent, parentDesc)) + + child, childDesc := createTestTx( + []wire.OutPoint{{Hash: *parent.Hash(), Index: 0}}, 1) + require.NoError(t, g.AddTransaction(child, childDesc)) + } + + identifiedPkgs, err := g.IdentifyPackages() + require.NoError(t, err) + require.Len(t, identifiedPkgs, 3, "Should identify 3 packages") + + packageCount := 0 + for pkg := range g.IteratePackages() { + packageCount++ + require.NotEmpty(t, pkg.ID) + require.NotEmpty(t, pkg.Transactions) + require.NotNil(t, pkg.Topology) + } + + require.Equal(t, 3, packageCount, "Should have 3 packages") +} + +// TestIterateClusters tests cluster iteration. +func TestIterateClusters(t *testing.T) { + g := New(DefaultConfig()) + + // Create 2 separate chains. + // Chain 1. + tx1, desc1 := createTestTx(nil, 1) + require.NoError(t, g.AddTransaction(tx1, desc1)) + tx2, desc2 := createTestTx( + []wire.OutPoint{{Hash: *tx1.Hash(), Index: 0}}, 1, + ) + require.NoError(t, g.AddTransaction(tx2, desc2)) + + // Chain 2. + tx3, desc3 := createTestTx(nil, 1) + require.NoError(t, g.AddTransaction(tx3, desc3)) + tx4, desc4 := createTestTx( + []wire.OutPoint{{Hash: *tx3.Hash(), Index: 0}}, 1, + ) + require.NoError(t, g.AddTransaction(tx4, desc4)) + + // Count clusters. + clusterCount := 0 + totalNodes := 0 + for cluster := range g.IterateClusters() { + clusterCount++ + totalNodes += len(cluster.Nodes) + require.NotEmpty(t, cluster.ID) + } + + require.Equal(t, 2, clusterCount, "Should have 2 clusters") + require.Equal(t, 4, totalNodes, "Should have 4 total nodes") +} + +// TestIterateClusterFromNode tests cluster iteration from specific node. +func TestIterateClusterFromNode(t *testing.T) { + g := New(DefaultConfig()) + + // Create a cluster of 3 connected transactions. + tx1, desc1 := createTestTx(nil, 2) + require.NoError(t, g.AddTransaction(tx1, desc1)) + + tx2, desc2 := createTestTx([]wire.OutPoint{{Hash: *tx1.Hash(), Index: 0}}, 1) + require.NoError(t, g.AddTransaction(tx2, desc2)) + + tx3, desc3 := createTestTx([]wire.OutPoint{{Hash: *tx1.Hash(), Index: 1}}, 1) + require.NoError(t, g.AddTransaction(tx3, desc3)) + + // Create separate transaction. + tx4, desc4 := createTestTx(nil, 1) + require.NoError(t, g.AddTransaction(tx4, desc4)) + + // Iterate cluster containing tx2. + count := 0 + seenHashes := make(map[chainhash.Hash]bool) + for node := range g.Iterate( + WithOrder(TraversalCluster), + WithStartNode(tx2.Hash()), + ) { + count++ + seenHashes[node.TxHash] = true + } + + require.Equal(t, 3, count, "Should see 3 nodes in cluster") + require.True(t, seenHashes[*tx1.Hash()]) + require.True(t, seenHashes[*tx2.Hash()]) + require.True(t, seenHashes[*tx3.Hash()]) + require.False(t, seenHashes[*tx4.Hash()], "Should not see tx4") +} + +// TestTraversalDefault tests the default traversal order. +func TestTraversalDefault(t *testing.T) { + g := New(DefaultConfig()) + + // Add some transactions. + tx1, desc1 := createTestTx(nil, 1) + require.NoError(t, g.AddTransaction(tx1, desc1)) + + tx2, desc2 := createTestTx([]wire.OutPoint{{Hash: *tx1.Hash(), Index: 0}}, 1) + require.NoError(t, g.AddTransaction(tx2, desc2)) + + tx3, desc3 := createTestTx(nil, 1) + require.NoError(t, g.AddTransaction(tx3, desc3)) + + // Test default traversal (should iterate all nodes). + count := 0 + nodes := make(map[string]bool) + for node := range g.Iterate( + WithOrder(TraversalDefault), + ) { + count++ + nodes[node.TxHash.String()] = true + } + + require.Equal(t, 3, count, "Should iterate over all 3 nodes") + require.True(t, nodes[tx1.Hash().String()]) + require.True(t, nodes[tx2.Hash().String()]) + require.True(t, nodes[tx3.Hash().String()]) +} + +// TestIterateTopological tests topological iteration. +func TestIterateTopological(t *testing.T) { + g := New(DefaultConfig()) + + // Create a DAG: + // tx1 -> tx2 -> tx3 + // \-> tx4 + tx1, desc1 := createTestTx(nil, 2) + require.NoError(t, g.AddTransaction(tx1, desc1)) + + tx2, desc2 := createTestTx([]wire.OutPoint{{Hash: *tx1.Hash(), Index: 0}}, 1) + require.NoError(t, g.AddTransaction(tx2, desc2)) + + tx3, desc3 := createTestTx([]wire.OutPoint{{Hash: *tx2.Hash(), Index: 0}}, 1) + require.NoError(t, g.AddTransaction(tx3, desc3)) + + tx4, desc4 := createTestTx([]wire.OutPoint{{Hash: *tx1.Hash(), Index: 1}}, 1) + require.NoError(t, g.AddTransaction(tx4, desc4)) + + // Test topological order. + var order []string + for node := range g.Iterate( + WithOrder(TraversalTopological), + ) { + order = append(order, node.TxHash.String()) + } + + require.Equal(t, 4, len(order)) + + // tx1 must come before tx2 and tx4. + tx1Idx := slices.Index(order, tx1.Hash().String()) + tx2Idx := slices.Index(order, tx2.Hash().String()) + tx3Idx := slices.Index(order, tx3.Hash().String()) + tx4Idx := slices.Index(order, tx4.Hash().String()) + + require.Less(t, tx1Idx, tx2Idx) + require.Less(t, tx1Idx, tx4Idx) + require.Less(t, tx2Idx, tx3Idx) +} + +// TestIterateClusterComplete tests complete cluster iteration. +func TestIterateClusterComplete(t *testing.T) { + g := New(DefaultConfig()) + + // Create a cluster with multiple nodes. + tx1, desc1 := createTestTx(nil, 2) + require.NoError(t, g.AddTransaction(tx1, desc1)) + + tx2, desc2 := createTestTx([]wire.OutPoint{{Hash: *tx1.Hash(), Index: 0}}, 1) + require.NoError(t, g.AddTransaction(tx2, desc2)) + + tx3, desc3 := createTestTx([]wire.OutPoint{{Hash: *tx1.Hash(), Index: 1}}, 1) + require.NoError(t, g.AddTransaction(tx3, desc3)) + + // Create separate cluster. + tx4, desc4 := createTestTx(nil, 1) + require.NoError(t, g.AddTransaction(tx4, desc4)) + + // Test cluster iteration from tx2. + count := 0 + nodes := make(map[string]bool) + for node := range g.Iterate( + WithOrder(TraversalCluster), + WithStartNode(tx2.Hash()), + ) { + count++ + nodes[node.TxHash.String()] = true + } + + require.Equal(t, 3, count, "Should find 3 nodes in cluster") + require.True(t, nodes[tx1.Hash().String()]) + require.True(t, nodes[tx2.Hash().String()]) + require.True(t, nodes[tx3.Hash().String()]) + require.False(t, nodes[tx4.Hash().String()], "tx4 should not be in cluster") +} + +// TestAddNeighborsToStack verifies that stack neighbor addition correctly +// adds parents, children, or both depending on traversal direction. This +// internal helper is critical for DFS/BFS traversal correctness. +func TestAddNeighborsToStack(t *testing.T) { + g := New(DefaultConfig()) + + // Create chain: tx1 -> tx2 -> tx3. + tx1, desc1 := createTestTx(nil, 1) + require.NoError(t, g.AddTransaction(tx1, desc1)) + + tx2, desc2 := createTestTx([]wire.OutPoint{{Hash: *tx1.Hash(), Index: 0}}, 1) + require.NoError(t, g.AddTransaction(tx2, desc2)) + + tx3, desc3 := createTestTx([]wire.OutPoint{{Hash: *tx2.Hash(), Index: 0}}, 1) + require.NoError(t, g.AddTransaction(tx3, desc3)) + + // Test DFS with DirectionBoth. + visited := make(map[string]bool) + for node := range g.Iterate( + WithOrder(TraversalDFS), + WithStartNode(tx2.Hash()), + WithDirection(DirectionBoth), + ) { + visited[node.TxHash.String()] = true + } + + // Should visit both parent (tx1) and child (tx3). + require.True(t, visited[tx1.Hash().String()]) + require.True(t, visited[tx3.Hash().String()]) + require.False(t, visited[tx2.Hash().String()]) // Start node excluded +} diff --git a/mempool/txgraph/mock_analyzer_test.go b/mempool/txgraph/mock_analyzer_test.go new file mode 100644 index 0000000000..16ba77c994 --- /dev/null +++ b/mempool/txgraph/mock_analyzer_test.go @@ -0,0 +1,52 @@ +package txgraph + +import ( + "github.com/btcsuite/btcd/wire" + "github.com/stretchr/testify/mock" +) + +// MockPackageAnalyzer is a mock implementation of PackageAnalyzer for testing. +type MockPackageAnalyzer struct { + mock.Mock +} + +// IsTRUCTransaction checks if transaction is TRUC. +func (m *MockPackageAnalyzer) IsTRUCTransaction(tx *wire.MsgTx) bool { + args := m.Called(tx) + return args.Bool(0) +} + +// HasEphemeralDust checks if transaction has ephemeral dust outputs. +func (m *MockPackageAnalyzer) HasEphemeralDust(tx *wire.MsgTx) bool { + args := m.Called(tx) + return args.Bool(0) +} + +// IsZeroFee checks if transaction has zero fees. +func (m *MockPackageAnalyzer) IsZeroFee(desc *TxDesc) bool { + args := m.Called(desc) + return args.Bool(0) +} + +// ValidateTRUCPackage validates TRUC package topology rules. +func (m *MockPackageAnalyzer) ValidateTRUCPackage(nodes []*TxGraphNode) bool { + args := m.Called(nodes) + return args.Bool(0) +} + +// ValidateEphemeralPackage validates ephemeral dust package rules. +func (m *MockPackageAnalyzer) ValidateEphemeralPackage( + nodes []*TxGraphNode) bool { + + args := m.Called(nodes) + return args.Bool(0) +} + +// AnalyzePackageType determines the type of package based on its structure. +func (m *MockPackageAnalyzer) AnalyzePackageType( + nodes []*TxGraphNode) PackageType { + + args := m.Called(nodes) + return args.Get(0).(PackageType) +} + diff --git a/mempool/txgraph/orphans_test.go b/mempool/txgraph/orphans_test.go new file mode 100644 index 0000000000..f6ef29b438 --- /dev/null +++ b/mempool/txgraph/orphans_test.go @@ -0,0 +1,383 @@ +package txgraph + +import ( + "testing" + + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/stretchr/testify/require" +) + +// TestGetOrphansNoPredicate verifies that orphan detection correctly +// identifies transactions without in-graph parents when no external +// confirmation predicate is provided. This ensures the mempool can +// distinguish between transactions that form complete chains versus those +// awaiting unknown parents. +func TestGetOrphansNoPredicate(t *testing.T) { + g := New(DefaultConfig()) + gen := newTxGenerator() + + // Create a rootless transaction to serve as a graph root, simulating + // a transaction that spends confirmed outputs. + tx1, desc1 := gen.createTx(nil, 2) + require.NoError(t, g.AddTransaction(tx1, desc1)) + + // Create a child transaction to establish that having an in-graph + // parent excludes a transaction from being considered an orphan. + tx2, desc2 := gen.createTx( + []wire.OutPoint{{Hash: *tx1.Hash(), Index: 0}}, 1, + ) + require.NoError(t, g.AddTransaction(tx2, desc2)) + + // Create a transaction with a parent that doesn't exist in the graph + // to verify it's correctly identified as an orphan. + unknownHash := chainhash.Hash{} + for i := range unknownHash { + unknownHash[i] = byte(i) + } + tx3, desc3 := gen.createTx( + []wire.OutPoint{{Hash: unknownHash, Index: 0}}, 1, + ) + require.NoError(t, g.AddTransaction(tx3, desc3)) + + // Without a predicate, only graph topology matters: transactions + // without in-graph parents are orphans. + orphans := g.GetOrphans(nil) + require.Len( + t, orphans, 2, + "should have 2 parentless transactions (tx1 and tx3)", + ) + + // Build a hash set to enable efficient lookup verification of which + // transactions were classified as orphans. + orphanHashes := make(map[chainhash.Hash]bool) + for _, orphan := range orphans { + orphanHashes[orphan.TxHash] = true + } + require.True(t, orphanHashes[*tx1.Hash()]) + require.True(t, orphanHashes[*tx3.Hash()]) + require.False(t, orphanHashes[*tx2.Hash()]) +} + +// TestGetOrphansWithPredicate verifies that providing a confirmation +// predicate allows the mempool to distinguish between transactions spending +// confirmed versus unconfirmed outputs, which is essential for package +// acceptance logic where only truly orphaned transactions need special +// handling. +func TestGetOrphansWithPredicate(t *testing.T) { + g := New(DefaultConfig()) + gen := newTxGenerator() + + confirmedHash := chainhash.Hash{9, 9, 9} + + // Create a transaction spending a confirmed output, which should not + // be treated as an orphan despite having no in-graph parent. + tx1, desc1 := gen.createTx( + []wire.OutPoint{{Hash: confirmedHash, Index: 0}}, 1, + ) + require.NoError(t, g.AddTransaction(tx1, desc1)) + + unconfirmedHash := chainhash.Hash{1, 2, 3} + + // Create a transaction spending an unconfirmed output not in the + // graph, which should be identified as a true orphan. + tx2, desc2 := gen.createTx( + []wire.OutPoint{{Hash: unconfirmedHash, Index: 0}}, 1, + ) + require.NoError(t, g.AddTransaction(tx2, desc2)) + + // Create a child transaction to verify that in-graph parents + // continue to exclude transactions from orphan status. + tx3, desc3 := gen.createTx( + []wire.OutPoint{{Hash: *tx1.Hash(), Index: 0}}, 1, + ) + require.NoError(t, g.AddTransaction(tx3, desc3)) + + // This predicate simulates blockchain state by declaring which + // outputs exist in the UTXO set. + isConfirmed := func(outpoint wire.OutPoint) bool { + return outpoint.Hash == confirmedHash + } + + orphans := g.GetOrphans(isConfirmed) + require.Len( + t, orphans, 1, "only tx2 should be orphan (unconfirmed input)", + ) + require.Equal(t, *tx2.Hash(), orphans[0].TxHash) +} + +// TestGetOrphansAllConfirmed verifies the critical distinction that +// transactions spending only confirmed outputs are never orphans, even when +// they lack in-graph parents. This ensures the mempool correctly handles +// package relay where root transactions spending confirmed outputs don't +// need special orphan processing. +func TestGetOrphansAllConfirmed(t *testing.T) { + g := New(DefaultConfig()) + gen := newTxGenerator() + + confirmedHash := chainhash.Hash{9, 9, 9} + + // Create a transaction that has no in-graph parent but spends a + // confirmed output, testing the predicate overrides topology-based + // orphan detection. + tx1, desc1 := gen.createTx( + []wire.OutPoint{{Hash: confirmedHash, Index: 0}}, 1, + ) + require.NoError(t, g.AddTransaction(tx1, desc1)) + + // The predicate declares this output as confirmed, meaning it exists + // in the UTXO set and doesn't need a parent transaction. + isConfirmed := func(outpoint wire.OutPoint) bool { + return outpoint.Hash == confirmedHash + } + + // Despite lacking an in-graph parent, the confirmed input means this + // transaction is not awaiting any missing dependencies. + orphans := g.GetOrphans(isConfirmed) + require.Len( + t, orphans, 0, "tx1 should not be an orphan (confirmed input)", + ) +} + +// TestGetOrphansMultipleInputs verifies that orphan detection uses AND +// semantics for multiple inputs: a transaction is an orphan if ANY input is +// unconfirmed and not in the graph. This matters for package relay where +// transactions with partially confirmed inputs still need their unconfirmed +// parent transactions included in the package. +func TestGetOrphansMultipleInputs(t *testing.T) { + g := New(DefaultConfig()) + gen := newTxGenerator() + + confirmedHash := chainhash.Hash{9, 9, 9} + unconfirmedHash := chainhash.Hash{1, 2, 3} + + // Create a transaction with mixed input types to verify that the + // presence of even one unconfirmed input makes it an orphan. + tx1, desc1 := gen.createTx([]wire.OutPoint{ + {Hash: confirmedHash, Index: 0}, + {Hash: unconfirmedHash, Index: 0}, + }, 1) + require.NoError(t, g.AddTransaction(tx1, desc1)) + + // The predicate only recognizes one of the two inputs as confirmed. + isConfirmed := func(outpoint wire.OutPoint) bool { + return outpoint.Hash == confirmedHash + } + + // The transaction is an orphan because it's waiting on the + // unconfirmed input, regardless of the confirmed input. + orphans := g.GetOrphans(isConfirmed) + require.Len( + t, orphans, 1, "should be orphan due to unconfirmed input", + ) + require.Equal(t, *tx1.Hash(), orphans[0].TxHash) +} + +// TestIterateOrphansNoPredicate verifies that the iterator-based orphan +// detection provides the same semantics as batch retrieval but enables +// memory-efficient processing of large orphan sets without materializing +// the entire collection. +func TestIterateOrphansNoPredicate(t *testing.T) { + g := New(DefaultConfig()) + gen := newTxGenerator() + + // Create a rootless transaction to establish one orphan candidate. + tx1, desc1 := gen.createTx(nil, 1) + require.NoError(t, g.AddTransaction(tx1, desc1)) + + // Create a child transaction to verify the iterator correctly + // excludes transactions with in-graph parents. + tx2, desc2 := gen.createTx( + []wire.OutPoint{{Hash: *tx1.Hash(), Index: 0}}, 1, + ) + require.NoError(t, g.AddTransaction(tx2, desc2)) + + // Create another orphan with an unknown parent to test that the + // iterator finds all parentless transactions. + unknownHash := chainhash.Hash{1, 2, 3} + tx3, desc3 := gen.createTx( + []wire.OutPoint{{Hash: unknownHash, Index: 0}}, 1, + ) + require.NoError(t, g.AddTransaction(tx3, desc3)) + + // Collect orphans via iteration, which provides the same results as + // GetOrphans but allows early termination. + var orphans []*TxGraphNode + for orphan := range g.IterateOrphans(nil) { + orphans = append(orphans, orphan) + } + require.Len( + t, orphans, 2, "should iterate over 2 parentless transactions", + ) + + // Build a hash set to verify that iteration found the correct + // orphans and excluded the child transaction. + orphanHashes := make(map[chainhash.Hash]bool) + for _, orphan := range orphans { + orphanHashes[orphan.TxHash] = true + } + require.True(t, orphanHashes[*tx1.Hash()]) + require.True(t, orphanHashes[*tx3.Hash()]) + require.False(t, orphanHashes[*tx2.Hash()]) +} + +// TestIterateOrphansWithPredicate verifies that the iterator correctly +// applies confirmation predicates during traversal, enabling efficient +// streaming identification of transactions that need parent resolution +// during package relay without buffering all candidates. +func TestIterateOrphansWithPredicate(t *testing.T) { + g := New(DefaultConfig()) + gen := newTxGenerator() + + confirmedHash := chainhash.Hash{9, 9, 9} + unconfirmedHash := chainhash.Hash{1, 2, 3} + + // Create a transaction spending a confirmed output to verify the + // predicate excludes it from iteration results. + tx1, desc1 := gen.createTx( + []wire.OutPoint{{Hash: confirmedHash, Index: 0}}, 1, + ) + require.NoError(t, g.AddTransaction(tx1, desc1)) + + // Create a transaction spending an unconfirmed output to ensure the + // iterator identifies it as requiring parent resolution. + tx2, desc2 := gen.createTx( + []wire.OutPoint{{Hash: unconfirmedHash, Index: 0}}, 1, + ) + require.NoError(t, g.AddTransaction(tx2, desc2)) + + // The predicate simulates UTXO set lookup during iteration. + isConfirmed := func(outpoint wire.OutPoint) bool { + return outpoint.Hash == confirmedHash + } + + var orphans []*TxGraphNode + for orphan := range g.IterateOrphans(isConfirmed) { + orphans = append(orphans, orphan) + } + require.Len( + t, orphans, 1, + "should iterate over 1 orphan with unconfirmed input", + ) + require.Equal(t, *tx2.Hash(), orphans[0].TxHash) +} + +// TestIterateOrphansEarlyStop verifies that the iterator properly supports +// early termination via break, which is critical for implementing efficient +// resource-limited operations like "find first N orphans" without +// traversing the entire graph. +func TestIterateOrphansEarlyStop(t *testing.T) { + g := New(DefaultConfig()) + gen := newTxGenerator() + + // Create multiple orphan transactions to verify early termination + // prevents processing all of them. + for i := 0; i < 5; i++ { + tx, desc := gen.createTx(nil, 1) + require.NoError(t, g.AddTransaction(tx, desc)) + } + + // Break out of iteration early to demonstrate that the iterator + // doesn't force full traversal of the orphan set. + count := 0 + for range g.IterateOrphans(nil) { + count++ + if count >= 2 { + break + } + } + + require.Equal(t, 2, count, "should stop iteration after 2 orphans") +} + +// TestGetOrphansEmptyGraph verifies the boundary condition that an empty +// graph correctly reports no orphans, ensuring the detection logic handles +// degenerate cases that may occur during mempool initialization or after +// block acceptance clears all transactions. +func TestGetOrphansEmptyGraph(t *testing.T) { + g := New(DefaultConfig()) + + orphans := g.GetOrphans(nil) + require.Len(t, orphans, 0, "empty graph should have no orphans") +} + +// TestIterateOrphansEmptyGraph verifies the iterator's boundary condition +// handling by ensuring iteration over an empty graph immediately completes +// without yielding any values, which is essential for correct behavior in +// newly initialized or fully cleared mempools. +func TestIterateOrphansEmptyGraph(t *testing.T) { + g := New(DefaultConfig()) + + var orphans []*TxGraphNode + for orphan := range g.IterateOrphans(nil) { + orphans = append(orphans, orphan) + } + require.Len(t, orphans, 0, "empty graph should yield no orphans") +} + +// TestGetOrphansNoOrphans verifies the case where graph topology changes +// (transaction removal) would create orphans, but a confirmation predicate +// indicates the missing parent is now in the blockchain. This models the +// scenario where a block confirms a transaction that had children in the +// mempool. +func TestGetOrphansNoOrphans(t *testing.T) { + g := New(DefaultConfig()) + gen := newTxGenerator() + + // Create a parent transaction that will be removed to simulate block + // confirmation. + tx1, desc1 := gen.createTx(nil, 1) + require.NoError(t, g.AddTransaction(tx1, desc1)) + + // Create a child transaction that will lose its in-graph parent but + // should not become an orphan if the parent is confirmed. + tx2, desc2 := gen.createTx( + []wire.OutPoint{{Hash: *tx1.Hash(), Index: 0}}, 1, + ) + require.NoError(t, g.AddTransaction(tx2, desc2)) + + // Simulate block confirmation by declaring the removed transaction's + // outputs as now existing in the UTXO set. + isConfirmed := func(outpoint wire.OutPoint) bool { + return outpoint.Hash == *tx1.Hash() + } + + require.NoError(t, g.RemoveTransaction(*tx1.Hash())) + + orphans := g.GetOrphans(isConfirmed) + require.Len( + t, orphans, 0, "tx2 should not be orphan (confirmed input)", + ) +} + +// TestGetOrphansAllInputsConfirmed verifies that when a transaction has +// multiple inputs, the AND semantics work in reverse: a transaction is NOT +// an orphan if ALL inputs are confirmed. This ensures transactions +// combining multiple confirmed outputs don't incorrectly get flagged for +// special orphan handling. +func TestGetOrphansAllInputsConfirmed(t *testing.T) { + g := New(DefaultConfig()) + gen := newTxGenerator() + + confirmedHash1 := chainhash.Hash{1, 1, 1} + confirmedHash2 := chainhash.Hash{2, 2, 2} + + // Create a transaction spending multiple confirmed outputs to verify + // that satisfying all inputs excludes it from orphan status. + tx1, desc1 := gen.createTx([]wire.OutPoint{ + {Hash: confirmedHash1, Index: 0}, + {Hash: confirmedHash2, Index: 0}, + }, 1) + require.NoError(t, g.AddTransaction(tx1, desc1)) + + // The predicate confirms both inputs exist in the UTXO set. + isConfirmed := func(outpoint wire.OutPoint) bool { + return outpoint.Hash == confirmedHash1 || + outpoint.Hash == confirmedHash2 + } + + orphans := g.GetOrphans(isConfirmed) + require.Len( + t, orphans, 0, "should not be orphan (all inputs confirmed)", + ) +} \ No newline at end of file diff --git a/mempool/txgraph/package.go b/mempool/txgraph/package.go new file mode 100644 index 0000000000..8b59d1eb29 --- /dev/null +++ b/mempool/txgraph/package.go @@ -0,0 +1,752 @@ +package txgraph + +import ( + "bytes" + "crypto/sha256" + "sort" + "sync/atomic" + + "github.com/btcsuite/btcd/chaincfg/chainhash" +) + +// IdentifyPackages scans the graph to detect and classify transaction +// packages according to their type (1P1C, TRUC, ephemeral, standard). +// Package classification enables the relay system to apply appropriate +// validation rules based on BIP 431 (TRUC) and ephemeral dust policies. +// We prioritize more specific package types (TRUC, ephemeral) before +// checking general types (1P1C, standard) because stricter validation +// rules must be enforced when applicable to ensure consensus +// compatibility. +func (g *TxGraph) IdentifyPackages() ([]*TxPackage, error) { + g.mu.RLock() + defer g.mu.RUnlock() + + packages := make([]*TxPackage, 0) + processed := make(map[chainhash.Hash]bool) + + for hash, node := range g.nodes { + if processed[hash] { + continue + } + + // Package roots have no unconfirmed parents, making them the + // starting point for package formation. Non-roots are + // processed as children of their parent packages. + if !g.isPackageRoot(node) { + continue + } + + // Try to form different package types in priority order. + // Specific types like TRUC and ephemeral have stricter + // topology and fee requirements that override the more + // permissive standard package rules. + // + // TODO(rosabeef): revisit + if pkg := g.tryTRUCPackage(node); pkg != nil { + packages = append(packages, pkg) + g.markPackageProcessed(pkg, processed) + + } else if pkg := g.tryEphemeralPackage(node); pkg != nil { + packages = append(packages, pkg) + g.markPackageProcessed(pkg, processed) + + } else if pkg := g.try1P1CPackage(node); pkg != nil { + packages = append(packages, pkg) + g.markPackageProcessed(pkg, processed) + + } else if pkg := g.tryStandardPackage(node); pkg != nil { + + packages = append(packages, pkg) + g.markPackageProcessed(pkg, processed) + } + } + + return packages, nil +} + +// isPackageRoot determines whether a node serves as the starting point +// for package formation. Roots are chosen as transactions with no +// unconfirmed parents because this ensures we process packages from their +// topological origin, allowing proper parent-child relationship +// validation and preventing double-counting of descendants. +func (g *TxGraph) isPackageRoot(node *TxGraphNode) bool { + // Only nodes without unconfirmed parents qualify as package roots. + // Nodes with parents will be included when their parent package is + // formed. + return len(node.Parents) == 0 +} + +// try1P1CPackage attempts to form a 1-parent-1-child package for CPFP +// fee bumping. This package type is preferred in relay policy because it +// allows efficient validation and ensures predictable topology. The +// strict 1P1C relationship guarantees that fee calculations are +// unambiguous and that the package cannot be fragmented during relay. +func (g *TxGraph) try1P1CPackage(node *TxGraphNode) *TxPackage { + // Require exactly one child to maintain the 1P1C topology invariant. + if len(node.Children) != 1 { + return nil + } + + // Extract the single child from the map. + var child *TxGraphNode + for _, c := range node.Children { + child = c + break + } + + // Verify the child has only this parent in the mempool. Multiple + // unconfirmed parents would violate 1P1C topology and complicate + // fee calculation for package relay. + unconfirmedParents := 0 + for _, parent := range child.Parents { + if _, exists := g.nodes[parent.TxHash]; exists { + unconfirmedParents++ + } + } + + if unconfirmedParents != 1 { + return nil + } + + // Construct the package with the parent as root to establish proper + // topological ordering for relay and mining. + pkg := &TxPackage{ + Type: PackageType1P1C, + Transactions: make(map[chainhash.Hash]*TxGraphNode), + Root: node, + } + + pkg.Transactions[node.TxHash] = node + pkg.Transactions[child.TxHash] = child + + // Compute aggregate package metrics for fee rate evaluation. Package + // fee rate determines mining priority and relay acceptance. + pkg.TotalFees = node.TxDesc.Fee + child.TxDesc.Fee + pkg.TotalSize = node.TxDesc.VirtualSize + child.TxDesc.VirtualSize + if pkg.TotalSize > 0 { + pkg.FeeRate = pkg.TotalFees * 1000 / pkg.TotalSize + } + + // Generate deterministic package ID for deduplication and indexing. + pkg.ID = g.generatePackageID(pkg) + + // Record topology characteristics for validation and optimization. + // 1P1C packages always form a simple linear chain. + pkg.Topology = PackageTopology{ + MaxDepth: 1, + MaxWidth: 1, + TotalNodes: 2, + IsLinear: true, + IsTree: true, + } + + return pkg +} + +// tryTRUCPackage attempts to form a TRUC (v3 transaction) package +// according to BIP 431. TRUC transactions use version 3 to signal opt-in +// topology restrictions that enable more efficient package relay and RBF. +// The restricted topology (max 1 parent, 1 child) prevents pinning +// attacks while maintaining CPFP fee bumping capability. +func (g *TxGraph) tryTRUCPackage(node *TxGraphNode) *TxPackage { + // Require analyzer to be configured for TRUC validation. Without + // the analyzer, we cannot enforce BIP 431 rules. + if g.analyzer == nil { + return nil + } + + // Verify this is a version 3 transaction signaling TRUC opt-in. + if !g.analyzer.IsTRUCTransaction(node.Tx.MsgTx()) { + return nil + } + + // Enforce BIP 431 topology restriction: TRUC transactions cannot + // have multiple children as this would enable pinning vectors. + if len(node.Children) > 1 { + return nil + } + + pkg := &TxPackage{ + Type: PackageTypeTRUC, + Transactions: make(map[chainhash.Hash]*TxGraphNode), + Root: node, + } + + // Add the root TRUC transaction to the package. + pkg.Transactions[node.TxHash] = node + totalFees := node.TxDesc.Fee + totalSize := node.TxDesc.VirtualSize + nodeCount := 1 + maxDepth := 0 + + // Include the TRUC child if present, maintaining the restricted + // topology required by BIP 431. + if len(node.Children) == 1 { + var child *TxGraphNode + for _, c := range node.Children { + child = c + break + } + + // Both parent and child must be v3 to form a valid TRUC + // package. Mixing v3 and non-v3 transactions would violate + // BIP 431 and create relay policy ambiguity. + if !g.analyzer.IsTRUCTransaction(child.Tx.MsgTx()) { + return nil + } + + pkg.Transactions[child.TxHash] = child + totalFees += child.TxDesc.Fee + totalSize += child.TxDesc.VirtualSize + nodeCount++ + maxDepth = 1 + } + + pkg.TotalFees = totalFees + pkg.TotalSize = totalSize + if pkg.TotalSize > 0 { + pkg.FeeRate = pkg.TotalFees * 1000 / pkg.TotalSize + } + + // Generate deterministic package ID for indexing and deduplication. + pkg.ID = g.generatePackageID(pkg) + + // Record topology metadata. TRUC packages are always linear + // chains due to the 1-parent-1-child restriction enforced above. + pkg.Topology = PackageTopology{ + MaxDepth: maxDepth, + MaxWidth: 1, + TotalNodes: nodeCount, + IsLinear: true, + IsTree: true, + } + + return pkg +} + +// tryEphemeralPackage attempts to form an ephemeral dust package. +// These packages allow zero-fee parent transactions with dust outputs +// (typically P2A anchors) that must be spent by a child transaction +// within the same package. This pattern enables efficient fee bumping +// for commitment transactions while preventing dust accumulation on the +// UTXO set. +func (g *TxGraph) tryEphemeralPackage(node *TxGraphNode) *TxPackage { + // Require analyzer for ephemeral dust detection and validation. + if g.analyzer == nil { + return nil + } + + // Verify the parent contains ephemeral dust outputs that must be + // spent within the package to prevent UTXO set pollution. + if !g.analyzer.HasEphemeralDust(node.Tx.MsgTx()) { + return nil + } + + // Ephemeral parents typically have zero fee since they rely on + // CPFP for mining incentive. Non-zero fee would be economically + // wasteful. + if !g.analyzer.IsZeroFee(node.TxDesc) { + return nil + } + + // Require at least one child to spend the dust output. Unspent + // dust violates the ephemeral package contract and would persist + // in the UTXO set. + if len(node.Children) == 0 { + return nil + } + + pkg := &TxPackage{ + Type: PackageTypeEphemeral, + Transactions: make(map[chainhash.Hash]*TxGraphNode), + Root: node, + } + + // Add the parent transaction containing ephemeral dust outputs. + pkg.Transactions[node.TxHash] = node + totalFees := node.TxDesc.Fee + totalSize := node.TxDesc.VirtualSize + + // Include all children that spend the ephemeral outputs. Multiple + // children may be present if the dust spending is batched with + // other operations. + for _, child := range node.Children { + pkg.Transactions[child.TxHash] = child + totalFees += child.TxDesc.Fee + totalSize += child.TxDesc.VirtualSize + } + + pkg.TotalFees = totalFees + pkg.TotalSize = totalSize + if pkg.TotalSize > 0 { + pkg.FeeRate = pkg.TotalFees * 1000 / pkg.TotalSize + } + + // Generate deterministic package ID for tracking and deduplication. + pkg.ID = g.generatePackageID(pkg) + + // Compute topology metrics for validation. Ephemeral packages + // may have multiple children unlike TRUC packages. + pkg.Topology = g.calculateTopology(pkg) + + return pkg +} + +// tryStandardPackage attempts to form a standard package from connected +// transactions. Standard packages serve as the fallback for transactions +// that don't qualify for specialized types. We use BFS traversal to +// discover the full connected component up to the maximum package size, +// enabling package relay for arbitrary transaction topologies. +func (g *TxGraph) tryStandardPackage(node *TxGraphNode) *TxPackage { + // Use breadth-first search to explore the connected component. + // BFS ensures we discover transactions in topological order and + // can efficiently limit package size. + visited := make(map[chainhash.Hash]bool) + queue := []*TxGraphNode{node} + + pkg := &TxPackage{ + Type: PackageTypeStandard, + Transactions: make(map[chainhash.Hash]*TxGraphNode), + Root: node, + } + + totalFees := int64(0) + totalSize := int64(0) + + for len(queue) > 0 && + len(pkg.Transactions) < g.config.MaxPackageSize { + + current := queue[0] + queue = queue[1:] + + if visited[current.TxHash] { + continue + } + visited[current.TxHash] = true + + // Include this transaction in the package and update + // metrics. + pkg.Transactions[current.TxHash] = current + totalFees += current.TxDesc.Fee + totalSize += current.TxDesc.VirtualSize + + // Traverse descendants to include child transactions in the + // package for CPFP consideration. + for _, child := range current.Children { + if !visited[child.TxHash] { + queue = append(queue, child) + } + } + + // Also traverse ancestors to capture the full connected + // component. This ensures we don't fragment packages that + // have complex dependency relationships. + for _, parent := range current.Parents { + if !visited[parent.TxHash] { + queue = append(queue, parent) + } + } + } + + // Single-transaction packages provide no relay benefit. Package + // formation is only useful when grouping related transactions for + // aggregate fee calculation. + if len(pkg.Transactions) <= 1 { + return nil + } + + pkg.TotalFees = totalFees + pkg.TotalSize = totalSize + if pkg.TotalSize > 0 { + pkg.FeeRate = pkg.TotalFees * 1000 / pkg.TotalSize + } + + // Generate deterministic package ID based on constituent + // transactions. + pkg.ID = g.generatePackageID(pkg) + + // Compute topology metrics for validation and optimization + // decisions. + pkg.Topology = g.calculateTopology(pkg) + + return pkg +} + +// generatePackageID generates a deterministic unique identifier for a +// package by hashing its constituent transaction IDs. Determinism is +// essential to enable package deduplication across peers and to provide +// stable references for package tracking during relay. Lexicographic +// sorting ensures the same set of transactions always produces the same +// ID regardless of discovery order. +func (g *TxGraph) generatePackageID(pkg *TxPackage) PackageID { + // Collect all transaction hashes from the package. + hashes := make([]chainhash.Hash, 0, len(pkg.Transactions)) + for hash := range pkg.Transactions { + hashes = append(hashes, hash) + } + + // Sort lexicographically to ensure deterministic ID generation. + // Without sorting, the iteration order of the map would produce + // non-deterministic results. + sort.Slice(hashes, func(i, j int) bool { + return bytes.Compare(hashes[i][:], hashes[j][:]) < 0 + }) + + // Concatenate sorted hashes and compute SHA256 to create a + // compact unique identifier. + var buf bytes.Buffer + for _, hash := range hashes { + buf.Write(hash[:]) + } + + packageHash := sha256.Sum256(buf.Bytes()) + + return PackageID{ + Hash: chainhash.Hash(packageHash), + Type: pkg.Type, + } +} + +// calculateTopology analyzes the structure of a package to compute its +// topological properties. These metrics inform validation decisions +// (depth limits), mining optimization (linear chains are simpler), and +// relay policy (tree structures avoid diamond dependencies). We use BFS +// from the root to measure depth and width at each level. +func (g *TxGraph) calculateTopology(pkg *TxPackage) PackageTopology { + topology := PackageTopology{ + TotalNodes: len(pkg.Transactions), + } + + // Traverse from root using BFS to compute depth and width + // metrics. Without a root, we cannot determine meaningful + // topology. + if pkg.Root != nil { + visited := make(map[chainhash.Hash]int) + queue := []struct { + node *TxGraphNode + depth int + }{{pkg.Root, 0}} + + maxDepth := 0 + widthByDepth := make(map[int]int) + + for len(queue) > 0 { + current := queue[0] + queue = queue[1:] + + if _, exists := visited[current.node.TxHash]; exists { + continue + } + visited[current.node.TxHash] = current.depth + + widthByDepth[current.depth]++ + if current.depth > maxDepth { + maxDepth = current.depth + } + + // Explore children to measure package depth. Only + // include children that are part of this package to + // avoid counting external dependencies. + for _, child := range current.node.Children { + _, exists := pkg.Transactions[child.TxHash] + if exists { + if _, visited := visited[child.TxHash]; !visited { + queue = append( + queue, struct { + node *TxGraphNode + depth int + }{child, current.depth + 1}, + ) + } + } + } + } + + topology.MaxDepth = maxDepth + + // Determine maximum width (most siblings at any depth + // level). High width indicates parallel transaction + // structure. + for _, width := range widthByDepth { + if width > topology.MaxWidth { + topology.MaxWidth = width + } + } + + // Check if package forms a linear chain. Linear chains are + // simpler to validate and optimize for block template + // construction. + topology.IsLinear = true + for _, width := range widthByDepth { + if width > 1 { + topology.IsLinear = false + break + } + } + + // Verify tree structure (no diamond dependencies). Tree + // topology simplifies relay and prevents pinning attacks. + topology.IsTree = g.isPackageTree(pkg) + } + + return topology +} + +// isPackageTree verifies that a package forms a tree structure without +// diamond patterns (nodes with multiple parents). Tree topology is +// desirable because it prevents dependency cycles and simplifies relay +// logic. Diamond patterns create ambiguity in fee assignment and can +// enable pinning attacks where an attacker creates complex dependencies +// to prevent package relay. +func (g *TxGraph) isPackageTree(pkg *TxPackage) bool { + // Scan all nodes to detect any with multiple parents within the + // package. Multiple parents indicate a diamond pattern where + // dependencies merge. + for _, node := range pkg.Transactions { + parentsInPackage := 0 + for parentHash := range node.Parents { + if _, exists := pkg.Transactions[parentHash]; exists { + parentsInPackage++ + if parentsInPackage > 1 { + return false + } + } + } + } + return true +} + +// markPackageProcessed records that a package has been formed and +// indexes all its transactions. This prevents duplicate package formation +// when a transaction could belong to multiple overlapping packages, +// ensuring each transaction is counted in exactly one package. The +// indexes enable efficient lookup of packages by transaction ID during +// relay and validation. +func (g *TxGraph) markPackageProcessed( + pkg *TxPackage, processed map[chainhash.Hash]bool) { + + for hash := range pkg.Transactions { + processed[hash] = true + + // Map each transaction to its package for reverse lookup + // during validation and relay operations. + g.indexes.nodeToPackage[hash] = pkg.ID + } + + // Store package in the global index to enable iteration over + // all packages for mempool queries and eviction. + g.indexes.packages[pkg.ID] = pkg + + // Increment package count for monitoring and metrics collection. + atomic.AddInt32(&g.metrics.packageCount, 1) +} + +// isPackageConnected verifies that all nodes in a package form a +// single connected component. Package relay requires connectivity because +// disconnected transactions cannot provide CPFP benefits to each other +// and should be treated as separate packages. We use BFS to test +// reachability from an arbitrary starting node to all other nodes. +func (g *TxGraph) isPackageConnected(nodes []*TxGraphNode) bool { + if len(nodes) <= 1 { + return true + } + + // Build lookup table for O(1) membership testing during + // traversal. + nodeSet := make(map[chainhash.Hash]bool) + for _, node := range nodes { + nodeSet[node.TxHash] = true + } + + // Use BFS from an arbitrary start node to test reachability. If + // all nodes are reachable, the graph is connected. + visited := make(map[chainhash.Hash]bool) + queue := []*TxGraphNode{nodes[0]} + visited[nodes[0].TxHash] = true + + for len(queue) > 0 { + current := queue[0] + queue = queue[1:] + + // Traverse parent edges to find ancestors in the package. + for _, parent := range current.Parents { + if nodeSet[parent.TxHash] && !visited[parent.TxHash] { + visited[parent.TxHash] = true + queue = append(queue, parent) + } + } + + // Traverse child edges to find descendants in the package. + for _, child := range current.Children { + if nodeSet[child.TxHash] && !visited[child.TxHash] { + visited[child.TxHash] = true + queue = append(queue, child) + } + } + } + + // If we visited all nodes, the graph is connected. Unvisited + // nodes indicate disconnected components. + return len(visited) == len(nodes) +} + +// CreatePackage constructs a TxPackage from a specified set of graph +// nodes, providing explicit control over package formation. This is used +// during relay validation when incoming packages specify their +// membership, or during RBF evaluation when comparing conflicting package +// groups. The function enforces connectivity requirements, computes +// topology metrics, and classifies package type based on BIP 431 and +// ephemeral dust rules to ensure appropriate relay policies are applied. +func (g *TxGraph) CreatePackage(nodes []*TxGraphNode) (*TxPackage, error) { + if len(nodes) == 0 { + return nil, ErrInvalidTopology + } + + if len(nodes) > g.config.MaxPackageSize { + return nil, ErrInvalidTopology + } + + pkg := &TxPackage{ + Type: PackageTypeStandard, + Transactions: make(map[chainhash.Hash]*TxGraphNode), + } + + totalFees := int64(0) + totalSize := int64(0) + + // Populate the package with all provided nodes and compute + // aggregate metrics. + for _, node := range nodes { + pkg.Transactions[node.TxHash] = node + totalFees += node.TxDesc.Fee + totalSize += node.TxDesc.VirtualSize + } + + // Enforce connectivity requirement for multi-transaction + // packages. Disconnected transactions cannot provide CPFP + // benefits and should be separate packages. + if len(nodes) > 1 { + if !g.isPackageConnected(nodes) { + return nil, ErrDisconnectedPackage + } + } + + // Identify the root node (no parents within package) to + // establish topological ordering for validation and relay. The + // root serves as the starting point for depth-first traversal. + for _, node := range nodes { + hasParentInPackage := false + for parentHash := range node.Parents { + if _, exists := pkg.Transactions[parentHash]; exists { + hasParentInPackage = true + break + } + } + if !hasParentInPackage { + pkg.Root = node + break + } + } + + // Classify package type (TRUC, ephemeral, 1P1C, standard) to + // determine which relay policies apply. + pkg.Type = g.determinePackageType(pkg) + + // Finalize package metrics for fee rate evaluation and relay + // decisions. + pkg.TotalFees = totalFees + pkg.TotalSize = totalSize + if pkg.TotalSize > 0 { + pkg.FeeRate = pkg.TotalFees * 1000 / pkg.TotalSize + } + + // Generate deterministic package ID for tracking and deduplication. + pkg.ID = g.generatePackageID(pkg) + + // Compute topology characteristics for policy enforcement. + pkg.Topology = g.calculateTopology(pkg) + + // Run validation rules appropriate to the package type. + // Validation may fail if topology constraints are violated or + // required conditions (e.g., dust spending) are not met. + if err := g.ValidatePackage(pkg); err != nil { + return nil, err + } + + pkg.IsValid = true + pkg.LastValidated = nodes[0].Metadata.AddedTime + + return pkg, nil +} + +// determinePackageType classifies a package by examining its topology +// and transaction properties. Classification determines which relay +// policies apply: TRUC packages follow BIP 431 rules, ephemeral packages +// must spend dust outputs, 1P1C packages get simplified validation, and +// standard packages use default rules. We check types in priority order +// from most to least restrictive. +func (g *TxGraph) determinePackageType(pkg *TxPackage) PackageType { + // Check for 1-parent-1-child topology by verifying exactly two + // transactions with a single parent-child relationship. + if len(pkg.Transactions) == 2 { + parentCount := 0 + childCount := 0 + for _, node := range pkg.Transactions { + hasParentInPkg := false + hasChildInPkg := false + + for parentHash := range node.Parents { + _, exists := pkg.Transactions[parentHash] + if exists { + hasParentInPkg = true + break + } + } + + for childHash := range node.Children { + _, exists := pkg.Transactions[childHash] + if exists { + hasChildInPkg = true + break + } + } + + if !hasParentInPkg { + parentCount++ + } + if !hasChildInPkg { + childCount++ + } + } + + if parentCount == 1 && childCount == 1 { + return PackageType1P1C + } + } + + // Check for TRUC by verifying all transactions are version 3. + // Mixed versions are not allowed in TRUC packages per BIP 431. + allTRUC := true + for _, node := range pkg.Transactions { + if !node.Metadata.IsTRUC { + allTRUC = false + break + } + } + if allTRUC { + return PackageTypeTRUC + } + + // Check for ephemeral dust. If any transaction creates + // ephemeral outputs, the entire package is classified as + // ephemeral and must meet dust spending requirements. + for _, node := range pkg.Transactions { + if node.Metadata.IsEphemeral { + return PackageTypeEphemeral + } + } + + // Default to standard package type for all other cases. + return PackageTypeStandard +} + diff --git a/mempool/txgraph/package_analyzer.go b/mempool/txgraph/package_analyzer.go new file mode 100644 index 0000000000..00e48ada8c --- /dev/null +++ b/mempool/txgraph/package_analyzer.go @@ -0,0 +1,43 @@ +package txgraph + +import ( + "github.com/btcsuite/btcd/wire" +) + +// PackageAnalyzer defines the interface for analyzing transaction packages. +// This abstraction allows for different implementations of package validation +// rules without coupling the graph to specific protocol details. This design +// enables testing with mock analyzers and future protocol upgrades without +// modifying core graph logic. +type PackageAnalyzer interface { + // IsTRUCTransaction checks if a transaction is version 3 and thus + // subject to TRUC topology restrictions. This enables the graph to + // identify which transactions need special validation and to enforce + // v3-specific relay policies. + IsTRUCTransaction(tx *wire.MsgTx) bool + + // HasEphemeralDust checks if a transaction contains dust outputs that + // must be spent in the same package. This enables enforcement of the + // ephemeral dust policy where such outputs cannot exist unspent. + HasEphemeralDust(tx *wire.MsgTx) bool + + // IsZeroFee identifies transactions with zero fees, which require + // special handling in packages since they cannot be mined alone. This + // is critical for CPFP and ephemeral dust validation. + IsZeroFee(desc *TxDesc) bool + + // ValidateTRUCPackage enforces BIP 431 topology restrictions on v3 + // packages. This includes single-parent rules, size limits, and + // topology constraints that prevent pinning attacks. + ValidateTRUCPackage(nodes []*TxGraphNode) bool + + // ValidateEphemeralPackage ensures all ephemeral dust outputs are + // spent within the package. This prevents unspendable dust from + // entering the UTXO set and enables zero-fee parent transactions. + ValidateEphemeralPackage(nodes []*TxGraphNode) bool + + // AnalyzePackageType determines the most specific package type + // (1P1C, TRUC, ephemeral, standard) based on structure and properties. + // This classification determines which validation rules apply. + AnalyzePackageType(nodes []*TxGraphNode) PackageType +} \ No newline at end of file diff --git a/mempool/txgraph/package_test.go b/mempool/txgraph/package_test.go new file mode 100644 index 0000000000..ada8cb87ee --- /dev/null +++ b/mempool/txgraph/package_test.go @@ -0,0 +1,549 @@ +package txgraph + +import ( + "testing" + + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +// TestCreatePackage verifies manual package creation through the graph +// API, ensuring that directly constructed packages maintain proper +// topology analysis and fee calculations. This is critical for package +// relay because nodes must be able to construct 1P1C packages on demand +// for relay validation without full package identification. +func TestCreatePackage(t *testing.T) { + g := New(DefaultConfig()) + + // Create simple parent-child pair. The child has higher fee to + // demonstrate CPFP (Child Pays For Parent) economics. + parent, parentDesc := createTestTx(nil, 1) + require.NoError(t, g.AddTransaction(parent, parentDesc)) + + child, childDesc := createTestTx( + []wire.OutPoint{{Hash: *parent.Hash(), Index: 0}}, 1) + // Set higher fee to simulate CPFP scenario where the child + // incentivizes mining of the parent. + childDesc.Fee = 2000 + require.NoError(t, g.AddTransaction(child, childDesc)) + + // Retrieve the graph nodes to construct a package manually. + parentNode, exists := g.GetNode(*parent.Hash()) + require.True(t, exists) + childNode, exists := g.GetNode(*child.Hash()) + require.True(t, exists) + + // Create package and verify it succeeds for connected transactions. + pkg, err := g.CreatePackage([]*TxGraphNode{parentNode, childNode}) + require.NoError(t, err) + require.NotNil(t, pkg) + + // Verify the package correctly identifies as 1P1C with proper + // topology properties for relay validation. + require.NotEmpty(t, pkg.ID) + require.Equal(t, PackageType1P1C, pkg.Type) + require.Len(t, pkg.Transactions, 2) + require.Equal(t, int64(3000), pkg.TotalFees) + require.True(t, pkg.Topology.IsLinear) + require.True(t, pkg.Topology.IsTree) + require.Equal(t, 1, pkg.Topology.MaxDepth) + + // Verify both transactions are properly indexed in the package. + require.NotNil(t, pkg.Transactions[*parent.Hash()]) + require.NotNil(t, pkg.Transactions[*child.Hash()]) +} + +// TestTRUCPackageIdentification verifies that v3 transaction packages +// are correctly identified as TRUC packages per BIP 431. This ensures +// the graph can distinguish TRUC packages from standard packages, +// enabling enforcement of v3 topology restrictions (single parent, +// single child) during package relay validation. +func TestTRUCPackageIdentification(t *testing.T) { + // Use a mock analyzer to simulate TRUC transaction detection + // without requiring full BIP 431 validation logic. + mockAnalyzer := new(MockPackageAnalyzer) + config := DefaultConfig() + config.PackageAnalyzer = mockAnalyzer + + g := New(config) + + // Create parent with version 3 to indicate TRUC transaction. + parentTx := wire.NewMsgTx(3) + parentTx.AddTxOut(wire.NewTxOut(100000, nil)) + parent := btcutil.NewTx(parentTx) + parentDesc := &TxDesc{ + TxHash: *parent.Hash(), + VirtualSize: int64(parent.MsgTx().SerializeSize()), + Fee: 1000, + FeePerKB: 10000, + } + + // Create child also with version 3 to form valid TRUC package. + childTx := wire.NewMsgTx(3) + childTx.AddTxIn(wire.NewTxIn(&wire.OutPoint{ + Hash: *parent.Hash(), + Index: 0, + }, nil, nil)) + childTx.AddTxOut(wire.NewTxOut(90000, nil)) + child := btcutil.NewTx(childTx) + childDesc := &TxDesc{ + TxHash: *child.Hash(), + VirtualSize: int64(child.MsgTx().SerializeSize()), + Fee: 2000, + FeePerKB: 20000, + } + + // Configure mock to report transactions as TRUC (v3). + mockAnalyzer.On("IsTRUCTransaction", mock.Anything).Return(true) + + require.NoError(t, g.AddTransaction(parent, parentDesc)) + require.NoError(t, g.AddTransaction(child, childDesc)) + + // Trigger package identification to classify the package type. + packages, err := g.IdentifyPackages() + require.NoError(t, err) + + // Verify a TRUC package was identified with correct topology + // constraints (depth 1, width 1) per BIP 431 requirements. + found := false + for _, pkg := range packages { + if pkg.Type == PackageTypeTRUC { + found = true + require.Len(t, pkg.Transactions, 2) + // TRUC packages enforce strict 1P1C topology for + // relay rules. + require.Equal(t, 1, pkg.Topology.MaxDepth) + require.Equal(t, 1, pkg.Topology.MaxWidth) + } + } + require.True(t, found, "Should identify TRUC package") + + mockAnalyzer.AssertExpectations(t) +} + +// TestEphemeralPackageIdentification verifies that zero-fee transactions +// with ephemeral dust outputs are correctly classified as ephemeral +// packages. This is essential for package relay because ephemeral dust +// allows zero-fee parents to be relayed as long as a child spends the +// dust output and pays sufficient fees for the entire package. +func TestEphemeralPackageIdentification(t *testing.T) { + // Use mock analyzer to simulate ephemeral dust detection logic + // without full policy validation. + mockAnalyzer := new(MockPackageAnalyzer) + config := DefaultConfig() + config.PackageAnalyzer = mockAnalyzer + + g := New(config) + + // Create parent with zero-value output representing ephemeral + // dust that must be spent immediately. + parentTx := wire.NewMsgTx(wire.TxVersion) + parentTx.AddTxOut(wire.NewTxOut(0, nil)) + parent := btcutil.NewTx(parentTx) + parentDesc := &TxDesc{ + TxHash: *parent.Hash(), + VirtualSize: int64(parent.MsgTx().SerializeSize()), + // Zero fee is allowed because child will pay for the + // entire package. + Fee: 0, + FeePerKB: 0, + } + + // Create child that spends the ephemeral dust output, paying + // fees for both transactions in the package. + childTx := wire.NewMsgTx(wire.TxVersion) + childTx.AddTxIn(wire.NewTxIn(&wire.OutPoint{ + Hash: *parent.Hash(), + Index: 0, + }, nil, nil)) + childTx.AddTxOut(wire.NewTxOut(100000, nil)) + child := btcutil.NewTx(childTx) + childDesc := &TxDesc{ + TxHash: *child.Hash(), + VirtualSize: int64(child.MsgTx().SerializeSize()), + Fee: 1000, + FeePerKB: 10000, + } + + // Configure mock to identify this as an ephemeral dust package. + mockAnalyzer.On("IsTRUCTransaction", mock.Anything).Return(false) + mockAnalyzer.On("HasEphemeralDust", mock.Anything).Return(true) + mockAnalyzer.On("IsZeroFee", mock.Anything).Return(true) + + require.NoError(t, g.AddTransaction(parent, parentDesc)) + require.NoError(t, g.AddTransaction(child, childDesc)) + + // Trigger package identification to detect ephemeral package. + packages, err := g.IdentifyPackages() + require.NoError(t, err) + + // Verify an ephemeral package was identified, enabling special + // relay treatment for zero-fee parents with dust outputs. + found := false + for _, pkg := range packages { + if pkg.Type == PackageTypeEphemeral { + found = true + require.Len(t, pkg.Transactions, 2) + } + } + require.True(t, found, "Should identify ephemeral package") + + mockAnalyzer.AssertExpectations(t) +} + +// TestStandardPackageIdentification verifies that multi-child packages +// are classified as standard packages rather than specialized types. +// This is important for package relay because standard packages have +// different validation rules than 1P1C packages and may not qualify +// for certain relay optimizations. +func TestStandardPackageIdentification(t *testing.T) { + g := New(DefaultConfig()) + + // Create parent with two outputs to enable multiple children, + // which disqualifies this from being a 1P1C package. + parent, parentDesc := createTestTx(nil, 2) + require.NoError(t, g.AddTransaction(parent, parentDesc)) + + // Create two children spending different outputs to form a + // non-linear topology. + child1, child1Desc := createTestTx( + []wire.OutPoint{{Hash: *parent.Hash(), Index: 0}}, 1) + require.NoError(t, g.AddTransaction(child1, child1Desc)) + + child2, child2Desc := createTestTx( + []wire.OutPoint{{Hash: *parent.Hash(), Index: 1}}, 1) + require.NoError(t, g.AddTransaction(child2, child2Desc)) + + // Trigger package identification to classify the package type. + packages, err := g.IdentifyPackages() + require.NoError(t, err) + + // Verify standard package classification due to having two + // children, which violates 1P1C constraints. + found := false + for _, pkg := range packages { + if pkg.Type == PackageTypeStandard { + found = true + require.Len(t, pkg.Transactions, 3) + require.Equal(t, 1, pkg.Topology.MaxDepth) + require.Equal(t, 2, pkg.Topology.MaxWidth) + require.False(t, pkg.Topology.IsLinear) + require.True(t, pkg.Topology.IsTree) + } + } + require.True(t, found, "Should identify standard package") +} + +// TestPackageTopologyCalculation verifies that complex DAG structures +// are correctly analyzed for depth, width, and tree properties. This +// is critical for package relay because topology properties determine +// which relay policies apply (e.g., TRUC requires linear topology, +// standard packages may allow more complex structures). +func TestPackageTopologyCalculation(t *testing.T) { + g := New(DefaultConfig()) + + // Create a diamond-shaped DAG structure with convergent paths + // to test non-tree topology detection: + // root + // / \ + // mid1 mid2 + // \ / + // leaf + + root, rootDesc := createTestTx(nil, 2) + require.NoError(t, g.AddTransaction(root, rootDesc)) + + mid1, mid1Desc := createTestTx( + []wire.OutPoint{{Hash: *root.Hash(), Index: 0}}, 1) + require.NoError(t, g.AddTransaction(mid1, mid1Desc)) + + mid2, mid2Desc := createTestTx( + []wire.OutPoint{{Hash: *root.Hash(), Index: 1}}, 1) + require.NoError(t, g.AddTransaction(mid2, mid2Desc)) + + // Create leaf that spends from both mid transactions, forming + // a diamond pattern with reconvergent paths. + leaf, leafDesc := createTestTx([]wire.OutPoint{ + {Hash: *mid1.Hash(), Index: 0}, + {Hash: *mid2.Hash(), Index: 0}, + }, 1) + require.NoError(t, g.AddTransaction(leaf, leafDesc)) + + // Trigger package identification to compute topology metrics. + packages, err := g.IdentifyPackages() + require.NoError(t, err) + + // Verify the diamond structure is recognized as a single + // package with non-tree topology due to convergence. + require.Len(t, packages, 1) + pkg := packages[0] + + require.Equal(t, 2, pkg.Topology.MaxDepth) + require.Equal(t, 2, pkg.Topology.MaxWidth) + require.False(t, pkg.Topology.IsLinear) + // Diamond pattern violates tree property due to multiple + // paths to leaf node. + require.False(t, pkg.Topology.IsTree) + require.Equal(t, 4, pkg.Topology.TotalNodes) +} + +// TestPackageWithDisconnectedNodes verifies that package creation +// fails when given unrelated transactions with no spending +// relationship. This validates the package invariant that all +// transactions must form a connected subgraph, which is required for +// package relay because relay policies assume topological ordering. +func TestPackageWithDisconnectedNodes(t *testing.T) { + g := New(DefaultConfig()) + + // Create two transactions with no spending relationship to + // test disconnected graph detection. + tx1, desc1 := createTestTx(nil, 1) + require.NoError(t, g.AddTransaction(tx1, desc1)) + + tx2, desc2 := createTestTx(nil, 1) + require.NoError(t, g.AddTransaction(tx2, desc2)) + + // Retrieve both nodes to attempt invalid package construction. + node1, _ := g.GetNode(*tx1.Hash()) + node2, _ := g.GetNode(*tx2.Hash()) + + // Verify that attempting to create a package from disconnected + // transactions returns an error, preventing invalid packages. + _, err := g.CreatePackage([]*TxGraphNode{node1, node2}) + require.Error(t, err) + require.Contains(t, err.Error(), "disconnected") +} + +// TestPackageFeeCalculation verifies that package-level fee rates are +// correctly computed as the aggregate fee divided by aggregate size. +// This is essential for package relay because miners prioritize +// packages by aggregate fee rate, and CPFP scenarios depend on +// accurate package-level fee calculations to incentivize mining. +func TestPackageFeeCalculation(t *testing.T) { + g := New(DefaultConfig()) + + // Create parent with low individual fee rate to simulate + // transactions that benefit from CPFP. + parent, parentDesc := createTestTx(nil, 1) + parentDesc.Fee = 100 + parentDesc.VirtualSize = 100 + parentDesc.FeePerKB = 1000 + require.NoError(t, g.AddTransaction(parent, parentDesc)) + + // Create child with high fee to boost the package fee rate, + // demonstrating CPFP economics. + child, childDesc := createTestTx( + []wire.OutPoint{{Hash: *parent.Hash(), Index: 0}}, 1) + childDesc.Fee = 900 + childDesc.VirtualSize = 100 + childDesc.FeePerKB = 9000 + require.NoError(t, g.AddTransaction(child, childDesc)) + + // Trigger package identification to compute aggregate metrics. + packages, err := g.IdentifyPackages() + require.NoError(t, err) + + // Verify package fee rate is calculated correctly as + // aggregate fees over aggregate size, resulting in a higher + // rate than the parent alone. + require.Len(t, packages, 1) + pkg := packages[0] + + require.Equal(t, int64(1000), pkg.TotalFees) + require.Equal(t, int64(200), pkg.TotalSize) + require.Equal(t, int64(5000), pkg.FeeRate) +} + +// TestComplexPackageIdentification verifies that highly connected +// transaction graphs with multiple roots and convergent paths are +// unified into a single package. This tests the package identification +// algorithm's ability to trace all spending relationships, which is +// critical for package relay to ensure complete package submission. +func TestComplexPackageIdentification(t *testing.T) { + g := New(DefaultConfig()) + + // Create a complex graph with multiple roots and convergent + // spending patterns to test advanced package detection: + // root1 root2 + // | \ / | + // | mid1 | + // | / \ | + // child1 child2 + + root1, root1Desc := createTestTx(nil, 2) + require.NoError(t, g.AddTransaction(root1, root1Desc)) + + root2, root2Desc := createTestTx(nil, 2) + require.NoError(t, g.AddTransaction(root2, root2Desc)) + + // Create middle transaction spending from both roots, forming + // a convergence point. + mid1, mid1Desc := createTestTx([]wire.OutPoint{ + {Hash: *root1.Hash(), Index: 1}, + {Hash: *root2.Hash(), Index: 0}, + }, 2) + require.NoError(t, g.AddTransaction(mid1, mid1Desc)) + + // Create children spending from both direct roots and the + // middle transaction to form complex dependencies. + child1, child1Desc := createTestTx([]wire.OutPoint{ + {Hash: *root1.Hash(), Index: 0}, + {Hash: *mid1.Hash(), Index: 0}, + }, 1) + require.NoError(t, g.AddTransaction(child1, child1Desc)) + + child2, child2Desc := createTestTx([]wire.OutPoint{ + {Hash: *root2.Hash(), Index: 1}, + {Hash: *mid1.Hash(), Index: 1}, + }, 1) + require.NoError(t, g.AddTransaction(child2, child2Desc)) + + // Trigger package identification to unify connected subgraph. + packages, err := g.IdentifyPackages() + require.NoError(t, err) + + // Verify all five transactions are unified into one standard + // package with non-tree topology due to convergence points. + require.Len(t, packages, 1) + pkg := packages[0] + + require.Equal(t, PackageTypeStandard, pkg.Type) + require.Len(t, pkg.Transactions, 5) + require.Equal(t, 2, pkg.Topology.MaxDepth) + require.False(t, pkg.Topology.IsLinear) + // Convergent paths violate tree property. + require.False(t, pkg.Topology.IsTree) +} + +// TestEmptyPackageIdentification verifies that package identification +// on an empty graph returns an empty list rather than failing. This +// validates the edge case handling for package relay initialization. +func TestEmptyPackageIdentification(t *testing.T) { + g := New(DefaultConfig()) + + packages, err := g.IdentifyPackages() + require.NoError(t, err) + require.Empty(t, packages) +} + +// TestSingleTransactionPackage verifies that isolated transactions +// can be packaged individually with correct topology properties. This +// ensures single-transaction packages are properly handled during +// package relay, even though they provide no CPFP benefit. +func TestSingleTransactionPackage(t *testing.T) { + g := New(DefaultConfig()) + + // Create an isolated transaction with no dependencies to test + // the degenerate single-node package case. + tx, desc := createTestTx(nil, 1) + require.NoError(t, g.AddTransaction(tx, desc)) + + // Retrieve the node for manual package construction. + node, exists := g.GetNode(*tx.Hash()) + require.True(t, exists) + + // Create a package containing only this single transaction. + pkg, err := g.CreatePackage([]*TxGraphNode{node}) + require.NoError(t, err) + require.NotNil(t, pkg) + + // Verify the single-transaction package has correct topology + // properties (depth 0, linear, tree). + require.Len(t, pkg.Transactions, 1) + require.Equal(t, PackageTypeStandard, pkg.Type) + require.Equal(t, 0, pkg.Topology.MaxDepth) + require.Equal(t, 1, pkg.Topology.MaxWidth) + require.True(t, pkg.Topology.IsLinear) + require.True(t, pkg.Topology.IsTree) +} + +// TestGetPackage verifies that packages can be retrieved by any +// transaction hash within the package, and that lookups fail for +// transactions not in packages. This supports package relay by +// enabling efficient package discovery given any member transaction. +func TestGetPackage(t *testing.T) { + g := New(DefaultConfig()) + + // Create a simple package to test package lookup functionality. + parent, parentDesc := createTestTx(nil, 1) + require.NoError(t, g.AddTransaction(parent, parentDesc)) + + child, childDesc := createTestTx( + []wire.OutPoint{{Hash: *parent.Hash(), Index: 0}}, 1) + require.NoError(t, g.AddTransaction(child, childDesc)) + + packages, err := g.IdentifyPackages() + require.NoError(t, err) + require.Len(t, packages, 1) + + // Verify package can be retrieved using parent transaction + // hash, demonstrating reverse lookup capability. + pkg, err := g.GetPackage(*parent.Hash()) + require.NoError(t, err) + require.NotNil(t, pkg) + require.Equal(t, packages[0].ID, pkg.ID) + + // Verify package can also be retrieved using child transaction + // hash, showing any package member can be used as lookup key. + pkg, err = g.GetPackage(*child.Hash()) + require.NoError(t, err) + require.NotNil(t, pkg) + require.Equal(t, packages[0].ID, pkg.ID) + + // Verify lookup fails for transaction not in any package, + // properly handling missing keys. + orphan := wire.NewMsgTx(1) + orphanHash := orphan.TxHash() + pkg, err = g.GetPackage(orphanHash) + require.Error(t, err) + require.Nil(t, pkg) +} + +// TestValidatePackage verifies that package validation detects +// malformed packages including nil packages, empty packages, and +// packages with inconsistent topology metadata. This is critical for +// package relay because invalid packages must be rejected before +// network propagation to prevent DoS attacks or relay failures. +func TestValidatePackage(t *testing.T) { + g := New(DefaultConfig()) + + // Create a well-formed package for baseline validation test. + parent, parentDesc := createTestTx(nil, 1) + require.NoError(t, g.AddTransaction(parent, parentDesc)) + + child, childDesc := createTestTx( + []wire.OutPoint{{Hash: *parent.Hash(), Index: 0}}, 1) + require.NoError(t, g.AddTransaction(child, childDesc)) + + parentNode, _ := g.GetNode(*parent.Hash()) + childNode, _ := g.GetNode(*child.Hash()) + + pkg, err := g.CreatePackage([]*TxGraphNode{parentNode, childNode}) + require.NoError(t, err) + + // Verify that a properly constructed package passes validation. + err = g.ValidatePackage(pkg) + require.NoError(t, err) + + // Verify nil package is rejected to prevent nil pointer errors. + err = g.ValidatePackage(nil) + require.Error(t, err) + + // Verify package with inconsistent topology is rejected, + // preventing relay of corrupted package metadata. + pkg.Topology.IsTree = false + pkg.Topology.MaxDepth = 10 + err = g.ValidatePackage(pkg) + require.Error(t, err) + + // Verify empty package is rejected as it provides no value for + // package relay. + emptyPkg := &TxPackage{ + Transactions: make(map[chainhash.Hash]*TxGraphNode), + } + err = g.ValidatePackage(emptyPkg) + require.Error(t, err) +} \ No newline at end of file diff --git a/mempool/txgraph/traversal_test.go b/mempool/txgraph/traversal_test.go new file mode 100644 index 0000000000..cc0de7e7cc --- /dev/null +++ b/mempool/txgraph/traversal_test.go @@ -0,0 +1,450 @@ +package txgraph + +import ( + "slices" + "testing" + + "github.com/btcsuite/btcd/wire" + "github.com/stretchr/testify/require" +) + +// TestIterateBFSWithIncludeStart validates that the IncludeStart option +// correctly controls whether the starting transaction is included in BFS +// traversal results. This is critical for mempool operations where some +// algorithms need to process the anchor transaction itself (e.g., fee +// calculation), while others only need its descendants (e.g., eviction +// checking). +func TestIterateBFSWithIncludeStart(t *testing.T) { + g := New(DefaultConfig()) + + // Create chain: tx1 -> tx2 -> tx3. + tx1, desc1 := createTestTx(nil, 1) + require.NoError(t, g.AddTransaction(tx1, desc1)) + + tx2, desc2 := createTestTx( + []wire.OutPoint{{Hash: *tx1.Hash(), Index: 0}}, 1, + ) + require.NoError(t, g.AddTransaction(tx2, desc2)) + + tx3, desc3 := createTestTx( + []wire.OutPoint{{Hash: *tx2.Hash(), Index: 0}}, 1, + ) + require.NoError(t, g.AddTransaction(tx3, desc3)) + + // When IncludeStart is true, the iterator must emit the start node + // first, followed by its descendants in breadth-first order. This + // allows callers to process the entire subgraph atomically. + visited := slices.Collect(g.Iterate( + WithOrder(TraversalBFS), + WithStartNode(tx2.Hash()), + WithIncludeStart(true), + )) + + require.Len(t, visited, 2, "should include start node") + require.Equal( + t, *tx2.Hash(), visited[0].TxHash, + "start node should be first", + ) + require.Equal(t, *tx3.Hash(), visited[1].TxHash) + + // When IncludeStart is false (default behavior), the iterator skips + // the start node and only yields descendants. This is useful when + // checking if removing a transaction would affect other transactions. + visited = slices.Collect(g.Iterate( + WithOrder(TraversalBFS), + WithStartNode(tx2.Hash()), + WithIncludeStart(false), + )) + + require.Len(t, visited, 1, "should not include start node") + require.Equal(t, *tx3.Hash(), visited[0].TxHash) +} + +// TestIterateDFSWithIncludeStart verifies that depth-first traversal +// respects the IncludeStart option when exploring transaction trees. DFS +// is used in mempool for dependency chain validation where we need to +// explore each branch fully before backtracking, making it essential for +// detecting circular dependencies and validating TRUC topology rules. +func TestIterateDFSWithIncludeStart(t *testing.T) { + g := New(DefaultConfig()) + + // Build a simple tree where tx1 has two children (tx2, tx3). + // This tests DFS behavior on branching structures rather than + // linear chains. + tx1, desc1 := createTestTx(nil, 2) + require.NoError(t, g.AddTransaction(tx1, desc1)) + + tx2, desc2 := createTestTx( + []wire.OutPoint{{Hash: *tx1.Hash(), Index: 0}}, 1, + ) + require.NoError(t, g.AddTransaction(tx2, desc2)) + + tx3, desc3 := createTestTx( + []wire.OutPoint{{Hash: *tx1.Hash(), Index: 1}}, 1, + ) + require.NoError(t, g.AddTransaction(tx3, desc3)) + + // With IncludeStart=true, the root must appear first, allowing + // algorithms to validate parent properties before processing + // children. + visited := slices.Collect(g.Iterate( + WithOrder(TraversalDFS), + WithStartNode(tx1.Hash()), + WithIncludeStart(true), + WithDirection(DirectionForward), + )) + + require.Len(t, visited, 3, "should include start node") + require.Equal( + t, *tx1.Hash(), visited[0].TxHash, + "start node should be first", + ) + + // With IncludeStart=false, we get only descendants. This mode is + // used when validating that child transactions remain valid after + // hypothetically removing the parent. + visited = slices.Collect(g.Iterate( + WithOrder(TraversalDFS), + WithStartNode(tx1.Hash()), + WithIncludeStart(false), + WithDirection(DirectionForward), + )) + + require.Len(t, visited, 2, "should not include start node") + + // Verify we got both children but not the parent. + seenHashes := make(map[string]bool) + for _, n := range visited { + seenHashes[n.TxHash.String()] = true + } + require.False(t, seenHashes[tx1.Hash().String()]) + require.True(t, seenHashes[tx2.Hash().String()]) + require.True(t, seenHashes[tx3.Hash().String()]) +} + +// TestIterateReverseTopoWithFilter validates that reverse topological +// ordering correctly processes transactions from leaves to roots while +// applying filters. This traversal order is critical for block template +// construction where transactions must be evaluated in an order that +// ensures all children are processed before their parents, allowing +// accurate ancestor fee rate calculations. +func TestIterateReverseTopoWithFilter(t *testing.T) { + g := New(DefaultConfig()) + + // Build a dependency chain with varying fee rates to test + // filtering during topological iteration. + tx1, desc1 := createTestTx(nil, 1) + desc1.FeePerKB = 1000 + require.NoError(t, g.AddTransaction(tx1, desc1)) + + tx2, desc2 := createTestTx( + []wire.OutPoint{{Hash: *tx1.Hash(), Index: 0}}, 1, + ) + desc2.FeePerKB = 5000 + require.NoError(t, g.AddTransaction(tx2, desc2)) + + tx3, desc3 := createTestTx( + []wire.OutPoint{{Hash: *tx2.Hash(), Index: 0}}, 1, + ) + desc3.FeePerKB = 10000 + require.NoError(t, g.AddTransaction(tx3, desc3)) + + // The filter allows selective processing of only high-value + // transactions, which is used during mempool eviction to prioritize + // keeping profitable transaction chains. + highFeeFilter := func(n *TxGraphNode) bool { + return n.TxDesc.FeePerKB >= 5000 + } + + visited := slices.Collect(g.Iterate( + WithOrder(TraversalReverseTopo), + WithFilter(highFeeFilter), + )) + + // Verify only high-fee transactions appear, in reverse topological + // order (children before parents). + require.Len(t, visited, 2, "should filter out low-fee transaction") + require.Equal( + t, *tx3.Hash(), visited[0].TxHash, + "tx3 should be first (reverse topo)", + ) + require.Equal(t, *tx2.Hash(), visited[1].TxHash) +} + +// TestIterateWithDirectionBackward verifies that backward traversal walks +// from children to ancestors. This direction is essential for calculating +// the total ancestor set of a transaction, which determines whether adding +// a new transaction would exceed ancestor count/size limits defined in +// mempool policy. +func TestIterateWithDirectionBackward(t *testing.T) { + g := New(DefaultConfig()) + + // Create chain: tx1 -> tx2 -> tx3. + tx1, desc1 := createTestTx(nil, 1) + require.NoError(t, g.AddTransaction(tx1, desc1)) + + tx2, desc2 := createTestTx( + []wire.OutPoint{{Hash: *tx1.Hash(), Index: 0}}, 1, + ) + require.NoError(t, g.AddTransaction(tx2, desc2)) + + tx3, desc3 := createTestTx( + []wire.OutPoint{{Hash: *tx2.Hash(), Index: 0}}, 1, + ) + require.NoError(t, g.AddTransaction(tx3, desc3)) + + // Starting from tx3 and moving backward visits all ancestors in + // breadth-first order. This is how we compute ancestor fees and + // validate ancestor limits. + visited := slices.Collect(g.Iterate( + WithOrder(TraversalBFS), + WithStartNode(tx3.Hash()), + WithDirection(DirectionBackward), + WithIncludeStart(false), + )) + + require.Len(t, visited, 2, "should visit tx2 and tx1") + require.Equal( + t, *tx2.Hash(), visited[0].TxHash, + "tx2 is immediate parent", + ) + require.Equal(t, *tx1.Hash(), visited[1].TxHash) +} + +// TestIterateWithDirectionBoth validates bidirectional traversal, which +// explores both ancestors and descendants from a starting transaction. +// This is used when evicting a transaction from the mempool to identify +// the complete cluster that would be affected, as we must remove all +// descendants while considering ancestor fee implications. +func TestIterateWithDirectionBoth(t *testing.T) { + g := New(DefaultConfig()) + + // Create chain: tx1 -> tx2 -> tx3. + tx1, desc1 := createTestTx(nil, 1) + require.NoError(t, g.AddTransaction(tx1, desc1)) + + tx2, desc2 := createTestTx( + []wire.OutPoint{{Hash: *tx1.Hash(), Index: 0}}, 1, + ) + require.NoError(t, g.AddTransaction(tx2, desc2)) + + tx3, desc3 := createTestTx( + []wire.OutPoint{{Hash: *tx2.Hash(), Index: 0}}, 1, + ) + require.NoError(t, g.AddTransaction(tx3, desc3)) + + // From tx2, bidirectional traversal finds both its ancestor (tx1) + // and descendant (tx3). This gives the complete cluster affected by + // any change to tx2. + visited := slices.Collect(g.Iterate( + WithOrder(TraversalBFS), + WithStartNode(tx2.Hash()), + WithDirection(DirectionBoth), + WithIncludeStart(false), + )) + + require.Len( + t, visited, 2, + "should visit both tx1 (parent) and tx3 (child)", + ) + seenHashes := make(map[string]bool) + for _, n := range visited { + seenHashes[n.TxHash.String()] = true + } + require.True(t, seenHashes[tx1.Hash().String()]) + require.True(t, seenHashes[tx3.Hash().String()]) + require.False( + t, seenHashes[tx2.Hash().String()], + "should not include start", + ) +} + +// TestIteratePairsWithOptions validates that IteratePairs correctly emits +// parent-child relationships as pairs, which is essential for CPFP (Child +// Pays For Parent) analysis. By iterating edges rather than nodes, we can +// efficiently compute fee deltas and determine which children are boosting +// low-fee ancestors. +func TestIteratePairsWithOptions(t *testing.T) { + g := New(DefaultConfig()) + + // Build a tree with one parent and two children to test edge + // enumeration. + tx1, desc1 := createTestTx(nil, 2) + require.NoError(t, g.AddTransaction(tx1, desc1)) + + tx2, desc2 := createTestTx( + []wire.OutPoint{{Hash: *tx1.Hash(), Index: 0}}, 1, + ) + require.NoError(t, g.AddTransaction(tx2, desc2)) + + tx3, desc3 := createTestTx( + []wire.OutPoint{{Hash: *tx1.Hash(), Index: 1}}, 1, + ) + require.NoError(t, g.AddTransaction(tx3, desc3)) + + // IteratePairs emits one pair per edge, allowing us to analyze + // each parent-child relationship independently for fee rate + // calculations. + pairs := slices.Collect(g.IteratePairs( + WithOrder(TraversalDefault), + WithStartNode(tx1.Hash()), + WithDirection(DirectionForward), + )) + + require.Len(t, pairs, 2, "should have 2 edges from tx1") + + // Verify each pair represents a valid edge from tx1 to one of its + // children. + for _, pair := range pairs { + require.Equal(t, *tx1.Hash(), pair.Parent.TxHash) + require.True(t, + pair.Child.TxHash == *tx2.Hash() || + pair.Child.TxHash == *tx3.Hash(), + "child should be tx2 or tx3", + ) + } +} + +// TestIteratePairsWithFilter validates that filters are applied to +// parent-child pairs, enabling selective analysis of specific +// relationships. This is used in RBF (Replace-By-Fee) scenarios where we +// need to identify which high-value dependencies would be broken by +// replacing a transaction. +func TestIteratePairsWithFilter(t *testing.T) { + g := New(DefaultConfig()) + + // Create two independent parent-child chains with different fee + // rates to test filtering at the edge level. + tx1, desc1 := createTestTx(nil, 1) + desc1.FeePerKB = 10000 + require.NoError(t, g.AddTransaction(tx1, desc1)) + + tx2, desc2 := createTestTx(nil, 1) + desc2.FeePerKB = 1000 + require.NoError(t, g.AddTransaction(tx2, desc2)) + + tx3, desc3 := createTestTx( + []wire.OutPoint{{Hash: *tx1.Hash(), Index: 0}}, 1, + ) + require.NoError(t, g.AddTransaction(tx3, desc3)) + + tx4, desc4 := createTestTx( + []wire.OutPoint{{Hash: *tx2.Hash(), Index: 0}}, 1, + ) + require.NoError(t, g.AddTransaction(tx4, desc4)) + + // The filter applies to parent nodes in the pairs, allowing us to + // focus analysis on edges originating from high-fee transactions. + highFeeFilter := func(n *TxGraphNode) bool { + return n.TxDesc.FeePerKB >= 5000 + } + + pairs := slices.Collect(g.IteratePairs( + WithOrder(TraversalDefault), + WithFilter(highFeeFilter), + )) + + // Only the edge from high-fee tx1 should appear. + require.Len(t, pairs, 1, "should filter out low-fee parent edges") + require.Equal(t, *tx1.Hash(), pairs[0].Parent.TxHash) + require.Equal(t, *tx3.Hash(), pairs[0].Child.TxHash) +} + +// TestIterateBackwardWithMaxDepth ensures that depth limits correctly +// bound backward traversal. This prevents unbounded ancestor walks in +// large transaction chains and enables efficient "bounded ancestor +// search" needed for quick policy checks without traversing the entire +// mempool history. +func TestIterateBackwardWithMaxDepth(t *testing.T) { + g := New(DefaultConfig()) + + // Build a longer chain to test depth limiting. + tx1, desc1 := createTestTx(nil, 1) + require.NoError(t, g.AddTransaction(tx1, desc1)) + + tx2, desc2 := createTestTx( + []wire.OutPoint{{Hash: *tx1.Hash(), Index: 0}}, 1, + ) + require.NoError(t, g.AddTransaction(tx2, desc2)) + + tx3, desc3 := createTestTx( + []wire.OutPoint{{Hash: *tx2.Hash(), Index: 0}}, 1, + ) + require.NoError(t, g.AddTransaction(tx3, desc3)) + + tx4, desc4 := createTestTx( + []wire.OutPoint{{Hash: *tx3.Hash(), Index: 0}}, 1, + ) + require.NoError(t, g.AddTransaction(tx4, desc4)) + + // MaxDepth limits how far back we search. This is critical for + // performance when checking if a transaction has any recent + // unconfirmed ancestors without walking the entire chain. + visited := slices.Collect(g.Iterate( + WithOrder(TraversalBFS), + WithStartNode(tx4.Hash()), + WithDirection(DirectionBackward), + WithMaxDepth(2), + WithIncludeStart(false), + )) + + // Depth 1 yields tx3, depth 2 yields tx2. tx1 at depth 3 is + // excluded. + require.Len(t, visited, 2, "should respect maxDepth") + require.Equal(t, *tx3.Hash(), visited[0].TxHash) + require.Equal(t, *tx2.Hash(), visited[1].TxHash) +} + +// TestIterateDFSBackwardWithFilter validates depth-first backward +// traversal with filtering on diamond-shaped DAGs. This pattern is +// crucial for analyzing complex dependency structures where a transaction +// has multiple parents, as seen in coinjoin and batched payment scenarios +// where we need to identify which specific ancestor chains meet certain +// criteria. +func TestIterateDFSBackwardWithFilter(t *testing.T) { + g := New(DefaultConfig()) + + // Build a diamond: tx1 branches to tx2 and tx3, which both feed into + // tx4. This tests filtering when multiple paths exist to ancestors. + tx1, desc1 := createTestTx(nil, 2) + desc1.FeePerKB = 1000 + require.NoError(t, g.AddTransaction(tx1, desc1)) + + tx2, desc2 := createTestTx( + []wire.OutPoint{{Hash: *tx1.Hash(), Index: 0}}, 1, + ) + desc2.FeePerKB = 10000 + require.NoError(t, g.AddTransaction(tx2, desc2)) + + tx3, desc3 := createTestTx( + []wire.OutPoint{{Hash: *tx1.Hash(), Index: 1}}, 1, + ) + desc3.FeePerKB = 500 + require.NoError(t, g.AddTransaction(tx3, desc3)) + + tx4, desc4 := createTestTx([]wire.OutPoint{ + {Hash: *tx2.Hash(), Index: 0}, + {Hash: *tx3.Hash(), Index: 0}, + }, 1) + require.NoError(t, g.AddTransaction(tx4, desc4)) + + // When traversing backward with a filter, we only follow paths + // through ancestors that pass the predicate. This allows + // identifying specific "valuable" dependency chains while ignoring + // low-value branches. + highFeeFilter := func(n *TxGraphNode) bool { + return n.TxDesc.FeePerKB >= 5000 + } + + visited := slices.Collect(g.Iterate( + WithOrder(TraversalDFS), + WithStartNode(tx4.Hash()), + WithDirection(DirectionBackward), + WithFilter(highFeeFilter), + WithIncludeStart(false), + )) + + // Only tx2 passes the filter. tx3 and tx1 are both too low-fee. + require.Len(t, visited, 1, "should filter out low-fee nodes") + require.Equal(t, *tx2.Hash(), visited[0].TxHash) +} \ No newline at end of file