diff --git a/lnprototest/clightning/clightning.py b/lnprototest/clightning/clightning.py index c574103..0ff79fd 100644 --- a/lnprototest/clightning/clightning.py +++ b/lnprototest/clightning/clightning.py @@ -55,6 +55,7 @@ class Runner(lnprototest.Runner): def __init__(self, config: Any): super().__init__(config) self.running = False + self.rpc = None self.cleanup_callbacks: List[Callable[[], None]] = [] self.fundchannel_future: Optional[Any] = None self.is_fundchannel_kill = False @@ -81,8 +82,8 @@ def __init__(self, config: Any): stdout=subprocess.PIPE, check=True, ) - .stdout.decode("utf-8") - .splitlines() + .stdout.decode("utf-8") + .splitlines() ) self.options: Dict[str, str] = {} for o in opts: @@ -91,6 +92,7 @@ def __init__(self, config: Any): else: k, v = o.split("/") self.options[k] = v + self.start() def get_keyset(self) -> KeySet: return KeySet( @@ -111,7 +113,7 @@ def is_running(self) -> bool: return self.running def start(self) -> None: - if self.running: + if self.is_running(): return self.proc = subprocess.Popen( [ @@ -139,6 +141,7 @@ def start(self) -> None: self.rpc = pyln.client.LightningRpc( os.path.join(self.lightning_dir, "regtest", "lightning-rpc") ) + def node_ready(rpc: pyln.client.LightningRpc) -> bool: try: rpc.getinfo() @@ -152,22 +155,12 @@ def node_ready(rpc: pyln.client.LightningRpc) -> bool: for i in range(5): self.rpc.newaddr() - def kill_fundchannel(self) -> None: - fut = self.fundchannel_future - self.fundchannel_future = None - self.is_fundchannel_kill = True - if fut: - try: - fut.result(0) - except (SpecFileError, futures.TimeoutError): - pass - def shutdown(self) -> None: for cb in self.cleanup_callbacks: cb() def stop(self) -> None: - if self.running is False: + if not self.running: return self.shutdown() self.rpc.stop() @@ -176,6 +169,16 @@ def stop(self) -> None: for c in self.conns.values(): cast(CLightningConn, c).connection.connection.close() + def kill_fundchannel(self) -> None: + fut = self.fundchannel_future + self.fundchannel_future = None + self.is_fundchannel_kill = True + if fut: + try: + fut.result(0) + except (SpecFileError, futures.TimeoutError): + pass + def connect(self, event: Event, connprivkey: str) -> None: self.add_conn(CLightningConn(connprivkey, self.lightning_port)) @@ -223,12 +226,12 @@ def recv(self, event: Event, conn: Conn, outbuf: bytes) -> None: raise EventError(event, "Connection closed") def fundchannel( - self, - event: Event, - conn: Conn, - amount: int, - feerate: int = 253, - expect_fail: bool = False, + self, + event: Event, + conn: Conn, + amount: int, + feerate: int = 253, + expect_fail: bool = False, ) -> None: """ event - the event which cause this, for error logging @@ -248,11 +251,11 @@ def fundchannel( self.fundchannel_future = None def _fundchannel( - runner: Runner, - conn: Conn, - amount: int, - feerate: int, - expect_fail: bool = False, + runner: Runner, + conn: Conn, + amount: int, + feerate: int, + expect_fail: bool = False, ) -> str: peer_id = conn.pubkey.format().hex() # Need to supply feerate here, since regtest cannot estimate fees @@ -276,14 +279,14 @@ def _done(fut: Any) -> None: self.cleanup_callbacks.append(self.kill_fundchannel) def init_rbf( - self, - event: Event, - conn: Conn, - channel_id: str, - amount: int, - utxo_txid: str, - utxo_outnum: int, - feerate: int, + self, + event: Event, + conn: Conn, + channel_id: str, + amount: int, + utxo_txid: str, + utxo_outnum: int, + feerate: int, ) -> None: if self.fundchannel_future: @@ -356,7 +359,7 @@ def addhtlc(self, event: Event, conn: Conn, amount: int, preimage: str) -> None: self.rpc.sendpay([routestep], payhash) def get_output_message( - self, conn: Conn, event: Event, timeout: int = TIMEOUT + self, conn: Conn, event: Event, timeout: int = TIMEOUT ) -> Optional[bytes]: fut = self.executor.submit(cast(CLightningConn, conn).connection.read_message) try: @@ -373,11 +376,11 @@ def check_error(self, event: Event, conn: Conn) -> Optional[str]: return msg.hex() def check_final_error( - self, - event: Event, - conn: Conn, - expected: bool, - must_not_events: List[MustNotMsg], + self, + event: Event, + conn: Conn, + expected: bool, + must_not_events: List[MustNotMsg], ) -> None: if not expected: # Inject raw packet to ensure it hangs up *after* processing all previous ones. diff --git a/lnprototest/dummyrunner.py b/lnprototest/dummyrunner.py index c32c84e..63185fc 100644 --- a/lnprototest/dummyrunner.py +++ b/lnprototest/dummyrunner.py @@ -16,7 +16,6 @@ class DummyRunner(Runner): - def __init__(self, config: Any): super().__init__(config) @@ -87,12 +86,12 @@ def recv(self, event: Event, conn: Conn, outbuf: bytes) -> None: print("[RECV {} {}]".format(event, outbuf.hex())) def fundchannel( - self, - event: Event, - conn: Conn, - amount: int, - feerate: int = 253, - expect_fail: bool = False, + self, + event: Event, + conn: Conn, + amount: int, + feerate: int = 253, + expect_fail: bool = False, ) -> None: if self.config.getoption("verbose"): print( @@ -102,14 +101,14 @@ def fundchannel( ) def init_rbf( - self, - event: Event, - conn: Conn, - channel_id: str, - amount: int, - utxo_txid: str, - utxo_outnum: int, - feerate: int, + self, + event: Event, + conn: Conn, + channel_id: str, + amount: int, + utxo_txid: str, + utxo_outnum: int, + feerate: int, ) -> None: if self.config.getoption("verbose"): print( @@ -144,21 +143,21 @@ def fake_field(ftype: FieldType) -> str: if ftype.elemtype.name == "byte": return "00" * ftype.arraysize return ( - "[" - + ",".join([DummyRunner.fake_field(ftype.elemtype)] * ftype.arraysize) - + "]" + "[" + + ",".join([DummyRunner.fake_field(ftype.elemtype)] * ftype.arraysize) + + "]" ) elif ftype.name in ( - "byte", - "u8", - "u16", - "u32", - "u64", - "tu16", - "tu32", - "tu64", - "bigsize", - "varint", + "byte", + "u8", + "u16", + "u32", + "u64", + "tu16", + "tu32", + "tu64", + "bigsize", + "varint", ): return "0" elif ftype.name in ("chain_hash", "channel_id", "sha256"): @@ -201,11 +200,11 @@ def check_error(self, event: Event, conn: Conn) -> Optional[str]: return "Dummy error" def check_final_error( - self, - event: Event, - conn: Conn, - expected: bool, - must_not_events: List[MustNotMsg], + self, + event: Event, + conn: Conn, + expected: bool, + must_not_events: List[MustNotMsg], ) -> None: pass diff --git a/lnprototest/runner.py b/lnprototest/runner.py index d65d6e2..c081020 100644 --- a/lnprototest/runner.py +++ b/lnprototest/runner.py @@ -49,7 +49,8 @@ def __init__(self, config: Any): self.stash: Dict[str, Dict[str, Any]] = {} def __enter__(self) -> "Runner": - """Call the method when enter inside the class the first time""" + """Call the method when enter inside the class the first time. + doc: https://docs.python.org/3/reference/datamodel.html#with-statement-context-managers""" self.start() return self @@ -233,6 +234,7 @@ def close_channel(self, channel_id: str) -> bool: a boolean value if it succeeded with success""" pass + def remote_revocation_basepoint() -> Callable[[Runner, Event, str], str]: """Get the remote revocation basepoint""" diff --git a/tests/conftest.py b/tests/conftest.py index 2cc04b6..cc468c1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,8 +27,7 @@ def pytest_addoption(parser: Any) -> None: @pytest.fixture() # type: ignore def runner(pytestconfig: Any) -> Any: parts = pytestconfig.getoption("runner").rpartition(".") - runner = importlib.import_module(parts[0]).__dict__[parts[2]](pytestconfig) - yield runner + yield importlib.import_module(parts[0]).__dict__[parts[2]](pytestconfig) @pytest.fixture()