|
42 | 42 | InitFnType,
|
43 | 43 | LayerNormType,
|
44 | 44 | ModelConfig,
|
| 45 | + ShardedCheckpointerType, |
| 46 | + TrainConfig, |
45 | 47 | )
|
46 | 48 | from .exceptions import OLMoConfigurationError
|
47 | 49 | from .initialization import init_normal
|
@@ -1740,15 +1742,26 @@ def from_checkpoint(
|
1740 | 1742 | model.load_state_dict(model._make_state_dict_compatible(state_dict)[0])
|
1741 | 1743 | model = model.to(torch.device(device))
|
1742 | 1744 | else:
|
1743 |
| - from .checkpoint import load_model_state |
| 1745 | + train_config = TrainConfig.load(config_path) |
| 1746 | + if train_config.sharded_checkpointer == ShardedCheckpointerType.olmo_core: |
| 1747 | + from olmo_core.distributed.checkpoint import ( # type: ignore |
| 1748 | + load_model_and_optim_state, |
| 1749 | + ) |
1744 | 1750 |
|
1745 |
| - # Initialize model on target device. In this case the state dict is loaded in-place |
1746 |
| - # so it's not necessary to start on CPU if the target device is a GPU. |
1747 |
| - model_config.init_device = device |
1748 |
| - model = OLMo(model_config) |
| 1751 | + model_config.init_device = device |
| 1752 | + model = OLMo(model_config) |
| 1753 | + load_model_and_optim_state(checkpoint_dir, model) |
| 1754 | + else: |
| 1755 | + # train_config.sharded_checkpointer == ShardedCheckpointerType.torch_new |
| 1756 | + from .checkpoint import load_model_state |
| 1757 | + |
| 1758 | + # Initialize model on target device. In this case the state dict is loaded in-place |
| 1759 | + # so it's not necessary to start on CPU if the target device is a GPU. |
| 1760 | + model_config.init_device = device |
| 1761 | + model = OLMo(model_config) |
1749 | 1762 |
|
1750 |
| - # Load state dict in place. |
1751 |
| - load_model_state(checkpoint_dir, model) |
| 1763 | + # Load state dict in place. |
| 1764 | + load_model_state(checkpoint_dir, model) |
1752 | 1765 |
|
1753 | 1766 | return model.eval()
|
1754 | 1767 |
|
|
0 commit comments