diff --git a/cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp b/cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp index d3740100fff3..4d3b396e23c9 100644 --- a/cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp +++ b/cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp @@ -144,7 +144,7 @@ torch::Tensor moeA2AInitializeOp(torch::Tensor const& workspace, int64_t epRank, // Synchronize among ranks cudaDeviceSynchronize(); - tensorrt_llm::mpi::MpiComm::world().barrier(); + tensorrt_llm::mpi::MpiComm::session().barrier(); return metainfo; } diff --git a/examples/disaggregated/slurm/benchmark/disaggr_torch_dwdp.slurm b/examples/disaggregated/slurm/benchmark/disaggr_torch_dwdp.slurm new file mode 100644 index 000000000000..145c3099267b --- /dev/null +++ b/examples/disaggregated/slurm/benchmark/disaggr_torch_dwdp.slurm @@ -0,0 +1,189 @@ +#!/bin/bash +set -euo pipefail + +# Parse named arguments +while [[ $# -gt 0 ]]; do + case $1 in + # Benchmark Configuration + --benchmark-mode) benchmark_mode="$2"; shift 2 ;; + + # Environment and paths + --trtllm-repo) trtllm_repo="$2"; shift 2 ;; + --work-dir) work_dir="$2"; shift 2 ;; + --full-logdir) full_logdir="$2"; shift 2 ;; + --container-name) container_name="$2"; shift 2 ;; + --container-mount) container_mount="$2"; shift 2 ;; + --container-image) container_image="$2"; shift 2 ;; + --build-wheel) build_wheel="$2"; shift 2 ;; + --cuda-architectures) cuda_architectures="$2"; shift 2 ;; + --trtllm-wheel-path) trtllm_wheel_path="$2"; shift 2 ;; + *) + echo "Unknown argument: $1" + exit 1 + ;; + esac +done + +# Print all parsed arguments +echo "Parsed arguments:" +echo +echo "Benchmark Configuration:" +echo " benchmark_mode: ${benchmark_mode}" +echo +echo "Environment Configuration:" +echo " trtllm_repo: ${trtllm_repo}" +echo " work_dir: ${work_dir}" +echo " full_logdir: ${full_logdir}" +echo " container_mount: ${container_mount}" +echo " container_image: ${container_image}" +echo " build_wheel: ${build_wheel}" +echo " cuda_architectures: ${cuda_architectures}" +echo " trtllm_wheel_path: ${trtllm_wheel_path}" + +# Set TRTLLM_DISAGG_BENCHMARK_GEN_ONLY=1 for gen_only_no_context mode +if [ "${benchmark_mode}" = "gen_only_no_context" ]; then + export TRTLLM_DISAGG_BENCHMARK_GEN_ONLY=1 + echo "Setting TRTLLM_DISAGG_BENCHMARK_GEN_ONLY=1 for gen_only_no_context mode" +fi + +# Function to cleanup on failure +cleanup_on_failure() { + echo "Error: $1" + scancel ${SLURM_JOB_ID} + exit 1 +} + +replace_placeholder() { + file_path="$1" + all_nodes_str="$2" + new_file_path="$3" + cp "$file_path" "$new_file_path" + IFS=',' read -r -a node_array <<< "$all_nodes_str" + for i in "${!node_array[@]}"; do + current_val="${node_array[$i]}" + placeholder="" + + # Use sed to replace the placeholder with the value in-place + sed -i "s|$placeholder|$current_val|g" "${new_file_path}" + echo "Replaced $placeholder with $current_val in ${new_file_path}" + done +} + +env > ${full_logdir}/environment.txt + +# Start container +echo "Starting container..." +if ! srun -l --container-image=${container_image} \ + --container-name=${container_name} \ + --container-mounts=${container_mount} \ + --mpi=pmix \ + echo "Container up." &> ${full_logdir}/1_container_launch.log; then + cleanup_on_failure "Failed to start container. Check ${full_logdir}/1_container_launch.log" +fi + +# Install TensorRT-LLM +if [ -n "${trtllm_wheel_path}" ]; then + # Install from pre-built wheel if path is provided + echo "Installing TensorRT-LLM from wheel: ${trtllm_wheel_path}..." + if ! srun --container-name=${container_name} \ + --container-mounts=${container_mount} --no-container-mount-home \ + --mpi=pmix --overlap -N $SLURM_NNODES --ntasks-per-node=1 \ + bash -c "pip install ${trtllm_wheel_path}[devel]" \ + &> ${full_logdir}/2_install.log; then + cleanup_on_failure "TensorRT-LLM wheel installation failed. Check ${full_logdir}/2_install.log for details" + fi + echo "TensorRT-LLM wheel installation completed successfully" +elif [ -d "${trtllm_repo}" ]; then + # Build and install from repository if no wheel path provided + echo "Installing TensorRT-LLM from ${trtllm_repo}..." + TRT_LLM_GIT_COMMIT=$(git -C ${trtllm_repo} rev-parse --short HEAD 2>/dev/null || echo "unknown") + echo "TRT_LLM_GIT_COMMIT: ${TRT_LLM_GIT_COMMIT}" + + if [ "${build_wheel}" = "true" ]; then + echo "Building TensorRT-LLM wheel on one node..." + build_command="python3 ./scripts/build_wheel.py --trt_root /usr/local/tensorrt --benchmarks --use_ccache --clean" + if [ -n "${cuda_architectures:-}" ]; then + build_command="${build_command} --cuda_architectures \"${cuda_architectures}\"" + fi + if ! srun --container-name=${container_name} \ + --container-mounts=${container_mount} \ + --mpi=pmix --overlap -N 1 --ntasks-per-node=1 \ + bash -c "cd ${trtllm_repo} && ${build_command}" \ + &> ${full_logdir}/2_build.log; then + cleanup_on_failure "TensorRT-LLM build failed. Check ${full_logdir}/2_build.log for details" + fi + echo "TensorRT-LLM build completed successfully" + fi + + echo "Installing TensorRT-LLM..." + if ! srun --container-name=${container_name} \ + --container-mounts=${container_mount} --no-container-mount-home \ + --mpi=pmix --overlap -N $SLURM_NNODES --ntasks-per-node=1 \ + bash -c "cd ${trtllm_repo} && pip install -e .[devel]" \ + &> ${full_logdir}/2_install.log; then + cleanup_on_failure "TensorRT-LLM installation failed. Check ${full_logdir}/2_install.log for details" + fi + echo "TensorRT-LLM installation completed successfully" +else + echo "trtllm_wheel_path and trtllm_repo are not provided, will use the installed TensorRT-LLM from the container" + # get_env file is in the same directory as this script + get_env_file=${work_dir}/get_env.py + if ! srun --container-name=${container_name} \ + --container-mounts=${container_mount} --no-container-mount-home \ + --mpi=pmix --overlap -N 1 --ntasks-per-node=1 \ + bash -c "python ${get_env_file} -e ${full_logdir}/env_vars.json" \ + &> ${full_logdir}/2_get_env.log; then + cleanup_on_failure "Failed to get TensorRT-LLM environment variables. Check ${full_logdir}/2_get_env.log for details" + fi + echo "TensorRT-LLM environment variables saved to ${full_logdir}/env_vars.json" +fi + +# Get node lists and replace the placeholder with the actual node names +echo "SLURM_NODELIST: ${SLURM_NODELIST}" +all_nodes=($(scontrol show hostname $SLURM_NODELIST | sort)) +all_nodes_str=$(IFS=','; echo "${all_nodes[*]}") +echo "all_nodes_str: ${all_nodes_str}" + +start_server_cmds_base_file=${full_logdir}/start_server_cmds_base.sh +start_server_cmds_file=${full_logdir}/start_server_cmds.sh +replace_placeholder "${start_server_cmds_base_file}" "${all_nodes_str}" "${start_server_cmds_file}" +server_config_base_file=${full_logdir}/server_config_base.yaml +server_config_file=${full_logdir}/server_config.yaml +replace_placeholder "${server_config_base_file}" "${all_nodes_str}" "${server_config_file}" +mpi_worker_config_base_file=${full_logdir}/mpi_worker_config_base.yaml +mpi_worker_config_file=${full_logdir}/mpi_worker_config.yaml +if [ -f "${mpi_worker_config_base_file}" ]; then + replace_placeholder "${mpi_worker_config_base_file}" "${all_nodes_str}" "${mpi_worker_config_file}" +fi +client_cmds_base_file=${full_logdir}/client_cmds_base.sh +client_cmds_file=${full_logdir}/client_cmds.sh +replace_placeholder "${client_cmds_base_file}" "${all_nodes_str}" "${client_cmds_file}" + +# start the servers (skip ctx workers if TRTLLM_DISAGG_BENCHMARK_GEN_ONLY is set). +echo "Starting worker commands from ${start_server_cmds_file}..." +cat ${start_server_cmds_file} | while read cmd; do + # Skip ctx worker commands if in gen-only mode + # CTX appears as argument to start_worker.sh and in log filename + if [ "${TRTLLM_DISAGG_BENCHMARK_GEN_ONLY:-0}" = "1" ] && [[ "$cmd" == *"start_worker.sh CTX"* ]]; then + echo "Skipping ctx worker command (TRTLLM_DISAGG_BENCHMARK_GEN_ONLY is set): ${cmd}" + continue + fi + echo "Executing command: ${cmd}" + eval "${cmd}" +done +echo "Server is ready!" + +# Start client commands +echo "Starting client commands from ${client_cmds_file}..." +while read -r cmd <&3; do + echo "Starting client command: ${cmd}" + eval "${cmd}" + if [ $? -ne 0 ]; then + cleanup_on_failure "Command failed: ${cmd}." + fi +done 3< "${client_cmds_file}" + +echo "Job completed successfully, total runtime: $SECONDS seconds" + +# try to kill the server and workers +scancel ${SLURM_JOB_ID} diff --git a/examples/disaggregated/slurm/benchmark/start_worker_dwdp.sh b/examples/disaggregated/slurm/benchmark/start_worker_dwdp.sh new file mode 100644 index 000000000000..b0fe5bf88443 --- /dev/null +++ b/examples/disaggregated/slurm/benchmark/start_worker_dwdp.sh @@ -0,0 +1,61 @@ +#! /bin/bash +set -u +set -e +set -x + +config_file=${1} +numa_bind=${2} +log_dir=${3} +enable_nsys=${4} +ctx_profile_range=${5} +gen_profile_range=${6} +num_ctx_gpus=${7} +ctx_worker_env_var=${8} +gen_worker_env_var=${9} + +unset UCX_NET_DEVICES +unset UCX_TLS + +echo "SLURM_PROCID: ${SLURM_PROCID}, hostname: $(hostname)" + +if [ "${SLURM_PROCID}" -lt "${num_ctx_gpus}" ]; then + worker_role="CTX" + worker_env_var=${ctx_worker_env_var} + profile_range=${ctx_profile_range} +else + worker_role="GEN" + worker_env_var=${gen_worker_env_var} + profile_range=${gen_profile_range} +fi + +echo "worker_role: ${worker_role}, profile_range: ${profile_range}" + +for env_var in ${worker_env_var}; do + export "${env_var}" + echo "Exported: ${env_var}" +done + +if [ "${numa_bind}" = "true" ]; then + numa_bind_cmd="numactl -m 0,1" + echo "numactl -m 0,1 - Only allocate memory from nodes on GB200/GB300 NVL72" +else + numa_bind_cmd="" + echo "Not binding memory. If on GB200/GB300 NVL72, use \"numactl -m 0,1\" to only allocate memory from nodes." +fi + +echo "config_file: ${config_file}" + +nsys_prefix="" +if [ "${enable_nsys}" != "true" ]; then + echo "nsys is not enabled, start normal flow" +else + nsys_file=${log_dir}/nsys_worker_proc_${worker_role}_${SLURM_PROCID} + export TLLM_PROFILE_RECORD_GC=1 + export TLLM_NVTX_DEBUG=1 + export NSYS_MPI_STORE_TEAMS_PER_RANK=1 + export TLLM_PROFILE_START_STOP=${profile_range} + echo "nsys is enabled on ${worker_role} ranks, TLLM_PROFILE_START_STOP=${profile_range}" + nsys_prefix="nsys profile -o ${nsys_file} -f true -t cuda,nvtx,python-gil -c cudaProfilerApi --cuda-graph-trace node --capture-range-end=stop --gpu-metrics-devices=none" +fi + +${nsys_prefix} ${numa_bind_cmd} trtllm-serve disaggregated_mpi_worker -c ${config_file} diff --git a/examples/disaggregated/slurm/benchmark/submit.py b/examples/disaggregated/slurm/benchmark/submit.py index 716260fea333..a1bcedaf5087 100644 --- a/examples/disaggregated/slurm/benchmark/submit.py +++ b/examples/disaggregated/slurm/benchmark/submit.py @@ -105,10 +105,13 @@ def assign_servers( server_allocations[server_type][i] = server_allocation port += 1 - assign_servers(allocations, "GEN", num_gen_servers, gen_world_size, - gpus_per_node) + # Keep the allocation order aligned with disagg_utils, which builds + # server_configs as ctx_cfgs + gen_cfgs and assigns rank offsets in that + # same order during split_world_comm(). assign_servers(allocations, "CTX", num_ctx_servers, ctx_world_size, gpus_per_node) + assign_servers(allocations, "GEN", num_gen_servers, gen_world_size, + gpus_per_node) return allocations @@ -506,17 +509,13 @@ def submit_job(config, log_dir, dry_run): } } - # Generate start worker commands with placeholder hostnames for server_type in allocations.keys(): server_cfg = server_configs[server_type] for server_id in allocations[server_type].keys(): allocation = allocations[server_type][server_id] - # Get GPU IDs for this server from allocation - # When multi-node, all nodes have same device list, so use first node [0] gpu_ids = list(allocation["nodes"].values())[0] - # Build environment for this worker cuda_devices = ','.join(map(str, gpu_ids)) worker_env = build_worker_environment( worker_config=worker_config, @@ -529,7 +528,6 @@ def submit_job(config, log_dir, dry_run): ) export_str = format_export_string(worker_env) - # Use script_dir for start_worker.sh cmd = [ "srun -l", f"--nodelist {','.join(allocation['nodes'].keys())}", diff --git a/examples/disaggregated/slurm/benchmark/submit_dwdp.py b/examples/disaggregated/slurm/benchmark/submit_dwdp.py new file mode 100644 index 000000000000..cc0529f52fad --- /dev/null +++ b/examples/disaggregated/slurm/benchmark/submit_dwdp.py @@ -0,0 +1,490 @@ +#!/usr/bin/env python3 +"""Submit DWDP disaggregated benchmark jobs. + +This script handles the DWDP-specific submission flow which requires MPI-based +worker launching via ``trtllm-serve disaggregated_mpi_worker``. It reuses +shared utilities from ``submit.py`` for config parsing, GPU allocation, and +sbatch command construction. +""" + +import argparse +import glob +import json +import os +import shutil +import subprocess +import sys +import traceback +from datetime import datetime + +import yaml +from submit import ( + allocate_gpus, + build_server_environment, + calculate_nodes, + convert_allocations_to_server_config, + convert_envs_to_str, + format_export_string, + load_config, + replace_env_in_file, + save_env_file, + save_worker_config, +) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Submit DWDP disaggregated benchmark job") + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument("-c", "--config", type=str, help="Path to the configuration YAML file") + group.add_argument( + "-d", "--dir", type=str, help="Directory containing YAML configuration files" + ) + parser.add_argument("--log-dir", type=str, default=None, help="Log directory") + parser.add_argument( + "--dry-run", action="store_true", help="Dry run the Python part, test purpose only" + ) + return parser.parse_args() + + +def generate_mpi_worker_config( + worker_config, allocations, env_config, disagg_hostname, disagg_port, output_path +): + """Generate a config YAML compatible with ``trtllm-serve disaggregated_mpi_worker``.""" + + def _build_urls(server_type): + urls = [] + for server_id in sorted(allocations.get(server_type, {}).keys()): + inst = allocations[server_type][server_id] + host = list(inst["nodes"].keys())[0] + urls.append(f"{host}:{inst['port']}") + return urls + + ctx_urls = _build_urls("CTX") + gen_urls = _build_urls("GEN") + + ctx_section = dict(worker_config["ctx"]) + ctx_section["num_instances"] = len(ctx_urls) + ctx_section["urls"] = ctx_urls + + gen_section = dict(worker_config["gen"]) + gen_section["num_instances"] = len(gen_urls) + gen_section["urls"] = gen_urls + + config = { + "model": env_config["model_path"], + "hostname": disagg_hostname, + "port": disagg_port, + "backend": "pytorch", + "max_retries": 100, + "context_servers": ctx_section, + "generation_servers": gen_section, + } + + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, "w") as f: + yaml.dump(config, f, default_flow_style=False, sort_keys=False) + + +def submit_dwdp_job(config, log_dir, dry_run): + """Submit a DWDP disaggregated benchmark job.""" + slurm_config = config["slurm"] + slurm_config.setdefault("extra_args", "") + slurm_config.setdefault("set_segment", True) + + hw_config = config["hardware"] + env_config = config["environment"] + worker_config = config["worker_config"] + benchmark_config = config["benchmark"] + + if "work_dir" in env_config and os.path.isdir(env_config["work_dir"]): + script_dir = env_config["work_dir"] + else: + script_dir = os.path.dirname(os.path.abspath(__file__)) + + if "accuracy" not in config: + config["accuracy"] = { + "enable_accuracy_test": False, + "model": "local-completions", + "tasks": "gsm8k", + "model_args_extra": ( + "num_concurrent=512,max_retries=3," + "tokenized_requests=false,timeout=1200," + "max_gen_toks=256,max_length=4096" + ), + } + + env_config.setdefault("trtllm_repo", "") + env_config.setdefault("build_wheel", False) + env_config.setdefault("cuda_architectures", "") + env_config.setdefault("trtllm_wheel_path", "") + env_config.setdefault("worker_env_var", "") + env_config.setdefault("server_env_var", "") + + profiling_config = config.get("profiling", {}) + profiling_config.setdefault("nsys_on", False) + profiling_config.setdefault("ctx_profile_range", "10-30") + profiling_config.setdefault("gen_profile_range", "200-250") + + ctx_num = hw_config["num_ctx_servers"] + gen_num = hw_config["num_gen_servers"] + gpus_per_node = hw_config["gpus_per_node"] + + ctx_tp_size = worker_config["ctx"].get("tensor_parallel_size", 1) + ctx_cp_size = worker_config["ctx"].get("context_parallel_size", 1) + ctx_pp_size = worker_config["ctx"].get("pipeline_parallel_size", 1) + ctx_world_size = ctx_tp_size * ctx_cp_size * ctx_pp_size + ctx_nodes = calculate_nodes(ctx_world_size, ctx_num, gpus_per_node) + + gen_tp_size = worker_config["gen"].get("tensor_parallel_size", 1) + gen_cp_size = worker_config["gen"].get("context_parallel_size", 1) + gen_pp_size = worker_config["gen"].get("pipeline_parallel_size", 1) + gen_world_size = gen_tp_size * gen_cp_size * gen_pp_size + gen_nodes = calculate_nodes(gen_world_size, gen_num, gpus_per_node) + ucx_warmup_requests = ( + 2 * ctx_world_size * gen_world_size if benchmark_config["mode"] == "e2e" else 0 + ) + + total_nodes = ctx_nodes + gen_nodes + total_tasks = total_nodes * gpus_per_node + + dwdp_size = worker_config.get("ctx", {}).get("dwdp_config", {}).get("dwdp_size", 1) + + isl = benchmark_config["input_length"] + osl = benchmark_config["output_length"] + gen_batch_size = worker_config["gen"]["max_batch_size"] + + load_balancer_config = worker_config["gen"].get("moe_config", {}).get("load_balancer", {}) + if isinstance(load_balancer_config, str): + with open(load_balancer_config, "r") as f: + load_balancer_config = yaml.safe_load(f) + eplb_num_slots = load_balancer_config.get("num_slots", 0) + + mtp_size = worker_config["gen"].get("speculative_config", {}).get("num_nextn_predict_layers", 0) + + if "log_dir" in env_config and env_config["log_dir"]: + log_dir = env_config["log_dir"] + if log_dir is None: + log_base = os.path.join(script_dir, "logs") + + date_prefix = datetime.now().strftime("%Y%m%d-%H%M%S") + log_base = os.path.join(log_base, f"{date_prefix}/{isl}-{osl}") + + dir_suffix = ( + f"disagg_ctx{ctx_num}_dwdp{dwdp_size}_gen{gen_num}" + f"_dep{gen_tp_size}_batch{gen_batch_size}" + f"_eplb{eplb_num_slots}_mtp{mtp_size}" + ) + + log_dir = os.path.join(log_base, dir_suffix) + + if os.path.exists(log_dir): + if not os.path.exists(os.path.join(log_dir, "trtllm_config.yaml")): + print(f"[WARNING] Removing existing log directory: {log_dir}") + shutil.rmtree(log_dir) + else: + print(f"[WARNING] trtllm_config.yaml exists, not removing the directory: {log_dir}") + for file in os.listdir(log_dir): + if file != "trtllm_config.yaml" and not file.startswith("concurrency_"): + if os.path.isdir(os.path.join(log_dir, file)): + shutil.rmtree(os.path.join(log_dir, file)) + else: + os.remove(os.path.join(log_dir, file)) + os.makedirs(log_dir, exist_ok=True) + print(f"Log will be saved to: {log_dir}") + + ctx_config_path = os.path.join(log_dir, "ctx_config.yaml") + gen_config_path = os.path.join(log_dir, "gen_config.yaml") + save_worker_config(worker_config["ctx"], ctx_config_path) + save_worker_config(worker_config["gen"], gen_config_path) + + allocations = allocate_gpus( + total_nodes=total_nodes, + gpus_per_node=gpus_per_node, + num_gen_servers=gen_num, + num_ctx_servers=ctx_num, + gen_world_size=gen_world_size, + ctx_world_size=ctx_world_size, + ) + with open(os.path.join(log_dir, "allocations.json"), "w") as f: + json.dump(allocations, f, indent=2) + + server_config = convert_allocations_to_server_config(allocations) + with open(os.path.join(log_dir, "server_config_base.yaml"), "w") as f: + yaml.dump(server_config, f) + disagg_server_hostname = server_config["hostname"] + disagg_server_port = server_config["port"] + + container_name = "disaggr-test" + start_server_cmds = [] + container_mount_str = env_config["container_mount"] + container_mount_str += f",{script_dir}:{script_dir}" + + # --- DWDP mode: single srun with disaggregated_mpi_worker --- + mpi_config_base_path = os.path.join(log_dir, "mpi_worker_config_base.yaml") + mpi_config_path = os.path.join(log_dir, "mpi_worker_config.yaml") + generate_mpi_worker_config( + worker_config, + allocations, + env_config, + disagg_server_hostname, + disagg_server_port, + mpi_config_base_path, + ) + + ctx_node_list = [] + for sid in sorted(allocations.get("CTX", {}).keys()): + for node in allocations["CTX"][sid]["nodes"]: + if node not in ctx_node_list: + ctx_node_list.append(node) + gen_node_list = [] + for sid in sorted(allocations.get("GEN", {}).keys()): + for node in allocations["GEN"][sid]["nodes"]: + if node not in gen_node_list: + gen_node_list.append(node) + mpi_nodelist = ctx_node_list + gen_node_list + total_mpi_tasks = ctx_num * ctx_world_size + gen_num * gen_world_size + mpi_num_nodes = len(mpi_nodelist) + num_ctx_gpus = ctx_num * ctx_world_size + worker_env_var = env_config.get("worker_env_var", "") + ctx_worker_env_var = env_config.get("ctx_worker_env_var", "") + gen_worker_env_var = env_config.get("gen_worker_env_var", "") + dwdp_ctx_worker_env_var = worker_env_var + ( + f" {ctx_worker_env_var}" if ctx_worker_env_var else "" + ) + dwdp_gen_worker_env_var = worker_env_var + ( + f" {gen_worker_env_var}" if gen_worker_env_var else "" + ) + + cmd = [ + "srun -l", + f"--nodelist {','.join(mpi_nodelist)}", + f"-N {mpi_num_nodes}", + f"--ntasks {total_mpi_tasks}", + f"--ntasks-per-node {gpus_per_node}", + f"--container-image {env_config['container_image']}", + f"--container-name {container_name}", + f"--container-mounts {container_mount_str}", + "--no-container-mount-home --mpi=pmix --overlap", + f"bash {os.path.join(script_dir, 'start_worker_dwdp.sh')}", + mpi_config_path, + str(slurm_config["numa_bind"]).lower(), + log_dir, + str(profiling_config["nsys_on"]).lower(), + f"'{profiling_config['ctx_profile_range']}'", + f"'{profiling_config['gen_profile_range']}'", + str(num_ctx_gpus), + f"'{dwdp_ctx_worker_env_var}'", + f"'{dwdp_gen_worker_env_var}'", + f"&> {log_dir}/3_output_workers.log &", + ] + start_server_cmds.append(" ".join(cmd)) + + # Generate start server commands + server_env = build_server_environment(env_config, benchmark_config["mode"]) + export_str = format_export_string(server_env) + + cmd = [ + "srun -l", + f"--nodelist {disagg_server_hostname}", + f"--container-name={container_name}", + f'--export="{export_str}"', + f"--container-image={env_config['container_image']}", + f"--container-mounts={container_mount_str}", + "--no-container-mount-home --mpi=pmix --overlap -N 1 -n 1", + f"bash {os.path.join(script_dir, 'start_server.sh')} {os.path.join(log_dir, 'server_config.yaml')}", + f"&> {log_dir}/4_output_server.log &", + ] + start_server_cmds.append(" ".join(cmd)) + + save_env_file( + os.path.join(log_dir, "env_vars.json"), + env_config.get("server_env_var", ""), + env_config.get("worker_env_var", ""), + env_config.get("ctx_worker_env_var", ""), + env_config.get("gen_worker_env_var", ""), + ) + + # Generate wait server command + cmd = [ + "srun -l", + f"--container-name={container_name}", + f"--container-mounts={container_mount_str}", + "--mpi=pmix --overlap -N 1 -n 1", + f"bash {os.path.join(script_dir, 'wait_server.sh')} {disagg_server_hostname} {disagg_server_port}", + f"&> {log_dir}/5_wait_server.log", + ] + start_server_cmds.append(" ".join(cmd)) + + with open(os.path.join(log_dir, "start_server_cmds_base.sh"), "w") as f: + f.write("\n".join(start_server_cmds) + "\n") + + # Generate client commands + client_cmds = [] + client_slurm_prefix = [ + f"srun -l --container-name={container_name}", + f"--container-mounts={container_mount_str}", + "--mpi=pmix --overlap -N 1 -n 1", + ] + if benchmark_config.get("enable_benchmark", True): + env_var = config["benchmark"].get("env_var", {}) + benchmark_prefix = client_slurm_prefix + [f'--export "{convert_envs_to_str(env_var)}"'] + if benchmark_config["use_nv_sa_benchmark"]: + if benchmark_config["mode"] == "gen_only": + print("[ERROR] SA benchmark client script is not supported for gen_only mode") + sys.exit(1) + benchmark_cmd = [ + f"bash {os.path.join(script_dir, 'run_benchmark_nv_sa.sh')}", + ( + f"'{env_config['model_path']}' {isl} {osl}" + f" {benchmark_config['benchmark_ratio']}" + f" {benchmark_config['multi_round']} {gen_num}" + f" '{benchmark_config['concurrency_list']}'" + f" {benchmark_config['streaming']} '{log_dir}'" + f" {disagg_server_hostname} {disagg_server_port}" + f" {ucx_warmup_requests}" + ), + f"&> {log_dir}/6_bench.log", + ] + client_cmds.append(" ".join(benchmark_prefix + benchmark_cmd)) + else: + benchmark_cmd = [ + f"bash {os.path.join(script_dir, 'run_benchmark.sh')}", + ( + f"'{env_config['model_path']}'" + f" '{benchmark_config['dataset_file']}'" + f" {benchmark_config['multi_round']} {gen_num}" + f" '{benchmark_config['concurrency_list']}'" + f" {benchmark_config['streaming']} '{log_dir}'" + f" {disagg_server_hostname} {disagg_server_port}" + f" {ucx_warmup_requests}" + ), + f"&> {log_dir}/6_bench.log", + ] + client_cmds.append(" ".join(benchmark_prefix + benchmark_cmd)) + + if config["accuracy"]["enable_accuracy_test"]: + env_var = config["accuracy"].get("env_var", {}) + accuracy_prefix = client_slurm_prefix + [f'--export "{convert_envs_to_str(env_var)}"'] + for task in config["accuracy"]["tasks"]: + extra_kwargs = config["accuracy"]["tasks"][task].get("extra_kwargs", {}) + extra_kwargs_str = "" + for key, value in extra_kwargs.items(): + if isinstance(value, bool): + if value: + extra_kwargs_str += f" --{key}" + elif key == "custom_config": + extra_kwargs_str += ( + f" --include_path={replace_env_in_file(log_dir, value, env_var)}" + ) + else: + extra_kwargs_str += f" --{key}='{value}'" + end_point_map = { + "local-completions": "v1/completions", + "local-chat-completions": "v1/chat/completions", + } + model = config["accuracy"]["tasks"][task]["model"] + accuracy_cmd = [ + "lm_eval", + "--model", + model, + "--tasks", + task, + "--model_args", + f"model={env_config['model_path']},base_url=http://{disagg_server_hostname}:{disagg_server_port}/{end_point_map[model]},{config['accuracy']['tasks'][task]['model_args_extra']}", + "--log_samples", + "--output_path", + f"{log_dir}/accuracy_eval_{task}", + extra_kwargs_str, + f"&> {log_dir}/7_accuracy_eval_{task}.log", + ] + client_cmds.append(" ".join(accuracy_prefix + accuracy_cmd)) + + done_cmd = ["echo", "${SLURM_JOB_NODELIST}", ">", f"{log_dir}/8_done_${{SLURM_JOB_ID}}.txt"] + client_cmds.append(" ".join(done_cmd)) + + with open(os.path.join(log_dir, "client_cmds_base.sh"), "w") as f: + f.write("\n".join(client_cmds) + "\n") + + slurm_script_file = slurm_config["script_file"] + if not os.path.isabs(slurm_script_file): + slurm_script_file = os.path.join(script_dir, slurm_script_file) + + if not os.path.exists(slurm_script_file): + print(f"[ERROR] SLURM script file not found: {slurm_script_file}", file=sys.stderr) + sys.exit(1) + + # yapf: disable + cmd = [ + 'sbatch', + f'--partition={slurm_config["partition"]}', + f'--account={slurm_config["account"]}', + f'--time={slurm_config["job_time"]}', + f'--job-name={slurm_config["job_name"]}', + f'--nodes={total_nodes}', + f'--ntasks={total_tasks}', + f'--ntasks-per-node={hw_config["gpus_per_node"]}', + *([] if not slurm_config['set_segment'] + else [f'--segment={total_nodes}']), + f'--output={log_dir}/slurm-%j.out', + f'--error={log_dir}/slurm-%j.err', + *([arg for arg in slurm_config['extra_args'].split() if arg]), + slurm_script_file, + + '--benchmark-mode', benchmark_config['mode'], + + '--trtllm-repo', env_config['trtllm_repo'], + '--work-dir', script_dir, + '--full-logdir', log_dir, + '--container-name', container_name, + '--container-mount', container_mount_str, + '--container-image', env_config['container_image'], + '--build-wheel', str(env_config['build_wheel']).lower(), + '--cuda-architectures', env_config['cuda_architectures'], + '--trtllm-wheel-path', env_config['trtllm_wheel_path'], + ] + # yapf: enable + + if dry_run: + print( + "[WARNING] Dry run mode, will not submit the job. This should be used for test purpose only." + ) + print("sbatch command:") + print(" ".join(cmd)) + return + else: + try: + subprocess.run(cmd, check=True) + except subprocess.CalledProcessError as e: + print(f"Error submitting job: {e}", file=sys.stderr) + sys.exit(1) + + +def main(): + args = parse_args() + + if args.config: + config_files = [args.config] + else: + yaml_pattern = os.path.join(args.dir, "*.yaml") + config_files = sorted(glob.glob(yaml_pattern)) + + if not config_files: + print(f"No YAML files found in directory: {args.dir}", file=sys.stderr) + sys.exit(1) + + print(f"Found {len(config_files)} YAML file(s) in {args.dir}") + + for config_file in config_files: + print(f"Processing: {config_file}") + try: + config = load_config(config_file) + submit_dwdp_job(config, args.log_dir, args.dry_run) + print(f"Successfully submitted job for: {config_file}\n") + except Exception as e: + traceback.print_exc() + print(f"Error processing {config_file}: {e}", file=sys.stderr) + continue + + +if __name__ == "__main__": + main() diff --git a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py index 065b1faa9129..f91d331dbe52 100644 --- a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py @@ -233,23 +233,23 @@ class GatherGroupedGemmInputsHelper(GroupedGemmInputsHelper): - permuted_idx_to_expanded_idx specifies the gather pattern - Shape inference uses permuted_idx_to_expanded_idx size instead of a size - Input tensor layout: - 0: a - Original input activation (not permuted) - 1: b - Weight tensor - 2: a_sf - Scale factor for a - 3: b_sf - Scale factor for b - 4: alpha - Per-expert scaling factor - 5: tile_idx_to_group_idx - Tile to expert mapping - 6: tile_idx_to_mn_limit - Tile M/N limits - 7: permuted_idx_to_expanded_idx - Token permutation mapping - 8: num_non_exiting_tiles - Number of valid tiles - 9: global_sf - Global scale factor + Input layout (positions 1, 3, 4 are lists for multi-B support): + 0: a - tensor, original input activation + 1: b_list - list of tensors, weight tensors + 2: a_sf - tensor, scale factor for a + 3: b_sf_list - list of tensors, scale factors for b + 4: alpha_list - list of tensors, per-expert scaling factors + 5: tile_idx_to_group_idx - tensor, tile to expert mapping + 6: tile_idx_to_mn_limit - tensor, tile M/N limits + 7: permuted_idx_to_expanded_idx - tensor, token permutation mapping + 8: num_non_exiting_tiles - tensor, number of valid tiles + 9: global_sf - tensor, global scale factor """ # Override: use permuted_idx_to_expanded_idx for shape inference IDX_PERMUTED_IDX_TO_EXPANDED_IDX = 7 IDX_SHAPE_INFER = IDX_PERMUTED_IDX_TO_EXPANDED_IDX - def inputs_pre_hook(self, inputs: List[torch.Tensor]) -> List[torch.Tensor]: + def inputs_pre_hook(self, inputs: List) -> List: """Pre-hook for gather-based SwiGLU fusion kernel. Generates: @@ -257,9 +257,22 @@ def inputs_pre_hook(self, inputs: List[torch.Tensor]) -> List[torch.Tensor]: - tile_idx_to_mn_limit - permuted_idx_to_expanded_idx (for gather operation) - num_non_exiting_tiles + + Input layout (positions 1, 3, 4 are lists): + 0: a - tensor + 1: b_list - list of tensors + 2: a_sf - tensor + 3: b_sf_list - list of tensors + 4: alpha_list - list of tensors + 5: tile_idx_to_group_idx - tensor + 6: tile_idx_to_mn_limit - tensor + 7: permuted_idx_to_expanded_idx - tensor + 8: num_non_exiting_tiles - tensor + 9: global_sf - tensor """ - a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, tile_idx_to_mn_limit, \ - permuted_idx_to_expanded_idx, num_non_exiting_tiles, global_sf = inputs + a, b_list, a_sf, b_sf_list, alpha_list, tile_idx_to_group_idx, \ + tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, \ + num_non_exiting_tiles, global_sf = inputs # Verify permuted_idx_to_expanded_idx index matches the class constant assert inputs[ self. @@ -291,7 +304,7 @@ def inputs_pre_hook(self, inputs: List[torch.Tensor]) -> List[torch.Tensor]: local_num_experts=self.num_local_experts, tile_tokens_dim=self.tile_size, ) - return (a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, + return (a, b_list, a_sf, b_sf_list, alpha_list, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, global_sf) @@ -873,8 +886,10 @@ def get_valid_tactics( **kwargs, ) -> List[Tuple[int, int]]: a, b, *_ = inputs + b_list = b if isinstance(b, (list, tuple)) else [b] m, k = a.size(0), a.size(1) * 2 - l, n = b.size(0), b.size(1) + l = sum(bi.size(0) for bi in b_list) + n = b_list[0].size(1) mma_tiler_mn_candidates = [(self.tile_size, 128), (self.tile_size, 256)] @@ -1130,7 +1145,8 @@ def __init__(self, local_expert_offset: int, tile_size: int, output_dtype: torch.dtype, - scaling_vector_size: int = 16): + scaling_vector_size: int = 16, + b_tensor_l_sizes: Optional[Tuple[int, ...]] = None): super().__init__() self.num_experts = num_experts self.top_k = top_k @@ -1141,6 +1157,7 @@ def __init__(self, assert output_dtype == torch.bfloat16 self.output_dtype = output_dtype self.scaling_vector_size = scaling_vector_size + self.b_tensor_l_sizes = b_tensor_l_sizes if (sm_version := get_sm_version()) not in (100, 103): raise ValueError( @@ -1161,6 +1178,7 @@ def unique_id(self): self.tile_size, self.output_dtype, self.scaling_vector_size, + self.b_tensor_l_sizes, ) def get_valid_tactics( @@ -1169,9 +1187,12 @@ def get_valid_tactics( profile: OptimizationProfile, **kwargs, ) -> List[Tuple[int, int]]: - a, b, *_ = inputs + a, b_list, *_ = inputs + if not isinstance(b_list, (list, tuple)): + raise TypeError("weight must be a list of tensors") m, k = a.size(0), a.size(1) * 2 - l, n = b.size(0), b.size(1) + l = sum(bi.size(0) for bi in b_list) + n = b_list[0].size(1) mma_tiler_mn_candidates = [(self.tile_size, 128), (self.tile_size, 256)] @@ -1237,29 +1258,45 @@ def get_tuning_config(self) -> TuningConfig: def forward(self, inputs: List[torch.Tensor], tactic: Optional[tuple]) -> torch.Tensor: - a, b, a_sf, b_sf, alpha, c, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, token_final_scales = inputs + a, b_list, a_sf, b_sf_list, alpha_list, c, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, token_final_scales = inputs + if not isinstance(b_list, (list, tuple)): + raise TypeError("weight must be a list of tensors") + if not isinstance(b_sf_list, (list, tuple)): + raise TypeError("weight_scale must be a list of tensors") + if not isinstance(alpha_list, (list, tuple)): + raise TypeError("alpha must be a list of tensors") + assert len(b_list) == len(b_sf_list) == len(alpha_list) + b_tensor_l_sizes = tuple(bi.size(0) for bi in b_list) + + b0 = b_list[0] + b_sf0 = b_sf_list[0] + alpha0 = alpha_list[0] assert a.dtype == torch.float4_e2m1fn_x2 assert a.dim() == 2 - assert b.dtype == torch.float4_e2m1fn_x2 - assert b.dim() == 3 + assert b0.dtype == torch.float4_e2m1fn_x2 + assert b0.dim() == 3 assert a_sf.dtype == torch.uint8 assert a_sf.dim() == 1 - assert b_sf.dtype == torch.uint8 - assert b_sf.dim() == 3 - assert alpha.dtype == torch.float32 - assert alpha.dim() == 1 + assert b_sf0.dtype == torch.uint8 + assert b_sf0.dim() == 3 + assert alpha0.dtype == torch.float32 + assert alpha0.dim() == 1 m, k = a.size(0), a.size(1) * 2 - l, n = b.size(0), b.size(1) + sum(bi.size(0) for bi in b_list) + n = b0.size(1) scale_k = k // self.scaling_vector_size assert m % self.tile_size == 0 assert k % (self.scaling_vector_size * 4) == 0 - assert b.size(2) * 2 == k + assert b0.size(2) * 2 == k assert a_sf.size(0) == m * scale_k - assert b_sf.size(0) == l - assert b_sf.size(1) == n - assert b_sf.size(2) == scale_k - assert alpha.size(0) == l + for bi, bsfi, ai in zip(b_list, b_sf_list, alpha_list): + assert bi.size(1) == n + assert bi.size(2) * 2 == k + assert bsfi.size(0) == bi.size(0) + assert bsfi.size(1) == n + assert bsfi.size(2) == scale_k + assert ai.size(0) == bi.size(0) assert c.dtype == self.output_dtype assert c.dim() == 2 @@ -1283,20 +1320,10 @@ def forward(self, inputs: List[torch.Tensor], a.data_ptr(), cute.AddressSpace.gmem, assumed_align=32) - b_ptr = make_ptr(cutlass.Float4E2M1FN, - b.data_ptr(), - cute.AddressSpace.gmem, - assumed_align=32) a_sf_ptr = make_ptr(cutlass.Float8E4M3FN, a_sf.data_ptr(), cute.AddressSpace.gmem, assumed_align=16) - b_sf_ptr = make_ptr(cutlass.Float8E4M3FN, - b_sf.data_ptr(), - cute.AddressSpace.gmem, - assumed_align=16) - alpha_ptr = make_ptr(cutlass.Float32, alpha.data_ptr(), - cute.AddressSpace.gmem) tile_idx_to_group_idx_ptr = make_ptr( cutlass.Int32, tile_idx_to_group_idx.data_ptr(), cute.AddressSpace.gmem) @@ -1317,6 +1344,20 @@ def forward(self, inputs: List[torch.Tensor], cute.AddressSpace.gmem, assumed_align=16) + b_ptr = tuple( + make_ptr(cutlass.Float4E2M1FN, + bi.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32) for bi in b_list) + b_sf_ptr = tuple( + make_ptr(cutlass.Float8E4M3FN, + bsfi.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16) for bsfi in b_sf_list) + alpha_ptr = tuple( + make_ptr(cutlass.Float32, ai.data_ptr(), cute.AddressSpace.gmem) + for ai in alpha_list) + torch_stream = torch.cuda.current_stream() stream = cuda.CUstream(torch_stream.cuda_stream) @@ -1330,7 +1371,7 @@ def forward(self, inputs: List[torch.Tensor], 0] == self.tile_size, f"Tactic ({tactic}) is incompatible with tile size ({self.tile_size})" cache_key = (self.scaling_vector_size, self.tile_size, mma_tiler_mn, - cluster_shape_mn, raster_along_m) + cluster_shape_mn, raster_along_m, b_tensor_l_sizes) if cache_key not in self.__class__.kernel_cache: gemm = self.__class__.kernel_class( sf_vec_size=self.scaling_vector_size, @@ -1338,14 +1379,14 @@ def forward(self, inputs: List[torch.Tensor], cluster_shape_mn=cluster_shape_mn, use_blkred=True, raster_along_m=raster_along_m, + b_tensor_l_sizes=b_tensor_l_sizes, ) # Compute max active clusters on current device hardware_info = cutlass.utils.HardwareInfo() max_active_clusters = hardware_info.get_max_active_clusters( cluster_shape_mn[0] * cluster_shape_mn[1]) - compiled_gemm = cute.compile( - gemm.wrapper, + compile_args = [ a_ptr, b_ptr, a_sf_ptr, @@ -1360,9 +1401,13 @@ def forward(self, inputs: List[torch.Tensor], m, n, k, - l, num_tokens, self.top_k, + ] + + compiled_gemm = cute.compile( + gemm.wrapper, + *compile_args, tile_size=self.tile_size, scaling_vector_size=self.scaling_vector_size, max_active_clusters=max_active_clusters, @@ -1372,7 +1417,7 @@ def forward(self, inputs: List[torch.Tensor], else: compiled_gemm = self.__class__.kernel_cache[cache_key] - compiled_gemm( + exec_args = [ a_ptr, b_ptr, a_sf_ptr, @@ -1387,11 +1432,10 @@ def forward(self, inputs: List[torch.Tensor], m, n, k, - l, num_tokens, self.top_k, - stream=stream, - ) + ] + compiled_gemm(*exec_args, stream=stream) return c @torch.library.custom_op( @@ -1400,10 +1444,10 @@ def forward(self, inputs: List[torch.Tensor], device_types="cuda") def cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell( input: torch.Tensor, - weight: torch.Tensor, + weight: List[torch.Tensor], input_scale: torch.Tensor, - weight_scale: torch.Tensor, - alpha: torch.Tensor, + weight_scale: List[torch.Tensor], + alpha: List[torch.Tensor], output: torch.Tensor, tile_idx_to_group_idx: torch.Tensor, tile_idx_to_mn_limit: torch.Tensor, @@ -1420,9 +1464,11 @@ def cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell( ) -> None: tuner = AutoTuner.get() + b_tensor_l_sizes = tuple(w.size(0) + for w in weight) if len(weight) > 1 else None runner = Sm100BlockScaledContiguousGroupedGemmFinalizeFusionRunner( num_experts, top_k, num_local_experts, local_expert_offset, - tile_size, output_dtype, scaling_vector_size) + tile_size, output_dtype, scaling_vector_size, b_tensor_l_sizes) inputs = [ input, weight, input_scale, weight_scale, alpha, output, @@ -1445,10 +1491,10 @@ def cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell( device_types="cuda") def cute_dsl_nvfp4_grouped_gemm_finalize_blackwell( input: torch.Tensor, - weight: torch.Tensor, + weight: List[torch.Tensor], input_scale: torch.Tensor, - weight_scale: torch.Tensor, - alpha: torch.Tensor, + weight_scale: List[torch.Tensor], + alpha: List[torch.Tensor], tile_idx_to_group_idx: torch.Tensor, tile_idx_to_mn_limit: torch.Tensor, permuted_idx_to_expanded_idx: torch.Tensor, @@ -1463,7 +1509,7 @@ def cute_dsl_nvfp4_grouped_gemm_finalize_blackwell( scaling_vector_size: int = 16, ) -> torch.Tensor: num_tokens = token_final_scales.size(0) - n = weight.size(1) + n = weight[0].size(1) output = torch.zeros(num_tokens, n, dtype=output_dtype, @@ -1494,10 +1540,10 @@ def cute_dsl_nvfp4_grouped_gemm_finalize_blackwell( "trtllm::cute_dsl_nvfp4_grouped_gemm_finalize_blackwell") def _( input: torch.Tensor, - weight: torch.Tensor, + weight: List[torch.Tensor], input_scale: torch.Tensor, - weight_scale: torch.Tensor, - alpha: torch.Tensor, + weight_scale: List[torch.Tensor], + alpha: List[torch.Tensor], tile_idx_to_group_idx: torch.Tensor, tile_idx_to_mn_limit: torch.Tensor, permuted_idx_to_expanded_idx: torch.Tensor, @@ -1512,7 +1558,7 @@ def _( scaling_vector_size: int = 16, ) -> torch.Tensor: num_tokens = token_final_scales.size(0) - n = weight.size(1) + n = weight[0].size(1) return torch.empty(num_tokens, n, dtype=output_dtype, @@ -1843,13 +1889,23 @@ class Sm100BlockScaledContiguousGatherGroupedGemmSwigluFusionRunner( kernel_cache = dict() tuning_config_cache = dict() + # Maximum number of B tensors supported (must match kernel's MAX_B_TENSORS) + MAX_B_TENSORS = 4 + def __init__(self, num_experts: int, top_k: int, num_local_experts: int, local_expert_offset: int, tile_size: int, - scaling_vector_size: int = 16): + scaling_vector_size: int = 16, + b_tensor_l_sizes: Optional[Tuple[int, ...]] = None): + """Initialize the runner. + + Args: + b_tensor_l_sizes: Tuple of L sizes for each B tensor in multi-B mode. + None for single-B mode. Used for kernel cache key. + """ super().__init__() self.num_experts = num_experts self.top_k = top_k @@ -1861,6 +1917,7 @@ def __init__(self, ) self.tile_size = tile_size self.scaling_vector_size = scaling_vector_size + self.b_tensor_l_sizes = b_tensor_l_sizes if (sm_version := get_sm_version()) not in (100, 103): raise ValueError( @@ -1880,19 +1937,24 @@ def unique_id(self): self.local_expert_offset, self.tile_size, self.scaling_vector_size, + self.b_tensor_l_sizes, ) def get_valid_tactics( self, - inputs: List[torch.Tensor], + inputs: List, profile: OptimizationProfile, **kwargs, ) -> List[Tuple[int, int]]: - a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, *_ = inputs + # Tuning uses layout: a, b_list, a_sf, b_sf_list, alpha_list, ... + a = inputs[0] + b_list = inputs[1] # List of B tensors + permuted_idx_to_expanded_idx = inputs[7] # m is the permuted size from permuted_idx_to_expanded_idx, not from a m = permuted_idx_to_expanded_idx.size(0) k = a.size(1) * 2 - l, n = b.size(0), b.size(1) + l = sum(bi.size(0) for bi in b_list) + n = b_list[0].size(1) mma_tiler_mn_candidates = [(self.tile_size, 128), (self.tile_size, 256)] @@ -1932,6 +1994,9 @@ def get_tuning_config(self) -> TuningConfig: self.num_local_experts, self.local_expert_offset, self.tile_size) + # Tuning uses layout: + # a, b_list, a_sf, b_sf_list, alpha_list, tile_idx, tile_mn_limit, permuted_idx, ... + # Constraint indices adjusted for list inputs at positions 1, 3, 4 self.__class__.tuning_config_cache[key] = TuningConfig( # Use permuted_idx_to_expanded_idx (IDX_SHAPE_INFER) for tuning dynamic_tensor_specs=(DynamicTensorSpec( @@ -1953,41 +2018,57 @@ def get_tuning_config(self) -> TuningConfig: ) return self.__class__.tuning_config_cache[key] - def forward(self, inputs: List[torch.Tensor], + def forward(self, inputs: List, tactic: Optional[tuple]) -> torch.Tensor: - a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, global_sf = inputs - # Verify permuted_idx_to_expanded_idx index matches the class constant - assert inputs[ - GatherGroupedGemmInputsHelper. - IDX_PERMUTED_IDX_TO_EXPANDED_IDX] is permuted_idx_to_expanded_idx + """Forward pass supporting both single tensor and list inputs. + + Input layout (positions 1, 3, 4 are lists for multi-B support): + 0: a - tensor + 1: b_list - list of tensors + 2: a_sf - tensor + 3: b_sf_list - list of tensors + 4: alpha_list - list of tensors + 5: tile_idx_to_group_idx - tensor + 6: tile_idx_to_mn_limit - tensor + 7: permuted_idx_to_expanded_idx - tensor + 8: num_non_exiting_tiles - tensor + 9: global_sf - tensor + """ + a, b_list, a_sf, b_sf_list, alpha_list, tile_idx_to_group_idx, \ + tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, \ + num_non_exiting_tiles, global_sf = inputs + + b_tensor_l_sizes = tuple(bi.size(0) for bi in b_list) + + b0 = b_list[0] # Use first B for shape inference + + # Verify input dtypes and dimensions assert a.dtype == torch.float4_e2m1fn_x2 assert a.dim() == 2 - assert b.dtype == torch.float4_e2m1fn_x2 - assert b.dim() == 3 + assert b0.dtype == torch.float4_e2m1fn_x2 + assert b0.dim() == 3 assert a_sf.dtype == torch.uint8 assert a_sf.dim() == 2 - assert b_sf.dtype == torch.uint8 - assert b_sf.dim() == 3 - assert alpha.dtype == torch.float32 - assert alpha.dim() == 1 + assert b_sf_list[0].dtype == torch.uint8 + assert b_sf_list[0].dim() == 3 + assert alpha_list[0].dtype == torch.float32 + assert alpha_list[0].dim() == 1 # a.size(0) is orig_m (original input size before gather) # permuted_idx_to_expanded_idx.size(0) is m (permuted size after gather) orig_m, k = a.size(0), a.size(1) * 2 m = permuted_idx_to_expanded_idx.size(0) - l, n = b.size(0), b.size(1) + n = b0.size(1) + sum(bi.size(0) for bi in b_list) scale_k = k // self.scaling_vector_size interm_size = n // 2 + assert m % self.tile_size == 0 assert k % (self.scaling_vector_size * 4) == 0 assert n % (self.scaling_vector_size * 4 * 2) == 0 - assert b.size(2) * 2 == k + assert b0.size(2) * 2 == k assert a_sf.size(0) == orig_m assert a_sf.size(1) == scale_k - assert b_sf.size(0) == l - assert b_sf.size(1) == n - assert b_sf.size(2) == scale_k - assert alpha.size(0) == l num_tiles = m // self.tile_size assert tile_idx_to_group_idx.dtype == torch.int32 @@ -2001,29 +2082,29 @@ def forward(self, inputs: List[torch.Tensor], assert global_sf.dtype == torch.float32 assert global_sf.numel() == 1 + # Allocate output tensors c = torch.empty(m, interm_size // 2, dtype=a.dtype, device=a.device) c_sf = torch.empty(m * interm_size // self.scaling_vector_size, dtype=a_sf.dtype, device=a_sf.device) + # Create common pointers a_ptr = make_ptr(cutlass.Float4E2M1FN, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=32) - b_ptr = make_ptr(cutlass.Float4E2M1FN, - b.data_ptr(), - cute.AddressSpace.gmem, - assumed_align=32) a_sf_ptr = make_ptr(cutlass.Float8E4M3FN, a_sf.data_ptr(), cute.AddressSpace.gmem, assumed_align=16) - b_sf_ptr = make_ptr(cutlass.Float8E4M3FN, - b_sf.data_ptr(), + c_ptr = make_ptr(cutlass.Float4E2M1FN, + c.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32) + c_sf_ptr = make_ptr(cutlass.Float8E4M3FN, + c_sf.data_ptr(), cute.AddressSpace.gmem, assumed_align=16) - alpha_ptr = make_ptr(cutlass.Float32, alpha.data_ptr(), - cute.AddressSpace.gmem) tile_idx_to_group_idx_ptr = make_ptr( cutlass.Int32, tile_idx_to_group_idx.data_ptr(), cute.AddressSpace.gmem) @@ -2038,14 +2119,20 @@ def forward(self, inputs: List[torch.Tensor], cute.AddressSpace.gmem) global_sf_ptr = make_ptr(cutlass.Float32, global_sf.data_ptr(), cute.AddressSpace.gmem) - c_ptr = make_ptr(cutlass.Float4E2M1FN, - c.data_ptr(), - cute.AddressSpace.gmem, - assumed_align=32) - c_sf_ptr = make_ptr(cutlass.Float8E4M3FN, - c_sf.data_ptr(), - cute.AddressSpace.gmem, - assumed_align=16) + + b_ptr = tuple( + make_ptr(cutlass.Float4E2M1FN, + bi.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32) for bi in b_list) + b_sf_ptr = tuple( + make_ptr(cutlass.Float8E4M3FN, + bsfi.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16) for bsfi in b_sf_list) + alpha_ptr = tuple( + make_ptr(cutlass.Float32, ai.data_ptr(), cute.AddressSpace.gmem) + for ai in alpha_list) torch_stream = torch.cuda.current_stream() stream = cuda.CUstream(torch_stream.cuda_stream) @@ -2060,7 +2147,9 @@ def forward(self, inputs: List[torch.Tensor], 0] == self.tile_size, f"Tactic ({tactic}) is incompatible with tile size ({self.tile_size})" cache_key = (self.scaling_vector_size, self.tile_size, self.top_k, - mma_tiler_mn, cluster_shape_mn, raster_along_m) + mma_tiler_mn, cluster_shape_mn, raster_along_m, + b_tensor_l_sizes) + if cache_key not in self.__class__.kernel_cache: gemm = self.__class__.kernel_class( sf_vec_size=self.scaling_vector_size, @@ -2069,14 +2158,13 @@ def forward(self, inputs: List[torch.Tensor], vectorized_f32=True, topk=self.top_k, raster_along_m=raster_along_m, + b_tensor_l_sizes=b_tensor_l_sizes, ) - # Compute max active clusters on current device hardware_info = cutlass.utils.HardwareInfo() max_active_clusters = hardware_info.get_max_active_clusters( cluster_shape_mn[0] * cluster_shape_mn[1]) - compiled_gemm = cute.compile( - gemm.wrapper, + compile_args = [ a_ptr, b_ptr, a_sf_ptr, @@ -2093,7 +2181,11 @@ def forward(self, inputs: List[torch.Tensor], m, n, k, - l, + ] + + compiled_gemm = cute.compile( + gemm.wrapper, + *compile_args, tile_size=self.tile_size, scaling_vector_size=self.scaling_vector_size, max_active_clusters=max_active_clusters, @@ -2103,7 +2195,7 @@ def forward(self, inputs: List[torch.Tensor], else: compiled_gemm = self.__class__.kernel_cache[cache_key] - compiled_gemm( + exec_args = [ a_ptr, b_ptr, a_sf_ptr, @@ -2120,21 +2212,22 @@ def forward(self, inputs: List[torch.Tensor], m, n, k, - l, - stream=stream, - ) + ] + + compiled_gemm(*exec_args, stream=stream) + return c, c_sf @torch.library.custom_op( - "trtllm::cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell", + "trtllm::cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell_multi_b", mutates_args=(), device_types="cuda") - def cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell( + def cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell_multi_b( input: torch.Tensor, - weight: torch.Tensor, + weight: List[torch.Tensor], input_scale: torch.Tensor, - weight_scale: torch.Tensor, - alpha: torch.Tensor, + weight_scale: List[torch.Tensor], + alpha: List[torch.Tensor], tile_idx_to_group_idx: torch.Tensor, tile_idx_to_mn_limit: torch.Tensor, permuted_idx_to_expanded_idx: torch.Tensor, @@ -2147,11 +2240,20 @@ def cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell( tile_size: int, scaling_vector_size: int = 16, ) -> Tuple[torch.Tensor, torch.Tensor]: + """CuteDSL-based NVFP4 gather grouped GEMM with SwiGLU fusion (multi-B list interface). + + Args: + weight: List of B tensors. Single-B mode: [b], multi-B mode: [b0, b1, ...]. + weight_scale: List of scale tensors, matching weight. + alpha: List of alpha tensors, matching weight. + """ tuner = AutoTuner.get() + b_tensor_l_sizes = tuple(w.size(0) for w in weight) + runner = Sm100BlockScaledContiguousGatherGroupedGemmSwigluFusionRunner( num_experts, top_k, num_local_experts, local_expert_offset, - tile_size, scaling_vector_size) + tile_size, scaling_vector_size, b_tensor_l_sizes) inputs = [ input, weight, input_scale, weight_scale, alpha, tile_idx_to_group_idx, tile_idx_to_mn_limit, @@ -2159,17 +2261,97 @@ def cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell( ] _, best_tactic = tuner.choose_one( - "trtllm::cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell", + "trtllm::cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell_multi_b", [runner], runner.get_tuning_config(), inputs, ) - output = runner(inputs, tactic=best_tactic) + + # Call forward with inputs list + output = runner.forward(inputs, tactic=best_tactic) return output + @torch.library.register_fake( + "trtllm::cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell_multi_b") + def _fake_multi_b( + input: torch.Tensor, + weight: List[torch.Tensor], + input_scale: torch.Tensor, + weight_scale: List[torch.Tensor], + alpha: List[torch.Tensor], + tile_idx_to_group_idx: torch.Tensor, + tile_idx_to_mn_limit: torch.Tensor, + permuted_idx_to_expanded_idx: torch.Tensor, + num_non_exiting_tiles: torch.Tensor, + global_sf: torch.Tensor, + num_experts: int, + top_k: int, + num_local_experts: int, + local_expert_offset: int, + tile_size: int, + scaling_vector_size: int = 16, + ) -> Tuple[torch.Tensor, torch.Tensor]: + m = permuted_idx_to_expanded_idx.size(0) + n = weight[0].size(1) + interm_size = n // 2 + output = torch.empty(m, + interm_size // 2, + dtype=input.dtype, + device=input.device) + output_scale = torch.empty(m * interm_size // scaling_vector_size, + dtype=input_scale.dtype, + device=input_scale.device) + return output, output_scale + + @torch.library.custom_op( + "trtllm::cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell", + mutates_args=(), + device_types="cuda") + def cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell( + input: torch.Tensor, + weight: torch.Tensor, + input_scale: torch.Tensor, + weight_scale: torch.Tensor, + alpha: torch.Tensor, + tile_idx_to_group_idx: torch.Tensor, + tile_idx_to_mn_limit: torch.Tensor, + permuted_idx_to_expanded_idx: torch.Tensor, + num_non_exiting_tiles: torch.Tensor, + global_sf: torch.Tensor, + num_experts: int, + top_k: int, + num_local_experts: int, + local_expert_offset: int, + tile_size: int, + scaling_vector_size: int = 16, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """CuteDSL-based NVFP4 gather grouped GEMM with SwiGLU fusion (single-B tensor interface). + + Thin wrapper: wraps single tensors into lists and calls + cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell_multi_b. + """ + return torch.ops.trtllm.cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell_multi_b( + input, + [weight], + input_scale, + [weight_scale], + [alpha], + tile_idx_to_group_idx, + tile_idx_to_mn_limit, + permuted_idx_to_expanded_idx, + num_non_exiting_tiles, + global_sf, + num_experts, + top_k, + num_local_experts, + local_expert_offset, + tile_size, + scaling_vector_size, + ) + @torch.library.register_fake( "trtllm::cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell") - def _( + def _fake_single_b( input: torch.Tensor, weight: torch.Tensor, input_scale: torch.Tensor, diff --git a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py index 5a4c5d0a0067..c339787301e6 100644 --- a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py +++ b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py @@ -377,6 +377,9 @@ class BlockScaledContiguousGatherGroupedGemmKernel: ... ) """ + # Maximum number of B tensors supported + MAX_B_TENSORS = 4 + def __init__( self, sf_vec_size: int, @@ -385,6 +388,7 @@ def __init__( vectorized_f32: bool, topk: cutlass.Int64, raster_along_m: bool = False, + b_tensor_l_sizes: Optional[Tuple[int, ...]] = None, ): """Initializes the configuration for a Blackwell blockscaled dense GEMM kernel with gather operation and SwiGLU fusion. @@ -420,6 +424,10 @@ def __init__( :type vectorized_f32: bool :param topk: Number of experts selected per token (used for token ID mapping). :type topk: cutlass.Int64 + :param b_tensor_l_sizes: Optional tuple of L sizes for each B tensor. + E.g., (8, 8, 16) means 3 B tensors with L=8, 8, 16. Sum equals total L. + If None, single B tensor mode (backward compatible). + :type b_tensor_l_sizes: Optional[Tuple[int, ...]] """ self.sf_vec_size = sf_vec_size @@ -502,6 +510,26 @@ def __init__( self.vectorized_f32 = vectorized_f32 + # Multi-B tensor configuration + if b_tensor_l_sizes is None: + self.num_b_tensors = 1 + self.b_tensor_l_sizes = None + # Offsets padded for safe indexing in kernel + self.b_tensor_l_offsets = (0,) + (2**30,) * self.MAX_B_TENSORS + else: + assert len(b_tensor_l_sizes) <= self.MAX_B_TENSORS, ( + f"Max {self.MAX_B_TENSORS} B tensors, got {len(b_tensor_l_sizes)}" + ) + self.num_b_tensors = len(b_tensor_l_sizes) + self.b_tensor_l_sizes = b_tensor_l_sizes + offsets = [0] + for l_size in b_tensor_l_sizes: + offsets.append(offsets[-1] + l_size) + # Pad to MAX_B_TENSORS + 1 for safe indexing + while len(offsets) < self.MAX_B_TENSORS + 1: + offsets.append(2**30) + self.b_tensor_l_offsets = tuple(offsets) + def _setup_attributes(self): """Set up configurations that are dependent on GEMM inputs @@ -696,17 +724,17 @@ def _setup_attributes(self): def __call__( self, a: cute.Tensor, - b: cute.Tensor, + b: Union[cute.Tensor, Tuple[cute.Tensor, ...]], c: cute.Tensor, sfa: cute.Tensor, - sfb: cute.Tensor, + sfb: Union[cute.Tensor, Tuple[cute.Tensor, ...]], sfc_tensor: Optional[cute.Tensor], norm_const_tensor: Optional[cute.Tensor], tile_idx_to_expert_idx: cute.Tensor, tile_idx_to_mn_limit: cute.Tensor, token_id_mapping_tensor: cute.Tensor, num_non_exiting_tiles: cute.Tensor, - alpha: cute.Tensor, + alpha: Union[cute.Tensor, Tuple[cute.Tensor, ...]], max_active_clusters: cutlass.Constexpr, stream: cuda.CUstream, epilogue_op: cutlass.Constexpr = lambda x: x, @@ -774,11 +802,14 @@ def __call__( """ # Setup static attributes before smem/grid/tma computation self.a_dtype: Type[cutlass.Numeric] = a.element_type - self.b_dtype: Type[cutlass.Numeric] = b.element_type + # Handle tuple of B tensors + b_tuple = b if isinstance(b, tuple) else (b,) + sfb_tuple = sfb if isinstance(sfb, tuple) else (sfb,) + self.b_dtype: Type[cutlass.Numeric] = b_tuple[0].element_type self.c_dtype: Type[cutlass.Numeric] = c.element_type self.sf_dtype: Type[cutlass.Numeric] = sfa.element_type self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode() - self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(b_tuple[0]).mma_major_mode() self.c_layout = utils.LayoutEnum.from_tensor(c) # Check if input data types are compatible with MMA instruction @@ -788,10 +819,28 @@ def __call__( # Setup attributes that dependent on gemm inputs self._setup_attributes() - # Setup sfb tensor by filling B tensor to scale factor atom layout - # ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL) - sfb_layout = blockscaled_utils.tile_atom_to_shape_SF(b.shape, self.sf_vec_size) - sfb = cute.make_tensor(sfb.iterator, sfb_layout) + # Setup sfb tensors - create layout for each B tensor (use const_expr, not loop) + sfb_layout_0 = blockscaled_utils.tile_atom_to_shape_SF(b_tuple[0].shape, self.sf_vec_size) + sfb_tensor_0 = cute.make_tensor(sfb_tuple[0].iterator, sfb_layout_0) + sfb_tensors = [sfb_tensor_0] + if cutlass.const_expr(self.num_b_tensors >= 2): + sfb_layout_1 = blockscaled_utils.tile_atom_to_shape_SF( + b_tuple[1].shape, self.sf_vec_size + ) + sfb_tensors.append(cute.make_tensor(sfb_tuple[1].iterator, sfb_layout_1)) + if cutlass.const_expr(self.num_b_tensors >= 3): + sfb_layout_2 = blockscaled_utils.tile_atom_to_shape_SF( + b_tuple[2].shape, self.sf_vec_size + ) + sfb_tensors.append(cute.make_tensor(sfb_tuple[2].iterator, sfb_layout_2)) + if cutlass.const_expr(self.num_b_tensors >= 4): + sfb_layout_3 = blockscaled_utils.tile_atom_to_shape_SF( + b_tuple[3].shape, self.sf_vec_size + ) + sfb_tensors.append(cute.make_tensor(sfb_tuple[3].iterator, sfb_layout_3)) + sfb_tuple = tuple(sfb_tensors) + # Backward compat alias + sfb = sfb_tuple[0] # Setup sfc tensor by filling C tensor to scale factor atom layout self.generate_sfc = sfc_tensor is not None and norm_const_tensor is not None @@ -821,51 +870,82 @@ def __call__( ) atom_thr_size = cute.size(tiled_mma.thr_id.shape) - # Setup TMA load for B + # Setup TMA ops (shared across all B tensors) b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id) - b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) - tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( - b_op, - b, - b_smem_layout, - self.mma_tiler, - tiled_mma, - self.cluster_layout_vmnk.shape, - ) - - # Setup TMA load for SFB sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(self.cluster_shape_mn, tiled_mma.thr_id) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) sfb_smem_layout = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0)) - tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( - sfb_op, - sfb, - sfb_smem_layout, - self.mma_tiler_sfb, - tiled_mma_sfb, - self.cluster_layout_sfb_vmnk.shape, - internal_type=cutlass.Int16, - ) - - # This modifies the layout to handle overlapping 256x(# of scale factors for a single column of B (nNSF)) - # logical blocks for SFB when cta_tile_shape_n=192. - if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192): - x = tma_tensor_sfb.stride[0][1] - y = cute.ceil_div(tma_tensor_sfb.shape[0][1], 4) - new_shape = ( - (tma_tensor_sfb.shape[0][0], ((2, 2), y)), - tma_tensor_sfb.shape[1], - tma_tensor_sfb.shape[2], + # Helper to create TMA for one B tensor + def _make_tma_b(b_tensor, sfb_tensor): + atom_b, tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b_tensor, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, ) - # Use right multiplication for ScaledBasis (3 * x instead of x * 3) - x_times_3 = 3 * x - new_stride = ( - (tma_tensor_sfb.stride[0][0], ((x, x), x_times_3)), - tma_tensor_sfb.stride[1], - tma_tensor_sfb.stride[2], + atom_sfb, tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( + sfb_op, + sfb_tensor, + sfb_smem_layout, + self.mma_tiler_sfb, + tiled_mma_sfb, + self.cluster_layout_sfb_vmnk.shape, + internal_type=cutlass.Int16, ) - tma_tensor_sfb_new_layout = cute.make_layout(new_shape, stride=new_stride) - tma_tensor_sfb = cute.make_tensor(tma_tensor_sfb.iterator, tma_tensor_sfb_new_layout) + # Handle overlapping layout for SFB when cta_tile_shape_n=192 + if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192): + x = tensor_sfb.stride[0][1] + y = cute.ceil_div(tensor_sfb.shape[0][1], 4) + new_shape = ( + (tensor_sfb.shape[0][0], ((2, 2), y)), + tensor_sfb.shape[1], + tensor_sfb.shape[2], + ) + x_times_3 = 3 * x + new_stride = ( + (tensor_sfb.stride[0][0], ((x, x), x_times_3)), + tensor_sfb.stride[1], + tensor_sfb.stride[2], + ) + tensor_sfb = cute.make_tensor( + tensor_sfb.iterator, cute.make_layout(new_shape, stride=new_stride) + ) + return atom_b, tensor_b, atom_sfb, tensor_sfb + + # Create TMA for all B tensors (use const_expr, not loop) + atom_b_0, tensor_b_0, atom_sfb_0, tensor_sfb_0 = _make_tma_b(b_tuple[0], sfb_tuple[0]) + tma_atoms_b = [atom_b_0] + tma_tensors_b = [tensor_b_0] + tma_atoms_sfb = [atom_sfb_0] + tma_tensors_sfb = [tensor_sfb_0] + if cutlass.const_expr(self.num_b_tensors >= 2): + atom_b_1, tensor_b_1, atom_sfb_1, tensor_sfb_1 = _make_tma_b(b_tuple[1], sfb_tuple[1]) + tma_atoms_b.append(atom_b_1) + tma_tensors_b.append(tensor_b_1) + tma_atoms_sfb.append(atom_sfb_1) + tma_tensors_sfb.append(tensor_sfb_1) + if cutlass.const_expr(self.num_b_tensors >= 3): + atom_b_2, tensor_b_2, atom_sfb_2, tensor_sfb_2 = _make_tma_b(b_tuple[2], sfb_tuple[2]) + tma_atoms_b.append(atom_b_2) + tma_tensors_b.append(tensor_b_2) + tma_atoms_sfb.append(atom_sfb_2) + tma_tensors_sfb.append(tensor_sfb_2) + if cutlass.const_expr(self.num_b_tensors >= 4): + atom_b_3, tensor_b_3, atom_sfb_3, tensor_sfb_3 = _make_tma_b(b_tuple[3], sfb_tuple[3]) + tma_atoms_b.append(atom_b_3) + tma_tensors_b.append(tensor_b_3) + tma_atoms_sfb.append(atom_sfb_3) + tma_tensors_sfb.append(tensor_sfb_3) + tma_atoms_b = tuple(tma_atoms_b) + tma_tensors_b = tuple(tma_tensors_b) + tma_atoms_sfb = tuple(tma_atoms_sfb) + tma_tensors_sfb = tuple(tma_tensors_sfb) + + # Handle alpha tuple (convert to tuple if single tensor) + alpha_tuple = alpha if isinstance(alpha, tuple) else (alpha,) b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) sfb_copy_size = cute.size_in_bytes(self.sf_dtype, sfb_smem_layout) @@ -990,11 +1070,11 @@ class SharedStorage2cta: tiled_mma, tiled_mma_sfb, a, - tma_atom_b, - tma_tensor_b, + tma_atoms_b, # Tuple of TMA atoms for B + tma_tensors_b, # Tuple of TMA tensors for B sfa, - tma_atom_sfb, - tma_tensor_sfb, + tma_atoms_sfb, # Tuple of TMA atoms for SFB + tma_tensors_sfb, # Tuple of TMA tensors for SFB tma_atom_c, tma_tensor_c, sfc_tensor, @@ -1003,7 +1083,7 @@ class SharedStorage2cta: tile_idx_to_mn_limit, token_id_mapping_tensor, num_non_exiting_tiles, - alpha, + alpha_tuple, self.cluster_layout_vmnk, self.cluster_layout_sfb_vmnk, self.a_smem_layout_staged, @@ -1074,11 +1154,11 @@ def kernel( tiled_mma: cute.TiledMma, tiled_mma_sfb: cute.TiledMma, mA_mkl: cute.Tensor, - tma_atom_b: cute.CopyAtom, - mB_nkl: cute.Tensor, + tma_atoms_b: Tuple[cute.CopyAtom, ...], + mB_nkl_tuple: Tuple[cute.Tensor, ...], mSFA_mkl: cute.Tensor, - tma_atom_sfb: cute.CopyAtom, - mSFB_nkl: cute.Tensor, + tma_atoms_sfb: Tuple[cute.CopyAtom, ...], + mSFB_nkl_tuple: Tuple[cute.Tensor, ...], tma_atom_c: cute.CopyAtom, mC_mnl: cute.Tensor, mSFC_mnl: Optional[cute.Tensor], @@ -1087,7 +1167,7 @@ def kernel( tile_idx_to_mn_limit: cute.Tensor, token_id_mapping_tensor: cute.Tensor, num_non_exiting_tiles: cute.Tensor, - alpha: cute.Tensor, + alpha_tuple: Tuple[cute.Tensor, ...], cluster_layout_vmnk: cute.Layout, cluster_layout_sfb_vmnk: cute.Layout, a_smem_layout_staged: cute.ComposedLayout, @@ -1109,8 +1189,18 @@ def kernel( # Prefetch tma desc # if warp_idx == self.tma_b_warp_id: - cpasync.prefetch_descriptor(tma_atom_b) - cpasync.prefetch_descriptor(tma_atom_sfb) + # Prefetch TMA descriptors for all B tensors using const_expr conditions + cpasync.prefetch_descriptor(tma_atoms_b[0]) + cpasync.prefetch_descriptor(tma_atoms_sfb[0]) + if cutlass.const_expr(self.num_b_tensors >= 2): + cpasync.prefetch_descriptor(tma_atoms_b[1]) + cpasync.prefetch_descriptor(tma_atoms_sfb[1]) + if cutlass.const_expr(self.num_b_tensors >= 3): + cpasync.prefetch_descriptor(tma_atoms_b[2]) + cpasync.prefetch_descriptor(tma_atoms_sfb[2]) + if cutlass.const_expr(self.num_b_tensors >= 4): + cpasync.prefetch_descriptor(tma_atoms_b[3]) + cpasync.prefetch_descriptor(tma_atoms_sfb[3]) cpasync.prefetch_descriptor(tma_atom_c) use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 @@ -1270,22 +1360,52 @@ def kernel( gA_mkl = cute.local_tile( mA_mkl, cute.slice_(self.cta_tile_shape_mnk, (None, 0, None)), (None, None, None) ) - # (bN, bK, loopN, loopK, loopL) - gB_nkl = cute.local_tile( - mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + # (bN, bK, loopN, loopK, loopL) - Use const_expr conditions for tuple indexing + gB_nkl_0 = cute.local_tile( + mB_nkl_tuple[0], cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) ) + if cutlass.const_expr(self.num_b_tensors >= 2): + gB_nkl_1 = cute.local_tile( + mB_nkl_tuple[1], cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) + if cutlass.const_expr(self.num_b_tensors >= 3): + gB_nkl_2 = cute.local_tile( + mB_nkl_tuple[2], cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) + if cutlass.const_expr(self.num_b_tensors >= 4): + gB_nkl_3 = cute.local_tile( + mB_nkl_tuple[3], cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) # (bM, bK, RestM, RestK, RestL) gSFA_mkl = cute.local_tile( mSFA_mkl, cute.slice_(self.cta_tile_shape_mnk_sfa, (None, 0, None)), (None, None, None) ) - # (bN, bK, RestN, RestK, RestL) - gSFB_nkl = cute.local_tile( - mSFB_nkl, + # (bN, bK, RestN, RestK, RestL) - Use const_expr conditions for tuple indexing + gSFB_nkl_0 = cute.local_tile( + mSFB_nkl_tuple[0], cute.slice_(self.mma_tiler_sfb, (0, None, None)), (None, None, None), ) + if cutlass.const_expr(self.num_b_tensors >= 2): + gSFB_nkl_1 = cute.local_tile( + mSFB_nkl_tuple[1], + cute.slice_(self.mma_tiler_sfb, (0, None, None)), + (None, None, None), + ) + if cutlass.const_expr(self.num_b_tensors >= 3): + gSFB_nkl_2 = cute.local_tile( + mSFB_nkl_tuple[2], + cute.slice_(self.mma_tiler_sfb, (0, None, None)), + (None, None, None), + ) + if cutlass.const_expr(self.num_b_tensors >= 4): + gSFB_nkl_3 = cute.local_tile( + mSFB_nkl_tuple[3], + cute.slice_(self.mma_tiler_sfb, (0, None, None)), + (None, None, None), + ) gToken_ml = cute.local_tile( token_id_mapping_tensor, cute.slice_(self.cta_tile_shape_mnk, (None, 0, 0)), (None,) @@ -1302,43 +1422,106 @@ def kernel( # thr_mma = tiled_mma.get_slice(mma_tile_coord_v) thr_mma_sfb = tiled_mma_sfb.get_slice(mma_tile_coord_v) - # (MMA, MMA_N, MMA_K, loopN, loopK, loopL) - tCgB = thr_mma.partition_B(gB_nkl) - # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) - tCgSFB = thr_mma_sfb.partition_B(gSFB_nkl) + # (MMA, MMA_N, MMA_K, loopN, loopK, loopL) - const_expr conditions + tCgB_0 = thr_mma.partition_B(gB_nkl_0) + if cutlass.const_expr(self.num_b_tensors >= 2): + tCgB_1 = thr_mma.partition_B(gB_nkl_1) + if cutlass.const_expr(self.num_b_tensors >= 3): + tCgB_2 = thr_mma.partition_B(gB_nkl_2) + if cutlass.const_expr(self.num_b_tensors >= 4): + tCgB_3 = thr_mma.partition_B(gB_nkl_3) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) - const_expr conditions + tCgSFB_0 = thr_mma_sfb.partition_B(gSFB_nkl_0) + if cutlass.const_expr(self.num_b_tensors >= 2): + tCgSFB_1 = thr_mma_sfb.partition_B(gSFB_nkl_1) + if cutlass.const_expr(self.num_b_tensors >= 3): + tCgSFB_2 = thr_mma_sfb.partition_B(gSFB_nkl_2) + if cutlass.const_expr(self.num_b_tensors >= 4): + tCgSFB_3 = thr_mma_sfb.partition_B(gSFB_nkl_3) # (MMA, MMA_M, MMA_N, loopM, loopN, loopL) tCgC = thr_mma.partition_C(gC_mnl) # # Partition global/shared tensor for TMA load B # - # TMA load B partition_S/D b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape) - # ((atom_v, rest_v), STAGE) - # ((atom_v, rest_v), loopM, loopK, loopL) - tBsB, tBgB = cpasync.tma_partition( - tma_atom_b, - block_in_cluster_coord_vmnk[1], - b_cta_layout, - cute.group_modes(sB, 0, 3), - cute.group_modes(tCgB, 0, 3), - ) - - # TMA load SFB partition_S/D sfb_cta_layout = cute.make_layout( cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape ) - # ((atom_v, rest_v), STAGE) - # ((atom_v, rest_v), RestN, RestK, RestL) - tBsSFB, tBgSFB = cute.nvgpu.cpasync.tma_partition( - tma_atom_sfb, + sB_grouped = cute.group_modes(sB, 0, 3) + sSFB_grouped = cute.group_modes(sSFB, 0, 3) + + # TMA partition for B tensor 0 + tBsB_0, tBgB_0 = cpasync.tma_partition( + tma_atoms_b[0], + block_in_cluster_coord_vmnk[1], + b_cta_layout, + sB_grouped, + cute.group_modes(tCgB_0, 0, 3), + ) + tBsSFB_0, tBgSFB_0 = cute.nvgpu.cpasync.tma_partition( + tma_atoms_sfb[0], block_in_cluster_coord_sfb_vmnk[1], sfb_cta_layout, - cute.group_modes(sSFB, 0, 3), - cute.group_modes(tCgSFB, 0, 3), - ) - tBsSFB = cute.filter_zeros(tBsSFB) - tBgSFB = cute.filter_zeros(tBgSFB) + sSFB_grouped, + cute.group_modes(tCgSFB_0, 0, 3), + ) + tBsSFB_0 = cute.filter_zeros(tBsSFB_0) + tBgSFB_0 = cute.filter_zeros(tBgSFB_0) + + # TMA partition for B tensor 1 (tBsB shared memory partition is same for all, use _ to ignore) + if cutlass.const_expr(self.num_b_tensors >= 2): + _, tBgB_1 = cpasync.tma_partition( + tma_atoms_b[1], + block_in_cluster_coord_vmnk[1], + b_cta_layout, + sB_grouped, + cute.group_modes(tCgB_1, 0, 3), + ) + _, tBgSFB_1 = cute.nvgpu.cpasync.tma_partition( + tma_atoms_sfb[1], + block_in_cluster_coord_sfb_vmnk[1], + sfb_cta_layout, + sSFB_grouped, + cute.group_modes(tCgSFB_1, 0, 3), + ) + tBgSFB_1 = cute.filter_zeros(tBgSFB_1) + + # TMA partition for B tensor 2 + if cutlass.const_expr(self.num_b_tensors >= 3): + _, tBgB_2 = cpasync.tma_partition( + tma_atoms_b[2], + block_in_cluster_coord_vmnk[1], + b_cta_layout, + sB_grouped, + cute.group_modes(tCgB_2, 0, 3), + ) + _, tBgSFB_2 = cute.nvgpu.cpasync.tma_partition( + tma_atoms_sfb[2], + block_in_cluster_coord_sfb_vmnk[1], + sfb_cta_layout, + sSFB_grouped, + cute.group_modes(tCgSFB_2, 0, 3), + ) + tBgSFB_2 = cute.filter_zeros(tBgSFB_2) + + # TMA partition for B tensor 3 + if cutlass.const_expr(self.num_b_tensors >= 4): + _, tBgB_3 = cpasync.tma_partition( + tma_atoms_b[3], + block_in_cluster_coord_vmnk[1], + b_cta_layout, + sB_grouped, + cute.group_modes(tCgB_3, 0, 3), + ) + _, tBgSFB_3 = cute.nvgpu.cpasync.tma_partition( + tma_atoms_sfb[3], + block_in_cluster_coord_sfb_vmnk[1], + sfb_cta_layout, + sSFB_grouped, + cute.group_modes(tCgSFB_3, 0, 3), + ) + tBgSFB_3 = cute.filter_zeros(tBgSFB_3) # # Partition shared/tensor memory tensor for TiledMMA_A/B/C @@ -1849,20 +2032,13 @@ def kernel( tile_info[1], tile_info[2], ) - # - # Slice to per mma tile index - # - # ((atom_v, rest_v), loopK) - tBgB_slice = tBgB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] + expert_idx = mma_tile_coord_mnl[2] # Apply SFB slicing hack when cta_tile_shape_n=64 slice_n = mma_tile_coord_mnl[1] if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64): slice_n = mma_tile_coord_mnl[1] // 2 - # ((atom_v, rest_v), RestK) - tBgSFB_slice = tBgSFB[(None, slice_n, None, mma_tile_coord_mnl[2])] - # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt b_producer_state.reset_count() peek_ab_empty_status = cutlass.Boolean(1) @@ -1872,35 +2048,247 @@ def kernel( # Tma load loop # for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): - # Conditionally wait for B buffer empty b_pipeline.producer_acquire(b_producer_state, peek_ab_empty_status) - - tBgB_k = tBgB_slice[(None, b_producer_state.count)] - tBgSFB_k = tBgSFB_slice[(None, b_producer_state.count)] - tBsB_pipe = tBsB[(None, b_producer_state.index)] - tBsSFB_pipe = tBsSFB[(None, b_producer_state.index)] - + tBsB_pipe = tBsB_0[(None, b_producer_state.index)] + tBsSFB_pipe = tBsSFB_0[(None, b_producer_state.index)] tma_bar = b_pipeline.producer_get_barrier(b_producer_state) - # TMA load B - cute.copy( - tma_atom_b, - tBgB_k, - tBsB_pipe, - tma_bar_ptr=tma_bar, - mcast_mask=b_full_mcast_mask, - ) - - # TMA load SFB - cute.copy( - tma_atom_sfb, - tBgSFB_k, - tBsSFB_pipe, - tma_bar_ptr=tma_bar, - mcast_mask=sfb_full_mcast_mask, - ) + # Select correct B tensor based on expert_idx + if cutlass.const_expr(self.num_b_tensors == 1): + # Single B tensor - original logic + tBgB_slice = tBgB_0[(None, mma_tile_coord_mnl[1], None, expert_idx)] + tBgSFB_slice = tBgSFB_0[(None, slice_n, None, expert_idx)] + cute.copy( + tma_atoms_b[0], + tBgB_slice[(None, b_producer_state.count)], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[0], + tBgSFB_slice[(None, b_producer_state.count)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + else: + # Multi-B tensor - select based on expert_idx + # Use nested const_expr ifs to avoid index out of range at compile time + if cutlass.const_expr(self.num_b_tensors == 2): + # Exactly 2 B tensors + if expert_idx < self.b_tensor_l_offsets[1]: + local_l_0 = expert_idx - self.b_tensor_l_offsets[0] + cute.copy( + tma_atoms_b[0], + tBgB_0[ + ( + None, + mma_tile_coord_mnl[1], + b_producer_state.count, + local_l_0, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[0], + tBgSFB_0[(None, slice_n, b_producer_state.count, local_l_0)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + else: + local_l_1 = expert_idx - self.b_tensor_l_offsets[1] + cute.copy( + tma_atoms_b[1], + tBgB_1[ + ( + None, + mma_tile_coord_mnl[1], + b_producer_state.count, + local_l_1, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[1], + tBgSFB_1[(None, slice_n, b_producer_state.count, local_l_1)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + elif cutlass.const_expr(self.num_b_tensors == 3): + # Exactly 3 B tensors + if expert_idx < self.b_tensor_l_offsets[1]: + local_l_0 = expert_idx - self.b_tensor_l_offsets[0] + cute.copy( + tma_atoms_b[0], + tBgB_0[ + ( + None, + mma_tile_coord_mnl[1], + b_producer_state.count, + local_l_0, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[0], + tBgSFB_0[(None, slice_n, b_producer_state.count, local_l_0)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + elif expert_idx < self.b_tensor_l_offsets[2]: + local_l_1 = expert_idx - self.b_tensor_l_offsets[1] + cute.copy( + tma_atoms_b[1], + tBgB_1[ + ( + None, + mma_tile_coord_mnl[1], + b_producer_state.count, + local_l_1, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[1], + tBgSFB_1[(None, slice_n, b_producer_state.count, local_l_1)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + else: + local_l_2 = expert_idx - self.b_tensor_l_offsets[2] + cute.copy( + tma_atoms_b[2], + tBgB_2[ + ( + None, + mma_tile_coord_mnl[1], + b_producer_state.count, + local_l_2, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[2], + tBgSFB_2[(None, slice_n, b_producer_state.count, local_l_2)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + else: + # 4 B tensors + if expert_idx < self.b_tensor_l_offsets[1]: + local_l_0 = expert_idx - self.b_tensor_l_offsets[0] + cute.copy( + tma_atoms_b[0], + tBgB_0[ + ( + None, + mma_tile_coord_mnl[1], + b_producer_state.count, + local_l_0, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[0], + tBgSFB_0[(None, slice_n, b_producer_state.count, local_l_0)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + elif expert_idx < self.b_tensor_l_offsets[2]: + local_l_1 = expert_idx - self.b_tensor_l_offsets[1] + cute.copy( + tma_atoms_b[1], + tBgB_1[ + ( + None, + mma_tile_coord_mnl[1], + b_producer_state.count, + local_l_1, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[1], + tBgSFB_1[(None, slice_n, b_producer_state.count, local_l_1)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + elif expert_idx < self.b_tensor_l_offsets[3]: + local_l_2 = expert_idx - self.b_tensor_l_offsets[2] + cute.copy( + tma_atoms_b[2], + tBgB_2[ + ( + None, + mma_tile_coord_mnl[1], + b_producer_state.count, + local_l_2, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[2], + tBgSFB_2[(None, slice_n, b_producer_state.count, local_l_2)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + else: + local_l_3 = expert_idx - self.b_tensor_l_offsets[3] + cute.copy( + tma_atoms_b[3], + tBgB_3[ + ( + None, + mma_tile_coord_mnl[1], + b_producer_state.count, + local_l_3, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[3], + tBgSFB_3[(None, slice_n, b_producer_state.count, local_l_3)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) - # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 b_producer_state.advance() peek_ab_empty_status = cutlass.Boolean(1) if b_producer_state.count < k_tile_cnt: @@ -2343,9 +2731,38 @@ def kernel( # # Get alpha for current group # - expert_idx = mma_tile_coord_mnl[2] - alpha_val = alpha[expert_idx] + + # Select alpha from correct tensor based on expert_idx + # Initialize alpha_val first to avoid DSL "None prior to if" error + alpha_val = alpha_tuple[0][expert_idx - self.b_tensor_l_offsets[0]] + if cutlass.const_expr(self.num_b_tensors == 1): + pass # Already initialized above + elif cutlass.const_expr(self.num_b_tensors == 2): + if expert_idx >= self.b_tensor_l_offsets[1]: + alpha_val = alpha_tuple[1][expert_idx - self.b_tensor_l_offsets[1]] + elif cutlass.const_expr(self.num_b_tensors == 3): + if ( + expert_idx >= self.b_tensor_l_offsets[1] + and expert_idx < self.b_tensor_l_offsets[2] + ): + alpha_val = alpha_tuple[1][expert_idx - self.b_tensor_l_offsets[1]] + elif expert_idx >= self.b_tensor_l_offsets[2]: + alpha_val = alpha_tuple[2][expert_idx - self.b_tensor_l_offsets[2]] + else: + # 4 B tensors + if ( + expert_idx >= self.b_tensor_l_offsets[1] + and expert_idx < self.b_tensor_l_offsets[2] + ): + alpha_val = alpha_tuple[1][expert_idx - self.b_tensor_l_offsets[1]] + elif ( + expert_idx >= self.b_tensor_l_offsets[2] + and expert_idx < self.b_tensor_l_offsets[3] + ): + alpha_val = alpha_tuple[2][expert_idx - self.b_tensor_l_offsets[2]] + elif expert_idx >= self.b_tensor_l_offsets[3]: + alpha_val = alpha_tuple[3][expert_idx - self.b_tensor_l_offsets[3]] # # Slice to per mma tile index @@ -3313,12 +3730,12 @@ def can_implement( def wrapper( self, a_ptr: cute.Pointer, - b_ptr: cute.Pointer, + b_ptr_tuple: Tuple[cute.Pointer, ...], a_sf_ptr: cute.Pointer, - b_sf_ptr: cute.Pointer, + b_sf_ptr_tuple: Tuple[cute.Pointer, ...], c_ptr: cute.Pointer, c_sf_ptr: cute.Pointer, - alpha_ptr: cute.Pointer, + alpha_ptr_tuple: Tuple[cute.Pointer, ...], tile_idx_to_group_idx_ptr: cute.Pointer, tile_idx_to_mn_limit_ptr: cute.Pointer, token_id_mapping_ptr: cute.Pointer, @@ -3328,40 +3745,102 @@ def wrapper( m: cutlass.Int64, n: cutlass.Int64, k: cutlass.Int64, - l: cutlass.Int64, # noqa: E741 tile_size: cutlass.Constexpr, scaling_vector_size: cutlass.Constexpr, max_active_clusters: cutlass.Constexpr, stream: cuda.CUstream, epilogue_op: cutlass.Constexpr = lambda x: x, ): + """Unified wrapper supporting both single-B and multi-B tensors. + + B tensors are always passed as tuples (length 1 for single-B). + L sizes are configured via b_tensor_l_sizes in __init__. + """ scale_k = k // scaling_vector_size interm_size = n // 2 num_tiles = m // tile_size + total_l = self.b_tensor_l_offsets[self.num_b_tensors] + a = cute.make_tensor( a_ptr, layout=cute.make_ordered_layout((orig_m, k, 1), order=(1, 0, 2)) ) - b = cute.make_tensor(b_ptr, layout=cute.make_ordered_layout((n, k, l), order=(1, 0, 2))) a_sf = cute.make_tensor( a_sf_ptr, layout=cute.make_ordered_layout((orig_m, scale_k, 1), order=(1, 0, 2)) ) - b_sf = cute.make_tensor( - b_sf_ptr, - layout=cute.make_ordered_layout( - (32, 4, n // 128, 4, scale_k // 4, l), order=(2, 1, 4, 0, 3, 5) - ), - ) c = cute.make_tensor( c_ptr, layout=cute.make_ordered_layout((m, interm_size, 1), order=(1, 0, 2)) ) c_sf = cute.make_tensor( c_sf_ptr, layout=cute.make_ordered_layout( - (32, 4, m // 128, 4, interm_size // (scaling_vector_size * 4), l), + (32, 4, m // 128, 4, interm_size // (scaling_vector_size * 4), total_l), order=(2, 1, 4, 0, 3, 5), ), ) - alpha = cute.make_tensor(alpha_ptr, layout=cute.make_layout((l,))) + + # Create B and alpha tensors using const_expr conditions + l_0 = self.b_tensor_l_sizes[0] + alpha_0 = cute.make_tensor(alpha_ptr_tuple[0], layout=cute.make_layout((l_0,))) + b_0 = cute.make_tensor( + b_ptr_tuple[0], layout=cute.make_ordered_layout((n, k, l_0), order=(1, 0, 2)) + ) + b_sf_0 = cute.make_tensor( + b_sf_ptr_tuple[0], + layout=cute.make_ordered_layout( + (32, 4, n // 128, 4, scale_k // 4, l_0), order=(2, 1, 4, 0, 3, 5) + ), + ) + b_tuple = [b_0] + b_sf_tuple = [b_sf_0] + alpha_tuple = [alpha_0] + + if cutlass.const_expr(self.num_b_tensors >= 2): + l_1 = self.b_tensor_l_sizes[1] + alpha_1 = cute.make_tensor(alpha_ptr_tuple[1], layout=cute.make_layout((l_1,))) + b_1 = cute.make_tensor( + b_ptr_tuple[1], layout=cute.make_ordered_layout((n, k, l_1), order=(1, 0, 2)) + ) + b_sf_1 = cute.make_tensor( + b_sf_ptr_tuple[1], + layout=cute.make_ordered_layout( + (32, 4, n // 128, 4, scale_k // 4, l_1), order=(2, 1, 4, 0, 3, 5) + ), + ) + b_tuple.append(b_1) + b_sf_tuple.append(b_sf_1) + alpha_tuple.append(alpha_1) + + if cutlass.const_expr(self.num_b_tensors >= 3): + l_2 = self.b_tensor_l_sizes[2] + alpha_2 = cute.make_tensor(alpha_ptr_tuple[2], layout=cute.make_layout((l_2,))) + b_2 = cute.make_tensor( + b_ptr_tuple[2], layout=cute.make_ordered_layout((n, k, l_2), order=(1, 0, 2)) + ) + b_sf_2 = cute.make_tensor( + b_sf_ptr_tuple[2], + layout=cute.make_ordered_layout( + (32, 4, n // 128, 4, scale_k // 4, l_2), order=(2, 1, 4, 0, 3, 5) + ), + ) + b_tuple.append(b_2) + b_sf_tuple.append(b_sf_2) + alpha_tuple.append(alpha_2) + + if cutlass.const_expr(self.num_b_tensors >= 4): + l_3 = self.b_tensor_l_sizes[3] + alpha_3 = cute.make_tensor(alpha_ptr_tuple[3], layout=cute.make_layout((l_3,))) + b_3 = cute.make_tensor( + b_ptr_tuple[3], layout=cute.make_ordered_layout((n, k, l_3), order=(1, 0, 2)) + ) + b_sf_3 = cute.make_tensor( + b_sf_ptr_tuple[3], + layout=cute.make_ordered_layout( + (32, 4, n // 128, 4, scale_k // 4, l_3), order=(2, 1, 4, 0, 3, 5) + ), + ) + b_tuple.append(b_3) + b_sf_tuple.append(b_sf_3) + alpha_tuple.append(alpha_3) tile_idx_to_group_idx = cute.make_tensor( tile_idx_to_group_idx_ptr, layout=cute.make_layout((num_tiles,)) @@ -3377,17 +3856,17 @@ def wrapper( return self( a, - b, + tuple(b_tuple), c, a_sf, - b_sf, + tuple(b_sf_tuple), c_sf, global_sf, tile_idx_to_group_idx, tile_idx_to_mn_limit, token_id_mapping, num_non_exiting_tiles, - alpha, + tuple(alpha_tuple), max_active_clusters=max_active_clusters, stream=stream, epilogue_op=epilogue_op, diff --git a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py index 50d36beff868..babf3dbcb261 100644 --- a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py +++ b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py @@ -26,7 +26,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from typing import Tuple, Type, Union +from typing import Optional, Tuple, Type, Union import cuda.bindings.driver as cuda import cutlass @@ -339,6 +339,9 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel: ... ) """ + # Maximum number of B tensors supported + MAX_B_TENSORS = 4 + def __init__( self, sf_vec_size: int, @@ -346,6 +349,7 @@ def __init__( cluster_shape_mn: Tuple[int, int], use_blkred: bool = False, raster_along_m: bool = False, + b_tensor_l_sizes: Optional[Tuple[int, ...]] = None, ): """Initializes the configuration for a Blackwell blockscaled dense GEMM kernel. @@ -363,6 +367,10 @@ def __init__( :type cluster_shape_mn: Tuple[int, int] :param raster_along_m: Boolean, True to use raster along M. :type raster_along_m: bool + :param b_tensor_l_sizes: Optional tuple of L sizes for each B tensor. + E.g., (8, 8, 16) means 3 B tensors with L=8, 8, 16. Sum equals total L. + If None, single B tensor mode (backward compatible). + :type b_tensor_l_sizes: Optional[Tuple[int, ...]] """ self.sf_vec_size = sf_vec_size @@ -424,6 +432,26 @@ def __init__( # TMEM offset for final accumulator self.tmem_final_offset = 384 + # Multi-B tensor configuration + if b_tensor_l_sizes is None: + self.num_b_tensors = 1 + self.b_tensor_l_sizes = None + # Offsets padded for safe indexing in kernel + self.b_tensor_l_offsets = (0,) + (2**30,) * self.MAX_B_TENSORS + else: + assert len(b_tensor_l_sizes) <= self.MAX_B_TENSORS, ( + f"Max {self.MAX_B_TENSORS} B tensors, got {len(b_tensor_l_sizes)}" + ) + self.num_b_tensors = len(b_tensor_l_sizes) + self.b_tensor_l_sizes = b_tensor_l_sizes + offsets = [0] + for l_size in b_tensor_l_sizes: + offsets.append(offsets[-1] + l_size) + # Pad to MAX_B_TENSORS + 1 for safe indexing + while len(offsets) < self.MAX_B_TENSORS + 1: + offsets.append(2**30) + self.b_tensor_l_offsets = tuple(offsets) + def _setup_attributes(self): """Set up configurations that are dependent on GEMM inputs @@ -602,14 +630,14 @@ def _setup_attributes(self): def __call__( self, a: cute.Tensor, - b: cute.Tensor, + b: Union[cute.Tensor, Tuple[cute.Tensor, ...]], out: cute.Tensor, sfa: cute.Tensor, - sfb: cute.Tensor, + sfb: Union[cute.Tensor, Tuple[cute.Tensor, ...]], tile_idx_to_expert_idx: cute.Tensor, num_non_exiting_tiles: cute.Tensor, tile_idx_to_mn_limit: cute.Tensor, - alpha: cute.Tensor, + alpha: Union[cute.Tensor, Tuple[cute.Tensor, ...]], max_active_clusters: cutlass.Constexpr, stream: cuda.CUstream, permuted_idx_to_expanded_idx: cute.Tensor, @@ -639,7 +667,7 @@ def __call__( :param num_non_exiting_tiles: Number of valid tiles (valid_m/cta_tile_m), shape (1,) :type num_non_exiting_tiles: cute.Tensor :param alpha: Alpha tensor for each group - :type alpha: cute.Tensor + :type alpha: Union[cute.Tensor, Tuple[cute.Tensor, ...]] :param max_active_clusters: Maximum number of active clusters :type max_active_clusters: cutlass.Constexpr :param stream: CUDA stream for asynchronous execution @@ -654,12 +682,16 @@ def __call__( """ # Setup static attributes before smem/grid/tma computation self.a_dtype: Type[cutlass.Numeric] = a.element_type - self.b_dtype: Type[cutlass.Numeric] = b.element_type + # Handle tuple of B tensors + b_tuple = b if isinstance(b, tuple) else (b,) + sfb_tuple = sfb if isinstance(sfb, tuple) else (sfb,) + alpha_tuple = alpha if isinstance(alpha, tuple) else (alpha,) + self.b_dtype: Type[cutlass.Numeric] = b_tuple[0].element_type self.out_dtype: Type[cutlass.Numeric] = out.element_type self.sf_dtype: Type[cutlass.Numeric] = sfa.element_type self.final_scale_dtype: Type[cutlass.Numeric] = token_final_scales.element_type self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode() - self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(b_tuple[0]).mma_major_mode() self.gemm_output_layout = utils.LayoutEnum.ROW_MAJOR self.topK = token_final_scales.shape[1] @@ -675,8 +707,27 @@ def __call__( sfa = cute.make_tensor(sfa.iterator, sfa_layout) # ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL) - sfb_layout = blockscaled_utils.tile_atom_to_shape_SF(b.shape, self.sf_vec_size) - sfb = cute.make_tensor(sfb.iterator, sfb_layout) + sfb_layout_0 = blockscaled_utils.tile_atom_to_shape_SF(b_tuple[0].shape, self.sf_vec_size) + sfb_tensor_0 = cute.make_tensor(sfb_tuple[0].iterator, sfb_layout_0) + sfb_tensors = [sfb_tensor_0] + if cutlass.const_expr(self.num_b_tensors >= 2): + sfb_layout_1 = blockscaled_utils.tile_atom_to_shape_SF( + b_tuple[1].shape, self.sf_vec_size + ) + sfb_tensors.append(cute.make_tensor(sfb_tuple[1].iterator, sfb_layout_1)) + if cutlass.const_expr(self.num_b_tensors >= 3): + sfb_layout_2 = blockscaled_utils.tile_atom_to_shape_SF( + b_tuple[2].shape, self.sf_vec_size + ) + sfb_tensors.append(cute.make_tensor(sfb_tuple[2].iterator, sfb_layout_2)) + if cutlass.const_expr(self.num_b_tensors >= 4): + sfb_layout_3 = blockscaled_utils.tile_atom_to_shape_SF( + b_tuple[3].shape, self.sf_vec_size + ) + sfb_tensors.append(cute.make_tensor(sfb_tuple[3].iterator, sfb_layout_3)) + sfb_tuple = tuple(sfb_tensors) + # Backward compat alias + sfb = sfb_tuple[0] tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( self.a_dtype, @@ -714,14 +765,6 @@ def __call__( # Setup TMA load for B b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id) b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) - tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( - b_op, - b, - b_smem_layout, - self.mma_tiler, - tiled_mma, - self.cluster_layout_vmnk.shape, - ) # Setup TMA load for SFA sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id) @@ -739,34 +782,74 @@ def __call__( # Setup TMA load for SFB sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(self.cluster_shape_mn, tiled_mma.thr_id) sfb_smem_layout = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0)) - tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( - sfb_op, - sfb, - sfb_smem_layout, - self.mma_tiler_sfb, - tiled_mma_sfb, - self.cluster_layout_sfb_vmnk.shape, - internal_type=cutlass.Int16, - ) - - if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192): - x = tma_tensor_sfb.stride[0][1] - y = cute.ceil_div(tma_tensor_sfb.shape[0][1], 4) - new_shape = ( - (tma_tensor_sfb.shape[0][0], ((2, 2), y)), - tma_tensor_sfb.shape[1], - tma_tensor_sfb.shape[2], + # Helper to create TMA for one B tensor + def _make_tma_b(b_tensor, sfb_tensor): + atom_b, tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b_tensor, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, ) - # Use right multiplication for ScaledBasis (3 * x instead of x * 3) - x_times_3 = 3 * x - new_stride = ( - (tma_tensor_sfb.stride[0][0], ((x, x), x_times_3)), - tma_tensor_sfb.stride[1], - tma_tensor_sfb.stride[2], + atom_sfb, tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( + sfb_op, + sfb_tensor, + sfb_smem_layout, + self.mma_tiler_sfb, + tiled_mma_sfb, + self.cluster_layout_sfb_vmnk.shape, + internal_type=cutlass.Int16, ) - tma_tensor_sfb_new_layout = cute.make_layout(new_shape, stride=new_stride) - tma_tensor_sfb = cute.make_tensor(tma_tensor_sfb.iterator, tma_tensor_sfb_new_layout) + if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192): + x = tensor_sfb.stride[0][1] + y = cute.ceil_div(tensor_sfb.shape[0][1], 4) + + new_shape = ( + (tensor_sfb.shape[0][0], ((2, 2), y)), + tensor_sfb.shape[1], + tensor_sfb.shape[2], + ) + # Use right multiplication for ScaledBasis (3 * x instead of x * 3) + x_times_3 = 3 * x + new_stride = ( + (tensor_sfb.stride[0][0], ((x, x), x_times_3)), + tensor_sfb.stride[1], + tensor_sfb.stride[2], + ) + tensor_sfb_new_layout = cute.make_layout(new_shape, stride=new_stride) + tensor_sfb = cute.make_tensor(tensor_sfb.iterator, tensor_sfb_new_layout) + return atom_b, tensor_b, atom_sfb, tensor_sfb + + # Create TMA for all B tensors (use const_expr, not loop) + atom_b_0, tensor_b_0, atom_sfb_0, tensor_sfb_0 = _make_tma_b(b_tuple[0], sfb_tuple[0]) + tma_atoms_b = [atom_b_0] + tma_tensors_b = [tensor_b_0] + tma_atoms_sfb = [atom_sfb_0] + tma_tensors_sfb = [tensor_sfb_0] + if cutlass.const_expr(self.num_b_tensors >= 2): + atom_b_1, tensor_b_1, atom_sfb_1, tensor_sfb_1 = _make_tma_b(b_tuple[1], sfb_tuple[1]) + tma_atoms_b.append(atom_b_1) + tma_tensors_b.append(tensor_b_1) + tma_atoms_sfb.append(atom_sfb_1) + tma_tensors_sfb.append(tensor_sfb_1) + if cutlass.const_expr(self.num_b_tensors >= 3): + atom_b_2, tensor_b_2, atom_sfb_2, tensor_sfb_2 = _make_tma_b(b_tuple[2], sfb_tuple[2]) + tma_atoms_b.append(atom_b_2) + tma_tensors_b.append(tensor_b_2) + tma_atoms_sfb.append(atom_sfb_2) + tma_tensors_sfb.append(tensor_sfb_2) + if cutlass.const_expr(self.num_b_tensors >= 4): + atom_b_3, tensor_b_3, atom_sfb_3, tensor_sfb_3 = _make_tma_b(b_tuple[3], sfb_tuple[3]) + tma_atoms_b.append(atom_b_3) + tma_tensors_b.append(tensor_b_3) + tma_atoms_sfb.append(atom_sfb_3) + tma_tensors_sfb.append(tensor_sfb_3) + tma_atoms_b = tuple(tma_atoms_b) + tma_tensors_b = tuple(tma_tensors_b) + tma_atoms_sfb = tuple(tma_atoms_sfb) + tma_tensors_sfb = tuple(tma_tensors_sfb) a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) @@ -777,7 +860,7 @@ def __call__( ) * atom_thr_size self.tile_sched_params, grid = self._compute_grid( - (a.shape[0], b.shape[0], a.shape[2]), + (a.shape[0], b_tuple[0].shape[0], a.shape[2]), self.cta_tile_shape_mnk, self.cluster_shape_mn, max_active_clusters, @@ -862,17 +945,17 @@ class SharedStorage: tiled_mma_sfb, tma_atom_a, tma_tensor_a, - tma_atom_b, - tma_tensor_b, + tma_atoms_b, + tma_tensors_b, tma_atom_sfa, tma_tensor_sfa, - tma_atom_sfb, - tma_tensor_sfb, + tma_atoms_sfb, + tma_tensors_sfb, out, tile_idx_to_expert_idx, num_non_exiting_tiles, tile_idx_to_mn_limit, - alpha, + alpha_tuple, permuted_idx_to_expanded_idx, token_final_scales, self.cluster_layout_vmnk, @@ -947,17 +1030,17 @@ def kernel( tiled_mma_sfb: cute.TiledMma, tma_atom_a: cute.CopyAtom, mA_mkl: cute.Tensor, - tma_atom_b: cute.CopyAtom, - mB_nkl: cute.Tensor, + tma_atoms_b: Tuple[cute.CopyAtom, ...], + mB_nkl_tuple: Tuple[cute.Tensor, ...], tma_atom_sfa: cute.CopyAtom, mSFA_mkl: cute.Tensor, - tma_atom_sfb: cute.CopyAtom, - mSFB_nkl: cute.Tensor, + tma_atoms_sfb: Tuple[cute.CopyAtom, ...], + mSFB_nkl_tuple: Tuple[cute.Tensor, ...], out: cute.Tensor, tile_idx_to_expert_idx: cute.Tensor, num_non_exiting_tiles: cute.Tensor, tile_idx_to_mn_limit: cute.Tensor, - alpha: cute.Tensor, + alpha_tuple: Tuple[cute.Tensor, ...], permuted_idx_to_expanded_idx: cute.Tensor, token_final_scales: cute.Tensor, cluster_layout_vmnk: cute.Layout, @@ -984,9 +1067,18 @@ def kernel( # if warp_idx == self.tma_warp_id: cpasync.prefetch_descriptor(tma_atom_a) - cpasync.prefetch_descriptor(tma_atom_b) cpasync.prefetch_descriptor(tma_atom_sfa) - cpasync.prefetch_descriptor(tma_atom_sfb) + cpasync.prefetch_descriptor(tma_atoms_b[0]) + cpasync.prefetch_descriptor(tma_atoms_sfb[0]) + if cutlass.const_expr(self.num_b_tensors >= 2): + cpasync.prefetch_descriptor(tma_atoms_b[1]) + cpasync.prefetch_descriptor(tma_atoms_sfb[1]) + if cutlass.const_expr(self.num_b_tensors >= 3): + cpasync.prefetch_descriptor(tma_atoms_b[2]) + cpasync.prefetch_descriptor(tma_atoms_sfb[2]) + if cutlass.const_expr(self.num_b_tensors >= 4): + cpasync.prefetch_descriptor(tma_atoms_b[3]) + cpasync.prefetch_descriptor(tma_atoms_sfb[3]) use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 @@ -1119,9 +1211,29 @@ def kernel( mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) ) # (bN, bK, loopN, loopK, loopL) - gB_nkl = cute.local_tile( - mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + gB_nkl_0 = cute.local_tile( + mB_nkl_tuple[0], + cute.slice_(self.mma_tiler, (0, None, None)), + (None, None, None), ) + if cutlass.const_expr(self.num_b_tensors >= 2): + gB_nkl_1 = cute.local_tile( + mB_nkl_tuple[1], + cute.slice_(self.mma_tiler, (0, None, None)), + (None, None, None), + ) + if cutlass.const_expr(self.num_b_tensors >= 3): + gB_nkl_2 = cute.local_tile( + mB_nkl_tuple[2], + cute.slice_(self.mma_tiler, (0, None, None)), + (None, None, None), + ) + if cutlass.const_expr(self.num_b_tensors >= 4): + gB_nkl_3 = cute.local_tile( + mB_nkl_tuple[3], + cute.slice_(self.mma_tiler, (0, None, None)), + (None, None, None), + ) # (bM, bK, RestM, RestK, RestL) gSFA_mkl = cute.local_tile( @@ -1129,11 +1241,29 @@ def kernel( ) # (bN, bK, RestN, RestK, RestL) - gSFB_nkl = cute.local_tile( - mSFB_nkl, + gSFB_nkl_0 = cute.local_tile( + mSFB_nkl_tuple[0], cute.slice_(self.mma_tiler_sfb, (0, None, None)), (None, None, None), ) + if cutlass.const_expr(self.num_b_tensors >= 2): + gSFB_nkl_1 = cute.local_tile( + mSFB_nkl_tuple[1], + cute.slice_(self.mma_tiler_sfb, (0, None, None)), + (None, None, None), + ) + if cutlass.const_expr(self.num_b_tensors >= 3): + gSFB_nkl_2 = cute.local_tile( + mSFB_nkl_tuple[2], + cute.slice_(self.mma_tiler_sfb, (0, None, None)), + (None, None, None), + ) + if cutlass.const_expr(self.num_b_tensors >= 4): + gSFB_nkl_3 = cute.local_tile( + mSFB_nkl_tuple[3], + cute.slice_(self.mma_tiler_sfb, (0, None, None)), + (None, None, None), + ) k_tile_cnt = cutlass.Int32(cute.size(gA_mkl, mode=[3])) @@ -1145,11 +1275,23 @@ def kernel( # (MMA, MMA_M, MMA_K, loopM, loopK, loopL) tCgA = thr_mma.partition_A(gA_mkl) # (MMA, MMA_N, MMA_K, loopN, loopK, loopL) - tCgB = thr_mma.partition_B(gB_nkl) + tCgB_0 = thr_mma.partition_B(gB_nkl_0) + if cutlass.const_expr(self.num_b_tensors >= 2): + tCgB_1 = thr_mma.partition_B(gB_nkl_1) + if cutlass.const_expr(self.num_b_tensors >= 3): + tCgB_2 = thr_mma.partition_B(gB_nkl_2) + if cutlass.const_expr(self.num_b_tensors >= 4): + tCgB_3 = thr_mma.partition_B(gB_nkl_3) # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) tCgSFA = thr_mma.partition_A(gSFA_mkl) # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) - tCgSFB = thr_mma_sfb.partition_B(gSFB_nkl) + tCgSFB_0 = thr_mma_sfb.partition_B(gSFB_nkl_0) + if cutlass.const_expr(self.num_b_tensors >= 2): + tCgSFB_1 = thr_mma_sfb.partition_B(gSFB_nkl_1) + if cutlass.const_expr(self.num_b_tensors >= 3): + tCgSFB_2 = thr_mma_sfb.partition_B(gSFB_nkl_2) + if cutlass.const_expr(self.num_b_tensors >= 4): + tCgSFB_3 = thr_mma_sfb.partition_B(gSFB_nkl_3) # # Partition global/shared tensor for TMA load A/B @@ -1169,13 +1311,37 @@ def kernel( b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape) # ((atom_v, rest_v), STAGE) # ((atom_v, rest_v), loopM, loopK, loopL) - tBsB, tBgB = cpasync.tma_partition( - tma_atom_b, + tBsB_0, tBgB_0 = cpasync.tma_partition( + tma_atoms_b[0], block_in_cluster_coord_vmnk[1], b_cta_layout, cute.group_modes(sB, 0, 3), - cute.group_modes(tCgB, 0, 3), - ) + cute.group_modes(tCgB_0, 0, 3), + ) + if cutlass.const_expr(self.num_b_tensors >= 2): + _, tBgB_1 = cpasync.tma_partition( + tma_atoms_b[1], + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB_1, 0, 3), + ) + if cutlass.const_expr(self.num_b_tensors >= 3): + _, tBgB_2 = cpasync.tma_partition( + tma_atoms_b[2], + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB_2, 0, 3), + ) + if cutlass.const_expr(self.num_b_tensors >= 4): + _, tBgB_3 = cpasync.tma_partition( + tma_atoms_b[3], + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB_3, 0, 3), + ) # TMA load SFA partition_S/D sfa_cta_layout = a_cta_layout @@ -1199,15 +1365,42 @@ def kernel( ) # ((atom_v, rest_v), STAGE) # ((atom_v, rest_v), RestN, RestK, RestL) - tBsSFB, tBgSFB = cute.nvgpu.cpasync.tma_partition( - tma_atom_sfb, + tBsSFB_0, tBgSFB_0 = cute.nvgpu.cpasync.tma_partition( + tma_atoms_sfb[0], block_in_cluster_coord_sfb_vmnk[1], sfb_cta_layout, cute.group_modes(sSFB, 0, 3), - cute.group_modes(tCgSFB, 0, 3), - ) - tBsSFB = cute.filter_zeros(tBsSFB) - tBgSFB = cute.filter_zeros(tBgSFB) + cute.group_modes(tCgSFB_0, 0, 3), + ) + tBsSFB_0 = cute.filter_zeros(tBsSFB_0) + tBgSFB_0 = cute.filter_zeros(tBgSFB_0) + if cutlass.const_expr(self.num_b_tensors >= 2): + _, tBgSFB_1 = cute.nvgpu.cpasync.tma_partition( + tma_atoms_sfb[1], + block_in_cluster_coord_sfb_vmnk[1], + sfb_cta_layout, + cute.group_modes(sSFB, 0, 3), + cute.group_modes(tCgSFB_1, 0, 3), + ) + tBgSFB_1 = cute.filter_zeros(tBgSFB_1) + if cutlass.const_expr(self.num_b_tensors >= 3): + _, tBgSFB_2 = cute.nvgpu.cpasync.tma_partition( + tma_atoms_sfb[2], + block_in_cluster_coord_sfb_vmnk[1], + sfb_cta_layout, + cute.group_modes(sSFB, 0, 3), + cute.group_modes(tCgSFB_2, 0, 3), + ) + tBgSFB_2 = cute.filter_zeros(tBgSFB_2) + if cutlass.const_expr(self.num_b_tensors >= 4): + _, tBgSFB_3 = cute.nvgpu.cpasync.tma_partition( + tma_atoms_sfb[3], + block_in_cluster_coord_sfb_vmnk[1], + sfb_cta_layout, + cute.group_modes(sSFB, 0, 3), + cute.group_modes(tCgSFB_3, 0, 3), + ) + tBgSFB_3 = cute.filter_zeros(tBgSFB_3) # # Partition shared/tensor memory tensor for TiledMMA_A/B/C @@ -1395,19 +1588,15 @@ def kernel( # # ((atom_v, rest_v), loopK) tAgA_slice = tAgA[(None, mma_tile_coord_mnl[0], None, 0)] - # ((atom_v, rest_v), loopK) - tBgB_slice = tBgB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] # ((atom_v, rest_v), RestK) tAgSFA_slice = tAgSFA[(None, mma_tile_coord_mnl[0], None, 0)] + expert_idx = mma_tile_coord_mnl[2] slice_n = mma_tile_coord_mnl[1] if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64): slice_n = mma_tile_coord_mnl[1] // 2 - # ((atom_v, rest_v), RestK) - tBgSFB_slice = tBgSFB[(None, slice_n, None, mma_tile_coord_mnl[2])] - # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt ab_producer_state.reset_count() peek_ab_empty_status = cutlass.Boolean(1) @@ -1418,13 +1607,11 @@ def kernel( # for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): tAgA_k = tAgA_slice[(None, ab_producer_state.count)] - tBgB_k = tBgB_slice[(None, ab_producer_state.count)] tAgSFA_k = tAgSFA_slice[(None, ab_producer_state.count)] - tBgSFB_k = tBgSFB_slice[(None, ab_producer_state.count)] tAsA_pipe = tAsA[(None, ab_producer_state.index)] - tBsB_pipe = tBsB[(None, ab_producer_state.index)] + tBsB_pipe = tBsB_0[(None, ab_producer_state.index)] tAsSFA_pipe = tAsSFA[(None, ab_producer_state.index)] - tBsSFB_pipe = tBsSFB[(None, ab_producer_state.index)] + tBsSFB_pipe = tBsSFB_0[(None, ab_producer_state.index)] tma_bar = ab_pipeline.producer_get_barrier(ab_producer_state) @@ -1439,14 +1626,6 @@ def kernel( tma_bar_ptr=tma_bar, mcast_mask=a_full_mcast_mask, ) - cute.copy( - tma_atom_b, - tBgB_k, - tBsB_pipe, - tma_bar_ptr=tma_bar, - mcast_mask=b_full_mcast_mask, - ) - cute.copy( tma_atom_sfa, tAgSFA_k, @@ -1454,13 +1633,235 @@ def kernel( tma_bar_ptr=tma_bar, mcast_mask=sfa_full_mcast_mask, ) - cute.copy( - tma_atom_sfb, - tBgSFB_k, - tBsSFB_pipe, - tma_bar_ptr=tma_bar, - mcast_mask=sfb_full_mcast_mask, - ) + # Select correct B tensor based on expert_idx + if cutlass.const_expr(self.num_b_tensors == 1): + tBgB_slice = tBgB_0[(None, mma_tile_coord_mnl[1], None, expert_idx)] + tBgSFB_slice = tBgSFB_0[(None, slice_n, None, expert_idx)] + cute.copy( + tma_atoms_b[0], + tBgB_slice[(None, ab_producer_state.count)], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[0], + tBgSFB_slice[(None, ab_producer_state.count)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + else: + if cutlass.const_expr(self.num_b_tensors == 2): + if expert_idx < self.b_tensor_l_offsets[1]: + local_l_0 = expert_idx - self.b_tensor_l_offsets[0] + cute.copy( + tma_atoms_b[0], + tBgB_0[ + ( + None, + mma_tile_coord_mnl[1], + ab_producer_state.count, + local_l_0, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[0], + tBgSFB_0[(None, slice_n, ab_producer_state.count, local_l_0)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + else: + local_l_1 = expert_idx - self.b_tensor_l_offsets[1] + cute.copy( + tma_atoms_b[1], + tBgB_1[ + ( + None, + mma_tile_coord_mnl[1], + ab_producer_state.count, + local_l_1, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[1], + tBgSFB_1[(None, slice_n, ab_producer_state.count, local_l_1)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + elif cutlass.const_expr(self.num_b_tensors == 3): + if expert_idx < self.b_tensor_l_offsets[1]: + local_l_0 = expert_idx - self.b_tensor_l_offsets[0] + cute.copy( + tma_atoms_b[0], + tBgB_0[ + ( + None, + mma_tile_coord_mnl[1], + ab_producer_state.count, + local_l_0, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[0], + tBgSFB_0[(None, slice_n, ab_producer_state.count, local_l_0)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + elif expert_idx < self.b_tensor_l_offsets[2]: + local_l_1 = expert_idx - self.b_tensor_l_offsets[1] + cute.copy( + tma_atoms_b[1], + tBgB_1[ + ( + None, + mma_tile_coord_mnl[1], + ab_producer_state.count, + local_l_1, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[1], + tBgSFB_1[(None, slice_n, ab_producer_state.count, local_l_1)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + else: + local_l_2 = expert_idx - self.b_tensor_l_offsets[2] + cute.copy( + tma_atoms_b[2], + tBgB_2[ + ( + None, + mma_tile_coord_mnl[1], + ab_producer_state.count, + local_l_2, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[2], + tBgSFB_2[(None, slice_n, ab_producer_state.count, local_l_2)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + else: + if expert_idx < self.b_tensor_l_offsets[1]: + local_l_0 = expert_idx - self.b_tensor_l_offsets[0] + cute.copy( + tma_atoms_b[0], + tBgB_0[ + ( + None, + mma_tile_coord_mnl[1], + ab_producer_state.count, + local_l_0, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[0], + tBgSFB_0[(None, slice_n, ab_producer_state.count, local_l_0)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + elif expert_idx < self.b_tensor_l_offsets[2]: + local_l_1 = expert_idx - self.b_tensor_l_offsets[1] + cute.copy( + tma_atoms_b[1], + tBgB_1[ + ( + None, + mma_tile_coord_mnl[1], + ab_producer_state.count, + local_l_1, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[1], + tBgSFB_1[(None, slice_n, ab_producer_state.count, local_l_1)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + elif expert_idx < self.b_tensor_l_offsets[3]: + local_l_2 = expert_idx - self.b_tensor_l_offsets[2] + cute.copy( + tma_atoms_b[2], + tBgB_2[ + ( + None, + mma_tile_coord_mnl[1], + ab_producer_state.count, + local_l_2, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[2], + tBgSFB_2[(None, slice_n, ab_producer_state.count, local_l_2)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + else: + local_l_3 = expert_idx - self.b_tensor_l_offsets[3] + cute.copy( + tma_atoms_b[3], + tBgB_3[ + ( + None, + mma_tile_coord_mnl[1], + ab_producer_state.count, + local_l_3, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[3], + tBgSFB_3[(None, slice_n, ab_producer_state.count, local_l_3)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 ab_producer_state.advance() @@ -1798,7 +2199,33 @@ def kernel( # expert_idx = mma_tile_coord_mnl[2] - alpha_val = alpha[expert_idx] + alpha_val = alpha_tuple[0][expert_idx - self.b_tensor_l_offsets[0]] + if cutlass.const_expr(self.num_b_tensors == 1): + pass + elif cutlass.const_expr(self.num_b_tensors == 2): + if expert_idx >= self.b_tensor_l_offsets[1]: + alpha_val = alpha_tuple[1][expert_idx - self.b_tensor_l_offsets[1]] + elif cutlass.const_expr(self.num_b_tensors == 3): + if ( + expert_idx >= self.b_tensor_l_offsets[1] + and expert_idx < self.b_tensor_l_offsets[2] + ): + alpha_val = alpha_tuple[1][expert_idx - self.b_tensor_l_offsets[1]] + elif expert_idx >= self.b_tensor_l_offsets[2]: + alpha_val = alpha_tuple[2][expert_idx - self.b_tensor_l_offsets[2]] + else: + if ( + expert_idx >= self.b_tensor_l_offsets[1] + and expert_idx < self.b_tensor_l_offsets[2] + ): + alpha_val = alpha_tuple[1][expert_idx - self.b_tensor_l_offsets[1]] + elif ( + expert_idx >= self.b_tensor_l_offsets[2] + and expert_idx < self.b_tensor_l_offsets[3] + ): + alpha_val = alpha_tuple[2][expert_idx - self.b_tensor_l_offsets[2]] + elif expert_idx >= self.b_tensor_l_offsets[3]: + alpha_val = alpha_tuple[3][expert_idx - self.b_tensor_l_offsets[3]] tile_m_start = tile_info[0] * self.cta_tile_shape_mnk[0] permuted_row = tile_m_start + epi_tidx @@ -2496,11 +2923,11 @@ def can_implement( def wrapper( self, a_ptr: cute.Pointer, - b_ptr: cute.Pointer, + b_ptr_tuple: Tuple[cute.Pointer, ...], a_sf_ptr: cute.Pointer, - b_sf_ptr: cute.Pointer, + b_sf_ptr_tuple: Tuple[cute.Pointer, ...], c_ptr: cute.Pointer, - alpha_ptr: cute.Pointer, + alpha_ptr_tuple: Tuple[cute.Pointer, ...], tile_idx_to_group_idx_ptr: cute.Pointer, tile_idx_to_mn_limit_ptr: cute.Pointer, permuted_idx_to_expanded_idx_ptr: cute.Pointer, @@ -2509,7 +2936,6 @@ def wrapper( m: cutlass.Int64, n: cutlass.Int64, k: cutlass.Int64, - l: cutlass.Int64, # noqa: E741 num_tokens: cutlass.Int64, top_k: cutlass.Int64, tile_size: cutlass.Constexpr, @@ -2518,26 +2944,87 @@ def wrapper( stream: cuda.CUstream, epilogue_op: cutlass.Constexpr = lambda x: x, ): + """Unified wrapper supporting both single-B and multi-B tensors. + + B tensors are always passed as tuples (length 1 for single-B). + L sizes are configured via b_tensor_l_sizes in __init__. + """ scale_k = k // scaling_vector_size num_tiles = m // tile_size + a = cute.make_tensor(a_ptr, layout=cute.make_ordered_layout((m, k, 1), order=(1, 0, 2))) - b = cute.make_tensor(b_ptr, layout=cute.make_ordered_layout((n, k, l), order=(1, 0, 2))) a_sf = cute.make_tensor( a_sf_ptr, layout=cute.make_ordered_layout( (32, 4, m // 128, 4, scale_k // 4, 1), order=(2, 1, 4, 0, 3, 5) ), ) - b_sf = cute.make_tensor( - b_sf_ptr, - layout=cute.make_ordered_layout( - (32, 4, n // 128, 4, scale_k // 4, l), order=(2, 1, 4, 0, 3, 5) - ), - ) c = cute.make_tensor( c_ptr, layout=cute.make_ordered_layout((num_tokens, n, 1), order=(1, 0, 2)) ) - alpha = cute.make_tensor(alpha_ptr, layout=cute.make_layout((l,))) + + l_0 = self.b_tensor_l_sizes[0] + alpha_0 = cute.make_tensor(alpha_ptr_tuple[0], layout=cute.make_layout((l_0,))) + b_0 = cute.make_tensor( + b_ptr_tuple[0], layout=cute.make_ordered_layout((n, k, l_0), order=(1, 0, 2)) + ) + b_sf_0 = cute.make_tensor( + b_sf_ptr_tuple[0], + layout=cute.make_ordered_layout( + (32, 4, n // 128, 4, scale_k // 4, l_0), order=(2, 1, 4, 0, 3, 5) + ), + ) + b_tuple = [b_0] + b_sf_tuple = [b_sf_0] + alpha_tuple = [alpha_0] + + if cutlass.const_expr(self.num_b_tensors >= 2): + l_1 = self.b_tensor_l_sizes[1] + alpha_1 = cute.make_tensor(alpha_ptr_tuple[1], layout=cute.make_layout((l_1,))) + b_1 = cute.make_tensor( + b_ptr_tuple[1], layout=cute.make_ordered_layout((n, k, l_1), order=(1, 0, 2)) + ) + b_sf_1 = cute.make_tensor( + b_sf_ptr_tuple[1], + layout=cute.make_ordered_layout( + (32, 4, n // 128, 4, scale_k // 4, l_1), order=(2, 1, 4, 0, 3, 5) + ), + ) + b_tuple.append(b_1) + b_sf_tuple.append(b_sf_1) + alpha_tuple.append(alpha_1) + + if cutlass.const_expr(self.num_b_tensors >= 3): + l_2 = self.b_tensor_l_sizes[2] + alpha_2 = cute.make_tensor(alpha_ptr_tuple[2], layout=cute.make_layout((l_2,))) + b_2 = cute.make_tensor( + b_ptr_tuple[2], layout=cute.make_ordered_layout((n, k, l_2), order=(1, 0, 2)) + ) + b_sf_2 = cute.make_tensor( + b_sf_ptr_tuple[2], + layout=cute.make_ordered_layout( + (32, 4, n // 128, 4, scale_k // 4, l_2), order=(2, 1, 4, 0, 3, 5) + ), + ) + b_tuple.append(b_2) + b_sf_tuple.append(b_sf_2) + alpha_tuple.append(alpha_2) + + if cutlass.const_expr(self.num_b_tensors >= 4): + l_3 = self.b_tensor_l_sizes[3] + alpha_3 = cute.make_tensor(alpha_ptr_tuple[3], layout=cute.make_layout((l_3,))) + b_3 = cute.make_tensor( + b_ptr_tuple[3], layout=cute.make_ordered_layout((n, k, l_3), order=(1, 0, 2)) + ) + b_sf_3 = cute.make_tensor( + b_sf_ptr_tuple[3], + layout=cute.make_ordered_layout( + (32, 4, n // 128, 4, scale_k // 4, l_3), order=(2, 1, 4, 0, 3, 5) + ), + ) + b_tuple.append(b_3) + b_sf_tuple.append(b_sf_3) + alpha_tuple.append(alpha_3) tile_idx_to_group_idx = cute.make_tensor( tile_idx_to_group_idx_ptr, layout=cute.make_layout((num_tiles,)) @@ -2558,14 +3045,14 @@ def wrapper( return self( a, - b, + tuple(b_tuple), c, a_sf, - b_sf, + tuple(b_sf_tuple), tile_idx_to_group_idx, num_non_exiting_tiles, tile_idx_to_mn_limit, - alpha, + tuple(alpha_tuple), max_active_clusters=max_active_clusters, stream=stream, permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx, diff --git a/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py b/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py index 855bec5f35ef..beff454807ff 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py @@ -36,6 +36,7 @@ from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm._torch.modules.fused_moe.interface import MoE from tensorrt_llm._torch.modules.fused_moe.routing import BaseMoeRoutingMethod +from tensorrt_llm._torch.pyexecutor.dwdp import get_global_dwdp_manager from tensorrt_llm._torch.utils import AuxStreamType, EventType, Fp4QuantizedTensor from tensorrt_llm.logger import logger from tensorrt_llm.models.modeling_utils import QuantConfig @@ -254,6 +255,19 @@ def __init__( # Validate configuration self.validate_config() + # ========== Optional DWDP integration ========== + self.dwdp_manager = get_global_dwdp_manager() + self.dwdp_handle_collector = None + self.dwdp_rank = None + self.enable_dwdp = False + if self.dwdp_manager is not None and self._should_enable_dwdp(): + self.enable_dwdp = True + self.dwdp_handle_collector = self.dwdp_manager.add_layer( + layer_idx=self.layer_idx, + ) + self.dwdp_rank = self.dwdp_manager.dwdp_rank + self.backend.dwdp_handle_collector = self.dwdp_handle_collector + # Mark as _weights_removed to skip ConfigurableMoE's post_load_weights in model_loader # The backend's post_load_weights will be called directly by model_loader # This avoids duplicate post_load_weights calls (once for ConfigurableMoE, once for backend) @@ -280,6 +294,22 @@ def validate_config(self): "apply_router_weight_on_input only supports top-1 routing" ) + def _should_enable_dwdp(self) -> bool: + # DWDP is currently supported only for CuteDslFusedMoE with NVFP4 quantization. + if not isinstance(self.backend, CuteDslFusedMoE): + return False + + quant_config = getattr(self.backend, "quant_config", None) + if quant_config is None: + quant_config = getattr(self.model_config, "quant_config", None) + if quant_config is None: + return False + + quant_mode = getattr(quant_config, "layer_quant_mode", None) + return bool( + quant_mode is not None and hasattr(quant_mode, "has_nvfp4") and quant_mode.has_nvfp4() + ) + def _create_comm_strategy(self, model_config: ModelConfig) -> Optional[Communication]: """ Create communication strategy based on configuration @@ -484,6 +514,10 @@ def forward_impl( do_finalize, ) + # DWDP: record compute and trigger next prefetch (per-layer, not per-chunk) + if self.enable_dwdp: + self.dwdp_manager.record_compute_and_prefetch_next(self.layer_idx) + # ========== Step 4: Handle output truncation and EPLB repeat ========== if self.use_dp and self.parallel_size > 1: outputs = outputs[: all_rank_num_tokens[self.mapping.tp_rank]] @@ -1163,6 +1197,11 @@ def _get_backend_kwargs( all_rank_num_tokens=all_rank_num_tokens, output_dtype=output_dtype ) + if self.enable_dwdp: + kwargs["dwdp_weight_view"] = self.dwdp_manager.build_weight_view( + self.layer_idx, self.backend + ) + # DeepGemm-specific parameters elif self.backend.__class__ == DeepGemmFusedMoE: if workspace is not None: diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py index 1273262f5f42..36dcbc87fff7 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py @@ -14,6 +14,7 @@ # limitations under the License. import math +from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch @@ -41,6 +42,26 @@ from .routing import BaseMoeRoutingMethod +@dataclass +class NvFp4WeightView: + """Bundles all NVFP4 weight tensors for MoE computation. + + Provides a unified interface for both non-DWDP and DWDP paths: + - Non-DWDP: each list contains 1 element (local weight). + - DWDP: each list contains N elements (one per DWDP rank), + where the local rank's entry holds the actual model weight + and other ranks' entries hold prefetched buffer tensors. + """ + w3_w1_weight: List[torch.Tensor] + fc1_weight_scale: List[torch.Tensor] + fc1_global_scale: List[torch.Tensor] + w2_weight: List[torch.Tensor] + fc2_weight_scale: List[torch.Tensor] + fc2_global_scale: List[torch.Tensor] + expert_size_per_partition: int + slot_start: int + + @torch.compile(options={"max-autotune": True}) def swiglu_fused_moe(x): x, gate = x.chunk(2, dim=-1) @@ -425,6 +446,7 @@ def __init__( init_load_balancer=init_load_balancer, without_comm=without_comm, ) + if self.aux_stream_dict is None: self.aux_stream_dict = aux_stream_dict if aux_stream_dict is not None else {} if AuxStreamType.MoeOutputMemset not in self.aux_stream_dict: @@ -436,6 +458,19 @@ def __init__( if key not in self.event_dict: self.event_dict[key] = torch.cuda.Event() + def _build_local_weight_view(self) -> NvFp4WeightView: + """Build weight view for non-DWDP path (single-element lists).""" + return NvFp4WeightView( + w3_w1_weight=[self.w3_w1_weight], + fc1_weight_scale=[self.quant_scales.fc1_weight_block], + fc1_global_scale=[self.quant_scales.fc1_global], + w2_weight=[self.w2_weight], + fc2_weight_scale=[self.quant_scales.fc2_weight_block], + fc2_global_scale=[self.quant_scales.fc2_global], + expert_size_per_partition=self.expert_size_per_partition, + slot_start=self.slot_start, + ) + def select_alltoall_method_type(self) -> AlltoallMethodType: return AlltoallMethodType.NotEnabled @@ -499,8 +534,21 @@ def run_moe_nvfp4( x_sf: Optional[torch.Tensor] = None, moe_output: Optional[torch.Tensor] = None, enable_alltoall: bool = False, + weight_view: Optional[NvFp4WeightView] = None, ) -> torch.Tensor: + """NVFP4 MoE computation with unified interface. + + Handles both non-DWDP and DWDP paths transparently: + - Non-DWDP (single-element weight lists): uses run_moe_nvfp4_impl. + Supports both fused-finalize and non-fused-finalize paths. + - DWDP (multi-element weight lists): uses run_moe_nvfp4_impl_dwdp. + Requires fused-finalize. + + Args: + weight_view: Bundled weight tensors. If None, local weights are used. + """ assert self.has_nvfp4 + assert weight_view is not None output_dtype = torch.bfloat16 if moe_output is None: @@ -513,24 +561,29 @@ def run_moe_nvfp4( self.hidden_size) assert moe_output.dtype == output_dtype - # After DeepEPLowLatency dispatch, token_selected_experts has shape - # [N, 1] instead of [N, top_k], because each row is already assigned - # to exactly one expert. Use the tensor shape as the effective top_k. effective_top_k = token_selected_experts.size(-1) + is_dwdp = len(weight_view.w3_w1_weight) > 1 + forward_impl = self.run_moe_nvfp4_impl_dwdp if is_dwdp else self.run_moe_nvfp4_impl + tuner = AutoTuner.get() runner = CuteDslFusedMoENvfp4Runner( - forward_impl=self.run_moe_nvfp4_impl, + forward_impl=forward_impl, num_experts=self.num_slots, top_k=effective_top_k, - num_local_experts=self.expert_size_per_partition, - local_expert_offset=self.slot_start, + num_local_experts=weight_view.expert_size_per_partition, + local_expert_offset=weight_view.slot_start, enable_finalize_fusion=self.use_fused_finalize, enable_alltoall=enable_alltoall, ) inputs = [ - x, token_selected_experts, token_final_scales, x_sf, moe_output + x, + token_selected_experts, + token_final_scales, + x_sf, + moe_output, + weight_view, ] _, best_tactic = tuner.choose_one( "CuteDslFusedMoE::run_moe_nvfp4", @@ -547,22 +600,23 @@ def run_moe_nvfp4_impl( token_final_scales: Optional[torch.Tensor], x_sf: torch.Tensor, moe_output: torch.Tensor, + weight_view: NvFp4WeightView, enable_alltoall: bool = False, tile_size: int = 128, ) -> torch.Tensor: + """Non-DWDP NVFP4 MoE implementation using single-tensor ops.""" output_dtype = torch.bfloat16 - - # Use effective top_k from tensor shape rather than routing config. - # After DeepEPLowLatency dispatch, each row maps to one expert (top_k=1). effective_top_k = token_selected_experts.size(1) + esp = weight_view.expert_size_per_partition + slot_start = weight_view.slot_start tile_idx_to_expert_idx, tile_idx_to_mn_limit, expanded_idx_to_permuted_idx, permuted_idx_to_expanded_idx, total_num_padded_tokens, num_non_exiting_tiles = torch.ops.trtllm.moe_sort( token_selected_experts=token_selected_experts, token_final_scales=token_final_scales, num_experts=self.num_slots, top_k=effective_top_k, - local_expert_offset=self.slot_start, - local_num_experts=self.expert_size_per_partition, + local_expert_offset=slot_start, + local_num_experts=esp, tile_tokens_dim=tile_size, ) @@ -573,10 +627,10 @@ def run_moe_nvfp4_impl( x, x_sf = torch.ops.trtllm.cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell( input=x.view(torch.float4_e2m1fn_x2), - weight=self.w3_w1_weight.view(torch.float4_e2m1fn_x2), + weight=weight_view.w3_w1_weight[0].view(torch.float4_e2m1fn_x2), input_scale=x_sf.view(torch.uint8), - weight_scale=self.quant_scales.fc1_weight_block.view(torch.uint8), - alpha=self.quant_scales.fc1_global, + weight_scale=weight_view.fc1_weight_scale[0].view(torch.uint8), + alpha=weight_view.fc1_global_scale[0], tile_idx_to_group_idx=tile_idx_to_expert_idx, tile_idx_to_mn_limit=tile_idx_to_mn_limit, permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx, @@ -584,8 +638,8 @@ def run_moe_nvfp4_impl( global_sf=self.fc2_input_scale, num_experts=self.num_slots, top_k=effective_top_k, - num_local_experts=self.expert_size_per_partition, - local_expert_offset=self.slot_start, + num_local_experts=esp, + local_expert_offset=slot_start, tile_size=tile_size, ) @@ -609,11 +663,12 @@ def run_moe_nvfp4_impl( torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell( input=x.view(torch.float4_e2m1fn_x2), - weight=self.w2_weight.view(torch.float4_e2m1fn_x2), + weight=[weight_view.w2_weight[0].view(torch.float4_e2m1fn_x2)], input_scale=x_sf.view(torch.uint8), - weight_scale=self.quant_scales.fc2_weight_block.view( - torch.uint8), - alpha=self.quant_scales.fc2_global, + weight_scale=[ + weight_view.fc2_weight_scale[0].view(torch.uint8) + ], + alpha=[weight_view.fc2_global_scale[0]], output=moe_output, tile_idx_to_group_idx=tile_idx_to_expert_idx, tile_idx_to_mn_limit=tile_idx_to_mn_limit, @@ -622,25 +677,24 @@ def run_moe_nvfp4_impl( token_final_scales=token_final_scales, num_experts=self.num_slots, top_k=effective_top_k, - num_local_experts=self.expert_size_per_partition, - local_expert_offset=self.slot_start, + num_local_experts=esp, + local_expert_offset=slot_start, tile_size=tile_size, output_dtype=output_dtype, ) else: x = torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_blackwell( input=x.view(torch.float4_e2m1fn_x2), - weight=self.w2_weight.view(torch.float4_e2m1fn_x2), + weight=weight_view.w2_weight[0].view(torch.float4_e2m1fn_x2), input_scale=x_sf.view(torch.uint8), - weight_scale=self.quant_scales.fc2_weight_block.view( - torch.uint8), - alpha=self.quant_scales.fc2_global, + weight_scale=weight_view.fc2_weight_scale[0].view(torch.uint8), + alpha=weight_view.fc2_global_scale[0], tile_idx_to_group_idx=tile_idx_to_expert_idx, num_non_exiting_tiles=num_non_exiting_tiles, num_experts=self.num_slots, top_k=effective_top_k, - num_local_experts=self.expert_size_per_partition, - local_expert_offset=self.slot_start, + num_local_experts=esp, + local_expert_offset=slot_start, tile_size=tile_size, output_dtype=output_dtype, ) @@ -652,6 +706,108 @@ def run_moe_nvfp4_impl( ) return moe_output + def run_moe_nvfp4_impl_dwdp( + self, + x: torch.Tensor, + token_selected_experts: torch.Tensor, + token_final_scales: Optional[torch.Tensor], + x_sf: torch.Tensor, + moe_output: torch.Tensor, + weight_view: NvFp4WeightView, + enable_alltoall: bool = False, + tile_size: int = 128, + ) -> torch.Tensor: + """DWDP NVFP4 MoE implementation using multi-B list ops. + + Requires fused-finalize since the non-fused FC2 op does not support + multiple B weight tensors. + """ + assert self.use_fused_finalize, ( + "DWDP requires fused finalize (cute_dsl_nvfp4_grouped_gemm_blackwell " + "does not support multiple B weight tensors)") + output_dtype = torch.bfloat16 + effective_top_k = token_selected_experts.size(1) + esp = weight_view.expert_size_per_partition + slot_start = weight_view.slot_start + + tile_idx_to_expert_idx, tile_idx_to_mn_limit, expanded_idx_to_permuted_idx, permuted_idx_to_expanded_idx, total_num_padded_tokens, num_non_exiting_tiles = torch.ops.trtllm.moe_sort( + token_selected_experts=token_selected_experts, + token_final_scales=token_final_scales, + num_experts=self.num_slots, + top_k=effective_top_k, + local_expert_offset=slot_start, + local_num_experts=esp, + tile_tokens_dim=tile_size, + ) + + self.event_dict[EventType.Main].record() + moe_output.record_stream( + self.aux_stream_dict[AuxStreamType.MoeOutputMemset]) + + x, x_sf = torch.ops.trtllm.cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell_multi_b( + input=x.view(torch.float4_e2m1fn_x2), + weight=[ + w.view(torch.float4_e2m1fn_x2) for w in weight_view.w3_w1_weight + ], + input_scale=x_sf.view(torch.uint8), + weight_scale=[ + ws.view(torch.uint8) for ws in weight_view.fc1_weight_scale + ], + alpha=weight_view.fc1_global_scale, + tile_idx_to_group_idx=tile_idx_to_expert_idx, + tile_idx_to_mn_limit=tile_idx_to_mn_limit, + permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx, + num_non_exiting_tiles=num_non_exiting_tiles, + global_sf=self.fc2_input_scale, + num_experts=self.num_slots, + top_k=effective_top_k, + num_local_experts=esp, + local_expert_offset=slot_start, + tile_size=tile_size, + ) + + with torch.cuda.stream( + self.aux_stream_dict[AuxStreamType.MoeOutputMemset]): + self.event_dict[EventType.Main].wait() + torch.ops.trtllm.moe_output_memset_inplace( + input=moe_output, + tile_idx_to_mn_limit=tile_idx_to_mn_limit, + expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx, + permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx, + num_non_exiting_tiles=num_non_exiting_tiles, + tile_tokens_dim=tile_size, + top_k=effective_top_k, + ep_size=self.mapping.moe_ep_size, + enable_alltoall=enable_alltoall, + ) + self.event_dict[EventType.MoeOutputMemset].record() + self.event_dict[EventType.MoeOutputMemset].wait() + + torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell( + input=x.view(torch.float4_e2m1fn_x2), + weight=[ + w.view(torch.float4_e2m1fn_x2) for w in weight_view.w2_weight + ], + input_scale=x_sf.view(torch.uint8), + weight_scale=[ + ws.view(torch.uint8) for ws in weight_view.fc2_weight_scale + ], + alpha=weight_view.fc2_global_scale, + output=moe_output, + tile_idx_to_group_idx=tile_idx_to_expert_idx, + tile_idx_to_mn_limit=tile_idx_to_mn_limit, + permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx, + num_non_exiting_tiles=num_non_exiting_tiles, + token_final_scales=token_final_scales, + num_experts=self.num_slots, + top_k=effective_top_k, + num_local_experts=esp, + local_expert_offset=slot_start, + tile_size=tile_size, + output_dtype=output_dtype, + ) + return moe_output + def run_moe_fp8_block_scales( self, x: torch.Tensor, @@ -739,6 +895,7 @@ def run_moe( x_sf: Optional[torch.Tensor] = None, moe_output: Optional[torch.Tensor] = None, enable_alltoall: bool = False, + **kwargs, ) -> torch.Tensor: """ Run MoE computation with CuteDSL backend. @@ -759,16 +916,21 @@ def run_moe( Returns: final_hidden_states tensor. """ + # Execute MoE computation if self.has_nvfp4: - return self.run_moe_nvfp4( + weight_view = kwargs.get( + "dwdp_weight_view") or self._build_local_weight_view() + result = self.run_moe_nvfp4( x=x, token_selected_experts=token_selected_experts, token_final_scales=token_final_scales, x_sf=x_sf, moe_output=moe_output, - enable_alltoall=enable_alltoall) + enable_alltoall=enable_alltoall, + weight_view=weight_view, + ) elif self.has_deepseek_fp8_block_scales: - return self.run_moe_fp8_block_scales( + result = self.run_moe_fp8_block_scales( x=x, token_selected_experts=token_selected_experts, token_final_scales=token_final_scales, @@ -778,6 +940,7 @@ def run_moe( raise ValueError( f"{self.__class__.__name__} doesn't support quantization mode {self.quant_config.quant_mode}." ) + return result def forward_chunk( self, @@ -815,3 +978,9 @@ def forward_chunk( x_sf=x_sf, enable_alltoall=False) return x + + def load_weights(self, weights: Dict[str, torch.Tensor]): + super().load_weights(weights) + dwdp_handle_collector = getattr(self, "dwdp_handle_collector", None) + if dwdp_handle_collector is not None: + dwdp_handle_collector.register_weights(self) diff --git a/tensorrt_llm/_torch/modules/fused_moe/interface.py b/tensorrt_llm/_torch/modules/fused_moe/interface.py index f5e8e1e6f5bb..08420ab0b815 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/interface.py +++ b/tensorrt_llm/_torch/modules/fused_moe/interface.py @@ -47,6 +47,7 @@ def _warn_and_return(reason: str) -> Tuple[bool, Optional[str]]: from ...model_config import ModelConfig +from ...pyexecutor.dwdp import get_global_dwdp_manager from ...utils import (ActivationType, AuxStreamType, Fp4QuantizedTensor, get_model_extra_attrs, is_gated_activation, is_torch_compiling) @@ -306,6 +307,30 @@ def __init__( self.initial_global_assignments = list(range(self.num_experts)) self.allreduce = None + # Override expert layout if DWDP is enabled + self._init_dwdp_expert_layout() + + def _init_dwdp_expert_layout(self): + """Override expert layout when DWDP is enabled.""" + dwdp_manager = get_global_dwdp_manager() + if dwdp_manager is None: + return + assert self.layer_load_balancer is None, ( + "DWDP and EPLB (MoE load balancer) cannot be used together. " + "Disable one of dwdp_config or moe_load_balancer.") + self.num_slots = self.num_experts + self.expert_size_per_partition = dwdp_manager.num_experts_per_worker + dwdp_size = dwdp_manager.dwdp_size + self.initial_global_assignments = [ + (ep_rank * self.num_experts // dwdp_size + local_slot_id) % + self.num_experts for ep_rank in range(dwdp_size) + for local_slot_id in range(self.expert_size_per_partition) + ] + self.slot_start = dwdp_manager.start_expert_id + self.slot_end = self.slot_start + self.expert_size_per_partition + self.initial_local_expert_ids = list( + range(self.slot_start, self.slot_end)) + def _init_load_balancer( self, model_config: ModelConfig, diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 112b7437df0f..fdc608ebad4e 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -31,6 +31,7 @@ get_spec_decoder, should_use_separate_draft_kv_cache) from .config_utils import (get_qwen3_hybrid_layer_masks, is_mla, is_nemotron_hybrid, is_qwen3_hybrid) +from .dwdp import DwdpManager from .guided_decoder import GuidedDecoder from .kv_cache_connector import KvCacheConnectorManager from .kv_cache_transceiver import AttentionTypeCpp, create_kv_cache_transceiver @@ -1122,6 +1123,7 @@ def create_py_executor_instance( cache_transceiver_config: Optional[CacheTransceiverConfig] = None, virtual_memory_pools: Optional[dict] = None, execution_stream: Optional[torch.cuda.Stream] = None, + dwdp_manager: Optional[DwdpManager] = None, ) -> PyExecutor: kv_cache_manager = resources.get(ResourceManagerType.KV_CACHE_MANAGER, None) @@ -1360,7 +1362,9 @@ def create_py_executor_instance( peft_cache_config=peft_cache_config, virtual_memory_pools=virtual_memory_pools, execution_stream=execution_stream, - waiting_queue_policy=waiting_queue_policy) + waiting_queue_policy=waiting_queue_policy, + dwdp_manager=dwdp_manager, + ) def create_torch_sampler_args( diff --git a/tensorrt_llm/_torch/pyexecutor/dwdp.py b/tensorrt_llm/_torch/pyexecutor/dwdp.py new file mode 100644 index 000000000000..5793aab4bac0 --- /dev/null +++ b/tensorrt_llm/_torch/pyexecutor/dwdp.py @@ -0,0 +1,585 @@ +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from cuda.bindings import driver as cuda_driver +from cuda.bindings import runtime as cudart +from mpi4py.MPI import COMM_WORLD + +from tensorrt_llm._torch.distributed import MPIDist +from tensorrt_llm._utils import global_mpi_rank, nvtx_range +from tensorrt_llm.llmapi.llm_args import DwdpConfig + +# Parameter names to collect handles for +WEIGHT_PARAMS = ["w3_w1_weight", "w2_weight"] +BIAS_PARAMS = ["w3_w1_bias", "w2_bias"] +# Quant scale params vary by quantization method +QUANT_SCALE_PARAMS = [ + "w3_w1_weight_scale", + "w2_weight_scale", # NVFP4/MXFP4 + "fc31_alpha", + "fc2_alpha", # NVFP4 alpha +] + + +_global_dwdp_manager: Optional["DwdpManager"] = None + + +def set_global_dwdp_manager(manager: "DwdpManager"): + global _global_dwdp_manager + _global_dwdp_manager = manager + + +def get_global_dwdp_manager() -> Optional["DwdpManager"]: + return _global_dwdp_manager + + +def check_cuda_error(err, context: str = ""): + """Check CUDA error.""" + if err != cudart.cudaError_t.cudaSuccess: + raise RuntimeError(f"CUDA error in {context}: {err}") + + +class DwdpLayerHandleCollector: + """ + Dwdp Layer Handle Collector for IPC handle coordination and prefetch buffer management. + """ + + def __init__( + self, + layer_idx: int, + ): + self.layer_idx = layer_idx + + # Local IPC handles: param_name -> handle_bytes + self.local_ipc_handles: Dict[str, bytes] = {} + # Local pointers: param_name -> data_ptr (for verification) + self.local_ptrs: Dict[str, int] = {} + # Local offsets: param_name -> offset from allocation base + # IPC handle points to allocation base, we need offset to get actual tensor data + self.local_offsets: Dict[str, int] = {} + # Parameter shapes: param_name -> shape (without expert dim) + self.param_shapes: Dict[str, torch.Size] = {} + # Parameter dtypes: param_name -> dtype + self.param_dtypes: Dict[str, torch.dtype] = {} + # Peer pointers: (peer_rank, param_name) -> ptr (already adjusted with offset) + self.peer_ptrs: Dict[Tuple[int, str], int] = {} + + def register_weights(self, module: nn.Module): + """ + Register weights from a MoE module and create IPC handles. + + Called after module.load_weights() completes. + + Args: + module: The MoE module with loaded weights + """ + params_to_register = [] + # Weights (check if present and not None) + for param_name in WEIGHT_PARAMS: + if hasattr(module, param_name) and getattr(module, param_name, None) is not None: + params_to_register.append(param_name) + # Bias (optional) + if hasattr(module, "bias"): + params_to_register.extend(BIAS_PARAMS) + # Quant scales (optional, depends on quant method) + for param_name in QUANT_SCALE_PARAMS: + if hasattr(module, param_name) and getattr(module, param_name, None) is not None: + params_to_register.append(param_name) + + # Register each parameter + for param_name in params_to_register: + param = getattr(module, param_name) + if isinstance(param, nn.Parameter): + param = param.data + if param is None: + continue + if not param.is_cuda or not param.is_contiguous(): + raise ValueError(f"Parameter {param_name} is not on GPU or is not contiguous") + self._register_param(param_name, param) + + def _register_param(self, param_name: str, param: torch.Tensor): + # Get IPC handle - note: handle points to the CUDA allocation base, not tensor's data_ptr + tensor_ptr = param.data_ptr() + err, handle = cudart.cudaIpcGetMemHandle(tensor_ptr) + check_cuda_error(err, f"get handle for {param_name}") + + # Get allocation base address using Driver API cuMemGetAddressRange + # This returns the actual base address and size of the CUDA allocation + # cudaPointerGetAttributes.devicePointer returns the input pointer, not base! + err, alloc_base, alloc_size = cuda_driver.cuMemGetAddressRange(tensor_ptr) + if err != cuda_driver.CUresult.CUDA_SUCCESS: + raise RuntimeError(f"cuMemGetAddressRange failed for {param_name}: {err}") + + # Calculate offset from allocation base + # Convert CUdeviceptr to int for arithmetic + offset = tensor_ptr - int(alloc_base) + + self.local_ipc_handles[param_name] = bytes(handle.reserved) + self.local_ptrs[param_name] = tensor_ptr + self.local_offsets[param_name] = offset + self.param_shapes[param_name] = param.shape[1:] + self.param_dtypes[param_name] = param.dtype + + def get_peer_ptr(self, peer_rank: int, param_name: str) -> int: + """Get pointer to parameter on peer rank.""" + return self.peer_ptrs[(peer_rank, param_name)] + + def cleanup(self): + """Clean up peer handles.""" + for _, ptr in self.peer_ptrs.items(): + cudart.cudaIpcCloseMemHandle(ptr) + self.peer_ptrs.clear() + + +class DwdpPrefetchBuffer: + """ + Ping-pong buffer for expert weight prefetching. + + Buffer Selection Strategy: + - Even layers (0, 2, 4, ...) use buffer[0] + - Odd layers (1, 3, 5, ...) use buffer[1] + - This ensures layer N-1's prefetch doesn't overwrite layer N's data + + Synchronization Strategy: + - prefetch_events[buffer_idx][layer_idx]: Recorded when prefetch completes + Waited by forward() before using prefetched data + - compute_events[buffer_idx][layer_idx]: Recorded when forward() completes + Waited by next prefetch before overwriting buffer + + Buffer Layout (organized by rank): + - buffers[buffer_idx][param_name] = List[Optional[Tensor]] + - len(list) == dwdp_size + - list[peer_rank] = Tensor[num_prefetch_experts, ...] for peer_rank != dwdp_rank + - list[dwdp_rank] = None (local weight used directly, not prefetched) + """ + + def __init__( + self, + dwdp_size: int, + dwdp_rank: int, + num_experts_per_worker: int, + num_prefetch_experts: int, + num_layers: int, + first_moe_layer_idx: int, + param_shapes: Dict[str, torch.Size], + param_dtypes: Dict[str, torch.dtype], + ): + self.dwdp_size = dwdp_size + self.num_prefetch_experts = num_prefetch_experts + self.num_experts_per_worker = num_experts_per_worker + self.num_layers = num_layers + self.first_moe_layer_idx = first_moe_layer_idx + self.num_buffers = 2 # Ping-pong + self.dwdp_rank = dwdp_rank + + self.param_shapes = param_shapes + self.param_dtypes = param_dtypes + + self.device = torch.cuda.current_device() + + # buffers[buffer_idx][param_name] = List[Optional[Tensor]] + # list[peer_rank] contains prefetched weights from that rank + # list[dwdp_rank] = None (local weights used directly) + self.buffers: List[Dict[str, List[Optional[torch.Tensor]]]] = [] + + for _ in range(self.num_buffers): + buffer = {} + for param_name, shape in param_shapes.items(): + dtype = param_dtypes[param_name] + # Pre-allocate list of length dwdp_size, one slot per rank + # tensor_list[dwdp_rank] = None (local weights used directly) + # tensor_list[peer_rank] = Tensor for prefetched weights from peer + tensor_list: List[Optional[torch.Tensor]] = [None] * dwdp_size + for peer_rank in range(dwdp_size): + if peer_rank != dwdp_rank: + buffer_shape = (self.num_prefetch_experts,) + tuple(shape) + tensor_list[peer_rank] = torch.empty( + buffer_shape, + dtype=dtype, + device=self.device, + ) + buffer[param_name] = tensor_list + self.buffers.append(buffer) + + self.max_layer_idx = num_layers + first_moe_layer_idx + self.prefetch_events: List[List[torch.cuda.Event]] = [ + [torch.cuda.Event() for _ in range(self.max_layer_idx // self.num_buffers + 1)] + for _ in range(self.num_buffers) + ] + self.compute_events: List[List[torch.cuda.Event]] = [ + [torch.cuda.Event() for _ in range(self.max_layer_idx // self.num_buffers + 1)] + for _ in range(self.num_buffers) + ] + self.prefetch_stream = torch.cuda.Stream(device=self.device) + + def initialize_compute_events(self): + for buffer_idx in range(self.num_buffers): + self.compute_events[buffer_idx][0].record(torch.cuda.current_stream()) + + def record_prefetch_event(self, layer_idx: int): + self.prefetch_events[layer_idx % self.num_buffers][layer_idx // self.num_buffers].record( + self.prefetch_stream + ) + + def record_compute_event(self, layer_idx: int): + self.compute_events[layer_idx % self.num_buffers][layer_idx // self.num_buffers].record( + torch.cuda.current_stream() + ) + + def wait_prefetch_event(self, layer_idx: int): + torch.cuda.current_stream().wait_event( + self.prefetch_events[layer_idx % self.num_buffers][layer_idx // self.num_buffers] + ) + + def wait_compute_event(self, layer_idx: int): + self.prefetch_stream.wait_event( + self.compute_events[layer_idx % self.num_buffers][layer_idx // self.num_buffers] + ) + + +class DwdpManager: + """ + Dwdp Manager for IPC handle coordination and prefetch buffer management. + + This manager: + - Tracks IPC handles for all MoE layers across Context workers + - Manages double-buffered prefetch buffers for remote expert weights + - Provides expert tensor routing (local vs. prefetched) + + """ + + def __init__( + self, + config: DwdpConfig, + dist: Optional[object] = None, + ): + self.config = config + self.dist = dist + self.dwdp_size = config.dwdp_size + self.num_experts_per_worker = config.num_experts_per_worker + self.num_groups = config.num_groups + + self._init_dwdp_group() + + # Per-layer IPC handle collectors (indexed by layer_idx) + self.ipc_collectors: List[DwdpLayerHandleCollector] = [] + + # Prefetch buffer (initialized later in create_py_executor) + self.prefetch_buffer: Optional[DwdpPrefetchBuffer] = None + # Auto-detected from first add_layer() call + self.first_moe_layer_idx: Optional[int] = None + + # Peer expert ranges: (peer_rank, (start_expert_id, end_expert_id)) + self.peer_expert_ranges: Dict[int, Tuple[int, int]] = {} + + self.dwdp_rank = self.rank % self.dwdp_size + self.num_prefetch_experts = config.num_prefetch_experts + self.start_expert_id = self.num_prefetch_experts * self.dwdp_rank + self.end_expert_id = self.start_expert_id + self.num_experts_per_worker + + def __enter__(self): + set_global_dwdp_manager(self) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.cleanup() + set_global_dwdp_manager(None) + return False + + def _init_dwdp_group(self): + if not isinstance(self.dist, MPIDist): + raise RuntimeError("DWDP requires MPI backend (MPIDist)") + + self.rank = global_mpi_rank() + + # Calculate which group this rank belongs to + # With num_groups=2, dwdp_size=4: + # Group 0: ranks [0, 1, 2, 3] + # Group 1: ranks [4, 5, 6, 7] + self.group_id = self.rank // self.dwdp_size + group_start_rank = self.group_id * self.dwdp_size + ranks = list(range(group_start_rank, group_start_rank + self.dwdp_size)) + + new_group = COMM_WORLD.group.Incl(ranks) + self.dwdp_group = COMM_WORLD.Create_group(new_group) + + def is_enabled(self) -> bool: + return self.dwdp_size > 1 + + def cleanup(self): + """Release all IPC handles and clean up resources.""" + for collector in self.ipc_collectors: + collector.cleanup() + self.ipc_collectors.clear() + if self.dwdp_group is not None: + self.dwdp_group.Free() + self.dwdp_group = None + + def add_layer( + self, + layer_idx: int, + ) -> "DwdpLayerHandleCollector": + """ + Add a new layer IPC handle collector. + + Called from CuteDslFusedMoE.__init__() during model construction. + """ + if self.first_moe_layer_idx is None: + self.first_moe_layer_idx = layer_idx + collector = DwdpLayerHandleCollector(layer_idx=layer_idx) + self.ipc_collectors.append(collector) + return collector + + def exchange_all_handles(self): + """ + Exchange IPC handles with peer Context workers via Dwdp Group AllGather. + + Called after all weights are loaded, before creating prefetch buffer. + """ + + # Collect all local handles with explicit worker info + local_data = { + "dwdp_rank": self.dwdp_rank, + "expert_start_id": self.start_expert_id, + "expert_end_id": self.end_expert_id, + "ipc_collectors": [], + } + for collector in self.ipc_collectors: + local_data["ipc_collectors"].append( + { + "layer_idx": collector.layer_idx, + "handles": collector.local_ipc_handles, + "offsets": collector.local_offsets, + } + ) + + # AllGather from all Context workers in DWDP group + all_data = self.dwdp_group.allgather(local_data) + + # Open handles from peer workers + for peer_data in all_data: + peer_rank = peer_data["dwdp_rank"] + self.peer_expert_ranges[peer_rank] = ( + peer_data["expert_start_id"], + peer_data["expert_end_id"], + ) + + if peer_rank == self.dwdp_rank: + continue + for layer_idx, ipc_collector in enumerate(peer_data["ipc_collectors"]): + collector = self.ipc_collectors[layer_idx] + peer_offsets = ipc_collector["offsets"] + for param_name, handle_bytes in ipc_collector["handles"].items(): + # Reconstruct and open handle + handle = cudart.cudaIpcMemHandle_t() + handle.reserved = list(handle_bytes) + + err, base_ptr = cudart.cudaIpcOpenMemHandle( + handle, cudart.cudaIpcMemLazyEnablePeerAccess + ) + check_cuda_error(err, f"open handle rank={peer_rank}") + + # Apply offset to get actual tensor pointer + # IPC handle points to allocation base, offset gives us the tensor location + offset = peer_offsets[param_name] + actual_ptr = base_ptr + offset + collector.peer_ptrs[(peer_rank, param_name)] = actual_ptr + + def initialize_prefetch_buffer(self): + """ + Initialize the prefetch buffer. + + Called in create_py_executor() after model loading. + """ + self.prefetch_buffer = DwdpPrefetchBuffer( + dwdp_size=self.dwdp_size, + dwdp_rank=self.dwdp_rank, + num_experts_per_worker=self.num_experts_per_worker, + num_prefetch_experts=self.num_prefetch_experts, + num_layers=len(self.ipc_collectors), + first_moe_layer_idx=self.first_moe_layer_idx, + param_shapes=self.ipc_collectors[0].param_shapes, + param_dtypes=self.ipc_collectors[0].param_dtypes, + ) + self.prefetch_buffer.initialize_compute_events() + + def prefetch_first_layers(self): + """Prefetch the first num_buffers layers as warmup.""" + if self.prefetch_buffer is None: + raise RuntimeError("Prefetch buffer is not initialized") + start = self.first_moe_layer_idx + for layer_idx in range(start, start + self.prefetch_buffer.num_buffers): + self.prefetch_layer(layer_idx) + self.prefetch_buffer.record_prefetch_event(layer_idx) + + def build_weight_view(self, layer_idx: int, backend): + """Build NvFp4WeightView from prefetch buffer and local weights. + + Assembles weight tensors from all DWDP ranks: + - Peer ranks: uses prefetched buffer tensors + - Local rank: uses backend's actual model weights + + Args: + layer_idx: The MoE layer index. + backend: The CuteDslFusedMoE backend holding local model weights. + + Returns: + NvFp4WeightView with all weights assembled. + """ + from tensorrt_llm._torch.modules.fused_moe.fused_moe_cute_dsl import NvFp4WeightView + + buffer_data = self.wait_prefetch_and_get_buffer(layer_idx) + required_keys = ( + "w3_w1_weight", + "w3_w1_weight_scale", + "fc31_alpha", + "w2_weight", + "w2_weight_scale", + "fc2_alpha", + ) + missing_keys = [key for key in required_keys if key not in buffer_data] + if missing_keys: + raise ValueError( + f"DWDP buffer missing required keys {missing_keys} for layer {layer_idx}." + ) + + w3_w1_weight_list = buffer_data["w3_w1_weight"] + fc1_weight_scale_list = buffer_data["w3_w1_weight_scale"] + fc1_global_scale_list = buffer_data["fc31_alpha"] + w2_weight_list = buffer_data["w2_weight"] + fc2_weight_scale_list = buffer_data["w2_weight_scale"] + fc2_global_scale_list = buffer_data["fc2_alpha"] + + w3_w1_weight_list[self.dwdp_rank] = backend.w3_w1_weight + fc1_weight_scale_list[self.dwdp_rank] = backend.quant_scales.fc1_weight_block + fc1_global_scale_list[self.dwdp_rank] = backend.quant_scales.fc1_global + w2_weight_list[self.dwdp_rank] = backend.w2_weight + fc2_weight_scale_list[self.dwdp_rank] = backend.quant_scales.fc2_weight_block + fc2_global_scale_list[self.dwdp_rank] = backend.quant_scales.fc2_global + + return NvFp4WeightView( + w3_w1_weight=w3_w1_weight_list, + fc1_weight_scale=fc1_weight_scale_list, + fc1_global_scale=fc1_global_scale_list, + w2_weight=w2_weight_list, + fc2_weight_scale=fc2_weight_scale_list, + fc2_global_scale=fc2_global_scale_list, + expert_size_per_partition=backend.num_slots, + slot_start=0, + ) + + def wait_prefetch_and_get_buffer( + self, layer_idx: int + ) -> Optional[Dict[str, List[Optional[torch.Tensor]]]]: + """Wait for prefetch to complete and return the buffer for this layer. + + Returns: + Dict mapping param_name to List[Optional[Tensor]] where: + - list[peer_rank] = Tensor for prefetched weights from that peer + - list[dwdp_rank] = None (local weights used directly) + """ + if self.prefetch_buffer is None: + raise RuntimeError("Prefetch buffer is not initialized") + self.prefetch_buffer.wait_prefetch_event(layer_idx) + buffer_idx = layer_idx % self.prefetch_buffer.num_buffers + return self.prefetch_buffer.buffers[buffer_idx] + + def record_compute_and_prefetch_next(self, layer_idx: int): + """Record compute completion and trigger prefetch for layer_idx + num_buffers.""" + if self.prefetch_buffer is None: + raise RuntimeError("Prefetch buffer is not initialized") + # Record compute event for current layer + self.prefetch_buffer.record_compute_event(layer_idx) + + next_layer_idx = layer_idx + self.prefetch_buffer.num_buffers + if next_layer_idx >= self.prefetch_buffer.max_layer_idx: + return + # prefetch_layer handles stream internally: local copy on default stream, peer copy on prefetch stream + self.prefetch_layer(next_layer_idx, wait_compute_layer_idx=layer_idx) + self.prefetch_buffer.record_prefetch_event(next_layer_idx) + + def _get_prefetch_src_offset_from_peer(self, peer_rank: int) -> int: + """ + Calculate the source offset (in number of experts) to fetch from a peer. + + Returns: + src_offset: Offset into peer's local expert tensor to start copying from + + Example: 256 experts, rank0: [0, 200), rank1: [56, 256) + - rank0 needs [200, 256) from rank1: + src_offset = 200 - 56 = 144 (fetch last 56 experts from rank1) + - rank1 needs [0, 56) from rank0: + src_offset = 0 - 0 = 0 (fetch first 56 experts from rank0) + """ + peer_start, peer_end = self.peer_expert_ranges[peer_rank] + + # What I need = global - what I have + # From peer = what I need ∩ what peer has + if self.dwdp_rank < peer_rank: + # I'm earlier rank, need experts after my end (tail of peer's experts) + prefetch_end = peer_end + prefetch_start = prefetch_end - self.num_prefetch_experts + else: + # I'm later rank, need experts before my start (head of peer's experts) + prefetch_start = peer_start + + src_offset = prefetch_start - peer_start + return src_offset + + @nvtx_range("dwdp_prefetch_layer") + def prefetch_layer(self, layer_idx: int, wait_compute_layer_idx: Optional[int] = None): + """ + Prefetch layer data from peer ranks. + + Args: + layer_idx: The layer to prefetch + wait_compute_layer_idx: If provided, wait for this layer's compute to complete + before overwriting buffer (used when prefetching next layer) + + Note: Local weights are used directly by the kernel, no copy needed. + Peer copy runs on prefetch stream. + """ + moe_idx = layer_idx - self.first_moe_layer_idx + param_names = self.ipc_collectors[moe_idx].param_shapes.keys() + collector = self.ipc_collectors[moe_idx] + buffer_idx = layer_idx % self.prefetch_buffer.num_buffers + + # Peer copy on prefetch stream + # Local weights are used directly - no local copy needed + with torch.cuda.stream(self.prefetch_buffer.prefetch_stream): + # Wait for compute to complete before overwriting buffer + if wait_compute_layer_idx is not None: + self.prefetch_buffer.wait_compute_event(wait_compute_layer_idx) + + for peer_rank in range(self.dwdp_size): + if peer_rank == self.dwdp_rank: + continue # Skip local rank - local weights used directly + + src_expert_offset = self._get_prefetch_src_offset_from_peer(peer_rank) + + for param_name in param_names: + param_shape = collector.param_shapes[param_name] + param_dtype = collector.param_dtypes[param_name] + expert_size = param_shape.numel() * param_dtype.itemsize + + # src_ptr points to peer's tensor start, add offset for specific experts + base_ptr = collector.get_peer_ptr(peer_rank, param_name) + src_ptr = base_ptr + src_expert_offset * expert_size + + # dst_tensor is directly indexed by peer_rank in the list + dst_tensor = self.prefetch_buffer.buffers[buffer_idx][param_name][peer_rank] + dst_ptr = dst_tensor.data_ptr() + + data_size = self.num_prefetch_experts * expert_size + + (err,) = cudart.cudaMemcpyAsync( + dst_ptr, + src_ptr, + data_size, + cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice, + self.prefetch_buffer.prefetch_stream.cuda_stream, + ) + check_cuda_error( + err, f"prefetch layer {layer_idx} peer_rank {peer_rank} {param_name}" + ) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 693fa3c75554..f673ec95da72 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -47,6 +47,7 @@ from ..speculative.drafter import Drafter from ..speculative.spec_sampler_base import SampleStateTensorsSpec from ..speculative.speculation_gate import SpeculationGate +from .dwdp import DwdpManager from .executor_request_queue import ExecutorRequestQueue, RequestQueueItem from .guided_decoder import GuidedDecoder from .handle_additional_outputs import HandleAdditionalOutputs @@ -285,7 +286,9 @@ def __init__( virtual_memory_pools: Optional[dict] = None, hang_detection_timeout: Optional[int] = None, execution_stream: Optional[torch.cuda.Stream] = None, - waiting_queue_policy: WaitingQueuePolicy = WaitingQueuePolicy.FCFS): + waiting_queue_policy: WaitingQueuePolicy = WaitingQueuePolicy.FCFS, + adp_router: Optional[ADPRouter] = None, + dwdp_manager: Optional[DwdpManager] = None): super(PyExecutor, self).__init__() self.device_id = torch.cuda.current_device() self.global_rank = dist.rank @@ -537,6 +540,11 @@ def on_detected(): if is_trace_enabled("TLLM_TRACE_EXECUTOR_LOOP"): self.event_loop = trace_func(self.event_loop) + if dwdp_manager is not None and not self.disable_overlap_scheduler: + raise ValueError( + "DWDP requires disable_overlap_scheduler=True. " + "Overlap scheduler is not yet supported with DWDP.") + if self.drafter is not None: if self.event_loop.__name__ == self._executor_loop_pp.__name__: raise NotImplementedError( @@ -552,6 +560,8 @@ def on_detected(): self._maybe_init_kv_connector_manager() + self.dwdp_manager = dwdp_manager + if start_worker: self.start_worker() @@ -780,6 +790,9 @@ def shutdown(self): if (isinstance(self.sampler, AsyncWorkerMixin) and self.sampler.async_worker_enabled()): self.sampler.async_worker_stop() + if self.dwdp_manager is not None: + self.dwdp_manager.__exit__(None, None, None) + self.dwdp_manager = None def can_enqueue_requests(self) -> bool: """ @@ -1994,6 +2007,8 @@ def _executor_loop(self): with self.perf_manager.record_perf_events( gpu_forward_start, gpu_forward_end) as fwd_timing: + if self.dwdp_manager is not None: + self.dwdp_manager.prefetch_first_layers() batch_outputs = self._forward_step(scheduled_batch) guided_decoder_failed_requests = None diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index 127ca5b62d1c..78d39cb45f89 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -37,6 +37,7 @@ create_py_executor_instance, instantiate_sampler, is_mla, validate_feature_combination) from .config_utils import is_nemotron_hybrid, is_qwen3_hybrid +from .dwdp import DwdpManager from .guided_decoder import CapturableGuidedDecoder, GuidedDecoder from .kv_cache_connector import KvCacheConnectorManager from .model_engine import PyTorchModelEngine @@ -384,6 +385,14 @@ def create_py_executor( ) logger.info("ATTENTION RUNTIME FEATURES: ", attn_runtime_features) + # Initialize DWDP Manager (only for context workers in disaggregated serving) + dwdp_manager: Optional[DwdpManager] = None + if llm_args.dwdp_config is not None: + assert mapping.tp_size == 1 and llm_args.dwdp_config.dwdp_size > 1, "DWDP requires TP=1 and dwdp_size > 1" + dwdp_manager = DwdpManager(config=llm_args.dwdp_config, dist=dist) + dwdp_manager.__enter__() + logger.info(f"Dwdp Manager initialized. Config: {llm_args.dwdp_config}") + mem_monitor = _ExecutorMemoryMonitor() @contextmanager @@ -748,6 +757,11 @@ def drafting_loop_wrapper(model): max_seq_len = kv_cache_creator._max_seq_len update_sampler_max_seq_len(max_seq_len, sampler) + # Exchange IPC Handles and Initialize Dwdp Prefetch Buffer + if dwdp_manager is not None: + dwdp_manager.exchange_all_handles() + dwdp_manager.initialize_prefetch_buffer() + # Resource managers for speculative decoding # For user-specified drafters, use extra_resource_managers in PyTorchBackend config # to provide a resource manager if required. @@ -865,6 +879,7 @@ def drafting_loop_wrapper(model): cache_transceiver_config=cache_transceiver_config, virtual_memory_pools=vm_pools, execution_stream=execution_stream, + dwdp_manager=dwdp_manager, ) _adjust_torch_mem_fraction() diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 4180f91b4a85..125d3803c71d 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -2387,6 +2387,30 @@ def model_name(self) -> Union[str, Path]: return self.model if isinstance(self.model, str) else None +class DwdpConfig(StrictBaseModel): + """Configuration for Distributed Weight Data Parallelism (DWDP). + + DWDP accelerates the context (prefill) phase of disaggregated MoE serving + by combining data parallelism with NVLink-based expert weight sharing. + Each worker holds a subset of experts locally and asynchronously prefetches + the remaining experts from peer workers via CUDA IPC, enabling fully + asynchronous execution across ranks without synchronization barriers. + + Currently supported with the CuteDSL MoE backend and NVFP4 quantization + on NVLink-connected multi-GPU systems. + """ + dwdp_size: int = Field(default=1, + description="The number of GPUs per DWDP group.") + num_groups: int = Field( + default=1, + description= + "The number of DWDP groups. Total workers = num_groups * dwdp_size.") + num_experts_per_worker: int = Field( + default=0, description="The number of experts per worker.") + num_prefetch_experts: int = Field( + default=0, description="The number of prefetch experts per worker.") + + class BaseLlmArgs(StrictBaseModel): """ Base class for both TorchLlmArgs and TrtLlmArgs. It contains all the arguments that are common to both. @@ -3275,6 +3299,11 @@ class TorchLlmArgs(BaseLlmArgs): description="NVFP4 GEMM backend config.", status="beta") + dwdp_config: Optional[DwdpConfig] = Field( + default=None, + description="DWDP (Distributed Weight Data Parallelism) config.", + status="prototype") + attn_backend: str = Field(default='TRTLLM', description="Attention backend to use.", status="beta") @@ -3745,6 +3774,7 @@ def update_llm_args_with_extra_dict( "nvfp4_gemm_config": Nvfp4GemmConfig, "attention_dp_config": AttentionDpConfig, "kv_cache_config": KvCacheConfig, + "dwdp_config": DwdpConfig, } for field_name, field_type in field_mapping.items(): if field_name in llm_args_dict: diff --git a/tests/integration/defs/accuracy/test_dwdp_disaggregated_serving.py b/tests/integration/defs/accuracy/test_dwdp_disaggregated_serving.py new file mode 100644 index 000000000000..ce1644c04711 --- /dev/null +++ b/tests/integration/defs/accuracy/test_dwdp_disaggregated_serving.py @@ -0,0 +1,282 @@ +"""DWDP disaggregated serving accuracy tests. + +Separated from test_disaggregated_serving.py to isolate MPI-dependent test +infrastructure for easier maintenance. +""" + +import contextlib +import os +import subprocess +import tempfile +import time +from typing import Any, Dict, Optional + +import openai +import pytest +import requests +import yaml + +from defs.common import get_free_port_in_ci as get_free_port +from tensorrt_llm.llmapi import CompletionOutput, RequestOutput, SamplingParams +from tensorrt_llm.llmapi.llm_args import LlmArgs +from tensorrt_llm.llmapi.tokenizer import load_hf_tokenizer + +from ..conftest import llm_models_root, skip_pre_blackwell +from ..trt_test_alternative import popen +from .accuracy_core import LlmapiAccuracyTestHarness +from .test_disaggregated_serving import ( + DEFAULT_SERVER_WAITING_TIMEOUT, + DEFAULT_TEST_TIMEOUT, + DuckLLM, + MyThreadPoolExecutor, + Result, + run_accuracy_test, +) + + +@contextlib.contextmanager +def launch_dwdp_disaggregated_llm( + worker_config: Dict[str, Any], + frontend_config: Dict[str, Any], + model_path: str, + total_gpus: int, + server_waiting_timeout: int = DEFAULT_SERVER_WAITING_TIMEOUT, + max_workers: int = 128, +): + """Launch DWDP disaggregated serving via mpirun. + + DWDP requires all workers (CTX + GEN) in a single MPI world for + IPC handle exchange and DWDP group formation. This function starts + all workers with ``mpirun`` and launches a separate disaggregated + frontend server for the client-facing OpenAI API. + """ + temp_dir = tempfile.TemporaryDirectory() + worker_config_path = os.path.join(temp_dir.name, "worker_config.yaml") + frontend_config_path = os.path.join(temp_dir.name, "frontend_config.yaml") + + with open(worker_config_path, "w") as f: + yaml.dump(worker_config, f, default_flow_style=False, sort_keys=False) + with open(frontend_config_path, "w") as f: + yaml.dump(frontend_config, f, default_flow_style=False, sort_keys=False) + + serve_port = frontend_config["port"] + + child_env = { + k: v for k, v in os.environ.items() if not k.startswith(("OMPI_", "PMIX_", "PMI_")) + } + + mpi_cmd = [ + "mpirun", + "--allow-run-as-root", + "-n", + str(total_gpus), + "trtllm-serve", + "disaggregated_mpi_worker", + "-c", + worker_config_path, + ] + + frontend_cmd = [ + "trtllm-serve", + "disaggregated", + "-c", + frontend_config_path, + "--server_start_timeout", + str(server_waiting_timeout), + "-r", + "360000", + ] + + with ( + MyThreadPoolExecutor(max_workers=max_workers) as thread_pool, + temp_dir, + popen(mpi_cmd, env=child_env) as mpi_proc, + popen(frontend_cmd, env=child_env) as frontend_proc, + ): + start_time = time.time() + server_is_ready = False + while time.time() - start_time < server_waiting_timeout: + time.sleep(5) + for proc, name in [ + (mpi_proc, "mpirun"), + (frontend_proc, "frontend"), + ]: + if proc.poll() is not None: + raise Exception(f"{name} process exited with code {proc.returncode}") + try: + response = requests.get(f"http://localhost:{serve_port}/cluster_info") + if response.status_code == 200: + cluster_info = response.json() + if cluster_info.get("is_ready"): + print(f"DWDP cluster ready: {cluster_info}") + server_is_ready = True + break + except requests.exceptions.ConnectionError: + continue + if not server_is_ready: + pytest.fail(f"DWDP server not ready after {server_waiting_timeout}s") + + model_name = worker_config.get("model", model_path) + client = openai.OpenAI( + api_key="1234567890", base_url=f"http://localhost:{serve_port}/v1", timeout=1800000 + ) + + def send_request(prompt: str, sampling_params: SamplingParams, streaming: bool): + kwargs = {} + if sampling_params is not None: + kwargs.update( + max_tokens=sampling_params.max_tokens, + temperature=( + sampling_params.temperature if sampling_params.top_p is not None else 0 + ), + top_p=sampling_params.top_p, + stop=sampling_params.stop, + seed=sampling_params.seed, + ) + response = client.completions.create( + model=model_name, prompt=prompt, stream=streaming, **kwargs + ) + result = Result( + id=0, + sampling_params=sampling_params, + outputs=[CompletionOutput(text=response.choices[0].text, index=0)], + ) + requested_output = RequestOutput._from_generation_result(result, prompt=prompt) + setattr(requested_output, "result", result.result) + return requested_output + + def generate_async( + prompt: str, sampling_params: Optional[SamplingParams] = None, streaming: bool = False + ): + future = thread_pool.submit(send_request, prompt, sampling_params, streaming) + thread_pool.futures.append(future) + return future + + args = LlmArgs(model=model_path) + tokenizer = load_hf_tokenizer(model_path) + try: + yield DuckLLM(args, tokenizer, generate_async) + finally: + all_procs = [frontend_proc, mpi_proc] + for proc in all_procs: + if proc.poll() is None: + proc.terminate() + deadline = time.monotonic() + 5 + for proc in all_procs: + remaining = max(0, deadline - time.monotonic()) + try: + proc.wait(timeout=remaining) + except subprocess.TimeoutExpired: + try: + proc.kill() + except ProcessLookupError: + pass + except OSError: + pass + + +@pytest.mark.timeout(DEFAULT_TEST_TIMEOUT) +class TestDwdpDeepSeekV3Lite(LlmapiAccuracyTestHarness): + MODEL_NAME = "deepseek-ai/DeepSeek-V3-Lite" + + @pytest.mark.skip_less_device(4) + @skip_pre_blackwell + def test_dwdp_accuracy(self): + model_path = f"{llm_models_root()}/DeepSeek-V3-Lite/nvfp4_moe_only_mtp" + + ctx_port_0 = get_free_port() + ctx_port_1 = get_free_port() + gen_port = get_free_port() + serve_port = get_free_port() + + ctx_server_config = { + "num_instances": 2, + "urls": [ + f"localhost:{ctx_port_0}", + f"localhost:{ctx_port_1}", + ], + "tensor_parallel_size": 1, + "pipeline_parallel_size": 1, + "disable_overlap_scheduler": True, + "enable_autotuner": False, + "enable_chunked_prefill": False, + "cuda_graph_config": None, + "max_batch_size": 16, + "max_num_tokens": 8192, + "kv_cache_config": { + "free_gpu_memory_fraction": 0.4, + "enable_block_reuse": False, + "enable_partial_reuse": False, + "tokens_per_block": 32, + }, + "cache_transceiver_config": { + "backend": "UCX", + "max_tokens_in_buffer": 8192, + }, + "moe_config": { + "backend": "CUTEDSL", + }, + "dwdp_config": { + "dwdp_size": 2, + "num_groups": 1, + "num_experts_per_worker": 36, + "num_prefetch_experts": 36, + }, + } + + gen_server_config = { + "num_instances": 1, + "urls": [f"localhost:{gen_port}"], + "tensor_parallel_size": 2, + "pipeline_parallel_size": 1, + "disable_overlap_scheduler": True, + "enable_autotuner": False, + "enable_chunked_prefill": False, + "cuda_graph_config": None, + "max_batch_size": 128, + "max_num_tokens": 1024, + "kv_cache_config": { + "free_gpu_memory_fraction": 0.5, + "enable_block_reuse": False, + "enable_partial_reuse": False, + "tokens_per_block": 32, + }, + "cache_transceiver_config": { + "backend": "UCX", + "max_tokens_in_buffer": 8192, + }, + "moe_config": { + "backend": "CUTEDSL", + }, + } + + worker_config = { + "model": model_path, + "hostname": "localhost", + "port": serve_port, + "backend": "pytorch", + "context_servers": ctx_server_config, + "generation_servers": gen_server_config, + } + + frontend_config = { + "backend": "pytorch", + "hostname": "localhost", + "port": serve_port, + "context_servers": { + "num_instances": 2, + "urls": [ + f"localhost:{ctx_port_0}", + f"localhost:{ctx_port_1}", + ], + }, + "generation_servers": { + "num_instances": 1, + "urls": [f"localhost:{gen_port}"], + }, + } + + with launch_dwdp_disaggregated_llm( + worker_config, frontend_config, model_path, total_gpus=4, max_workers=128 + ) as llm: + run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"]) diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index 63ae65f6cc05..d74631bc958f 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -380,6 +380,7 @@ accuracy/test_disaggregated_serving.py::TestDeepSeekV32Exp::test_auto_dtype_with accuracy/test_disaggregated_serving.py::TestDeepSeekV32Exp::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1dp2cp2] accuracy/test_disaggregated_serving.py::TestDeepSeekV32Exp::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2] accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend +accuracy/test_dwdp_disaggregated_serving.py::TestDwdpDeepSeekV3Lite::test_dwdp_accuracy accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False] accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True] accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[True] diff --git a/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml b/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml index e0539b70ea96..3cb1cefbfaaf 100644 --- a/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml +++ b/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml @@ -95,6 +95,7 @@ l0_gb200_multi_gpus: - accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4_adp_on-trtllm] - accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4_adp_off-trtllm] - accuracy/test_llm_api_pytorch_multimodal.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm] TIMEOUT (90) + - accuracy/test_dwdp_disaggregated_serving.py::TestDwdpDeepSeekV3Lite::test_dwdp_accuracy - condition: ranges: diff --git a/tests/scripts/cute_dsl_kernels/run_blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py b/tests/scripts/cute_dsl_kernels/run_blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py index 3fb0015c9833..3d7c46e5d0aa 100644 --- a/tests/scripts/cute_dsl_kernels/run_blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py +++ b/tests/scripts/cute_dsl_kernels/run_blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py @@ -92,6 +92,37 @@ cvt_sf_M32x4xrm_K4xrk_L_to_MKL = kernel_module.cvt_sf_M32x4xrm_K4xrk_L_to_MKL +def split_groups_to_b_tensors( + num_groups: int, num_b_tensors: int +) -> Tuple[Tuple[int, ...], Tuple[Tuple[int, ...], ...]]: + """Split groups into multiple B tensors. + + :param num_groups: Total number of groups (experts) + :param num_b_tensors: Number of B tensors to split into + :return: Tuple of (b_tensor_l_sizes, groups_per_b_tensor) + - b_tensor_l_sizes: L size for each B tensor + - groups_per_b_tensor: Tuple of group indices for each B tensor + """ + # Distribute groups evenly across B tensors + base_groups_per_tensor = num_groups // num_b_tensors + remainder = num_groups % num_b_tensors + + b_tensor_l_sizes = [] + groups_per_b_tensor = [] + current_group = 0 + + for i in range(num_b_tensors): + # Add one extra group to first 'remainder' tensors + num_groups_in_tensor = base_groups_per_tensor + (1 if i < remainder else 0) + b_tensor_l_sizes.append(num_groups_in_tensor) + groups_per_b_tensor.append( + tuple(range(current_group, current_group + num_groups_in_tensor)) + ) + current_group += num_groups_in_tensor + + return tuple(b_tensor_l_sizes), tuple(groups_per_b_tensor) + + def create_mask(group_m_list, mma_tiler_m, permuted_m=None): """Create mask and group mapping for contiguous grouped GEMM. @@ -375,6 +406,8 @@ def create_tensors( sf_vec_size, mma_tiler_m, permuted_m=None, + b_tensor_l_sizes=None, + groups_per_b_tensor=None, ): """Create tensors for contiguous grouped GEMM with gather operation and SwiGLU fusion. @@ -383,7 +416,7 @@ def create_tensors( Returns tensors including: - A: Input matrix (MxKx1) - - B: Weight matrix with interleaved up/gate weights (NxKxL) + - B: Weight matrix with interleaved up/gate weights (NxKxL) or list of tensors for multi-B - C: Output matrix (Mx(N/2)x1), N is halved due to SwiGLU fusion - SFA, SFB: Scale factor matrices for A and B - SFC: Scale factor matrix for C (only when c_dtype is Float4E2M1FN) @@ -392,31 +425,14 @@ def create_tensors( - num_non_exiting_tiles: Number of valid tiles to process :param mma_tiler_m: MMA tile size in M dimension (from mma_tiler_mn[0]), also used for alignment - :param permuted_m: Optional padded M dimension for cuda_graph support. If provided, - A matrix, C matrix, token_id_mapping, and scale factor A will be padded to this size. - The kernel exits when tile_idx >= num_non_exiting_tiles. - - Example with CUDA graph padding: - # For MoE: m=4096, topK=8, num_local_experts=256, experts_per_rank=8 - permuted_m = 4096 * 8 + 8 * 255 # = 34808 - tensors = create_tensors( - num_groups=8, # num_local_experts - group_m_list=[512, 1024, ...], # actual group sizes - n=4096, k=7168, - a_major="k", b_major="k", cd_major="n", - ab_dtype=cutlass.Float4E2M1FN, - c_dtype=cutlass.BFloat16, - sf_dtype=cutlass.Float8E4M3FN, - sf_vec_size=16, - mma_tiler_m=128, # MMA tile size in M dimension, also used for alignment - permuted_m=34808 # Enable padding for cuda_graph - ) - # Returns tensors with A, C, SFA, and token_id_mapping padded to permuted_m size, - # kernel exits early when tile_idx >= num_non_exiting_tiles + :param permuted_m: Optional padded M dimension for cuda_graph support. + :param b_tensor_l_sizes: Optional tuple of L sizes for multi-B tensor mode. + :param groups_per_b_tensor: Optional tuple of group indices for each B tensor. """ torch.manual_seed(1111) - alpha_torch_cpu = torch.randn((num_groups,), dtype=torch.float32) + # Determine if multi-B tensor mode + multi_b_mode = b_tensor_l_sizes is not None ( valid_m, @@ -427,21 +443,14 @@ def create_tensors( ) = create_mask(group_m_list, mma_tiler_m, permuted_m) max_m = max(group_m_list) - - # Use permuted_m for A/C tensors if provided (for cuda_graph support) tensor_m = permuted_m if permuted_m is not None else valid_m a_torch_cpu = cutlass_torch.matrix(1, max_m, k, a_major == "m", cutlass.Float32) - b_torch_cpu = cutlass_torch.matrix(num_groups, n, k, b_major == "n", cutlass.Float32) - # C tensor also uses tensor_m (permuted_m) for cuda_graph support c_torch_cpu = cutlass_torch.matrix(1, tensor_m, n // 2, cd_major == "m", cutlass.Float32) a_tensor, a_torch_gpu = cutlass_torch.cute_tensor_like( a_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16 ) - b_tensor, b_torch_gpu = cutlass_torch.cute_tensor_like( - b_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16 - ) c_tensor, c_torch_gpu = cutlass_torch.cute_tensor_like( c_torch_cpu, c_dtype, is_dynamic_layout=True, assumed_align=16 ) @@ -452,26 +461,83 @@ def create_tensors( stride_order=(2, 0, 1) if a_major == "k" else (2, 1, 0), divisibility=32 if ab_dtype == cutlass.Float4E2M1FN else 16, ) - b_tensor.mark_compact_shape_dynamic( - mode=1 if b_major == "k" else 0, - stride_order=(2, 0, 1) if b_major == "k" else (2, 1, 0), - divisibility=32 if ab_dtype == cutlass.Float4E2M1FN else 16, - ) c_tensor.mark_compact_shape_dynamic( mode=1 if cd_major == "n" else 0, stride_order=(2, 0, 1) if cd_major == "n" else (2, 1, 0), divisibility=32 if ab_dtype == cutlass.Float4E2M1FN else 16, ) + if multi_b_mode: + # Multi-B tensor mode: create multiple B tensors + b_torch_cpu_list = [] + b_tensor_list = [] + b_torch_gpu_list = [] + sfb_torch_cpu_list = [] + sfb_tensor_list = [] + sfb_torch_gpu_list = [] + alpha_torch_cpu_list = [] + alpha_tensor_list = [] + + for l_size in b_tensor_l_sizes: + # Create alpha for this B tensor + alpha_torch_cpu = torch.randn((l_size,), dtype=torch.float32) + alpha_torch_cpu_list.append(alpha_torch_cpu) + alpha = from_dlpack(alpha_torch_cpu.cuda()).mark_layout_dynamic() + alpha_tensor_list.append(alpha) + + # Create B tensor + b_torch_cpu = cutlass_torch.matrix(l_size, n, k, b_major == "n", cutlass.Float32) + b_tensor, b_torch_gpu = cutlass_torch.cute_tensor_like( + b_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + b_tensor.mark_compact_shape_dynamic( + mode=1 if b_major == "k" else 0, + stride_order=(2, 0, 1) if b_major == "k" else (2, 1, 0), + divisibility=32 if ab_dtype == cutlass.Float4E2M1FN else 16, + ) + b_torch_cpu_list.append(b_torch_cpu) + b_tensor_list.append(b_tensor) + b_torch_gpu_list.append(b_torch_gpu) + + # Create SFB tensor + sfb_torch_cpu, sfb_tensor, sfb_torch_gpu = create_scale_factor_tensor( + l_size, n, k, sf_vec_size, sf_dtype + ) + sfb_torch_cpu_list.append(sfb_torch_cpu) + sfb_tensor_list.append(sfb_tensor) + sfb_torch_gpu_list.append(sfb_torch_gpu) + + b_tensor = b_tensor_list + b_torch_cpu = b_torch_cpu_list + b_torch_gpu = b_torch_gpu_list + sfb_tensor = sfb_tensor_list + sfb_torch_cpu = sfb_torch_cpu_list + sfb_torch_gpu = sfb_torch_gpu_list + alpha_torch_cpu = alpha_torch_cpu_list + alpha = alpha_tensor_list + else: + # Single B tensor mode + alpha_torch_cpu = torch.randn((num_groups,), dtype=torch.float32) + alpha = from_dlpack(alpha_torch_cpu.cuda()).mark_layout_dynamic() + + b_torch_cpu = cutlass_torch.matrix(num_groups, n, k, b_major == "n", cutlass.Float32) + b_tensor, b_torch_gpu = cutlass_torch.cute_tensor_like( + b_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + b_tensor.mark_compact_shape_dynamic( + mode=1 if b_major == "k" else 0, + stride_order=(2, 0, 1) if b_major == "k" else (2, 1, 0), + divisibility=32 if ab_dtype == cutlass.Float4E2M1FN else 16, + ) + sfb_torch_cpu, sfb_tensor, sfb_torch_gpu = create_scale_factor_tensor( + num_groups, n, k, sf_vec_size, sf_dtype + ) + # Use tensor_m (permuted_m if provided) for scale factor A sfa_torch_cpu, sfa_tensor, sfa_torch_gpu = create_scale_factor_tensor_unswizzled( 1, max_m, k, sf_vec_size, sf_dtype ) - sfb_torch_cpu, sfb_tensor, sfb_torch_gpu = create_scale_factor_tensor( - num_groups, n, k, sf_vec_size, sf_dtype - ) - token_id_mapping_cpu, token_id_mapping, token_id_mapping_torch = create_token_id_mapping_tensor( group_m_list, mma_tiler_m, max_token_id=max_m, permuted_m=permuted_m ) @@ -480,8 +546,6 @@ def create_tensors( tile_idx_to_mn_limit = from_dlpack(_tile_idx_to_mn_limit).mark_layout_dynamic() num_non_exiting_tiles = from_dlpack(_num_non_exiting_tiles).mark_layout_dynamic() - alpha = from_dlpack(alpha_torch_cpu.cuda()).mark_layout_dynamic() - # Create sfc_tensor and norm_const_tensor when c_dtype is Float4E2M1FN sfc_torch_cpu = None sfc_tensor = None @@ -555,6 +619,7 @@ def run( permuted_m: int = None, use_cupti: bool = False, raster_along_m: bool = False, + num_b_tensors: int = None, **kwargs, ): """Run contiguous grouped GEMM with gather operation and SwiGLU fusion for FC1 layer. @@ -566,24 +631,16 @@ def run( Note: Output C has N/2 columns since SwiGLU combines pairs of (up, gate) from interleaved B weights. - This function: - - Creates tensors including token_id_mapping for gather operation - - Uses LDGSTS for loading A and SFA matrices with gather capability - - Uses TMA for loading B and SFB matrices with multicast - - Performs SwiGLU activation fusion in epilogue - - Optionally performs quantization fusion for Float4E2M1FN output - - Performs reference checking (if not skipped) - - Benchmarks kernel performance - - :param nkl: (N, K, L) dimensions where L is the number of experts/groups - :param group_m_list: List of M values for each group - :param mma_tiler_mn: MMA tile shape (M, N), where mma_tiler_mn[0] is used for group M alignment - :param permuted_m: Optional padded M dimension for CUDA graph support. If provided, - A/C matrices, token_id_mapping, and scale factor A will be padded to this size. + :param num_b_tensors: If specified, enables multi-B tensor mode (2-4 tensors). """ + # Determine if multi-B tensor mode + multi_b_mode = num_b_tensors is not None + print("Running Blackwell Persistent Contiguous Grouped GEMM with Gather test:") print(f"nkl: {nkl}") print(f"group_m_list: {group_m_list}") + if multi_b_mode: + print(f"Multi-B tensor mode: {num_b_tensors} B tensors") print( f"AB dtype: {ab_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}, " f"Scale factor dtype: {sf_dtype}, SF Vec size: {sf_vec_size}" @@ -608,6 +665,14 @@ def run( if not torch.cuda.is_available(): raise RuntimeError("GPU is required to run this example!") + # Split groups into multiple B tensors if multi-B mode + b_tensor_l_sizes = None + groups_per_b_tensor = None + if multi_b_mode: + b_tensor_l_sizes, groups_per_b_tensor = split_groups_to_b_tensors(num_groups, num_b_tensors) + print(f"b_tensor_l_sizes: {b_tensor_l_sizes}") + print(f"groups_per_b_tensor: {groups_per_b_tensor}") + # Skip unsupported testcase # Note: For grouped GEMM, we use mma_tiler_mn[0] as the m parameter for can_implement check # since individual group M values vary @@ -677,7 +742,10 @@ def run( sf_vec_size, mma_tiler_mn[0], # mma_tiler_m, also used for alignment permuted_m, + b_tensor_l_sizes=b_tensor_l_sizes, + groups_per_b_tensor=groups_per_b_tensor, ) + # Configure gemm kernel gemm = BlockScaledContiguousGatherGroupedGemmKernel( sf_vec_size, @@ -686,6 +754,7 @@ def run( True, topk=1, raster_along_m=raster_along_m, + b_tensor_l_sizes=b_tensor_l_sizes if multi_b_mode else None, ) # Compute max active clusters on current device @@ -700,40 +769,78 @@ def run( current_stream = cuda.CUstream(torch_stream.cuda_stream) # Compile gemm kernel # sfc_tensor is optional and can be set as None (Python's None value) if not needed. - compiled_gemm = cute.compile( - gemm, - a_tensor, - b_tensor, - c_tensor, - sfa_tensor, - sfb_tensor, - sfc_tensor, - norm_const_tensor, - tile_idx_to_expert_idx, - tile_idx_to_mn_limit, - token_id_mapping, - num_non_exiting_tiles, - alpha, - max_active_clusters, - current_stream, - ) + if multi_b_mode: + # Multi-B tensor mode: pass tuples + compiled_gemm = cute.compile( + gemm, + a_tensor, + tuple(b_tensor), + c_tensor, + sfa_tensor, + tuple(sfb_tensor), + sfc_tensor, + norm_const_tensor, + tile_idx_to_expert_idx, + tile_idx_to_mn_limit, + token_id_mapping, + num_non_exiting_tiles, + tuple(alpha), + max_active_clusters, + current_stream, + ) + else: + # Single-B tensor mode + compiled_gemm = cute.compile( + gemm, + a_tensor, + b_tensor, + c_tensor, + sfa_tensor, + sfb_tensor, + sfc_tensor, + norm_const_tensor, + tile_idx_to_expert_idx, + tile_idx_to_mn_limit, + token_id_mapping, + num_non_exiting_tiles, + alpha, + max_active_clusters, + current_stream, + ) # Execution - compiled_gemm( - a_tensor, - b_tensor, - c_tensor, - sfa_tensor, - sfb_tensor, - sfc_tensor, - norm_const_tensor, - tile_idx_to_expert_idx, - tile_idx_to_mn_limit, - token_id_mapping, - num_non_exiting_tiles, - alpha, - current_stream, - ) + if multi_b_mode: + compiled_gemm( + a_tensor, + tuple(b_tensor), + c_tensor, + sfa_tensor, + tuple(sfb_tensor), + sfc_tensor, + norm_const_tensor, + tile_idx_to_expert_idx, + tile_idx_to_mn_limit, + token_id_mapping, + num_non_exiting_tiles, + tuple(alpha), + current_stream, + ) + else: + compiled_gemm( + a_tensor, + b_tensor, + c_tensor, + sfa_tensor, + sfb_tensor, + sfc_tensor, + norm_const_tensor, + tile_idx_to_expert_idx, + tile_idx_to_mn_limit, + token_id_mapping, + num_non_exiting_tiles, + alpha, + current_stream, + ) torch.cuda.synchronize() # Compute reference result @@ -751,10 +858,28 @@ def run( for i, group_m in enumerate(aligned_group_m_list): end = start + group_m res_a = a_torch_cpu_f32[token_id_mapping_cpu[start:end]] - res_b = torch.einsum("nk,nk->nk", b_torch_cpu[:, :, i], sfb_torch_cpu[:, :, i]) - gemm_result[0, start:end, :] = ( - torch.einsum("mk,nk->mn", res_a, res_b) * alpha_torch_cpu[i] - ) + + if multi_b_mode: + # Find which B tensor this group belongs to + b_tensor_idx = None + local_group_idx = None + for b_idx, groups in enumerate(groups_per_b_tensor): + if i in groups: + b_tensor_idx = b_idx + local_group_idx = groups.index(i) + break + assert b_tensor_idx is not None, f"Group {i} not found in any B tensor" + res_b = torch.einsum( + "nk,nk->nk", + b_torch_cpu[b_tensor_idx][:, :, local_group_idx], + sfb_torch_cpu[b_tensor_idx][:, :, local_group_idx], + ) + alpha_val = alpha_torch_cpu[b_tensor_idx][local_group_idx] + else: + res_b = torch.einsum("nk,nk->nk", b_torch_cpu[:, :, i], sfb_torch_cpu[:, :, i]) + alpha_val = alpha_torch_cpu[i] + + gemm_result[0, start:end, :] = torch.einsum("mk,nk->mn", res_a, res_b) * alpha_val start = end # Step 2: Apply SwiGLU on interleaved GEMM result @@ -1020,24 +1145,7 @@ def generate_tensors(): token_id_mapping, num_non_exiting_tiles, alpha, - a_torch_cpu, - b_torch_cpu, - c_torch_cpu, - sfa_torch_cpu, - sfb_torch_cpu, - sfc_torch_cpu, - norm_const_torch_cpu, - alpha_torch_cpu, - a_torch_gpu, - b_torch_gpu, - c_torch_gpu, - sfa_torch_gpu, - sfb_torch_gpu, - sfc_torch_gpu, - norm_const_torch_gpu, - aligned_group_m_list, - valid_m, - token_id_mapping_cpu, + *_, ) = create_tensors( num_groups, group_m_list, @@ -1052,40 +1160,67 @@ def generate_tensors(): sf_vec_size, mma_tiler_mn[0], # mma_tiler_m, also used for alignment permuted_m, + b_tensor_l_sizes=b_tensor_l_sizes, + groups_per_b_tensor=groups_per_b_tensor, ) - return cute.testing.JitArguments( - a_tensor, - b_tensor, - c_tensor, - sfa_tensor, - sfb_tensor, - sfc_tensor, - norm_const_tensor, - tile_idx_to_expert_idx, - tile_idx_to_mn_limit, - token_id_mapping, - num_non_exiting_tiles, - alpha, - current_stream, - ) + if multi_b_mode: + return cute.testing.JitArguments( + a_tensor, + tuple(b_tensor), + c_tensor, + sfa_tensor, + tuple(sfb_tensor), + sfc_tensor, + norm_const_tensor, + tile_idx_to_expert_idx, + tile_idx_to_mn_limit, + token_id_mapping, + num_non_exiting_tiles, + tuple(alpha), + current_stream, + ) + else: + return cute.testing.JitArguments( + a_tensor, + b_tensor, + c_tensor, + sfa_tensor, + sfb_tensor, + sfc_tensor, + norm_const_tensor, + tile_idx_to_expert_idx, + tile_idx_to_mn_limit, + token_id_mapping, + num_non_exiting_tiles, + alpha, + current_stream, + ) workspace_count = 1 if use_cold_l2: # Calculate actual tensor_m used (with padding if permuted_m provided) tensor_m = permuted_m if permuted_m is not None else valid_m + if multi_b_mode: + b_bytes = sum(t.numel() * t.element_size() for t in b_torch_gpu) + sfb_bytes = sum(t.numel() * t.element_size() for t in sfb_torch_gpu) + alpha_bytes = sum(t.numel() * t.element_size() for t in alpha_torch_cpu) + else: + b_bytes = b_torch_gpu.numel() * b_torch_gpu.element_size() + sfb_bytes = sfb_torch_gpu.numel() * sfb_torch_gpu.element_size() + alpha_bytes = alpha_torch_cpu.numel() * alpha_torch_cpu.element_size() one_workspace_bytes = ( a_torch_gpu.numel() * a_torch_gpu.element_size() - + b_torch_gpu.numel() * b_torch_gpu.element_size() + + b_bytes + c_torch_gpu.numel() * c_torch_gpu.element_size() + sfa_torch_gpu.numel() * sfa_torch_gpu.element_size() - + sfb_torch_gpu.numel() * sfb_torch_gpu.element_size() + + sfb_bytes + (tensor_m // mma_tiler_mn[0]) * 4 # tile_idx_to_expert_idx length (tiles) * sizeof(int32) + (tensor_m // mma_tiler_mn[0]) * 4 # tile_idx_to_mn_limit length (tiles) * sizeof(int32) + tensor_m * 4 # token_id_mapping_tensor length (elements) * sizeof(int32) + 1 * 4 # num_non_exiting_tiles (1 element) * sizeof(int32) - + alpha_torch_cpu.numel() * alpha_torch_cpu.element_size() + + alpha_bytes ) workspace_count = cute.testing.get_workspace_count( one_workspace_bytes, warmup_iterations, iterations @@ -1245,6 +1380,13 @@ def read_benchmark_file( parser.add_argument( "--raster_along_m", action="store_true", default=False, help="Raster along M dimension" ) + parser.add_argument( + "--num_b_tensors", + type=int, + default=None, + help="Number of B tensors to split into (for multi-B tensor test). " + "If specified, enables multi-B tensor mode. Must be 2, 3, or 4.", + ) args = parser.parse_args() @@ -1279,6 +1421,16 @@ def read_benchmark_file( if len(args.cluster_shape_mn) != 2: parser.error("--cluster_shape_mn must contain exactly 2 values") + if args.num_b_tensors is not None: + if args.num_b_tensors < 2 or args.num_b_tensors > 4: + parser.error("--num_b_tensors must be 2, 3, or 4") + n, k, num_groups = nkl + if num_groups < args.num_b_tensors: + parser.error( + f"--num_b_tensors ({args.num_b_tensors}) cannot be greater than " + f"number of groups ({num_groups})" + ) + exec_time = run( nkl, group_m_list, @@ -1300,6 +1452,7 @@ def read_benchmark_file( args.permuted_m, args.use_cupti, args.raster_along_m, + args.num_b_tensors, ) print(f"Execution time: {exec_time:.2f} us") print("PASS") diff --git a/tests/scripts/cute_dsl_kernels/run_blockscaled_contiguous_grouped_gemm_finalize_fusion.py b/tests/scripts/cute_dsl_kernels/run_blockscaled_contiguous_grouped_gemm_finalize_fusion.py index b60201dbc3a8..671cdc5db765 100644 --- a/tests/scripts/cute_dsl_kernels/run_blockscaled_contiguous_grouped_gemm_finalize_fusion.py +++ b/tests/scripts/cute_dsl_kernels/run_blockscaled_contiguous_grouped_gemm_finalize_fusion.py @@ -297,6 +297,7 @@ def create_tensors( mma_tiler_mn, permuted_m=None, seq_len=None, + b_tensor_l_sizes=None, ): """Create tensors for contiguous grouped GEMM. @@ -304,6 +305,7 @@ def create_tensors( A matrix, C matrix, and scale factor A will be padded to this size. The kernel exits when tile_idx >= num_non_exiting_tiles. :param seq_len: Sequence length (number of output tokens for C tensor) + :param b_tensor_l_sizes: Optional tuple of L sizes for multi-B tensor mode. :return: Tuple of (a_tensor, b_tensor, out_tensor, sfa_tensor, sfb_tensor, tile_idx_to_expert_idx, num_non_exiting_tiles, alpha, a_torch_cpu, b_torch_cpu, out_torch_cpu, sfa_torch_cpu, sfb_torch_cpu, @@ -331,6 +333,11 @@ def create_tensors( """ torch.manual_seed(1111) + multi_b_mode = b_tensor_l_sizes is not None + if multi_b_mode: + total_l = sum(b_tensor_l_sizes) + if total_l != l: + raise ValueError(f"Sum of b_tensor_l_sizes ({total_l}) must equal total L ({l}).") alpha_torch_cpu = torch.ones((l,), dtype=torch.float32) * 0.1 ( @@ -394,6 +401,50 @@ def create_tensors( out_torch_gpu.fill_(0) + if multi_b_mode: + b_torch_cpu_list = [] + b_tensor_list = [] + b_torch_gpu_list = [] + sfb_torch_cpu_list = [] + sfb_tensor_list = [] + sfb_torch_gpu_list = [] + alpha_torch_cpu_list = [] + alpha_tensor_list = [] + + for l_size in b_tensor_l_sizes: + alpha_cpu = torch.ones((l_size,), dtype=torch.float32) * 0.1 + alpha_torch_cpu_list.append(alpha_cpu) + alpha_tensor_list.append(from_dlpack(alpha_cpu.cuda()).mark_layout_dynamic()) + + b_cpu = cutlass_torch.matrix(l_size, n, k, b_major == "n", cutlass.Float32) + b_tensor_i, b_torch_gpu_i = cutlass_torch.cute_tensor_like( + b_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + b_tensor_i.mark_compact_shape_dynamic( + mode=1 if b_major == "k" else 0, + stride_order=(2, 0, 1) if b_major == "k" else (2, 1, 0), + divisibility=32 if ab_dtype == cutlass.Float4E2M1FN else 16, + ) + b_torch_cpu_list.append(b_cpu) + b_tensor_list.append(b_tensor_i) + b_torch_gpu_list.append(b_torch_gpu_i) + + sfb_cpu, sfb_tensor_i, sfb_torch_gpu_i = create_scale_factor_tensor( + l_size, n, k, sf_vec_size, sf_dtype + ) + sfb_torch_cpu_list.append(sfb_cpu) + sfb_tensor_list.append(sfb_tensor_i) + sfb_torch_gpu_list.append(sfb_torch_gpu_i) + + b_tensor = b_tensor_list + b_torch_cpu = b_torch_cpu_list + b_torch_gpu = b_torch_gpu_list + sfb_tensor = sfb_tensor_list + sfb_torch_cpu = sfb_torch_cpu_list + sfb_torch_gpu = sfb_torch_gpu_list + alpha = alpha_tensor_list + alpha_torch_cpu = alpha_torch_cpu_list + return ( a_tensor, b_tensor, @@ -436,6 +487,13 @@ def verify_reference_result( topK: int, seq_len: int, ) -> torch.Tensor: + if isinstance(b_torch_cpu, list): + b_torch_cpu = torch.cat(b_torch_cpu, dim=2) + if isinstance(sfb_torch_cpu, list): + sfb_torch_cpu = torch.cat(sfb_torch_cpu, dim=2) + if isinstance(alpha_torch_cpu, list): + alpha_torch_cpu = torch.cat(alpha_torch_cpu, dim=0) + gemm_output = torch.empty((1, valid_m, n), dtype=torch.float32) valid_mask = torch.zeros((valid_m,), dtype=torch.bool, device="cuda") ######### gemm calculation ######### @@ -502,6 +560,7 @@ def run( raster_along_m: bool = False, use_blkred: bool = False, use_cupti: bool = False, + b_tensor_l_sizes=None, **kwargs, ): """Prepare A/B/C tensors, launch GPU kernel, and reference checking. @@ -559,6 +618,10 @@ def run( # Unpack parameters n, k, l = nkl # noqa: E741 + multi_b_mode = b_tensor_l_sizes is not None + total_l = sum(b_tensor_l_sizes) if multi_b_mode else l + if multi_b_mode and total_l != l: + raise ValueError(f"Sum of b_tensor_l_sizes ({total_l}) must equal L ({l}).") if not torch.cuda.is_available(): raise RuntimeError("GPU is required to run this example!") @@ -574,7 +637,7 @@ def run( m_aligned, n, k, - l, + total_l, a_major, b_major, out_major, @@ -623,6 +686,7 @@ def run( mma_tiler_mn, # cta_tile_m permuted_m, seq_len, # Pass seq_len as num_tokens for C tensor shape + b_tensor_l_sizes=b_tensor_l_sizes, ) # Calculate actual tensor_m used (with padding if permuted_m provided) @@ -649,6 +713,7 @@ def run( cluster_shape_mn, use_blkred=use_blkred, raster_along_m=raster_along_m, + b_tensor_l_sizes=b_tensor_l_sizes if multi_b_mode else None, ) # Compute max active clusters on current device @@ -661,29 +726,27 @@ def run( current_stream = cutlass_torch.default_stream() # Compile gemm kernel - compiled_gemm = cute.compile( - gemm, - a_tensor, - b_tensor, - out_tensor, - sfa_tensor, - sfb_tensor, - tile_idx_to_expert_idx, - num_non_exiting_tiles, - tile_idx_to_mn_limit, - alpha, - max_active_clusters, - current_stream, - permuted_idx_to_expanded_idx, - token_final_scales, - options="--opt-level 2", - ) - - # Compute reference result - if not skip_ref_check: - print("Verifying results...") - # Execution - compiled_gemm( + if multi_b_mode: + compiled_gemm = cute.compile( + gemm, + a_tensor, + tuple(b_tensor), + out_tensor, + sfa_tensor, + tuple(sfb_tensor), + tile_idx_to_expert_idx, + num_non_exiting_tiles, + tile_idx_to_mn_limit, + tuple(alpha), + max_active_clusters, + current_stream, + permuted_idx_to_expanded_idx, + token_final_scales, + options="--opt-level 2", + ) + else: + compiled_gemm = cute.compile( + gemm, a_tensor, b_tensor, out_tensor, @@ -693,11 +756,48 @@ def run( num_non_exiting_tiles, tile_idx_to_mn_limit, alpha, + max_active_clusters, current_stream, permuted_idx_to_expanded_idx, token_final_scales, + options="--opt-level 2", ) + # Compute reference result + if not skip_ref_check: + print("Verifying results...") + # Execution + if multi_b_mode: + compiled_gemm( + a_tensor, + tuple(b_tensor), + out_tensor, + sfa_tensor, + tuple(sfb_tensor), + tile_idx_to_expert_idx, + num_non_exiting_tiles, + tile_idx_to_mn_limit, + tuple(alpha), + current_stream, + permuted_idx_to_expanded_idx, + token_final_scales, + ) + else: + compiled_gemm( + a_tensor, + b_tensor, + out_tensor, + sfa_tensor, + sfb_tensor, + tile_idx_to_expert_idx, + num_non_exiting_tiles, + tile_idx_to_mn_limit, + alpha, + current_stream, + permuted_idx_to_expanded_idx, + token_final_scales, + ) + torch.cuda.synchronize() ref_result = verify_reference_result( a_torch_cpu, @@ -792,6 +892,7 @@ def generate_tensors(): mma_tiler_mn, # cta_tile_m permuted_m, seq_len, # Pass seq_len as num_tokens for C tensor shape + b_tensor_l_sizes=b_tensor_l_sizes, ) ( @@ -808,6 +909,21 @@ def generate_tensors(): final_scale_dtype, ) + if multi_b_mode: + return cute.testing.JitArguments( + a_tensor, + tuple(b_tensor), + out_tensor, + sfa_tensor, + tuple(sfb_tensor), + tile_idx_to_expert_idx, + num_non_exiting_tiles, + tile_idx_to_mn_limit, + tuple(alpha), + current_stream, + permuted_idx_to_expanded_idx, + token_final_scales, + ) return cute.testing.JitArguments( a_tensor, b_tensor, @@ -825,16 +941,22 @@ def generate_tensors(): workspace_count = 1 if use_cold_l2: + + def _tensor_list_bytes(tensors): + if isinstance(tensors, list): + return sum(t.numel() * t.element_size() for t in tensors) + return tensors.numel() * tensors.element_size() + one_workspace_bytes = ( - a_torch_gpu.numel() * a_torch_gpu.element_size() - + b_torch_gpu.numel() * b_torch_gpu.element_size() - + out_torch_gpu.numel() * out_torch_gpu.element_size() - + sfa_torch_gpu.numel() * sfa_torch_gpu.element_size() - + sfb_torch_gpu.numel() * sfb_torch_gpu.element_size() + _tensor_list_bytes(a_torch_gpu) + + _tensor_list_bytes(b_torch_gpu) + + _tensor_list_bytes(out_torch_gpu) + + _tensor_list_bytes(sfa_torch_gpu) + + _tensor_list_bytes(sfb_torch_gpu) + (tensor_m // mma_tiler_mn[0]) * 4 # tile_idx_to_expert_idx length (tiles) * sizeof(int32) + 1 * 4 # num_non_exiting_tiles (1 element) * sizeof(int32) - + alpha_torch_cpu.numel() * alpha_torch_cpu.element_size() + + _tensor_list_bytes(alpha_torch_cpu) ) workspace_count = cute.testing.get_workspace_count( one_workspace_bytes, warmup_iterations, iterations @@ -861,6 +983,14 @@ def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: except ValueError: raise argparse.ArgumentTypeError("Invalid format. Expected comma-separated integers.") + def split_groups_to_b_tensors(num_groups: int, num_b_tensors: int) -> Tuple[int, ...]: + if num_b_tensors <= 0: + raise argparse.ArgumentTypeError("num_b_tensors must be positive.") + base = num_groups // num_b_tensors + remainder = num_groups % num_b_tensors + sizes = [base + (1 if i < remainder else 0) for i in range(num_b_tensors)] + return tuple(sizes) + def read_benchmark_file( filepath: str, ) -> Tuple[Tuple[int, int, int], Tuple[int, ...]]: @@ -1054,6 +1184,19 @@ def parse_benchmark_arg( help="Use CUPTI to measure execution time", ) + parser.add_argument( + "--num_b_tensors", + type=int, + default=1, + help="Number of B tensors for multi-B mode (default: 1).", + ) + parser.add_argument( + "--b_tensor_l_sizes", + type=parse_comma_separated_ints, + default=None, + help="Comma-separated L sizes for each B tensor (e.g., 8,8,16). Overrides --num_b_tensors.", + ) + args = parser.parse_args() # Process arguments to generate nkl and group_m_list @@ -1071,6 +1214,17 @@ def parse_benchmark_arg( if len(args.cluster_shape_mn) != 2: parser.error("--cluster_shape_mn must contain exactly 2 values") + _, _, l = nkl # noqa: E741 + b_tensor_l_sizes = None + if args.b_tensor_l_sizes is not None: + b_tensor_l_sizes = args.b_tensor_l_sizes + if args.num_b_tensors != 1 and args.num_b_tensors != len(b_tensor_l_sizes): + parser.error("--num_b_tensors must match length of --b_tensor_l_sizes") + if sum(b_tensor_l_sizes) != l: + parser.error("--b_tensor_l_sizes must sum to L") + elif args.num_b_tensors > 1: + b_tensor_l_sizes = split_groups_to_b_tensors(l, args.num_b_tensors) + exec_time = run( nkl, group_m_list, @@ -1095,6 +1249,7 @@ def parse_benchmark_arg( args.raster_along_m, args.use_blkred, args.use_cupti, + b_tensor_l_sizes=b_tensor_l_sizes, ) print("exec_time: ", exec_time) print("PASS") diff --git a/tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py b/tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py index 91024f5e4b77..3f35261ddc61 100644 --- a/tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py +++ b/tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py @@ -548,10 +548,10 @@ def test_nvfp4_grouped_gemm_finalize_blackwell( c = torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_finalize_blackwell( a, - b, + [b], a_sf, - b_sf, - alpha, + [b_sf], + [alpha], tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, @@ -586,6 +586,35 @@ def test_nvfp4_grouped_gemm_finalize_blackwell( match_ratio = torch.isclose(c, c_ref, rtol=1.6e-2, atol=1e-5).sum().item() / c.numel() assert match_ratio > 0.99 + if num_local_experts > 1: + split_sizes = (num_local_experts // 2, num_local_experts - num_local_experts // 2) + b_list = list(torch.split(b, split_sizes, dim=0)) + b_sf_list = list(torch.split(b_sf, split_sizes, dim=0)) + alpha_list = list(torch.split(alpha, split_sizes, dim=0)) + c_multi = torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_finalize_blackwell( + a, + b_list, + a_sf, + b_sf_list, + alpha_list, + tile_idx_to_group_idx, + tile_idx_to_mn_limit, + permuted_idx_to_expanded_idx, + num_non_exiting_tiles, + token_final_scales, + num_experts=num_experts, + top_k=top_k, + num_local_experts=num_local_experts, + local_expert_offset=0, + tile_size=tile_size, + output_dtype=torch.bfloat16, + scaling_vector_size=sf_vec_size, + ) + multi_match_ratio = ( + torch.isclose(c_multi, c_ref, rtol=1.6e-2, atol=1e-5).sum().item() / c_ref.numel() + ) + assert multi_match_ratio > 0.99 + @pytest.mark.skipif( get_sm_version() not in (100, 103), @@ -839,13 +868,13 @@ def test_nvfp4_gather_grouped_gemm_swiglu_blackwell( global_sf = c_ref[:num_valid_permuted_tokens].abs().max().float() / (448 * 6) c_ref, c_sf_ref = torch.ops.trtllm.fp4_quantize(c_ref, 1 / global_sf, sf_vec_size, False) - # Call gather kernel - c, c_sf = torch.ops.trtllm.cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell( + # Call gather kernel (single-B via multi_b op with single-element lists) + c, c_sf = torch.ops.trtllm.cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell_multi_b( a, - b_interleaved, + [b_interleaved], a_sf_unswizzled, - b_sf_interleaved, - alpha, + [b_sf_interleaved], + [alpha], tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, @@ -891,3 +920,45 @@ def test_nvfp4_gather_grouped_gemm_swiglu_blackwell( c_sf_valid = torch.cat(c_sf_valid) c_sf_ref_valid = torch.cat(c_sf_ref_valid) check_accuracy(c_sf_valid, c_sf_ref_valid, atol=1e-4, rtol=1e-4, percent=0.95) + + if num_local_experts > 1: + split_sizes = ( + num_local_experts // 2, + num_local_experts - num_local_experts // 2, + ) + b_interleaved_list = list(torch.split(b_interleaved, split_sizes, dim=0)) + b_sf_interleaved_list = list(torch.split(b_sf_interleaved, split_sizes, dim=0)) + alpha_list = list(torch.split(alpha, split_sizes, dim=0)) + c_multi, c_sf_multi = ( + torch.ops.trtllm.cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell_multi_b( + a, + b_interleaved_list, + a_sf_unswizzled, + b_sf_interleaved_list, + alpha_list, + tile_idx_to_group_idx, + tile_idx_to_mn_limit, + permuted_idx_to_expanded_idx, + num_non_exiting_tiles, + torch.tensor([1 / global_sf], dtype=torch.float32, device="cuda"), + num_experts=num_experts, + top_k=top_k, + num_local_experts=num_local_experts, + local_expert_offset=0, + tile_size=tile_size, + scaling_vector_size=sf_vec_size, + ) + ) + c_multi_valid = c_multi[:num_valid_permuted_tokens].view(torch.uint8)[valid_token_mask] + check_accuracy(c_multi_valid, c_ref_valid, atol=1e-4, rtol=1e-4, percent=0.95) + + c_sf_multi_unswizzled = unswizzle_sf( + c_sf_multi, max_num_permuted_tokens, interm_size, sf_vec_size + ) + c_sf_multi_valid = [] + for i in range(num_valid_permuted_tokens): + if i >= tile_idx_to_mn_limit_list[i // tile_size]: + continue + c_sf_multi_valid.append(c_sf_multi_unswizzled[i]) + c_sf_multi_valid = torch.cat(c_sf_multi_valid) + check_accuracy(c_sf_multi_valid, c_sf_ref_valid, atol=1e-4, rtol=1e-4, percent=0.95) diff --git a/tests/unittest/api_stability/references/llm.yaml b/tests/unittest/api_stability/references/llm.yaml index 8b2971b5ac2c..3b774c8076c4 100644 --- a/tests/unittest/api_stability/references/llm.yaml +++ b/tests/unittest/api_stability/references/llm.yaml @@ -87,6 +87,10 @@ methods: annotation: Optional[tensorrt_llm.llmapi.llm_args.AttentionDpConfig] default: null status: beta + dwdp_config: + annotation: Optional[tensorrt_llm.llmapi.llm_args.DwdpConfig] + default: null + status: prototype checkpoint_loader: annotation: Optional[tensorrt_llm._torch.models.checkpoints.BaseCheckpointLoader] default: null