Skip to content

Commit

Permalink
community: fix some features on Naver ChatModel & embedding model (#2…
Browse files Browse the repository at this point in the history
…8228)

# Description

- adding stopReason to response_metadata to call stream and astream
- excluding NCP_APIGW_API_KEY input required validation
- to remove warning Field "model_name" has conflict with protected
namespace "model_".

cc. @vbarda
  • Loading branch information
hyper-clova authored Nov 20, 2024
1 parent 4da3562 commit 218b4e0
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 25 deletions.
37 changes: 21 additions & 16 deletions libs/community/langchain_community/chat_models/naver.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.utils import convert_to_secret_str, get_from_env
from pydantic import AliasChoices, Field, SecretStr, model_validator
from pydantic import AliasChoices, ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self

_DEFAULT_BASE_URL = "https://clovastudio.stream.ntruss.com"
Expand All @@ -51,6 +51,12 @@ def _convert_chunk_to_message_chunk(
role = message.get("role")
content = message.get("content") or ""

if sse.event == "result":
response_metadata = {}
if "stopReason" in sse_data:
response_metadata["stopReason"] = sse_data["stopReason"]
return AIMessageChunk(content="", response_metadata=response_metadata)

if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
Expand Down Expand Up @@ -124,8 +130,6 @@ async def _aiter_sse(
event_data = sse.json()
if sse.event == "signal" and event_data.get("data", {}) == "[DONE]":
return
if sse.event == "result":
return
yield sse


Expand Down Expand Up @@ -210,10 +214,7 @@ class ChatClovaX(BaseChatModel):
timeout: int = Field(gt=0, default=90)
max_retries: int = Field(ge=1, default=2)

class Config:
"""Configuration for this pydantic object."""

populate_by_name = True
model_config = ConfigDict(populate_by_name=True, protected_namespaces=())

@property
def _default_params(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -286,7 +287,7 @@ def validate_model_after(self) -> Self:

if not self.ncp_apigw_api_key:
self.ncp_apigw_api_key = convert_to_secret_str(
get_from_env("ncp_apigw_api_key", "NCP_APIGW_API_KEY")
get_from_env("ncp_apigw_api_key", "NCP_APIGW_API_KEY", "")
)

if not self.base_url:
Expand All @@ -311,22 +312,28 @@ def validate_model_after(self) -> Self:
return self

def default_headers(self) -> Dict[str, Any]:
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
}

clovastudio_api_key = (
self.ncp_clovastudio_api_key.get_secret_value()
if self.ncp_clovastudio_api_key
else None
)
if clovastudio_api_key:
headers["X-NCP-CLOVASTUDIO-API-KEY"] = clovastudio_api_key

apigw_api_key = (
self.ncp_apigw_api_key.get_secret_value()
if self.ncp_apigw_api_key
else None
)
return {
"Content-Type": "application/json",
"Accept": "application/json",
"X-NCP-CLOVASTUDIO-API-KEY": clovastudio_api_key,
"X-NCP-APIGW-API-KEY": apigw_api_key,
}
if apigw_api_key:
headers["X-NCP-APIGW-API-KEY"] = apigw_api_key

return headers

def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]
Expand Down Expand Up @@ -363,8 +370,6 @@ def iter_sse() -> Iterator[ServerSentEvent]:
and event_data.get("data", {}) == "[DONE]"
):
return
if sse.event == "result":
return
if sse.event == "error":
raise SSEError(message=sse.data)
yield sse
Expand Down
24 changes: 15 additions & 9 deletions libs/community/langchain_community/embeddings/naver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pydantic import (
AliasChoices,
BaseModel,
ConfigDict,
Field,
SecretStr,
model_validator,
Expand Down Expand Up @@ -86,8 +87,7 @@ class ClovaXEmbeddings(BaseModel, Embeddings):

timeout: int = Field(gt=0, default=60)

class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())

@property
def lc_secrets(self) -> Dict[str, str]:
Expand Down Expand Up @@ -115,7 +115,7 @@ def validate_model_after(self) -> Self:

if not self.ncp_apigw_api_key:
self.ncp_apigw_api_key = convert_to_secret_str(
get_from_env("ncp_apigw_api_key", "NCP_APIGW_API_KEY")
get_from_env("ncp_apigw_api_key", "NCP_APIGW_API_KEY", "")
)

if not self.base_url:
Expand Down Expand Up @@ -143,22 +143,28 @@ def validate_model_after(self) -> Self:
return self

def default_headers(self) -> Dict[str, Any]:
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
}

clovastudio_api_key = (
self.ncp_clovastudio_api_key.get_secret_value()
if self.ncp_clovastudio_api_key
else None
)
if clovastudio_api_key:
headers["X-NCP-CLOVASTUDIO-API-KEY"] = clovastudio_api_key

apigw_api_key = (
self.ncp_apigw_api_key.get_secret_value()
if self.ncp_apigw_api_key
else None
)
return {
"Content-Type": "application/json",
"Accept": "application/json",
"X-NCP-CLOVASTUDIO-API-KEY": clovastudio_api_key,
"X-NCP-APIGW-API-KEY": apigw_api_key,
}
if apigw_api_key:
headers["X-NCP-APIGW-API-KEY"] = apigw_api_key

return headers

def _embed_text(self, text: str) -> List[float]:
payload = {"text": text}
Expand Down

0 comments on commit 218b4e0

Please sign in to comment.