Skip to content

Commit a0def4a

Browse files
authored
feat(processor): implement migration detection and error handling for processor configurations (#1968)
* feat(processor): implement migration detection and error handling for processor configurations - Added ProcessorMigrationError to handle migration requirements for old model formats. - Enhanced DataProcessorPipeline.from_pretrained to include robust migration detection logic. - Implemented methods for resolving configuration sources, validating loaded configs, and checking for valid processor configurations. - Introduced comprehensive tests for migration detection and configuration validation to ensure correct behavior. * refactor(processor): simplify loading logic and enhance migration detection - Refactored DataProcessorPipeline to implement a simplified three-way loading strategy for configuration files. - Introduced explicit config_filename parameter to avoid ambiguity during loading. - Updated ProcessorMigrationError to provide clearer error messages for migration requirements. - Enhanced tests to cover new loading logic and ensure proper migration detection. - Removed deprecated methods related to config source resolution.
1 parent c31ce3f commit a0def4a

15 files changed

+1425
-203
lines changed

src/lerobot/processor/pipeline.py

Lines changed: 700 additions & 152 deletions
Large diffs are not rendered by default.

tests/processor/test_act_processor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,9 @@ def test_act_processor_save_and_load():
239239
preprocessor.save_pretrained(tmpdir)
240240

241241
# Load preprocessor
242-
loaded_preprocessor = DataProcessorPipeline.from_pretrained(tmpdir)
242+
loaded_preprocessor = DataProcessorPipeline.from_pretrained(
243+
tmpdir, config_filename="policy_preprocessor.json"
244+
)
243245

244246
# Test that loaded processor works
245247
observation = {OBS_STATE: torch.randn(7)}

tests/processor/test_batch_processor.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,10 @@ def test_save_and_load_pretrained():
290290

291291
# Load pipeline
292292
loaded_pipeline = DataProcessorPipeline.from_pretrained(
293-
tmp_dir, to_transition=identity_transition, to_output=identity_transition
293+
tmp_dir,
294+
config_filename="batchpipeline.json",
295+
to_transition=identity_transition,
296+
to_output=identity_transition,
294297
)
295298

296299
assert loaded_pipeline.name == "BatchPipeline"
@@ -325,7 +328,10 @@ def test_registry_based_save_load():
325328
with tempfile.TemporaryDirectory() as tmp_dir:
326329
pipeline.save_pretrained(tmp_dir)
327330
loaded_pipeline = DataProcessorPipeline.from_pretrained(
328-
tmp_dir, to_transition=identity_transition, to_output=identity_transition
331+
tmp_dir,
332+
config_filename="dataprocessorpipeline.json",
333+
to_transition=identity_transition,
334+
to_output=identity_transition,
329335
)
330336

331337
# Verify the loaded processor works

tests/processor/test_classifier_processor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,9 @@ def test_classifier_processor_save_and_load():
250250
preprocessor.save_pretrained(tmpdir)
251251

252252
# Load preprocessor
253-
loaded_preprocessor = DataProcessorPipeline.from_pretrained(tmpdir)
253+
loaded_preprocessor = DataProcessorPipeline.from_pretrained(
254+
tmpdir, config_filename="classifier_preprocessor.json"
255+
)
254256

255257
# Test that loaded processor works
256258
observation = {

tests/processor/test_device_processor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,9 @@ def test_save_and_load_pretrained():
324324
robot_processor.save_pretrained(tmpdir)
325325

326326
# Load
327-
loaded_processor = DataProcessorPipeline.from_pretrained(tmpdir)
327+
loaded_processor = DataProcessorPipeline.from_pretrained(
328+
tmpdir, config_filename="device_test_processor.json"
329+
)
328330

329331
assert len(loaded_processor.steps) == 1
330332
loaded_device_processor = loaded_processor.steps[0]

tests/processor/test_diffusion_processor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,9 @@ def test_diffusion_processor_save_and_load():
258258
preprocessor.save_pretrained(tmpdir)
259259

260260
# Load preprocessor
261-
loaded_preprocessor = DataProcessorPipeline.from_pretrained(tmpdir)
261+
loaded_preprocessor = DataProcessorPipeline.from_pretrained(
262+
tmpdir, config_filename="policy_preprocessor.json"
263+
)
262264

263265
# Test that loaded processor works
264266
observation = {

0 commit comments

Comments
 (0)