Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
405 changes: 405 additions & 0 deletions run_deepseek_batch_split.sh

Large diffs are not rendered by default.

66 changes: 62 additions & 4 deletions tensorrt_llm/_torch/attention_backend/trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,16 @@ def run(
self.spec_decoding_generation_lengths,
self.spec_decoding_position_offsets, self.spec_decoding_packed_mask
]

#print(f"[DEBUG] TrtllmAttention.forward - q shape: {q.shape}, \
# sequence_length shape: {self.sequence_length.shape}, \
# host_past_key_value_lengths shape: {self.host_past_key_value_lengths.shape}, \
# context_lengths shape: {self.context_lengths.shape}, \
# host_context_lengths shape: {self.host_context_lengths.shape}, \
# host_request_types shape: {self.host_request_types.shape}, \
# kv_cache_block_offsets shape: {self.kv_cache_block_offsets.shape if self.kv_cache_block_offsets is not None else None}, \
# k shape: {k.shape if k is not None else None}, v shape: {v.shape if v is not None else None}, output shape: {output.shape if output is not None else None}, output_sf shape: {output_sf.shape if output_sf is not None else None}, \
# out_dtype: {out_dtype if out_dtype is not None else None}, is_fused_qkv: {is_fused_qkv}, update_kv_cache: {update_kv_cache}, attention_mask: {attention_mask}")

torch.ops.trtllm.attention_inplace(
q,
Expand Down Expand Up @@ -685,11 +695,22 @@ def __post_init__(self) -> None:
pin_memory=True,
)

def prepare(self) -> None:
def prepare(self, splitBatchOverlap: Optional[int] = None) -> None:
print(f"[DEBUG] TrtllmAttention.prepare {splitBatchOverlap}")
extra_attrs = get_model_extra_attrs()
# If model extra attrs is set, attention_metadata is setup in executor.
if extra_attrs is None:
get_global_attrs().attention_metadata = weakref.ref(self)
if splitBatchOverlap is not None:
#print(f"[DEBUG] TrtllmAttention.prepare - splitBatchOverlap is not None")
if splitBatchOverlap == 1:
print(f"[DEBUG] TrtllmAttention.prepare - splitBatchOverlap is 1")
get_global_attrs().attention_metadata_half1 = weakref.ref(self)
else:
print(f"[DEBUG] TrtllmAttention.prepare - splitBatchOverlap is not 1")
get_global_attrs().attention_metadata_half2 = weakref.ref(self)
else:
print(f"[DEBUG] TrtllmAttention.prepare - extra_attrs is None Setting self refrence to global attention_metadata")
get_global_attrs().attention_metadata = weakref.ref(self)
if self.kv_cache_manager is None:
# Convert the attention metadata to a TRT-LLM no cache attention metadata.
assert self.kv_cache_manager is None, "no cache attention should not have KV cache manager"
Expand All @@ -711,17 +732,21 @@ def prepare(self) -> None:
dtype=torch.int,
device='cpu',
)
#print(f"[DEBUG] TrtllmAttention.prepare - num_seqs: {self.num_seqs}")
self.prompt_lens_cpu[:self.num_seqs].copy_(prompt_lens)
self.prompt_lens_cuda[:self.num_seqs].copy_(
self.prompt_lens_cpu[:self.num_seqs], non_blocking=True)

# number of tokens in the kv cache for each sequence in the batch
#print(f"[DEBUG] TrtllmAttention.prepare - self.kv_cache_params.use_cache: {self.kv_cache_params.use_cache}")
#print(f"[DEBUG] TrtllmAttention.prepare - self.kv_cache_params.num_cached_tokens_per_seq: {self.kv_cache_params.num_cached_tokens_per_seq}")
cached_token_lens = torch.tensor(
self.kv_cache_params.num_cached_tokens_per_seq,
dtype=torch.int,
device='cpu',
) if self.kv_cache_params.use_cache else None

#print(f"[DEBUG] TrtllmAttention.prepare - cached_token_lens: {cached_token_lens}")
#print(f"[DEBUG] TrtllmAttention.prepare - self.seq_lens_kv: {self.seq_lens_kv}")
if self.enable_flash_mla:
self.prepare_flash_mla()
# number of tokens needed in the kv cache for each sequence after the next pass
Expand Down Expand Up @@ -754,6 +779,8 @@ def prepare(self) -> None:
self.kv_cache_block_offsets[:, :self.num_seqs].copy_(
self.host_kv_cache_block_offsets[:, :self.num_seqs],
non_blocking=True)
#print(f"[DEBUG] TrtllmAttention.prepare {splitBatchOverlap} - self.kv_lens[:self.num_seqs]: {self.kv_lens[:self.num_seqs]}")
#print(f"[DEBUG] TrtllmAttention.prepare {splitBatchOverlap} - self.kv_cache_manager.max_seq_len: {self.kv_cache_manager.max_seq_len}")
assert self.kv_lens[:self.num_seqs].max(
) <= self.kv_cache_manager.max_seq_len, f"Please set max_seq_len to at least {self.kv_lens[:self.num_seqs].max()} for kv cache manager."

Expand All @@ -765,6 +792,7 @@ def prepare(self) -> None:
num_seqs]

def prepare_flash_mla(self) -> None:
#print(f"[DEBUG] TrtllmAttention.prepare_flash_mla - self.request_ids: {self.request_ids}")
block_ids_per_seq = self.kv_cache_manager.get_block_ids_per_seq(
self.request_ids).pin_memory()
num_blocks = block_ids_per_seq.shape[1]
Expand Down Expand Up @@ -1094,6 +1122,15 @@ def forward(
output_sf: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
#print(f"[DEBUG] TrtllmAttention.forward - q shape: {q.shape}, dtype: {q.dtype}")
#print(f"[DEBUG] TrtllmAttention.forward - k shape: {k.shape if k is not None else None}, dtype: {k.dtype if k is not None else None}")
#print(f"[DEBUG] TrtllmAttention.forward - v shape: {v.shape if v is not None else None}, dtype: {v.dtype if v is not None else None}")
#print(f"[DEBUG] TrtllmAttention.forward - attention_input_type: {attention_input_type}")
#print(f"[DEBUG] TrtllmAttention.forward - latent_cache shape: {latent_cache.shape if latent_cache is not None else None}")
#print(f"[DEBUG] TrtllmAttention.forward - q_pe shape: {q_pe.shape if q_pe is not None else None}")
#print(f"[DEBUG] TrtllmAttention.forward - output shape: {output.shape if output is not None else None}")
#print(f"[DEBUG] TrtllmAttention.forward - attn_metadata: {metadata}")

assert isinstance(
metadata,
TrtllmAttentionMetadata,
Expand All @@ -1120,6 +1157,22 @@ def forward(
use_paged_context_fmha=use_paged_context_fmha,
is_mla_enable=self.is_mla_enable,
)
#print(f"[DEBUG] TrtllmAttention.forward wrapper.plan layer_idx: {self.get_local_layer_idx(metadata)}\
# tokens_per_block: {metadata.tokens_per_block}, max_num_requests: {metadata.max_num_requests},\
# max_seq_len: {metadata.max_seq_len}, max_num_tokens: {metadata.max_num_tokens}\
# attention_window_size: {attention_window_size}, sink_token_length: {0}, beam_width: {metadata.beam_width}\
# sequence_length: {metadata.kv_lens_cuda_runtime.shape if metadata.kv_lens_cuda_runtime is not None else None}, host_past_key_value_lengths: {metadata.kv_lens_runtime.shape if metadata.kv_lens_runtime is not None else None}\
# context_lengths: {metadata.prompt_lens_cuda_runtime.shape if metadata.prompt_lens_cuda_runtime is not None else None}, host_context_lengths: {metadata.prompt_lens_cpu_runtime.shape if metadata.prompt_lens_cpu_runtime is not None else None}\
# host_request_types: {metadata.host_request_types_runtime.shape if metadata.host_request_types_runtime is not None else None}, kv_cache_block_offsets: {metadata.kv_cache_block_offsets.shape if metadata.kv_cache_block_offsets is not None else None}\
# host_kv_cache_block_offsets: {metadata.host_kv_cache_block_offsets.shape if metadata.host_kv_cache_block_offsets is not None else None}, host_kv_cache_pool_pointers: {metadata.host_kv_cache_pool_pointers.shape if metadata.host_kv_cache_pool_pointers is not None else None}\
# host_kv_cache_pool_mapping: {metadata.host_kv_cache_pool_mapping.shape if metadata.host_kv_cache_pool_mapping is not None else None}, block_ids_per_seq: {metadata.block_ids_per_seq.shape if metadata.block_ids_per_seq is not None else None}\
# workspace: {metadata.workspace.shape if metadata.workspace is not None else None}, cache_indirection: {metadata.cache_indirection.shape if metadata.cache_indirection is not None else None}, kv_scale_orig_quant: {self.kv_scale_orig_quant}\
# kv_scale_quant_orig: {self.kv_scale_quant_orig}, out_scale: {out_scale.shape if out_scale is not None else None}, out_scale_sf: {out_scale_sf.shape if out_scale_sf is not None else None}\
# latent_cache: {latent_cache.shape if latent_cache is not None else None}, q_pe: {q_pe.shape if q_pe is not None else None}, mrope_config: {mrope_config}, mla_context_paged_kv: {mla_context_paged_kv.shape if mla_context_paged_kv is not None else None}\
# mla_context_kv_cache_block_offsets: {mla_context_kv_cache_block_offsets.shape if mla_context_kv_cache_block_offsets is not None else None}, softmax_stats_tensor: {softmax_stats_tensor.shape if softmax_stats_tensor is not None else None}\
# is_spec_decoding_enabled: {metadata.is_spec_decoding_enabled}, use_spec_decoding: {metadata.use_spec_decoding}\
# spec_decoding_position_offsets: {metadata.spec_decoding_position_offsets.shape if metadata.spec_decoding_position_offsets is not None else None}, spec_decoding_packed_mask: {metadata.spec_decoding_packed_mask.shape if metadata.spec_decoding_packed_mask is not None else None}\
# spec_decoding_generation_lengths: {metadata.spec_decoding_generation_lengths}")
self.wrapper.plan(
layer_idx=self.get_local_layer_idx(metadata),
tokens_per_block=metadata.tokens_per_block,
Expand Down Expand Up @@ -1174,6 +1227,7 @@ def forward(
# TODO(qijun): revisit fp8_context_fmha logic
out_dtype = torch.float8_e4m3fn

#print(f"[DEBUG] TrtllmAttention.forward - About to call CUDA kernel with out_dtype: {out_dtype}")
output, output_sf = self.wrapper.run(
q,
k,
Expand All @@ -1187,7 +1241,11 @@ def forward(

if out_dtype == torch.uint8:
assert output_sf is not None
return Fp4QuantizedTensor(output, output_sf)
result = Fp4QuantizedTensor(output, output_sf)
#print(f"[DEBUG] TrtllmAttention.forward - Returning Fp4QuantizedTensor with shape: {result.shape}")
return result

#print(f"[DEBUG] TrtllmAttention.forward - Final output shape: {output.shape}, dtype: {output.dtype}")
return output

@classmethod
Expand Down
29 changes: 23 additions & 6 deletions tensorrt_llm/_torch/distributed/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,15 +169,29 @@ def allgather(
if mapping.tp_size == 1:
return input

# print(f"[DEBUG] allgather - input shape: {mapping.tp_group}")
# print(f"[DEBUG] allgather - input shape: {mapping}")
# print(f"[DEBUG] allgather - sizes : {sizes}")
if sizes is not None:
assert len(sizes) == len(mapping.tp_group)
if isinstance(input, torch.Tensor):
assert input.shape[dim] == sizes[mapping.tp_rank]
else:
assert all([
val.shape[dim] == sizes[mapping.tp_rank] for val in input
if val is not None
])
#for val in input:
# if val is not None:
# print(f"[DEBUG] allgather - val shape: {val.shape}, dtype: {val.dtype}")
# print(f"[DEBUG] allgather - val shape: {val.shape[dim]}, dtype: {val.dtype}")
# print(f"[DEBUG] allgather - sizes[mapping.tp_rank]: {sizes[mapping.tp_rank]}")
if os.environ.get("ENABLE_TRTLLM_SPLIT_BATCH_OVERLAP", "0") == "1":
assert all([
val.shape[dim] == sizes[mapping.tp_rank] / 2 or val.shape[dim] == sizes[mapping.tp_rank] for val in input
if val is not None
])
else:
assert all([
val.shape[dim] == sizes[mapping.tp_rank] for val in input
if val is not None
])
Comment on lines +185 to +194
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Fix line length violation and improve readability.

The conditional assertion logic implements valid functionality for split batch overlap, but the line exceeds the 120-character limit and could be more readable.

Apply this diff to fix the line length and improve readability:

             if os.environ.get("ENABLE_TRTLLM_SPLIT_BATCH_OVERLAP", "0") == "1":
-                assert all([
-                    val.shape[dim] == sizes[mapping.tp_rank] / 2 or val.shape[dim] == sizes[mapping.tp_rank] for val in input
-                    if val is not None
-                ])
+                expected_size = sizes[mapping.tp_rank]
+                assert all([
+                    val.shape[dim] == expected_size / 2 or val.shape[dim] == expected_size
+                    for val in input if val is not None
+                ])
             else:
                 assert all([
                     val.shape[dim] == sizes[mapping.tp_rank] for val in input
                     if val is not None
                 ])
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if os.environ.get("ENABLE_TRTLLM_SPLIT_BATCH_OVERLAP", "0") == "1":
assert all([
val.shape[dim] == sizes[mapping.tp_rank] / 2 or val.shape[dim] == sizes[mapping.tp_rank] for val in input
if val is not None
])
else:
assert all([
val.shape[dim] == sizes[mapping.tp_rank] for val in input
if val is not None
])
if os.environ.get("ENABLE_TRTLLM_SPLIT_BATCH_OVERLAP", "0") == "1":
expected_size = sizes[mapping.tp_rank]
assert all([
val.shape[dim] == expected_size / 2 or val.shape[dim] == expected_size
for val in input if val is not None
])
else:
assert all([
val.shape[dim] == sizes[mapping.tp_rank] for val in input
if val is not None
])
🧰 Tools
🪛 Ruff (0.12.2)

186-186: Line too long (125 > 120)

(E501)

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/distributed/ops.py around lines 184 to 193, the assertion
lines exceed the 120-character limit and reduce readability. Refactor the
assertions by breaking down the conditions into intermediate variables or using
multiple lines for the list comprehension, ensuring each line stays within the
character limit and the logic remains clear and easy to follow.


# Inputs are reshaped in this way to pass necessary shape information to the allgather op
if isinstance(input, torch.Tensor):
Expand All @@ -192,7 +206,10 @@ def allgather(
val.contiguous().view(-1, val_info['numel_base'])
for val, val_info in zip(input, output_info)
]

#for val in input:
# print(f"[DEBUG] allgather - input shape: {val.shape}")
#for size in sizes:
# print(f"[DEBUG] allgather - size: {size}")
output = torch_op(
input,
sizes,
Expand Down Expand Up @@ -546,7 +563,7 @@ def forward(
hidden_states: hidden_states of the model
residual: residual tensor
"""

# print(f"[DEBUG] MoEAllReduce.forward - input shape: {input.shape}, dtype: {input.dtype}")
return torch.ops.trtllm.moe_allreduce(
active_experts_token_input=input,
residual=all_reduce_params.residual,
Expand Down
Loading