diff --git a/unsloth/import_fixes.py b/unsloth/import_fixes.py index 46df719a57..4fbab6a94c 100644 --- a/unsloth/import_fixes.py +++ b/unsloth/import_fixes.py @@ -123,6 +123,47 @@ def __getattr__(self, name): warnings.filterwarnings("ignore", message = "`int4_weight_only` is deprecated") warnings.filterwarnings("ignore", message = "`int8_weight_only` is deprecated") + # TorchAO deprecated import paths (https://github.com/pytorch/ao/issues/2752) + warnings.filterwarnings( + "ignore", + message = r"Importing.*from torchao\.dtypes.*is deprecated", + category = DeprecationWarning, + ) + warnings.filterwarnings( + "ignore", + message = r"Importing BlockSparseLayout from torchao\.dtypes is deprecated", + category = DeprecationWarning, + ) + + # SWIG builtin type warnings (from bitsandbytes/triton SWIG bindings) + warnings.filterwarnings( + "ignore", + message = r"builtin type Swig.*has no __module__ attribute", + category = DeprecationWarning, + ) + + # Triton autotuner deprecation (https://github.com/triton-lang/triton/pull/4496) + warnings.filterwarnings( + "ignore", + message = r"warmup, rep, and use_cuda_graph parameters are deprecated", + category = DeprecationWarning, + ) + + # Python 3.12+ multiprocessing fork warning in multi-threaded processes + warnings.filterwarnings( + "ignore", + message = r".*multi-threaded.*use of fork\(\) may lead to deadlocks", + category = DeprecationWarning, + ) + + # Resource warnings from internal socket/file operations + warnings.filterwarnings( + "ignore", message = r"unclosed.*socket", category = ResourceWarning + ) + warnings.filterwarnings( + "ignore", message = r"unclosed file.*dev/null", category = ResourceWarning + ) + # Fix up AttributeError: 'MessageFactory' object has no attribute 'GetPrototype' # MUST do this at the start primarily due to tensorflow causing issues diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index e5f5dfe68e..be04279aa2 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1152,8 +1152,12 @@ def has_internet(host = "8.8.8.8", port = 53, timeout = 3): return False try: socket.setdefaulttimeout(timeout) - socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect((host, port)) - return True + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + try: + sock.connect((host, port)) + return True + finally: + sock.close() except socket.error as ex: return False