Skip to content

Commit

Permalink
bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
VarunGumma committed Jun 20, 2024
1 parent c96ae13 commit d96946e
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 12 deletions.
6 changes: 5 additions & 1 deletion fairseq/models/transformer/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_

device = "cuda" if torch.cuda.is_available() else "cpu"


# rewrite name for backward compatibility in `make_generation_fast_`
def module_name_fordropout(module_name: str) -> str:
Expand Down Expand Up @@ -135,7 +137,9 @@ def __init__(
self.num_layers = len(self.layers)

if cfg.decoder.normalize_before and not cfg.no_decoder_final_norm:
self.layer_norm = self.normalization(self.embed_dim, rms=cfg.decoder.use_rmsnorm)
self.layer_norm = self.normalization(
self.embed_dim, rms=cfg.decoder.use_rmsnorm
)
else:
self.layer_norm = None

Expand Down
6 changes: 4 additions & 2 deletions fairseq/models/transformer/transformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_

device = "cuda" if torch.cuda.is_available() else "cpu"


# rewrite name for backward compatibility in `make_generation_fast_`
def module_name_fordropout(module_name: str) -> str:
Expand Down Expand Up @@ -76,8 +78,8 @@ def __init__(self, cfg, dictionary, embed_tokens, return_fc=False):
)

self.layernorm_embedding = (
self.normalization(embed_dim, rms=cfg.encoder.use_rmsnorm)
if cfg.layernorm_embedding
self.normalization(embed_dim, rms=cfg.encoder.use_rmsnorm)
if cfg.layernorm_embedding
else None
)

Expand Down
7 changes: 3 additions & 4 deletions fairseq/models/transformer/transformer_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,17 +301,16 @@ def _transformer_base18L(args):
args.activation_fn = getattr(args, "activation_fn", "gelu")
args.encoder_layers = getattr(args, "encoder_layers", 18)
args.decoder_layers = getattr(args, "decoder_layers", 18)
args.layernorm_embedding = getattr(args, "layernorm_embedding", True)
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
base_architecture(args)


@register_model_architecture("transformer", "transformer_IT2_dist")
def transformer_base18L(args):
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", True
)
args.layernorm_embedding = getattr(args, "layernorm_embedding", True)
args.share_decoder_input_output_embed = getattr(args, "share_decoder_input_output_embed", True)
_transformer_base18L(args)


Expand Down
8 changes: 3 additions & 5 deletions fairseq_cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,9 @@


# copied from: https://github.com/microsoft/LoRA/blob/main/loralib/utils.py
def mark_only_lora_as_trainable(model: torch.nn.Module, bias: str = "none") -> None:
def mark_only_lora_as_trainable(model, bias) -> None:
for n, p in model.named_parameters():
if "lora_" not in n:
p.requires_grad = False
p.requires_grad = "lora_" in n
if bias == "none":
return
elif bias == "all":
Expand Down Expand Up @@ -226,11 +225,10 @@ def main(cfg: FairseqConfig) -> None:
"merge_weights": True,
}

lora_modules = lora_config["target_modules"].split(",")
lora_modules = set(lora_config["target_modules"].split(","))
lora_bias = lora_config.get("bias", "none")
replace_with_lora(model, lora_modules, lora_params)
mark_only_lora_as_trainable(model, bias=lora_bias)
print(model)
### EXPERIMENTAL :: NOT TO BE USED UNTIL TESTED ###

logger.info(
Expand Down

0 comments on commit d96946e

Please sign in to comment.