Skip to content

Commit 5cc1d38

Browse files
authored
fix: nvbugs/5187237: fix deterministic mode crash (#3448)
* nvbugs/5187237 nvbugs/5112075: fix deterministic mode error * remove waive Signed-off-by: Xiwen Yu <[email protected]> * Revert "remove waive" This reverts commit 0bf5486d19906d692bfb7a6262333c296b0087ac. Signed-off-by: Xiwen Yu <[email protected]> * revert ar fusion Signed-off-by: Xiwen Yu <[email protected]> --------- Signed-off-by: Xiwen Yu <[email protected]>
1 parent e36092b commit 5cc1d38

File tree

5 files changed

+15
-5
lines changed

5 files changed

+15
-5
lines changed

cpp/tensorrt_llm/kernels/communicationKernels/allReduceWorkspace.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ __global__ void lamport_initialize_kernel(float* ptr, int size)
2828

2929
void lamport_initialize(void* ptr, int bytes, cudaStream_t stream)
3030
{
31-
lamport_initialize_kernel<<<bytes / 128, 128, 0, stream>>>(reinterpret_cast<float*>(ptr), bytes / sizeof(float));
31+
int grid_size = (bytes + 127) / 128;
32+
lamport_initialize_kernel<<<grid_size, 128, 0, stream>>>(reinterpret_cast<float*>(ptr), bytes / sizeof(float));
3233
}
3334

3435
Workspace::Workspace(int rank, int tp_size, int max_token_num, int hidden_dim,

cpp/tensorrt_llm/kernels/customAllReduceKernels.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1989,6 +1989,10 @@ void residualRmsNorm(
19891989
void lamportInitialize(void* buffer, size_t size, nvinfer1::DataType dataType, cudaStream_t stream)
19901990
{
19911991
sync_check_cuda_error(stream);
1992+
if (size == 0)
1993+
{
1994+
return;
1995+
}
19921996
switch (dataType)
19931997
{
19941998
case nvinfer1::DataType::kFLOAT:

tensorrt_llm/_torch/distributed/ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ def get_deepseek_allreduce_workspace(mapping: Mapping) -> torch.LongTensor:
3232
if mapping not in deepseek_allreduce_workspaces:
3333
ipc_buffers, workspace = CustomAllReduceHelper.allocate_allreduce_fusion_workspace(
3434
mapping,
35-
CustomAllReduceHelper.max_workspace_size_auto(mapping.tp_size),
35+
CustomAllReduceHelper.max_workspace_size_auto(
36+
mapping.tp_size, support_deterministic=False),
3637
)
3738
deepseek_allreduce_workspaces[mapping] = (ipc_buffers, workspace)
3839
return deepseek_allreduce_workspaces[mapping][1]

tensorrt_llm/llmapi/llm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,10 @@ def _build_model(self):
517517
if self.args.kv_cache_config is not None:
518518
executor_config.kv_cache_config = PybindMirror.maybe_to_pybind(
519519
self.args.kv_cache_config)
520+
if os.getenv("FORCE_DETERMINISTIC", "0") == "1":
521+
# Disable KV cache reuse for deterministic mode
522+
executor_config.kv_cache_config.enable_block_reuse = False
523+
executor_config.kv_cache_config.enable_partial_reuse = False
520524
if self.args.peft_cache_config is not None:
521525
executor_config.peft_cache_config = PybindMirror.maybe_to_pybind(
522526
self.args.peft_cache_config)

tensorrt_llm/plugin/plugin.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -704,8 +704,8 @@ def set_workspace_tensor(self,
704704
)
705705

706706
@staticmethod
707-
def max_workspace_size_auto(tp_size: int) -> int:
708-
if force_all_reduce_deterministic():
707+
def max_workspace_size_auto(tp_size: int, support_deterministic) -> int:
708+
if force_all_reduce_deterministic() and support_deterministic:
709709
workspace_size = os.getenv("FORCE_ALLREDUCE_KERNEL_WORKSPACE_SIZE",
710710
"1000000000")
711711
return int(workspace_size)
@@ -746,7 +746,7 @@ def allocate_workspace(mapping: Mapping,
746746
lamport_buffers_0.local_ptr,
747747
lamport_buffers_1.local_ptr,
748748
lamport_buffers_2.local_ptr,
749-
size * mapping.tp_size,
749+
lamport_buffers_size,
750750
)
751751
buffers = [
752752
ipc_buffers_ping, ipc_buffers_pong, ipc_barriers_in,

0 commit comments

Comments
 (0)