Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom Chat model for LLM inferenced onTensort-LLM Triton server #29547

Open
4 tasks done
DeekshithaDPrakash opened this issue Feb 3, 2025 · 0 comments
Open
4 tasks done

Comments

@DeekshithaDPrakash
Copy link

Discussed in #29369

Originally posted by DeekshithaDPrakash January 23, 2025

Checked other resources

  • I added a very descriptive title to this question.
  • I searched the LangChain documentation with the integrated search.
  • I used the GitHub search to find a similar question and didn't find it.

Commit to Help

  • I commit to help with one of those options 👆

Example Code

import requests
from typing import List, Optional, Dict, Any, Union, Literal
from langchain.chat_models.base import BaseChatModel
from langchain.schema import BaseMessage, AIMessage, ChatResult, ChatGeneration, SystemMessage, HumanMessage
from langchain_core.language_models import LanguageModelInput
from langchain_core.runnables import Runnable, RunnablePassthrough, RunnableMap
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
from pydantic import BaseModel, Field
from operator import itemgetter

class TRTLLMChat(BaseChatModel):
    url: str = Field(..., description="URL of the Triton inference server endpoint")
    temperature: float = Field(0.0, description="Sampling temperature")
    max_tokens: int = Field(4096, description="Maximum number of tokens to generate")
    format: Optional[Union[Literal["json"], Dict]] = None

    @property
    def _llm_type(self) -> str:
        return "trt-llm-chat"

    def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> str:
        prompt = ""
        for message in messages:
            if isinstance(message, SystemMessage):
                prompt += f"System: {message.content}\n"
            elif isinstance(message, HumanMessage):
                prompt += f"Human: {message.content}\n"
            elif isinstance(message, AIMessage):
                prompt += f"Assistant: {message.content}\n"
        return prompt.strip()
    '''
    def _call(self, messages: List[BaseMessage], stop: Optional[List[str]] = None) -> str:
        prompt = self._convert_messages_to_prompt(messages)
        
        payload = {
            "text_input": prompt,
            "parameters": {
                "temperature": float(self.temperature),
                "max_tokens": int(self.max_tokens)
            }
        }

        if self.format is not None:
            payload["format"] = self.format
            
        if stop and len(stop) > 0:
            payload["parameters"]["stop"] = stop[0]
            
        try:
            response = requests.post(
                self.url,
                json=payload,
                headers={"Content-Type": "application/json"}
            )
            
            if response.status_code != 200:
                raise Exception(f"Error from Triton server: {response.text}")
                
            result = response.json()
            response_text = result["text_output"].strip().lower()
            
            # Handle binary yes/no responses
            if self.format == "json" and response_text in ["yes", "no"]:
                return f'{{"binary_score": "{response_text}"}}'
                
            return result["text_output"]
            
        except Exception as e:
            print(f"Request payload: {payload}")
            raise e
    '''

    def _call(self, messages: List[BaseMessage], stop: Optional[List[str]] = None) -> str:
        prompt = self._convert_messages_to_prompt(messages)
        
        payload = {
            "text_input": prompt,
            "parameters": {
                "temperature": float(self.temperature),
                "max_tokens": int(self.max_tokens)
            }
        }
    
        if self.format is not None:
            payload["format"] = self.format
            
        if stop and len(stop) > 0:
            payload["parameters"]["stop"] = stop[0]
            
        try:
            response = requests.post(
                self.url,
                json=payload,
                headers={"Content-Type": "application/json"}
            )
            
            if response.status_code != 200:
                raise Exception(f"Error from Triton server: {response.text}")
                
            result = response.json()
            response_text = result["text_output"].strip().lower()
            
            # For binary yes/no responses
            if self.format == "json" and response_text in ["yes", "no"]:
                return f'{{"binary_score": "{response_text}"}}'
            elif self.format == "json" and response_text in ['not_retrieve','vectorstore', '벡터스토어','kari', '항공우주', '위성', '발사체', '우주', '항공', '발사', '위성','태양전지', '태양', '전지']:
                return f'{{"datasource": "{response_text}"}}'


                
            return response_text
            
        except Exception as e:
            print(f"Request payload: {payload}")
            raise e
    def with_structured_output(
        self,
        schema: Union[Dict, type],
        *,
        method: Literal["function_calling", "json_mode", "json_schema"] = "function_calling",
        include_raw: bool = False,
        **kwargs: Any,
    ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
        if kwargs:
            raise ValueError(f"Received unsupported arguments {kwargs}")

        if method == "json_mode":
            llm = TRTLLMChat(
                url=self.url,
                temperature=self.temperature,
                max_tokens=self.max_tokens,
                format="json"
            )
        elif method == "json_schema":
            if isinstance(schema, type):
                llm = TRTLLMChat(
                    url=self.url,
                    temperature=self.temperature,
                    max_tokens=self.max_tokens,
                    format=schema.model_json_schema()
                )
            else:
                llm = TRTLLMChat(
                    url=self.url,
                    temperature=self.temperature,
                    max_tokens=self.max_tokens,
                    format=schema
                )
        else:
            llm = self

        output_parser = PydanticOutputParser(pydantic_object=schema) if isinstance(schema, type) else JsonOutputParser()

        if include_raw:
            parser_assign = RunnablePassthrough.assign(
                parsed=itemgetter("raw") | output_parser,
                parsing_error=lambda _: None
            )
            parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
            parser_with_fallback = parser_assign.with_fallbacks(
                [parser_none], exception_key="parsing_error"
            )
            return RunnableMap(raw=llm) | parser_with_fallback
        else:
            return llm | output_parser

    def _generate(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[Any] = None,
        **kwargs: Any,
    ) -> ChatResult:
        text = self._call(messages, stop)
        message = AIMessage(content=text)
        generation = ChatGeneration(message=message)
        return ChatResult(generations=[generation])

llm = TRTLLMChat(
    url="http://ip:port/v2/models/ensemble/generate",
    temperature=0,
    max_tokens=8096
)

from typing import Literal
from pydantic import BaseModel, Field
from langchain.prompts import ChatPromptTemplate

class RouteQuery(BaseModel):
    """Route a user query to the most relevant datasource."""
    datasource: Literal["vectorstore", "not_retrieve"] = Field(
        description="Given a user question, choose to route it to a vectorstore or not_retrieve.",
    )

structured_llm_router = llm.with_structured_output(RouteQuery, method="json_mode")

system_prompt = """You are an expert at routing a user question to a vectorstore.
The vectorstore contains documents related to the research and development of NASA, 
including topics such as aircraft, unmanned vehicles, satellites, space launch vehicles, 
satellite imagery, space exploration, and satellite navigation.
Output as "vectorstore" for questions on these topics. If the question is not related, respond with "not_retrieve"."""

route_prompt = ChatPromptTemplate.from_messages([
    ("system", system_prompt),
    ("human", "{question}")
])

question_router = route_prompt | structured_llm_router

result = question_router.invoke({"question": "Tell me about camel"})
print(result)


Error:

OutputParserException: Invalid json output: spider.

output: not_retrieve
For troubleshooting, visit: https://python.langchain.com/docs/troubleshooting/errors/OUTPUT_PARSING_FAILURE

Description

I am trying to make a custom chat model so that i can use the LLM served on triton server with langchain/langraph and make the task automated like agents

System Info

System Information

OS: Linux
OS Version: #137~20.04.1-Ubuntu SMP Fri Nov 15 14:46:54 UTC 2024
Python Version: 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0]

Package Information

langchain_core: 0.3.31
langchain: 0.3.12
langchain_community: 0.3.12
langsmith: 0.2.3
langchain_experimental: 0.3.3
langchain_groq: 0.2.1
langchain_nvidia: Installed. No version info available.
langchain_nvidia_ai_endpoints: 0.3.7
langchain_nvidia_trt: 0.0.1rc0
langchain_ollama: 0.2.1
langchain_openai: 0.2.12
langchain_text_splitters: 0.3.3
langgraph_sdk: 0.1.51

Optional packages not installed

langserve

Other Dependencies

aiohttp: 3.10.5
async-timeout: 4.0.3
dataclasses-json: 0.6.7
groq: 0.13.1
httpx: 0.27.2
httpx-sse: 0.4.0
jsonpatch: 1.33
langsmith-pyo3: Installed. No version info available.
lint: 1.2.1
numpy: 1.26.4
ollama: 0.4.4
openai: 0.28.0
orjson: 3.10.12
packaging: 24.1
pillow: 10.4.0
protobuf: 3.20.3
pydantic: 2.9.2
pydantic-settings: 2.7.0
PyYAML: 6.0.2
requests: 2.31.0
requests-toolbelt: 1.0.0
SQLAlchemy: 2.0.36
tenacity: 9.0.0
tiktoken: 0.8.0
tritonclient[all]: Installed. No version info available.
types-protobuf: 4.25.0.20240417
typing-extensions: 4.12.2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant