Skip to content

Commit

Permalink
bug fix atp (#127)
Browse files Browse the repository at this point in the history
* type cast stderr writes as byte strings encoded with utf

* fix race condition in test_atp.py by waiting for a condition variable instead of an exit event
  • Loading branch information
mfleader authored Apr 2, 2024
1 parent b4e163a commit f453b25
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 40 deletions.
8 changes: 5 additions & 3 deletions src/arcaflow_plugin_sdk/atp.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,14 @@ def run_server_read_loop(self) -> None:
server_fatal=False,
error_msg=f"Unknown runtime message ID: {msg_id}",
)
self.stderr.write(
f"Unknown kind of runtime message: {msg_id}"
self.stderr.write(bytes(
f"Unknown kind of runtime message: {msg_id}",
encoding="utf")
)

except cbor2.CBORDecodeError as err:
self.stderr.write(f"Error while decoding CBOR: {err}")
self.stderr.write(bytes(
f"Error while decoding CBOR: {err}", encoding="utf"))
self.send_error_message(
"",
step_fatal=False,
Expand Down
75 changes: 38 additions & 37 deletions src/arcaflow_plugin_sdk/test_atp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import signal
import time
import unittest
from threading import Event
from threading import Condition, Lock
from typing import List, TextIO, Tuple, Union

from arcaflow_plugin_sdk import atp, plugin, schema
Expand All @@ -25,30 +25,31 @@ class Output:
description="Says hello :)",
outputs={"success": Output},
)
def hello_world(params: Input) -> Tuple[str, Union[Output]]:
def hello_world(params: Input) -> Tuple[str, Output]:
print("Hello world!")
return "success", Output("Hello, {}!".format(params.name))


# noinspection PyTypeChecker
@plugin.step(
id="hello-world-broken",
name="Broken!",
description="Throws an exception with the text 'abcde'",
outputs={"success": Output},
)
def hello_world_broken(_: Input) -> Tuple[str, Union[Output]]:
def hello_world_broken(_: Input) -> Tuple[str, Output]:
print("Hello world!")
raise Exception("abcde")


@dataclasses.dataclass
class StepTestInput:
wait_time_seconds: float
expected_signal_count: int


@dataclasses.dataclass
class SignalTestInput:
final: bool # The last one will trigger the end of the step.
value: int


Expand All @@ -59,14 +60,16 @@ class SignalTestOutput:

class SignalTestStep:
signal_values: List[int]
exit_event: Event
exit_condition: Condition
lock: Lock

def __init__(self):
# Due to the way Python works, this MUST be done here, and not inlined
# above, or else it will be shared by all objects, resulting in a
# shared list and event, which would cause problems.
self.signal_values = []
self.exit_event = Event()
self.lock = Lock()
self.exit_condition = Condition(self.lock)

@plugin.step_with_signals(
id="signal_test_step",
Expand All @@ -80,7 +83,11 @@ def __init__(self):
def signal_test_step(
self, params: StepTestInput
) -> Tuple[str, Union[SignalTestOutput]]:
self.exit_event.wait(params.wait_time_seconds)
with self.exit_condition:
self.exit_condition.wait_for(
lambda:
len(self.signal_values) >= params.expected_signal_count,
timeout=params.wait_time_seconds)
return "success", SignalTestOutput(self.signal_values)

@plugin.signal_handler(
Expand All @@ -92,12 +99,11 @@ def signal_test_step(
),
)
def signal_test_signal_handler(self, signal_input: SignalTestInput):
if signal_input.value < 0:
self.exit_event.set()
raise Exception("Value below zero.")
self.signal_values.append(signal_input.value)
if signal_input.final:
self.exit_event.set()
with self.exit_condition:
if signal_input.value < 0:
raise Exception(f"Value below zero: {signal_input.value}")
self.signal_values.append(signal_input.value)
self.exit_condition.notify()


test_schema = plugin.build_schema(hello_world)
Expand Down Expand Up @@ -199,30 +205,31 @@ def test_step_with_signals(self):
)

client.start_work(
self.id(), "signal_test_step", {"wait_time_seconds": "5"}
self.id(), "signal_test_step",
{"wait_time_seconds": 5.0, "expected_signal_count": 3}
)
client.send_signal(
self.id(),
"record_value",
{"final": "false", "value": "1"},
{"value": 1},
)
client.send_signal(
self.id(),
"record_value",
{"final": "false", "value": "2"},
{"value": 2},
)
client.send_signal(
self.id(),
"record_value",
{"final": "true", "value": "3"},
{"value": 3},
)
result = client.read_single_result()
self.assertEqual(result.run_id, self.id())
client.send_client_done()
self.assertEqual(result.debug_logs, "")
self.assertEqual(result.output_id, "success")
self.assertListEqual(
result.output_data["signals_received"], [1, 2, 3]
sorted(result.output_data["signals_received"]), [1, 2, 3]
)
finally:
self._cleanup(pid, stdin_writer, stdout_reader)
Expand Down Expand Up @@ -250,27 +257,29 @@ def test_multi_step_with_signals(self):
step_b_id = self.id() + "_b"

client.start_work(
step_a_id, "signal_test_step", {"wait_time_seconds": "5"}
step_a_id, "signal_test_step",
{"wait_time_seconds": 5.0, "expected_signal_count": 2}
)
client.start_work(
step_b_id, "signal_test_step", {"wait_time_seconds": "5"}
step_b_id, "signal_test_step",
{"wait_time_seconds": 5.0, "expected_signal_count": 1}
)
client.send_signal(
step_a_id,
"record_value",
{"final": "false", "value": "1"},
{"value": 1},
)
client.send_signal(
step_b_id,
"record_value",
{"final": "true", "value": "2"},
{"value": 2},
)
step_b_result = client.read_single_result()

client.send_signal(
step_a_id,
"record_value",
{"final": "true", "value": "3"},
{"value": 3},
)
step_a_result = client.read_single_result()
client.send_client_done()
Expand Down Expand Up @@ -350,6 +359,7 @@ def test_invalid_runtime_message_id(self):
client.start_output()
client.read_hello()

# noinspection PyTypeChecker
client.send_runtime_message(1000, "", "")

with self.assertRaises(atp.PluginClientStateException) as context:
Expand Down Expand Up @@ -380,30 +390,21 @@ def test_error_in_signal(self):
)

client.start_work(
self.id(), "signal_test_step", {"wait_time_seconds": "5"}
self.id(), "signal_test_step",
{"wait_time_seconds": 5.0, "expected_signal_count": 1}
)
client.send_signal(
self.id(),
"record_value",
{"final": "false", "value": "1"},
{"value": -1},
)
client.send_signal(
self.id(),
"record_value",
{"final": "false", "value": "-1"},
)
result = client.read_single_result()
self.assertEqual(result.run_id, self.id())
self.assertEqual(result.debug_logs, "")
self.assertEqual(result.output_id, "success")
self.assertListEqual(result.output_data["signals_received"], [1])

# Note: The exception is raised after the step finishes in the test
# class
with self.assertRaises(atp.PluginClientStateException) as context:
_, _, _, _ = client.read_single_result()
client.read_single_result()
client.send_client_done()
self.assertIn("Value below zero.", str(context.exception))
self.assertIn("Value below zero: -1", str(context.exception))

finally:
self._cleanup(pid, stdin_writer, stdout_reader, True)
Expand Down

0 comments on commit f453b25

Please sign in to comment.