Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 23 additions & 12 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,13 @@ jobs:
- name: Check out repository code
uses: actions/checkout@v4

# - name: Restore Cache from S3
# id: hf-cache-restore-s3
# run: |
# mkdir -p ~/.cache/huggingface/hub
# curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd
#
- name: Restore Cache from S3
id: hf-cache-restore-s3
run: |
mkdir -p ~/.cache/huggingface/hub
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
ls -ltr ~/.cache/huggingface/hub/

- name: Setup Python
uses: actions/setup-python@v5
with:
Expand Down Expand Up @@ -111,6 +112,9 @@ jobs:
run: |
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures

- name: Show HF cache
run: hf cache scan

- name: Run tests
run: |
df -h
Expand All @@ -122,6 +126,9 @@ jobs:
df -h
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml

- name: Show HF cache
run: hf cache scan

- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
Expand Down Expand Up @@ -149,12 +156,13 @@ jobs:
- name: Check out repository code
uses: actions/checkout@v4

# - name: Restore Cache from S3
# id: hf-cache-restore-s3
# run: |
# mkdir -p ~/.cache/huggingface/hub
# curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd
#
- name: Restore Cache from S3
id: hf-cache-restore-s3
run: |
mkdir -p ~/.cache/huggingface/hub
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
ls -ltr ~/.cache/huggingface/hub/

- name: Setup Python
uses: actions/setup-python@v5
with:
Expand Down Expand Up @@ -200,6 +208,9 @@ jobs:
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
pytest -v --durations=10 tests/cli/

- name: Show HF cache
run: hf cache scan

gate-skip-e2e:
needs: [pre-commit, pytest, pytest-sdist]
runs-on: ubuntu-latest
Expand Down
2 changes: 1 addition & 1 deletion cicd/multigpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
set -e

# Only run two tests at a time to avoid OOM on GPU (with coverage collection)
pytest -v --durations=10 -n2 \
pytest -v --durations=10 -n2 --maxfail=4 \
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
/workspace/axolotl/tests/e2e/multigpu/ \
Expand Down
10 changes: 5 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ huggingface_hub>=0.36.0
peft>=0.18.0
tokenizers>=0.22.1
transformers==4.57.1
accelerate==1.11.0
datasets==4.4.1
deepspeed>=0.17.0
trl==0.25.0
accelerate==1.12.0
datasets==4.4.2
deepspeed>=0.18.3
trl==0.25.1
hf_xet==1.2.0
kernels>=0.9.0
kernels==0.11.5
trackio>=0.13.0
typing_extensions>=4.14.0

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def get_package_version():
"came_pytorch==0.1.3",
],
"ray": [
"ray[train]",
"ray[train]>=2.52.1",
],
"vllm": [
"vllm==0.10.0",
Expand Down
5 changes: 2 additions & 3 deletions src/axolotl/core/trainers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ def compute_loss(
inputs_key = "labels" if "labels" in inputs else "input_ids"
trainable_tokens = (inputs[inputs_key] != -100).sum()
total_tokens = inputs[inputs_key].numel()
total_tokens = torch.tensor(total_tokens, device=inputs[inputs_key].device)

if is_distributed():
torch.distributed.all_reduce(
Expand All @@ -375,9 +376,7 @@ def compute_loss(
self.state.tokens["trainable"] = (
self.state.tokens["trainable"] + trainable_tokens.detach().cpu()
)
self.state.tokens["total"] = (
self.state.tokens["total"] + torch.as_tensor(total_tokens).cpu()
)
self.state.tokens["total"] = self.state.tokens["total"] + total_tokens.cpu()
# Store per-step trainable tokens for throughput calculation
self.state.tokens["trainable_tokens"] = trainable_tokens.detach().cpu()

Expand Down
30 changes: 30 additions & 0 deletions src/axolotl/monkeypatch/accelerate/parallelism_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,33 @@ def patch_parallelism_config():

ParallelismConfig._validate_accelerator = _validate_accelerator
AcceleratorState.is_fsdp2 = property(patched_is_fsdp2)


def patch_prepare_cp():
import functools

import torch
from accelerate import Accelerator

def patched_prepare_cp(self, *args):
if self.parallelism_config.cp_backend == "deepspeed":
return args

from accelerate.big_modeling import _attach_context_parallel_hooks
from torch.distributed.tensor.experimental import context_parallel
from torch.distributed.tensor.experimental._attention import set_rotate_method

cp_comm_strategy = self.parallelism_config.cp_handler.cp_comm_strategy
set_rotate_method(cp_comm_strategy)

self._cp_context = functools.partial(
context_parallel, mesh=self.torch_device_mesh["cp"]
)

for arg in args:
if isinstance(arg, torch.nn.Module):
_attach_context_parallel_hooks(arg)

return args

Accelerator._prepare_cp = patched_prepare_cp
3 changes: 3 additions & 0 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,9 @@ def setup_parallelism_envs(cfg):
set_accelerate_parallelism_config = True
os.environ["PARALLELISM_CONFIG_CP_SIZE"] = str(cfg.context_parallel_size)
os.environ["ACCELERATE_ALLOW_CP_STANDALONE"] = "true"
from axolotl.monkeypatch.accelerate.parallelism_config import patch_prepare_cp

patch_prepare_cp()
if set_accelerate_parallelism_config:
os.environ["ACCELERATE_USE_PARALLELISM_CONFIG"] = "true"

Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def snapshot_download_w_retry(*args, **kwargs):
"""
with hf_offline_context(True):
try:
return snapshot_download(*args, **kwargs)
return snapshot_download(*args, local_files_only=True, **kwargs)
except LocalEntryNotFoundError:
pass
with hf_offline_context(False):
Expand Down
3 changes: 0 additions & 3 deletions tests/hf_offline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from contextlib import contextmanager
from functools import wraps

from huggingface_hub.utils import reset_sessions


def reload_modules(hf_hub_offline):
# Force reload of the modules that check this variable
Expand All @@ -21,7 +19,6 @@ def reload_modules(hf_hub_offline):
huggingface_hub.constants.HF_HUB_OFFLINE = hf_hub_offline
importlib.reload(datasets.config)
datasets.config.HF_HUB_OFFLINE = hf_hub_offline
reset_sessions()


def enable_hf_offline(test_func):
Expand Down
Loading