-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Fix torchvision compatibility check for source builds and future torch versions #3978
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6edc44a
99a0616
d771634
61132cc
133deda
d75ebdf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -503,46 +503,138 @@ def make_inputs_require_grads(module, input, output): | |
| ) | ||
|
|
||
|
|
||
| def _is_custom_torch_build(raw_version_str): | ||
| """Check if a raw version string indicates a custom or source build. | ||
| Must operate on the raw string from importlib_version(), not the parsed | ||
| Version object, since our custom Version() strips local identifiers. | ||
|
|
||
| Standard PyTorch releases use: +cu124, +rocm6.3, +cpu, +xpu | ||
| Source/custom builds use: +gitXXXXXXX, +HEXHASH, or other suffixes. | ||
| """ | ||
| if "+" not in raw_version_str: | ||
| return False | ||
| local = raw_version_str.split("+", 1)[1] | ||
| if not local: | ||
| return False | ||
| # Use fullmatch so the entire local identifier must match, not just a prefix. | ||
| # cu/rocm require a trailing digit (e.g. cu124, rocm6.3). cpu/xpu are exact. | ||
| # Case-insensitive since some builds may use uppercase. | ||
| return not re.fullmatch(r"cu\d[\d.]*|rocm\d[\d.]*|cpu|xpu", local, re.IGNORECASE) | ||
|
|
||
|
|
||
| def _infer_required_torchvision(torch_major, torch_minor): | ||
| """Infer the minimum required torchvision minor version from torch version. | ||
|
|
||
| The torch -> torchvision minor version mapping follows a consistent formula: | ||
| torch 1.x -> torchvision 0.(x + 1) (verified: torch 1.7 through 1.13) | ||
| torch 2.x -> torchvision 0.(x + 15) (verified: torch 2.0 through 2.9) | ||
|
|
||
| Returns (tv_major, tv_minor) or None if the major version is unrecognized. | ||
| """ | ||
| if torch_major == 1 and torch_minor >= 7: | ||
| return (0, torch_minor + 1) | ||
| if torch_major == 2: | ||
| return (0, torch_minor + 15) | ||
| return None | ||
|
|
||
|
|
||
| def torchvision_compatibility_check(): | ||
| # Allow skipping via environment variable for custom environments | ||
| if os.environ.get("UNSLOTH_SKIP_TORCHVISION_CHECK", "0").lower() in ("1", "true"): | ||
| return | ||
|
|
||
| if importlib.util.find_spec("torch") is None: | ||
| raise ImportError("Unsloth: torch not found. Please install torch first.") | ||
| if importlib.util.find_spec("torchvision") is None: | ||
| return | ||
| torch_version = importlib_version("torch") | ||
| torchvision_version = importlib_version("torchvision") | ||
|
|
||
| # Torch version -> minimum required torchvision version | ||
| try: | ||
| torch_version_raw = importlib_version("torch") | ||
| torchvision_version_raw = importlib_version("torchvision") | ||
| except Exception: | ||
| return | ||
|
|
||
| try: | ||
| torch_v = Version(torch_version_raw) | ||
| tv_v = Version(torchvision_version_raw) | ||
| except Exception: | ||
| return | ||
|
Comment on lines
+560
to
+561
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the previous block, this except Exception as e:
logger.warning(
f"Unsloth: Could not parse torch/torchvision versions, skipping compatibility check. "
f"Versions: torch='{torch_version_raw}', torchvision='{torchvision_version_raw}'. Error: {e}"
)
return |
||
|
|
||
| # Known compatibility table (ground truth, takes precedence over formula). | ||
| # See https://pytorch.org/get-started/previous-versions/ | ||
| TORCH_TORCHVISION_COMPAT = [ | ||
| ("2.9.0", "0.24.0"), | ||
| ("2.8.0", "0.23.0"), | ||
| ("2.7.0", "0.22.0"), | ||
| ("2.6.0", "0.21.0"), | ||
| ("2.5.0", "0.20.0"), | ||
| ("2.4.0", "0.19.0"), | ||
| ] | ||
|
|
||
| required_torchvision = None | ||
| for min_torch, min_torchvision in TORCH_TORCHVISION_COMPAT: | ||
| if Version(torch_version) >= Version(min_torch): | ||
| required_torchvision = min_torchvision | ||
| break | ||
|
|
||
| if required_torchvision is None: | ||
| # Torch version not in compatibility table, skip check | ||
| return | ||
|
|
||
| if Version(torchvision_version) < Version(required_torchvision): | ||
| raise ImportError( | ||
| f"Unsloth: torch=={torch_version} requires torchvision>={required_torchvision}, " | ||
| f"but found torchvision=={torchvision_version}. " | ||
| f"Please refer to https://pytorch.org/get-started/previous-versions/ for more information." | ||
| TORCH_TORCHVISION_COMPAT = { | ||
| (2, 9): (0, 24), | ||
| (2, 8): (0, 23), | ||
| (2, 7): (0, 22), | ||
| (2, 6): (0, 21), | ||
| (2, 5): (0, 20), | ||
| (2, 4): (0, 19), | ||
| } | ||
|
|
||
| # Extract major.minor from the parsed version | ||
| torch_release = torch_v.release | ||
| if len(torch_release) < 2: | ||
| return | ||
| torch_major, torch_minor = torch_release[0], torch_release[1] | ||
|
Comment on lines
+574
to
+578
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This logic derives the required torchvision version purely from Useful? React with 👍 / 👎. |
||
|
|
||
| # Try known table first, then fall back to formula for forward compatibility | ||
| required = TORCH_TORCHVISION_COMPAT.get((torch_major, torch_minor)) | ||
| is_in_known_table = required is not None | ||
|
|
||
| if required is None: | ||
| required = _infer_required_torchvision(torch_major, torch_minor) | ||
|
|
||
| if required is None: | ||
| return | ||
|
|
||
| required_tv_str = f"{required[0]}.{required[1]}.0" | ||
|
|
||
| if tv_v >= Version(required_tv_str): | ||
| logger.info( | ||
| f"Unsloth: torch=={torch_version_raw} and " | ||
| f"torchvision=={torchvision_version_raw} are compatible." | ||
| ) | ||
| return | ||
|
|
||
| logger.info( | ||
| f"Unsloth: torch=={torch_version} and torchvision=={torchvision_version} are compatible." | ||
| # Version mismatch detected | ||
| message = ( | ||
| f"Unsloth: torch=={torch_version_raw} requires " | ||
| f"torchvision>={required_tv_str}, " | ||
| f"but found torchvision=={torchvision_version_raw}. " | ||
| f"Please refer to https://pytorch.org/get-started/previous-versions/ " | ||
| f"for more information." | ||
| ) | ||
|
|
||
| is_custom = _is_custom_torch_build(torch_version_raw) or _is_custom_torch_build( | ||
| torchvision_version_raw | ||
| ) | ||
|
|
||
| # Detect nightly/dev/alpha/beta/rc builds from the raw version string. | ||
| # These often have version mismatches that are expected. | ||
| _pre_tags = (".dev", "a0", "b0", "rc", "alpha", "beta", "nightly") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
| is_prerelease = any(t in torch_version_raw for t in _pre_tags) or any( | ||
| t in torchvision_version_raw for t in _pre_tags | ||
| ) | ||
|
|
||
| # Downgrade to warning for custom/source/pre-release builds or formula-predicted | ||
| if is_custom or is_prerelease or not is_in_known_table: | ||
| reason = ( | ||
| "custom/source build" | ||
| if is_custom | ||
| else "pre-release build" | ||
| if is_prerelease | ||
| else "newer torch version" | ||
| ) | ||
| logger.warning( | ||
| f"{message}\n" | ||
| f"Detected a {reason}. " | ||
| f"Continuing with a warning. " | ||
| f"Set UNSLOTH_SKIP_TORCHVISION_CHECK=1 to silence this." | ||
| ) | ||
| return | ||
|
|
||
| raise ImportError(message) | ||
|
|
||
|
|
||
| # Fix TRL OpenEnv 0.26 NameError: name 'SamplingParams' is not defined | ||
| def fix_openenv_no_vllm(): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
try...except Exceptionblock silently ignores errors when fetching package versions. While this prevents the application from crashing during import, it can hide underlying environment issues, such as a corrupted installation. Logging a warning here would provide valuable feedback to the user that the compatibility check was skipped and why.