Patch for Gaudi Text-Generation Pipeline#690
Conversation
|
Hi @regisss. As per our conversation this morning, I have changed the output format of the pipeline class in order to make it work with langchain==0.0.191. I would appreciate it if you could try running the Will work on adding some langchain examples to the blog next if these changes look good. Thanks! |
regisss
left a comment
There was a problem hiding this comment.
Maybe we can test the type of the output and keep the former way if it is a string (I guess it is, maybe the type is different) else return the new format. WDYT?
I tried the former output type (string) but faced some issues with langchain==0.0.191. It looks like langchain expects the output to be a dictionary wrapped in a list as shown here: https://github.com/langchain-ai/langchain/blob/b3ae6bcd3f42ec85ee65eb29c922ab22a17a0210/langchain/llms/huggingface_pipeline.py#L169 |
|
Yep I understand. What I propose is to have something like: if isinstance(output, str):
# return the same as before
else:
# return LangChain formatWould that work for you? |
Where do you suggest adding this check? One way would be adding an input argument to the pipeline constructor (use_with_langchain=False) and returning output in langchain format if it's True: class GaudiTextGenerationPipeline(TextGenerationPipeline):
def __init__(self, args, logger, use_with_langchain=False):
self.model, self.tokenizer, self.generation_config = initialize_model(args, logger)
self.task = "text-generation"
self.device = args.device
if args.do_sample:
self.generation_config.temperature = args.temperature
self.generation_config.top_p = args.top_p
self.max_padding_length = args.max_input_tokens if args.max_input_tokens > 0 else 100
self.use_hpu_graphs = args.use_hpu_graphs
self.profiling_steps = args.profiling_steps
self.profiling_warmup_steps = args.profiling_warmup_steps
self.use_with_langchain = use_with_langchain
if self.use_with_langchain:
self.generation_config.ignore_eos = False
import habana_frameworks.torch.hpu as torch_hpu
logger.info("Graph compilation...")
for _ in range(3):
self("Here is my prompt")
torch_hpu.synchronize()
def __call__(self, prompt: str):
model_inputs = self.tokenizer.encode_plus(
prompt, return_tensors="pt", max_length=self.max_padding_length, padding="max_length", truncation=True
)
for t in model_inputs:
if torch.is_tensor(model_inputs[t]):
model_inputs[t] = model_inputs[t].to(self.device)
output = self.model.generate(
**model_inputs,
generation_config=self.generation_config,
lazy_mode=True,
hpu_graphs=self.use_hpu_graphs,
profiling_steps=self.profiling_steps,
profiling_warmup_steps=self.profiling_warmup_steps,
).cpu()
output_text = self.tokenizer.decode(output[0], skip_special_tokens=True)
if self.use_with_langchain:
return [{"generated_text": output_text}]
return output_textWDYT? |
|
@sjagtap1803 That looks good to me 👍 |
regisss
left a comment
There was a problem hiding this comment.
Could you also run the following from the root of the repo please?
pip install -U ruff
make style
Besides, do you already know which version of LangChain should be used? We should specify it in the README.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
I applied code formatting by running Regarding the LangChain version, I will test a few examples with 0.0.191 later today and update the blog accordingly. If the examples run as expected, I will specify the version in the README. |
What does this PR do?
This PR includes some minor changes to the text-generation pipeline code for langchain==0.0.191 compatibility.
Before submitting