diff --git a/Transformer/PythonCheckpointReader.swift b/Transformer/PythonCheckpointReader.swift index 08fc30795fe..2a175c10f36 100644 --- a/Transformer/PythonCheckpointReader.swift +++ b/Transformer/PythonCheckpointReader.swift @@ -14,32 +14,22 @@ import TensorFlow -struct Config { +struct Config : Codable { let vocabSize: Int let contextSize: Int let embeddingSize: Int let headCount: Int let layerCount: Int -} -extension Config { - init(dictionary: [String: Int]) { - vocabSize = dictionary["n_vocab"]! - contextSize = dictionary["n_ctx"]! - embeddingSize = dictionary["n_embd"]! - headCount = dictionary["n_head"]! - layerCount = dictionary["n_layer"]! + enum CodingKeys: String, CodingKey { + case vocabSize = "n_vocab" + case contextSize = "n_ctx" + case embeddingSize = "n_embd" + case headCount = "n_head" + case layerCount = "n_layer" } } -let config = Config(dictionary: [ - "n_vocab": 50257, - "n_ctx": 1024, - "n_embd": 768, - "n_head": 12, - "n_layer": 12 -]) - func readTensor( fromPath path: String, name: String, @@ -55,18 +45,23 @@ func readTensor( } protocol InitializableFromPythonCheckpoint { - init(contentsOfPythonCheckpointFile path: String, scope: String) + init(contentsOfPythonCheckpointFile path: String, config: Config, scope: String) } extension Dense: InitializableFromPythonCheckpoint { - init(contentsOfPythonCheckpointFile path: String, scope: String) { + init(contentsOfPythonCheckpointFile path: String, config: Config, scope: String) { let kernel = readTensor(fromPath: path, name: scope + "/w", scalarType: Scalar.self) self.init( weight: kernel.squeezingShape(at: 0), bias: readTensor(fromPath: path, name: scope + "/b", scalarType: Scalar.self), activation: identity) } - init(contentsOfPythonCheckpointFile path: String, scope: String, activation: String) { + init( + contentsOfPythonCheckpointFile path: String, + config: Config, + scope: String, + activation: String + ) { let kernel = readTensor(fromPath: path, name: scope + "/w", scalarType: Scalar.self) self.init( weight: kernel.squeezingShape(at: 0), @@ -76,7 +71,7 @@ extension Dense: InitializableFromPythonCheckpoint { } extension LayerNorm: InitializableFromPythonCheckpoint { - init(contentsOfPythonCheckpointFile path: String, scope: String) { + init(contentsOfPythonCheckpointFile path: String, config: Config, scope: String) { self.init( offset: readTensor(fromPath: path, name: scope + "/b", scalarType: Scalar.self), scale: readTensor(fromPath: path, name: scope + "/g", scalarType: Scalar.self), @@ -86,48 +81,55 @@ extension LayerNorm: InitializableFromPythonCheckpoint { } extension MultiHeadAttention: InitializableFromPythonCheckpoint { - init(contentsOfPythonCheckpointFile path: String, scope: String) { + init(contentsOfPythonCheckpointFile path: String, config: Config, scope: String) { attention = Attention( size: config.embeddingSize / config.headCount, causal: true, dropProbability: 0.2) wqkv = TimeDistributed(Dense( contentsOfPythonCheckpointFile: path, + config: config, scope: scope + "/c_attn")) wo = TimeDistributed(Dense( contentsOfPythonCheckpointFile: path, + config: config, scope: scope + "/c_proj")) - headCount = 12 + headCount = config.headCount } } extension FeedForward: InitializableFromPythonCheckpoint { - init(contentsOfPythonCheckpointFile path: String, scope: String) { + init(contentsOfPythonCheckpointFile path: String, config: Config, scope: String) { dense1 = TimeDistributed(Dense( contentsOfPythonCheckpointFile: path, - scope: scope + "/c_fc", activation: "gelu")) + config: config, + scope: scope + "/c_fc", + activation: "gelu")) dense2 = TimeDistributed(Dense( contentsOfPythonCheckpointFile: path, + config: config, scope: scope + "/c_proj")) dropout = Dropout(probability: 0.2) } } extension EncoderLayer: InitializableFromPythonCheckpoint { - init(contentsOfPythonCheckpointFile path: String, scope: String) { + init(contentsOfPythonCheckpointFile path: String, config: Config, scope: String) { selfAttention = MultiHeadAttention( - contentsOfPythonCheckpointFile: path, - scope: scope + "/attn") + contentsOfPythonCheckpointFile: path, config: config, scope: scope + "/attn") selfAttentionDropout = Dropout(probability: 0.2) - selfAttentionNorm = LayerNorm(contentsOfPythonCheckpointFile: path, scope: scope + "/ln_1") - feedForward = FeedForward(contentsOfPythonCheckpointFile: path, scope: scope + "/mlp") + selfAttentionNorm = LayerNorm( + contentsOfPythonCheckpointFile: path, config: config, scope: scope + "/ln_1") + feedForward = FeedForward( + contentsOfPythonCheckpointFile: path, config: config, scope: scope + "/mlp") feedForwardDropout = Dropout(probability: 0.2) - feedForwardNorm = LayerNorm(contentsOfPythonCheckpointFile: path, scope: scope + "/ln_2") + feedForwardNorm = LayerNorm( + contentsOfPythonCheckpointFile: path, config: config, scope: scope + "/ln_2") } } extension TransformerLM: InitializableFromPythonCheckpoint { - init(contentsOfPythonCheckpointFile path: String, scope: String) { + init(contentsOfPythonCheckpointFile path: String, config: Config, scope: String) { embedding = Embedding( weight: readTensor(fromPath: path, name: scope + "/wte", scalarType: Float.self)) positionalEmbeddings = readTensor( @@ -135,8 +137,10 @@ extension TransformerLM: InitializableFromPythonCheckpoint { name: scope + "/wpe", scalarType: Float.self) layers = (0.."])! var tokens = Tensor(shape: [1, 1], scalars: [start_token])