diff --git a/aiter/dist/device_communicators/communicator_cuda.py b/aiter/dist/device_communicators/communicator_cuda.py index fcc7ee05b2..c7837376c8 100644 --- a/aiter/dist/device_communicators/communicator_cuda.py +++ b/aiter/dist/device_communicators/communicator_cuda.py @@ -149,6 +149,8 @@ def all_reduce( qr_comm is not None and not qr_comm.disabled and qr_comm.should_quick_allreduce(input_) + and (input_.nelement() * input_.element_size()) >= 4*1024*1024 # input shape should be such that quick reduce will show benefits. + # input shape estimated at 2 * max concurrency for now. if performance issues, subject to change ): out = qr_comm.quick_all_reduce(input_) assert out is not None