-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Fix cutlass inductor options for PyTorch < 2.8.0 #3988
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -70,6 +70,12 @@ def __init__(self, **kwargs): | |||||||||||||||||||||||||||||||||||||||||||||||
| except Exception: | ||||||||||||||||||||||||||||||||||||||||||||||||
| trl_version = Version("0.0.0") | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| # Get PyTorch version for feature detection | ||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||
| torch_version = Version(torch.__version__.split("+")[0].split("a")[0].split("b")[0]) | ||||||||||||||||||||||||||||||||||||||||||||||||
| except Exception: | ||||||||||||||||||||||||||||||||||||||||||||||||
| torch_version = Version("0.0.0") | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| def vLLMSamplingParams(**kwargs): | ||||||||||||||||||||||||||||||||||||||||||||||||
| from vllm import SamplingParams | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1126,16 +1132,18 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): | |||||||||||||||||||||||||||||||||||||||||||||||
| # Generate torch_compile_options based on device type | ||||||||||||||||||||||||||||||||||||||||||||||||
| if DEVICE_TYPE == "cuda": | ||||||||||||||||||||||||||||||||||||||||||||||||
| # CUDA-specific options (added to base options) | ||||||||||||||||||||||||||||||||||||||||||||||||
| new_options = ( | ||||||||||||||||||||||||||||||||||||||||||||||||
| base_options | ||||||||||||||||||||||||||||||||||||||||||||||||
| + """ | ||||||||||||||||||||||||||||||||||||||||||||||||
| "triton.enable_persistent_tma_matmul": torch.cuda.get_device_capability()[0] >= 9, | ||||||||||||||||||||||||||||||||||||||||||||||||
| cuda_options = """ | ||||||||||||||||||||||||||||||||||||||||||||||||
| "triton.enable_persistent_tma_matmul": torch.cuda.get_device_capability()[0] >= 9,""" | ||||||||||||||||||||||||||||||||||||||||||||||||
| # cutlass options were added in PyTorch 2.8.0 | ||||||||||||||||||||||||||||||||||||||||||||||||
| if torch_version >= Version("2.8.0"): | ||||||||||||||||||||||||||||||||||||||||||||||||
| cuda_options += """ | ||||||||||||||||||||||||||||||||||||||||||||||||
| "cuda.cutlass_epilogue_fusion_enabled": torch.cuda.get_device_capability()[0] >= 9, | ||||||||||||||||||||||||||||||||||||||||||||||||
| "cuda.cutlass_tma_only": torch.cuda.get_device_capability()[0] >= 9, | ||||||||||||||||||||||||||||||||||||||||||||||||
| "cuda.cutlass_tma_only": torch.cuda.get_device_capability()[0] >= 9,""" | ||||||||||||||||||||||||||||||||||||||||||||||||
| cuda_options += """ | ||||||||||||||||||||||||||||||||||||||||||||||||
| "cuda.compile_opt_level" : "-O2", | ||||||||||||||||||||||||||||||||||||||||||||||||
| "cuda.enable_cuda_lto" : True, | ||||||||||||||||||||||||||||||||||||||||||||||||
| }""" | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+1135
to
1145
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The way the I've provided a suggestion that fixes this bug by removing the comma and also makes the string construction slightly more readable by combining the unconditional string additions.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
| new_options = base_options + cuda_options | ||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||
| # XPU, HIP, and other device types use base options only | ||||||||||||||||||||||||||||||||||||||||||||||||
| new_options = ( | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current method of parsing the PyTorch version string is a bit fragile. It handles
aandbfor alpha/beta releases, but it doesn't handledevreleases correctly. For example, a version like2.8.0.dev20240715would not be stripped down, andVersion('2.8.0.dev...')is considered less thanVersion('2.8.0'), which might not be the intended behavior if dev releases should already include the new features.A more robust approach would be to use a regular expression to extract just the base
X.Y.Zversion string. This would handle a wider variety of version formats gracefully.