Skip to content

Commit a4ee2da

Browse files
authored
Feature/add mlx (#73)
1 parent 6633d8a commit a4ee2da

29 files changed

+516
-1395
lines changed

Package.swift

+14-7
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@ let package = Package(
2727
.library(name: "SpeziLLMFog", targets: ["SpeziLLMFog"])
2828
],
2929
dependencies: [
30+
.package(url: "https://github.com/ml-explore/mlx-swift", from: "0.18.1"),
31+
.package(url: "https://github.com/ml-explore/mlx-swift-examples", from: "1.16.0"),
32+
.package(url: "https://github.com/huggingface/swift-transformers", .upToNextMinor(from: "0.1.12")),
3033
.package(url: "https://github.com/StanfordBDHG/OpenAI", .upToNextMinor(from: "0.2.9")),
31-
.package(url: "https://github.com/StanfordBDHG/llama.cpp", .upToNextMinor(from: "0.3.3")),
3234
.package(url: "https://github.com/StanfordSpezi/Spezi", from: "1.2.1"),
3335
.package(url: "https://github.com/StanfordSpezi/SpeziFoundation", from: "2.0.0-beta.3"),
3436
.package(url: "https://github.com/StanfordSpezi/SpeziStorage", from: "1.0.2"),
@@ -49,19 +51,24 @@ let package = Package(
4951
name: "SpeziLLMLocal",
5052
dependencies: [
5153
.target(name: "SpeziLLM"),
52-
.product(name: "llama", package: "llama.cpp"),
5354
.product(name: "SpeziFoundation", package: "SpeziFoundation"),
54-
.product(name: "Spezi", package: "Spezi")
55-
],
56-
swiftSettings: [
57-
.interoperabilityMode(.Cxx)
55+
.product(name: "Spezi", package: "Spezi"),
56+
.product(name: "MLX", package: "mlx-swift"),
57+
.product(name: "MLXFast", package: "mlx-swift"),
58+
.product(name: "MLXNN", package: "mlx-swift"),
59+
.product(name: "MLXOptimizers", package: "mlx-swift"),
60+
.product(name: "MLXRandom", package: "mlx-swift"),
61+
.product(name: "Transformers", package: "swift-transformers"),
62+
.product(name: "LLM", package: "mlx-swift-examples")
5863
]
5964
),
6065
.target(
6166
name: "SpeziLLMLocalDownload",
6267
dependencies: [
6368
.product(name: "SpeziOnboarding", package: "SpeziOnboarding"),
64-
.product(name: "SpeziViews", package: "SpeziViews")
69+
.product(name: "SpeziViews", package: "SpeziViews"),
70+
.target(name: "SpeziLLMLocal"),
71+
.product(name: "LLM", package: "mlx-swift-examples")
6572
]
6673
),
6774
.target(

README.md

+7-30
Original file line numberDiff line numberDiff line change
@@ -57,37 +57,13 @@ The section below highlights the setup and basic use of the [SpeziLLMLocal](http
5757
5858
### Spezi LLM Local
5959

60-
The target enables developers to easily execute medium-size Language Models (LLMs) locally on-device via the [llama.cpp framework](https://github.com/ggerganov/llama.cpp). The module allows you to interact with the locally run LLM via purely Swift-based APIs, no interaction with low-level C or C++ code is necessary, building on top of the infrastructure of the [SpeziLLM target](https://swiftpackageindex.com/stanfordspezi/spezillm/documentation/spezillm).
60+
The target enables developers to easily execute medium-size Language Models (LLMs) locally on-device. The module allows you to interact with the locally run LLM via purely Swift-based APIs, no interaction with low-level code is necessary, building on top of the infrastructure of the [SpeziLLM target](https://swiftpackageindex.com/stanfordspezi/spezillm/documentation/spezillm).
61+
62+
> [!IMPORTANT]
63+
> Spezi LLM Local is not compatible with simulators. The underlying [`mlx-swift`](https://github.com/ml-explore/mlx-swift) requires a modern Metal MTLGPUFamily and the simulator does not provide that.
6164
6265
> [!IMPORTANT]
63-
> Important: In order to use the LLM local target, one needs to set build parameters in the consuming Xcode project or the consuming SPM package to enable the [Swift / C++ Interop](https://www.swift.org/documentation/cxx-interop/), introduced in Xcode 15 and Swift 5.9. Keep in mind that this is true for nested dependencies, one needs to set this configuration recursivly for the entire dependency tree towards the llama.cpp SPM package. <!-- markdown-link-check-disable-line -->
64-
>
65-
> **For Xcode projects:**
66-
> - Open your [build settings in Xcode](https://developer.apple.com/documentation/xcode/configuring-the-build-settings-of-a-target/) by selecting *PROJECT_NAME > TARGET_NAME > Build Settings*.
67-
> - Within the *Build Settings*, search for the `C++ and Objective-C Interoperability` setting and set it to `C++ / Objective-C++`. This enables the project to use the C++ headers from llama.cpp.
68-
>
69-
> **For SPM packages:**
70-
> - Open the `Package.swift` file of your [SPM package]((https://www.swift.org/documentation/package-manager/)) <!-- markdown-link-check-disable-line -->
71-
> - Within the package `target` that consumes the llama.cpp package, add the `interoperabilityMode(_:)` Swift build setting like that:
72-
> ```swift
73-
> /// Adds the dependency to the Spezi LLM SPM package
74-
> dependencies: [
75-
> .package(url: "https://github.com/StanfordSpezi/SpeziLLM", .upToNextMinor(from: "0.6.0"))
76-
> ],
77-
> targets: [
78-
> .target(
79-
> name: "ExampleConsumingTarget",
80-
> /// State the dependence of the target to SpeziLLMLocal
81-
> dependencies: [
82-
> .product(name: "SpeziLLMLocal", package: "SpeziLLM")
83-
> ],
84-
> /// Important: Configure the `.interoperabilityMode(_:)` within the `swiftSettings`
85-
> swiftSettings: [
86-
> .interoperabilityMode(.Cxx)
87-
> ]
88-
> )
89-
> ]
90-
> ```
66+
> Important: To use the LLM local target, some LLMs require adding the [Increase Memory Limit](https://developer.apple.com/documentation/bundleresources/entitlements/com_apple_developer_kernel_increased-memory-limit) entitlement to the project.
9167
9268
#### Setup
9369

@@ -123,7 +99,8 @@ struct LLMLocalDemoView: View {
12399
// Instantiate the `LLMLocalSchema` to an `LLMLocalSession` via the `LLMRunner`.
124100
let llmSession: LLMLocalSession = runner(
125101
with: LLMLocalSchema(
126-
modelPath: URL(string: "URL to the local model file")!
102+
model: .llama3_8B_4bit,
103+
formatChat: LLMLocalSchema.PromptFormattingDefaults.llama3
127104
)
128105
)
129106

Original file line numberDiff line numberDiff line change
@@ -1,218 +1,27 @@
11
//
22
// This source file is part of the Stanford Spezi open source project
33
//
4-
// SPDX-FileCopyrightText: 2022 Stanford University and the project authors (see CONTRIBUTORS.md)
4+
// SPDX-FileCopyrightText: 2024 Stanford University and the project authors (see CONTRIBUTORS.md)
55
//
66
// SPDX-License-Identifier: MIT
77
//
88

99
import Foundation
10-
@preconcurrency import llama
1110

1211

1312
/// Represents the context parameters of the LLM.
14-
///
15-
/// Internally, these data points are passed as a llama.cpp `llama_context_params` C struct to the LLM.
1613
public struct LLMLocalContextParameters: Sendable {
17-
// swiftlint:disable identifier_name
18-
/// Swift representation of the `ggml_type` of llama.cpp, indicating data types within KV caches.
19-
public enum GGMLType: UInt32 {
20-
case f32 = 0
21-
case f16
22-
case q4_0
23-
case q4_1
24-
case q5_0 = 6
25-
case q5_1
26-
case q8_0
27-
case q8_1
28-
/// k-quantizations
29-
case q2_k
30-
case q3_k
31-
case q4_k
32-
case q5_k
33-
case q6_k
34-
case q8_k
35-
case iq2_xxs
36-
case iq2_xs
37-
case i8
38-
case i16
39-
case i32
40-
}
41-
// swiftlint:enable identifier_name
42-
43-
44-
/// Wrapped C struct from the llama.cpp library, later-on passed to the LLM
45-
private var wrapped: llama_context_params
46-
47-
48-
/// Context parameters in llama.cpp's low-level C representation
49-
var llamaCppRepresentation: llama_context_params {
50-
wrapped
51-
}
52-
5314
/// RNG seed of the LLM
54-
var seed: UInt32 {
55-
get {
56-
wrapped.seed
57-
}
58-
set {
59-
wrapped.seed = newValue
60-
}
61-
}
62-
63-
/// Context window size in tokens (0 = take default window size from model)
64-
var contextWindowSize: UInt32 {
65-
get {
66-
wrapped.n_ctx
67-
}
68-
set {
69-
wrapped.n_ctx = newValue
70-
}
71-
}
72-
73-
/// Maximum batch size during prompt processing
74-
var batchSize: UInt32 {
75-
get {
76-
wrapped.n_batch
77-
}
78-
set {
79-
wrapped.n_batch = newValue
80-
}
81-
}
82-
83-
/// Number of threads used by LLM for generation of output
84-
var threadCount: UInt32 {
85-
get {
86-
wrapped.n_threads
87-
}
88-
set {
89-
wrapped.n_threads = newValue
90-
}
91-
}
92-
93-
/// Number of threads used by LLM for batch processing
94-
var threadCountBatch: UInt32 {
95-
get {
96-
wrapped.n_threads_batch
97-
}
98-
set {
99-
wrapped.n_threads_batch = newValue
100-
}
101-
}
102-
103-
/// RoPE base frequency (0 = take default from model)
104-
var ropeFreqBase: Float {
105-
get {
106-
wrapped.rope_freq_base
107-
}
108-
set {
109-
wrapped.rope_freq_base = newValue
110-
}
111-
}
112-
113-
/// RoPE frequency scaling factor (0 = take default from model)
114-
var ropeFreqScale: Float {
115-
get {
116-
wrapped.rope_freq_scale
117-
}
118-
set {
119-
wrapped.rope_freq_scale = newValue
120-
}
121-
}
122-
123-
/// If `true`, offload the KQV ops (including the KV cache) to GPU
124-
var offloadKQV: Bool {
125-
get {
126-
wrapped.offload_kqv
127-
}
128-
set {
129-
wrapped.offload_kqv = newValue
130-
}
131-
}
132-
133-
/// ``GGMLType`` of the key of the KV cache
134-
var kvKeyType: GGMLType {
135-
get {
136-
GGMLType(rawValue: wrapped.type_k.rawValue) ?? .f16
137-
}
138-
set {
139-
wrapped.type_k = ggml_type(rawValue: newValue.rawValue)
140-
}
141-
}
142-
143-
/// ``GGMLType`` of the value of the KV cache
144-
var kvValueType: GGMLType {
145-
get {
146-
GGMLType(rawValue: wrapped.type_v.rawValue) ?? .f16
147-
}
148-
set {
149-
wrapped.type_v = ggml_type(rawValue: newValue.rawValue)
150-
}
151-
}
152-
153-
/// If `true`, the (deprecated) `llama_eval()` call computes all logits, not just the last one
154-
var computeAllLogits: Bool {
155-
get {
156-
wrapped.logits_all
157-
}
158-
set {
159-
wrapped.logits_all = newValue
160-
}
161-
}
162-
163-
/// If `true`, the mode is set to embeddings only
164-
var embeddingsOnly: Bool {
165-
get {
166-
wrapped.embeddings
167-
}
168-
set {
169-
wrapped.embeddings = newValue
170-
}
171-
}
15+
var seed: UInt64?
17216

17317
/// Creates the ``LLMLocalContextParameters`` which wrap the underlying llama.cpp `llama_context_params` C struct.
17418
/// Is passed to the underlying llama.cpp model in order to configure the context of the LLM.
17519
///
17620
/// - Parameters:
177-
/// - seed: RNG seed of the LLM, defaults to `4294967295` (which represents a random seed).
178-
/// - contextWindowSize: Context window size in tokens, defaults to `1024`.
179-
/// - batchSize: Maximum batch size during prompt processing, defaults to `1024` tokens.
180-
/// - threadCount: Number of threads used by LLM for generation of output, defaults to the processor count of the device.
181-
/// - threadCountBatch: Number of threads used by LLM for batch processing, defaults to the processor count of the device.
182-
/// - ropeFreqBase: RoPE base frequency, defaults to `0` indicating the default from model.
183-
/// - ropeFreqScale: RoPE frequency scaling factor, defaults to `0` indicating the default from model.
184-
/// - offloadKQV: Offloads the KQV ops (including the KV cache) to GPU, defaults to `true`.
185-
/// - kvKeyType: ``GGMLType`` of the key of the KV cache, defaults to ``GGMLType/f16``.
186-
/// - kvValueType: ``GGMLType`` of the value of the KV cache, defaults to ``GGMLType/f16``.
187-
/// - computeAllLogits: `llama_eval()` call computes all logits, not just the last one. Defaults to `false`.
188-
/// - embeddingsOnly: Embedding-only mode, defaults to `false`.
21+
/// - seed: RNG seed of the LLM, defaults to a random seed.
18922
public init(
190-
seed: UInt32 = 4294967295,
191-
contextWindowSize: UInt32 = 1024,
192-
batchSize: UInt32 = 1024,
193-
threadCount: UInt32 = .init(ProcessInfo.processInfo.processorCount),
194-
threadCountBatch: UInt32 = .init(ProcessInfo.processInfo.processorCount),
195-
ropeFreqBase: Float = 0.0,
196-
ropeFreqScale: Float = 0.0,
197-
offloadKQV: Bool = true,
198-
kvKeyType: GGMLType = .f16,
199-
kvValueType: GGMLType = .f16,
200-
computeAllLogits: Bool = false,
201-
embeddingsOnly: Bool = false
23+
seed: UInt64? = nil
20224
) {
203-
self.wrapped = llama_context_default_params()
204-
20525
self.seed = seed
206-
self.contextWindowSize = contextWindowSize
207-
self.batchSize = batchSize
208-
self.threadCount = threadCount
209-
self.threadCountBatch = threadCountBatch
210-
self.ropeFreqBase = ropeFreqBase
211-
self.ropeFreqScale = ropeFreqScale
212-
self.offloadKQV = offloadKQV
213-
self.kvKeyType = kvKeyType
214-
self.kvValueType = kvValueType
215-
self.computeAllLogits = computeAllLogits
216-
self.embeddingsOnly = embeddingsOnly
21726
}
21827
}

0 commit comments

Comments
 (0)