|
1 | 1 | //
|
2 | 2 | // This source file is part of the Stanford Spezi open source project
|
3 | 3 | //
|
4 |
| -// SPDX-FileCopyrightText: 2022 Stanford University and the project authors (see CONTRIBUTORS.md) |
| 4 | +// SPDX-FileCopyrightText: 2024 Stanford University and the project authors (see CONTRIBUTORS.md) |
5 | 5 | //
|
6 | 6 | // SPDX-License-Identifier: MIT
|
7 | 7 | //
|
8 | 8 |
|
9 | 9 | import Foundation
|
10 |
| -@preconcurrency import llama |
11 | 10 |
|
12 | 11 |
|
13 | 12 | /// Represents the context parameters of the LLM.
|
14 |
| -/// |
15 |
| -/// Internally, these data points are passed as a llama.cpp `llama_context_params` C struct to the LLM. |
16 | 13 | public struct LLMLocalContextParameters: Sendable {
|
17 |
| - // swiftlint:disable identifier_name |
18 |
| - /// Swift representation of the `ggml_type` of llama.cpp, indicating data types within KV caches. |
19 |
| - public enum GGMLType: UInt32 { |
20 |
| - case f32 = 0 |
21 |
| - case f16 |
22 |
| - case q4_0 |
23 |
| - case q4_1 |
24 |
| - case q5_0 = 6 |
25 |
| - case q5_1 |
26 |
| - case q8_0 |
27 |
| - case q8_1 |
28 |
| - /// k-quantizations |
29 |
| - case q2_k |
30 |
| - case q3_k |
31 |
| - case q4_k |
32 |
| - case q5_k |
33 |
| - case q6_k |
34 |
| - case q8_k |
35 |
| - case iq2_xxs |
36 |
| - case iq2_xs |
37 |
| - case i8 |
38 |
| - case i16 |
39 |
| - case i32 |
40 |
| - } |
41 |
| - // swiftlint:enable identifier_name |
42 |
| - |
43 |
| - |
44 |
| - /// Wrapped C struct from the llama.cpp library, later-on passed to the LLM |
45 |
| - private var wrapped: llama_context_params |
46 |
| - |
47 |
| - |
48 |
| - /// Context parameters in llama.cpp's low-level C representation |
49 |
| - var llamaCppRepresentation: llama_context_params { |
50 |
| - wrapped |
51 |
| - } |
52 |
| - |
53 | 14 | /// RNG seed of the LLM
|
54 |
| - var seed: UInt32 { |
55 |
| - get { |
56 |
| - wrapped.seed |
57 |
| - } |
58 |
| - set { |
59 |
| - wrapped.seed = newValue |
60 |
| - } |
61 |
| - } |
62 |
| - |
63 |
| - /// Context window size in tokens (0 = take default window size from model) |
64 |
| - var contextWindowSize: UInt32 { |
65 |
| - get { |
66 |
| - wrapped.n_ctx |
67 |
| - } |
68 |
| - set { |
69 |
| - wrapped.n_ctx = newValue |
70 |
| - } |
71 |
| - } |
72 |
| - |
73 |
| - /// Maximum batch size during prompt processing |
74 |
| - var batchSize: UInt32 { |
75 |
| - get { |
76 |
| - wrapped.n_batch |
77 |
| - } |
78 |
| - set { |
79 |
| - wrapped.n_batch = newValue |
80 |
| - } |
81 |
| - } |
82 |
| - |
83 |
| - /// Number of threads used by LLM for generation of output |
84 |
| - var threadCount: UInt32 { |
85 |
| - get { |
86 |
| - wrapped.n_threads |
87 |
| - } |
88 |
| - set { |
89 |
| - wrapped.n_threads = newValue |
90 |
| - } |
91 |
| - } |
92 |
| - |
93 |
| - /// Number of threads used by LLM for batch processing |
94 |
| - var threadCountBatch: UInt32 { |
95 |
| - get { |
96 |
| - wrapped.n_threads_batch |
97 |
| - } |
98 |
| - set { |
99 |
| - wrapped.n_threads_batch = newValue |
100 |
| - } |
101 |
| - } |
102 |
| - |
103 |
| - /// RoPE base frequency (0 = take default from model) |
104 |
| - var ropeFreqBase: Float { |
105 |
| - get { |
106 |
| - wrapped.rope_freq_base |
107 |
| - } |
108 |
| - set { |
109 |
| - wrapped.rope_freq_base = newValue |
110 |
| - } |
111 |
| - } |
112 |
| - |
113 |
| - /// RoPE frequency scaling factor (0 = take default from model) |
114 |
| - var ropeFreqScale: Float { |
115 |
| - get { |
116 |
| - wrapped.rope_freq_scale |
117 |
| - } |
118 |
| - set { |
119 |
| - wrapped.rope_freq_scale = newValue |
120 |
| - } |
121 |
| - } |
122 |
| - |
123 |
| - /// If `true`, offload the KQV ops (including the KV cache) to GPU |
124 |
| - var offloadKQV: Bool { |
125 |
| - get { |
126 |
| - wrapped.offload_kqv |
127 |
| - } |
128 |
| - set { |
129 |
| - wrapped.offload_kqv = newValue |
130 |
| - } |
131 |
| - } |
132 |
| - |
133 |
| - /// ``GGMLType`` of the key of the KV cache |
134 |
| - var kvKeyType: GGMLType { |
135 |
| - get { |
136 |
| - GGMLType(rawValue: wrapped.type_k.rawValue) ?? .f16 |
137 |
| - } |
138 |
| - set { |
139 |
| - wrapped.type_k = ggml_type(rawValue: newValue.rawValue) |
140 |
| - } |
141 |
| - } |
142 |
| - |
143 |
| - /// ``GGMLType`` of the value of the KV cache |
144 |
| - var kvValueType: GGMLType { |
145 |
| - get { |
146 |
| - GGMLType(rawValue: wrapped.type_v.rawValue) ?? .f16 |
147 |
| - } |
148 |
| - set { |
149 |
| - wrapped.type_v = ggml_type(rawValue: newValue.rawValue) |
150 |
| - } |
151 |
| - } |
152 |
| - |
153 |
| - /// If `true`, the (deprecated) `llama_eval()` call computes all logits, not just the last one |
154 |
| - var computeAllLogits: Bool { |
155 |
| - get { |
156 |
| - wrapped.logits_all |
157 |
| - } |
158 |
| - set { |
159 |
| - wrapped.logits_all = newValue |
160 |
| - } |
161 |
| - } |
162 |
| - |
163 |
| - /// If `true`, the mode is set to embeddings only |
164 |
| - var embeddingsOnly: Bool { |
165 |
| - get { |
166 |
| - wrapped.embeddings |
167 |
| - } |
168 |
| - set { |
169 |
| - wrapped.embeddings = newValue |
170 |
| - } |
171 |
| - } |
| 15 | + var seed: UInt64? |
172 | 16 |
|
173 | 17 | /// Creates the ``LLMLocalContextParameters`` which wrap the underlying llama.cpp `llama_context_params` C struct.
|
174 | 18 | /// Is passed to the underlying llama.cpp model in order to configure the context of the LLM.
|
175 | 19 | ///
|
176 | 20 | /// - Parameters:
|
177 |
| - /// - seed: RNG seed of the LLM, defaults to `4294967295` (which represents a random seed). |
178 |
| - /// - contextWindowSize: Context window size in tokens, defaults to `1024`. |
179 |
| - /// - batchSize: Maximum batch size during prompt processing, defaults to `1024` tokens. |
180 |
| - /// - threadCount: Number of threads used by LLM for generation of output, defaults to the processor count of the device. |
181 |
| - /// - threadCountBatch: Number of threads used by LLM for batch processing, defaults to the processor count of the device. |
182 |
| - /// - ropeFreqBase: RoPE base frequency, defaults to `0` indicating the default from model. |
183 |
| - /// - ropeFreqScale: RoPE frequency scaling factor, defaults to `0` indicating the default from model. |
184 |
| - /// - offloadKQV: Offloads the KQV ops (including the KV cache) to GPU, defaults to `true`. |
185 |
| - /// - kvKeyType: ``GGMLType`` of the key of the KV cache, defaults to ``GGMLType/f16``. |
186 |
| - /// - kvValueType: ``GGMLType`` of the value of the KV cache, defaults to ``GGMLType/f16``. |
187 |
| - /// - computeAllLogits: `llama_eval()` call computes all logits, not just the last one. Defaults to `false`. |
188 |
| - /// - embeddingsOnly: Embedding-only mode, defaults to `false`. |
| 21 | + /// - seed: RNG seed of the LLM, defaults to a random seed. |
189 | 22 | public init(
|
190 |
| - seed: UInt32 = 4294967295, |
191 |
| - contextWindowSize: UInt32 = 1024, |
192 |
| - batchSize: UInt32 = 1024, |
193 |
| - threadCount: UInt32 = .init(ProcessInfo.processInfo.processorCount), |
194 |
| - threadCountBatch: UInt32 = .init(ProcessInfo.processInfo.processorCount), |
195 |
| - ropeFreqBase: Float = 0.0, |
196 |
| - ropeFreqScale: Float = 0.0, |
197 |
| - offloadKQV: Bool = true, |
198 |
| - kvKeyType: GGMLType = .f16, |
199 |
| - kvValueType: GGMLType = .f16, |
200 |
| - computeAllLogits: Bool = false, |
201 |
| - embeddingsOnly: Bool = false |
| 23 | + seed: UInt64? = nil |
202 | 24 | ) {
|
203 |
| - self.wrapped = llama_context_default_params() |
204 |
| - |
205 | 25 | self.seed = seed
|
206 |
| - self.contextWindowSize = contextWindowSize |
207 |
| - self.batchSize = batchSize |
208 |
| - self.threadCount = threadCount |
209 |
| - self.threadCountBatch = threadCountBatch |
210 |
| - self.ropeFreqBase = ropeFreqBase |
211 |
| - self.ropeFreqScale = ropeFreqScale |
212 |
| - self.offloadKQV = offloadKQV |
213 |
| - self.kvKeyType = kvKeyType |
214 |
| - self.kvValueType = kvValueType |
215 |
| - self.computeAllLogits = computeAllLogits |
216 |
| - self.embeddingsOnly = embeddingsOnly |
217 | 26 | }
|
218 | 27 | }
|
0 commit comments