1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16- import inspect
1716import logging
1817from functools import wraps
19- from typing import Any , Dict , List , Optional
18+ from typing import Any , List , Optional
2019
2120from langchain_core .callbacks .manager import CallbackManagerForLLMRun
2221from langchain_core .language_models .chat_models import generate_from_stream
2322from langchain_core .messages import BaseMessage
2423from langchain_core .outputs import ChatResult
2524from langchain_nvidia_ai_endpoints import ChatNVIDIA as ChatNVIDIAOriginal
26- from pydantic import Field
25+ from pydantic . v1 import Field
2726
28- log = logging .getLogger (__name__ ) # pragma: no cover
27+ log = logging .getLogger (__name__ )
2928
3029
31- def stream_decorator (func ): # pragma: no cover
30+ def stream_decorator (func ):
3231 @wraps (func )
3332 def wrapper (
3433 self ,
@@ -52,52 +51,10 @@ def wrapper(
5251
5352# NOTE: this needs to have the same name as the original class,
5453# otherwise, there's a check inside `langchain-nvidia-ai-endpoints` that will fail.
55- class ChatNVIDIA (ChatNVIDIAOriginal ): # pragma: no cover
54+ class ChatNVIDIA (ChatNVIDIAOriginal ):
5655 streaming : bool = Field (
5756 default = False , description = "Whether to use streaming or not"
5857 )
59- custom_headers : Optional [Dict [str , str ]] = Field (
60- default = None , description = "Custom HTTP headers to send with requests"
61- )
62-
63- def __init__ (self , ** kwargs : Any ):
64- super ().__init__ (** kwargs )
65- if self .custom_headers :
66- custom_headers_error = (
67- "custom_headers requires langchain-nvidia-ai-endpoints >= 0.3.0. "
68- "Your version does not support the required client structure or "
69- "extra_headers parameter. Please upgrade: "
70- "pip install --upgrade langchain-nvidia-ai-endpoints>=0.3.0"
71- )
72- if not hasattr (self ._client , "get_req" ):
73- raise RuntimeError (custom_headers_error )
74-
75- sig = inspect .signature (self ._client .get_req )
76- if "extra_headers" not in sig .parameters :
77- raise RuntimeError (custom_headers_error )
78-
79- self ._wrap_client_methods ()
80-
81- def _wrap_client_methods (self ):
82- original_get_req = self ._client .get_req
83- original_get_req_stream = self ._client .get_req_stream
84-
85- def wrapped_get_req (payload : dict = None , extra_headers : dict = None ):
86- payload = payload or {}
87- extra_headers = extra_headers or {}
88- merged_headers = {** extra_headers , ** self .custom_headers }
89- return original_get_req (payload = payload , extra_headers = merged_headers )
90-
91- def wrapped_get_req_stream (payload : dict = None , extra_headers : dict = None ):
92- payload = payload or {}
93- extra_headers = extra_headers or {}
94- merged_headers = {** extra_headers , ** self .custom_headers }
95- return original_get_req_stream (
96- payload = payload , extra_headers = merged_headers
97- )
98-
99- object .__setattr__ (self ._client , "get_req" , wrapped_get_req )
100- object .__setattr__ (self ._client , "get_req_stream" , wrapped_get_req_stream )
10158
10259 @stream_decorator
10360 def _generate (
0 commit comments