feat: refit refactoring with zmq and overlapping#1267
Conversation
|
📝 WalkthroughWalkthroughSwitches colocated refit from grouped CUDA-IPC handles to ZMQ-based streaming. Introduces ZMQ sockets on policy (REQ) and generation workers (REP), streams packed weights in aligned buffers, and awaits update futures. Refactors interfaces and vLLM paths accordingly; retains NCCL broadcast for non-colocated mode. Adds memory querying utilities and zmq dependency. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Trainer as Trainer (GRPO)
participant Policy as Policy Worker(s) [REQ]
participant Gen as Generation Worker(s) [REP]
note over Trainer,Gen: Colocated refit via ZMQ streaming
Trainer->>Policy: stream_weights_via_ipc_zmq(buffer_size_bytes)
Trainer->>Gen: update_weights_via_ipc_zmq()
par Stream batches
loop For each packed buffer
Policy->>Gen: ZMQ SEND(buffer bytes, metadata)
Gen-->>Policy: ZMQ RECV(ack 0)
Gen->>Gen: Rebuild tensors, load weights, sync stream
end
Policy->>Gen: ZMQ SEND("complete")
Gen-->>Policy: ZMQ RECV(ack "done")
end
Gen-->>Trainer: futures resolved (success flags)
Policy-->>Trainer: futures resolved
Trainer->>Trainer: Validate all successes
alt Any failure
Trainer-->>Trainer: Raise refit update error
end
sequenceDiagram
participant Trainer as Trainer (GRPO)
participant Policy as Policy
participant Gen as Generation
note over Trainer,Gen: Non-colocated path unchanged
Trainer->>Policy: broadcast_weights_for_collective()
Policy-->>Gen: NCCL broadcast
Gen-->>Trainer: success
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60–90 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 5
🧹 Nitpick comments (8)
nemo_rl/models/generation/interfaces.py (1)
236-238: Keep the IPC update contract abstract
GenerationInterfacepreviously forced subclasses to implement the IPC update hook. Dropping@abstractmethodallows concrete classes to instantiate without overriding, only to fail later when the method is invoked. Please markupdate_weights_via_ipc_zmqas abstract (as before) to enforce the contract at construction time.- def update_weights_via_ipc_zmq(self) -> list[ray.ObjectRef]: + @abstractmethod + def update_weights_via_ipc_zmq(self) -> list[ray.ObjectRef]:nemo_rl/models/generation/vllm/vllm_worker_async.py (1)
753-763: Refresh the docstring to match the new signatureThe method no longer accepts
ipc_handles, but the docstring still documents that parameter. Please align the docstring with the current API to avoid confusing downstream readers.- async def update_weights_via_ipc_zmq_async( - self, - ) -> bool: - """Async version of update_weights_from_ipc_handles. - - Args: - ipc_handles (dict): Dictionary mapping device UUIDs (str) to parameter IPC handles. + async def update_weights_via_ipc_zmq_async(self) -> bool: + """Async counterpart of ``update_weights_via_ipc_zmq`` on the sync worker. Returns: bool: True if weights were successfully updated, False otherwise. """nemo_rl/models/policy/dtensor_policy_worker.py (2)
1718-1724: Return type mismatch: cast NVML free memory to int.nvml.get_free_memory_bytes returns float; method is annotated to return int.
- return get_free_memory_bytes(device_idx) + return int(get_free_memory_bytes(device_idx))
1726-1751: Make streamed tensors contiguous and validate buffer size.
- Add an assert to catch unintended zero-sized buffers.
- Ensure tensors are contiguous before packing to avoid view/copy pitfalls.
def stream_weights_via_ipc_zmq(self, buffer_size_bytes: int = 0) -> None: self.maybe_init_zmq() @@ - def dtensor_params_generator(): + def dtensor_params_generator(): """Generator that yields (name, tensor) pairs, converting DTensors to local tensors.""" for name, tensor in self.model.state_dict().items(): if isinstance(tensor, DTensor): # Convert DTensor to full tensor for streaming - full_tensor = tensor.full_tensor() - # Convert to target dtype - yield name, full_tensor.to(self.dtype, non_blocking=True) + full_tensor = tensor.full_tensor() + # Convert to target dtype and ensure contiguity + yield name, full_tensor.to(self.dtype, non_blocking=True).contiguous() else: - # Convert to target dtype - yield name, tensor.to(self.dtype, non_blocking=True) + # Convert to target dtype and ensure contiguity + yield name, tensor.to(self.dtype, non_blocking=True).contiguous() @@ - stream_weights_via_ipc_zmq_impl( + assert buffer_size_bytes > 0, "buffer_size_bytes must be > 0" + stream_weights_via_ipc_zmq_impl( params_generator=dtensor_params_generator(), buffer_size_bytes=buffer_size_bytes, zmq_socket=self.zmq_socket, rank=self.rank, worker_name=str(self), )nemo_rl/models/policy/lm_policy.py (1)
659-665: Cast free memory to int for annotated return type.- free_memory_bytes = min(ray.get(future) for future in futures) - return free_memory_bytes + free_memory_bytes = min(ray.get(future) for future in futures) + return int(free_memory_bytes)nemo_rl/models/generation/vllm/vllm_generation.py (1)
751-779: Docstring and return type drift; clarify no-arg API and futures return.Update to reflect ZMQ streaming and that this returns futures, not bool.
- def update_weights_via_ipc_zmq(self) -> list[ray.ObjectRef]: - """Update weights of the policy using IPC handles, considering tensor parallelism. - - For tp > 1, only the leader in each tensor parallel tied worker group will update weights. - - Args: - ipc_handles (dict): Dictionary mapping device UUIDs (str) to parameter IPC handles. - - Returns: - bool: True if weights were successfully updated, False otherwise. - """ + def update_weights_via_ipc_zmq(self) -> list[ray.ObjectRef]: + """Kick off ZMQ-based IPC weight updates on workers. + Returns ray futures; caller should await them.""" @@ - # Use run_all_workers_single_data since no data needs to be passed + # Use run_all_workers_single_data since no data needs to be passed futures = self.worker_group.run_all_workers_single_data( method_name, run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"], ) - # this function should co-work with lm_policy, so we should wait for all futures to complete outside + # Wait for completion outside to coordinate with policy side return futuresnemo_rl/models/generation/vllm/vllm_worker.py (2)
701-739: Align NVTX tag and error message with new API name; keep collective_rpc call.- @wrap_with_nvtx_name("vllm_genertion_worker/update_weights_from_ipc_handles") + @wrap_with_nvtx_name("vllm_generation_worker/update_weights_via_ipc_zmq") def update_weights_via_ipc_zmq(self) -> bool: @@ - if self.cfg["vllm_cfg"]["async_engine"]: - raise RuntimeError( - "update_weights_from_ipc_handles cannot be used with async_engine=True. Use update_weights_from_ipc_handles_async instead." - ) + if self.cfg["vllm_cfg"]["async_engine"]: + raise RuntimeError( + "update_weights_via_ipc_zmq cannot be used with async_engine=True. Use update_weights_via_ipc_zmq_async instead." + ) @@ - result_or_coro = self.llm.collective_rpc( - "update_weights_via_ipc_zmq", - args=tuple(), - ) + result_or_coro = self.llm.collective_rpc("update_weights_via_ipc_zmq", args=tuple())
727-733: Propagate worker failure context or raise.Current print+False hides errors. Consider raising RuntimeError to fail fast in orchestration paths.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
⛔ Files ignored due to path filters (1)
uv.lockis excluded by!**/*.lock
📒 Files selected for processing (12)
nemo_rl/algorithms/grpo.py(1 hunks)nemo_rl/models/generation/interfaces.py(1 hunks)nemo_rl/models/generation/vllm/vllm_backend.py(2 hunks)nemo_rl/models/generation/vllm/vllm_generation.py(2 hunks)nemo_rl/models/generation/vllm/vllm_worker.py(2 hunks)nemo_rl/models/generation/vllm/vllm_worker_async.py(2 hunks)nemo_rl/models/policy/dtensor_policy_worker.py(2 hunks)nemo_rl/models/policy/interfaces.py(1 hunks)nemo_rl/models/policy/lm_policy.py(1 hunks)nemo_rl/models/policy/megatron_policy_worker.py(5 hunks)nemo_rl/models/policy/utils.py(2 hunks)pyproject.toml(1 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts
Files:
nemo_rl/models/generation/interfaces.pynemo_rl/models/generation/vllm/vllm_worker_async.pynemo_rl/models/generation/vllm/vllm_generation.pynemo_rl/algorithms/grpo.pynemo_rl/models/generation/vllm/vllm_worker.pynemo_rl/models/policy/lm_policy.pynemo_rl/models/generation/vllm/vllm_backend.pynemo_rl/models/policy/utils.pynemo_rl/models/policy/dtensor_policy_worker.pynemo_rl/models/policy/megatron_policy_worker.pynemo_rl/models/policy/interfaces.py
nemo_rl/**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
nemo_rl/**/*.py: Do not set non-None configuration defaults in code; YAML is the single source of truth for defaults
Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Express configuration optionality via TypedDict using typing.NotRequired
When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
For any class or function decorated with @ray.remote, add '# pragma: no cover' on the class/def line (and on remote functions)
Files:
nemo_rl/models/generation/interfaces.pynemo_rl/models/generation/vllm/vllm_worker_async.pynemo_rl/models/generation/vllm/vllm_generation.pynemo_rl/algorithms/grpo.pynemo_rl/models/generation/vllm/vllm_worker.pynemo_rl/models/policy/lm_policy.pynemo_rl/models/generation/vllm/vllm_backend.pynemo_rl/models/policy/utils.pynemo_rl/models/policy/dtensor_policy_worker.pynemo_rl/models/policy/megatron_policy_worker.pynemo_rl/models/policy/interfaces.py
🧬 Code graph analysis (9)
nemo_rl/models/generation/interfaces.py (3)
nemo_rl/models/generation/vllm/vllm_backend.py (1)
update_weights_via_ipc_zmq(84-161)nemo_rl/models/generation/vllm/vllm_generation.py (1)
update_weights_via_ipc_zmq(751-779)nemo_rl/models/generation/vllm/vllm_worker.py (1)
update_weights_via_ipc_zmq(702-738)
nemo_rl/models/generation/vllm/vllm_generation.py (4)
nemo_rl/models/generation/interfaces.py (1)
update_weights_via_ipc_zmq(236-238)nemo_rl/models/generation/vllm/vllm_backend.py (1)
update_weights_via_ipc_zmq(84-161)nemo_rl/models/generation/vllm/vllm_worker.py (1)
update_weights_via_ipc_zmq(702-738)nemo_rl/distributed/worker_groups.py (1)
run_all_workers_single_data(728-772)
nemo_rl/algorithms/grpo.py (9)
nemo_rl/models/policy/dtensor_policy_worker.py (2)
get_free_memory_bytes(1718-1723)stream_weights_via_ipc_zmq(1727-1751)nemo_rl/models/policy/lm_policy.py (2)
get_free_memory_bytes(659-664)stream_weights_via_ipc_zmq(666-671)nemo_rl/models/policy/megatron_policy_worker.py (2)
get_free_memory_bytes(1625-1630)stream_weights_via_ipc_zmq(1634-1653)nemo_rl/utils/nvml.py (1)
get_free_memory_bytes(80-90)nemo_rl/models/policy/interfaces.py (1)
stream_weights_via_ipc_zmq(161-164)nemo_rl/models/generation/interfaces.py (1)
update_weights_via_ipc_zmq(236-238)nemo_rl/models/generation/vllm/vllm_backend.py (1)
update_weights_via_ipc_zmq(84-161)nemo_rl/models/generation/vllm/vllm_generation.py (1)
update_weights_via_ipc_zmq(751-779)nemo_rl/models/generation/vllm/vllm_worker.py (1)
update_weights_via_ipc_zmq(702-738)
nemo_rl/models/generation/vllm/vllm_worker.py (3)
nemo_rl/models/generation/interfaces.py (1)
update_weights_via_ipc_zmq(236-238)nemo_rl/models/generation/vllm/vllm_backend.py (1)
update_weights_via_ipc_zmq(84-161)nemo_rl/models/generation/vllm/vllm_generation.py (1)
update_weights_via_ipc_zmq(751-779)
nemo_rl/models/policy/lm_policy.py (4)
nemo_rl/models/policy/dtensor_policy_worker.py (2)
get_free_memory_bytes(1718-1723)stream_weights_via_ipc_zmq(1727-1751)nemo_rl/models/policy/megatron_policy_worker.py (2)
get_free_memory_bytes(1625-1630)stream_weights_via_ipc_zmq(1634-1653)nemo_rl/distributed/worker_groups.py (1)
run_all_workers_single_data(728-772)nemo_rl/models/policy/interfaces.py (1)
stream_weights_via_ipc_zmq(161-164)
nemo_rl/models/generation/vllm/vllm_backend.py (8)
nemo_rl/models/policy/dtensor_policy_worker.py (4)
get_zmq_address(1699-1701)report_device_id(1686-1697)maybe_init_zmq(1703-1708)prepare_refit_info(1711-1716)nemo_rl/models/policy/megatron_policy_worker.py (4)
get_zmq_address(1545-1547)report_device_id(1532-1543)maybe_init_zmq(1549-1554)prepare_refit_info(1558-1572)nemo_rl/models/generation/vllm/vllm_worker.py (3)
report_device_id(681-695)prepare_refit_info(697-699)update_weights_via_ipc_zmq(702-738)nemo_rl/models/generation/interfaces.py (2)
prepare_refit_info(232-234)update_weights_via_ipc_zmq(236-238)nemo_rl/models/generation/vllm/vllm_generation.py (2)
prepare_refit_info(732-749)update_weights_via_ipc_zmq(751-779)nemo_rl/utils/nsys.py (1)
wrap_with_nvtx_name(82-94)nemo_rl/models/policy/utils.py (1)
calculate_aligned_size(256-266)nemo_rl/models/generation/fp8.py (2)
is_fp8_model(202-213)load_weights(289-316)
nemo_rl/models/policy/dtensor_policy_worker.py (7)
nemo_rl/models/generation/vllm/vllm_backend.py (4)
get_zmq_address(57-59)report_device_id(52-55)maybe_init_zmq(61-72)prepare_refit_info(74-81)nemo_rl/models/policy/megatron_policy_worker.py (6)
get_zmq_address(1545-1547)report_device_id(1532-1543)maybe_init_zmq(1549-1554)prepare_refit_info(1558-1572)get_free_memory_bytes(1625-1630)stream_weights_via_ipc_zmq(1634-1653)nemo_rl/models/policy/interfaces.py (2)
prepare_refit_info(157-158)stream_weights_via_ipc_zmq(161-164)nemo_rl/models/policy/lm_policy.py (3)
prepare_refit_info(648-657)get_free_memory_bytes(659-664)stream_weights_via_ipc_zmq(666-671)nemo_rl/utils/nvml.py (1)
get_free_memory_bytes(80-90)nemo_rl/utils/nsys.py (1)
wrap_with_nvtx_name(82-94)nemo_rl/models/policy/utils.py (1)
stream_weights_via_ipc_zmq_impl(269-384)
nemo_rl/models/policy/megatron_policy_worker.py (3)
nemo_rl/models/policy/dtensor_policy_worker.py (5)
get_zmq_address(1699-1701)report_device_id(1686-1697)maybe_init_zmq(1703-1708)get_free_memory_bytes(1718-1723)stream_weights_via_ipc_zmq(1727-1751)nemo_rl/models/policy/lm_policy.py (2)
get_free_memory_bytes(659-664)stream_weights_via_ipc_zmq(666-671)nemo_rl/models/policy/utils.py (1)
stream_weights_via_ipc_zmq_impl(269-384)
nemo_rl/models/policy/interfaces.py (3)
nemo_rl/models/policy/dtensor_policy_worker.py (1)
stream_weights_via_ipc_zmq(1727-1751)nemo_rl/models/policy/lm_policy.py (1)
stream_weights_via_ipc_zmq(666-671)nemo_rl/models/policy/megatron_policy_worker.py (1)
stream_weights_via_ipc_zmq(1634-1653)
🪛 Ruff (0.13.2)
nemo_rl/models/generation/vllm/vllm_generation.py
763-763: Avoid specifying long messages outside the exception class
(TRY003)
nemo_rl/models/generation/vllm/vllm_backend.py
155-155: Consider moving this statement to an else block
(TRY300)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Lint check
- GitHub Check: Post submodule check comment / Comment on PR
🔇 Additional comments (2)
nemo_rl/models/policy/dtensor_policy_worker.py (1)
1711-1716: Ensure dtype metadata matches streamed dtype across workers.You record
(shape, self.dtype). That’s fine only if the streaming generator also casts toself.dtypefor every tensor, and the receiver reconstructs using the same dtype. Please confirm MegatronPolicyWorker.prepare_refit_info also uses the same dtype to avoid size mismatches.nemo_rl/models/policy/lm_policy.py (1)
666-672: LGTM: streaming API returns futures as intended.
ℹ️ File Consistency CheckCheck based on commit: 0d395af (PR #1267 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
0d395af to
3764bbd
Compare
ℹ️ File Consistency CheckCheck based on commit: 3764bbd (PR #1267 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
ℹ️ File Consistency CheckCheck based on commit: 73cfc13 (PR #1267 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
|
@coderabbitai generate docstrings |
✅ Actions performedDocstrings generation started. |
Docstrings generation was requested by @ZhiyuLi-Nvidia. * #1267 (comment) The following files were modified: * `nemo_rl/algorithms/grpo.py` * `nemo_rl/models/generation/interfaces.py` * `nemo_rl/models/generation/vllm/vllm_backend.py` * `nemo_rl/models/generation/vllm/vllm_generation.py` * `nemo_rl/models/generation/vllm/vllm_worker.py` * `nemo_rl/models/generation/vllm/vllm_worker_async.py` * `nemo_rl/models/policy/dtensor_policy_worker.py` * `nemo_rl/models/policy/dtensor_policy_worker_v2.py` * `nemo_rl/models/policy/interfaces.py` * `nemo_rl/models/policy/lm_policy.py` * `nemo_rl/models/policy/megatron_policy_worker.py` * `nemo_rl/models/policy/utils.py`
|
Note Generated docstrings for this pull request at #1270 |
73cfc13 to
b9f427b
Compare
ℹ️ File Consistency CheckCheck based on commit: b9f427b (PR #1267 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
|
@guyueh1 @terrykong @parthchadha @yuki-97 @yfw could you take a review? |
ℹ️ File Consistency CheckCheck based on commit: 6e17ece (PR #1267 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
ℹ️ File Consistency CheckCheck based on commit: 73b0b0b (PR #1267 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
73b0b0b to
bca5826
Compare
Signed-off-by: Zhiyu Li <zhiyul@NVIDIA.com>
ℹ️ File Consistency CheckCheck based on commit: 7a85634 (PR #1267 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
ℹ️ File Consistency CheckCheck based on commit: 1da3381 (PR #1267 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
ℹ️ File Consistency CheckCheck based on commit: f319c1c (PR #1267 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
f319c1c to
364b492
Compare
ℹ️ File Consistency CheckCheck based on commit: 364b492 (PR #1267 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
Signed-off-by: Zhiyu Li <zhiyul@NVIDIA.com>
364b492 to
62acf7a
Compare
ℹ️ File Consistency CheckCheck based on commit: 62acf7a (PR #1267 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
What does this PR do ?
Summary
As titled, we now use Ray solely for lightweight coordination and notifications, while leveraging ZeroMQ (ZMQ) for high-throughput, peer-to-peer (P2P) inter-process communication (IPC). This design avoids costly serialization and enables superior overlap between computation and data transfer.
Asynchronous Overlap with Ping-Pong Buffers: The weight update process (refit) now operates within an inner generator loop, triggered by notifications from the Ray coordinator. We achieve significant overlap between the weight gathering generator and the actual weight updates by using asynchronous ZMQ send/recv operations combined with a ping-pong buffer. This allows one buffer to be filled with incoming weights while the other is being used for computation, maximizing hardware utilization.
Unified Communication Protocol
The ZMQ-based P2P communication logic has been standardized and applied across multiple components. This includes the async and sync LLM engines and all supported backends: DTensor, DTensor v2, and Mcore.
Significant Code Simplification
By refactoring and aligning all components to use this single, efficient communication protocol, we have greatly simplified the codebase. This effort resulted in a net removal of approximately 170 lines of redundant or complex code.
Performance
Issues
List issues that this PR closes (syntax):
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit