@@ -477,6 +477,7 @@ impl SimulatedChannel {
477477
478478/// SimNetwork represents a high level network coordinator that is responsible for the task of actually propagating
479479/// payments through the simulated network.
480+ #[ async_trait]
480481pub trait SimNetwork : Send + Sync {
481482 /// Sends payments over the route provided through the network, reporting the final payment outcome to the sender
482483 /// channel provided.
@@ -490,7 +491,7 @@ pub trait SimNetwork: Send + Sync {
490491 ) ;
491492
492493 /// Looks up a node in the simulated network and a list of its channel capacities.
493- fn lookup_node ( & self , node : & PublicKey ) -> Result < ( NodeInfo , Vec < u64 > ) , LightningError > ;
494+ async fn lookup_node ( & self , node : & PublicKey ) -> Result < ( NodeInfo , Vec < u64 > ) , LightningError > ;
494495 /// Lists all nodes in the simulated network.
495496 fn list_nodes ( & self ) -> Vec < NodeInfo > ;
496497}
@@ -794,11 +795,17 @@ impl<T: SimNetwork, C: Clock> LightningNode for SimNode<T, C> {
794795 }
795796
796797 async fn get_node_info ( & self , node_id : & PublicKey ) -> Result < NodeInfo , LightningError > {
797- Ok ( self . network . lock ( ) . await . lookup_node ( node_id) ?. 0 )
798+ Ok ( self . network . lock ( ) . await . lookup_node ( node_id) . await ?. 0 )
798799 }
799800
800- async fn list_channels ( & self ) -> Result < Vec < u64 > , LightningError > {
801- Ok ( self . network . lock ( ) . await . lookup_node ( & self . info . pubkey ) ?. 1 )
801+ async fn channel_capacities ( & self ) -> Result < u64 , LightningError > {
802+ let channels = self
803+ . network
804+ . lock ( )
805+ . await
806+ . lookup_node ( & self . info . pubkey )
807+ . await ?;
808+ Ok ( channels. 1 . iter ( ) . sum ( ) )
802809 }
803810
804811 async fn get_graph ( & self ) -> Result < Graph , LightningError > {
@@ -1017,9 +1024,9 @@ async fn handle_intercepted_htlc(
10171024
10181025/// Graph is the top level struct that is used to coordinate simulation of lightning nodes.
10191026pub struct SimGraph {
1020- /// nodes caches the list of nodes in the network with a vector of their channel capacities , only used for quick
1027+ /// nodes caches the list of nodes in the network with a vector of their channel ids , only used for quick
10211028 /// lookup.
1022- nodes : HashMap < PublicKey , ( NodeInfo , Vec < u64 > ) > ,
1029+ nodes : HashMap < PublicKey , ( NodeInfo , Vec < ShortChannelID > ) > ,
10231030
10241031 /// channels maps the scid of a channel to its current simulation state.
10251032 channels : Arc < Mutex < HashMap < ShortChannelID , SimulatedChannel > > > ,
@@ -1051,7 +1058,7 @@ impl SimGraph {
10511058 default_custom_records : CustomRecords ,
10521059 shutdown_signal : ( Trigger , Listener ) ,
10531060 ) -> Result < Self , SimulationError > {
1054- let mut nodes: HashMap < PublicKey , ( NodeInfo , Vec < u64 > ) > = HashMap :: new ( ) ;
1061+ let mut nodes: HashMap < PublicKey , ( NodeInfo , Vec < ShortChannelID > ) > = HashMap :: new ( ) ;
10551062 let mut channels = HashMap :: new ( ) ;
10561063
10571064 for channel in graph_channels. iter ( ) {
@@ -1068,18 +1075,16 @@ impl SimGraph {
10681075 Entry :: Vacant ( v) => v. insert ( channel. clone ( ) ) ,
10691076 } ;
10701077
1071- if !channel. exclude_capacity {
1072- // It's okay to have duplicate pubkeys because one node can have many channels.
1073- for info in [ & channel. node_1 . policy , & channel. node_2 . policy ] {
1074- match nodes. entry ( info. pubkey ) {
1075- Entry :: Occupied ( o) => o. into_mut ( ) . 1 . push ( channel. capacity_msat ) ,
1076- Entry :: Vacant ( v) => {
1077- v. insert ( (
1078- node_info ( info. pubkey , info. alias . clone ( ) ) ,
1079- vec ! [ channel. capacity_msat] ,
1080- ) ) ;
1081- } ,
1082- }
1078+ // It's okay to have duplicate pubkeys because one node can have many channels.
1079+ for info in [ & channel. node_1 . policy , & channel. node_2 . policy ] {
1080+ match nodes. entry ( info. pubkey ) {
1081+ Entry :: Occupied ( o) => o. into_mut ( ) . 1 . push ( channel. short_channel_id ) ,
1082+ Entry :: Vacant ( v) => {
1083+ v. insert ( (
1084+ node_info ( info. pubkey , info. alias . clone ( ) ) ,
1085+ vec ! [ channel. short_channel_id] ,
1086+ ) ) ;
1087+ } ,
10831088 }
10841089 }
10851090 }
@@ -1101,9 +1106,11 @@ pub async fn ln_node_from_graph<C: Clock>(
11011106 routing_graph : Arc < LdkNetworkGraph > ,
11021107 clock : Arc < C > ,
11031108) -> Result < HashMap < PublicKey , Arc < Mutex < SimNode < SimGraph , C > > > > , LightningError > {
1104- let mut nodes: HashMap < PublicKey , Arc < Mutex < SimNode < SimGraph , C > > > > = HashMap :: new ( ) ;
1109+ let sim_graph = graph. lock ( ) . await ;
1110+ let mut nodes: HashMap < PublicKey , Arc < Mutex < SimNode < SimGraph , C > > > > =
1111+ HashMap :: with_capacity ( sim_graph. nodes . len ( ) ) ;
11051112
1106- for node in graph . lock ( ) . await . nodes . iter ( ) {
1113+ for node in sim_graph . nodes . iter ( ) {
11071114 nodes. insert (
11081115 * node. 0 ,
11091116 Arc :: new ( Mutex :: new ( SimNode :: new (
@@ -1182,6 +1189,7 @@ pub fn populate_network_graph<C: Clock>(
11821189 Ok ( graph)
11831190}
11841191
1192+ #[ async_trait]
11851193impl SimNetwork for SimGraph {
11861194 /// dispatch_payment asynchronously propagates a payment through the simulated network, returning a tracking
11871195 /// channel that can be used to obtain the result of the payment. At present, MPP payments are not supported.
@@ -1231,13 +1239,27 @@ impl SimNetwork for SimGraph {
12311239 }
12321240
12331241 /// lookup_node fetches a node's information and channel capacities.
1234- fn lookup_node ( & self , node : & PublicKey ) -> Result < ( NodeInfo , Vec < u64 > ) , LightningError > {
1235- match self . nodes . get ( node) {
1236- Some ( node) => Ok ( node. clone ( ) ) ,
1237- None => Err ( LightningError :: GetNodeInfoError (
1238- "Node not found" . to_string ( ) ,
1239- ) ) ,
1240- }
1242+ async fn lookup_node ( & self , node : & PublicKey ) -> Result < ( NodeInfo , Vec < u64 > ) , LightningError > {
1243+ let node_info = match self . nodes . get ( node) {
1244+ Some ( node) => node. clone ( ) ,
1245+ None => {
1246+ return Err ( LightningError :: GetNodeInfoError ( format ! (
1247+ "Node {} not found" ,
1248+ node
1249+ ) ) )
1250+ } ,
1251+ } ;
1252+
1253+ let channels = self . channels . lock ( ) . await ;
1254+ let capacities: Vec < u64 > = node_info
1255+ . 1
1256+ . iter ( )
1257+ . filter_map ( |scid| channels. get ( scid) )
1258+ . filter ( |channel| !channel. exclude_capacity )
1259+ . map ( |channel| channel. capacity_msat )
1260+ . collect ( ) ;
1261+
1262+ Ok ( ( node_info. 0 , capacities) )
12411263 }
12421264
12431265 fn list_nodes ( & self ) -> Vec < NodeInfo > {
@@ -1965,34 +1987,25 @@ mod tests {
19651987 . await
19661988 . unwrap ( ) ;
19671989
1968- let node_1_channels = nodes
1969- . get ( & pk1)
1970- . unwrap ( )
1971- . lock ( )
1972- . await
1973- . list_channels ( )
1974- . await
1975- . unwrap ( ) ;
1990+ assert ! ( nodes. len( ) == 3 ) ;
19761991
1977- // Node 1 has 2 channels but one was excluded so here we should only have the one that was
1978- // not excluded.
1979- assert ! ( node_1_channels. len( ) == 1 ) ;
1980- assert ! ( node_1_channels[ 0 ] == capacity_1) ;
1992+ let node_1 = nodes. get ( & pk1) . unwrap ( ) . lock ( ) . await ;
1993+ let node_1_capacity = node_1. channel_capacities ( ) . await . unwrap ( ) ;
19811994
1982- let node_2_channels = nodes
1983- . get ( & pk2)
1984- . unwrap ( )
1985- . lock ( )
1986- . await
1987- . list_channels ( )
1988- . await
1989- . unwrap ( ) ;
1995+ // Node 1 has 2 channels but one was excluded so here we should only have the capacity of
1996+ // the channel that was not excluded.
1997+ assert ! ( node_1_capacity == capacity_1) ;
19901998
1991- assert ! ( node_2_channels. len( ) == 1 ) ;
1992- assert ! ( node_2_channels[ 0 ] == capacity_1) ;
1999+ let node_2 = nodes. get ( & pk2) . unwrap ( ) . lock ( ) . await ;
2000+ let node_2_capacity = node_2. channel_capacities ( ) . await . unwrap ( ) ;
2001+ assert ! ( node_2_capacity == capacity_1) ;
19932002
1994- // Node 3's only channel was excluded so it won't be present here.
1995- assert ! ( !nodes. contains_key( & pk3) ) ;
2003+ // Node 3 should be returned from ln_node_from_graph but it won't have any channel capacity
2004+ // present because its only channel was excluded.
2005+ let node_3 = nodes. get ( & pk3) ;
2006+ assert ! ( node_3. is_some( ) ) ;
2007+ let node_3 = node_3. unwrap ( ) . lock ( ) . await ;
2008+ assert ! ( node_3. channel_capacities( ) . await . unwrap( ) == 0 ) ;
19962009 }
19972010
19982011 /// Tests basic functionality of a `SimulatedChannel` but does no endeavor to test the underlying
@@ -2062,6 +2075,7 @@ mod tests {
20622075 mock ! {
20632076 Network { }
20642077
2078+ #[ async_trait]
20652079 impl SimNetwork for Network {
20662080 fn dispatch_payment(
20672081 & mut self ,
@@ -2072,7 +2086,7 @@ mod tests {
20722086 sender: Sender <Result <PaymentResult , LightningError >>,
20732087 ) ;
20742088
2075- fn lookup_node( & self , node: & PublicKey ) -> Result <( NodeInfo , Vec <u64 >) , LightningError >;
2089+ async fn lookup_node( & self , node: & PublicKey ) -> Result <( NodeInfo , Vec <u64 >) , LightningError >;
20762090 fn list_nodes( & self ) -> Vec <NodeInfo >;
20772091 }
20782092 }
@@ -2103,12 +2117,17 @@ mod tests {
21032117 . lock ( )
21042118 . await
21052119 . expect_lookup_node ( )
2106- . returning ( move |_| Ok ( ( node_info ( lookup_pk, String :: default ( ) ) , vec ! [ 1 , 2 , 3 ] ) ) ) ;
2120+ . returning ( move |_| {
2121+ Ok ( (
2122+ node_info ( lookup_pk, String :: default ( ) ) ,
2123+ vec ! [ 10_000 , 20_000 , 10_000 ] ,
2124+ ) )
2125+ } ) ;
21072126
21082127 // Assert that we get three channels from the mock.
21092128 let node_info = node. get_node_info ( & lookup_pk) . await . unwrap ( ) ;
21102129 assert_eq ! ( lookup_pk, node_info. pubkey) ;
2111- assert_eq ! ( node. list_channels ( ) . await . unwrap( ) . len ( ) , 3 ) ;
2130+ assert_eq ! ( node. channel_capacities ( ) . await . unwrap( ) , 40_000 ) ;
21122131
21132132 // Next, we're going to test handling of in-flight payments. To do this, we'll mock out calls to our dispatch
21142133 // function to send different results depending on the destination.
0 commit comments