You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
File :OLMo/olmo/train.py
In the following training loop, we will break our pre-training for only 1 epoch ?
@property
def max_epochs(self) -> int:
if isinstance(self.cfg.max_duration, str) and self.cfg.max_duration.endswith("ep"):
return int(self.cfg.max_duration[:-2].strip())
else:
return 1
with torch_profiler as p:
for epoch in range(self.epoch or 0, self.max_epochs):
for batch in self.train_loader:
# Bookkeeping.
# NOTE: To track the global batch size / number of tokens per batch we make the assumption that all
# batches see the same number of tokens, which should be the case for language model pre-training
# (at least when drop_last=True).
# Alternatively we'd have to use a distributed all reduce over seq_len here, but I don't want that
# overhead. So for now I'm putting these assertions here so if the assumption is violated it will
# fail loudly.
batch_size, seq_len = batch["input_ids"].shape
assert seq_len == self.cfg.model.max_sequence_length
assert batch_size == self.cfg.device_train_batch_size
global_batch_size = batch_size * get_world_size() # assumes batch size equal across ranks
self.global_step += 1
self.global_train_examples_seen_this_epoch += global_batch_size
self.global_train_tokens_seen += global_batch_size * seq_len
speed_monitor.batch_start(
self.global_train_tokens_seen,
batch_size * seq_len, # num tokens in batch for this device
# We start monitoring speed after the first batch since the first
# batch might be an outlier due to compiling and other initialization overhead.
record=not first_batch,
)
should_log_this_step = self.should_log_this_step()
# Run train step on batch.
metrics = self.train_step(batch, reduce_global_loss=should_log_this_step)
# Maybe collect other metrics.
if should_log_this_step:
# Speed metrics.
metrics.update(speed_monitor.check())
# System metrics.
metrics.update(self.system_metrics())
# Learning rate metrics.
metrics.update(lr_monitor.check())
# Log metrics to console.
if self.global_step % self.cfg.console_log_interval == 0:
self.log_metrics_to_console(f"[step={self.global_step}/{self.max_steps}]", metrics)
# Log metrics to W&B.
if (
wandb.run is not None
and self.cfg.wandb is not None
and self.global_step % self.cfg.wandb.log_interval == 0
):
wandb.log(metrics, step=self.global_step)
# Check if/when run should be canceled.
if not cancel_initiated and self.global_step % self.cfg.canceled_check_interval == 0:
cancel_initiated, extra_steps = self.check_if_cancelled()
if cancel_initiated:
stop_at = (
self.global_step + extra_steps
if stop_at is None
else min(self.global_step + extra_steps, stop_at)
)
# Maybe save sharded checkpoint.
if save_checkpoints and (
cancel_initiated
or (
self.global_step % self.cfg.save_interval == 0
and self.cfg.save_num_checkpoints_to_keep != 0
)
):
log.info("Saving checkpoint...")
checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded)
log.info(f"Checkpoint saved to {checkpoint_path}")
# Remove any ephemeral checkpoints.
while self.ephemeral_checkpoints:
self.remove_ephemeral_checkpoint()
# Reset speed monitor so that we don't count the time taken to save checkpoints.
speed_monitor.reset()
# If the run was just canceled this will be the final checkpoint.
if cancel_initiated:
save_checkpoints = False
elif (
self.cfg.save_interval_ephemeral is not None
and self.global_step % self.cfg.save_interval_ephemeral == 0
):
log.info("Saving ephemeral checkpoint...")
checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded_ephemeral)
log.info(f"Checkpoint saved to {checkpoint_path}")
# Reset speed monitor so that we don't count the time taken to save checkpoints.
speed_monitor.reset()
# Maybe save unsharded checkpoint.
if (
save_checkpoints
and self.cfg.save_interval_unsharded is not None
and self.global_step % self.cfg.save_interval_unsharded == 0
and self.cfg.save_num_unsharded_checkpoints_to_keep != 0
):
log.info("Saving unsharded checkpoint...")
checkpoint_path, _ = self.save_checkpoint(CheckpointType.unsharded)
log.info(f"Unsharded checkpoint saved to {checkpoint_path}")
# Reset speed monitor so that we don't count the time taken to save checkpoints.
speed_monitor.reset()
# Maybe run evaluations.
if not cancel_initiated and self.global_step % self.cfg.eval_interval == 0:
eval_metrics = self.eval()
# Log metrics to W&B.
if wandb.run is not None:
wandb.log(eval_metrics, step=self.global_step)
# Reset speed monitor so that we don't count the time taken to run evaluations.
speed_monitor.reset()
# Reset model to 'train' mode.
self.fsdp_model.train()
# End of batch.
first_batch = False
if p is not None:
p.step()
if stop_at is not None and self.global_step >= stop_at:
break
# Python Profiler stuff
# We do this now, at the bottom of this loop, so we capture the work of getting the next batch.
if python_profiler is not None:
if self.global_step == 5:
python_profiler.enable()
elif self.global_step == 8:
python_profiler.disable()
python_profiler.print_stats(sort=SortKey.CUMULATIVE)
python_profiler = None
else:
log.info("Training epoch complete")
self.epoch = epoch + 1
self.global_train_examples_seen_this_epoch = 0
if self.epoch < self.max_epochs:
self.dataset.reshuffle()
continue
break
@Xuekai-Zhu , what is the value of "max_duration" in the config that you're using?
If you want it to be more than 1 epoch, say 2 epochs, the config should have max_duration: 2ep.
Yes, i found if i want it to be more than 1 epoch, the config should have max_duration: 2ep.
But when i want use max tokens to control the the training process, i can't reach the max tokens casuing be limited by default 1 epochs.
source tokens 8B, max_duration: 30B, -> training complete at 8B tokens (1 epochs);
❌ can't reach the max_duration set in config.
🐛 Describe the bug
File :OLMo/olmo/train.py
In the following training loop, we will break our pre-training for only 1 epoch ?
Versions
Python 3.10.13
WARNING: Could not find a Python project for directory /scratch2/nlp/zhuxuekai/scaling_law4AI_data/OLMo (tried all parent directories)
-e git+ssh://[email protected]/Xuekai-Zhu/scaling_law4AI_data.git@a15301e68a4dd616e3971c54370cb4a957e4d14c#egg=ai2_olmo
aiohttp==3.9.3
aiosignal==1.3.1
aniso8601==9.0.1
annotated-types==0.6.0
antlr4-python3-runtime==4.9.3
anykeystore==0.2
appdirs==1.4.4
async-timeout==4.0.3
asyncio==3.4.3
attrs==23.2.0
backports.tarfile==1.1.0
beaker-gantry==0.22.2
beaker-py==1.26.4
black==23.12.1
blinker==1.7.0
boltons==24.0.0
boto3==1.34.86
botocore==1.34.86
build==1.2.1
cached_path==1.6.2
cachetools==5.3.3
certifi==2024.2.2
cffi==1.16.0
charset-normalizer==3.3.2
click==8.1.7
click-help-colors==0.9.4
cmake==3.28.3
contourpy==1.2.0
cryptacular==1.6.2
cryptography==42.0.5
cycler==0.12.1
datasets==2.18.0
deepspeed==0.14.0
deepspeed-kernels==0.0.1.dev1698255861
deepspeed-mii==0.2.3
defusedxml==0.7.1
dill==0.3.8
docker==6.1.3
docker-pycreds==0.4.0
docutils==0.21.1
exceptiongroup==1.2.0
face==20.1.1
filelock==3.9.0
Flask==3.0.2
Flask-RESTful==0.3.10
fonttools==4.50.0
frozenlist==1.4.1
fsspec==2024.2.0
ftfy==6.2.0
gitdb==4.0.11
GitPython==3.1.42
glom==23.5.0
google-api-core==2.18.0
google-auth==2.29.0
google-cloud-core==2.4.1
google-cloud-storage==2.16.0
google-crc32c==1.5.0
google-resumable-media==2.7.0
googleapis-common-protos==1.63.0
greenlet==3.0.3
grpcio==1.62.1
grpcio-tools==1.62.1
hjson==3.1.0
huggingface-hub==0.21.4
hupper==1.12.1
idna==3.6
importlib_metadata==7.1.0
iniconfig==2.0.0
isort==5.12.0
itsdangerous==2.1.2
jaraco.classes==3.4.0
jaraco.context==5.3.0
jaraco.functools==4.0.0
jeepney==0.8.0
Jinja2==3.1.2
jmespath==1.0.1
joblib==1.4.0
keyring==25.1.0
kiwisolver==1.4.5
lightning-utilities==0.11.2
markdown-it-py==3.0.0
MarkupSafe==2.1.3
matplotlib==3.8.3
mdurl==0.1.2
Megatron==0.5.1
megatron_core==0.5.0
more-itertools==10.2.0
mpmath==1.3.0
msgspec==0.18.6
multidict==6.0.5
multiprocess==0.70.16
mypy==1.3.0
mypy-extensions==1.0.0
necessary==0.4.3
networkx==3.2.1
nh3==0.2.17
ninja==1.11.1.1
numpy==1.26.4
nvidia-cublas-cu11==11.11.3.6
nvidia-cuda-cupti-cu11==11.8.87
nvidia-cuda-nvrtc-cu11==11.8.89
nvidia-cuda-runtime-cu11==11.8.89
nvidia-cudnn-cu11==8.7.0.84
nvidia-cufft-cu11==10.9.0.58
nvidia-curand-cu11==10.3.0.86
nvidia-cusolver-cu11==11.4.1.48
nvidia-cusparse-cu11==11.7.5.86
nvidia-nccl-cu11==2.19.3
nvidia-nvtx-cu11==11.8.86
oauthlib==3.2.2
omegaconf==2.3.0
packaging==24.0
pandas==2.2.1
PasteDeploy==3.1.0
pathspec==0.12.1
pbkdf2==1.3
petname==2.6
pillow==10.2.0
pkginfo==1.10.0
plaster==1.1.2
plaster-pastedeploy==1.0.1
platformdirs==4.2.0
pluggy==1.4.0
proto-plus==1.23.0
protobuf==4.25.3
psutil==5.9.8
py-cpuinfo==9.0.0
pyarrow==15.0.2
pyarrow-hotfix==0.6
pyasn1==0.6.0
pyasn1_modules==0.4.0
pycparser==2.22
pydantic==2.6.4
pydantic_core==2.16.3
Pygments==2.17.2
pynvml==11.5.0
pyparsing==3.1.2
pyproject_hooks==1.0.0
pyramid==2.0.2
pyramid-mailer==0.15.1
pytest==8.1.1
pytest-sphinx==0.6.3
python-dateutil==2.9.0.post0
python3-openid==3.2.0
pytz==2024.1
PyYAML==6.0.1
pyzmq==25.1.2
readme_renderer==43.0
regex==2023.12.25
repoze.sendmail==4.4.1
requests==2.31.0
requests-oauthlib==2.0.0
requests-toolbelt==1.0.0
requirements-parser==0.9.0
rfc3986==2.0.0
rich==13.7.1
rsa==4.9
ruff==0.3.7
s3transfer==0.10.1
safetensors==0.4.2
scikit-learn==1.4.2
scipy==1.13.0
seaborn==0.13.2
SecretStorage==3.3.3
sentry-sdk==1.43.0
setproctitle==1.3.3
six==1.16.0
smart-open==7.0.4
smashed==0.21.5
smmap==5.0.1
SQLAlchemy==2.0.29
sympy==1.12
threadpoolctl==3.4.0
tokenizers==0.15.2
tomli==2.0.1
torch==2.2.1+cu118
torchmetrics==1.3.2
tqdm==4.66.2
transaction==4.0
transformers==4.38.2
translationstring==1.4
triton==2.2.0
trouting==0.3.3
twine==5.0.0
types-setuptools==69.5.0.20240415
typing_extensions==4.8.0
tzdata==2024.1
ujson==5.9.0
urllib3==2.2.1
velruse==1.1.1
venusian==3.1.0
wandb==0.16.4
wcwidth==0.2.13
WebOb==1.8.7
websocket-client==1.7.0
Werkzeug==3.0.1
wrapt==1.16.0
WTForms==3.1.2
wtforms-recaptcha==0.3.2
xxhash==3.4.1
yarl==1.9.4
zipp==3.18.1
zmq==0.0.0
zope.deprecation==5.0
zope.interface==6.3
zope.sqlalchemy==3.1
The text was updated successfully, but these errors were encountered: