@@ -10,60 +10,157 @@ import ExecuTorchLLM
1010import XCTest
1111
1212extension UIImage {
13- func asImage( ) -> Image {
14- let targetSide = CGFloat ( 336 )
15- let scale = max ( targetSide / size. width, targetSide / size. height)
16- let scaledSize = CGSize ( width: size. width * scale, height: size. height * scale)
13+ func centerCropped( to sideSize: CGFloat ) -> UIImage {
14+ precondition ( sideSize > 0 )
1715 let format = UIGraphicsImageRendererFormat . default ( )
1816 format. scale = 1
19- let scaledImage = UIGraphicsImageRenderer ( size: scaledSize, format: format) . image { _ in
20- draw ( in: CGRect ( origin: . zero, size: scaledSize) )
21- }
22- guard let scaledCGImage = scaledImage. cgImage else {
23- return Image ( data: Data ( ) , width: 336 , height: 336 , channels: 3 )
24- }
25- let cropRect = CGRect (
26- x: ( ( scaledSize. width - targetSide) * 0.5 ) . rounded ( . down) ,
27- y: ( ( scaledSize. height - targetSide) * 0.5 ) . rounded ( . down) ,
28- width: targetSide. rounded ( . down) ,
29- height: targetSide. rounded ( . down)
30- )
31- let croppedCGImage = scaledCGImage. cropping ( to: cropRect) ?? scaledCGImage
32- let imageWidth = croppedCGImage. width
33- let imageHeight = croppedCGImage. height
34- let pixelCount = imageWidth * imageHeight
35- var rgbaBuffer = [ UInt8] ( repeating: 0 , count: pixelCount * 4 )
36- let context = CGContext (
17+ format. opaque = false
18+ return UIGraphicsImageRenderer ( size: CGSize ( width: sideSize, height: sideSize) , format: format)
19+ . image { _ in
20+ let scaleFactor = max ( sideSize / size. width, sideSize / size. height)
21+ let scaledWidth = size. width * scaleFactor
22+ let scaledHeight = size. height * scaleFactor
23+ let originX = ( sideSize - scaledWidth) / 2
24+ let originY = ( sideSize - scaledHeight) / 2
25+ draw ( in: CGRect ( x: originX, y: originY, width: scaledWidth, height: scaledHeight) )
26+ }
27+ }
28+
29+ func rgbBytes( ) -> [ UInt8 ] ? {
30+ guard let cgImage = cgImage else { return nil }
31+ let pixelWidth = Int ( cgImage. width)
32+ let pixelHeight = Int ( cgImage. height)
33+ let pixelCount = pixelWidth * pixelHeight
34+ let bytesPerPixel = 4
35+ let bytesPerRow = pixelWidth * bytesPerPixel
36+ var rgbaBuffer = [ UInt8] ( repeating: 0 , count: pixelCount * bytesPerPixel)
37+ guard let context = CGContext (
3738 data: & rgbaBuffer,
38- width: imageWidth ,
39- height: imageHeight ,
39+ width: pixelWidth ,
40+ height: pixelHeight ,
4041 bitsPerComponent: 8 ,
41- bytesPerRow: imageWidth * 4 ,
42+ bytesPerRow: bytesPerRow ,
4243 space: CGColorSpaceCreateDeviceRGB ( ) ,
4344 bitmapInfo: CGImageAlphaInfo . premultipliedLast. rawValue | CGBitmapInfo . byteOrder32Big. rawValue
44- ) !
45- context. draw ( croppedCGImage, in: CGRect ( x: 0 , y: 0 , width: imageWidth, height: imageHeight) )
46- var planarRGB = [ UInt8] ( repeating: 0 , count: pixelCount * 3 )
45+ ) else { return nil }
46+
47+ context. draw ( cgImage, in: CGRect ( x: 0 , y: 0 , width: pixelWidth, height: pixelHeight) )
48+
49+ var rgbBytes = [ UInt8] ( repeating: 0 , count: pixelCount * 3 )
4750 for pixelIndex in 0 ..< pixelCount {
48- let sourceOffset = pixelIndex * 4
49- planarRGB [ pixelIndex] = rgbaBuffer [ sourceOffset ]
50- planarRGB [ pixelIndex + pixelCount] = rgbaBuffer [ sourceOffset + 1 ]
51- planarRGB [ pixelIndex + pixelCount * 2 ] = rgbaBuffer [ sourceOffset + 2 ]
51+ let sourceIndex = pixelIndex * bytesPerPixel
52+ rgbBytes [ pixelIndex] = rgbaBuffer [ sourceIndex + 0 ]
53+ rgbBytes [ pixelIndex + pixelCount] = rgbaBuffer [ sourceIndex + 1 ]
54+ rgbBytes [ pixelIndex + 2 * pixelCount ] = rgbaBuffer [ sourceIndex + 2 ]
5255 }
53- return Image ( data: Data ( planarRGB) , width: 336 , height: 336 , channels: 3 )
56+ return rgbBytes
57+ }
58+
59+ func rgbBytesNormalized( mean: [ Float ] = [ 0 , 0 , 0 ] , std: [ Float ] = [ 1 , 1 , 1 ] ) -> [ Float ] ? {
60+ precondition ( mean. count == 3 && std. count == 3 )
61+ precondition ( std [ 0 ] != 0 && std [ 1 ] != 0 && std [ 2 ] != 0 )
62+ guard let rgbBytes = rgbBytes ( ) else { return nil }
63+ let pixelCount = rgbBytes. count / 3
64+ var rgbBytesNormalized = [ Float] ( repeating: 0 , count: pixelCount * 3 )
65+ for pixelIndex in 0 ..< pixelCount {
66+ rgbBytesNormalized [ pixelIndex] =
67+ ( Float ( rgbBytes [ pixelIndex] ) / 255.0 - mean[ 0 ] ) / std[ 0 ]
68+ rgbBytesNormalized [ pixelIndex + pixelCount] =
69+ ( Float ( rgbBytes [ pixelIndex + pixelCount] ) / 255.0 - mean[ 1 ] ) / std[ 1 ]
70+ rgbBytesNormalized [ pixelIndex + 2 * pixelCount] =
71+ ( Float ( rgbBytes [ pixelIndex + 2 * pixelCount] ) / 255.0 - mean[ 2 ] ) / std[ 2 ]
72+ }
73+ return rgbBytesNormalized
74+ }
75+
76+ func asImage( _ sideSize: CGFloat ) -> Image {
77+ return Image (
78+ data: Data ( centerCropped ( to: sideSize) . rgbBytes ( ) ?? [ ] ) ,
79+ width: Int ( sideSize) ,
80+ height: Int ( sideSize) ,
81+ channels: 3
82+ )
83+ }
84+
85+ func asNormalizedImage(
86+ _ sideSize: CGFloat ,
87+ mean: [ Float ] = [ 0.485 , 0.456 , 0.406 ] ,
88+ std: [ Float ] = [ 0.229 , 0.224 , 0.225 ]
89+ ) -> Image {
90+ return Image (
91+ float: ( centerCropped ( to: sideSize) . rgbBytesNormalized ( mean: mean, std: std) ?? [ ] ) . withUnsafeBufferPointer { Data ( buffer: $0) } ,
92+ width: Int ( sideSize) ,
93+ height: Int ( sideSize) ,
94+ channels: 3
95+ )
5496 }
5597}
5698
5799class MultimodalRunnerTest : XCTestCase {
58- let systemPrompt = " A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: "
59- let assistantPrompt = " ASSISTANT: "
100+ let systemPrompt = " A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. "
60101 let userPrompt = " What's on the picture? "
61- let sequenceLength = 768
102+
103+ func testGemma( ) {
104+ let chatTemplate = " <start_of_turn>user \n %@<end_of_turn> \n <start_of_turn>model "
105+ let sideSize : CGFloat = 896
106+ let sequenceLength = 768
107+ let bundle = Bundle ( for: type ( of: self ) )
108+ guard let modelPath = bundle. path ( forResource: " gemma3 " , ofType: " pte " ) ,
109+ let tokenizerPath = bundle. path ( forResource: " gemma3_tokenizer " , ofType: " model " ) ,
110+ let imagePath = bundle. path ( forResource: " IMG_0005 " , ofType: " jpg " ) ,
111+ let uiImage = UIImage ( contentsOfFile: imagePath) else {
112+ XCTFail ( " Couldn't find model or tokenizer files " )
113+ return
114+ }
115+ let runner = MultimodalRunner ( modelPath: modelPath, tokenizerPath: tokenizerPath)
116+ var text = " "
117+
118+ do {
119+ try runner. generate ( [
120+ MultimodalInput ( systemPrompt) ,
121+ MultimodalInput ( uiImage. asNormalizedImage ( sideSize) ) ,
122+ MultimodalInput ( String ( format: chatTemplate, userPrompt) ) ,
123+ ] , Config {
124+ $0. sequenceLength = sequenceLength
125+ } ) { token in
126+ text += token
127+ if token == " <end_of_turn> " {
128+ runner. stop ( )
129+ }
130+ }
131+ } catch {
132+ XCTFail ( " Failed to generate text with error \( error) " )
133+ }
134+ XCTAssertTrue ( text. lowercased ( ) . contains ( " waterfall " ) )
135+
136+ text = " "
137+ runner. reset ( )
138+ do {
139+ try runner. generate ( [
140+ MultimodalInput ( systemPrompt) ,
141+ MultimodalInput ( uiImage. asNormalizedImage ( sideSize) ) ,
142+ MultimodalInput ( String ( format: chatTemplate, userPrompt) ) ,
143+ ] , Config {
144+ $0. sequenceLength = sequenceLength
145+ } ) { token in
146+ text += token
147+ if token == " <end_of_turn> " {
148+ runner. stop ( )
149+ }
150+ }
151+ } catch {
152+ XCTFail ( " Failed to generate text with error \( error) " )
153+ }
154+ XCTAssertTrue ( text. lowercased ( ) . contains ( " waterfall " ) )
155+ }
62156
63157 func testLLaVA( ) {
158+ let chatTemplate = " USER: %@ ASSISTANT: "
159+ let sideSize : CGFloat = 336
160+ let sequenceLength = 768
64161 let bundle = Bundle ( for: type ( of: self ) )
65162 guard let modelPath = bundle. path ( forResource: " llava " , ofType: " pte " ) ,
66- let tokenizerPath = bundle. path ( forResource: " tokenizer " , ofType: " bin " ) ,
163+ let tokenizerPath = bundle. path ( forResource: " llava_tokenizer " , ofType: " bin " ) ,
67164 let imagePath = bundle. path ( forResource: " IMG_0005 " , ofType: " jpg " ) ,
68165 let uiImage = UIImage ( contentsOfFile: imagePath) else {
69166 XCTFail ( " Couldn't find model or tokenizer files " )
@@ -75,8 +172,8 @@ class MultimodalRunnerTest: XCTestCase {
75172 do {
76173 try runner. generate ( [
77174 MultimodalInput ( systemPrompt) ,
78- MultimodalInput ( uiImage. asImage ( ) ) ,
79- MultimodalInput ( " \( userPrompt) \( assistantPrompt ) " ) ,
175+ MultimodalInput ( uiImage. asImage ( sideSize ) ) ,
176+ MultimodalInput ( String ( format : chatTemplate , userPrompt) ) ,
80177 ] , Config {
81178 $0. sequenceLength = sequenceLength
82179 } ) { token in
@@ -92,8 +189,8 @@ class MultimodalRunnerTest: XCTestCase {
92189 do {
93190 try runner. generate ( [
94191 MultimodalInput ( systemPrompt) ,
95- MultimodalInput ( uiImage. asImage ( ) ) ,
96- MultimodalInput ( " \( userPrompt) \( assistantPrompt ) " ) ,
192+ MultimodalInput ( uiImage. asImage ( sideSize ) ) ,
193+ MultimodalInput ( String ( format : chatTemplate , userPrompt) ) ,
97194 ] , Config {
98195 $0. sequenceLength = sequenceLength
99196 } ) { token in
0 commit comments