From d96946e65694aa78b397d14dd895c9a651704a8b Mon Sep 17 00:00:00 2001 From: VarunGumma Date: Thu, 20 Jun 2024 07:42:54 +0000 Subject: [PATCH] bug fixes --- fairseq/models/transformer/transformer_decoder.py | 6 +++++- fairseq/models/transformer/transformer_encoder.py | 6 ++++-- fairseq/models/transformer/transformer_legacy.py | 7 +++---- fairseq_cli/train.py | 8 +++----- 4 files changed, 15 insertions(+), 12 deletions(-) diff --git a/fairseq/models/transformer/transformer_decoder.py b/fairseq/models/transformer/transformer_decoder.py index 890bf48cad..5129cb7277 100755 --- a/fairseq/models/transformer/transformer_decoder.py +++ b/fairseq/models/transformer/transformer_decoder.py @@ -27,6 +27,8 @@ from fairseq.modules.checkpoint_activations import checkpoint_wrapper from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ +device = "cuda" if torch.cuda.is_available() else "cpu" + # rewrite name for backward compatibility in `make_generation_fast_` def module_name_fordropout(module_name: str) -> str: @@ -135,7 +137,9 @@ def __init__( self.num_layers = len(self.layers) if cfg.decoder.normalize_before and not cfg.no_decoder_final_norm: - self.layer_norm = self.normalization(self.embed_dim, rms=cfg.decoder.use_rmsnorm) + self.layer_norm = self.normalization( + self.embed_dim, rms=cfg.decoder.use_rmsnorm + ) else: self.layer_norm = None diff --git a/fairseq/models/transformer/transformer_encoder.py b/fairseq/models/transformer/transformer_encoder.py index ab5cd0c437..0a7445995c 100755 --- a/fairseq/models/transformer/transformer_encoder.py +++ b/fairseq/models/transformer/transformer_encoder.py @@ -25,6 +25,8 @@ from fairseq.modules.checkpoint_activations import checkpoint_wrapper from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ +device = "cuda" if torch.cuda.is_available() else "cpu" + # rewrite name for backward compatibility in `make_generation_fast_` def module_name_fordropout(module_name: str) -> str: @@ -76,8 +78,8 @@ def __init__(self, cfg, dictionary, embed_tokens, return_fc=False): ) self.layernorm_embedding = ( - self.normalization(embed_dim, rms=cfg.encoder.use_rmsnorm) - if cfg.layernorm_embedding + self.normalization(embed_dim, rms=cfg.encoder.use_rmsnorm) + if cfg.layernorm_embedding else None ) diff --git a/fairseq/models/transformer/transformer_legacy.py b/fairseq/models/transformer/transformer_legacy.py index 7b44101d30..eb54e577d2 100755 --- a/fairseq/models/transformer/transformer_legacy.py +++ b/fairseq/models/transformer/transformer_legacy.py @@ -301,7 +301,7 @@ def _transformer_base18L(args): args.activation_fn = getattr(args, "activation_fn", "gelu") args.encoder_layers = getattr(args, "encoder_layers", 18) args.decoder_layers = getattr(args, "decoder_layers", 18) - args.layernorm_embedding = getattr(args, "layernorm_embedding", True) + args.layernorm_embedding = getattr(args, "layernorm_embedding", False) args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True) base_architecture(args) @@ -309,9 +309,8 @@ def _transformer_base18L(args): @register_model_architecture("transformer", "transformer_IT2_dist") def transformer_base18L(args): - args.share_decoder_input_output_embed = getattr( - args, "share_decoder_input_output_embed", True - ) + args.layernorm_embedding = getattr(args, "layernorm_embedding", True) + args.share_decoder_input_output_embed = getattr(args, "share_decoder_input_output_embed", True) _transformer_base18L(args) diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 10563ee2b6..6941cc4935 100755 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -45,10 +45,9 @@ # copied from: https://github.com/microsoft/LoRA/blob/main/loralib/utils.py -def mark_only_lora_as_trainable(model: torch.nn.Module, bias: str = "none") -> None: +def mark_only_lora_as_trainable(model, bias) -> None: for n, p in model.named_parameters(): - if "lora_" not in n: - p.requires_grad = False + p.requires_grad = "lora_" in n if bias == "none": return elif bias == "all": @@ -226,11 +225,10 @@ def main(cfg: FairseqConfig) -> None: "merge_weights": True, } - lora_modules = lora_config["target_modules"].split(",") + lora_modules = set(lora_config["target_modules"].split(",")) lora_bias = lora_config.get("bias", "none") replace_with_lora(model, lora_modules, lora_params) mark_only_lora_as_trainable(model, bias=lora_bias) - print(model) ### EXPERIMENTAL :: NOT TO BE USED UNTIL TESTED ### logger.info(