From bbe48fa0d17ee0539ef364bfed92715595be8088 Mon Sep 17 00:00:00 2001 From: tianyuz-nv Date: Thu, 19 Mar 2026 01:08:08 -0700 Subject: [PATCH 01/12] [None][feat] Add DWDP (Distributed Weight Data Parallelism) support for MoE inference Core DWDP runtime (dwdp.py): - DwdpManager: IPC handle exchange across MPI ranks - DwdpHandleCollector: per-layer weight/scale/bias handle collection - Expert weight prefetching with double-buffering MoE integration (configurable_moe.py, fused_moe_cute_dsl.py, interface.py): - DWDP support in ConfigurableMoE with CuteDSL+NVFP4 backend - NvFp4WeightView for DWDP weight access patterns - Contiguous gather/scatter grouped GEMM kernels CuteDSL kernel extensions: - Blockscaled contiguous gather grouped GEMM with SwiGLU fusion - Blockscaled contiguous grouped GEMM finalize fusion Executor integration (py_executor.py, py_executor_creator.py, llm_args.py): - DwdpConfig dataclass for YAML-based configuration - DwdpManager initialization and per-step prefetching Disaggregated serving scripts: - start_worker_dwdp.sh for MPI-based worker launch - submit.py DWDP configuration support CI test: - DWDP accuracy test with DeepSeek-V3-Lite (NVFP4, 4 GPUs, GSM8K) Co-authored-by: wanqian-nv <221923321+wanqian-nv@users.noreply.github.com> Co-authored-by: zongfeijing <20381269+zongfeijing@users.noreply.github.com> Signed-off-by: tianyuz-nv --- cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp | 3 +- .../slurm/benchmark/disaggr_torch.slurm | 5 + .../slurm/benchmark/start_worker_dwdp.sh | 61 ++ .../disaggregated/slurm/benchmark/submit.py | 205 +++-- .../_torch/custom_ops/cute_dsl_custom_ops.py | 455 +++++++---- ...guous_gather_grouped_gemm_swiglu_fusion.py | 763 ++++++++++++++---- ...contiguous_grouped_gemm_finalize_fusion.py | 709 +++++++++++++--- .../modules/fused_moe/configurable_moe.py | 35 + .../modules/fused_moe/fused_moe_cute_dsl.py | 220 ++++- .../_torch/modules/fused_moe/interface.py | 26 + tensorrt_llm/_torch/pyexecutor/_util.py | 7 +- tensorrt_llm/_torch/pyexecutor/dwdp.py | 556 +++++++++++++ tensorrt_llm/_torch/pyexecutor/py_executor.py | 8 +- .../_torch/pyexecutor/py_executor_creator.py | 14 + tensorrt_llm/llmapi/llm_args.py | 20 + .../accuracy/test_disaggregated_serving.py | 243 ++++++ ...guous_gather_grouped_gemm_swiglu_fusion.py | 407 +++++++--- ...contiguous_grouped_gemm_finalize_fusion.py | 215 ++++- .../_torch/thop/parallel/test_cute_dsl_moe.py | 86 +- 19 files changed, 3380 insertions(+), 658 deletions(-) create mode 100644 examples/disaggregated/slurm/benchmark/start_worker_dwdp.sh create mode 100644 tensorrt_llm/_torch/pyexecutor/dwdp.py diff --git a/cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp b/cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp index d81ae4e39909..f5b2e4b7d5e4 100644 --- a/cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp +++ b/cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp @@ -144,7 +144,8 @@ torch::Tensor moeA2AInitializeOp(torch::Tensor const& workspace, int64_t epRank, // Synchronize among ranks cudaDeviceSynchronize(); - tensorrt_llm::mpi::MpiComm::world().barrier(); + // tensorrt_llm::mpi::MpiComm::world().barrier(); + tensorrt_llm::mpi::MpiComm::session().barrier(); return metainfo; } diff --git a/examples/disaggregated/slurm/benchmark/disaggr_torch.slurm b/examples/disaggregated/slurm/benchmark/disaggr_torch.slurm index 9e0d66144478..145c3099267b 100644 --- a/examples/disaggregated/slurm/benchmark/disaggr_torch.slurm +++ b/examples/disaggregated/slurm/benchmark/disaggr_torch.slurm @@ -150,6 +150,11 @@ replace_placeholder "${start_server_cmds_base_file}" "${all_nodes_str}" "${start server_config_base_file=${full_logdir}/server_config_base.yaml server_config_file=${full_logdir}/server_config.yaml replace_placeholder "${server_config_base_file}" "${all_nodes_str}" "${server_config_file}" +mpi_worker_config_base_file=${full_logdir}/mpi_worker_config_base.yaml +mpi_worker_config_file=${full_logdir}/mpi_worker_config.yaml +if [ -f "${mpi_worker_config_base_file}" ]; then + replace_placeholder "${mpi_worker_config_base_file}" "${all_nodes_str}" "${mpi_worker_config_file}" +fi client_cmds_base_file=${full_logdir}/client_cmds_base.sh client_cmds_file=${full_logdir}/client_cmds.sh replace_placeholder "${client_cmds_base_file}" "${all_nodes_str}" "${client_cmds_file}" diff --git a/examples/disaggregated/slurm/benchmark/start_worker_dwdp.sh b/examples/disaggregated/slurm/benchmark/start_worker_dwdp.sh new file mode 100644 index 000000000000..b0fe5bf88443 --- /dev/null +++ b/examples/disaggregated/slurm/benchmark/start_worker_dwdp.sh @@ -0,0 +1,61 @@ +#! /bin/bash +set -u +set -e +set -x + +config_file=${1} +numa_bind=${2} +log_dir=${3} +enable_nsys=${4} +ctx_profile_range=${5} +gen_profile_range=${6} +num_ctx_gpus=${7} +ctx_worker_env_var=${8} +gen_worker_env_var=${9} + +unset UCX_NET_DEVICES +unset UCX_TLS + +echo "SLURM_PROCID: ${SLURM_PROCID}, hostname: $(hostname)" + +if [ "${SLURM_PROCID}" -lt "${num_ctx_gpus}" ]; then + worker_role="CTX" + worker_env_var=${ctx_worker_env_var} + profile_range=${ctx_profile_range} +else + worker_role="GEN" + worker_env_var=${gen_worker_env_var} + profile_range=${gen_profile_range} +fi + +echo "worker_role: ${worker_role}, profile_range: ${profile_range}" + +for env_var in ${worker_env_var}; do + export "${env_var}" + echo "Exported: ${env_var}" +done + +if [ "${numa_bind}" = "true" ]; then + numa_bind_cmd="numactl -m 0,1" + echo "numactl -m 0,1 - Only allocate memory from nodes on GB200/GB300 NVL72" +else + numa_bind_cmd="" + echo "Not binding memory. If on GB200/GB300 NVL72, use \"numactl -m 0,1\" to only allocate memory from nodes." +fi + +echo "config_file: ${config_file}" + +nsys_prefix="" +if [ "${enable_nsys}" != "true" ]; then + echo "nsys is not enabled, start normal flow" +else + nsys_file=${log_dir}/nsys_worker_proc_${worker_role}_${SLURM_PROCID} + export TLLM_PROFILE_RECORD_GC=1 + export TLLM_NVTX_DEBUG=1 + export NSYS_MPI_STORE_TEAMS_PER_RANK=1 + export TLLM_PROFILE_START_STOP=${profile_range} + echo "nsys is enabled on ${worker_role} ranks, TLLM_PROFILE_START_STOP=${profile_range}" + nsys_prefix="nsys profile -o ${nsys_file} -f true -t cuda,nvtx,python-gil -c cudaProfilerApi --cuda-graph-trace node --capture-range-end=stop --gpu-metrics-devices=none" +fi + +${nsys_prefix} ${numa_bind_cmd} trtllm-serve disaggregated_mpi_worker -c ${config_file} diff --git a/examples/disaggregated/slurm/benchmark/submit.py b/examples/disaggregated/slurm/benchmark/submit.py index 97903423f204..0a4efa620bbb 100644 --- a/examples/disaggregated/slurm/benchmark/submit.py +++ b/examples/disaggregated/slurm/benchmark/submit.py @@ -49,6 +49,44 @@ def save_worker_config(worker_config, output_path): yaml.dump(worker_config, f, default_flow_style=False) +def generate_mpi_worker_config(worker_config, allocations, env_config, + disagg_hostname, disagg_port, output_path): + """Generate a config YAML compatible with ``trtllm-serve disaggregated_mpi_worker``. + """ + def _build_urls(server_type): + urls = [] + for server_id in sorted(allocations.get(server_type, {}).keys()): + inst = allocations[server_type][server_id] + host = list(inst["nodes"].keys())[0] + urls.append(f"{host}:{inst['port']}") + return urls + + ctx_urls = _build_urls("CTX") + gen_urls = _build_urls("GEN") + + ctx_section = dict(worker_config['ctx']) + ctx_section['num_instances'] = len(ctx_urls) + ctx_section['urls'] = ctx_urls + + gen_section = dict(worker_config['gen']) + gen_section['num_instances'] = len(gen_urls) + gen_section['urls'] = gen_urls + + config = { + 'model': env_config['model_path'], + 'hostname': disagg_hostname, + 'port': disagg_port, + 'backend': 'pytorch', + 'max_retries': 100, + 'context_servers': ctx_section, + 'generation_servers': gen_section, + } + + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, 'w') as f: + yaml.dump(config, f, default_flow_style=False, sort_keys=False) + + def calculate_nodes(world_size, num_servers, gpus_per_node): """Calculate required nodes based on world size and server count.""" return math.ceil(world_size * num_servers / gpus_per_node) @@ -101,10 +139,13 @@ def assign_servers( assign_server(server_allocation, world_size, gpus_per_node) server_allocations[server_type][i] = server_allocation - assign_servers(allocations, "GEN", num_gen_servers, gen_world_size, - gpus_per_node) + # Keep the allocation order aligned with disagg_utils, which builds + # server_configs as ctx_cfgs + gen_cfgs and assigns rank offsets in that + # same order during split_world_comm(). assign_servers(allocations, "CTX", num_ctx_servers, ctx_world_size, gpus_per_node) + assign_servers(allocations, "GEN", num_gen_servers, gen_world_size, + gpus_per_node) return allocations @@ -172,6 +213,7 @@ def replace_env_in_file(log_dir, file_path, env_var): return tmp_dir + def build_worker_environment(worker_config, env_config, role, benchmark_mode, nsys_on, profile_range, concurrency, gpu_ids): """Build complete environment dictionary for worker processes. @@ -287,6 +329,7 @@ def format_export_string(env_dict): return ",".join(export_list) + def save_env_file(env_file, server_env_var, worker_env_var, ctx_worker_env_var, gen_worker_env_var): @@ -373,6 +416,12 @@ def submit_job(config, log_dir, dry_run): total_nodes = ctx_nodes + gen_nodes total_tasks = total_nodes * gpus_per_node + # Detect DWDP mode: when enabled, use a single srun with + # trtllm-serve disaggregated_mpi_worker instead of per-instance sruns + dwdp_enabled = worker_config.get('ctx', {}).get( + 'dwdp_config', {}).get('enabled', False) + dwdp_size = worker_config.get('ctx', {}).get('dwdp_config', {}).get('dwdp_size', 1) + # Generate log directory path based on configuration isl = benchmark_config['input_length'] osl = benchmark_config['output_length'] @@ -401,10 +450,13 @@ def submit_job(config, log_dir, dry_run): log_base = os.path.join(log_base, f"{date_prefix}/{isl}-{osl}") # Determine directory suffix based on attention_dp - if gen_enable_attention_dp: - dir_suffix = f"disagg_ctx{ctx_num}_gen{gen_num}_dep{gen_tp_size}_batch{gen_batch_size}_eplb{eplb_num_slots}_mtp{mtp_size}" + if dwdp_enabled: + dir_suffix = f"disagg_ctx{ctx_num}_dwdp{dwdp_size}_gen{gen_num}_dep{gen_tp_size}_batch{gen_batch_size}_eplb{eplb_num_slots}_mtp{mtp_size}" else: - dir_suffix = f"disagg_ctx{ctx_num}_gen{gen_num}_tep{gen_tp_size}_batch{gen_batch_size}_eplb{eplb_num_slots}_mtp{mtp_size}" + if gen_enable_attention_dp: + dir_suffix = f"disagg_ctx{ctx_num}_gen{gen_num}_dep{gen_tp_size}_batch{gen_batch_size}_eplb{eplb_num_slots}_mtp{mtp_size}" + else: + dir_suffix = f"disagg_ctx{ctx_num}_gen{gen_num}_tep{gen_tp_size}_batch{gen_batch_size}_eplb{eplb_num_slots}_mtp{mtp_size}" # Create full log directory path log_dir = os.path.join(log_base, dir_suffix) @@ -481,53 +533,102 @@ def submit_job(config, log_dir, dry_run): } } - # Generate start worker commands with placeholder hostnames - for server_type in allocations.keys(): - server_cfg = server_configs[server_type] - - for server_id in allocations[server_type].keys(): - allocation = allocations[server_type][server_id] - # Get GPU IDs for this server from allocation - # When multi-node, all nodes have same device list, so use first node [0] - gpu_ids = list(allocation["nodes"].values())[0] - - # Build environment for this worker - worker_env = build_worker_environment( - worker_config=worker_config, - env_config=env_config, - role=server_type, - benchmark_mode=benchmark_config['mode'], - nsys_on=profiling_config['nsys_on'], - profile_range=server_cfg['profile_range'], - concurrency=benchmark_config['concurrency_list'].split(',')[0], - gpu_ids=gpu_ids, - ) - export_str = format_export_string(worker_env) - - # Use script_dir for start_worker.sh - cmd = [ - "srun -l", - f"--nodelist {','.join(allocation['nodes'].keys())}", - f"-N {len(allocation['nodes'])}", - f"--ntasks {server_cfg['world_size']}", - f"--ntasks-per-node {gpus_per_node}", - f"--export=\"{export_str}\"", - f"--container-image {env_config['container_image']}", - f"--container-name {container_name}", - f"--container-mounts {container_mount_str}", - "--no-container-mount-home --mpi=pmix --overlap", - f"bash {os.path.join(script_dir, 'start_worker.sh')}", - server_type, - str(server_id), - env_config['model_path'], - str(allocation["port"]), - str(slurm_config['numa_bind']).lower(), - log_dir, - str(profiling_config['nsys_on']).lower(), - server_cfg['config_path'], - f"&> {log_dir}/3_output_{server_type}_{server_id}.log &", - ] - start_server_cmds.append(" ".join(cmd)) + if dwdp_enabled: + # --- DWDP mode: single srun with disaggregated_mpi_worker --- + mpi_config_base_path = os.path.join(log_dir, 'mpi_worker_config_base.yaml') + mpi_config_path = os.path.join(log_dir, 'mpi_worker_config.yaml') + generate_mpi_worker_config(worker_config, allocations, env_config, + disagg_server_hostname, disagg_server_port, + mpi_config_base_path) + + # Nodelist: CTX nodes first, then GEN nodes (matches + # split_world_comm order: server_configs = ctx_cfgs + gen_cfgs) + ctx_node_list = [] + for sid in sorted(allocations.get("CTX", {}).keys()): + for node in allocations["CTX"][sid]["nodes"]: + if node not in ctx_node_list: + ctx_node_list.append(node) + gen_node_list = [] + for sid in sorted(allocations.get("GEN", {}).keys()): + for node in allocations["GEN"][sid]["nodes"]: + if node not in gen_node_list: + gen_node_list.append(node) + mpi_nodelist = ctx_node_list + gen_node_list + total_mpi_tasks = ctx_num * ctx_world_size + gen_num * gen_world_size + mpi_num_nodes = len(mpi_nodelist) + num_ctx_gpus = ctx_num * ctx_world_size + dwdp_ctx_worker_env_var = worker_env_var + \ + (f" {ctx_worker_env_var}" if ctx_worker_env_var else "") + dwdp_gen_worker_env_var = worker_env_var + \ + (f" {gen_worker_env_var}" if gen_worker_env_var else "") + + cmd = [ + "srun -l", + f"--nodelist {','.join(mpi_nodelist)}", + f"-N {mpi_num_nodes}", + f"--ntasks {total_mpi_tasks}", + f"--ntasks-per-node {gpus_per_node}", + f"--container-image {env_config['container_image']}", + f"--container-name {container_name}", + f"--container-mounts {container_mount_str}", + "--no-container-mount-home --mpi=pmix --overlap", + f"bash {os.path.join(script_dir, 'start_worker_dwdp.sh')}", + mpi_config_path, + str(slurm_config['numa_bind']).lower(), + log_dir, + str(profiling_config['nsys_on']).lower(), + f"'{profiling_config['ctx_profile_range']}'", + f"'{profiling_config['gen_profile_range']}'", + str(num_ctx_gpus), + f"'{dwdp_ctx_worker_env_var}'", + f"'{dwdp_gen_worker_env_var}'", + f"&> {log_dir}/3_output_workers.log &", + ] + start_server_cmds.append(" ".join(cmd)) + else: + # --- Standard mode: per-instance srun --- + for server_type in allocations.keys(): + server_cfg = server_configs[server_type] + + for server_id in allocations[server_type].keys(): + allocation = allocations[server_type][server_id] + gpu_ids = list(allocation["nodes"].values())[0] + + worker_env = build_worker_environment( + worker_config=worker_config, + env_config=env_config, + role=server_type, + benchmark_mode=benchmark_config['mode'], + nsys_on=profiling_config['nsys_on'], + profile_range=server_cfg['profile_range'], + concurrency=benchmark_config['concurrency_list'].split(',')[0], + gpu_ids=gpu_ids, + ) + export_str = format_export_string(worker_env) + + cmd = [ + "srun -l", + f"--nodelist {','.join(allocation['nodes'].keys())}", + f"-N {len(allocation['nodes'])}", + f"--ntasks {server_cfg['world_size']}", + f"--ntasks-per-node {gpus_per_node}", + f"--export=\"{export_str}\"", + f"--container-image {env_config['container_image']}", + f"--container-name {container_name}", + f"--container-mounts {container_mount_str}", + "--no-container-mount-home --mpi=pmix --overlap", + f"bash {os.path.join(script_dir, 'start_worker.sh')}", + server_type, + str(server_id), + env_config['model_path'], + str(allocation["port"]), + str(slurm_config['numa_bind']).lower(), + log_dir, + str(profiling_config['nsys_on']).lower(), + server_cfg['config_path'], + f"&> {log_dir}/3_output_{server_type}_{server_id}.log &", + ] + start_server_cmds.append(" ".join(cmd)) # Generate start server commands (use script_dir for start_server.sh) server_env = build_server_environment(env_config, benchmark_config['mode']) diff --git a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py index d6f616088951..06d992ffb411 100644 --- a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py @@ -228,23 +228,23 @@ class GatherGroupedGemmInputsHelper(GroupedGemmInputsHelper): - permuted_idx_to_expanded_idx specifies the gather pattern - Shape inference uses permuted_idx_to_expanded_idx size instead of a size - Input tensor layout: - 0: a - Original input activation (not permuted) - 1: b - Weight tensor - 2: a_sf - Scale factor for a - 3: b_sf - Scale factor for b - 4: alpha - Per-expert scaling factor - 5: tile_idx_to_group_idx - Tile to expert mapping - 6: tile_idx_to_mn_limit - Tile M/N limits - 7: permuted_idx_to_expanded_idx - Token permutation mapping - 8: num_non_exiting_tiles - Number of valid tiles - 9: global_sf - Global scale factor + Input layout (positions 1, 3, 4 are lists for multi-B support): + 0: a - tensor, original input activation + 1: b_list - list of tensors, weight tensors + 2: a_sf - tensor, scale factor for a + 3: b_sf_list - list of tensors, scale factors for b + 4: alpha_list - list of tensors, per-expert scaling factors + 5: tile_idx_to_group_idx - tensor, tile to expert mapping + 6: tile_idx_to_mn_limit - tensor, tile M/N limits + 7: permuted_idx_to_expanded_idx - tensor, token permutation mapping + 8: num_non_exiting_tiles - tensor, number of valid tiles + 9: global_sf - tensor, global scale factor """ # Override: use permuted_idx_to_expanded_idx for shape inference IDX_PERMUTED_IDX_TO_EXPANDED_IDX = 7 IDX_SHAPE_INFER = IDX_PERMUTED_IDX_TO_EXPANDED_IDX - def inputs_pre_hook(self, inputs: List[torch.Tensor]) -> List[torch.Tensor]: + def inputs_pre_hook(self, inputs: List) -> List: """Pre-hook for gather-based SwiGLU fusion kernel. Generates: @@ -252,9 +252,22 @@ def inputs_pre_hook(self, inputs: List[torch.Tensor]) -> List[torch.Tensor]: - tile_idx_to_mn_limit - permuted_idx_to_expanded_idx (for gather operation) - num_non_exiting_tiles + + Input layout (positions 1, 3, 4 are lists): + 0: a - tensor + 1: b_list - list of tensors + 2: a_sf - tensor + 3: b_sf_list - list of tensors + 4: alpha_list - list of tensors + 5: tile_idx_to_group_idx - tensor + 6: tile_idx_to_mn_limit - tensor + 7: permuted_idx_to_expanded_idx - tensor + 8: num_non_exiting_tiles - tensor + 9: global_sf - tensor """ - a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, tile_idx_to_mn_limit, \ - permuted_idx_to_expanded_idx, num_non_exiting_tiles, global_sf = inputs + a, b_list, a_sf, b_sf_list, alpha_list, tile_idx_to_group_idx, \ + tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, \ + num_non_exiting_tiles, global_sf = inputs # Verify permuted_idx_to_expanded_idx index matches the class constant assert inputs[ self. @@ -286,7 +299,7 @@ def inputs_pre_hook(self, inputs: List[torch.Tensor]) -> List[torch.Tensor]: local_num_experts=self.num_local_experts, tile_tokens_dim=self.tile_size, ) - return (a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, + return (a, b_list, a_sf, b_sf_list, alpha_list, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, global_sf) @@ -858,8 +871,10 @@ def get_valid_tactics( **kwargs, ) -> List[Tuple[int, int]]: a, b, *_ = inputs + b_list = b if isinstance(b, (list, tuple)) else [b] m, k = a.size(0), a.size(1) * 2 - l, n = b.size(0), b.size(1) + l = sum(bi.size(0) for bi in b_list) + n = b_list[0].size(1) mma_tiler_mn_candidates = [(self.tile_size, 128), (self.tile_size, 256)] @@ -1115,7 +1130,8 @@ def __init__(self, local_expert_offset: int, tile_size: int, output_dtype: torch.dtype, - scaling_vector_size: int = 16): + scaling_vector_size: int = 16, + b_tensor_l_sizes: Optional[Tuple[int, ...]] = None): super().__init__() self.num_experts = num_experts self.top_k = top_k @@ -1126,6 +1142,7 @@ def __init__(self, assert output_dtype == torch.bfloat16 self.output_dtype = output_dtype self.scaling_vector_size = scaling_vector_size + self.b_tensor_l_sizes = b_tensor_l_sizes if (sm_version := get_sm_version()) not in (100, 103): raise ValueError( @@ -1146,6 +1163,7 @@ def unique_id(self): self.tile_size, self.output_dtype, self.scaling_vector_size, + self.b_tensor_l_sizes, ) def get_valid_tactics( @@ -1154,9 +1172,12 @@ def get_valid_tactics( profile: OptimizationProfile, **kwargs, ) -> List[Tuple[int, int]]: - a, b, *_ = inputs + a, b_list, *_ = inputs + if not isinstance(b_list, (list, tuple)): + raise TypeError("weight must be a list of tensors") m, k = a.size(0), a.size(1) * 2 - l, n = b.size(0), b.size(1) + l = sum(bi.size(0) for bi in b_list) + n = b_list[0].size(1) mma_tiler_mn_candidates = [(self.tile_size, 128), (self.tile_size, 256)] @@ -1222,29 +1243,45 @@ def get_tuning_config(self) -> TuningConfig: def forward(self, inputs: List[torch.Tensor], tactic: Optional[tuple]) -> torch.Tensor: - a, b, a_sf, b_sf, alpha, c, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, token_final_scales = inputs + a, b_list, a_sf, b_sf_list, alpha_list, c, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, token_final_scales = inputs + if not isinstance(b_list, (list, tuple)): + raise TypeError("weight must be a list of tensors") + if not isinstance(b_sf_list, (list, tuple)): + raise TypeError("weight_scale must be a list of tensors") + if not isinstance(alpha_list, (list, tuple)): + raise TypeError("alpha must be a list of tensors") + assert len(b_list) == len(b_sf_list) == len(alpha_list) + b_tensor_l_sizes = tuple(bi.size(0) for bi in b_list) + + b0 = b_list[0] + b_sf0 = b_sf_list[0] + alpha0 = alpha_list[0] assert a.dtype == torch.float4_e2m1fn_x2 assert a.dim() == 2 - assert b.dtype == torch.float4_e2m1fn_x2 - assert b.dim() == 3 + assert b0.dtype == torch.float4_e2m1fn_x2 + assert b0.dim() == 3 assert a_sf.dtype == torch.uint8 assert a_sf.dim() == 1 - assert b_sf.dtype == torch.uint8 - assert b_sf.dim() == 3 - assert alpha.dtype == torch.float32 - assert alpha.dim() == 1 + assert b_sf0.dtype == torch.uint8 + assert b_sf0.dim() == 3 + assert alpha0.dtype == torch.float32 + assert alpha0.dim() == 1 m, k = a.size(0), a.size(1) * 2 - l, n = b.size(0), b.size(1) + l = sum(bi.size(0) for bi in b_list) + n = b0.size(1) scale_k = k // self.scaling_vector_size assert m % self.tile_size == 0 assert k % (self.scaling_vector_size * 4) == 0 - assert b.size(2) * 2 == k + assert b0.size(2) * 2 == k assert a_sf.size(0) == m * scale_k - assert b_sf.size(0) == l - assert b_sf.size(1) == n - assert b_sf.size(2) == scale_k - assert alpha.size(0) == l + for bi, bsfi, ai in zip(b_list, b_sf_list, alpha_list): + assert bi.size(1) == n + assert bi.size(2) * 2 == k + assert bsfi.size(0) == bi.size(0) + assert bsfi.size(1) == n + assert bsfi.size(2) == scale_k + assert ai.size(0) == bi.size(0) assert c.dtype == self.output_dtype assert c.dim() == 2 @@ -1268,20 +1305,10 @@ def forward(self, inputs: List[torch.Tensor], a.data_ptr(), cute.AddressSpace.gmem, assumed_align=32) - b_ptr = make_ptr(cutlass.Float4E2M1FN, - b.data_ptr(), - cute.AddressSpace.gmem, - assumed_align=32) a_sf_ptr = make_ptr(cutlass.Float8E4M3FN, a_sf.data_ptr(), cute.AddressSpace.gmem, assumed_align=16) - b_sf_ptr = make_ptr(cutlass.Float8E4M3FN, - b_sf.data_ptr(), - cute.AddressSpace.gmem, - assumed_align=16) - alpha_ptr = make_ptr(cutlass.Float32, alpha.data_ptr(), - cute.AddressSpace.gmem) tile_idx_to_group_idx_ptr = make_ptr( cutlass.Int32, tile_idx_to_group_idx.data_ptr(), cute.AddressSpace.gmem) @@ -1302,6 +1329,23 @@ def forward(self, inputs: List[torch.Tensor], cute.AddressSpace.gmem, assumed_align=16) + b_ptr = tuple( + make_ptr(cutlass.Float4E2M1FN, + bi.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32) + for bi in b_list) + b_sf_ptr = tuple( + make_ptr(cutlass.Float8E4M3FN, + bsfi.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16) + for bsfi in b_sf_list) + alpha_ptr = tuple( + make_ptr(cutlass.Float32, ai.data_ptr(), + cute.AddressSpace.gmem) + for ai in alpha_list) + torch_stream = torch.cuda.current_stream() stream = cuda.CUstream(torch_stream.cuda_stream) @@ -1315,7 +1359,7 @@ def forward(self, inputs: List[torch.Tensor], 0] == self.tile_size, f"Tactic ({tactic}) is incompatible with tile size ({self.tile_size})" cache_key = (self.scaling_vector_size, self.tile_size, mma_tiler_mn, - cluster_shape_mn, raster_along_m) + cluster_shape_mn, raster_along_m, b_tensor_l_sizes) if cache_key not in self.__class__.kernel_cache: gemm = self.__class__.kernel_class( sf_vec_size=self.scaling_vector_size, @@ -1323,14 +1367,14 @@ def forward(self, inputs: List[torch.Tensor], cluster_shape_mn=cluster_shape_mn, use_blkred=True, raster_along_m=raster_along_m, + b_tensor_l_sizes=b_tensor_l_sizes, ) # Compute max active clusters on current device hardware_info = cutlass.utils.HardwareInfo() max_active_clusters = hardware_info.get_max_active_clusters( cluster_shape_mn[0] * cluster_shape_mn[1]) - compiled_gemm = cute.compile( - gemm.wrapper, + compile_args = [ a_ptr, b_ptr, a_sf_ptr, @@ -1345,9 +1389,12 @@ def forward(self, inputs: List[torch.Tensor], m, n, k, - l, - num_tokens, - self.top_k, + num_tokens, self.top_k, + ] + + compiled_gemm = cute.compile( + gemm.wrapper, + *compile_args, tile_size=self.tile_size, scaling_vector_size=self.scaling_vector_size, max_active_clusters=max_active_clusters, @@ -1357,7 +1404,7 @@ def forward(self, inputs: List[torch.Tensor], else: compiled_gemm = self.__class__.kernel_cache[cache_key] - compiled_gemm( + exec_args = [ a_ptr, b_ptr, a_sf_ptr, @@ -1372,11 +1419,9 @@ def forward(self, inputs: List[torch.Tensor], m, n, k, - l, - num_tokens, - self.top_k, - stream=stream, - ) + num_tokens, self.top_k, + ] + compiled_gemm(*exec_args, stream=stream) return c @torch.library.custom_op( @@ -1385,10 +1430,10 @@ def forward(self, inputs: List[torch.Tensor], device_types="cuda") def cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell( input: torch.Tensor, - weight: torch.Tensor, + weight: List[torch.Tensor], input_scale: torch.Tensor, - weight_scale: torch.Tensor, - alpha: torch.Tensor, + weight_scale: List[torch.Tensor], + alpha: List[torch.Tensor], output: torch.Tensor, tile_idx_to_group_idx: torch.Tensor, tile_idx_to_mn_limit: torch.Tensor, @@ -1405,9 +1450,11 @@ def cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell( ) -> None: tuner = AutoTuner.get() + b_tensor_l_sizes = tuple(w.size(0) + for w in weight) if len(weight) > 1 else None runner = Sm100BlockScaledContiguousGroupedGemmFinalizeFusionRunner( num_experts, top_k, num_local_experts, local_expert_offset, - tile_size, output_dtype, scaling_vector_size) + tile_size, output_dtype, scaling_vector_size, b_tensor_l_sizes) inputs = [ input, weight, input_scale, weight_scale, alpha, output, @@ -1430,10 +1477,10 @@ def cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell( device_types="cuda") def cute_dsl_nvfp4_grouped_gemm_finalize_blackwell( input: torch.Tensor, - weight: torch.Tensor, + weight: List[torch.Tensor], input_scale: torch.Tensor, - weight_scale: torch.Tensor, - alpha: torch.Tensor, + weight_scale: List[torch.Tensor], + alpha: List[torch.Tensor], tile_idx_to_group_idx: torch.Tensor, tile_idx_to_mn_limit: torch.Tensor, permuted_idx_to_expanded_idx: torch.Tensor, @@ -1448,7 +1495,7 @@ def cute_dsl_nvfp4_grouped_gemm_finalize_blackwell( scaling_vector_size: int = 16, ) -> torch.Tensor: num_tokens = token_final_scales.size(0) - n = weight.size(1) + n = weight[0].size(1) output = torch.zeros(num_tokens, n, dtype=output_dtype, @@ -1479,10 +1526,10 @@ def cute_dsl_nvfp4_grouped_gemm_finalize_blackwell( "trtllm::cute_dsl_nvfp4_grouped_gemm_finalize_blackwell") def _( input: torch.Tensor, - weight: torch.Tensor, + weight: List[torch.Tensor], input_scale: torch.Tensor, - weight_scale: torch.Tensor, - alpha: torch.Tensor, + weight_scale: List[torch.Tensor], + alpha: List[torch.Tensor], tile_idx_to_group_idx: torch.Tensor, tile_idx_to_mn_limit: torch.Tensor, permuted_idx_to_expanded_idx: torch.Tensor, @@ -1497,7 +1544,7 @@ def _( scaling_vector_size: int = 16, ) -> torch.Tensor: num_tokens = token_final_scales.size(0) - n = weight.size(1) + n = weight[0].size(1) return torch.empty(num_tokens, n, dtype=output_dtype, @@ -1828,13 +1875,23 @@ class Sm100BlockScaledContiguousGatherGroupedGemmSwigluFusionRunner( kernel_cache = dict() tuning_config_cache = dict() + # Maximum number of B tensors supported (must match kernel's MAX_B_TENSORS) + MAX_B_TENSORS = 4 + def __init__(self, num_experts: int, top_k: int, num_local_experts: int, local_expert_offset: int, tile_size: int, - scaling_vector_size: int = 16): + scaling_vector_size: int = 16, + b_tensor_l_sizes: Optional[Tuple[int, ...]] = None): + """Initialize the runner. + + Args: + b_tensor_l_sizes: Tuple of L sizes for each B tensor in multi-B mode. + None for single-B mode. Used for kernel cache key. + """ super().__init__() self.num_experts = num_experts self.top_k = top_k @@ -1846,6 +1903,7 @@ def __init__(self, ) self.tile_size = tile_size self.scaling_vector_size = scaling_vector_size + self.b_tensor_l_sizes = b_tensor_l_sizes if (sm_version := get_sm_version()) not in (100, 103): raise ValueError( @@ -1865,19 +1923,24 @@ def unique_id(self): self.local_expert_offset, self.tile_size, self.scaling_vector_size, + self.b_tensor_l_sizes, ) def get_valid_tactics( self, - inputs: List[torch.Tensor], + inputs: List, profile: OptimizationProfile, **kwargs, ) -> List[Tuple[int, int]]: - a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, *_ = inputs + # Tuning uses layout: a, b_list, a_sf, b_sf_list, alpha_list, ... + a = inputs[0] + b_list = inputs[1] # List of B tensors + permuted_idx_to_expanded_idx = inputs[7] # m is the permuted size from permuted_idx_to_expanded_idx, not from a m = permuted_idx_to_expanded_idx.size(0) k = a.size(1) * 2 - l, n = b.size(0), b.size(1) + l = sum(bi.size(0) for bi in b_list) + n = b_list[0].size(1) mma_tiler_mn_candidates = [(self.tile_size, 128), (self.tile_size, 256)] @@ -1917,6 +1980,9 @@ def get_tuning_config(self) -> TuningConfig: self.num_local_experts, self.local_expert_offset, self.tile_size) + # Tuning uses layout: + # a, b_list, a_sf, b_sf_list, alpha_list, tile_idx, tile_mn_limit, permuted_idx, ... + # Constraint indices adjusted for list inputs at positions 1, 3, 4 self.__class__.tuning_config_cache[key] = TuningConfig( # Use permuted_idx_to_expanded_idx (IDX_SHAPE_INFER) for tuning dynamic_tensor_specs=(DynamicTensorSpec( @@ -1938,41 +2004,57 @@ def get_tuning_config(self) -> TuningConfig: ) return self.__class__.tuning_config_cache[key] - def forward(self, inputs: List[torch.Tensor], + def forward(self, inputs: List, tactic: Optional[tuple]) -> torch.Tensor: - a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, global_sf = inputs - # Verify permuted_idx_to_expanded_idx index matches the class constant - assert inputs[ - GatherGroupedGemmInputsHelper. - IDX_PERMUTED_IDX_TO_EXPANDED_IDX] is permuted_idx_to_expanded_idx + """Forward pass supporting both single tensor and list inputs. + + Input layout (positions 1, 3, 4 are lists for multi-B support): + 0: a - tensor + 1: b_list - list of tensors + 2: a_sf - tensor + 3: b_sf_list - list of tensors + 4: alpha_list - list of tensors + 5: tile_idx_to_group_idx - tensor + 6: tile_idx_to_mn_limit - tensor + 7: permuted_idx_to_expanded_idx - tensor + 8: num_non_exiting_tiles - tensor + 9: global_sf - tensor + """ + a, b_list, a_sf, b_sf_list, alpha_list, tile_idx_to_group_idx, \ + tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, \ + num_non_exiting_tiles, global_sf = inputs + + b_tensor_l_sizes = tuple(bi.size(0) for bi in b_list) + + b0 = b_list[0] # Use first B for shape inference + + # Verify input dtypes and dimensions assert a.dtype == torch.float4_e2m1fn_x2 assert a.dim() == 2 - assert b.dtype == torch.float4_e2m1fn_x2 - assert b.dim() == 3 + assert b0.dtype == torch.float4_e2m1fn_x2 + assert b0.dim() == 3 assert a_sf.dtype == torch.uint8 assert a_sf.dim() == 2 - assert b_sf.dtype == torch.uint8 - assert b_sf.dim() == 3 - assert alpha.dtype == torch.float32 - assert alpha.dim() == 1 + assert b_sf_list[0].dtype == torch.uint8 + assert b_sf_list[0].dim() == 3 + assert alpha_list[0].dtype == torch.float32 + assert alpha_list[0].dim() == 1 # a.size(0) is orig_m (original input size before gather) # permuted_idx_to_expanded_idx.size(0) is m (permuted size after gather) orig_m, k = a.size(0), a.size(1) * 2 m = permuted_idx_to_expanded_idx.size(0) - l, n = b.size(0), b.size(1) + n = b0.size(1) + l = sum(bi.size(0) for bi in b_list) scale_k = k // self.scaling_vector_size interm_size = n // 2 + assert m % self.tile_size == 0 assert k % (self.scaling_vector_size * 4) == 0 assert n % (self.scaling_vector_size * 4 * 2) == 0 - assert b.size(2) * 2 == k + assert b0.size(2) * 2 == k assert a_sf.size(0) == orig_m assert a_sf.size(1) == scale_k - assert b_sf.size(0) == l - assert b_sf.size(1) == n - assert b_sf.size(2) == scale_k - assert alpha.size(0) == l num_tiles = m // self.tile_size assert tile_idx_to_group_idx.dtype == torch.int32 @@ -1986,29 +2068,29 @@ def forward(self, inputs: List[torch.Tensor], assert global_sf.dtype == torch.float32 assert global_sf.numel() == 1 + # Allocate output tensors c = torch.empty(m, interm_size // 2, dtype=a.dtype, device=a.device) c_sf = torch.empty(m * interm_size // self.scaling_vector_size, dtype=a_sf.dtype, device=a_sf.device) + # Create common pointers a_ptr = make_ptr(cutlass.Float4E2M1FN, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=32) - b_ptr = make_ptr(cutlass.Float4E2M1FN, - b.data_ptr(), - cute.AddressSpace.gmem, - assumed_align=32) a_sf_ptr = make_ptr(cutlass.Float8E4M3FN, a_sf.data_ptr(), cute.AddressSpace.gmem, assumed_align=16) - b_sf_ptr = make_ptr(cutlass.Float8E4M3FN, - b_sf.data_ptr(), + c_ptr = make_ptr(cutlass.Float4E2M1FN, + c.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32) + c_sf_ptr = make_ptr(cutlass.Float8E4M3FN, + c_sf.data_ptr(), cute.AddressSpace.gmem, assumed_align=16) - alpha_ptr = make_ptr(cutlass.Float32, alpha.data_ptr(), - cute.AddressSpace.gmem) tile_idx_to_group_idx_ptr = make_ptr( cutlass.Int32, tile_idx_to_group_idx.data_ptr(), cute.AddressSpace.gmem) @@ -2023,14 +2105,23 @@ def forward(self, inputs: List[torch.Tensor], cute.AddressSpace.gmem) global_sf_ptr = make_ptr(cutlass.Float32, global_sf.data_ptr(), cute.AddressSpace.gmem) - c_ptr = make_ptr(cutlass.Float4E2M1FN, - c.data_ptr(), - cute.AddressSpace.gmem, - assumed_align=32) - c_sf_ptr = make_ptr(cutlass.Float8E4M3FN, - c_sf.data_ptr(), - cute.AddressSpace.gmem, - assumed_align=16) + + b_ptr = tuple( + make_ptr(cutlass.Float4E2M1FN, + bi.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32) + for bi in b_list) + b_sf_ptr = tuple( + make_ptr(cutlass.Float8E4M3FN, + bsfi.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16) + for bsfi in b_sf_list) + alpha_ptr = tuple( + make_ptr(cutlass.Float32, ai.data_ptr(), + cute.AddressSpace.gmem) + for ai in alpha_list) torch_stream = torch.cuda.current_stream() stream = cuda.CUstream(torch_stream.cuda_stream) @@ -2045,7 +2136,9 @@ def forward(self, inputs: List[torch.Tensor], 0] == self.tile_size, f"Tactic ({tactic}) is incompatible with tile size ({self.tile_size})" cache_key = (self.scaling_vector_size, self.tile_size, self.top_k, - mma_tiler_mn, cluster_shape_mn, raster_along_m) + mma_tiler_mn, cluster_shape_mn, raster_along_m, + b_tensor_l_sizes) + if cache_key not in self.__class__.kernel_cache: gemm = self.__class__.kernel_class( sf_vec_size=self.scaling_vector_size, @@ -2054,31 +2147,22 @@ def forward(self, inputs: List[torch.Tensor], vectorized_f32=True, topk=self.top_k, raster_along_m=raster_along_m, + b_tensor_l_sizes=b_tensor_l_sizes, ) - # Compute max active clusters on current device hardware_info = cutlass.utils.HardwareInfo() max_active_clusters = hardware_info.get_max_active_clusters( cluster_shape_mn[0] * cluster_shape_mn[1]) + compile_args = [ + a_ptr, b_ptr, a_sf_ptr, b_sf_ptr, c_ptr, c_sf_ptr, + alpha_ptr, tile_idx_to_group_idx_ptr, + tile_idx_to_mn_limit_ptr, permuted_idx_to_expanded_idx_ptr, + num_non_exiting_tiles_ptr, global_sf_ptr, orig_m, m, n, k, + ] + compiled_gemm = cute.compile( gemm.wrapper, - a_ptr, - b_ptr, - a_sf_ptr, - b_sf_ptr, - c_ptr, - c_sf_ptr, - alpha_ptr, - tile_idx_to_group_idx_ptr, - tile_idx_to_mn_limit_ptr, - permuted_idx_to_expanded_idx_ptr, - num_non_exiting_tiles_ptr, - global_sf_ptr, - orig_m, - m, - n, - k, - l, + *compile_args, tile_size=self.tile_size, scaling_vector_size=self.scaling_vector_size, max_active_clusters=max_active_clusters, @@ -2088,38 +2172,27 @@ def forward(self, inputs: List[torch.Tensor], else: compiled_gemm = self.__class__.kernel_cache[cache_key] - compiled_gemm( - a_ptr, - b_ptr, - a_sf_ptr, - b_sf_ptr, - c_ptr, - c_sf_ptr, - alpha_ptr, - tile_idx_to_group_idx_ptr, - tile_idx_to_mn_limit_ptr, - permuted_idx_to_expanded_idx_ptr, - num_non_exiting_tiles_ptr, - global_sf_ptr, - orig_m, - m, - n, - k, - l, - stream=stream, - ) + exec_args = [ + a_ptr, b_ptr, a_sf_ptr, b_sf_ptr, c_ptr, c_sf_ptr, alpha_ptr, + tile_idx_to_group_idx_ptr, tile_idx_to_mn_limit_ptr, + permuted_idx_to_expanded_idx_ptr, num_non_exiting_tiles_ptr, + global_sf_ptr, orig_m, m, n, k, + ] + + compiled_gemm(*exec_args, stream=stream) + return c, c_sf @torch.library.custom_op( - "trtllm::cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell", + "trtllm::cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell_multi_b", mutates_args=(), device_types="cuda") - def cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell( + def cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell_multi_b( input: torch.Tensor, - weight: torch.Tensor, + weight: List[torch.Tensor], input_scale: torch.Tensor, - weight_scale: torch.Tensor, - alpha: torch.Tensor, + weight_scale: List[torch.Tensor], + alpha: List[torch.Tensor], tile_idx_to_group_idx: torch.Tensor, tile_idx_to_mn_limit: torch.Tensor, permuted_idx_to_expanded_idx: torch.Tensor, @@ -2132,11 +2205,20 @@ def cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell( tile_size: int, scaling_vector_size: int = 16, ) -> Tuple[torch.Tensor, torch.Tensor]: + """CuteDSL-based NVFP4 gather grouped GEMM with SwiGLU fusion (multi-B list interface). + + Args: + weight: List of B tensors. Single-B mode: [b], multi-B mode: [b0, b1, ...]. + weight_scale: List of scale tensors, matching weight. + alpha: List of alpha tensors, matching weight. + """ tuner = AutoTuner.get() + b_tensor_l_sizes = tuple(w.size(0) for w in weight) + runner = Sm100BlockScaledContiguousGatherGroupedGemmSwigluFusionRunner( num_experts, top_k, num_local_experts, local_expert_offset, - tile_size, scaling_vector_size) + tile_size, scaling_vector_size, b_tensor_l_sizes) inputs = [ input, weight, input_scale, weight_scale, alpha, tile_idx_to_group_idx, tile_idx_to_mn_limit, @@ -2144,17 +2226,86 @@ def cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell( ] _, best_tactic = tuner.choose_one( - "trtllm::cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell", + "trtllm::cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell_multi_b", [runner], runner.get_tuning_config(), inputs, ) - output = runner(inputs, tactic=best_tactic) + + # Call forward with inputs list + output = runner.forward(inputs, tactic=best_tactic) return output + @torch.library.register_fake( + "trtllm::cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell_multi_b") + def _fake_multi_b( + input: torch.Tensor, + weight: List[torch.Tensor], + input_scale: torch.Tensor, + weight_scale: List[torch.Tensor], + alpha: List[torch.Tensor], + tile_idx_to_group_idx: torch.Tensor, + tile_idx_to_mn_limit: torch.Tensor, + permuted_idx_to_expanded_idx: torch.Tensor, + num_non_exiting_tiles: torch.Tensor, + global_sf: torch.Tensor, + num_experts: int, + top_k: int, + num_local_experts: int, + local_expert_offset: int, + tile_size: int, + scaling_vector_size: int = 16, + ) -> Tuple[torch.Tensor, torch.Tensor]: + m = permuted_idx_to_expanded_idx.size(0) + n = weight[0].size(1) + interm_size = n // 2 + output = torch.empty(m, + interm_size // 2, + dtype=input.dtype, + device=input.device) + output_scale = torch.empty(m * interm_size // scaling_vector_size, + dtype=input_scale.dtype, + device=input_scale.device) + return output, output_scale + + @torch.library.custom_op( + "trtllm::cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell", + mutates_args=(), + device_types="cuda") + def cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell( + input: torch.Tensor, + weight: torch.Tensor, + input_scale: torch.Tensor, + weight_scale: torch.Tensor, + alpha: torch.Tensor, + tile_idx_to_group_idx: torch.Tensor, + tile_idx_to_mn_limit: torch.Tensor, + permuted_idx_to_expanded_idx: torch.Tensor, + num_non_exiting_tiles: torch.Tensor, + global_sf: torch.Tensor, + num_experts: int, + top_k: int, + num_local_experts: int, + local_expert_offset: int, + tile_size: int, + scaling_vector_size: int = 16, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """CuteDSL-based NVFP4 gather grouped GEMM with SwiGLU fusion (single-B tensor interface). + + Thin wrapper: wraps single tensors into lists and calls + cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell_multi_b. + """ + return torch.ops.trtllm.cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell_multi_b( + input, [weight], input_scale, [weight_scale], [alpha], + tile_idx_to_group_idx, tile_idx_to_mn_limit, + permuted_idx_to_expanded_idx, num_non_exiting_tiles, global_sf, + num_experts, top_k, num_local_experts, local_expert_offset, + tile_size, scaling_vector_size, + ) + @torch.library.register_fake( "trtllm::cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell") - def _( + def _fake_single_b( input: torch.Tensor, weight: torch.Tensor, input_scale: torch.Tensor, diff --git a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py index 5a4c5d0a0067..c339787301e6 100644 --- a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py +++ b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py @@ -377,6 +377,9 @@ class BlockScaledContiguousGatherGroupedGemmKernel: ... ) """ + # Maximum number of B tensors supported + MAX_B_TENSORS = 4 + def __init__( self, sf_vec_size: int, @@ -385,6 +388,7 @@ def __init__( vectorized_f32: bool, topk: cutlass.Int64, raster_along_m: bool = False, + b_tensor_l_sizes: Optional[Tuple[int, ...]] = None, ): """Initializes the configuration for a Blackwell blockscaled dense GEMM kernel with gather operation and SwiGLU fusion. @@ -420,6 +424,10 @@ def __init__( :type vectorized_f32: bool :param topk: Number of experts selected per token (used for token ID mapping). :type topk: cutlass.Int64 + :param b_tensor_l_sizes: Optional tuple of L sizes for each B tensor. + E.g., (8, 8, 16) means 3 B tensors with L=8, 8, 16. Sum equals total L. + If None, single B tensor mode (backward compatible). + :type b_tensor_l_sizes: Optional[Tuple[int, ...]] """ self.sf_vec_size = sf_vec_size @@ -502,6 +510,26 @@ def __init__( self.vectorized_f32 = vectorized_f32 + # Multi-B tensor configuration + if b_tensor_l_sizes is None: + self.num_b_tensors = 1 + self.b_tensor_l_sizes = None + # Offsets padded for safe indexing in kernel + self.b_tensor_l_offsets = (0,) + (2**30,) * self.MAX_B_TENSORS + else: + assert len(b_tensor_l_sizes) <= self.MAX_B_TENSORS, ( + f"Max {self.MAX_B_TENSORS} B tensors, got {len(b_tensor_l_sizes)}" + ) + self.num_b_tensors = len(b_tensor_l_sizes) + self.b_tensor_l_sizes = b_tensor_l_sizes + offsets = [0] + for l_size in b_tensor_l_sizes: + offsets.append(offsets[-1] + l_size) + # Pad to MAX_B_TENSORS + 1 for safe indexing + while len(offsets) < self.MAX_B_TENSORS + 1: + offsets.append(2**30) + self.b_tensor_l_offsets = tuple(offsets) + def _setup_attributes(self): """Set up configurations that are dependent on GEMM inputs @@ -696,17 +724,17 @@ def _setup_attributes(self): def __call__( self, a: cute.Tensor, - b: cute.Tensor, + b: Union[cute.Tensor, Tuple[cute.Tensor, ...]], c: cute.Tensor, sfa: cute.Tensor, - sfb: cute.Tensor, + sfb: Union[cute.Tensor, Tuple[cute.Tensor, ...]], sfc_tensor: Optional[cute.Tensor], norm_const_tensor: Optional[cute.Tensor], tile_idx_to_expert_idx: cute.Tensor, tile_idx_to_mn_limit: cute.Tensor, token_id_mapping_tensor: cute.Tensor, num_non_exiting_tiles: cute.Tensor, - alpha: cute.Tensor, + alpha: Union[cute.Tensor, Tuple[cute.Tensor, ...]], max_active_clusters: cutlass.Constexpr, stream: cuda.CUstream, epilogue_op: cutlass.Constexpr = lambda x: x, @@ -774,11 +802,14 @@ def __call__( """ # Setup static attributes before smem/grid/tma computation self.a_dtype: Type[cutlass.Numeric] = a.element_type - self.b_dtype: Type[cutlass.Numeric] = b.element_type + # Handle tuple of B tensors + b_tuple = b if isinstance(b, tuple) else (b,) + sfb_tuple = sfb if isinstance(sfb, tuple) else (sfb,) + self.b_dtype: Type[cutlass.Numeric] = b_tuple[0].element_type self.c_dtype: Type[cutlass.Numeric] = c.element_type self.sf_dtype: Type[cutlass.Numeric] = sfa.element_type self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode() - self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(b_tuple[0]).mma_major_mode() self.c_layout = utils.LayoutEnum.from_tensor(c) # Check if input data types are compatible with MMA instruction @@ -788,10 +819,28 @@ def __call__( # Setup attributes that dependent on gemm inputs self._setup_attributes() - # Setup sfb tensor by filling B tensor to scale factor atom layout - # ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL) - sfb_layout = blockscaled_utils.tile_atom_to_shape_SF(b.shape, self.sf_vec_size) - sfb = cute.make_tensor(sfb.iterator, sfb_layout) + # Setup sfb tensors - create layout for each B tensor (use const_expr, not loop) + sfb_layout_0 = blockscaled_utils.tile_atom_to_shape_SF(b_tuple[0].shape, self.sf_vec_size) + sfb_tensor_0 = cute.make_tensor(sfb_tuple[0].iterator, sfb_layout_0) + sfb_tensors = [sfb_tensor_0] + if cutlass.const_expr(self.num_b_tensors >= 2): + sfb_layout_1 = blockscaled_utils.tile_atom_to_shape_SF( + b_tuple[1].shape, self.sf_vec_size + ) + sfb_tensors.append(cute.make_tensor(sfb_tuple[1].iterator, sfb_layout_1)) + if cutlass.const_expr(self.num_b_tensors >= 3): + sfb_layout_2 = blockscaled_utils.tile_atom_to_shape_SF( + b_tuple[2].shape, self.sf_vec_size + ) + sfb_tensors.append(cute.make_tensor(sfb_tuple[2].iterator, sfb_layout_2)) + if cutlass.const_expr(self.num_b_tensors >= 4): + sfb_layout_3 = blockscaled_utils.tile_atom_to_shape_SF( + b_tuple[3].shape, self.sf_vec_size + ) + sfb_tensors.append(cute.make_tensor(sfb_tuple[3].iterator, sfb_layout_3)) + sfb_tuple = tuple(sfb_tensors) + # Backward compat alias + sfb = sfb_tuple[0] # Setup sfc tensor by filling C tensor to scale factor atom layout self.generate_sfc = sfc_tensor is not None and norm_const_tensor is not None @@ -821,51 +870,82 @@ def __call__( ) atom_thr_size = cute.size(tiled_mma.thr_id.shape) - # Setup TMA load for B + # Setup TMA ops (shared across all B tensors) b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id) - b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) - tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( - b_op, - b, - b_smem_layout, - self.mma_tiler, - tiled_mma, - self.cluster_layout_vmnk.shape, - ) - - # Setup TMA load for SFB sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(self.cluster_shape_mn, tiled_mma.thr_id) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) sfb_smem_layout = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0)) - tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( - sfb_op, - sfb, - sfb_smem_layout, - self.mma_tiler_sfb, - tiled_mma_sfb, - self.cluster_layout_sfb_vmnk.shape, - internal_type=cutlass.Int16, - ) - - # This modifies the layout to handle overlapping 256x(# of scale factors for a single column of B (nNSF)) - # logical blocks for SFB when cta_tile_shape_n=192. - if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192): - x = tma_tensor_sfb.stride[0][1] - y = cute.ceil_div(tma_tensor_sfb.shape[0][1], 4) - new_shape = ( - (tma_tensor_sfb.shape[0][0], ((2, 2), y)), - tma_tensor_sfb.shape[1], - tma_tensor_sfb.shape[2], + # Helper to create TMA for one B tensor + def _make_tma_b(b_tensor, sfb_tensor): + atom_b, tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b_tensor, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, ) - # Use right multiplication for ScaledBasis (3 * x instead of x * 3) - x_times_3 = 3 * x - new_stride = ( - (tma_tensor_sfb.stride[0][0], ((x, x), x_times_3)), - tma_tensor_sfb.stride[1], - tma_tensor_sfb.stride[2], + atom_sfb, tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( + sfb_op, + sfb_tensor, + sfb_smem_layout, + self.mma_tiler_sfb, + tiled_mma_sfb, + self.cluster_layout_sfb_vmnk.shape, + internal_type=cutlass.Int16, ) - tma_tensor_sfb_new_layout = cute.make_layout(new_shape, stride=new_stride) - tma_tensor_sfb = cute.make_tensor(tma_tensor_sfb.iterator, tma_tensor_sfb_new_layout) + # Handle overlapping layout for SFB when cta_tile_shape_n=192 + if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192): + x = tensor_sfb.stride[0][1] + y = cute.ceil_div(tensor_sfb.shape[0][1], 4) + new_shape = ( + (tensor_sfb.shape[0][0], ((2, 2), y)), + tensor_sfb.shape[1], + tensor_sfb.shape[2], + ) + x_times_3 = 3 * x + new_stride = ( + (tensor_sfb.stride[0][0], ((x, x), x_times_3)), + tensor_sfb.stride[1], + tensor_sfb.stride[2], + ) + tensor_sfb = cute.make_tensor( + tensor_sfb.iterator, cute.make_layout(new_shape, stride=new_stride) + ) + return atom_b, tensor_b, atom_sfb, tensor_sfb + + # Create TMA for all B tensors (use const_expr, not loop) + atom_b_0, tensor_b_0, atom_sfb_0, tensor_sfb_0 = _make_tma_b(b_tuple[0], sfb_tuple[0]) + tma_atoms_b = [atom_b_0] + tma_tensors_b = [tensor_b_0] + tma_atoms_sfb = [atom_sfb_0] + tma_tensors_sfb = [tensor_sfb_0] + if cutlass.const_expr(self.num_b_tensors >= 2): + atom_b_1, tensor_b_1, atom_sfb_1, tensor_sfb_1 = _make_tma_b(b_tuple[1], sfb_tuple[1]) + tma_atoms_b.append(atom_b_1) + tma_tensors_b.append(tensor_b_1) + tma_atoms_sfb.append(atom_sfb_1) + tma_tensors_sfb.append(tensor_sfb_1) + if cutlass.const_expr(self.num_b_tensors >= 3): + atom_b_2, tensor_b_2, atom_sfb_2, tensor_sfb_2 = _make_tma_b(b_tuple[2], sfb_tuple[2]) + tma_atoms_b.append(atom_b_2) + tma_tensors_b.append(tensor_b_2) + tma_atoms_sfb.append(atom_sfb_2) + tma_tensors_sfb.append(tensor_sfb_2) + if cutlass.const_expr(self.num_b_tensors >= 4): + atom_b_3, tensor_b_3, atom_sfb_3, tensor_sfb_3 = _make_tma_b(b_tuple[3], sfb_tuple[3]) + tma_atoms_b.append(atom_b_3) + tma_tensors_b.append(tensor_b_3) + tma_atoms_sfb.append(atom_sfb_3) + tma_tensors_sfb.append(tensor_sfb_3) + tma_atoms_b = tuple(tma_atoms_b) + tma_tensors_b = tuple(tma_tensors_b) + tma_atoms_sfb = tuple(tma_atoms_sfb) + tma_tensors_sfb = tuple(tma_tensors_sfb) + + # Handle alpha tuple (convert to tuple if single tensor) + alpha_tuple = alpha if isinstance(alpha, tuple) else (alpha,) b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) sfb_copy_size = cute.size_in_bytes(self.sf_dtype, sfb_smem_layout) @@ -990,11 +1070,11 @@ class SharedStorage2cta: tiled_mma, tiled_mma_sfb, a, - tma_atom_b, - tma_tensor_b, + tma_atoms_b, # Tuple of TMA atoms for B + tma_tensors_b, # Tuple of TMA tensors for B sfa, - tma_atom_sfb, - tma_tensor_sfb, + tma_atoms_sfb, # Tuple of TMA atoms for SFB + tma_tensors_sfb, # Tuple of TMA tensors for SFB tma_atom_c, tma_tensor_c, sfc_tensor, @@ -1003,7 +1083,7 @@ class SharedStorage2cta: tile_idx_to_mn_limit, token_id_mapping_tensor, num_non_exiting_tiles, - alpha, + alpha_tuple, self.cluster_layout_vmnk, self.cluster_layout_sfb_vmnk, self.a_smem_layout_staged, @@ -1074,11 +1154,11 @@ def kernel( tiled_mma: cute.TiledMma, tiled_mma_sfb: cute.TiledMma, mA_mkl: cute.Tensor, - tma_atom_b: cute.CopyAtom, - mB_nkl: cute.Tensor, + tma_atoms_b: Tuple[cute.CopyAtom, ...], + mB_nkl_tuple: Tuple[cute.Tensor, ...], mSFA_mkl: cute.Tensor, - tma_atom_sfb: cute.CopyAtom, - mSFB_nkl: cute.Tensor, + tma_atoms_sfb: Tuple[cute.CopyAtom, ...], + mSFB_nkl_tuple: Tuple[cute.Tensor, ...], tma_atom_c: cute.CopyAtom, mC_mnl: cute.Tensor, mSFC_mnl: Optional[cute.Tensor], @@ -1087,7 +1167,7 @@ def kernel( tile_idx_to_mn_limit: cute.Tensor, token_id_mapping_tensor: cute.Tensor, num_non_exiting_tiles: cute.Tensor, - alpha: cute.Tensor, + alpha_tuple: Tuple[cute.Tensor, ...], cluster_layout_vmnk: cute.Layout, cluster_layout_sfb_vmnk: cute.Layout, a_smem_layout_staged: cute.ComposedLayout, @@ -1109,8 +1189,18 @@ def kernel( # Prefetch tma desc # if warp_idx == self.tma_b_warp_id: - cpasync.prefetch_descriptor(tma_atom_b) - cpasync.prefetch_descriptor(tma_atom_sfb) + # Prefetch TMA descriptors for all B tensors using const_expr conditions + cpasync.prefetch_descriptor(tma_atoms_b[0]) + cpasync.prefetch_descriptor(tma_atoms_sfb[0]) + if cutlass.const_expr(self.num_b_tensors >= 2): + cpasync.prefetch_descriptor(tma_atoms_b[1]) + cpasync.prefetch_descriptor(tma_atoms_sfb[1]) + if cutlass.const_expr(self.num_b_tensors >= 3): + cpasync.prefetch_descriptor(tma_atoms_b[2]) + cpasync.prefetch_descriptor(tma_atoms_sfb[2]) + if cutlass.const_expr(self.num_b_tensors >= 4): + cpasync.prefetch_descriptor(tma_atoms_b[3]) + cpasync.prefetch_descriptor(tma_atoms_sfb[3]) cpasync.prefetch_descriptor(tma_atom_c) use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 @@ -1270,22 +1360,52 @@ def kernel( gA_mkl = cute.local_tile( mA_mkl, cute.slice_(self.cta_tile_shape_mnk, (None, 0, None)), (None, None, None) ) - # (bN, bK, loopN, loopK, loopL) - gB_nkl = cute.local_tile( - mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + # (bN, bK, loopN, loopK, loopL) - Use const_expr conditions for tuple indexing + gB_nkl_0 = cute.local_tile( + mB_nkl_tuple[0], cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) ) + if cutlass.const_expr(self.num_b_tensors >= 2): + gB_nkl_1 = cute.local_tile( + mB_nkl_tuple[1], cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) + if cutlass.const_expr(self.num_b_tensors >= 3): + gB_nkl_2 = cute.local_tile( + mB_nkl_tuple[2], cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) + if cutlass.const_expr(self.num_b_tensors >= 4): + gB_nkl_3 = cute.local_tile( + mB_nkl_tuple[3], cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) # (bM, bK, RestM, RestK, RestL) gSFA_mkl = cute.local_tile( mSFA_mkl, cute.slice_(self.cta_tile_shape_mnk_sfa, (None, 0, None)), (None, None, None) ) - # (bN, bK, RestN, RestK, RestL) - gSFB_nkl = cute.local_tile( - mSFB_nkl, + # (bN, bK, RestN, RestK, RestL) - Use const_expr conditions for tuple indexing + gSFB_nkl_0 = cute.local_tile( + mSFB_nkl_tuple[0], cute.slice_(self.mma_tiler_sfb, (0, None, None)), (None, None, None), ) + if cutlass.const_expr(self.num_b_tensors >= 2): + gSFB_nkl_1 = cute.local_tile( + mSFB_nkl_tuple[1], + cute.slice_(self.mma_tiler_sfb, (0, None, None)), + (None, None, None), + ) + if cutlass.const_expr(self.num_b_tensors >= 3): + gSFB_nkl_2 = cute.local_tile( + mSFB_nkl_tuple[2], + cute.slice_(self.mma_tiler_sfb, (0, None, None)), + (None, None, None), + ) + if cutlass.const_expr(self.num_b_tensors >= 4): + gSFB_nkl_3 = cute.local_tile( + mSFB_nkl_tuple[3], + cute.slice_(self.mma_tiler_sfb, (0, None, None)), + (None, None, None), + ) gToken_ml = cute.local_tile( token_id_mapping_tensor, cute.slice_(self.cta_tile_shape_mnk, (None, 0, 0)), (None,) @@ -1302,43 +1422,106 @@ def kernel( # thr_mma = tiled_mma.get_slice(mma_tile_coord_v) thr_mma_sfb = tiled_mma_sfb.get_slice(mma_tile_coord_v) - # (MMA, MMA_N, MMA_K, loopN, loopK, loopL) - tCgB = thr_mma.partition_B(gB_nkl) - # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) - tCgSFB = thr_mma_sfb.partition_B(gSFB_nkl) + # (MMA, MMA_N, MMA_K, loopN, loopK, loopL) - const_expr conditions + tCgB_0 = thr_mma.partition_B(gB_nkl_0) + if cutlass.const_expr(self.num_b_tensors >= 2): + tCgB_1 = thr_mma.partition_B(gB_nkl_1) + if cutlass.const_expr(self.num_b_tensors >= 3): + tCgB_2 = thr_mma.partition_B(gB_nkl_2) + if cutlass.const_expr(self.num_b_tensors >= 4): + tCgB_3 = thr_mma.partition_B(gB_nkl_3) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) - const_expr conditions + tCgSFB_0 = thr_mma_sfb.partition_B(gSFB_nkl_0) + if cutlass.const_expr(self.num_b_tensors >= 2): + tCgSFB_1 = thr_mma_sfb.partition_B(gSFB_nkl_1) + if cutlass.const_expr(self.num_b_tensors >= 3): + tCgSFB_2 = thr_mma_sfb.partition_B(gSFB_nkl_2) + if cutlass.const_expr(self.num_b_tensors >= 4): + tCgSFB_3 = thr_mma_sfb.partition_B(gSFB_nkl_3) # (MMA, MMA_M, MMA_N, loopM, loopN, loopL) tCgC = thr_mma.partition_C(gC_mnl) # # Partition global/shared tensor for TMA load B # - # TMA load B partition_S/D b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape) - # ((atom_v, rest_v), STAGE) - # ((atom_v, rest_v), loopM, loopK, loopL) - tBsB, tBgB = cpasync.tma_partition( - tma_atom_b, - block_in_cluster_coord_vmnk[1], - b_cta_layout, - cute.group_modes(sB, 0, 3), - cute.group_modes(tCgB, 0, 3), - ) - - # TMA load SFB partition_S/D sfb_cta_layout = cute.make_layout( cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape ) - # ((atom_v, rest_v), STAGE) - # ((atom_v, rest_v), RestN, RestK, RestL) - tBsSFB, tBgSFB = cute.nvgpu.cpasync.tma_partition( - tma_atom_sfb, + sB_grouped = cute.group_modes(sB, 0, 3) + sSFB_grouped = cute.group_modes(sSFB, 0, 3) + + # TMA partition for B tensor 0 + tBsB_0, tBgB_0 = cpasync.tma_partition( + tma_atoms_b[0], + block_in_cluster_coord_vmnk[1], + b_cta_layout, + sB_grouped, + cute.group_modes(tCgB_0, 0, 3), + ) + tBsSFB_0, tBgSFB_0 = cute.nvgpu.cpasync.tma_partition( + tma_atoms_sfb[0], block_in_cluster_coord_sfb_vmnk[1], sfb_cta_layout, - cute.group_modes(sSFB, 0, 3), - cute.group_modes(tCgSFB, 0, 3), - ) - tBsSFB = cute.filter_zeros(tBsSFB) - tBgSFB = cute.filter_zeros(tBgSFB) + sSFB_grouped, + cute.group_modes(tCgSFB_0, 0, 3), + ) + tBsSFB_0 = cute.filter_zeros(tBsSFB_0) + tBgSFB_0 = cute.filter_zeros(tBgSFB_0) + + # TMA partition for B tensor 1 (tBsB shared memory partition is same for all, use _ to ignore) + if cutlass.const_expr(self.num_b_tensors >= 2): + _, tBgB_1 = cpasync.tma_partition( + tma_atoms_b[1], + block_in_cluster_coord_vmnk[1], + b_cta_layout, + sB_grouped, + cute.group_modes(tCgB_1, 0, 3), + ) + _, tBgSFB_1 = cute.nvgpu.cpasync.tma_partition( + tma_atoms_sfb[1], + block_in_cluster_coord_sfb_vmnk[1], + sfb_cta_layout, + sSFB_grouped, + cute.group_modes(tCgSFB_1, 0, 3), + ) + tBgSFB_1 = cute.filter_zeros(tBgSFB_1) + + # TMA partition for B tensor 2 + if cutlass.const_expr(self.num_b_tensors >= 3): + _, tBgB_2 = cpasync.tma_partition( + tma_atoms_b[2], + block_in_cluster_coord_vmnk[1], + b_cta_layout, + sB_grouped, + cute.group_modes(tCgB_2, 0, 3), + ) + _, tBgSFB_2 = cute.nvgpu.cpasync.tma_partition( + tma_atoms_sfb[2], + block_in_cluster_coord_sfb_vmnk[1], + sfb_cta_layout, + sSFB_grouped, + cute.group_modes(tCgSFB_2, 0, 3), + ) + tBgSFB_2 = cute.filter_zeros(tBgSFB_2) + + # TMA partition for B tensor 3 + if cutlass.const_expr(self.num_b_tensors >= 4): + _, tBgB_3 = cpasync.tma_partition( + tma_atoms_b[3], + block_in_cluster_coord_vmnk[1], + b_cta_layout, + sB_grouped, + cute.group_modes(tCgB_3, 0, 3), + ) + _, tBgSFB_3 = cute.nvgpu.cpasync.tma_partition( + tma_atoms_sfb[3], + block_in_cluster_coord_sfb_vmnk[1], + sfb_cta_layout, + sSFB_grouped, + cute.group_modes(tCgSFB_3, 0, 3), + ) + tBgSFB_3 = cute.filter_zeros(tBgSFB_3) # # Partition shared/tensor memory tensor for TiledMMA_A/B/C @@ -1849,20 +2032,13 @@ def kernel( tile_info[1], tile_info[2], ) - # - # Slice to per mma tile index - # - # ((atom_v, rest_v), loopK) - tBgB_slice = tBgB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] + expert_idx = mma_tile_coord_mnl[2] # Apply SFB slicing hack when cta_tile_shape_n=64 slice_n = mma_tile_coord_mnl[1] if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64): slice_n = mma_tile_coord_mnl[1] // 2 - # ((atom_v, rest_v), RestK) - tBgSFB_slice = tBgSFB[(None, slice_n, None, mma_tile_coord_mnl[2])] - # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt b_producer_state.reset_count() peek_ab_empty_status = cutlass.Boolean(1) @@ -1872,35 +2048,247 @@ def kernel( # Tma load loop # for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): - # Conditionally wait for B buffer empty b_pipeline.producer_acquire(b_producer_state, peek_ab_empty_status) - - tBgB_k = tBgB_slice[(None, b_producer_state.count)] - tBgSFB_k = tBgSFB_slice[(None, b_producer_state.count)] - tBsB_pipe = tBsB[(None, b_producer_state.index)] - tBsSFB_pipe = tBsSFB[(None, b_producer_state.index)] - + tBsB_pipe = tBsB_0[(None, b_producer_state.index)] + tBsSFB_pipe = tBsSFB_0[(None, b_producer_state.index)] tma_bar = b_pipeline.producer_get_barrier(b_producer_state) - # TMA load B - cute.copy( - tma_atom_b, - tBgB_k, - tBsB_pipe, - tma_bar_ptr=tma_bar, - mcast_mask=b_full_mcast_mask, - ) - - # TMA load SFB - cute.copy( - tma_atom_sfb, - tBgSFB_k, - tBsSFB_pipe, - tma_bar_ptr=tma_bar, - mcast_mask=sfb_full_mcast_mask, - ) + # Select correct B tensor based on expert_idx + if cutlass.const_expr(self.num_b_tensors == 1): + # Single B tensor - original logic + tBgB_slice = tBgB_0[(None, mma_tile_coord_mnl[1], None, expert_idx)] + tBgSFB_slice = tBgSFB_0[(None, slice_n, None, expert_idx)] + cute.copy( + tma_atoms_b[0], + tBgB_slice[(None, b_producer_state.count)], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[0], + tBgSFB_slice[(None, b_producer_state.count)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + else: + # Multi-B tensor - select based on expert_idx + # Use nested const_expr ifs to avoid index out of range at compile time + if cutlass.const_expr(self.num_b_tensors == 2): + # Exactly 2 B tensors + if expert_idx < self.b_tensor_l_offsets[1]: + local_l_0 = expert_idx - self.b_tensor_l_offsets[0] + cute.copy( + tma_atoms_b[0], + tBgB_0[ + ( + None, + mma_tile_coord_mnl[1], + b_producer_state.count, + local_l_0, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[0], + tBgSFB_0[(None, slice_n, b_producer_state.count, local_l_0)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + else: + local_l_1 = expert_idx - self.b_tensor_l_offsets[1] + cute.copy( + tma_atoms_b[1], + tBgB_1[ + ( + None, + mma_tile_coord_mnl[1], + b_producer_state.count, + local_l_1, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[1], + tBgSFB_1[(None, slice_n, b_producer_state.count, local_l_1)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + elif cutlass.const_expr(self.num_b_tensors == 3): + # Exactly 3 B tensors + if expert_idx < self.b_tensor_l_offsets[1]: + local_l_0 = expert_idx - self.b_tensor_l_offsets[0] + cute.copy( + tma_atoms_b[0], + tBgB_0[ + ( + None, + mma_tile_coord_mnl[1], + b_producer_state.count, + local_l_0, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[0], + tBgSFB_0[(None, slice_n, b_producer_state.count, local_l_0)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + elif expert_idx < self.b_tensor_l_offsets[2]: + local_l_1 = expert_idx - self.b_tensor_l_offsets[1] + cute.copy( + tma_atoms_b[1], + tBgB_1[ + ( + None, + mma_tile_coord_mnl[1], + b_producer_state.count, + local_l_1, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[1], + tBgSFB_1[(None, slice_n, b_producer_state.count, local_l_1)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + else: + local_l_2 = expert_idx - self.b_tensor_l_offsets[2] + cute.copy( + tma_atoms_b[2], + tBgB_2[ + ( + None, + mma_tile_coord_mnl[1], + b_producer_state.count, + local_l_2, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[2], + tBgSFB_2[(None, slice_n, b_producer_state.count, local_l_2)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + else: + # 4 B tensors + if expert_idx < self.b_tensor_l_offsets[1]: + local_l_0 = expert_idx - self.b_tensor_l_offsets[0] + cute.copy( + tma_atoms_b[0], + tBgB_0[ + ( + None, + mma_tile_coord_mnl[1], + b_producer_state.count, + local_l_0, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[0], + tBgSFB_0[(None, slice_n, b_producer_state.count, local_l_0)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + elif expert_idx < self.b_tensor_l_offsets[2]: + local_l_1 = expert_idx - self.b_tensor_l_offsets[1] + cute.copy( + tma_atoms_b[1], + tBgB_1[ + ( + None, + mma_tile_coord_mnl[1], + b_producer_state.count, + local_l_1, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[1], + tBgSFB_1[(None, slice_n, b_producer_state.count, local_l_1)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + elif expert_idx < self.b_tensor_l_offsets[3]: + local_l_2 = expert_idx - self.b_tensor_l_offsets[2] + cute.copy( + tma_atoms_b[2], + tBgB_2[ + ( + None, + mma_tile_coord_mnl[1], + b_producer_state.count, + local_l_2, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[2], + tBgSFB_2[(None, slice_n, b_producer_state.count, local_l_2)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + else: + local_l_3 = expert_idx - self.b_tensor_l_offsets[3] + cute.copy( + tma_atoms_b[3], + tBgB_3[ + ( + None, + mma_tile_coord_mnl[1], + b_producer_state.count, + local_l_3, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[3], + tBgSFB_3[(None, slice_n, b_producer_state.count, local_l_3)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) - # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 b_producer_state.advance() peek_ab_empty_status = cutlass.Boolean(1) if b_producer_state.count < k_tile_cnt: @@ -2343,9 +2731,38 @@ def kernel( # # Get alpha for current group # - expert_idx = mma_tile_coord_mnl[2] - alpha_val = alpha[expert_idx] + + # Select alpha from correct tensor based on expert_idx + # Initialize alpha_val first to avoid DSL "None prior to if" error + alpha_val = alpha_tuple[0][expert_idx - self.b_tensor_l_offsets[0]] + if cutlass.const_expr(self.num_b_tensors == 1): + pass # Already initialized above + elif cutlass.const_expr(self.num_b_tensors == 2): + if expert_idx >= self.b_tensor_l_offsets[1]: + alpha_val = alpha_tuple[1][expert_idx - self.b_tensor_l_offsets[1]] + elif cutlass.const_expr(self.num_b_tensors == 3): + if ( + expert_idx >= self.b_tensor_l_offsets[1] + and expert_idx < self.b_tensor_l_offsets[2] + ): + alpha_val = alpha_tuple[1][expert_idx - self.b_tensor_l_offsets[1]] + elif expert_idx >= self.b_tensor_l_offsets[2]: + alpha_val = alpha_tuple[2][expert_idx - self.b_tensor_l_offsets[2]] + else: + # 4 B tensors + if ( + expert_idx >= self.b_tensor_l_offsets[1] + and expert_idx < self.b_tensor_l_offsets[2] + ): + alpha_val = alpha_tuple[1][expert_idx - self.b_tensor_l_offsets[1]] + elif ( + expert_idx >= self.b_tensor_l_offsets[2] + and expert_idx < self.b_tensor_l_offsets[3] + ): + alpha_val = alpha_tuple[2][expert_idx - self.b_tensor_l_offsets[2]] + elif expert_idx >= self.b_tensor_l_offsets[3]: + alpha_val = alpha_tuple[3][expert_idx - self.b_tensor_l_offsets[3]] # # Slice to per mma tile index @@ -3313,12 +3730,12 @@ def can_implement( def wrapper( self, a_ptr: cute.Pointer, - b_ptr: cute.Pointer, + b_ptr_tuple: Tuple[cute.Pointer, ...], a_sf_ptr: cute.Pointer, - b_sf_ptr: cute.Pointer, + b_sf_ptr_tuple: Tuple[cute.Pointer, ...], c_ptr: cute.Pointer, c_sf_ptr: cute.Pointer, - alpha_ptr: cute.Pointer, + alpha_ptr_tuple: Tuple[cute.Pointer, ...], tile_idx_to_group_idx_ptr: cute.Pointer, tile_idx_to_mn_limit_ptr: cute.Pointer, token_id_mapping_ptr: cute.Pointer, @@ -3328,40 +3745,102 @@ def wrapper( m: cutlass.Int64, n: cutlass.Int64, k: cutlass.Int64, - l: cutlass.Int64, # noqa: E741 tile_size: cutlass.Constexpr, scaling_vector_size: cutlass.Constexpr, max_active_clusters: cutlass.Constexpr, stream: cuda.CUstream, epilogue_op: cutlass.Constexpr = lambda x: x, ): + """Unified wrapper supporting both single-B and multi-B tensors. + + B tensors are always passed as tuples (length 1 for single-B). + L sizes are configured via b_tensor_l_sizes in __init__. + """ scale_k = k // scaling_vector_size interm_size = n // 2 num_tiles = m // tile_size + total_l = self.b_tensor_l_offsets[self.num_b_tensors] + a = cute.make_tensor( a_ptr, layout=cute.make_ordered_layout((orig_m, k, 1), order=(1, 0, 2)) ) - b = cute.make_tensor(b_ptr, layout=cute.make_ordered_layout((n, k, l), order=(1, 0, 2))) a_sf = cute.make_tensor( a_sf_ptr, layout=cute.make_ordered_layout((orig_m, scale_k, 1), order=(1, 0, 2)) ) - b_sf = cute.make_tensor( - b_sf_ptr, - layout=cute.make_ordered_layout( - (32, 4, n // 128, 4, scale_k // 4, l), order=(2, 1, 4, 0, 3, 5) - ), - ) c = cute.make_tensor( c_ptr, layout=cute.make_ordered_layout((m, interm_size, 1), order=(1, 0, 2)) ) c_sf = cute.make_tensor( c_sf_ptr, layout=cute.make_ordered_layout( - (32, 4, m // 128, 4, interm_size // (scaling_vector_size * 4), l), + (32, 4, m // 128, 4, interm_size // (scaling_vector_size * 4), total_l), order=(2, 1, 4, 0, 3, 5), ), ) - alpha = cute.make_tensor(alpha_ptr, layout=cute.make_layout((l,))) + + # Create B and alpha tensors using const_expr conditions + l_0 = self.b_tensor_l_sizes[0] + alpha_0 = cute.make_tensor(alpha_ptr_tuple[0], layout=cute.make_layout((l_0,))) + b_0 = cute.make_tensor( + b_ptr_tuple[0], layout=cute.make_ordered_layout((n, k, l_0), order=(1, 0, 2)) + ) + b_sf_0 = cute.make_tensor( + b_sf_ptr_tuple[0], + layout=cute.make_ordered_layout( + (32, 4, n // 128, 4, scale_k // 4, l_0), order=(2, 1, 4, 0, 3, 5) + ), + ) + b_tuple = [b_0] + b_sf_tuple = [b_sf_0] + alpha_tuple = [alpha_0] + + if cutlass.const_expr(self.num_b_tensors >= 2): + l_1 = self.b_tensor_l_sizes[1] + alpha_1 = cute.make_tensor(alpha_ptr_tuple[1], layout=cute.make_layout((l_1,))) + b_1 = cute.make_tensor( + b_ptr_tuple[1], layout=cute.make_ordered_layout((n, k, l_1), order=(1, 0, 2)) + ) + b_sf_1 = cute.make_tensor( + b_sf_ptr_tuple[1], + layout=cute.make_ordered_layout( + (32, 4, n // 128, 4, scale_k // 4, l_1), order=(2, 1, 4, 0, 3, 5) + ), + ) + b_tuple.append(b_1) + b_sf_tuple.append(b_sf_1) + alpha_tuple.append(alpha_1) + + if cutlass.const_expr(self.num_b_tensors >= 3): + l_2 = self.b_tensor_l_sizes[2] + alpha_2 = cute.make_tensor(alpha_ptr_tuple[2], layout=cute.make_layout((l_2,))) + b_2 = cute.make_tensor( + b_ptr_tuple[2], layout=cute.make_ordered_layout((n, k, l_2), order=(1, 0, 2)) + ) + b_sf_2 = cute.make_tensor( + b_sf_ptr_tuple[2], + layout=cute.make_ordered_layout( + (32, 4, n // 128, 4, scale_k // 4, l_2), order=(2, 1, 4, 0, 3, 5) + ), + ) + b_tuple.append(b_2) + b_sf_tuple.append(b_sf_2) + alpha_tuple.append(alpha_2) + + if cutlass.const_expr(self.num_b_tensors >= 4): + l_3 = self.b_tensor_l_sizes[3] + alpha_3 = cute.make_tensor(alpha_ptr_tuple[3], layout=cute.make_layout((l_3,))) + b_3 = cute.make_tensor( + b_ptr_tuple[3], layout=cute.make_ordered_layout((n, k, l_3), order=(1, 0, 2)) + ) + b_sf_3 = cute.make_tensor( + b_sf_ptr_tuple[3], + layout=cute.make_ordered_layout( + (32, 4, n // 128, 4, scale_k // 4, l_3), order=(2, 1, 4, 0, 3, 5) + ), + ) + b_tuple.append(b_3) + b_sf_tuple.append(b_sf_3) + alpha_tuple.append(alpha_3) tile_idx_to_group_idx = cute.make_tensor( tile_idx_to_group_idx_ptr, layout=cute.make_layout((num_tiles,)) @@ -3377,17 +3856,17 @@ def wrapper( return self( a, - b, + tuple(b_tuple), c, a_sf, - b_sf, + tuple(b_sf_tuple), c_sf, global_sf, tile_idx_to_group_idx, tile_idx_to_mn_limit, token_id_mapping, num_non_exiting_tiles, - alpha, + tuple(alpha_tuple), max_active_clusters=max_active_clusters, stream=stream, epilogue_op=epilogue_op, diff --git a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py index 50d36beff868..babf3dbcb261 100644 --- a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py +++ b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py @@ -26,7 +26,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from typing import Tuple, Type, Union +from typing import Optional, Tuple, Type, Union import cuda.bindings.driver as cuda import cutlass @@ -339,6 +339,9 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel: ... ) """ + # Maximum number of B tensors supported + MAX_B_TENSORS = 4 + def __init__( self, sf_vec_size: int, @@ -346,6 +349,7 @@ def __init__( cluster_shape_mn: Tuple[int, int], use_blkred: bool = False, raster_along_m: bool = False, + b_tensor_l_sizes: Optional[Tuple[int, ...]] = None, ): """Initializes the configuration for a Blackwell blockscaled dense GEMM kernel. @@ -363,6 +367,10 @@ def __init__( :type cluster_shape_mn: Tuple[int, int] :param raster_along_m: Boolean, True to use raster along M. :type raster_along_m: bool + :param b_tensor_l_sizes: Optional tuple of L sizes for each B tensor. + E.g., (8, 8, 16) means 3 B tensors with L=8, 8, 16. Sum equals total L. + If None, single B tensor mode (backward compatible). + :type b_tensor_l_sizes: Optional[Tuple[int, ...]] """ self.sf_vec_size = sf_vec_size @@ -424,6 +432,26 @@ def __init__( # TMEM offset for final accumulator self.tmem_final_offset = 384 + # Multi-B tensor configuration + if b_tensor_l_sizes is None: + self.num_b_tensors = 1 + self.b_tensor_l_sizes = None + # Offsets padded for safe indexing in kernel + self.b_tensor_l_offsets = (0,) + (2**30,) * self.MAX_B_TENSORS + else: + assert len(b_tensor_l_sizes) <= self.MAX_B_TENSORS, ( + f"Max {self.MAX_B_TENSORS} B tensors, got {len(b_tensor_l_sizes)}" + ) + self.num_b_tensors = len(b_tensor_l_sizes) + self.b_tensor_l_sizes = b_tensor_l_sizes + offsets = [0] + for l_size in b_tensor_l_sizes: + offsets.append(offsets[-1] + l_size) + # Pad to MAX_B_TENSORS + 1 for safe indexing + while len(offsets) < self.MAX_B_TENSORS + 1: + offsets.append(2**30) + self.b_tensor_l_offsets = tuple(offsets) + def _setup_attributes(self): """Set up configurations that are dependent on GEMM inputs @@ -602,14 +630,14 @@ def _setup_attributes(self): def __call__( self, a: cute.Tensor, - b: cute.Tensor, + b: Union[cute.Tensor, Tuple[cute.Tensor, ...]], out: cute.Tensor, sfa: cute.Tensor, - sfb: cute.Tensor, + sfb: Union[cute.Tensor, Tuple[cute.Tensor, ...]], tile_idx_to_expert_idx: cute.Tensor, num_non_exiting_tiles: cute.Tensor, tile_idx_to_mn_limit: cute.Tensor, - alpha: cute.Tensor, + alpha: Union[cute.Tensor, Tuple[cute.Tensor, ...]], max_active_clusters: cutlass.Constexpr, stream: cuda.CUstream, permuted_idx_to_expanded_idx: cute.Tensor, @@ -639,7 +667,7 @@ def __call__( :param num_non_exiting_tiles: Number of valid tiles (valid_m/cta_tile_m), shape (1,) :type num_non_exiting_tiles: cute.Tensor :param alpha: Alpha tensor for each group - :type alpha: cute.Tensor + :type alpha: Union[cute.Tensor, Tuple[cute.Tensor, ...]] :param max_active_clusters: Maximum number of active clusters :type max_active_clusters: cutlass.Constexpr :param stream: CUDA stream for asynchronous execution @@ -654,12 +682,16 @@ def __call__( """ # Setup static attributes before smem/grid/tma computation self.a_dtype: Type[cutlass.Numeric] = a.element_type - self.b_dtype: Type[cutlass.Numeric] = b.element_type + # Handle tuple of B tensors + b_tuple = b if isinstance(b, tuple) else (b,) + sfb_tuple = sfb if isinstance(sfb, tuple) else (sfb,) + alpha_tuple = alpha if isinstance(alpha, tuple) else (alpha,) + self.b_dtype: Type[cutlass.Numeric] = b_tuple[0].element_type self.out_dtype: Type[cutlass.Numeric] = out.element_type self.sf_dtype: Type[cutlass.Numeric] = sfa.element_type self.final_scale_dtype: Type[cutlass.Numeric] = token_final_scales.element_type self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode() - self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(b_tuple[0]).mma_major_mode() self.gemm_output_layout = utils.LayoutEnum.ROW_MAJOR self.topK = token_final_scales.shape[1] @@ -675,8 +707,27 @@ def __call__( sfa = cute.make_tensor(sfa.iterator, sfa_layout) # ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL) - sfb_layout = blockscaled_utils.tile_atom_to_shape_SF(b.shape, self.sf_vec_size) - sfb = cute.make_tensor(sfb.iterator, sfb_layout) + sfb_layout_0 = blockscaled_utils.tile_atom_to_shape_SF(b_tuple[0].shape, self.sf_vec_size) + sfb_tensor_0 = cute.make_tensor(sfb_tuple[0].iterator, sfb_layout_0) + sfb_tensors = [sfb_tensor_0] + if cutlass.const_expr(self.num_b_tensors >= 2): + sfb_layout_1 = blockscaled_utils.tile_atom_to_shape_SF( + b_tuple[1].shape, self.sf_vec_size + ) + sfb_tensors.append(cute.make_tensor(sfb_tuple[1].iterator, sfb_layout_1)) + if cutlass.const_expr(self.num_b_tensors >= 3): + sfb_layout_2 = blockscaled_utils.tile_atom_to_shape_SF( + b_tuple[2].shape, self.sf_vec_size + ) + sfb_tensors.append(cute.make_tensor(sfb_tuple[2].iterator, sfb_layout_2)) + if cutlass.const_expr(self.num_b_tensors >= 4): + sfb_layout_3 = blockscaled_utils.tile_atom_to_shape_SF( + b_tuple[3].shape, self.sf_vec_size + ) + sfb_tensors.append(cute.make_tensor(sfb_tuple[3].iterator, sfb_layout_3)) + sfb_tuple = tuple(sfb_tensors) + # Backward compat alias + sfb = sfb_tuple[0] tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( self.a_dtype, @@ -714,14 +765,6 @@ def __call__( # Setup TMA load for B b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id) b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) - tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( - b_op, - b, - b_smem_layout, - self.mma_tiler, - tiled_mma, - self.cluster_layout_vmnk.shape, - ) # Setup TMA load for SFA sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id) @@ -739,34 +782,74 @@ def __call__( # Setup TMA load for SFB sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(self.cluster_shape_mn, tiled_mma.thr_id) sfb_smem_layout = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0)) - tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( - sfb_op, - sfb, - sfb_smem_layout, - self.mma_tiler_sfb, - tiled_mma_sfb, - self.cluster_layout_sfb_vmnk.shape, - internal_type=cutlass.Int16, - ) - - if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192): - x = tma_tensor_sfb.stride[0][1] - y = cute.ceil_div(tma_tensor_sfb.shape[0][1], 4) - new_shape = ( - (tma_tensor_sfb.shape[0][0], ((2, 2), y)), - tma_tensor_sfb.shape[1], - tma_tensor_sfb.shape[2], + # Helper to create TMA for one B tensor + def _make_tma_b(b_tensor, sfb_tensor): + atom_b, tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b_tensor, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, ) - # Use right multiplication for ScaledBasis (3 * x instead of x * 3) - x_times_3 = 3 * x - new_stride = ( - (tma_tensor_sfb.stride[0][0], ((x, x), x_times_3)), - tma_tensor_sfb.stride[1], - tma_tensor_sfb.stride[2], + atom_sfb, tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( + sfb_op, + sfb_tensor, + sfb_smem_layout, + self.mma_tiler_sfb, + tiled_mma_sfb, + self.cluster_layout_sfb_vmnk.shape, + internal_type=cutlass.Int16, ) - tma_tensor_sfb_new_layout = cute.make_layout(new_shape, stride=new_stride) - tma_tensor_sfb = cute.make_tensor(tma_tensor_sfb.iterator, tma_tensor_sfb_new_layout) + if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192): + x = tensor_sfb.stride[0][1] + y = cute.ceil_div(tensor_sfb.shape[0][1], 4) + + new_shape = ( + (tensor_sfb.shape[0][0], ((2, 2), y)), + tensor_sfb.shape[1], + tensor_sfb.shape[2], + ) + # Use right multiplication for ScaledBasis (3 * x instead of x * 3) + x_times_3 = 3 * x + new_stride = ( + (tensor_sfb.stride[0][0], ((x, x), x_times_3)), + tensor_sfb.stride[1], + tensor_sfb.stride[2], + ) + tensor_sfb_new_layout = cute.make_layout(new_shape, stride=new_stride) + tensor_sfb = cute.make_tensor(tensor_sfb.iterator, tensor_sfb_new_layout) + return atom_b, tensor_b, atom_sfb, tensor_sfb + + # Create TMA for all B tensors (use const_expr, not loop) + atom_b_0, tensor_b_0, atom_sfb_0, tensor_sfb_0 = _make_tma_b(b_tuple[0], sfb_tuple[0]) + tma_atoms_b = [atom_b_0] + tma_tensors_b = [tensor_b_0] + tma_atoms_sfb = [atom_sfb_0] + tma_tensors_sfb = [tensor_sfb_0] + if cutlass.const_expr(self.num_b_tensors >= 2): + atom_b_1, tensor_b_1, atom_sfb_1, tensor_sfb_1 = _make_tma_b(b_tuple[1], sfb_tuple[1]) + tma_atoms_b.append(atom_b_1) + tma_tensors_b.append(tensor_b_1) + tma_atoms_sfb.append(atom_sfb_1) + tma_tensors_sfb.append(tensor_sfb_1) + if cutlass.const_expr(self.num_b_tensors >= 3): + atom_b_2, tensor_b_2, atom_sfb_2, tensor_sfb_2 = _make_tma_b(b_tuple[2], sfb_tuple[2]) + tma_atoms_b.append(atom_b_2) + tma_tensors_b.append(tensor_b_2) + tma_atoms_sfb.append(atom_sfb_2) + tma_tensors_sfb.append(tensor_sfb_2) + if cutlass.const_expr(self.num_b_tensors >= 4): + atom_b_3, tensor_b_3, atom_sfb_3, tensor_sfb_3 = _make_tma_b(b_tuple[3], sfb_tuple[3]) + tma_atoms_b.append(atom_b_3) + tma_tensors_b.append(tensor_b_3) + tma_atoms_sfb.append(atom_sfb_3) + tma_tensors_sfb.append(tensor_sfb_3) + tma_atoms_b = tuple(tma_atoms_b) + tma_tensors_b = tuple(tma_tensors_b) + tma_atoms_sfb = tuple(tma_atoms_sfb) + tma_tensors_sfb = tuple(tma_tensors_sfb) a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) @@ -777,7 +860,7 @@ def __call__( ) * atom_thr_size self.tile_sched_params, grid = self._compute_grid( - (a.shape[0], b.shape[0], a.shape[2]), + (a.shape[0], b_tuple[0].shape[0], a.shape[2]), self.cta_tile_shape_mnk, self.cluster_shape_mn, max_active_clusters, @@ -862,17 +945,17 @@ class SharedStorage: tiled_mma_sfb, tma_atom_a, tma_tensor_a, - tma_atom_b, - tma_tensor_b, + tma_atoms_b, + tma_tensors_b, tma_atom_sfa, tma_tensor_sfa, - tma_atom_sfb, - tma_tensor_sfb, + tma_atoms_sfb, + tma_tensors_sfb, out, tile_idx_to_expert_idx, num_non_exiting_tiles, tile_idx_to_mn_limit, - alpha, + alpha_tuple, permuted_idx_to_expanded_idx, token_final_scales, self.cluster_layout_vmnk, @@ -947,17 +1030,17 @@ def kernel( tiled_mma_sfb: cute.TiledMma, tma_atom_a: cute.CopyAtom, mA_mkl: cute.Tensor, - tma_atom_b: cute.CopyAtom, - mB_nkl: cute.Tensor, + tma_atoms_b: Tuple[cute.CopyAtom, ...], + mB_nkl_tuple: Tuple[cute.Tensor, ...], tma_atom_sfa: cute.CopyAtom, mSFA_mkl: cute.Tensor, - tma_atom_sfb: cute.CopyAtom, - mSFB_nkl: cute.Tensor, + tma_atoms_sfb: Tuple[cute.CopyAtom, ...], + mSFB_nkl_tuple: Tuple[cute.Tensor, ...], out: cute.Tensor, tile_idx_to_expert_idx: cute.Tensor, num_non_exiting_tiles: cute.Tensor, tile_idx_to_mn_limit: cute.Tensor, - alpha: cute.Tensor, + alpha_tuple: Tuple[cute.Tensor, ...], permuted_idx_to_expanded_idx: cute.Tensor, token_final_scales: cute.Tensor, cluster_layout_vmnk: cute.Layout, @@ -984,9 +1067,18 @@ def kernel( # if warp_idx == self.tma_warp_id: cpasync.prefetch_descriptor(tma_atom_a) - cpasync.prefetch_descriptor(tma_atom_b) cpasync.prefetch_descriptor(tma_atom_sfa) - cpasync.prefetch_descriptor(tma_atom_sfb) + cpasync.prefetch_descriptor(tma_atoms_b[0]) + cpasync.prefetch_descriptor(tma_atoms_sfb[0]) + if cutlass.const_expr(self.num_b_tensors >= 2): + cpasync.prefetch_descriptor(tma_atoms_b[1]) + cpasync.prefetch_descriptor(tma_atoms_sfb[1]) + if cutlass.const_expr(self.num_b_tensors >= 3): + cpasync.prefetch_descriptor(tma_atoms_b[2]) + cpasync.prefetch_descriptor(tma_atoms_sfb[2]) + if cutlass.const_expr(self.num_b_tensors >= 4): + cpasync.prefetch_descriptor(tma_atoms_b[3]) + cpasync.prefetch_descriptor(tma_atoms_sfb[3]) use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 @@ -1119,9 +1211,29 @@ def kernel( mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) ) # (bN, bK, loopN, loopK, loopL) - gB_nkl = cute.local_tile( - mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + gB_nkl_0 = cute.local_tile( + mB_nkl_tuple[0], + cute.slice_(self.mma_tiler, (0, None, None)), + (None, None, None), ) + if cutlass.const_expr(self.num_b_tensors >= 2): + gB_nkl_1 = cute.local_tile( + mB_nkl_tuple[1], + cute.slice_(self.mma_tiler, (0, None, None)), + (None, None, None), + ) + if cutlass.const_expr(self.num_b_tensors >= 3): + gB_nkl_2 = cute.local_tile( + mB_nkl_tuple[2], + cute.slice_(self.mma_tiler, (0, None, None)), + (None, None, None), + ) + if cutlass.const_expr(self.num_b_tensors >= 4): + gB_nkl_3 = cute.local_tile( + mB_nkl_tuple[3], + cute.slice_(self.mma_tiler, (0, None, None)), + (None, None, None), + ) # (bM, bK, RestM, RestK, RestL) gSFA_mkl = cute.local_tile( @@ -1129,11 +1241,29 @@ def kernel( ) # (bN, bK, RestN, RestK, RestL) - gSFB_nkl = cute.local_tile( - mSFB_nkl, + gSFB_nkl_0 = cute.local_tile( + mSFB_nkl_tuple[0], cute.slice_(self.mma_tiler_sfb, (0, None, None)), (None, None, None), ) + if cutlass.const_expr(self.num_b_tensors >= 2): + gSFB_nkl_1 = cute.local_tile( + mSFB_nkl_tuple[1], + cute.slice_(self.mma_tiler_sfb, (0, None, None)), + (None, None, None), + ) + if cutlass.const_expr(self.num_b_tensors >= 3): + gSFB_nkl_2 = cute.local_tile( + mSFB_nkl_tuple[2], + cute.slice_(self.mma_tiler_sfb, (0, None, None)), + (None, None, None), + ) + if cutlass.const_expr(self.num_b_tensors >= 4): + gSFB_nkl_3 = cute.local_tile( + mSFB_nkl_tuple[3], + cute.slice_(self.mma_tiler_sfb, (0, None, None)), + (None, None, None), + ) k_tile_cnt = cutlass.Int32(cute.size(gA_mkl, mode=[3])) @@ -1145,11 +1275,23 @@ def kernel( # (MMA, MMA_M, MMA_K, loopM, loopK, loopL) tCgA = thr_mma.partition_A(gA_mkl) # (MMA, MMA_N, MMA_K, loopN, loopK, loopL) - tCgB = thr_mma.partition_B(gB_nkl) + tCgB_0 = thr_mma.partition_B(gB_nkl_0) + if cutlass.const_expr(self.num_b_tensors >= 2): + tCgB_1 = thr_mma.partition_B(gB_nkl_1) + if cutlass.const_expr(self.num_b_tensors >= 3): + tCgB_2 = thr_mma.partition_B(gB_nkl_2) + if cutlass.const_expr(self.num_b_tensors >= 4): + tCgB_3 = thr_mma.partition_B(gB_nkl_3) # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) tCgSFA = thr_mma.partition_A(gSFA_mkl) # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) - tCgSFB = thr_mma_sfb.partition_B(gSFB_nkl) + tCgSFB_0 = thr_mma_sfb.partition_B(gSFB_nkl_0) + if cutlass.const_expr(self.num_b_tensors >= 2): + tCgSFB_1 = thr_mma_sfb.partition_B(gSFB_nkl_1) + if cutlass.const_expr(self.num_b_tensors >= 3): + tCgSFB_2 = thr_mma_sfb.partition_B(gSFB_nkl_2) + if cutlass.const_expr(self.num_b_tensors >= 4): + tCgSFB_3 = thr_mma_sfb.partition_B(gSFB_nkl_3) # # Partition global/shared tensor for TMA load A/B @@ -1169,13 +1311,37 @@ def kernel( b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape) # ((atom_v, rest_v), STAGE) # ((atom_v, rest_v), loopM, loopK, loopL) - tBsB, tBgB = cpasync.tma_partition( - tma_atom_b, + tBsB_0, tBgB_0 = cpasync.tma_partition( + tma_atoms_b[0], block_in_cluster_coord_vmnk[1], b_cta_layout, cute.group_modes(sB, 0, 3), - cute.group_modes(tCgB, 0, 3), - ) + cute.group_modes(tCgB_0, 0, 3), + ) + if cutlass.const_expr(self.num_b_tensors >= 2): + _, tBgB_1 = cpasync.tma_partition( + tma_atoms_b[1], + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB_1, 0, 3), + ) + if cutlass.const_expr(self.num_b_tensors >= 3): + _, tBgB_2 = cpasync.tma_partition( + tma_atoms_b[2], + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB_2, 0, 3), + ) + if cutlass.const_expr(self.num_b_tensors >= 4): + _, tBgB_3 = cpasync.tma_partition( + tma_atoms_b[3], + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB_3, 0, 3), + ) # TMA load SFA partition_S/D sfa_cta_layout = a_cta_layout @@ -1199,15 +1365,42 @@ def kernel( ) # ((atom_v, rest_v), STAGE) # ((atom_v, rest_v), RestN, RestK, RestL) - tBsSFB, tBgSFB = cute.nvgpu.cpasync.tma_partition( - tma_atom_sfb, + tBsSFB_0, tBgSFB_0 = cute.nvgpu.cpasync.tma_partition( + tma_atoms_sfb[0], block_in_cluster_coord_sfb_vmnk[1], sfb_cta_layout, cute.group_modes(sSFB, 0, 3), - cute.group_modes(tCgSFB, 0, 3), - ) - tBsSFB = cute.filter_zeros(tBsSFB) - tBgSFB = cute.filter_zeros(tBgSFB) + cute.group_modes(tCgSFB_0, 0, 3), + ) + tBsSFB_0 = cute.filter_zeros(tBsSFB_0) + tBgSFB_0 = cute.filter_zeros(tBgSFB_0) + if cutlass.const_expr(self.num_b_tensors >= 2): + _, tBgSFB_1 = cute.nvgpu.cpasync.tma_partition( + tma_atoms_sfb[1], + block_in_cluster_coord_sfb_vmnk[1], + sfb_cta_layout, + cute.group_modes(sSFB, 0, 3), + cute.group_modes(tCgSFB_1, 0, 3), + ) + tBgSFB_1 = cute.filter_zeros(tBgSFB_1) + if cutlass.const_expr(self.num_b_tensors >= 3): + _, tBgSFB_2 = cute.nvgpu.cpasync.tma_partition( + tma_atoms_sfb[2], + block_in_cluster_coord_sfb_vmnk[1], + sfb_cta_layout, + cute.group_modes(sSFB, 0, 3), + cute.group_modes(tCgSFB_2, 0, 3), + ) + tBgSFB_2 = cute.filter_zeros(tBgSFB_2) + if cutlass.const_expr(self.num_b_tensors >= 4): + _, tBgSFB_3 = cute.nvgpu.cpasync.tma_partition( + tma_atoms_sfb[3], + block_in_cluster_coord_sfb_vmnk[1], + sfb_cta_layout, + cute.group_modes(sSFB, 0, 3), + cute.group_modes(tCgSFB_3, 0, 3), + ) + tBgSFB_3 = cute.filter_zeros(tBgSFB_3) # # Partition shared/tensor memory tensor for TiledMMA_A/B/C @@ -1395,19 +1588,15 @@ def kernel( # # ((atom_v, rest_v), loopK) tAgA_slice = tAgA[(None, mma_tile_coord_mnl[0], None, 0)] - # ((atom_v, rest_v), loopK) - tBgB_slice = tBgB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] # ((atom_v, rest_v), RestK) tAgSFA_slice = tAgSFA[(None, mma_tile_coord_mnl[0], None, 0)] + expert_idx = mma_tile_coord_mnl[2] slice_n = mma_tile_coord_mnl[1] if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64): slice_n = mma_tile_coord_mnl[1] // 2 - # ((atom_v, rest_v), RestK) - tBgSFB_slice = tBgSFB[(None, slice_n, None, mma_tile_coord_mnl[2])] - # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt ab_producer_state.reset_count() peek_ab_empty_status = cutlass.Boolean(1) @@ -1418,13 +1607,11 @@ def kernel( # for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): tAgA_k = tAgA_slice[(None, ab_producer_state.count)] - tBgB_k = tBgB_slice[(None, ab_producer_state.count)] tAgSFA_k = tAgSFA_slice[(None, ab_producer_state.count)] - tBgSFB_k = tBgSFB_slice[(None, ab_producer_state.count)] tAsA_pipe = tAsA[(None, ab_producer_state.index)] - tBsB_pipe = tBsB[(None, ab_producer_state.index)] + tBsB_pipe = tBsB_0[(None, ab_producer_state.index)] tAsSFA_pipe = tAsSFA[(None, ab_producer_state.index)] - tBsSFB_pipe = tBsSFB[(None, ab_producer_state.index)] + tBsSFB_pipe = tBsSFB_0[(None, ab_producer_state.index)] tma_bar = ab_pipeline.producer_get_barrier(ab_producer_state) @@ -1439,14 +1626,6 @@ def kernel( tma_bar_ptr=tma_bar, mcast_mask=a_full_mcast_mask, ) - cute.copy( - tma_atom_b, - tBgB_k, - tBsB_pipe, - tma_bar_ptr=tma_bar, - mcast_mask=b_full_mcast_mask, - ) - cute.copy( tma_atom_sfa, tAgSFA_k, @@ -1454,13 +1633,235 @@ def kernel( tma_bar_ptr=tma_bar, mcast_mask=sfa_full_mcast_mask, ) - cute.copy( - tma_atom_sfb, - tBgSFB_k, - tBsSFB_pipe, - tma_bar_ptr=tma_bar, - mcast_mask=sfb_full_mcast_mask, - ) + # Select correct B tensor based on expert_idx + if cutlass.const_expr(self.num_b_tensors == 1): + tBgB_slice = tBgB_0[(None, mma_tile_coord_mnl[1], None, expert_idx)] + tBgSFB_slice = tBgSFB_0[(None, slice_n, None, expert_idx)] + cute.copy( + tma_atoms_b[0], + tBgB_slice[(None, ab_producer_state.count)], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[0], + tBgSFB_slice[(None, ab_producer_state.count)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + else: + if cutlass.const_expr(self.num_b_tensors == 2): + if expert_idx < self.b_tensor_l_offsets[1]: + local_l_0 = expert_idx - self.b_tensor_l_offsets[0] + cute.copy( + tma_atoms_b[0], + tBgB_0[ + ( + None, + mma_tile_coord_mnl[1], + ab_producer_state.count, + local_l_0, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[0], + tBgSFB_0[(None, slice_n, ab_producer_state.count, local_l_0)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + else: + local_l_1 = expert_idx - self.b_tensor_l_offsets[1] + cute.copy( + tma_atoms_b[1], + tBgB_1[ + ( + None, + mma_tile_coord_mnl[1], + ab_producer_state.count, + local_l_1, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[1], + tBgSFB_1[(None, slice_n, ab_producer_state.count, local_l_1)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + elif cutlass.const_expr(self.num_b_tensors == 3): + if expert_idx < self.b_tensor_l_offsets[1]: + local_l_0 = expert_idx - self.b_tensor_l_offsets[0] + cute.copy( + tma_atoms_b[0], + tBgB_0[ + ( + None, + mma_tile_coord_mnl[1], + ab_producer_state.count, + local_l_0, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[0], + tBgSFB_0[(None, slice_n, ab_producer_state.count, local_l_0)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + elif expert_idx < self.b_tensor_l_offsets[2]: + local_l_1 = expert_idx - self.b_tensor_l_offsets[1] + cute.copy( + tma_atoms_b[1], + tBgB_1[ + ( + None, + mma_tile_coord_mnl[1], + ab_producer_state.count, + local_l_1, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[1], + tBgSFB_1[(None, slice_n, ab_producer_state.count, local_l_1)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + else: + local_l_2 = expert_idx - self.b_tensor_l_offsets[2] + cute.copy( + tma_atoms_b[2], + tBgB_2[ + ( + None, + mma_tile_coord_mnl[1], + ab_producer_state.count, + local_l_2, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[2], + tBgSFB_2[(None, slice_n, ab_producer_state.count, local_l_2)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + else: + if expert_idx < self.b_tensor_l_offsets[1]: + local_l_0 = expert_idx - self.b_tensor_l_offsets[0] + cute.copy( + tma_atoms_b[0], + tBgB_0[ + ( + None, + mma_tile_coord_mnl[1], + ab_producer_state.count, + local_l_0, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[0], + tBgSFB_0[(None, slice_n, ab_producer_state.count, local_l_0)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + elif expert_idx < self.b_tensor_l_offsets[2]: + local_l_1 = expert_idx - self.b_tensor_l_offsets[1] + cute.copy( + tma_atoms_b[1], + tBgB_1[ + ( + None, + mma_tile_coord_mnl[1], + ab_producer_state.count, + local_l_1, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[1], + tBgSFB_1[(None, slice_n, ab_producer_state.count, local_l_1)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + elif expert_idx < self.b_tensor_l_offsets[3]: + local_l_2 = expert_idx - self.b_tensor_l_offsets[2] + cute.copy( + tma_atoms_b[2], + tBgB_2[ + ( + None, + mma_tile_coord_mnl[1], + ab_producer_state.count, + local_l_2, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[2], + tBgSFB_2[(None, slice_n, ab_producer_state.count, local_l_2)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + else: + local_l_3 = expert_idx - self.b_tensor_l_offsets[3] + cute.copy( + tma_atoms_b[3], + tBgB_3[ + ( + None, + mma_tile_coord_mnl[1], + ab_producer_state.count, + local_l_3, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[3], + tBgSFB_3[(None, slice_n, ab_producer_state.count, local_l_3)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 ab_producer_state.advance() @@ -1798,7 +2199,33 @@ def kernel( # expert_idx = mma_tile_coord_mnl[2] - alpha_val = alpha[expert_idx] + alpha_val = alpha_tuple[0][expert_idx - self.b_tensor_l_offsets[0]] + if cutlass.const_expr(self.num_b_tensors == 1): + pass + elif cutlass.const_expr(self.num_b_tensors == 2): + if expert_idx >= self.b_tensor_l_offsets[1]: + alpha_val = alpha_tuple[1][expert_idx - self.b_tensor_l_offsets[1]] + elif cutlass.const_expr(self.num_b_tensors == 3): + if ( + expert_idx >= self.b_tensor_l_offsets[1] + and expert_idx < self.b_tensor_l_offsets[2] + ): + alpha_val = alpha_tuple[1][expert_idx - self.b_tensor_l_offsets[1]] + elif expert_idx >= self.b_tensor_l_offsets[2]: + alpha_val = alpha_tuple[2][expert_idx - self.b_tensor_l_offsets[2]] + else: + if ( + expert_idx >= self.b_tensor_l_offsets[1] + and expert_idx < self.b_tensor_l_offsets[2] + ): + alpha_val = alpha_tuple[1][expert_idx - self.b_tensor_l_offsets[1]] + elif ( + expert_idx >= self.b_tensor_l_offsets[2] + and expert_idx < self.b_tensor_l_offsets[3] + ): + alpha_val = alpha_tuple[2][expert_idx - self.b_tensor_l_offsets[2]] + elif expert_idx >= self.b_tensor_l_offsets[3]: + alpha_val = alpha_tuple[3][expert_idx - self.b_tensor_l_offsets[3]] tile_m_start = tile_info[0] * self.cta_tile_shape_mnk[0] permuted_row = tile_m_start + epi_tidx @@ -2496,11 +2923,11 @@ def can_implement( def wrapper( self, a_ptr: cute.Pointer, - b_ptr: cute.Pointer, + b_ptr_tuple: Tuple[cute.Pointer, ...], a_sf_ptr: cute.Pointer, - b_sf_ptr: cute.Pointer, + b_sf_ptr_tuple: Tuple[cute.Pointer, ...], c_ptr: cute.Pointer, - alpha_ptr: cute.Pointer, + alpha_ptr_tuple: Tuple[cute.Pointer, ...], tile_idx_to_group_idx_ptr: cute.Pointer, tile_idx_to_mn_limit_ptr: cute.Pointer, permuted_idx_to_expanded_idx_ptr: cute.Pointer, @@ -2509,7 +2936,6 @@ def wrapper( m: cutlass.Int64, n: cutlass.Int64, k: cutlass.Int64, - l: cutlass.Int64, # noqa: E741 num_tokens: cutlass.Int64, top_k: cutlass.Int64, tile_size: cutlass.Constexpr, @@ -2518,26 +2944,87 @@ def wrapper( stream: cuda.CUstream, epilogue_op: cutlass.Constexpr = lambda x: x, ): + """Unified wrapper supporting both single-B and multi-B tensors. + + B tensors are always passed as tuples (length 1 for single-B). + L sizes are configured via b_tensor_l_sizes in __init__. + """ scale_k = k // scaling_vector_size num_tiles = m // tile_size + a = cute.make_tensor(a_ptr, layout=cute.make_ordered_layout((m, k, 1), order=(1, 0, 2))) - b = cute.make_tensor(b_ptr, layout=cute.make_ordered_layout((n, k, l), order=(1, 0, 2))) a_sf = cute.make_tensor( a_sf_ptr, layout=cute.make_ordered_layout( (32, 4, m // 128, 4, scale_k // 4, 1), order=(2, 1, 4, 0, 3, 5) ), ) - b_sf = cute.make_tensor( - b_sf_ptr, - layout=cute.make_ordered_layout( - (32, 4, n // 128, 4, scale_k // 4, l), order=(2, 1, 4, 0, 3, 5) - ), - ) c = cute.make_tensor( c_ptr, layout=cute.make_ordered_layout((num_tokens, n, 1), order=(1, 0, 2)) ) - alpha = cute.make_tensor(alpha_ptr, layout=cute.make_layout((l,))) + + l_0 = self.b_tensor_l_sizes[0] + alpha_0 = cute.make_tensor(alpha_ptr_tuple[0], layout=cute.make_layout((l_0,))) + b_0 = cute.make_tensor( + b_ptr_tuple[0], layout=cute.make_ordered_layout((n, k, l_0), order=(1, 0, 2)) + ) + b_sf_0 = cute.make_tensor( + b_sf_ptr_tuple[0], + layout=cute.make_ordered_layout( + (32, 4, n // 128, 4, scale_k // 4, l_0), order=(2, 1, 4, 0, 3, 5) + ), + ) + b_tuple = [b_0] + b_sf_tuple = [b_sf_0] + alpha_tuple = [alpha_0] + + if cutlass.const_expr(self.num_b_tensors >= 2): + l_1 = self.b_tensor_l_sizes[1] + alpha_1 = cute.make_tensor(alpha_ptr_tuple[1], layout=cute.make_layout((l_1,))) + b_1 = cute.make_tensor( + b_ptr_tuple[1], layout=cute.make_ordered_layout((n, k, l_1), order=(1, 0, 2)) + ) + b_sf_1 = cute.make_tensor( + b_sf_ptr_tuple[1], + layout=cute.make_ordered_layout( + (32, 4, n // 128, 4, scale_k // 4, l_1), order=(2, 1, 4, 0, 3, 5) + ), + ) + b_tuple.append(b_1) + b_sf_tuple.append(b_sf_1) + alpha_tuple.append(alpha_1) + + if cutlass.const_expr(self.num_b_tensors >= 3): + l_2 = self.b_tensor_l_sizes[2] + alpha_2 = cute.make_tensor(alpha_ptr_tuple[2], layout=cute.make_layout((l_2,))) + b_2 = cute.make_tensor( + b_ptr_tuple[2], layout=cute.make_ordered_layout((n, k, l_2), order=(1, 0, 2)) + ) + b_sf_2 = cute.make_tensor( + b_sf_ptr_tuple[2], + layout=cute.make_ordered_layout( + (32, 4, n // 128, 4, scale_k // 4, l_2), order=(2, 1, 4, 0, 3, 5) + ), + ) + b_tuple.append(b_2) + b_sf_tuple.append(b_sf_2) + alpha_tuple.append(alpha_2) + + if cutlass.const_expr(self.num_b_tensors >= 4): + l_3 = self.b_tensor_l_sizes[3] + alpha_3 = cute.make_tensor(alpha_ptr_tuple[3], layout=cute.make_layout((l_3,))) + b_3 = cute.make_tensor( + b_ptr_tuple[3], layout=cute.make_ordered_layout((n, k, l_3), order=(1, 0, 2)) + ) + b_sf_3 = cute.make_tensor( + b_sf_ptr_tuple[3], + layout=cute.make_ordered_layout( + (32, 4, n // 128, 4, scale_k // 4, l_3), order=(2, 1, 4, 0, 3, 5) + ), + ) + b_tuple.append(b_3) + b_sf_tuple.append(b_sf_3) + alpha_tuple.append(alpha_3) tile_idx_to_group_idx = cute.make_tensor( tile_idx_to_group_idx_ptr, layout=cute.make_layout((num_tiles,)) @@ -2558,14 +3045,14 @@ def wrapper( return self( a, - b, + tuple(b_tuple), c, a_sf, - b_sf, + tuple(b_sf_tuple), tile_idx_to_group_idx, num_non_exiting_tiles, tile_idx_to_mn_limit, - alpha, + tuple(alpha_tuple), max_active_clusters=max_active_clusters, stream=stream, permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx, diff --git a/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py b/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py index 607b5d870e82..2737ce123a63 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py @@ -35,6 +35,7 @@ from tensorrt_llm._torch.expert_statistic import ExpertStatistic from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm._torch.modules.fused_moe.interface import MoE +from tensorrt_llm._torch.pyexecutor.dwdp import get_global_dwdp_manager from tensorrt_llm._torch.modules.fused_moe.routing import BaseMoeRoutingMethod from tensorrt_llm._torch.utils import AuxStreamType, EventType, Fp4QuantizedTensor from tensorrt_llm.logger import logger @@ -253,6 +254,19 @@ def __init__( # Validate configuration self.validate_config() + # ========== Optional DWDP integration ========== + self.dwdp_manager = get_global_dwdp_manager() + self.dwdp_handle_collector = None + self.dwdp_rank = None + self.enable_dwdp = False + if self.dwdp_manager is not None and self._should_enable_dwdp(): + self.enable_dwdp = True + self.dwdp_handle_collector = self.dwdp_manager.add_layer( + layer_idx=self.layer_idx, + ) + self.dwdp_rank = self.dwdp_manager.dwdp_rank + self.backend.dwdp_handle_collector = self.dwdp_handle_collector + # Mark as _weights_removed to skip ConfigurableMoE's post_load_weights in model_loader # The backend's post_load_weights will be called directly by model_loader # This avoids duplicate post_load_weights calls (once for ConfigurableMoE, once for backend) @@ -279,6 +293,20 @@ def validate_config(self): "apply_router_weight_on_input only supports top-1 routing" ) + def _should_enable_dwdp(self) -> bool: + # DWDP is currently supported only for CuteDslFusedMoE with NVFP4 quantization. + if not isinstance(self.backend, CuteDslFusedMoE): + return False + + quant_config = getattr(self.backend, "quant_config", None) + if quant_config is None: + quant_config = getattr(self.model_config, "quant_config", None) + if quant_config is None: + return False + + quant_mode = getattr(quant_config, "layer_quant_mode", None) + return bool(quant_mode is not None and hasattr(quant_mode, "has_nvfp4") and quant_mode.has_nvfp4()) + def _create_comm_strategy(self, model_config: ModelConfig) -> Optional[Communication]: """ Create communication strategy based on configuration @@ -753,6 +781,8 @@ def _forward_chunk_impl( router_logits, do_finalize, all_rank_num_tokens, output_dtype, x, workspace ), ) + if self.enable_dwdp: + self.dwdp_manager.record_compute_and_prefetch_next(self.layer_idx) # ========== Step 8: EPLB - Start CPU stage ========== self._load_balancer_start_set_cpu_stage(is_last_call) @@ -1131,6 +1161,11 @@ def _get_backend_kwargs( all_rank_num_tokens=all_rank_num_tokens, output_dtype=output_dtype ) + if self.enable_dwdp: + kwargs["dwdp_weight_view"] = self.dwdp_manager.build_weight_view( + self.layer_idx, self.backend + ) + # DeepGemm-specific parameters elif self.backend.__class__ == DeepGemmFusedMoE: if workspace is not None: diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py index 1273262f5f42..a1959850593f 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py @@ -14,6 +14,7 @@ # limitations under the License. import math +from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch @@ -41,6 +42,26 @@ from .routing import BaseMoeRoutingMethod +@dataclass +class NvFp4WeightView: + """Bundles all NVFP4 weight tensors for MoE computation. + + Provides a unified interface for both non-DWDP and DWDP paths: + - Non-DWDP: each list contains 1 element (local weight). + - DWDP: each list contains N elements (one per DWDP rank), + where the local rank's entry holds the actual model weight + and other ranks' entries hold prefetched buffer tensors. + """ + w3_w1_weight: List[torch.Tensor] + fc1_weight_scale: List[torch.Tensor] + fc1_global_scale: List[torch.Tensor] + w2_weight: List[torch.Tensor] + fc2_weight_scale: List[torch.Tensor] + fc2_global_scale: List[torch.Tensor] + expert_size_per_partition: int + slot_start: int + + @torch.compile(options={"max-autotune": True}) def swiglu_fused_moe(x): x, gate = x.chunk(2, dim=-1) @@ -425,6 +446,7 @@ def __init__( init_load_balancer=init_load_balancer, without_comm=without_comm, ) + if self.aux_stream_dict is None: self.aux_stream_dict = aux_stream_dict if aux_stream_dict is not None else {} if AuxStreamType.MoeOutputMemset not in self.aux_stream_dict: @@ -436,6 +458,20 @@ def __init__( if key not in self.event_dict: self.event_dict[key] = torch.cuda.Event() + + def _build_local_weight_view(self) -> NvFp4WeightView: + """Build weight view for non-DWDP path (single-element lists).""" + return NvFp4WeightView( + w3_w1_weight=[self.w3_w1_weight], + fc1_weight_scale=[self.quant_scales.fc1_weight_block], + fc1_global_scale=[self.quant_scales.fc1_global], + w2_weight=[self.w2_weight], + fc2_weight_scale=[self.quant_scales.fc2_weight_block], + fc2_global_scale=[self.quant_scales.fc2_global], + expert_size_per_partition=self.expert_size_per_partition, + slot_start=self.slot_start, + ) + def select_alltoall_method_type(self) -> AlltoallMethodType: return AlltoallMethodType.NotEnabled @@ -499,8 +535,21 @@ def run_moe_nvfp4( x_sf: Optional[torch.Tensor] = None, moe_output: Optional[torch.Tensor] = None, enable_alltoall: bool = False, + weight_view: Optional[NvFp4WeightView] = None, ) -> torch.Tensor: + """NVFP4 MoE computation with unified interface. + + Handles both non-DWDP and DWDP paths transparently: + - Non-DWDP (single-element weight lists): uses run_moe_nvfp4_impl. + Supports both fused-finalize and non-fused-finalize paths. + - DWDP (multi-element weight lists): uses run_moe_nvfp4_impl_dwdp. + Requires fused-finalize. + + Args: + weight_view: Bundled weight tensors. If None, local weights are used. + """ assert self.has_nvfp4 + assert weight_view is not None output_dtype = torch.bfloat16 if moe_output is None: @@ -513,24 +562,25 @@ def run_moe_nvfp4( self.hidden_size) assert moe_output.dtype == output_dtype - # After DeepEPLowLatency dispatch, token_selected_experts has shape - # [N, 1] instead of [N, top_k], because each row is already assigned - # to exactly one expert. Use the tensor shape as the effective top_k. effective_top_k = token_selected_experts.size(-1) + is_dwdp = len(weight_view.w3_w1_weight) > 1 + forward_impl = self.run_moe_nvfp4_impl_dwdp if is_dwdp else self.run_moe_nvfp4_impl + tuner = AutoTuner.get() runner = CuteDslFusedMoENvfp4Runner( - forward_impl=self.run_moe_nvfp4_impl, + forward_impl=forward_impl, num_experts=self.num_slots, top_k=effective_top_k, - num_local_experts=self.expert_size_per_partition, - local_expert_offset=self.slot_start, + num_local_experts=weight_view.expert_size_per_partition, + local_expert_offset=weight_view.slot_start, enable_finalize_fusion=self.use_fused_finalize, enable_alltoall=enable_alltoall, ) inputs = [ - x, token_selected_experts, token_final_scales, x_sf, moe_output + x, token_selected_experts, token_final_scales, x_sf, moe_output, + weight_view, ] _, best_tactic = tuner.choose_one( "CuteDslFusedMoE::run_moe_nvfp4", @@ -547,22 +597,23 @@ def run_moe_nvfp4_impl( token_final_scales: Optional[torch.Tensor], x_sf: torch.Tensor, moe_output: torch.Tensor, + weight_view: NvFp4WeightView, enable_alltoall: bool = False, tile_size: int = 128, ) -> torch.Tensor: + """Non-DWDP NVFP4 MoE implementation using single-tensor ops.""" output_dtype = torch.bfloat16 - - # Use effective top_k from tensor shape rather than routing config. - # After DeepEPLowLatency dispatch, each row maps to one expert (top_k=1). effective_top_k = token_selected_experts.size(1) + esp = weight_view.expert_size_per_partition + slot_start = weight_view.slot_start tile_idx_to_expert_idx, tile_idx_to_mn_limit, expanded_idx_to_permuted_idx, permuted_idx_to_expanded_idx, total_num_padded_tokens, num_non_exiting_tiles = torch.ops.trtllm.moe_sort( token_selected_experts=token_selected_experts, token_final_scales=token_final_scales, num_experts=self.num_slots, top_k=effective_top_k, - local_expert_offset=self.slot_start, - local_num_experts=self.expert_size_per_partition, + local_expert_offset=slot_start, + local_num_experts=esp, tile_tokens_dim=tile_size, ) @@ -573,10 +624,10 @@ def run_moe_nvfp4_impl( x, x_sf = torch.ops.trtllm.cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell( input=x.view(torch.float4_e2m1fn_x2), - weight=self.w3_w1_weight.view(torch.float4_e2m1fn_x2), + weight=weight_view.w3_w1_weight[0].view(torch.float4_e2m1fn_x2), input_scale=x_sf.view(torch.uint8), - weight_scale=self.quant_scales.fc1_weight_block.view(torch.uint8), - alpha=self.quant_scales.fc1_global, + weight_scale=weight_view.fc1_weight_scale[0].view(torch.uint8), + alpha=weight_view.fc1_global_scale[0], tile_idx_to_group_idx=tile_idx_to_expert_idx, tile_idx_to_mn_limit=tile_idx_to_mn_limit, permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx, @@ -584,8 +635,8 @@ def run_moe_nvfp4_impl( global_sf=self.fc2_input_scale, num_experts=self.num_slots, top_k=effective_top_k, - num_local_experts=self.expert_size_per_partition, - local_expert_offset=self.slot_start, + num_local_experts=esp, + local_expert_offset=slot_start, tile_size=tile_size, ) @@ -609,11 +660,10 @@ def run_moe_nvfp4_impl( torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell( input=x.view(torch.float4_e2m1fn_x2), - weight=self.w2_weight.view(torch.float4_e2m1fn_x2), + weight=[weight_view.w2_weight[0].view(torch.float4_e2m1fn_x2)], input_scale=x_sf.view(torch.uint8), - weight_scale=self.quant_scales.fc2_weight_block.view( - torch.uint8), - alpha=self.quant_scales.fc2_global, + weight_scale=[weight_view.fc2_weight_scale[0].view(torch.uint8)], + alpha=[weight_view.fc2_global_scale[0]], output=moe_output, tile_idx_to_group_idx=tile_idx_to_expert_idx, tile_idx_to_mn_limit=tile_idx_to_mn_limit, @@ -622,25 +672,24 @@ def run_moe_nvfp4_impl( token_final_scales=token_final_scales, num_experts=self.num_slots, top_k=effective_top_k, - num_local_experts=self.expert_size_per_partition, - local_expert_offset=self.slot_start, + num_local_experts=esp, + local_expert_offset=slot_start, tile_size=tile_size, output_dtype=output_dtype, ) else: x = torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_blackwell( input=x.view(torch.float4_e2m1fn_x2), - weight=self.w2_weight.view(torch.float4_e2m1fn_x2), + weight=weight_view.w2_weight[0].view(torch.float4_e2m1fn_x2), input_scale=x_sf.view(torch.uint8), - weight_scale=self.quant_scales.fc2_weight_block.view( - torch.uint8), - alpha=self.quant_scales.fc2_global, + weight_scale=weight_view.fc2_weight_scale[0].view(torch.uint8), + alpha=weight_view.fc2_global_scale[0], tile_idx_to_group_idx=tile_idx_to_expert_idx, num_non_exiting_tiles=num_non_exiting_tiles, num_experts=self.num_slots, top_k=effective_top_k, - num_local_experts=self.expert_size_per_partition, - local_expert_offset=self.slot_start, + num_local_experts=esp, + local_expert_offset=slot_start, tile_size=tile_size, output_dtype=output_dtype, ) @@ -652,6 +701,101 @@ def run_moe_nvfp4_impl( ) return moe_output + def run_moe_nvfp4_impl_dwdp( + self, + x: torch.Tensor, + token_selected_experts: torch.Tensor, + token_final_scales: Optional[torch.Tensor], + x_sf: torch.Tensor, + moe_output: torch.Tensor, + weight_view: NvFp4WeightView, + enable_alltoall: bool = False, + tile_size: int = 128, + ) -> torch.Tensor: + """DWDP NVFP4 MoE implementation using multi-B list ops. + + Requires fused-finalize since the non-fused FC2 op does not support + multiple B weight tensors. + """ + assert self.use_fused_finalize, ( + "DWDP requires fused finalize (cute_dsl_nvfp4_grouped_gemm_blackwell " + "does not support multiple B weight tensors)" + ) + output_dtype = torch.bfloat16 + effective_top_k = token_selected_experts.size(1) + esp = weight_view.expert_size_per_partition + slot_start = weight_view.slot_start + + tile_idx_to_expert_idx, tile_idx_to_mn_limit, expanded_idx_to_permuted_idx, permuted_idx_to_expanded_idx, total_num_padded_tokens, num_non_exiting_tiles = torch.ops.trtllm.moe_sort( + token_selected_experts=token_selected_experts, + token_final_scales=token_final_scales, + num_experts=self.num_slots, + top_k=effective_top_k, + local_expert_offset=slot_start, + local_num_experts=esp, + tile_tokens_dim=tile_size, + ) + + self.event_dict[EventType.Main].record() + moe_output.record_stream( + self.aux_stream_dict[AuxStreamType.MoeOutputMemset]) + + x, x_sf = torch.ops.trtllm.cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell_multi_b( + input=x.view(torch.float4_e2m1fn_x2), + weight=[w.view(torch.float4_e2m1fn_x2) for w in weight_view.w3_w1_weight], + input_scale=x_sf.view(torch.uint8), + weight_scale=[ws.view(torch.uint8) for ws in weight_view.fc1_weight_scale], + alpha=weight_view.fc1_global_scale, + tile_idx_to_group_idx=tile_idx_to_expert_idx, + tile_idx_to_mn_limit=tile_idx_to_mn_limit, + permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx, + num_non_exiting_tiles=num_non_exiting_tiles, + global_sf=self.fc2_input_scale, + num_experts=self.num_slots, + top_k=effective_top_k, + num_local_experts=esp, + local_expert_offset=slot_start, + tile_size=tile_size, + ) + + with torch.cuda.stream( + self.aux_stream_dict[AuxStreamType.MoeOutputMemset]): + self.event_dict[EventType.Main].wait() + torch.ops.trtllm.moe_output_memset_inplace( + input=moe_output, + tile_idx_to_mn_limit=tile_idx_to_mn_limit, + expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx, + permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx, + num_non_exiting_tiles=num_non_exiting_tiles, + tile_tokens_dim=tile_size, + top_k=effective_top_k, + ep_size=self.mapping.moe_ep_size, + enable_alltoall=enable_alltoall, + ) + self.event_dict[EventType.MoeOutputMemset].record() + self.event_dict[EventType.MoeOutputMemset].wait() + + torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell( + input=x.view(torch.float4_e2m1fn_x2), + weight=[w.view(torch.float4_e2m1fn_x2) for w in weight_view.w2_weight], + input_scale=x_sf.view(torch.uint8), + weight_scale=[ws.view(torch.uint8) for ws in weight_view.fc2_weight_scale], + alpha=weight_view.fc2_global_scale, + output=moe_output, + tile_idx_to_group_idx=tile_idx_to_expert_idx, + tile_idx_to_mn_limit=tile_idx_to_mn_limit, + permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx, + num_non_exiting_tiles=num_non_exiting_tiles, + token_final_scales=token_final_scales, + num_experts=self.num_slots, + top_k=effective_top_k, + num_local_experts=esp, + local_expert_offset=slot_start, + tile_size=tile_size, + output_dtype=output_dtype, + ) + return moe_output + def run_moe_fp8_block_scales( self, x: torch.Tensor, @@ -739,6 +883,7 @@ def run_moe( x_sf: Optional[torch.Tensor] = None, moe_output: Optional[torch.Tensor] = None, enable_alltoall: bool = False, + **kwargs, ) -> torch.Tensor: """ Run MoE computation with CuteDSL backend. @@ -759,16 +904,20 @@ def run_moe( Returns: final_hidden_states tensor. """ + # Execute MoE computation if self.has_nvfp4: - return self.run_moe_nvfp4( + weight_view = kwargs.get("dwdp_weight_view") or self._build_local_weight_view() + result = self.run_moe_nvfp4( x=x, token_selected_experts=token_selected_experts, token_final_scales=token_final_scales, x_sf=x_sf, moe_output=moe_output, - enable_alltoall=enable_alltoall) + enable_alltoall=enable_alltoall, + weight_view=weight_view, + ) elif self.has_deepseek_fp8_block_scales: - return self.run_moe_fp8_block_scales( + result = self.run_moe_fp8_block_scales( x=x, token_selected_experts=token_selected_experts, token_final_scales=token_final_scales, @@ -778,6 +927,7 @@ def run_moe( raise ValueError( f"{self.__class__.__name__} doesn't support quantization mode {self.quant_config.quant_mode}." ) + return result def forward_chunk( self, @@ -815,3 +965,9 @@ def forward_chunk( x_sf=x_sf, enable_alltoall=False) return x + + def load_weights(self, weights: Dict[str, torch.Tensor]): + super().load_weights(weights) + dwdp_handle_collector = getattr(self, "dwdp_handle_collector", None) + if dwdp_handle_collector is not None: + dwdp_handle_collector.register_weights(self) \ No newline at end of file diff --git a/tensorrt_llm/_torch/modules/fused_moe/interface.py b/tensorrt_llm/_torch/modules/fused_moe/interface.py index f5e8e1e6f5bb..d1ed9288d857 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/interface.py +++ b/tensorrt_llm/_torch/modules/fused_moe/interface.py @@ -51,6 +51,7 @@ def _warn_and_return(reason: str) -> Tuple[bool, Optional[str]]: get_model_extra_attrs, is_gated_activation, is_torch_compiling) from .routing import BaseMoeRoutingMethod +from ...pyexecutor.dwdp import get_global_dwdp_manager class MoEWeightLoadingMode(Enum): @@ -306,6 +307,31 @@ def __init__( self.initial_global_assignments = list(range(self.num_experts)) self.allreduce = None + # Override expert layout if DWDP is enabled + self._init_dwdp_expert_layout() + + def _init_dwdp_expert_layout(self): + """Override expert layout when DWDP is enabled.""" + dwdp_manager = get_global_dwdp_manager() + if dwdp_manager is None: + return + assert self.layer_load_balancer is None, ( + "DWDP and EPLB (MoE load balancer) cannot be used together. " + "Disable one of dwdp_config.enabled or moe_load_balancer." + ) + self.num_slots = self.num_experts + self.expert_size_per_partition = dwdp_manager.experts_per_worker + dwdp_size = dwdp_manager.dwdp_size + self.initial_global_assignments = [ + (ep_rank * self.num_experts // dwdp_size + local_slot_id) % + self.num_experts for ep_rank in range(dwdp_size) + for local_slot_id in range(self.expert_size_per_partition) + ] + self.slot_start = dwdp_manager.start_expert_id + self.slot_end = self.slot_start + self.expert_size_per_partition + self.initial_local_expert_ids = list( + range(self.slot_start, self.slot_end)) + def _init_load_balancer( self, model_config: ModelConfig, diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 1a8e0682f14a..dc868bc050af 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -45,7 +45,7 @@ KVCacheV2DummyScheduler, SimpleScheduler, SimpleUnifiedScheduler) from .seq_slot_manager import SeqSlotManager - +from .dwdp import DwdpManager GB = 1 << 30 @@ -1031,6 +1031,7 @@ def create_py_executor_instance( cache_transceiver_config: Optional[CacheTransceiverConfig] = None, virtual_memory_pools: Optional[dict] = None, execution_stream: Optional[torch.cuda.Stream] = None, + dwdp_manager: Optional[DwdpManager] = None, ) -> PyExecutor: kv_cache_manager = resources.get(ResourceManagerType.KV_CACHE_MANAGER, None) @@ -1218,7 +1219,9 @@ def create_py_executor_instance( peft_cache_config=peft_cache_config, virtual_memory_pools=virtual_memory_pools, execution_stream=execution_stream, - waiting_queue_policy=waiting_queue_policy) + waiting_queue_policy=waiting_queue_policy, + dwdp_manager=dwdp_manager, + ) def create_torch_sampler_args( diff --git a/tensorrt_llm/_torch/pyexecutor/dwdp.py b/tensorrt_llm/_torch/pyexecutor/dwdp.py new file mode 100644 index 000000000000..2b2aa1910c22 --- /dev/null +++ b/tensorrt_llm/_torch/pyexecutor/dwdp.py @@ -0,0 +1,556 @@ +import torch +import torch.nn as nn + +from tensorrt_llm.llmapi.llm_args import DwdpConfig +from typing import List, Optional, Dict, Tuple +from tensorrt_llm._torch.distributed import MPIDist +from tensorrt_llm._utils import global_mpi_rank +from mpi4py.MPI import COMM_WORLD + +from cuda.bindings import runtime as cudart +from cuda.bindings import driver as cuda_driver +from tensorrt_llm._utils import nvtx_range + + + +# Parameter names to collect handles for +WEIGHT_PARAMS = ['w3_w1_weight', 'w2_weight'] +BIAS_PARAMS = ['w3_w1_bias', 'w2_bias'] +# Quant scale params vary by quantization method +QUANT_SCALE_PARAMS = [ + 'w3_w1_weight_scale', 'w2_weight_scale', # NVFP4/MXFP4 + 'fc31_alpha', 'fc2_alpha', # NVFP4 alpha +] + + + +_global_dwdp_manager: Optional["DwdpManager"] = None + + +def set_global_dwdp_manager(manager: "DwdpManager"): + global _global_dwdp_manager + _global_dwdp_manager = manager + + +def get_global_dwdp_manager() -> Optional["DwdpManager"]: + return _global_dwdp_manager + + +def check_cuda_error(err, context: str = ""): + """Check CUDA error.""" + if err != cudart.cudaError_t.cudaSuccess: + raise RuntimeError(f"CUDA error in {context}: {err}") + + +class DwdpLayerHandleCollector: + """ + Dwdp Layer Handle Collector for IPC handle coordination and prefetch buffer management. + """ + + def __init__( + self, + layer_idx: int, + ): + + self.layer_idx = layer_idx + + # Local IPC handles: param_name -> handle_bytes + self.local_ipc_handles: Dict[str, bytes] = {} + # Local pointers: param_name -> data_ptr (for verification) + self.local_ptrs: Dict[str, int] = {} + # Local offsets: param_name -> offset from allocation base + # IPC handle points to allocation base, we need offset to get actual tensor data + self.local_offsets: Dict[str, int] = {} + # Parameter shapes: param_name -> shape (without expert dim) + self.param_shapes: Dict[str, torch.Size] = {} + # Parameter dtypes: param_name -> dtype + self.param_dtypes: Dict[str, torch.dtype] = {} + # Peer pointers: (peer_rank, param_name) -> ptr (already adjusted with offset) + self.peer_ptrs: Dict[Tuple[int, str], int] = {} + + def register_weights(self, module: nn.Module): + """ + Register weights from a MoE module and create IPC handles. + + Called after module.load_weights() completes. + + Args: + module: The MoE module with loaded weights + """ + params_to_register = [] + # Weights (check if present and not None) + for param_name in WEIGHT_PARAMS: + if hasattr(module, param_name) and getattr(module, param_name, None) is not None: + params_to_register.append(param_name) + # Bias (optional) + if hasattr(module, 'bias'): + params_to_register.extend(BIAS_PARAMS) + # Quant scales (optional, depends on quant method) + for param_name in QUANT_SCALE_PARAMS: + if hasattr(module, param_name) and getattr(module, param_name, None) is not None: + params_to_register.append(param_name) + + # Register each parameter + for param_name in params_to_register: + param = getattr(module, param_name) + if isinstance(param, nn.Parameter): + param = param.data + if param is None: + continue + if not param.is_cuda or not param.is_contiguous(): + raise ValueError(f"Parameter {param_name} is not on GPU or is not contiguous") + self._register_param(param_name, param) + + def _register_param(self, param_name: str, param: torch.Tensor): + # Get IPC handle - note: handle points to the CUDA allocation base, not tensor's data_ptr + tensor_ptr = param.data_ptr() + err, handle = cudart.cudaIpcGetMemHandle(tensor_ptr) + check_cuda_error(err, f"get handle for {param_name}") + + # Get allocation base address using Driver API cuMemGetAddressRange + # This returns the actual base address and size of the CUDA allocation + # cudaPointerGetAttributes.devicePointer returns the input pointer, not base! + err, alloc_base, alloc_size = cuda_driver.cuMemGetAddressRange(tensor_ptr) + if err != cuda_driver.CUresult.CUDA_SUCCESS: + raise RuntimeError(f"cuMemGetAddressRange failed for {param_name}: {err}") + + # Calculate offset from allocation base + # Convert CUdeviceptr to int for arithmetic + offset = tensor_ptr - int(alloc_base) + + self.local_ipc_handles[param_name] = bytes(handle.reserved) + self.local_ptrs[param_name] = tensor_ptr + self.local_offsets[param_name] = offset + self.param_shapes[param_name] = param.shape[1:] + self.param_dtypes[param_name] = param.dtype + + def get_peer_ptr(self, peer_rank: int, param_name: str) -> int: + """Get pointer to parameter on peer rank.""" + return self.peer_ptrs[(peer_rank, param_name)] + + def cleanup(self): + """Clean up peer handles.""" + for _, ptr in self.peer_ptrs.items(): + cudart.cudaIpcCloseMemHandle(ptr) + self.peer_ptrs.clear() + + +class DwdpPrefetchBuffer: + """ + Ping-pong buffer for expert weight prefetching. + + Buffer Selection Strategy: + - Even layers (0, 2, 4, ...) use buffer[0] + - Odd layers (1, 3, 5, ...) use buffer[1] + - This ensures layer N-1's prefetch doesn't overwrite layer N's data + + Synchronization Strategy: + - prefetch_events[buffer_idx][layer_idx]: Recorded when prefetch completes + Waited by forward() before using prefetched data + - compute_events[buffer_idx][layer_idx]: Recorded when forward() completes + Waited by next prefetch before overwriting buffer + + Buffer Layout (organized by rank): + - buffers[buffer_idx][param_name] = List[Optional[Tensor]] + - len(list) == dwdp_size + - list[peer_rank] = Tensor[num_prefetch_experts, ...] for peer_rank != dwdp_rank + - list[dwdp_rank] = None (local weight used directly, not prefetched) + """ + def __init__( + self, + dwdp_size: int, + dwdp_rank: int, + experts_per_worker: int, + num_prefetch_experts: int, + num_layers: int, + first_moe_layer_idx: int, + param_shapes: Dict[str, torch.Size], + param_dtypes: Dict[str, torch.dtype], + ): + + self.dwdp_size = dwdp_size + self.num_prefetch_experts = num_prefetch_experts + self.experts_per_worker = experts_per_worker + self.num_layers = num_layers + self.first_moe_layer_idx = first_moe_layer_idx + self.num_buffers = 2 # Ping-pong + self.dwdp_rank = dwdp_rank + + self.param_shapes = param_shapes + self.param_dtypes = param_dtypes + + self.device = torch.cuda.current_device() + + # buffers[buffer_idx][param_name] = List[Optional[Tensor]] + # list[peer_rank] contains prefetched weights from that rank + # list[dwdp_rank] = None (local weights used directly) + self.buffers: List[Dict[str, List[Optional[torch.Tensor]]]] = [] + + for _ in range(self.num_buffers): + buffer = {} + for param_name, shape in param_shapes.items(): + dtype = param_dtypes[param_name] + # Pre-allocate list of length dwdp_size, one slot per rank + # tensor_list[dwdp_rank] = None (local weights used directly) + # tensor_list[peer_rank] = Tensor for prefetched weights from peer + tensor_list: List[Optional[torch.Tensor]] = [None] * dwdp_size + for peer_rank in range(dwdp_size): + if peer_rank != dwdp_rank: + buffer_shape = (self.num_prefetch_experts,) + tuple(shape) + tensor_list[peer_rank] = torch.empty( + buffer_shape, + dtype=dtype, + device=self.device, + ) + buffer[param_name] = tensor_list + self.buffers.append(buffer) + + self.max_layer_idx = num_layers + first_moe_layer_idx + self.prefetch_events: List[List[torch.cuda.Event]] = [ + [torch.cuda.Event() for _ in range(self.max_layer_idx//self.num_buffers + 1)] + for _ in range(self.num_buffers) + ] + self.compute_events: List[List[torch.cuda.Event]] = [ + [torch.cuda.Event() for _ in range(self.max_layer_idx//self.num_buffers + 1)] + for _ in range(self.num_buffers) + ] + self.prefetch_stream = torch.cuda.Stream(device=self.device) + + def initialize_compute_events(self): + for buffer_idx in range(self.num_buffers): + self.compute_events[buffer_idx][0].record(torch.cuda.current_stream()) + + def record_prefetch_event(self, layer_idx: int): + self.prefetch_events[layer_idx % self.num_buffers][layer_idx // self.num_buffers].record(self.prefetch_stream) + + def record_compute_event(self, layer_idx: int): + self.compute_events[layer_idx % self.num_buffers][layer_idx // self.num_buffers].record(torch.cuda.current_stream()) + + def wait_prefetch_event(self, layer_idx: int): + torch.cuda.current_stream().wait_event(self.prefetch_events[layer_idx % self.num_buffers][layer_idx // self.num_buffers]) + + def wait_compute_event(self, layer_idx: int): + self.prefetch_stream.wait_event(self.compute_events[layer_idx % self.num_buffers][layer_idx // self.num_buffers]) + + +class DwdpManager: + """ + Dwdp Manager for IPC handle coordination and prefetch buffer management. + + This manager: + - Tracks IPC handles for all MoE layers across Context workers + - Manages double-buffered prefetch buffers for remote expert weights + - Provides expert tensor routing (local vs. prefetched) + + """ + + def __init__( + self, + config: DwdpConfig, + dist: Optional[object] = None, + ): + + self.config = config + self.dist = dist + self.dwdp_size = config.dwdp_size + self.experts_per_worker = config.experts_per_worker + self.num_group = config.num_group + + self._init_dwdp_group() + + # Per-layer IPC handle collectors (indexed by layer_idx) + self.ipc_collectors: List[DwdpLayerHandleCollector] = [] + + # Prefetch buffer (initialized later in create_py_executor) + self.prefetch_buffer: Optional[DwdpPrefetchBuffer] = None + # Auto-detected from first add_layer() call + self.first_moe_layer_idx: Optional[int] = None + + # Peer expert ranges: (peer_rank, (start_expert_id, end_expert_id)) + self.peer_expert_ranges: Dict[int, Tuple[int, int]] = {} + + self.dwdp_rank = self.rank % self.dwdp_size + self.num_prefetch_experts = config.num_prefetch_experts + self.start_expert_id = self.num_prefetch_experts * self.dwdp_rank + self.end_expert_id = self.start_expert_id + self.experts_per_worker + + set_global_dwdp_manager(self) + + def _init_dwdp_group(self): + + if not isinstance(self.dist, MPIDist): + raise RuntimeError("DWDP requires MPI backend (MPIDist)") + + self.rank = global_mpi_rank() + + # Calculate which group this rank belongs to + # With num_group=2, dwdp_size=4: + # Group 0: ranks [0, 1, 2, 3] + # Group 1: ranks [4, 5, 6, 7] + self.group_id = self.rank // self.dwdp_size + group_start_rank = self.group_id * self.dwdp_size + ranks = list(range(group_start_rank, group_start_rank + self.dwdp_size)) + + new_group = COMM_WORLD.group.Incl(ranks) + self.dwdp_group = COMM_WORLD.Create_group(new_group) + + def is_enabled(self) -> bool: + return self.config.enabled and self.dwdp_size > 1 + + def add_layer( + self, + layer_idx: int, + ) -> "DwdpLayerHandleCollector": + """ + Add a new layer IPC handle collector. + + Called from CuteDslFusedMoE.__init__() during model construction. + """ + if self.first_moe_layer_idx is None: + self.first_moe_layer_idx = layer_idx + collector = DwdpLayerHandleCollector( + layer_idx=layer_idx + ) + self.ipc_collectors.append(collector) + return collector + + def exchange_all_handles(self): + """ + Exchange IPC handles with peer Context workers via Dwdp Group AllGather. + + Called after all weights are loaded, before creating prefetch buffer. + """ + + # Collect all local handles with explicit worker info + local_data = { + 'dwdp_rank': self.dwdp_rank, + 'expert_start_id': self.start_expert_id, + 'expert_end_id': self.end_expert_id, + 'ipc_collectors': [], + } + for collector in self.ipc_collectors: + local_data['ipc_collectors'].append({ + 'layer_idx': collector.layer_idx, + 'handles': collector.local_ipc_handles, + 'offsets': collector.local_offsets, + }) + + # AllGather from all Context workers in DWDP group + all_data = self.dwdp_group.allgather(local_data) + + # Open handles from peer workers + for peer_data in all_data: + peer_rank = peer_data['dwdp_rank'] + self.peer_expert_ranges[peer_rank] = (peer_data['expert_start_id'], peer_data['expert_end_id']) + + if peer_rank == self.dwdp_rank: + continue + for layer_idx, ipc_collector in enumerate(peer_data['ipc_collectors']): + collector = self.ipc_collectors[layer_idx] + peer_offsets = ipc_collector['offsets'] + for param_name, handle_bytes in ipc_collector['handles'].items(): + # Reconstruct and open handle + handle = cudart.cudaIpcMemHandle_t() + handle.reserved = list(handle_bytes) + + err, base_ptr = cudart.cudaIpcOpenMemHandle( + handle, + cudart.cudaIpcMemLazyEnablePeerAccess + ) + check_cuda_error(err, f"open handle rank={peer_rank}") + + # Apply offset to get actual tensor pointer + # IPC handle points to allocation base, offset gives us the tensor location + offset = peer_offsets[param_name] + actual_ptr = base_ptr + offset + collector.peer_ptrs[(peer_rank, param_name)] = actual_ptr + + def initialize_prefetch_buffer(self): + """ + Initialize the prefetch buffer. + + Called in create_py_executor() after model loading. + """ + self.prefetch_buffer = DwdpPrefetchBuffer( + dwdp_size=self.dwdp_size, + dwdp_rank=self.dwdp_rank, + experts_per_worker=self.experts_per_worker, + num_prefetch_experts=self.num_prefetch_experts, + num_layers=len(self.ipc_collectors), + first_moe_layer_idx=self.first_moe_layer_idx, + param_shapes=self.ipc_collectors[0].param_shapes, + param_dtypes=self.ipc_collectors[0].param_dtypes, + ) + self.prefetch_buffer.initialize_compute_events() + + def prefetch_first_layers(self): + """Prefetch the first num_buffers layers as warmup.""" + if self.prefetch_buffer is None: + raise RuntimeError("Prefetch buffer is not initialized") + start = self.first_moe_layer_idx + for layer_idx in range(start, start + self.prefetch_buffer.num_buffers): + self.prefetch_layer(layer_idx) + self.prefetch_buffer.record_prefetch_event(layer_idx) + + def build_weight_view(self, layer_idx: int, backend): + """Build NvFp4WeightView from prefetch buffer and local weights. + + Assembles weight tensors from all DWDP ranks: + - Peer ranks: uses prefetched buffer tensors + - Local rank: uses backend's actual model weights + + Args: + layer_idx: The MoE layer index. + backend: The CuteDslFusedMoE backend holding local model weights. + + Returns: + NvFp4WeightView with all weights assembled. + """ + from tensorrt_llm._torch.modules.fused_moe.fused_moe_cute_dsl import NvFp4WeightView + + buffer_data = self.wait_prefetch_and_get_buffer(layer_idx) + required_keys = ( + "w3_w1_weight", "w3_w1_weight_scale", "fc31_alpha", + "w2_weight", "w2_weight_scale", "fc2_alpha", + ) + missing_keys = [key for key in required_keys if key not in buffer_data] + if missing_keys: + raise ValueError( + f"DWDP buffer missing required keys {missing_keys} for layer {layer_idx}." + ) + + w3_w1_weight_list = buffer_data["w3_w1_weight"] + fc1_weight_scale_list = buffer_data["w3_w1_weight_scale"] + fc1_global_scale_list = buffer_data["fc31_alpha"] + w2_weight_list = buffer_data["w2_weight"] + fc2_weight_scale_list = buffer_data["w2_weight_scale"] + fc2_global_scale_list = buffer_data["fc2_alpha"] + + w3_w1_weight_list[self.dwdp_rank] = backend.w3_w1_weight + fc1_weight_scale_list[self.dwdp_rank] = backend.quant_scales.fc1_weight_block + fc1_global_scale_list[self.dwdp_rank] = backend.quant_scales.fc1_global + w2_weight_list[self.dwdp_rank] = backend.w2_weight + fc2_weight_scale_list[self.dwdp_rank] = backend.quant_scales.fc2_weight_block + fc2_global_scale_list[self.dwdp_rank] = backend.quant_scales.fc2_global + + return NvFp4WeightView( + w3_w1_weight=w3_w1_weight_list, + fc1_weight_scale=fc1_weight_scale_list, + fc1_global_scale=fc1_global_scale_list, + w2_weight=w2_weight_list, + fc2_weight_scale=fc2_weight_scale_list, + fc2_global_scale=fc2_global_scale_list, + expert_size_per_partition=backend.num_slots, + slot_start=0, + ) + + def wait_prefetch_and_get_buffer(self, layer_idx: int) -> Optional[Dict[str, List[Optional[torch.Tensor]]]]: + """Wait for prefetch to complete and return the buffer for this layer. + + Returns: + Dict mapping param_name to List[Optional[Tensor]] where: + - list[peer_rank] = Tensor for prefetched weights from that peer + - list[dwdp_rank] = None (local weights used directly) + """ + if self.prefetch_buffer is None: + raise RuntimeError("Prefetch buffer is not initialized") + self.prefetch_buffer.wait_prefetch_event(layer_idx) + buffer_idx = layer_idx % self.prefetch_buffer.num_buffers + return self.prefetch_buffer.buffers[buffer_idx] + + def record_compute_and_prefetch_next(self, layer_idx: int): + """Record compute completion and trigger prefetch for layer_idx + num_buffers.""" + if self.prefetch_buffer is None: + raise RuntimeError("Prefetch buffer is not initialized") + # Record compute event for current layer + self.prefetch_buffer.record_compute_event(layer_idx) + + next_layer_idx = layer_idx + self.prefetch_buffer.num_buffers + if next_layer_idx >= self.prefetch_buffer.max_layer_idx: + return + # prefetch_layer handles stream internally: local copy on default stream, peer copy on prefetch stream + self.prefetch_layer(next_layer_idx, wait_compute_layer_idx=layer_idx) + self.prefetch_buffer.record_prefetch_event(next_layer_idx) + + def _get_prefetch_src_offset_from_peer(self, peer_rank: int) -> int: + """ + Calculate the source offset (in number of experts) to fetch from a peer. + + Returns: + src_offset: Offset into peer's local expert tensor to start copying from + + Example: 256 experts, rank0: [0, 200), rank1: [56, 256) + - rank0 needs [200, 256) from rank1: + src_offset = 200 - 56 = 144 (fetch last 56 experts from rank1) + - rank1 needs [0, 56) from rank0: + src_offset = 0 - 0 = 0 (fetch first 56 experts from rank0) + """ + peer_start, peer_end = self.peer_expert_ranges[peer_rank] + + # What I need = global - what I have + # From peer = what I need ∩ what peer has + if self.dwdp_rank < peer_rank: + # I'm earlier rank, need experts after my end (tail of peer's experts) + prefetch_end = peer_end + prefetch_start = prefetch_end - self.num_prefetch_experts + else: + # I'm later rank, need experts before my start (head of peer's experts) + prefetch_start = peer_start + + src_offset = prefetch_start - peer_start + return src_offset + + @nvtx_range("dwdp_prefetch_layer") + def prefetch_layer(self, layer_idx: int, wait_compute_layer_idx: Optional[int] = None): + """ + Prefetch layer data from peer ranks. + + Args: + layer_idx: The layer to prefetch + wait_compute_layer_idx: If provided, wait for this layer's compute to complete + before overwriting buffer (used when prefetching next layer) + + Note: Local weights are used directly by the kernel, no copy needed. + Peer copy runs on prefetch stream. + """ + moe_idx = layer_idx - self.first_moe_layer_idx + param_names = self.ipc_collectors[moe_idx].param_shapes.keys() + collector = self.ipc_collectors[moe_idx] + buffer_idx = layer_idx % self.prefetch_buffer.num_buffers + + # Peer copy on prefetch stream + # Local weights are used directly - no local copy needed + with torch.cuda.stream(self.prefetch_buffer.prefetch_stream): + # Wait for compute to complete before overwriting buffer + if wait_compute_layer_idx is not None: + self.prefetch_buffer.wait_compute_event(wait_compute_layer_idx) + + for peer_rank in range(self.dwdp_size): + if peer_rank == self.dwdp_rank: + continue # Skip local rank - local weights used directly + + src_expert_offset = self._get_prefetch_src_offset_from_peer(peer_rank) + + for param_name in param_names: + param_shape = collector.param_shapes[param_name] + param_dtype = collector.param_dtypes[param_name] + expert_size = param_shape.numel() * param_dtype.itemsize + + # src_ptr points to peer's tensor start, add offset for specific experts + base_ptr = collector.get_peer_ptr(peer_rank, param_name) + src_ptr = base_ptr + src_expert_offset * expert_size + + # dst_tensor is directly indexed by peer_rank in the list + dst_tensor = self.prefetch_buffer.buffers[buffer_idx][param_name][peer_rank] + dst_ptr = dst_tensor.data_ptr() + + data_size = self.num_prefetch_experts * expert_size + + err, = cudart.cudaMemcpyAsync( + dst_ptr, + src_ptr, + data_size, + cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice, + self.prefetch_buffer.prefetch_stream.cuda_stream, + ) + check_cuda_error(err, f"prefetch layer {layer_idx} peer_rank {peer_rank} {param_name}") \ No newline at end of file diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index a267e165dd69..71147860659a 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -68,6 +68,7 @@ SerializableSchedulerOutput, WaitingQueue, create_waiting_queue) from .scheduler.adp_router import ADPRouter, DefaultADPRouter +from .dwdp import DwdpManager # Environment variable to specify iteration ranges for profiling start/stop. # Format: "start1-stop1,start2-stop2,..." or single iterations "iter1,iter2,..." @@ -281,7 +282,8 @@ def __init__( hang_detection_timeout: Optional[int] = None, execution_stream: Optional[torch.cuda.Stream] = None, waiting_queue_policy: WaitingQueuePolicy = WaitingQueuePolicy.FCFS, - adp_router: Optional[ADPRouter] = None): + adp_router: Optional[ADPRouter] = None, + dwdp_manager: Optional[DwdpManager] = None): super(PyExecutor, self).__init__() self.device_id = torch.cuda.current_device() self.global_rank = dist.rank @@ -531,6 +533,8 @@ def on_detected(): self._maybe_init_kv_connector_manager() + self.dwdp_manager = dwdp_manager + if start_worker: self.start_worker() @@ -1849,6 +1853,8 @@ def _executor_loop(self): with self.perf_manager.record_perf_events( gpu_forward_start, gpu_forward_end) as fwd_timing: + if self.dwdp_manager is not None: + self.dwdp_manager.prefetch_first_layers() batch_outputs = self._forward_step(scheduled_batch) guided_decoder_failed_requests = None diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index 2d7614559e65..691b25f2ed7b 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -42,6 +42,7 @@ from .model_engine import PyTorchModelEngine from .model_loader import ModelLoader, _construct_checkpoint_loader from .py_executor import PyExecutor +from .dwdp import DwdpManager class _ExecutorMemoryMonitor: @@ -376,6 +377,13 @@ def create_py_executor( ) logger.info("ATTENTION RUNTIME FEATURES: ", attn_runtime_features) + # Initialize DWDP Manager (only for context workers in disaggregated serving) + dwdp_manager: Optional[DwdpManager] = None + if llm_args.dwdp_config is not None and llm_args.dwdp_config.enabled: + assert mapping.tp_size == 1 and llm_args.dwdp_config.dwdp_size > 1, "DWDP requires TP=1 and dwdp_size > 1" + dwdp_manager = DwdpManager(config=llm_args.dwdp_config, dist=dist) + logger.info(f"Dwdp Manager initialized. Config: {llm_args.dwdp_config}") + mem_monitor = _ExecutorMemoryMonitor() @contextmanager @@ -720,6 +728,11 @@ def drafting_loop_wrapper(model): max_seq_len = kv_cache_creator._max_seq_len update_sampler_max_seq_len(max_seq_len, sampler) + # Exchange IPC Handles and Initialize Dwdp Prefetch Buffer + if dwdp_manager is not None: + dwdp_manager.exchange_all_handles() + dwdp_manager.initialize_prefetch_buffer() + # Resource managers for speculative decoding # For user-specified drafters, use extra_resource_managers in PyTorchBackend config # to provide a resource manager if required. @@ -830,6 +843,7 @@ def drafting_loop_wrapper(model): cache_transceiver_config=cache_transceiver_config, virtual_memory_pools=vm_pools, execution_stream=execution_stream, + dwdp_manager=dwdp_manager, ) _adjust_torch_mem_fraction() diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 91419488c46e..85eb1f943977 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -2153,6 +2153,20 @@ def model_name(self) -> Union[str, Path]: return self.model if isinstance(self.model, str) else None +class DwdpConfig(StrictBaseModel): + """ + Configuration for DWDP. + """ + enabled: bool = Field(default=False, description="Whether to enable DWDP.") + dwdp_size: int = Field(default=1, description="The number of GPUs per DWDP group.") + num_group: int = Field(default=1, description="The number of DWDP groups. Total workers = num_group * dwdp_size.") + experts_per_worker: int = Field(default=0, description="The number of experts per worker.") + num_prefetch_experts: int = Field(default=0, description="The number of prefetch experts per worker.") + + @classmethod + def from_dict(cls, data: dict): + return cls(**data) + class BaseLlmArgs(StrictBaseModel): """ Base class for both TorchLlmArgs and TrtLlmArgs. It contains all the arguments that are common to both. @@ -3033,6 +3047,11 @@ class TorchLlmArgs(BaseLlmArgs): description="NVFP4 GEMM backend config.", status="beta") + dwdp_config: DwdpConfig = Field( + default_factory=DwdpConfig, + description="DWDP (Distributed Weight Data Parallelism) config.", + status="beta") + attn_backend: str = Field(default='TRTLLM', description="Attention backend to use.", status="beta") @@ -3475,6 +3494,7 @@ def update_llm_args_with_extra_dict( "nvfp4_gemm_config": Nvfp4GemmConfig, "attention_dp_config": AttentionDpConfig, "kv_cache_config": KvCacheConfig, + "dwdp_config": DwdpConfig, } for field_name, field_type in field_mapping.items(): if field_name in llm_args_dict: diff --git a/tests/integration/defs/accuracy/test_disaggregated_serving.py b/tests/integration/defs/accuracy/test_disaggregated_serving.py index 2e3ad4f1bb3e..5093a4a248de 100644 --- a/tests/integration/defs/accuracy/test_disaggregated_serving.py +++ b/tests/integration/defs/accuracy/test_disaggregated_serving.py @@ -1107,6 +1107,109 @@ def test_guided_decoding(self, backend: str, mtp_nextn: int, mocker): self.MODEL_PATH) as llm: run_accuracy_test(llm, self.MODEL_NAME, ["JsonModeEval"]) + @pytest.mark.skip_less_device(4) + @skip_pre_blackwell + def test_dwdp_accuracy(self): + model_path = f"{llm_models_root()}/DeepSeek-V3-Lite/nvfp4_moe_only_mtp" + + ctx_port_0 = get_free_port() + ctx_port_1 = get_free_port() + gen_port = get_free_port() + serve_port = get_free_port() + + ctx_server_config = { + "num_instances": 2, + "urls": [ + f"localhost:{ctx_port_0}", + f"localhost:{ctx_port_1}", + ], + "tensor_parallel_size": 1, + "pipeline_parallel_size": 1, + "disable_overlap_scheduler": True, + "enable_autotuner": False, + "enable_chunked_prefill": False, + "cuda_graph_config": None, + "max_batch_size": 16, + "max_num_tokens": 8192, + "kv_cache_config": { + "free_gpu_memory_fraction": 0.4, + "enable_block_reuse": False, + "enable_partial_reuse": False, + "tokens_per_block": 32, + }, + "cache_transceiver_config": { + "backend": "UCX", + "max_tokens_in_buffer": 8192, + }, + "moe_config": { + "backend": "CUTEDSL", + }, + "dwdp_config": { + "enabled": True, + "dwdp_size": 2, + "num_group": 1, + "experts_per_worker": 36, + "num_prefetch_experts": 36, + }, + } + + gen_server_config = { + "num_instances": 1, + "urls": [f"localhost:{gen_port}"], + "tensor_parallel_size": 2, + "pipeline_parallel_size": 1, + "disable_overlap_scheduler": True, + "enable_autotuner": False, + "enable_chunked_prefill": False, + "cuda_graph_config": None, + "max_batch_size": 128, + "max_num_tokens": 1024, + "kv_cache_config": { + "free_gpu_memory_fraction": 0.5, + "enable_block_reuse": False, + "enable_partial_reuse": False, + "tokens_per_block": 32, + }, + "cache_transceiver_config": { + "backend": "UCX", + "max_tokens_in_buffer": 8192, + }, + "moe_config": { + "backend": "CUTEDSL", + }, + } + + worker_config = { + "model": model_path, + "hostname": "localhost", + "port": serve_port, + "backend": "pytorch", + "context_servers": ctx_server_config, + "generation_servers": gen_server_config, + } + + frontend_config = { + "backend": "pytorch", + "hostname": "localhost", + "port": serve_port, + "context_servers": { + "num_instances": 2, + "urls": [ + f"localhost:{ctx_port_0}", + f"localhost:{ctx_port_1}", + ], + }, + "generation_servers": { + "num_instances": 1, + "urls": [f"localhost:{gen_port}"], + }, + } + + with launch_dwdp_disaggregated_llm( + worker_config, frontend_config, model_path, + total_gpus=4, max_workers=128) as llm: + run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"]) + @pytest.mark.timeout(DEFAULT_TEST_TIMEOUT) class TestGemma3_1BInstruct(LlmapiAccuracyTestHarness): @@ -1766,3 +1869,143 @@ def test_nixl_backend(self): with launch_disaggregated_llm(disagg_cfg, ctx_cfg, gen_cfg, self.MODEL_PATH) as llm: run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"]) + + +@contextlib.contextmanager +def launch_dwdp_disaggregated_llm( + worker_config: Dict[str, Any], + frontend_config: Dict[str, Any], + model_path: str, + total_gpus: int, + server_waiting_timeout: int = DEFAULT_SERVER_WAITING_TIMEOUT, + max_workers: int = 128, +): + """Launch DWDP disaggregated serving via mpirun. + + DWDP requires all workers (CTX + GEN) in a single MPI world for + IPC handle exchange and DWDP group formation. This function starts + all workers with ``mpirun`` and launches a separate disaggregated + frontend server for the client-facing OpenAI API. + """ + temp_dir = tempfile.TemporaryDirectory() + worker_config_path = os.path.join(temp_dir.name, "worker_config.yaml") + frontend_config_path = os.path.join(temp_dir.name, "frontend_config.yaml") + + with open(worker_config_path, "w") as f: + yaml.dump(worker_config, f, default_flow_style=False, sort_keys=False) + with open(frontend_config_path, "w") as f: + yaml.dump(frontend_config, f, default_flow_style=False, sort_keys=False) + + serve_port = frontend_config["port"] + + # Prevent the parent process's MPI state (set by mpi4py init during + # tensorrt_llm import) from leaking into the mpirun subprocess. + # mpirun must create a fresh MPI world for the DWDP workers. + child_env = {k: v for k, v in os.environ.items() + if not k.startswith(('OMPI_', 'PMIX_', 'PMI_'))} + + mpi_cmd = [ + "mpirun", "--allow-run-as-root", "-n", + str(total_gpus), "trtllm-serve", "disaggregated_mpi_worker", "-c", + worker_config_path + ] + + frontend_cmd = [ + "trtllm-serve", "disaggregated", "-c", frontend_config_path, + "--server_start_timeout", + str(server_waiting_timeout), "-r", "360000" + ] + + with ( + MyThreadPoolExecutor(max_workers=max_workers) as thread_pool, + temp_dir, + popen(mpi_cmd, env=child_env) as mpi_proc, + popen(frontend_cmd, env=child_env) as frontend_proc, + ): + start_time = time.time() + server_is_ready = False + while time.time() - start_time < server_waiting_timeout: + time.sleep(5) + for proc, name in [ + (mpi_proc, "mpirun"), + (frontend_proc, "frontend"), + ]: + if proc.poll() is not None: + raise Exception( + f"{name} process exited with code {proc.returncode}" + ) + try: + response = requests.get( + f"http://localhost:{serve_port}/cluster_info") + if response.status_code == 200: + cluster_info = response.json() + if cluster_info.get("is_ready"): + print(f"DWDP cluster ready: {cluster_info}") + server_is_ready = True + break + except requests.exceptions.ConnectionError: + continue + if not server_is_ready: + pytest.fail( + f"DWDP server not ready after {server_waiting_timeout}s") + + model_name = worker_config.get("model", model_path) + client = openai.OpenAI(api_key="1234567890", + base_url=f"http://localhost:{serve_port}/v1", + timeout=1800000) + + def send_request(prompt: str, sampling_params: SamplingParams, + streaming: bool): + kwargs = {} + if sampling_params is not None: + kwargs.update( + max_tokens=sampling_params.max_tokens, + temperature=(sampling_params.temperature + if sampling_params.top_p is not None else 0), + top_p=sampling_params.top_p, + stop=sampling_params.stop, + seed=sampling_params.seed) + response = client.completions.create(model=model_name, + prompt=prompt, + stream=streaming, + **kwargs) + result = Result(id=0, + sampling_params=sampling_params, + outputs=[ + CompletionOutput(text=response.choices[0].text, + index=0) + ]) + requested_output = RequestOutput._from_generation_result( + result, prompt=prompt) + setattr(requested_output, "result", result.result) + return requested_output + + def generate_async(prompt: str, + sampling_params: Optional[SamplingParams] = None, + streaming: bool = False): + future = thread_pool.submit(send_request, prompt, sampling_params, + streaming) + thread_pool.futures.append(future) + return future + + args = LlmArgs(model=model_path) + tokenizer = load_hf_tokenizer(model_path) + try: + yield DuckLLM(args, tokenizer, generate_async) + finally: + all_procs = [frontend_proc, mpi_proc] + for proc in all_procs: + if proc.poll() is None: + proc.terminate() + deadline = time.monotonic() + 5 + for proc in all_procs: + remaining = max(0, deadline - time.monotonic()) + try: + proc.wait(timeout=remaining) + except subprocess.TimeoutExpired: + try: + proc.kill() + except ProcessLookupError: + pass + except OSError: + pass diff --git a/tests/scripts/cute_dsl_kernels/run_blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py b/tests/scripts/cute_dsl_kernels/run_blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py index 3fb0015c9833..3d7c46e5d0aa 100644 --- a/tests/scripts/cute_dsl_kernels/run_blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py +++ b/tests/scripts/cute_dsl_kernels/run_blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py @@ -92,6 +92,37 @@ cvt_sf_M32x4xrm_K4xrk_L_to_MKL = kernel_module.cvt_sf_M32x4xrm_K4xrk_L_to_MKL +def split_groups_to_b_tensors( + num_groups: int, num_b_tensors: int +) -> Tuple[Tuple[int, ...], Tuple[Tuple[int, ...], ...]]: + """Split groups into multiple B tensors. + + :param num_groups: Total number of groups (experts) + :param num_b_tensors: Number of B tensors to split into + :return: Tuple of (b_tensor_l_sizes, groups_per_b_tensor) + - b_tensor_l_sizes: L size for each B tensor + - groups_per_b_tensor: Tuple of group indices for each B tensor + """ + # Distribute groups evenly across B tensors + base_groups_per_tensor = num_groups // num_b_tensors + remainder = num_groups % num_b_tensors + + b_tensor_l_sizes = [] + groups_per_b_tensor = [] + current_group = 0 + + for i in range(num_b_tensors): + # Add one extra group to first 'remainder' tensors + num_groups_in_tensor = base_groups_per_tensor + (1 if i < remainder else 0) + b_tensor_l_sizes.append(num_groups_in_tensor) + groups_per_b_tensor.append( + tuple(range(current_group, current_group + num_groups_in_tensor)) + ) + current_group += num_groups_in_tensor + + return tuple(b_tensor_l_sizes), tuple(groups_per_b_tensor) + + def create_mask(group_m_list, mma_tiler_m, permuted_m=None): """Create mask and group mapping for contiguous grouped GEMM. @@ -375,6 +406,8 @@ def create_tensors( sf_vec_size, mma_tiler_m, permuted_m=None, + b_tensor_l_sizes=None, + groups_per_b_tensor=None, ): """Create tensors for contiguous grouped GEMM with gather operation and SwiGLU fusion. @@ -383,7 +416,7 @@ def create_tensors( Returns tensors including: - A: Input matrix (MxKx1) - - B: Weight matrix with interleaved up/gate weights (NxKxL) + - B: Weight matrix with interleaved up/gate weights (NxKxL) or list of tensors for multi-B - C: Output matrix (Mx(N/2)x1), N is halved due to SwiGLU fusion - SFA, SFB: Scale factor matrices for A and B - SFC: Scale factor matrix for C (only when c_dtype is Float4E2M1FN) @@ -392,31 +425,14 @@ def create_tensors( - num_non_exiting_tiles: Number of valid tiles to process :param mma_tiler_m: MMA tile size in M dimension (from mma_tiler_mn[0]), also used for alignment - :param permuted_m: Optional padded M dimension for cuda_graph support. If provided, - A matrix, C matrix, token_id_mapping, and scale factor A will be padded to this size. - The kernel exits when tile_idx >= num_non_exiting_tiles. - - Example with CUDA graph padding: - # For MoE: m=4096, topK=8, num_local_experts=256, experts_per_rank=8 - permuted_m = 4096 * 8 + 8 * 255 # = 34808 - tensors = create_tensors( - num_groups=8, # num_local_experts - group_m_list=[512, 1024, ...], # actual group sizes - n=4096, k=7168, - a_major="k", b_major="k", cd_major="n", - ab_dtype=cutlass.Float4E2M1FN, - c_dtype=cutlass.BFloat16, - sf_dtype=cutlass.Float8E4M3FN, - sf_vec_size=16, - mma_tiler_m=128, # MMA tile size in M dimension, also used for alignment - permuted_m=34808 # Enable padding for cuda_graph - ) - # Returns tensors with A, C, SFA, and token_id_mapping padded to permuted_m size, - # kernel exits early when tile_idx >= num_non_exiting_tiles + :param permuted_m: Optional padded M dimension for cuda_graph support. + :param b_tensor_l_sizes: Optional tuple of L sizes for multi-B tensor mode. + :param groups_per_b_tensor: Optional tuple of group indices for each B tensor. """ torch.manual_seed(1111) - alpha_torch_cpu = torch.randn((num_groups,), dtype=torch.float32) + # Determine if multi-B tensor mode + multi_b_mode = b_tensor_l_sizes is not None ( valid_m, @@ -427,21 +443,14 @@ def create_tensors( ) = create_mask(group_m_list, mma_tiler_m, permuted_m) max_m = max(group_m_list) - - # Use permuted_m for A/C tensors if provided (for cuda_graph support) tensor_m = permuted_m if permuted_m is not None else valid_m a_torch_cpu = cutlass_torch.matrix(1, max_m, k, a_major == "m", cutlass.Float32) - b_torch_cpu = cutlass_torch.matrix(num_groups, n, k, b_major == "n", cutlass.Float32) - # C tensor also uses tensor_m (permuted_m) for cuda_graph support c_torch_cpu = cutlass_torch.matrix(1, tensor_m, n // 2, cd_major == "m", cutlass.Float32) a_tensor, a_torch_gpu = cutlass_torch.cute_tensor_like( a_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16 ) - b_tensor, b_torch_gpu = cutlass_torch.cute_tensor_like( - b_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16 - ) c_tensor, c_torch_gpu = cutlass_torch.cute_tensor_like( c_torch_cpu, c_dtype, is_dynamic_layout=True, assumed_align=16 ) @@ -452,26 +461,83 @@ def create_tensors( stride_order=(2, 0, 1) if a_major == "k" else (2, 1, 0), divisibility=32 if ab_dtype == cutlass.Float4E2M1FN else 16, ) - b_tensor.mark_compact_shape_dynamic( - mode=1 if b_major == "k" else 0, - stride_order=(2, 0, 1) if b_major == "k" else (2, 1, 0), - divisibility=32 if ab_dtype == cutlass.Float4E2M1FN else 16, - ) c_tensor.mark_compact_shape_dynamic( mode=1 if cd_major == "n" else 0, stride_order=(2, 0, 1) if cd_major == "n" else (2, 1, 0), divisibility=32 if ab_dtype == cutlass.Float4E2M1FN else 16, ) + if multi_b_mode: + # Multi-B tensor mode: create multiple B tensors + b_torch_cpu_list = [] + b_tensor_list = [] + b_torch_gpu_list = [] + sfb_torch_cpu_list = [] + sfb_tensor_list = [] + sfb_torch_gpu_list = [] + alpha_torch_cpu_list = [] + alpha_tensor_list = [] + + for l_size in b_tensor_l_sizes: + # Create alpha for this B tensor + alpha_torch_cpu = torch.randn((l_size,), dtype=torch.float32) + alpha_torch_cpu_list.append(alpha_torch_cpu) + alpha = from_dlpack(alpha_torch_cpu.cuda()).mark_layout_dynamic() + alpha_tensor_list.append(alpha) + + # Create B tensor + b_torch_cpu = cutlass_torch.matrix(l_size, n, k, b_major == "n", cutlass.Float32) + b_tensor, b_torch_gpu = cutlass_torch.cute_tensor_like( + b_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + b_tensor.mark_compact_shape_dynamic( + mode=1 if b_major == "k" else 0, + stride_order=(2, 0, 1) if b_major == "k" else (2, 1, 0), + divisibility=32 if ab_dtype == cutlass.Float4E2M1FN else 16, + ) + b_torch_cpu_list.append(b_torch_cpu) + b_tensor_list.append(b_tensor) + b_torch_gpu_list.append(b_torch_gpu) + + # Create SFB tensor + sfb_torch_cpu, sfb_tensor, sfb_torch_gpu = create_scale_factor_tensor( + l_size, n, k, sf_vec_size, sf_dtype + ) + sfb_torch_cpu_list.append(sfb_torch_cpu) + sfb_tensor_list.append(sfb_tensor) + sfb_torch_gpu_list.append(sfb_torch_gpu) + + b_tensor = b_tensor_list + b_torch_cpu = b_torch_cpu_list + b_torch_gpu = b_torch_gpu_list + sfb_tensor = sfb_tensor_list + sfb_torch_cpu = sfb_torch_cpu_list + sfb_torch_gpu = sfb_torch_gpu_list + alpha_torch_cpu = alpha_torch_cpu_list + alpha = alpha_tensor_list + else: + # Single B tensor mode + alpha_torch_cpu = torch.randn((num_groups,), dtype=torch.float32) + alpha = from_dlpack(alpha_torch_cpu.cuda()).mark_layout_dynamic() + + b_torch_cpu = cutlass_torch.matrix(num_groups, n, k, b_major == "n", cutlass.Float32) + b_tensor, b_torch_gpu = cutlass_torch.cute_tensor_like( + b_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + b_tensor.mark_compact_shape_dynamic( + mode=1 if b_major == "k" else 0, + stride_order=(2, 0, 1) if b_major == "k" else (2, 1, 0), + divisibility=32 if ab_dtype == cutlass.Float4E2M1FN else 16, + ) + sfb_torch_cpu, sfb_tensor, sfb_torch_gpu = create_scale_factor_tensor( + num_groups, n, k, sf_vec_size, sf_dtype + ) + # Use tensor_m (permuted_m if provided) for scale factor A sfa_torch_cpu, sfa_tensor, sfa_torch_gpu = create_scale_factor_tensor_unswizzled( 1, max_m, k, sf_vec_size, sf_dtype ) - sfb_torch_cpu, sfb_tensor, sfb_torch_gpu = create_scale_factor_tensor( - num_groups, n, k, sf_vec_size, sf_dtype - ) - token_id_mapping_cpu, token_id_mapping, token_id_mapping_torch = create_token_id_mapping_tensor( group_m_list, mma_tiler_m, max_token_id=max_m, permuted_m=permuted_m ) @@ -480,8 +546,6 @@ def create_tensors( tile_idx_to_mn_limit = from_dlpack(_tile_idx_to_mn_limit).mark_layout_dynamic() num_non_exiting_tiles = from_dlpack(_num_non_exiting_tiles).mark_layout_dynamic() - alpha = from_dlpack(alpha_torch_cpu.cuda()).mark_layout_dynamic() - # Create sfc_tensor and norm_const_tensor when c_dtype is Float4E2M1FN sfc_torch_cpu = None sfc_tensor = None @@ -555,6 +619,7 @@ def run( permuted_m: int = None, use_cupti: bool = False, raster_along_m: bool = False, + num_b_tensors: int = None, **kwargs, ): """Run contiguous grouped GEMM with gather operation and SwiGLU fusion for FC1 layer. @@ -566,24 +631,16 @@ def run( Note: Output C has N/2 columns since SwiGLU combines pairs of (up, gate) from interleaved B weights. - This function: - - Creates tensors including token_id_mapping for gather operation - - Uses LDGSTS for loading A and SFA matrices with gather capability - - Uses TMA for loading B and SFB matrices with multicast - - Performs SwiGLU activation fusion in epilogue - - Optionally performs quantization fusion for Float4E2M1FN output - - Performs reference checking (if not skipped) - - Benchmarks kernel performance - - :param nkl: (N, K, L) dimensions where L is the number of experts/groups - :param group_m_list: List of M values for each group - :param mma_tiler_mn: MMA tile shape (M, N), where mma_tiler_mn[0] is used for group M alignment - :param permuted_m: Optional padded M dimension for CUDA graph support. If provided, - A/C matrices, token_id_mapping, and scale factor A will be padded to this size. + :param num_b_tensors: If specified, enables multi-B tensor mode (2-4 tensors). """ + # Determine if multi-B tensor mode + multi_b_mode = num_b_tensors is not None + print("Running Blackwell Persistent Contiguous Grouped GEMM with Gather test:") print(f"nkl: {nkl}") print(f"group_m_list: {group_m_list}") + if multi_b_mode: + print(f"Multi-B tensor mode: {num_b_tensors} B tensors") print( f"AB dtype: {ab_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}, " f"Scale factor dtype: {sf_dtype}, SF Vec size: {sf_vec_size}" @@ -608,6 +665,14 @@ def run( if not torch.cuda.is_available(): raise RuntimeError("GPU is required to run this example!") + # Split groups into multiple B tensors if multi-B mode + b_tensor_l_sizes = None + groups_per_b_tensor = None + if multi_b_mode: + b_tensor_l_sizes, groups_per_b_tensor = split_groups_to_b_tensors(num_groups, num_b_tensors) + print(f"b_tensor_l_sizes: {b_tensor_l_sizes}") + print(f"groups_per_b_tensor: {groups_per_b_tensor}") + # Skip unsupported testcase # Note: For grouped GEMM, we use mma_tiler_mn[0] as the m parameter for can_implement check # since individual group M values vary @@ -677,7 +742,10 @@ def run( sf_vec_size, mma_tiler_mn[0], # mma_tiler_m, also used for alignment permuted_m, + b_tensor_l_sizes=b_tensor_l_sizes, + groups_per_b_tensor=groups_per_b_tensor, ) + # Configure gemm kernel gemm = BlockScaledContiguousGatherGroupedGemmKernel( sf_vec_size, @@ -686,6 +754,7 @@ def run( True, topk=1, raster_along_m=raster_along_m, + b_tensor_l_sizes=b_tensor_l_sizes if multi_b_mode else None, ) # Compute max active clusters on current device @@ -700,40 +769,78 @@ def run( current_stream = cuda.CUstream(torch_stream.cuda_stream) # Compile gemm kernel # sfc_tensor is optional and can be set as None (Python's None value) if not needed. - compiled_gemm = cute.compile( - gemm, - a_tensor, - b_tensor, - c_tensor, - sfa_tensor, - sfb_tensor, - sfc_tensor, - norm_const_tensor, - tile_idx_to_expert_idx, - tile_idx_to_mn_limit, - token_id_mapping, - num_non_exiting_tiles, - alpha, - max_active_clusters, - current_stream, - ) + if multi_b_mode: + # Multi-B tensor mode: pass tuples + compiled_gemm = cute.compile( + gemm, + a_tensor, + tuple(b_tensor), + c_tensor, + sfa_tensor, + tuple(sfb_tensor), + sfc_tensor, + norm_const_tensor, + tile_idx_to_expert_idx, + tile_idx_to_mn_limit, + token_id_mapping, + num_non_exiting_tiles, + tuple(alpha), + max_active_clusters, + current_stream, + ) + else: + # Single-B tensor mode + compiled_gemm = cute.compile( + gemm, + a_tensor, + b_tensor, + c_tensor, + sfa_tensor, + sfb_tensor, + sfc_tensor, + norm_const_tensor, + tile_idx_to_expert_idx, + tile_idx_to_mn_limit, + token_id_mapping, + num_non_exiting_tiles, + alpha, + max_active_clusters, + current_stream, + ) # Execution - compiled_gemm( - a_tensor, - b_tensor, - c_tensor, - sfa_tensor, - sfb_tensor, - sfc_tensor, - norm_const_tensor, - tile_idx_to_expert_idx, - tile_idx_to_mn_limit, - token_id_mapping, - num_non_exiting_tiles, - alpha, - current_stream, - ) + if multi_b_mode: + compiled_gemm( + a_tensor, + tuple(b_tensor), + c_tensor, + sfa_tensor, + tuple(sfb_tensor), + sfc_tensor, + norm_const_tensor, + tile_idx_to_expert_idx, + tile_idx_to_mn_limit, + token_id_mapping, + num_non_exiting_tiles, + tuple(alpha), + current_stream, + ) + else: + compiled_gemm( + a_tensor, + b_tensor, + c_tensor, + sfa_tensor, + sfb_tensor, + sfc_tensor, + norm_const_tensor, + tile_idx_to_expert_idx, + tile_idx_to_mn_limit, + token_id_mapping, + num_non_exiting_tiles, + alpha, + current_stream, + ) torch.cuda.synchronize() # Compute reference result @@ -751,10 +858,28 @@ def run( for i, group_m in enumerate(aligned_group_m_list): end = start + group_m res_a = a_torch_cpu_f32[token_id_mapping_cpu[start:end]] - res_b = torch.einsum("nk,nk->nk", b_torch_cpu[:, :, i], sfb_torch_cpu[:, :, i]) - gemm_result[0, start:end, :] = ( - torch.einsum("mk,nk->mn", res_a, res_b) * alpha_torch_cpu[i] - ) + + if multi_b_mode: + # Find which B tensor this group belongs to + b_tensor_idx = None + local_group_idx = None + for b_idx, groups in enumerate(groups_per_b_tensor): + if i in groups: + b_tensor_idx = b_idx + local_group_idx = groups.index(i) + break + assert b_tensor_idx is not None, f"Group {i} not found in any B tensor" + res_b = torch.einsum( + "nk,nk->nk", + b_torch_cpu[b_tensor_idx][:, :, local_group_idx], + sfb_torch_cpu[b_tensor_idx][:, :, local_group_idx], + ) + alpha_val = alpha_torch_cpu[b_tensor_idx][local_group_idx] + else: + res_b = torch.einsum("nk,nk->nk", b_torch_cpu[:, :, i], sfb_torch_cpu[:, :, i]) + alpha_val = alpha_torch_cpu[i] + + gemm_result[0, start:end, :] = torch.einsum("mk,nk->mn", res_a, res_b) * alpha_val start = end # Step 2: Apply SwiGLU on interleaved GEMM result @@ -1020,24 +1145,7 @@ def generate_tensors(): token_id_mapping, num_non_exiting_tiles, alpha, - a_torch_cpu, - b_torch_cpu, - c_torch_cpu, - sfa_torch_cpu, - sfb_torch_cpu, - sfc_torch_cpu, - norm_const_torch_cpu, - alpha_torch_cpu, - a_torch_gpu, - b_torch_gpu, - c_torch_gpu, - sfa_torch_gpu, - sfb_torch_gpu, - sfc_torch_gpu, - norm_const_torch_gpu, - aligned_group_m_list, - valid_m, - token_id_mapping_cpu, + *_, ) = create_tensors( num_groups, group_m_list, @@ -1052,40 +1160,67 @@ def generate_tensors(): sf_vec_size, mma_tiler_mn[0], # mma_tiler_m, also used for alignment permuted_m, + b_tensor_l_sizes=b_tensor_l_sizes, + groups_per_b_tensor=groups_per_b_tensor, ) - return cute.testing.JitArguments( - a_tensor, - b_tensor, - c_tensor, - sfa_tensor, - sfb_tensor, - sfc_tensor, - norm_const_tensor, - tile_idx_to_expert_idx, - tile_idx_to_mn_limit, - token_id_mapping, - num_non_exiting_tiles, - alpha, - current_stream, - ) + if multi_b_mode: + return cute.testing.JitArguments( + a_tensor, + tuple(b_tensor), + c_tensor, + sfa_tensor, + tuple(sfb_tensor), + sfc_tensor, + norm_const_tensor, + tile_idx_to_expert_idx, + tile_idx_to_mn_limit, + token_id_mapping, + num_non_exiting_tiles, + tuple(alpha), + current_stream, + ) + else: + return cute.testing.JitArguments( + a_tensor, + b_tensor, + c_tensor, + sfa_tensor, + sfb_tensor, + sfc_tensor, + norm_const_tensor, + tile_idx_to_expert_idx, + tile_idx_to_mn_limit, + token_id_mapping, + num_non_exiting_tiles, + alpha, + current_stream, + ) workspace_count = 1 if use_cold_l2: # Calculate actual tensor_m used (with padding if permuted_m provided) tensor_m = permuted_m if permuted_m is not None else valid_m + if multi_b_mode: + b_bytes = sum(t.numel() * t.element_size() for t in b_torch_gpu) + sfb_bytes = sum(t.numel() * t.element_size() for t in sfb_torch_gpu) + alpha_bytes = sum(t.numel() * t.element_size() for t in alpha_torch_cpu) + else: + b_bytes = b_torch_gpu.numel() * b_torch_gpu.element_size() + sfb_bytes = sfb_torch_gpu.numel() * sfb_torch_gpu.element_size() + alpha_bytes = alpha_torch_cpu.numel() * alpha_torch_cpu.element_size() one_workspace_bytes = ( a_torch_gpu.numel() * a_torch_gpu.element_size() - + b_torch_gpu.numel() * b_torch_gpu.element_size() + + b_bytes + c_torch_gpu.numel() * c_torch_gpu.element_size() + sfa_torch_gpu.numel() * sfa_torch_gpu.element_size() - + sfb_torch_gpu.numel() * sfb_torch_gpu.element_size() + + sfb_bytes + (tensor_m // mma_tiler_mn[0]) * 4 # tile_idx_to_expert_idx length (tiles) * sizeof(int32) + (tensor_m // mma_tiler_mn[0]) * 4 # tile_idx_to_mn_limit length (tiles) * sizeof(int32) + tensor_m * 4 # token_id_mapping_tensor length (elements) * sizeof(int32) + 1 * 4 # num_non_exiting_tiles (1 element) * sizeof(int32) - + alpha_torch_cpu.numel() * alpha_torch_cpu.element_size() + + alpha_bytes ) workspace_count = cute.testing.get_workspace_count( one_workspace_bytes, warmup_iterations, iterations @@ -1245,6 +1380,13 @@ def read_benchmark_file( parser.add_argument( "--raster_along_m", action="store_true", default=False, help="Raster along M dimension" ) + parser.add_argument( + "--num_b_tensors", + type=int, + default=None, + help="Number of B tensors to split into (for multi-B tensor test). " + "If specified, enables multi-B tensor mode. Must be 2, 3, or 4.", + ) args = parser.parse_args() @@ -1279,6 +1421,16 @@ def read_benchmark_file( if len(args.cluster_shape_mn) != 2: parser.error("--cluster_shape_mn must contain exactly 2 values") + if args.num_b_tensors is not None: + if args.num_b_tensors < 2 or args.num_b_tensors > 4: + parser.error("--num_b_tensors must be 2, 3, or 4") + n, k, num_groups = nkl + if num_groups < args.num_b_tensors: + parser.error( + f"--num_b_tensors ({args.num_b_tensors}) cannot be greater than " + f"number of groups ({num_groups})" + ) + exec_time = run( nkl, group_m_list, @@ -1300,6 +1452,7 @@ def read_benchmark_file( args.permuted_m, args.use_cupti, args.raster_along_m, + args.num_b_tensors, ) print(f"Execution time: {exec_time:.2f} us") print("PASS") diff --git a/tests/scripts/cute_dsl_kernels/run_blockscaled_contiguous_grouped_gemm_finalize_fusion.py b/tests/scripts/cute_dsl_kernels/run_blockscaled_contiguous_grouped_gemm_finalize_fusion.py index 848277c2afa7..aa8e16d40a12 100644 --- a/tests/scripts/cute_dsl_kernels/run_blockscaled_contiguous_grouped_gemm_finalize_fusion.py +++ b/tests/scripts/cute_dsl_kernels/run_blockscaled_contiguous_grouped_gemm_finalize_fusion.py @@ -297,6 +297,7 @@ def create_tensors( mma_tiler_mn, permuted_m=None, seq_len=None, + b_tensor_l_sizes=None, ): """Create tensors for contiguous grouped GEMM. @@ -304,6 +305,7 @@ def create_tensors( A matrix, C matrix, and scale factor A will be padded to this size. The kernel exits when tile_idx >= num_non_exiting_tiles. :param seq_len: Sequence length (number of output tokens for C tensor) + :param b_tensor_l_sizes: Optional tuple of L sizes for multi-B tensor mode. :return: Tuple of (a_tensor, b_tensor, out_tensor, sfa_tensor, sfb_tensor, tile_idx_to_expert_idx, num_non_exiting_tiles, alpha, a_torch_cpu, b_torch_cpu, out_torch_cpu, sfa_torch_cpu, sfb_torch_cpu, @@ -331,6 +333,11 @@ def create_tensors( """ torch.manual_seed(1111) + multi_b_mode = b_tensor_l_sizes is not None + if multi_b_mode: + total_l = sum(b_tensor_l_sizes) + if total_l != l: + raise ValueError(f"Sum of b_tensor_l_sizes ({total_l}) must equal total L ({l}).") alpha_torch_cpu = torch.ones((l,), dtype=torch.float32) * 0.1 ( @@ -394,6 +401,50 @@ def create_tensors( out_torch_gpu.fill_(0) + if multi_b_mode: + b_torch_cpu_list = [] + b_tensor_list = [] + b_torch_gpu_list = [] + sfb_torch_cpu_list = [] + sfb_tensor_list = [] + sfb_torch_gpu_list = [] + alpha_torch_cpu_list = [] + alpha_tensor_list = [] + + for l_size in b_tensor_l_sizes: + alpha_cpu = torch.ones((l_size,), dtype=torch.float32) * 0.1 + alpha_torch_cpu_list.append(alpha_cpu) + alpha_tensor_list.append(from_dlpack(alpha_cpu.cuda()).mark_layout_dynamic()) + + b_cpu = cutlass_torch.matrix(l_size, n, k, b_major == "n", cutlass.Float32) + b_tensor_i, b_torch_gpu_i = cutlass_torch.cute_tensor_like( + b_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + b_tensor_i.mark_compact_shape_dynamic( + mode=1 if b_major == "k" else 0, + stride_order=(2, 0, 1) if b_major == "k" else (2, 1, 0), + divisibility=32 if ab_dtype == cutlass.Float4E2M1FN else 16, + ) + b_torch_cpu_list.append(b_cpu) + b_tensor_list.append(b_tensor_i) + b_torch_gpu_list.append(b_torch_gpu_i) + + sfb_cpu, sfb_tensor_i, sfb_torch_gpu_i = create_scale_factor_tensor( + l_size, n, k, sf_vec_size, sf_dtype + ) + sfb_torch_cpu_list.append(sfb_cpu) + sfb_tensor_list.append(sfb_tensor_i) + sfb_torch_gpu_list.append(sfb_torch_gpu_i) + + b_tensor = b_tensor_list + b_torch_cpu = b_torch_cpu_list + b_torch_gpu = b_torch_gpu_list + sfb_tensor = sfb_tensor_list + sfb_torch_cpu = sfb_torch_cpu_list + sfb_torch_gpu = sfb_torch_gpu_list + alpha = alpha_tensor_list + alpha_torch_cpu = alpha_torch_cpu_list + return ( a_tensor, b_tensor, @@ -436,6 +487,13 @@ def verify_reference_result( topK: int, seq_len: int, ) -> torch.Tensor: + if isinstance(b_torch_cpu, list): + b_torch_cpu = torch.cat(b_torch_cpu, dim=2) + if isinstance(sfb_torch_cpu, list): + sfb_torch_cpu = torch.cat(sfb_torch_cpu, dim=2) + if isinstance(alpha_torch_cpu, list): + alpha_torch_cpu = torch.cat(alpha_torch_cpu, dim=0) + gemm_output = torch.empty((1, valid_m, n), dtype=torch.float32) valid_mask = torch.zeros((valid_m,), dtype=torch.bool, device="cuda") ######### gemm calculation ######### @@ -501,6 +559,7 @@ def run( seq_len: int = 4096, raster_along_m: bool = False, use_cupti: bool = False, + b_tensor_l_sizes=None, **kwargs, ): """Prepare A/B/C tensors, launch GPU kernel, and reference checking. @@ -556,6 +615,10 @@ def run( # Unpack parameters n, k, l = nkl # noqa: E741 + multi_b_mode = b_tensor_l_sizes is not None + total_l = sum(b_tensor_l_sizes) if multi_b_mode else l + if multi_b_mode and total_l != l: + raise ValueError(f"Sum of b_tensor_l_sizes ({total_l}) must equal L ({l}).") if not torch.cuda.is_available(): raise RuntimeError("GPU is required to run this example!") @@ -571,7 +634,7 @@ def run( m_aligned, n, k, - l, + total_l, a_major, b_major, out_major, @@ -620,6 +683,7 @@ def run( mma_tiler_mn, # cta_tile_m permuted_m, seq_len, # Pass seq_len as num_tokens for C tensor shape + b_tensor_l_sizes=b_tensor_l_sizes, ) # Calculate actual tensor_m used (with padding if permuted_m provided) @@ -645,6 +709,7 @@ def run( mma_tiler_mn, cluster_shape_mn, raster_along_m, + b_tensor_l_sizes=b_tensor_l_sizes if multi_b_mode else None, ) # Compute max active clusters on current device @@ -657,29 +722,27 @@ def run( current_stream = cutlass_torch.default_stream() # Compile gemm kernel - compiled_gemm = cute.compile( - gemm, - a_tensor, - b_tensor, - out_tensor, - sfa_tensor, - sfb_tensor, - tile_idx_to_expert_idx, - num_non_exiting_tiles, - tile_idx_to_mn_limit, - alpha, - max_active_clusters, - current_stream, - permuted_idx_to_expanded_idx, - token_final_scales, - options="--opt-level 2", - ) - - # Compute reference result - if not skip_ref_check: - print("Verifying results...") - # Execution - compiled_gemm( + if multi_b_mode: + compiled_gemm = cute.compile( + gemm, + a_tensor, + tuple(b_tensor), + out_tensor, + sfa_tensor, + tuple(sfb_tensor), + tile_idx_to_expert_idx, + num_non_exiting_tiles, + tile_idx_to_mn_limit, + tuple(alpha), + max_active_clusters, + current_stream, + permuted_idx_to_expanded_idx, + token_final_scales, + options="--opt-level 2", + ) + else: + compiled_gemm = cute.compile( + gemm, a_tensor, b_tensor, out_tensor, @@ -689,11 +752,48 @@ def run( num_non_exiting_tiles, tile_idx_to_mn_limit, alpha, + max_active_clusters, current_stream, permuted_idx_to_expanded_idx, token_final_scales, + options="--opt-level 2", ) + # Compute reference result + if not skip_ref_check: + print("Verifying results...") + # Execution + if multi_b_mode: + compiled_gemm( + a_tensor, + tuple(b_tensor), + out_tensor, + sfa_tensor, + tuple(sfb_tensor), + tile_idx_to_expert_idx, + num_non_exiting_tiles, + tile_idx_to_mn_limit, + tuple(alpha), + current_stream, + permuted_idx_to_expanded_idx, + token_final_scales, + ) + else: + compiled_gemm( + a_tensor, + b_tensor, + out_tensor, + sfa_tensor, + sfb_tensor, + tile_idx_to_expert_idx, + num_non_exiting_tiles, + tile_idx_to_mn_limit, + alpha, + current_stream, + permuted_idx_to_expanded_idx, + token_final_scales, + ) + torch.cuda.synchronize() ref_result = verify_reference_result( a_torch_cpu, @@ -788,6 +888,7 @@ def generate_tensors(): mma_tiler_mn, # cta_tile_m permuted_m, seq_len, # Pass seq_len as num_tokens for C tensor shape + b_tensor_l_sizes=b_tensor_l_sizes, ) ( @@ -804,6 +905,21 @@ def generate_tensors(): final_scale_dtype, ) + if multi_b_mode: + return cute.testing.JitArguments( + a_tensor, + tuple(b_tensor), + out_tensor, + sfa_tensor, + tuple(sfb_tensor), + tile_idx_to_expert_idx, + num_non_exiting_tiles, + tile_idx_to_mn_limit, + tuple(alpha), + current_stream, + permuted_idx_to_expanded_idx, + token_final_scales, + ) return cute.testing.JitArguments( a_tensor, b_tensor, @@ -821,16 +937,22 @@ def generate_tensors(): workspace_count = 1 if use_cold_l2: + + def _tensor_list_bytes(tensors): + if isinstance(tensors, list): + return sum(t.numel() * t.element_size() for t in tensors) + return tensors.numel() * tensors.element_size() + one_workspace_bytes = ( - a_torch_gpu.numel() * a_torch_gpu.element_size() - + b_torch_gpu.numel() * b_torch_gpu.element_size() - + out_torch_gpu.numel() * out_torch_gpu.element_size() - + sfa_torch_gpu.numel() * sfa_torch_gpu.element_size() - + sfb_torch_gpu.numel() * sfb_torch_gpu.element_size() + _tensor_list_bytes(a_torch_gpu) + + _tensor_list_bytes(b_torch_gpu) + + _tensor_list_bytes(out_torch_gpu) + + _tensor_list_bytes(sfa_torch_gpu) + + _tensor_list_bytes(sfb_torch_gpu) + (tensor_m // mma_tiler_mn[0]) * 4 # tile_idx_to_expert_idx length (tiles) * sizeof(int32) + 1 * 4 # num_non_exiting_tiles (1 element) * sizeof(int32) - + alpha_torch_cpu.numel() * alpha_torch_cpu.element_size() + + _tensor_list_bytes(alpha_torch_cpu) ) workspace_count = cute.testing.get_workspace_count( one_workspace_bytes, warmup_iterations, iterations @@ -857,6 +979,14 @@ def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: except ValueError: raise argparse.ArgumentTypeError("Invalid format. Expected comma-separated integers.") + def split_groups_to_b_tensors(num_groups: int, num_b_tensors: int) -> Tuple[int, ...]: + if num_b_tensors <= 0: + raise argparse.ArgumentTypeError("num_b_tensors must be positive.") + base = num_groups // num_b_tensors + remainder = num_groups % num_b_tensors + sizes = [base + (1 if i < remainder else 0) for i in range(num_b_tensors)] + return tuple(sizes) + def read_benchmark_file( filepath: str, ) -> Tuple[Tuple[int, int, int], Tuple[int, ...]]: @@ -1043,6 +1173,19 @@ def parse_benchmark_arg( help="Use CUPTI to measure execution time", ) + parser.add_argument( + "--num_b_tensors", + type=int, + default=1, + help="Number of B tensors for multi-B mode (default: 1).", + ) + parser.add_argument( + "--b_tensor_l_sizes", + type=parse_comma_separated_ints, + default=None, + help="Comma-separated L sizes for each B tensor (e.g., 8,8,16). Overrides --num_b_tensors.", + ) + args = parser.parse_args() # Process arguments to generate nkl and group_m_list @@ -1060,6 +1203,17 @@ def parse_benchmark_arg( if len(args.cluster_shape_mn) != 2: parser.error("--cluster_shape_mn must contain exactly 2 values") + _, _, l = nkl # noqa: E741 + b_tensor_l_sizes = None + if args.b_tensor_l_sizes is not None: + b_tensor_l_sizes = args.b_tensor_l_sizes + if args.num_b_tensors != 1 and args.num_b_tensors != len(b_tensor_l_sizes): + parser.error("--num_b_tensors must match length of --b_tensor_l_sizes") + if sum(b_tensor_l_sizes) != l: + parser.error("--b_tensor_l_sizes must sum to L") + elif args.num_b_tensors > 1: + b_tensor_l_sizes = split_groups_to_b_tensors(l, args.num_b_tensors) + exec_time = run( nkl, group_m_list, @@ -1083,6 +1237,7 @@ def parse_benchmark_arg( args.seq_len, args.raster_along_m, args.use_cupti, + b_tensor_l_sizes=b_tensor_l_sizes, ) print("exec_time: ", exec_time) print("PASS") diff --git a/tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py b/tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py index 91024f5e4b77..8dfa72bc9faa 100644 --- a/tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py +++ b/tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py @@ -548,10 +548,10 @@ def test_nvfp4_grouped_gemm_finalize_blackwell( c = torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_finalize_blackwell( a, - b, + [b], a_sf, - b_sf, - alpha, + [b_sf], + [alpha], tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, @@ -586,6 +586,35 @@ def test_nvfp4_grouped_gemm_finalize_blackwell( match_ratio = torch.isclose(c, c_ref, rtol=1.6e-2, atol=1e-5).sum().item() / c.numel() assert match_ratio > 0.99 + if num_local_experts > 1: + split_sizes = (num_local_experts // 2, num_local_experts - num_local_experts // 2) + b_list = list(torch.split(b, split_sizes, dim=0)) + b_sf_list = list(torch.split(b_sf, split_sizes, dim=0)) + alpha_list = list(torch.split(alpha, split_sizes, dim=0)) + c_multi = torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_finalize_blackwell( + a, + b_list, + a_sf, + b_sf_list, + alpha_list, + tile_idx_to_group_idx, + tile_idx_to_mn_limit, + permuted_idx_to_expanded_idx, + num_non_exiting_tiles, + token_final_scales, + num_experts=num_experts, + top_k=top_k, + num_local_experts=num_local_experts, + local_expert_offset=0, + tile_size=tile_size, + output_dtype=torch.bfloat16, + scaling_vector_size=sf_vec_size, + ) + multi_match_ratio = ( + torch.isclose(c_multi, c_ref, rtol=1.6e-2, atol=1e-5).sum().item() / c_ref.numel() + ) + assert multi_match_ratio > 0.99 + @pytest.mark.skipif( get_sm_version() not in (100, 103), @@ -839,13 +868,13 @@ def test_nvfp4_gather_grouped_gemm_swiglu_blackwell( global_sf = c_ref[:num_valid_permuted_tokens].abs().max().float() / (448 * 6) c_ref, c_sf_ref = torch.ops.trtllm.fp4_quantize(c_ref, 1 / global_sf, sf_vec_size, False) - # Call gather kernel - c, c_sf = torch.ops.trtllm.cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell( + # Call gather kernel (single-B via multi_b op with single-element lists) + c, c_sf = torch.ops.trtllm.cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell_multi_b( a, - b_interleaved, + [b_interleaved], a_sf_unswizzled, - b_sf_interleaved, - alpha, + [b_sf_interleaved], + [alpha], tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, @@ -891,3 +920,44 @@ def test_nvfp4_gather_grouped_gemm_swiglu_blackwell( c_sf_valid = torch.cat(c_sf_valid) c_sf_ref_valid = torch.cat(c_sf_ref_valid) check_accuracy(c_sf_valid, c_sf_ref_valid, atol=1e-4, rtol=1e-4, percent=0.95) + + if num_local_experts > 1: + split_sizes = ( + num_local_experts // 2, + num_local_experts - num_local_experts // 2, + ) + b_interleaved_list = list(torch.split(b_interleaved, split_sizes, dim=0)) + b_sf_interleaved_list = list(torch.split(b_sf_interleaved, split_sizes, dim=0)) + alpha_list = list(torch.split(alpha, split_sizes, dim=0)) + c_multi, c_sf_multi = ( + torch.ops.trtllm.cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell_multi_b( + a, + b_interleaved_list, + a_sf_unswizzled, + b_sf_interleaved_list, + alpha_list, + tile_idx_to_group_idx, + tile_idx_to_mn_limit, + permuted_idx_to_expanded_idx, + num_non_exiting_tiles, + torch.tensor([1 / global_sf], dtype=torch.float32, device="cuda"), + num_experts=num_experts, + top_k=top_k, + num_local_experts=num_local_experts, + local_expert_offset=0, + tile_size=tile_size, + scaling_vector_size=sf_vec_size, + ) + ) + c_multi_valid = c_multi[:num_valid_permuted_tokens].view(torch.uint8)[valid_token_mask] + check_accuracy(c_multi_valid, c_ref_valid, atol=1e-4, rtol=1e-4, percent=0.95) + + c_sf_multi_unswizzled = unswizzle_sf( + c_sf_multi, max_num_permuted_tokens, interm_size, sf_vec_size + ) + c_sf_multi_valid = [] + for i in range(num_valid_permuted_tokens): + if permuted_idx_to_expanded_idx[i].item() != helper.pad_val: + c_sf_multi_valid.append(c_sf_multi_unswizzled[i]) + c_sf_multi_valid = torch.cat(c_sf_multi_valid) + check_accuracy(c_sf_multi_valid, c_sf_ref_valid, atol=1e-4, rtol=1e-4, percent=0.95) From f76d02671aae64b9f79b31c3c1885b5125c8702d Mon Sep 17 00:00:00 2001 From: tianyuz-nv Date: Mon, 23 Mar 2026 00:21:21 -0700 Subject: [PATCH 02/12] Register DWDP accuracy test in CI test lists Signed-off-by: tianyuz-nv --- tests/integration/test_lists/qa/llm_function_core.txt | 1 + tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index ec19b2068832..df96a89fba34 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -378,6 +378,7 @@ accuracy/test_disaggregated_serving.py::TestDeepSeekV32Exp::test_auto_dtype_with accuracy/test_disaggregated_serving.py::TestDeepSeekV32Exp::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1dp2cp2] accuracy/test_disaggregated_serving.py::TestDeepSeekV32Exp::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2] accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_dwdp_accuracy accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False] accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True] accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[True] diff --git a/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml b/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml index 97da746c27fd..73a041cbe421 100644 --- a/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml +++ b/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml @@ -91,6 +91,7 @@ l0_gb200_multi_gpus: - accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4_adp_on-trtllm] - accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4_adp_off-trtllm] - accuracy/test_llm_api_pytorch_multimodal.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm] TIMEOUT (90) + - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_dwdp_accuracy - condition: ranges: From f153099e0ec2665639791861fb36e13b7e76b0da Mon Sep 17 00:00:00 2001 From: tianyuz-nv Date: Mon, 23 Mar 2026 02:57:56 -0700 Subject: [PATCH 03/12] Fix CI: remove forbidden from_dict, init helper, apply pre-commit formatting - Remove DwdpConfig.from_dict() to comply with Pydantic best practices test - Initialize GroupedGemmInputsHelper in test_nvfp4_gather_grouped_gemm_swiglu_blackwell - Apply pre-commit formatting (isort, yapf, ruff, autoflake, trailing whitespace) Signed-off-by: Tianyu Zhang Signed-off-by: tianyuz-nv Made-with: Cursor --- .../disaggregated/slurm/benchmark/submit.py | 16 +- .../_torch/custom_ops/cute_dsl_custom_ops.py | 89 +++++--- .../modules/fused_moe/configurable_moe.py | 6 +- .../modules/fused_moe/fused_moe_cute_dsl.py | 35 ++- .../_torch/modules/fused_moe/interface.py | 5 +- tensorrt_llm/_torch/pyexecutor/_util.py | 3 +- tensorrt_llm/_torch/pyexecutor/dwdp.py | 203 ++++++++++-------- tensorrt_llm/_torch/pyexecutor/py_executor.py | 2 +- .../_torch/pyexecutor/py_executor_creator.py | 2 +- tensorrt_llm/llmapi/llm_args.py | 17 +- .../accuracy/test_disaggregated_serving.py | 18 +- .../_torch/thop/parallel/test_cute_dsl_moe.py | 1 + 12 files changed, 233 insertions(+), 164 deletions(-) diff --git a/examples/disaggregated/slurm/benchmark/submit.py b/examples/disaggregated/slurm/benchmark/submit.py index 1de0658fae4e..48577f4a9351 100644 --- a/examples/disaggregated/slurm/benchmark/submit.py +++ b/examples/disaggregated/slurm/benchmark/submit.py @@ -53,6 +53,7 @@ def generate_mpi_worker_config(worker_config, allocations, env_config, disagg_hostname, disagg_port, output_path): """Generate a config YAML compatible with ``trtllm-serve disaggregated_mpi_worker``. """ + def _build_urls(server_type): urls = [] for server_id in sorted(allocations.get(server_type, {}).keys()): @@ -229,7 +230,6 @@ def replace_env_in_file(log_dir, file_path, env_var): return tmp_dir - def build_worker_environment(worker_config, env_config, role, benchmark_mode, nsys_on, profile_range, concurrency): """Build complete environment dictionary for worker processes. @@ -362,7 +362,6 @@ def format_export_string(env_dict): return ",".join(export_list) - def save_env_file(env_file, server_env_var, worker_env_var, ctx_worker_env_var, gen_worker_env_var): @@ -451,9 +450,10 @@ def submit_job(config, log_dir, dry_run): # Detect DWDP mode: when enabled, use a single srun with # trtllm-serve disaggregated_mpi_worker instead of per-instance sruns - dwdp_enabled = worker_config.get('ctx', {}).get( - 'dwdp_config', {}).get('enabled', False) - dwdp_size = worker_config.get('ctx', {}).get('dwdp_config', {}).get('dwdp_size', 1) + dwdp_enabled = worker_config.get('ctx', {}).get('dwdp_config', + {}).get('enabled', False) + dwdp_size = worker_config.get('ctx', {}).get('dwdp_config', + {}).get('dwdp_size', 1) # Generate log directory path based on configuration isl = benchmark_config['input_length'] @@ -560,7 +560,8 @@ def submit_job(config, log_dir, dry_run): if dwdp_enabled: # --- DWDP mode: single srun with disaggregated_mpi_worker --- - mpi_config_base_path = os.path.join(log_dir, 'mpi_worker_config_base.yaml') + mpi_config_base_path = os.path.join(log_dir, + 'mpi_worker_config_base.yaml') mpi_config_path = os.path.join(log_dir, 'mpi_worker_config.yaml') generate_mpi_worker_config(worker_config, allocations, env_config, disagg_server_hostname, disagg_server_port, @@ -627,7 +628,8 @@ def submit_job(config, log_dir, dry_run): benchmark_mode=benchmark_config['mode'], nsys_on=profiling_config['nsys_on'], profile_range=server_cfg['profile_range'], - concurrency=benchmark_config['concurrency_list'].split(',')[0], + concurrency=benchmark_config['concurrency_list'].split(',') + [0], ) export_str = format_export_string(worker_env) diff --git a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py index c7d300e68ce2..719ce73398ab 100644 --- a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py @@ -1272,7 +1272,7 @@ def forward(self, inputs: List[torch.Tensor], assert alpha0.dim() == 1 m, k = a.size(0), a.size(1) * 2 - l = sum(bi.size(0) for bi in b_list) + sum(bi.size(0) for bi in b_list) n = b0.size(1) scale_k = k // self.scaling_vector_size assert m % self.tile_size == 0 @@ -1337,17 +1337,14 @@ def forward(self, inputs: List[torch.Tensor], make_ptr(cutlass.Float4E2M1FN, bi.data_ptr(), cute.AddressSpace.gmem, - assumed_align=32) - for bi in b_list) + assumed_align=32) for bi in b_list) b_sf_ptr = tuple( make_ptr(cutlass.Float8E4M3FN, bsfi.data_ptr(), cute.AddressSpace.gmem, - assumed_align=16) - for bsfi in b_sf_list) + assumed_align=16) for bsfi in b_sf_list) alpha_ptr = tuple( - make_ptr(cutlass.Float32, ai.data_ptr(), - cute.AddressSpace.gmem) + make_ptr(cutlass.Float32, ai.data_ptr(), cute.AddressSpace.gmem) for ai in alpha_list) torch_stream = torch.cuda.current_stream() @@ -1393,7 +1390,8 @@ def forward(self, inputs: List[torch.Tensor], m, n, k, - num_tokens, self.top_k, + num_tokens, + self.top_k, ] compiled_gemm = cute.compile( @@ -1423,7 +1421,8 @@ def forward(self, inputs: List[torch.Tensor], m, n, k, - num_tokens, self.top_k, + num_tokens, + self.top_k, ] compiled_gemm(*exec_args, stream=stream) return c @@ -2049,7 +2048,7 @@ def forward(self, inputs: List, orig_m, k = a.size(0), a.size(1) * 2 m = permuted_idx_to_expanded_idx.size(0) n = b0.size(1) - l = sum(bi.size(0) for bi in b_list) + sum(bi.size(0) for bi in b_list) scale_k = k // self.scaling_vector_size interm_size = n // 2 @@ -2114,17 +2113,14 @@ def forward(self, inputs: List, make_ptr(cutlass.Float4E2M1FN, bi.data_ptr(), cute.AddressSpace.gmem, - assumed_align=32) - for bi in b_list) + assumed_align=32) for bi in b_list) b_sf_ptr = tuple( make_ptr(cutlass.Float8E4M3FN, bsfi.data_ptr(), cute.AddressSpace.gmem, - assumed_align=16) - for bsfi in b_sf_list) + assumed_align=16) for bsfi in b_sf_list) alpha_ptr = tuple( - make_ptr(cutlass.Float32, ai.data_ptr(), - cute.AddressSpace.gmem) + make_ptr(cutlass.Float32, ai.data_ptr(), cute.AddressSpace.gmem) for ai in alpha_list) torch_stream = torch.cuda.current_stream() @@ -2158,10 +2154,22 @@ def forward(self, inputs: List, cluster_shape_mn[0] * cluster_shape_mn[1]) compile_args = [ - a_ptr, b_ptr, a_sf_ptr, b_sf_ptr, c_ptr, c_sf_ptr, - alpha_ptr, tile_idx_to_group_idx_ptr, - tile_idx_to_mn_limit_ptr, permuted_idx_to_expanded_idx_ptr, - num_non_exiting_tiles_ptr, global_sf_ptr, orig_m, m, n, k, + a_ptr, + b_ptr, + a_sf_ptr, + b_sf_ptr, + c_ptr, + c_sf_ptr, + alpha_ptr, + tile_idx_to_group_idx_ptr, + tile_idx_to_mn_limit_ptr, + permuted_idx_to_expanded_idx_ptr, + num_non_exiting_tiles_ptr, + global_sf_ptr, + orig_m, + m, + n, + k, ] compiled_gemm = cute.compile( @@ -2177,10 +2185,22 @@ def forward(self, inputs: List, compiled_gemm = self.__class__.kernel_cache[cache_key] exec_args = [ - a_ptr, b_ptr, a_sf_ptr, b_sf_ptr, c_ptr, c_sf_ptr, alpha_ptr, - tile_idx_to_group_idx_ptr, tile_idx_to_mn_limit_ptr, - permuted_idx_to_expanded_idx_ptr, num_non_exiting_tiles_ptr, - global_sf_ptr, orig_m, m, n, k, + a_ptr, + b_ptr, + a_sf_ptr, + b_sf_ptr, + c_ptr, + c_sf_ptr, + alpha_ptr, + tile_idx_to_group_idx_ptr, + tile_idx_to_mn_limit_ptr, + permuted_idx_to_expanded_idx_ptr, + num_non_exiting_tiles_ptr, + global_sf_ptr, + orig_m, + m, + n, + k, ] compiled_gemm(*exec_args, stream=stream) @@ -2300,11 +2320,22 @@ def cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell( cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell_multi_b. """ return torch.ops.trtllm.cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell_multi_b( - input, [weight], input_scale, [weight_scale], [alpha], - tile_idx_to_group_idx, tile_idx_to_mn_limit, - permuted_idx_to_expanded_idx, num_non_exiting_tiles, global_sf, - num_experts, top_k, num_local_experts, local_expert_offset, - tile_size, scaling_vector_size, + input, + [weight], + input_scale, + [weight_scale], + [alpha], + tile_idx_to_group_idx, + tile_idx_to_mn_limit, + permuted_idx_to_expanded_idx, + num_non_exiting_tiles, + global_sf, + num_experts, + top_k, + num_local_experts, + local_expert_offset, + tile_size, + scaling_vector_size, ) @torch.library.register_fake( diff --git a/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py b/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py index f89a0bf60dc1..092519860741 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py @@ -35,8 +35,8 @@ from tensorrt_llm._torch.expert_statistic import ExpertStatistic from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm._torch.modules.fused_moe.interface import MoE -from tensorrt_llm._torch.pyexecutor.dwdp import get_global_dwdp_manager from tensorrt_llm._torch.modules.fused_moe.routing import BaseMoeRoutingMethod +from tensorrt_llm._torch.pyexecutor.dwdp import get_global_dwdp_manager from tensorrt_llm._torch.utils import AuxStreamType, EventType, Fp4QuantizedTensor from tensorrt_llm.logger import logger from tensorrt_llm.models.modeling_utils import QuantConfig @@ -305,7 +305,9 @@ def _should_enable_dwdp(self) -> bool: return False quant_mode = getattr(quant_config, "layer_quant_mode", None) - return bool(quant_mode is not None and hasattr(quant_mode, "has_nvfp4") and quant_mode.has_nvfp4()) + return bool( + quant_mode is not None and hasattr(quant_mode, "has_nvfp4") and quant_mode.has_nvfp4() + ) def _create_comm_strategy(self, model_config: ModelConfig) -> Optional[Communication]: """ diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py index a1959850593f..36dcbc87fff7 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py @@ -458,7 +458,6 @@ def __init__( if key not in self.event_dict: self.event_dict[key] = torch.cuda.Event() - def _build_local_weight_view(self) -> NvFp4WeightView: """Build weight view for non-DWDP path (single-element lists).""" return NvFp4WeightView( @@ -579,7 +578,11 @@ def run_moe_nvfp4( ) inputs = [ - x, token_selected_experts, token_final_scales, x_sf, moe_output, + x, + token_selected_experts, + token_final_scales, + x_sf, + moe_output, weight_view, ] _, best_tactic = tuner.choose_one( @@ -662,7 +665,9 @@ def run_moe_nvfp4_impl( input=x.view(torch.float4_e2m1fn_x2), weight=[weight_view.w2_weight[0].view(torch.float4_e2m1fn_x2)], input_scale=x_sf.view(torch.uint8), - weight_scale=[weight_view.fc2_weight_scale[0].view(torch.uint8)], + weight_scale=[ + weight_view.fc2_weight_scale[0].view(torch.uint8) + ], alpha=[weight_view.fc2_global_scale[0]], output=moe_output, tile_idx_to_group_idx=tile_idx_to_expert_idx, @@ -719,8 +724,7 @@ def run_moe_nvfp4_impl_dwdp( """ assert self.use_fused_finalize, ( "DWDP requires fused finalize (cute_dsl_nvfp4_grouped_gemm_blackwell " - "does not support multiple B weight tensors)" - ) + "does not support multiple B weight tensors)") output_dtype = torch.bfloat16 effective_top_k = token_selected_experts.size(1) esp = weight_view.expert_size_per_partition @@ -742,9 +746,13 @@ def run_moe_nvfp4_impl_dwdp( x, x_sf = torch.ops.trtllm.cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell_multi_b( input=x.view(torch.float4_e2m1fn_x2), - weight=[w.view(torch.float4_e2m1fn_x2) for w in weight_view.w3_w1_weight], + weight=[ + w.view(torch.float4_e2m1fn_x2) for w in weight_view.w3_w1_weight + ], input_scale=x_sf.view(torch.uint8), - weight_scale=[ws.view(torch.uint8) for ws in weight_view.fc1_weight_scale], + weight_scale=[ + ws.view(torch.uint8) for ws in weight_view.fc1_weight_scale + ], alpha=weight_view.fc1_global_scale, tile_idx_to_group_idx=tile_idx_to_expert_idx, tile_idx_to_mn_limit=tile_idx_to_mn_limit, @@ -777,9 +785,13 @@ def run_moe_nvfp4_impl_dwdp( torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell( input=x.view(torch.float4_e2m1fn_x2), - weight=[w.view(torch.float4_e2m1fn_x2) for w in weight_view.w2_weight], + weight=[ + w.view(torch.float4_e2m1fn_x2) for w in weight_view.w2_weight + ], input_scale=x_sf.view(torch.uint8), - weight_scale=[ws.view(torch.uint8) for ws in weight_view.fc2_weight_scale], + weight_scale=[ + ws.view(torch.uint8) for ws in weight_view.fc2_weight_scale + ], alpha=weight_view.fc2_global_scale, output=moe_output, tile_idx_to_group_idx=tile_idx_to_expert_idx, @@ -906,7 +918,8 @@ def run_moe( """ # Execute MoE computation if self.has_nvfp4: - weight_view = kwargs.get("dwdp_weight_view") or self._build_local_weight_view() + weight_view = kwargs.get( + "dwdp_weight_view") or self._build_local_weight_view() result = self.run_moe_nvfp4( x=x, token_selected_experts=token_selected_experts, @@ -970,4 +983,4 @@ def load_weights(self, weights: Dict[str, torch.Tensor]): super().load_weights(weights) dwdp_handle_collector = getattr(self, "dwdp_handle_collector", None) if dwdp_handle_collector is not None: - dwdp_handle_collector.register_weights(self) \ No newline at end of file + dwdp_handle_collector.register_weights(self) diff --git a/tensorrt_llm/_torch/modules/fused_moe/interface.py b/tensorrt_llm/_torch/modules/fused_moe/interface.py index d1ed9288d857..fcc3f3e3bb27 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/interface.py +++ b/tensorrt_llm/_torch/modules/fused_moe/interface.py @@ -47,11 +47,11 @@ def _warn_and_return(reason: str) -> Tuple[bool, Optional[str]]: from ...model_config import ModelConfig +from ...pyexecutor.dwdp import get_global_dwdp_manager from ...utils import (ActivationType, AuxStreamType, Fp4QuantizedTensor, get_model_extra_attrs, is_gated_activation, is_torch_compiling) from .routing import BaseMoeRoutingMethod -from ...pyexecutor.dwdp import get_global_dwdp_manager class MoEWeightLoadingMode(Enum): @@ -317,8 +317,7 @@ def _init_dwdp_expert_layout(self): return assert self.layer_load_balancer is None, ( "DWDP and EPLB (MoE load balancer) cannot be used together. " - "Disable one of dwdp_config.enabled or moe_load_balancer." - ) + "Disable one of dwdp_config.enabled or moe_load_balancer.") self.num_slots = self.num_experts self.expert_size_per_partition = dwdp_manager.experts_per_worker dwdp_size = dwdp_manager.dwdp_size diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 1504096af076..94e0fb9debf1 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -31,6 +31,7 @@ get_spec_decoder, should_use_separate_draft_kv_cache) from .config_utils import (get_qwen3_hybrid_layer_masks, is_mla, is_nemotron_hybrid, is_qwen3_hybrid) +from .dwdp import DwdpManager from .guided_decoder import GuidedDecoder from .kv_cache_connector import KvCacheConnectorManager from .kv_cache_transceiver import AttentionTypeCpp, create_kv_cache_transceiver @@ -47,7 +48,7 @@ KVCacheV2Scheduler, SimpleScheduler, SimpleUnifiedScheduler) from .seq_slot_manager import SeqSlotManager -from .dwdp import DwdpManager + GB = 1 << 30 diff --git a/tensorrt_llm/_torch/pyexecutor/dwdp.py b/tensorrt_llm/_torch/pyexecutor/dwdp.py index 2b2aa1910c22..892ee285444f 100644 --- a/tensorrt_llm/_torch/pyexecutor/dwdp.py +++ b/tensorrt_llm/_torch/pyexecutor/dwdp.py @@ -1,29 +1,27 @@ +from typing import Dict, List, Optional, Tuple + import torch import torch.nn as nn - -from tensorrt_llm.llmapi.llm_args import DwdpConfig -from typing import List, Optional, Dict, Tuple -from tensorrt_llm._torch.distributed import MPIDist -from tensorrt_llm._utils import global_mpi_rank -from mpi4py.MPI import COMM_WORLD - -from cuda.bindings import runtime as cudart from cuda.bindings import driver as cuda_driver -from tensorrt_llm._utils import nvtx_range - +from cuda.bindings import runtime as cudart +from mpi4py.MPI import COMM_WORLD +from tensorrt_llm._torch.distributed import MPIDist +from tensorrt_llm._utils import global_mpi_rank, nvtx_range +from tensorrt_llm.llmapi.llm_args import DwdpConfig # Parameter names to collect handles for -WEIGHT_PARAMS = ['w3_w1_weight', 'w2_weight'] -BIAS_PARAMS = ['w3_w1_bias', 'w2_bias'] +WEIGHT_PARAMS = ["w3_w1_weight", "w2_weight"] +BIAS_PARAMS = ["w3_w1_bias", "w2_bias"] # Quant scale params vary by quantization method QUANT_SCALE_PARAMS = [ - 'w3_w1_weight_scale', 'w2_weight_scale', # NVFP4/MXFP4 - 'fc31_alpha', 'fc2_alpha', # NVFP4 alpha + "w3_w1_weight_scale", + "w2_weight_scale", # NVFP4/MXFP4 + "fc31_alpha", + "fc2_alpha", # NVFP4 alpha ] - _global_dwdp_manager: Optional["DwdpManager"] = None @@ -46,12 +44,11 @@ class DwdpLayerHandleCollector: """ Dwdp Layer Handle Collector for IPC handle coordination and prefetch buffer management. """ - + def __init__( self, layer_idx: int, ): - self.layer_idx = layer_idx # Local IPC handles: param_name -> handle_bytes @@ -71,9 +68,9 @@ def __init__( def register_weights(self, module: nn.Module): """ Register weights from a MoE module and create IPC handles. - + Called after module.load_weights() completes. - + Args: module: The MoE module with loaded weights """ @@ -83,7 +80,7 @@ def register_weights(self, module: nn.Module): if hasattr(module, param_name) and getattr(module, param_name, None) is not None: params_to_register.append(param_name) # Bias (optional) - if hasattr(module, 'bias'): + if hasattr(module, "bias"): params_to_register.extend(BIAS_PARAMS) # Quant scales (optional, depends on quant method) for param_name in QUANT_SCALE_PARAMS: @@ -106,18 +103,18 @@ def _register_param(self, param_name: str, param: torch.Tensor): tensor_ptr = param.data_ptr() err, handle = cudart.cudaIpcGetMemHandle(tensor_ptr) check_cuda_error(err, f"get handle for {param_name}") - + # Get allocation base address using Driver API cuMemGetAddressRange # This returns the actual base address and size of the CUDA allocation # cudaPointerGetAttributes.devicePointer returns the input pointer, not base! err, alloc_base, alloc_size = cuda_driver.cuMemGetAddressRange(tensor_ptr) if err != cuda_driver.CUresult.CUDA_SUCCESS: raise RuntimeError(f"cuMemGetAddressRange failed for {param_name}: {err}") - + # Calculate offset from allocation base # Convert CUdeviceptr to int for arithmetic offset = tensor_ptr - int(alloc_base) - + self.local_ipc_handles[param_name] = bytes(handle.reserved) self.local_ptrs[param_name] = tensor_ptr self.local_offsets[param_name] = offset @@ -127,7 +124,7 @@ def _register_param(self, param_name: str, param: torch.Tensor): def get_peer_ptr(self, peer_rank: int, param_name: str) -> int: """Get pointer to parameter on peer rank.""" return self.peer_ptrs[(peer_rank, param_name)] - + def cleanup(self): """Clean up peer handles.""" for _, ptr in self.peer_ptrs.items(): @@ -138,24 +135,25 @@ def cleanup(self): class DwdpPrefetchBuffer: """ Ping-pong buffer for expert weight prefetching. - + Buffer Selection Strategy: - Even layers (0, 2, 4, ...) use buffer[0] - Odd layers (1, 3, 5, ...) use buffer[1] - This ensures layer N-1's prefetch doesn't overwrite layer N's data - + Synchronization Strategy: - prefetch_events[buffer_idx][layer_idx]: Recorded when prefetch completes Waited by forward() before using prefetched data - compute_events[buffer_idx][layer_idx]: Recorded when forward() completes Waited by next prefetch before overwriting buffer - + Buffer Layout (organized by rank): - buffers[buffer_idx][param_name] = List[Optional[Tensor]] - len(list) == dwdp_size - list[peer_rank] = Tensor[num_prefetch_experts, ...] for peer_rank != dwdp_rank - list[dwdp_rank] = None (local weight used directly, not prefetched) """ + def __init__( self, dwdp_size: int, @@ -167,7 +165,6 @@ def __init__( param_shapes: Dict[str, torch.Size], param_dtypes: Dict[str, torch.dtype], ): - self.dwdp_size = dwdp_size self.num_prefetch_experts = num_prefetch_experts self.experts_per_worker = experts_per_worker @@ -178,9 +175,9 @@ def __init__( self.param_shapes = param_shapes self.param_dtypes = param_dtypes - + self.device = torch.cuda.current_device() - + # buffers[buffer_idx][param_name] = List[Optional[Tensor]] # list[peer_rank] contains prefetched weights from that rank # list[dwdp_rank] = None (local weights used directly) @@ -204,14 +201,14 @@ def __init__( ) buffer[param_name] = tensor_list self.buffers.append(buffer) - + self.max_layer_idx = num_layers + first_moe_layer_idx self.prefetch_events: List[List[torch.cuda.Event]] = [ - [torch.cuda.Event() for _ in range(self.max_layer_idx//self.num_buffers + 1)] + [torch.cuda.Event() for _ in range(self.max_layer_idx // self.num_buffers + 1)] for _ in range(self.num_buffers) ] self.compute_events: List[List[torch.cuda.Event]] = [ - [torch.cuda.Event() for _ in range(self.max_layer_idx//self.num_buffers + 1)] + [torch.cuda.Event() for _ in range(self.max_layer_idx // self.num_buffers + 1)] for _ in range(self.num_buffers) ] self.prefetch_stream = torch.cuda.Stream(device=self.device) @@ -219,37 +216,44 @@ def __init__( def initialize_compute_events(self): for buffer_idx in range(self.num_buffers): self.compute_events[buffer_idx][0].record(torch.cuda.current_stream()) - + def record_prefetch_event(self, layer_idx: int): - self.prefetch_events[layer_idx % self.num_buffers][layer_idx // self.num_buffers].record(self.prefetch_stream) + self.prefetch_events[layer_idx % self.num_buffers][layer_idx // self.num_buffers].record( + self.prefetch_stream + ) def record_compute_event(self, layer_idx: int): - self.compute_events[layer_idx % self.num_buffers][layer_idx // self.num_buffers].record(torch.cuda.current_stream()) + self.compute_events[layer_idx % self.num_buffers][layer_idx // self.num_buffers].record( + torch.cuda.current_stream() + ) def wait_prefetch_event(self, layer_idx: int): - torch.cuda.current_stream().wait_event(self.prefetch_events[layer_idx % self.num_buffers][layer_idx // self.num_buffers]) + torch.cuda.current_stream().wait_event( + self.prefetch_events[layer_idx % self.num_buffers][layer_idx // self.num_buffers] + ) def wait_compute_event(self, layer_idx: int): - self.prefetch_stream.wait_event(self.compute_events[layer_idx % self.num_buffers][layer_idx // self.num_buffers]) + self.prefetch_stream.wait_event( + self.compute_events[layer_idx % self.num_buffers][layer_idx // self.num_buffers] + ) class DwdpManager: """ Dwdp Manager for IPC handle coordination and prefetch buffer management. - + This manager: - Tracks IPC handles for all MoE layers across Context workers - Manages double-buffered prefetch buffers for remote expert weights - Provides expert tensor routing (local vs. prefetched) - + """ - + def __init__( self, config: DwdpConfig, dist: Optional[object] = None, ): - self.config = config self.dist = dist self.dwdp_size = config.dwdp_size @@ -257,10 +261,10 @@ def __init__( self.num_group = config.num_group self._init_dwdp_group() - + # Per-layer IPC handle collectors (indexed by layer_idx) self.ipc_collectors: List[DwdpLayerHandleCollector] = [] - + # Prefetch buffer (initialized later in create_py_executor) self.prefetch_buffer: Optional[DwdpPrefetchBuffer] = None # Auto-detected from first add_layer() call @@ -277,12 +281,11 @@ def __init__( set_global_dwdp_manager(self) def _init_dwdp_group(self): - if not isinstance(self.dist, MPIDist): raise RuntimeError("DWDP requires MPI backend (MPIDist)") self.rank = global_mpi_rank() - + # Calculate which group this rank belongs to # With num_group=2, dwdp_size=4: # Group 0: ranks [0, 1, 2, 3] @@ -290,75 +293,77 @@ def _init_dwdp_group(self): self.group_id = self.rank // self.dwdp_size group_start_rank = self.group_id * self.dwdp_size ranks = list(range(group_start_rank, group_start_rank + self.dwdp_size)) - + new_group = COMM_WORLD.group.Incl(ranks) self.dwdp_group = COMM_WORLD.Create_group(new_group) def is_enabled(self) -> bool: return self.config.enabled and self.dwdp_size > 1 - + def add_layer( self, layer_idx: int, ) -> "DwdpLayerHandleCollector": """ Add a new layer IPC handle collector. - + Called from CuteDslFusedMoE.__init__() during model construction. """ if self.first_moe_layer_idx is None: self.first_moe_layer_idx = layer_idx - collector = DwdpLayerHandleCollector( - layer_idx=layer_idx - ) + collector = DwdpLayerHandleCollector(layer_idx=layer_idx) self.ipc_collectors.append(collector) return collector - + def exchange_all_handles(self): """ Exchange IPC handles with peer Context workers via Dwdp Group AllGather. - + Called after all weights are loaded, before creating prefetch buffer. """ - + # Collect all local handles with explicit worker info local_data = { - 'dwdp_rank': self.dwdp_rank, - 'expert_start_id': self.start_expert_id, - 'expert_end_id': self.end_expert_id, - 'ipc_collectors': [], + "dwdp_rank": self.dwdp_rank, + "expert_start_id": self.start_expert_id, + "expert_end_id": self.end_expert_id, + "ipc_collectors": [], } for collector in self.ipc_collectors: - local_data['ipc_collectors'].append({ - 'layer_idx': collector.layer_idx, - 'handles': collector.local_ipc_handles, - 'offsets': collector.local_offsets, - }) - + local_data["ipc_collectors"].append( + { + "layer_idx": collector.layer_idx, + "handles": collector.local_ipc_handles, + "offsets": collector.local_offsets, + } + ) + # AllGather from all Context workers in DWDP group all_data = self.dwdp_group.allgather(local_data) - + # Open handles from peer workers for peer_data in all_data: - peer_rank = peer_data['dwdp_rank'] - self.peer_expert_ranges[peer_rank] = (peer_data['expert_start_id'], peer_data['expert_end_id']) + peer_rank = peer_data["dwdp_rank"] + self.peer_expert_ranges[peer_rank] = ( + peer_data["expert_start_id"], + peer_data["expert_end_id"], + ) if peer_rank == self.dwdp_rank: continue - for layer_idx, ipc_collector in enumerate(peer_data['ipc_collectors']): + for layer_idx, ipc_collector in enumerate(peer_data["ipc_collectors"]): collector = self.ipc_collectors[layer_idx] - peer_offsets = ipc_collector['offsets'] - for param_name, handle_bytes in ipc_collector['handles'].items(): + peer_offsets = ipc_collector["offsets"] + for param_name, handle_bytes in ipc_collector["handles"].items(): # Reconstruct and open handle handle = cudart.cudaIpcMemHandle_t() handle.reserved = list(handle_bytes) - + err, base_ptr = cudart.cudaIpcOpenMemHandle( - handle, - cudart.cudaIpcMemLazyEnablePeerAccess + handle, cudart.cudaIpcMemLazyEnablePeerAccess ) check_cuda_error(err, f"open handle rank={peer_rank}") - + # Apply offset to get actual tensor pointer # IPC handle points to allocation base, offset gives us the tensor location offset = peer_offsets[param_name] @@ -368,7 +373,7 @@ def exchange_all_handles(self): def initialize_prefetch_buffer(self): """ Initialize the prefetch buffer. - + Called in create_py_executor() after model loading. """ self.prefetch_buffer = DwdpPrefetchBuffer( @@ -382,7 +387,7 @@ def initialize_prefetch_buffer(self): param_dtypes=self.ipc_collectors[0].param_dtypes, ) self.prefetch_buffer.initialize_compute_events() - + def prefetch_first_layers(self): """Prefetch the first num_buffers layers as warmup.""" if self.prefetch_buffer is None: @@ -410,8 +415,12 @@ def build_weight_view(self, layer_idx: int, backend): buffer_data = self.wait_prefetch_and_get_buffer(layer_idx) required_keys = ( - "w3_w1_weight", "w3_w1_weight_scale", "fc31_alpha", - "w2_weight", "w2_weight_scale", "fc2_alpha", + "w3_w1_weight", + "w3_w1_weight_scale", + "fc31_alpha", + "w2_weight", + "w2_weight_scale", + "fc2_alpha", ) missing_keys = [key for key in required_keys if key not in buffer_data] if missing_keys: @@ -444,9 +453,11 @@ def build_weight_view(self, layer_idx: int, backend): slot_start=0, ) - def wait_prefetch_and_get_buffer(self, layer_idx: int) -> Optional[Dict[str, List[Optional[torch.Tensor]]]]: + def wait_prefetch_and_get_buffer( + self, layer_idx: int + ) -> Optional[Dict[str, List[Optional[torch.Tensor]]]]: """Wait for prefetch to complete and return the buffer for this layer. - + Returns: Dict mapping param_name to List[Optional[Tensor]] where: - list[peer_rank] = Tensor for prefetched weights from that peer @@ -475,10 +486,10 @@ def record_compute_and_prefetch_next(self, layer_idx: int): def _get_prefetch_src_offset_from_peer(self, peer_rank: int) -> int: """ Calculate the source offset (in number of experts) to fetch from a peer. - + Returns: src_offset: Offset into peer's local expert tensor to start copying from - + Example: 256 experts, rank0: [0, 200), rank1: [56, 256) - rank0 needs [200, 256) from rank1: src_offset = 200 - 56 = 144 (fetch last 56 experts from rank1) @@ -486,7 +497,7 @@ def _get_prefetch_src_offset_from_peer(self, peer_rank: int) -> int: src_offset = 0 - 0 = 0 (fetch first 56 experts from rank0) """ peer_start, peer_end = self.peer_expert_ranges[peer_rank] - + # What I need = global - what I have # From peer = what I need ∩ what peer has if self.dwdp_rank < peer_rank: @@ -496,7 +507,7 @@ def _get_prefetch_src_offset_from_peer(self, peer_rank: int) -> int: else: # I'm later rank, need experts before my start (head of peer's experts) prefetch_start = peer_start - + src_offset = prefetch_start - peer_start return src_offset @@ -504,12 +515,12 @@ def _get_prefetch_src_offset_from_peer(self, peer_rank: int) -> int: def prefetch_layer(self, layer_idx: int, wait_compute_layer_idx: Optional[int] = None): """ Prefetch layer data from peer ranks. - + Args: layer_idx: The layer to prefetch wait_compute_layer_idx: If provided, wait for this layer's compute to complete before overwriting buffer (used when prefetching next layer) - + Note: Local weights are used directly by the kernel, no copy needed. Peer copy runs on prefetch stream. """ @@ -524,33 +535,35 @@ def prefetch_layer(self, layer_idx: int, wait_compute_layer_idx: Optional[int] = # Wait for compute to complete before overwriting buffer if wait_compute_layer_idx is not None: self.prefetch_buffer.wait_compute_event(wait_compute_layer_idx) - + for peer_rank in range(self.dwdp_size): if peer_rank == self.dwdp_rank: continue # Skip local rank - local weights used directly - + src_expert_offset = self._get_prefetch_src_offset_from_peer(peer_rank) - + for param_name in param_names: param_shape = collector.param_shapes[param_name] param_dtype = collector.param_dtypes[param_name] expert_size = param_shape.numel() * param_dtype.itemsize - + # src_ptr points to peer's tensor start, add offset for specific experts base_ptr = collector.get_peer_ptr(peer_rank, param_name) src_ptr = base_ptr + src_expert_offset * expert_size - + # dst_tensor is directly indexed by peer_rank in the list dst_tensor = self.prefetch_buffer.buffers[buffer_idx][param_name][peer_rank] dst_ptr = dst_tensor.data_ptr() - + data_size = self.num_prefetch_experts * expert_size - err, = cudart.cudaMemcpyAsync( + (err,) = cudart.cudaMemcpyAsync( dst_ptr, src_ptr, data_size, cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice, self.prefetch_buffer.prefetch_stream.cuda_stream, ) - check_cuda_error(err, f"prefetch layer {layer_idx} peer_rank {peer_rank} {param_name}") \ No newline at end of file + check_cuda_error( + err, f"prefetch layer {layer_idx} peer_rank {peer_rank} {param_name}" + ) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 8002ea8ab92d..a18108ef9175 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -47,6 +47,7 @@ from ..speculative.drafter import Drafter from ..speculative.spec_sampler_base import SampleStateTensorsSpec from ..speculative.speculation_gate import SpeculationGate +from .dwdp import DwdpManager from .executor_request_queue import ExecutorRequestQueue, RequestQueueItem from .guided_decoder import GuidedDecoder from .handle_additional_outputs import HandleAdditionalOutputs @@ -68,7 +69,6 @@ SerializableSchedulerOutput, WaitingQueue, create_waiting_queue) from .scheduler.adp_router import ADPRouter, DefaultADPRouter -from .dwdp import DwdpManager # Environment variable to specify iteration ranges for profiling start/stop. # Format: "start1-stop1,start2-stop2,..." or single iterations "iter1,iter2,..." diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index 8eafa62a2cf5..dfd0d9b5c44c 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -37,12 +37,12 @@ create_py_executor_instance, instantiate_sampler, is_mla, validate_feature_combination) from .config_utils import is_nemotron_hybrid, is_qwen3_hybrid +from .dwdp import DwdpManager from .guided_decoder import CapturableGuidedDecoder, GuidedDecoder from .kv_cache_connector import KvCacheConnectorManager from .model_engine import PyTorchModelEngine from .model_loader import ModelLoader, _construct_checkpoint_loader from .py_executor import PyExecutor -from .dwdp import DwdpManager class _ExecutorMemoryMonitor: diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 97c0e695a5cf..675056799011 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -2378,14 +2378,17 @@ class DwdpConfig(StrictBaseModel): Configuration for DWDP. """ enabled: bool = Field(default=False, description="Whether to enable DWDP.") - dwdp_size: int = Field(default=1, description="The number of GPUs per DWDP group.") - num_group: int = Field(default=1, description="The number of DWDP groups. Total workers = num_group * dwdp_size.") - experts_per_worker: int = Field(default=0, description="The number of experts per worker.") - num_prefetch_experts: int = Field(default=0, description="The number of prefetch experts per worker.") + dwdp_size: int = Field(default=1, + description="The number of GPUs per DWDP group.") + num_group: int = Field( + default=1, + description= + "The number of DWDP groups. Total workers = num_group * dwdp_size.") + experts_per_worker: int = Field( + default=0, description="The number of experts per worker.") + num_prefetch_experts: int = Field( + default=0, description="The number of prefetch experts per worker.") - @classmethod - def from_dict(cls, data: dict): - return cls(**data) class BaseLlmArgs(StrictBaseModel): """ diff --git a/tests/integration/defs/accuracy/test_disaggregated_serving.py b/tests/integration/defs/accuracy/test_disaggregated_serving.py index c4f435f997e4..35825c154115 100644 --- a/tests/integration/defs/accuracy/test_disaggregated_serving.py +++ b/tests/integration/defs/accuracy/test_disaggregated_serving.py @@ -1237,9 +1237,11 @@ def test_dwdp_accuracy(self): }, } - with launch_dwdp_disaggregated_llm( - worker_config, frontend_config, model_path, - total_gpus=4, max_workers=128) as llm: + with launch_dwdp_disaggregated_llm(worker_config, + frontend_config, + model_path, + total_gpus=4, + max_workers=128) as llm: run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"]) @@ -1948,8 +1950,11 @@ def launch_dwdp_disaggregated_llm( # Prevent the parent process's MPI state (set by mpi4py init during # tensorrt_llm import) from leaking into the mpirun subprocess. # mpirun must create a fresh MPI world for the DWDP workers. - child_env = {k: v for k, v in os.environ.items() - if not k.startswith(('OMPI_', 'PMIX_', 'PMI_'))} + child_env = { + k: v + for k, v in os.environ.items() + if not k.startswith(('OMPI_', 'PMIX_', 'PMI_')) + } mpi_cmd = [ "mpirun", "--allow-run-as-root", "-n", @@ -1979,8 +1984,7 @@ def launch_dwdp_disaggregated_llm( ]: if proc.poll() is not None: raise Exception( - f"{name} process exited with code {proc.returncode}" - ) + f"{name} process exited with code {proc.returncode}") try: response = requests.get( f"http://localhost:{serve_port}/cluster_info") diff --git a/tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py b/tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py index 8dfa72bc9faa..1432f2ec1fd1 100644 --- a/tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py +++ b/tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py @@ -763,6 +763,7 @@ def test_nvfp4_gather_grouped_gemm_swiglu_blackwell( interm_size = 8192 num_experts = 256 num_local_experts = num_experts // ep_size + helper = GroupedGemmInputsHelper(num_experts, top_k, num_local_experts, 0, tile_size) # Generate routing information routing_logits = torch.randn(num_tokens, num_experts, device="cuda") From 38d6e9298521ddaaab27715cc2fa13dfda8a81d0 Mon Sep 17 00:00:00 2001 From: tianyuz-nv Date: Mon, 23 Mar 2026 21:52:05 -0700 Subject: [PATCH 04/12] Fix CI and add env var config for disaggregated benchmark - Add dwdp_config entry to llm.yaml API stability reference - Fix c_sf_multi_valid filtering to use tile boundary condition consistent with c_sf_ref_valid, resolving 72 MoE kernel test failures - Add worker_env_var extraction in disaggregated benchmark submit script Signed-off-by: tianyuz-nv --- examples/disaggregated/slurm/benchmark/submit.py | 3 +++ tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py | 5 +++-- tests/unittest/api_stability/references/llm.yaml | 4 ++++ 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/examples/disaggregated/slurm/benchmark/submit.py b/examples/disaggregated/slurm/benchmark/submit.py index 48577f4a9351..fd2aa7da0513 100644 --- a/examples/disaggregated/slurm/benchmark/submit.py +++ b/examples/disaggregated/slurm/benchmark/submit.py @@ -583,6 +583,9 @@ def submit_job(config, log_dir, dry_run): total_mpi_tasks = ctx_num * ctx_world_size + gen_num * gen_world_size mpi_num_nodes = len(mpi_nodelist) num_ctx_gpus = ctx_num * ctx_world_size + worker_env_var = env_config.get('worker_env_var', '') + ctx_worker_env_var = env_config.get('ctx_worker_env_var', '') + gen_worker_env_var = env_config.get('gen_worker_env_var', '') dwdp_ctx_worker_env_var = worker_env_var + \ (f" {ctx_worker_env_var}" if ctx_worker_env_var else "") dwdp_gen_worker_env_var = worker_env_var + \ diff --git a/tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py b/tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py index 1432f2ec1fd1..a96fe088a3c3 100644 --- a/tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py +++ b/tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py @@ -958,7 +958,8 @@ def test_nvfp4_gather_grouped_gemm_swiglu_blackwell( ) c_sf_multi_valid = [] for i in range(num_valid_permuted_tokens): - if permuted_idx_to_expanded_idx[i].item() != helper.pad_val: - c_sf_multi_valid.append(c_sf_multi_unswizzled[i]) + if i >= tile_idx_to_mn_limit_list[i // tile_size]: + continue + c_sf_multi_valid.append(c_sf_multi_unswizzled[i]) c_sf_multi_valid = torch.cat(c_sf_multi_valid) check_accuracy(c_sf_multi_valid, c_sf_ref_valid, atol=1e-4, rtol=1e-4, percent=0.95) diff --git a/tests/unittest/api_stability/references/llm.yaml b/tests/unittest/api_stability/references/llm.yaml index 4fe428b7a416..2b5c0e7d08e4 100644 --- a/tests/unittest/api_stability/references/llm.yaml +++ b/tests/unittest/api_stability/references/llm.yaml @@ -87,6 +87,10 @@ methods: annotation: Optional[tensorrt_llm.llmapi.llm_args.AttentionDpConfig] default: null status: beta + dwdp_config: + annotation: tensorrt_llm.llmapi.llm_args.DwdpConfig + default: null + status: beta checkpoint_loader: annotation: Optional[tensorrt_llm._torch.models.checkpoints.BaseCheckpointLoader] default: null From f52716d53661c482f0fa33e921a1f6f858dae801 Mon Sep 17 00:00:00 2001 From: tianyuz-nv Date: Mon, 23 Mar 2026 22:33:57 -0700 Subject: [PATCH 05/12] Remove unused helper variable to fix ruff F841 Signed-off-by: tianyuz-nv --- tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py b/tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py index a96fe088a3c3..3f35261ddc61 100644 --- a/tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py +++ b/tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py @@ -763,7 +763,6 @@ def test_nvfp4_gather_grouped_gemm_swiglu_blackwell( interm_size = 8192 num_experts = 256 num_local_experts = num_experts // ep_size - helper = GroupedGemmInputsHelper(num_experts, top_k, num_local_experts, 0, tile_size) # Generate routing information routing_logits = torch.randn(num_tokens, num_experts, device="cuda") From b10dc898d41d34df92aa54c2951c3b6a02154d26 Mon Sep 17 00:00:00 2001 From: tianyuz-nv Date: Sun, 29 Mar 2026 20:20:15 -0700 Subject: [PATCH 06/12] Improve DwdpConfig docstring per review feedback Signed-off-by: tianyuz-nv --- tensorrt_llm/llmapi/llm_args.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 739d9885ae87..129fa18da3bc 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -2387,8 +2387,16 @@ def model_name(self) -> Union[str, Path]: class DwdpConfig(StrictBaseModel): - """ - Configuration for DWDP. + """Configuration for Distributed Weight Data Parallelism (DWDP). + + DWDP accelerates the context (prefill) phase of disaggregated MoE serving + by combining data parallelism with NVLink-based expert weight sharing. + Each worker holds a subset of experts locally and asynchronously prefetches + the remaining experts from peer workers via CUDA IPC, enabling fully + asynchronous execution across ranks without synchronization barriers. + + Currently supported with the CuteDSL MoE backend and NVFP4 quantization + on NVLink-connected multi-GPU systems. """ enabled: bool = Field(default=False, description="Whether to enable DWDP.") dwdp_size: int = Field(default=1, From 70403705bb8294f8ba256ed6ca0ababe622a6694 Mon Sep 17 00:00:00 2001 From: tianyuz-nv Date: Mon, 30 Mar 2026 04:11:13 -0700 Subject: [PATCH 07/12] Decouple DWDP from mainline disagg scripts per reviewer feedback Signed-off-by: tianyuz-nv --- .../slurm/benchmark/disaggr_torch.slurm | 5 - .../slurm/benchmark/disaggr_torch_dwdp.slurm | 189 +++++++ .../disaggregated/slurm/benchmark/submit.py | 201 ++------ .../slurm/benchmark/submit_dwdp.py | 483 ++++++++++++++++++ .../accuracy/test_disaggregated_serving.py | 247 --------- .../test_dwdp_disaggregated_serving.py | 283 ++++++++++ .../test_lists/qa/llm_function_core.txt | 2 +- .../test-db/l0_gb200_multi_gpus.yml | 2 +- 8 files changed, 1004 insertions(+), 408 deletions(-) create mode 100644 examples/disaggregated/slurm/benchmark/disaggr_torch_dwdp.slurm create mode 100644 examples/disaggregated/slurm/benchmark/submit_dwdp.py create mode 100644 tests/integration/defs/accuracy/test_dwdp_disaggregated_serving.py diff --git a/examples/disaggregated/slurm/benchmark/disaggr_torch.slurm b/examples/disaggregated/slurm/benchmark/disaggr_torch.slurm index 145c3099267b..9e0d66144478 100644 --- a/examples/disaggregated/slurm/benchmark/disaggr_torch.slurm +++ b/examples/disaggregated/slurm/benchmark/disaggr_torch.slurm @@ -150,11 +150,6 @@ replace_placeholder "${start_server_cmds_base_file}" "${all_nodes_str}" "${start server_config_base_file=${full_logdir}/server_config_base.yaml server_config_file=${full_logdir}/server_config.yaml replace_placeholder "${server_config_base_file}" "${all_nodes_str}" "${server_config_file}" -mpi_worker_config_base_file=${full_logdir}/mpi_worker_config_base.yaml -mpi_worker_config_file=${full_logdir}/mpi_worker_config.yaml -if [ -f "${mpi_worker_config_base_file}" ]; then - replace_placeholder "${mpi_worker_config_base_file}" "${all_nodes_str}" "${mpi_worker_config_file}" -fi client_cmds_base_file=${full_logdir}/client_cmds_base.sh client_cmds_file=${full_logdir}/client_cmds.sh replace_placeholder "${client_cmds_base_file}" "${all_nodes_str}" "${client_cmds_file}" diff --git a/examples/disaggregated/slurm/benchmark/disaggr_torch_dwdp.slurm b/examples/disaggregated/slurm/benchmark/disaggr_torch_dwdp.slurm new file mode 100644 index 000000000000..145c3099267b --- /dev/null +++ b/examples/disaggregated/slurm/benchmark/disaggr_torch_dwdp.slurm @@ -0,0 +1,189 @@ +#!/bin/bash +set -euo pipefail + +# Parse named arguments +while [[ $# -gt 0 ]]; do + case $1 in + # Benchmark Configuration + --benchmark-mode) benchmark_mode="$2"; shift 2 ;; + + # Environment and paths + --trtllm-repo) trtllm_repo="$2"; shift 2 ;; + --work-dir) work_dir="$2"; shift 2 ;; + --full-logdir) full_logdir="$2"; shift 2 ;; + --container-name) container_name="$2"; shift 2 ;; + --container-mount) container_mount="$2"; shift 2 ;; + --container-image) container_image="$2"; shift 2 ;; + --build-wheel) build_wheel="$2"; shift 2 ;; + --cuda-architectures) cuda_architectures="$2"; shift 2 ;; + --trtllm-wheel-path) trtllm_wheel_path="$2"; shift 2 ;; + *) + echo "Unknown argument: $1" + exit 1 + ;; + esac +done + +# Print all parsed arguments +echo "Parsed arguments:" +echo +echo "Benchmark Configuration:" +echo " benchmark_mode: ${benchmark_mode}" +echo +echo "Environment Configuration:" +echo " trtllm_repo: ${trtllm_repo}" +echo " work_dir: ${work_dir}" +echo " full_logdir: ${full_logdir}" +echo " container_mount: ${container_mount}" +echo " container_image: ${container_image}" +echo " build_wheel: ${build_wheel}" +echo " cuda_architectures: ${cuda_architectures}" +echo " trtllm_wheel_path: ${trtllm_wheel_path}" + +# Set TRTLLM_DISAGG_BENCHMARK_GEN_ONLY=1 for gen_only_no_context mode +if [ "${benchmark_mode}" = "gen_only_no_context" ]; then + export TRTLLM_DISAGG_BENCHMARK_GEN_ONLY=1 + echo "Setting TRTLLM_DISAGG_BENCHMARK_GEN_ONLY=1 for gen_only_no_context mode" +fi + +# Function to cleanup on failure +cleanup_on_failure() { + echo "Error: $1" + scancel ${SLURM_JOB_ID} + exit 1 +} + +replace_placeholder() { + file_path="$1" + all_nodes_str="$2" + new_file_path="$3" + cp "$file_path" "$new_file_path" + IFS=',' read -r -a node_array <<< "$all_nodes_str" + for i in "${!node_array[@]}"; do + current_val="${node_array[$i]}" + placeholder="" + + # Use sed to replace the placeholder with the value in-place + sed -i "s|$placeholder|$current_val|g" "${new_file_path}" + echo "Replaced $placeholder with $current_val in ${new_file_path}" + done +} + +env > ${full_logdir}/environment.txt + +# Start container +echo "Starting container..." +if ! srun -l --container-image=${container_image} \ + --container-name=${container_name} \ + --container-mounts=${container_mount} \ + --mpi=pmix \ + echo "Container up." &> ${full_logdir}/1_container_launch.log; then + cleanup_on_failure "Failed to start container. Check ${full_logdir}/1_container_launch.log" +fi + +# Install TensorRT-LLM +if [ -n "${trtllm_wheel_path}" ]; then + # Install from pre-built wheel if path is provided + echo "Installing TensorRT-LLM from wheel: ${trtllm_wheel_path}..." + if ! srun --container-name=${container_name} \ + --container-mounts=${container_mount} --no-container-mount-home \ + --mpi=pmix --overlap -N $SLURM_NNODES --ntasks-per-node=1 \ + bash -c "pip install ${trtllm_wheel_path}[devel]" \ + &> ${full_logdir}/2_install.log; then + cleanup_on_failure "TensorRT-LLM wheel installation failed. Check ${full_logdir}/2_install.log for details" + fi + echo "TensorRT-LLM wheel installation completed successfully" +elif [ -d "${trtllm_repo}" ]; then + # Build and install from repository if no wheel path provided + echo "Installing TensorRT-LLM from ${trtllm_repo}..." + TRT_LLM_GIT_COMMIT=$(git -C ${trtllm_repo} rev-parse --short HEAD 2>/dev/null || echo "unknown") + echo "TRT_LLM_GIT_COMMIT: ${TRT_LLM_GIT_COMMIT}" + + if [ "${build_wheel}" = "true" ]; then + echo "Building TensorRT-LLM wheel on one node..." + build_command="python3 ./scripts/build_wheel.py --trt_root /usr/local/tensorrt --benchmarks --use_ccache --clean" + if [ -n "${cuda_architectures:-}" ]; then + build_command="${build_command} --cuda_architectures \"${cuda_architectures}\"" + fi + if ! srun --container-name=${container_name} \ + --container-mounts=${container_mount} \ + --mpi=pmix --overlap -N 1 --ntasks-per-node=1 \ + bash -c "cd ${trtllm_repo} && ${build_command}" \ + &> ${full_logdir}/2_build.log; then + cleanup_on_failure "TensorRT-LLM build failed. Check ${full_logdir}/2_build.log for details" + fi + echo "TensorRT-LLM build completed successfully" + fi + + echo "Installing TensorRT-LLM..." + if ! srun --container-name=${container_name} \ + --container-mounts=${container_mount} --no-container-mount-home \ + --mpi=pmix --overlap -N $SLURM_NNODES --ntasks-per-node=1 \ + bash -c "cd ${trtllm_repo} && pip install -e .[devel]" \ + &> ${full_logdir}/2_install.log; then + cleanup_on_failure "TensorRT-LLM installation failed. Check ${full_logdir}/2_install.log for details" + fi + echo "TensorRT-LLM installation completed successfully" +else + echo "trtllm_wheel_path and trtllm_repo are not provided, will use the installed TensorRT-LLM from the container" + # get_env file is in the same directory as this script + get_env_file=${work_dir}/get_env.py + if ! srun --container-name=${container_name} \ + --container-mounts=${container_mount} --no-container-mount-home \ + --mpi=pmix --overlap -N 1 --ntasks-per-node=1 \ + bash -c "python ${get_env_file} -e ${full_logdir}/env_vars.json" \ + &> ${full_logdir}/2_get_env.log; then + cleanup_on_failure "Failed to get TensorRT-LLM environment variables. Check ${full_logdir}/2_get_env.log for details" + fi + echo "TensorRT-LLM environment variables saved to ${full_logdir}/env_vars.json" +fi + +# Get node lists and replace the placeholder with the actual node names +echo "SLURM_NODELIST: ${SLURM_NODELIST}" +all_nodes=($(scontrol show hostname $SLURM_NODELIST | sort)) +all_nodes_str=$(IFS=','; echo "${all_nodes[*]}") +echo "all_nodes_str: ${all_nodes_str}" + +start_server_cmds_base_file=${full_logdir}/start_server_cmds_base.sh +start_server_cmds_file=${full_logdir}/start_server_cmds.sh +replace_placeholder "${start_server_cmds_base_file}" "${all_nodes_str}" "${start_server_cmds_file}" +server_config_base_file=${full_logdir}/server_config_base.yaml +server_config_file=${full_logdir}/server_config.yaml +replace_placeholder "${server_config_base_file}" "${all_nodes_str}" "${server_config_file}" +mpi_worker_config_base_file=${full_logdir}/mpi_worker_config_base.yaml +mpi_worker_config_file=${full_logdir}/mpi_worker_config.yaml +if [ -f "${mpi_worker_config_base_file}" ]; then + replace_placeholder "${mpi_worker_config_base_file}" "${all_nodes_str}" "${mpi_worker_config_file}" +fi +client_cmds_base_file=${full_logdir}/client_cmds_base.sh +client_cmds_file=${full_logdir}/client_cmds.sh +replace_placeholder "${client_cmds_base_file}" "${all_nodes_str}" "${client_cmds_file}" + +# start the servers (skip ctx workers if TRTLLM_DISAGG_BENCHMARK_GEN_ONLY is set). +echo "Starting worker commands from ${start_server_cmds_file}..." +cat ${start_server_cmds_file} | while read cmd; do + # Skip ctx worker commands if in gen-only mode + # CTX appears as argument to start_worker.sh and in log filename + if [ "${TRTLLM_DISAGG_BENCHMARK_GEN_ONLY:-0}" = "1" ] && [[ "$cmd" == *"start_worker.sh CTX"* ]]; then + echo "Skipping ctx worker command (TRTLLM_DISAGG_BENCHMARK_GEN_ONLY is set): ${cmd}" + continue + fi + echo "Executing command: ${cmd}" + eval "${cmd}" +done +echo "Server is ready!" + +# Start client commands +echo "Starting client commands from ${client_cmds_file}..." +while read -r cmd <&3; do + echo "Starting client command: ${cmd}" + eval "${cmd}" + if [ $? -ne 0 ]; then + cleanup_on_failure "Command failed: ${cmd}." + fi +done 3< "${client_cmds_file}" + +echo "Job completed successfully, total runtime: $SECONDS seconds" + +# try to kill the server and workers +scancel ${SLURM_JOB_ID} diff --git a/examples/disaggregated/slurm/benchmark/submit.py b/examples/disaggregated/slurm/benchmark/submit.py index fd2aa7da0513..0dfa3d6fc356 100644 --- a/examples/disaggregated/slurm/benchmark/submit.py +++ b/examples/disaggregated/slurm/benchmark/submit.py @@ -49,45 +49,6 @@ def save_worker_config(worker_config, output_path): yaml.dump(worker_config, f, default_flow_style=False) -def generate_mpi_worker_config(worker_config, allocations, env_config, - disagg_hostname, disagg_port, output_path): - """Generate a config YAML compatible with ``trtllm-serve disaggregated_mpi_worker``. - """ - - def _build_urls(server_type): - urls = [] - for server_id in sorted(allocations.get(server_type, {}).keys()): - inst = allocations[server_type][server_id] - host = list(inst["nodes"].keys())[0] - urls.append(f"{host}:{inst['port']}") - return urls - - ctx_urls = _build_urls("CTX") - gen_urls = _build_urls("GEN") - - ctx_section = dict(worker_config['ctx']) - ctx_section['num_instances'] = len(ctx_urls) - ctx_section['urls'] = ctx_urls - - gen_section = dict(worker_config['gen']) - gen_section['num_instances'] = len(gen_urls) - gen_section['urls'] = gen_urls - - config = { - 'model': env_config['model_path'], - 'hostname': disagg_hostname, - 'port': disagg_port, - 'backend': 'pytorch', - 'max_retries': 100, - 'context_servers': ctx_section, - 'generation_servers': gen_section, - } - - os.makedirs(os.path.dirname(output_path), exist_ok=True) - with open(output_path, 'w') as f: - yaml.dump(config, f, default_flow_style=False, sort_keys=False) - - def calculate_nodes(world_size, num_servers, gpus_per_node): """Calculate required nodes based on world size and server count.""" return math.ceil(world_size * num_servers / gpus_per_node) @@ -448,13 +409,6 @@ def submit_job(config, log_dir, dry_run): total_nodes = ctx_nodes + gen_nodes total_tasks = total_nodes * gpus_per_node - # Detect DWDP mode: when enabled, use a single srun with - # trtllm-serve disaggregated_mpi_worker instead of per-instance sruns - dwdp_enabled = worker_config.get('ctx', {}).get('dwdp_config', - {}).get('enabled', False) - dwdp_size = worker_config.get('ctx', {}).get('dwdp_config', - {}).get('dwdp_size', 1) - # Generate log directory path based on configuration isl = benchmark_config['input_length'] osl = benchmark_config['output_length'] @@ -483,13 +437,10 @@ def submit_job(config, log_dir, dry_run): log_base = os.path.join(log_base, f"{date_prefix}/{isl}-{osl}") # Determine directory suffix based on attention_dp - if dwdp_enabled: - dir_suffix = f"disagg_ctx{ctx_num}_dwdp{dwdp_size}_gen{gen_num}_dep{gen_tp_size}_batch{gen_batch_size}_eplb{eplb_num_slots}_mtp{mtp_size}" + if gen_enable_attention_dp: + dir_suffix = f"disagg_ctx{ctx_num}_gen{gen_num}_dep{gen_tp_size}_batch{gen_batch_size}_eplb{eplb_num_slots}_mtp{mtp_size}" else: - if gen_enable_attention_dp: - dir_suffix = f"disagg_ctx{ctx_num}_gen{gen_num}_dep{gen_tp_size}_batch{gen_batch_size}_eplb{eplb_num_slots}_mtp{mtp_size}" - else: - dir_suffix = f"disagg_ctx{ctx_num}_gen{gen_num}_tep{gen_tp_size}_batch{gen_batch_size}_eplb{eplb_num_slots}_mtp{mtp_size}" + dir_suffix = f"disagg_ctx{ctx_num}_gen{gen_num}_tep{gen_tp_size}_batch{gen_batch_size}_eplb{eplb_num_slots}_mtp{mtp_size}" # Create full log directory path log_dir = os.path.join(log_base, dir_suffix) @@ -558,108 +509,50 @@ def submit_job(config, log_dir, dry_run): } } - if dwdp_enabled: - # --- DWDP mode: single srun with disaggregated_mpi_worker --- - mpi_config_base_path = os.path.join(log_dir, - 'mpi_worker_config_base.yaml') - mpi_config_path = os.path.join(log_dir, 'mpi_worker_config.yaml') - generate_mpi_worker_config(worker_config, allocations, env_config, - disagg_server_hostname, disagg_server_port, - mpi_config_base_path) - - # Nodelist: CTX nodes first, then GEN nodes (matches - # split_world_comm order: server_configs = ctx_cfgs + gen_cfgs) - ctx_node_list = [] - for sid in sorted(allocations.get("CTX", {}).keys()): - for node in allocations["CTX"][sid]["nodes"]: - if node not in ctx_node_list: - ctx_node_list.append(node) - gen_node_list = [] - for sid in sorted(allocations.get("GEN", {}).keys()): - for node in allocations["GEN"][sid]["nodes"]: - if node not in gen_node_list: - gen_node_list.append(node) - mpi_nodelist = ctx_node_list + gen_node_list - total_mpi_tasks = ctx_num * ctx_world_size + gen_num * gen_world_size - mpi_num_nodes = len(mpi_nodelist) - num_ctx_gpus = ctx_num * ctx_world_size - worker_env_var = env_config.get('worker_env_var', '') - ctx_worker_env_var = env_config.get('ctx_worker_env_var', '') - gen_worker_env_var = env_config.get('gen_worker_env_var', '') - dwdp_ctx_worker_env_var = worker_env_var + \ - (f" {ctx_worker_env_var}" if ctx_worker_env_var else "") - dwdp_gen_worker_env_var = worker_env_var + \ - (f" {gen_worker_env_var}" if gen_worker_env_var else "") - - cmd = [ - "srun -l", - f"--nodelist {','.join(mpi_nodelist)}", - f"-N {mpi_num_nodes}", - f"--ntasks {total_mpi_tasks}", - f"--ntasks-per-node {gpus_per_node}", - f"--container-image {env_config['container_image']}", - f"--container-name {container_name}", - f"--container-mounts {container_mount_str}", - "--no-container-mount-home --mpi=pmix --overlap", - f"bash {os.path.join(script_dir, 'start_worker_dwdp.sh')}", - mpi_config_path, - str(slurm_config['numa_bind']).lower(), - log_dir, - str(profiling_config['nsys_on']).lower(), - f"'{profiling_config['ctx_profile_range']}'", - f"'{profiling_config['gen_profile_range']}'", - str(num_ctx_gpus), - f"'{dwdp_ctx_worker_env_var}'", - f"'{dwdp_gen_worker_env_var}'", - f"&> {log_dir}/3_output_workers.log &", - ] - start_server_cmds.append(" ".join(cmd)) - else: - # --- Standard mode: per-instance srun --- - for server_type in allocations.keys(): - server_cfg = server_configs[server_type] - - for server_id in allocations[server_type].keys(): - allocation = allocations[server_type][server_id] - gpu_ids = list(allocation["nodes"].values())[0] - - cuda_devices = ','.join(map(str, gpu_ids)) - worker_env = build_worker_environment( - worker_config=worker_config, - env_config=env_config, - role=server_type, - benchmark_mode=benchmark_config['mode'], - nsys_on=profiling_config['nsys_on'], - profile_range=server_cfg['profile_range'], - concurrency=benchmark_config['concurrency_list'].split(',') - [0], - ) - export_str = format_export_string(worker_env) - - cmd = [ - "srun -l", - f"--nodelist {','.join(allocation['nodes'].keys())}", - f"-N {len(allocation['nodes'])}", - f"--ntasks {server_cfg['world_size']}", - f"--ntasks-per-node {gpus_per_node}", - f"--export=\"{export_str}\"", - f"--container-image {env_config['container_image']}", - f"--container-name {container_name}", - f"--container-mounts {container_mount_str}", - "--no-container-mount-home --mpi=pmix --overlap", - f"bash {os.path.join(script_dir, 'start_worker.sh')}", - server_type, - str(server_id), - env_config['model_path'], - str(allocation["port"]), - str(slurm_config['numa_bind']).lower(), - log_dir, - str(profiling_config['nsys_on']).lower(), - server_cfg['config_path'], - cuda_devices, - f"&> {log_dir}/3_output_{server_type}_{server_id}.log &", - ] - start_server_cmds.append(" ".join(cmd)) + for server_type in allocations.keys(): + server_cfg = server_configs[server_type] + + for server_id in allocations[server_type].keys(): + allocation = allocations[server_type][server_id] + gpu_ids = list(allocation["nodes"].values())[0] + + cuda_devices = ','.join(map(str, gpu_ids)) + worker_env = build_worker_environment( + worker_config=worker_config, + env_config=env_config, + role=server_type, + benchmark_mode=benchmark_config['mode'], + nsys_on=profiling_config['nsys_on'], + profile_range=server_cfg['profile_range'], + concurrency=benchmark_config['concurrency_list'].split(',') + [0], + ) + export_str = format_export_string(worker_env) + + cmd = [ + "srun -l", + f"--nodelist {','.join(allocation['nodes'].keys())}", + f"-N {len(allocation['nodes'])}", + f"--ntasks {server_cfg['world_size']}", + f"--ntasks-per-node {gpus_per_node}", + f"--export=\"{export_str}\"", + f"--container-image {env_config['container_image']}", + f"--container-name {container_name}", + f"--container-mounts {container_mount_str}", + "--no-container-mount-home --mpi=pmix --overlap", + f"bash {os.path.join(script_dir, 'start_worker.sh')}", + server_type, + str(server_id), + env_config['model_path'], + str(allocation["port"]), + str(slurm_config['numa_bind']).lower(), + log_dir, + str(profiling_config['nsys_on']).lower(), + server_cfg['config_path'], + cuda_devices, + f"&> {log_dir}/3_output_{server_type}_{server_id}.log &", + ] + start_server_cmds.append(" ".join(cmd)) # Generate start server commands (use script_dir for start_server.sh) server_env = build_server_environment(env_config, benchmark_config['mode']) diff --git a/examples/disaggregated/slurm/benchmark/submit_dwdp.py b/examples/disaggregated/slurm/benchmark/submit_dwdp.py new file mode 100644 index 000000000000..0bb19cce7c8f --- /dev/null +++ b/examples/disaggregated/slurm/benchmark/submit_dwdp.py @@ -0,0 +1,483 @@ +#!/usr/bin/env python3 +"""Submit DWDP disaggregated benchmark jobs. + +This script handles the DWDP-specific submission flow which requires MPI-based +worker launching via ``trtllm-serve disaggregated_mpi_worker``. It reuses +shared utilities from ``submit.py`` for config parsing, GPU allocation, and +sbatch command construction. +""" + +import argparse +import glob +import json +import math +import os +import shutil +import subprocess +import sys +import traceback +from datetime import datetime +from typing import Any, Dict, List + +import yaml + +from submit import ( + allocate_gpus, + build_server_environment, + build_worker_environment, + calculate_nodes, + convert_allocations_to_server_config, + convert_envs_to_str, + format_export_string, + load_config, + replace_env_in_file, + save_env_file, + save_worker_config, +) + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Submit DWDP disaggregated benchmark job') + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument('-c', + '--config', + type=str, + help='Path to the configuration YAML file') + group.add_argument('-d', + '--dir', + type=str, + help='Directory containing YAML configuration files') + parser.add_argument('--log-dir', + type=str, + default=None, + help='Log directory') + parser.add_argument('--dry-run', + action='store_true', + help='Dry run the Python part, test purpose only') + return parser.parse_args() + + +def generate_mpi_worker_config(worker_config, allocations, env_config, + disagg_hostname, disagg_port, output_path): + """Generate a config YAML compatible with ``trtllm-serve disaggregated_mpi_worker``.""" + + def _build_urls(server_type): + urls = [] + for server_id in sorted(allocations.get(server_type, {}).keys()): + inst = allocations[server_type][server_id] + host = list(inst["nodes"].keys())[0] + urls.append(f"{host}:{inst['port']}") + return urls + + ctx_urls = _build_urls("CTX") + gen_urls = _build_urls("GEN") + + ctx_section = dict(worker_config['ctx']) + ctx_section['num_instances'] = len(ctx_urls) + ctx_section['urls'] = ctx_urls + + gen_section = dict(worker_config['gen']) + gen_section['num_instances'] = len(gen_urls) + gen_section['urls'] = gen_urls + + config = { + 'model': env_config['model_path'], + 'hostname': disagg_hostname, + 'port': disagg_port, + 'backend': 'pytorch', + 'max_retries': 100, + 'context_servers': ctx_section, + 'generation_servers': gen_section, + } + + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, 'w') as f: + yaml.dump(config, f, default_flow_style=False, sort_keys=False) + + +def submit_dwdp_job(config, log_dir, dry_run): + """Submit a DWDP disaggregated benchmark job.""" + slurm_config = config['slurm'] + slurm_config.setdefault('extra_args', '') + slurm_config.setdefault('set_segment', True) + + hw_config = config['hardware'] + env_config = config['environment'] + worker_config = config['worker_config'] + benchmark_config = config['benchmark'] + + if 'work_dir' in env_config and os.path.isdir(env_config['work_dir']): + script_dir = env_config['work_dir'] + else: + script_dir = os.path.dirname(os.path.abspath(__file__)) + + if 'accuracy' not in config: + config['accuracy'] = { + 'enable_accuracy_test': + False, + 'model': + 'local-completions', + 'tasks': + 'gsm8k', + 'model_args_extra': + 'num_concurrent=512,max_retries=3,tokenized_requests=false,timeout=1200,max_gen_toks=256,max_length=4096' + } + + env_config.setdefault('trtllm_repo', '') + env_config.setdefault('build_wheel', False) + env_config.setdefault('cuda_architectures', '') + env_config.setdefault('trtllm_wheel_path', '') + env_config.setdefault('worker_env_var', '') + env_config.setdefault('server_env_var', '') + + profiling_config = config.get('profiling', {}) + profiling_config.setdefault('nsys_on', False) + profiling_config.setdefault('ctx_profile_range', '10-30') + profiling_config.setdefault('gen_profile_range', '200-250') + + ctx_num = hw_config['num_ctx_servers'] + gen_num = hw_config['num_gen_servers'] + gpus_per_node = hw_config['gpus_per_node'] + + ctx_tp_size = worker_config['ctx'].get('tensor_parallel_size', 1) + ctx_cp_size = worker_config['ctx'].get('context_parallel_size', 1) + ctx_pp_size = worker_config['ctx'].get('pipeline_parallel_size', 1) + ctx_world_size = ctx_tp_size * ctx_cp_size * ctx_pp_size + ctx_nodes = calculate_nodes(ctx_world_size, ctx_num, gpus_per_node) + + gen_tp_size = worker_config['gen'].get('tensor_parallel_size', 1) + gen_cp_size = worker_config['gen'].get('context_parallel_size', 1) + gen_pp_size = worker_config['gen'].get('pipeline_parallel_size', 1) + gen_world_size = gen_tp_size * gen_cp_size * gen_pp_size + gen_nodes = calculate_nodes(gen_world_size, gen_num, gpus_per_node) + ucx_warmup_requests = 2 * ctx_world_size * \ + gen_world_size if benchmark_config['mode'] == "e2e" else 0 + + total_nodes = ctx_nodes + gen_nodes + total_tasks = total_nodes * gpus_per_node + + dwdp_size = worker_config.get('ctx', {}).get('dwdp_config', + {}).get('dwdp_size', 1) + + isl = benchmark_config['input_length'] + osl = benchmark_config['output_length'] + gen_batch_size = worker_config['gen']['max_batch_size'] + + load_balancer_config = worker_config['gen'].get('moe_config', {}).get( + 'load_balancer', {}) + if isinstance(load_balancer_config, str): + with open(load_balancer_config, 'r') as f: + load_balancer_config = yaml.safe_load(f) + eplb_num_slots = load_balancer_config.get('num_slots', 0) + + mtp_size = worker_config['gen'].get('speculative_config', + {}).get('num_nextn_predict_layers', 0) + + if 'log_dir' in env_config and env_config['log_dir']: + log_dir = env_config['log_dir'] + if log_dir is None: + log_base = os.path.join(script_dir, "logs") + + date_prefix = datetime.now().strftime("%Y%m%d-%H%M%S") + log_base = os.path.join(log_base, f"{date_prefix}/{isl}-{osl}") + + dir_suffix = f"disagg_ctx{ctx_num}_dwdp{dwdp_size}_gen{gen_num}_dep{gen_tp_size}_batch{gen_batch_size}_eplb{eplb_num_slots}_mtp{mtp_size}" + + log_dir = os.path.join(log_base, dir_suffix) + + if os.path.exists(log_dir): + if not os.path.exists(os.path.join(log_dir, 'trtllm_config.yaml')): + print(f"[WARNING] Removing existing log directory: {log_dir}") + shutil.rmtree(log_dir) + else: + print( + f"[WARNING] trtllm_config.yaml exists, not removing the directory: {log_dir}" + ) + for file in os.listdir(log_dir): + if file != 'trtllm_config.yaml' and not file.startswith( + 'concurrency_'): + if os.path.isdir(os.path.join(log_dir, file)): + shutil.rmtree(os.path.join(log_dir, file)) + else: + os.remove(os.path.join(log_dir, file)) + os.makedirs(log_dir, exist_ok=True) + print(f"Log will be saved to: {log_dir}") + + ctx_config_path = os.path.join(log_dir, 'ctx_config.yaml') + gen_config_path = os.path.join(log_dir, 'gen_config.yaml') + save_worker_config(worker_config['ctx'], ctx_config_path) + save_worker_config(worker_config['gen'], gen_config_path) + + allocations = allocate_gpus( + total_nodes=total_nodes, + gpus_per_node=gpus_per_node, + num_gen_servers=gen_num, + num_ctx_servers=ctx_num, + gen_world_size=gen_world_size, + ctx_world_size=ctx_world_size, + ) + with open(os.path.join(log_dir, "allocations.json"), "w") as f: + json.dump(allocations, f, indent=2) + + server_config = convert_allocations_to_server_config(allocations) + with open(os.path.join(log_dir, "server_config_base.yaml"), "w") as f: + yaml.dump(server_config, f) + disagg_server_hostname = server_config['hostname'] + disagg_server_port = server_config['port'] + + container_name = "disaggr-test" + start_server_cmds = [] + container_mount_str = env_config['container_mount'] + container_mount_str += f",{script_dir}:{script_dir}" + + # --- DWDP mode: single srun with disaggregated_mpi_worker --- + mpi_config_base_path = os.path.join(log_dir, + 'mpi_worker_config_base.yaml') + mpi_config_path = os.path.join(log_dir, 'mpi_worker_config.yaml') + generate_mpi_worker_config(worker_config, allocations, env_config, + disagg_server_hostname, disagg_server_port, + mpi_config_base_path) + + ctx_node_list = [] + for sid in sorted(allocations.get("CTX", {}).keys()): + for node in allocations["CTX"][sid]["nodes"]: + if node not in ctx_node_list: + ctx_node_list.append(node) + gen_node_list = [] + for sid in sorted(allocations.get("GEN", {}).keys()): + for node in allocations["GEN"][sid]["nodes"]: + if node not in gen_node_list: + gen_node_list.append(node) + mpi_nodelist = ctx_node_list + gen_node_list + total_mpi_tasks = ctx_num * ctx_world_size + gen_num * gen_world_size + mpi_num_nodes = len(mpi_nodelist) + num_ctx_gpus = ctx_num * ctx_world_size + worker_env_var = env_config.get('worker_env_var', '') + ctx_worker_env_var = env_config.get('ctx_worker_env_var', '') + gen_worker_env_var = env_config.get('gen_worker_env_var', '') + dwdp_ctx_worker_env_var = worker_env_var + \ + (f" {ctx_worker_env_var}" if ctx_worker_env_var else "") + dwdp_gen_worker_env_var = worker_env_var + \ + (f" {gen_worker_env_var}" if gen_worker_env_var else "") + + cmd = [ + "srun -l", + f"--nodelist {','.join(mpi_nodelist)}", + f"-N {mpi_num_nodes}", + f"--ntasks {total_mpi_tasks}", + f"--ntasks-per-node {gpus_per_node}", + f"--container-image {env_config['container_image']}", + f"--container-name {container_name}", + f"--container-mounts {container_mount_str}", + "--no-container-mount-home --mpi=pmix --overlap", + f"bash {os.path.join(script_dir, 'start_worker_dwdp.sh')}", + mpi_config_path, + str(slurm_config['numa_bind']).lower(), + log_dir, + str(profiling_config['nsys_on']).lower(), + f"'{profiling_config['ctx_profile_range']}'", + f"'{profiling_config['gen_profile_range']}'", + str(num_ctx_gpus), + f"'{dwdp_ctx_worker_env_var}'", + f"'{dwdp_gen_worker_env_var}'", + f"&> {log_dir}/3_output_workers.log &", + ] + start_server_cmds.append(" ".join(cmd)) + + # Generate start server commands + server_env = build_server_environment(env_config, benchmark_config['mode']) + export_str = format_export_string(server_env) + + cmd = [ + "srun -l", + f"--nodelist {disagg_server_hostname}", + f"--container-name={container_name}", + f"--export=\"{export_str}\"", + f"--container-image={env_config['container_image']}", + f"--container-mounts={container_mount_str}", + f"--no-container-mount-home --mpi=pmix --overlap -N 1 -n 1", + f"bash {os.path.join(script_dir, 'start_server.sh')} {os.path.join(log_dir, 'server_config.yaml')}", + f"&> {log_dir}/4_output_server.log &", + ] + start_server_cmds.append(" ".join(cmd)) + + save_env_file( + os.path.join(log_dir, "env_vars.json"), + env_config.get('server_env_var', ''), + env_config.get('worker_env_var', ''), + env_config.get('ctx_worker_env_var', ''), + env_config.get('gen_worker_env_var', ''), + ) + + # Generate wait server command + cmd = [ + "srun -l", + f"--container-name={container_name}", + f"--container-mounts={container_mount_str}", + f"--mpi=pmix --overlap -N 1 -n 1", + f"bash {os.path.join(script_dir, 'wait_server.sh')} {disagg_server_hostname} {disagg_server_port}", + f"&> {log_dir}/5_wait_server.log", + ] + start_server_cmds.append(" ".join(cmd)) + + with open(os.path.join(log_dir, "start_server_cmds_base.sh"), "w") as f: + f.write("\n".join(start_server_cmds) + "\n") + + # Generate client commands + client_cmds = [] + client_slurm_prefix = [ + f"srun -l --container-name={container_name}", + f"--container-mounts={container_mount_str}", + f"--mpi=pmix --overlap -N 1 -n 1", + ] + if benchmark_config.get('enable_benchmark', True): + env_var = config['benchmark'].get('env_var', {}) + benchmark_prefix = client_slurm_prefix + [ + f"--export \"{convert_envs_to_str(env_var)}\"" + ] + if benchmark_config['use_nv_sa_benchmark']: + if benchmark_config['mode'] == "gen_only": + print( + f"[ERROR] SA benchmark client script is not supported for gen_only mode" + ) + sys.exit(1) + benchmark_cmd = [ + f"bash {os.path.join(script_dir, 'run_benchmark_nv_sa.sh')}", + f"'{env_config['model_path']}' {isl} {osl} {benchmark_config['benchmark_ratio']} {benchmark_config['multi_round']} {gen_num} '{benchmark_config['concurrency_list']}' {benchmark_config['streaming']} '{log_dir}' {disagg_server_hostname} {disagg_server_port} {ucx_warmup_requests}", + f"&> {log_dir}/6_bench.log" + ] + client_cmds.append(" ".join(benchmark_prefix + benchmark_cmd)) + else: + benchmark_cmd = [ + f"bash {os.path.join(script_dir, 'run_benchmark.sh')}", + f"'{env_config['model_path']}' '{benchmark_config['dataset_file']}' {benchmark_config['multi_round']} {gen_num} '{benchmark_config['concurrency_list']}' {benchmark_config['streaming']} '{log_dir}' {disagg_server_hostname} {disagg_server_port} {ucx_warmup_requests}", + f"&> {log_dir}/6_bench.log" + ] + client_cmds.append(" ".join(benchmark_prefix + benchmark_cmd)) + + if config['accuracy']['enable_accuracy_test']: + env_var = config['accuracy'].get('env_var', {}) + accuracy_prefix = client_slurm_prefix + [ + f"--export \"{convert_envs_to_str(env_var)}\"" + ] + for task in config['accuracy']['tasks']: + extra_kwargs = config['accuracy']['tasks'][task].get( + 'extra_kwargs', {}) + extra_kwargs_str = "" + for key, value in extra_kwargs.items(): + if isinstance(value, bool): + if value: + extra_kwargs_str += f" --{key}" + elif key == "custom_config": + extra_kwargs_str += f" --include_path={replace_env_in_file(log_dir, value, env_var)}" + else: + extra_kwargs_str += f" --{key}='{value}'" + end_point_map = { + 'local-completions': 'v1/completions', + 'local-chat-completions': 'v1/chat/completions', + } + model = config['accuracy']['tasks'][task]['model'] + accuracy_cmd = [ + 'lm_eval', '--model', model, '--tasks', task, '--model_args', + f"model={env_config['model_path']},base_url=http://{disagg_server_hostname}:{disagg_server_port}/{end_point_map[model]},{config['accuracy']['tasks'][task]['model_args_extra']}", + '--log_samples', '--output_path', + f'{log_dir}/accuracy_eval_{task}', extra_kwargs_str, + f"&> {log_dir}/7_accuracy_eval_{task}.log" + ] + client_cmds.append(" ".join(accuracy_prefix + accuracy_cmd)) + + done_cmd = [ + "echo", "${SLURM_JOB_NODELIST}", ">", + f"{log_dir}/8_done_${{SLURM_JOB_ID}}.txt" + ] + client_cmds.append(" ".join(done_cmd)) + + with open(os.path.join(log_dir, "client_cmds_base.sh"), "w") as f: + f.write("\n".join(client_cmds) + "\n") + + slurm_script_file = slurm_config['script_file'] + if not os.path.isabs(slurm_script_file): + slurm_script_file = os.path.join(script_dir, slurm_script_file) + + if not os.path.exists(slurm_script_file): + print(f"[ERROR] SLURM script file not found: {slurm_script_file}", + file=sys.stderr) + sys.exit(1) + + # yapf: disable + cmd = [ + 'sbatch', + f'--partition={slurm_config["partition"]}', + f'--account={slurm_config["account"]}', + f'--time={slurm_config["job_time"]}', + f'--job-name={slurm_config["job_name"]}', + f'--nodes={total_nodes}', + f'--ntasks={total_tasks}', + f'--ntasks-per-node={hw_config["gpus_per_node"]}', + *([] if not slurm_config['set_segment'] + else [f'--segment={total_nodes}']), + f'--output={log_dir}/slurm-%j.out', + f'--error={log_dir}/slurm-%j.err', + *([arg for arg in slurm_config['extra_args'].split() if arg]), + slurm_script_file, + + '--benchmark-mode', benchmark_config['mode'], + + '--trtllm-repo', env_config['trtllm_repo'], + '--work-dir', script_dir, + '--full-logdir', log_dir, + '--container-name', container_name, + '--container-mount', container_mount_str, + '--container-image', env_config['container_image'], + '--build-wheel', str(env_config['build_wheel']).lower(), + '--cuda-architectures', env_config['cuda_architectures'], + '--trtllm-wheel-path', env_config['trtllm_wheel_path'], + ] + # yapf: enable + + if dry_run: + print( + "[WARNING] Dry run mode, will not submit the job. This should be used for test purpose only." + ) + print("sbatch command:") + print(" ".join(cmd)) + return + else: + try: + subprocess.run(cmd, check=True) + except subprocess.CalledProcessError as e: + print(f"Error submitting job: {e}", file=sys.stderr) + sys.exit(1) + + +def main(): + args = parse_args() + + if args.config: + config_files = [args.config] + else: + yaml_pattern = os.path.join(args.dir, '*.yaml') + config_files = sorted(glob.glob(yaml_pattern)) + + if not config_files: + print(f"No YAML files found in directory: {args.dir}", + file=sys.stderr) + sys.exit(1) + + print(f"Found {len(config_files)} YAML file(s) in {args.dir}") + + for config_file in config_files: + print(f"Processing: {config_file}") + try: + config = load_config(config_file) + submit_dwdp_job(config, args.log_dir, args.dry_run) + print(f"Successfully submitted job for: {config_file}\n") + except Exception as e: + traceback.print_exc() + print(f"Error processing {config_file}: {e}", file=sys.stderr) + continue + + +if __name__ == '__main__': + main() diff --git a/tests/integration/defs/accuracy/test_disaggregated_serving.py b/tests/integration/defs/accuracy/test_disaggregated_serving.py index 35825c154115..c1c30ea256f1 100644 --- a/tests/integration/defs/accuracy/test_disaggregated_serving.py +++ b/tests/integration/defs/accuracy/test_disaggregated_serving.py @@ -1139,111 +1139,6 @@ def test_guided_decoding(self, backend: str, mtp_nextn: int, mocker): self.MODEL_PATH) as llm: run_accuracy_test(llm, self.MODEL_NAME, ["JsonModeEval"]) - @pytest.mark.skip_less_device(4) - @skip_pre_blackwell - def test_dwdp_accuracy(self): - model_path = f"{llm_models_root()}/DeepSeek-V3-Lite/nvfp4_moe_only_mtp" - - ctx_port_0 = get_free_port() - ctx_port_1 = get_free_port() - gen_port = get_free_port() - serve_port = get_free_port() - - ctx_server_config = { - "num_instances": 2, - "urls": [ - f"localhost:{ctx_port_0}", - f"localhost:{ctx_port_1}", - ], - "tensor_parallel_size": 1, - "pipeline_parallel_size": 1, - "disable_overlap_scheduler": True, - "enable_autotuner": False, - "enable_chunked_prefill": False, - "cuda_graph_config": None, - "max_batch_size": 16, - "max_num_tokens": 8192, - "kv_cache_config": { - "free_gpu_memory_fraction": 0.4, - "enable_block_reuse": False, - "enable_partial_reuse": False, - "tokens_per_block": 32, - }, - "cache_transceiver_config": { - "backend": "UCX", - "max_tokens_in_buffer": 8192, - }, - "moe_config": { - "backend": "CUTEDSL", - }, - "dwdp_config": { - "enabled": True, - "dwdp_size": 2, - "num_group": 1, - "experts_per_worker": 36, - "num_prefetch_experts": 36, - }, - } - - gen_server_config = { - "num_instances": 1, - "urls": [f"localhost:{gen_port}"], - "tensor_parallel_size": 2, - "pipeline_parallel_size": 1, - "disable_overlap_scheduler": True, - "enable_autotuner": False, - "enable_chunked_prefill": False, - "cuda_graph_config": None, - "max_batch_size": 128, - "max_num_tokens": 1024, - "kv_cache_config": { - "free_gpu_memory_fraction": 0.5, - "enable_block_reuse": False, - "enable_partial_reuse": False, - "tokens_per_block": 32, - }, - "cache_transceiver_config": { - "backend": "UCX", - "max_tokens_in_buffer": 8192, - }, - "moe_config": { - "backend": "CUTEDSL", - }, - } - - worker_config = { - "model": model_path, - "hostname": "localhost", - "port": serve_port, - "backend": "pytorch", - "context_servers": ctx_server_config, - "generation_servers": gen_server_config, - } - - frontend_config = { - "backend": "pytorch", - "hostname": "localhost", - "port": serve_port, - "context_servers": { - "num_instances": 2, - "urls": [ - f"localhost:{ctx_port_0}", - f"localhost:{ctx_port_1}", - ], - }, - "generation_servers": { - "num_instances": 1, - "urls": [f"localhost:{gen_port}"], - }, - } - - with launch_dwdp_disaggregated_llm(worker_config, - frontend_config, - model_path, - total_gpus=4, - max_workers=128) as llm: - run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"]) - @pytest.mark.timeout(DEFAULT_TEST_TIMEOUT) class TestGemma3_1BInstruct(LlmapiAccuracyTestHarness): @@ -1918,145 +1813,3 @@ def test_nixl_backend(self): with launch_disaggregated_llm(disagg_cfg, ctx_cfg, gen_cfg, self.MODEL_PATH) as llm: run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"]) - - -@contextlib.contextmanager -def launch_dwdp_disaggregated_llm( - worker_config: Dict[str, Any], - frontend_config: Dict[str, Any], - model_path: str, - total_gpus: int, - server_waiting_timeout: int = DEFAULT_SERVER_WAITING_TIMEOUT, - max_workers: int = 128, -): - """Launch DWDP disaggregated serving via mpirun. - - DWDP requires all workers (CTX + GEN) in a single MPI world for - IPC handle exchange and DWDP group formation. This function starts - all workers with ``mpirun`` and launches a separate disaggregated - frontend server for the client-facing OpenAI API. - """ - temp_dir = tempfile.TemporaryDirectory() - worker_config_path = os.path.join(temp_dir.name, "worker_config.yaml") - frontend_config_path = os.path.join(temp_dir.name, "frontend_config.yaml") - - with open(worker_config_path, "w") as f: - yaml.dump(worker_config, f, default_flow_style=False, sort_keys=False) - with open(frontend_config_path, "w") as f: - yaml.dump(frontend_config, f, default_flow_style=False, sort_keys=False) - - serve_port = frontend_config["port"] - - # Prevent the parent process's MPI state (set by mpi4py init during - # tensorrt_llm import) from leaking into the mpirun subprocess. - # mpirun must create a fresh MPI world for the DWDP workers. - child_env = { - k: v - for k, v in os.environ.items() - if not k.startswith(('OMPI_', 'PMIX_', 'PMI_')) - } - - mpi_cmd = [ - "mpirun", "--allow-run-as-root", "-n", - str(total_gpus), "trtllm-serve", "disaggregated_mpi_worker", "-c", - worker_config_path - ] - - frontend_cmd = [ - "trtllm-serve", "disaggregated", "-c", frontend_config_path, - "--server_start_timeout", - str(server_waiting_timeout), "-r", "360000" - ] - - with ( - MyThreadPoolExecutor(max_workers=max_workers) as thread_pool, - temp_dir, - popen(mpi_cmd, env=child_env) as mpi_proc, - popen(frontend_cmd, env=child_env) as frontend_proc, - ): - start_time = time.time() - server_is_ready = False - while time.time() - start_time < server_waiting_timeout: - time.sleep(5) - for proc, name in [ - (mpi_proc, "mpirun"), - (frontend_proc, "frontend"), - ]: - if proc.poll() is not None: - raise Exception( - f"{name} process exited with code {proc.returncode}") - try: - response = requests.get( - f"http://localhost:{serve_port}/cluster_info") - if response.status_code == 200: - cluster_info = response.json() - if cluster_info.get("is_ready"): - print(f"DWDP cluster ready: {cluster_info}") - server_is_ready = True - break - except requests.exceptions.ConnectionError: - continue - if not server_is_ready: - pytest.fail( - f"DWDP server not ready after {server_waiting_timeout}s") - - model_name = worker_config.get("model", model_path) - client = openai.OpenAI(api_key="1234567890", - base_url=f"http://localhost:{serve_port}/v1", - timeout=1800000) - - def send_request(prompt: str, sampling_params: SamplingParams, - streaming: bool): - kwargs = {} - if sampling_params is not None: - kwargs.update( - max_tokens=sampling_params.max_tokens, - temperature=(sampling_params.temperature - if sampling_params.top_p is not None else 0), - top_p=sampling_params.top_p, - stop=sampling_params.stop, - seed=sampling_params.seed) - response = client.completions.create(model=model_name, - prompt=prompt, - stream=streaming, - **kwargs) - result = Result(id=0, - sampling_params=sampling_params, - outputs=[ - CompletionOutput(text=response.choices[0].text, - index=0) - ]) - requested_output = RequestOutput._from_generation_result( - result, prompt=prompt) - setattr(requested_output, "result", result.result) - return requested_output - - def generate_async(prompt: str, - sampling_params: Optional[SamplingParams] = None, - streaming: bool = False): - future = thread_pool.submit(send_request, prompt, sampling_params, - streaming) - thread_pool.futures.append(future) - return future - - args = LlmArgs(model=model_path) - tokenizer = load_hf_tokenizer(model_path) - try: - yield DuckLLM(args, tokenizer, generate_async) - finally: - all_procs = [frontend_proc, mpi_proc] - for proc in all_procs: - if proc.poll() is None: - proc.terminate() - deadline = time.monotonic() + 5 - for proc in all_procs: - remaining = max(0, deadline - time.monotonic()) - try: - proc.wait(timeout=remaining) - except subprocess.TimeoutExpired: - try: - proc.kill() - except ProcessLookupError: - pass - except OSError: - pass diff --git a/tests/integration/defs/accuracy/test_dwdp_disaggregated_serving.py b/tests/integration/defs/accuracy/test_dwdp_disaggregated_serving.py new file mode 100644 index 000000000000..bea6bb8f7ae7 --- /dev/null +++ b/tests/integration/defs/accuracy/test_dwdp_disaggregated_serving.py @@ -0,0 +1,283 @@ +"""DWDP disaggregated serving accuracy tests. + +Separated from test_disaggregated_serving.py to isolate MPI-dependent test +infrastructure for easier maintenance. +""" + +import contextlib +import os +import subprocess +import tempfile +import time +from typing import Any, Dict, Optional + +import openai +import pytest +import requests +import yaml +from defs.common import get_free_port_in_ci as get_free_port + +from tensorrt_llm.llmapi import CompletionOutput, RequestOutput, SamplingParams +from tensorrt_llm.llmapi.llm_args import LlmArgs +from tensorrt_llm.llmapi.tokenizer import load_hf_tokenizer + +from ..conftest import llm_models_root, skip_pre_blackwell +from ..trt_test_alternative import popen +from .accuracy_core import LlmapiAccuracyTestHarness +from .test_disaggregated_serving import ( + DEFAULT_SERVER_WAITING_TIMEOUT, + DEFAULT_TEST_TIMEOUT, + DuckLLM, + MyThreadPoolExecutor, + Result, + run_accuracy_test, +) + + +@contextlib.contextmanager +def launch_dwdp_disaggregated_llm( + worker_config: Dict[str, Any], + frontend_config: Dict[str, Any], + model_path: str, + total_gpus: int, + server_waiting_timeout: int = DEFAULT_SERVER_WAITING_TIMEOUT, + max_workers: int = 128, +): + """Launch DWDP disaggregated serving via mpirun. + + DWDP requires all workers (CTX + GEN) in a single MPI world for + IPC handle exchange and DWDP group formation. This function starts + all workers with ``mpirun`` and launches a separate disaggregated + frontend server for the client-facing OpenAI API. + """ + temp_dir = tempfile.TemporaryDirectory() + worker_config_path = os.path.join(temp_dir.name, "worker_config.yaml") + frontend_config_path = os.path.join(temp_dir.name, "frontend_config.yaml") + + with open(worker_config_path, "w") as f: + yaml.dump(worker_config, f, default_flow_style=False, sort_keys=False) + with open(frontend_config_path, "w") as f: + yaml.dump(frontend_config, f, default_flow_style=False, sort_keys=False) + + serve_port = frontend_config["port"] + + child_env = { + k: v + for k, v in os.environ.items() + if not k.startswith(('OMPI_', 'PMIX_', 'PMI_')) + } + + mpi_cmd = [ + "mpirun", "--allow-run-as-root", "-n", + str(total_gpus), "trtllm-serve", "disaggregated_mpi_worker", "-c", + worker_config_path + ] + + frontend_cmd = [ + "trtllm-serve", "disaggregated", "-c", frontend_config_path, + "--server_start_timeout", + str(server_waiting_timeout), "-r", "360000" + ] + + with ( + MyThreadPoolExecutor(max_workers=max_workers) as thread_pool, + temp_dir, + popen(mpi_cmd, env=child_env) as mpi_proc, + popen(frontend_cmd, env=child_env) as frontend_proc, + ): + start_time = time.time() + server_is_ready = False + while time.time() - start_time < server_waiting_timeout: + time.sleep(5) + for proc, name in [ + (mpi_proc, "mpirun"), + (frontend_proc, "frontend"), + ]: + if proc.poll() is not None: + raise Exception( + f"{name} process exited with code {proc.returncode}") + try: + response = requests.get( + f"http://localhost:{serve_port}/cluster_info") + if response.status_code == 200: + cluster_info = response.json() + if cluster_info.get("is_ready"): + print(f"DWDP cluster ready: {cluster_info}") + server_is_ready = True + break + except requests.exceptions.ConnectionError: + continue + if not server_is_ready: + pytest.fail( + f"DWDP server not ready after {server_waiting_timeout}s") + + model_name = worker_config.get("model", model_path) + client = openai.OpenAI(api_key="1234567890", + base_url=f"http://localhost:{serve_port}/v1", + timeout=1800000) + + def send_request(prompt: str, sampling_params: SamplingParams, + streaming: bool): + kwargs = {} + if sampling_params is not None: + kwargs.update( + max_tokens=sampling_params.max_tokens, + temperature=(sampling_params.temperature + if sampling_params.top_p is not None else 0), + top_p=sampling_params.top_p, + stop=sampling_params.stop, + seed=sampling_params.seed) + response = client.completions.create(model=model_name, + prompt=prompt, + stream=streaming, + **kwargs) + result = Result(id=0, + sampling_params=sampling_params, + outputs=[ + CompletionOutput(text=response.choices[0].text, + index=0) + ]) + requested_output = RequestOutput._from_generation_result( + result, prompt=prompt) + setattr(requested_output, "result", result.result) + return requested_output + + def generate_async(prompt: str, + sampling_params: Optional[SamplingParams] = None, + streaming: bool = False): + future = thread_pool.submit(send_request, prompt, sampling_params, + streaming) + thread_pool.futures.append(future) + return future + + args = LlmArgs(model=model_path) + tokenizer = load_hf_tokenizer(model_path) + try: + yield DuckLLM(args, tokenizer, generate_async) + finally: + all_procs = [frontend_proc, mpi_proc] + for proc in all_procs: + if proc.poll() is None: + proc.terminate() + deadline = time.monotonic() + 5 + for proc in all_procs: + remaining = max(0, deadline - time.monotonic()) + try: + proc.wait(timeout=remaining) + except subprocess.TimeoutExpired: + try: + proc.kill() + except ProcessLookupError: + pass + except OSError: + pass + + +@pytest.mark.timeout(DEFAULT_TEST_TIMEOUT) +class TestDwdpDeepSeekV3Lite(LlmapiAccuracyTestHarness): + MODEL_NAME = "deepseek-ai/DeepSeek-V3-Lite" + + @pytest.mark.skip_less_device(4) + @skip_pre_blackwell + def test_dwdp_accuracy(self): + model_path = f"{llm_models_root()}/DeepSeek-V3-Lite/nvfp4_moe_only_mtp" + + ctx_port_0 = get_free_port() + ctx_port_1 = get_free_port() + gen_port = get_free_port() + serve_port = get_free_port() + + ctx_server_config = { + "num_instances": 2, + "urls": [ + f"localhost:{ctx_port_0}", + f"localhost:{ctx_port_1}", + ], + "tensor_parallel_size": 1, + "pipeline_parallel_size": 1, + "disable_overlap_scheduler": True, + "enable_autotuner": False, + "enable_chunked_prefill": False, + "cuda_graph_config": None, + "max_batch_size": 16, + "max_num_tokens": 8192, + "kv_cache_config": { + "free_gpu_memory_fraction": 0.4, + "enable_block_reuse": False, + "enable_partial_reuse": False, + "tokens_per_block": 32, + }, + "cache_transceiver_config": { + "backend": "UCX", + "max_tokens_in_buffer": 8192, + }, + "moe_config": { + "backend": "CUTEDSL", + }, + "dwdp_config": { + "enabled": True, + "dwdp_size": 2, + "num_group": 1, + "experts_per_worker": 36, + "num_prefetch_experts": 36, + }, + } + + gen_server_config = { + "num_instances": 1, + "urls": [f"localhost:{gen_port}"], + "tensor_parallel_size": 2, + "pipeline_parallel_size": 1, + "disable_overlap_scheduler": True, + "enable_autotuner": False, + "enable_chunked_prefill": False, + "cuda_graph_config": None, + "max_batch_size": 128, + "max_num_tokens": 1024, + "kv_cache_config": { + "free_gpu_memory_fraction": 0.5, + "enable_block_reuse": False, + "enable_partial_reuse": False, + "tokens_per_block": 32, + }, + "cache_transceiver_config": { + "backend": "UCX", + "max_tokens_in_buffer": 8192, + }, + "moe_config": { + "backend": "CUTEDSL", + }, + } + + worker_config = { + "model": model_path, + "hostname": "localhost", + "port": serve_port, + "backend": "pytorch", + "context_servers": ctx_server_config, + "generation_servers": gen_server_config, + } + + frontend_config = { + "backend": "pytorch", + "hostname": "localhost", + "port": serve_port, + "context_servers": { + "num_instances": 2, + "urls": [ + f"localhost:{ctx_port_0}", + f"localhost:{ctx_port_1}", + ], + }, + "generation_servers": { + "num_instances": 1, + "urls": [f"localhost:{gen_port}"], + }, + } + + with launch_dwdp_disaggregated_llm(worker_config, + frontend_config, + model_path, + total_gpus=4, + max_workers=128) as llm: + run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"]) diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index f149dff121c2..6e4f62fdc88f 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -380,7 +380,7 @@ accuracy/test_disaggregated_serving.py::TestDeepSeekV32Exp::test_auto_dtype_with accuracy/test_disaggregated_serving.py::TestDeepSeekV32Exp::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1dp2cp2] accuracy/test_disaggregated_serving.py::TestDeepSeekV32Exp::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2] accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend -accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_dwdp_accuracy +accuracy/test_dwdp_disaggregated_serving.py::TestDwdpDeepSeekV3Lite::test_dwdp_accuracy accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False] accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True] accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[True] diff --git a/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml b/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml index c04cbe44c1f2..01fb27db5a96 100644 --- a/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml +++ b/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml @@ -93,7 +93,7 @@ l0_gb200_multi_gpus: - accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4_adp_on-trtllm] - accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4_adp_off-trtllm] - accuracy/test_llm_api_pytorch_multimodal.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm] TIMEOUT (90) - - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_dwdp_accuracy + - accuracy/test_dwdp_disaggregated_serving.py::TestDwdpDeepSeekV3Lite::test_dwdp_accuracy - condition: ranges: From 28390e50c2449fef831f8bd8ad2373e668afad81 Mon Sep 17 00:00:00 2001 From: tianyuz-nv Date: Mon, 30 Mar 2026 08:51:26 -0700 Subject: [PATCH 08/12] fix(moe): remove commented-out barrier in moeA2AInitializeOp Remove the dead commented-out world().barrier() line as session().barrier() is the intended synchronization primitive. Signed-off-by: tianyuz-nv --- cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp b/cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp index 28bdd7fb62ac..4d3b396e23c9 100644 --- a/cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp +++ b/cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp @@ -144,7 +144,6 @@ torch::Tensor moeA2AInitializeOp(torch::Tensor const& workspace, int64_t epRank, // Synchronize among ranks cudaDeviceSynchronize(); - // tensorrt_llm::mpi::MpiComm::world().barrier(); tensorrt_llm::mpi::MpiComm::session().barrier(); return metainfo; From a4085eacec932d05df3f740a684e6bf93486334c Mon Sep 17 00:00:00 2001 From: tianyuz-nv Date: Mon, 30 Mar 2026 08:56:06 -0700 Subject: [PATCH 09/12] fix(dwdp): move record_compute_and_prefetch_next to per-layer level Move the DWDP prefetch trigger from _forward_chunk_impl() to forward_impl() to ensure it is called once per layer instead of once per chunk. When num_chunks > 1, _forward_chunk_impl() is called multiple times in a loop. The previous placement would trigger prefetch for layer_idx+2 after the first chunk's compute, overwriting the ping-pong buffer that subsequent chunks still need to read from, causing a potential data race with silent precision degradation. Moving the call to forward_impl() guarantees it executes exactly once per layer, after all chunks have completed, making DWDP compatible with multi-chunk MoE execution. This is a defensive fix: current DWDP configurations always have num_chunks=1 (moe_max_num_tokens >= local_tokens), so the bug has not been triggered in practice. Signed-off-by: tianyuz-nv --- tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py b/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py index 092519860741..35478e4a98c6 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py @@ -513,6 +513,10 @@ def forward_impl( do_finalize, ) + # DWDP: record compute and trigger next prefetch (per-layer, not per-chunk) + if self.enable_dwdp: + self.dwdp_manager.record_compute_and_prefetch_next(self.layer_idx) + # ========== Step 4: Handle output truncation and EPLB repeat ========== if self.use_dp and self.parallel_size > 1: outputs = outputs[: all_rank_num_tokens[self.mapping.tp_rank]] @@ -808,8 +812,6 @@ def _forward_chunk_impl( router_logits, do_finalize, all_rank_num_tokens, output_dtype, x, workspace ), ) - if self.enable_dwdp: - self.dwdp_manager.record_compute_and_prefetch_next(self.layer_idx) # ========== Step 8: EPLB - Start CPU stage ========== self._load_balancer_start_set_cpu_stage(is_last_call) From a8a0a8f06bea7630b0c629e97d37ba056e2b72cf Mon Sep 17 00:00:00 2001 From: tianyuz-nv Date: Mon, 30 Mar 2026 10:52:24 -0700 Subject: [PATCH 10/12] style: fix pre-commit formatting for DWDP files Signed-off-by: tianyuz-nv --- .../disaggregated/slurm/benchmark/submit.py | 3 +- .../slurm/benchmark/submit_dwdp.py | 351 +++++++++--------- .../test_dwdp_disaggregated_serving.py | 98 ++--- 3 files changed, 229 insertions(+), 223 deletions(-) diff --git a/examples/disaggregated/slurm/benchmark/submit.py b/examples/disaggregated/slurm/benchmark/submit.py index 0dfa3d6fc356..a1bcedaf5087 100644 --- a/examples/disaggregated/slurm/benchmark/submit.py +++ b/examples/disaggregated/slurm/benchmark/submit.py @@ -524,8 +524,7 @@ def submit_job(config, log_dir, dry_run): benchmark_mode=benchmark_config['mode'], nsys_on=profiling_config['nsys_on'], profile_range=server_cfg['profile_range'], - concurrency=benchmark_config['concurrency_list'].split(',') - [0], + concurrency=benchmark_config['concurrency_list'].split(',')[0], ) export_str = format_export_string(worker_env) diff --git a/examples/disaggregated/slurm/benchmark/submit_dwdp.py b/examples/disaggregated/slurm/benchmark/submit_dwdp.py index 0bb19cce7c8f..cc0529f52fad 100644 --- a/examples/disaggregated/slurm/benchmark/submit_dwdp.py +++ b/examples/disaggregated/slurm/benchmark/submit_dwdp.py @@ -10,21 +10,17 @@ import argparse import glob import json -import math import os import shutil import subprocess import sys import traceback from datetime import datetime -from typing import Any, Dict, List import yaml - from submit import ( allocate_gpus, build_server_environment, - build_worker_environment, calculate_nodes, convert_allocations_to_server_config, convert_envs_to_str, @@ -37,29 +33,22 @@ def parse_args(): - parser = argparse.ArgumentParser( - description='Submit DWDP disaggregated benchmark job') + parser = argparse.ArgumentParser(description="Submit DWDP disaggregated benchmark job") group = parser.add_mutually_exclusive_group(required=True) - group.add_argument('-c', - '--config', - type=str, - help='Path to the configuration YAML file') - group.add_argument('-d', - '--dir', - type=str, - help='Directory containing YAML configuration files') - parser.add_argument('--log-dir', - type=str, - default=None, - help='Log directory') - parser.add_argument('--dry-run', - action='store_true', - help='Dry run the Python part, test purpose only') + group.add_argument("-c", "--config", type=str, help="Path to the configuration YAML file") + group.add_argument( + "-d", "--dir", type=str, help="Directory containing YAML configuration files" + ) + parser.add_argument("--log-dir", type=str, default=None, help="Log directory") + parser.add_argument( + "--dry-run", action="store_true", help="Dry run the Python part, test purpose only" + ) return parser.parse_args() -def generate_mpi_worker_config(worker_config, allocations, env_config, - disagg_hostname, disagg_port, output_path): +def generate_mpi_worker_config( + worker_config, allocations, env_config, disagg_hostname, disagg_port, output_path +): """Generate a config YAML compatible with ``trtllm-serve disaggregated_mpi_worker``.""" def _build_urls(server_type): @@ -73,130 +62,129 @@ def _build_urls(server_type): ctx_urls = _build_urls("CTX") gen_urls = _build_urls("GEN") - ctx_section = dict(worker_config['ctx']) - ctx_section['num_instances'] = len(ctx_urls) - ctx_section['urls'] = ctx_urls + ctx_section = dict(worker_config["ctx"]) + ctx_section["num_instances"] = len(ctx_urls) + ctx_section["urls"] = ctx_urls - gen_section = dict(worker_config['gen']) - gen_section['num_instances'] = len(gen_urls) - gen_section['urls'] = gen_urls + gen_section = dict(worker_config["gen"]) + gen_section["num_instances"] = len(gen_urls) + gen_section["urls"] = gen_urls config = { - 'model': env_config['model_path'], - 'hostname': disagg_hostname, - 'port': disagg_port, - 'backend': 'pytorch', - 'max_retries': 100, - 'context_servers': ctx_section, - 'generation_servers': gen_section, + "model": env_config["model_path"], + "hostname": disagg_hostname, + "port": disagg_port, + "backend": "pytorch", + "max_retries": 100, + "context_servers": ctx_section, + "generation_servers": gen_section, } os.makedirs(os.path.dirname(output_path), exist_ok=True) - with open(output_path, 'w') as f: + with open(output_path, "w") as f: yaml.dump(config, f, default_flow_style=False, sort_keys=False) def submit_dwdp_job(config, log_dir, dry_run): """Submit a DWDP disaggregated benchmark job.""" - slurm_config = config['slurm'] - slurm_config.setdefault('extra_args', '') - slurm_config.setdefault('set_segment', True) + slurm_config = config["slurm"] + slurm_config.setdefault("extra_args", "") + slurm_config.setdefault("set_segment", True) - hw_config = config['hardware'] - env_config = config['environment'] - worker_config = config['worker_config'] - benchmark_config = config['benchmark'] + hw_config = config["hardware"] + env_config = config["environment"] + worker_config = config["worker_config"] + benchmark_config = config["benchmark"] - if 'work_dir' in env_config and os.path.isdir(env_config['work_dir']): - script_dir = env_config['work_dir'] + if "work_dir" in env_config and os.path.isdir(env_config["work_dir"]): + script_dir = env_config["work_dir"] else: script_dir = os.path.dirname(os.path.abspath(__file__)) - if 'accuracy' not in config: - config['accuracy'] = { - 'enable_accuracy_test': - False, - 'model': - 'local-completions', - 'tasks': - 'gsm8k', - 'model_args_extra': - 'num_concurrent=512,max_retries=3,tokenized_requests=false,timeout=1200,max_gen_toks=256,max_length=4096' + if "accuracy" not in config: + config["accuracy"] = { + "enable_accuracy_test": False, + "model": "local-completions", + "tasks": "gsm8k", + "model_args_extra": ( + "num_concurrent=512,max_retries=3," + "tokenized_requests=false,timeout=1200," + "max_gen_toks=256,max_length=4096" + ), } - env_config.setdefault('trtllm_repo', '') - env_config.setdefault('build_wheel', False) - env_config.setdefault('cuda_architectures', '') - env_config.setdefault('trtllm_wheel_path', '') - env_config.setdefault('worker_env_var', '') - env_config.setdefault('server_env_var', '') - - profiling_config = config.get('profiling', {}) - profiling_config.setdefault('nsys_on', False) - profiling_config.setdefault('ctx_profile_range', '10-30') - profiling_config.setdefault('gen_profile_range', '200-250') - - ctx_num = hw_config['num_ctx_servers'] - gen_num = hw_config['num_gen_servers'] - gpus_per_node = hw_config['gpus_per_node'] - - ctx_tp_size = worker_config['ctx'].get('tensor_parallel_size', 1) - ctx_cp_size = worker_config['ctx'].get('context_parallel_size', 1) - ctx_pp_size = worker_config['ctx'].get('pipeline_parallel_size', 1) + env_config.setdefault("trtllm_repo", "") + env_config.setdefault("build_wheel", False) + env_config.setdefault("cuda_architectures", "") + env_config.setdefault("trtllm_wheel_path", "") + env_config.setdefault("worker_env_var", "") + env_config.setdefault("server_env_var", "") + + profiling_config = config.get("profiling", {}) + profiling_config.setdefault("nsys_on", False) + profiling_config.setdefault("ctx_profile_range", "10-30") + profiling_config.setdefault("gen_profile_range", "200-250") + + ctx_num = hw_config["num_ctx_servers"] + gen_num = hw_config["num_gen_servers"] + gpus_per_node = hw_config["gpus_per_node"] + + ctx_tp_size = worker_config["ctx"].get("tensor_parallel_size", 1) + ctx_cp_size = worker_config["ctx"].get("context_parallel_size", 1) + ctx_pp_size = worker_config["ctx"].get("pipeline_parallel_size", 1) ctx_world_size = ctx_tp_size * ctx_cp_size * ctx_pp_size ctx_nodes = calculate_nodes(ctx_world_size, ctx_num, gpus_per_node) - gen_tp_size = worker_config['gen'].get('tensor_parallel_size', 1) - gen_cp_size = worker_config['gen'].get('context_parallel_size', 1) - gen_pp_size = worker_config['gen'].get('pipeline_parallel_size', 1) + gen_tp_size = worker_config["gen"].get("tensor_parallel_size", 1) + gen_cp_size = worker_config["gen"].get("context_parallel_size", 1) + gen_pp_size = worker_config["gen"].get("pipeline_parallel_size", 1) gen_world_size = gen_tp_size * gen_cp_size * gen_pp_size gen_nodes = calculate_nodes(gen_world_size, gen_num, gpus_per_node) - ucx_warmup_requests = 2 * ctx_world_size * \ - gen_world_size if benchmark_config['mode'] == "e2e" else 0 + ucx_warmup_requests = ( + 2 * ctx_world_size * gen_world_size if benchmark_config["mode"] == "e2e" else 0 + ) total_nodes = ctx_nodes + gen_nodes total_tasks = total_nodes * gpus_per_node - dwdp_size = worker_config.get('ctx', {}).get('dwdp_config', - {}).get('dwdp_size', 1) + dwdp_size = worker_config.get("ctx", {}).get("dwdp_config", {}).get("dwdp_size", 1) - isl = benchmark_config['input_length'] - osl = benchmark_config['output_length'] - gen_batch_size = worker_config['gen']['max_batch_size'] + isl = benchmark_config["input_length"] + osl = benchmark_config["output_length"] + gen_batch_size = worker_config["gen"]["max_batch_size"] - load_balancer_config = worker_config['gen'].get('moe_config', {}).get( - 'load_balancer', {}) + load_balancer_config = worker_config["gen"].get("moe_config", {}).get("load_balancer", {}) if isinstance(load_balancer_config, str): - with open(load_balancer_config, 'r') as f: + with open(load_balancer_config, "r") as f: load_balancer_config = yaml.safe_load(f) - eplb_num_slots = load_balancer_config.get('num_slots', 0) + eplb_num_slots = load_balancer_config.get("num_slots", 0) - mtp_size = worker_config['gen'].get('speculative_config', - {}).get('num_nextn_predict_layers', 0) + mtp_size = worker_config["gen"].get("speculative_config", {}).get("num_nextn_predict_layers", 0) - if 'log_dir' in env_config and env_config['log_dir']: - log_dir = env_config['log_dir'] + if "log_dir" in env_config and env_config["log_dir"]: + log_dir = env_config["log_dir"] if log_dir is None: log_base = os.path.join(script_dir, "logs") date_prefix = datetime.now().strftime("%Y%m%d-%H%M%S") log_base = os.path.join(log_base, f"{date_prefix}/{isl}-{osl}") - dir_suffix = f"disagg_ctx{ctx_num}_dwdp{dwdp_size}_gen{gen_num}_dep{gen_tp_size}_batch{gen_batch_size}_eplb{eplb_num_slots}_mtp{mtp_size}" + dir_suffix = ( + f"disagg_ctx{ctx_num}_dwdp{dwdp_size}_gen{gen_num}" + f"_dep{gen_tp_size}_batch{gen_batch_size}" + f"_eplb{eplb_num_slots}_mtp{mtp_size}" + ) log_dir = os.path.join(log_base, dir_suffix) if os.path.exists(log_dir): - if not os.path.exists(os.path.join(log_dir, 'trtllm_config.yaml')): + if not os.path.exists(os.path.join(log_dir, "trtllm_config.yaml")): print(f"[WARNING] Removing existing log directory: {log_dir}") shutil.rmtree(log_dir) else: - print( - f"[WARNING] trtllm_config.yaml exists, not removing the directory: {log_dir}" - ) + print(f"[WARNING] trtllm_config.yaml exists, not removing the directory: {log_dir}") for file in os.listdir(log_dir): - if file != 'trtllm_config.yaml' and not file.startswith( - 'concurrency_'): + if file != "trtllm_config.yaml" and not file.startswith("concurrency_"): if os.path.isdir(os.path.join(log_dir, file)): shutil.rmtree(os.path.join(log_dir, file)) else: @@ -204,10 +192,10 @@ def submit_dwdp_job(config, log_dir, dry_run): os.makedirs(log_dir, exist_ok=True) print(f"Log will be saved to: {log_dir}") - ctx_config_path = os.path.join(log_dir, 'ctx_config.yaml') - gen_config_path = os.path.join(log_dir, 'gen_config.yaml') - save_worker_config(worker_config['ctx'], ctx_config_path) - save_worker_config(worker_config['gen'], gen_config_path) + ctx_config_path = os.path.join(log_dir, "ctx_config.yaml") + gen_config_path = os.path.join(log_dir, "gen_config.yaml") + save_worker_config(worker_config["ctx"], ctx_config_path) + save_worker_config(worker_config["gen"], gen_config_path) allocations = allocate_gpus( total_nodes=total_nodes, @@ -223,21 +211,25 @@ def submit_dwdp_job(config, log_dir, dry_run): server_config = convert_allocations_to_server_config(allocations) with open(os.path.join(log_dir, "server_config_base.yaml"), "w") as f: yaml.dump(server_config, f) - disagg_server_hostname = server_config['hostname'] - disagg_server_port = server_config['port'] + disagg_server_hostname = server_config["hostname"] + disagg_server_port = server_config["port"] container_name = "disaggr-test" start_server_cmds = [] - container_mount_str = env_config['container_mount'] + container_mount_str = env_config["container_mount"] container_mount_str += f",{script_dir}:{script_dir}" # --- DWDP mode: single srun with disaggregated_mpi_worker --- - mpi_config_base_path = os.path.join(log_dir, - 'mpi_worker_config_base.yaml') - mpi_config_path = os.path.join(log_dir, 'mpi_worker_config.yaml') - generate_mpi_worker_config(worker_config, allocations, env_config, - disagg_server_hostname, disagg_server_port, - mpi_config_base_path) + mpi_config_base_path = os.path.join(log_dir, "mpi_worker_config_base.yaml") + mpi_config_path = os.path.join(log_dir, "mpi_worker_config.yaml") + generate_mpi_worker_config( + worker_config, + allocations, + env_config, + disagg_server_hostname, + disagg_server_port, + mpi_config_base_path, + ) ctx_node_list = [] for sid in sorted(allocations.get("CTX", {}).keys()): @@ -253,13 +245,15 @@ def submit_dwdp_job(config, log_dir, dry_run): total_mpi_tasks = ctx_num * ctx_world_size + gen_num * gen_world_size mpi_num_nodes = len(mpi_nodelist) num_ctx_gpus = ctx_num * ctx_world_size - worker_env_var = env_config.get('worker_env_var', '') - ctx_worker_env_var = env_config.get('ctx_worker_env_var', '') - gen_worker_env_var = env_config.get('gen_worker_env_var', '') - dwdp_ctx_worker_env_var = worker_env_var + \ - (f" {ctx_worker_env_var}" if ctx_worker_env_var else "") - dwdp_gen_worker_env_var = worker_env_var + \ - (f" {gen_worker_env_var}" if gen_worker_env_var else "") + worker_env_var = env_config.get("worker_env_var", "") + ctx_worker_env_var = env_config.get("ctx_worker_env_var", "") + gen_worker_env_var = env_config.get("gen_worker_env_var", "") + dwdp_ctx_worker_env_var = worker_env_var + ( + f" {ctx_worker_env_var}" if ctx_worker_env_var else "" + ) + dwdp_gen_worker_env_var = worker_env_var + ( + f" {gen_worker_env_var}" if gen_worker_env_var else "" + ) cmd = [ "srun -l", @@ -273,9 +267,9 @@ def submit_dwdp_job(config, log_dir, dry_run): "--no-container-mount-home --mpi=pmix --overlap", f"bash {os.path.join(script_dir, 'start_worker_dwdp.sh')}", mpi_config_path, - str(slurm_config['numa_bind']).lower(), + str(slurm_config["numa_bind"]).lower(), log_dir, - str(profiling_config['nsys_on']).lower(), + str(profiling_config["nsys_on"]).lower(), f"'{profiling_config['ctx_profile_range']}'", f"'{profiling_config['gen_profile_range']}'", str(num_ctx_gpus), @@ -286,17 +280,17 @@ def submit_dwdp_job(config, log_dir, dry_run): start_server_cmds.append(" ".join(cmd)) # Generate start server commands - server_env = build_server_environment(env_config, benchmark_config['mode']) + server_env = build_server_environment(env_config, benchmark_config["mode"]) export_str = format_export_string(server_env) cmd = [ "srun -l", f"--nodelist {disagg_server_hostname}", f"--container-name={container_name}", - f"--export=\"{export_str}\"", + f'--export="{export_str}"', f"--container-image={env_config['container_image']}", f"--container-mounts={container_mount_str}", - f"--no-container-mount-home --mpi=pmix --overlap -N 1 -n 1", + "--no-container-mount-home --mpi=pmix --overlap -N 1 -n 1", f"bash {os.path.join(script_dir, 'start_server.sh')} {os.path.join(log_dir, 'server_config.yaml')}", f"&> {log_dir}/4_output_server.log &", ] @@ -304,10 +298,10 @@ def submit_dwdp_job(config, log_dir, dry_run): save_env_file( os.path.join(log_dir, "env_vars.json"), - env_config.get('server_env_var', ''), - env_config.get('worker_env_var', ''), - env_config.get('ctx_worker_env_var', ''), - env_config.get('gen_worker_env_var', ''), + env_config.get("server_env_var", ""), + env_config.get("worker_env_var", ""), + env_config.get("ctx_worker_env_var", ""), + env_config.get("gen_worker_env_var", ""), ) # Generate wait server command @@ -315,7 +309,7 @@ def submit_dwdp_job(config, log_dir, dry_run): "srun -l", f"--container-name={container_name}", f"--container-mounts={container_mount_str}", - f"--mpi=pmix --overlap -N 1 -n 1", + "--mpi=pmix --overlap -N 1 -n 1", f"bash {os.path.join(script_dir, 'wait_server.sh')} {disagg_server_hostname} {disagg_server_port}", f"&> {log_dir}/5_wait_server.log", ] @@ -329,80 +323,94 @@ def submit_dwdp_job(config, log_dir, dry_run): client_slurm_prefix = [ f"srun -l --container-name={container_name}", f"--container-mounts={container_mount_str}", - f"--mpi=pmix --overlap -N 1 -n 1", + "--mpi=pmix --overlap -N 1 -n 1", ] - if benchmark_config.get('enable_benchmark', True): - env_var = config['benchmark'].get('env_var', {}) - benchmark_prefix = client_slurm_prefix + [ - f"--export \"{convert_envs_to_str(env_var)}\"" - ] - if benchmark_config['use_nv_sa_benchmark']: - if benchmark_config['mode'] == "gen_only": - print( - f"[ERROR] SA benchmark client script is not supported for gen_only mode" - ) + if benchmark_config.get("enable_benchmark", True): + env_var = config["benchmark"].get("env_var", {}) + benchmark_prefix = client_slurm_prefix + [f'--export "{convert_envs_to_str(env_var)}"'] + if benchmark_config["use_nv_sa_benchmark"]: + if benchmark_config["mode"] == "gen_only": + print("[ERROR] SA benchmark client script is not supported for gen_only mode") sys.exit(1) benchmark_cmd = [ f"bash {os.path.join(script_dir, 'run_benchmark_nv_sa.sh')}", - f"'{env_config['model_path']}' {isl} {osl} {benchmark_config['benchmark_ratio']} {benchmark_config['multi_round']} {gen_num} '{benchmark_config['concurrency_list']}' {benchmark_config['streaming']} '{log_dir}' {disagg_server_hostname} {disagg_server_port} {ucx_warmup_requests}", - f"&> {log_dir}/6_bench.log" + ( + f"'{env_config['model_path']}' {isl} {osl}" + f" {benchmark_config['benchmark_ratio']}" + f" {benchmark_config['multi_round']} {gen_num}" + f" '{benchmark_config['concurrency_list']}'" + f" {benchmark_config['streaming']} '{log_dir}'" + f" {disagg_server_hostname} {disagg_server_port}" + f" {ucx_warmup_requests}" + ), + f"&> {log_dir}/6_bench.log", ] client_cmds.append(" ".join(benchmark_prefix + benchmark_cmd)) else: benchmark_cmd = [ f"bash {os.path.join(script_dir, 'run_benchmark.sh')}", - f"'{env_config['model_path']}' '{benchmark_config['dataset_file']}' {benchmark_config['multi_round']} {gen_num} '{benchmark_config['concurrency_list']}' {benchmark_config['streaming']} '{log_dir}' {disagg_server_hostname} {disagg_server_port} {ucx_warmup_requests}", - f"&> {log_dir}/6_bench.log" + ( + f"'{env_config['model_path']}'" + f" '{benchmark_config['dataset_file']}'" + f" {benchmark_config['multi_round']} {gen_num}" + f" '{benchmark_config['concurrency_list']}'" + f" {benchmark_config['streaming']} '{log_dir}'" + f" {disagg_server_hostname} {disagg_server_port}" + f" {ucx_warmup_requests}" + ), + f"&> {log_dir}/6_bench.log", ] client_cmds.append(" ".join(benchmark_prefix + benchmark_cmd)) - if config['accuracy']['enable_accuracy_test']: - env_var = config['accuracy'].get('env_var', {}) - accuracy_prefix = client_slurm_prefix + [ - f"--export \"{convert_envs_to_str(env_var)}\"" - ] - for task in config['accuracy']['tasks']: - extra_kwargs = config['accuracy']['tasks'][task].get( - 'extra_kwargs', {}) + if config["accuracy"]["enable_accuracy_test"]: + env_var = config["accuracy"].get("env_var", {}) + accuracy_prefix = client_slurm_prefix + [f'--export "{convert_envs_to_str(env_var)}"'] + for task in config["accuracy"]["tasks"]: + extra_kwargs = config["accuracy"]["tasks"][task].get("extra_kwargs", {}) extra_kwargs_str = "" for key, value in extra_kwargs.items(): if isinstance(value, bool): if value: extra_kwargs_str += f" --{key}" elif key == "custom_config": - extra_kwargs_str += f" --include_path={replace_env_in_file(log_dir, value, env_var)}" + extra_kwargs_str += ( + f" --include_path={replace_env_in_file(log_dir, value, env_var)}" + ) else: extra_kwargs_str += f" --{key}='{value}'" end_point_map = { - 'local-completions': 'v1/completions', - 'local-chat-completions': 'v1/chat/completions', + "local-completions": "v1/completions", + "local-chat-completions": "v1/chat/completions", } - model = config['accuracy']['tasks'][task]['model'] + model = config["accuracy"]["tasks"][task]["model"] accuracy_cmd = [ - 'lm_eval', '--model', model, '--tasks', task, '--model_args', + "lm_eval", + "--model", + model, + "--tasks", + task, + "--model_args", f"model={env_config['model_path']},base_url=http://{disagg_server_hostname}:{disagg_server_port}/{end_point_map[model]},{config['accuracy']['tasks'][task]['model_args_extra']}", - '--log_samples', '--output_path', - f'{log_dir}/accuracy_eval_{task}', extra_kwargs_str, - f"&> {log_dir}/7_accuracy_eval_{task}.log" + "--log_samples", + "--output_path", + f"{log_dir}/accuracy_eval_{task}", + extra_kwargs_str, + f"&> {log_dir}/7_accuracy_eval_{task}.log", ] client_cmds.append(" ".join(accuracy_prefix + accuracy_cmd)) - done_cmd = [ - "echo", "${SLURM_JOB_NODELIST}", ">", - f"{log_dir}/8_done_${{SLURM_JOB_ID}}.txt" - ] + done_cmd = ["echo", "${SLURM_JOB_NODELIST}", ">", f"{log_dir}/8_done_${{SLURM_JOB_ID}}.txt"] client_cmds.append(" ".join(done_cmd)) with open(os.path.join(log_dir, "client_cmds_base.sh"), "w") as f: f.write("\n".join(client_cmds) + "\n") - slurm_script_file = slurm_config['script_file'] + slurm_script_file = slurm_config["script_file"] if not os.path.isabs(slurm_script_file): slurm_script_file = os.path.join(script_dir, slurm_script_file) if not os.path.exists(slurm_script_file): - print(f"[ERROR] SLURM script file not found: {slurm_script_file}", - file=sys.stderr) + print(f"[ERROR] SLURM script file not found: {slurm_script_file}", file=sys.stderr) sys.exit(1) # yapf: disable @@ -457,12 +465,11 @@ def main(): if args.config: config_files = [args.config] else: - yaml_pattern = os.path.join(args.dir, '*.yaml') + yaml_pattern = os.path.join(args.dir, "*.yaml") config_files = sorted(glob.glob(yaml_pattern)) if not config_files: - print(f"No YAML files found in directory: {args.dir}", - file=sys.stderr) + print(f"No YAML files found in directory: {args.dir}", file=sys.stderr) sys.exit(1) print(f"Found {len(config_files)} YAML file(s) in {args.dir}") @@ -479,5 +486,5 @@ def main(): continue -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tests/integration/defs/accuracy/test_dwdp_disaggregated_serving.py b/tests/integration/defs/accuracy/test_dwdp_disaggregated_serving.py index bea6bb8f7ae7..eabf9a5e4045 100644 --- a/tests/integration/defs/accuracy/test_dwdp_disaggregated_serving.py +++ b/tests/integration/defs/accuracy/test_dwdp_disaggregated_serving.py @@ -15,8 +15,8 @@ import pytest import requests import yaml -from defs.common import get_free_port_in_ci as get_free_port +from defs.common import get_free_port_in_ci as get_free_port from tensorrt_llm.llmapi import CompletionOutput, RequestOutput, SamplingParams from tensorrt_llm.llmapi.llm_args import LlmArgs from tensorrt_llm.llmapi.tokenizer import load_hf_tokenizer @@ -62,28 +62,36 @@ def launch_dwdp_disaggregated_llm( serve_port = frontend_config["port"] child_env = { - k: v - for k, v in os.environ.items() - if not k.startswith(('OMPI_', 'PMIX_', 'PMI_')) + k: v for k, v in os.environ.items() if not k.startswith(("OMPI_", "PMIX_", "PMI_")) } mpi_cmd = [ - "mpirun", "--allow-run-as-root", "-n", - str(total_gpus), "trtllm-serve", "disaggregated_mpi_worker", "-c", - worker_config_path + "mpirun", + "--allow-run-as-root", + "-n", + str(total_gpus), + "trtllm-serve", + "disaggregated_mpi_worker", + "-c", + worker_config_path, ] frontend_cmd = [ - "trtllm-serve", "disaggregated", "-c", frontend_config_path, + "trtllm-serve", + "disaggregated", + "-c", + frontend_config_path, "--server_start_timeout", - str(server_waiting_timeout), "-r", "360000" + str(server_waiting_timeout), + "-r", + "360000", ] with ( - MyThreadPoolExecutor(max_workers=max_workers) as thread_pool, - temp_dir, - popen(mpi_cmd, env=child_env) as mpi_proc, - popen(frontend_cmd, env=child_env) as frontend_proc, + MyThreadPoolExecutor(max_workers=max_workers) as thread_pool, + temp_dir, + popen(mpi_cmd, env=child_env) as mpi_proc, + popen(frontend_cmd, env=child_env) as frontend_proc, ): start_time = time.time() server_is_ready = False @@ -94,11 +102,9 @@ def launch_dwdp_disaggregated_llm( (frontend_proc, "frontend"), ]: if proc.poll() is not None: - raise Exception( - f"{name} process exited with code {proc.returncode}") + raise Exception(f"{name} process exited with code {proc.returncode}") try: - response = requests.get( - f"http://localhost:{serve_port}/cluster_info") + response = requests.get(f"http://localhost:{serve_port}/cluster_info") if response.status_code == 200: cluster_info = response.json() if cluster_info.get("is_ready"): @@ -108,45 +114,41 @@ def launch_dwdp_disaggregated_llm( except requests.exceptions.ConnectionError: continue if not server_is_ready: - pytest.fail( - f"DWDP server not ready after {server_waiting_timeout}s") + pytest.fail(f"DWDP server not ready after {server_waiting_timeout}s") model_name = worker_config.get("model", model_path) - client = openai.OpenAI(api_key="1234567890", - base_url=f"http://localhost:{serve_port}/v1", - timeout=1800000) + client = openai.OpenAI( + api_key="1234567890", base_url=f"http://localhost:{serve_port}/v1", timeout=1800000 + ) - def send_request(prompt: str, sampling_params: SamplingParams, - streaming: bool): + def send_request(prompt: str, sampling_params: SamplingParams, streaming: bool): kwargs = {} if sampling_params is not None: kwargs.update( max_tokens=sampling_params.max_tokens, - temperature=(sampling_params.temperature - if sampling_params.top_p is not None else 0), + temperature=( + sampling_params.temperature if sampling_params.top_p is not None else 0 + ), top_p=sampling_params.top_p, stop=sampling_params.stop, - seed=sampling_params.seed) - response = client.completions.create(model=model_name, - prompt=prompt, - stream=streaming, - **kwargs) - result = Result(id=0, - sampling_params=sampling_params, - outputs=[ - CompletionOutput(text=response.choices[0].text, - index=0) - ]) - requested_output = RequestOutput._from_generation_result( - result, prompt=prompt) + seed=sampling_params.seed, + ) + response = client.completions.create( + model=model_name, prompt=prompt, stream=streaming, **kwargs + ) + result = Result( + id=0, + sampling_params=sampling_params, + outputs=[CompletionOutput(text=response.choices[0].text, index=0)], + ) + requested_output = RequestOutput._from_generation_result(result, prompt=prompt) setattr(requested_output, "result", result.result) return requested_output - def generate_async(prompt: str, - sampling_params: Optional[SamplingParams] = None, - streaming: bool = False): - future = thread_pool.submit(send_request, prompt, sampling_params, - streaming) + def generate_async( + prompt: str, sampling_params: Optional[SamplingParams] = None, streaming: bool = False + ): + future = thread_pool.submit(send_request, prompt, sampling_params, streaming) thread_pool.futures.append(future) return future @@ -275,9 +277,7 @@ def test_dwdp_accuracy(self): }, } - with launch_dwdp_disaggregated_llm(worker_config, - frontend_config, - model_path, - total_gpus=4, - max_workers=128) as llm: + with launch_dwdp_disaggregated_llm( + worker_config, frontend_config, model_path, total_gpus=4, max_workers=128 + ) as llm: run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"]) From 85def5350e90d27e97ff7a7e384bcf1db893ee92 Mon Sep 17 00:00:00 2001 From: tianyuz-nv Date: Tue, 31 Mar 2026 02:56:02 -0700 Subject: [PATCH 11/12] refactor(dwdp): address reviewer feedback on DWDP config and lifecycle - Remove DwdpConfig.enabled field; use Optional[DwdpConfig]=None pattern - Change dwdp_config status from "beta" to "prototype" - Rename num_group -> num_groups, experts_per_worker -> num_experts_per_worker - Disallow DWDP with overlap scheduler via explicit ValueError - Add DwdpManager.cleanup() for IPC handle and MPI group release - Refactor DwdpManager as context manager (__enter__/__exit__) for controlled global registration and resource cleanup on shutdown Signed-off-by: tianyuz-nv --- .../_torch/modules/fused_moe/interface.py | 4 +-- tensorrt_llm/_torch/pyexecutor/dwdp.py | 32 ++++++++++++++----- tensorrt_llm/_torch/pyexecutor/py_executor.py | 8 +++++ .../_torch/pyexecutor/py_executor_creator.py | 3 +- tensorrt_llm/llmapi/llm_args.py | 13 ++++---- .../test_dwdp_disaggregated_serving.py | 5 ++- 6 files changed, 44 insertions(+), 21 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/interface.py b/tensorrt_llm/_torch/modules/fused_moe/interface.py index fcc3f3e3bb27..08420ab0b815 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/interface.py +++ b/tensorrt_llm/_torch/modules/fused_moe/interface.py @@ -317,9 +317,9 @@ def _init_dwdp_expert_layout(self): return assert self.layer_load_balancer is None, ( "DWDP and EPLB (MoE load balancer) cannot be used together. " - "Disable one of dwdp_config.enabled or moe_load_balancer.") + "Disable one of dwdp_config or moe_load_balancer.") self.num_slots = self.num_experts - self.expert_size_per_partition = dwdp_manager.experts_per_worker + self.expert_size_per_partition = dwdp_manager.num_experts_per_worker dwdp_size = dwdp_manager.dwdp_size self.initial_global_assignments = [ (ep_rank * self.num_experts // dwdp_size + local_slot_id) % diff --git a/tensorrt_llm/_torch/pyexecutor/dwdp.py b/tensorrt_llm/_torch/pyexecutor/dwdp.py index 892ee285444f..5793aab4bac0 100644 --- a/tensorrt_llm/_torch/pyexecutor/dwdp.py +++ b/tensorrt_llm/_torch/pyexecutor/dwdp.py @@ -158,7 +158,7 @@ def __init__( self, dwdp_size: int, dwdp_rank: int, - experts_per_worker: int, + num_experts_per_worker: int, num_prefetch_experts: int, num_layers: int, first_moe_layer_idx: int, @@ -167,7 +167,7 @@ def __init__( ): self.dwdp_size = dwdp_size self.num_prefetch_experts = num_prefetch_experts - self.experts_per_worker = experts_per_worker + self.num_experts_per_worker = num_experts_per_worker self.num_layers = num_layers self.first_moe_layer_idx = first_moe_layer_idx self.num_buffers = 2 # Ping-pong @@ -257,8 +257,8 @@ def __init__( self.config = config self.dist = dist self.dwdp_size = config.dwdp_size - self.experts_per_worker = config.experts_per_worker - self.num_group = config.num_group + self.num_experts_per_worker = config.num_experts_per_worker + self.num_groups = config.num_groups self._init_dwdp_group() @@ -276,9 +276,16 @@ def __init__( self.dwdp_rank = self.rank % self.dwdp_size self.num_prefetch_experts = config.num_prefetch_experts self.start_expert_id = self.num_prefetch_experts * self.dwdp_rank - self.end_expert_id = self.start_expert_id + self.experts_per_worker + self.end_expert_id = self.start_expert_id + self.num_experts_per_worker + def __enter__(self): set_global_dwdp_manager(self) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.cleanup() + set_global_dwdp_manager(None) + return False def _init_dwdp_group(self): if not isinstance(self.dist, MPIDist): @@ -287,7 +294,7 @@ def _init_dwdp_group(self): self.rank = global_mpi_rank() # Calculate which group this rank belongs to - # With num_group=2, dwdp_size=4: + # With num_groups=2, dwdp_size=4: # Group 0: ranks [0, 1, 2, 3] # Group 1: ranks [4, 5, 6, 7] self.group_id = self.rank // self.dwdp_size @@ -298,7 +305,16 @@ def _init_dwdp_group(self): self.dwdp_group = COMM_WORLD.Create_group(new_group) def is_enabled(self) -> bool: - return self.config.enabled and self.dwdp_size > 1 + return self.dwdp_size > 1 + + def cleanup(self): + """Release all IPC handles and clean up resources.""" + for collector in self.ipc_collectors: + collector.cleanup() + self.ipc_collectors.clear() + if self.dwdp_group is not None: + self.dwdp_group.Free() + self.dwdp_group = None def add_layer( self, @@ -379,7 +395,7 @@ def initialize_prefetch_buffer(self): self.prefetch_buffer = DwdpPrefetchBuffer( dwdp_size=self.dwdp_size, dwdp_rank=self.dwdp_rank, - experts_per_worker=self.experts_per_worker, + num_experts_per_worker=self.num_experts_per_worker, num_prefetch_experts=self.num_prefetch_experts, num_layers=len(self.ipc_collectors), first_moe_layer_idx=self.first_moe_layer_idx, diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 1fd246556b64..77ef1dcd5532 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -535,6 +535,11 @@ def on_detected(): if is_trace_enabled("TLLM_TRACE_EXECUTOR_LOOP"): self.event_loop = trace_func(self.event_loop) + if dwdp_manager is not None and not self.disable_overlap_scheduler: + raise ValueError( + "DWDP requires disable_overlap_scheduler=True. " + "Overlap scheduler is not yet supported with DWDP.") + if self.drafter is not None: if self.event_loop.__name__ == self._executor_loop_pp.__name__: raise NotImplementedError( @@ -767,6 +772,9 @@ def shutdown(self): if (isinstance(self.sampler, AsyncWorkerMixin) and self.sampler.async_worker_enabled()): self.sampler.async_worker_stop() + if self.dwdp_manager is not None: + self.dwdp_manager.__exit__(None, None, None) + self.dwdp_manager = None def can_enqueue_requests(self) -> bool: """ diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index dfd0d9b5c44c..78d39cb45f89 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -387,9 +387,10 @@ def create_py_executor( # Initialize DWDP Manager (only for context workers in disaggregated serving) dwdp_manager: Optional[DwdpManager] = None - if llm_args.dwdp_config is not None and llm_args.dwdp_config.enabled: + if llm_args.dwdp_config is not None: assert mapping.tp_size == 1 and llm_args.dwdp_config.dwdp_size > 1, "DWDP requires TP=1 and dwdp_size > 1" dwdp_manager = DwdpManager(config=llm_args.dwdp_config, dist=dist) + dwdp_manager.__enter__() logger.info(f"Dwdp Manager initialized. Config: {llm_args.dwdp_config}") mem_monitor = _ExecutorMemoryMonitor() diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 129fa18da3bc..070ff3470fcd 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -2398,14 +2398,13 @@ class DwdpConfig(StrictBaseModel): Currently supported with the CuteDSL MoE backend and NVFP4 quantization on NVLink-connected multi-GPU systems. """ - enabled: bool = Field(default=False, description="Whether to enable DWDP.") dwdp_size: int = Field(default=1, description="The number of GPUs per DWDP group.") - num_group: int = Field( + num_groups: int = Field( default=1, description= - "The number of DWDP groups. Total workers = num_group * dwdp_size.") - experts_per_worker: int = Field( + "The number of DWDP groups. Total workers = num_groups * dwdp_size.") + num_experts_per_worker: int = Field( default=0, description="The number of experts per worker.") num_prefetch_experts: int = Field( default=0, description="The number of prefetch experts per worker.") @@ -3299,10 +3298,10 @@ class TorchLlmArgs(BaseLlmArgs): description="NVFP4 GEMM backend config.", status="beta") - dwdp_config: DwdpConfig = Field( - default_factory=DwdpConfig, + dwdp_config: Optional[DwdpConfig] = Field( + default=None, description="DWDP (Distributed Weight Data Parallelism) config.", - status="beta") + status="prototype") attn_backend: str = Field(default='TRTLLM', description="Attention backend to use.", diff --git a/tests/integration/defs/accuracy/test_dwdp_disaggregated_serving.py b/tests/integration/defs/accuracy/test_dwdp_disaggregated_serving.py index eabf9a5e4045..ce1644c04711 100644 --- a/tests/integration/defs/accuracy/test_dwdp_disaggregated_serving.py +++ b/tests/integration/defs/accuracy/test_dwdp_disaggregated_serving.py @@ -217,10 +217,9 @@ def test_dwdp_accuracy(self): "backend": "CUTEDSL", }, "dwdp_config": { - "enabled": True, "dwdp_size": 2, - "num_group": 1, - "experts_per_worker": 36, + "num_groups": 1, + "num_experts_per_worker": 36, "num_prefetch_experts": 36, }, } From afcd735706bc8d8c2791b2b4991492369ecf5daa Mon Sep 17 00:00:00 2001 From: tianyuz-nv Date: Tue, 31 Mar 2026 10:35:10 -0700 Subject: [PATCH 12/12] test(api_stability): align llm.yaml dwdp_config with Optional + prototype Signed-off-by: tianyuz-nv --- tests/unittest/api_stability/references/llm.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unittest/api_stability/references/llm.yaml b/tests/unittest/api_stability/references/llm.yaml index 4db666edd805..3b774c8076c4 100644 --- a/tests/unittest/api_stability/references/llm.yaml +++ b/tests/unittest/api_stability/references/llm.yaml @@ -88,9 +88,9 @@ methods: default: null status: beta dwdp_config: - annotation: tensorrt_llm.llmapi.llm_args.DwdpConfig + annotation: Optional[tensorrt_llm.llmapi.llm_args.DwdpConfig] default: null - status: beta + status: prototype checkpoint_loader: annotation: Optional[tensorrt_llm._torch.models.checkpoints.BaseCheckpointLoader] default: null