Skip to content
150 changes: 121 additions & 29 deletions unsloth/import_fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,46 +503,138 @@ def make_inputs_require_grads(module, input, output):
)


def _is_custom_torch_build(raw_version_str):
"""Check if a raw version string indicates a custom or source build.
Must operate on the raw string from importlib_version(), not the parsed
Version object, since our custom Version() strips local identifiers.

Standard PyTorch releases use: +cu124, +rocm6.3, +cpu, +xpu
Source/custom builds use: +gitXXXXXXX, +HEXHASH, or other suffixes.
"""
if "+" not in raw_version_str:
return False
local = raw_version_str.split("+", 1)[1]
if not local:
return False
# Use fullmatch so the entire local identifier must match, not just a prefix.
# cu/rocm require a trailing digit (e.g. cu124, rocm6.3). cpu/xpu are exact.
# Case-insensitive since some builds may use uppercase.
return not re.fullmatch(r"cu\d[\d.]*|rocm\d[\d.]*|cpu|xpu", local, re.IGNORECASE)


def _infer_required_torchvision(torch_major, torch_minor):
"""Infer the minimum required torchvision minor version from torch version.

The torch -> torchvision minor version mapping follows a consistent formula:
torch 1.x -> torchvision 0.(x + 1) (verified: torch 1.7 through 1.13)
torch 2.x -> torchvision 0.(x + 15) (verified: torch 2.0 through 2.9)

Returns (tv_major, tv_minor) or None if the major version is unrecognized.
"""
if torch_major == 1 and torch_minor >= 7:
return (0, torch_minor + 1)
if torch_major == 2:
return (0, torch_minor + 15)
return None


def torchvision_compatibility_check():
# Allow skipping via environment variable for custom environments
if os.environ.get("UNSLOTH_SKIP_TORCHVISION_CHECK", "0").lower() in ("1", "true"):
return

if importlib.util.find_spec("torch") is None:
raise ImportError("Unsloth: torch not found. Please install torch first.")
if importlib.util.find_spec("torchvision") is None:
return
torch_version = importlib_version("torch")
torchvision_version = importlib_version("torchvision")

# Torch version -> minimum required torchvision version
try:
torch_version_raw = importlib_version("torch")
torchvision_version_raw = importlib_version("torchvision")
except Exception:
return
Comment on lines +554 to +555
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The try...except Exception block silently ignores errors when fetching package versions. While this prevents the application from crashing during import, it can hide underlying environment issues, such as a corrupted installation. Logging a warning here would provide valuable feedback to the user that the compatibility check was skipped and why.

Suggested change
except Exception:
return
except Exception as e:
logger.warning(f"Unsloth: Could not determine torch/torchvision versions, skipping compatibility check. Error: {e}")
return


try:
torch_v = Version(torch_version_raw)
tv_v = Version(torchvision_version_raw)
except Exception:
return
Comment on lines +560 to +561
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the previous block, this try...except silently fails when parsing version strings. This can make it difficult to diagnose problems related to malformed version numbers. Logging a warning that includes the problematic version strings would greatly improve debuggability for users.

    except Exception as e:
        logger.warning(
            f"Unsloth: Could not parse torch/torchvision versions, skipping compatibility check. "
            f"Versions: torch='{torch_version_raw}', torchvision='{torchvision_version_raw}'. Error: {e}"
        )
        return


# Known compatibility table (ground truth, takes precedence over formula).
# See https://pytorch.org/get-started/previous-versions/
TORCH_TORCHVISION_COMPAT = [
("2.9.0", "0.24.0"),
("2.8.0", "0.23.0"),
("2.7.0", "0.22.0"),
("2.6.0", "0.21.0"),
("2.5.0", "0.20.0"),
("2.4.0", "0.19.0"),
]

required_torchvision = None
for min_torch, min_torchvision in TORCH_TORCHVISION_COMPAT:
if Version(torch_version) >= Version(min_torch):
required_torchvision = min_torchvision
break

if required_torchvision is None:
# Torch version not in compatibility table, skip check
return

if Version(torchvision_version) < Version(required_torchvision):
raise ImportError(
f"Unsloth: torch=={torch_version} requires torchvision>={required_torchvision}, "
f"but found torchvision=={torchvision_version}. "
f"Please refer to https://pytorch.org/get-started/previous-versions/ for more information."
TORCH_TORCHVISION_COMPAT = {
(2, 9): (0, 24),
(2, 8): (0, 23),
(2, 7): (0, 22),
(2, 6): (0, 21),
(2, 5): (0, 20),
(2, 4): (0, 19),
}

# Extract major.minor from the parsed version
torch_release = torch_v.release
if len(torch_release) < 2:
return
torch_major, torch_minor = torch_release[0], torch_release[1]
Comment on lines +574 to +578
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Allow pre-release/nightly pairs without false mismatch

This logic derives the required torchvision version purely from torch_v.release (e.g., 2.8.0.dev20240301 becomes major/minor 2,8) and later compares against a final release string like 0.23.0. For nightly/dev builds, torchvision is typically 0.23.0.dev…, which Version() ranks below 0.23.0, so the check will raise ImportError even when the matching nightly pair is installed. This is a regression for standard nightly builds with local tags like +cu121 (not treated as custom). Consider detecting pre-releases (torch_v.is_prerelease or tv_v.is_prerelease) and downgrading to a warning or skipping the strict >= comparison for nightly pairs.

Useful? React with 👍 / 👎.


# Try known table first, then fall back to formula for forward compatibility
required = TORCH_TORCHVISION_COMPAT.get((torch_major, torch_minor))
is_in_known_table = required is not None

if required is None:
required = _infer_required_torchvision(torch_major, torch_minor)

if required is None:
return

required_tv_str = f"{required[0]}.{required[1]}.0"

if tv_v >= Version(required_tv_str):
logger.info(
f"Unsloth: torch=={torch_version_raw} and "
f"torchvision=={torchvision_version_raw} are compatible."
)
return

logger.info(
f"Unsloth: torch=={torch_version} and torchvision=={torchvision_version} are compatible."
# Version mismatch detected
message = (
f"Unsloth: torch=={torch_version_raw} requires "
f"torchvision>={required_tv_str}, "
f"but found torchvision=={torchvision_version_raw}. "
f"Please refer to https://pytorch.org/get-started/previous-versions/ "
f"for more information."
)

is_custom = _is_custom_torch_build(torch_version_raw) or _is_custom_torch_build(
torchvision_version_raw
)

# Detect nightly/dev/alpha/beta/rc builds from the raw version string.
# These often have version mismatches that are expected.
_pre_tags = (".dev", "a0", "b0", "rc", "alpha", "beta", "nightly")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _pre_tags tuple is a constant collection of strings. It's a good practice to define such constants at the module level (e.g., as _TORCHVISION_PRE_RELEASE_TAGS) rather than inside a function. This improves readability and maintainability by making it clear that this is a fixed set of values used for the check.

is_prerelease = any(t in torch_version_raw for t in _pre_tags) or any(
t in torchvision_version_raw for t in _pre_tags
)

# Downgrade to warning for custom/source/pre-release builds or formula-predicted
if is_custom or is_prerelease or not is_in_known_table:
reason = (
"custom/source build"
if is_custom
else "pre-release build"
if is_prerelease
else "newer torch version"
)
logger.warning(
f"{message}\n"
f"Detected a {reason}. "
f"Continuing with a warning. "
f"Set UNSLOTH_SKIP_TORCHVISION_CHECK=1 to silence this."
)
return

raise ImportError(message)


# Fix TRL OpenEnv 0.26 NameError: name 'SamplingParams' is not defined
def fix_openenv_no_vllm():
Expand Down