Skip to content

Commit e46e205

Browse files
committed
mistral-nemo-12b from llama_8b
Signed-off-by: Alexandros Koumparoulis <[email protected]>
1 parent b7c0bfa commit e46e205

File tree

1 file changed

+108
-22
lines changed

1 file changed

+108
-22
lines changed

nemo/collections/llm/recipes/mistral_nemo_12b.py

+108-22
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515

16-
from typing import Optional
16+
from typing import Callable, Optional
1717

1818
import nemo_run as run
1919
import pytorch_lightning as pl
@@ -24,15 +24,17 @@
2424
from nemo import lightning as nl
2525
from nemo.collections.llm.api import finetune, pretrain
2626
from nemo.collections.llm.gpt.data.mock import MockDataModule
27+
from nemo.collections.llm.gpt.data.squad import SquadDataModule
2728
from nemo.collections.llm.gpt.model.mistral import MistralModel, MistralNeMoConfig12B
29+
from nemo.collections.llm.peft.lora import LoRA
2830
from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger
2931
from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing
3032
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed
33+
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
3134
from nemo.utils.exp_manager import TimingCallback
3235

3336
NAME = "mistral_nemo_base_12b"
3437

35-
3638
@run.cli.factory(name=NAME)
3739
def model() -> run.Config[pl.LightningModule]:
3840
"""
@@ -61,7 +63,7 @@ def trainer(
6163
sequence_parallelism: bool = False,
6264
num_nodes: int = 1,
6365
num_gpus_per_node: int = 8,
64-
max_steps: int = 100,
66+
max_steps: int = 1168251,
6567
callbacks: Optional[list[run.Config[Callback]]] = None,
6668
) -> run.Config[nl.Trainer]:
6769
"""
@@ -91,6 +93,10 @@ def trainer(
9193
Python API usage:
9294
>>> trainer_config = trainer(num_nodes=2, num_gpus_per_node=8)
9395
>>> print(trainer_config)
96+
97+
Note:
98+
For more information on distributed training strategies, refer to the
99+
NeMo documentation on multi-GPU and multi-node training.
94100
"""
95101
strategy = run.Config(
96102
nl.MegatronStrategy,
@@ -101,7 +107,6 @@ def trainer(
101107
context_parallel_size=context_parallelism,
102108
sequence_parallel=sequence_parallelism,
103109
gradient_as_bucket_view=True,
104-
ckpt_include_optimizer=True,
105110
ckpt_async_save=True,
106111
ckpt_parallel_load=True,
107112
ddp=run.Config(
@@ -119,7 +124,6 @@ def trainer(
119124
accumulate_grad_batches=1,
120125
callbacks=callbacks,
121126
devices=num_gpus_per_node,
122-
gradient_clip_val=1.0,
123127
limit_test_batches=50,
124128
limit_val_batches=32,
125129
log_every_n_steps=10,
@@ -157,37 +161,79 @@ def pretrain_recipe(
157161
Examples:
158162
CLI usage:
159163
$ nemo llm pretrain --factory mistral_nemo_base_12b
160-
$ nemo llm pretrain --factory "mistral_nemo_base_12b(num_nodes=2, name='my_mistral_pretrain')"
164+
$ nemo llm pretrain --factory "mistral_nemo_base_12b(num_nodes=2, name='my_pretrain')"
161165
162166
Python API usage:
163-
>>> recipe = pretrain_recipe(name="mistral_pretrain", num_nodes=2)
167+
>>> recipe = pretrain_recipe(name="mistral_nemo_base_12b", num_nodes=2)
164168
>>> print(recipe)
169+
170+
Note:
171+
For more details on pre-training LLMs with NeMo, see the pre-training
172+
guide in the `examples/llm/pretrain/` directory.
165173
"""
166174
return run.Partial(
167175
fn,
168176
model=model(),
169177
trainer=trainer(
170-
tensor_parallelism=2,
171-
pipeline_parallelism=1,
172-
pipeline_parallelism_type=None,
173-
virtual_pipeline_parallelism=None,
174-
context_parallelism=2,
175-
sequence_parallelism=False,
176178
num_nodes=num_nodes,
177179
num_gpus_per_node=num_gpus_per_node,
178180
callbacks=[run.Config(TimingCallback)],
179181
),
180-
data=run.Config(MockDataModule, seq_length=4096, global_batch_size=512, micro_batch_size=1),
182+
data=run.Config(MockDataModule, seq_length=8192, global_batch_size=512, micro_batch_size=1),
181183
log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)),
182184
optim=distributed_fused_adam_with_cosine_annealing(max_lr=3e-4),
183185
resume=default_resume(),
184186
)
185187

186188

187-
@run.cli.factory(name=NAME + "_hf")
188-
def hf_resume() -> run.Config[nl.AutoResume]:
189+
@run.cli.factory(target=pretrain, name=NAME + "_optimized")
190+
def pretrain_recipe_performance(
191+
dir: Optional[str] = None,
192+
name: str = "default",
193+
num_nodes: int = 1,
194+
num_gpus_per_node: int = 8,
195+
fn: Callable = pretrain,
196+
) -> run.Partial:
189197
"""
190-
Configure automatic resumption from a Hugging Face checkpoint for Mistral-Nemo-Base-12B model.
198+
Create a performance-optimized pre-training recipe for Mistral-Nemo-Base-12B model.
199+
200+
This recipe enables performance optimizations that may not be suitable for all use cases.
201+
It builds upon the standard pre-training recipe and adds additional performance enhancements.
202+
203+
Args:
204+
dir (Optional[str]): Directory for saving logs and checkpoints.
205+
name (str): Name of the pre-training run.
206+
num_nodes (int): Number of compute nodes to use.
207+
num_gpus_per_node (int): Number of GPUs per node.
208+
fn (Callable): The pre-training function to use.
209+
210+
Returns:
211+
run.Partial: Partial configuration for performance-optimized pre-training.
212+
213+
Examples:
214+
$ nemo llm pretrain --factory mistral_nemo_base_12b_optimized
215+
216+
Python API usage:
217+
>>> recipe = pretrain_recipe_performance(name="mistral_nemo_base_12b_perf", num_nodes=4)
218+
>>> print(recipe)
219+
220+
Note:
221+
Use this recipe with caution and only when you need maximum performance.
222+
It may not be suitable for all hardware configurations or use cases.
223+
"""
224+
recipe = pretrain_recipe(name=name, dir=dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=fn)
225+
226+
recipe.trainer.callbacks.append(
227+
run.Config(
228+
MegatronCommOverlapCallback,
229+
tp_comm_overlap=False,
230+
)
231+
)
232+
return recipe
233+
234+
235+
def hf_resume() -> run.Config[nl.AutoResume]:
236+
"""Configure automatic resumption from a Hugging Face checkpoint.
191237
192238
This function sets up the configuration to resume training from a pre-trained
193239
Hugging Face model checkpoint.
@@ -196,11 +242,51 @@ def hf_resume() -> run.Config[nl.AutoResume]:
196242
197243
Returns:
198244
run.Config[nl.AutoResume]: Configuration for resuming from HuggingFace checkpoint.
199-
200-
Note:
201-
This is particularly useful for fine-tuning scenarios where you want to
202-
start from the pre-trained Mistral-Nemo-Base-12B model.
203245
"""
204246
return run.Config(
205-
nl.AutoResume, restore_config=run.Config(nl.RestoreConfig, path="hf://mistralai/Mistral-Nemo-Base-2407")
247+
nl.AutoResume,
248+
restore_config=run.Config(nl.RestoreConfig, path="hf://mistralai/Mistral-Nemo-Base-2407"),
206249
)
250+
251+
252+
@run.cli.factory(target=finetune, name=NAME)
253+
def finetune_recipe(
254+
dir: Optional[str] = None,
255+
name: str = "default",
256+
num_nodes: int = 1,
257+
num_gpus_per_node: int = 8,
258+
) -> run.Partial:
259+
"""
260+
Create a fine-tuning recipe for Mistral-Nemo-Base-12B model.
261+
262+
This function sets up a complete configuration for fine-tuning, including
263+
model, trainer, data, logging, optimization, and resumption settings.
264+
It uses LoRA (Low-Rank Adaptation) for efficient fine-tuning.
265+
266+
Args:
267+
dir (Optional[str]): Directory for saving logs and checkpoints.
268+
name (str): Name of the fine-tuning run.
269+
num_nodes (int): Number of compute nodes to use.
270+
num_gpus_per_node (int): Number of GPUs per node.
271+
272+
Returns:
273+
run.Partial: Partial configuration for fine-tuning.
274+
275+
Examples:
276+
CLI usage:
277+
$ nemo llm finetune --factory mistral_nemo_base_12b
278+
279+
Python API usage:
280+
>>> recipe = finetune_recipe(name="mistral_nemo_base_12b_finetune", num_nodes=2)
281+
>>> print(recipe)
282+
283+
Note:
284+
This recipe uses the SQuAD dataset for fine-tuning. For more information
285+
on fine-tuning LLMs with NeMo, see the fine-tuning guide in the
286+
`examples/llm/finetune/` directory.
287+
"""
288+
recipe = pretrain_recipe(name=name, dir=dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=finetune)
289+
recipe.resume = hf_resume()
290+
recipe.peft = run.Config(LoRA)
291+
recipe.data = run.Config(SquadDataModule, seq_length=8192, global_batch_size=512, micro_batch_size=1)
292+
return recipe

0 commit comments

Comments
 (0)