Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add proxy configuration for Cohere model #4152

Merged
merged 1 commit into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions api/core/model_runtime/model_providers/cohere/cohere.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ provider_credential_schema:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
show_on: [ ]
- variable: base_url
label:
zh_Hans: API Base
en_US: API Base
type: text-input
required: false
placeholder:
zh_Hans: 在此输入您的 API Base,如 https://api.cohere.ai/v1
en_US: Enter your API Base, e.g. https://api.cohere.ai/v1
model_credential_schema:
model:
label:
Expand Down Expand Up @@ -70,3 +79,12 @@ model_credential_schema:
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
- variable: base_url
label:
zh_Hans: API Base
en_US: API Base
type: text-input
required: false
placeholder:
zh_Hans: 在此输入您的 API Base,如 https://api.cohere.ai/v1
en_US: Enter your API Base, e.g. https://api.cohere.ai/v1
9 changes: 5 additions & 4 deletions api/core/model_runtime/model_providers/cohere/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def _generate(self, model: str, credentials: dict,
:return: full response or stream response chunk generator result
"""
# initialize client
client = cohere.Client(credentials.get('api_key'))
client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))

if stop:
model_parameters['end_sequences'] = stop
Expand Down Expand Up @@ -233,7 +233,8 @@ def _handle_generate_response(self, model: str, credentials: dict, response: Gen

return response

def _handle_generate_stream_response(self, model: str, credentials: dict, response: Iterator[GenerateStreamedResponse],
def _handle_generate_stream_response(self, model: str, credentials: dict,
response: Iterator[GenerateStreamedResponse],
prompt_messages: list[PromptMessage]) -> Generator:
"""
Handle llm stream response
Expand Down Expand Up @@ -317,7 +318,7 @@ def _chat_generate(self, model: str, credentials: dict,
:return: full response or stream response chunk generator result
"""
# initialize client
client = cohere.Client(credentials.get('api_key'))
client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))

if stop:
model_parameters['stop_sequences'] = stop
Expand Down Expand Up @@ -636,7 +637,7 @@ def _num_tokens_from_string(self, model: str, credentials: dict, text: str) -> i
:return: number of tokens
"""
# initialize client
client = cohere.Client(credentials.get('api_key'))
client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))

response = client.tokenize(
text=text,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _invoke(self, model: str, credentials: dict,
)

# initialize client
client = cohere.Client(credentials.get('api_key'))
client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
response = client.rerank(
query=query,
documents=docs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def _tokenize(self, model: str, credentials: dict, text: str) -> list[str]:
return []

# initialize client
client = cohere.Client(credentials.get('api_key'))
client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))

response = client.tokenize(
text=text,
Expand Down Expand Up @@ -180,7 +180,7 @@ def _embedding_invoke(self, model: str, credentials: dict, texts: list[str]) ->
:return: embeddings and used tokens
"""
# initialize client
client = cohere.Client(credentials.get('api_key'))
client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))

# call embedding model
response = client.embed(
Expand Down