Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug fix atp #127

Merged
merged 8 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/arcaflow_plugin_sdk/atp.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,14 @@ def run_server_read_loop(self) -> None:
server_fatal=False,
error_msg=f"Unknown runtime message ID: {msg_id}",
)
self.stderr.write(
f"Unknown kind of runtime message: {msg_id}"
self.stderr.write(bytes(
f"Unknown kind of runtime message: {msg_id}",
encoding="utf")
)

except cbor2.CBORDecodeError as err:
self.stderr.write(f"Error while decoding CBOR: {err}")
self.stderr.write(bytes(
f"Error while decoding CBOR: {err}", encoding="utf"))
self.send_error_message(
"",
step_fatal=False,
Expand Down
73 changes: 37 additions & 36 deletions src/arcaflow_plugin_sdk/test_atp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import signal
import time
import unittest
from threading import Event
from threading import Condition, Lock
from typing import List, TextIO, Tuple, Union

from arcaflow_plugin_sdk import atp, plugin, schema
Expand All @@ -30,25 +30,26 @@ def hello_world(params: Input) -> Tuple[str, Union[Output]]:
return "success", Output("Hello, {}!".format(params.name))


# noinspection PyTypeChecker
@plugin.step(
id="hello-world-broken",
name="Broken!",
description="Throws an exception with the text 'abcde'",
outputs={"success": Output},
)
def hello_world_broken(_: Input) -> Tuple[str, Union[Output]]:
def hello_world_broken(_: Input) -> Tuple[str, Output]:
print("Hello world!")
raise Exception("abcde")


@dataclasses.dataclass
class StepTestInput:
wait_time_seconds: float
expected_signal_count: int


@dataclasses.dataclass
class SignalTestInput:
final: bool # The last one will trigger the end of the step.
value: int


Expand All @@ -59,14 +60,16 @@ class SignalTestOutput:

class SignalTestStep:
signal_values: List[int]
exit_event: Event
exit_condition: Condition
lock: Lock

def __init__(self):
# Due to the way Python works, this MUST be done here, and not inlined
# above, or else it will be shared by all objects, resulting in a
# shared list and event, which would cause problems.
self.signal_values = []
self.exit_event = Event()
self.lock = Lock()
self.exit_condition = Condition(self.lock)

@plugin.step_with_signals(
id="signal_test_step",
Expand All @@ -80,7 +83,11 @@ def __init__(self):
def signal_test_step(
self, params: StepTestInput
) -> Tuple[str, Union[SignalTestOutput]]:
self.exit_event.wait(params.wait_time_seconds)
with self.exit_condition:
self.exit_condition.wait_for(
lambda:
len(self.signal_values) >= params.expected_signal_count,
timeout=params.wait_time_seconds)
return "success", SignalTestOutput(self.signal_values)

@plugin.signal_handler(
Expand All @@ -92,12 +99,11 @@ def signal_test_step(
),
)
def signal_test_signal_handler(self, signal_input: SignalTestInput):
if signal_input.value < 0:
self.exit_event.set()
raise Exception("Value below zero.")
self.signal_values.append(signal_input.value)
if signal_input.final:
self.exit_event.set()
with self.exit_condition:
if signal_input.value < 0:
raise Exception(f"Value below zero: {signal_input.value}")
self.signal_values.append(signal_input.value)
self.exit_condition.notify()


test_schema = plugin.build_schema(hello_world)
Expand Down Expand Up @@ -199,30 +205,31 @@ def test_step_with_signals(self):
)

client.start_work(
self.id(), "signal_test_step", {"wait_time_seconds": "5"}
self.id(), "signal_test_step",
{"wait_time_seconds": "5", "expected_signal_count": "3"}
)
client.send_signal(
self.id(),
"record_value",
{"final": "false", "value": "1"},
{"value": "1"},
)
client.send_signal(
self.id(),
"record_value",
{"final": "false", "value": "2"},
{"value": "2"},
)
client.send_signal(
self.id(),
"record_value",
{"final": "true", "value": "3"},
{"value": "3"},
)
result = client.read_single_result()
self.assertEqual(result.run_id, self.id())
client.send_client_done()
self.assertEqual(result.debug_logs, "")
self.assertEqual(result.output_id, "success")
self.assertListEqual(
result.output_data["signals_received"], [1, 2, 3]
sorted(result.output_data["signals_received"]), [1, 2, 3]
)
finally:
self._cleanup(pid, stdin_writer, stdout_reader)
Expand Down Expand Up @@ -250,27 +257,29 @@ def test_multi_step_with_signals(self):
step_b_id = self.id() + "_b"

client.start_work(
step_a_id, "signal_test_step", {"wait_time_seconds": "5"}
step_a_id, "signal_test_step",
{"wait_time_seconds": 5.0, "expected_signal_count": 2}
)
client.start_work(
step_b_id, "signal_test_step", {"wait_time_seconds": "5"}
step_b_id, "signal_test_step",
{"wait_time_seconds": 5.0, "expected_signal_count": 1}
)
client.send_signal(
step_a_id,
"record_value",
{"final": "false", "value": "1"},
{"value": "1"},
)
client.send_signal(
step_b_id,
"record_value",
{"final": "true", "value": "2"},
{"value": "2"},
)
step_b_result = client.read_single_result()

client.send_signal(
step_a_id,
"record_value",
{"final": "true", "value": "3"},
{"value": "3"},
)
step_a_result = client.read_single_result()
client.send_client_done()
Expand Down Expand Up @@ -350,6 +359,7 @@ def test_invalid_runtime_message_id(self):
client.start_output()
client.read_hello()

# noinspection PyTypeChecker
client.send_runtime_message(1000, "", "")

with self.assertRaises(atp.PluginClientStateException) as context:
Expand Down Expand Up @@ -380,30 +390,21 @@ def test_error_in_signal(self):
)

client.start_work(
self.id(), "signal_test_step", {"wait_time_seconds": "5"}
self.id(), "signal_test_step",
{"wait_time_seconds": 5.0, "expected_signal_count": 1}
)
client.send_signal(
self.id(),
"record_value",
{"final": "false", "value": "1"},
{"value": "-1"},
)
client.send_signal(
self.id(),
"record_value",
{"final": "false", "value": "-1"},
)
result = client.read_single_result()
self.assertEqual(result.run_id, self.id())
self.assertEqual(result.debug_logs, "")
self.assertEqual(result.output_id, "success")
self.assertListEqual(result.output_data["signals_received"], [1])

# Note: The exception is raised after the step finishes in the test
# class
with self.assertRaises(atp.PluginClientStateException) as context:
_, _, _, _ = client.read_single_result()
client.read_single_result()
client.send_client_done()
self.assertIn("Value below zero.", str(context.exception))
self.assertIn("Value below zero: -1", str(context.exception))

finally:
self._cleanup(pid, stdin_writer, stdout_reader, True)
Expand Down