Skip to content

Commit 93ad679

Browse files
Add Gemma 3 test. (#14837)
Summary: Also, let multimodal runner accept image/audio data as floats too Differential Revision: D84001548 --------- Co-authored-by: Anthony Shoumikhin <[email protected]>
1 parent c212b42 commit 93ad679

File tree

4 files changed

+175
-47
lines changed

4 files changed

+175
-47
lines changed

extension/llm/apple/ExecuTorchLLM/Exported/ExecuTorchLLMMultimodalRunner.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ __attribute__((objc_subclassing_restricted))
4444
channels:(NSInteger)channels
4545
NS_DESIGNATED_INITIALIZER;
4646

47+
- (instancetype)initWithFloatData:(NSData *)data
48+
width:(NSInteger)width
49+
height:(NSInteger)height
50+
channels:(NSInteger)channels
51+
NS_DESIGNATED_INITIALIZER;
52+
4753
@property(nonatomic, readonly) NSData *data;
4854

4955
@property(nonatomic, readonly) NSInteger width;
@@ -52,6 +58,8 @@ __attribute__((objc_subclassing_restricted))
5258

5359
@property(nonatomic, readonly) NSInteger channels;
5460

61+
@property(nonatomic, readonly) BOOL isFloat;
62+
5563
+ (instancetype)new NS_UNAVAILABLE;
5664
- (instancetype)init NS_UNAVAILABLE;
5765

@@ -80,6 +88,12 @@ __attribute__((objc_subclassing_restricted))
8088
frames:(NSInteger)frames
8189
NS_DESIGNATED_INITIALIZER;
8290

91+
- (instancetype)initWithFloatData:(NSData *)data
92+
batchSize:(NSInteger)batchSize
93+
bins:(NSInteger)bins
94+
frames:(NSInteger)frames
95+
NS_DESIGNATED_INITIALIZER;
96+
8397
@property(nonatomic, readonly) NSData *data;
8498

8599
@property(nonatomic, readonly) NSInteger batchSize;
@@ -88,6 +102,8 @@ __attribute__((objc_subclassing_restricted))
88102

89103
@property(nonatomic, readonly) NSInteger frames;
90104

105+
@property(nonatomic, readonly) BOOL isFloat;
106+
91107
+ (instancetype)new NS_UNAVAILABLE;
92108
- (instancetype)init NS_UNAVAILABLE;
93109

extension/llm/apple/ExecuTorchLLM/Exported/ExecuTorchLLMMultimodalRunner.mm

Lines changed: 77 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,22 @@ - (instancetype)initWithData:(NSData *)data
3232
_width = width;
3333
_height = height;
3434
_channels = channels;
35+
_isFloat = NO;
36+
}
37+
return self;
38+
}
39+
40+
- (instancetype)initWithFloatData:(NSData *)data
41+
width:(NSInteger)width
42+
height:(NSInteger)height
43+
channels:(NSInteger)channels {
44+
self = [super init];
45+
if (self) {
46+
_data = [data copy];
47+
_width = width;
48+
_height = height;
49+
_channels = channels;
50+
_isFloat = YES;
3551
}
3652
return self;
3753
}
@@ -53,6 +69,22 @@ - (instancetype)initWithData:(NSData *)data
5369
_batchSize = batchSize;
5470
_bins = bins;
5571
_frames = frames;
72+
_isFloat = NO;
73+
}
74+
return self;
75+
}
76+
77+
- (instancetype)initWithFloatData:(NSData *)data
78+
batchSize:(NSInteger)batchSize
79+
bins:(NSInteger)bins
80+
frames:(NSInteger)frames {
81+
self = [super init];
82+
if (self) {
83+
_data = [data copy];
84+
_batchSize = batchSize;
85+
_bins = bins;
86+
_frames = frames;
87+
_isFloat = YES;
5688
}
5789
return self;
5890
}
@@ -170,20 +202,58 @@ - (BOOL)generateWithInputs:(NSArray<ExecuTorchLLMMultimodalInput *> *)inputs
170202
return NO;
171203
}
172204
std::vector<llm::MultimodalInput> nativeInputs;
205+
nativeInputs.reserve((size_t)inputs.count);
173206
for (ExecuTorchLLMMultimodalInput *input in inputs) {
174207
switch (input.type) {
175208
case ExecuTorchLLMMultimodalInputTypeText:
176209
nativeInputs.emplace_back(llm::MultimodalInput(input.text.UTF8String));
177210
break;
178211
case ExecuTorchLLMMultimodalInputTypeImage: {
179212
ExecuTorchLLMImage *image = input.image;
180-
std::vector<uint8_t> data((uint8_t *)image.data.bytes, (uint8_t *)image.data.bytes + image.data.length);
181-
nativeInputs.emplace_back(llm::MultimodalInput(llm::Image(
182-
std::move(data),
183-
(int32_t)image.width,
184-
(int32_t)image.height,
185-
(int32_t)image.channels
186-
)));
213+
if (image.isFloat) {
214+
const float *buffer = (const float *)image.data.bytes;
215+
size_t elementCount = (size_t)image.data.length / sizeof(float);
216+
std::vector<float> data(buffer, buffer + elementCount);
217+
nativeInputs.emplace_back(llm::MultimodalInput(llm::Image(
218+
std::move(data),
219+
(int32_t)image.width,
220+
(int32_t)image.height,
221+
(int32_t)image.channels
222+
)));
223+
} else {
224+
const uint8_t *buffer = (const uint8_t *)image.data.bytes;
225+
std::vector<uint8_t> data(buffer, buffer + image.data.length);
226+
nativeInputs.emplace_back(llm::MultimodalInput(llm::Image(
227+
std::move(data),
228+
(int32_t)image.width,
229+
(int32_t)image.height,
230+
(int32_t)image.channels
231+
)));
232+
}
233+
break;
234+
}
235+
case ExecuTorchLLMMultimodalInputTypeAudio: {
236+
ExecuTorchLLMAudio *audio = input.audio;
237+
if (audio.isFloat) {
238+
const float *buffer = (const float *)audio.data.bytes;
239+
size_t elementCount = (size_t)audio.data.length / sizeof(float);
240+
std::vector<float> data(buffer, buffer + elementCount);
241+
nativeInputs.emplace_back(llm::MultimodalInput(llm::Audio(
242+
std::move(data),
243+
(int32_t)audio.batchSize,
244+
(int32_t)audio.bins,
245+
(int32_t)audio.frames
246+
)));
247+
} else {
248+
const uint8_t *buffer = (const uint8_t *)audio.data.bytes;
249+
std::vector<uint8_t> data(buffer, buffer + audio.data.length);
250+
nativeInputs.emplace_back(llm::MultimodalInput(llm::Audio(
251+
std::move(data),
252+
(int32_t)audio.batchSize,
253+
(int32_t)audio.bins,
254+
(int32_t)audio.frames
255+
)));
256+
}
187257
break;
188258
}
189259
default: {

extension/llm/apple/ExecuTorchLLM/__tests__/MultimodalRunnerTest.swift

Lines changed: 80 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -11,47 +11,89 @@ import ExecuTorchLLM
1111
import XCTest
1212

1313
extension UIImage {
14-
func asImage() -> Image {
15-
let targetSide = CGFloat(336)
16-
let scale = max(targetSide / size.width, targetSide / size.height)
17-
let scaledSize = CGSize(width: size.width * scale, height: size.height * scale)
14+
func centerCropped(to sideSize: CGFloat) -> UIImage {
15+
precondition(sideSize > 0)
1816
let format = UIGraphicsImageRendererFormat.default()
1917
format.scale = 1
20-
let scaledImage = UIGraphicsImageRenderer(size: scaledSize, format: format).image { _ in
21-
draw(in: CGRect(origin: .zero, size: scaledSize))
22-
}
23-
guard let scaledCGImage = scaledImage.cgImage else {
24-
return Image(data: Data(), width: 336, height: 336, channels: 3)
25-
}
26-
let cropRect = CGRect(
27-
x: ((scaledSize.width - targetSide) * 0.5).rounded(.down),
28-
y: ((scaledSize.height - targetSide) * 0.5).rounded(.down),
29-
width: targetSide.rounded(.down),
30-
height: targetSide.rounded(.down)
31-
)
32-
let croppedCGImage = scaledCGImage.cropping(to: cropRect) ?? scaledCGImage
33-
let imageWidth = croppedCGImage.width
34-
let imageHeight = croppedCGImage.height
35-
let pixelCount = imageWidth * imageHeight
36-
var rgbaBuffer = [UInt8](repeating: 0, count: pixelCount * 4)
37-
let context = CGContext(
18+
format.opaque = false
19+
return UIGraphicsImageRenderer(size: CGSize(width: sideSize, height: sideSize), format: format)
20+
.image { _ in
21+
let scaleFactor = max(sideSize / size.width, sideSize / size.height)
22+
let scaledWidth = size.width * scaleFactor
23+
let scaledHeight = size.height * scaleFactor
24+
let originX = (sideSize - scaledWidth) / 2
25+
let originY = (sideSize - scaledHeight) / 2
26+
draw(in: CGRect(x: originX, y: originY, width: scaledWidth, height: scaledHeight))
27+
}
28+
}
29+
30+
func rgbBytes() -> [UInt8]? {
31+
guard let cgImage = cgImage else { return nil }
32+
let pixelWidth = Int(cgImage.width)
33+
let pixelHeight = Int(cgImage.height)
34+
let pixelCount = pixelWidth * pixelHeight
35+
let bytesPerPixel = 4
36+
let bytesPerRow = pixelWidth * bytesPerPixel
37+
var rgbaBuffer = [UInt8](repeating: 0, count: pixelCount * bytesPerPixel)
38+
guard let context = CGContext(
3839
data: &rgbaBuffer,
39-
width: imageWidth,
40-
height: imageHeight,
40+
width: pixelWidth,
41+
height: pixelHeight,
4142
bitsPerComponent: 8,
42-
bytesPerRow: imageWidth * 4,
43+
bytesPerRow: bytesPerRow,
4344
space: CGColorSpaceCreateDeviceRGB(),
4445
bitmapInfo: CGImageAlphaInfo.premultipliedLast.rawValue | CGBitmapInfo.byteOrder32Big.rawValue
45-
)!
46-
context.draw(croppedCGImage, in: CGRect(x: 0, y: 0, width: imageWidth, height: imageHeight))
47-
var planarRGB = [UInt8](repeating: 0, count: pixelCount * 3)
46+
) else { return nil }
47+
48+
context.draw(cgImage, in: CGRect(x: 0, y: 0, width: pixelWidth, height: pixelHeight))
49+
50+
var rgbBytes = [UInt8](repeating: 0, count: pixelCount * 3)
51+
for pixelIndex in 0..<pixelCount {
52+
let sourceIndex = pixelIndex * bytesPerPixel
53+
rgbBytes[pixelIndex] = rgbaBuffer[sourceIndex + 0]
54+
rgbBytes[pixelIndex + pixelCount] = rgbaBuffer[sourceIndex + 1]
55+
rgbBytes[pixelIndex + 2 * pixelCount] = rgbaBuffer[sourceIndex + 2]
56+
}
57+
return rgbBytes
58+
}
59+
60+
func rgbBytesNormalized(mean: [Float] = [0, 0, 0], std: [Float] = [1, 1, 1]) -> [Float]? {
61+
precondition(mean.count == 3 && std.count == 3)
62+
precondition(std[0] != 0 && std[1] != 0 && std[2] != 0)
63+
guard let rgbBytes = rgbBytes() else { return nil }
64+
let pixelCount = rgbBytes.count / 3
65+
var rgbBytesNormalized = [Float](repeating: 0, count: pixelCount * 3)
4866
for pixelIndex in 0..<pixelCount {
49-
let sourceOffset = pixelIndex * 4
50-
planarRGB[pixelIndex] = rgbaBuffer[sourceOffset]
51-
planarRGB[pixelIndex + pixelCount] = rgbaBuffer[sourceOffset + 1]
52-
planarRGB[pixelIndex + pixelCount * 2] = rgbaBuffer[sourceOffset + 2]
67+
rgbBytesNormalized[pixelIndex] =
68+
(Float(rgbBytes[pixelIndex]) / 255.0 - mean[0]) / std[0]
69+
rgbBytesNormalized[pixelIndex + pixelCount] =
70+
(Float(rgbBytes[pixelIndex + pixelCount]) / 255.0 - mean[1]) / std[1]
71+
rgbBytesNormalized[pixelIndex + 2 * pixelCount] =
72+
(Float(rgbBytes[pixelIndex + 2 * pixelCount]) / 255.0 - mean[2]) / std[2]
5373
}
54-
return Image(data: Data(planarRGB), width: 336, height: 336, channels: 3)
74+
return rgbBytesNormalized
75+
}
76+
77+
func asImage(_ sideSize: CGFloat) -> Image {
78+
return Image(
79+
data: Data(centerCropped(to: sideSize).rgbBytes() ?? []),
80+
width: Int(sideSize),
81+
height: Int(sideSize),
82+
channels: 3
83+
)
84+
}
85+
86+
func asNormalizedImage(
87+
_ sideSize: CGFloat,
88+
mean: [Float] = [0.485, 0.456, 0.406],
89+
std: [Float] = [0.229, 0.224, 0.225]
90+
) -> Image {
91+
return Image(
92+
float: (centerCropped(to: sideSize).rgbBytesNormalized(mean: mean, std: std) ?? []).withUnsafeBufferPointer { Data(buffer: $0) },
93+
width: Int(sideSize),
94+
height: Int(sideSize),
95+
channels: 3
96+
)
5597
}
5698
}
5799

@@ -120,7 +162,7 @@ class MultimodalRunnerTest: XCTestCase {
120162
let sequenceLength = 768
121163
let bundle = Bundle(for: type(of: self))
122164
guard let modelPath = bundle.path(forResource: "llava", ofType: "pte"),
123-
let tokenizerPath = bundle.path(forResource: "tokenizer", ofType: "bin"),
165+
let tokenizerPath = bundle.path(forResource: "llava_tokenizer", ofType: "bin"),
124166
let imagePath = bundle.path(forResource: "IMG_0005", ofType: "jpg"),
125167
let uiImage = UIImage(contentsOfFile: imagePath) else {
126168
XCTFail("Couldn't find model or tokenizer files")
@@ -132,8 +174,8 @@ class MultimodalRunnerTest: XCTestCase {
132174
do {
133175
try runner.generate([
134176
MultimodalInput(systemPrompt),
135-
MultimodalInput(uiImage.asImage()),
136-
MultimodalInput("\(userPrompt) \(assistantPrompt)"),
177+
MultimodalInput(uiImage.asImage(sideSize)),
178+
MultimodalInput(String(format: chatTemplate, userPrompt)),
137179
], Config {
138180
$0.sequenceLength = sequenceLength
139181
}) { token in
@@ -149,8 +191,8 @@ class MultimodalRunnerTest: XCTestCase {
149191
do {
150192
try runner.generate([
151193
MultimodalInput(systemPrompt),
152-
MultimodalInput(uiImage.asImage()),
153-
MultimodalInput("\(userPrompt) \(assistantPrompt)"),
194+
MultimodalInput(uiImage.asImage(sideSize)),
195+
MultimodalInput(String(format: chatTemplate, userPrompt)),
154196
], Config {
155197
$0.sequenceLength = sequenceLength
156198
}) { token in

extension/llm/apple/ExecuTorchLLM/__tests__/TextRunnerTest.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class TextRunnerTest: XCTestCase {
4242
func testLLaMA() {
4343
let bundle = Bundle(for: type(of: self))
4444
guard let modelPath = bundle.path(forResource: "llama3_2-1B", ofType: "pte"),
45-
let tokenizerPath = bundle.path(forResource: "tokenizer", ofType: "model") else {
45+
let tokenizerPath = bundle.path(forResource: "llama_tokenizer", ofType: "model") else {
4646
XCTFail("Couldn't find model or tokenizer files")
4747
return
4848
}
@@ -77,7 +77,7 @@ class TextRunnerTest: XCTestCase {
7777
func testPhi4() {
7878
let bundle = Bundle(for: type(of: self))
7979
guard let modelPath = bundle.path(forResource: "phi4-mini", ofType: "pte"),
80-
let tokenizerPath = bundle.path(forResource: "tokenizer", ofType: "json") else {
80+
let tokenizerPath = bundle.path(forResource: "phi_tokenizer", ofType: "json") else {
8181
XCTFail("Couldn't find model or tokenizer files")
8282
return
8383
}

0 commit comments

Comments
 (0)