Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions unsloth/models/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ def __init__(self, **kwargs):
except Exception:
trl_version = Version("0.0.0")

# Get PyTorch version for feature detection
try:
torch_version = Version(torch.__version__.split("+")[0].split("a")[0].split("b")[0])

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current method of parsing the PyTorch version string is a bit fragile. It handles a and b for alpha/beta releases, but it doesn't handle dev releases correctly. For example, a version like 2.8.0.dev20240715 would not be stripped down, and Version('2.8.0.dev...') is considered less than Version('2.8.0'), which might not be the intended behavior if dev releases should already include the new features.

A more robust approach would be to use a regular expression to extract just the base X.Y.Z version string. This would handle a wider variety of version formats gracefully.

Suggested change
torch_version = Version(torch.__version__.split("+")[0].split("a")[0].split("b")[0])
torch_version = Version(re.match(r"^\d+\.\d+\.\d+", torch.__version__).group(0))

except Exception:
torch_version = Version("0.0.0")


def vLLMSamplingParams(**kwargs):
from vllm import SamplingParams
Expand Down Expand Up @@ -1126,16 +1132,18 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
# Generate torch_compile_options based on device type
if DEVICE_TYPE == "cuda":
# CUDA-specific options (added to base options)
new_options = (
base_options
+ """
"triton.enable_persistent_tma_matmul": torch.cuda.get_device_capability()[0] >= 9,
cuda_options = """
"triton.enable_persistent_tma_matmul": torch.cuda.get_device_capability()[0] >= 9,"""
# cutlass options were added in PyTorch 2.8.0
if torch_version >= Version("2.8.0"):
cuda_options += """
"cuda.cutlass_epilogue_fusion_enabled": torch.cuda.get_device_capability()[0] >= 9,
"cuda.cutlass_tma_only": torch.cuda.get_device_capability()[0] >= 9,
"cuda.cutlass_tma_only": torch.cuda.get_device_capability()[0] >= 9,"""
cuda_options += """
"cuda.compile_opt_level" : "-O2",
"cuda.enable_cuda_lto" : True,
}"""
Comment on lines +1135 to 1145

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The way the cuda_options string is constructed introduces a syntax error. There's a trailing comma after "cuda.enable_cuda_lto": True, which is the last item in the dictionary being constructed. This will cause a SyntaxError when the string is evaluated as Python code.

I've provided a suggestion that fixes this bug by removing the comma and also makes the string construction slightly more readable by combining the unconditional string additions.

Suggested change
cuda_options = """
"triton.enable_persistent_tma_matmul": torch.cuda.get_device_capability()[0] >= 9,"""
# cutlass options were added in PyTorch 2.8.0
if torch_version >= Version("2.8.0"):
cuda_options += """
"cuda.cutlass_epilogue_fusion_enabled": torch.cuda.get_device_capability()[0] >= 9,
"cuda.cutlass_tma_only": torch.cuda.get_device_capability()[0] >= 9,
"cuda.cutlass_tma_only": torch.cuda.get_device_capability()[0] >= 9,"""
cuda_options += """
"cuda.compile_opt_level" : "-O2",
"cuda.enable_cuda_lto" : True,
}"""
cuda_options = '''
"triton.enable_persistent_tma_matmul": torch.cuda.get_device_capability()[0] >= 9,'''
# cutlass options were added in PyTorch 2.8.0
if torch_version >= Version("2.8.0"):
cuda_options += '''
"cuda.cutlass_epilogue_fusion_enabled": torch.cuda.get_device_capability()[0] >= 9,
"cuda.cutlass_tma_only": torch.cuda.get_device_capability()[0] >= 9,'''
cuda_options += '''
"cuda.compile_opt_level" : "-O2",
"cuda.enable_cuda_lto" : True
}'''

)
new_options = base_options + cuda_options
else:
# XPU, HIP, and other device types use base options only
new_options = (
Expand Down