@@ -11,47 +11,89 @@ import ExecuTorchLLM
1111import XCTest
1212
1313extension UIImage {
14- func asImage( ) -> Image {
15- let targetSide = CGFloat ( 336 )
16- let scale = max ( targetSide / size. width, targetSide / size. height)
17- let scaledSize = CGSize ( width: size. width * scale, height: size. height * scale)
14+ func centerCropped( to sideSize: CGFloat ) -> UIImage {
15+ precondition ( sideSize > 0 )
1816 let format = UIGraphicsImageRendererFormat . default ( )
1917 format. scale = 1
20- let scaledImage = UIGraphicsImageRenderer ( size: scaledSize, format: format) . image { _ in
21- draw ( in: CGRect ( origin: . zero, size: scaledSize) )
22- }
23- guard let scaledCGImage = scaledImage. cgImage else {
24- return Image ( data: Data ( ) , width: 336 , height: 336 , channels: 3 )
25- }
26- let cropRect = CGRect (
27- x: ( ( scaledSize. width - targetSide) * 0.5 ) . rounded ( . down) ,
28- y: ( ( scaledSize. height - targetSide) * 0.5 ) . rounded ( . down) ,
29- width: targetSide. rounded ( . down) ,
30- height: targetSide. rounded ( . down)
31- )
32- let croppedCGImage = scaledCGImage. cropping ( to: cropRect) ?? scaledCGImage
33- let imageWidth = croppedCGImage. width
34- let imageHeight = croppedCGImage. height
35- let pixelCount = imageWidth * imageHeight
36- var rgbaBuffer = [ UInt8] ( repeating: 0 , count: pixelCount * 4 )
37- let context = CGContext (
18+ format. opaque = false
19+ return UIGraphicsImageRenderer ( size: CGSize ( width: sideSize, height: sideSize) , format: format)
20+ . image { _ in
21+ let scaleFactor = max ( sideSize / size. width, sideSize / size. height)
22+ let scaledWidth = size. width * scaleFactor
23+ let scaledHeight = size. height * scaleFactor
24+ let originX = ( sideSize - scaledWidth) / 2
25+ let originY = ( sideSize - scaledHeight) / 2
26+ draw ( in: CGRect ( x: originX, y: originY, width: scaledWidth, height: scaledHeight) )
27+ }
28+ }
29+
30+ func rgbBytes( ) -> [ UInt8 ] ? {
31+ guard let cgImage = cgImage else { return nil }
32+ let pixelWidth = Int ( cgImage. width)
33+ let pixelHeight = Int ( cgImage. height)
34+ let pixelCount = pixelWidth * pixelHeight
35+ let bytesPerPixel = 4
36+ let bytesPerRow = pixelWidth * bytesPerPixel
37+ var rgbaBuffer = [ UInt8] ( repeating: 0 , count: pixelCount * bytesPerPixel)
38+ guard let context = CGContext (
3839 data: & rgbaBuffer,
39- width: imageWidth ,
40- height: imageHeight ,
40+ width: pixelWidth ,
41+ height: pixelHeight ,
4142 bitsPerComponent: 8 ,
42- bytesPerRow: imageWidth * 4 ,
43+ bytesPerRow: bytesPerRow ,
4344 space: CGColorSpaceCreateDeviceRGB ( ) ,
4445 bitmapInfo: CGImageAlphaInfo . premultipliedLast. rawValue | CGBitmapInfo . byteOrder32Big. rawValue
45- ) !
46- context. draw ( croppedCGImage, in: CGRect ( x: 0 , y: 0 , width: imageWidth, height: imageHeight) )
47- var planarRGB = [ UInt8] ( repeating: 0 , count: pixelCount * 3 )
46+ ) else { return nil }
47+
48+ context. draw ( cgImage, in: CGRect ( x: 0 , y: 0 , width: pixelWidth, height: pixelHeight) )
49+
50+ var rgbBytes = [ UInt8] ( repeating: 0 , count: pixelCount * 3 )
51+ for pixelIndex in 0 ..< pixelCount {
52+ let sourceIndex = pixelIndex * bytesPerPixel
53+ rgbBytes [ pixelIndex] = rgbaBuffer [ sourceIndex + 0 ]
54+ rgbBytes [ pixelIndex + pixelCount] = rgbaBuffer [ sourceIndex + 1 ]
55+ rgbBytes [ pixelIndex + 2 * pixelCount] = rgbaBuffer [ sourceIndex + 2 ]
56+ }
57+ return rgbBytes
58+ }
59+
60+ func rgbBytesNormalized( mean: [ Float ] = [ 0 , 0 , 0 ] , std: [ Float ] = [ 1 , 1 , 1 ] ) -> [ Float ] ? {
61+ precondition ( mean. count == 3 && std. count == 3 )
62+ precondition ( std [ 0 ] != 0 && std [ 1 ] != 0 && std [ 2 ] != 0 )
63+ guard let rgbBytes = rgbBytes ( ) else { return nil }
64+ let pixelCount = rgbBytes. count / 3
65+ var rgbBytesNormalized = [ Float] ( repeating: 0 , count: pixelCount * 3 )
4866 for pixelIndex in 0 ..< pixelCount {
49- let sourceOffset = pixelIndex * 4
50- planarRGB [ pixelIndex] = rgbaBuffer [ sourceOffset]
51- planarRGB [ pixelIndex + pixelCount] = rgbaBuffer [ sourceOffset + 1 ]
52- planarRGB [ pixelIndex + pixelCount * 2 ] = rgbaBuffer [ sourceOffset + 2 ]
67+ rgbBytesNormalized [ pixelIndex] =
68+ ( Float ( rgbBytes [ pixelIndex] ) / 255.0 - mean[ 0 ] ) / std[ 0 ]
69+ rgbBytesNormalized [ pixelIndex + pixelCount] =
70+ ( Float ( rgbBytes [ pixelIndex + pixelCount] ) / 255.0 - mean[ 1 ] ) / std[ 1 ]
71+ rgbBytesNormalized [ pixelIndex + 2 * pixelCount] =
72+ ( Float ( rgbBytes [ pixelIndex + 2 * pixelCount] ) / 255.0 - mean[ 2 ] ) / std[ 2 ]
5373 }
54- return Image ( data: Data ( planarRGB) , width: 336 , height: 336 , channels: 3 )
74+ return rgbBytesNormalized
75+ }
76+
77+ func asImage( _ sideSize: CGFloat ) -> Image {
78+ return Image (
79+ data: Data ( centerCropped ( to: sideSize) . rgbBytes ( ) ?? [ ] ) ,
80+ width: Int ( sideSize) ,
81+ height: Int ( sideSize) ,
82+ channels: 3
83+ )
84+ }
85+
86+ func asNormalizedImage(
87+ _ sideSize: CGFloat ,
88+ mean: [ Float ] = [ 0.485 , 0.456 , 0.406 ] ,
89+ std: [ Float ] = [ 0.229 , 0.224 , 0.225 ]
90+ ) -> Image {
91+ return Image (
92+ float: ( centerCropped ( to: sideSize) . rgbBytesNormalized ( mean: mean, std: std) ?? [ ] ) . withUnsafeBufferPointer { Data ( buffer: $0) } ,
93+ width: Int ( sideSize) ,
94+ height: Int ( sideSize) ,
95+ channels: 3
96+ )
5597 }
5698}
5799
@@ -120,7 +162,7 @@ class MultimodalRunnerTest: XCTestCase {
120162 let sequenceLength = 768
121163 let bundle = Bundle ( for: type ( of: self ) )
122164 guard let modelPath = bundle. path ( forResource: " llava " , ofType: " pte " ) ,
123- let tokenizerPath = bundle. path ( forResource: " tokenizer " , ofType: " bin " ) ,
165+ let tokenizerPath = bundle. path ( forResource: " llava_tokenizer " , ofType: " bin " ) ,
124166 let imagePath = bundle. path ( forResource: " IMG_0005 " , ofType: " jpg " ) ,
125167 let uiImage = UIImage ( contentsOfFile: imagePath) else {
126168 XCTFail ( " Couldn't find model or tokenizer files " )
@@ -132,8 +174,8 @@ class MultimodalRunnerTest: XCTestCase {
132174 do {
133175 try runner. generate ( [
134176 MultimodalInput ( systemPrompt) ,
135- MultimodalInput ( uiImage. asImage ( ) ) ,
136- MultimodalInput ( " \( userPrompt) \( assistantPrompt ) " ) ,
177+ MultimodalInput ( uiImage. asImage ( sideSize ) ) ,
178+ MultimodalInput ( String ( format : chatTemplate , userPrompt) ) ,
137179 ] , Config {
138180 $0. sequenceLength = sequenceLength
139181 } ) { token in
@@ -149,8 +191,8 @@ class MultimodalRunnerTest: XCTestCase {
149191 do {
150192 try runner. generate ( [
151193 MultimodalInput ( systemPrompt) ,
152- MultimodalInput ( uiImage. asImage ( ) ) ,
153- MultimodalInput ( " \( userPrompt) \( assistantPrompt ) " ) ,
194+ MultimodalInput ( uiImage. asImage ( sideSize ) ) ,
195+ MultimodalInput ( String ( format : chatTemplate , userPrompt) ) ,
154196 ] , Config {
155197 $0. sequenceLength = sequenceLength
156198 } ) { token in
0 commit comments