diff --git a/tests/distributed/omni_connectors/test_chunk_transfer_adapter.py b/tests/distributed/omni_connectors/test_chunk_transfer_adapter.py index b8c1761cdef..87430fd7a96 100644 --- a/tests/distributed/omni_connectors/test_chunk_transfer_adapter.py +++ b/tests/distributed/omni_connectors/test_chunk_transfer_adapter.py @@ -199,37 +199,53 @@ def test_send_single_request_cleans_up_after_finished_payload(build_adapter, mon assert args[1] == "ext-finished" -def test_update_request_payload(build_adapter): - adapter, _ = build_adapter() - - first: OmniPayload = { +def test_load_poll_non_ar_merges_into_existing_additional_information(build_adapter): + adapter, connector = build_adapter(stage_id=2, model_mode="diffusion") + request = _req("req-non-ar", RequestStatus.WAITING, external_req_id="ext-non-ar") + request.additional_information = { "hidden_states": {"output": torch.tensor([[1.0]])}, - "codes": {"audio": [1]}, - "meta": {"finished": torch.tensor(False, dtype=torch.bool)}, + "ids": {"prompt": [11, 12]}, + "meta": {"finished": torch.tensor(False, dtype=torch.bool), "step": 1}, } - adapter._update_request_payload("ext", first) - second: OmniPayload = { + request.num_computed_tokens = 9 + + payload: OmniPayload = { "hidden_states": {"output": torch.tensor([[2.0]])}, - "codes": {"audio": [2]}, - "meta": {"finished": torch.tensor(True, dtype=torch.bool)}, + "ids": {"all": [21, 22]}, + "codes": {"audio": torch.tensor([7, 8], dtype=torch.long)}, + "meta": {"finished": torch.tensor(True, dtype=torch.bool), "phase": "decode"}, + "kv_metadata": {"foo": "bar"}, } - merged = adapter._update_request_payload("ext", second) + connector.get.return_value = (payload, 8) - assert torch.equal(merged["hidden_states"]["output"], torch.tensor([[1.0], [2.0]])) - assert merged["codes"]["audio"] == [1, 2] - assert merged["meta"]["finished"].item() is True + assert adapter._poll_single_request(request) is True + + assert request.prompt_token_ids == [7, 8] + assert request.num_computed_tokens == 0 + assert torch.equal( + request.additional_information["hidden_states"]["output"], + torch.tensor([[2.0]]), + ) + assert request.additional_information["ids"]["prompt"] == [11, 12] + assert request.additional_information["ids"]["all"] == [21, 22] + # non-ar merge path intentionally doesn't overwrite meta.finished. + assert request.additional_information["meta"]["finished"].item() is False + assert request.additional_information["meta"]["phase"] == "decode" + assert request.additional_information["kv_metadata"] == {"foo": "bar"} + assert "req-non-ar" in adapter._finished_load_reqs + assert "req-non-ar" in adapter.finished_requests def test_load_poll_ar_request_additional_information_concats_tensors(build_adapter): adapter, connector = build_adapter(stage_id=2, model_mode="ar") request = _req("req-merged", RequestStatus.WAITING, external_req_id="ext-merged") - - adapter.request_ids_mapping["req-merged"] = "ext-merged" - adapter.request_payload["ext-merged"] = { + request.additional_information = { "hidden_states": {"output": torch.tensor([[1.0]])}, "ids": {"prompt": [11, 12]}, "meta": {"finished": torch.tensor(False, dtype=torch.bool)}, } + + adapter.request_ids_mapping["req-merged"] = "ext-merged" payload: OmniPayload = { "hidden_states": {"output": torch.tensor([[2.0]])}, "meta": {"finished": torch.tensor(True, dtype=torch.bool)}, @@ -238,12 +254,8 @@ def test_load_poll_ar_request_additional_information_concats_tensors(build_adapt adapter._poll_single_request(request) - assert torch.equal( - request.additional_information["hidden_states"]["output"], - torch.tensor([[1.0], [2.0]]), - ) - # Keys absent from the new chunk are dropped (matches main's behavior). - assert "ids" not in request.additional_information + # AR mode now forwards the latest payload directly. + assert request.additional_information == payload assert request.additional_information["meta"]["finished"].item() is True