Skip to content

Commit

Permalink
fix: parameters aren't encoded (#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinhermawan authored Aug 6, 2024
1 parent 46a2cf8 commit a92f36a
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 12 deletions.
25 changes: 20 additions & 5 deletions Playground/OKPlayground/Views/ChatView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,34 @@ struct ChatView: View {
@Environment(ViewModel.self) private var viewModel

@State private var model: String? = nil
@State private var temperature: Double = 0.5
@State private var prompt = ""
@State private var response = ""
@State private var cancellables = Set<AnyCancellable>()

var body: some View {
NavigationStack {
Form {
Section {
Picker("Model", selection: $model) {
Section("Model") {
Picker("Selected Model", selection: $model) {
ForEach(viewModel.models, id: \.self) { model in
Text(model)
.tag(model as String?)
}
}

}

Section("Temperature") {
Slider(value: $temperature, in: 0...1, step: 0.1) {
Text("Temperature")
} minimumValueLabel: {
Text("0")
} maximumValueLabel: {
Text("1")
}
}

Section("Prompt") {
TextField("Prompt", text: $prompt)
}

Expand All @@ -53,7 +66,8 @@ struct ChatView: View {

guard let model = model else { return }
let messages = [OKChatRequestData.Message(role: .user, content: prompt)]
let data = OKChatRequestData(model: model, messages: messages)
var data = OKChatRequestData(model: model, messages: messages)
data.options = OKCompletionOptions(temperature: temperature)

Task {
for try await chunk in viewModel.ollamaKit.chat(data: data) {
Expand All @@ -67,7 +81,8 @@ struct ChatView: View {

guard let model = model else { return }
let messages = [OKChatRequestData.Message(role: .user, content: prompt)]
let data = OKChatRequestData(model: model, messages: messages)
var data = OKChatRequestData(model: model, messages: messages)
data.options = OKCompletionOptions(temperature: temperature)

viewModel.ollamaKit.chat(data: data)
.sink { completion in
Expand Down
25 changes: 20 additions & 5 deletions Playground/OKPlayground/Views/GenerateView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,34 @@ struct GenerateView: View {
@Environment(ViewModel.self) private var viewModel

@State private var model: String? = nil
@State private var temperature: Double = 0.5
@State private var prompt = ""
@State private var response = ""
@State private var cancellables = Set<AnyCancellable>()

var body: some View {
NavigationStack {
Form {
Section {
Picker("Model", selection: $model) {
Section("Model") {
Picker("Selected Model", selection: $model) {
ForEach(viewModel.models, id: \.self) { model in
Text(model)
.tag(model as String?)
}
}

}

Section("Temperature") {
Slider(value: $temperature, in: 0...1, step: 0.1) {
Text("Temperature")
} minimumValueLabel: {
Text("0")
} maximumValueLabel: {
Text("1")
}
}

Section("Prompt") {
TextField("Prompt", text: $prompt)
}

Expand All @@ -52,7 +65,8 @@ struct GenerateView: View {
self.response = ""

guard let model = model else { return }
let data = OKGenerateRequestData(model: model, prompt: prompt)
var data = OKGenerateRequestData(model: model, prompt: prompt)
data.options = OKCompletionOptions(temperature: temperature)

Task {
for try await chunk in viewModel.ollamaKit.generate(data: data) {
Expand All @@ -65,7 +79,8 @@ struct GenerateView: View {
self.response = ""

guard let model = model else { return }
let data = OKGenerateRequestData(model: model, prompt: prompt)
var data = OKGenerateRequestData(model: model, prompt: prompt)
data.options = OKCompletionOptions(temperature: temperature)

viewModel.ollamaKit.generate(data: data)
.sink { completion in
Expand Down
17 changes: 17 additions & 0 deletions Sources/OllamaKit/RequestData/Completion/OKCompletionOptions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,21 @@ public struct OKCompletionOptions: Encodable {
/// `minP` ensures that tokens below a certain probability threshold are excluded,
/// focusing the model's output on more probable sequences. Default is 0.0, meaning no filtering.
public var minP: Double?

public init(mirostat: Int? = nil, mirostatEta: Double? = nil, mirostatTau: Double? = nil, numCtx: Int? = nil, repeatLastN: Int? = nil, repeatPenalty: Double? = nil, temperature: Double? = nil, seed: Int? = nil, stop: String? = nil, tfsZ: Double? = nil, numPredict: Int? = nil, topK: Int? = nil, topP: Double? = nil, minP: Double? = nil) {
self.mirostat = mirostat
self.mirostatEta = mirostatEta
self.mirostatTau = mirostatTau
self.numCtx = numCtx
self.repeatLastN = repeatLastN
self.repeatPenalty = repeatPenalty
self.temperature = temperature
self.seed = seed
self.stop = stop
self.tfsZ = tfsZ
self.numPredict = numPredict
self.topK = topK
self.topP = topP
self.minP = minP
}
}
20 changes: 19 additions & 1 deletion Sources/OllamaKit/RequestData/OKChatRequestData.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import Foundation

/// A structure that encapsulates data for chat requests to the Ollama API.
public struct OKChatRequestData: Encodable {
public struct OKChatRequestData {
private let stream: Bool

/// A string representing the model identifier to be used for the chat session.
Expand Down Expand Up @@ -60,3 +60,21 @@ public struct OKChatRequestData: Encodable {
}
}
}

extension OKChatRequestData: Encodable {
public func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self)
try container.encode(stream, forKey: .stream)
try container.encode(model, forKey: .model)
try container.encode(messages, forKey: .messages)
try container.encodeIfPresent(tools, forKey: .tools)

if let options {
try options.encode(to: encoder)
}
}

private enum CodingKeys: String, CodingKey {
case stream, model, messages, tools
}
}
22 changes: 21 additions & 1 deletion Sources/OllamaKit/RequestData/OKGenerateRequestData.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import Foundation

/// A structure that encapsulates the data required for generating responses using the Ollama API.
public struct OKGenerateRequestData: Encodable {
public struct OKGenerateRequestData {
private let stream: Bool

/// A string representing the identifier of the model.
Expand Down Expand Up @@ -36,3 +36,23 @@ public struct OKGenerateRequestData: Encodable {
self.images = images
}
}

extension OKGenerateRequestData: Encodable {
public func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self)
try container.encode(stream, forKey: .stream)
try container.encode(model, forKey: .model)
try container.encode(prompt, forKey: .prompt)
try container.encodeIfPresent(images, forKey: .images)
try container.encodeIfPresent(system, forKey: .system)
try container.encodeIfPresent(context, forKey: .context)

if let options {
try options.encode(to: encoder)
}
}

private enum CodingKeys: String, CodingKey {
case stream, model, prompt, images, system, context
}
}

0 comments on commit a92f36a

Please sign in to comment.