Skip to content

Commit 7780df3

Browse files
committed
Revert "feat(llm): Add custom HTTP headers support to ChatNVIDIA provider (#1461)"
This reverts commit aafd733.
1 parent d2bfaea commit 7780df3

File tree

2 files changed

+5
-506
lines changed

2 files changed

+5
-506
lines changed

nemoguardrails/llm/providers/_langchain_nvidia_ai_endpoints_patch.py

Lines changed: 5 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,21 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import inspect
1716
import logging
1817
from functools import wraps
19-
from typing import Any, Dict, List, Optional
18+
from typing import Any, List, Optional
2019

2120
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
2221
from langchain_core.language_models.chat_models import generate_from_stream
2322
from langchain_core.messages import BaseMessage
2423
from langchain_core.outputs import ChatResult
2524
from langchain_nvidia_ai_endpoints import ChatNVIDIA as ChatNVIDIAOriginal
26-
from pydantic import Field
25+
from pydantic.v1 import Field
2726

28-
log = logging.getLogger(__name__) # pragma: no cover
27+
log = logging.getLogger(__name__)
2928

3029

31-
def stream_decorator(func): # pragma: no cover
30+
def stream_decorator(func):
3231
@wraps(func)
3332
def wrapper(
3433
self,
@@ -52,52 +51,10 @@ def wrapper(
5251

5352
# NOTE: this needs to have the same name as the original class,
5453
# otherwise, there's a check inside `langchain-nvidia-ai-endpoints` that will fail.
55-
class ChatNVIDIA(ChatNVIDIAOriginal): # pragma: no cover
54+
class ChatNVIDIA(ChatNVIDIAOriginal):
5655
streaming: bool = Field(
5756
default=False, description="Whether to use streaming or not"
5857
)
59-
custom_headers: Optional[Dict[str, str]] = Field(
60-
default=None, description="Custom HTTP headers to send with requests"
61-
)
62-
63-
def __init__(self, **kwargs: Any):
64-
super().__init__(**kwargs)
65-
if self.custom_headers:
66-
custom_headers_error = (
67-
"custom_headers requires langchain-nvidia-ai-endpoints >= 0.3.0. "
68-
"Your version does not support the required client structure or "
69-
"extra_headers parameter. Please upgrade: "
70-
"pip install --upgrade langchain-nvidia-ai-endpoints>=0.3.0"
71-
)
72-
if not hasattr(self._client, "get_req"):
73-
raise RuntimeError(custom_headers_error)
74-
75-
sig = inspect.signature(self._client.get_req)
76-
if "extra_headers" not in sig.parameters:
77-
raise RuntimeError(custom_headers_error)
78-
79-
self._wrap_client_methods()
80-
81-
def _wrap_client_methods(self):
82-
original_get_req = self._client.get_req
83-
original_get_req_stream = self._client.get_req_stream
84-
85-
def wrapped_get_req(payload: dict = None, extra_headers: dict = None):
86-
payload = payload or {}
87-
extra_headers = extra_headers or {}
88-
merged_headers = {**extra_headers, **self.custom_headers}
89-
return original_get_req(payload=payload, extra_headers=merged_headers)
90-
91-
def wrapped_get_req_stream(payload: dict = None, extra_headers: dict = None):
92-
payload = payload or {}
93-
extra_headers = extra_headers or {}
94-
merged_headers = {**extra_headers, **self.custom_headers}
95-
return original_get_req_stream(
96-
payload=payload, extra_headers=merged_headers
97-
)
98-
99-
object.__setattr__(self._client, "get_req", wrapped_get_req)
100-
object.__setattr__(self._client, "get_req_stream", wrapped_get_req_stream)
10158

10259
@stream_decorator
10360
def _generate(

0 commit comments

Comments
 (0)