Skip to content

Commit d666192

Browse files
committed
better context
1 parent 515c072 commit d666192

File tree

3 files changed

+67
-8
lines changed

3 files changed

+67
-8
lines changed

Sources/GenKit/Sessions/ChatSession.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ public struct ChatSessionRequest {
168168
public private(set) var history: [Message] = []
169169
public private(set) var tools: [Tool] = []
170170
public private(set) var tool: Tool? = nil
171-
public private(set) var context: [String] = []
171+
public private(set) var context: [String: String] = [:]
172172
public private(set) var temperature: Float? = nil
173173

174174
public init(service: ChatService, model: Model, toolCallback: ToolCallback? = nil) {
@@ -198,7 +198,7 @@ public struct ChatSessionRequest {
198198
}
199199
}
200200

201-
public mutating func with(context: [String]) {
201+
public mutating func with(context: [String: String]) {
202202
self.context = context
203203
}
204204

@@ -211,11 +211,11 @@ public struct ChatSessionRequest {
211211

212212
// Apply user context
213213
var systemContext = ""
214-
if !context.isEmpty {
214+
if let memories = context["MEMORIES"] {
215215
systemContext = """
216216
The following is context about the current user:
217217
<user_context>
218-
\(context.joined(separator: "\n"))
218+
\(memories)
219219
</user_context>
220220
"""
221221
}
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import Foundation
2+
3+
public class PipelineSession {
4+
public static let shared = PipelineSession()
5+
6+
public func completion(_ request: Pipeline.Request, runLoopLimit: Int = 10) async throws -> Pipeline.Response {
7+
var pipelineResponse = Pipeline.Response(steps: [])
8+
9+
for step in request.steps {
10+
let instructions = PromptTemplate(step.instructions, with: step.inputs)
11+
12+
var req = ChatSessionRequest(service: step.service, model: step.model)
13+
req.with(history: [.user(content: instructions)])
14+
15+
let resp = try await ChatSession.shared.completion(req)
16+
let stepCompletion = Pipeline.Response.Step(
17+
instructions: instructions,
18+
inputs: step.inputs,
19+
outputs: step.outputs,
20+
messages: resp.messages
21+
)
22+
pipelineResponse.steps.append(stepCompletion)
23+
24+
// TODO: Extract the expected output variables
25+
}
26+
27+
return pipelineResponse
28+
}
29+
}
30+
31+
public struct Pipeline {
32+
33+
public struct Request: Sendable {
34+
public var steps: [Step]
35+
36+
public struct Step: Sendable {
37+
public var service: ChatService
38+
public var model: Model
39+
public var instructions: String
40+
public var inputs: [String: String]
41+
public var outputs: [String: String]
42+
}
43+
}
44+
45+
public struct Response: Codable, Sendable {
46+
public var steps: [Step]
47+
48+
public struct Step: Codable, Sendable {
49+
public var instructions: String
50+
public var inputs: [String: String]
51+
public var outputs: [String: String]
52+
public var messages: [Message]
53+
}
54+
}
55+
}
56+
57+
enum PipelineSessionError: Error {
58+
case unknown
59+
}

Sources/GenKit/Sessions/VisionSession.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ public struct VisionSessionRequest {
7070

7171
public private(set) var system: String? = nil
7272
public private(set) var history: [Message] = []
73-
public private(set) var context: [String] = []
73+
public private(set) var context: [String: String] = [:]
7474
public private(set) var temperature: Float? = nil
7575

7676
public init(service: VisionService, model: Model, toolCallback: ToolCallback? = nil) {
@@ -87,7 +87,7 @@ public struct VisionSessionRequest {
8787
self.history = history
8888
}
8989

90-
public mutating func with(context: [String]) {
90+
public mutating func with(context: [String: String]) {
9191
self.context = context
9292
}
9393

@@ -100,11 +100,11 @@ public struct VisionSessionRequest {
100100

101101
// Apply user context
102102
var systemContext = ""
103-
if !context.isEmpty {
103+
if let memories = context["MEMORIES"] {
104104
systemContext = """
105105
The following is context about the current user:
106106
<user_context>
107-
\(context.joined(separator: "\n"))
107+
\(memories)
108108
</user_context>
109109
"""
110110
}

0 commit comments

Comments
 (0)