Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use shard_as in scan to ensure that inputs and their gradients have the same sharding #8879

Merged
merged 8 commits into from
Mar 29, 2025

Conversation

tengyifei
Copy link
Collaborator

@tengyifei tengyifei commented Mar 24, 2025

Fixes #8883. See #8883 for explanation.

@tengyifei tengyifei marked this pull request as ready for review March 24, 2025 21:08
@tengyifei tengyifei changed the base branch from master to yifeit/call-jax-cache March 24, 2025 21:08
@tengyifei tengyifei force-pushed the yifeit/scan-shard-as branch 2 times, most recently from 705814f to b2140fa Compare March 24, 2025 23:42
@tengyifei tengyifei changed the base branch from yifeit/call-jax-cache to master March 24, 2025 23:42
@tengyifei tengyifei force-pushed the yifeit/scan-shard-as branch from b2140fa to ba4b6ab Compare March 24, 2025 23:43
@tengyifei tengyifei requested review from qihqi and bhavya01 March 26, 2025 20:58
@tengyifei tengyifei force-pushed the yifeit/scan-shard-as branch from 5f84ac1 to 9dcfe1c Compare March 29, 2025 00:09
@tengyifei tengyifei merged commit 6d88c08 into master Mar 29, 2025
22 of 23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[RFC] Use shard_as to improve sharding and avoid OOM
2 participants