Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Block sparse mm #1058

Merged
merged 10 commits into from
May 2, 2024
Merged

Block sparse mm #1058

merged 10 commits into from
May 2, 2024

Conversation

jagrit06
Copy link
Member

Proposed changes

Adds operation and primitive to gather matrices before matmul on the fly

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@jagrit06 jagrit06 marked this pull request as ready for review April 30, 2024 17:29
@jagrit06 jagrit06 requested a review from awni April 30, 2024 17:30
@jagrit06 jagrit06 linked an issue Apr 30, 2024 that may be closed by this pull request
python/src/ops.cpp Outdated Show resolved Hide resolved
python/src/ops.cpp Outdated Show resolved Hide resolved
python/src/ops.cpp Outdated Show resolved Hide resolved
mlx/ops.cpp Outdated Show resolved Hide resolved
mlx/ops.cpp Outdated Show resolved Hide resolved
mlx/ops.cpp Outdated Show resolved Hide resolved
@awni
Copy link
Member

awni commented May 1, 2024

Very cool!! Can't wait to try this in an MOE!

@awni
Copy link
Member

awni commented May 2, 2024

Some MOE benchmarks:

Generation

python -m mlx_lm.generate --model Qwen/Qwen1.5-MoE-A2.7B-Chat  --prompt "Write a story about Einstein" --max-tokens 256 --temp 0.0
Pre: 31.285 tokens-per-sec
Post: 72.387 tokens-per-sec

LoRA

python -m mlx_lm.lora --train --iters 50 --model Qwen/Qwen1.5-MoE-A2.7B-Chat --data ../lora/data

Pre: Iter 30: Train loss 1.475, Learning Rate 1.000e-05, It/sec 1.692, Tokens/sec 291.262, Trained Tokens 5325, Peak mem 29.248 GB
Post: Iter 30: Train loss 1.466, Learning Rate 1.000e-05, It/sec 2.724, Tokens/sec 468.749, Trained Tokens 5325, Peak mem 28.051 GB

Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀

@jagrit06 jagrit06 merged commit f390957 into main May 2, 2024
3 checks passed
@jagrit06 jagrit06 deleted the block-sparse-mm branch May 2, 2024 21:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] Matmul gives wrong output for large sizes
2 participants