diff --git a/nemo/collections/nlp/parts/peft_config.py b/nemo/collections/nlp/parts/peft_config.py index dd75747fd73c..524a7fb62368 100644 --- a/nemo/collections/nlp/parts/peft_config.py +++ b/nemo/collections/nlp/parts/peft_config.py @@ -60,10 +60,14 @@ def __init__(self, cfg): else: kv_channels = cfg.kv_channels projection_size = kv_channels * cfg.num_attention_heads + num_query_groups = cfg.get("num_query_groups", None) + if num_query_groups is None: + num_query_groups = cfg.num_attention_heads + qkv_projection_size = projection_size + (2 * kv_channels * num_query_groups) config_args = { "in_features": cfg.hidden_size, - "out_features": 3 * projection_size, + "out_features": qkv_projection_size, "dim": lora_cfg.adapter_dim, "norm_position": None, "norm_type": None,