diff --git a/unsloth/import_fixes.py b/unsloth/import_fixes.py index fa13835883..46df719a57 100644 --- a/unsloth/import_fixes.py +++ b/unsloth/import_fixes.py @@ -503,46 +503,138 @@ def make_inputs_require_grads(module, input, output): ) +def _is_custom_torch_build(raw_version_str): + """Check if a raw version string indicates a custom or source build. + Must operate on the raw string from importlib_version(), not the parsed + Version object, since our custom Version() strips local identifiers. + + Standard PyTorch releases use: +cu124, +rocm6.3, +cpu, +xpu + Source/custom builds use: +gitXXXXXXX, +HEXHASH, or other suffixes. + """ + if "+" not in raw_version_str: + return False + local = raw_version_str.split("+", 1)[1] + if not local: + return False + # Use fullmatch so the entire local identifier must match, not just a prefix. + # cu/rocm require a trailing digit (e.g. cu124, rocm6.3). cpu/xpu are exact. + # Case-insensitive since some builds may use uppercase. + return not re.fullmatch(r"cu\d[\d.]*|rocm\d[\d.]*|cpu|xpu", local, re.IGNORECASE) + + +def _infer_required_torchvision(torch_major, torch_minor): + """Infer the minimum required torchvision minor version from torch version. + + The torch -> torchvision minor version mapping follows a consistent formula: + torch 1.x -> torchvision 0.(x + 1) (verified: torch 1.7 through 1.13) + torch 2.x -> torchvision 0.(x + 15) (verified: torch 2.0 through 2.9) + + Returns (tv_major, tv_minor) or None if the major version is unrecognized. + """ + if torch_major == 1 and torch_minor >= 7: + return (0, torch_minor + 1) + if torch_major == 2: + return (0, torch_minor + 15) + return None + + def torchvision_compatibility_check(): + # Allow skipping via environment variable for custom environments + if os.environ.get("UNSLOTH_SKIP_TORCHVISION_CHECK", "0").lower() in ("1", "true"): + return + if importlib.util.find_spec("torch") is None: raise ImportError("Unsloth: torch not found. Please install torch first.") if importlib.util.find_spec("torchvision") is None: return - torch_version = importlib_version("torch") - torchvision_version = importlib_version("torchvision") - # Torch version -> minimum required torchvision version + try: + torch_version_raw = importlib_version("torch") + torchvision_version_raw = importlib_version("torchvision") + except Exception: + return + + try: + torch_v = Version(torch_version_raw) + tv_v = Version(torchvision_version_raw) + except Exception: + return + + # Known compatibility table (ground truth, takes precedence over formula). # See https://pytorch.org/get-started/previous-versions/ - TORCH_TORCHVISION_COMPAT = [ - ("2.9.0", "0.24.0"), - ("2.8.0", "0.23.0"), - ("2.7.0", "0.22.0"), - ("2.6.0", "0.21.0"), - ("2.5.0", "0.20.0"), - ("2.4.0", "0.19.0"), - ] - - required_torchvision = None - for min_torch, min_torchvision in TORCH_TORCHVISION_COMPAT: - if Version(torch_version) >= Version(min_torch): - required_torchvision = min_torchvision - break - - if required_torchvision is None: - # Torch version not in compatibility table, skip check - return - - if Version(torchvision_version) < Version(required_torchvision): - raise ImportError( - f"Unsloth: torch=={torch_version} requires torchvision>={required_torchvision}, " - f"but found torchvision=={torchvision_version}. " - f"Please refer to https://pytorch.org/get-started/previous-versions/ for more information." + TORCH_TORCHVISION_COMPAT = { + (2, 9): (0, 24), + (2, 8): (0, 23), + (2, 7): (0, 22), + (2, 6): (0, 21), + (2, 5): (0, 20), + (2, 4): (0, 19), + } + + # Extract major.minor from the parsed version + torch_release = torch_v.release + if len(torch_release) < 2: + return + torch_major, torch_minor = torch_release[0], torch_release[1] + + # Try known table first, then fall back to formula for forward compatibility + required = TORCH_TORCHVISION_COMPAT.get((torch_major, torch_minor)) + is_in_known_table = required is not None + + if required is None: + required = _infer_required_torchvision(torch_major, torch_minor) + + if required is None: + return + + required_tv_str = f"{required[0]}.{required[1]}.0" + + if tv_v >= Version(required_tv_str): + logger.info( + f"Unsloth: torch=={torch_version_raw} and " + f"torchvision=={torchvision_version_raw} are compatible." ) + return - logger.info( - f"Unsloth: torch=={torch_version} and torchvision=={torchvision_version} are compatible." + # Version mismatch detected + message = ( + f"Unsloth: torch=={torch_version_raw} requires " + f"torchvision>={required_tv_str}, " + f"but found torchvision=={torchvision_version_raw}. " + f"Please refer to https://pytorch.org/get-started/previous-versions/ " + f"for more information." + ) + + is_custom = _is_custom_torch_build(torch_version_raw) or _is_custom_torch_build( + torchvision_version_raw + ) + + # Detect nightly/dev/alpha/beta/rc builds from the raw version string. + # These often have version mismatches that are expected. + _pre_tags = (".dev", "a0", "b0", "rc", "alpha", "beta", "nightly") + is_prerelease = any(t in torch_version_raw for t in _pre_tags) or any( + t in torchvision_version_raw for t in _pre_tags ) + # Downgrade to warning for custom/source/pre-release builds or formula-predicted + if is_custom or is_prerelease or not is_in_known_table: + reason = ( + "custom/source build" + if is_custom + else "pre-release build" + if is_prerelease + else "newer torch version" + ) + logger.warning( + f"{message}\n" + f"Detected a {reason}. " + f"Continuing with a warning. " + f"Set UNSLOTH_SKIP_TORCHVISION_CHECK=1 to silence this." + ) + return + + raise ImportError(message) + # Fix TRL OpenEnv 0.26 NameError: name 'SamplingParams' is not defined def fix_openenv_no_vllm():