Skip to content

Commit c322b9a

Browse files
authored
Merge pull request #686 from allenai/fix-from-checkpoint
Fixes for OLMo.from_checkpoint
2 parents c482df7 + cb51bb1 commit c322b9a

File tree

2 files changed

+21
-7
lines changed

2 files changed

+21
-7
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2525

2626
- Fixed restarting a training run in later epochs so that we no longer need to set the flag `--epoch=INT`.
2727
- Fix bug where the attention norm, when applied before the attention block, was modifying the residual stream.
28+
- Fixed `OLMo.from_checkpoint()` so that it correctly loads `olmo_core` and `torch_new` style checkpoints.
2829

2930
## [v0.4.0](https://github.com/allenai/OLMo/releases/tag/v0.4.0) - 2024-07-11
3031

olmo/model.py

+20-7
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
InitFnType,
4343
LayerNormType,
4444
ModelConfig,
45+
ShardedCheckpointerType,
46+
TrainConfig,
4547
)
4648
from .exceptions import OLMoConfigurationError
4749
from .initialization import init_normal
@@ -1740,15 +1742,26 @@ def from_checkpoint(
17401742
model.load_state_dict(model._make_state_dict_compatible(state_dict)[0])
17411743
model = model.to(torch.device(device))
17421744
else:
1743-
from .checkpoint import load_model_state
1745+
train_config = TrainConfig.load(config_path)
1746+
if train_config.sharded_checkpointer == ShardedCheckpointerType.olmo_core:
1747+
from olmo_core.distributed.checkpoint import ( # type: ignore
1748+
load_model_and_optim_state,
1749+
)
17441750

1745-
# Initialize model on target device. In this case the state dict is loaded in-place
1746-
# so it's not necessary to start on CPU if the target device is a GPU.
1747-
model_config.init_device = device
1748-
model = OLMo(model_config)
1751+
model_config.init_device = device
1752+
model = OLMo(model_config)
1753+
load_model_and_optim_state(checkpoint_dir, model)
1754+
else:
1755+
# train_config.sharded_checkpointer == ShardedCheckpointerType.torch_new
1756+
from .checkpoint import load_model_state
1757+
1758+
# Initialize model on target device. In this case the state dict is loaded in-place
1759+
# so it's not necessary to start on CPU if the target device is a GPU.
1760+
model_config.init_device = device
1761+
model = OLMo(model_config)
17491762

1750-
# Load state dict in place.
1751-
load_model_state(checkpoint_dir, model)
1763+
# Load state dict in place.
1764+
load_model_state(checkpoint_dir, model)
17521765

17531766
return model.eval()
17541767

0 commit comments

Comments
 (0)