diff --git a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py index ee7a2763f7d..1c9695b4e84 100644 --- a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py @@ -54,7 +54,6 @@ def __init__(self, alpha: float, output_dtype: torch.dtype): f"SM version {get_sm_version()} is not supported for {self.__class__.__name__}, it only supports SM 100" ) - # rewrite the hash function because the value of self.alpha doesn't affect the tactic. def unique_id(self): return (self.output_dtype, ) @@ -531,6 +530,17 @@ def __init__(self, f"SM version {get_sm_version()} is not supported for {self.__class__.__name__}, it only supports SM 100" ) + def unique_id(self): + return ( + self.num_experts, + self.top_k, + self.num_local_experts, + self.local_expert_offset, + self.tile_size, + self.output_dtype, + self.scaling_vector_size, + ) + def get_valid_tactics( self, inputs: List[torch.Tensor], @@ -571,7 +581,7 @@ def get_valid_tactics( return valid_tactics def get_tuning_config(self) -> TuningConfig: - key = hash(self) + key = self.unique_id() if key not in self.__class__.tuning_config_cache: helper = GroupedGemmInputsHelper(self.num_experts, self.top_k, self.num_local_experts, @@ -807,6 +817,17 @@ def __init__(self, f"SM version {get_sm_version()} is not supported for {self.__class__.__name__}, it only supports SM 100" ) + def unique_id(self): + return ( + self.num_experts, + self.top_k, + self.num_local_experts, + self.local_expert_offset, + self.tile_size, + self.output_dtype, + self.scaling_vector_size, + ) + def get_valid_tactics( self, inputs: List[torch.Tensor], @@ -847,7 +868,7 @@ def get_valid_tactics( return valid_tactics def get_tuning_config(self) -> TuningConfig: - key = hash(self) + key = self.unique_id() if key not in self.__class__.tuning_config_cache: helper = GroupedGemmInputsHelper(self.num_experts, self.top_k, self.num_local_experts, @@ -1124,6 +1145,16 @@ def __init__(self, f"SM version {get_sm_version()} is not supported for {self.__class__.__name__}, it only supports SM 100" ) + def unique_id(self): + return ( + self.num_experts, + self.top_k, + self.num_local_experts, + self.local_expert_offset, + self.tile_size, + self.scaling_vector_size, + ) + def get_valid_tactics( self, inputs: List[torch.Tensor], @@ -1164,7 +1195,7 @@ def get_valid_tactics( return valid_tactics def get_tuning_config(self) -> TuningConfig: - key = hash(self) + key = self.unique_id() if key not in self.__class__.tuning_config_cache: helper = GroupedGemmInputsHelper(self.num_experts, self.top_k, self.num_local_experts, @@ -1443,6 +1474,16 @@ def __init__(self, self.output_dtype = output_dtype self.scaling_vector_size = scaling_vector_size + def unique_id(self): + return ( + self.num_experts, + self.top_k, + self.num_local_experts, + self.local_expert_offset, + self.output_dtype, + self.scaling_vector_size, + ) + def get_valid_tactics( self, inputs: List[torch.Tensor], @@ -1452,7 +1493,7 @@ def get_valid_tactics( return [128] def get_tuning_config(self) -> TuningConfig: - key = hash(self) + key = self.unique_id() if key not in self.__class__.tuning_config_cache: helper = FusedMoEInputsHelper(self.num_experts, self.top_k, self.num_local_experts,