diff --git a/examples/qualcomm/oss_scripts/llama/model/static_llama.py b/examples/qualcomm/oss_scripts/llama/model/static_llama.py index 09cc7504224..40044db7428 100755 --- a/examples/qualcomm/oss_scripts/llama/model/static_llama.py +++ b/examples/qualcomm/oss_scripts/llama/model/static_llama.py @@ -37,7 +37,7 @@ def __init__(self, config: ModelArgs, output_new_cache_only=False): super().__init__() self.dim = config.dim self.n_heads = config.n_heads - self.head_dim = config.dim // config.n_heads + self.head_dim = config.head_dim self.n_kv_heads = config.n_kv_heads self.num_key_value_groups = config.n_heads // self.n_kv_heads self.max_seq_len = config.max_seq_len @@ -304,7 +304,7 @@ def __init__( ): super().__init__() self.dim = config.dim - self.head_dim = config.dim // config.n_heads + self.head_dim = config.head_dim self.max_batch_size = config.max_batch_size self.max_seq_len = config.max_seq_len self.n_heads = config.n_heads @@ -328,9 +328,11 @@ def __init__( self.output = nn.Linear(config.dim, config.vocab_size, bias=False) self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) freqs_cos, freqs_sin = precompute_freqs_cis( - config.dim // config.n_heads, + config.head_dim, config.max_seq_len, config.rope_freq_base, + config.use_scaled_rope, + config.rope_scale_factor, ) self.register_buffer("freqs_cos", freqs_cos, persistent=False) self.register_buffer("freqs_sin", freqs_sin, persistent=False)