From 45df468a9e3e6be3a3df321c64ee2898b9b91232 Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Tue, 20 Feb 2024 10:42:12 +0100 Subject: [PATCH] Add generate kwargs to VQA pipeline --- src/transformers/pipelines/visual_question_answering.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/pipelines/visual_question_answering.py b/src/transformers/pipelines/visual_question_answering.py index f456835d7090..9106b19d3367 100644 --- a/src/transformers/pipelines/visual_question_answering.py +++ b/src/transformers/pipelines/visual_question_answering.py @@ -123,9 +123,9 @@ def preprocess(self, inputs, padding=False, truncation=False, timeout=None): model_inputs.update(image_features) return model_inputs - def _forward(self, model_inputs): + def _forward(self, model_inputs, **generate_kwargs): if self.model.can_generate(): - model_outputs = self.model.generate(**model_inputs) + model_outputs = self.model.generate(**model_inputs, **generate_kwargs) else: model_outputs = self.model(**model_inputs) return model_outputs