diff --git a/graph/db/benchmark_test.go b/graph/db/benchmark_test.go index 1d9a782fc96..f1366459b9c 100644 --- a/graph/db/benchmark_test.go +++ b/graph/db/benchmark_test.go @@ -19,6 +19,8 @@ import ( "github.com/lightningnetwork/lnd/kvdb/postgres" "github.com/lightningnetwork/lnd/kvdb/sqlbase" "github.com/lightningnetwork/lnd/kvdb/sqlite" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/sqldb" "github.com/stretchr/testify/require" "golang.org/x/time/rate" @@ -643,6 +645,10 @@ func BenchmarkGraphReadMethods(b *testing.B) { nativeSQLPostgresConn, } + // We use a counter to make sure that any call-back is doing something + // useful, otherwise the compiler may optimize it away in the future. + var counter int64 + tests := []struct { name string fn func(b testing.TB, store V1Store) @@ -652,6 +658,11 @@ func BenchmarkGraphReadMethods(b *testing.B) { fn: func(b testing.TB, store V1Store) { err := store.ForEachNode( ctx, func(_ NodeRTx) error { + // Increment the counter to + // ensure the callback is doing + // something. + counter++ + return nil }, func() {}, ) @@ -667,6 +678,11 @@ func BenchmarkGraphReadMethods(b *testing.B) { _ *models.ChannelEdgePolicy, _ *models.ChannelEdgePolicy) error { + // Increment the counter to + // ensure the callback is doing + // something. + counter++ + return nil }, func() {}, ) @@ -682,6 +698,43 @@ func BenchmarkGraphReadMethods(b *testing.B) { require.NoError(b, err) }, }, + { + name: "ForEachNodeCacheable", + fn: func(b testing.TB, store V1Store) { + err := store.ForEachNodeCacheable( + ctx, func(_ route.Vertex, + _ *lnwire.FeatureVector) error { + + // Increment the counter to + // ensure the callback is doing + // something. + counter++ + + return nil + }, func() {}, + ) + require.NoError(b, err) + }, + }, + { + name: "ForEachNodeCached", + fn: func(b testing.TB, store V1Store) { + //nolint:ll + err := store.ForEachNodeCached( + ctx, func(route.Vertex, + map[uint64]*DirectedChannel) error { + + // Increment the counter to + // ensure the callback is doing + // something. + counter++ + + return nil + }, func() {}, + ) + require.NoError(b, err) + }, + }, } for _, test := range tests { diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index 02a3f637280..9d439f44862 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -100,6 +100,7 @@ type SQLQueries interface { GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error) HighestSCID(ctx context.Context, version int16) ([]byte, error) ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error) + ListChannelsForNodeIDs(ctx context.Context, arg sqlc.ListChannelsForNodeIDsParams) ([]sqlc.ListChannelsForNodeIDsRow, error) ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error) ListChannelsWithPoliciesForCachePaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesForCachePaginatedParams) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow, error) ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error) @@ -756,7 +757,7 @@ func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context, } return forEachNodeChannel( - ctx, db, s.cfg.ChainHash, nodeID, + ctx, db, s.cfg, nodeID, func(info *models.ChannelEdgeInfo, outPolicy *models.ChannelEdgePolicy, _ *models.ChannelEdgePolicy) error { @@ -814,7 +815,7 @@ func (s *SQLStore) ForEachNode(ctx context.Context, node *models.LightningNode) error { return cb(newSQLGraphNodeTx( - db, s.cfg.ChainHash, dbNodeID, node, + db, s.cfg, dbNodeID, node, )) }, ) @@ -824,24 +825,24 @@ func (s *SQLStore) ForEachNode(ctx context.Context, // sqlGraphNodeTx is an implementation of the NodeRTx interface backed by the // SQLStore and a SQL transaction. type sqlGraphNodeTx struct { - db SQLQueries - id int64 - node *models.LightningNode - chain chainhash.Hash + db SQLQueries + id int64 + node *models.LightningNode + cfg *SQLStoreConfig } // A compile-time constraint to ensure sqlGraphNodeTx implements the NodeRTx // interface. var _ NodeRTx = (*sqlGraphNodeTx)(nil) -func newSQLGraphNodeTx(db SQLQueries, chain chainhash.Hash, +func newSQLGraphNodeTx(db SQLQueries, cfg *SQLStoreConfig, id int64, node *models.LightningNode) *sqlGraphNodeTx { return &sqlGraphNodeTx{ - db: db, - chain: chain, - id: id, - node: node, + db: db, + cfg: cfg, + id: id, + node: node, } } @@ -861,7 +862,7 @@ func (s *sqlGraphNodeTx) ForEachChannel(cb func(*models.ChannelEdgeInfo, ctx := context.TODO() - return forEachNodeChannel(ctx, s.db, s.chain, s.id, cb) + return forEachNodeChannel(ctx, s.db, s.cfg, s.id, cb) } // FetchNode fetches the node with the given pub key under the same transaction @@ -878,7 +879,7 @@ func (s *sqlGraphNodeTx) FetchNode(nodePub route.Vertex) (NodeRTx, error) { nodePub, err) } - return newSQLGraphNodeTx(s.db, s.chain, id, node), nil + return newSQLGraphNodeTx(s.db, s.cfg, id, node), nil } // ForEachNodeDirectedChannel iterates through all channels of a given node, @@ -903,8 +904,6 @@ func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex, // graph, executing the passed callback with each node encountered. If the // callback returns an error, then the transaction is aborted and the iteration // stops early. -// -// NOTE: This is a part of the V1Store interface. func (s *SQLStore) ForEachNodeCacheable(ctx context.Context, cb func(route.Vertex, *lnwire.FeatureVector) error, reset func()) error { @@ -912,14 +911,8 @@ func (s *SQLStore) ForEachNodeCacheable(ctx context.Context, err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { return forEachNodeCacheable( ctx, s.cfg.QueryCfg, db, - func(nodeID int64, nodePub route.Vertex) error { - features, err := getNodeFeatures( - ctx, db, nodeID, - ) - if err != nil { - return fmt.Errorf("unable to fetch "+ - "node features: %w", err) - } + func(_ int64, nodePub route.Vertex, + features *lnwire.FeatureVector) error { return cb(nodePub, features) }, @@ -959,9 +952,7 @@ func (s *SQLStore) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex, return fmt.Errorf("unable to fetch node: %w", err) } - return forEachNodeChannel( - ctx, db, s.cfg.ChainHash, dbNode.ID, cb, - ) + return forEachNodeChannel(ctx, db, s.cfg, dbNode.ID, cb) }, reset) } @@ -1093,116 +1084,185 @@ func (s *SQLStore) ForEachNodeCached(ctx context.Context, cb func(node route.Vertex, chans map[uint64]*DirectedChannel) error, reset func()) error { - handleNode := func(db SQLQueries, nodeID int64, - nodePub route.Vertex) error { + type nodeCachedBatchData struct { + features map[int64][]int + chanBatchData *batchChannelData + chanMap map[int64][]sqlc.ListChannelsForNodeIDsRow + } - features, err := getNodeFeatures(ctx, db, nodeID) - if err != nil { - return fmt.Errorf("unable to fetch node(id=%d) "+ - "features: %w", nodeID, err) - } + return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { + // pageQueryFunc is used to query the next page of nodes. + pageQueryFunc := func(ctx context.Context, lastID int64, + limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) { - toNodeCallback := func() route.Vertex { - return nodePub + return db.ListNodeIDsAndPubKeys( + ctx, sqlc.ListNodeIDsAndPubKeysParams{ + Version: int16(ProtocolV1), + ID: lastID, + Limit: limit, + }, + ) } - rows, err := db.ListChannelsByNodeID( - ctx, sqlc.ListChannelsByNodeIDParams{ - Version: int16(ProtocolV1), - NodeID1: nodeID, - }, - ) - if err != nil { - return fmt.Errorf("unable to fetch channels of "+ - "node(id=%d): %w", nodeID, err) - } + // batchDataFunc is then used to batch load the data required + // for each page of nodes. + batchDataFunc := func(ctx context.Context, + nodeIDs []int64) (*nodeCachedBatchData, error) { - channels := make(map[uint64]*DirectedChannel, len(rows)) - for _, row := range rows { - node1, node2, err := buildNodeVertices( - row.Node1Pubkey, row.Node2Pubkey, + // Batch load node features. + nodeFeatures, err := batchLoadNodeFeaturesHelper( + ctx, s.cfg.QueryCfg, db, nodeIDs, ) if err != nil { - return err + return nil, fmt.Errorf("unable to batch load "+ + "node features: %w", err) } - e, err := getAndBuildEdgeInfo( - ctx, db, s.cfg.ChainHash, row.GraphChannel, - node1, node2, + // Batch load ALL unique channels for ALL nodes in this + // page. + allChannels, err := db.ListChannelsForNodeIDs( + ctx, sqlc.ListChannelsForNodeIDsParams{ + Version: int16(ProtocolV1), + Node1Ids: nodeIDs, + Node2Ids: nodeIDs, + }, ) if err != nil { - return fmt.Errorf("unable to build channel "+ - "info: %w", err) + return nil, fmt.Errorf("unable to batch "+ + "fetch channels for nodes: %w", err) } - dbPol1, dbPol2, err := extractChannelPolicies(row) - if err != nil { - return fmt.Errorf("unable to extract channel "+ - "policies: %w", err) + // Deduplicate channels and collect IDs. + var ( + allChannelIDs []int64 + allPolicyIDs []int64 + ) + uniqueChannels := make( + map[int64]sqlc.ListChannelsForNodeIDsRow, + ) + + for _, channel := range allChannels { + channelID := channel.GraphChannel.ID + + // Only process each unique channel once. + _, exists := uniqueChannels[channelID] + if exists { + continue + } + + uniqueChannels[channelID] = channel + allChannelIDs = append(allChannelIDs, channelID) + + if channel.Policy1ID.Valid { + allPolicyIDs = append( + allPolicyIDs, + channel.Policy1ID.Int64, + ) + } + if channel.Policy2ID.Valid { + allPolicyIDs = append( + allPolicyIDs, + channel.Policy2ID.Int64, + ) + } } - p1, p2, err := getAndBuildChanPolicies( - ctx, db, dbPol1, dbPol2, e.ChannelID, node1, - node2, + // Batch load channel data for all unique channels. + channelBatchData, err := batchLoadChannelData( + ctx, s.cfg.QueryCfg, db, allChannelIDs, + allPolicyIDs, ) if err != nil { - return fmt.Errorf("unable to build channel "+ - "policies: %w", err) + return nil, fmt.Errorf("unable to batch "+ + "load channel data: %w", err) } - // Determine the outgoing and incoming policy - // for this channel and node combo. - outPolicy, inPolicy := p1, p2 - if p1 != nil && p1.ToNode == nodePub { - outPolicy, inPolicy = p2, p1 - } else if p2 != nil && p2.ToNode != nodePub { - outPolicy, inPolicy = p2, p1 + // Create map of node ID to channels that involve this + // node. + nodeIDSet := make(map[int64]bool) + for _, nodeID := range nodeIDs { + nodeIDSet[nodeID] = true } - var cachedInPolicy *models.CachedEdgePolicy - if inPolicy != nil { - cachedInPolicy = models.NewCachedPolicy( - inPolicy, - ) - cachedInPolicy.ToNodePubKey = toNodeCallback - cachedInPolicy.ToNodeFeatures = features + nodeChannelMap := make( + map[int64][]sqlc.ListChannelsForNodeIDsRow, + ) + for _, channel := range uniqueChannels { + // Add channel to both nodes if they're in our + // current page. + node1 := channel.GraphChannel.NodeID1 + if nodeIDSet[node1] { + nodeChannelMap[node1] = append( + nodeChannelMap[node1], channel, + ) + } + node2 := channel.GraphChannel.NodeID2 + if nodeIDSet[node2] { + nodeChannelMap[node2] = append( + nodeChannelMap[node2], channel, + ) + } } - var inboundFee lnwire.Fee - if outPolicy != nil { - outPolicy.InboundFee.WhenSome( - func(fee lnwire.Fee) { - inboundFee = fee - }, - ) + return &nodeCachedBatchData{ + features: nodeFeatures, + chanBatchData: channelBatchData, + chanMap: nodeChannelMap, + }, nil + } + + // processItem is used to process each node in the current page. + processItem := func(ctx context.Context, + nodeData sqlc.ListNodeIDsAndPubKeysRow, + batchData *nodeCachedBatchData) error { + + // Build feature vector for this node. + fv := lnwire.EmptyFeatureVector() + features, exists := batchData.features[nodeData.ID] + if exists { + for _, bit := range features { + fv.Set(lnwire.FeatureBit(bit)) + } } - directedChannel := &DirectedChannel{ - ChannelID: e.ChannelID, - IsNode1: nodePub == e.NodeKey1Bytes, - OtherNode: e.NodeKey2Bytes, - Capacity: e.Capacity, - OutPolicySet: outPolicy != nil, - InPolicy: cachedInPolicy, - InboundFee: inboundFee, + var nodePub route.Vertex + copy(nodePub[:], nodeData.PubKey) + + nodeChannels := batchData.chanMap[nodeData.ID] + + toNodeCallback := func() route.Vertex { + return nodePub } - if nodePub == e.NodeKey2Bytes { - directedChannel.OtherNode = e.NodeKey1Bytes + // Build cached channels map for this node. + channels := make(map[uint64]*DirectedChannel) + for _, channelRow := range nodeChannels { + directedChan, err := buildDirectedChannel( + s.cfg.ChainHash, nodeData.ID, nodePub, + channelRow, batchData.chanBatchData, fv, + toNodeCallback, + ) + if err != nil { + return err + } + + channels[directedChan.ChannelID] = directedChan } - channels[e.ChannelID] = directedChannel + return cb(nodePub, channels) } - return cb(nodePub, channels) - } + return sqldb.ExecuteCollectAndBatchWithSharedDataQuery( + ctx, s.cfg.QueryCfg, int64(-1), pageQueryFunc, + func(node sqlc.ListNodeIDsAndPubKeysRow) int64 { + return node.ID + }, + func(node sqlc.ListNodeIDsAndPubKeysRow) (int64, + error) { - return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { - return forEachNodeCacheable( - ctx, s.cfg.QueryCfg, db, - func(nodeID int64, nodePub route.Vertex) error { - return handleNode(db, nodeID, nodePub) + return node.ID, nil }, + batchDataFunc, processItem, ) }, reset) } @@ -2951,18 +3011,27 @@ func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries, } // forEachNodeCacheable fetches all V1 node IDs and pub keys from the database, -// and executes the provided callback for each node. +// and executes the provided callback for each node. It does so via pagination +// along with batch loading of the node feature bits. func forEachNodeCacheable(ctx context.Context, cfg *sqldb.QueryConfig, - db SQLQueries, - cb func(nodeID int64, nodePub route.Vertex) error) error { + db SQLQueries, processNode func(nodeID int64, nodePub route.Vertex, + features *lnwire.FeatureVector) error) error { handleNode := func(_ context.Context, - node sqlc.ListNodeIDsAndPubKeysRow) error { + dbNode sqlc.ListNodeIDsAndPubKeysRow, + featureBits map[int64][]int) error { + + fv := lnwire.EmptyFeatureVector() + if features, exists := featureBits[dbNode.ID]; exists { + for _, bit := range features { + fv.Set(lnwire.FeatureBit(bit)) + } + } var pub route.Vertex - copy(pub[:], node.PubKey) + copy(pub[:], dbNode.PubKey) - return cb(node.ID, pub) + return processNode(dbNode.ID, pub, fv) } queryFunc := func(ctx context.Context, lastID int64, @@ -2981,8 +3050,19 @@ func forEachNodeCacheable(ctx context.Context, cfg *sqldb.QueryConfig, return row.ID } - return sqldb.ExecutePaginatedQuery( - ctx, cfg, int64(-1), queryFunc, extractCursor, handleNode, + collectFunc := func(node sqlc.ListNodeIDsAndPubKeysRow) (int64, error) { + return node.ID, nil + } + + batchQueryFunc := func(ctx context.Context, + nodeIDs []int64) (map[int64][]int, error) { + + return batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs) + } + + return sqldb.ExecuteCollectAndBatchWithSharedDataQuery( + ctx, cfg, int64(-1), queryFunc, extractCursor, collectFunc, + batchQueryFunc, handleNode, ) } @@ -2991,11 +3071,11 @@ func forEachNodeCacheable(ctx context.Context, cfg *sqldb.QueryConfig, // edge information, the outgoing policy and the incoming policy for the // channel and node combo. func forEachNodeChannel(ctx context.Context, db SQLQueries, - chain chainhash.Hash, id int64, cb func(*models.ChannelEdgeInfo, + cfg *SQLStoreConfig, id int64, cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error { - // Get all the V1 channels for this node.Add commentMore actions + // Get all the V1 channels for this node. rows, err := db.ListChannelsByNodeID( ctx, sqlc.ListChannelsByNodeIDParams{ Version: int16(ProtocolV1), @@ -3006,6 +3086,29 @@ func forEachNodeChannel(ctx context.Context, db SQLQueries, return fmt.Errorf("unable to fetch channels: %w", err) } + // Collect all the channel and policy IDs. + var ( + chanIDs = make([]int64, 0, len(rows)) + policyIDs = make([]int64, 0, 2*len(rows)) + ) + for _, row := range rows { + chanIDs = append(chanIDs, row.GraphChannel.ID) + + if row.Policy1ID.Valid { + policyIDs = append(policyIDs, row.Policy1ID.Int64) + } + if row.Policy2ID.Valid { + policyIDs = append(policyIDs, row.Policy2ID.Int64) + } + } + + batchData, err := batchLoadChannelData( + ctx, cfg.QueryCfg, db, chanIDs, policyIDs, + ) + if err != nil { + return fmt.Errorf("unable to batch load channel data: %w", err) + } + // Call the call-back for each channel and its known policies. for _, row := range rows { node1, node2, err := buildNodeVertices( @@ -3016,8 +3119,9 @@ func forEachNodeChannel(ctx context.Context, db SQLQueries, err) } - edge, err := getAndBuildEdgeInfo( - ctx, db, chain, row.GraphChannel, node1, node2, + edge, err := buildEdgeInfoWithBatchData( + cfg.ChainHash, row.GraphChannel, node1, node2, + batchData, ) if err != nil { return fmt.Errorf("unable to build channel info: %w", @@ -3030,8 +3134,8 @@ func forEachNodeChannel(ctx context.Context, db SQLQueries, "policies: %w", err) } - p1, p2, err := getAndBuildChanPolicies( - ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2, + p1, p2, err := buildChanPoliciesWithBatchData( + dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData, ) if err != nil { return fmt.Errorf("unable to build channel "+ @@ -4403,6 +4507,50 @@ func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy, return policy1, policy2, nil + case sqlc.ListChannelsForNodeIDsRow: + if r.Policy1ID.Valid { + policy1 = &sqlc.GraphChannelPolicy{ + ID: r.Policy1ID.Int64, + Version: r.Policy1Version.Int16, + ChannelID: r.GraphChannel.ID, + NodeID: r.Policy1NodeID.Int64, + Timelock: r.Policy1Timelock.Int32, + FeePpm: r.Policy1FeePpm.Int64, + BaseFeeMsat: r.Policy1BaseFeeMsat.Int64, + MinHtlcMsat: r.Policy1MinHtlcMsat.Int64, + MaxHtlcMsat: r.Policy1MaxHtlcMsat, + LastUpdate: r.Policy1LastUpdate, + InboundBaseFeeMsat: r.Policy1InboundBaseFeeMsat, + InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat, + Disabled: r.Policy1Disabled, + MessageFlags: r.Policy1MessageFlags, + ChannelFlags: r.Policy1ChannelFlags, + Signature: r.Policy1Signature, + } + } + if r.Policy2ID.Valid { + policy2 = &sqlc.GraphChannelPolicy{ + ID: r.Policy2ID.Int64, + Version: r.Policy2Version.Int16, + ChannelID: r.GraphChannel.ID, + NodeID: r.Policy2NodeID.Int64, + Timelock: r.Policy2Timelock.Int32, + FeePpm: r.Policy2FeePpm.Int64, + BaseFeeMsat: r.Policy2BaseFeeMsat.Int64, + MinHtlcMsat: r.Policy2MinHtlcMsat.Int64, + MaxHtlcMsat: r.Policy2MaxHtlcMsat, + LastUpdate: r.Policy2LastUpdate, + InboundBaseFeeMsat: r.Policy2InboundBaseFeeMsat, + InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat, + Disabled: r.Policy2Disabled, + MessageFlags: r.Policy2MessageFlags, + ChannelFlags: r.Policy2ChannelFlags, + Signature: r.Policy2Signature, + } + } + + return policy1, policy2, nil + case sqlc.ListChannelsByNodeIDRow: if r.Policy1ID.Valid { policy1 = &sqlc.GraphChannelPolicy{ @@ -5110,3 +5258,83 @@ func forEachChannelWithPolicies(ctx context.Context, db SQLQueries, collectFunc, batchDataFunc, processItem, ) } + +// buildDirectedChannel builds a DirectedChannel instance from the provided +// data. +func buildDirectedChannel(chain chainhash.Hash, nodeID int64, + nodePub route.Vertex, channelRow sqlc.ListChannelsForNodeIDsRow, + channelBatchData *batchChannelData, features *lnwire.FeatureVector, + toNodeCallback func() route.Vertex) (*DirectedChannel, error) { + + node1, node2, err := buildNodeVertices( + channelRow.Node1Pubkey, channelRow.Node2Pubkey, + ) + if err != nil { + return nil, fmt.Errorf("unable to build node vertices: %w", err) + } + + edge, err := buildEdgeInfoWithBatchData( + chain, channelRow.GraphChannel, node1, node2, channelBatchData, + ) + if err != nil { + return nil, fmt.Errorf("unable to build channel info: %w", err) + } + + dbPol1, dbPol2, err := extractChannelPolicies(channelRow) + if err != nil { + return nil, fmt.Errorf("unable to extract channel policies: %w", + err) + } + + p1, p2, err := buildChanPoliciesWithBatchData( + dbPol1, dbPol2, edge.ChannelID, node1, node2, + channelBatchData, + ) + if err != nil { + return nil, fmt.Errorf("unable to build channel policies: %w", + err) + } + + // Determine outgoing and incoming policy for this specific node. + p1ToNode := channelRow.GraphChannel.NodeID2 + p2ToNode := channelRow.GraphChannel.NodeID1 + outPolicy, inPolicy := p1, p2 + if (p1 != nil && p1ToNode == nodeID) || + (p2 != nil && p2ToNode != nodeID) { + + outPolicy, inPolicy = p2, p1 + } + + // Build cached policy. + var cachedInPolicy *models.CachedEdgePolicy + if inPolicy != nil { + cachedInPolicy = models.NewCachedPolicy(inPolicy) + cachedInPolicy.ToNodePubKey = toNodeCallback + cachedInPolicy.ToNodeFeatures = features + } + + // Extract inbound fee. + var inboundFee lnwire.Fee + if outPolicy != nil { + outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) { + inboundFee = fee + }) + } + + // Build directed channel. + directedChannel := &DirectedChannel{ + ChannelID: edge.ChannelID, + IsNode1: nodePub == edge.NodeKey1Bytes, + OtherNode: edge.NodeKey2Bytes, + Capacity: edge.Capacity, + OutPolicySet: outPolicy != nil, + InPolicy: cachedInPolicy, + InboundFee: inboundFee, + } + + if nodePub == edge.NodeKey2Bytes { + directedChannel.OtherNode = edge.NodeKey1Bytes + } + + return directedChannel, nil +} diff --git a/sqldb/sqlc/graph.sql.go b/sqldb/sqlc/graph.sql.go index fa6d50803f0..6db5d32c8af 100644 --- a/sqldb/sqlc/graph.sql.go +++ b/sqldb/sqlc/graph.sql.go @@ -2320,6 +2320,190 @@ func (q *Queries) ListChannelsByNodeID(ctx context.Context, arg ListChannelsByNo return items, nil } +const listChannelsForNodeIDs = `-- name: ListChannelsForNodeIDs :many +SELECT c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature, + n1.pub_key AS node1_pubkey, + n2.pub_key AS node2_pubkey, + + -- Policy 1 + -- TODO(elle): use sqlc.embed to embed policy structs + -- once this issue is resolved: + -- https://github.com/sqlc-dev/sqlc/issues/2997 + cp1.id AS policy1_id, + cp1.node_id AS policy1_node_id, + cp1.version AS policy1_version, + cp1.timelock AS policy1_timelock, + cp1.fee_ppm AS policy1_fee_ppm, + cp1.base_fee_msat AS policy1_base_fee_msat, + cp1.min_htlc_msat AS policy1_min_htlc_msat, + cp1.max_htlc_msat AS policy1_max_htlc_msat, + cp1.last_update AS policy1_last_update, + cp1.disabled AS policy1_disabled, + cp1.inbound_base_fee_msat AS policy1_inbound_base_fee_msat, + cp1.inbound_fee_rate_milli_msat AS policy1_inbound_fee_rate_milli_msat, + cp1.message_flags AS policy1_message_flags, + cp1.channel_flags AS policy1_channel_flags, + cp1.signature AS policy1_signature, + + -- Policy 2 + cp2.id AS policy2_id, + cp2.node_id AS policy2_node_id, + cp2.version AS policy2_version, + cp2.timelock AS policy2_timelock, + cp2.fee_ppm AS policy2_fee_ppm, + cp2.base_fee_msat AS policy2_base_fee_msat, + cp2.min_htlc_msat AS policy2_min_htlc_msat, + cp2.max_htlc_msat AS policy2_max_htlc_msat, + cp2.last_update AS policy2_last_update, + cp2.disabled AS policy2_disabled, + cp2.inbound_base_fee_msat AS policy2_inbound_base_fee_msat, + cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, + cp2.message_flags AS policy2_message_flags, + cp2.channel_flags AS policy2_channel_flags, + cp2.signature AS policy2_signature + +FROM graph_channels c + JOIN graph_nodes n1 ON c.node_id_1 = n1.id + JOIN graph_nodes n2 ON c.node_id_2 = n2.id + LEFT JOIN graph_channel_policies cp1 + ON cp1.channel_id = c.id AND cp1.node_id = c.node_id_1 AND cp1.version = c.version + LEFT JOIN graph_channel_policies cp2 + ON cp2.channel_id = c.id AND cp2.node_id = c.node_id_2 AND cp2.version = c.version +WHERE c.version = $1 + AND (c.node_id_1 IN (/*SLICE:node1_ids*/?) + OR c.node_id_2 IN (/*SLICE:node2_ids*/?)) +` + +type ListChannelsForNodeIDsParams struct { + Version int16 + Node1Ids []int64 + Node2Ids []int64 +} + +type ListChannelsForNodeIDsRow struct { + GraphChannel GraphChannel + Node1Pubkey []byte + Node2Pubkey []byte + Policy1ID sql.NullInt64 + Policy1NodeID sql.NullInt64 + Policy1Version sql.NullInt16 + Policy1Timelock sql.NullInt32 + Policy1FeePpm sql.NullInt64 + Policy1BaseFeeMsat sql.NullInt64 + Policy1MinHtlcMsat sql.NullInt64 + Policy1MaxHtlcMsat sql.NullInt64 + Policy1LastUpdate sql.NullInt64 + Policy1Disabled sql.NullBool + Policy1InboundBaseFeeMsat sql.NullInt64 + Policy1InboundFeeRateMilliMsat sql.NullInt64 + Policy1MessageFlags sql.NullInt16 + Policy1ChannelFlags sql.NullInt16 + Policy1Signature []byte + Policy2ID sql.NullInt64 + Policy2NodeID sql.NullInt64 + Policy2Version sql.NullInt16 + Policy2Timelock sql.NullInt32 + Policy2FeePpm sql.NullInt64 + Policy2BaseFeeMsat sql.NullInt64 + Policy2MinHtlcMsat sql.NullInt64 + Policy2MaxHtlcMsat sql.NullInt64 + Policy2LastUpdate sql.NullInt64 + Policy2Disabled sql.NullBool + Policy2InboundBaseFeeMsat sql.NullInt64 + Policy2InboundFeeRateMilliMsat sql.NullInt64 + Policy2MessageFlags sql.NullInt16 + Policy2ChannelFlags sql.NullInt16 + Policy2Signature []byte +} + +func (q *Queries) ListChannelsForNodeIDs(ctx context.Context, arg ListChannelsForNodeIDsParams) ([]ListChannelsForNodeIDsRow, error) { + query := listChannelsForNodeIDs + var queryParams []interface{} + queryParams = append(queryParams, arg.Version) + if len(arg.Node1Ids) > 0 { + for _, v := range arg.Node1Ids { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:node1_ids*/?", makeQueryParams(len(queryParams), len(arg.Node1Ids)), 1) + } else { + query = strings.Replace(query, "/*SLICE:node1_ids*/?", "NULL", 1) + } + if len(arg.Node2Ids) > 0 { + for _, v := range arg.Node2Ids { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:node2_ids*/?", makeQueryParams(len(queryParams), len(arg.Node2Ids)), 1) + } else { + query = strings.Replace(query, "/*SLICE:node2_ids*/?", "NULL", 1) + } + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListChannelsForNodeIDsRow + for rows.Next() { + var i ListChannelsForNodeIDsRow + if err := rows.Scan( + &i.GraphChannel.ID, + &i.GraphChannel.Version, + &i.GraphChannel.Scid, + &i.GraphChannel.NodeID1, + &i.GraphChannel.NodeID2, + &i.GraphChannel.Outpoint, + &i.GraphChannel.Capacity, + &i.GraphChannel.BitcoinKey1, + &i.GraphChannel.BitcoinKey2, + &i.GraphChannel.Node1Signature, + &i.GraphChannel.Node2Signature, + &i.GraphChannel.Bitcoin1Signature, + &i.GraphChannel.Bitcoin2Signature, + &i.Node1Pubkey, + &i.Node2Pubkey, + &i.Policy1ID, + &i.Policy1NodeID, + &i.Policy1Version, + &i.Policy1Timelock, + &i.Policy1FeePpm, + &i.Policy1BaseFeeMsat, + &i.Policy1MinHtlcMsat, + &i.Policy1MaxHtlcMsat, + &i.Policy1LastUpdate, + &i.Policy1Disabled, + &i.Policy1InboundBaseFeeMsat, + &i.Policy1InboundFeeRateMilliMsat, + &i.Policy1MessageFlags, + &i.Policy1ChannelFlags, + &i.Policy1Signature, + &i.Policy2ID, + &i.Policy2NodeID, + &i.Policy2Version, + &i.Policy2Timelock, + &i.Policy2FeePpm, + &i.Policy2BaseFeeMsat, + &i.Policy2MinHtlcMsat, + &i.Policy2MaxHtlcMsat, + &i.Policy2LastUpdate, + &i.Policy2Disabled, + &i.Policy2InboundBaseFeeMsat, + &i.Policy2InboundFeeRateMilliMsat, + &i.Policy2MessageFlags, + &i.Policy2ChannelFlags, + &i.Policy2Signature, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const listChannelsPaginated = `-- name: ListChannelsPaginated :many SELECT id, bitcoin_key_1, bitcoin_key_2, outpoint FROM graph_channels c diff --git a/sqldb/sqlc/querier.go b/sqldb/sqlc/querier.go index f7fad44ab5f..55fd74b755a 100644 --- a/sqldb/sqlc/querier.go +++ b/sqldb/sqlc/querier.go @@ -98,6 +98,7 @@ type Querier interface { IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error) IsZombieChannel(ctx context.Context, arg IsZombieChannelParams) (bool, error) ListChannelsByNodeID(ctx context.Context, arg ListChannelsByNodeIDParams) ([]ListChannelsByNodeIDRow, error) + ListChannelsForNodeIDs(ctx context.Context, arg ListChannelsForNodeIDsParams) ([]ListChannelsForNodeIDsRow, error) ListChannelsPaginated(ctx context.Context, arg ListChannelsPaginatedParams) ([]ListChannelsPaginatedRow, error) ListChannelsWithPoliciesForCachePaginated(ctx context.Context, arg ListChannelsWithPoliciesForCachePaginatedParams) ([]ListChannelsWithPoliciesForCachePaginatedRow, error) ListChannelsWithPoliciesPaginated(ctx context.Context, arg ListChannelsWithPoliciesPaginatedParams) ([]ListChannelsWithPoliciesPaginatedRow, error) diff --git a/sqldb/sqlc/queries/graph.sql b/sqldb/sqlc/queries/graph.sql index ae497941341..ae276a5cb6e 100644 --- a/sqldb/sqlc/queries/graph.sql +++ b/sqldb/sqlc/queries/graph.sql @@ -446,6 +446,59 @@ WHERE version = $1 ORDER BY scid DESC LIMIT 1; +-- name: ListChannelsForNodeIDs :many +SELECT sqlc.embed(c), + n1.pub_key AS node1_pubkey, + n2.pub_key AS node2_pubkey, + + -- Policy 1 + -- TODO(elle): use sqlc.embed to embed policy structs + -- once this issue is resolved: + -- https://github.com/sqlc-dev/sqlc/issues/2997 + cp1.id AS policy1_id, + cp1.node_id AS policy1_node_id, + cp1.version AS policy1_version, + cp1.timelock AS policy1_timelock, + cp1.fee_ppm AS policy1_fee_ppm, + cp1.base_fee_msat AS policy1_base_fee_msat, + cp1.min_htlc_msat AS policy1_min_htlc_msat, + cp1.max_htlc_msat AS policy1_max_htlc_msat, + cp1.last_update AS policy1_last_update, + cp1.disabled AS policy1_disabled, + cp1.inbound_base_fee_msat AS policy1_inbound_base_fee_msat, + cp1.inbound_fee_rate_milli_msat AS policy1_inbound_fee_rate_milli_msat, + cp1.message_flags AS policy1_message_flags, + cp1.channel_flags AS policy1_channel_flags, + cp1.signature AS policy1_signature, + + -- Policy 2 + cp2.id AS policy2_id, + cp2.node_id AS policy2_node_id, + cp2.version AS policy2_version, + cp2.timelock AS policy2_timelock, + cp2.fee_ppm AS policy2_fee_ppm, + cp2.base_fee_msat AS policy2_base_fee_msat, + cp2.min_htlc_msat AS policy2_min_htlc_msat, + cp2.max_htlc_msat AS policy2_max_htlc_msat, + cp2.last_update AS policy2_last_update, + cp2.disabled AS policy2_disabled, + cp2.inbound_base_fee_msat AS policy2_inbound_base_fee_msat, + cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, + cp2.message_flags AS policy2_message_flags, + cp2.channel_flags AS policy2_channel_flags, + cp2.signature AS policy2_signature + +FROM graph_channels c + JOIN graph_nodes n1 ON c.node_id_1 = n1.id + JOIN graph_nodes n2 ON c.node_id_2 = n2.id + LEFT JOIN graph_channel_policies cp1 + ON cp1.channel_id = c.id AND cp1.node_id = c.node_id_1 AND cp1.version = c.version + LEFT JOIN graph_channel_policies cp2 + ON cp2.channel_id = c.id AND cp2.node_id = c.node_id_2 AND cp2.version = c.version +WHERE c.version = $1 + AND (c.node_id_1 IN (sqlc.slice('node1_ids')/*SLICE:node1_ids*/) + OR c.node_id_2 IN (sqlc.slice('node2_ids')/*SLICE:node2_ids*/)); + -- name: ListChannelsByNodeID :many SELECT sqlc.embed(c), n1.pub_key AS node1_pubkey,