Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 26 additions & 9 deletions nemo_rl/algorithms/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,15 +417,16 @@ def dpo_train(

with timer.time("total_step_time"):
print("▶ Taking a training step...")
train_results = policy.train(
batch,
loss_fn,
eval_mode=False,
## NOTE: we double the batch size here because each preference example corresponds to a pair of
## examples, chosen and rejected, and the pair needs to be processed as part of the same microbatch.
gbs=master_config["policy"]["train_global_batch_size"] * 2,
mbs=master_config["policy"]["train_micro_batch_size"] * 2,
)
with timer.time("policy_training"):
train_results = policy.train(
batch,
loss_fn,
eval_mode=False,
## NOTE: we double the batch size here because each preference example corresponds to a pair of
## examples, chosen and rejected, and the pair needs to be processed as part of the same microbatch.
gbs=master_config["policy"]["train_global_batch_size"] * 2,
mbs=master_config["policy"]["train_micro_batch_size"] * 2,
)

is_last_step = total_steps + 1 >= master_config["dpo"][
"max_num_steps"
Expand Down Expand Up @@ -519,6 +520,22 @@ def dpo_train(

print("\n📊 Training Results:")
print(f" • Loss: {float(metrics['loss']):.4f}")
if "total_flops" in train_results:
total_tflops = (
train_results["total_flops"]
/ timing_metrics["policy_training"]
/ 1e12
)
num_ranks = train_results["num_ranks"]
print(
f" • Training FLOPS: {total_tflops:.2f} TFLOPS ({total_tflops / num_ranks:.2f} TFLOPS per rank)"
)
if "theoretical_tflops" in train_results:
theoretical_tflops = train_results["theoretical_tflops"]
print(
f" • Training Model Floating Point Utilization: {100 * total_tflops / theoretical_tflops:.2f}%"
)
metrics["train_fp_utilization"] = total_tflops / theoretical_tflops
print("\n⏱️ Timing:")
# Display total time first, separately
total_time = timing_metrics.get("total_step_time", 0)
Expand Down
14 changes: 14 additions & 0 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,20 @@ def grpo_train(
print(
f" • Mean Generation Length: {rollout_metrics['mean_gen_tokens_per_sample']:.4f}"
)
if "total_flops" in train_results:
total_tflops = (
train_results["total_flops"] / timing_metrics["policy_training"] / 1e12
)
num_ranks = train_results["num_ranks"]
print(
f" • Training FLOPS: {total_tflops:.2f} TFLOPS ({total_tflops / num_ranks:.2f} TFLOPS per rank)"
)
if "theoretical_tflops" in train_results:
theoretical_tflops = train_results["theoretical_tflops"]
print(
f" • Training Model Floating Point Utilization: {100 * total_tflops / theoretical_tflops:.2f}%"
)
metrics["train_fp_utilization"] = total_tflops / theoretical_tflops

print("\n⏱️ Timing:")
# Display total time first, separately
Expand Down
19 changes: 18 additions & 1 deletion nemo_rl/algorithms/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,8 @@ def sft_train(
)

print("▶ Taking a training step...")
train_results = policy.train(train_data, loss_fn)
with timer.time("policy_training"):
train_results = policy.train(train_data, loss_fn)

is_last_step = total_steps + 1 >= master_config["sft"][
"max_num_steps"
Expand Down Expand Up @@ -502,6 +503,22 @@ def sft_train(

print("\n📊 Training Results:")
print(f" • Loss: {float(metrics['loss']):.4f}")
if "total_flops" in train_results:
total_tflops = (
train_results["total_flops"]
/ timing_metrics["policy_training"]
/ 1e12
)
num_ranks = train_results["num_ranks"]
print(
f" • Training FLOPS: {total_tflops:.2f} TFLOPS ({total_tflops / num_ranks:.2f} TFLOPS per rank)"
)
if "theoretical_tflops" in train_results:
theoretical_tflops = train_results["theoretical_tflops"]
print(
f" • Training Model Floating Point Utilization: {100 * total_tflops / theoretical_tflops:.2f}%"
)
metrics["train_fp_utilization"] = total_tflops / theoretical_tflops
print("\n⏱️ Timing:")
# Display total time first, separately
total_time = timing_metrics.get("total_step_time", 0)
Expand Down
2 changes: 2 additions & 0 deletions nemo_rl/models/policy/dtensor_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,8 @@ def train(
"global_loss": global_loss.cpu(),
"grad_norm": grad_norm,
"rank": torch.distributed.get_rank(),
"gpu_name": torch.cuda.get_device_name(),
"model_dtype": self.dtype,
"all_mb_metrics": dict(mb_metrics),
}

Expand Down
33 changes: 33 additions & 0 deletions nemo_rl/models/policy/lm_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import warnings
from collections import defaultdict
from typing import Any, Optional, Union

Expand Down Expand Up @@ -41,6 +42,11 @@
LogprobOutputSpec,
ReferenceLogprobOutputSpec,
)
from nemo_rl.utils.flops_tracker import (
FLOPTracker,
get_default_hf_config,
get_theoretical_tflops,
)

PathLike = Union[str, "os.PathLike[Any]"]

Expand Down Expand Up @@ -147,6 +153,15 @@ def __init__(
else:
self.use_dynamic_batches = False

# initialize FLOPs tracker
try:
self.flops_tracker = FLOPTracker.from_config(
config["model_name"], get_default_hf_config(config["model_name"])
)
except ValueError as e:
self.flops_tracker = None
print(f"FLOPS tracker not supported for model {config['model_name']}: {e}")

if config["sequence_packing"]["enabled"]:
assert (
config["megatron_cfg"]["enabled"] or config["dtensor_cfg"]["enabled"]
Expand Down Expand Up @@ -349,6 +364,12 @@ def train(
batch_size=batch_size,
)

if self.flops_tracker is not None:
self.flops_tracker.reset()
for shard in sharded_data:
input_lengths = shard["input_lengths"]
self.flops_tracker.track_batch(input_lengths.tolist())

# Train each shard in parallel
futures = self.worker_group.run_all_workers_sharded_data(
"train",
Expand Down Expand Up @@ -379,6 +400,18 @@ def train(
"grad_norm": results[0]["grad_norm"],
}

if self.flops_tracker is not None:
aggregated_results["total_flops"] = self.flops_tracker.total_flops
aggregated_results["num_ranks"] = len(results)

try:
aggregated_results["theoretical_tflops"] = sum(
get_theoretical_tflops(r["gpu_name"], r["model_dtype"])
for r in results
)
except Exception as e:
warnings.warn(f"Error getting theoretical flops: {e}")

# Aggregate metrics across all workers
all_mb_metrics = defaultdict(list)
for r in results:
Expand Down
2 changes: 2 additions & 0 deletions nemo_rl/models/policy/megatron_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,8 @@ def train(
metrics = {
"global_loss": global_loss.cpu(),
"rank": torch.distributed.get_rank(),
"gpu_name": torch.cuda.get_device_name(),
"model_dtype": self.dtype,
"all_mb_metrics": dict(mb_metrics),
"grad_norm": torch.tensor(
mb_metrics["grad_norm"][-1]
Expand Down
Loading
Loading