Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 34 additions & 26 deletions cpp/tensorrt_llm/kernels/mlaKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,6 @@ __global__ void applyMLARopeAndAssignQKVKernelGeneration(T* qkv_output, T* q_pe,
float const* dequant_scale_kv, float host_bmm1_scale, int32_t const* helix_position_offsets,
bool const* helix_is_inactive_rank)
{

// Constants.
using VecT = typename VecType<T>::Type;
using GPTJEltT = typename VecType<T>::GPTJEltType;
Expand Down Expand Up @@ -475,23 +474,27 @@ __global__ void applyMLARopeAndAssignQKVKernelGeneration(T* qkv_output, T* q_pe,

if (valid_token)
{
if (head_idx == head_num && (helix_is_inactive_rank == nullptr || !helix_is_inactive_rank[batch_idx]))
if (head_idx == head_num)
{
auto const token_kv_idx = kv_cache_lengths[batch_idx] - seq_len + local_token_idx;

// If helix parallelism is being used, only write to KV cache if current rank is active.
if (helix_is_inactive_rank == nullptr || !helix_is_inactive_rank[batch_idx])
{
auto kDst = reinterpret_cast<T*>(kv_cache.getKBlockPtr(batch_idx, token_kv_idx));
auto inBlockIdx = kv_cache.getKVLocalIdx(
token_kv_idx, 0, TOTAL_VEC_PER_HEAD, K_VECS_PER_HEAD + head_dim_vec_idx);
if (cache_type == KvCacheDataType::FP8)
{
auto const token_kv_idx = kv_cache_lengths[batch_idx] - seq_len + local_token_idx;

quantCopy<T, ELTS_PER_VEC>(
reinterpret_cast<__nv_fp8_e4m3*>(kDst) + inBlockIdx * ELTS_PER_VEC,
reinterpret_cast<T const*>(&data), quant_scale_kv_val);
{
auto kDst = reinterpret_cast<T*>(kv_cache.getKBlockPtr(batch_idx, token_kv_idx));
auto inBlockIdx = kv_cache.getKVLocalIdx(
token_kv_idx, 0, TOTAL_VEC_PER_HEAD, K_VECS_PER_HEAD + head_dim_vec_idx);
if (cache_type == KvCacheDataType::FP8)
{

quantCopy<T, ELTS_PER_VEC>(
reinterpret_cast<__nv_fp8_e4m3*>(kDst) + inBlockIdx * ELTS_PER_VEC,
reinterpret_cast<T const*>(&data), quant_scale_kv_val);
}
else
reinterpret_cast<VecT*>(kDst)[inBlockIdx] = data;
}
else
reinterpret_cast<VecT*>(kDst)[inBlockIdx] = data;
}
}
else
Expand Down Expand Up @@ -529,28 +532,33 @@ __global__ void applyMLARopeAndAssignQKVKernelGeneration(T* qkv_output, T* q_pe,
auto local_token_idx = global_token_idx % seq_len;
bool valid_token = global_token_idx < total_s_len;

if (valid_token && (helix_is_inactive_rank == nullptr || !helix_is_inactive_rank[batch_idx]))
if (valid_token)
{
if (head_dim_vec_idx == 0)
{
seqQOffset[batch_idx + 1] = head_num * seq_len * (batch_idx + 1);
}

auto const token_kv_idx = kv_cache_lengths[batch_idx] - seq_len + local_token_idx;
auto const src_kv_global_offset = static_cast<size_t>(global_token_idx) * (c_k + ROPE_DIM);

// If helix parallelism is being used, only write to KV cache if current rank is active.
if (helix_is_inactive_rank == nullptr || !helix_is_inactive_rank[batch_idx])
{
auto kDst = reinterpret_cast<T*>(kv_cache.getKBlockPtr(batch_idx, token_kv_idx));
auto inBlockIdx = kv_cache.getKVLocalIdx(token_kv_idx, 0, TOTAL_VEC_PER_HEAD, head_dim_vec_idx);
auto const token_kv_idx = kv_cache_lengths[batch_idx] - seq_len + local_token_idx;
auto const src_kv_global_offset = static_cast<size_t>(global_token_idx) * (c_k + ROPE_DIM);

if (cache_type == KvCacheDataType::FP8)
{
quantCopy<T, ELTS_PER_VEC>(reinterpret_cast<__nv_fp8_e4m3*>(kDst) + inBlockIdx * ELTS_PER_VEC,
fuse_buf + src_kv_global_offset + head_dim_idx, quant_scale_kv_val);
auto kDst = reinterpret_cast<T*>(kv_cache.getKBlockPtr(batch_idx, token_kv_idx));
auto inBlockIdx = kv_cache.getKVLocalIdx(token_kv_idx, 0, TOTAL_VEC_PER_HEAD, head_dim_vec_idx);

if (cache_type == KvCacheDataType::FP8)
{
quantCopy<T, ELTS_PER_VEC>(
reinterpret_cast<__nv_fp8_e4m3*>(kDst) + inBlockIdx * ELTS_PER_VEC,
fuse_buf + src_kv_global_offset + head_dim_idx, quant_scale_kv_val);
}
else
reinterpret_cast<VecT*>(kDst)[inBlockIdx]
= *reinterpret_cast<VecT const*>(&fuse_buf[src_kv_global_offset + head_dim_idx]);
}
else
reinterpret_cast<VecT*>(kDst)[inBlockIdx]
= *reinterpret_cast<VecT const*>(&fuse_buf[src_kv_global_offset + head_dim_idx]);
}
}
}
Expand Down
5 changes: 0 additions & 5 deletions tests/integration/defs/accuracy/references/gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,6 @@ deepseek-ai/DeepSeek-V3-Lite:
kv_cache_quant_algo: FP8
spec_dec_algo: MTP
accuracy: 64.14
# https://nvbugs/5637012: Currently, BS>1 has accuracy issues with helix for GSM8K.
# BS=1 has expected accuracy but will be too slow for CI testing. So, adding this
# accuracy spec while we investigate the issue.
- extra_acc_spec: helix_with_bs8
accuracy: 50.0
deepseek-ai/DeepSeek-R1:
- quant_algo: NVFP4
accuracy: 95.42
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,6 @@ def test_auto_dtype_with_helix(self):
"pipeline_parallel_size": 1,
"tensor_parallel_size": 2,
"context_parallel_size": 1,
"max_batch_size": 8,
"disable_overlap_scheduler": True,
"kv_cache_config": kv_cache_config,
"enable_chunked_prefill": False,
Expand All @@ -879,7 +878,6 @@ def test_auto_dtype_with_helix(self):
"cp_type": "HELIX",
"tokens_per_block": 32
},
"max_batch_size": 8,
"disable_overlap_scheduler": True,
"kv_cache_config": kv_cache_config,
"enable_chunked_prefill": False,
Expand Down Expand Up @@ -907,7 +905,7 @@ def test_auto_dtype_with_helix(self):
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm, extra_acc_spec="helix_with_bs8")
task.evaluate(llm)

@pytest.mark.skip_less_device(2)
@pytest.mark.skip_less_device_memory(60000)
Expand Down
1 change: 0 additions & 1 deletion tests/unittest/_torch/modules/test_mla_helix.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,6 @@ def _run_single_rank(func, *args, **kwargs):
raise Exception(f"\n\nError occurred. Original traceback is\n{tb}\n")


# note: due to bad numerics with smaller context sizes, we allow up to 2% mismatches
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="needs 2 GPUs to run this test")
@pytest.mark.parametrize("scenario", test_scenarios, ids=lambda x: f"scenario: {x}")
def test_mla_helix_distributed(
Expand Down
Loading