Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/axolotl/integrations/liger/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,12 @@ def check_deprecated_swiglu(cls, data):
)
data["liger_glu_activation"] = data.pop("liger_swiglu")
return data

@model_validator(mode="before")
@classmethod
def check_tiled_mlp_conflict(cls, data):
if data.get("liger_glu_activation") is True and data.get("tiled_mlp") is True:
raise ValueError(
"You cannot have both `liger_glu_activation` and `tiled_mlp` set."
)
return data
5 changes: 5 additions & 0 deletions src/axolotl/loaders/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,4 +295,9 @@ def _load_mistral_common_tokenizer(cfg: DictDefault):
LOG.info(
"No Chat template selected. Consider adding a chat template for easier inference."
)

# make the tokenizer.pad call quieter 🤐
if hasattr(tokenizer, "deprecation_warnings"):
tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True

return tokenizer
7 changes: 7 additions & 0 deletions src/axolotl/monkeypatch/tiled_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import torch.distributed as dist

from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
from axolotl.utils.logging import get_logger

LOG = get_logger(__name__)


def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
Expand Down Expand Up @@ -63,6 +66,10 @@ def tiled_mlp_forward(self, x):

mlp_cls.forward = tiled_mlp_forward
mlp_cls._compute_params = [] # pylint: disable=protected-access
LOG.info(
f"Successfully monkey-patched TiledMLP for model_type: {model_type}",
main_process_only=True,
)
except (ImportError, AttributeError) as e:
raise RuntimeError(
f"Could not import MLP class for model_type: {model_type}. "
Expand Down
7 changes: 4 additions & 3 deletions src/axolotl/utils/schemas/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# pylint: disable=too-many-boolean-expressions

import json
import logging
import tempfile
from pathlib import Path

Expand All @@ -13,11 +12,12 @@
)
from transformers.utils.import_utils import is_torch_npu_available

from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType

# pylint: disable=too-many-lines

LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)

SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}

Expand Down Expand Up @@ -116,7 +116,8 @@ def check_eval_packing(cls, data):
and not data.get("eval_table_size")
):
LOG.info(
"explicitly setting `eval_sample_packing` to match `sample_packing`"
"explicitly setting `eval_sample_packing` to match `sample_packing`",
main_process_only=True,
Comment thread
winglian marked this conversation as resolved.
)
data["eval_sample_packing"] = True

Expand Down