Skip to content

Commit

Permalink
add json mode and codestral support
Browse files Browse the repository at this point in the history
  • Loading branch information
fsndzomga committed May 31, 2024
1 parent 0e531fe commit d1ac25d
Showing 1 changed file with 29 additions and 5 deletions.
34 changes: 29 additions & 5 deletions lib/mistral_rb.rb
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def initialize(api_key: ENV["MISTRAL_API_KEY"], base_uri: "https://api.mistral.a
self.class.base_uri base_uri
end

def create_chat_completion(model:, messages:, temperature: 0.7, top_p: 1, max_tokens: nil, stream: false, safe_prompt: false, random_seed: nil, tools: nil, tool_choice: nil)
def create_chat_completion(model:, messages:, temperature: 0.7, top_p: 1, max_tokens: nil, stream: false, safe_prompt: false, random_seed: nil, tools: nil, tool_choice: nil, response_format: nil)
body = {
model: model,
messages: messages,
Expand All @@ -31,12 +31,10 @@ def create_chat_completion(model:, messages:, temperature: 0.7, top_p: 1, max_to
safe_prompt: safe_prompt,
random_seed: random_seed,
tools: tools,
tool_choice: tool_choice
tool_choice: tool_choice,
response_format: response_format
}.compact.to_json

# Debugging: print the request body
puts "Request Body: #{body}"

if stream
# Use on_data callback for streaming
self.class.post("/chat/completions", body: body, headers: @headers, stream_body: true) do |fragment, _, _|
Expand All @@ -51,6 +49,32 @@ def create_chat_completion(model:, messages:, temperature: 0.7, top_p: 1, max_to
end
end

def create_fim_completion(prompt:, suffix:, model:, temperature: 0.7, top_p: 1, max_tokens: nil, min_tokens: 0, stream: false, random_seed: nil, stop: [])
body = {
prompt: prompt,
suffix: suffix,
model: model,
temperature: temperature,
top_p: top_p,
max_tokens: max_tokens,
min_tokens: min_tokens,
stream: stream,
random_seed: random_seed,
stop: stop
}.compact.to_json

if stream
self.class.post("/fim/completions", body: body, headers: @headers, stream_body: true) do |fragment, _, _|
processed_chunk = handle_stream_chunk(fragment)
yield(processed_chunk) if block_given? && processed_chunk
end
else
response = self.class.post("/fim/completions", body: body, headers: @headers)
parsed_response = handle_response(response)
MistralModels::CompletionResponse.new(parsed_response)
end
end

def create_embeddings(model:, input:, encoding_format: "float")
body = {
model: model,
Expand Down

0 comments on commit d1ac25d

Please sign in to comment.