Skip to content

Commit db35980

Browse files
author
Hossein Kavianihamedani
committed
Add multi-dataset evaluation support
- Add eval_utils.py with run_evaluation() function for multi-dataset evaluation - Update main.py to support multi-dataset configuration and evaluation - Add validation config settings (enabled, eval_interval, eval_steps) - Refactor setup() to support dataset_val.datasets structure - Add unified forward() method with compute_gradients flag - Add evaluate() method that calls run_evaluation() - Update llama3_8b.yaml with multi-dataset configuration
1 parent feff83c commit db35980

File tree

3 files changed

+452
-35
lines changed

3 files changed

+452
-35
lines changed

apps/sft/eval_utils.py

Lines changed: 326 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,326 @@
1+
"""Utility functions for evaluation to make main.py more concise."""
2+
3+
import logging
4+
from typing import Any, Callable, Iterator
5+
6+
import torch
7+
from torch import nn
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
def move_batch_to_device(batch: dict[str, Any], device: torch.device) -> dict[str, Any]:
13+
"""Move all tensors in batch to specified device.
14+
15+
Args:
16+
batch: Dictionary containing batch data
17+
device: Target device
18+
19+
Returns:
20+
Batch with tensors moved to device (modifies in-place and returns)
21+
"""
22+
for k, v in batch.items():
23+
if isinstance(v, torch.Tensor):
24+
batch[k] = v.to(device)
25+
return batch
26+
27+
28+
def extract_epoch_from_batch(batch: dict) -> int | None:
29+
"""Extract epoch number from batch metrics.
30+
31+
Args:
32+
batch: Batch dictionary with 'metrics' field
33+
34+
Returns:
35+
Epoch number from metrics, or None if not found
36+
"""
37+
if "metrics" in batch:
38+
for metric in batch["metrics"]:
39+
if hasattr(metric, "metric_name") and metric.metric_name == "num_epochs":
40+
return metric.value
41+
return None
42+
43+
44+
def start_epoch_sync(
45+
epoch_increment: int,
46+
device: torch.device,
47+
dp_process_group: Any = None,
48+
) -> tuple[torch.Tensor | None, Any]:
49+
"""Start async all_reduce for epoch synchronization across ranks.
50+
51+
Args:
52+
epoch_increment: Difference between current and starting epoch
53+
device: Device for tensor
54+
dp_process_group: Data parallel process group (None = default group)
55+
56+
Returns:
57+
Tuple of (epoch_tensor, pending_work) for async operation, or (None, None) if not initialized
58+
"""
59+
if not torch.distributed.is_initialized():
60+
return None, None
61+
62+
epoch_tensor = torch.tensor([epoch_increment], dtype=torch.long, device=device)
63+
pending_work = torch.distributed.all_reduce(
64+
epoch_tensor,
65+
op=torch.distributed.ReduceOp.MAX,
66+
group=dp_process_group,
67+
async_op=True,
68+
)
69+
return epoch_tensor, pending_work
70+
71+
72+
def check_epoch_complete(
73+
pending_work: Any,
74+
epoch_tensor: torch.Tensor | None,
75+
) -> bool:
76+
"""Wait for async epoch sync and check if epoch completed.
77+
78+
Args:
79+
pending_work: Pending async all_reduce work
80+
epoch_tensor: Tensor containing epoch increment
81+
82+
Returns:
83+
True if any rank completed an epoch, False otherwise
84+
"""
85+
if pending_work is None:
86+
return False
87+
88+
pending_work.wait()
89+
if epoch_tensor is not None:
90+
return bool((epoch_tensor > 0).any().item())
91+
return False
92+
93+
94+
def eval_loop(
95+
dataloader_iter: Iterator,
96+
forward_fn: Callable[[dict, torch.Tensor], torch.Tensor],
97+
device: torch.device,
98+
eval_steps: int,
99+
dataset_name: str,
100+
dp_process_group: Any = None,
101+
extract_epoch_fn: Callable[[dict], int | None] = extract_epoch_from_batch,
102+
log_interval: int = 10,
103+
) -> tuple[float, int]:
104+
"""Run evaluation loop with epoch synchronization.
105+
106+
Args:
107+
dataloader_iter: Iterator over validation data
108+
forward_fn: Function that takes (batch_dict, labels_tensor) and returns loss tensor
109+
device: Device for computation
110+
eval_steps: Maximum number of eval steps (0 = no limit)
111+
dataset_name: Name for logging
112+
dp_process_group: Data parallel process group for epoch sync
113+
extract_epoch_fn: Function to extract epoch from batch
114+
log_interval: Log every N batches
115+
116+
Returns:
117+
Tuple of (avg_loss, num_batches)
118+
"""
119+
total_loss = torch.tensor(0.0, device=device)
120+
num_batches, starting_epoch = 0, None
121+
122+
# Prefetch first batch
123+
next_batch = next(dataloader_iter)
124+
should_break, pending_work, epoch_tensor = False, None, None
125+
126+
with torch.no_grad():
127+
while True:
128+
# Check if previous epoch sync completed
129+
if pending_work is not None:
130+
should_break = check_epoch_complete(pending_work, epoch_tensor)
131+
pending_work = None
132+
133+
if should_break:
134+
logger.info(
135+
f"[{dataset_name}] Epoch completed across all ranks - stopping evaluation"
136+
)
137+
break
138+
139+
if eval_steps > 0 and num_batches >= eval_steps:
140+
logger.info(f"[{dataset_name}] Reached eval_steps cap of {eval_steps}")
141+
break
142+
143+
batch = next_batch
144+
145+
# Track starting epoch
146+
current_epoch = extract_epoch_fn(batch)
147+
if starting_epoch is None:
148+
starting_epoch = current_epoch
149+
150+
# Prefetch next batch and start async epoch check
151+
try:
152+
next_batch = next(dataloader_iter)
153+
next_epoch = extract_epoch_fn(next_batch)
154+
155+
# Only check epochs if both are available
156+
if next_epoch is not None and starting_epoch is not None:
157+
epoch_increment = next_epoch - starting_epoch
158+
if torch.distributed.is_initialized():
159+
epoch_tensor, pending_work = start_epoch_sync(
160+
epoch_increment, device, dp_process_group
161+
)
162+
else:
163+
should_break = epoch_increment > 0
164+
except StopIteration:
165+
should_break = True
166+
167+
# Process current batch (overlaps with async all_reduce)
168+
move_batch_to_device(batch, device)
169+
labels = batch.pop("labels")
170+
loss = forward_fn(batch, labels)
171+
total_loss += loss
172+
num_batches += 1
173+
174+
if num_batches % log_interval == 0:
175+
logger.info(
176+
f" [{dataset_name}] Eval batch {num_batches} | Loss: {loss:.4f}"
177+
)
178+
179+
avg_loss = (total_loss / max(num_batches, 1)).item()
180+
logger.info(
181+
f"[{dataset_name}] COMPLETE | Val Loss: {avg_loss:.4f} | Batches: {num_batches}"
182+
)
183+
184+
return avg_loss, num_batches
185+
186+
187+
async def evaluate_single_dataset(
188+
val_dataloader: Any,
189+
dataset_name: str,
190+
forward_fn: Callable[[dict, torch.Tensor], torch.Tensor],
191+
device: torch.device,
192+
eval_steps: int,
193+
dp_process_group: Any = None,
194+
extract_epoch_fn: Callable[[dict], int | None] = extract_epoch_from_batch,
195+
) -> dict[str, float]:
196+
"""Evaluate on a single validation dataset with epoch synchronization.
197+
198+
Args:
199+
val_dataloader: DataLoader for this validation dataset
200+
dataset_name: Name of the dataset (for logging)
201+
forward_fn: Function that takes (batch_dict, labels_tensor) and returns loss
202+
device: Device for computation
203+
eval_steps: Maximum number of eval steps
204+
dp_process_group: Data parallel process group
205+
extract_epoch_fn: Function to extract epoch from batch
206+
207+
Returns:
208+
Dict with metrics: {"val_loss": float, "val_batches": int}
209+
"""
210+
avg_loss, num_batches = eval_loop(
211+
dataloader_iter=iter(val_dataloader),
212+
forward_fn=forward_fn,
213+
device=device,
214+
eval_steps=eval_steps,
215+
dataset_name=dataset_name,
216+
dp_process_group=dp_process_group,
217+
extract_epoch_fn=extract_epoch_fn,
218+
log_interval=10,
219+
)
220+
221+
return {"val_loss": avg_loss, "val_batches": num_batches}
222+
223+
224+
async def run_evaluation(
225+
val_dataloaders: dict[str, Any],
226+
model_parts: list[nn.Module],
227+
forward_fn: Callable[[dict, torch.Tensor], torch.Tensor],
228+
device: torch.device,
229+
eval_steps: int,
230+
dp_process_group: Any = None,
231+
extract_epoch_fn: Callable[[dict], int | None] = extract_epoch_from_batch,
232+
) -> dict[str, dict[str, float]]:
233+
"""Run evaluation on multiple validation datasets.
234+
235+
Evaluates on all configured validation datasets and returns per-dataset metrics.
236+
Sets models to eval mode before evaluation and back to train mode after.
237+
238+
Args:
239+
val_dataloaders: Dict mapping dataset names to dataloaders
240+
model_parts: List of model parts (for setting eval/train mode)
241+
forward_fn: Function that takes (batch_dict, labels_tensor) and returns loss
242+
device: Device for computation
243+
eval_steps: Maximum number of eval steps per dataset
244+
dp_process_group: Data parallel process group
245+
extract_epoch_fn: Function to extract epoch from batch
246+
247+
Returns:
248+
Dict mapping dataset name to metrics dict, e.g.:
249+
{
250+
"val_in_domain": {"val_loss": 2.5, "val_batches": 100},
251+
"val_out_domain": {"val_loss": 3.1, "val_batches": 100}
252+
}
253+
"""
254+
logger.info("=" * 50)
255+
logger.info("STARTING EVALUATION")
256+
logger.info("=" * 50)
257+
258+
# Set models to eval mode
259+
for model_part in model_parts:
260+
model_part.eval()
261+
262+
all_metrics = {}
263+
264+
# Evaluate on each dataset
265+
for dataset_name, val_dataloader in val_dataloaders.items():
266+
logger.info(f"\n{'='*50}")
267+
logger.info(f"Evaluating on dataset: {dataset_name}")
268+
logger.info(f"{'='*50}")
269+
270+
dataset_metrics = await evaluate_single_dataset(
271+
val_dataloader=val_dataloader,
272+
dataset_name=dataset_name,
273+
forward_fn=forward_fn,
274+
device=device,
275+
eval_steps=eval_steps,
276+
dp_process_group=dp_process_group,
277+
extract_epoch_fn=extract_epoch_fn,
278+
)
279+
all_metrics[dataset_name] = dataset_metrics
280+
281+
# Set models back to train mode
282+
for model_part in model_parts:
283+
model_part.train()
284+
285+
logger.info("\n" + "=" * 50)
286+
logger.info("EVALUATION COMPLETE - Summary:")
287+
for dataset_name, metrics in all_metrics.items():
288+
logger.info(
289+
f" {dataset_name}: Loss={metrics['val_loss']:.4f}, Batches={metrics['val_batches']}"
290+
)
291+
logger.info("=" * 50)
292+
293+
return all_metrics
294+
295+
296+
def get_dp_process_group(parallel_dims: Any) -> Any:
297+
"""Get the Data Parallel process group for epoch synchronization.
298+
299+
Returns the DP process group if DP parallelism is enabled, otherwise None.
300+
This ensures all_reduce only happens across ranks with different data.
301+
302+
Args:
303+
parallel_dims: ParallelDims object containing parallel configuration
304+
305+
Returns:
306+
DP process group or None if not available/needed
307+
"""
308+
if not torch.distributed.is_initialized():
309+
return None
310+
311+
if parallel_dims is None:
312+
return None
313+
314+
# Check if DP is enabled
315+
if not parallel_dims.dp_enabled:
316+
# No DP parallelism, use default group (all ranks)
317+
return None
318+
319+
try:
320+
# Get the "dp" submesh which contains only DP dimensions (dp_replicate + dp_shard)
321+
# This excludes TP and PP ranks which should already be synchronized
322+
dp_mesh = parallel_dims.world_mesh.get_group("dp")
323+
return dp_mesh
324+
except Exception as e:
325+
logger.warning(f"Could not get DP process group, using default: {e}")
326+
return None

apps/sft/llama3_8b.yaml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,18 @@ optimizer:
2626
lr_scheduler:
2727
warmup_steps: 200
2828

29+
# Unified dataset configuration
30+
# First dataset with split='train' is used for training
31+
dataset_val:
32+
datasets:
33+
- name: "train"
34+
path: "yahma/alpaca-cleaned"
35+
split: "train[:95%]"
36+
37+
- name: "val"
38+
path: "yahma/alpaca-cleaned"
39+
split: "train[95%:]"
40+
2941
training:
3042
local_batch_size: 1
3143
seq_len: 2048
@@ -62,6 +74,7 @@ metric_logging:
6274
group: sft_exp_${oc.env:USER}
6375
logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce
6476

77+
6578
# profiling:
6679
# enable_profiling: false
6780

0 commit comments

Comments
 (0)