diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index eacfecc6c3..7d512bff1a 100755 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -70,6 +70,12 @@ def __init__(self, **kwargs): except Exception: trl_version = Version("0.0.0") +# Get PyTorch version for feature detection +try: + torch_version = Version(torch.__version__.split("+")[0].split("a")[0].split("b")[0]) +except Exception: + torch_version = Version("0.0.0") + def vLLMSamplingParams(**kwargs): from vllm import SamplingParams @@ -1126,16 +1132,18 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Generate torch_compile_options based on device type if DEVICE_TYPE == "cuda": # CUDA-specific options (added to base options) - new_options = ( - base_options - + """ - "triton.enable_persistent_tma_matmul": torch.cuda.get_device_capability()[0] >= 9, + cuda_options = """ + "triton.enable_persistent_tma_matmul": torch.cuda.get_device_capability()[0] >= 9,""" + # cutlass options were added in PyTorch 2.8.0 + if torch_version >= Version("2.8.0"): + cuda_options += """ "cuda.cutlass_epilogue_fusion_enabled": torch.cuda.get_device_capability()[0] >= 9, - "cuda.cutlass_tma_only": torch.cuda.get_device_capability()[0] >= 9, + "cuda.cutlass_tma_only": torch.cuda.get_device_capability()[0] >= 9,""" + cuda_options += """ "cuda.compile_opt_level" : "-O2", "cuda.enable_cuda_lto" : True, }""" - ) + new_options = base_options + cuda_options else: # XPU, HIP, and other device types use base options only new_options = (