diff --git a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py index e0c50eccc3d..842c48725f8 100644 --- a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py @@ -19,6 +19,199 @@ except ImportError: from cuda import cuda + +class GroupedGemmInputsHelper: + + def __init__(self, num_experts: int, top_k: int, num_local_experts: int, + local_expert_offset: int, tile_size: int): + self.num_experts = num_experts + self.top_k = top_k + self.num_local_experts = num_local_experts + self.local_expert_offset = local_expert_offset + self.tile_size = tile_size + # Padding values should never be accessed. + # Intentionally use a large padding value to expose issues early. + self.pad_val = int(2e9) + + def get_max_num_tiles(self, num_tokens: int) -> int: + num_expanded_tokens = num_tokens * self.top_k + if num_expanded_tokens <= self.num_local_experts: + return num_expanded_tokens + return (num_expanded_tokens + + (self.tile_size - 1) * self.num_local_experts) // self.tile_size + + def get_max_num_permuted_tokens(self, num_tokens: int) -> int: + return self.get_max_num_tiles(num_tokens) * self.tile_size + + def infer_num_tokens(self, max_num_permuted_tokens: int) -> int: + """Infer the maximum possible number of tokens given the max_num_permuted_tokens. + """ + max_num_tiles = max_num_permuted_tokens // self.tile_size + if max_num_tiles >= self.num_local_experts: + return (max_num_permuted_tokens - (self.tile_size - 1) * + (self.num_local_experts - 1)) // self.top_k + return max_num_tiles // self.top_k + + def gen_tuning_buckets(self, max_num_tokens: int) -> List[int]: + buckets = get_last_power_of_2_num_tokens_buckets( + self.infer_num_tokens(max_num_tokens)) + return sorted( + list(set(self.get_max_num_permuted_tokens(x) for x in buckets))) + + def map_to_tuning_buckets(self, x: int) -> int: + return self.get_max_num_permuted_tokens( + last_positive_power_of_2(self.infer_num_tokens(x))) + + def infer_shape_num_tokens(self, input_shapes: List[torch.Size]) -> int: + return self.infer_num_tokens(input_shapes[0][0]) + + def infer_shape_max_num_tiles(self, input_shapes: List[torch.Size]) -> int: + return input_shapes[0][0] // self.tile_size + + def infer_shape_max_num_permuted_tokens( + self, input_shapes: List[torch.Size]) -> int: + return self.infer_shape_max_num_tiles(input_shapes) * self.tile_size + + def generate_num_tokens_per_expert(self, num_tokens: int) -> List[int]: + average_num_tokens_per_expert = num_tokens * self.top_k / self.num_experts + balance = 0 + num_tokens_per_expert = [] + for i in range(self.num_local_experts): + balance += average_num_tokens_per_expert + if balance <= 1e-3: + continue + curr_num_tokens = int(balance) + 1 + num_tokens_per_expert.append(curr_num_tokens) + balance -= curr_num_tokens + return num_tokens_per_expert + + def generate_tile_idx_to_group_idx( + self, num_tokens_per_expert: List[int]) -> List[int]: + tile_idx_to_group_idx = [] + for i, curr_num_tokens in enumerate(num_tokens_per_expert): + curr_num_tiles = (curr_num_tokens + self.tile_size - + 1) // self.tile_size + tile_idx_to_group_idx.extend([i] * curr_num_tiles) + return tile_idx_to_group_idx + + def generate_tile_idx_to_mn_limit( + self, num_tokens_per_expert: List[int]) -> List[int]: + tile_idx_to_mn_limit = [] + for i, curr_num_tokens in enumerate(num_tokens_per_expert): + curr_num_tiles = (curr_num_tokens + self.tile_size - + 1) // self.tile_size + prev_mn_limit = len(tile_idx_to_mn_limit) * self.tile_size + for j in range(curr_num_tiles): + tile_idx_to_mn_limit.append(prev_mn_limit + min( + (j + 1) * self.tile_size, curr_num_tokens)) + return tile_idx_to_mn_limit + + def generate_permuted_idx_to_expanded_idx( + self, num_tokens: int, + num_tokens_per_expert: List[int]) -> List[int]: + permuted_idx_to_expanded_idx = [] + colmajor_expanded_idx = 0 + for i, curr_num_tokens in enumerate(num_tokens_per_expert): + curr_num_tiles = (curr_num_tokens + self.tile_size - + 1) // self.tile_size + for j in range(curr_num_tiles * self.tile_size): + if j < curr_num_tokens: + token_idx = colmajor_expanded_idx % num_tokens + topk_idx = colmajor_expanded_idx // num_tokens + expanded_idx = token_idx * self.top_k + topk_idx + permuted_idx_to_expanded_idx.append(expanded_idx) + colmajor_expanded_idx += 1 + else: + permuted_idx_to_expanded_idx.append(self.pad_val) + return permuted_idx_to_expanded_idx + + def inputs_pre_hook(self, inputs: List[torch.Tensor]) -> List[torch.Tensor]: + a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, num_non_exiting_tiles, *others = inputs + num_tokens = self.infer_num_tokens(a.size(0)) + num_tokens_per_expert = self.generate_num_tokens_per_expert(num_tokens) + tile_idx_to_group_idx_list = self.generate_tile_idx_to_group_idx( + num_tokens_per_expert) + num_non_exiting_tiles_val = len(tile_idx_to_group_idx_list) + num_padding_tiles_val = tile_idx_to_group_idx.size( + 0) - num_non_exiting_tiles_val + assert num_non_exiting_tiles_val > 0 + assert num_padding_tiles_val >= 0 + + tile_idx_to_group_idx = torch.tensor( + tile_idx_to_group_idx_list + [self.pad_val] * num_padding_tiles_val, + dtype=tile_idx_to_group_idx.dtype, + device=tile_idx_to_group_idx.device) + num_non_exiting_tiles = torch.tensor( + [num_non_exiting_tiles_val], + dtype=num_non_exiting_tiles.dtype, + device=num_non_exiting_tiles.device) + return a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, num_non_exiting_tiles, *others + + def inputs_pre_hook_finalize_fusion( + self, inputs: List[torch.Tensor]) -> List[torch.Tensor]: + a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, token_final_scales = inputs + num_tokens = self.infer_num_tokens(a.size(0)) + num_tokens_per_expert = self.generate_num_tokens_per_expert(num_tokens) + tile_idx_to_group_idx_list = self.generate_tile_idx_to_group_idx( + num_tokens_per_expert) + tile_idx_to_mn_limit_list = self.generate_tile_idx_to_mn_limit( + num_tokens_per_expert) + permuted_idx_to_expanded_idx_list = self.generate_permuted_idx_to_expanded_idx( + num_tokens, num_tokens_per_expert) + num_non_exiting_tiles_val = len(tile_idx_to_group_idx_list) + num_padding_tiles_val = tile_idx_to_group_idx.size( + 0) - num_non_exiting_tiles_val + assert num_non_exiting_tiles_val > 0 + assert num_padding_tiles_val >= 0 + assert len(tile_idx_to_mn_limit_list) == num_non_exiting_tiles_val + assert len(permuted_idx_to_expanded_idx_list + ) == num_non_exiting_tiles_val * self.tile_size + + tile_idx_to_group_idx = torch.tensor( + tile_idx_to_group_idx_list + [self.pad_val] * num_padding_tiles_val, + dtype=tile_idx_to_group_idx.dtype, + device=tile_idx_to_group_idx.device) + tile_idx_to_mn_limit = torch.tensor( + tile_idx_to_mn_limit_list + [self.pad_val] * num_padding_tiles_val, + dtype=tile_idx_to_mn_limit.dtype, + device=tile_idx_to_mn_limit.device) + permuted_idx_to_expanded_idx = torch.tensor( + permuted_idx_to_expanded_idx_list + [self.pad_val] * + (num_padding_tiles_val * self.tile_size), + dtype=permuted_idx_to_expanded_idx.dtype, + device=permuted_idx_to_expanded_idx.device) + num_non_exiting_tiles = torch.tensor( + [num_non_exiting_tiles_val], + dtype=num_non_exiting_tiles.dtype, + device=num_non_exiting_tiles.device) + return a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, token_final_scales + + +class FusedMoEInputsHelper: + + def __init__(self, num_experts: int, top_k: int, num_local_experts: int, + local_expert_offset: int): + self.num_experts = num_experts + self.top_k = top_k + self.num_local_experts = num_local_experts + self.local_expert_offset = local_expert_offset + + def infer_shape_num_tokens(self, input_shapes: List[torch.Size]) -> int: + return input_shapes[0][0] + + def inputs_pre_hook(self, inputs: List[torch.Tensor]) -> List[torch.Tensor]: + x, x_sf, token_selected_experts, token_final_scales, *others = inputs + num_tokens = token_selected_experts.size(0) + new_token_final_scales, new_token_selected_experts = torch.randn( + num_tokens, self.num_experts, + device=token_selected_experts.device).topk(self.top_k, dim=-1) + new_token_selected_experts = new_token_selected_experts.to( + token_selected_experts.dtype) + new_token_final_scales = new_token_final_scales.softmax(dim=-1).to( + token_final_scales.dtype) + return x, x_sf, new_token_selected_experts, new_token_final_scales, *others + + if IS_CUTLASS_DSL_AVAILABLE: import cutlass @@ -443,180 +636,6 @@ def _( ret = mat_a.new_empty(shape, dtype=torch.bfloat16) return ret - class GroupedGemmInputsHelper: - - def __init__(self, num_experts: int, top_k: int, num_local_experts: int, - local_expert_offset: int, tile_size: int): - self.num_experts = num_experts - self.top_k = top_k - self.num_local_experts = num_local_experts - self.local_expert_offset = local_expert_offset - self.tile_size = tile_size - # Padding values should never be accessed. - # Intentionally use a large padding value to expose issues early. - self.pad_val = int(2e9) - - def get_max_num_tiles(self, num_tokens: int) -> int: - num_expanded_tokens = num_tokens * self.top_k - if num_expanded_tokens <= self.num_local_experts: - return num_expanded_tokens - return ( - num_expanded_tokens + - (self.tile_size - 1) * self.num_local_experts) // self.tile_size - - def get_max_num_permuted_tokens(self, num_tokens: int) -> int: - return self.get_max_num_tiles(num_tokens) * self.tile_size - - def infer_num_tokens(self, max_num_permuted_tokens: int) -> int: - """Infer the maximum possible number of tokens given the max_num_permuted_tokens. - """ - max_num_tiles = max_num_permuted_tokens // self.tile_size - if max_num_tiles >= self.num_local_experts: - return (max_num_permuted_tokens - (self.tile_size - 1) * - (self.num_local_experts - 1)) // self.top_k - return max_num_tiles // self.top_k - - def gen_tuning_buckets(self, max_num_tokens: int) -> List[int]: - buckets = get_last_power_of_2_num_tokens_buckets( - self.infer_num_tokens(max_num_tokens)) - return sorted( - list(set(self.get_max_num_permuted_tokens(x) for x in buckets))) - - def map_to_tuning_buckets(self, x: int) -> int: - return self.get_max_num_permuted_tokens( - last_positive_power_of_2(self.infer_num_tokens(x))) - - def infer_shape_num_tokens(self, input_shapes: List[torch.Size]) -> int: - return self.infer_num_tokens(input_shapes[0][0]) - - def infer_shape_max_num_tiles(self, - input_shapes: List[torch.Size]) -> int: - return input_shapes[0][0] // self.tile_size - - def infer_shape_max_num_permuted_tokens( - self, input_shapes: List[torch.Size]) -> int: - return self.infer_shape_max_num_tiles(input_shapes) * self.tile_size - - def generate_num_tokens_per_expert(self, num_tokens: int) -> List[int]: - average_num_tokens_per_expert = num_tokens * self.top_k / self.num_experts - balance = 0 - num_tokens_per_expert = [] - for i in range(self.num_local_experts): - balance += average_num_tokens_per_expert - if balance <= 1e-3: - continue - curr_num_tokens = int(balance) + 1 - num_tokens_per_expert.append(curr_num_tokens) - balance -= curr_num_tokens - return num_tokens_per_expert - - def generate_tile_idx_to_group_idx( - self, num_tokens_per_expert: List[int]) -> List[int]: - tile_idx_to_group_idx = [] - for i, curr_num_tokens in enumerate(num_tokens_per_expert): - curr_num_tiles = (curr_num_tokens + self.tile_size - - 1) // self.tile_size - tile_idx_to_group_idx.extend([i] * curr_num_tiles) - return tile_idx_to_group_idx - - def generate_tile_idx_to_mn_limit( - self, num_tokens_per_expert: List[int]) -> List[int]: - tile_idx_to_mn_limit = [] - for i, curr_num_tokens in enumerate(num_tokens_per_expert): - curr_num_tiles = (curr_num_tokens + self.tile_size - - 1) // self.tile_size - prev_mn_limit = len(tile_idx_to_mn_limit) * self.tile_size - for j in range(curr_num_tiles): - tile_idx_to_mn_limit.append(prev_mn_limit + min( - (j + 1) * self.tile_size, curr_num_tokens)) - return tile_idx_to_mn_limit - - def generate_permuted_idx_to_expanded_idx( - self, num_tokens: int, - num_tokens_per_expert: List[int]) -> List[int]: - permuted_idx_to_expanded_idx = [] - colmajor_expanded_idx = 0 - for i, curr_num_tokens in enumerate(num_tokens_per_expert): - curr_num_tiles = (curr_num_tokens + self.tile_size - - 1) // self.tile_size - for j in range(curr_num_tiles * self.tile_size): - if j < curr_num_tokens: - token_idx = colmajor_expanded_idx % num_tokens - topk_idx = colmajor_expanded_idx // num_tokens - expanded_idx = token_idx * self.top_k + topk_idx - permuted_idx_to_expanded_idx.append(expanded_idx) - colmajor_expanded_idx += 1 - else: - permuted_idx_to_expanded_idx.append(self.pad_val) - return permuted_idx_to_expanded_idx - - def inputs_pre_hook(self, - inputs: List[torch.Tensor]) -> List[torch.Tensor]: - a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, num_non_exiting_tiles, *others = inputs - num_tokens = self.infer_num_tokens(a.size(0)) - num_tokens_per_expert = self.generate_num_tokens_per_expert( - num_tokens) - tile_idx_to_group_idx_list = self.generate_tile_idx_to_group_idx( - num_tokens_per_expert) - num_non_exiting_tiles_val = len(tile_idx_to_group_idx_list) - num_padding_tiles_val = tile_idx_to_group_idx.size( - 0) - num_non_exiting_tiles_val - assert num_non_exiting_tiles_val > 0 - assert num_padding_tiles_val >= 0 - - tile_idx_to_group_idx = torch.tensor( - tile_idx_to_group_idx_list + - [self.pad_val] * num_padding_tiles_val, - dtype=tile_idx_to_group_idx.dtype, - device=tile_idx_to_group_idx.device) - num_non_exiting_tiles = torch.tensor( - [num_non_exiting_tiles_val], - dtype=num_non_exiting_tiles.dtype, - device=num_non_exiting_tiles.device) - return a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, num_non_exiting_tiles, *others - - def inputs_pre_hook_finalize_fusion( - self, inputs: List[torch.Tensor]) -> List[torch.Tensor]: - a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, token_final_scales = inputs - num_tokens = self.infer_num_tokens(a.size(0)) - num_tokens_per_expert = self.generate_num_tokens_per_expert( - num_tokens) - tile_idx_to_group_idx_list = self.generate_tile_idx_to_group_idx( - num_tokens_per_expert) - tile_idx_to_mn_limit_list = self.generate_tile_idx_to_mn_limit( - num_tokens_per_expert) - permuted_idx_to_expanded_idx_list = self.generate_permuted_idx_to_expanded_idx( - num_tokens, num_tokens_per_expert) - num_non_exiting_tiles_val = len(tile_idx_to_group_idx_list) - num_padding_tiles_val = tile_idx_to_group_idx.size( - 0) - num_non_exiting_tiles_val - assert num_non_exiting_tiles_val > 0 - assert num_padding_tiles_val >= 0 - assert len(tile_idx_to_mn_limit_list) == num_non_exiting_tiles_val - assert len(permuted_idx_to_expanded_idx_list - ) == num_non_exiting_tiles_val * self.tile_size - - tile_idx_to_group_idx = torch.tensor( - tile_idx_to_group_idx_list + - [self.pad_val] * num_padding_tiles_val, - dtype=tile_idx_to_group_idx.dtype, - device=tile_idx_to_group_idx.device) - tile_idx_to_mn_limit = torch.tensor( - tile_idx_to_mn_limit_list + - [self.pad_val] * num_padding_tiles_val, - dtype=tile_idx_to_mn_limit.dtype, - device=tile_idx_to_mn_limit.device) - permuted_idx_to_expanded_idx = torch.tensor( - permuted_idx_to_expanded_idx_list + [self.pad_val] * - (num_padding_tiles_val * self.tile_size), - dtype=permuted_idx_to_expanded_idx.dtype, - device=permuted_idx_to_expanded_idx.device) - num_non_exiting_tiles = torch.tensor( - [num_non_exiting_tiles_val], - dtype=num_non_exiting_tiles.dtype, - device=num_non_exiting_tiles.device) - return a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, token_final_scales - class Sm100BlockScaledContiguousGroupedGemmRunner(TunableRunner): kernel_class = Sm100BlockScaledContiguousGroupedGemmKernel kernel_cache = dict() @@ -1544,32 +1563,6 @@ def _( device=input_scale.device) return output, output_scale - class FusedMoEInputsHelper: - - def __init__(self, num_experts: int, top_k: int, num_local_experts: int, - local_expert_offset: int): - self.num_experts = num_experts - self.top_k = top_k - self.num_local_experts = num_local_experts - self.local_expert_offset = local_expert_offset - - def infer_shape_num_tokens(self, input_shapes: List[torch.Size]) -> int: - return input_shapes[0][0] - - def inputs_pre_hook(self, - inputs: List[torch.Tensor]) -> List[torch.Tensor]: - x, x_sf, token_selected_experts, token_final_scales, *others = inputs - num_tokens = token_selected_experts.size(0) - new_token_final_scales, new_token_selected_experts = torch.randn( - num_tokens, - self.num_experts, - device=token_selected_experts.device).topk(self.top_k, dim=-1) - new_token_selected_experts = new_token_selected_experts.to( - token_selected_experts.dtype) - new_token_final_scales = new_token_final_scales.softmax(dim=-1).to( - token_final_scales.dtype) - return x, x_sf, new_token_selected_experts, new_token_final_scales, *others - class Sm100BlockScaledFusedMoERunner(TunableRunner): tuning_config_cache = dict() diff --git a/tensorrt_llm/commands/eval.py b/tensorrt_llm/commands/eval.py index 2ac2c95bd84..d849a7c91a4 100644 --- a/tensorrt_llm/commands/eval.py +++ b/tensorrt_llm/commands/eval.py @@ -117,10 +117,6 @@ def main(ctx, model: str, tokenizer: Optional[str], log_level: str, revision: Optional[str], extra_llm_api_options: Optional[str], disable_kv_cache_reuse: bool): logger.set_level(log_level) - build_config = BuildConfig(max_batch_size=max_batch_size, - max_num_tokens=max_num_tokens, - max_beam_width=max_beam_width, - max_seq_len=max_seq_len) kv_cache_config = KvCacheConfig( free_gpu_memory_fraction=kv_cache_free_gpu_memory_fraction, @@ -135,7 +131,6 @@ def main(ctx, model: str, tokenizer: Optional[str], log_level: str, "gpus_per_node": gpus_per_node, "trust_remote_code": trust_remote_code, "revision": revision, - "build_config": build_config, "kv_cache_config": kv_cache_config, } @@ -145,10 +140,17 @@ def main(ctx, model: str, tokenizer: Optional[str], log_level: str, profiler.start("trtllm init") if backend == 'pytorch': - llm_args.pop("build_config", None) - llm = PyTorchLLM(**llm_args) + llm = PyTorchLLM(**llm_args, + max_batch_size=max_batch_size, + max_num_tokens=max_num_tokens, + max_beam_width=max_beam_width, + max_seq_len=max_seq_len) elif backend == 'tensorrt': - llm = LLM(**llm_args) + build_config = BuildConfig(max_batch_size=max_batch_size, + max_num_tokens=max_num_tokens, + max_beam_width=max_beam_width, + max_seq_len=max_seq_len) + llm = LLM(**llm_args, build_config=build_config) else: raise click.BadParameter( f"{backend} is not a known backend, check help for available options.",