Skip to content

Commit f4fe9e2

Browse files
committed
feat(migrate): Add model card generation and saving to migration script
- Implemented functionality to generate and save a model card for the migrated model, including metadata such as dataset repository ID, license, and tags. - Enhanced the script to push the model card to the hub if requested, improving model documentation and accessibility. - Refactored the saving process to ensure the model card is saved locally and uploaded correctly when pushing to the hub.
1 parent 655303e commit f4fe9e2

File tree

1 file changed

+39
-6
lines changed

1 file changed

+39
-6
lines changed

src/lerobot/processor/migrate_policy_normalization.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,7 @@ def main():
411411
UnnormalizerProcessor(features=output_features, norm_map=norm_map, stats=stats),
412412
]
413413
postprocessor = RobotProcessor(postprocessor_steps, name=f"{policy_type}_postprocessor")
414+
414415
# Determine hub repo ID if pushing to hub
415416
if args.push_to_hub:
416417
if args.hub_repo_id:
@@ -424,12 +425,6 @@ def main():
424425
else:
425426
hub_repo_id = None
426427

427-
# Save model using the policy's save_pretrained method
428-
print(f"Saving model to {output_dir}...")
429-
policy.save_pretrained(
430-
output_dir, push_to_hub=args.push_to_hub, repo_id=hub_repo_id, private=args.private
431-
)
432-
433428
# Save preprocessor and postprocessor to root directory
434429
print(f"Saving preprocessor to {output_dir}...")
435430
preprocessor.save_pretrained(output_dir)
@@ -441,6 +436,44 @@ def main():
441436
if args.push_to_hub:
442437
postprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private)
443438

439+
# Save model using the policy's save_pretrained method
440+
print(f"Saving model to {output_dir}...")
441+
policy.save_pretrained(
442+
output_dir, push_to_hub=args.push_to_hub, repo_id=hub_repo_id, private=args.private
443+
)
444+
445+
# Generate and save model card
446+
print("Generating model card...")
447+
# Get metadata from original config
448+
dataset_repo_id = config.get("repo_id", "unknown")
449+
license = config.get("license", "apache-2.0")
450+
451+
tags = config.get("tags", ["robotics", "lerobot", policy_type]) or ["robotics", "lerobot", policy_type]
452+
tags = set(tags).union({"robotics", "lerobot", policy_type})
453+
tags = list(tags)
454+
455+
# Generate model card
456+
card = policy.generate_model_card(
457+
dataset_repo_id=dataset_repo_id, model_type=policy_type, license=license, tags=tags
458+
)
459+
460+
# Save model card locally
461+
card.save(str(output_dir / "README.md"))
462+
print(f"Model card saved to {output_dir / 'README.md'}")
463+
# Push model card to hub if requested
464+
if args.push_to_hub:
465+
from huggingface_hub import HfApi
466+
467+
api = HfApi()
468+
api.upload_file(
469+
path_or_fileobj=str(output_dir / "README.md"),
470+
path_in_repo="README.md",
471+
repo_id=hub_repo_id,
472+
repo_type="model",
473+
commit_message="Add model card for migrated model",
474+
)
475+
print("Model card pushed to hub")
476+
444477
print("\nMigration complete!")
445478
print(f"Migrated model saved to: {output_dir}")
446479
if args.push_to_hub:

0 commit comments

Comments
 (0)