Skip to content

Commit 5585bf0

Browse files
mridul-sahuOrbax Authors
authored andcommitted
Refactor benchmark metrics to support multiple measurements per operation without needing context manager for each one.
Also, * Adding `IOBytesMetric` to track I/O read/write bytes and throughput. * Introducing a `METRIC_REGISTRY` to map metric keys to classes. PiperOrigin-RevId: 831668357
1 parent 6e7ac96 commit 5585bf0

18 files changed

+332
-209
lines changed

checkpoint/orbax/checkpoint/_src/testing/benchmarks/array_handler_benchmark.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def test_fn(self, test_context: core.TestContext) -> core.TestResult:
174174
save_args = type_handlers.SaveArgs()
175175

176176
# --- Serialization ---
177-
with metrics.time('serialize'):
177+
with metrics.measure('serialize'):
178178

179179
async def serialize_and_wait():
180180
serialize_futures = await handler.serialize(
@@ -191,7 +191,7 @@ async def serialize_and_wait():
191191
logging.info('Serialization complete for %s', param_info.name)
192192

193193
if options.use_ocdbt:
194-
with metrics.time('merge_ocdbt'):
194+
with metrics.measure('merge_ocdbt'):
195195
asyncio.run(
196196
ocdbt_utils.merge_ocdbt_per_process_files(
197197
test_context.path,
@@ -203,7 +203,7 @@ async def serialize_and_wait():
203203
logging.info('OCDBT merge complete for %s', test_context.path)
204204

205205
# --- Metadata Validation ---
206-
with metrics.time('metadata_validation'):
206+
with metrics.measure('metadata_validation'):
207207
metadata = asyncio.run(handler.metadata([param_info]))[0]
208208
self._validate_metadata(metadata, sharded_array, param_info.name)
209209
multihost.sync_global_processes('metadata validation complete')
@@ -215,7 +215,7 @@ async def serialize_and_wait():
215215
global_shape=sharded_array.shape,
216216
dtype=sharded_array.dtype,
217217
)
218-
with metrics.time('deserialize'):
218+
with metrics.measure('deserialize'):
219219
restored_array = asyncio.run(
220220
handler.deserialize([param_info], args=[restore_args])
221221
)[0]
@@ -224,7 +224,7 @@ async def serialize_and_wait():
224224
logging.info('Deserialization complete for %s', param_info.name)
225225

226226
# --- Restored Array Validation ---
227-
with metrics.time('correctness_check'):
227+
with metrics.measure('correctness_check'):
228228
pytree_utils.assert_pytree_equal(sharded_array, restored_array)
229229
logging.info('Correctness check passed for %s', param_info.name)
230230

checkpoint/orbax/checkpoint/_src/testing/benchmarks/array_handler_benchmark_test.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,10 @@ def test_benchmark_returns_test_result_with_basic_metrics(self, options):
138138
result = self._run_benchmark_workflow_test(options)
139139

140140
self.assertIsInstance(result, benchmarks_core.TestResult)
141-
self.assertIn('serialize_time', result.metrics.results)
142-
self.assertIn('metadata_validation_time', result.metrics.results)
143-
self.assertIn('deserialize_time', result.metrics.results)
144-
self.assertIn('correctness_check_time', result.metrics.results)
141+
self.assertIn('serialize_time_duration', result.metrics.results)
142+
self.assertIn('metadata_validation_time_duration', result.metrics.results)
143+
self.assertIn('deserialize_time_duration', result.metrics.results)
144+
self.assertIn('correctness_check_time_duration', result.metrics.results)
145145

146146
@parameterized.named_parameters(
147147
dict(
@@ -156,14 +156,14 @@ def test_benchmark_returns_test_result_with_basic_metrics(self, options):
156156
def test_benchmark_ocdbt_enabled_calls_merge(self, options):
157157
result = self._run_benchmark_workflow_test(options)
158158

159-
self.assertIn('merge_ocdbt_time', result.metrics.results)
159+
self.assertIn('merge_ocdbt_time_duration', result.metrics.results)
160160
self.mock_merge_ocdbt.assert_called_once()
161161

162162
def test_benchmark_ocdbt_disabled_does_not_merge(self):
163163
options = ArrayHandlerBenchmarkOptions(use_ocdbt=False, use_zarr3=True)
164164
result = self._run_benchmark_workflow_test(options)
165165

166-
self.assertNotIn('merge_ocdbt_time', result.metrics.results)
166+
self.assertNotIn('merge_ocdbt_time_duration', result.metrics.results)
167167
self.mock_merge_ocdbt.assert_not_called()
168168

169169
@parameterized.named_parameters(

checkpoint/orbax/checkpoint/_src/testing/benchmarks/checkpoint_manager_benchmark.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,10 @@ def test_fn(
9090
step_saved = -1
9191
for step in range(options.train_steps):
9292
logging.info('Saving checkpoint at step %d', step)
93-
with metrics.time(f'save_{step}'):
93+
with metrics.measure(f'save_{step}'):
9494
if mngr.save(step, args=composite_args):
9595
step_saved = step
96-
with metrics.time(f'wait_until_finished_{step}'):
96+
with metrics.measure(f'wait_until_finished_{step}'):
9797
mngr.wait_until_finished()
9898
logging.info('Finished saving checkpoint at step %d', step)
9999

@@ -105,12 +105,12 @@ def test_fn(
105105
f'Expected latest step to be {step_saved}, got {latest_step}'
106106
)
107107

108-
with metrics.time(f'restore_{latest_step}'):
108+
with metrics.measure(f'restore_{latest_step}'):
109109
logging.info('Restoring checkpoint at step %d', latest_step)
110110
restored = mngr.restore(latest_step, args=restore_args)
111111
logging.info('Finished restoring checkpoint at step %d', latest_step)
112112

113-
with metrics.time('correctness_check'):
113+
with metrics.measure('correctness_check'):
114114
pytree_utils.assert_pytree_equal(pytree, restored['pytree'])
115115
assert (
116116
json_data == restored['json_item']

checkpoint/orbax/checkpoint/_src/testing/benchmarks/checkpoint_manager_benchmark_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,10 @@ def test_benchmark_test_fn_succeeds(self):
110110
self.assertIsInstance(result, benchmarks_core.TestResult)
111111
self.assertContainsSubset(
112112
{
113-
'save_0_time',
114-
'wait_until_finished_0_time',
115-
'restore_0_time',
116-
'correctness_check_time',
113+
'save_0_time_duration',
114+
'wait_until_finished_0_time_duration',
115+
'restore_0_time_duration',
116+
'correctness_check_time_duration',
117117
},
118118
result.metrics.results.keys(),
119119
)

checkpoint/orbax/checkpoint/_src/testing/benchmarks/checkpoint_manager_perf_benchmark.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def test_fn(
110110
save_times = []
111111
total_save_times = []
112112
for i in range(options.train_steps):
113-
with metrics.time(f'save_{i}'):
113+
with metrics.measure(f'save_{i}'):
114114
save_start = time.time()
115115
mngr.save(
116116
i,
@@ -121,7 +121,7 @@ def test_fn(
121121
save_times.append(time.time() - save_start)
122122
mngr.wait_until_finished()
123123
total_save_times.append(time.time() - save_start)
124-
with metrics.time(f'train_step_{i}'):
124+
with metrics.measure(f'train_step_{i}'):
125125
pytree = self._train_step(pytree)
126126

127127
save_times = np.array(save_times)
@@ -150,7 +150,7 @@ def test_fn(
150150
)
151151
context.pytree = self._clear_pytree(context.pytree)
152152

153-
with metrics.time('restore'):
153+
with metrics.measure('restore'):
154154
mngr.restore(
155155
mngr.latest_step(),
156156
args=ocp.args.Composite(

checkpoint/orbax/checkpoint/_src/testing/benchmarks/checkpoint_policy_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def test_fn(
237237
),
238238
)
239239
for step in range(options.num_checkpoints):
240-
with metrics.time(f'saving step {step}'):
240+
with metrics.measure(f'saving step {step}'):
241241
checkpointer_manager.save(step, args=ocp.args.PyTreeSave(pytree))
242242
checkpointer_manager.wait_until_finished()
243243
all_steps = checkpointer_manager.all_steps()

checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/core.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,12 +143,12 @@ def run(self) -> TestResult:
143143
)
144144

145145
benchmark_metrics = metric_lib.Metrics(name=f"{self.name} Internal")
146-
with benchmark_metrics.time("sync_global_processes:benchmark:run"):
146+
with benchmark_metrics.measure("sync_global_processes:benchmark:run"):
147147
multihost.sync_global_processes("benchmark:run")
148148

149149
path = directory_setup.setup_test_directory(self.name, self.output_dir)
150150

151-
with benchmark_metrics.time(
151+
with benchmark_metrics.measure(
152152
"sync_global_processes:benchmark:setup_test_directory"
153153
):
154154
multihost.sync_global_processes("benchmark:setup_test_directory")
@@ -160,7 +160,9 @@ def run(self) -> TestResult:
160160
else:
161161
data = checkpoint_generation.load_checkpoint(self.checkpoint_config.path)
162162

163-
with benchmark_metrics.time("sync_global_processes:benchmark:setup_pytree"):
163+
with benchmark_metrics.measure(
164+
"sync_global_processes:benchmark:setup_pytree"
165+
):
164166
multihost.sync_global_processes("benchmark:setup_pytree")
165167

166168
context = TestContext(

0 commit comments

Comments
 (0)