diff --git a/.gitignore b/.gitignore index 0a63ae1..262e324 100644 --- a/.gitignore +++ b/.gitignore @@ -136,6 +136,10 @@ dmypy.json # Xcode *.xcworkspace +xcuserdata/ +*.xcuserstate +*.xcuserdatad/ +project.pbxproj.orig # FastVLM models app/FastVLM/model \ No newline at end of file diff --git a/app/FastVLM/FastVLM.swift b/app/FastVLM/FastVLM.swift index 731b0ef..dc643b1 100644 --- a/app/FastVLM/FastVLM.swift +++ b/app/FastVLM/FastVLM.swift @@ -213,10 +213,20 @@ private enum Language { fatalError("one of inputs or inputEmbedding must be non-nil") } - let mask = createAttentionMask(h: h, cache: cache) + let mask = createAttentionMask(h: h, cache: cache, returnArray: true) + let maskArray: MLXArray? = { + switch mask { + case .array(let array): + return array + case .causal, .none: + return nil + case .arrays(_): + return nil + } + }() for (i, layer) in layers.enumerated() { - h = layer(h, mask: mask, cache: cache?[i]) + h = layer(h, mask: maskArray, cache: cache?[i]) } return norm(h) @@ -361,13 +371,21 @@ public class FastVLMProcessor: UserInputProcessor { } public func prepare(prompt: UserInput.Prompt, imageTHW: THW?) -> String { - var messages = prompt.asMessages() - if messages[0]["role"] != "system" { + var messages: [Message] + switch prompt { + case .text(let text): + messages = [["role": "user", "content": text]] + case .messages(let msgs): + messages = msgs + case .chat(let chatMsgs): + messages = chatMsgs.map { ["role": $0.role.rawValue, "content": $0.content] } + } + if (messages[0]["role"] as? String) != "system" { messages.insert(["role": "system", "content": "You are a helpful assistant."], at: 0) } let lastIndex = messages.count - 1 - var lastMessage = messages[lastIndex]["content"] ?? "" + var lastMessage: String = (messages[lastIndex]["content"] as? String) ?? "" // processing_llava.py if let imageTHW { @@ -382,8 +400,8 @@ public class FastVLMProcessor: UserInputProcessor { numImageTokens -= 1 } - lastMessage += Array(repeating: config.imageToken, count: numImageTokens) - .joined() + let imageTokens = String(repeating: config.imageToken, count: numImageTokens) + lastMessage += imageTokens } messages[lastIndex]["content"] = lastMessage @@ -411,7 +429,7 @@ public class FastVLMProcessor: UserInputProcessor { let (pixels, thw) = try preprocess( image: input.images[0].asCIImage(), processing: input.processing) - let image = LMInput.ProcessedImage(pixels: pixels, imageGridThw: [thw]) + let image = LMInput.ProcessedImage(pixels: pixels, frames: [thw]) let prompt = prepare(prompt: input.prompt, imageTHW: thw) let promptTokens = tokenizer.encode(text: prompt) @@ -537,7 +555,7 @@ public class FastVLM: Module, VLMModel, KVCacheDimensionProvider { public func prepare(_ input: LMInput, cache: [any KVCache], windowSize: Int?) throws -> PrepareResult { - let gridThw = input.image?.imageGridThw + let gridThw = input.image?.frames let dtype = DType.float32 let pixels = input.image?.pixels.asType(dtype)