diff --git a/examples/text-generation/text-generation-pipeline/pipeline.py b/examples/text-generation/text-generation-pipeline/pipeline.py index 15cb96a3d4..6361e25fa2 100644 --- a/examples/text-generation/text-generation-pipeline/pipeline.py +++ b/examples/text-generation/text-generation-pipeline/pipeline.py @@ -58,14 +58,21 @@ def __call__(self, prompt): if torch.is_tensor(model_inputs[t]): model_inputs[t] = model_inputs[t].to(self.device) + from optimum.habana.utils import HabanaProfile + + profiler = HabanaProfile( + warmup=self.profiling_warmup_steps, + active=self.profiling_steps, + record_shapes=self.profiling_record_shapes, + name="generate", + ) + output = self.model.generate( **model_inputs, generation_config=self.generation_config, lazy_mode=True, hpu_graphs=self.use_hpu_graphs, - profiling_steps=self.profiling_steps, - profiling_warmup_steps=self.profiling_warmup_steps, - profiling_record_shapes=self.profiling_record_shapes, + profiler=profiler, ).cpu() if use_batch: