Skip to content

Commit

Permalink
lnprototest: refactroing abstract class and fixed the #14
Browse files Browse the repository at this point in the history
kill all the process when the class was removed from the scope.

Signed-off-by: Vincenzo Palazzo <[email protected]>
  • Loading branch information
vincenzopalazzo committed Mar 9, 2022
1 parent d5b7efe commit 22a05fc
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 59 deletions.
54 changes: 27 additions & 27 deletions lnprototest/clightning/clightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def __init__(self, config: Any):
stdout=subprocess.PIPE,
check=True,
)
.stdout.decode("utf-8")
.splitlines()
.stdout.decode("utf-8")
.splitlines()
)
self.options: Dict[str, str] = {}
for o in opts:
Expand Down Expand Up @@ -232,12 +232,12 @@ def recv(self, event: Event, conn: Conn, outbuf: bytes) -> None:
raise EventError(event, "Connection closed")

def fundchannel(
self,
event: Event,
conn: Conn,
amount: int,
feerate: int = 253,
expect_fail: bool = False,
self,
event: Event,
conn: Conn,
amount: int,
feerate: int = 253,
expect_fail: bool = False,
) -> None:
"""
event - the event which cause this, for error logging
Expand All @@ -257,11 +257,11 @@ def fundchannel(
self.fundchannel_future = None

def _fundchannel(
runner: Runner,
conn: Conn,
amount: int,
feerate: int,
expect_fail: bool = False,
runner: Runner,
conn: Conn,
amount: int,
feerate: int,
expect_fail: bool = False,
) -> str:
peer_id = conn.pubkey.format().hex()
# Need to supply feerate here, since regtest cannot estimate fees
Expand All @@ -285,14 +285,14 @@ def _done(fut: Any) -> None:
self.cleanup_callbacks.append(self.kill_fundchannel)

def init_rbf(
self,
event: Event,
conn: Conn,
channel_id: str,
amount: int,
utxo_txid: str,
utxo_outnum: int,
feerate: int,
self,
event: Event,
conn: Conn,
channel_id: str,
amount: int,
utxo_txid: str,
utxo_outnum: int,
feerate: int,
) -> None:

if self.fundchannel_future:
Expand Down Expand Up @@ -365,7 +365,7 @@ def addhtlc(self, event: Event, conn: Conn, amount: int, preimage: str) -> None:
self.rpc.sendpay([routestep], payhash)

def get_output_message(
self, conn: Conn, event: Event, timeout: int = TIMEOUT
self, conn: Conn, event: Event, timeout: int = TIMEOUT
) -> Optional[bytes]:
fut = self.executor.submit(cast(CLightningConn, conn).connection.read_message)
try:
Expand All @@ -382,11 +382,11 @@ def check_error(self, event: Event, conn: Conn) -> Optional[str]:
return msg.hex()

def check_final_error(
self,
event: Event,
conn: Conn,
expected: bool,
must_not_events: List[MustNotMsg],
self,
event: Event,
conn: Conn,
expected: bool,
must_not_events: List[MustNotMsg],
) -> None:
if not expected:
# Inject raw packet to ensure it hangs up *after* processing all previous ones.
Expand Down
65 changes: 33 additions & 32 deletions lnprototest/dummyrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@


class DummyRunner(Runner):

def __init__(self, config: Any):
super().__init__(config)

Expand Down Expand Up @@ -86,12 +87,12 @@ def recv(self, event: Event, conn: Conn, outbuf: bytes) -> None:
print("[RECV {} {}]".format(event, outbuf.hex()))

def fundchannel(
self,
event: Event,
conn: Conn,
amount: int,
feerate: int = 253,
expect_fail: bool = False,
self,
event: Event,
conn: Conn,
amount: int,
feerate: int = 253,
expect_fail: bool = False,
) -> None:
if self.config.getoption("verbose"):
print(
Expand All @@ -101,14 +102,14 @@ def fundchannel(
)

def init_rbf(
self,
event: Event,
conn: Conn,
channel_id: str,
amount: int,
utxo_txid: str,
utxo_outnum: int,
feerate: int,
self,
event: Event,
conn: Conn,
channel_id: str,
amount: int,
utxo_txid: str,
utxo_outnum: int,
feerate: int,
) -> None:
if self.config.getoption("verbose"):
print(
Expand Down Expand Up @@ -143,21 +144,21 @@ def fake_field(ftype: FieldType) -> str:
if ftype.elemtype.name == "byte":
return "00" * ftype.arraysize
return (
"["
+ ",".join([DummyRunner.fake_field(ftype.elemtype)] * ftype.arraysize)
+ "]"
"["
+ ",".join([DummyRunner.fake_field(ftype.elemtype)] * ftype.arraysize)
+ "]"
)
elif ftype.name in (
"byte",
"u8",
"u16",
"u32",
"u64",
"tu16",
"tu32",
"tu64",
"bigsize",
"varint",
"byte",
"u8",
"u16",
"u32",
"u64",
"tu16",
"tu32",
"tu64",
"bigsize",
"varint",
):
return "0"
elif ftype.name in ("chain_hash", "channel_id", "sha256"):
Expand Down Expand Up @@ -200,11 +201,11 @@ def check_error(self, event: Event, conn: Conn) -> Optional[str]:
return "Dummy error"

def check_final_error(
self,
event: Event,
conn: Conn,
expected: bool,
must_not_events: List[MustNotMsg],
self,
event: Event,
conn: Conn,
expected: bool,
must_not_events: List[MustNotMsg],
) -> None:
pass

Expand Down
9 changes: 9 additions & 0 deletions lnprototest/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,15 @@ def __init__(self, config: Any):
else:
self.logger.setLevel(logging.INFO)

def __enter__(self) -> "Runner":
"""Call the method when enter inside the class the first time"""
self.start()
return self

def __del__(self):
"""Call the method when the class is out of scope"""
self.stop()

def _is_dummy(self) -> bool:
"""The DummyRunner returns True here, as it can't do some things"""
return False
Expand Down

0 comments on commit 22a05fc

Please sign in to comment.