diff --git a/llm/predict/predictor.py b/llm/predict/predictor.py index 3ae5a22e8810..66fa10f14774 100644 --- a/llm/predict/predictor.py +++ b/llm/predict/predictor.py @@ -346,12 +346,7 @@ class StaticGraphPredictor(BasePredictor): def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer = None): super().__init__(config, tokenizer) - params_path = os.path.join(self.config.model_name_or_path, self.config.model_prefix + ".pdiparams") - if paddle.framework.use_pir_api(): - model_path = os.path.join(self.config.model_name_or_path, self.config.model_prefix + ".json") - else: - model_path = os.path.join(self.config.model_name_or_path, self.config.model_prefix + ".pdmodel") - inference_config = paddle.inference.Config(model_path, params_path) + inference_config = paddle.inference.Config(self.config.model_name_or_path, self.config.model_prefix) if self.config.device == "gpu": # set GPU configs accordingly