Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 35 additions & 23 deletions tests/distributed/omni_connectors/test_chunk_transfer_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,37 +199,53 @@ def test_send_single_request_cleans_up_after_finished_payload(build_adapter, mon
assert args[1] == "ext-finished"


def test_update_request_payload(build_adapter):
adapter, _ = build_adapter()

first: OmniPayload = {
def test_load_poll_non_ar_merges_into_existing_additional_information(build_adapter):
adapter, connector = build_adapter(stage_id=2, model_mode="diffusion")
request = _req("req-non-ar", RequestStatus.WAITING, external_req_id="ext-non-ar")
request.additional_information = {
"hidden_states": {"output": torch.tensor([[1.0]])},
"codes": {"audio": [1]},
"meta": {"finished": torch.tensor(False, dtype=torch.bool)},
"ids": {"prompt": [11, 12]},
"meta": {"finished": torch.tensor(False, dtype=torch.bool), "step": 1},
}
adapter._update_request_payload("ext", first)
second: OmniPayload = {
request.num_computed_tokens = 9

payload: OmniPayload = {
"hidden_states": {"output": torch.tensor([[2.0]])},
"codes": {"audio": [2]},
"meta": {"finished": torch.tensor(True, dtype=torch.bool)},
"ids": {"all": [21, 22]},
"codes": {"audio": torch.tensor([7, 8], dtype=torch.long)},
"meta": {"finished": torch.tensor(True, dtype=torch.bool), "phase": "decode"},
"kv_metadata": {"foo": "bar"},
}
merged = adapter._update_request_payload("ext", second)
connector.get.return_value = (payload, 8)

assert torch.equal(merged["hidden_states"]["output"], torch.tensor([[1.0], [2.0]]))
assert merged["codes"]["audio"] == [1, 2]
assert merged["meta"]["finished"].item() is True
assert adapter._poll_single_request(request) is True

assert request.prompt_token_ids == [7, 8]
assert request.num_computed_tokens == 0
assert torch.equal(
request.additional_information["hidden_states"]["output"],
torch.tensor([[2.0]]),
)
assert request.additional_information["ids"]["prompt"] == [11, 12]
assert request.additional_information["ids"]["all"] == [21, 22]
# non-ar merge path intentionally doesn't overwrite meta.finished.
assert request.additional_information["meta"]["finished"].item() is False
assert request.additional_information["meta"]["phase"] == "decode"
assert request.additional_information["kv_metadata"] == {"foo": "bar"}
assert "req-non-ar" in adapter._finished_load_reqs
assert "req-non-ar" in adapter.finished_requests


def test_load_poll_ar_request_additional_information_concats_tensors(build_adapter):
adapter, connector = build_adapter(stage_id=2, model_mode="ar")
request = _req("req-merged", RequestStatus.WAITING, external_req_id="ext-merged")

adapter.request_ids_mapping["req-merged"] = "ext-merged"
adapter.request_payload["ext-merged"] = {
request.additional_information = {
"hidden_states": {"output": torch.tensor([[1.0]])},
"ids": {"prompt": [11, 12]},
"meta": {"finished": torch.tensor(False, dtype=torch.bool)},
}

adapter.request_ids_mapping["req-merged"] = "ext-merged"
payload: OmniPayload = {
"hidden_states": {"output": torch.tensor([[2.0]])},
"meta": {"finished": torch.tensor(True, dtype=torch.bool)},
Expand All @@ -238,12 +254,8 @@ def test_load_poll_ar_request_additional_information_concats_tensors(build_adapt

adapter._poll_single_request(request)

assert torch.equal(
request.additional_information["hidden_states"]["output"],
torch.tensor([[1.0], [2.0]]),
)
# Keys absent from the new chunk are dropped (matches main's behavior).
assert "ids" not in request.additional_information
# AR mode now forwards the latest payload directly.
assert request.additional_information == payload
assert request.additional_information["meta"]["finished"].item() is True


Expand Down
Loading