From 6c9345a99443fa0327ee1203abd412ec5984c463 Mon Sep 17 00:00:00 2001 From: Manan17 Date: Thu, 9 Apr 2026 00:05:26 -0500 Subject: [PATCH 01/68] Add Apple Silicon MLX routing Rewrite __init__.py: detect MLX on macOS arm64 before any torch imports Extract original GPU init to _gpu_init.py (unchanged) MLX path imports FastMLXModel from unsloth_zoo, skips all GPU code GPU path unchanged: from ._gpu_init import * --- .github/workflows/release-desktop.yml | 226 - .gitignore | 12 - .pre-commit-config.yaml | 2 +- README.md | 57 +- install.ps1 | 510 +- install.sh | 768 +- ...all_gemma4_mlx.sh => install_gemma4_mlx.sh | 66 +- pyproject.toml | 5 +- scripts/install_qwen3_6_mlx.sh | 191 - studio/Unsloth_Studio_Colab.ipynb | 2 +- .../assets/configs/inference_defaults.json | 10 +- studio/backend/auth/authentication.py | 64 +- studio/backend/auth/storage.py | 398 +- studio/backend/core/__init__.py | 7 +- .../core/data_recipe/jobs/constants.py | 1 - .../backend/core/data_recipe/jobs/manager.py | 124 +- studio/backend/core/data_recipe/jobs/parse.py | 224 +- studio/backend/core/data_recipe/jobs/types.py | 25 - .../backend/core/data_recipe/jobs/worker.py | 29 +- .../data_recipe/local_callable_validators.py | 9 +- studio/backend/core/export/export.py | 102 +- studio/backend/core/export/orchestrator.py | 336 +- studio/backend/core/export/worker.py | 184 +- .../core/inference/anthropic_compat.py | 576 -- studio/backend/core/inference/audio_codecs.py | 12 +- studio/backend/core/inference/defaults.py | 2 - studio/backend/core/inference/inference.py | 4 - studio/backend/core/inference/llama_cpp.py | 1344 +--- .../core/inference/llama_server_args.py | 120 - studio/backend/core/inference/orchestrator.py | 48 +- studio/backend/core/training/resume.py | 75 - studio/backend/core/training/trainer.py | 29 +- studio/backend/core/training/training.py | 33 +- studio/backend/core/training/worker.py | 418 +- studio/backend/loggers/config.py | 12 - studio/backend/loggers/handlers.py | 26 +- studio/backend/main.py | 58 +- studio/backend/models/auth.py | 49 - studio/backend/models/inference.py | 623 +- studio/backend/models/models.py | 65 - studio/backend/models/training.py | 5 - .../data-designer-github-repo-seed/README.md | 73 - .../pyproject.toml | 25 - .../__init__.py | 7 - .../data_designer_github_repo_seed/config.py | 64 - .../data_designer_github_repo_seed/impl.py | 83 - .../data_designer_github_repo_seed/plugin.py | 10 - .../data_designer_github_repo_seed/scraper.py | 236 - .../scraper_impl/__init__.py | 2 - .../scraper_impl/gh_client.py | 248 - .../scraper_impl/queries.py | 685 -- .../scraper_impl/scraper.py | 756 -- .../scraper_impl/state_store.py | 105 - .../backend/requirements/extras-no-deps.txt | 8 +- .../single-env/data-designer-deps.txt | 3 +- studio/backend/routes/__init__.py | 2 - studio/backend/routes/auth.py | 98 +- studio/backend/routes/data_recipe/jobs.py | 214 +- studio/backend/routes/data_recipe/seed.py | 12 - studio/backend/routes/data_recipe/validate.py | 85 - studio/backend/routes/datasets.py | 130 +- studio/backend/routes/export.py | 215 +- studio/backend/routes/inference.py | 2777 +------ studio/backend/routes/models.py | 1403 +--- studio/backend/routes/training.py | 39 - studio/backend/routes/training_history.py | 13 +- studio/backend/run.py | 34 +- studio/backend/state/tool_policy.py | 33 - studio/backend/storage/studio_db.py | 81 +- studio/backend/tests/conftest.py | 126 +- .../backend/tests/test_anthropic_messages.py | 1013 --- .../tests/test_browse_folders_route.py | 86 - .../backend/tests/test_cached_gguf_routes.py | 398 - .../tests/test_data_recipe_github_progress.py | 91 - studio/backend/tests/test_desktop_auth.py | 598 -- .../backend/tests/test_export_log_cursor.py | 179 - studio/backend/tests/test_gpu_selection.py | 326 +- studio/backend/tests/test_host_defaults.py | 98 - .../backend/tests/test_kv_cache_estimation.py | 1344 +--- .../test_llama_cpp_cache_aware_disk_check.py | 243 - .../tests/test_llama_cpp_context_fit.py | 393 - .../tests/test_llama_cpp_load_progress.py | 258 - .../test_llama_cpp_load_progress_live.py | 202 - .../test_llama_cpp_load_progress_matrix.py | 473 -- .../test_llama_cpp_max_context_threshold.py | 248 - .../tests/test_llama_cpp_no_context_shift.py | 137 - .../backend/tests/test_llama_server_args.py | 189 - .../tests/test_native_context_length.py | 12 - .../tests/test_openai_tool_passthrough.py | 474 -- studio/backend/tests/test_pytorch_mirror.py | 55 - studio/backend/tests/test_responses_api.py | 328 - .../tests/test_responses_tool_passthrough.py | 667 -- studio/backend/tests/test_studio_api.py | 974 --- .../backend/tests/test_tool_policy_gates.py | 56 - .../backend/tests/test_tool_policy_state.py | 59 - .../backend/tests/test_trained_model_scan.py | 101 - .../tests/test_training_worker_flash_attn.py | 170 - studio/backend/tests/test_utils.py | 8 +- studio/backend/tests/test_vision_cache.py | 45 +- studio/backend/tests/test_vram_estimation.py | 1674 ---- studio/backend/utils/datasets/llm_assist.py | 2 +- .../backend/utils/datasets/model_mappings.py | 41 - .../backend/utils/hardware/VRAM_ESTIMATION.md | 89 +- studio/backend/utils/hardware/__init__.py | 10 - studio/backend/utils/hardware/amd.py | 384 - studio/backend/utils/hardware/hardware.py | 309 +- studio/backend/utils/hardware/nvidia.py | 13 - .../backend/utils/hardware/vram_estimation.py | 951 +-- studio/backend/utils/models/__init__.py | 7 +- studio/backend/utils/models/model_config.py | 322 +- studio/backend/utils/native_path_leases.py | 406 - studio/backend/utils/paths/__init__.py | 2 - studio/backend/utils/paths/storage_roots.py | 45 - studio/backend/utils/subprocess_compat.py | 34 - studio/backend/utils/transformers_version.py | 11 - studio/backend/utils/wheel_utils.py | 175 - studio/frontend/package.json | 8 - studio/frontend/public/blacklogo-c.png | Bin 141545 -> 0 bytes studio/frontend/public/circle-logo-small.png | Bin 14416 -> 0 bytes .../fonts/FiraCode-VariableFont_wght.ttf | Bin 264848 -> 0 bytes .../frontend/public/fonts/Hellix-Medium.woff | Bin 58448 -> 0 bytes .../frontend/public/fonts/Hellix-Regular.woff | Bin 58584 -> 0 bytes studio/frontend/public/sidebar-logo-black.png | Bin 9170 -> 0 bytes studio/frontend/public/sidebar-logo-white.png | Bin 9057 -> 0 bytes studio/frontend/public/sticker.png | Bin 1013935 -> 0 bytes studio/frontend/public/studio.png | Bin 27420 -> 0 bytes studio/frontend/public/unsloth-beta-black.png | Bin 160502 -> 0 bytes studio/frontend/public/unsloth-beta-white.png | Bin 156421 -> 0 bytes studio/frontend/public/whitelogo-c.png | Bin 139810 -> 0 bytes studio/frontend/src/app/auth-guards.ts | 62 +- studio/frontend/src/app/provider.tsx | 278 +- studio/frontend/src/app/routes/__root.tsx | 84 +- studio/frontend/src/app/routes/chat.tsx | 15 +- studio/frontend/src/app/routes/onboarding.tsx | 5 - .../frontend/src/components/app-sidebar.tsx | 606 -- .../components/assistant-ui/code-plugin.ts | 66 - .../components/assistant-ui/code-themes.ts | 30 - .../assistant-ui/code-toggle-icon.tsx | 22 - .../components/assistant-ui/markdown-text.tsx | 27 +- .../assistant-ui/model-selector.tsx | 94 +- .../model-selector/folder-browser.tsx | 328 - .../model-selector/model-delete-action.tsx | 112 - .../assistant-ui/model-selector/pickers.tsx | 651 +- .../assistant-ui/model-selector/types.ts | 5 - .../src/components/assistant-ui/reasoning.tsx | 89 +- .../src/components/assistant-ui/sources.tsx | 15 +- .../src/components/assistant-ui/thread.tsx | 550 +- .../components/assistant-ui/tool-fallback.tsx | 4 +- .../components/assistant-ui/tool-group.tsx | 4 +- .../assistant-ui/tool-ui-python.tsx | 4 +- .../assistant-ui/tool-ui-terminal.tsx | 4 +- .../use-intent-aware-autoscroll.tsx | 451 -- studio/frontend/src/components/navbar.tsx | 617 +- .../src/components/shutdown-dialog.tsx | 13 +- .../src/components/tauri/startup-screen.tsx | 461 -- .../src/components/tauri/update-banner.tsx | 146 - .../src/components/tauri/update-screen.tsx | 217 - .../src/components/tauri/window-titlebar.tsx | 329 - .../components/ui/animated-theme-toggler.tsx | 76 +- studio/frontend/src/components/ui/command.tsx | 11 +- .../frontend/src/components/ui/progress.tsx | 13 +- .../src/components/ui/shimmer-button.tsx | 96 - studio/frontend/src/components/ui/sidebar.tsx | 137 +- studio/frontend/src/components/ui/slider.tsx | 4 +- .../frontend/src/components/ui/textarea.tsx | 35 +- studio/frontend/src/config/env.ts | 31 +- studio/frontend/src/features/auth/api.ts | 89 +- .../features/auth/change-password-page.tsx | 2 +- .../features/auth/components/auth-form.tsx | 16 +- studio/frontend/src/features/auth/index.ts | 6 - .../frontend/src/features/auth/login-page.tsx | 2 +- studio/frontend/src/features/auth/session.ts | 6 +- .../src/features/auth/tauri-auto-auth.ts | 106 - .../src/features/chat/api/chat-adapter.ts | 254 +- .../src/features/chat/api/chat-api.ts | 115 +- .../frontend/src/features/chat/chat-page.tsx | 970 +-- .../src/features/chat/chat-settings-sheet.tsx | 1108 +-- .../chat/components/chat-search-dialog.tsx | 120 - .../chat/components/context-usage-bar.tsx | 7 - .../chat/components/model-load-status.tsx | 59 +- .../chat/hooks/use-chat-model-runtime.ts | 582 +- .../chat/hooks/use-chat-search-index.ts | 159 - .../chat/hooks/use-chat-sidebar-items.ts | 97 - .../features/chat/hooks/use-transfer-stats.ts | 55 - .../features/chat/presets/preset-policy.ts | 351 - .../src/features/chat/runtime-provider.tsx | 188 +- .../src/features/chat/shared-composer.tsx | 209 +- .../chat/stores/chat-runtime-store.ts | 113 +- .../features/chat/stores/chat-search-store.ts | 18 - .../src/features/chat/thread-sidebar.tsx | 74 +- .../frontend/src/features/chat/tour/steps.tsx | 8 +- .../frontend/src/features/chat/types/api.ts | 12 - .../chat/utils/chat-thread-tombstones.ts | 12 - .../features/chat/utils/clear-all-chats.ts | 15 - .../chat/utils/delete-thread-message.ts | 124 - .../chat/utils/export-chat-history.ts | 41 - .../features/chat/utils/format-transfer.ts | 44 - .../src/features/chat/utils/qwen-params.ts | 29 - .../src/features/chat/utils/transfer-stats.ts | 88 - .../hooks/use-recipe-sidebar-items.ts | 22 - .../learning-recipes/github-support-bot.json | 238 - .../data-recipes/learning-recipes/index.ts | 11 - .../data-recipes/pages/data-recipes-page.tsx | 19 +- .../data-recipes/pages/edit-recipe-page.tsx | 2 +- .../src/features/export/api/export-api.ts | 177 +- .../export/components/export-dialog.tsx | 351 +- .../frontend/src/features/export/constants.ts | 6 +- .../src/features/export/export-page.tsx | 30 +- .../src/features/native-intents/api.ts | 46 - .../components/native-model-chip.tsx | 107 - .../components/native-model-drop-overlay.tsx | 70 - .../native-intents/native-intent-drain.tsx | 39 - .../src/features/native-intents/store.ts | 23 - .../src/features/native-intents/types.ts | 39 - .../native-intents/use-native-dialogs.ts | 68 - .../native-intents/use-native-drop.ts | 132 - .../native-intents/use-native-readiness.ts | 46 - .../onboarding/components/splash-screen.tsx | 6 +- .../onboarding/components/wizard-footer.tsx | 18 +- .../onboarding/components/wizard-layout.tsx | 28 +- .../onboarding/components/wizard-sidebar.tsx | 10 +- .../profile-personalization-panel.tsx | 157 - .../profile/components/user-avatar.tsx | 45 - .../profile/hooks/use-effective-profile.ts | 19 - studio/frontend/src/features/profile/index.ts | 6 - .../profile/stores/user-profile-store.ts | 24 - .../features/profile/utils/avatar-initials.ts | 13 - .../src/features/profile/utils/jwt-subject.ts | 21 - .../profile/utils/resize-image-file.ts | 83 - .../src/features/recipe-studio/api/index.ts | 70 +- .../recipe-studio/blocks/definitions.ts | 18 +- .../executions/execution-data-tab.tsx | 178 +- .../executions/execution-overview-tab.tsx | 66 - .../components/executions/executions-view.tsx | 87 +- .../components/inline/inline-seed.tsx | 83 +- .../components/recipe-graph-aux-node.tsx | 4 +- .../components/recipe-graph-node.tsx | 4 +- .../components/recipe-studio-header.tsx | 11 +- .../runtime/execution-progress-island.tsx | 98 +- .../recipe-studio/dialogs/preview-dialog.tsx | 9 +- .../dialogs/seed/seed-dialog.tsx | 796 +- .../easy/github-crawler-easy-view.tsx | 191 - .../features/recipe-studio/execution-types.ts | 26 +- .../recipe-studio/executions/runtime.ts | 11 - .../hooks/use-recipe-executions.ts | 117 +- .../hooks/use-recipe-studio-actions.ts | 4 - .../recipe-studio/recipe-studio-page.tsx | 75 +- .../recipe-studio/stores/recipe-studio.ts | 10 +- .../src/features/recipe-studio/types/index.ts | 19 +- .../import/parsers/seed-config-parser.ts | 32 - .../features/recipe-studio/utils/node-data.ts | 4 +- .../utils/payload/builders-seed.ts | 108 +- .../recipe-studio/utils/payload/types.ts | 2 +- .../recipe-studio/utils/validation.ts | 51 +- .../src/features/settings/api/api-keys.ts | 41 - .../settings/components/api-key-row.tsx | 101 - .../settings/components/create-key-form.tsx | 83 - .../settings/components/key-reveal-card.tsx | 71 - .../settings/components/settings-row.tsx | 39 - .../settings/components/settings-section.tsx | 30 - .../settings/components/theme-segmented.tsx | 58 - .../components/update-studio-instructions.tsx | 185 - .../settings/components/usage-examples.tsx | 158 - .../frontend/src/features/settings/index.ts | 6 - .../src/features/settings/settings-dialog.tsx | 174 - .../settings/stores/settings-dialog-store.ts | 53 - .../features/settings/stores/theme-store.ts | 96 - .../src/features/settings/tabs/about-tab.tsx | 134 - .../features/settings/tabs/api-keys-tab.tsx | 168 - .../features/settings/tabs/appearance-tab.tsx | 40 - .../src/features/settings/tabs/chat-tab.tsx | 129 - .../features/settings/tabs/general-tab.tsx | 239 - .../features/settings/tabs/profile-tab.tsx | 19 - .../studio/historical-training-view.tsx | 1 - .../src/features/studio/history-card-grid.tsx | 74 +- .../features/studio/live-training-view.tsx | 2 - .../charts/chart-preferences-store.ts | 9 +- .../studio/sections/dataset-section.tsx | 112 +- .../src/features/studio/studio-page.tsx | 42 +- .../studio/training-start-overlay.tsx | 270 +- .../training/hooks/use-training-actions.ts | 70 +- .../hooks/use-training-history-sidebar.ts | 56 - .../hooks/use-training-runtime-lifecycle.ts | 120 +- .../hooks/use-training-unload-guard.ts | 40 - .../frontend/src/features/training/index.ts | 1 - .../training/stores/training-runtime-store.ts | 13 - .../src/features/training/types/api.ts | 1 - .../src/features/training/types/history.ts | 2 - .../src/features/training/types/runtime.ts | 13 - studio/frontend/src/hooks/index.ts | 2 - .../src/hooks/use-collapse-scroll-lock.ts | 63 - studio/frontend/src/hooks/use-gpu-info.ts | 3 +- .../frontend/src/hooks/use-hf-model-search.ts | 8 +- studio/frontend/src/hooks/use-mobile.ts | 27 +- studio/frontend/src/hooks/use-sidebar-pin.ts | 59 - .../frontend/src/hooks/use-tauri-backend.ts | 581 -- studio/frontend/src/hooks/use-tauri-update.ts | 310 - studio/frontend/src/index.css | 925 +-- studio/frontend/src/lib/api-base.ts | 30 - studio/frontend/src/lib/copy-to-clipboard.ts | 67 +- studio/frontend/src/lib/latex.ts | 117 - .../frontend/src/lib/native-notifications.ts | 234 - studio/frontend/src/lib/open-link.ts | 31 - studio/frontend/src/lib/tauri-diagnostics.ts | 176 - studio/install_llama_prebuilt.py | 345 +- studio/install_python_stack.py | 637 +- studio/setup.ps1 | 255 +- studio/setup.sh | 54 +- studio/src-tauri/Cargo.lock | 6839 ----------------- studio/src-tauri/Cargo.toml | 48 - studio/src-tauri/Entitlements.plist | 15 - studio/src-tauri/build.rs | 3 - studio/src-tauri/capabilities/default.json | 33 - studio/src-tauri/icons/128x128.png | Bin 9194 -> 0 bytes studio/src-tauri/icons/32x32.png | Bin 1930 -> 0 bytes studio/src-tauri/icons/icon.icns | Bin 263568 -> 0 bytes studio/src-tauri/icons/icon.ico | Bin 34589 -> 0 bytes studio/src-tauri/icons/icon.png | Bin 46007 -> 0 bytes studio/src-tauri/linux/postremove.sh | 9 - studio/src-tauri/src/commands.rs | 639 -- studio/src-tauri/src/desktop_auth.rs | 505 -- studio/src-tauri/src/diagnostics/mod.rs | 169 - studio/src-tauri/src/diagnostics/phase_log.rs | 823 -- studio/src-tauri/src/diagnostics/redaction.rs | 217 - studio/src-tauri/src/diagnostics/report.rs | 607 -- studio/src-tauri/src/diagnostics/state.rs | 774 -- studio/src-tauri/src/install.rs | 895 --- studio/src-tauri/src/main.rs | 229 - studio/src-tauri/src/native_backend_lease.rs | 199 - studio/src-tauri/src/native_intents.rs | 476 -- studio/src-tauri/src/native_path_policy.rs | 291 - studio/src-tauri/src/preflight.rs | 831 -- studio/src-tauri/src/process.rs | 658 -- studio/src-tauri/src/update.rs | 414 - studio/src-tauri/src/windows_job.rs | 72 - studio/src-tauri/tauri.conf.json | 83 - studio/src-tauri/tauri.macos.conf.json | 11 - studio/src-tauri/tauri.windows.conf.json | 10 - .../windows/branding/nsis-header.bmp | Bin 102654 -> 0 bytes .../windows/branding/nsis-sidebar.bmp | Bin 618006 -> 0 bytes studio/src-tauri/windows/hooks.nsh | 9 - studio/src-tauri/windows/installer.nsi | 994 --- tests/python/test_cross_platform_parity.py | 16 - .../test_dpo_vision_processor_passthrough.py | 149 - ...sentence_transformer_redirect_lifecycle.py | 218 - .../test_flash_attn_install_python_stack.py | 282 - tests/python/test_no_torch_filtering.py | 5 +- .../test_unsloth_run_tool_policy_resolver.py | 201 - .../saving/non_peft/test_mistral_non_peft.py | 2 +- .../saving/non_peft/test_whisper_non_peft.py | 2 +- .../test_fix_sentencepiece_gguf_robustness.py | 132 - .../test_patch_saving_none_tokenizer.py | 47 - .../test_index_file_sharded_model.py | 4 +- .../vision_models/test_push_to_hub_merged.py | 4 +- ..._merge_qwen2.5vl32B_model_ocr_benchmark.py | 4 +- ...t_save_merge_vision_model_ocr_benchmark.py | 4 +- tests/sh/test_get_torch_index_url.sh | 179 +- tests/sh/test_install_host_defaults.sh | 104 - tests/sh/test_mac_intel_compat.sh | 24 +- tests/sh/test_tauri_install_exit_order.sh | 95 - tests/studio/install/test_rocm_support.py | 1346 ---- tests/studio/install/test_selection_logic.py | 17 - tests/studio/test_cancel_atomicity.py | 289 - tests/studio/test_cancel_id_wiring.py | 169 - .../test_chat_preset_builtin_invariants.py | 272 - tests/studio/test_cli_repo_variant.py | 145 - tests/studio/test_cli_run_alias.py | 69 - tests/studio/test_cli_studio_defaults.py | 87 - tests/studio/test_llama_cpp_wall_clock_cap.py | 123 - .../test_stream_cancel_registration_timing.py | 718 -- .../test_studio_gguf_export_script_pin.py | 231 - .../test_studio_text_descender_clipping.py | 69 - tests/test_cli_export_unpacking.py | 160 - tests/test_gemma4_chat_template.py | 183 - tests/test_get_model_name.py | 40 - tests/test_peft_weight_converter_compat.py | 259 - tests/test_raw_text.py | 69 - tests/test_resolve_model_class.py | 137 - tests/utils/test_qat.py | 11 +- unsloth/__init__.py | 359 +- unsloth/_gpu_init.py | 330 + unsloth/chat_templates.py | 97 +- unsloth/dataprep/raw_text.py | 6 +- unsloth/import_fixes.py | 109 - unsloth/kernels/fp8.py | 3 +- unsloth/kernels/utils.py | 18 +- unsloth/models/_utils.py | 611 +- unsloth/models/llama.py | 63 +- unsloth/models/loader.py | 22 +- unsloth/models/mapper.py | 48 - unsloth/models/mistral.py | 10 - unsloth/models/rl.py | 88 - unsloth/models/rl_replacements.py | 400 +- unsloth/models/sentence_transformer.py | 228 +- unsloth/models/vision.py | 197 +- unsloth/ollama_template_mappers.py | 7 - unsloth/save.py | 350 +- unsloth/tokenizer_utils.py | 772 +- unsloth_cli/__init__.py | 14 +- unsloth_cli/_tool_policy.py | 71 - unsloth_cli/commands/export.py | 11 +- unsloth_cli/commands/studio.py | 748 +- 402 files changed, 5323 insertions(+), 70359 deletions(-) delete mode 100644 .github/workflows/release-desktop.yml rename scripts/install_gemma4_mlx.sh => install_gemma4_mlx.sh (66%) delete mode 100644 scripts/install_qwen3_6_mlx.sh delete mode 100644 studio/backend/core/inference/anthropic_compat.py delete mode 100644 studio/backend/core/inference/llama_server_args.py delete mode 100644 studio/backend/core/training/resume.py delete mode 100644 studio/backend/plugins/data-designer-github-repo-seed/README.md delete mode 100644 studio/backend/plugins/data-designer-github-repo-seed/pyproject.toml delete mode 100644 studio/backend/plugins/data-designer-github-repo-seed/src/data_designer_github_repo_seed/__init__.py delete mode 100644 studio/backend/plugins/data-designer-github-repo-seed/src/data_designer_github_repo_seed/config.py delete mode 100644 studio/backend/plugins/data-designer-github-repo-seed/src/data_designer_github_repo_seed/impl.py delete mode 100644 studio/backend/plugins/data-designer-github-repo-seed/src/data_designer_github_repo_seed/plugin.py delete mode 100644 studio/backend/plugins/data-designer-github-repo-seed/src/data_designer_github_repo_seed/scraper.py delete mode 100644 studio/backend/plugins/data-designer-github-repo-seed/src/data_designer_github_repo_seed/scraper_impl/__init__.py delete mode 100644 studio/backend/plugins/data-designer-github-repo-seed/src/data_designer_github_repo_seed/scraper_impl/gh_client.py delete mode 100644 studio/backend/plugins/data-designer-github-repo-seed/src/data_designer_github_repo_seed/scraper_impl/queries.py delete mode 100644 studio/backend/plugins/data-designer-github-repo-seed/src/data_designer_github_repo_seed/scraper_impl/scraper.py delete mode 100644 studio/backend/plugins/data-designer-github-repo-seed/src/data_designer_github_repo_seed/scraper_impl/state_store.py delete mode 100644 studio/backend/state/tool_policy.py delete mode 100644 studio/backend/tests/test_anthropic_messages.py delete mode 100644 studio/backend/tests/test_browse_folders_route.py delete mode 100644 studio/backend/tests/test_cached_gguf_routes.py delete mode 100644 studio/backend/tests/test_data_recipe_github_progress.py delete mode 100644 studio/backend/tests/test_desktop_auth.py delete mode 100644 studio/backend/tests/test_export_log_cursor.py delete mode 100644 studio/backend/tests/test_host_defaults.py delete mode 100644 studio/backend/tests/test_llama_cpp_cache_aware_disk_check.py delete mode 100644 studio/backend/tests/test_llama_cpp_context_fit.py delete mode 100644 studio/backend/tests/test_llama_cpp_load_progress.py delete mode 100644 studio/backend/tests/test_llama_cpp_load_progress_live.py delete mode 100644 studio/backend/tests/test_llama_cpp_load_progress_matrix.py delete mode 100644 studio/backend/tests/test_llama_cpp_max_context_threshold.py delete mode 100644 studio/backend/tests/test_llama_cpp_no_context_shift.py delete mode 100644 studio/backend/tests/test_llama_server_args.py delete mode 100644 studio/backend/tests/test_openai_tool_passthrough.py delete mode 100644 studio/backend/tests/test_pytorch_mirror.py delete mode 100644 studio/backend/tests/test_responses_api.py delete mode 100644 studio/backend/tests/test_responses_tool_passthrough.py delete mode 100644 studio/backend/tests/test_studio_api.py delete mode 100644 studio/backend/tests/test_tool_policy_gates.py delete mode 100644 studio/backend/tests/test_tool_policy_state.py delete mode 100644 studio/backend/tests/test_trained_model_scan.py delete mode 100644 studio/backend/tests/test_training_worker_flash_attn.py delete mode 100644 studio/backend/utils/hardware/amd.py delete mode 100644 studio/backend/utils/native_path_leases.py delete mode 100644 studio/backend/utils/subprocess_compat.py delete mode 100644 studio/backend/utils/wheel_utils.py delete mode 100644 studio/frontend/public/blacklogo-c.png delete mode 100644 studio/frontend/public/circle-logo-small.png delete mode 100644 studio/frontend/public/fonts/FiraCode-VariableFont_wght.ttf delete mode 100644 studio/frontend/public/fonts/Hellix-Medium.woff delete mode 100644 studio/frontend/public/fonts/Hellix-Regular.woff delete mode 100644 studio/frontend/public/sidebar-logo-black.png delete mode 100644 studio/frontend/public/sidebar-logo-white.png delete mode 100644 studio/frontend/public/sticker.png delete mode 100644 studio/frontend/public/studio.png delete mode 100644 studio/frontend/public/unsloth-beta-black.png delete mode 100644 studio/frontend/public/unsloth-beta-white.png delete mode 100644 studio/frontend/public/whitelogo-c.png delete mode 100644 studio/frontend/src/components/app-sidebar.tsx delete mode 100644 studio/frontend/src/components/assistant-ui/code-plugin.ts delete mode 100644 studio/frontend/src/components/assistant-ui/code-themes.ts delete mode 100644 studio/frontend/src/components/assistant-ui/code-toggle-icon.tsx delete mode 100644 studio/frontend/src/components/assistant-ui/model-selector/folder-browser.tsx delete mode 100644 studio/frontend/src/components/assistant-ui/model-selector/model-delete-action.tsx delete mode 100644 studio/frontend/src/components/assistant-ui/use-intent-aware-autoscroll.tsx delete mode 100644 studio/frontend/src/components/tauri/startup-screen.tsx delete mode 100644 studio/frontend/src/components/tauri/update-banner.tsx delete mode 100644 studio/frontend/src/components/tauri/update-screen.tsx delete mode 100644 studio/frontend/src/components/tauri/window-titlebar.tsx delete mode 100644 studio/frontend/src/components/ui/shimmer-button.tsx delete mode 100644 studio/frontend/src/features/auth/tauri-auto-auth.ts delete mode 100644 studio/frontend/src/features/chat/components/chat-search-dialog.tsx delete mode 100644 studio/frontend/src/features/chat/hooks/use-chat-search-index.ts delete mode 100644 studio/frontend/src/features/chat/hooks/use-chat-sidebar-items.ts delete mode 100644 studio/frontend/src/features/chat/hooks/use-transfer-stats.ts delete mode 100644 studio/frontend/src/features/chat/presets/preset-policy.ts delete mode 100644 studio/frontend/src/features/chat/stores/chat-search-store.ts delete mode 100644 studio/frontend/src/features/chat/utils/chat-thread-tombstones.ts delete mode 100644 studio/frontend/src/features/chat/utils/clear-all-chats.ts delete mode 100644 studio/frontend/src/features/chat/utils/delete-thread-message.ts delete mode 100644 studio/frontend/src/features/chat/utils/export-chat-history.ts delete mode 100644 studio/frontend/src/features/chat/utils/format-transfer.ts delete mode 100644 studio/frontend/src/features/chat/utils/qwen-params.ts delete mode 100644 studio/frontend/src/features/chat/utils/transfer-stats.ts delete mode 100644 studio/frontend/src/features/data-recipes/hooks/use-recipe-sidebar-items.ts delete mode 100644 studio/frontend/src/features/data-recipes/learning-recipes/github-support-bot.json delete mode 100644 studio/frontend/src/features/native-intents/api.ts delete mode 100644 studio/frontend/src/features/native-intents/components/native-model-chip.tsx delete mode 100644 studio/frontend/src/features/native-intents/components/native-model-drop-overlay.tsx delete mode 100644 studio/frontend/src/features/native-intents/native-intent-drain.tsx delete mode 100644 studio/frontend/src/features/native-intents/store.ts delete mode 100644 studio/frontend/src/features/native-intents/types.ts delete mode 100644 studio/frontend/src/features/native-intents/use-native-dialogs.ts delete mode 100644 studio/frontend/src/features/native-intents/use-native-drop.ts delete mode 100644 studio/frontend/src/features/native-intents/use-native-readiness.ts delete mode 100644 studio/frontend/src/features/profile/components/profile-personalization-panel.tsx delete mode 100644 studio/frontend/src/features/profile/components/user-avatar.tsx delete mode 100644 studio/frontend/src/features/profile/hooks/use-effective-profile.ts delete mode 100644 studio/frontend/src/features/profile/index.ts delete mode 100644 studio/frontend/src/features/profile/stores/user-profile-store.ts delete mode 100644 studio/frontend/src/features/profile/utils/avatar-initials.ts delete mode 100644 studio/frontend/src/features/profile/utils/jwt-subject.ts delete mode 100644 studio/frontend/src/features/profile/utils/resize-image-file.ts delete mode 100644 studio/frontend/src/features/recipe-studio/easy/github-crawler-easy-view.tsx delete mode 100644 studio/frontend/src/features/settings/api/api-keys.ts delete mode 100644 studio/frontend/src/features/settings/components/api-key-row.tsx delete mode 100644 studio/frontend/src/features/settings/components/create-key-form.tsx delete mode 100644 studio/frontend/src/features/settings/components/key-reveal-card.tsx delete mode 100644 studio/frontend/src/features/settings/components/settings-row.tsx delete mode 100644 studio/frontend/src/features/settings/components/settings-section.tsx delete mode 100644 studio/frontend/src/features/settings/components/theme-segmented.tsx delete mode 100644 studio/frontend/src/features/settings/components/update-studio-instructions.tsx delete mode 100644 studio/frontend/src/features/settings/components/usage-examples.tsx delete mode 100644 studio/frontend/src/features/settings/index.ts delete mode 100644 studio/frontend/src/features/settings/settings-dialog.tsx delete mode 100644 studio/frontend/src/features/settings/stores/settings-dialog-store.ts delete mode 100644 studio/frontend/src/features/settings/stores/theme-store.ts delete mode 100644 studio/frontend/src/features/settings/tabs/about-tab.tsx delete mode 100644 studio/frontend/src/features/settings/tabs/api-keys-tab.tsx delete mode 100644 studio/frontend/src/features/settings/tabs/appearance-tab.tsx delete mode 100644 studio/frontend/src/features/settings/tabs/chat-tab.tsx delete mode 100644 studio/frontend/src/features/settings/tabs/general-tab.tsx delete mode 100644 studio/frontend/src/features/settings/tabs/profile-tab.tsx delete mode 100644 studio/frontend/src/features/training/hooks/use-training-history-sidebar.ts delete mode 100644 studio/frontend/src/features/training/hooks/use-training-unload-guard.ts delete mode 100644 studio/frontend/src/hooks/use-collapse-scroll-lock.ts delete mode 100644 studio/frontend/src/hooks/use-sidebar-pin.ts delete mode 100644 studio/frontend/src/hooks/use-tauri-backend.ts delete mode 100644 studio/frontend/src/hooks/use-tauri-update.ts delete mode 100644 studio/frontend/src/lib/api-base.ts delete mode 100644 studio/frontend/src/lib/native-notifications.ts delete mode 100644 studio/frontend/src/lib/open-link.ts delete mode 100644 studio/frontend/src/lib/tauri-diagnostics.ts delete mode 100644 studio/src-tauri/Cargo.lock delete mode 100644 studio/src-tauri/Cargo.toml delete mode 100644 studio/src-tauri/Entitlements.plist delete mode 100644 studio/src-tauri/build.rs delete mode 100644 studio/src-tauri/capabilities/default.json delete mode 100644 studio/src-tauri/icons/128x128.png delete mode 100644 studio/src-tauri/icons/32x32.png delete mode 100644 studio/src-tauri/icons/icon.icns delete mode 100644 studio/src-tauri/icons/icon.ico delete mode 100644 studio/src-tauri/icons/icon.png delete mode 100755 studio/src-tauri/linux/postremove.sh delete mode 100644 studio/src-tauri/src/commands.rs delete mode 100644 studio/src-tauri/src/desktop_auth.rs delete mode 100644 studio/src-tauri/src/diagnostics/mod.rs delete mode 100644 studio/src-tauri/src/diagnostics/phase_log.rs delete mode 100644 studio/src-tauri/src/diagnostics/redaction.rs delete mode 100644 studio/src-tauri/src/diagnostics/report.rs delete mode 100644 studio/src-tauri/src/diagnostics/state.rs delete mode 100644 studio/src-tauri/src/install.rs delete mode 100644 studio/src-tauri/src/main.rs delete mode 100644 studio/src-tauri/src/native_backend_lease.rs delete mode 100644 studio/src-tauri/src/native_intents.rs delete mode 100644 studio/src-tauri/src/native_path_policy.rs delete mode 100644 studio/src-tauri/src/preflight.rs delete mode 100644 studio/src-tauri/src/process.rs delete mode 100644 studio/src-tauri/src/update.rs delete mode 100644 studio/src-tauri/src/windows_job.rs delete mode 100644 studio/src-tauri/tauri.conf.json delete mode 100644 studio/src-tauri/tauri.macos.conf.json delete mode 100644 studio/src-tauri/tauri.windows.conf.json delete mode 100644 studio/src-tauri/windows/branding/nsis-header.bmp delete mode 100644 studio/src-tauri/windows/branding/nsis-sidebar.bmp delete mode 100644 studio/src-tauri/windows/hooks.nsh delete mode 100644 studio/src-tauri/windows/installer.nsi delete mode 100644 tests/python/test_dpo_vision_processor_passthrough.py delete mode 100644 tests/python/test_fast_sentence_transformer_redirect_lifecycle.py delete mode 100644 tests/python/test_flash_attn_install_python_stack.py delete mode 100644 tests/python/test_unsloth_run_tool_policy_resolver.py delete mode 100644 tests/saving/test_fix_sentencepiece_gguf_robustness.py delete mode 100644 tests/saving/test_patch_saving_none_tokenizer.py delete mode 100755 tests/sh/test_install_host_defaults.sh delete mode 100644 tests/sh/test_tauri_install_exit_order.sh delete mode 100644 tests/studio/install/test_rocm_support.py delete mode 100644 tests/studio/test_cancel_atomicity.py delete mode 100644 tests/studio/test_cancel_id_wiring.py delete mode 100644 tests/studio/test_chat_preset_builtin_invariants.py delete mode 100644 tests/studio/test_cli_repo_variant.py delete mode 100644 tests/studio/test_cli_run_alias.py delete mode 100644 tests/studio/test_cli_studio_defaults.py delete mode 100644 tests/studio/test_llama_cpp_wall_clock_cap.py delete mode 100644 tests/studio/test_stream_cancel_registration_timing.py delete mode 100644 tests/studio/test_studio_gguf_export_script_pin.py delete mode 100644 tests/studio/test_studio_text_descender_clipping.py delete mode 100644 tests/test_cli_export_unpacking.py delete mode 100644 tests/test_gemma4_chat_template.py delete mode 100644 tests/test_peft_weight_converter_compat.py delete mode 100644 tests/test_resolve_model_class.py create mode 100644 unsloth/_gpu_init.py delete mode 100644 unsloth_cli/_tool_policy.py diff --git a/.github/workflows/release-desktop.yml b/.github/workflows/release-desktop.yml deleted file mode 100644 index ea82739968..0000000000 --- a/.github/workflows/release-desktop.yml +++ /dev/null @@ -1,226 +0,0 @@ -name: Release Desktop App - -on: - workflow_dispatch: - inputs: - draft: - description: 'Create as draft release' - type: boolean - default: true - -permissions: - contents: write - -jobs: - build: - strategy: - fail-fast: false - max-parallel: 1 - matrix: - include: - - platform: macos-latest - args: '--target aarch64-apple-darwin' - label: macOS (Apple Silicon) - # - platform: macos-latest - # args: '--target x86_64-apple-darwin' - # label: macOS (Intel) - - platform: ubuntu-22.04 - args: '' - label: Linux (x64) - - platform: windows-latest - args: '' - label: Windows (x64) - - name: Build ${{ matrix.label }} - runs-on: ${{ matrix.platform }} - - env: - FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: true - - - steps: - - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 - - # ── Linux dependencies ── - - name: Install Linux dependencies - if: matrix.platform == 'ubuntu-22.04' - run: | - sudo apt-get update - sudo apt-get install -y libwebkit2gtk-4.1-dev libayatana-appindicator3-dev librsvg2-dev libxdo-dev libssl-dev patchelf - - # ── Node.js ── - - name: Setup Node.js - uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 - with: - node-version: 24 - - - name: Install pinned Tauri CLI - run: npm install --save-dev --prefix studio @tauri-apps/cli@2.10.1 - - - name: Verify pinned Tauri CLI - shell: bash - run: | - out="$(npx --prefix studio tauri --version)" - echo "$out" - if [ "$out" != "tauri-cli 2.10.1" ]; then - echo "Expected tauri-cli 2.10.1, got $out" >&2 - exit 1 - fi - - - name: Install frontend dependencies - working-directory: studio/frontend - run: npm install - - - name: Verify backend package is published - shell: bash - run: | - node <<'JS' - const { readFileSync } = require('node:fs'); - - (async () => { - const cargo = readFileSync('studio/src-tauri/Cargo.toml', 'utf8'); - const match = cargo.match(/^version\s*=\s*"([^"]+)"/m); - if (!match) throw new Error('Could not read desktop app version'); - - const appVersion = match[1]; - const response = await fetch(`https://pypi.org/pypi/unsloth/${appVersion}/json`); - if (!response.ok) { - const message = 'Publish unsloth=={app_version} to PyPI before the desktop release'; - throw new Error(`${message.replace('{app_version}', appVersion)} (HTTP ${response.status})`); - } - })(); - JS - - # ── Rust ── - - name: Install Rust stable - uses: dtolnay/rust-toolchain@stable - with: - targets: ${{ matrix.platform == 'macos-latest' && 'aarch64-apple-darwin,x86_64-apple-darwin' || '' }} - - - name: Rust cache - uses: swatinem/rust-cache@42dc69e1aa15d09112580998cf2ef0119e2e91ae - with: - workspaces: 'studio/src-tauri -> target' - - # ── macOS: import signing certificate ── - - name: Import Apple certificate - if: matrix.platform == 'macos-latest' - env: - APPLE_CERTIFICATE: ${{ secrets.APPLE_CERTIFICATE }} - APPLE_CERTIFICATE_PASSWORD: ${{ secrets.APPLE_CERTIFICATE_PASSWORD }} - KEYCHAIN_PASSWORD: ${{ secrets.KEYCHAIN_PASSWORD }} - run: | - echo $APPLE_CERTIFICATE | base64 --decode > certificate.p12 - security create-keychain -p "$KEYCHAIN_PASSWORD" build.keychain - security default-keychain -s build.keychain - security unlock-keychain -p "$KEYCHAIN_PASSWORD" build.keychain - security set-keychain-settings -t 3600 -u build.keychain - security import certificate.p12 -k build.keychain -P "$APPLE_CERTIFICATE_PASSWORD" -T /usr/bin/codesign - security set-key-partition-list -S apple-tool:,apple:,codesign: -s -k "$KEYCHAIN_PASSWORD" build.keychain - security find-identity -v -p codesigning build.keychain - rm -f certificate.p12 - - # ── Windows: install Azure Trusted Signing CLI ── - - name: Install trusted-signing-cli - if: matrix.platform == 'windows-latest' - run: | - cargo install trusted-signing-cli --version 0.9.0 --locked - echo "$env:USERPROFILE\.cargo\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append - - # ── Windows: verify signing CLI is accessible ── - - name: Verify trusted-signing-cli - if: matrix.platform == 'windows-latest' - run: | - Write-Output "PATH: $env:PATH" - Get-Command trusted-signing-cli -ErrorAction SilentlyContinue || Write-Output "trusted-signing-cli NOT in PATH" - trusted-signing-cli --version || Write-Output "trusted-signing-cli failed to run" - - # ── Linux: build + sign + upload ── - - name: Build Linux app - if: matrix.platform == 'ubuntu-22.04' - uses: tauri-apps/tauri-action@84b9d35b5fc46c1e45415bdb6144030364f7ebc5 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - TAURI_SIGNING_PRIVATE_KEY: ${{ secrets.TAURI_SIGNING_PRIVATE_KEY }} - TAURI_SIGNING_PRIVATE_KEY_PASSWORD: ${{ secrets.TAURI_SIGNING_PRIVATE_KEY_PASSWORD }} - with: - projectPath: studio - tauriScript: npx --prefix . tauri - tagName: desktop-v__VERSION__ - releaseName: 'Unsloth Studio (Desktop) v__VERSION__' - releaseBody: | - Desktop app for Unsloth Studio. - - **macOS**: Download the Apple Silicon `.dmg`. - **Windows**: Download the `-setup.exe` installer. - **Linux**: Download `.deb` (Ubuntu/Debian) or `.AppImage` (universal). - - > Linux in-app updates are AppImage-oriented. Package installs should update by downloading a new package. - > Linux AppImage on Ubuntu 24.04+ may require: `sudo apt install libfuse2t64` - > First-run system dependency elevation is supported on Ubuntu/Debian. Other Linux distributions should install system packages manually. - releaseDraft: ${{ inputs.draft }} - prerelease: false - args: -v ${{ matrix.args }} - - # ── macOS: build + sign + notarize + upload ── - - name: Build macOS app - if: matrix.platform == 'macos-latest' - uses: tauri-apps/tauri-action@84b9d35b5fc46c1e45415bdb6144030364f7ebc5 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - TAURI_SIGNING_PRIVATE_KEY: ${{ secrets.TAURI_SIGNING_PRIVATE_KEY }} - TAURI_SIGNING_PRIVATE_KEY_PASSWORD: ${{ secrets.TAURI_SIGNING_PRIVATE_KEY_PASSWORD }} - APPLE_SIGNING_IDENTITY: ${{ secrets.APPLE_SIGNING_IDENTITY }} - APPLE_ID: ${{ secrets.APPLE_ID }} - APPLE_PASSWORD: ${{ secrets.APPLE_PASSWORD }} - APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }} - with: - projectPath: studio - tauriScript: npx --prefix . tauri - tagName: desktop-v__VERSION__ - releaseName: 'Unsloth Studio (Desktop) v__VERSION__' - releaseBody: | - Desktop app for Unsloth Studio. - - **macOS**: Download the Apple Silicon `.dmg`. - **Windows**: Download the `-setup.exe` installer. - **Linux**: Download `.deb` (Ubuntu/Debian) or `.AppImage` (universal). - - > Linux in-app updates are AppImage-oriented. Package installs should update by downloading a new package. - > Linux AppImage on Ubuntu 24.04+ may require: `sudo apt install libfuse2t64` - > First-run system dependency elevation is supported on Ubuntu/Debian. Other Linux distributions should install system packages manually. - releaseDraft: ${{ inputs.draft }} - prerelease: false - args: -v ${{ matrix.args }} - - # ── Windows: build + sign + upload ── - - name: Build Windows app - if: matrix.platform == 'windows-latest' - uses: tauri-apps/tauri-action@84b9d35b5fc46c1e45415bdb6144030364f7ebc5 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - TAURI_SIGNING_PRIVATE_KEY: ${{ secrets.TAURI_SIGNING_PRIVATE_KEY }} - TAURI_SIGNING_PRIVATE_KEY_PASSWORD: ${{ secrets.TAURI_SIGNING_PRIVATE_KEY_PASSWORD }} - AZURE_CLIENT_ID: ${{ secrets.AZURE_CLIENT_ID }} - AZURE_CLIENT_SECRET: ${{ secrets.AZURE_CLIENT_SECRET }} - AZURE_TENANT_ID: ${{ secrets.AZURE_TENANT_ID }} - AZURE_TRUSTED_SIGNING_ACCOUNT_NAME: ${{ secrets.AZURE_TRUSTED_SIGNING_ACCOUNT_NAME }} - AZURE_CERTIFICATE_PROFILE_NAME: ${{ secrets.AZURE_CERTIFICATE_PROFILE_NAME }} - with: - projectPath: studio - tauriScript: npx --prefix . tauri - tagName: desktop-v__VERSION__ - releaseName: 'Unsloth Studio (Desktop) v__VERSION__' - releaseBody: | - Desktop app for Unsloth Studio. - - **macOS**: Download the Apple Silicon `.dmg`. - **Windows**: Download the `-setup.exe` installer. - **Linux**: Download `.deb` (Ubuntu/Debian) or `.AppImage` (universal). - - > Linux in-app updates are AppImage-oriented. Package installs should update by downloading a new package. - > Linux AppImage on Ubuntu 24.04+ may require: `sudo apt install libfuse2t64` - > First-run system dependency elevation is supported on Ubuntu/Debian. Other Linux distributions should install system packages manually. - releaseDraft: ${{ inputs.draft }} - prerelease: false - args: -v ${{ matrix.args }} diff --git a/.gitignore b/.gitignore index b6786ee655..7a24d07c6f 100644 --- a/.gitignore +++ b/.gitignore @@ -204,18 +204,6 @@ tmp/ **/node_modules/ auth.db -# Tauri local build/generated output -studio/src-tauri/target/ -studio/src-tauri/gen/ -studio/src-tauri/artifacts/ -studio/src-tauri/icons/android/ -studio/src-tauri/icons/ios/ -studio/src-tauri/icons/128x128@2x.png -studio/src-tauri/icons/64x64.png -studio/src-tauri/icons/Square*Logo.png -studio/src-tauri/icons/StoreLogo.png -studio/src-tauri/icons/squarehq.png - # Local working docs **/CLAUDE.md **/claude.md diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a2a4995d62..e41c37d209 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.15.12 + rev: v0.15.9 hooks: - id: ruff args: diff --git a/README.md b/README.md index 6c12b28096..7046a2af7c 100644 --- a/README.md +++ b/README.md @@ -1,43 +1,28 @@

- - - Unsloth logo + + + Unsloth logo

-Unsloth Studio lets you run and train models locally. +Run and train AI models with a unified local interface.

Features • - Quickstart • + QuickstartNotebooks • - Documentation + Documentation • + Reddit

-
- -unsloth studio ui homepage + +unsloth studio ui homepage -## ⚡ Get started - -#### macOS, Linux, WSL: -```bash -curl -fsSL https://unsloth.ai/install.sh | sh -``` -#### Windows: -```powershell -irm https://unsloth.ai/install.ps1 | iex -``` -#### Community: - -- [Discord](https://discord.gg/unsloth) -- [𝕏 (Twitter)](https://x.com/UnslothAI) -- [Reddit](https://reddit.com/r/unsloth) - -## ⭐ Features Unsloth Studio (Beta) lets you run and train text, [audio](https://unsloth.ai/docs/basics/text-to-speech-tts-fine-tuning), [embedding](https://unsloth.ai/docs/new/embedding-finetuning), [vision](https://unsloth.ai/docs/basics/vision-fine-tuning) models on Windows, Linux and macOS. +## ⭐ Features +Unsloth provides several key features for both inference and training: ### Inference * **Search + download + run models** including GGUF, LoRA adapters, safetensors * **Export models**: [Save or export](https://unsloth.ai/docs/new/studio/export) models to GGUF, 16-bit safetensors and other formats. @@ -55,7 +40,7 @@ Unsloth Studio (Beta) lets you run and train text, [audio](https://unsloth.ai/do * **Observability**: Monitor training live, track loss and GPU usage and customize graphs. * [Multi-GPU](https://unsloth.ai/docs/basics/multi-gpu-training-with-unsloth) training is supported, with major improvements coming soon. -## 📥 Install +## ⚡ Quickstart Unsloth can be used in two ways: through **[Unsloth Studio](https://unsloth.ai/docs/new/studio/)**, the web UI, or through **Unsloth Core**, the code-based version. Each has different requirements. ### Unsloth Studio (web UI) @@ -79,9 +64,8 @@ irm https://unsloth.ai/install.ps1 | iex #### Launch ```bash -unsloth studio -p 8888 +unsloth studio -H 0.0.0.0 -p 8888 ``` -> For cloud VMs or LAN access, add `-H 0.0.0.0` to bind on all interfaces. #### Update To update, use the same install commands as above. Or run (does not work on Windows): @@ -149,8 +133,7 @@ Read our [guide](https://unsloth.ai/docs/get-started/fine-tuning-llms-guide). Ad - See detailed documentation for Unsloth [here](https://unsloth.ai/docs) ## 🦥 Unsloth News -- **Qwen3.6**: Qwen3.6-35B-A3B can now be trained and run in Unsloth Studio. [Blog](https://unsloth.ai/docs/models/qwen3.6) -- **Gemma 4**: Run and train Google’s new models directly in Unsloth. [Blog](https://unsloth.ai/docs/models/gemma-4) +- **Gemma 4**: Run and train Google’s new models directly in Unsloth Studio! [Blog](https://unsloth.ai/docs/models/gemma-4) - **Introducing Unsloth Studio**: our new web UI for running and training LLMs. [Blog](https://unsloth.ai/docs/new/studio) - **Qwen3.5** - 0.8B, 2B, 4B, 9B, 27B, 35-A3B, 112B-A10B are now supported. [Guide + notebooks](https://unsloth.ai/docs/models/qwen3.5/fine-tune) - Train **MoE LLMs 12x faster** with 35% less VRAM - DeepSeek, GLM, Qwen and gpt-oss. [Blog](https://unsloth.ai/docs/new/faster-moe) @@ -168,7 +151,7 @@ The below advanced instructions are for Unsloth Studio. For Unsloth Core advance git clone https://github.com/unslothai/unsloth cd unsloth ./install.sh --local -unsloth studio -p 8888 +unsloth studio -H 0.0.0.0 -p 8888 ``` Then to update : ```bash @@ -181,7 +164,7 @@ git clone https://github.com/unslothai/unsloth.git cd unsloth Set-ExecutionPolicy -Scope Process -ExecutionPolicy Bypass .\install.ps1 --local -unsloth studio -p 8888 +unsloth studio -H 0.0.0.0 -p 8888 ``` Then to update : ```bash @@ -194,11 +177,11 @@ git clone https://github.com/unslothai/unsloth cd unsloth git checkout nightly ./install.sh --local -unsloth studio -p 8888 +unsloth studio -H 0.0.0.0 -p 8888 ``` Then to launch every time: ```bash -unsloth studio -p 8888 +unsloth studio -H 0.0.0.0 -p 8888 ``` #### Nightly: Windows: @@ -209,11 +192,11 @@ cd unsloth git checkout nightly Set-ExecutionPolicy -Scope Process -ExecutionPolicy Bypass .\install.ps1 --local -unsloth studio -p 8888 +unsloth studio -H 0.0.0.0 -p 8888 ``` Then to launch every time: ```bash -unsloth studio -p 8888 +unsloth studio -H 0.0.0.0 -p 8888 ``` #### Uninstall diff --git a/install.ps1 b/install.ps1 index 7dc5a50250..a2acd6c4ea 100644 --- a/install.ps1 +++ b/install.ps1 @@ -8,90 +8,15 @@ function Install-UnslothStudio { $ErrorActionPreference = "Stop" $script:UnslothVerbose = ($env:UNSLOTH_VERBOSE -eq "1") - # ── Tauri structured output ── - function Write-TauriLog { - param([string]$Tag, [string]$Message) - if ($TauriMode) { - Write-Host "[TAURI:$Tag] $Message" - } - } - - function Format-TauriDiagBool { - param([bool]$Value) - if ($Value) { return "true" } - return "false" - } - - function Get-TauriDiagArch { - $arch = [string]$env:PROCESSOR_ARCHITECTURE - if ([string]::IsNullOrWhiteSpace($arch)) { - try { $arch = [System.Runtime.InteropServices.RuntimeInformation]::OSArchitecture.ToString() } catch { $arch = "unknown" } - } - $arch = $arch.ToLowerInvariant() - switch ($arch) { - "amd64" { return "x86_64" } - "x64" { return "x86_64" } - "arm64" { return "arm64" } - "x86" { return "x86" } - default { return ($arch -replace '[^a-z0-9_.-]', '_') } - } - } - - function Get-TauriTorchIndexFamily { - param([string]$TorchIndexUrl) - if ($SkipTorch) { return "none" } - if ([string]::IsNullOrWhiteSpace($TorchIndexUrl)) { return "none" } - $leaf = ($TorchIndexUrl.TrimEnd('/') -split '/')[-1].ToLowerInvariant() - if (@("cpu", "cu118", "cu124", "cu126", "cu128", "cu130") -contains $leaf) { return $leaf } - if ($leaf -match '^rocm[0-9]+\.[0-9]+$') { return $leaf } - return "auto" - } - - function Get-TauriGpuBranch { - param([string]$TorchIndexFamily) - if ($SkipTorch) { return "no_torch" } - if ($TorchIndexFamily -like "cu*") { return "cuda" } - if ($TorchIndexFamily -like "rocm*") { return "rocm" } - if ($TorchIndexFamily -eq "cpu") { return "cpu" } - return "unknown" - } - - function Write-TauriDiag { - param( - [string]$GpuBranch = "unknown", - [string]$TorchIndexFamily = "none", - [string]$PythonVersionForDiag = $PythonVersion - ) - if ([string]::IsNullOrWhiteSpace($PythonVersionForDiag)) { $PythonVersionForDiag = "unknown" } - Write-TauriLog "DIAG" "diag_schema=1 platform=windows arch=$(Get-TauriDiagArch) python_version=$($PythonVersionForDiag.ToLowerInvariant()) skip_torch=$(Format-TauriDiagBool $SkipTorch) mac_intel=false gpu_branch=$GpuBranch torch_index_family=$TorchIndexFamily" - } - - function Exit-InstallFailure { - param( - [Parameter(Mandatory = $true)][string]$Message, - [int]$Code = 1 - ) - if ($Code -eq 0) { $Code = 1 } - Write-TauriLog "ERROR" $Message - if (Get-Command Restore-StudioVenvRollback -CommandType Function -ErrorAction SilentlyContinue) { - Restore-StudioVenvRollback - } - if ($TauriMode) { - exit $Code - } - } - # ── Parse flags ── $StudioLocalInstall = $false $PackageName = "unsloth" $RepoRoot = "" - $TauriMode = $false $SkipTorch = $false $argList = $args for ($i = 0; $i -lt $argList.Count; $i++) { switch ($argList[$i]) { "--local" { $StudioLocalInstall = $true } - "--tauri" { $TauriMode = $true } "--no-torch" { $SkipTorch = $true } "--verbose" { $script:UnslothVerbose = $true } "-v" { $script:UnslothVerbose = $true } @@ -99,7 +24,7 @@ function Install-UnslothStudio { $i++ if ($i -ge $argList.Count) { Write-Host "[ERROR] --package requires an argument." -ForegroundColor Red - return (Exit-InstallFailure "--package requires an argument.") + return } $PackageName = $argList[$i] } @@ -115,16 +40,10 @@ function Install-UnslothStudio { $RepoRoot = (Resolve-Path (Split-Path -Parent $PSCommandPath)).Path if (-not (Test-Path (Join-Path $RepoRoot "pyproject.toml"))) { Write-Host "[ERROR] --local must be run from the unsloth repo root (pyproject.toml not found at $RepoRoot)" -ForegroundColor Red - return (Exit-InstallFailure "--local must be run from the unsloth repo root") + return } } - # Validate --package to prevent injection into shell/Python commands - if ($PackageName -notmatch '^[a-zA-Z0-9][a-zA-Z0-9._-]*$') { - Write-Host "[ERROR] --package name contains invalid characters (allowed: a-z A-Z 0-9 . _ -)" -ForegroundColor Red - return (Exit-InstallFailure "--package name contains invalid characters") - } - $PythonVersion = "3.13" $StudioHome = Join-Path $env:USERPROFILE ".unsloth\studio" $VenvDir = Join-Path $StudioHome "unsloth_studio" @@ -181,115 +100,22 @@ function Install-UnslothStudio { Write-Host "" # ── Helper: refresh PATH from registry (deduplicating entries) ── - # Merge order: venv Scripts (if active) > Machine > User > current $env:Path. - # Dedup compares both raw and expanded forms (%VAR% vs literal). function Refresh-SessionPath { $machine = [System.Environment]::GetEnvironmentVariable("Path", "Machine") $user = [System.Environment]::GetEnvironmentVariable("Path", "User") - $venvScripts = if ($env:VIRTUAL_ENV) { Join-Path $env:VIRTUAL_ENV "Scripts" } else { $null } - $sources = @() - if ($venvScripts) { $sources += $venvScripts } - $sources += @($machine, $user, $env:Path) - $merged = ($sources | Where-Object { $_ }) -join ";" + $merged = "$machine;$user;$env:Path" $seen = @{} - $unique = New-Object System.Collections.Generic.List[string] + $unique = @() foreach ($p in $merged -split ";") { - $rawKey = $p.Trim().Trim('"').TrimEnd("\").ToLowerInvariant() - $expKey = [Environment]::ExpandEnvironmentVariables($p).Trim().Trim('"').TrimEnd("\").ToLowerInvariant() - if ($rawKey -and -not $seen.ContainsKey($rawKey) -and -not $seen.ContainsKey($expKey)) { - $seen[$rawKey] = $true - if ($expKey -and $expKey -ne $rawKey) { $seen[$expKey] = $true } - $unique.Add($p) + $key = $p.TrimEnd("\").ToLowerInvariant() + if ($key -and -not $seen.ContainsKey($key)) { + $seen[$key] = $true + $unique += $p } } $env:Path = $unique -join ";" } - # ── Helper: safely add a directory to the persistent User PATH ── - # Direct registry access preserves REG_EXPAND_SZ (avoids dotnet/runtime#1442). - # Append (default) keeps existing tools first; Prepend for must-win entries. - function Add-ToUserPath { - param( - [Parameter(Mandatory = $true)][string]$Directory, - [ValidateSet('Append','Prepend')] - [string]$Position = 'Append' - ) - try { - $regKey = [Microsoft.Win32.Registry]::CurrentUser.CreateSubKey('Environment') - try { - $rawPath = $regKey.GetValue('Path', '', [Microsoft.Win32.RegistryValueOptions]::DoNotExpandEnvironmentNames) - [string[]]$entries = if ($rawPath) { $rawPath -split ';' } else { @() } # string[] prevents scalar collapse - $normalDir = $Directory.Trim().Trim('"').TrimEnd('\').ToLowerInvariant() - $expNormalDir = [Environment]::ExpandEnvironmentVariables($Directory).Trim().Trim('"').TrimEnd('\').ToLowerInvariant() - $kept = New-Object System.Collections.Generic.List[string] - $matchIndices = New-Object System.Collections.Generic.List[int] - for ($i = 0; $i -lt $entries.Count; $i++) { - $stripped = $entries[$i].Trim().Trim('"') - $rawNorm = $stripped.TrimEnd('\').ToLowerInvariant() - $expNorm = [Environment]::ExpandEnvironmentVariables($stripped).TrimEnd('\').ToLowerInvariant() - $isMatch = ($rawNorm -and ($rawNorm -eq $normalDir -or $rawNorm -eq $expNormalDir)) -or - ($expNorm -and ($expNorm -eq $normalDir -or $expNorm -eq $expNormalDir)) - if ($isMatch) { - $matchIndices.Add($i) - continue - } - $kept.Add($entries[$i]) - } - $alreadyPresent = $matchIndices.Count -gt 0 - if ($alreadyPresent -and $Position -eq 'Append') { # Append: idempotent no-op - return $false - } - if ($alreadyPresent -and $Position -eq 'Prepend' -and # Prepend: no-op if already at front - $matchIndices.Count -eq 1 -and $matchIndices[0] -eq 0) { - return $false - } - # One-time backup under HKCU\Software\Unsloth\PathBackup - if ($rawPath) { - try { - $backupKey = [Microsoft.Win32.Registry]::CurrentUser.CreateSubKey('Software\Unsloth') - try { - $existingBackup = $backupKey.GetValue('PathBackup', $null) - if (-not $existingBackup) { - $backupKey.SetValue('PathBackup', $rawPath, [Microsoft.Win32.RegistryValueKind]::ExpandString) - } - } finally { - $backupKey.Close() - } - } catch { } - } - if (-not $rawPath) { - Write-Host "[WARN] User PATH is empty - initializing with $Directory" -ForegroundColor Yellow - } - $newPath = if ($rawPath) { - if ($Position -eq 'Prepend') { - (@($Directory) + $kept) -join ';' - } else { - ($kept + @($Directory)) -join ';' - } - } else { - $Directory - } - if ($newPath -ceq $rawPath) { # no actual change - return $false - } - $regKey.SetValue('Path', $newPath, [Microsoft.Win32.RegistryValueKind]::ExpandString) - # Broadcast WM_SETTINGCHANGE via dummy env-var roundtrip. - # [NullString]::Value avoids PS 7.5+/.NET 9 $null-to-"" coercion. - try { - $d = "UnslothPathRefresh_$([guid]::NewGuid().ToString('N').Substring(0,8))" - [Environment]::SetEnvironmentVariable($d, '1', 'User') - [Environment]::SetEnvironmentVariable($d, [NullString]::Value, 'User') - } catch { } - return $true - } finally { - $regKey.Close() - } - } catch { - Write-Host "[WARN] Could not update User PATH: $($_.Exception.Message)" -ForegroundColor Yellow - return $false - } - } - function step { param( [Parameter(Mandatory = $true)][string]$Label, @@ -552,7 +378,7 @@ try { } catch {} exit 1 } - `$studioCommand = '& "' + `$studioExe + '" studio -p ' + `$launchPort + `$studioCommand = '& "' + `$studioExe + '" studio -H 0.0.0.0 -p ' + `$launchPort `$launchArgs = @( '-NoExit', '-NoProfile', @@ -690,12 +516,11 @@ shell.Run cmd, 0, False } # ── Check winget ── - Write-TauriLog "STEP" "Checking system dependencies" if (-not (Get-Command winget -ErrorAction SilentlyContinue)) { step "winget" "not available" "Red" substep "Install it from https://aka.ms/getwinget" "Yellow" substep "or install Python $PythonVersion and uv manually, then re-run." "Yellow" - return (Exit-InstallFailure "winget is not available") + return } # ── Helper: detect a working Python 3.11-3.13 on the system ── @@ -770,7 +595,6 @@ shell.Run cmd, 0, False # ── Install Python if no compatible version (3.11-3.13) found ── # Find-CompatiblePython returns @{ Version = "3.13"; Path = "C:\...\python.exe" } or $null. - Write-TauriLog "STEP" "Installing Python" $DetectedPython = Find-CompatiblePython if ($DetectedPython) { step "python" "Python $($DetectedPython.Version) already installed" @@ -814,17 +638,11 @@ shell.Run cmd, 0, False Write-Host " Please install Python $PythonVersion manually from https://www.python.org/downloads/" -ForegroundColor Yellow Write-Host " Make sure to check 'Add Python to PATH' during installation." -ForegroundColor Yellow Write-Host " Then re-run this installer." -ForegroundColor Yellow - return (Exit-InstallFailure "Python installation failed") + return } } - $DiagPythonVersion = $PythonVersion - if ($DetectedPython) { $DiagPythonVersion = $DetectedPython.Version } - $InitialGpuBranch = "unknown" - if ($SkipTorch) { $InitialGpuBranch = "no_torch" } - Write-TauriDiag -GpuBranch $InitialGpuBranch -TorchIndexFamily "none" -PythonVersionForDiag $DiagPythonVersion # ── Install uv if not present ── - Write-TauriLog "STEP" "Installing uv package manager" if (-not (Get-Command uv -ErrorAction SilentlyContinue)) { substep "installing uv package manager..." $prevEAP = $ErrorActionPreference @@ -835,7 +653,7 @@ shell.Run cmd, 0, False # Fallback: if winget didn't put uv on PATH, try the PowerShell installer if (-not (Get-Command uv -ErrorAction SilentlyContinue)) { substep "trying alternative uv installer..." "Yellow" - Invoke-Expression (Invoke-RestMethod -Uri "https://astral.sh/uv/install.ps1") + powershell -ExecutionPolicy ByPass -c "irm https://astral.sh/uv/install.ps1 | iex" Refresh-SessionPath } } @@ -843,81 +661,23 @@ shell.Run cmd, 0, False if (-not (Get-Command uv -ErrorAction SilentlyContinue)) { step "uv" "could not be installed" "Red" substep "Install it from https://docs.astral.sh/uv/" "Yellow" - return (Exit-InstallFailure "uv could not be installed") + return } # ── Create venv (migrate old layout if possible, otherwise fresh) ── # Pass the resolved executable path to uv so it does not re-resolve # a version string back to a conda interpreter. - Write-TauriLog "STEP" "Creating virtual environment" if (-not (Test-Path $StudioHome)) { New-Item -ItemType Directory -Path $StudioHome -Force | Out-Null } $VenvPython = Join-Path $VenvDir "Scripts\python.exe" $_Migrated = $false - $script:StudioVenvRollbackDir = $null - $script:StudioVenvRollbackTarget = $VenvDir - $script:StudioVenvRollbackActive = $false - - function Start-StudioVenvRollback { - param([Parameter(Mandatory = $true)][string]$ExistingDir) - $stamp = Get-Date -Format "yyyyMMddHHmmss" - $candidate = Join-Path $StudioHome "unsloth_studio.rollback.$stamp.$PID" - $suffix = 0 - while (Test-Path $candidate) { - $suffix++ - $candidate = Join-Path $StudioHome "unsloth_studio.rollback.$stamp.$PID.$suffix" - } - Move-Item -Path $ExistingDir -Destination $candidate -ErrorAction Stop - $script:StudioVenvRollbackDir = $candidate - $script:StudioVenvRollbackTarget = $ExistingDir - $script:StudioVenvRollbackActive = $true - substep "previous environment preserved for rollback" - } - - function Restore-StudioVenvRollback { - if (-not $script:StudioVenvRollbackActive) { return } - $backup = $script:StudioVenvRollbackDir - $target = $script:StudioVenvRollbackTarget - if (-not $backup -or -not (Test-Path $backup)) { - $script:StudioVenvRollbackActive = $false - return - } - substep "restoring previous environment after failed install..." "Yellow" - try { - if (Test-Path $target) { - Remove-Item -Recurse -Force $target -ErrorAction SilentlyContinue - } - Move-Item -Path $backup -Destination $target -Force -ErrorAction Stop - substep "restored previous environment" - $script:StudioVenvRollbackActive = $false - $script:StudioVenvRollbackDir = $null - } catch { - Write-Host "[WARN] Could not restore previous environment from $backup to $target" -ForegroundColor Yellow - Write-Host " $($_.Exception.Message)" -ForegroundColor Yellow - } - } - - function Complete-StudioVenvRollback { - if (-not $script:StudioVenvRollbackActive) { return } - $backup = $script:StudioVenvRollbackDir - if ($backup -and (Test-Path $backup)) { - Remove-Item -Recurse -Force $backup -ErrorAction SilentlyContinue - } - $script:StudioVenvRollbackActive = $false - $script:StudioVenvRollbackDir = $null - } if (Test-Path $VenvPython) { - # New layout already exists -- replace only after preserving rollback copy. - substep "preserving existing environment for rollback..." - try { - Start-StudioVenvRollback -ExistingDir $VenvDir - } catch { - Write-Host "[ERROR] Could not prepare existing environment for reinstall: $($_.Exception.Message)" -ForegroundColor Red - return (Exit-InstallFailure "Could not prepare existing environment for reinstall") - } + # New layout already exists -- nuke for fresh install + substep "removing existing environment for fresh install..." + Remove-Item -Recurse -Force $VenvDir } elseif (Test-Path (Join-Path $StudioHome ".venv\Scripts\python.exe")) { # Old layout (~/.unsloth/studio/.venv) exists -- validate before migrating $OldVenv = Join-Path $StudioHome ".venv" @@ -926,23 +686,18 @@ shell.Run cmd, 0, False $prevEAP2 = $ErrorActionPreference $ErrorActionPreference = "Continue" try { - if ($SkipTorch) { - & $OldPy -c "import sys; print(sys.executable)" 2>$null | Out-Null - } else { - & $OldPy -c "import torch; A = torch.ones((2,2)); B = A + A" 2>$null | Out-Null - } - $legacyOk = ($LASTEXITCODE -eq 0) - } catch { $legacyOk = $false } + & $OldPy -c "import torch; A = torch.ones((2,2)); B = A + A" 2>$null | Out-Null + $torchOk = ($LASTEXITCODE -eq 0) + } catch { $torchOk = $false } $ErrorActionPreference = $prevEAP2 - if ($legacyOk) { + if ($torchOk) { substep "legacy environment is healthy -- migrating..." Move-Item -Path $OldVenv -Destination $VenvDir -Force substep "moved .venv -> unsloth_studio" $_Migrated = $true } else { substep "legacy environment failed validation -- creating fresh environment" "Yellow" - $invalidVenv = Join-Path $StudioHome (".venv.invalid.{0}.{1}" -f (Get-Date -Format "yyyyMMddHHmmss"), $PID) - Move-Item -Path $OldVenv -Destination $invalidVenv -Force -ErrorAction SilentlyContinue + Remove-Item -Recurse -Force $OldVenv -ErrorAction SilentlyContinue } } elseif (Test-Path (Join-Path $env:USERPROFILE "unsloth_studio\Scripts\python.exe")) { # CWD-relative venv from old install.ps1 -- migrate to absolute path @@ -959,7 +714,7 @@ shell.Run cmd, 0, False $venvExit = Invoke-InstallCommand { uv venv $VenvDir --python "$($DetectedPython.Path)" } if ($venvExit -ne 0) { Write-Host "[ERROR] Failed to create virtual environment (exit code $venvExit)" -ForegroundColor Red - return (Exit-InstallFailure "Failed to create virtual environment (exit code $venvExit)" $venvExit) + return } } else { step "venv" "using migrated environment" @@ -999,7 +754,7 @@ shell.Run cmd, 0, False # ── Choose the correct PyTorch index URL based on driver CUDA version ── # Mirrors Get-PytorchCudaTag in setup.ps1. function Get-TorchIndexUrl { - $baseUrl = if ($env:UNSLOTH_PYTORCH_MIRROR) { $env:UNSLOTH_PYTORCH_MIRROR.TrimEnd('/') } else { "https://download.pytorch.org/whl" } + $baseUrl = "https://download.pytorch.org/whl" if (-not $NvidiaSmiExe) { return "$baseUrl/cpu" } try { $output = & $NvidiaSmiExe 2>&1 | Out-String @@ -1017,9 +772,6 @@ shell.Run cmd, 0, False return "$baseUrl/cu126" } $TorchIndexUrl = Get-TorchIndexUrl - $TorchIndexFamily = Get-TauriTorchIndexFamily $TorchIndexUrl - $GpuBranch = Get-TauriGpuBranch $TorchIndexFamily - Write-TauriDiag -GpuBranch $GpuBranch -TorchIndexFamily $TorchIndexFamily -PythonVersionForDiag $DetectedPython.Version # ── Print CPU-only hint when no GPU detected ── if (-not $SkipTorch -and $TorchIndexUrl -like "*/cpu") { @@ -1063,12 +815,11 @@ shell.Run cmd, 0, False if ($_Migrated) { # Migrated env: force-reinstall unsloth+unsloth-zoo to ensure clean state # in the new venv location, while preserving existing torch/CUDA - Write-TauriLog "STEP" "Installing unsloth" substep "upgrading unsloth in migrated environment..." if ($SkipTorch) { # No-torch: install unsloth + unsloth-zoo with --no-deps, then # runtime deps (typer, safetensors, transformers, etc.) with --no-deps. - $baseInstallExit = Invoke-InstallCommand { uv pip install --python $VenvPython --no-deps --reinstall-package unsloth --reinstall-package unsloth-zoo "unsloth>=2026.4.8" unsloth-zoo } + $baseInstallExit = Invoke-InstallCommand { uv pip install --python $VenvPython --no-deps --reinstall-package unsloth --reinstall-package unsloth-zoo "unsloth>=2026.4.4" unsloth-zoo } if ($baseInstallExit -eq 0) { $NoTorchReq = Find-NoTorchRuntimeFile if ($NoTorchReq) { @@ -1076,45 +827,37 @@ shell.Run cmd, 0, False } } } else { - $baseInstallExit = Invoke-InstallCommand { uv pip install --python $VenvPython --reinstall-package unsloth --reinstall-package unsloth-zoo "unsloth>=2026.4.8" unsloth-zoo } + $baseInstallExit = Invoke-InstallCommand { uv pip install --python $VenvPython --reinstall-package unsloth --reinstall-package unsloth-zoo "unsloth>=2026.4.4" unsloth-zoo } } if ($baseInstallExit -ne 0) { Write-Host "[ERROR] Failed to install unsloth (exit code $baseInstallExit)" -ForegroundColor Red - return (Exit-InstallFailure "Failed to install unsloth (exit code $baseInstallExit)" $baseInstallExit) + return } if ($StudioLocalInstall) { substep "overlaying local repo (editable)..." $overlayExit = Invoke-InstallCommand { uv pip install --python $VenvPython -e $RepoRoot --no-deps } if ($overlayExit -ne 0) { Write-Host "[ERROR] Failed to overlay local repo (exit code $overlayExit)" -ForegroundColor Red - return (Exit-InstallFailure "Failed to overlay local repo (exit code $overlayExit)" $overlayExit) - } - substep "overlaying unsloth-zoo from git main..." - $zooOverlayExit = Invoke-InstallCommand { uv pip install --python $VenvPython --no-deps --reinstall-package unsloth-zoo "unsloth-zoo @ git+https://github.com/unslothai/unsloth-zoo" } - if ($zooOverlayExit -ne 0) { - Write-Host "[ERROR] Failed to overlay unsloth-zoo (exit code $zooOverlayExit)" -ForegroundColor Red - return (Exit-InstallFailure "Failed to overlay unsloth-zoo (exit code $zooOverlayExit)" $zooOverlayExit) + return } } } elseif ($TorchIndexUrl) { if ($SkipTorch) { substep "skipping PyTorch (--no-torch flag set)." "Yellow" } else { - Write-TauriLog "STEP" "Installing PyTorch" substep "installing PyTorch ($TorchIndexUrl)..." $torchInstallExit = Invoke-InstallCommand { uv pip install --python $VenvPython "torch>=2.4,<2.11.0" torchvision torchaudio --index-url $TorchIndexUrl } if ($torchInstallExit -ne 0) { Write-Host "[ERROR] Failed to install PyTorch (exit code $torchInstallExit)" -ForegroundColor Red - return (Exit-InstallFailure "Failed to install PyTorch (exit code $torchInstallExit)" $torchInstallExit) + return } } - Write-TauriLog "STEP" "Installing unsloth" substep "installing unsloth (this may take a few minutes)..." if ($SkipTorch) { # No-torch: install unsloth + unsloth-zoo with --no-deps, then # runtime deps (typer, safetensors, transformers, etc.) with --no-deps. - $baseInstallExit = Invoke-InstallCommand { uv pip install --python $VenvPython --no-deps --upgrade-package unsloth --upgrade-package unsloth-zoo "unsloth>=2026.4.8" unsloth-zoo } + $baseInstallExit = Invoke-InstallCommand { uv pip install --python $VenvPython --no-deps --upgrade-package unsloth --upgrade-package unsloth-zoo "unsloth>=2026.4.4" unsloth-zoo } if ($baseInstallExit -eq 0) { $NoTorchReq = Find-NoTorchRuntimeFile if ($NoTorchReq) { @@ -1122,13 +865,13 @@ shell.Run cmd, 0, False } } } elseif ($StudioLocalInstall) { - $baseInstallExit = Invoke-InstallCommand { uv pip install --python $VenvPython --upgrade-package unsloth "unsloth>=2026.4.8" unsloth-zoo } + $baseInstallExit = Invoke-InstallCommand { uv pip install --python $VenvPython --upgrade-package unsloth "unsloth>=2026.4.4" unsloth-zoo } } else { - $baseInstallExit = Invoke-InstallCommand { uv pip install --python $VenvPython --upgrade-package unsloth -- "$PackageName" } + $baseInstallExit = Invoke-InstallCommand { uv pip install --python $VenvPython --upgrade-package unsloth "$PackageName" } } if ($baseInstallExit -ne 0) { Write-Host "[ERROR] Failed to install unsloth (exit code $baseInstallExit)" -ForegroundColor Red - return (Exit-InstallFailure "Failed to install unsloth (exit code $baseInstallExit)" $baseInstallExit) + return } if ($StudioLocalInstall) { @@ -1136,85 +879,29 @@ shell.Run cmd, 0, False $overlayExit = Invoke-InstallCommand { uv pip install --python $VenvPython -e $RepoRoot --no-deps } if ($overlayExit -ne 0) { Write-Host "[ERROR] Failed to overlay local repo (exit code $overlayExit)" -ForegroundColor Red - return (Exit-InstallFailure "Failed to overlay local repo (exit code $overlayExit)" $overlayExit) - } - substep "overlaying unsloth-zoo from git main..." - $zooOverlayExit = Invoke-InstallCommand { uv pip install --python $VenvPython --no-deps --reinstall-package unsloth-zoo "unsloth-zoo @ git+https://github.com/unslothai/unsloth-zoo" } - if ($zooOverlayExit -ne 0) { - Write-Host "[ERROR] Failed to overlay unsloth-zoo (exit code $zooOverlayExit)" -ForegroundColor Red - return (Exit-InstallFailure "Failed to overlay unsloth-zoo (exit code $zooOverlayExit)" $zooOverlayExit) + return } } } else { # Fallback: GPU detection failed to produce a URL -- let uv resolve torch - Write-TauriLog "STEP" "Installing unsloth" substep "installing unsloth (this may take a few minutes)..." if ($StudioLocalInstall) { - $baseInstallExit = Invoke-InstallCommand { uv pip install --python $VenvPython unsloth-zoo "unsloth>=2026.4.8" --torch-backend=auto } + $baseInstallExit = Invoke-InstallCommand { uv pip install --python $VenvPython unsloth-zoo "unsloth>=2026.4.4" --torch-backend=auto } if ($baseInstallExit -ne 0) { Write-Host "[ERROR] Failed to install unsloth (exit code $baseInstallExit)" -ForegroundColor Red - return (Exit-InstallFailure "Failed to install unsloth (exit code $baseInstallExit)" $baseInstallExit) + return } substep "overlaying local repo (editable)..." $overlayExit = Invoke-InstallCommand { uv pip install --python $VenvPython -e $RepoRoot --no-deps } if ($overlayExit -ne 0) { Write-Host "[ERROR] Failed to overlay local repo (exit code $overlayExit)" -ForegroundColor Red - return (Exit-InstallFailure "Failed to overlay local repo (exit code $overlayExit)" $overlayExit) - } - substep "overlaying unsloth-zoo from git main..." - $zooOverlayExit = Invoke-InstallCommand { uv pip install --python $VenvPython --no-deps --reinstall-package unsloth-zoo "unsloth-zoo @ git+https://github.com/unslothai/unsloth-zoo" } - if ($zooOverlayExit -ne 0) { - Write-Host "[ERROR] Failed to overlay unsloth-zoo (exit code $zooOverlayExit)" -ForegroundColor Red - return (Exit-InstallFailure "Failed to overlay unsloth-zoo (exit code $zooOverlayExit)" $zooOverlayExit) + return } } else { - $baseInstallExit = Invoke-InstallCommand { uv pip install --python $VenvPython --torch-backend=auto -- "$PackageName" } + $baseInstallExit = Invoke-InstallCommand { uv pip install --python $VenvPython "$PackageName" --torch-backend=auto } if ($baseInstallExit -ne 0) { Write-Host "[ERROR] Failed to install unsloth (exit code $baseInstallExit)" -ForegroundColor Red - return (Exit-InstallFailure "Failed to install unsloth (exit code $baseInstallExit)" $baseInstallExit) - } - } - } - - # Overlay Tauri-bundled studio fixes that may be ahead of PyPI. Skipped - # for --local: the editable install above already makes _PACKAGE_ROOT in - # unsloth_cli/commands/studio.py resolve to the repo (PEP 660 __file__). - # Source paths match the Tauri bundle layout in studio/src-tauri/tauri.conf.json, - # which bundles install_python_stack.py at the bundle root next to install.ps1. - if ($TauriMode) { - $rawPath = if ($PSCommandPath) { $PSCommandPath } else { $MyInvocation.ScriptName } - if ($rawPath) { - # Strip leading \\?\ extended-length prefix if the launcher passed one. - $scriptDir = Split-Path -Parent ($rawPath -replace '^\\\\\?\\', '') - $overlayMap = [ordered]@{ - "install_python_stack.py" = "Lib\site-packages\studio\install_python_stack.py" - } - foreach ($rel in $overlayMap.Keys) { - $src = Join-Path $scriptDir $rel - $dst = Join-Path $VenvDir $overlayMap[$rel] - if (-not (Test-Path $src)) { continue } - $dstParent = Split-Path -Parent $dst - if (-not (Test-Path $dstParent)) { - Write-Host "[WARN] Overlay target dir missing: $dstParent; studio setup may use stale bundled file" -ForegroundColor Yellow - continue - } - try { - if (-not (Test-Path $dst)) { - # Backfill: target file missing but parent dir exists. - Copy-Item $src $dst -Force - substep ("backfilled bundled " + (Split-Path -Leaf $rel)) - } else { - # Hash-compare so re-runs are no-ops when files already match. - $srcHash = (Get-FileHash $src -Algorithm SHA256).Hash - $dstHash = (Get-FileHash $dst -Algorithm SHA256).Hash - if ($srcHash -ne $dstHash) { - Copy-Item $src $dst -Force - substep ("applied bundled " + (Split-Path -Leaf $rel)) - } - } - } catch { - Write-Host "[WARN] Could not overlay $($rel): $($_.Exception.Message); studio setup may use stale bundled file" -ForegroundColor Yellow - } + return } } } @@ -1222,7 +909,6 @@ shell.Run cmd, 0, False # ── Run studio setup ── # setup.ps1 will handle installing Git, CMake, Visual Studio Build Tools, # CUDA Toolkit, Node.js, and other dependencies automatically via winget. - Write-TauriLog "STEP" "Running studio setup" step "setup" "running unsloth studio setup..." $UnslothExe = Join-Path $VenvDir "Scripts\unsloth.exe" if (-not (Test-Path $UnslothExe)) { @@ -1230,14 +916,12 @@ shell.Run cmd, 0, False Write-Host " Expected: $UnslothExe" -ForegroundColor Yellow Write-Host " This usually means an older unsloth version was installed that does not include the Studio CLI." -ForegroundColor Yellow Write-Host " Try re-running the installer or see: https://github.com/unslothai/unsloth?tab=readme-ov-file#-quickstart" -ForegroundColor Yellow - return (Exit-InstallFailure "unsloth CLI was not installed correctly") + return } # Tell setup.ps1 to skip base package installation (install.ps1 already did it) $env:SKIP_STUDIO_BASE = "1" $env:STUDIO_PACKAGE_NAME = $PackageName $env:UNSLOTH_NO_TORCH = if ($SkipTorch) { "true" } else { "false" } - # Tauri desktop app bundles its own frontend — skip Node/npm/frontend build - $env:SKIP_STUDIO_FRONTEND = if ($TauriMode) { "1" } else { "0" } # Always set STUDIO_LOCAL_INSTALL explicitly to avoid stale values from # a previous --local run in the same PowerShell session. if ($StudioLocalInstall) { @@ -1252,117 +936,37 @@ shell.Run cmd, 0, False # and bypass the fast-path version check from PR #4667. $studioArgs = @('studio', 'setup') if ($script:UnslothVerbose) { $studioArgs += '--verbose' } - $env:UNSLOTH_INSTALL_ROLLBACK_MANAGED = "1" - try { - & $UnslothExe @studioArgs - $setupExit = $LASTEXITCODE - } finally { - Remove-Item Env:UNSLOTH_INSTALL_ROLLBACK_MANAGED -ErrorAction SilentlyContinue - } + & $UnslothExe @studioArgs + $setupExit = $LASTEXITCODE if ($setupExit -ne 0) { Write-Host "[ERROR] unsloth studio setup failed (exit code $setupExit)" -ForegroundColor Red - return (Exit-InstallFailure "unsloth studio setup failed (exit code $setupExit)" $setupExit) + return } - # ── Expose `unsloth` via a shim dir containing only unsloth.exe ── - # We do NOT add the venv Scripts dir to PATH (it also holds python.exe - # and pip.exe, which would hijack the user's system interpreter). - # Hardlink preferred; falls back to copy if cross-volume or non-NTFS. - # - # Remove the legacy venv Scripts PATH entry that older installers wrote. - $LegacyScriptsDir = Join-Path $VenvDir "Scripts" - try { - $legacyKey = [Microsoft.Win32.Registry]::CurrentUser.CreateSubKey('Environment') - try { - $rawPath = $legacyKey.GetValue('Path', '', [Microsoft.Win32.RegistryValueOptions]::DoNotExpandEnvironmentNames) - if ($rawPath) { - [string[]]$pathEntries = $rawPath -split ';' - $normalLegacy = $LegacyScriptsDir.Trim().Trim('"').TrimEnd('\').ToLowerInvariant() - $expNormalLegacy = [Environment]::ExpandEnvironmentVariables($LegacyScriptsDir).Trim().Trim('"').TrimEnd('\').ToLowerInvariant() - $filtered = @($pathEntries | Where-Object { - $stripped = $_.Trim().Trim('"') - $rawNorm = $stripped.TrimEnd('\').ToLowerInvariant() - $expNorm = [Environment]::ExpandEnvironmentVariables($stripped).TrimEnd('\').ToLowerInvariant() - ($rawNorm -ne $normalLegacy -and $rawNorm -ne $expNormalLegacy) -and - ($expNorm -ne $normalLegacy -and $expNorm -ne $expNormalLegacy) - }) - $cleanedPath = $filtered -join ';' - if ($cleanedPath -ne $rawPath) { - $legacyKey.SetValue('Path', $cleanedPath, [Microsoft.Win32.RegistryValueKind]::ExpandString) - try { - $d = "UnslothPathRefresh_$([guid]::NewGuid().ToString('N').Substring(0,8))" - [Environment]::SetEnvironmentVariable($d, '1', 'User') - [Environment]::SetEnvironmentVariable($d, [NullString]::Value, 'User') - } catch { } - } - } - } finally { - $legacyKey.Close() - } - } catch { } - $ShimDir = Join-Path $StudioHome "bin" - New-Item -ItemType Directory -Force -Path $ShimDir | Out-Null - $ShimExe = Join-Path $ShimDir "unsloth.exe" - # try/catch: if unsloth.exe is locked (Studio running), keep the old shim. - $shimUpdated = $false - try { - if (Test-Path $ShimExe) { Remove-Item $ShimExe -Force -ErrorAction Stop } - try { - New-Item -ItemType HardLink -Path $ShimExe -Target $UnslothExe -ErrorAction Stop | Out-Null - } catch { - Copy-Item -Path $UnslothExe -Destination $ShimExe -Force -ErrorAction Stop # fallback: copy - } - $shimUpdated = $true - } catch { - if (Test-Path $ShimExe) { - Write-Host "[WARN] Could not refresh unsloth launcher at $ShimExe." -ForegroundColor Yellow - Write-Host " This usually means a running 'unsloth studio' process still holds the file open." -ForegroundColor Yellow - Write-Host " Close Studio and re-run the installer to pick up the latest launcher." -ForegroundColor Yellow - Write-Host " Continuing with the existing launcher." -ForegroundColor Yellow + New-StudioShortcuts -UnslothExePath $UnslothExe + + # ── Add venv Scripts dir to User PATH so `unsloth studio` works from any terminal ── + $ScriptsDir = Join-Path $VenvDir "Scripts" + $UserPath = [System.Environment]::GetEnvironmentVariable("Path", "User") + if (-not $UserPath -or $UserPath -notlike "*$ScriptsDir*") { + if ($UserPath) { + [System.Environment]::SetEnvironmentVariable("Path", "$ScriptsDir;$UserPath", "User") } else { - Write-Host "[WARN] Could not create unsloth launcher at $ShimExe" -ForegroundColor Yellow - Write-Host " $($_.Exception.Message)" -ForegroundColor Yellow - Write-Host " Launch unsloth studio directly via '$UnslothExe' until the next successful install." -ForegroundColor Yellow + [System.Environment]::SetEnvironmentVariable("Path", "$ScriptsDir", "User") } + Refresh-SessionPath + step "path" "added unsloth to PATH" } - # Only add to PATH when the launcher actually exists on disk. - $pathAdded = $false - if (Test-Path $ShimExe) { - $pathAdded = Add-ToUserPath -Directory $ShimDir -Position 'Prepend' - } - if ($shimUpdated -and $pathAdded) { - step "path" "added unsloth launcher to PATH" - } - Refresh-SessionPath # sync current session with registry - Complete-StudioVenvRollback - - # ── Tauri mode: done, skip shortcuts and auto-launch ── - if ($TauriMode) { - Write-TauriLog "DONE" "" - return - } - - New-StudioShortcuts -UnslothExePath $UnslothExe - # In interactive terminals, ask the user before starting Studio. - # In non-interactive environments (CI, Docker) just print instructions. + # Launch studio automatically in interactive terminals; + # in non-interactive environments (CI, Docker) just print instructions. $IsInteractive = [Environment]::UserInteractive -and (-not [Console]::IsInputRedirected) if ($IsInteractive) { - Write-Host "" - $reply = Read-Host " Start Unsloth Studio now? [Y/n]" - if ([string]::IsNullOrWhiteSpace($reply) -or $reply -match '^[Yy]') { - & $UnslothExe studio -p 8888 - } else { - step "launch" "to start later, run:" - substep "unsloth studio -p 8888" - substep "(add -H 0.0.0.0 to allow network / cloud access)" - Write-Host "" - } + & $UnslothExe studio -H 0.0.0.0 -p 8888 } else { step "launch" "manual commands:" substep "& `"$VenvDir\Scripts\Activate.ps1`"" - substep "unsloth studio -p 8888" - substep "(add -H 0.0.0.0 to allow network / cloud access)" + substep "unsloth studio -H 0.0.0.0 -p 8888" Write-Host "" } } diff --git a/install.sh b/install.sh index 7948170043..ea53ecc6d6 100755 --- a/install.sh +++ b/install.sh @@ -35,7 +35,6 @@ substep() { printf " ${C_DIM}%-15s${2:-$C_DIM}%s${C_RST}\n" "" "$1"; } # ── Parse flags ── STUDIO_LOCAL_INSTALL=false PACKAGE_NAME="unsloth" -TAURI_MODE=false _USER_PYTHON="" _NO_TORCH_FLAG=false _VERBOSE=false @@ -55,7 +54,6 @@ for arg in "$@"; do case "$arg" in --local) STUDIO_LOCAL_INSTALL=true ;; --package) _next_is_package=true ;; - --tauri) TAURI_MODE=true ;; --python) _next_is_python=true ;; --no-torch) _NO_TORCH_FLAG=true ;; --verbose|-v) _VERBOSE=true ;; @@ -96,45 +94,6 @@ run_install_cmd() { return $_rc } -# Install bitsandbytes on AMD ROCm hosts. Uses the continuous-release_main -# wheel for the ROCm 4-bit GEMV fix (bnb PR #1887, post-0.49.2); bnb <= 0.49.2 -# NaNs at decode shape on every AMD GPU. Falls back to PyPI >=0.49.1 if the -# pre-release URL is unreachable. Drop the pin once bnb 0.50+ ships on PyPI. -_install_bnb_rocm() { - _label="$1" - _venv_py="$2" - case "$_ARCH" in - x86_64|amd64) - _bnb_whl_url="https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-manylinux_2_24_x86_64.whl" - ;; - aarch64|arm64) - _bnb_whl_url="https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-manylinux_2_24_aarch64.whl" - ;; - *) - _bnb_whl_url="" - ;; - esac - # uv rejects the continuous-release_main bitsandbytes wheel because the - # filename version (1.33.7rc0) does not match the embedded metadata version - # (0.50.0.dev0). pip accepts the mismatch, so bootstrap pip and use it. - if ! "$_venv_py" -m pip --version >/dev/null 2>&1; then - if ! run_maybe_quiet "$_venv_py" -m ensurepip --upgrade; then - run_maybe_quiet uv pip install --python "$_venv_py" pip || \ - substep "[WARN] could not bootstrap pip; bitsandbytes install will likely fail" "$C_WARN" - fi - fi - if [ -n "$_bnb_whl_url" ]; then - substep "installing bitsandbytes for AMD ROCm (pre-release, PR #1887)..." - if run_install_cmd "$_label (pre-release)" "$_venv_py" -m pip install \ - --force-reinstall --no-cache-dir --no-deps "$_bnb_whl_url"; then - return 0 - fi - substep "[WARN] bnb pre-release install failed; falling back to PyPI (4-bit decode broken on ROCm)" "$C_WARN" - fi - run_install_cmd "$_label (pypi fallback)" "$_venv_py" -m pip install \ - --force-reinstall --no-cache-dir --no-deps "bitsandbytes>=0.49.1" -} - if [ "$_next_is_package" = true ]; then echo "❌ ERROR: --package requires an argument." >&2 exit 1 @@ -144,137 +103,9 @@ if [ "$_next_is_python" = true ]; then exit 1 fi -# Validate --package to prevent injection into shell/Python commands. -# Must start with a letter/digit (rejects leading dashes that uv would parse as flags). -case "$PACKAGE_NAME" in - [!a-zA-Z0-9]*) - echo "❌ ERROR: --package name must start with a letter or digit." >&2 - exit 1 ;; - *[!a-zA-Z0-9._-]*) - echo "❌ ERROR: --package name contains invalid characters (allowed: a-z A-Z 0-9 . _ -)" >&2 - exit 1 ;; -esac - -# ── Tauri structured output ── -tauri_log() { - if [ "$TAURI_MODE" = true ]; then - echo "[TAURI:$1] $2" - fi -} - -tauri_diag_marker() { - _diag_gpu_branch="${1:-unknown}" - _diag_torch_index_family="${2:-none}" - tauri_log "DIAG" "diag_schema=1 platform=${OS:-unknown} arch=${_ARCH:-unknown} python_version=${PYTHON_VERSION:-unknown} skip_torch=${SKIP_TORCH:-false} mac_intel=${MAC_INTEL:-false} gpu_branch=${_diag_gpu_branch} torch_index_family=${_diag_torch_index_family}" -} - -_tauri_torch_index_family() { - if [ "${SKIP_TORCH:-false}" = true ]; then - echo "none" - return - fi - _diag_url="${1:-}" - case "$_diag_url" in - */cu118) echo "cu118" ;; - */cu124) echo "cu124" ;; - */cu126) echo "cu126" ;; - */cu128) echo "cu128" ;; - */cu130) echo "cu130" ;; - */cpu) echo "cpu" ;; - */rocm[0-9]*.[0-9]*) - _diag_family=${_diag_url##*/} - case "$_diag_family" in - rocm[0-9]*.[0-9]*) echo "$_diag_family" ;; - *) echo "auto" ;; - esac ;; - "") echo "none" ;; - *) echo "auto" ;; - esac -} - -_tauri_gpu_branch() { - _diag_family="${1:-unknown}" - _diag_radeon="${2:-false}" - if [ "${SKIP_TORCH:-false}" = true ]; then - echo "no_torch" - return - fi - if [ "${OS:-}" = "macos" ]; then - echo "mac" - return - fi - case "$_diag_family" in - cu*) echo "cuda" ;; - rocm*) - if [ "$_diag_radeon" = true ]; then - echo "rocm_radeon" - else - echo "rocm" - fi ;; - radeon) echo "rocm_radeon" ;; - cpu) echo "cpu" ;; - none) echo "no_torch" ;; - *) echo "unknown" ;; - esac -} - PYTHON_VERSION="" # resolved after platform detection STUDIO_HOME="$HOME/.unsloth/studio" VENV_DIR="$STUDIO_HOME/unsloth_studio" -_VENV_ROLLBACK_DIR="" -_VENV_ROLLBACK_TARGET="$VENV_DIR" -_VENV_ROLLBACK_ACTIVE=false - -_start_studio_venv_replacement() { - _existing_dir="$1" - _stamp=$(date +%Y%m%d%H%M%S 2>/dev/null || echo "time") - _candidate="$STUDIO_HOME/unsloth_studio.rollback.$_stamp.$$" - _suffix=0 - while [ -e "$_candidate" ]; do - _suffix=$((_suffix + 1)) - _candidate="$STUDIO_HOME/unsloth_studio.rollback.$_stamp.$$.$_suffix" - done - mv "$_existing_dir" "$_candidate" - _VENV_ROLLBACK_DIR="$_candidate" - _VENV_ROLLBACK_TARGET="$_existing_dir" - _VENV_ROLLBACK_ACTIVE=true - substep "previous environment preserved for rollback" -} - -_restore_studio_venv_replacement() { - [ "$_VENV_ROLLBACK_ACTIVE" = true ] || return 0 - [ -n "$_VENV_ROLLBACK_DIR" ] && [ -d "$_VENV_ROLLBACK_DIR" ] || { - _VENV_ROLLBACK_ACTIVE=false - return 0 - } - substep "restoring previous environment after failed install..." "$C_WARN" - rm -rf "$_VENV_ROLLBACK_TARGET" - if mv "$_VENV_ROLLBACK_DIR" "$_VENV_ROLLBACK_TARGET"; then - substep "restored previous environment" - _VENV_ROLLBACK_ACTIVE=false - _VENV_ROLLBACK_DIR="" - else - echo "⚠️ Could not restore previous environment from $_VENV_ROLLBACK_DIR to $_VENV_ROLLBACK_TARGET" >&2 - fi -} - -_commit_studio_venv_replacement() { - [ "$_VENV_ROLLBACK_ACTIVE" = true ] || return 0 - if [ -n "$_VENV_ROLLBACK_DIR" ] && [ -d "$_VENV_ROLLBACK_DIR" ]; then - rm -rf "$_VENV_ROLLBACK_DIR" || true - fi - _VENV_ROLLBACK_ACTIVE=false - _VENV_ROLLBACK_DIR="" -} - -_on_install_exit() { - _status=$? - if [ "$_status" -ne 0 ]; then - _restore_studio_venv_replacement - fi - exit "$_status" -} -trap _on_install_exit EXIT # ── Helper: download a URL to a file (supports curl and wget) ── download() { @@ -322,12 +153,6 @@ _smart_apt_install() { return 0 fi - # In Tauri mode, report needed packages and exit — Rust handles elevation - if [ "$TAURI_MODE" = true ]; then - tauri_log "NEED_SUDO" "$_STILL_MISSING" - exit 2 - fi - # Step 3: Escalate -- need elevated permissions for remaining packages if command -v sudo >/dev/null 2>&1; then echo "" @@ -622,11 +447,11 @@ if [ -t 1 ]; then ) & # Clear traps so exec does not trigger _release_lock (the subshell owns it) trap - EXIT INT TERM - exec "$UNSLOTH_EXE" studio -p "$_launch_port" + exec "$UNSLOTH_EXE" studio -H 0.0.0.0 -p "$_launch_port" else # ── Background mode (no TTY) ── # Used by macOS .app and headless invocations. - _launch_cmd=$(printf '%q ' "$UNSLOTH_EXE" studio -p "$_launch_port") + _launch_cmd=$(printf '%q ' "$UNSLOTH_EXE" studio -H 0.0.0.0 -p "$_launch_port") _launch_cmd=${_launch_cmd% } _spawn_terminal "$_launch_cmd" @@ -888,7 +713,6 @@ printf " ${C_DIM}%s${C_RST}\n" "$RULE" echo "" # ── Detect platform ── -tauri_log "STEP" "Detecting platform" OS="linux" if [ "$(uname)" = "Darwin" ]; then OS="macos" @@ -938,18 +762,9 @@ if [ "$_NO_TORCH_FLAG" = true ] || [ "$MAC_INTEL" = true ]; then SKIP_TORCH=true fi -_TAURI_INITIAL_GPU_BRANCH="unknown" -if [ "$SKIP_TORCH" = true ]; then - _TAURI_INITIAL_GPU_BRANCH="no_torch" -elif [ "$OS" = "macos" ]; then - _TAURI_INITIAL_GPU_BRANCH="mac" -fi -tauri_diag_marker "$_TAURI_INITIAL_GPU_BRANCH" "none" - # ── Check system dependencies ── # cmake and git are needed by unsloth studio setup to build the GGUF inference # engine (llama.cpp). build-essential and libcurl-dev are also needed on Linux. -tauri_log "STEP" "Checking system dependencies" MISSING="" command -v cmake >/dev/null 2>&1 || MISSING="$MISSING cmake" @@ -974,7 +789,9 @@ case "$OS" in fi command -v gcc >/dev/null 2>&1 || MISSING="$MISSING build-essential" # libcurl dev headers for llama.cpp HTTPS support - command -v curl-config >/dev/null 2>&1 || MISSING="$MISSING libcurl4-openssl-dev" + if command -v dpkg >/dev/null 2>&1; then + dpkg -s libcurl4-openssl-dev >/dev/null 2>&1 || MISSING="$MISSING libcurl4-openssl-dev" + fi ;; esac @@ -999,15 +816,9 @@ if [ -n "$MISSING" ]; then if command -v apt-get >/dev/null 2>&1; then _smart_apt_install $MISSING else - echo " Automatic system package installation is supported on apt-based" - echo " Linux distributions (Ubuntu/Debian) only. Please install the" - echo " missing dependencies with your package manager, then re-run setup:" + echo " apt-get is not available. Please install with your package manager:" echo " $MISSING" - echo "" - echo " Examples:" - echo " Fedora/RHEL: sudo dnf install cmake git gcc gcc-c++ make libcurl-devel" - echo " Arch: sudo pacman -S --needed cmake git base-devel curl" - echo " openSUSE: sudo zypper install cmake git gcc gcc-c++ make libcurl-devel" + echo " Then re-run Unsloth Studio setup." exit 1 fi ;; @@ -1018,7 +829,6 @@ else fi # ── Install uv ── -tauri_log "STEP" "Installing uv package manager" UV_MIN_VERSION="0.7.14" version_ge() { @@ -1073,25 +883,17 @@ if ! command -v uv >/dev/null 2>&1 || ! _uv_version_ok uv; then fi # ── Create venv (migrate old layout if possible, otherwise fresh) ── -tauri_log "STEP" "Creating virtual environment" mkdir -p "$STUDIO_HOME" _MIGRATED=false if [ -x "$VENV_DIR/bin/python" ]; then - # New layout already exists — replace only after preserving rollback copy. - substep "preserving existing environment for rollback..." - _start_studio_venv_replacement "$VENV_DIR" + # New layout already exists — nuke for fresh install + rm -rf "$VENV_DIR" elif [ -x "$STUDIO_HOME/.venv/bin/python" ]; then - # Old layout exists — validate before migrating. - # In no-torch mode, a missing torch package is expected; validate Python only. + # Old layout exists — validate before migrating substep "found legacy Studio environment, validating..." - _legacy_ok=false - if [ "$SKIP_TORCH" = true ]; then - if "$STUDIO_HOME/.venv/bin/python" -c "import sys; print(sys.executable)" >/dev/null 2>&1; then - _legacy_ok=true - fi - elif "$STUDIO_HOME/.venv/bin/python" -c " + if "$STUDIO_HOME/.venv/bin/python" -c " import torch device = 'cuda' if torch.cuda.is_available() else 'cpu' A = torch.ones((10, 10), device=device) @@ -1101,17 +903,13 @@ D = A + B E = D @ C torch.testing.assert_close(torch.unique(E), torch.tensor((20,), device=E.device, dtype=E.dtype)) " >/dev/null 2>&1; then - _legacy_ok=true - fi - if [ "$_legacy_ok" = true ]; then echo "✅ Legacy environment is healthy — migrating..." mv "$STUDIO_HOME/.venv" "$VENV_DIR" echo " Moved ~/.unsloth/studio/.venv → $VENV_DIR" _MIGRATED=true else echo "⚠️ Legacy environment failed validation — creating fresh environment" - _invalid_venv="$STUDIO_HOME/.venv.invalid.$(date +%Y%m%d%H%M%S 2>/dev/null || echo time).$$" - mv "$STUDIO_HOME/.venv" "$_invalid_venv" 2>/dev/null || true + rm -rf "$STUDIO_HOME/.venv" fi fi @@ -1180,122 +978,22 @@ _find_no_torch_runtime() { fi } -# ── AMD ROCm GPU detection helper ── -# Returns 0 (true) if an actual AMD GPU is present, 1 (false) otherwise. -# Checks rocminfo for gfx[1-9]* (excludes gfx000 CPU agent) and -# amd-smi list for GPU data rows (excludes header-only output). -_has_amd_rocm_gpu() { - if command -v rocminfo >/dev/null 2>&1 && \ - rocminfo 2>/dev/null | awk '/Name:[[:space:]]*gfx[0-9]/ && !/Name:[[:space:]]*gfx000/{found=1} END{exit !found}'; then - return 0 - elif command -v amd-smi >/dev/null 2>&1 && \ - amd-smi list 2>/dev/null | awk '/^GPU[[:space:]]*[:\[][[:space:]]*[0-9]/{ found=1 } END{ exit !found }'; then - return 0 - fi - return 1 -} - -# ── NVIDIA usable-GPU helper ── -# Returns 0 (true) only if nvidia-smi is present AND actually lists a GPU. -# Prevents AMD-only hosts with a stale nvidia-smi on PATH from being routed -# into the CUDA branch. -_has_usable_nvidia_gpu() { - _nvsmi="" - if command -v nvidia-smi >/dev/null 2>&1; then - _nvsmi="nvidia-smi" - elif [ -x "/usr/bin/nvidia-smi" ]; then - _nvsmi="/usr/bin/nvidia-smi" - else - return 1 - fi - "$_nvsmi" -L 2>/dev/null | awk '/^GPU[[:space:]]+[0-9]+:/{found=1} END{exit !found}' -} - # ── Detect GPU and choose PyTorch index URL ── # Mirrors Get-TorchIndexUrl in install.ps1. # On CPU-only machines this returns the cpu index, avoiding the solver # dead-end where --torch-backend=auto resolves to unsloth==2024.8. get_torch_index_url() { - _base="${UNSLOTH_PYTORCH_MIRROR:-https://download.pytorch.org/whl}" - _base="${_base%/}" + _base="https://download.pytorch.org/whl" # macOS: always CPU (no CUDA support) case "$(uname -s)" in Darwin) echo "$_base/cpu"; return ;; esac - # Try nvidia-smi -- require the binary to actually list a usable GPU. - # Presence of the binary alone (container leftovers, stale driver - # packages) is not sufficient: otherwise an AMD-only host would - # silently install CUDA wheels. + # Try nvidia-smi _smi="" - if _has_usable_nvidia_gpu; then - if command -v nvidia-smi >/dev/null 2>&1; then - _smi="nvidia-smi" - elif [ -x "/usr/bin/nvidia-smi" ]; then - _smi="/usr/bin/nvidia-smi" - fi - fi - if [ -z "$_smi" ]; then - # No NVIDIA GPU -- check for AMD ROCm GPU. - # PyTorch only publishes ROCm wheels for linux-x86_64; skip the - # ROCm branch entirely on aarch64 / arm64 / other architectures - # so non-x86_64 Linux hosts fall back cleanly to CPU wheels. - case "$(uname -m)" in - x86_64|amd64) : ;; - *) echo "$_base/cpu"; return ;; - esac - if ! _has_amd_rocm_gpu; then - echo "$_base/cpu"; return - fi - # AMD GPU confirmed -- detect ROCm version - _rocm_tag="" - _rocm_tag=$({ command -v amd-smi >/dev/null 2>&1 && \ - amd-smi version 2>/dev/null | awk -F'ROCm version: ' \ - 'NF>1{gsub(/[^0-9.]/, "", $2); split($2,a,"."); print "rocm"a[1]"."a[2]; ok=1; exit} END{exit !ok}'; } || \ - { [ -r /opt/rocm/.info/version ] && \ - awk -F. '{print "rocm"$1"."$2; exit}' /opt/rocm/.info/version; } || \ - { command -v hipconfig >/dev/null 2>&1 && \ - hipconfig --version 2>/dev/null | awk 'NR==1 && /^[0-9]/{split($1,a,"."); if(a[1]+0>0){print "rocm"a[1]"."a[2]; found=1}} END{exit !found}'; } || \ - { command -v dpkg-query >/dev/null 2>&1 && \ - ver="$(dpkg-query -W -f='${Version}\n' rocm-core 2>/dev/null)" && \ - [ -n "$ver" ] && \ - printf '%s\n' "$ver" | sed 's/^[0-9]*://' | awk -F'[.-]' '{print "rocm"$1"."$2; exit}'; } || \ - { command -v rpm >/dev/null 2>&1 && \ - ver="$(rpm -q --qf '%{VERSION}\n' rocm-core 2>/dev/null)" && \ - [ -n "$ver" ] && \ - printf '%s\n' "$ver" | awk -F'[.-]' '{print "rocm"$1"."$2; exit}'; }) 2>/dev/null - # Validate _rocm_tag: must match "rocmX.Y" with major >= 1 - case "$_rocm_tag" in - rocm[1-9]*.[0-9]*) : ;; # valid (major >= 1) - *) _rocm_tag="" ;; # reject malformed (empty, garbled, or major=0) - esac - if [ -n "$_rocm_tag" ]; then - # Minimum supported: ROCm 6.0 (no PyTorch wheels exist for older) - case "$_rocm_tag" in - rocm[1-5].*) echo "$_base/cpu"; return ;; - esac - # ROCm 7.2 only has torch 2.11.0 which exceeds current bounds - # (<2.11.0). Fall back to rocm7.1 index which has torch 2.10.0. - # Enumerate explicit versions rather than matching rocm6.* so - # a host on ROCm 6.5 or 6.6 (no PyTorch wheels published) is - # clipped down to the last supported 6.x (rocm6.4) instead of - # constructing https://download.pytorch.org/whl/rocm6.5 which - # returns HTTP 403. PyTorch only ships: rocm5.7, 6.0, 6.1, 6.2, - # 6.3, 6.4, 7.0, 7.1, 7.2 (and 5.7 is below our minimum). - # TODO: uncomment rocm7.2 when the torch upper bound is bumped - # to >=2.11.0. - case "$_rocm_tag" in - rocm6.0|rocm6.0.*|rocm6.1|rocm6.1.*|rocm6.2|rocm6.2.*|rocm6.3|rocm6.3.*|rocm6.4|rocm6.4.*|rocm7.0|rocm7.0.*|rocm7.1|rocm7.1.*) - echo "$_base/$_rocm_tag" ;; - rocm6.*) - # ROCm 6.5+ (no published PyTorch wheels): clip down - # to the last supported 6.x wheel set. - echo "$_base/rocm6.4" ;; - *) - # ROCm 7.2+ (including future 10.x+): cap to rocm7.1 - echo "$_base/rocm7.1" ;; - esac - return - fi - echo "$_base/cpu"; return + if command -v nvidia-smi >/dev/null 2>&1; then + _smi="nvidia-smi" + elif [ -x "/usr/bin/nvidia-smi" ]; then + _smi="/usr/bin/nvidia-smi" fi + if [ -z "$_smi" ]; then echo "$_base/cpu"; return; fi # Parse CUDA version from nvidia-smi output (POSIX-safe, no grep -P) _cuda_ver=$(LC_ALL=C $_smi 2>/dev/null \ | sed -n 's/.*CUDA Version:[[:space:]]*\([0-9][0-9]*\.[0-9][0-9]*\).*/\1/p' \ @@ -1313,167 +1011,23 @@ get_torch_index_url() { elif [ "$_major" -ge 11 ]; then echo "$_base/cu118" else echo "$_base/cpu"; fi } - -get_radeon_wheel_url() { - # Only meaningful on Linux. Picks a repo.radeon.com base URL whose listing - # contains torch wheels. Tries paths like rocm-rel-7.2.1/, rocm-rel-7.2/, - # rocm-rel-7.1.1/, rocm-rel-7.1/ (AMD publishes both M.m and M.m.p dirs). - # Accepts both X.Y and X.Y.Z host versions since /opt/rocm/.info/version - # and hipconfig --version can return either shape. - case "$(uname -s)" in Linux) ;; *) echo ""; return ;; esac - - # Detect ROCm version (X.Y or X.Y.Z) -- try amd-smi, then - # /opt/rocm/.info/version, then hipconfig. - _full_ver="" - _full_ver=$({ command -v amd-smi >/dev/null 2>&1 && \ - amd-smi version 2>/dev/null | awk -F'ROCm version: ' \ - 'NF>1{if(match($2,/[0-9]+\.[0-9]+(\.[0-9]+)?/)){print substr($2,RSTART,RLENGTH); ok=1; exit}} END{exit !ok}'; } || \ - { [ -r /opt/rocm/.info/version ] && \ - awk 'match($0,/[0-9]+\.[0-9]+(\.[0-9]+)?/){print substr($0,RSTART,RLENGTH); found=1; exit} END{exit !found}' /opt/rocm/.info/version; } || \ - { command -v hipconfig >/dev/null 2>&1 && \ - hipconfig --version 2>/dev/null | awk 'NR==1 && match($0,/[0-9]+\.[0-9]+(\.[0-9]+)?/){print substr($0,RSTART,RLENGTH); found=1} END{exit !found}'; }) 2>/dev/null - - # Validate: must be X.Y or X.Y.Z with X >= 1 - case "$_full_ver" in - [1-9]*.[0-9]*.[0-9]*) : ;; # X.Y.Z - [1-9]*.[0-9]*) : ;; # X.Y - *) echo ""; return ;; - esac - echo "https://repo.radeon.com/rocm/manylinux/rocm-rel-${_full_ver}/" -} - -# ── Radeon repo wheel selection helpers ────────────────────────────────────── -# Fetches the Radeon repo directory listing once into _RADEON_LISTING (global). -# _RADEON_PYTAG holds the CPython tag for the running interpreter (e.g. cp312). -# _RADEON_BASE_URL holds the base URL for relative-href resolution. -_RADEON_LISTING="" -_RADEON_PYTAG="" -_RADEON_BASE_URL="" - -_radeon_fetch_listing() { - # Usage: _radeon_fetch_listing BASE_URL - # Populates _RADEON_LISTING, _RADEON_PYTAG, _RADEON_BASE_URL. - _RADEON_BASE_URL="$1" - _RADEON_PYTAG=$("$_VENV_PY" -c " -import sys -print('cp{}{}'.format(sys.version_info.major, sys.version_info.minor)) -" 2>/dev/null) || return 1 - if command -v curl >/dev/null 2>&1; then - _RADEON_LISTING=$(curl -fsSL --max-time 20 "$_RADEON_BASE_URL" 2>/dev/null) - elif command -v wget >/dev/null 2>&1; then - _RADEON_LISTING=$(wget -qO- --timeout=20 "$_RADEON_BASE_URL" 2>/dev/null) - fi - [ -n "$_RADEON_LISTING" ] || return 1 -} - -_pick_radeon_wheel() { - # Usage: _pick_radeon_wheel PACKAGE_NAME - # Scans $_RADEON_LISTING for the newest wheel whose filename starts exactly - # with PACKAGE_NAME- and matches _RADEON_PYTAG + linux_x86_64. - # Prints the full URL (resolving relative hrefs against _RADEON_BASE_URL). - # - # POSIX-compliant pipeline: all href parsing, filtering, and version - # selection is done inside a single awk script rather than reaching - # for GNU extensions (grep -o, sort -V) that would break under BSD - # or BusyBox coreutils. - _pkg="$1" - [ -n "$_RADEON_LISTING" ] || return 1 - [ -n "$_RADEON_PYTAG" ] || return 1 - _tag="$_RADEON_PYTAG" - _href=$(printf '%s\n' "$_RADEON_LISTING" \ - | awk -v pkg="$_pkg" -v tag="$_tag" ' - BEGIN { max_pad = ""; max_url = "" } - { - line = $0 - while (match(line, /href="[^"]*"/)) { - # Strip the leading href=" (6 chars) and trailing " (1 char) - url = substr(line, RSTART + 6, RLENGTH - 7) - line = substr(line, RSTART + RLENGTH) - - # Extract basename, strip query / fragment - n = split(url, p, "/") - base = p[n] - sub(/[?#].*/, "", base) - - prefix = pkg "-" - # Match cpXY-cpXY or cpXY-abi3 with any linux x86_64 - # platform tag (linux_x86_64, manylinux_2_28_x86_64, - # manylinux2014_x86_64, etc.) - if (substr(base, 1, length(prefix)) == prefix && - index(base, "-" tag "-") > 0 && - match(base, /x86_64\.whl$/)) { - # Extract the version component (first - # dotted-number run) and pad each piece so a - # plain lexical comparison gives us the newest. - if (match(base, /[0-9]+\.[0-9]+(\.[0-9]+)?/)) { - ver = substr(base, RSTART, RLENGTH) - m = split(ver, v, ".") - pad = "" - for (i = 1; i <= m; i++) - pad = pad sprintf("%08d", v[i]) - if (pad > max_pad) { - max_pad = pad - max_url = url - } - } - } - } - } - END { if (max_url != "") print max_url }') - [ -z "$_href" ] && return 1 - case "$_href" in - http*) printf '%s\n' "$_href" ;; - *) printf '%s\n' "${_RADEON_BASE_URL%/}/${_href#/}" ;; - esac -} - TORCH_INDEX_URL=$(get_torch_index_url) -# Auto-detect GPU for AMD ROCm based -# get_torch_index_url must have chosen */rocm* -# (gfx in rocminfo or amd-smi list). Then require rocminfo "Marketing Name:.*Radeon". -_amd_gpu_radeon=false -case "$TORCH_INDEX_URL" in - */rocm*) - if _has_amd_rocm_gpu && command -v rocminfo >/dev/null 2>&1 && \ - rocminfo 2>/dev/null | grep -q 'Marketing Name:.*Radeon'; then - _amd_gpu_radeon=true - fi - ;; -esac -_TAURI_TORCH_INDEX_FAMILY=$(_tauri_torch_index_family "$TORCH_INDEX_URL") -if [ "$_amd_gpu_radeon" = true ] && [ "$SKIP_TORCH" = false ]; then - _TAURI_TORCH_INDEX_FAMILY="radeon" -fi -_TAURI_GPU_BRANCH=$(_tauri_gpu_branch "$_TAURI_TORCH_INDEX_FAMILY" "$_amd_gpu_radeon") -tauri_diag_marker "$_TAURI_GPU_BRANCH" "$_TAURI_TORCH_INDEX_FAMILY" - # ── Print CPU-only hint when no GPU detected ── case "$TORCH_INDEX_URL" in */cpu) if [ "$SKIP_TORCH" = false ] && [ "$OS" != "macos" ]; then echo "" - echo " NOTE: No GPU detected (nvidia-smi and ROCm not found)." + echo " NOTE: No NVIDIA GPU detected (nvidia-smi not found)." echo " Installing CPU-only PyTorch. If you only need GGUF chat/inference," echo " re-run with --no-torch for a faster, lighter install:" echo " curl -fsSL https://unsloth.ai/install.sh | sh -s -- --no-torch" - echo " AMD ROCm users: see https://docs.unsloth.ai/get-started/install-and-update/amd" echo "" fi ;; - */rocm*) - echo "" - if [ "$_amd_gpu_radeon" = true ]; then - echo " AMD Radeon + ROCm detected -- installing PyTorch wheels from repo.radeon.com" - else - echo " AMD ROCm detected -- installing ROCm-enabled PyTorch ($TORCH_INDEX_URL)" - fi - echo "" - ;; esac # ── Install unsloth directly into the venv (no activation needed) ── -tauri_log "STEP" "Installing PyTorch" _VENV_PY="$VENV_DIR/bin/python" if [ "$_MIGRATED" = true ]; then # Migrated env: force-reinstall unsloth+unsloth-zoo to ensure clean state @@ -1486,7 +1040,7 @@ if [ "$_MIGRATED" = true ]; then # to prevent transitive torch resolution. run_install_cmd "install unsloth (migrated no-torch)" uv pip install --python "$_VENV_PY" --no-deps \ --reinstall-package unsloth --reinstall-package unsloth-zoo \ - "unsloth>=2026.4.8" unsloth-zoo + "unsloth>=2026.4.4" unsloth-zoo _NO_TORCH_RT="$(_find_no_torch_runtime)" if [ -n "$_NO_TORCH_RT" ]; then run_install_cmd "install no-torch runtime deps" uv pip install --python "$_VENV_PY" --no-deps -r "$_NO_TORCH_RT" @@ -1494,175 +1048,29 @@ if [ "$_MIGRATED" = true ]; then else run_install_cmd "install unsloth (migrated)" uv pip install --python "$_VENV_PY" \ --reinstall-package unsloth --reinstall-package unsloth-zoo \ - "unsloth>=2026.4.8" unsloth-zoo + "unsloth>=2026.4.4" unsloth-zoo fi if [ "$STUDIO_LOCAL_INSTALL" = true ]; then substep "overlaying local repo (editable)..." run_install_cmd "overlay local repo" uv pip install --python "$_VENV_PY" -e "$_REPO_ROOT" --no-deps - substep "overlaying unsloth-zoo from git main..." - run_install_cmd "overlay unsloth-zoo (git main)" uv pip install --python "$_VENV_PY" \ - --no-deps --reinstall-package unsloth-zoo \ - "unsloth-zoo @ git+https://github.com/unslothai/unsloth-zoo" - fi - # AMD ROCm: install bitsandbytes even in migrated environments so - # existing ROCm installs gain the AMD bitsandbytes build without a - # fresh reinstall. - if [ "$SKIP_TORCH" = false ]; then - case "$TORCH_INDEX_URL" in - */rocm*) - _install_bnb_rocm "install bitsandbytes (AMD)" "$_VENV_PY" - # Repair ROCm torch if overwritten during migrated install - _has_hip=$("$_VENV_PY" -c "import torch; print(getattr(torch.version,'hip','') or '')" 2>/dev/null || true) - if [ -z "$_has_hip" ]; then - substep "repairing ROCm torch (overwritten by dependency resolution)..." - run_install_cmd "repair ROCm torch" uv pip install --python "$_VENV_PY" \ - "$TORCH_CONSTRAINT" torchvision torchaudio \ - --index-url "$TORCH_INDEX_URL" \ - --force-reinstall - fi - ;; - esac fi elif [ -n "$TORCH_INDEX_URL" ]; then # Fresh: Step 1 - install torch from explicit index (skip when --no-torch or Intel Mac) if [ "$SKIP_TORCH" = true ]; then substep "skipping PyTorch (--no-torch or Intel Mac x86_64)." "$C_WARN" - elif [ "$_amd_gpu_radeon" = true ]; then - _radeon_url=$(get_radeon_wheel_url) - if [ -n "$_radeon_url" ]; then - _radeon_listing_ok=false - if _radeon_fetch_listing "$_radeon_url" 2>/dev/null; then - _radeon_listing_ok=true - else - # Try shorter X.Y path (AMD publishes both X.Y.Z and X.Y dirs) - _radeon_url_short=$(printf '%s\n' "$_radeon_url" \ - | sed 's|rocm-rel-\([0-9]*\)\.\([0-9]*\)\.[0-9]*/|rocm-rel-\1.\2/|') - if [ "$_radeon_url_short" != "$_radeon_url" ] && \ - _radeon_fetch_listing "$_radeon_url_short" 2>/dev/null; then - _radeon_listing_ok=true - fi - fi - - if [ "$_radeon_listing_ok" = true ]; then - # Require torch, torchvision, torchaudio wheels to all resolve - # from the Radeon listing. If any is missing for this Python - # tag, fall through to the standard ROCm index instead of - # silently mixing Radeon wheels with PyPI defaults. - _torch_whl=$(_pick_radeon_wheel "torch" 2>/dev/null) || _torch_whl="" - _tv_whl=$(_pick_radeon_wheel "torchvision" 2>/dev/null) || _tv_whl="" - _ta_whl=$(_pick_radeon_wheel "torchaudio" 2>/dev/null) || _ta_whl="" - _tri_whl=$(_pick_radeon_wheel "triton" 2>/dev/null) || _tri_whl="" - # Sanity-check torch / torchvision / torchaudio are a - # matching release. The Radeon repo publishes multiple - # generations simultaneously, so picking the highest-version - # wheel for each package independently can assemble a - # mismatched trio (e.g. torch 2.9.1 + torchvision 0.23.0 + - # torchaudio 2.9.0 from the current rocm-rel-7.2.1 index). - # Check that torch and torchaudio share the same X.Y public - # version prefix, and that torchvision's minor correctly - # pairs with torch's minor (torchvision = torch.minor - 5 - # since torch 2.4 -> torchvision 0.19 -> torch 2.9 -> - # torchvision 0.24). - # URL-decode each wheel name so %2B -> + before version - # extraction. Real Radeon wheel hrefs are percent-encoded - # (torch-2.10.0%2Brocm7.2.0...), so a plain [+-] terminator - # in the sed regex below would never match and - # _radeon_versions_match would stay false for every real - # listing, silently forcing a fallback to the generic - # ROCm index. - _torch_ver="" - _tv_ver="" - _ta_ver="" - if [ -n "$_torch_whl" ]; then - _torch_name=$(printf '%s' "${_torch_whl##*/}" | sed 's/%2[Bb]/+/g') - _torch_ver=$(printf '%s\n' "$_torch_name" | sed -n 's|^torch-\([0-9][0-9]*\.[0-9][0-9]*\)\(\.[0-9][0-9]*\)\{0,1\}[+-].*|\1|p') - fi - if [ -n "$_tv_whl" ]; then - _tv_name=$(printf '%s' "${_tv_whl##*/}" | sed 's/%2[Bb]/+/g') - _tv_ver=$(printf '%s\n' "$_tv_name" | sed -n 's|^torchvision-\([0-9][0-9]*\.[0-9][0-9]*\)\(\.[0-9][0-9]*\)\{0,1\}[+-].*|\1|p') - fi - if [ -n "$_ta_whl" ]; then - _ta_name=$(printf '%s' "${_ta_whl##*/}" | sed 's/%2[Bb]/+/g') - _ta_ver=$(printf '%s\n' "$_ta_name" | sed -n 's|^torchaudio-\([0-9][0-9]*\.[0-9][0-9]*\)\(\.[0-9][0-9]*\)\{0,1\}[+-].*|\1|p') - fi - _radeon_versions_match=false - if [ -n "$_torch_ver" ] && [ -n "$_tv_ver" ] && [ -n "$_ta_ver" ]; then - _torch_major=${_torch_ver%%.*} - _torch_minor=${_torch_ver#*.} - _ta_major=${_ta_ver%%.*} - _ta_minor=${_ta_ver#*.} - _tv_major=${_tv_ver%%.*} - _tv_minor=${_tv_ver#*.} - # torchvision expected minor (e.g. torch 2.9 -> 0.24) - _expected_tv_minor=$((_torch_minor + 15)) - if [ "$_torch_major" = "$_ta_major" ] && \ - [ "$_torch_minor" = "$_ta_minor" ] && \ - [ "$_tv_major" = "0" ] && \ - [ "$_tv_minor" = "$_expected_tv_minor" ]; then - _radeon_versions_match=true - fi - fi - if [ -z "$_torch_whl" ] || [ -z "$_tv_whl" ] || [ -z "$_ta_whl" ] || \ - [ "$_radeon_versions_match" != true ]; then - substep "[WARN] Radeon repo lacks a compatible wheel set for this Python; falling back to ROCm index ($TORCH_INDEX_URL)" "$C_WARN" - run_install_cmd "install PyTorch" uv pip install --python "$_VENV_PY" \ - "$TORCH_CONSTRAINT" torchvision torchaudio \ - --index-url "$TORCH_INDEX_URL" - else - substep "installing PyTorch from Radeon repo (${_RADEON_BASE_URL})..." - # Pass explicit wheel URLs so the matched trio is - # installed together. --find-links lets uv discover - # the Radeon listing for any local lookup, and PyPI - # (not disabled) provides transitive deps like - # filelock / sympy / networkx which are not in the - # Radeon listing. - if [ -n "$_tri_whl" ]; then - run_install_cmd "install triton + PyTorch" uv pip install --python "$_VENV_PY" \ - --find-links "$_RADEON_BASE_URL" \ - "$_tri_whl" "$_torch_whl" "$_tv_whl" "$_ta_whl" - else - run_install_cmd "install PyTorch" uv pip install --python "$_VENV_PY" \ - --find-links "$_RADEON_BASE_URL" \ - "$_torch_whl" "$_tv_whl" "$_ta_whl" - fi - fi - else - substep "[WARN] Radeon repo unavailable; falling back to ROCm index ($TORCH_INDEX_URL)" "$C_WARN" - run_install_cmd "install PyTorch" uv pip install --python "$_VENV_PY" \ - "$TORCH_CONSTRAINT" torchvision torchaudio \ - --index-url "$TORCH_INDEX_URL" - fi - else - substep "[WARN] Radeon GPU detected but could not detect full ROCm version; falling back to ROCm index" "$C_WARN" - run_install_cmd "install PyTorch" uv pip install --python "$_VENV_PY" \ - "$TORCH_CONSTRAINT" torchvision torchaudio \ - --index-url "$TORCH_INDEX_URL" - fi else substep "installing PyTorch ($TORCH_INDEX_URL)..." run_install_cmd "install PyTorch" uv pip install --python "$_VENV_PY" "$TORCH_CONSTRAINT" torchvision torchaudio \ --index-url "$TORCH_INDEX_URL" fi - # AMD ROCm: install bitsandbytes (once, after torch, for all ROCm paths). - # Gate on SKIP_TORCH=false so a user running with --no-torch on a ROCm - # host stays in GGUF-only mode rather than pulling in bitsandbytes, - # which is only useful once torch is present for training. - if [ "$SKIP_TORCH" = false ]; then - case "$TORCH_INDEX_URL" in - */rocm*) - _install_bnb_rocm "install bitsandbytes (AMD)" "$_VENV_PY" - ;; - esac - fi # Fresh: Step 2 - install unsloth, preserving pre-installed torch - tauri_log "STEP" "Installing Unsloth" substep "installing unsloth (this may take a few minutes)..." if [ "$SKIP_TORCH" = true ]; then # No-torch: install unsloth + unsloth-zoo with --no-deps, then # runtime deps (typer, safetensors, transformers, etc.) with --no-deps. run_install_cmd "install unsloth (no-torch)" uv pip install --python "$_VENV_PY" --no-deps \ --upgrade-package unsloth --upgrade-package unsloth-zoo \ - "unsloth>=2026.4.8" unsloth-zoo + "unsloth>=2026.4.4" unsloth-zoo _NO_TORCH_RT="$(_find_no_torch_runtime)" if [ -n "$_NO_TORCH_RT" ]; then run_install_cmd "install no-torch runtime deps" uv pip install --python "$_VENV_PY" --no-deps -r "$_NO_TORCH_RT" @@ -1670,59 +1078,29 @@ elif [ -n "$TORCH_INDEX_URL" ]; then if [ "$STUDIO_LOCAL_INSTALL" = true ]; then substep "overlaying local repo (editable)..." run_install_cmd "overlay local repo" uv pip install --python "$_VENV_PY" -e "$_REPO_ROOT" --no-deps - substep "overlaying unsloth-zoo from git main..." - run_install_cmd "overlay unsloth-zoo (git main)" uv pip install --python "$_VENV_PY" \ - --no-deps --reinstall-package unsloth-zoo \ - "unsloth-zoo @ git+https://github.com/unslothai/unsloth-zoo" fi elif [ "$STUDIO_LOCAL_INSTALL" = true ]; then run_install_cmd "install unsloth (local)" uv pip install --python "$_VENV_PY" \ - --upgrade-package unsloth "unsloth>=2026.4.8" unsloth-zoo + --upgrade-package unsloth "unsloth>=2026.4.4" unsloth-zoo substep "overlaying local repo (editable)..." run_install_cmd "overlay local repo" uv pip install --python "$_VENV_PY" -e "$_REPO_ROOT" --no-deps - substep "overlaying unsloth-zoo from git main..." - run_install_cmd "overlay unsloth-zoo (git main)" uv pip install --python "$_VENV_PY" \ - --no-deps --reinstall-package unsloth-zoo \ - "unsloth-zoo @ git+https://github.com/unslothai/unsloth-zoo" else run_install_cmd "install unsloth" uv pip install --python "$_VENV_PY" \ - --upgrade-package unsloth -- "$PACKAGE_NAME" - fi - # AMD ROCm: repair torch if the unsloth/unsloth-zoo install pulled in - # CUDA torch from PyPI, overwriting the ROCm wheels installed in Step 1. - if [ "$SKIP_TORCH" = false ]; then - case "$TORCH_INDEX_URL" in - */rocm*) - _has_hip=$("$_VENV_PY" -c "import torch; print(getattr(torch.version,'hip','') or '')" 2>/dev/null || true) - if [ -z "$_has_hip" ]; then - substep "repairing ROCm torch (overwritten by dependency resolution)..." - run_install_cmd "repair ROCm torch" uv pip install --python "$_VENV_PY" \ - "$TORCH_CONSTRAINT" torchvision torchaudio \ - --index-url "$TORCH_INDEX_URL" \ - --force-reinstall - fi - ;; - esac + --upgrade-package unsloth "$PACKAGE_NAME" fi else # Fallback: GPU detection failed to produce a URL -- let uv resolve torch - tauri_log "STEP" "Installing Unsloth" substep "installing unsloth (this may take a few minutes)..." if [ "$STUDIO_LOCAL_INSTALL" = true ]; then - run_install_cmd "install unsloth (auto torch backend)" uv pip install --python "$_VENV_PY" unsloth-zoo "unsloth>=2026.4.8" --torch-backend=auto + run_install_cmd "install unsloth (auto torch backend)" uv pip install --python "$_VENV_PY" unsloth-zoo "unsloth>=2026.4.4" --torch-backend=auto substep "overlaying local repo (editable)..." run_install_cmd "overlay local repo" uv pip install --python "$_VENV_PY" -e "$_REPO_ROOT" --no-deps - substep "overlaying unsloth-zoo from git main..." - run_install_cmd "overlay unsloth-zoo (git main)" uv pip install --python "$_VENV_PY" \ - --no-deps --reinstall-package unsloth-zoo \ - "unsloth-zoo @ git+https://github.com/unslothai/unsloth-zoo" else - run_install_cmd "install unsloth (auto torch backend)" uv pip install --python "$_VENV_PY" --torch-backend=auto -- "$PACKAGE_NAME" + run_install_cmd "install unsloth (auto torch backend)" uv pip install --python "$_VENV_PY" "$PACKAGE_NAME" --torch-backend=auto fi fi # ── Run studio setup ── -tauri_log "STEP" "Running Studio setup" # When --local, use the repo's own setup.sh directly. # Otherwise, find it inside the installed package. SETUP_SH="" @@ -1743,7 +1121,6 @@ if [ -z "$SETUP_SH" ] || [ ! -f "$SETUP_SH" ]; then fi if [ -z "$SETUP_SH" ] || [ ! -f "$SETUP_SH" ]; then - tauri_log "ERROR" "Could not find studio/setup.sh in the installed package" echo "❌ ERROR: Could not find studio/setup.sh in the installed package." exit 1 fi @@ -1761,32 +1138,24 @@ if ! command -v bash >/dev/null 2>&1; then fi step "setup" "running unsloth studio update..." +# install.sh already installs base packages (unsloth + unsloth-zoo) and +# no-torch-runtime.txt above, so tell install_python_stack.py to skip +# the base step to avoid redundant reinstallation. _SKIP_BASE=1 +# Run setup.sh outside set -e so that a llama.cpp build failure (exit 1) +# does not skip PATH setup, shortcuts, and launch below. We capture the +# exit code and propagate it after post-install steps finish. _SETUP_EXIT=0 -# Tauri desktop app bundles its own frontend — skip Node/npm/frontend build -_SKIP_FRONTEND=0 -if [ "$TAURI_MODE" = true ]; then - _SKIP_FRONTEND=1 -fi if [ "$STUDIO_LOCAL_INSTALL" = true ]; then SKIP_STUDIO_BASE="$_SKIP_BASE" \ - SKIP_STUDIO_FRONTEND="$_SKIP_FRONTEND" \ STUDIO_PACKAGE_NAME="$PACKAGE_NAME" \ STUDIO_LOCAL_INSTALL=1 \ STUDIO_LOCAL_REPO="$_REPO_ROOT" \ UNSLOTH_NO_TORCH="$SKIP_TORCH" \ bash "$SETUP_SH" /dev/null substep "done" step "install" "installing transformers>=5.5.0..." @@ -139,11 +145,40 @@ else fi fi -# ── Verify installation ────────────────────────────────────── -if "$_VENV_PY" -c "import mlx_vlm"; then - substep "mlx-vlm verified" +# ── Find mlx-lm models directory ───────────────────────────── +MLX_MODELS=$("$_VENV_PY" -c "import mlx_lm; print(mlx_lm.__path__[0])")/models +step "models dir" "$MLX_MODELS" + +# ── Download and install Gemma 4 model files ────────────────── + +step "download" "installing Gemma 4 model files..." + +_install_model_file() { + _fname="$1" + if curl -fsSL "${REPO_URL}/unsloth/models/${_fname}" -o "${MLX_MODELS}/${_fname}" 2>/dev/null; then + substep "downloaded ${_fname} from branch ${BRANCH}" + elif [ -f "./${_fname}" ]; then + substep "using local ./${_fname}" + cp "./${_fname}" "${MLX_MODELS}/${_fname}" + else + fail "Could not install ${_fname}. Tried: + 1) ${REPO_URL}/unsloth/models/${_fname} + 2) Local file ./${_fname} + + To fix, download the file manually and place it in the current directory, + then re-run this script." + fi +} + +_install_model_file "gemma4.py" +_install_model_file "gemma4_text.py" + +# Verify files were installed correctly +if "$_VENV_PY" -c "from mlx_lm.models.gemma4_text import ProportionalRoPE" 2>/dev/null; then + substep "model files verified" else - fail "Installation verification failed." + fail "Model files installed but verification failed (ProportionalRoPE import error). + Try manually from: https://github.com/unslothai/unsloth/tree/feature/${BRANCH}" fi # ── Done ────────────────────────────────────────────────────── @@ -151,19 +186,18 @@ echo "" printf " ${C_TITLE}%s${C_RST}\n" "Gemma 4 MLX installed!" printf " ${C_DIM}%s${C_RST}\n" "$RULE" echo "" -step "available models" "unsloth/gemma-4-E2B-it-UD-MLX-4bit" -substep "unsloth/gemma-4-E4B-it-UD-MLX-4bit" -substep "unsloth/gemma-4-26b-a4b-it-UD-MLX-4bit" -substep "unsloth/gemma-4-31b-it-UD-MLX-4bit" +step "available models" "unsloth/gemma-4-E2B-it-UD-MLX-4bit (/BF16)" +substep "unsloth/gemma-4-E4B-it-UD-MLX-4bit (/BF16)" echo "" step "venv activate" "source ${VENV_DIR}/bin/activate" echo "" -step "text chat" "python -m mlx_vlm.chat --model unsloth/gemma-4-E2B-it-UD-MLX-4bit" -echo "" -step "vision chat" "python -m mlx_vlm.chat --model unsloth/gemma-4-31b-it-UD-MLX-4bit" -substep "Use /image path/to/image.jpg to load an image" +step "quick start" "python -m mlx_lm chat --model unsloth/gemma-4-E2B-it-UD-MLX-4bit --max-tokens 200" echo "" -step "gradio UI" "python -m mlx_vlm.chat_ui --model unsloth/gemma-4-31b-it-UD-MLX-4bit" +step "python API" "from mlx_lm import load, generate" +substep "model, tokenizer = load('unsloth/gemma-4-E2B-it-UD-MLX-4bit')" +substep "messages = [{'role': 'user', 'content': 'Hello!'}]" +substep "prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)" +substep "print(generate(model, tokenizer, prompt=prompt, max_tokens=200))" echo "" printf " ${C_DIM}%s${C_RST}\n" "$RULE" echo "" diff --git a/pyproject.toml b/pyproject.toml index c2f884e192..50bdf58b95 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,6 @@ studio = [ "frontend/*.yaml", "frontend/.git*", "backend/requirements/**/*", - "backend/plugins/**/*", "backend/core/data_recipe/oxc-validator/*.json", "backend/core/data_recipe/oxc-validator/*.mjs", ] @@ -89,7 +88,7 @@ huggingfacenotorch = [ ] huggingface = [ "unsloth[huggingfacenotorch]", - "unsloth_zoo>=2026.5.1", + "unsloth_zoo>=2026.4.3", "torchvision", "unsloth[triton]", ] @@ -579,7 +578,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3 ; ('linux' in sys_platform)", ] colab-new = [ - "unsloth_zoo>=2026.5.1", + "unsloth_zoo>=2026.4.3", "packaging", "tyro", "transformers>=4.51.3,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3,!=4.53.0,!=4.54.0,!=4.55.0,!=4.55.1,!=4.57.0,!=4.57.4,!=4.57.5,!=5.0.0,!=5.1.0,<=5.5.0", diff --git a/scripts/install_qwen3_6_mlx.sh b/scripts/install_qwen3_6_mlx.sh deleted file mode 100644 index 5ce66d29a6..0000000000 --- a/scripts/install_qwen3_6_mlx.sh +++ /dev/null @@ -1,191 +0,0 @@ -#!/bin/bash -set -e - -# ============================================================ -# Qwen3.6 MLX — One-command setup + inference -# -# Usage: -# bash install_qwen3_6_mlx.sh [--venv-dir DIR] -# -# This script: -# 1. Creates a Python virtual environment -# 2. Installs uv, mlx-vlm, transformers, torch, torchvision -# ============================================================ - -# ── Output style (inspired by unsloth/install.sh) ───────────── -RULE="" -_rule_i=0 -while [ "$_rule_i" -lt 52 ]; do - RULE="${RULE}─" - _rule_i=$((_rule_i + 1)) -done - -if [ -n "${NO_COLOR:-}" ]; then - C_TITLE= C_DIM= C_OK= C_WARN= C_ERR= C_RST= -elif [ -t 1 ] || [ -n "${FORCE_COLOR:-}" ]; then - _ESC="$(printf '\033')" - C_TITLE="${_ESC}[38;5;117m" - C_DIM="${_ESC}[38;5;245m" - C_OK="${_ESC}[38;5;108m" - C_WARN="${_ESC}[38;5;136m" - C_ERR="${_ESC}[91m" - C_RST="${_ESC}[0m" -else - C_TITLE= C_DIM= C_OK= C_WARN= C_ERR= C_RST= -fi - -step() { printf " ${C_DIM}%-18.18s${C_RST}${3:-$C_OK}%s${C_RST}\n" "$1" "$2"; } -substep() { printf " ${C_DIM}%-18s${2:-$C_DIM}%s${C_RST}\n" "" "$1"; } -fail() { step "error" "$1" "$C_ERR"; exit 1; } - -# ── Parse flags ─────────────────────────────────────────────── -VENV_DIR="" -_next_is_venv=false - -for arg in "$@"; do - if [ "$_next_is_venv" = true ]; then - VENV_DIR="$arg" - _next_is_venv=false - continue - fi - case "$arg" in - --venv-dir) _next_is_venv=true ;; - esac -done - -# Default venv location -if [ -z "$VENV_DIR" ]; then - VENV_DIR="$HOME/.unsloth/unsloth_qwen3_6_mlx" -fi - -# ── Banner ──────────────────────────────────────────────────── -echo "" -printf " ${C_TITLE}%s${C_RST}\n" "Qwen3.6 MLX Installer" -printf " ${C_DIM}%s${C_RST}\n" "$RULE" -echo "" - -# ── Platform check ──────────────────────────────────────────── -if [ "$(uname)" != "Darwin" ]; then - fail "MLX requires macOS with Apple Silicon. Detected: $(uname)" -fi - -_ARCH=$(uname -m) -if [ "$_ARCH" != "arm64" ]; then - step "warning" "Apple Silicon recommended (detected: $_ARCH)" "$C_WARN" -fi - -step "platform" "macOS ($_ARCH)" - -# ── Detect Python ───────────────────────────────────────────── -PYTHON="" -for _candidate in python3.12 python3.11 python3.13 python3; do - if command -v "$_candidate" >/dev/null 2>&1; then - PYTHON="$_candidate" - break - fi -done - -if [ -z "$PYTHON" ]; then - fail "Python 3 not found. Install via: brew install python@3.12" -fi - -_PY_VERSION=$("$PYTHON" -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}')") -step "python" "$PYTHON ($_PY_VERSION)" - -# ── Create virtual environment ──────────────────────────────── -if [ -x "$VENV_DIR/bin/python" ]; then - step "venv" "using existing environment" - substep "$VENV_DIR" -else - step "venv" "creating virtual environment" - substep "$VENV_DIR" - mkdir -p "$(dirname "$VENV_DIR")" - "$PYTHON" -m venv "$VENV_DIR" -fi - -# ── Install uv ─────────────────────────────────────────────── -if ! command -v uv >/dev/null 2>&1; then - step "uv" "installing uv package manager..." - _uv_tmp=$(mktemp) - curl -LsSf "https://astral.sh/uv/install.sh" -o "$_uv_tmp" - sh "$_uv_tmp" /dev/null || echo 'uv')" -fi - -_VENV_PY="$VENV_DIR/bin/python" - -# ── Install dependencies ────────────────────────────────────── -step "install" "installing mlx-vlm..." -uv pip install --python "$_VENV_PY" -q mlx-vlm -substep "done" - -step "install" "installing transformers>=5.2.0..." -if uv pip install --python "$_VENV_PY" -q "transformers>=5.2.0"; then - substep "installed from PyPI" -else - substep "PyPI install failed, trying GitHub..." - if uv pip install --python "$_VENV_PY" -q "git+https://github.com/huggingface/transformers.git"; then - substep "installed from huggingface/transformers main" - else - fail "Could not install transformers>=5.2.0 (required for Qwen3.5/3.6 model support). Please check your Python version (>=3.10 required) and network connection, then try again." - fi -fi - -step "install" "installing torch + torchvision (needed for Qwen3 VL processor)..." -uv pip install --python "$_VENV_PY" -q torch torchvision -substep "done" - -# ── Verify installation ────────────────────────────────────── -if "$_VENV_PY" -c "import mlx_vlm; import torch; import torchvision; import transformers"; then - substep "mlx-vlm + torch + transformers verified" -else - fail "Installation verification failed. Please ensure Python >=3.10 and try again." -fi - -# ── Apply patches for multi-turn image chat ────────────────── -_PATCH_BASE="https://raw.githubusercontent.com/unslothai/unsloth/refs/heads/fix/ui-fix/unsloth/models/patches/mlx_vlm_qwen3_5" -_SITE_PKGS=$("$_VENV_PY" -c "import site; print(site.getsitepackages()[0])") - -step "patch" "fixing multi-turn image chat..." - -if curl -sSLf "${_PATCH_BASE}/qwen3_5.py" -o "${_SITE_PKGS}/mlx_vlm/models/qwen3_5/qwen3_5.py"; then - substep "patched qwen3_5.py (MRoPE position reset)" -else - step "warning" "failed to download qwen3_5.py patch — multi-turn image chat may not work" "$C_WARN" -fi - -if curl -sSLf "${_PATCH_BASE}/generate.py" -o "${_SITE_PKGS}/mlx_vlm/generate.py"; then - substep "patched generate.py (mask trim on cache reuse)" -else - step "warning" "failed to download generate.py patch — multi-turn image chat may not work" "$C_WARN" -fi - -# Clear pycache so patches take effect -find "${_SITE_PKGS}/mlx_vlm" -name "__pycache__" -type d -exec rm -rf {} + 2>/dev/null || true -substep "cleared bytecode cache" - -# ── Done ────────────────────────────────────────────────────── -echo "" -printf " ${C_TITLE}%s${C_RST}\n" "Qwen3.6 MLX installed!" -printf " ${C_DIM}%s${C_RST}\n" "$RULE" -echo "" -step "available models" "unsloth/Qwen3.6-35B-A3B-UD-MLX-3bit" -substep "unsloth/Qwen3.6-35B-A3B-UD-MLX-4bit" -substep "unsloth/Qwen3.6-35B-A3B-MLX-8bit" -echo "" -step "venv activate" "source ${VENV_DIR}/bin/activate" -echo "" -step "vision chat" "python -m mlx_vlm.chat --model unsloth/Qwen3.6-35B-A3B-UD-MLX-4bit" -substep "Use /image path/to/image.jpg to load an image" -echo "" -step "gradio UI" "python -m mlx_vlm.chat_ui --model unsloth/Qwen3.6-35B-A3B-UD-MLX-4bit" -echo "" -printf " ${C_DIM}%s${C_RST}\n" "$RULE" -echo "" diff --git a/studio/Unsloth_Studio_Colab.ipynb b/studio/Unsloth_Studio_Colab.ipynb index c3aec04820..46e2067ba7 100644 --- a/studio/Unsloth_Studio_Colab.ipynb +++ b/studio/Unsloth_Studio_Colab.ipynb @@ -64,7 +64,7 @@ "id": "27e68f91" }, "outputs": [], - "source": "!git clone --depth 1 --branch main https://github.com/unslothai/unsloth.git\n%cd /content/unsloth\n!chmod +x studio/setup.sh && ./studio/setup.sh --local" + "source": "!git clone --depth 1 --branch main https://github.com/unslothai/unsloth.git\n%cd /content/unsloth\n!chmod +x studio/setup.sh && ./studio/setup.sh" }, { "cell_type": "markdown", diff --git a/studio/backend/assets/configs/inference_defaults.json b/studio/backend/assets/configs/inference_defaults.json index 1b10b557e4..1b4b5381e8 100644 --- a/studio/backend/assets/configs/inference_defaults.json +++ b/studio/backend/assets/configs/inference_defaults.json @@ -1,14 +1,6 @@ { "_comment": "Per-model-family inference parameter defaults. Sources: (1) Ollama params blobs, (2) Existing Unsloth Studio YAML configs. Patterns ordered longest-match-first.", "families": { - "qwen3.6": { - "temperature": 0.7, - "top_p": 0.8, - "top_k": 20, - "min_p": 0.0, - "repetition_penalty": 1.0, - "presence_penalty": 1.5 - }, "qwen3.5": { "temperature": 0.7, "top_p": 0.8, @@ -377,7 +369,7 @@ } }, "patterns": [ - "qwen3.6", "qwen3.5", + "qwen3.5", "qwen3-coder", "qwen3-next", "qwen3-vl", "qwen3", "qwen2.5-coder", "qwen2.5-vl", "qwen2.5-omni", "qwen2.5-math", "qwen2.5", "qwen2-vl", "qwen2", diff --git a/studio/backend/auth/authentication.py b/studio/backend/auth/authentication.py index 6ddcbc8e0b..b39f915764 100644 --- a/studio/backend/auth/authentication.py +++ b/studio/backend/auth/authentication.py @@ -10,12 +10,10 @@ import jwt from .storage import ( - API_KEY_PREFIX, get_jwt_secret, get_user_and_secret, load_jwt_secret, save_refresh_token, - validate_api_key, verify_refresh_token, ) @@ -52,8 +50,6 @@ def _decode_subject_without_verification(token: str) -> Optional[str]: def create_access_token( subject: str, expires_delta: Optional[timedelta] = None, - *, - desktop: bool = False, ) -> str: """ Create a signed JWT for the given subject (e.g. username). @@ -61,8 +57,6 @@ def create_access_token( Tokens are valid across restarts because the signing secret is stored in SQLite. """ to_encode = {"sub": subject} - if desktop: - to_encode["desktop"] = True expire = datetime.now(timezone.utc) + ( expires_delta or timedelta(minutes = ACCESS_TOKEN_EXPIRE_MINUTES) ) @@ -74,29 +68,7 @@ def create_access_token( ) -def is_desktop_access_token(token: str) -> bool: - """Return true only for a valid desktop-issued JWT access token.""" - if token.startswith(API_KEY_PREFIX): - return False - - subject = _decode_subject_without_verification(token) - if subject is None: - return False - - record = get_user_and_secret(subject) - if record is None: - return False - - _salt, _pwd_hash, jwt_secret, _must_change_password = record - try: - payload = jwt.decode(token, jwt_secret, algorithms = [ALGORITHM]) - except jwt.InvalidTokenError: - return False - - return payload.get("sub") == subject and payload.get("desktop") is True - - -def create_refresh_token(subject: str, *, desktop: bool = False) -> str: +def create_refresh_token(subject: str) -> str: """ Create a random refresh token, store its hash in SQLite, and return it. @@ -104,28 +76,21 @@ def create_refresh_token(subject: str, *, desktop: bool = False) -> str: """ token = secrets.token_urlsafe(48) expires_at = datetime.now(timezone.utc) + timedelta(days = REFRESH_TOKEN_EXPIRE_DAYS) - save_refresh_token(token, subject, expires_at.isoformat(), is_desktop = desktop) + save_refresh_token(token, subject, expires_at.isoformat()) return token -def refresh_access_token( - refresh_token: str, -) -> Tuple[Optional[str], Optional[str], bool]: +def refresh_access_token(refresh_token: str) -> Tuple[Optional[str], Optional[str]]: """ Validate a refresh token and issue a new access token. The refresh token itself is NOT consumed — it stays valid until expiry. Returns a new access_token or None if the refresh token is invalid/expired. """ - verified = verify_refresh_token(refresh_token) - if verified is None: - return None, None, False - username, is_desktop = verified - return ( - create_access_token(subject = username, desktop = is_desktop), - username, - is_desktop, - ) + username = verify_refresh_token(refresh_token) + if username is None: + return None, None + return create_access_token(subject = username), username def reload_secret() -> None: @@ -172,18 +137,6 @@ async def secure_endpoint(current_subject: str = Depends(get_current_subject)): ... """ token = credentials.credentials - - # --- API key path (sk-unsloth-...) --- - if token.startswith(API_KEY_PREFIX): - username = validate_api_key(token) - if username is None: - raise HTTPException( - status_code = status.HTTP_401_UNAUTHORIZED, - detail = "Invalid or expired API key", - ) - return username - - # --- JWT path --- subject = _decode_subject_without_verification(token) if subject is None: raise HTTPException( @@ -206,8 +159,7 @@ async def secure_endpoint(current_subject: str = Depends(get_current_subject)): status_code = status.HTTP_401_UNAUTHORIZED, detail = "Invalid token payload", ) - is_desktop = payload.get("desktop") is True - if must_change_password and not allow_password_change and not is_desktop: + if must_change_password and not allow_password_change: raise HTTPException( status_code = status.HTTP_403_FORBIDDEN, detail = "Password change required", diff --git a/studio/backend/auth/storage.py b/studio/backend/auth/storage.py index 9a03f5f542..1395574cce 100644 --- a/studio/backend/auth/storage.py +++ b/studio/backend/auth/storage.py @@ -6,7 +6,6 @@ """ import hashlib -import os import secrets import sqlite3 from datetime import datetime, timezone @@ -55,10 +54,6 @@ def generate_bootstrap_password() -> str: # before the user changes the password. ensure_dir(_BOOTSTRAP_PW_PATH.parent) _BOOTSTRAP_PW_PATH.write_text(_bootstrap_password) - try: - os.chmod(_BOOTSTRAP_PW_PATH, 0o600) - except OSError: - pass return _bootstrap_password @@ -68,17 +63,6 @@ def get_bootstrap_password() -> Optional[str]: return _bootstrap_password -def _load_bootstrap_password() -> Optional[str]: - """Load an existing bootstrap password without creating one.""" - global _bootstrap_password - _bootstrap_password = None - if _BOOTSTRAP_PW_PATH.is_file(): - bootstrap_password = _BOOTSTRAP_PW_PATH.read_text().strip() - if bootstrap_password: - _bootstrap_password = bootstrap_password - return _bootstrap_password - - def clear_bootstrap_password() -> None: """Delete the persisted bootstrap password file (called after password change).""" global _bootstrap_password @@ -88,22 +72,7 @@ def clear_bootstrap_password() -> None: def _hash_token(token: str) -> str: - """SHA-256 hash helper used for refresh token storage. - - Plain SHA-256 is intentional here: refresh tokens are high-entropy - random strings from ``secrets.token_urlsafe(48)`` (384 bits of - entropy), so a slow KDF (Argon2 / bcrypt / PBKDF2) provides zero - additional security — no attacker can brute-force 2^384 regardless - of hash speed — while adding tens of ms of CPU to every refresh. - See the OWASP Password Storage Cheat Sheet on fast-vs-slow hashing - of high-entropy inputs. - - API keys use the separate ``_pbkdf2_api_key`` helper below, which - runs PBKDF2-HMAC-SHA256 with a persistent server-side salt — not - for cryptographic reasons (128-bit random tokens don't need slow - hashing), but because CodeQL's ``py/weak-sensitive-data-hashing`` - query mislabels API keys as passwords and demands a KDF. - """ + """SHA-256 hash helper used for refresh token storage.""" return hashlib.sha256(token.encode("utf-8")).hexdigest() @@ -130,39 +99,7 @@ def get_connection() -> sqlite3.Connection: id INTEGER PRIMARY KEY, token_hash TEXT NOT NULL, username TEXT NOT NULL, - expires_at TEXT NOT NULL, - is_desktop INTEGER NOT NULL DEFAULT 0 - ); - """ - ) - conn.execute( - """ - CREATE TABLE IF NOT EXISTS api_keys ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - username TEXT NOT NULL, - key_prefix TEXT NOT NULL, - key_hash TEXT NOT NULL UNIQUE, - name TEXT NOT NULL DEFAULT '', - created_at TEXT NOT NULL, - last_used_at TEXT, - expires_at TEXT, - is_active INTEGER NOT NULL DEFAULT 1, - is_internal INTEGER NOT NULL DEFAULT 0 - ); - """ - ) - api_key_columns = { - row["name"] for row in conn.execute("PRAGMA table_info(api_keys)") - } - if "is_internal" not in api_key_columns: - conn.execute( - "ALTER TABLE api_keys ADD COLUMN is_internal INTEGER NOT NULL DEFAULT 0" - ) - conn.execute( - """ - CREATE TABLE IF NOT EXISTS app_secrets ( - key TEXT PRIMARY KEY, - value TEXT NOT NULL + expires_at TEXT NOT NULL ); """ ) @@ -171,107 +108,10 @@ def get_connection() -> sqlite3.Connection: conn.execute( "ALTER TABLE auth_user ADD COLUMN must_change_password INTEGER NOT NULL DEFAULT 0" ) - refresh_columns = { - row["name"] for row in conn.execute("PRAGMA table_info(refresh_tokens)") - } - if "is_desktop" not in refresh_columns: - conn.execute( - "ALTER TABLE refresh_tokens ADD COLUMN is_desktop INTEGER NOT NULL DEFAULT 0" - ) conn.commit() return conn -# ── API-key PBKDF2 salt ──────────────────────────────────────────────── -# -# Module-level cache for the persistent API-key PBKDF2 salt. Populated -# lazily on first use via ``_get_or_create_api_key_pbkdf2_salt``. Not -# protected by a lock because (a) the ``INSERT OR IGNORE`` provides -# atomicity at the SQLite layer and (b) concurrent populations converge -# on the same value, so the worst case is a harmless duplicate read on -# startup. -_api_key_pbkdf2_salt_cache: Optional[bytes] = None - - -def _get_or_create_api_key_pbkdf2_salt() -> bytes: - """Return the persistent API-key PBKDF2 salt, generating it once if missing. - - Stored as a hex-encoded 32-byte random value in the ``app_secrets`` - table under key ``"api_key_pbkdf2_salt"``. Regenerated only if the row - is missing (i.e. fresh install, or operator manually deleted the row - and accepts invalidating existing API keys). - """ - global _api_key_pbkdf2_salt_cache - if _api_key_pbkdf2_salt_cache is not None: - return _api_key_pbkdf2_salt_cache - - conn = get_connection() - try: - cur = conn.execute( - "SELECT value FROM app_secrets WHERE key = ?", - ("api_key_pbkdf2_salt",), - ) - row = cur.fetchone() - if row is None: - new_value = secrets.token_hex(32) # 32 bytes -> 64 hex chars - conn.execute( - "INSERT OR IGNORE INTO app_secrets (key, value) VALUES (?, ?)", - ("api_key_pbkdf2_salt", new_value), - ) - conn.commit() - cur = conn.execute( - "SELECT value FROM app_secrets WHERE key = ?", - ("api_key_pbkdf2_salt",), - ) - row = cur.fetchone() - salt = bytes.fromhex(row["value"]) - finally: - conn.close() - - _api_key_pbkdf2_salt_cache = salt - return salt - - -_API_KEY_PBKDF2_ITERATIONS = 100_000 -DESKTOP_SECRET_PREFIX = "desktop-" -_DESKTOP_SECRET_HASH_KEY = "desktop_secret_hash" -_DESKTOP_SECRET_CREATED_AT_KEY = "desktop_secret_created_at" - - -def _pbkdf2_api_key(raw_key: str) -> str: - """PBKDF2-HMAC-SHA256 an API key with a persistent server-side salt. - - Used for API-key storage ONLY, not refresh tokens. Matches the - PBKDF2 algorithm + iteration count used by the password hasher in - ``auth/hashing.py`` so the codebase is consistent on which KDF it - uses for credential storage. - - Notes on why a slow KDF here is *only* a CodeQL appeasement and - *not* a cryptographic requirement: API keys are cryptographically - random 128-bit tokens (via ``secrets.token_hex``), so brute force - against 2^128 is infeasible regardless of hash speed. CodeQL's - ``py/weak-sensitive-data-hashing`` query mislabels these tokens as - "password" sensitive data and then demands a KDF from its - allowlist (Argon2 / scrypt / bcrypt / PBKDF2). Per the query's - own recommendation page we use PBKDF2. The persistent salt is - still loaded from ``app_secrets`` so an attacker dumping the - ``api_keys`` table alone cannot derive hashes for candidate - tokens without also obtaining the salt row. - """ - salt = _get_or_create_api_key_pbkdf2_salt() - dk = hashlib.pbkdf2_hmac( - "sha256", - raw_key.encode("utf-8"), - salt, - _API_KEY_PBKDF2_ITERATIONS, - ) - return dk.hex() - - -def _pbkdf2_desktop_secret(raw_secret: str) -> str: - return _pbkdf2_api_key(raw_secret) - - def is_initialized() -> bool: """Check if auth is ready for login (at least one user exists in DB).""" conn = get_connection() @@ -413,10 +253,6 @@ def ensure_default_admin() -> bool: Uses a randomly generated diceware passphrase as the bootstrap password. Returns True when the default admin was created in this call. """ - if get_user_and_secret(DEFAULT_ADMIN_USERNAME) is not None: - _load_bootstrap_password() - return False - bootstrap_pw = generate_bootstrap_password() try: create_initial_user( @@ -449,19 +285,12 @@ def update_password(username: str, new_password: str) -> bool: conn.commit() if cursor.rowcount > 0: clear_bootstrap_password() - clear_desktop_secret() return cursor.rowcount > 0 finally: conn.close() -def save_refresh_token( - token: str, - username: str, - expires_at: str, - *, - is_desktop: bool = False, -) -> None: +def save_refresh_token(token: str, username: str, expires_at: str) -> None: """ Store a hashed refresh token with its associated username and expiry. """ @@ -470,21 +299,21 @@ def save_refresh_token( try: conn.execute( """ - INSERT INTO refresh_tokens (token_hash, username, expires_at, is_desktop) - VALUES (?, ?, ?, ?) + INSERT INTO refresh_tokens (token_hash, username, expires_at) + VALUES (?, ?, ?) """, - (token_hash, username, expires_at, int(is_desktop)), + (token_hash, username, expires_at), ) conn.commit() finally: conn.close() -def verify_refresh_token(token: str) -> Optional[Tuple[str, bool]]: +def verify_refresh_token(token: str) -> Optional[str]: """ - Verify a refresh token and return the username plus desktop marker. + Verify a refresh token and return the username. - Returns the username and desktop marker if valid and not expired, None otherwise. + Returns the username if valid and not expired, None otherwise. The token is NOT consumed — it stays valid until it expires. """ token_hash = _hash_token(token) @@ -499,7 +328,7 @@ def verify_refresh_token(token: str) -> Optional[Tuple[str, bool]]: cur = conn.execute( """ - SELECT id, username, expires_at, is_desktop FROM refresh_tokens + SELECT id, username, expires_at FROM refresh_tokens WHERE token_hash = ? """, (token_hash,), @@ -515,7 +344,7 @@ def verify_refresh_token(token: str) -> Optional[Tuple[str, bool]]: conn.commit() return None - return row["username"], bool(row["is_desktop"]) + return row["username"] finally: conn.close() @@ -528,208 +357,3 @@ def revoke_user_refresh_tokens(username: str) -> None: conn.commit() finally: conn.close() - - -def create_desktop_secret() -> str: - """Create/rotate the local desktop credential and return it once.""" - ensure_default_admin() - raw_secret = DESKTOP_SECRET_PREFIX + secrets.token_urlsafe(48) - secret_hash = _pbkdf2_desktop_secret(raw_secret) - now = datetime.now(timezone.utc).isoformat() - conn = get_connection() - try: - conn.execute( - "INSERT OR REPLACE INTO app_secrets (key, value) VALUES (?, ?)", - (_DESKTOP_SECRET_HASH_KEY, secret_hash), - ) - conn.execute( - "INSERT OR REPLACE INTO app_secrets (key, value) VALUES (?, ?)", - (_DESKTOP_SECRET_CREATED_AT_KEY, now), - ) - conn.commit() - return raw_secret - finally: - conn.close() - - -def validate_desktop_secret(raw_secret: str) -> Optional[str]: - """Return the real admin username when the desktop secret matches.""" - if not raw_secret.startswith(DESKTOP_SECRET_PREFIX): - return None - if get_user_and_secret(DEFAULT_ADMIN_USERNAME) is None: - return None - - secret_hash = _pbkdf2_desktop_secret(raw_secret) - conn = get_connection() - try: - cur = conn.execute( - "SELECT value FROM app_secrets WHERE key = ?", - (_DESKTOP_SECRET_HASH_KEY,), - ) - row = cur.fetchone() - if row is None: - return None - if not secrets.compare_digest(row["value"], secret_hash): - return None - return DEFAULT_ADMIN_USERNAME - finally: - conn.close() - - -def clear_desktop_secret() -> None: - """Remove backend-side desktop auth state.""" - conn = get_connection() - try: - conn.execute( - "DELETE FROM app_secrets WHERE key IN (?, ?)", - (_DESKTOP_SECRET_HASH_KEY, _DESKTOP_SECRET_CREATED_AT_KEY), - ) - conn.commit() - finally: - conn.close() - - -# --------------------------------------------------------------------------- -# API key management -# --------------------------------------------------------------------------- - -API_KEY_PREFIX = "sk-unsloth-" - - -def create_api_key( - username: str, - name: str, - expires_at: Optional[str] = None, - internal: bool = False, -) -> Tuple[str, dict]: - """Create a new API key for *username*. - - Returns ``(raw_key, row_dict)`` where *raw_key* is shown to the user - exactly once. The database only stores the PBKDF2 hash. - - Pass ``internal=True`` for keys minted by workflows (e.g. data-recipe - runs) that should not appear in user-facing key listings. - """ - raw_key = API_KEY_PREFIX + secrets.token_hex(16) - key_hash = _pbkdf2_api_key(raw_key) - key_prefix = raw_key[len(API_KEY_PREFIX) : len(API_KEY_PREFIX) + 8] - now = datetime.now(timezone.utc).isoformat() - - conn = get_connection() - try: - conn.execute( - """ - INSERT INTO api_keys (username, key_prefix, key_hash, name, created_at, expires_at, is_internal) - VALUES (?, ?, ?, ?, ?, ?, ?) - """, - ( - username, - key_prefix, - key_hash, - name, - now, - expires_at, - 1 if internal else 0, - ), - ) - conn.commit() - cur = conn.execute("SELECT * FROM api_keys WHERE key_hash = ?", (key_hash,)) - row = cur.fetchone() - return raw_key, dict(row) - finally: - conn.close() - - -def list_api_keys(username: str, include_internal: bool = False) -> list: - """Return API keys for *username*. Internal workflow keys are hidden - by default so they do not clutter user-facing UIs.""" - conn = get_connection() - try: - if include_internal: - cur = conn.execute( - """ - SELECT id, username, key_prefix, name, created_at, last_used_at, - expires_at, is_active, is_internal - FROM api_keys - WHERE username = ? - ORDER BY created_at DESC - """, - (username,), - ) - else: - cur = conn.execute( - """ - SELECT id, username, key_prefix, name, created_at, last_used_at, - expires_at, is_active, is_internal - FROM api_keys - WHERE username = ? AND is_internal = 0 - ORDER BY created_at DESC - """, - (username,), - ) - return [dict(row) for row in cur.fetchall()] - finally: - conn.close() - - -def revoke_api_key(username: str, key_id: int) -> bool: - """Soft-delete an API key. Returns True if a matching row was found.""" - conn = get_connection() - try: - cursor = conn.execute( - "UPDATE api_keys SET is_active = 0 WHERE id = ? AND username = ?", - (key_id, username), - ) - conn.commit() - return cursor.rowcount > 0 - finally: - conn.close() - - -def revoke_internal_api_key(key_id: int) -> bool: - """Revoke an internal workflow-minted key without requiring a username. - - Used by the recipe runner to retire its sk-unsloth-* key once the job - terminates, shrinking the window a leaked key could be abused. - """ - conn = get_connection() - try: - cursor = conn.execute( - "UPDATE api_keys SET is_active = 0 WHERE id = ? AND is_internal = 1", - (key_id,), - ) - conn.commit() - return cursor.rowcount > 0 - finally: - conn.close() - - -def validate_api_key(raw_key: str) -> Optional[str]: - """Validate *raw_key* and return the owning username, or ``None``. - - Also updates ``last_used_at`` on success. - """ - key_hash = _pbkdf2_api_key(raw_key) - conn = get_connection() - try: - cur = conn.execute( - "SELECT id, username, is_active, expires_at FROM api_keys WHERE key_hash = ?", - (key_hash,), - ) - row = cur.fetchone() - if row is None: - return None - if not row["is_active"]: - return None - if row["expires_at"] is not None: - expires = datetime.fromisoformat(row["expires_at"]) - if datetime.now(timezone.utc) > expires: - return None - conn.execute( - "UPDATE api_keys SET last_used_at = ? WHERE id = ?", - (datetime.now(timezone.utc).isoformat(), row["id"]), - ) - conn.commit() - return row["username"] - finally: - conn.close() diff --git a/studio/backend/core/__init__.py b/studio/backend/core/__init__.py index d39815c437..d8d95e2f1a 100644 --- a/studio/backend/core/__init__.py +++ b/studio/backend/core/__init__.py @@ -31,7 +31,6 @@ # Config "ModelConfig", "is_vision_model", - "scan_trained_models", "scan_trained_loras", "load_model_defaults", "get_base_model_from_lora", @@ -73,7 +72,6 @@ def __getattr__(name): if name in ( "is_vision_model", "ModelConfig", - "scan_trained_models", "scan_trained_loras", "load_model_defaults", "get_base_model_from_lora", @@ -81,15 +79,14 @@ def __getattr__(name): from utils.models import ( is_vision_model, ModelConfig, - scan_trained_models, + scan_trained_loras, load_model_defaults, get_base_model_from_lora, ) globals()["is_vision_model"] = is_vision_model globals()["ModelConfig"] = ModelConfig - globals()["scan_trained_models"] = scan_trained_models - globals()["scan_trained_loras"] = scan_trained_models + globals()["scan_trained_loras"] = scan_trained_loras globals()["load_model_defaults"] = load_model_defaults globals()["get_base_model_from_lora"] = get_base_model_from_lora return globals()[name] diff --git a/studio/backend/core/data_recipe/jobs/constants.py b/studio/backend/core/data_recipe/jobs/constants.py index 0045276e20..08237326f8 100644 --- a/studio/backend/core/data_recipe/jobs/constants.py +++ b/studio/backend/core/data_recipe/jobs/constants.py @@ -9,7 +9,6 @@ STAGE_DAG = "dag" STAGE_HEALTHCHECK = "healthcheck" STAGE_SAMPLING = "sampling" -STAGE_SOURCE = "source" STAGE_COLUMN_CONFIG = "column_config" STAGE_GENERATING = "generating" STAGE_BATCH = "batch" diff --git a/studio/backend/core/data_recipe/jobs/manager.py b/studio/backend/core/data_recipe/jobs/manager.py index cdc28d9560..3d7cf2dbe6 100644 --- a/studio/backend/core/data_recipe/jobs/manager.py +++ b/studio/backend/core/data_recipe/jobs/manager.py @@ -33,60 +33,6 @@ _CTX = mp.get_context("spawn") -def _github_source_estimated_total(recipe: dict) -> int | None: - seed_config = recipe.get("seed_config") - if not isinstance(seed_config, dict): - return None - source = seed_config.get("source") - if not isinstance(source, dict) or source.get("seed_type") != "github_repo": - return None - - repos_raw = source.get("repos") - repos = ( - [repo for repo in repos_raw if isinstance(repo, str) and repo.strip()] - if isinstance(repos_raw, list) - else [] - ) - item_types_raw = source.get("item_types") - item_types = ( - [ - item - for item in item_types_raw - if isinstance(item, str) and item in {"issues", "pulls", "commits"} - ] - if isinstance(item_types_raw, list) - else [] - ) - try: - limit = int(source.get("limit") or 0) - except (TypeError, ValueError): - return None - if not repos or not item_types or limit <= 0: - return None - return len(repos) * len(item_types) * limit - - -def _source_progress_status(job: Job) -> dict[str, Any] | None: - progress = job.source_progress - if progress is None: - return None - return { - "source": progress.source, - "status": progress.status, - "repo": progress.repo, - "resource": progress.resource, - "page": progress.page, - "page_items": progress.page_items, - "fetched_items": progress.fetched_items, - "estimated_total": progress.estimated_total, - "percent": progress.percent, - "rate_remaining": progress.rate_remaining, - "retry_after_sec": progress.retry_after_sec, - "message": progress.message, - "updated_at": progress.updated_at, - } - - @dataclass class Subscription: replay: list[dict] @@ -125,20 +71,8 @@ def __init__(self) -> None: self._pump_thread: threading.Thread | None = None self._seq: int = 0 - def start( - self, - *, - recipe: dict, - run: dict, - internal_api_key_id: int | None = None, - ) -> str: - """Spawn the job subprocess (one at a time, no cap). - - ``internal_api_key_id`` is the row id of a workflow-scoped - sk-unsloth-* key minted by the route layer for local providers. - JobManager revokes it when the job reaches a terminal state so the - key's live window is no longer than the run. - """ + def start(self, *, recipe: dict, run: dict) -> str: + """Spawn the job subprocess (one at a time, no cap).""" llm_columns = recipe.get("columns") or [] llm_column_count = 0 if isinstance(llm_columns, list): @@ -158,29 +92,18 @@ def start( job_id = uuid.uuid4().hex self._job = Job(job_id = job_id, status = "pending", started_at = time.time()) self._job.progress_columns_total = llm_column_count - self._job.source_progress_estimated_total = _github_source_estimated_total( - recipe - ) - self._job.internal_api_key_id = internal_api_key_id self._events.clear() self._seq = 0 run_payload = dict(run) run_payload["_job_id"] = job_id - from utils.native_path_leases import ( - native_path_secret_removed_for_child_start, - run_without_native_path_secret, + mp_q = _CTX.Queue() + proc = _CTX.Process( + target = run_job_process, + kwargs = {"event_queue": mp_q, "recipe": recipe, "run": run_payload}, + daemon = True, ) - - with native_path_secret_removed_for_child_start(): - mp_q = _CTX.Queue() - proc = _CTX.Process( - target = run_without_native_path_secret, - args = (run_job_process,), - kwargs = {"event_queue": mp_q, "recipe": recipe, "run": run_payload}, - daemon = True, - ) - proc.start() + proc.start() self._mp_q = mp_q self._proc = proc @@ -240,7 +163,6 @@ def get_status(self, job_id: str) -> dict | None: "ok": job.column_progress.ok, "failed": job.column_progress.failed, }, - "source_progress": _source_progress_status(job), "model_usage": { name: { "model": usage.model, @@ -483,7 +405,6 @@ def _pump_loop(self) -> None: for e in self._drain_queue(mp_q): self._handle_event(job, e) - retired_job: Job | None = None with self._lock: if self._job and self._job.status in { "pending", @@ -508,9 +429,6 @@ def _pump_loop(self) -> None: "job_id": self._job.job_id, } ) - retired_job = self._job - if retired_job is not None: - self._retire_workflow_key(retired_job) return def _handle_event(self, job: Job, event: dict) -> None: @@ -518,7 +436,6 @@ def _handle_event(self, job: Job, event: dict) -> None: et = event.get("type") msg = event.get("message") if et == "log" else None - terminal = False with self._lock: if self._job is None or self._job.job_id != job.job_id: return @@ -535,43 +452,18 @@ def _handle_event(self, job: Job, event: dict) -> None: if self._job.progress.total and self._job.progress.total > 0: self._job.progress.done = self._job.progress.total self._job.progress.percent = 100.0 - terminal = True if et == EVENT_JOB_ERROR: self._job.status = "error" self._job.finished_at = time.time() self._job.error = event.get("error") or "error" - terminal = True - if et == EVENT_JOB_CANCELLED: - terminal = True if msg: upd = parse_log_message(msg) if upd: apply_update(self._job, upd) - if terminal: - self._retire_workflow_key(job) - self._emit(event) - def _retire_workflow_key(self, job: Job) -> None: - """Revoke the workflow-scoped sk-unsloth-* key, if one was minted. - - Best-effort: revocation failures are swallowed. The key would - expire on its own after 24h, so a missed revoke is a latency - concern, not a correctness one. - """ - key_id = getattr(job, "internal_api_key_id", None) - if not key_id: - return - try: - from auth import storage # deferred: avoids circular import - - storage.revoke_internal_api_key(int(key_id)) - except Exception: - pass - job.internal_api_key_id = None - _JOB_MANAGER: JobManager | None = None diff --git a/studio/backend/core/data_recipe/jobs/parse.py b/studio/backend/core/data_recipe/jobs/parse.py index cea6d8ea64..324b62a92e 100644 --- a/studio/backend/core/data_recipe/jobs/parse.py +++ b/studio/backend/core/data_recipe/jobs/parse.py @@ -4,7 +4,6 @@ from __future__ import annotations import re -import time from dataclasses import dataclass from typing import Any @@ -18,10 +17,9 @@ STAGE_PREVIEW, STAGE_PROFILING, STAGE_SAMPLING, - STAGE_SOURCE, USAGE_RESET_STAGES, ) -from .types import Job, ModelUsage, Progress, SourceProgress +from .types import Job, ModelUsage, Progress @dataclass(frozen = True) @@ -43,7 +41,6 @@ class ParsedUpdate: usage_requests_total: int | None = None usage_rpm: float | None = None usage_section_start: bool | None = None - source_progress: SourceProgress | None = None # kinda of a bummber but currently only option, Best effort parser from data-designer logs -> structured status for UI. @@ -64,165 +61,9 @@ class ParsedUpdate: _RE_USAGE_REQUESTS = re.compile( r"requests:\s*success=(?P\d+),\s*failed=(?P\d+),\s*total=(?P\d+),\s*rpm=(?P[0-9.]+)" ) -_RE_GITHUB_PAGE = re.compile( - r"^\[(?P[^\]\s]+/[^\]\s]+)\]\s+" - r"(?Pissues|PRs|commits)\s+page\s+(?P\d+)\s+" - r"\(\+(?P\d+)\).*?\bremaining=(?P\d+)", - re.IGNORECASE, -) -_RE_GITHUB_RATE_LIMIT = re.compile( - r"Rate limit hit\. Sleeping (?P\d+)s until reset\.", - re.IGNORECASE, -) -_RE_GITHUB_SECONDARY_RATE_LIMIT = re.compile( - r"Secondary rate limit(?: on REST)?\. Sleep (?P\d+)s\.", - re.IGNORECASE, -) -_RE_GITHUB_REST_RATE_LIMIT = re.compile( - r"REST 403/429, sleep (?P\d+)", - re.IGNORECASE, -) -_RE_GITHUB_TRANSIENT = re.compile( - r"^(?PGraphQL|REST) (?P\d{3}) transient, retrying", - re.IGNORECASE, -) -_RE_GITHUB_NETWORK_RETRY = re.compile( - r"^(?PGraphQL|REST) network error: .* Retry\.", - re.IGNORECASE, -) -_RE_GITHUB_TRIAL_LIMIT = re.compile( - r"Trial limit reached for (?Pissues|PRs|commits) \((?P\d+)\)", - re.IGNORECASE, -) -_RE_GITHUB_COMPLETE = re.compile( - r"Scraper complete\. GraphQL calls=\d+ REST calls=\d+", - re.IGNORECASE, -) def parse_log_message(msg: str) -> ParsedUpdate | None: - m = _RE_GITHUB_PAGE.search(msg) - if m: - resource_raw = m.group("resource") - resource = "pulls" if resource_raw.lower() == "prs" else resource_raw.lower() - repo = m.group("repo") - page = int(m.group("page")) - page_items = int(m.group("items")) - return ParsedUpdate( - stage = STAGE_SOURCE, - source_progress = SourceProgress( - source = "github", - status = "fetching", - repo = repo, - resource = resource, - page = page, - page_items = page_items, - rate_remaining = int(m.group("remaining")), - message = ( - f"Scraping GitHub source: {repo} " - f"{resource} page {page} (+{page_items})" - ), - ), - ) - - m = _RE_GITHUB_RATE_LIMIT.search(msg) - if m: - seconds = int(m.group("seconds")) - return ParsedUpdate( - stage = STAGE_SOURCE, - source_progress = SourceProgress( - source = "github", - status = "rate_limited", - retry_after_sec = seconds, - message = ( - "Waiting for GitHub rate limit. " - "Studio will resume automatically." - ), - ), - ) - - m = _RE_GITHUB_SECONDARY_RATE_LIMIT.search(msg) - if m: - seconds = int(m.group("seconds")) - return ParsedUpdate( - stage = STAGE_SOURCE, - source_progress = SourceProgress( - source = "github", - status = "rate_limited", - retry_after_sec = seconds, - message = ( - "Waiting for GitHub secondary rate limit. " - "Studio will resume automatically." - ), - ), - ) - - m = _RE_GITHUB_REST_RATE_LIMIT.search(msg) - if m: - seconds = int(m.group("seconds")) - return ParsedUpdate( - stage = STAGE_SOURCE, - source_progress = SourceProgress( - source = "github", - status = "rate_limited", - retry_after_sec = seconds, - message = ( - "Waiting for GitHub rate limit. " - "Studio will resume automatically." - ), - ), - ) - - m = _RE_GITHUB_TRIAL_LIMIT.search(msg) - if m: - resource_raw = m.group("resource") - resource = "pulls" if resource_raw.lower() == "prs" else resource_raw.lower() - items = int(m.group("items")) - return ParsedUpdate( - stage = STAGE_SOURCE, - source_progress = SourceProgress( - source = "github", - status = "fetching", - resource = resource, - message = f"GitHub {resource} trial limit reached ({items}).", - ), - ) - - m = _RE_GITHUB_TRANSIENT.search(msg) - if m: - api = m.group("api") - code = m.group("code") - return ParsedUpdate( - stage = STAGE_SOURCE, - source_progress = SourceProgress( - source = "github", - status = "retrying", - message = f"GitHub {api} returned {code}; retrying automatically.", - ), - ) - - m = _RE_GITHUB_NETWORK_RETRY.search(msg) - if m: - api = m.group("api") - return ParsedUpdate( - stage = STAGE_SOURCE, - source_progress = SourceProgress( - source = "github", - status = "retrying", - message = f"GitHub {api} request failed; retrying automatically.", - ), - ) - - if _RE_GITHUB_COMPLETE.search(msg): - return ParsedUpdate( - stage = STAGE_SOURCE, - source_progress = SourceProgress( - source = "github", - status = "completed", - message = "GitHub source scrape complete.", - ), - ) - m = _RE_SAMPLERS.search(msg) if m: return ParsedUpdate( @@ -331,8 +172,6 @@ def apply_update(job: Job, update: ParsedUpdate) -> None: job.batch.idx = update.batch_idx if update.batch_total is not None: job.batch.total = update.batch_total - if update.source_progress is not None: - _apply_source_progress(job, update.source_progress) if update.stage in USAGE_RESET_STAGES: # usage summary is a short block so we reset once we move into the next stage. @@ -377,67 +216,6 @@ def apply_update(job: Job, update: ParsedUpdate) -> None: usage.rpm = update.usage_rpm -def _apply_source_progress(job: Job, progress: SourceProgress) -> None: - previous = job.source_progress - now = time.time() - - page_items = progress.page_items - if progress.repo and progress.resource and progress.page is not None: - page_key = f"{progress.repo}:{progress.resource}:{progress.page}" - count_key = f"{progress.repo}:{progress.resource}" - if page_key not in job._source_seen_pages: - job._source_seen_pages.add(page_key) - job._source_counts[count_key] = int( - job._source_counts.get(count_key, 0) - ) + int(page_items or 0) - - fetched_items = sum(job._source_counts.values()) - if fetched_items <= 0: - fetched_items = progress.fetched_items or ( - previous.fetched_items if previous else None - ) - - estimated_total = ( - progress.estimated_total - or job.source_progress_estimated_total - or (previous.estimated_total if previous else None) - ) - percent: float | None = progress.percent - if percent is None and estimated_total and fetched_items is not None: - raw_percent = (float(fetched_items) / float(max(1, estimated_total))) * 100.0 - percent = 100.0 if progress.status == "completed" else min(99.0, raw_percent) - if percent is None and previous is not None: - percent = previous.percent - - job.source_progress = SourceProgress( - source = "github", - status = progress.status or (previous.status if previous else None), - repo = progress.repo or (previous.repo if previous else None), - resource = progress.resource or (previous.resource if previous else None), - page = ( - progress.page - if progress.page is not None - else (previous.page if previous else None) - ), - page_items = ( - page_items - if page_items is not None - else (previous.page_items if previous else None) - ), - fetched_items = fetched_items, - estimated_total = estimated_total, - percent = percent, - rate_remaining = ( - progress.rate_remaining - if progress.rate_remaining is not None - else (previous.rate_remaining if previous else None) - ), - retry_after_sec = progress.retry_after_sec, - message = progress.message or (previous.message if previous else None), - updated_at = now, - ) - - def _compute_overall_progress(job: Job, column_progress: Progress) -> Progress: if not job.rows: return column_progress diff --git a/studio/backend/core/data_recipe/jobs/types.py b/studio/backend/core/data_recipe/jobs/types.py index 3d3ddb974e..8d77903238 100644 --- a/studio/backend/core/data_recipe/jobs/types.py +++ b/studio/backend/core/data_recipe/jobs/types.py @@ -35,23 +35,6 @@ class BatchProgress: total: int | None = None -@dataclass -class SourceProgress: - source: str = "github" - status: str | None = None - repo: str | None = None - resource: str | None = None - page: int | None = None - page_items: int | None = None - fetched_items: int | None = None - estimated_total: int | None = None - percent: float | None = None - rate_remaining: int | None = None - retry_after_sec: int | None = None - message: str | None = None - updated_at: float | None = None - - @dataclass class ModelUsage: model: str @@ -74,7 +57,6 @@ class Job: progress: Progress = field(default_factory = Progress) column_progress: Progress = field(default_factory = Progress) batch: BatchProgress = field(default_factory = BatchProgress) - source_progress: SourceProgress | None = None rows: int | None = None cols: int | None = None error: str | None = None @@ -88,15 +70,8 @@ class Job: processor_artifacts: dict[str, Any] | None = None model_usage: dict[str, ModelUsage] = field(default_factory = dict) progress_columns_total: int | None = None - source_progress_estimated_total: int | None = None completed_columns: list[str] = field(default_factory = list) - # Id of the internal sk-unsloth-* API key minted for a local-model - # workflow. Revoked when the job terminates so the key's live window - # matches the run rather than its 24h TTL. - internal_api_key_id: int | None = None _current_usage_model: str | None = None _in_usage_summary: bool = False _seen_generation_columns: list[str] = field(default_factory = list) _column_done: dict[str, int] = field(default_factory = dict) - _source_counts: dict[str, int] = field(default_factory = dict) - _source_seen_pages: set[str] = field(default_factory = set) diff --git a/studio/backend/core/data_recipe/jobs/worker.py b/studio/backend/core/data_recipe/jobs/worker.py index 8c5c7fe657..63e38bd18d 100644 --- a/studio/backend/core/data_recipe/jobs/worker.py +++ b/studio/backend/core/data_recipe/jobs/worker.py @@ -21,15 +21,6 @@ from utils.paths import ensure_dir, recipe_datasets_root _ARTIFACT_ROOT = recipe_datasets_root() -_RE_GITHUB_CURSOR = re.compile(r"\bcursor=[^\s,]+") -_RE_SECRET_TOKEN = re.compile( - r"\b(?:(?:ghp|gho|ghu|ghs|ghr|github_pat)_[A-Za-z0-9_]+|sk-unsloth-[A-Za-z0-9]+)" -) - - -def _sanitize_log_message(message: str) -> str: - message = _RE_GITHUB_CURSOR.sub("cursor=", message) - return _RE_SECRET_TOKEN.sub("", message) class _QueueLogHandler(logging.Handler): @@ -44,7 +35,7 @@ def emit(self, record: logging.LogRecord) -> None: "ts": record.created, "level": record.levelname, "logger": record.name, - "message": _sanitize_log_message(record.getMessage()), + "message": record.getMessage(), } self._q.put(event) except (OSError, RuntimeError, ValueError): @@ -128,16 +119,10 @@ def run_job_process( # Attach queue logger directly to `data_designer` so parser events survive root resets. handler = _QueueLogHandler(event_queue) handler.setLevel(logging.INFO) - for logger_name in ( - "data_designer", - "scraper", - "gh_client", - "data_designer_github_repo_seed", - ): - logger = logging.getLogger(logger_name) - logger.addHandler(handler) - logger.setLevel(logging.INFO) - logger.propagate = True + data_designer_logger = logging.getLogger("data_designer") + data_designer_logger.addHandler(handler) + data_designer_logger.setLevel(logging.INFO) + data_designer_logger.propagate = True if run_config_raw: designer.set_run_config(RunConfig.model_validate(run_config_raw)) @@ -195,8 +180,8 @@ def run_job_process( { "type": EVENT_JOB_ERROR, "ts": time.time(), - "error": _sanitize_log_message(str(exc)), - "stack": _sanitize_log_message(traceback.format_exc(limit = 20)), + "error": str(exc), + "stack": traceback.format_exc(limit = 20), } ) diff --git a/studio/backend/core/data_recipe/local_callable_validators.py b/studio/backend/core/data_recipe/local_callable_validators.py index 44459e88c5..c32b2fccaf 100644 --- a/studio/backend/core/data_recipe/local_callable_validators.py +++ b/studio/backend/core/data_recipe/local_callable_validators.py @@ -33,12 +33,6 @@ _OXC_RUNNER_PATH = _OXC_TOOL_DIR / "validate.mjs" -from utils.native_path_leases import child_env_without_native_path_secret -from utils.subprocess_compat import ( - windows_hidden_subprocess_kwargs as _windows_hidden_subprocess_kwargs, -) - - @dataclass(frozen = True) class OxcLocalCallableValidatorSpec: name: str @@ -249,7 +243,7 @@ def _run_oxc_batch( } try: tmp_dir = ensure_dir(oxc_validator_tmp_root()) - env = child_env_without_native_path_secret() + env = dict(os.environ) tmp_dir_str = str(tmp_dir) env["TMPDIR"] = tmp_dir_str env["TMP"] = tmp_dir_str @@ -262,7 +256,6 @@ def _run_oxc_batch( capture_output = True, check = False, env = env, - **_windows_hidden_subprocess_kwargs(), ) except (OSError, ValueError) as exc: logger.warning("OXC subprocess launch failed: %s", exc) diff --git a/studio/backend/core/export/export.py b/studio/backend/core/export/export.py index 6fee5a38f7..966e045b13 100644 --- a/studio/backend/core/export/export.py +++ b/studio/backend/core/export/export.py @@ -28,8 +28,6 @@ logger = get_logger(__name__) -_LLAMA_CPP_SCRIPTS_WARNING_EMITTED = False - def _is_wsl(): """Detect if running under Windows Subsystem for Linux.""" @@ -312,7 +310,7 @@ def export_merged_model( repo_id: Optional[str] = None, hf_token: Optional[str] = None, private: bool = False, - ) -> Tuple[bool, str, Optional[str]]: + ) -> Tuple[bool, str]: """ Export merged model (for PEFT models). @@ -325,21 +323,14 @@ def export_merged_model( private: Whether to make the repo private Returns: - Tuple of (success, message, output_path). output_path is the - resolved absolute on-disk directory of the saved model when - ``save_directory`` was set, else None. + Tuple of (success: bool, message: str) """ if not self.current_model or not self.current_tokenizer: - return False, "No model loaded. Please select a checkpoint first.", None + return False, "No model loaded. Please select a checkpoint first." if not self.is_peft: - return ( - False, - "This is not a PEFT model. Use 'Export Base Model' instead.", - None, - ) + return False, "This is not a PEFT model. Use 'Export Base Model' instead." - output_path: Optional[str] = None try: # Determine save method if format_type == "4-bit (FP4)": @@ -363,7 +354,6 @@ def export_merged_model( # Write export metadata so the Chat page can identify the base model self._write_export_metadata(save_directory) logger.info(f"Model saved successfully to {save_directory}") - output_path = str(Path(save_directory).resolve()) # Push to hub if requested if push_to_hub: @@ -371,7 +361,6 @@ def export_merged_model( return ( False, "Repository ID and Hugging Face token required for Hub upload", - None, ) logger.info(f"Pushing merged model to Hub: {repo_id}") @@ -389,14 +378,14 @@ def export_merged_model( ) logger.info(f"Model pushed successfully to {repo_id}") - return True, "Model exported successfully", output_path + return True, "Model exported successfully" except Exception as e: logger.error(f"Error exporting merged model: {e}") import traceback logger.error(traceback.format_exc()) - return False, f"Export failed: {str(e)}", None + return False, f"Export failed: {str(e)}" def export_base_model( self, @@ -406,26 +395,22 @@ def export_base_model( hf_token: Optional[str] = None, private: bool = False, base_model_id: Optional[str] = None, - ) -> Tuple[bool, str, Optional[str]]: + ) -> Tuple[bool, str]: """ Export base model (for non-PEFT models). Returns: - Tuple of (success, message, output_path). output_path is the - resolved absolute on-disk directory of the saved model when - ``save_directory`` was set, else None. + Tuple of (success: bool, message: str) """ if not self.current_model or not self.current_tokenizer: - return False, "No model loaded. Please select a checkpoint first.", None + return False, "No model loaded. Please select a checkpoint first." if self.is_peft: return ( False, "This is a PEFT model. Use 'Merged Model' export type instead.", - None, ) - output_path: Optional[str] = None try: # Save locally if requested if save_directory: @@ -439,7 +424,6 @@ def export_base_model( # Write export metadata so the Chat page can identify the base model self._write_export_metadata(save_directory) logger.info(f"Model saved successfully to {save_directory}") - output_path = str(Path(save_directory).resolve()) # Push to hub if requested if push_to_hub: @@ -447,7 +431,6 @@ def export_base_model( return ( False, "Repository ID and Hugging Face token required for Hub upload", - None, ) logger.info(f"Pushing base model to Hub: {repo_id}") @@ -489,16 +472,16 @@ def export_base_model( ) logger.info(f"Model pushed successfully to {repo_id}") else: - return False, "Local save directory required for Hub upload", None + return False, "Local save directory required for Hub upload" - return True, "Model exported successfully", output_path + return True, "Model exported successfully" except Exception as e: logger.error(f"Error exporting base model: {e}") import traceback logger.error(traceback.format_exc()) - return False, f"Export failed: {str(e)}", None + return False, f"Export failed: {str(e)}" def export_gguf( self, @@ -507,7 +490,7 @@ def export_gguf( push_to_hub: bool = False, repo_id: Optional[str] = None, hf_token: Optional[str] = None, - ) -> Tuple[bool, str, Optional[str]]: + ) -> Tuple[bool, str]: """ Export model in GGUF format. @@ -519,43 +502,15 @@ def export_gguf( hf_token: Hugging Face token Returns: - Tuple of (success, message, output_path). output_path is the - resolved absolute on-disk directory containing the .gguf - files when ``save_directory`` was set, else None. + Tuple of (success: bool, message: str) """ if not self.current_model or not self.current_tokenizer: - return False, "No model loaded. Please select a checkpoint first.", None + return False, "No model loaded. Please select a checkpoint first." - output_path: Optional[str] = None try: # Convert quantization method to lowercase for unsloth quant_method = quantization_method.lower() - # Pin convert_hf_to_gguf.py to the same llama.cpp ref as the - # llama-quantize binary (Studio installs at a tagged ref via - # setup.sh) so it can't drift past the pinned binary's gguf API. - # Set before both branches; hub-only export has save_directory == "". - global _LLAMA_CPP_SCRIPTS_WARNING_EMITTED - try: - from unsloth_zoo.llama_cpp import ( - LLAMA_CPP_DEFAULT_DIR, - _resolve_local_convert_script, # noqa: F401 - ) - - os.environ.setdefault( - "UNSLOTH_LLAMA_CPP_SCRIPTS_DIR", LLAMA_CPP_DEFAULT_DIR - ) - except ImportError: - if not _LLAMA_CPP_SCRIPTS_WARNING_EMITTED: - logger.warning( - "Unsloth: installed unsloth_zoo does not honor " - "UNSLOTH_LLAMA_CPP_SCRIPTS_DIR; convert_hf_to_gguf.py will " - "still be downloaded from llama.cpp master and may drift " - "past the pinned llama-quantize binary. Upgrade unsloth_zoo " - "to activate the local script pin." - ) - _LLAMA_CPP_SCRIPTS_WARNING_EMITTED = True - # Save locally if requested if save_directory: save_directory = str(resolve_export_dir(save_directory)) @@ -646,7 +601,6 @@ def export_gguf( abs_save_dir, "\n ".join(os.path.basename(f) for f in final_ggufs) or "(none)", ) - output_path = str(Path(abs_save_dir).resolve()) # Push to hub if requested if push_to_hub: @@ -654,7 +608,6 @@ def export_gguf( return ( False, "Repository ID and Hugging Face token required for Hub upload", - None, ) logger.info(f"Pushing GGUF model to Hub: {repo_id}") @@ -667,18 +620,14 @@ def export_gguf( ) logger.info(f"GGUF model pushed successfully to {repo_id}") - return ( - True, - f"GGUF model exported successfully ({quantization_method})", - output_path, - ) + return True, f"GGUF model exported successfully ({quantization_method})" except Exception as e: logger.error(f"Error exporting GGUF model: {e}") import traceback logger.error(traceback.format_exc()) - return False, f"GGUF export failed: {str(e)}", None + return False, f"GGUF export failed: {str(e)}" def export_lora_adapter( self, @@ -687,22 +636,19 @@ def export_lora_adapter( repo_id: Optional[str] = None, hf_token: Optional[str] = None, private: bool = False, - ) -> Tuple[bool, str, Optional[str]]: + ) -> Tuple[bool, str]: """ Export LoRA adapter only (not merged). Returns: - Tuple of (success, message, output_path). output_path is the - resolved absolute on-disk directory of the saved adapter - when ``save_directory`` was set, else None. + Tuple of (success: bool, message: str) """ if not self.current_model or not self.current_tokenizer: - return False, "No model loaded. Please select a checkpoint first.", None + return False, "No model loaded. Please select a checkpoint first." if not self.is_peft: - return False, "This is not a PEFT model. No adapter to export.", None + return False, "This is not a PEFT model. No adapter to export." - output_path: Optional[str] = None try: # Save locally if requested if save_directory: @@ -713,7 +659,6 @@ def export_lora_adapter( self.current_model.save_pretrained(save_directory) self.current_tokenizer.save_pretrained(save_directory) logger.info(f"Adapter saved successfully to {save_directory}") - output_path = str(Path(save_directory).resolve()) # Push to hub if requested if push_to_hub: @@ -721,7 +666,6 @@ def export_lora_adapter( return ( False, "Repository ID and Hugging Face token required for Hub upload", - None, ) logger.info(f"Pushing LoRA adapter to Hub: {repo_id}") @@ -732,14 +676,14 @@ def export_lora_adapter( ) logger.info(f"Adapter pushed successfully to {repo_id}") - return True, "LoRA adapter exported successfully", output_path + return True, "LoRA adapter exported successfully" except Exception as e: logger.error(f"Error exporting LoRA adapter: {e}") import traceback logger.error(traceback.format_exc()) - return False, f"Adapter export failed: {str(e)}", None + return False, f"Adapter export failed: {str(e)}" # Global export backend instance diff --git a/studio/backend/core/export/orchestrator.py b/studio/backend/core/export/orchestrator.py index 82de925592..500bc9e706 100644 --- a/studio/backend/core/export/orchestrator.py +++ b/studio/backend/core/export/orchestrator.py @@ -16,25 +16,19 @@ import atexit import structlog -from collections import deque from loggers import get_logger import multiprocessing as mp import queue import threading import time from pathlib import Path -from typing import Any, Deque, Dict, List, Optional, Tuple +from typing import Any, List, Optional, Tuple from utils.paths import outputs_root logger = get_logger(__name__) _CTX = mp.get_context("spawn") -# Maximum number of captured log lines kept in memory per export -# orchestrator. Acts as scrollback for the live export log panel in the -# UI. 4000 lines is ~1 MB worst-case at 256 chars/line. -_LOG_BUFFER_MAXLEN = 4000 - class ExportOrchestrator: """ @@ -50,9 +44,6 @@ def __init__(self): self._proc: Optional[mp.Process] = None self._cmd_queue: Any = None self._resp_queue: Any = None - # Serializes export operations (load_checkpoint, export_*, - # cleanup) so concurrent HTTP requests can never interleave - # commands on the subprocess queue. Previously unused. self._lock = threading.Lock() # Local state mirrors (updated from subprocess responses) @@ -60,131 +51,30 @@ def __init__(self): self.is_vision: bool = False self.is_peft: bool = False - # ── Live log capture ───────────────────────────────────── - # Thread-safe ring buffer of log lines forwarded from the - # worker subprocess. Powers the GET /api/export/logs/stream - # SSE endpoint that the export dialog consumes. - self._log_buffer: Deque[Dict[str, Any]] = deque(maxlen = _LOG_BUFFER_MAXLEN) - self._log_lock = threading.Lock() - # Monotonically increasing sequence number. Never reset across - # operations, so SSE clients can use it as a stable cursor even - # if clear_logs() is called mid-session. - self._log_seq: int = 0 - # Snapshot of _log_seq captured at the start of the current run - # (updated by clear_logs()). The SSE endpoint defaults its - # cursor to this value so a client that connects AFTER the - # worker has already emitted its first lines still sees the - # full run. Every line appended during the current run has seq - # strictly greater than _run_start_seq, and every line from - # prior runs has seq less than or equal to it. - self._run_start_seq: int = 0 - # True while an export operation (load/export/cleanup) is - # running. The SSE endpoint ends the stream 1 second after - # this flips back to False to drain any trailing log lines. - self._export_active: bool = False - atexit.register(self._cleanup) logger.info("ExportOrchestrator initialized (subprocess mode)") - # ------------------------------------------------------------------ - # Live log capture helpers - # ------------------------------------------------------------------ - - def _append_log(self, entry: Dict[str, Any]) -> None: - """Append a log line from the worker subprocess to the buffer. - - Entries look like {"type": "log", "stream": "stdout"|"stderr", - "line": "...", "ts": ...}. Each is stamped with a monotonic - seq number before it lands in the buffer so SSE clients can - cursor through new lines. - """ - line = entry.get("line") - if not line: - return - with self._log_lock: - self._log_seq += 1 - self._log_buffer.append( - { - "seq": self._log_seq, - "stream": entry.get("stream", "stdout"), - "line": line, - "ts": entry.get("ts", time.time()), - } - ) - - def clear_logs(self) -> None: - """Drop any buffered log lines from a previous operation. - - Called at the start of each export op so the UI shows only the - output of the current run. The seq counter is NOT reset, so an - SSE client that captured the cursor before clear_logs() will - still see new lines (with strictly greater seq numbers). - - Also snapshots the current seq into ``_run_start_seq`` so the - SSE endpoint can anchor its default cursor at the start of - this run. Anything appended after this call has seq strictly - greater than the snapshot and is reachable via - ``get_logs_since(get_run_start_seq())``. - """ - with self._log_lock: - self._log_buffer.clear() - self._run_start_seq = self._log_seq - - def get_logs_since(self, cursor: int) -> Tuple[List[Dict[str, Any]], int]: - """Return log entries with seq > cursor, plus the new cursor.""" - with self._log_lock: - new_entries = [entry for entry in self._log_buffer if entry["seq"] > cursor] - if new_entries: - return new_entries, new_entries[-1]["seq"] - return [], cursor - - def get_current_log_seq(self) -> int: - """Return the current seq counter without reading any entries.""" - with self._log_lock: - return self._log_seq - - def get_run_start_seq(self) -> int: - """Return the seq value captured at the start of the current run. - - The SSE endpoint uses this as the default cursor so a client - that connects AFTER the worker has already started emitting - output still sees every line from the current run. - """ - with self._log_lock: - return self._run_start_seq - - def is_export_active(self) -> bool: - """True while an export / load / cleanup command is running.""" - return self._export_active - # ------------------------------------------------------------------ # Subprocess lifecycle # ------------------------------------------------------------------ def _spawn_subprocess(self, config: dict) -> None: """Spawn a new export subprocess.""" - from utils.native_path_leases import ( - native_path_secret_removed_for_child_start, - run_without_native_path_secret, - ) - from .worker import run_export_process - with native_path_secret_removed_for_child_start(): - self._cmd_queue = _CTX.Queue() - self._resp_queue = _CTX.Queue() - - self._proc = _CTX.Process( - target = run_without_native_path_secret, - args = (run_export_process,), - kwargs = { - "cmd_queue": self._cmd_queue, - "resp_queue": self._resp_queue, - "config": config, - }, - daemon = True, - ) - self._proc.start() + self._cmd_queue = _CTX.Queue() + self._resp_queue = _CTX.Queue() + + self._proc = _CTX.Process( + target = run_export_process, + kwargs = { + "cmd_queue": self._cmd_queue, + "resp_queue": self._resp_queue, + "config": config, + }, + daemon = True, + ) + self._proc.start() logger.info("Export subprocess started (pid=%s)", self._proc.pid) def _shutdown_subprocess(self, timeout: float = 10.0) -> None: @@ -289,26 +179,8 @@ def _wait_response(self, expected_type: str, timeout: float = 3600.0) -> dict: error_msg = resp.get("error", "Unknown error") raise RuntimeError(f"Subprocess error: {error_msg}") - if rtype == "log": - # Forwarded stdout/stderr line from the worker process. - self._append_log(resp) - continue - if rtype == "status": - message = resp.get("message", "") - logger.info("Export subprocess status: %s", message) - # Surface status messages in the live log panel too so - # users see high level progress (e.g. "Importing - # Unsloth...", "Loading checkpoint: ...") alongside - # subprocess output. - if message: - self._append_log( - { - "stream": "status", - "line": message, - "ts": resp.get("ts", time.time()), - } - ) + logger.info("Export subprocess status: %s", resp.get("message", "")) continue # Other response types during wait — skip @@ -359,47 +231,37 @@ def load_checkpoint( "hf_token": hf_token, } - with self._lock: - # Start a fresh log buffer for this operation so the UI - # sees only the current run's output. - self.clear_logs() - self._export_active = True - try: - # Always kill existing subprocess and spawn fresh. - if self._ensure_subprocess_alive(): - self._shutdown_subprocess() - elif self._proc is not None: - self._shutdown_subprocess(timeout = 2) + # Always kill existing subprocess and spawn fresh. + if self._ensure_subprocess_alive(): + self._shutdown_subprocess() + elif self._proc is not None: + self._shutdown_subprocess(timeout = 2) - logger.info( - "Spawning fresh export subprocess for '%s'", checkpoint_path - ) - self._spawn_subprocess(sub_config) + logger.info("Spawning fresh export subprocess for '%s'", checkpoint_path) + self._spawn_subprocess(sub_config) - try: - resp = self._wait_response("loaded") - except RuntimeError as exc: - self._shutdown_subprocess(timeout = 5) - self.current_checkpoint = None - self.is_vision = False - self.is_peft = False - return False, str(exc) - - if resp.get("success"): - self.current_checkpoint = resp.get("checkpoint") - self.is_vision = resp.get("is_vision", False) - self.is_peft = resp.get("is_peft", False) - logger.info("Checkpoint '%s' loaded in subprocess", checkpoint_path) - return True, resp.get("message", "Loaded successfully") - else: - error = resp.get("message", "Failed to load checkpoint") - logger.error("Failed to load checkpoint: %s", error) - self.current_checkpoint = None - self.is_vision = False - self.is_peft = False - return False, error - finally: - self._export_active = False + try: + resp = self._wait_response("loaded", timeout = 300) + except RuntimeError as exc: + self._shutdown_subprocess(timeout = 5) + self.current_checkpoint = None + self.is_vision = False + self.is_peft = False + return False, str(exc) + + if resp.get("success"): + self.current_checkpoint = resp.get("checkpoint") + self.is_vision = resp.get("is_vision", False) + self.is_peft = resp.get("is_peft", False) + logger.info("Checkpoint '%s' loaded in subprocess", checkpoint_path) + return True, resp.get("message", "Loaded successfully") + else: + error = resp.get("message", "Failed to load checkpoint") + logger.error("Failed to load checkpoint: %s", error) + self.current_checkpoint = None + self.is_vision = False + self.is_peft = False + return False, error def export_merged_model( self, @@ -409,7 +271,7 @@ def export_merged_model( repo_id: Optional[str] = None, hf_token: Optional[str] = None, private: bool = False, - ) -> Tuple[bool, str, Optional[str]]: + ) -> Tuple[bool, str]: """Export merged PEFT model.""" return self._run_export( "merged", @@ -431,7 +293,7 @@ def export_base_model( hf_token: Optional[str] = None, private: bool = False, base_model_id: Optional[str] = None, - ) -> Tuple[bool, str, Optional[str]]: + ) -> Tuple[bool, str]: """Export base model (non-PEFT).""" return self._run_export( "base", @@ -452,7 +314,7 @@ def export_gguf( push_to_hub: bool = False, repo_id: Optional[str] = None, hf_token: Optional[str] = None, - ) -> Tuple[bool, str, Optional[str]]: + ) -> Tuple[bool, str]: """Export model in GGUF format.""" return self._run_export( "gguf", @@ -472,7 +334,7 @@ def export_lora_adapter( repo_id: Optional[str] = None, hf_token: Optional[str] = None, private: bool = False, - ) -> Tuple[bool, str, Optional[str]]: + ) -> Tuple[bool, str]: """Export LoRA adapter only.""" return self._run_export( "lora", @@ -485,74 +347,46 @@ def export_lora_adapter( }, ) - def _run_export( - self, export_type: str, params: dict - ) -> Tuple[bool, str, Optional[str]]: - """Send an export command to the subprocess and wait for result. - - Returns ``(success, message, output_path)``. ``output_path`` is the - resolved on-disk directory the worker actually wrote to (None when - the export only pushed to Hub or failed before any file was - written). Surfaced via the export route's ``details.output_path`` - so the dialog's success screen can show the user where the model - landed. - """ - with self._lock: - if not self._ensure_subprocess_alive(): - return ( - False, - "No export subprocess running. Load a checkpoint first.", - None, - ) - - self.clear_logs() - self._export_active = True - try: - cmd = {"type": "export", "export_type": export_type, **params} - try: - self._send_cmd(cmd) - resp = self._wait_response( - f"export_{export_type}_done", - timeout = 3600, # GGUF for 30B+ models can take 30+ min - ) - return ( - resp.get("success", False), - resp.get("message", ""), - resp.get("output_path"), - ) - except RuntimeError as exc: - return False, str(exc), None - finally: - self._export_active = False + def _run_export(self, export_type: str, params: dict) -> Tuple[bool, str]: + """Send an export command to the subprocess and wait for result.""" + if not self._ensure_subprocess_alive(): + return False, "No export subprocess running. Load a checkpoint first." + + cmd = {"type": "export", "export_type": export_type, **params} + + try: + self._send_cmd(cmd) + resp = self._wait_response( + f"export_{export_type}_done", + timeout = 3600, # GGUF for 30B+ models can take 30+ min + ) + return resp.get("success", False), resp.get("message", "") + except RuntimeError as exc: + return False, str(exc) def cleanup_memory(self) -> bool: """Cleanup export-related models from memory.""" - with self._lock: - if not self._ensure_subprocess_alive(): - # No subprocess — just clear local state - self.current_checkpoint = None - self.is_vision = False - self.is_peft = False - return True - - self._export_active = True - try: - try: - self._send_cmd({"type": "cleanup"}) - resp = self._wait_response("cleanup_done", timeout = 30) - success = resp.get("success", False) - except RuntimeError: - success = False - - # Shut down subprocess after cleanup — no model loaded - self._shutdown_subprocess() - - self.current_checkpoint = None - self.is_vision = False - self.is_peft = False - return success - finally: - self._export_active = False + if not self._ensure_subprocess_alive(): + # No subprocess — just clear local state + self.current_checkpoint = None + self.is_vision = False + self.is_peft = False + return True + + try: + self._send_cmd({"type": "cleanup"}) + resp = self._wait_response("cleanup_done", timeout = 30) + success = resp.get("success", False) + except RuntimeError: + success = False + + # Shut down subprocess after cleanup — no model loaded + self._shutdown_subprocess() + + self.current_checkpoint = None + self.is_vision = False + self.is_peft = False + return success def scan_checkpoints( self, outputs_dir: str = str(outputs_root()) diff --git a/studio/backend/core/export/worker.py b/studio/backend/core/export/worker.py index f77b1966c4..3f3dc955fa 100644 --- a/studio/backend/core/export/worker.py +++ b/studio/backend/core/export/worker.py @@ -17,12 +17,10 @@ from __future__ import annotations -import errno import structlog from loggers import get_logger import os import sys -import threading import time import traceback from pathlib import Path @@ -31,154 +29,6 @@ logger = get_logger(__name__) -# Gate that controls whether captured stdout/stderr lines are forwarded -# to the parent's resp_queue (and from there to the export-dialog SSE -# stream). Closed by default so the noisy bootstrap phase -- transformers -# venv activation, Unsloth/torch imports, base-model resolution, "Top -# GGUF/hub models" lists, vision detection, weight loading bars -- is -# suppressed in the UI. _handle_export() opens the gate at the start of -# the actual export work and leaves it open; the orchestrator always -# spawns a fresh subprocess for the next checkpoint load (see -# orchestrator._spawn_subprocess) which resets this state. -# -# Lines dropped while the gate is closed are still echoed to the saved -# original stdout/stderr fds so the server console / log file keeps the -# full output for debugging. -_log_forward_gate = threading.Event() - - -def _setup_log_capture(resp_queue: Any) -> None: - """Redirect fds 1 and 2 through pipes so every line printed by this - worker process and any child process it spawns is forwarded to the - parent process via resp_queue as {"type": "log", ...} messages. - - Must be called BEFORE LogConfig.setup_logging and BEFORE any ML - imports, otherwise library handlers may capture the original stderr - reference and bypass the pipe. - - Lines are also echoed back to the original stdout/stderr so the - server console keeps receiving the full subprocess output, even - while ``_log_forward_gate`` is closed. - """ - - try: - saved_out_fd = os.dup(1) - saved_err_fd = os.dup(2) - except OSError: - # dup failed (exotic platforms) - give up quietly, export still - # works, just no live log streaming. - return - - try: - r_out, w_out = os.pipe() - r_err, w_err = os.pipe() - except OSError: - os.close(saved_out_fd) - os.close(saved_err_fd) - return - - try: - os.dup2(w_out, 1) - os.dup2(w_err, 2) - except OSError: - for fd in (saved_out_fd, saved_err_fd, r_out, w_out, r_err, w_err): - try: - os.close(fd) - except OSError: - pass - return - - # Close the write ends we just dup2'd (fds 1 and 2 are the real - # write ends now). - os.close(w_out) - os.close(w_err) - - # Replace Python's sys.stdout/sys.stderr with line-buffered writers - # bound to the (now-redirected) fds 1 and 2. - try: - sys.stdout = os.fdopen(1, "w", buffering = 1, encoding = "utf-8", errors = "replace") - sys.stderr = os.fdopen(2, "w", buffering = 1, encoding = "utf-8", errors = "replace") - except Exception: - pass - - def _reader(read_fd: int, stream_name: str, echo_fd: int) -> None: - buf = bytearray() - while True: - try: - chunk = os.read(read_fd, 4096) - except OSError as exc: - if exc.errno == errno.EBADF: - break - continue - if not chunk: - break - # Echo to the original fd so the server console still sees - # the full output. - try: - os.write(echo_fd, chunk) - except OSError: - pass - buf.extend(chunk) - # Split on \n OR \r so tqdm-style progress bars update. - while True: - nl = -1 - for i, b in enumerate(buf): - if b == 0x0A or b == 0x0D: - nl = i - break - if nl < 0: - break - line = bytes(buf[:nl]).decode("utf-8", errors = "replace") - del buf[: nl + 1] - if not line: - continue - if not _log_forward_gate.is_set(): - # Gate closed (bootstrap phase) -- already echoed to - # the saved console fd above; drop the line so the - # export dialog doesn't see import / vendoring noise. - continue - try: - resp_queue.put_nowait( - { - "type": "log", - "stream": stream_name, - "line": line, - "ts": time.time(), - } - ) - except Exception: - # Queue put failed (full, closed, etc.) - drop the - # line rather than crash the reader thread. - pass - if buf and _log_forward_gate.is_set(): - try: - resp_queue.put_nowait( - { - "type": "log", - "stream": stream_name, - "line": bytes(buf).decode("utf-8", errors = "replace"), - "ts": time.time(), - } - ) - except Exception: - pass - - t_out = threading.Thread( - target = _reader, - args = (r_out, "stdout", saved_out_fd), - daemon = True, - name = "export-log-stdout", - ) - t_err = threading.Thread( - target = _reader, - args = (r_err, "stderr", saved_err_fd), - daemon = True, - name = "export-log-stderr", - ) - t_out.start() - t_err.start() - - def _activate_transformers_version(model_name: str) -> None: """Activate the correct transformers version BEFORE any ML imports.""" # Ensure backend is on path for utils imports @@ -267,17 +117,9 @@ def _handle_export(backend, cmd: dict, resp_queue: Any) -> None: export_type = cmd["export_type"] # "merged", "base", "gguf", "lora" response_type = f"export_{export_type}_done" - # Open the log forwarding gate so the user sees the actual export - # progress (Unsloth merge bars, file copies, GGUF conversion, etc.) - # in the live log panel. The gate stays open for the rest of this - # subprocess's life; the orchestrator spawns a fresh subprocess for - # the next checkpoint load, which resets the gate to closed. - _log_forward_gate.set() - - output_path: Any = None try: if export_type == "merged": - success, message, output_path = backend.export_merged_model( + success, message = backend.export_merged_model( save_directory = cmd.get("save_directory", ""), format_type = cmd.get("format_type", "16-bit (FP16)"), push_to_hub = cmd.get("push_to_hub", False), @@ -286,7 +128,7 @@ def _handle_export(backend, cmd: dict, resp_queue: Any) -> None: private = cmd.get("private", False), ) elif export_type == "base": - success, message, output_path = backend.export_base_model( + success, message = backend.export_base_model( save_directory = cmd.get("save_directory", ""), push_to_hub = cmd.get("push_to_hub", False), repo_id = cmd.get("repo_id"), @@ -295,7 +137,7 @@ def _handle_export(backend, cmd: dict, resp_queue: Any) -> None: base_model_id = cmd.get("base_model_id"), ) elif export_type == "gguf": - success, message, output_path = backend.export_gguf( + success, message = backend.export_gguf( save_directory = cmd.get("save_directory", ""), quantization_method = cmd.get("quantization_method", "Q4_K_M"), push_to_hub = cmd.get("push_to_hub", False), @@ -303,7 +145,7 @@ def _handle_export(backend, cmd: dict, resp_queue: Any) -> None: hf_token = cmd.get("hf_token"), ) elif export_type == "lora": - success, message, output_path = backend.export_lora_adapter( + success, message = backend.export_lora_adapter( save_directory = cmd.get("save_directory", ""), push_to_hub = cmd.get("push_to_hub", False), repo_id = cmd.get("repo_id"), @@ -319,7 +161,6 @@ def _handle_export(backend, cmd: dict, resp_queue: Any) -> None: "type": response_type, "success": success, "message": message, - "output_path": output_path, "ts": time.time(), }, ) @@ -331,7 +172,6 @@ def _handle_export(backend, cmd: dict, resp_queue: Any) -> None: "type": response_type, "success": False, "message": str(exc), - "output_path": None, "stack": traceback.format_exc(limit = 20), "ts": time.time(), }, @@ -377,26 +217,10 @@ def run_export_process( """ import queue as _queue - # Install fd-level stdout/stderr capture FIRST so every subsequent - # print and every child process inherits the redirected fds. This - # is what powers the live export log stream in the UI. - _setup_log_capture(resp_queue) - os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["PYTHONWARNINGS"] = ( "ignore" # Suppress warnings at C-level before imports ) - # Force unbuffered output from any child Python process (e.g. the - # GGUF converter) so their prints surface in the log stream as they - # happen rather than at the end. - os.environ["PYTHONUNBUFFERED"] = "1" - # tqdm defaults to a 10-second mininterval when stdout is not a tty - # (which it isn't here -- we redirected fd 1/2 to a pipe). That makes - # multi-step progress bars look frozen in the export log panel. Force - # frequent flushes so the user sees movement during merge / GGUF - # conversion. Has no effect on single-step bars (e.g. "Copying 1 - # files") which only emit start/end events regardless. - os.environ.setdefault("TQDM_MININTERVAL", "0.5") import warnings from loggers.config import LogConfig diff --git a/studio/backend/core/inference/anthropic_compat.py b/studio/backend/core/inference/anthropic_compat.py deleted file mode 100644 index 263718c540..0000000000 --- a/studio/backend/core/inference/anthropic_compat.py +++ /dev/null @@ -1,576 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. - -""" -Anthropic Messages API ↔ OpenAI format translation utilities. - -Pure functions and a stateful stream emitter — no FastAPI, no I/O. -""" - -from __future__ import annotations - -import json -from typing import Any, Optional, Union - - -def _anthropic_image_block_to_openai_part(block: dict) -> Optional[dict]: - """Translate one Anthropic ``image`` block to an OpenAI ``image_url`` part. - - Accepts both source shapes: - - ``{"type": "base64", "media_type": "image/jpeg", "data": "..."}`` - - ``{"type": "url", "url": "https://..."}`` - - Returns ``None`` when the source is malformed so the caller can skip it. - """ - source = block.get("source") or {} - stype = source.get("type") - if stype == "base64": - data = source.get("data") - if not data: - return None - media_type = source.get("media_type") or "image/jpeg" - return { - "type": "image_url", - "image_url": {"url": f"data:{media_type};base64,{data}"}, - } - if stype == "url": - url = source.get("url") - if not url: - return None - return {"type": "image_url", "image_url": {"url": url}} - return None - - -def anthropic_messages_to_openai( - messages: list[dict], - system: Optional[Union[str, list]] = None, -) -> list[dict]: - """Convert Anthropic messages + system to OpenAI-format message dicts. - - User messages that carry ``image`` blocks are emitted as OpenAI - multimodal content arrays (``[{type: "text", ...}, {type: "image_url", ...}]``) - so they flow through llama-server's native vision pathway. - """ - result: list[dict] = [] - - # System prompt - if system: - if isinstance(system, str): - result.append({"role": "system", "content": system}) - elif isinstance(system, list): - parts = [] - for block in system: - if isinstance(block, dict) and block.get("type") == "text": - parts.append(block["text"]) - elif isinstance(block, str): - parts.append(block) - if parts: - result.append({"role": "system", "content": "\n".join(parts)}) - - for msg in messages: - role = msg["role"] if isinstance(msg, dict) else msg.role - content = msg["content"] if isinstance(msg, dict) else msg.content - - if isinstance(content, str): - result.append({"role": role, "content": content}) - continue - - if role == "assistant": - # Assistant content carries text + tool_use; images aren't - # part of Anthropic's assistant content model. - text_parts: list[str] = [] - tool_calls: list[dict] = [] - for block in content: - b = block if isinstance(block, dict) else block.model_dump() - btype = b.get("type", "") - if btype == "text": - text_parts.append(b["text"]) - elif btype == "tool_use": - tool_calls.append( - { - "id": b["id"], - "type": "function", - "function": { - "name": b["name"], - "arguments": json.dumps(b["input"]), - }, - } - ) - msg_dict: dict[str, Any] = {"role": "assistant"} - if text_parts: - msg_dict["content"] = "\n".join(text_parts) - if tool_calls: - msg_dict["tool_calls"] = tool_calls - result.append(msg_dict) - continue - - if role == "user": - # Build an ordered part list so text/image interleaving is - # preserved (e.g. [text, image, text, image]). tool_result - # blocks become their own OpenAI "tool" role messages. - user_parts: list[dict] = [] - has_image = False - tool_results: list[dict] = [] - for block in content: - b = block if isinstance(block, dict) else block.model_dump() - btype = b.get("type", "") - if btype == "text": - user_parts.append({"type": "text", "text": b["text"]}) - elif btype == "image": - part = _anthropic_image_block_to_openai_part(b) - if part is not None: - user_parts.append(part) - has_image = True - elif btype == "tool_result": - tc = b.get("content", "") - if isinstance(tc, list): - tc = " ".join( - p["text"] - for p in tc - if isinstance(p, dict) and p.get("type") == "text" - ) - tool_results.append( - { - "role": "tool", - "tool_call_id": b["tool_use_id"], - "content": str(tc), - } - ) - - if has_image: - result.append({"role": "user", "content": user_parts}) - else: - # No images — collapse text parts to a plain string so - # existing text-only callers keep their simple shape. - text = "\n".join(p["text"] for p in user_parts) - if text: - result.append({"role": "user", "content": text}) - for tr in tool_results: - result.append(tr) - - return result - - -def anthropic_tools_to_openai(tools: list) -> list[dict]: - """Convert Anthropic tool definitions to OpenAI function-tool format.""" - result = [] - for t in tools: - td = t if isinstance(t, dict) else t.model_dump() - result.append( - { - "type": "function", - "function": { - "name": td["name"], - "description": td.get("description", ""), - "parameters": td.get("input_schema", {}), - }, - } - ) - return result - - -def anthropic_tool_choice_to_openai(tc: Any) -> Any: - """Translate Anthropic `tool_choice` into OpenAI `tool_choice`. - - Anthropic formats (all dict shapes with a ``type`` discriminator): - - - ``{"type": "auto"}`` → ``"auto"`` - - ``{"type": "any"}`` → ``"required"`` - - ``{"type": "none"}`` → ``"none"`` - - ``{"type": "tool", "name": "get_weather"}`` - → ``{"type": "function", "function": {"name": "get_weather"}}`` - - Returns ``None`` for ``None`` or any unrecognized shape (caller may - then fall back to its own default, typically ``"auto"``). - """ - if tc is None: - return None - if not isinstance(tc, dict): - return None - t = tc.get("type") - if t == "auto": - return "auto" - if t == "any": - return "required" - if t == "none": - return "none" - if t == "tool": - name = tc.get("name") - if not name: - return None - return {"type": "function", "function": {"name": name}} - return None - - -def build_anthropic_sse_event(event_type: str, data: dict) -> str: - """Format a single Anthropic SSE event.""" - return f"event: {event_type}\ndata: {json.dumps(data)}\n\n" - - -class AnthropicStreamEmitter: - """Converts generator events from generate_chat_completion_with_tools() - into Anthropic Messages SSE strings.""" - - def __init__(self) -> None: - self.block_index: int = 0 - self._text_block_open: bool = False - self._prev_text: str = "" - self._usage: dict = {} - - def start(self, message_id: str, model: str) -> list[str]: - """Emit message_start and open the first text content block.""" - events = [] - events.append( - build_anthropic_sse_event( - "message_start", - { - "type": "message_start", - "message": { - "id": message_id, - "type": "message", - "role": "assistant", - "content": [], - "model": model, - "stop_reason": None, - "stop_sequence": None, - "usage": {"input_tokens": 0, "output_tokens": 0}, - }, - }, - ) - ) - events.extend(self._open_text_block()) - return events - - def feed(self, event: dict) -> list[str]: - """Process one generator event, return SSE strings.""" - etype = event.get("type", "") - if etype == "content": - return self._handle_content(event) - elif etype == "tool_start": - return self._handle_tool_start(event) - elif etype == "tool_end": - return self._handle_tool_end(event) - elif etype == "metadata": - self._usage = event.get("usage", {}) - return [] - # status events — no Anthropic equivalent - return [] - - def finish(self, stop_reason: str = "end_turn") -> list[str]: - """Close any open block and emit message_delta + message_stop.""" - events = [] - if self._text_block_open: - events.append(self._close_block()) - events.append( - build_anthropic_sse_event( - "message_delta", - { - "type": "message_delta", - "delta": {"stop_reason": stop_reason, "stop_sequence": None}, - "usage": { - "output_tokens": self._usage.get("completion_tokens", 0), - }, - }, - ) - ) - events.append( - build_anthropic_sse_event( - "message_stop", - { - "type": "message_stop", - }, - ) - ) - return events - - def _handle_content(self, event: dict) -> list[str]: - cumulative = event.get("text", "") - new_text = cumulative[len(self._prev_text) :] - self._prev_text = cumulative - if not new_text: - return [] - if not self._text_block_open: - events = self._open_text_block() - else: - events = [] - events.append( - build_anthropic_sse_event( - "content_block_delta", - { - "type": "content_block_delta", - "index": self.block_index, - "delta": {"type": "text_delta", "text": new_text}, - }, - ) - ) - return events - - def _handle_tool_start(self, event: dict) -> list[str]: - events = [] - # Close current text block if open - if self._text_block_open: - events.append(self._close_block()) - # Open a tool_use block - self.block_index += 1 - events.append( - build_anthropic_sse_event( - "content_block_start", - { - "type": "content_block_start", - "index": self.block_index, - "content_block": { - "type": "tool_use", - "id": event.get("tool_call_id", ""), - "name": event.get("tool_name", ""), - "input": {}, - }, - }, - ) - ) - # Emit the arguments as input_json_delta - args = event.get("arguments", {}) - if args: - events.append( - build_anthropic_sse_event( - "content_block_delta", - { - "type": "content_block_delta", - "index": self.block_index, - "delta": { - "type": "input_json_delta", - "partial_json": json.dumps(args), - }, - }, - ) - ) - return events - - def _handle_tool_end(self, event: dict) -> list[str]: - events = [] - # Close the tool_use block - events.append(self._close_block()) - # Emit custom tool_result event (non-standard, ignored by SDKs) - events.append( - build_anthropic_sse_event( - "tool_result", - { - "type": "tool_result", - "tool_use_id": event.get("tool_call_id", ""), - "content": event.get("result", ""), - }, - ) - ) - # Open a new text block for the model's next response - self.block_index += 1 - events.extend(self._open_text_block()) - # Reset text tracking for the next synthesis turn - self._prev_text = "" - return events - - def _open_text_block(self) -> list[str]: - self._text_block_open = True - return [ - build_anthropic_sse_event( - "content_block_start", - { - "type": "content_block_start", - "index": self.block_index, - "content_block": {"type": "text", "text": ""}, - }, - ) - ] - - def _close_block(self) -> str: - self._text_block_open = False - return build_anthropic_sse_event( - "content_block_stop", - { - "type": "content_block_stop", - "index": self.block_index, - }, - ) - - -class AnthropicPassthroughEmitter: - """Converts llama-server's OpenAI-format streaming chunks into Anthropic SSE. - - Used for the client-side tool-use pass-through path: the client (e.g. Claude - Code) sends its own tool definitions in the ``tools`` field and expects to - execute them itself. We forward them to llama-server and translate the - streaming response back to Anthropic format without executing anything. - """ - - def __init__(self) -> None: - self.block_index: int = -1 - self._current_block_type: Optional[str] = None # "text" | "tool_use" | None - self._tool_call_states: dict = {} # delta index -> {block_index, id, name} - self._usage: dict = {} - self._stop_reason: str = "end_turn" - - def start(self, message_id: str, model: str) -> list[str]: - return [ - build_anthropic_sse_event( - "message_start", - { - "type": "message_start", - "message": { - "id": message_id, - "type": "message", - "role": "assistant", - "content": [], - "model": model, - "stop_reason": None, - "stop_sequence": None, - "usage": {"input_tokens": 0, "output_tokens": 0}, - }, - }, - ) - ] - - def feed_chunk(self, chunk: dict) -> list[str]: - """Process one OpenAI streaming chat.completion.chunk.""" - events: list[str] = [] - - # usage-only chunks carry token totals - usage = chunk.get("usage") - if usage: - self._usage = usage - - choices = chunk.get("choices") or [] - if not choices: - return events - - choice = choices[0] - delta = choice.get("delta") or {} - finish_reason = choice.get("finish_reason") - - # ── Text content ── - content = delta.get("content") - if content: - if self._current_block_type != "text": - if self._current_block_type is not None: - events.append(self._close_current_block()) - events.extend(self._open_text_block()) - events.append( - build_anthropic_sse_event( - "content_block_delta", - { - "type": "content_block_delta", - "index": self.block_index, - "delta": {"type": "text_delta", "text": content}, - }, - ) - ) - - # ── Tool calls (streaming deltas) ── - tool_calls = delta.get("tool_calls") or [] - for tc in tool_calls: - tc_idx = tc.get("index", 0) - fn = tc.get("function") or {} - if tc_idx not in self._tool_call_states: - # New tool call — close prior block, open tool_use block - if self._current_block_type is not None: - events.append(self._close_current_block()) - tc_id = tc.get("id", "") - tc_name = fn.get("name", "") - self.block_index += 1 - self._current_block_type = "tool_use" - self._tool_call_states[tc_idx] = { - "block_index": self.block_index, - "id": tc_id, - "name": tc_name, - } - events.append( - build_anthropic_sse_event( - "content_block_start", - { - "type": "content_block_start", - "index": self.block_index, - "content_block": { - "type": "tool_use", - "id": tc_id, - "name": tc_name, - "input": {}, - }, - }, - ) - ) - - args_delta = fn.get("arguments", "") - if args_delta: - events.append( - build_anthropic_sse_event( - "content_block_delta", - { - "type": "content_block_delta", - "index": self._tool_call_states[tc_idx]["block_index"], - "delta": { - "type": "input_json_delta", - "partial_json": args_delta, - }, - }, - ) - ) - - # ── Finish reason ── - if finish_reason: - if finish_reason == "tool_calls": - self._stop_reason = "tool_use" - elif finish_reason == "length": - self._stop_reason = "max_tokens" - else: - self._stop_reason = "end_turn" - - return events - - def finish(self) -> list[str]: - events: list[str] = [] - if self._current_block_type is not None: - events.append(self._close_current_block()) - events.append( - build_anthropic_sse_event( - "message_delta", - { - "type": "message_delta", - "delta": { - "stop_reason": self._stop_reason, - "stop_sequence": None, - }, - "usage": { - "output_tokens": self._usage.get("completion_tokens", 0), - }, - }, - ) - ) - events.append( - build_anthropic_sse_event( - "message_stop", - {"type": "message_stop"}, - ) - ) - return events - - def _open_text_block(self) -> list[str]: - self.block_index += 1 - self._current_block_type = "text" - return [ - build_anthropic_sse_event( - "content_block_start", - { - "type": "content_block_start", - "index": self.block_index, - "content_block": {"type": "text", "text": ""}, - }, - ) - ] - - def _close_current_block(self) -> str: - idx = self.block_index - self._current_block_type = None - return build_anthropic_sse_event( - "content_block_stop", - { - "type": "content_block_stop", - "index": idx, - }, - ) diff --git a/studio/backend/core/inference/audio_codecs.py b/studio/backend/core/inference/audio_codecs.py index df3bf27c16..bcf3ec2937 100644 --- a/studio/backend/core/inference/audio_codecs.py +++ b/studio/backend/core/inference/audio_codecs.py @@ -8,7 +8,6 @@ import io import re -import subprocess import wave import structlog from loggers import get_logger @@ -17,11 +16,6 @@ import numpy as np import torch -from utils.native_path_leases import child_env_without_native_path_secret -from utils.subprocess_compat import ( - windows_hidden_subprocess_kwargs as _windows_hidden_subprocess_kwargs, -) - logger = get_logger(__name__) @@ -87,6 +81,7 @@ def _load_bicodec(self, device: str, model_repo_path: Optional[str] = None) -> N return import os import sys + import subprocess # Clone SparkAudio/Spark-TTS GitHub repo for the sparktts Python package # (same approach as training — the HF model repos don't contain the package) @@ -106,8 +101,6 @@ def _load_bicodec(self, device: str, model_repo_path: Optional[str] = None) -> N spark_code_dir, ], check = True, - env = child_env_without_native_path_secret(), - **_windows_hidden_subprocess_kwargs(), ) if spark_code_dir not in sys.path: @@ -126,6 +119,7 @@ def _load_dac(self, device: str) -> None: return import os import sys + import subprocess # Clone OuteTTS repo (same pattern as Spark-TTS / BiCodec) # The pip package has problematic dependencies; the notebook clones and @@ -145,8 +139,6 @@ def _load_dac(self, device: str) -> None: outetts_code_dir, ], check = True, - env = child_env_without_native_path_secret(), - **_windows_hidden_subprocess_kwargs(), ) # Remove files that pull in heavy / incompatible dependencies # (matches notebook: gguf_model.py is under models/, others under outetts/) diff --git a/studio/backend/core/inference/defaults.py b/studio/backend/core/inference/defaults.py index 53718c1294..f3026dddaf 100644 --- a/studio/backend/core/inference/defaults.py +++ b/studio/backend/core/inference/defaults.py @@ -10,7 +10,6 @@ "unsloth/gemma-4-E4B-it-GGUF", "unsloth/gemma-4-31B-it-GGUF", "unsloth/gemma-4-26B-A4B-it-GGUF", - "unsloth/Qwen3.6-35B-A3B-GGUF", "unsloth/Qwen3.5-4B-GGUF", "unsloth/Qwen3.5-9B-GGUF", "unsloth/Qwen3.5-35B-A3B-GGUF", @@ -28,7 +27,6 @@ "unsloth/gemma-4-E4B-it-GGUF", "unsloth/gemma-4-31B-it-GGUF", "unsloth/gemma-4-26B-A4B-it-GGUF", - "unsloth/Qwen3.6-35B-A3B-GGUF", "unsloth/Qwen3.5-4B-GGUF", "unsloth/Qwen3.5-9B-GGUF", "unsloth/Qwen3.5-35B-A3B-GGUF", diff --git a/studio/backend/core/inference/inference.py b/studio/backend/core/inference/inference.py index 4c140013a0..867bdefc62 100644 --- a/studio/backend/core/inference/inference.py +++ b/studio/backend/core/inference/inference.py @@ -253,10 +253,6 @@ def load_model( """ Load any model: base, LoRA adapter, text, or vision. """ - # GGUF uses max_seq_length=0 as "model default"; Unsloth crashes on it. - if max_seq_length <= 0: - max_seq_length = 2048 - try: model_name = config.identifier diff --git a/studio/backend/core/inference/llama_cpp.py b/studio/backend/core/inference/llama_cpp.py index f768764c22..c84ac640df 100644 --- a/studio/backend/core/inference/llama_cpp.py +++ b/studio/backend/core/inference/llama_cpp.py @@ -11,7 +11,6 @@ import atexit import contextlib import json -import os import re import struct import structlog @@ -19,23 +18,16 @@ import shutil import socket import subprocess -import sys import threading import time from pathlib import Path -from typing import Generator, List, Optional +from typing import Generator, Optional from urllib.parse import urlparse import httpx -from utils.native_path_leases import child_env_without_native_path_secret -from utils.subprocess_compat import ( - windows_hidden_subprocess_kwargs as _windows_hidden_subprocess_kwargs, -) - logger = get_logger(__name__) - # ── Pre-compiled patterns for plan-without-action re-prompt ── # Forward-looking intent signals that indicate the model is # describing what it *will* do rather than giving a final answer. @@ -55,21 +47,6 @@ r")" ) _MAX_REPROMPTS = 3 - -# Without max_tokens, llama-server defaults to n_predict = n_ctx (up to -# 262144 for Qwen3.5), producing many-minute zombie decodes when cancel -# fails. t_max_predict_ms is a wall-clock backstop applied unconditionally, -# but the llama.cpp README notes it ONLY fires after a newline has been -# generated -- a model stuck in a long unbroken non-newline sequence is -# unbounded by it. So we still want a token cap as the front-line limiter. -# -# The cap is the model's effective context length when we know it, -# falling back to a generous floor when metadata is unavailable. 4096 was -# too low: Qwen3 / gpt-oss reasoning traces routinely exceed it, and any -# OpenAI-API caller that omits max_tokens (langchain, llama-index, raw -# curl) sees responses silently truncated mid-sentence. -_DEFAULT_MAX_TOKENS_FLOOR = 32768 -_DEFAULT_T_MAX_PREDICT_MS = 600_000 # 10 min _REPROMPT_MAX_CHARS = 2000 # ── Pre-compiled patterns for GGUF shard detection ─────────── @@ -77,238 +54,6 @@ _SHARD_RE = re.compile(r"^(.*)-\d{5}-of-\d{5}\.gguf$") -# ── Sliding-window-pattern resolver ─────────────────────────── -# Resolves the per-layer SWA mask when a GGUF reports a sliding window -# but no `sliding_window_pattern` field. Tier order in -# `_resolve_swa_pattern`: GGUF metadata, on-disk cache, bootstrap dict -# below, transformers introspection, HF Hub config.json, legacy 1/4 -# fallback. Period N means layer i is SWA iff `(i + 1) % N != 0`, -# matching transformers. Skipped on purpose: phi3 (no key/val length -# in GGUF, window >= ctx anyway), qwen2 family (converter strips -# sliding_window when use_sliding_window=False), mistral v0.1/v0.2 -# (all-SWA can't be expressed as a period). -_BOOTSTRAP_SWA_DEFAULTS: dict[str, int] = { - "gemma2": 2, # Gemma2Config.sliding_window_pattern - "gemma3": 6, # Gemma3TextConfig.sliding_window_pattern - "gemma3n": 5, # text_config.layer_types: SWA*4 + FULL - "gpt_oss": 2, # text_config.layer_types: alternating - "cohere2": 4, # Cohere2Config.sliding_window_pattern -} - -# Process-wide cache backed by JSON on disk. Values are int period or -# list[bool] mask. Lazy-loaded. -_SWA_CACHE: Optional[dict] = None -_SWA_CACHE_LOCK = threading.Lock() - - -def _swa_cache_path() -> Path: - home = os.environ.get("UNSLOTH_STUDIO_HOME") or os.environ.get("STUDIO_HOME") - base = Path(home) if home else Path.home() / ".unsloth" / "studio" - return base / "swa_cache.json" - - -def _load_swa_cache() -> dict: - global _SWA_CACHE - with _SWA_CACHE_LOCK: - if _SWA_CACHE is not None: - return _SWA_CACHE - try: - with open(_swa_cache_path()) as f: - _SWA_CACHE = json.load(f) - if not isinstance(_SWA_CACHE, dict): - _SWA_CACHE = {} - except (FileNotFoundError, json.JSONDecodeError, OSError): - _SWA_CACHE = {} - return _SWA_CACHE - - -def _save_swa_cache(cache: dict) -> None: - try: - path = _swa_cache_path() - path.parent.mkdir(parents = True, exist_ok = True) - tmp = path.with_suffix(".json.tmp") - with open(tmp, "w") as f: - json.dump(cache, f, indent = 2, sort_keys = True) - tmp.replace(path) - except OSError: - pass - - -def _period_from_layer_types(layer_types: list) -> Optional[int]: - """Smallest period N where `(i+1) % N != 0` matches the SWA mask, - or None if no fixed period fits.""" - if not layer_types: - return None - is_swa = ["full" not in str(t).lower() for t in layer_types] - n = len(is_swa) - for N in range(1, n + 1): - if all(((i + 1) % N != 0) == is_swa[i] for i in range(n)): - return N - return None - - -def _fetch_swa_entry_from_hf(repo_id: str) -> Optional[object]: - try: - from huggingface_hub import hf_hub_download - - cfg_path = hf_hub_download(repo_id, "config.json", repo_type = "model") - with open(cfg_path) as f: - cfg = json.load(f) - except Exception: - return None - - src = cfg.get("text_config") if isinstance(cfg.get("text_config"), dict) else cfg - period = src.get("sliding_window_pattern") - if isinstance(period, int) and period > 0: - return period - lt = src.get("layer_types") - if isinstance(lt, list) and lt: - return _period_from_layer_types(lt) or [ - "full" not in str(t).lower() for t in lt - ] - return None - - -def _arch_aliases(arch: str) -> tuple: - # GGUF emits `falcon-h1`; HF model_type is `falcon_h1`. Normalise both ways. - seen = [] - for a in (arch, arch.replace("-", "_"), arch.replace("_", "-")): - if a and a not in seen: - seen.append(a) - return tuple(seen) - - -def _swa_entry_from_config_obj(cfg) -> Optional[object]: - src = getattr(cfg, "text_config", None) or cfg - period = getattr(src, "sliding_window_pattern", None) - if isinstance(period, int) and period > 0: - return period - lt = getattr(src, "layer_types", None) - if isinstance(lt, list) and lt: - return _period_from_layer_types(lt) or [ - "full" not in str(t).lower() for t in lt - ] - return None - - -_SWA_PATTERN_SOURCE_RE = re.compile( - r"sliding_window_pattern\s*(?::\s*[\w\[\], ]*)?\s*=\s*(\d+)" -) - - -def _resolve_swa_entry_from_transformers(arch: str) -> Optional[object]: - """Default-instantiate the matching Config; on failure, regex-parse - its source for `sliding_window_pattern = N`.""" - try: - from transformers.models.auto.configuration_auto import ( - CONFIG_MAPPING, - CONFIG_MAPPING_NAMES, - ) - except Exception: - return None - - cfg_class = None - for alias in _arch_aliases(arch): - if alias in CONFIG_MAPPING_NAMES: - try: - cfg_class = CONFIG_MAPPING[alias] - break - except Exception: - cfg_class = None - if cfg_class is None: - return None - - try: - if (entry := _swa_entry_from_config_obj(cfg_class())) is not None: - return entry - except Exception: - pass - - import inspect - - candidates = [cfg_class] - text_cfg_class = getattr(cfg_class, "sub_configs", {}).get("text_config") - if text_cfg_class is not None: - candidates.append(text_cfg_class) - for cls in candidates: - try: - src = inspect.getsource(cls) - except (OSError, TypeError): - continue - if m := _SWA_PATTERN_SOURCE_RE.search(src): - period = int(m.group(1)) - if period > 0: - return period - return None - - -def _resolve_swa_pattern( - arch: Optional[str], - n_layers: Optional[int], - source_repo_candidates: tuple = (), - *, - allow_network: Optional[bool] = None, -) -> Optional[list]: - if not arch or not n_layers: - return None - if allow_network is None: - allow_network = os.environ.get("UNSLOTH_STUDIO_OFFLINE", "0") not in ( - "1", - "true", - "True", - "yes", - ) - - cache = _load_swa_cache() - - def _entry_to_mask(entry): - if isinstance(entry, int) and entry > 0: - return [(i + 1) % entry != 0 for i in range(n_layers)] - if isinstance(entry, list) and entry: - return [bool(entry[i % len(entry)]) for i in range(n_layers)] - return None - - def _persist(entry): - with _SWA_CACHE_LOCK: - cache[arch] = entry - _save_swa_cache(cache) - - if (entry := cache.get(arch)) is not None: - if (mask := _entry_to_mask(entry)) is not None: - return mask - - if (entry := _BOOTSTRAP_SWA_DEFAULTS.get(arch)) is not None: - return _entry_to_mask(entry) - - entry = _resolve_swa_entry_from_transformers(arch) - if entry is not None: - _persist(entry) - return _entry_to_mask(entry) - - # Tier 3: live HF fetch (with persistent caching of the result) - if allow_network: - for repo_id in source_repo_candidates: - if not repo_id: - continue - entry = _fetch_swa_entry_from_hf(repo_id) - if entry is not None: - _persist(entry) - return _entry_to_mask(entry) - - return None - - -def _hf_repo_from_url(url: Optional[str]) -> Optional[str]: - """Strip `https://huggingface.co/owner/name(/...)` to `owner/name`.""" - if not url or "huggingface.co/" not in url: - return None - tail = url.split("huggingface.co/", 1)[1].rstrip("/") - parts = tail.split("/") - if len(parts) < 2: - return None - return f"{parts[0]}/{parts[1]}" - - # Model size extraction — lazy import to avoid pulling in transformers # at module level. See PR description for the full explanation. def _extract_model_size_b(model_id: str): @@ -336,84 +81,6 @@ def _extract_model_size_b(model_id: str): _TC_PARAM_CLOSE_RE = re.compile(r"\s*\s*$") -_TOOL_TEMPLATE_MARKERS = ( - "{%- if tools %}", - "{%- if tools -%}", - "{% if tools %}", - "{% if tools -%}", - '"role" == "tool"', - "'role' == 'tool'", - 'message.role == "tool"', - "message.role == 'tool'", -) - - -def detect_reasoning_flags( - chat_template: Optional[str], - model_identifier: Optional[str] = None, - *, - log_source: Optional[str] = None, -) -> dict: - """Classify a chat template's reasoning and tool-calling capabilities. - - Returns a dict with the same five keys populated by the GGUF sniffer: - ``supports_reasoning``, ``reasoning_style`` - (``"enable_thinking"`` | ``"reasoning_effort"``), - ``reasoning_always_on``, ``supports_preserve_thinking``, and - ``supports_tools``. Used by both the llama-server backend at load - time and the safetensors/transformers paths in ``routes/inference`` - so the two agree on what the frontend will see. - """ - flags = { - "supports_reasoning": False, - "reasoning_style": "enable_thinking", - "reasoning_always_on": False, - "supports_preserve_thinking": False, - "supports_tools": False, - } - if not chat_template: - return flags - tpl = chat_template - prefix = f"{log_source}: " if log_source else "" - - if "enable_thinking" in tpl: - flags["supports_reasoning"] = True - flags["reasoning_style"] = "enable_thinking" - logger.info(f"{prefix}model supports reasoning (enable_thinking)") - elif "reasoning_effort" in tpl: - # gpt-oss / Harmony templates use reasoning_effort - # ("low" | "medium" | "high") instead of a boolean. - flags["supports_reasoning"] = True - flags["reasoning_style"] = "reasoning_effort" - logger.info(f"{prefix}model supports reasoning (reasoning_effort)") - elif "thinking" in tpl: - # DeepSeek uses 'thinking' instead of 'enable_thinking' - normalized_id = (model_identifier or "").lower() - if "deepseek" in normalized_id: - flags["supports_reasoning"] = True - logger.info(f"{prefix}model supports reasoning (DeepSeek thinking)") - - # Hardcoded tags or reasoning_content in the template mean - # thinking is always on (no toggle to disable it). - if not flags["supports_reasoning"]: - if ("" in tpl and "" in tpl) or "reasoning_content" in tpl: - flags["supports_reasoning"] = True - flags["reasoning_always_on"] = True - logger.info(f"{prefix}model always reasons ( tags in template)") - - # preserve_thinking is an independent kwarg on some Qwen templates - # that keeps historical blocks in prior assistant turns. - if "preserve_thinking" in tpl: - flags["supports_preserve_thinking"] = True - logger.info(f"{prefix}model supports preserve_thinking") - - if any(marker in tpl for marker in _TOOL_TEMPLATE_MARKERS): - flags["supports_tools"] = True - logger.info(f"{prefix}model supports tool calling") - - return flags - - class LlamaCppBackend: """ Manages a llama-server subprocess for GGUF model inference. @@ -439,8 +106,6 @@ def __init__(self): self._chat_template: Optional[str] = None self._supports_reasoning: bool = False self._reasoning_always_on: bool = False - self._reasoning_style: str = "enable_thinking" - self._supports_preserve_thinking: bool = False self._supports_tools: bool = False self._cache_type_kv: Optional[str] = None self._reasoning_default: bool = True @@ -448,24 +113,17 @@ def __init__(self): # KV-cache estimation fields (populated by _read_gguf_metadata) self._n_layers: Optional[int] = None self._n_kv_heads: Optional[int] = None - self._n_kv_heads_by_layer: Optional[list[int]] = None self._n_heads: Optional[int] = None self._embedding_length: Optional[int] = None - # Architecture-aware KV fields for 5-path estimation + # Architecture-aware KV fields (8 new fields for 5-path estimation) self._kv_key_length: Optional[int] = None self._kv_value_length: Optional[int] = None self._sliding_window: Optional[int] = None - self._sliding_window_pattern: Optional[list[bool]] = None self._full_attention_interval: Optional[int] = None self._kv_lora_rank: Optional[int] = None self._key_length_mla: Optional[int] = None - self._kv_key_length_swa: Optional[int] = None - self._kv_value_length_swa: Optional[int] = None self._ssm_inner_size: Optional[int] = None self._ssm_state_size: Optional[int] = None - # Last N layers reuse KV from earlier layers and don't allocate - # their own cache (Gemma 3n / Gemma 4: .attention.shared_kv_layers). - self._shared_kv_layers: Optional[int] = None self._lock = threading.Lock() self._stdout_lines: list[str] = [] self._stdout_thread: Optional[threading.Thread] = None @@ -509,17 +167,7 @@ def context_length(self) -> Optional[int]: @property def max_context_length(self) -> Optional[int]: - """Return the largest context that fits on this hardware at load time. - - This is the "safe zone" threshold the UI renders warnings - against. For a model whose weights fit on some GPU subset, it - is the binary-search cap from ``_fit_context_to_vram`` for that - subset. For a model whose weights exceed 90% of every GPU - subset, it is the 4096 fallback -- the spec's default when the - model will not fit. The UI slider ceiling is - ``native_context_length``; dragging above ``max_context_length`` - triggers the "might be slower" warning. - """ + """Return the maximum context currently available on this hardware.""" return self._max_context_length or self._context_length @property @@ -527,96 +175,6 @@ def native_context_length(self) -> Optional[int]: """Return the model's native context length from GGUF metadata.""" return self._context_length - def load_progress(self) -> Optional[dict]: - """Return live model-load progress, or None if not loading. - - While llama-server is warming up, its process is typically in - kernel state D (disk sleep) mmap'ing the weight shards into - page cache before pushing layers to VRAM. During that window - ``/api/inference/status`` only reports ``loading``, which gives - the UI nothing to display besides a spinner that looks stuck - for minutes on large MoE models. - - This method samples ``/proc//status VmRSS`` against the - sum of the GGUF shard sizes so the UI can render a real bar - and compute rate / ETA. Returns ``None`` when no load is in - flight (no process, or process already healthy). - - Shape:: - - { - "phase": "mmap" | "ready", - "bytes_loaded": int, # VmRSS of the llama-server - "bytes_total": int, # sum of shard file sizes - "fraction": float, # bytes_loaded / bytes_total, 0..1 - } - - Linux-only in the current implementation. On macOS/Windows the - equivalent would be a different API; this returns ``None`` on - platforms where ``/proc//status`` is unavailable. - """ - proc = self._process - if proc is None: - return None - pid = proc.pid - if pid is None: - return None - - # Sum up shard sizes (primary + any extras sitting alongside). - bytes_total = 0 - gguf_path = self._gguf_path - if gguf_path: - primary = Path(gguf_path) - try: - if primary.is_file(): - bytes_total += primary.stat().st_size - except OSError: - pass - # Extra shards live alongside the primary with the same prefix - # before the shard index (e.g. ``-00001-of-00004.gguf``). - try: - parent = primary.parent - stem = primary.name - m = _SHARD_RE.match(stem) - prefix = m.group(1) if m else None - if prefix and parent.is_dir(): - for sibling in parent.iterdir(): - if ( - sibling.is_file() - and sibling.name.startswith(prefix) - and sibling.name != stem - and sibling.suffix == ".gguf" - ): - try: - bytes_total += sibling.stat().st_size - except OSError: - pass - except OSError: - pass - - # Read VmRSS from /proc//status. Kilobytes on Linux. - bytes_loaded = 0 - try: - with open(f"/proc/{pid}/status", "r", encoding = "utf-8") as f: - for line in f: - if line.startswith("VmRSS:"): - kb = int(line.split()[1]) - bytes_loaded = kb * 1024 - break - except (FileNotFoundError, PermissionError, ValueError, OSError): - return None - - phase = "ready" if self._healthy else "mmap" - fraction = 0.0 - if bytes_total > 0: - fraction = min(1.0, bytes_loaded / bytes_total) - return { - "phase": phase, - "bytes_loaded": bytes_loaded, - "bytes_total": bytes_total, - "fraction": round(fraction, 4), - } - @property def chat_template(self) -> Optional[str]: return self._chat_template @@ -629,51 +187,10 @@ def supports_reasoning(self) -> bool: def reasoning_always_on(self) -> bool: return self._reasoning_always_on - @property - def reasoning_style(self) -> str: - return self._reasoning_style - - @property - def supports_preserve_thinking(self) -> bool: - return self._supports_preserve_thinking - @property def reasoning_default(self) -> bool: return self._reasoning_default - def _reasoning_kwargs(self, enable_thinking: bool) -> dict: - if self._reasoning_style == "reasoning_effort": - return {"reasoning_effort": "high" if enable_thinking else "low"} - return {"enable_thinking": enable_thinking} - - def _request_reasoning_kwargs( - self, - enable_thinking: Optional[bool], - reasoning_effort: Optional[str] = None, - preserve_thinking: Optional[bool] = None, - ) -> Optional[dict]: - """Build chat_template_kwargs from per-request reasoning fields. - - Produces a merged dict covering the active model's reasoning style - (``enable_thinking`` or ``reasoning_effort``) plus the independent - ``preserve_thinking`` kwarg when the template supports it. - """ - kwargs: dict = {} - # Always-on reasoning models hardcode tags in their template - # and do not consume enable_thinking / reasoning_effort -- skip. - if self._supports_reasoning and not self._reasoning_always_on: - if self._reasoning_style == "reasoning_effort": - if reasoning_effort in ("low", "medium", "high"): - kwargs["reasoning_effort"] = reasoning_effort - elif enable_thinking is not None: - kwargs["reasoning_effort"] = "high" if enable_thinking else "low" - else: - if enable_thinking is not None: - kwargs["enable_thinking"] = enable_thinking - if self._supports_preserve_thinking and preserve_thinking is not None: - kwargs["preserve_thinking"] = preserve_thinking - return kwargs or None - @property def supports_tools(self) -> bool: return self._supports_tools @@ -805,24 +322,14 @@ def _get_gguf_size_bytes(model_path: str) -> int: @staticmethod def _get_gpu_free_memory() -> list[tuple[int, int]]: - """Query free memory per GPU. - - Order: - 1. ``nvidia-smi`` (NVIDIA CUDA hosts) -- respects - ``CUDA_VISIBLE_DEVICES``. - 2. ``torch.cuda.mem_get_info`` -- universal fallback that - works on AMD ROCm too because the HIP runtime - reuses the entire ``torch.cuda.*`` namespace. Covers the - AMD case for issue #5106 (nvidia-smi-only probe silently - returned [] on AMD hosts) and also rescues NVIDIA hosts - where ``nvidia-smi`` is missing from PATH. - - Returns list of (gpu_index, free_mib) sorted by index. Empty - list if no supported GPU is reachable. + """Query free memory per GPU via nvidia-smi. + + Returns list of (gpu_index, free_mib) sorted by index. + Respects CUDA_VISIBLE_DEVICES if set. + Returns empty list if nvidia-smi is not available. """ import os - # ── NVIDIA via nvidia-smi ──────────────────────────────────── try: result = subprocess.run( [ @@ -833,98 +340,31 @@ def _get_gpu_free_memory() -> list[tuple[int, int]]: capture_output = True, text = True, timeout = 10, - env = child_env_without_native_path_secret(), - **_windows_hidden_subprocess_kwargs(), ) - if result.returncode == 0: - allowed: Optional[set[int]] = None - cvd = os.environ.get("CUDA_VISIBLE_DEVICES") - if cvd is not None: - try: - # `if x.strip()` filters trailing-comma masks like - # "0,1," which would otherwise raise ValueError on - # an empty token. An explicitly empty mask (CVD="") - # yields an empty `allowed` set so all GPUs are - # filtered out, matching the codebase convention. - allowed = set( - int(x.strip()) for x in cvd.split(",") if x.strip() - ) - except ValueError: - pass - gpus: list[tuple[int, int]] = [] - for line in result.stdout.strip().splitlines(): - parts = line.split(",") - if len(parts) == 2: - idx = int(parts[0].strip()) - free_mib = int(parts[1].strip()) - if allowed is not None and idx not in allowed: - continue - gpus.append((idx, free_mib)) - # Match the docstring's sort-by-id guarantee. nvidia-smi - # almost always returns sorted output, but driver order - # is not formally guaranteed. - gpus.sort(key = lambda g: g[0]) - if gpus: - return gpus - except Exception as e: - logger.debug(f"nvidia-smi probe failed: {e}") - - # ── Torch fallback (covers AMD ROCm and missing nvidia-smi) ── - try: - import torch - - if not hasattr(torch, "cuda") or not torch.cuda.is_available(): - return [] - if not hasattr(torch.cuda, "mem_get_info"): + if result.returncode != 0: return [] - # torch.cuda enumerates GPUs RELATIVE to the visibility mask. - # On NVIDIA builds the mask is CUDA_VISIBLE_DEVICES; on AMD - # ROCm builds it is HIP_VISIBLE_DEVICES (or ROCR_VISIBLE_DEVICES - # if HIP is unset). Downstream we feed these IDs back into the - # llama-server subprocess as CVD, so we must translate visible - # ordinals back to physical indices first; otherwise launching - # with ``CUDA_VISIBLE_DEVICES=2,3`` would get rewritten to - # ``CUDA_VISIBLE_DEVICES=0,1`` and target the wrong GPUs. - physical_ids: Optional[list[int]] = None - # Match the codebase convention in - # ``utils/hardware/hardware.py::_get_parent_visible_gpu_spec``: - # treat an explicitly empty mask (``HIP_VISIBLE_DEVICES=""``) - # as "set to no GPUs" rather than falling through to the next - # var. ``or`` would coerce empty string to falsy and silently - # promote the wrong source. - if getattr(torch.version, "hip", None) is not None: - hip_v = os.environ.get("HIP_VISIBLE_DEVICES") - rocr_v = os.environ.get("ROCR_VISIBLE_DEVICES") - cvd = ( - hip_v - if hip_v is not None - else rocr_v - if rocr_v is not None - else os.environ.get("CUDA_VISIBLE_DEVICES") - ) - else: - cvd = os.environ.get("CUDA_VISIBLE_DEVICES") - if cvd is not None: + + # Parse which GPUs are allowed by existing CUDA_VISIBLE_DEVICES + allowed = None + cvd = os.environ.get("CUDA_VISIBLE_DEVICES") + if cvd is not None and cvd.strip(): try: - # Empty mask (CVD="") yields an empty list so the - # below loop produces no GPUs, consistent with the - # nvidia-smi path and utils/hardware/hardware.py. - physical_ids = [int(x.strip()) for x in cvd.split(",") if x.strip()] + allowed = set(int(x.strip()) for x in cvd.split(",")) except ValueError: - physical_ids = None + pass # Non-numeric (e.g., "GPU-uuid"), ignore filter + gpus = [] - for ordinal in range(torch.cuda.device_count()): - free_bytes, _total_bytes = torch.cuda.mem_get_info(ordinal) - idx = ( - physical_ids[ordinal] - if physical_ids is not None and ordinal < len(physical_ids) - else ordinal - ) - gpus.append((idx, free_bytes // (1024 * 1024))) - # Match the nvidia-smi path's docstring guarantee of sorted-by-id. - return sorted(gpus, key = lambda g: g[0]) + for line in result.stdout.strip().splitlines(): + parts = line.split(",") + if len(parts) == 2: + idx = int(parts[0].strip()) + free_mib = int(parts[1].strip()) + if allowed is not None and idx not in allowed: + continue + gpus.append((idx, free_mib)) + return gpus except Exception as e: - logger.debug(f"torch GPU probe failed: {e}") + logger.debug(f"Failed to query GPU free memory via nvidia-smi: {e}") return [] @staticmethod @@ -984,29 +424,13 @@ def _can_estimate_kv(self) -> bool: # New-style: need both explicit key AND value dimensions if self._kv_key_length is not None and self._kv_value_length is not None: return True - # Legacy: need embedding_length + a head count (scalar or per-layer). + # Legacy: need embedding_length + head count return self._embedding_length is not None and ( - self._n_kv_heads is not None - or self._n_heads is not None - or self._n_kv_heads_by_layer is not None + self._n_kv_heads is not None or self._n_heads is not None ) - def _kv_heads_for_layer(self, layer_idx: int, fallback: int) -> int: - if self._n_kv_heads_by_layer is not None and layer_idx < len( - self._n_kv_heads_by_layer - ): - return self._n_kv_heads_by_layer[layer_idx] - return fallback - def _estimate_kv_cache_bytes( - self, - n_ctx: int, - cache_type_kv: Optional[str] = None, - *, - swa_full: bool = False, - n_parallel: int = 1, - kv_unified: bool = True, - ctx_checkpoints: int = 0, + self, n_ctx: int, cache_type_kv: Optional[str] = None ) -> int: """Estimate KV cache VRAM for a given context length. @@ -1017,34 +441,12 @@ def _estimate_kv_cache_bytes( 4. GQA -- standard full KV with explicit key/value dimensions 5. Legacy -- fallback using embed // n_heads - Server-flag knobs (mirror llama-server's CLI): - swa_full -- ``--swa-full``: force SWA layers to cache the - full ``n_ctx`` (collapses path 3 to path 4 - sizing for the SWA layers). - n_parallel -- ``--parallel``: number of server slots. - Verified empirically against llama-server: - non-SWA layers stay constant (cells split - across slots), SWA layers scale linearly - (per-slot window). - kv_unified -- ``--kv-unified`` (default on): retained for - API forward-compat. Currently a no-op for - memory math because the unified buffer total - matches per-slot buffers in measured cases. - ctx_checkpoints -- ``--ctx-checkpoints``: SWA snapshot count per - slot (PR #15293). Each snapshot stores one - sliding-window of state per SWA layer. - Returns 0 if metadata is insufficient for estimation. """ if not self._can_estimate_kv() or n_ctx <= 0: return 0 n_layers = self._n_layers # type: ignore[assignment] - # Gemma 3n / Gemma 4 reuse KV from earlier layers in the last - # ``shared_kv_layers`` blocks -- those don't allocate their own - # cache. Floor at 1 so a misconfigured GGUF can't zero out KV. - shared = self._shared_kv_layers or 0 - n_layers_kv = max(1, n_layers - shared) n_kv = self._n_kv_heads or self._n_heads or 1 # type: ignore[assignment] # Bytes per element depends on KV cache quantization @@ -1060,8 +462,6 @@ def _estimate_kv_cache_bytes( "iq4_nl": 0.5625, }.get(cache_type_kv or "f16", 2.0) - slots = max(1, n_parallel) - # Path 1: MLA (DeepSeek-V2/V3, GLM-4.7, GLM-5, Kimi-K2.5) # MLA stores one compressed KV latent per token/layer (shared across heads). # V is reconstructed from the latent on the fly -- no separate V cache. @@ -1072,7 +472,7 @@ def _estimate_kv_cache_bytes( n_kv_mla = self._n_kv_heads or 1 rope_dim = self._key_length_mla or 64 key_len = self._kv_key_length or (self._kv_lora_rank + rope_dim) - return int(n_layers_kv * n_ctx * n_kv_mla * key_len * bpe) + return int(n_layers * n_ctx * n_kv_mla * key_len * bpe) key_len = self._kv_key_length val_len = self._kv_value_length @@ -1090,19 +490,11 @@ def _estimate_kv_cache_bytes( head_dim = self._embedding_length // self._n_heads if self._n_heads else 128 # type: ignore[operator] return int(n_attn * n_ctx * n_kv * 2 * head_dim * bpe) - # Path 3: Sliding window (Gemma 2/3/3n/4, gpt-oss, Cohere2 ...). - # Pattern is filled in by the resolver at parse time; if absent, - # falls through to the legacy 1/4-global heuristic below. - # Per-layer-type ``--parallel N`` accounting (verified empirically - # against ``llama-server``): - # * non-SWA layers: total cells = n_ctx, partitioned across - # slots -> total memory CONSTANT in slots. - # * SWA layers: per-slot cells = 2 * sliding_window - # (capped at n_ctx and at per_slot_ctx - # when ctx is split among many slots) -> - # total memory grows LINEARLY in slots. - # ``--swa-full`` forces full n_ctx for SWA layers instead. - # ``--ctx-checkpoints N`` adds N snapshots per SWA layer per slot. + # Path 3: Sliding Window (Gemma-3, gpt-oss) + # SWA layers only cache min(ctx, window) tokens; global layers cache full ctx. + # Most SWA architectures use few global layers (e.g., Gemma-3 uses 1 in 6). + # Without an explicit field, we conservatively assume 1/4 of layers are global + # which is still far more accurate than the legacy formula (which ignores SWA). if ( self._sliding_window is not None and self._sliding_window > 0 @@ -1110,72 +502,20 @@ def _estimate_kv_cache_bytes( and val_len is not None ): swa = self._sliding_window - per_slot_ctx = max(1, n_ctx // slots) - # ``--swa-full`` makes SWA layers cache the full context just - # like non-SWA: cells get partitioned across slots, so per-slot - # cells = per_slot_ctx and the slots*per-slot product collapses - # back to the constant ``n_ctx`` total. Otherwise SWA caches - # 2*sliding_window per slot, clamped at the per-slot ctx. - swa_cells_per_slot = ( - per_slot_ctx if swa_full else min(n_ctx, 2 * swa, per_slot_ctx) - ) - key_len_swa = self._kv_key_length_swa or key_len - val_len_swa = self._kv_value_length_swa or val_len - if self._sliding_window_pattern is not None: - global_bytes = 0.0 # constant across slots - swa_bytes_per_slot = 0.0 # multiplied by slots - checkpoint_extra_per_slot = 0.0 - # Iterate only over layers that allocate their own KV; - # the trailing ``shared`` layers reuse earlier caches. - for layer_idx in range(n_layers_kv): - layer_n_kv = self._kv_heads_for_layer(layer_idx, n_kv) - is_swa = ( - layer_idx < len(self._sliding_window_pattern) - and self._sliding_window_pattern[layer_idx] - ) - if is_swa: - swa_bytes_per_slot += ( - swa_cells_per_slot - * layer_n_kv - * (key_len_swa + val_len_swa) - * bpe - ) - if ctx_checkpoints > 0 and not swa_full: - checkpoint_extra_per_slot += ( - ctx_checkpoints - * swa - * layer_n_kv - * (key_len_swa + val_len_swa) - * bpe - ) - else: - global_bytes += n_ctx * layer_n_kv * (key_len + val_len) * bpe - return int( - global_bytes - + slots * (swa_bytes_per_slot + checkpoint_extra_per_slot) - ) - n_global = max(1, n_layers_kv // 4) - n_swa = n_layers_kv - n_global + n_global = max(1, n_layers // 4) + n_swa = n_layers - n_global kv_per_token = n_kv * (key_len + val_len) * bpe - kv_per_token_swa = n_kv * (key_len_swa + val_len_swa) * bpe - global_bytes = n_global * n_ctx * kv_per_token - swa_bytes_per_slot = n_swa * swa_cells_per_slot * kv_per_token_swa - checkpoint_extra_per_slot = ( - ctx_checkpoints * n_swa * swa * kv_per_token_swa - if ctx_checkpoints > 0 and not swa_full - else 0.0 - ) return int( - global_bytes + slots * (swa_bytes_per_slot + checkpoint_extra_per_slot) + n_global * n_ctx * kv_per_token + n_swa * min(n_ctx, swa) * kv_per_token ) # Path 4: Standard GQA with explicit key/value dimensions if key_len is not None and val_len is not None: - return int(n_layers_kv * n_ctx * n_kv * (key_len + val_len) * bpe) + return int(n_layers * n_ctx * n_kv * (key_len + val_len) * bpe) # Path 5: Legacy fallback (old GGUFs without explicit dimensions) head_dim = self._embedding_length // self._n_heads if self._n_heads else 128 # type: ignore[operator] - return int(2 * n_kv * head_dim * n_layers_kv * n_ctx * bpe) + return int(2 * n_kv * head_dim * n_layers * n_ctx * bpe) def _fit_context_to_vram( self, @@ -1184,12 +524,6 @@ def _fit_context_to_vram( model_size_bytes: int, cache_type_kv: Optional[str] = None, min_ctx: int = 4096, - *, - swa_full: bool = False, - n_parallel: int = 1, - kv_unified: bool = True, - ctx_checkpoints: int = 0, - kv_on_gpu: bool = True, ) -> int: """Return the largest context length that fits in GPU VRAM. @@ -1197,11 +531,6 @@ def _fit_context_to_vram( threshold -- 10% reserved for compute buffers, CUDA context, scratch space, flash-attn workspace, etc.). If the model weights alone don't fit, returns min_ctx unchanged. - - ``kv_on_gpu`` mirrors ``--kv-offload`` (default on). When False - the KV cache lives in CPU RAM and doesn't compete with weights - for VRAM; the requested context is honored verbatim. The other - keyword args mirror ``_estimate_kv_cache_bytes``. """ if not self._can_estimate_kv(): logger.debug( @@ -1211,22 +540,11 @@ def _fit_context_to_vram( ) return requested_ctx - # KV lives off-GPU: no VRAM accounting needed for the cache itself. - if not kv_on_gpu: - return requested_ctx - - kv_kwargs = dict( - swa_full = swa_full, - n_parallel = n_parallel, - kv_unified = kv_unified, - ctx_checkpoints = ctx_checkpoints, - ) - budget_bytes = available_mib * 1024 * 1024 * 0.90 model_footprint = model_size_bytes # Check if requested context already fits - kv = self._estimate_kv_cache_bytes(requested_ctx, cache_type_kv, **kv_kwargs) + kv = self._estimate_kv_cache_bytes(requested_ctx, cache_type_kv) if model_footprint + kv <= budget_bytes: return requested_ctx @@ -1248,7 +566,7 @@ def _fit_context_to_vram( best = effective_min while lo <= hi: mid = (lo + hi) // 2 - kv = self._estimate_kv_cache_bytes(mid, cache_type_kv, **kv_kwargs) + kv = self._estimate_kv_cache_bytes(mid, cache_type_kv) if kv <= remaining: best = mid lo = mid + 1 @@ -1381,19 +699,6 @@ def _gguf_skip_value(f, vtype: int) -> None: for _ in range(alen): LlamaCppBackend._gguf_skip_value(f, atype) - @staticmethod - def _gguf_read_array_value(f, atype: int, alen: int) -> Optional[list]: - if atype == 4: # UINT32 - return [struct.unpack(" None: """Read context_length, architecture params, and chat_template from a GGUF header. @@ -1406,50 +711,26 @@ def _read_gguf_metadata(self, gguf_path: str) -> None: self._chat_template = None self._supports_reasoning = False self._reasoning_always_on = False - self._reasoning_style = "enable_thinking" - self._reasoning_default = True - self._supports_preserve_thinking = False self._supports_tools = False self._n_layers = None self._n_kv_heads = None - self._n_kv_heads_by_layer = None self._n_heads = None self._embedding_length = None self._kv_key_length = None self._kv_value_length = None self._sliding_window = None - self._sliding_window_pattern = None self._full_attention_interval = None self._kv_lora_rank = None self._key_length_mla = None - self._kv_key_length_swa = None - self._kv_value_length_swa = None self._ssm_inner_size = None self._ssm_state_size = None - self._shared_kv_layers = None try: - WANTED = { - "general.architecture", - "tokenizer.chat_template", - # Source-repo hints for the SWA resolver's HF fallback. - "general.source.huggingface.repository", - "general.source.url", - "general.source.repo_url", - "general.base_model.0.repo_url", - "general.base_model.0.organization", - "general.base_model.0.name", - "general.basename", - "general.organization", - "general.size_label", - "general.finetune", - } + WANTED = {"general.architecture", "tokenizer.chat_template"} # Additional arch-specific keys are added dynamically once # we know the architecture name. arch_keys: dict[str, str] = {} # gguf_key -> attribute name arch = None - sliding_window_pattern_period: Optional[int] = None - general: dict[str, str] = {} with open(gguf_path, "rb") as f: magic = struct.unpack(" None: _tensor_count, kv_count = struct.unpack(" None: f"GGUF metadata: chat_template={len(self._chat_template)} chars" ) # Detect thinking/reasoning support from chat template - flags = detect_reasoning_flags( - self._chat_template, - self._model_identifier, - log_source = "GGUF metadata", - ) - self._supports_reasoning = flags["supports_reasoning"] - self._reasoning_style = flags["reasoning_style"] - self._reasoning_always_on = flags["reasoning_always_on"] - self._supports_preserve_thinking = flags["supports_preserve_thinking"] - self._supports_tools = flags["supports_tools"] + tpl = self._chat_template + if "enable_thinking" in tpl: + self._supports_reasoning = True + logger.info( + "GGUF metadata: model supports reasoning (enable_thinking)" + ) + elif "thinking" in tpl: + # DeepSeek uses 'thinking' instead of 'enable_thinking' + normalized_id = (self._model_identifier or "").lower() + if "deepseek" in normalized_id: + self._supports_reasoning = True + logger.info( + "GGUF metadata: model supports reasoning (DeepSeek thinking)" + ) + # Models with hardcoded tags or reasoning_content + # in their chat template always produce thinking output + # (no toggle to disable it). + if not self._supports_reasoning: + if ( + "" in tpl + and "" in tpl + or "reasoning_content" in tpl + ): + self._supports_reasoning = True + self._reasoning_always_on = True + logger.info( + "GGUF metadata: model always reasons ( tags in template)" + ) + # Detect tool calling support from chat template + tool_markers = [ + "{%- if tools %}", + "{%- if tools -%}", + "{% if tools %}", + "{% if tools -%}", + '"role" == "tool"', + "'role' == 'tool'", + 'message.role == "tool"', + "message.role == 'tool'", + ] + if any(marker in tpl for marker in tool_markers): + self._supports_tools = True + logger.info("GGUF metadata: model supports tool calling") except Exception as e: logger.warning(f"Failed to read GGUF metadata: {e}") @@ -1683,34 +904,10 @@ def _download_gguf( try: import os - from huggingface_hub import get_paths_info, try_to_load_from_cache + from huggingface_hub import get_paths_info path_infos = list(get_paths_info(hf_repo, all_gguf_files, token = hf_token)) - total_bytes = sum((p.size or 0) for p in path_infos) - - # Subtract bytes already present in the HF cache so we only - # preflight against what we actually have to download. Without - # this, re-loading a cached large model (e.g. MiniMax-M2.7-GGUF - # at 131 GB) fails cold whenever free disk is below the full - # weight footprint, even though nothing needs downloading. - already_cached_bytes = 0 - for p in path_infos: - if not p.size: - continue - try: - cached_path = try_to_load_from_cache(hf_repo, p.path) - except Exception: - cached_path = None - if isinstance(cached_path, str) and os.path.exists(cached_path): - try: - on_disk = os.path.getsize(cached_path) - except OSError: - on_disk = 0 - # Count as satisfied only when the full blob is present. - if on_disk >= p.size: - already_cached_bytes += p.size - - total_download_bytes = max(0, total_bytes - already_cached_bytes) + total_download_bytes = sum((p.size or 0) for p in path_infos) if total_download_bytes > 0: cache_dir = os.environ.get( @@ -1722,11 +919,9 @@ def _download_gguf( total_gb = total_download_bytes / (1024**3) free_gb = free_bytes / (1024**3) - cached_gb = already_cached_bytes / (1024**3) logger.info( - f"GGUF download: {total_gb:.1f} GB needed " - f"({cached_gb:.1f} GB already cached), " + f"GGUF download: {total_gb:.1f} GB needed, " f"{free_gb:.1f} GB free on disk" ) @@ -1829,7 +1024,7 @@ def _download_mmproj( # Prefer F16 variant target = None for f in mmproj_files: - if f.lower().endswith("-f16.gguf"): + if "f16" in f.lower(): target = f break if target is None: @@ -1868,8 +1063,6 @@ def load_model( speculative_type: Optional[str] = None, n_threads: Optional[int] = None, n_gpu_layers: Optional[int] = None, # Accepted for caller compat, unused - n_parallel: int = 1, - extra_args: Optional[List[str]] = None, ) -> bool: """ Start llama-server with a GGUF model. @@ -1992,38 +1185,43 @@ def load_model( pool_mib, model_size, cache_type_kv, - n_parallel = n_parallel, - ) - kv = self._estimate_kv_cache_bytes( - capped, cache_type_kv, n_parallel = n_parallel ) + kv = self._estimate_kv_cache_bytes(capped, cache_type_kv) total_mib = (model_size + kv) / (1024 * 1024) if total_mib <= pool_mib * 0.90: best_cap = max(best_cap, capped) if best_cap > 0: max_available_ctx = best_cap - else: - # Weights exceed 90% of every GPU subset's free - # memory, so there is no fitting context. Anchor - # the UI's "safe zone" threshold at 4096 (the - # spec's default when the model cannot fit) so - # the ctx slider shows the "might be slower" - # warning as soon as the user drags above the - # fallback default instead of never. - max_available_ctx = min(4096, native_ctx_for_cap) if explicit_ctx: - # Honor the user's requested context verbatim. If it - # fits, pin GPUs and skip --fit; if it doesn't, ship - # -c --fit on and let llama-server flex - # -ngl (CPU layer offload). The UI is expected to - # have surfaced the "might be slower" warning before - # the user submitted a ctx above the fit ceiling. + # Try to honor the user's requested context exactly. requested_total = model_size + self._estimate_kv_cache_bytes( - effective_ctx, cache_type_kv, n_parallel = n_parallel + effective_ctx, cache_type_kv ) gpu_indices, use_fit = self._select_gpus(requested_total, gpus) - # No silent shrink: effective_ctx stays == n_ctx. + + # Full context doesn't fit anywhere -- cap it on the + # best GPU subset we can find (fewest GPUs first). + if use_fit: + ranked = sorted(gpus, key = lambda g: g[1], reverse = True) + for n_gpus in range(1, len(ranked) + 1): + subset = ranked[:n_gpus] + pool_mib = sum(free for _, free in subset) + capped = self._fit_context_to_vram( + effective_ctx, + pool_mib, + model_size, + cache_type_kv, + ) + kv = self._estimate_kv_cache_bytes( + capped, cache_type_kv + ) + total_mib = (model_size + kv) / (1024 * 1024) + if total_mib <= pool_mib * 0.90: + effective_ctx = capped + gpu_indices = sorted(idx for idx, _ in subset) + use_fit = False + break else: # Auto context: prefer fewer GPUs, cap context to fit. ranked = sorted(gpus, key = lambda g: g[1], reverse = True) @@ -2035,24 +1233,14 @@ def load_model( pool_mib, model_size, cache_type_kv, - n_parallel = n_parallel, - ) - kv = self._estimate_kv_cache_bytes( - capped, cache_type_kv, n_parallel = n_parallel ) + kv = self._estimate_kv_cache_bytes(capped, cache_type_kv) total_mib = (model_size + kv) / (1024 * 1024) if total_mib <= pool_mib * 0.90: effective_ctx = capped gpu_indices = sorted(idx for idx, _ in subset) use_fit = False break - else: - # No subset can host the weights (weights alone - # exceed 90% of every pool). Per spec, default - # the UI-visible context to 4096 and let - # --fit on flex -ngl so llama-server offloads - # layers to CPU RAM. - effective_ctx = min(4096, effective_ctx) elif gpus: # Can't estimate KV -- fall back to file-size-only check. @@ -2063,18 +1251,9 @@ def load_model( model_size_gb = round(model_size / (1024**3), 2), ) gpu_indices, use_fit = self._select_gpus(model_size, gpus) - if use_fit and not explicit_ctx: - # Weights don't fit on any subset. Default the UI to - # 4096 so the slider doesn't land on an unusable native - # context. --fit on will flex -ngl at runtime. - effective_ctx = ( - min(4096, effective_ctx) if effective_ctx > 0 else 4096 - ) if effective_ctx < original_ctx: - kv_est = self._estimate_kv_cache_bytes( - effective_ctx, cache_type_kv, n_parallel = n_parallel - ) + kv_est = self._estimate_kv_cache_bytes(effective_ctx, cache_type_kv) logger.info( f"Context auto-reduced: {original_ctx} -> {effective_ctx} " f"(model: {model_size / (1024**3):.1f} GB, " @@ -2082,7 +1261,7 @@ def load_model( ) kv_cache_bytes = self._estimate_kv_cache_bytes( - effective_ctx, cache_type_kv, n_parallel = n_parallel + effective_ctx, cache_type_kv ) logger.info( f"GGUF size: {model_size / (1024**3):.1f} GB, " @@ -2104,11 +1283,9 @@ def load_model( "-c", str(effective_ctx) if effective_ctx > 0 else "0", "--parallel", - str(n_parallel), + "1", # Single-user studio, saves VRAM "--flash-attn", "on", # Force flash attention for speed - # Error out at n_ctx instead of silently rotating the KV cache; frontend catches it and points the user at "Context Length". - "--no-context-shift", ] if use_fit: @@ -2117,10 +1294,8 @@ def load_model( # Model fits on selected GPU(s) -- offload all layers cmd.extend(["-ngl", "-1"]) - # -1 = llama.cpp auto-detect (physical cores). Pass explicitly so we - # do not inherit llama-server's internal default, which has historically - # varied (hardware concurrency incl. hyperthreads on some builds). - cmd.extend(["--threads", str(n_threads if n_threads is not None else -1)]) + if n_threads is not None: + cmd.extend(["--threads", str(n_threads)]) # Always enable Jinja chat template rendering for proper template support cmd.extend(["--jinja"]) @@ -2152,7 +1327,7 @@ def load_model( # existing text (code refactoring, summarization, reasoning). # For general chat with low repetition, overhead is ~5 ms. # - # Benchmarks from upstream llama.cpp speculative-decoding PRs: + # Benchmarks from llama.cpp PRs #18471, #19164: # Scenario | Without | With | Speedup # gpt-oss-120b code refactor | 181 t/s | 446 t/s | 2.5x # Qwen3-235B offloaded | 12 t/s | 21 t/s | 1.8x @@ -2165,21 +1340,11 @@ def load_model( # ref: https://github.com/ggml-org/llama.cpp/blob/master/docs/speculative.md # ref: https://github.com/ggml-org/llama.cpp/pull/19164 # ref: https://github.com/ggml-org/llama.cpp/pull/18471 - # ``"default"`` -> let llama-server pick a sensible spec - # config via ``--spec-default``. Explicit type names are - # passed through with the manual draft tuning we've shipped - # historically so power users keep their overrides. _valid_spec_types = {"ngram-simple", "ngram-mod"} - normalized_spec = ( - speculative_type.lower().strip() if speculative_type else None - ) - if normalized_spec and normalized_spec != "off" and not is_vision: - if normalized_spec == "default": - cmd.append("--spec-default") - self._speculative_type = "default" - elif normalized_spec in _valid_spec_types: - cmd.extend(["--spec-type", normalized_spec]) - if normalized_spec == "ngram-mod": + if speculative_type and speculative_type in _valid_spec_types: + if not is_vision: # spec decoding disabled for vision models + cmd.extend(["--spec-type", speculative_type]) + if speculative_type == "ngram-mod": cmd.extend( [ "--spec-ngram-size-n", @@ -2190,7 +1355,7 @@ def load_model( "64", ] ) - self._speculative_type = normalized_spec + self._speculative_type = speculative_type else: self._speculative_type = None else: @@ -2200,18 +1365,6 @@ def load_model( if chat_template_override: import tempfile - self._chat_template = chat_template_override - flags = detect_reasoning_flags( - self._chat_template, - self._model_identifier, - log_source = "GGUF chat template override", - ) - self._supports_reasoning = flags["supports_reasoning"] - self._reasoning_style = flags["reasoning_style"] - self._reasoning_always_on = flags["reasoning_always_on"] - self._supports_preserve_thinking = flags["supports_preserve_thinking"] - self._supports_tools = flags["supports_tools"] - self._chat_template_file = tempfile.NamedTemporaryFile( mode = "w", suffix = ".jinja", @@ -2226,25 +1379,25 @@ def load_model( ) # For reasoning models, set default thinking mode. - # Qwen3.5/3.6 models below 9B (0.8B, 2B, 4B) disable thinking by default. + # Qwen3.5 models below 9B (0.8B, 2B, 4B) disable thinking by default. # Only 9B and larger enable thinking. - # Always-on templates ignore the kwarg entirely, so skip. - if self._supports_reasoning and not self._reasoning_always_on: + if self._supports_reasoning: thinking_default = True mid = (model_identifier or "").lower() - if "qwen3.5" in mid or "qwen3.6" in mid: + if "qwen3.5" in mid: size_val = _extract_model_size_b(mid) if size_val is not None and size_val < 9: thinking_default = False self._reasoning_default = thinking_default - reasoning_kw = self._reasoning_kwargs(thinking_default) cmd.extend( [ "--chat-template-kwargs", - json.dumps(reasoning_kw), + json.dumps({"enable_thinking": thinking_default}), ] ) - logger.info(f"Reasoning model: {reasoning_kw} by default") + logger.info( + f"Reasoning model: enable_thinking={thinking_default} by default" + ) if mmproj_path: if not Path(mmproj_path).is_file(): @@ -2264,17 +1417,6 @@ def load_model( else: self._api_key = None - # User-supplied pass-through args go last so llama.cpp's - # last-wins flag parsing lets the user override Studio's - # auto-set tier-2 flags (e.g. --cache-type-k, --spec-type). - # The route layer has already validated this list against - # the managed-flag denylist via validate_extra_args(). - if extra_args: - cmd.extend(str(a) for a in extra_args) - logger.info( - f"Appending user extra args to llama-server: {list(extra_args)}" - ) - _log_cmd = list(cmd) if "--api-key" in _log_cmd: _ki = _log_cmd.index("--api-key") + 1 @@ -2286,7 +1428,7 @@ def load_model( import os import sys - env = child_env_without_native_path_secret() + env = os.environ.copy() binary_dir = str(Path(binary).parent) if sys.platform == "win32": @@ -2370,29 +1512,9 @@ def load_model( f"{new_ld}:{existing_ld}" if existing_ld else new_ld ) - # Pin to selected GPU(s). On ROCm, llama-server (and any torch - # in the subprocess) honors HIP_VISIBLE_DEVICES / ROCR_VISIBLE_DEVICES; - # narrowing only CUDA_VISIBLE_DEVICES leaves an AMD child seeing - # the full HIP/ROCR set the parent inherited. + # Pin to selected GPU(s) via CUDA_VISIBLE_DEVICES if gpu_indices is not None: - pinned = ",".join(str(i) for i in gpu_indices) - env["CUDA_VISIBLE_DEVICES"] = pinned - try: - import torch as _torch - - if getattr(_torch.version, "hip", None) is not None: - env["HIP_VISIBLE_DEVICES"] = pinned - env["ROCR_VISIBLE_DEVICES"] = pinned - except Exception as e: - logger.debug( - "Failed to set ROCm visibility env vars for child: %s", e - ) - - # Defensive kill: if a concurrent load slipped past Phase 1 - # (because its `self._process` was None at the time) and - # already stored a Popen handle here, drop that orphan - # before we overwrite the reference. See issue #5161. - self._kill_process() + env["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in gpu_indices) self._stdout_lines = [] self._process = subprocess.Popen( @@ -2401,7 +1523,6 @@ def load_model( stderr = subprocess.STDOUT, text = True, env = env, - **_windows_hidden_subprocess_kwargs(), ) # Start background thread to drain stdout and prevent pipe deadlock @@ -2410,12 +1531,7 @@ def load_model( ) self._stdout_thread.start() - # Store the resolved on-disk path, not the caller's kwarg. In - # HF mode the caller passes gguf_path=None and the real path - # (``model_path``) is what llama-server is actually mmap'ing. - # Downstream consumers (load_progress, log lines, etc.) need - # the path that exists on disk. - self._gguf_path = model_path + self._gguf_path = gguf_path self._hf_repo = hf_repo # For local GGUF files, extract variant from filename if not provided if hf_variant: @@ -2447,28 +1563,6 @@ def load_model( # Wait for llama-server to become healthy if not self._wait_for_health(timeout = 600.0): self._kill_process() - _gguf = gguf_path or "" - _is_ollama = ( - ".studio_links" in _gguf - or os.sep + "ollama_links" + os.sep in _gguf - or os.sep + ".cache" + os.sep + "ollama" + os.sep in _gguf - or (self._model_identifier or "").startswith("ollama/") - ) - # Only show the Ollama-specific message when the server - # output indicates a GGUF compatibility issue, not for - # unrelated failures like OOM or missing binaries. - if _is_ollama: - _output = "\n".join(self._stdout_lines[-50:]).lower() - _gguf_compat_hints = ( - "key not found", - "unknown model architecture", - "failed to load model", - ) - if any(h in _output for h in _gguf_compat_hints): - raise RuntimeError( - "Some Ollama models do not work with llama.cpp. " - "Try a different model, or use this model directly through Ollama instead." - ) raise RuntimeError( "llama-server failed to start. " "Check that the GGUF file is valid and you have enough memory." @@ -2503,29 +1597,21 @@ def unload_model(self) -> bool: self._chat_template = None self._supports_reasoning = False self._reasoning_always_on = False - self._reasoning_style = "enable_thinking" - self._reasoning_default = True - self._supports_preserve_thinking = False self._supports_tools = False self._cache_type_kv = None self._speculative_type = None self._n_layers = None self._n_kv_heads = None - self._n_kv_heads_by_layer = None self._n_heads = None self._embedding_length = None self._kv_key_length = None self._kv_value_length = None self._sliding_window = None - self._sliding_window_pattern = None self._full_attention_interval = None self._kv_lora_rank = None self._key_length_mla = None - self._kv_key_length_swa = None - self._kv_value_length_swa = None self._ssm_inner_size = None self._ssm_state_size = None - self._shared_kv_layers = None # Clean up temp chat template file if hasattr(self, "_chat_template_file") and self._chat_template_file: try: @@ -2681,7 +1767,6 @@ def _kill_orphaned_servers(): capture_output = True, text = True, timeout = 5, - env = child_env_without_native_path_secret(), ) if result.returncode != 0: return @@ -3062,8 +2147,6 @@ def generate_chat_completion( stop: Optional[list[str]] = None, cancel_event: Optional[threading.Event] = None, enable_thinking: Optional[bool] = None, - reasoning_effort: Optional[str] = None, - preserve_thinking: Optional[bool] = None, ) -> Generator[str | dict, None, None]: """ Send a chat completion request to llama-server and stream tokens back. @@ -3088,21 +2171,11 @@ def generate_chat_completion( "repeat_penalty": repetition_penalty, "presence_penalty": presence_penalty, } - # Pass enable_thinking / reasoning_effort / preserve_thinking per-request - _reasoning_kw = self._request_reasoning_kwargs( - enable_thinking, reasoning_effort, preserve_thinking - ) - if _reasoning_kw is not None: - payload["chat_template_kwargs"] = _reasoning_kw - # Default cap to the model's effective context length when known, - # otherwise the conservative floor. The wall-clock backstop below - # keeps a stuck model from running indefinitely either way. - payload["max_tokens"] = ( - max_tokens - if max_tokens is not None - else (self._effective_context_length or _DEFAULT_MAX_TOKENS_FLOOR) - ) - payload["t_max_predict_ms"] = _DEFAULT_T_MAX_PREDICT_MS + # Pass enable_thinking per-request for reasoning models + if self._supports_reasoning and enable_thinking is not None: + payload["chat_template_kwargs"] = {"enable_thinking": enable_thinking} + if max_tokens is not None: + payload["max_tokens"] = max_tokens if stop: payload["stop"] = stop payload["stream_options"] = {"include_usage": True} @@ -3122,9 +2195,7 @@ def generate_chat_completion( _auth_headers = ( {"Authorization": f"Bearer {self._api_key}"} if self._api_key else None ) - with httpx.Client( - timeout = stream_timeout, limits = httpx.Limits(max_keepalive_connections = 0) - ) as client: + with httpx.Client(timeout = stream_timeout) as client: with self._stream_with_retry( client, url, @@ -3238,8 +2309,6 @@ def generate_chat_completion_with_tools( stop: Optional[list[str]] = None, cancel_event: Optional[threading.Event] = None, enable_thinking: Optional[bool] = None, - reasoning_effort: Optional[str] = None, - preserve_thinking: Optional[bool] = None, max_tool_iterations: int = 25, auto_heal_tool_calls: bool = True, tool_call_timeout: int = 300, @@ -3318,17 +2387,10 @@ def _strip_tool_markup(text: str, *, final: bool = False) -> str: "tools": tools, "tool_choice": "auto", } - _reasoning_kw = self._request_reasoning_kwargs( - enable_thinking, reasoning_effort, preserve_thinking - ) - if _reasoning_kw is not None: - payload["chat_template_kwargs"] = _reasoning_kw - payload["max_tokens"] = ( - max_tokens - if max_tokens is not None - else (self._effective_context_length or _DEFAULT_MAX_TOKENS_FLOOR) - ) - payload["t_max_predict_ms"] = _DEFAULT_T_MAX_PREDICT_MS + if self._supports_reasoning and enable_thinking is not None: + payload["chat_template_kwargs"] = {"enable_thinking": enable_thinking} + if max_tokens is not None: + payload["max_tokens"] = max_tokens if stop: payload["stop"] = stop @@ -3367,10 +2429,7 @@ def _strip_tool_markup(text: str, *, final: bool = False) -> str: write = 10, pool = 10, ) - with httpx.Client( - timeout = stream_timeout, - limits = httpx.Limits(max_keepalive_connections = 0), - ) as client: + with httpx.Client(timeout = stream_timeout) as client: with self._stream_with_retry( client, url, @@ -3978,17 +3037,12 @@ def _strip_tool_markup(text: str, *, final: bool = False) -> str: "repeat_penalty": repetition_penalty, "presence_penalty": presence_penalty, } - _reasoning_kw = self._request_reasoning_kwargs( - enable_thinking, reasoning_effort, preserve_thinking - ) - if _reasoning_kw is not None: - stream_payload["chat_template_kwargs"] = _reasoning_kw - stream_payload["max_tokens"] = ( - max_tokens - if max_tokens is not None - else (self._effective_context_length or _DEFAULT_MAX_TOKENS_FLOOR) - ) - stream_payload["t_max_predict_ms"] = _DEFAULT_T_MAX_PREDICT_MS + if self._supports_reasoning and enable_thinking is not None: + stream_payload["chat_template_kwargs"] = { + "enable_thinking": enable_thinking + } + if max_tokens is not None: + stream_payload["max_tokens"] = max_tokens if stop: stream_payload["stop"] = stop stream_payload["stream_options"] = {"include_usage": True} @@ -4007,9 +3061,7 @@ def _strip_tool_markup(text: str, *, final: bool = False) -> str: _auth_headers = ( {"Authorization": f"Bearer {self._api_key}"} if self._api_key else None ) - with httpx.Client( - timeout = stream_timeout, limits = httpx.Limits(max_keepalive_connections = 0) - ) as client: + with httpx.Client(timeout = stream_timeout) as client: with self._stream_with_retry( client, url, diff --git a/studio/backend/core/inference/llama_server_args.py b/studio/backend/core/inference/llama_server_args.py deleted file mode 100644 index 44c7d542c7..0000000000 --- a/studio/backend/core/inference/llama_server_args.py +++ /dev/null @@ -1,120 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 - -"""Validator for user-supplied llama-server pass-through args. - -Studio runs llama-server as a managed subprocess and lets callers pass -extra flags directly (CLI: ``unsloth run ... --top-k 20``; HTTP: -``LoadRequest.llama_extra_args``). This module is the boundary that -rejects only flags Studio fundamentally cannot share with the user -- -model identity, the auth key, and the network endpoint Studio's HTTP -proxy targets. Anything else passes through. - -User-supplied args are appended to ``cmd`` after Studio's auto-set -flags, so llama.cpp's last-wins CLI parsing makes the user's value -override the auto-set one. That covers tunable knobs the user might -reasonably want to override -- ``-c``/``--ctx-size``, -``-np``/``--parallel``, ``-fa``/``--flash-attn``, -``-ngl``/``--gpu-layers``, ``-t``/``--threads``, ``-fit``/``--fit*``, -``--cache-type-k/v``, ``--chat-template-file/-kwargs``, -``--spec-*``, ``--jinja``/``--no-jinja``, -``--no-context-shift``/``--context-shift``, sampling params, etc. - -Reference: https://github.com/ggml-org/llama.cpp/blob/master/tools/server/README.md -""" - -from __future__ import annotations - -from typing import Iterable, Optional - -# Each group is the full set of aliases (short + long) for one -# hard-denied flag, taken from the llama-server README. If llama.cpp -# adds a new alias for an existing denied flag, extend the relevant -# group. -# -# Flags NOT in this list (e.g. -c, --parallel, --flash-attn, -ngl, -# -t/--threads, --jinja, --no-context-shift, --fit*, --cache-type-*, -# --chat-template-*, --spec-*) pass through and override Studio's -# auto-set version via llama.cpp's last-wins CLI parsing. -_DENYLIST_GROUPS: tuple[frozenset[str], ...] = ( - # Model identity -- Studio resolves the model from LoadRequest and - # passes -m / mmproj after downloading from HF if needed. A second - # -m would point at a different model than the one Studio thinks - # is loaded. - frozenset({"-m", "--model"}), - frozenset({"-mu", "--model-url"}), - frozenset({"-dr", "--docker-repo"}), - frozenset({"-hf", "-hfr", "--hf-repo"}), - frozenset({"-hff", "--hf-file"}), - frozenset({"-hfv", "-hfrv", "--hf-repo-v"}), - frozenset({"-hffv", "--hf-file-v"}), - frozenset({"-hft", "--hf-token"}), - frozenset({"-mm", "--mmproj"}), - frozenset({"-mmu", "--mmproj-url"}), - # Networking -- Studio binds llama-server's port and reverse-proxies - # HTTP traffic to it. Retargeting host/port/path/prefix would - # orphan Studio's proxy and the UI would lose the server. - frozenset({"--host"}), - frozenset({"--port"}), - frozenset({"--path"}), - frozenset({"--api-prefix"}), - frozenset({"--reuse-port"}), - # Auth / TLS -- Studio terminates auth at its own layer; an - # upstream --api-key would shadow Studio's UNSLOTH_DIRECT_STREAM - # key, and TLS on llama-server would break the local proxy hop. - frozenset({"--api-key"}), - frozenset({"--api-key-file"}), - frozenset({"--ssl-key-file"}), - frozenset({"--ssl-cert-file"}), - # Single-model server -- Studio runs one model per llama-server - # process and serves its own UI. Enabling multi-model loading or - # llama-server's built-in web UI changes the surface clients see. - frozenset({"--webui", "--no-webui"}), - frozenset({"--models-dir"}), - frozenset({"--models-preset"}), - frozenset({"--models-max"}), - frozenset({"--models-autoload", "--no-models-autoload"}), -) - -_DENYLIST: frozenset[str] = frozenset().union(*_DENYLIST_GROUPS) - - -def _flag_name(token: str) -> Optional[str]: - """Return the flag name for a token, or None if it isn't a flag. - - Peels ``--key=value`` to the bare ``--key``. Plain numeric values - like ``-1`` or ``-0.5`` (e.g. ``--seed -1``) are values, not flags; - llama-server short-form flags always start with a letter. - """ - if not token.startswith("-") or token in {"-", "--"}: - return None - if len(token) >= 2 and (token[1].isdigit() or token[1] == "."): - return None - return token.split("=", 1)[0] - - -def validate_extra_args(args: Optional[Iterable[str]]) -> list[str]: - """Validate user-supplied llama-server args. - - Returns the args as a flat list ready to extend the llama-server - command. Raises ``ValueError`` (with the offending flag in the - message) the moment a token resolves to a Studio-managed flag. - """ - if not args: - return [] - out: list[str] = [] - for raw in args: - token = str(raw) - flag = _flag_name(token) - if flag is not None and flag in _DENYLIST: - raise ValueError( - f"llama-server flag '{flag}' is managed by Unsloth Studio " - f"and cannot be passed as an extra arg" - ) - out.append(token) - return out - - -def is_managed_flag(flag: str) -> bool: - """True if ``flag`` is a Studio-managed llama-server flag.""" - return flag in _DENYLIST diff --git a/studio/backend/core/inference/orchestrator.py b/studio/backend/core/inference/orchestrator.py index 5562820f49..cb5d9da34a 100644 --- a/studio/backend/core/inference/orchestrator.py +++ b/studio/backend/core/inference/orchestrator.py @@ -166,30 +166,23 @@ def _fetch_top_models(self) -> None: def _spawn_subprocess(self, config: dict) -> None: """Spawn a new inference subprocess.""" - from utils.native_path_leases import ( - native_path_secret_removed_for_child_start, - run_without_native_path_secret, - ) - from .worker import run_inference_process - with native_path_secret_removed_for_child_start(): - self._cmd_queue = _CTX.Queue() - self._resp_queue = _CTX.Queue() - self._cancel_event = _CTX.Event() - - self._proc = _CTX.Process( - target = run_without_native_path_secret, - args = (run_inference_process,), - kwargs = { - "cmd_queue": self._cmd_queue, - "resp_queue": self._resp_queue, - "cancel_event": self._cancel_event, - "config": config, - }, - daemon = True, - ) - self._proc.start() + self._cmd_queue = _CTX.Queue() + self._resp_queue = _CTX.Queue() + self._cancel_event = _CTX.Event() + + self._proc = _CTX.Process( + target = run_inference_process, + kwargs = { + "cmd_queue": self._cmd_queue, + "resp_queue": self._resp_queue, + "cancel_event": self._cancel_event, + "config": config, + }, + daemon = True, + ) + self._proc.start() logger.info("Inference subprocess started (pid=%s)", self._proc.pid) def _cancel_generation(self) -> None: @@ -715,17 +708,6 @@ def load_model( def unload_model(self, model_name: str) -> bool: """Unload a model from the subprocess.""" - if model_name in self.loading_models: - logger.info( - "Cancelling in-flight load for model '%s' by terminating subprocess", - model_name, - ) - self._shutdown_subprocess(timeout = 0.5) - self.loading_models.discard(model_name) - self.active_model_name = None - self.models.clear() - return True - if not self._ensure_subprocess_alive(): # No subprocess — just clear local state self.models.pop(model_name, None) diff --git a/studio/backend/core/training/resume.py b/studio/backend/core/training/resume.py deleted file mode 100644 index 165c1c2cf1..0000000000 --- a/studio/backend/core/training/resume.py +++ /dev/null @@ -1,75 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 - -"""Helpers for validating resumable training outputs.""" - -from pathlib import Path -from typing import Optional - -from utils.paths import outputs_root, resolve_output_dir - - -def _is_under_outputs(path: Path) -> bool: - resolved = path.resolve(strict = False) - root = outputs_root().resolve(strict = False) - try: - resolved.relative_to(root) - return True - except ValueError: - return False - - -def has_resume_state(path_value: Optional[str]) -> bool: - if not path_value: - return False - return get_resume_checkpoint_path(path_value) is not None - - -def _checkpoint_step(path: Path) -> int: - try: - return int(path.name.removeprefix("checkpoint-")) - except ValueError: - return -1 - - -def get_resume_checkpoint_path(path_value: str) -> Optional[str]: - path = resolve_output_dir(path_value) - if not _is_under_outputs(path) or not path.is_dir(): - return None - if (path / "trainer_state.json").is_file(): - return str(path) - - checkpoints = [ - child - for child in path.glob("checkpoint-*") - if child.is_dir() and (child / "trainer_state.json").is_file() - ] - if not checkpoints: - return None - return str(max(checkpoints, key = _checkpoint_step)) - - -def normalize_resume_output_dir(path_value: str) -> str: - path = resolve_output_dir(path_value) - if not _is_under_outputs(path): - raise ValueError("Resume checkpoint must be inside Studio outputs.") - return str(path) - - -def can_resume_run(run: dict) -> bool: - if run.get("resumed_later"): - return False - - final_step = run.get("final_step") - total_steps = run.get("total_steps") - has_remaining_steps = ( - not isinstance(final_step, int) - or not isinstance(total_steps, int) - or total_steps <= 0 - or final_step < total_steps - ) - return ( - run.get("status") == "stopped" - and has_remaining_steps - and has_resume_state(run.get("output_dir")) - ) diff --git a/studio/backend/core/training/trainer.py b/studio/backend/core/training/trainer.py index fe8d277ac0..77cbda6b45 100644 --- a/studio/backend/core/training/trainer.py +++ b/studio/backend/core/training/trainer.py @@ -49,7 +49,6 @@ import json import threading import math -import subprocess import structlog from loggers import get_logger import time @@ -70,11 +69,6 @@ ) from trl import SFTTrainer, SFTConfig -from utils.native_path_leases import child_env_without_native_path_secret -from utils.subprocess_compat import ( - windows_hidden_subprocess_kwargs as _windows_hidden_subprocess_kwargs, -) - logger = get_logger(__name__) @@ -377,7 +371,6 @@ def _build_audio_training_args(self, training_args, output_dir, *, extra_args = def _finalize_training(self, output_dir, label = ""): """Save model after training and update progress. Used by all training branches.""" if self.should_stop and self.save_on_stop: - self.trainer._save_checkpoint(self.trainer.model, trial = None) self.trainer.save_model() self.tokenizer.save_pretrained(output_dir) self._patch_adapter_config(output_dir) @@ -1772,8 +1765,6 @@ def _preprocess_bicodec_dataset(self, dataset, custom_format_mapping = None): spark_code_dir, ], check = True, - env = child_env_without_native_path_secret(), - **_windows_hidden_subprocess_kwargs(), ) if spark_code_dir not in sys.path: @@ -1991,6 +1982,8 @@ def _preprocess_dac_dataset(self, dataset, custom_format_mapping = None): device = "cuda" if torch.cuda.is_available() else "cpu" # Clone OuteTTS repo (same as audio_codecs._load_dac) + import subprocess + base_dir = os.path.dirname(os.path.abspath(__file__)) outetts_code_dir = os.path.join(base_dir, "inference", "OuteTTS") outetts_pkg = os.path.join(outetts_code_dir, "outetts") @@ -2007,8 +2000,6 @@ def _preprocess_dac_dataset(self, dataset, custom_format_mapping = None): outetts_code_dir, ], check = True, - env = child_env_without_native_path_secret(), - **_windows_hidden_subprocess_kwargs(), ) for fpath in [ os.path.join(outetts_pkg, "models", "gguf_model.py"), @@ -2832,9 +2823,7 @@ def _train_worker(self, dataset: Dataset, **training_args): total_steps = total, status_message = "Starting CSM training..." ) logger.info(f"CSM training config: {config}\n") - self.trainer.train( - resume_from_checkpoint = training_args.get("resume_from_checkpoint") - ) + self.trainer.train() self._finalize_training(output_dir, "CSM") return @@ -2873,9 +2862,7 @@ def _train_worker(self, dataset: Dataset, **training_args): total_steps = total, status_message = "Starting SNAC training..." ) logger.info(f"SNAC training config: {config}\n") - self.trainer.train( - resume_from_checkpoint = training_args.get("resume_from_checkpoint") - ) + self.trainer.train() self._finalize_training(output_dir, "SNAC") return @@ -2921,9 +2908,7 @@ def _train_worker(self, dataset: Dataset, **training_args): total_steps = total, status_message = "Starting Whisper training..." ) logger.info(f"Whisper training config: {config}\n") - self.trainer.train( - resume_from_checkpoint = training_args.get("resume_from_checkpoint") - ) + self.trainer.train() self._finalize_training(output_dir, "Whisper") return @@ -3418,9 +3403,7 @@ def audio_vlm_collate_fn(examples): # ========== START TRAINING ========== self._update_progress(status_message = "Starting training...") logger.info("Starting training...\n") - self.trainer.train( - resume_from_checkpoint = training_args.get("resume_from_checkpoint") - ) + self.trainer.train() # ========== SAVE MODEL ========== self._finalize_training(output_dir) diff --git a/studio/backend/core/training/training.py b/studio/backend/core/training/training.py index 5642faa189..f35c7e8ad3 100644 --- a/studio/backend/core/training/training.py +++ b/studio/backend/core/training/training.py @@ -29,10 +29,6 @@ import matplotlib.pyplot as plt from utils.hardware import prepare_gpu_selection -from utils.native_path_leases import ( - native_path_secret_removed_for_child_start, - run_without_native_path_secret, -) logger = get_logger(__name__) @@ -189,7 +185,6 @@ def start_training(self, job_id: str, **kwargs) -> bool: "wandb_project": kwargs.get("wandb_project", "unsloth-training"), "enable_tensorboard": kwargs.get("enable_tensorboard", False), "tensorboard_dir": kwargs.get("tensorboard_dir", "runs"), - "resume_from_checkpoint": kwargs.get("resume_from_checkpoint"), "trust_remote_code": kwargs.get("trust_remote_code", False), "gpu_ids": kwargs.get("gpu_ids"), } @@ -217,22 +212,20 @@ def start_training(self, job_id: str, **kwargs) -> bool: from .worker import run_training_process + event_queue = _CTX.Queue() + stop_queue = _CTX.Queue() + + proc = _CTX.Process( + target = run_training_process, + kwargs = { + "event_queue": event_queue, + "stop_queue": stop_queue, + "config": config, + }, + daemon = True, + ) try: - with native_path_secret_removed_for_child_start(): - event_queue = _CTX.Queue() - stop_queue = _CTX.Queue() - - proc = _CTX.Process( - target = run_without_native_path_secret, - args = (run_training_process,), - kwargs = { - "event_queue": event_queue, - "stop_queue": stop_queue, - "config": config, - }, - daemon = True, - ) - proc.start() + proc.start() except Exception: logger.error("Failed to start training subprocess", exc_info = True) return False diff --git a/studio/backend/core/training/worker.py b/studio/backend/core/training/worker.py index 60b9e994ab..a461972eca 100644 --- a/studio/backend/core/training/worker.py +++ b/studio/backend/core/training/worker.py @@ -16,40 +16,26 @@ import structlog from loggers import get_logger import os +import platform import shutil import sys import time import traceback +import json import subprocess as _sp from pathlib import Path -from typing import Any, Callable +from typing import Any +import urllib.error +import urllib.request logger = get_logger(__name__) from utils.hardware import apply_gpu_ids -from utils.wheel_utils import ( - direct_wheel_url, - flash_attn_wheel_url, - install_wheel, - probe_torch_wheel_env, - url_exists, -) - - -def _output_dir_from_resume_checkpoint( - resume_from_checkpoint: str | None, -) -> str | None: - if not resume_from_checkpoint: - return None - path = Path(resume_from_checkpoint) - return str(path.parent if path.name.startswith("checkpoint-") else path) _CAUSAL_CONV1D_RELEASE_TAG = "v1.6.1.post4" _CAUSAL_CONV1D_PACKAGE_VERSION = "1.6.1" _MAMBA_SSM_RELEASE_TAG = "v2.3.1" _MAMBA_SSM_PACKAGE_VERSION = "2.3.1" -_FLASH_ATTN_RUNTIME_MIN_SEQ_LEN = 32768 -_FLASH_ATTN_SKIP_ENV = "UNSLOTH_STUDIO_SKIP_FLASHATTN_INSTALL" def _model_wants_causal_conv1d(model_name: str) -> bool: @@ -59,8 +45,6 @@ def _model_wants_causal_conv1d(model_name: str) -> bool: for key in ( "qwen3.5", "qwen3_5", - "qwen3.6", - "qwen3_6", "qwen3-next", "qwen3_next", "nemotron_h", @@ -75,186 +59,206 @@ def _model_wants_causal_conv1d(model_name: str) -> bool: ) +def _causal_conv1d_platform_tag() -> str | None: + machine = platform.machine().lower() + if sys.platform.startswith("linux"): + if machine in {"x86_64", "amd64"}: + return "linux_x86_64" + if machine in {"aarch64", "arm64"}: + return "linux_aarch64" + return None + # No prebuilt wheels published for macOS or Windows + return None + + +def _probe_causal_conv1d_env() -> dict[str, str] | None: + try: + probe = _sp.run( + [ + sys.executable, + "-c", + ( + "import json, sys, re, torch; " + "parts = torch.__version__.split('+', 1)[0].split('.')[:2]; " + "minor = re.sub(r'[^0-9].*', '', parts[1]) if len(parts) > 1 else '0'; " + "torch_mm = parts[0] + '.' + minor; " + "print(json.dumps({" + "'python_tag': f'cp{sys.version_info.major}{sys.version_info.minor}', " + "'torch_mm': torch_mm, " + "'cuda_major': str(int(str(torch.version.cuda).split('.', 1)[0])) if torch.version.cuda else '', " + "'cxx11abi': str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()" + "}))" + ), + ], + stdout = _sp.PIPE, + stderr = _sp.PIPE, + text = True, + timeout = 30, + ) + except _sp.TimeoutExpired: + logger.warning("Torch environment probe timed out after 30s") + return None + if probe.returncode != 0: + logger.warning( + "Failed to probe torch environment for causal-conv1d wheel:\n%s", + probe.stdout, + ) + return None + + try: + return json.loads(probe.stdout.strip()) + except json.JSONDecodeError: + logger.warning( + "Failed to parse torch environment probe output: %s", probe.stdout + ) + return None + + +def _direct_wheel_url( + *, + filename_prefix: str, + package_version: str, + release_tag: str, + release_base_url: str, + env: dict[str, str] | None = None, +) -> str | None: + env = env or _probe_causal_conv1d_env() + platform_tag = _causal_conv1d_platform_tag() + if env is None or platform_tag is None or not env.get("cuda_major"): + return None + + filename = ( + f"{filename_prefix}-{package_version}" + f"+cu{env['cuda_major']}torch{env['torch_mm']}" + f"cxx11abi{env['cxx11abi']}-{env['python_tag']}-{env['python_tag']}-{platform_tag}.whl" + ) + return f"{release_base_url}/{release_tag}/{filename}" + + +def _url_exists(url: str) -> bool: + try: + request = urllib.request.Request(url, method = "HEAD") + with urllib.request.urlopen(request, timeout = 10): + return True + except urllib.error.HTTPError as exc: + if exc.code == 404: + return False + logger.warning("Unexpected HTTP error while probing %s: %s", url, exc) + return False + except Exception as exc: + logger.warning("Failed to probe %s: %s", url, exc) + return False + + def _install_package_wheel_first( *, event_queue: Any, import_name: str, display_name: str, pypi_name: str, - pypi_version: str | None = None, - filename_prefix: str | None = None, - release_tag: str | None = None, - release_base_url: str | None = None, - wheel_url_builder: Callable[[dict[str, str] | None], str | None] | None = None, - pypi_spec: str | None = None, - pypi_status_message: str | None = None, -) -> bool: + pypi_version: str, + filename_prefix: str, + release_tag: str, + release_base_url: str, +) -> None: try: __import__(import_name) logger.info("%s already installed", display_name) - return True + return except ImportError: pass - env = probe_torch_wheel_env(timeout = 30) - if wheel_url_builder is not None: - wheel_url = wheel_url_builder(env) - else: - wheel_url = direct_wheel_url( - filename_prefix = filename_prefix, - package_version = pypi_version, - release_tag = release_tag, - release_base_url = release_base_url, - env = env, - ) + env = _probe_causal_conv1d_env() + wheel_url = _direct_wheel_url( + filename_prefix = filename_prefix, + package_version = pypi_version, + release_tag = release_tag, + release_base_url = release_base_url, + env = env, + ) if wheel_url is None: logger.info("No compatible %s wheel candidate", display_name) - elif url_exists(wheel_url): - _send_status(event_queue, f"Installing prebuilt {display_name} wheel...") - for installer, result in install_wheel( - wheel_url, - python_executable = sys.executable, - use_uv = bool(shutil.which("uv")), - run = _sp.run, - ): - if result.returncode == 0: - logger.info("Installed prebuilt %s wheel successfully", display_name) - return True - logger.warning( - "%s failed to install %s wheel:\n%s", - installer, - display_name, - result.stdout, - ) - else: - logger.info("No published %s wheel found: %s", display_name, wheel_url) - - is_hip = env and env.get("hip_version") - if is_hip and not shutil.which("hipcc"): - logger.error( - "%s requires hipcc for source compilation on ROCm. " - "Install the ROCm HIP SDK: https://rocm.docs.amd.com", - display_name, - ) - _send_status( - event_queue, - f"{display_name}: hipcc not found (ROCm HIP SDK required)", - ) - return False - - if pypi_spec is None: - pypi_spec = f"{pypi_name}=={pypi_version}" - - if pypi_status_message is None: - if is_hip: - pypi_status_message = ( - f"Compiling {display_name} from source for ROCm " - "(this may take several minutes)..." - ) - else: - pypi_status_message = f"Installing {display_name} from PyPI..." - - _send_status(event_queue, pypi_status_message) - - # Prefer uv for faster dependency resolution when available - plain_pypi_install = pypi_version is None - if plain_pypi_install: - if shutil.which("uv"): - pypi_cmd = [ - "uv", - "pip", - "install", - "--python", - sys.executable, - pypi_spec, - ] - else: - pypi_cmd = [sys.executable, "-m", "pip", "install", pypi_spec] else: - if shutil.which("uv"): - pypi_cmd = [ - "uv", - "pip", - "install", - "--python", - sys.executable, - "--no-build-isolation", - "--no-deps", - ] - # Avoid stale cache artifacts from partial HIP source builds - if is_hip: - pypi_cmd.append("--no-cache") - pypi_cmd.append(pypi_spec) + if _url_exists(wheel_url): + _send_status(event_queue, f"Installing prebuilt {display_name} wheel...") + installed = False + # Try uv first if available, then fall back to pip + if shutil.which("uv"): + uv_cmd = [ + "uv", + "pip", + "install", + "--python", + sys.executable, + "--no-deps", + wheel_url, + ] + result = _sp.run( + uv_cmd, + stdout = _sp.PIPE, + stderr = _sp.STDOUT, + text = True, + ) + if result.returncode == 0: + installed = True + else: + logger.warning( + "uv failed to install %s wheel:\n%s", + display_name, + result.stdout, + ) + if not installed: + pip_cmd = [ + sys.executable, + "-m", + "pip", + "install", + "--no-deps", + wheel_url, + ] + result = _sp.run( + pip_cmd, + stdout = _sp.PIPE, + stderr = _sp.STDOUT, + text = True, + ) + if result.returncode == 0: + installed = True + else: + logger.warning( + "pip failed to install %s wheel:\n%s", + display_name, + result.stdout, + ) + if installed: + logger.info("Installed prebuilt %s wheel successfully", display_name) + return else: - pypi_cmd = [ - sys.executable, - "-m", - "pip", - "install", - "--no-build-isolation", - "--no-deps", - "--no-cache-dir", - pypi_spec, - ] - - # Source compilation on ROCm can take 10-30 minutes; use a generous - # timeout. Non-HIP installs preserve the pre-existing "no timeout" - # behaviour so unrelated slow installs (e.g. causal-conv1d source - # build on Linux aarch64 or unsupported torch/CUDA combinations) - # are not aborted at 5 minutes by this PR. - _run_kwargs: dict[str, Any] = { - "stdout": _sp.PIPE, - "stderr": _sp.STDOUT, - "text": True, - } - if is_hip: - _run_kwargs["timeout"] = 1800 - - try: - result = _sp.run(pypi_cmd, **_run_kwargs) - except _sp.TimeoutExpired: - logger.error( - "%s installation timed out after %ds", - display_name, - _run_kwargs.get("timeout"), - ) - _send_status( - event_queue, - f"{display_name} installation timed out after " - f"{_run_kwargs.get('timeout')}s", - ) - return False - + logger.info("No published %s wheel found: %s", display_name, wheel_url) + + _send_status(event_queue, f"Installing {display_name} from PyPI...") + pypi_cmd = [ + sys.executable, + "-m", + "pip", + "install", + "--no-build-isolation", + "--no-deps", + "--no-cache-dir", + f"{pypi_name}=={pypi_version}", + ] + result = _sp.run( + pypi_cmd, + stdout = _sp.PIPE, + stderr = _sp.STDOUT, + text = True, + ) if result.returncode != 0: - if is_hip: - # Surface a clear error for ROCm source build failures - error_lines = (result.stdout or "").strip().splitlines() - snippet = "\n".join(error_lines[-5:]) if error_lines else "(no output)" - logger.error( - "Failed to compile %s for ROCm:\n%s", - display_name, - result.stdout, - ) - _send_status( - event_queue, - f"Failed to compile {display_name} for ROCm. " - "Check that hipcc and ROCm development headers are installed.\n" - f"{snippet}", - ) - else: - logger.error( - "Failed to install %s from PyPI:\n%s", - display_name, - result.stdout, - ) - return False + logger.error("Failed to install %s from PyPI:\n%s", display_name, result.stdout) + return - if is_hip: - logger.info("Compiled and installed %s from source for ROCm", display_name) - else: - logger.info("Installed %s from PyPI", display_name) - return True + logger.info("Installed %s from PyPI", display_name) def _ensure_causal_conv1d_fast_path(event_queue: Any, model_name: str) -> None: @@ -301,31 +305,6 @@ def _ensure_mamba_ssm(event_queue: Any, model_name: str) -> None: ) -def _should_try_runtime_flash_attn_install(max_seq_length: int) -> bool: - if os.getenv(_FLASH_ATTN_SKIP_ENV) == "1": - return False - if max_seq_length < _FLASH_ATTN_RUNTIME_MIN_SEQ_LEN: - return False - return sys.platform.startswith("linux") - - -def _ensure_flash_attn_for_long_context(event_queue: Any, max_seq_length: int) -> None: - if not _should_try_runtime_flash_attn_install(max_seq_length): - return - - installed = _install_package_wheel_first( - event_queue = event_queue, - import_name = "flash_attn", - display_name = "flash-attn", - pypi_name = "flash-attn", - wheel_url_builder = flash_attn_wheel_url, - pypi_spec = "flash-attn", - pypi_status_message = "Installing flash-attn from PyPI for long-context training...", - ) - if not installed: - _send_status(event_queue, "Continuing without flash-attn") - - def _activate_transformers_version(model_name: str) -> None: """Activate the correct transformers version BEFORE any ML imports.""" # Ensure backend is on path for utils imports @@ -408,10 +387,6 @@ def run_training_process( try: _ensure_causal_conv1d_fast_path(event_queue, model_name) _ensure_mamba_ssm(event_queue, model_name) - _ensure_flash_attn_for_long_context( - event_queue, - int(config.get("max_seq_length", 2048)), - ) except Exception as exc: event_queue.put( { @@ -766,10 +741,7 @@ def _monitor_tqdm(): return # Generate output dir - resume_from_checkpoint = config.get("resume_from_checkpoint") - output_dir = config.get("output_dir") or _output_dir_from_resume_checkpoint( - resume_from_checkpoint - ) + output_dir = config.get("output_dir") if not output_dir: output_dir = f"{model_name.replace('/', '_')}_{int(time.time())}" output_dir = str(resolve_output_dir(output_dir)) @@ -817,7 +789,6 @@ def _monitor_tqdm(): max_seq_length = config.get("max_seq_length", 2048), optim = config.get("optim", "adamw_8bit"), lr_scheduler_type = config.get("lr_scheduler_type", "linear"), - resume_from_checkpoint = resume_from_checkpoint, ) _tqdm_stop.set() @@ -834,13 +805,10 @@ def _monitor_tqdm(): } ) else: - saved_output_dir = ( - None if trainer.should_stop and not trainer.save_on_stop else output_dir - ) event_queue.put( { "type": "complete", - "output_dir": saved_output_dir, + "output_dir": output_dir, "status_message": progress.status_message or "Training completed", "ts": time.time(), } @@ -1125,15 +1093,11 @@ def _poll_stop(): ) return - resume_from_checkpoint = config.get("resume_from_checkpoint") - output_dir = config.get("output_dir") or _output_dir_from_resume_checkpoint( - resume_from_checkpoint - ) + output_dir = config.get("output_dir") if not output_dir: output_dir = str( resolve_output_dir(f"{model_name.replace('/', '_')}_{int(time.time())}") ) - output_dir = str(resolve_output_dir(output_dir)) num_epochs = config.get("num_epochs", 2) batch_size = config.get("batch_size", 256) @@ -1241,7 +1205,7 @@ def on_step_end(self, args, state, control, **kwargs): callbacks = [_EmbeddingProgressCallback()], ) - trainer.train(resume_from_checkpoint = resume_from_checkpoint) + trainer.train() except Exception as e: event_queue.put( { @@ -1267,8 +1231,6 @@ def on_step_end(self, args, state, control, **kwargs): _send_status(event_queue, "Saving model...") try: - if _should_stop and _save_on_stop: - trainer._save_checkpoint(trainer.model, trial = None) model.save_pretrained(output_dir) model.tokenizer.save_pretrained(output_dir) logger.info("Embedding model saved to %s", output_dir) diff --git a/studio/backend/loggers/config.py b/studio/backend/loggers/config.py index 4a27f13d38..0d32a64657 100644 --- a/studio/backend/loggers/config.py +++ b/studio/backend/loggers/config.py @@ -22,8 +22,6 @@ import structlog -from loggers.handlers import filter_sensitive_data - class LogConfig: """Structured logging configuration for the application. @@ -46,22 +44,12 @@ def setup_logging( # Fallback to INFO if an invalid level is provided log_level = getattr(logging, log_level_name, logging.INFO) - if sys.platform == "win32": - for stream in (sys.stdout, sys.stderr): - if hasattr(stream, "reconfigure"): - try: - stream.reconfigure(encoding = "utf-8", errors = "replace") - except Exception: - pass - structlog.configure( processors = [ # Reorder processors to control field order structlog.processors.TimeStamper(fmt = "iso"), # timestamp first structlog.processors.add_log_level, # level second structlog.contextvars.merge_contextvars, - structlog.processors.format_exc_info, - filter_sensitive_data, # Custom processor to flatten the extra field lambda logger, method_name, event_dict: { "timestamp": event_dict.get("timestamp"), diff --git a/studio/backend/loggers/handlers.py b/studio/backend/loggers/handlers.py index ddd404cdf3..3add92ea1e 100644 --- a/studio/backend/loggers/handlers.py +++ b/studio/backend/loggers/handlers.py @@ -15,7 +15,6 @@ - get_logger: Factory function for structured loggers """ -import re import time from typing import Callable @@ -23,12 +22,7 @@ from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware -from utils.native_path_leases import redact_native_paths - logger = structlog.get_logger(__name__) -_NATIVE_PATH_LEASE_RE = re.compile( - r"(?i)(\b(?:native_path_lease|nativePathLease)[\"']?\s*[:=]\s*[\"']?)[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+" -) class LoggingMiddleware(BaseHTTPMiddleware): @@ -81,12 +75,6 @@ def filter_sensitive_data(logger, method_name, event_dict): """Structlog processor to filter out base64 data from logs.""" def filter_value(value): - if isinstance(value, str): - try: - value = redact_native_paths(value) - except Exception: - pass - value = _NATIVE_PATH_LEASE_RE.sub(r"\1", value) if ( isinstance(value, str) and len(value) > 100 @@ -95,22 +83,12 @@ def filter_value(value): # Likely base64 data, truncate it return value[:20] + "..." elif isinstance(value, dict): - return { - k: "" - if str(k).replace("_", "").lower() == "nativepathlease" - else filter_value(v) - for k, v in value.items() - } + return {k: filter_value(v) for k, v in value.items()} elif isinstance(value, list): return [filter_value(item) for item in value] return value - return { - k: "" - if str(k).replace("_", "").lower() == "nativepathlease" - else filter_value(v) - for k, v in event_dict.items() - } + return {k: filter_value(v) for k, v in event_dict.items()} def get_logger(name: str) -> structlog.BoundLogger: diff --git a/studio/backend/main.py b/studio/backend/main.py index 0958094ff0..ad19ee9679 100644 --- a/studio/backend/main.py +++ b/studio/backend/main.py @@ -27,7 +27,6 @@ import shutil import warnings from contextlib import asynccontextmanager -from importlib.metadata import PackageNotFoundError, version as package_version # Fix broken Windows registry MIME types. Some Windows installs map .js to # "text/plain" in the registry (HKCR\.js\Content Type). Python's mimetypes @@ -62,7 +61,6 @@ datasets_router, export_router, inference_router, - inference_studio_router, models_router, training_history_router, training_router, @@ -78,28 +76,6 @@ import utils.hardware.hardware as _hw_module from utils.cache_cleanup import clear_unsloth_compiled_cache -from utils.native_path_leases import native_path_leases_supported - - -def get_unsloth_version() -> str: - try: - return package_version("unsloth") - except PackageNotFoundError: - pass - - version_file = ( - _Path(__file__).resolve().parents[2] / "unsloth" / "models" / "_utils.py" - ) - try: - for line in version_file.read_text(encoding = "utf-8").splitlines(): - if line.startswith("__version__ = "): - return line.split("=", 1)[1].strip().strip('"').strip("'") - except OSError: - pass - return "dev" - - -UNSLOTH_VERSION = get_unsloth_version() @asynccontextmanager @@ -164,7 +140,7 @@ def _precache(): # Create FastAPI app app = FastAPI( title = "Unsloth UI Backend", - version = UNSLOTH_VERSION, + version = "1.0.0", description = "Backend API for Unsloth UI - Training and Model Management", lifespan = lifespan, ) @@ -181,24 +157,9 @@ def _precache(): app.add_middleware(LoggingMiddleware) # CORS middleware -_api_only = os.environ.get("UNSLOTH_API_ONLY") == "1" -_cors_origins = ["*"] -if _api_only: - _cors_origins = [ - "tauri://localhost", # Linux/macOS Tauri webview - "http://tauri.localhost", # Windows Tauri webview - "http://localhost", # dev fallback - "http://localhost:5173", # Tauri dev/Vite - "http://127.0.0.1:5173", # Tauri dev/Vite fallback - ] - _cors_origin_regex = None -else: - _cors_origin_regex = None - app.add_middleware( CORSMiddleware, - allow_origins = _cors_origins, - allow_origin_regex = _cors_origin_regex, + allow_origins = ["*"], # In production, specify allowed origins allow_credentials = True, allow_methods = ["*"], allow_headers = ["*"], @@ -211,9 +172,6 @@ def _precache(): app.include_router(training_router, prefix = "/api/train", tags = ["training"]) app.include_router(models_router, prefix = "/api/models", tags = ["models"]) app.include_router(inference_router, prefix = "/api/inference", tags = ["inference"]) -# Studio-only inference endpoints (cancel, etc.) are intentionally NOT -# exposed on the /v1 OpenAI-compat prefix below. -app.include_router(inference_studio_router, prefix = "/api/inference", tags = ["inference"]) # OpenAI-compatible endpoints: mount the same inference router at /v1 # so external tools (Open WebUI, SillyTavern, etc.) can use the @@ -240,12 +198,8 @@ async def health_check(): "status": "healthy", "timestamp": datetime.now().isoformat(), "service": "Unsloth UI Backend", - "version": UNSLOTH_VERSION, "device_type": device_type, "chat_only": _hw_module.CHAT_ONLY, - "desktop_protocol_version": 1, - "supports_desktop_auth": True, - "native_path_leases_supported": native_path_leases_supported(), } @@ -283,7 +237,6 @@ async def get_system_info(): import platform import psutil from utils.hardware import get_device - from utils.hardware.hardware import _backend_label visibility_info = get_backend_visible_gpu_info() gpu_info = { @@ -297,10 +250,7 @@ async def get_system_info(): return { "platform": platform.platform(), "python_version": platform.python_version(), - # Use the centralized _backend_label helper so the /api/system - # endpoint reports "rocm" on AMD hosts instead of "cuda", matching - # the /api/hardware and /api/gpu-visibility endpoints. - "device_backend": _backend_label(get_device()), + "device_backend": get_device().value, "cpu_count": psutil.cpu_count(), "memory": { "total_gb": round(memory.total / 1e9, 2), @@ -399,7 +349,7 @@ async def serve_root(): @app.get("/{full_path:path}") async def serve_frontend(full_path: str): - if full_path in {"api", "v1"} or full_path.startswith(("api/", "v1/")): + if full_path.startswith("api"): return {"error": "API endpoint not found"} file_path = (build_path / full_path).resolve() diff --git a/studio/backend/models/auth.py b/studio/backend/models/auth.py index 23eb0ac4c0..73d21130ae 100644 --- a/studio/backend/models/auth.py +++ b/studio/backend/models/auth.py @@ -5,8 +5,6 @@ Pydantic schemas for Authentication API """ -from typing import Optional - from pydantic import BaseModel, Field @@ -17,12 +15,6 @@ class AuthLoginRequest(BaseModel): password: str = Field(..., description = "Password") -class DesktopLoginRequest(BaseModel): - """Desktop-only local secret exchange payload.""" - - secret: str = Field(..., description = "Desktop local auth secret") - - class RefreshTokenRequest(BaseModel): """Refresh token payload to obtain new access + refresh tokens.""" @@ -53,44 +45,3 @@ class ChangePasswordRequest(BaseModel): new_password: str = Field( ..., min_length = 8, description = "Replacement password (minimum 8 characters)" ) - - -# --------------------------------------------------------------------------- -# API key schemas -# --------------------------------------------------------------------------- - - -class CreateApiKeyRequest(BaseModel): - """Request body to create a new API key.""" - - name: str = Field(..., description = "Human-readable label for this key") - expires_in_days: Optional[int] = Field( - None, description = "Number of days until the key expires (None = never)" - ) - - -class ApiKeyResponse(BaseModel): - """Public representation of an API key (never contains the raw key).""" - - id: int - name: str - key_prefix: str = Field( - ..., description = "First 8 characters after sk-unsloth- for display" - ) - created_at: str - last_used_at: Optional[str] = None - expires_at: Optional[str] = None - is_active: bool - - -class CreateApiKeyResponse(BaseModel): - """Returned once when a key is created -- ``key`` is never shown again.""" - - key: str = Field(..., description = "Full API key (shown once)") - api_key: ApiKeyResponse - - -class ApiKeyListResponse(BaseModel): - """List of API keys for the authenticated user.""" - - api_keys: list[ApiKeyResponse] diff --git a/studio/backend/models/inference.py b/studio/backend/models/inference.py index 43087cc5bf..cf08ecbc12 100644 --- a/studio/backend/models/inference.py +++ b/studio/backend/models/inference.py @@ -11,16 +11,13 @@ import uuid from typing import Annotated, Any, Dict, Literal, Optional, List, Union -from pydantic import BaseModel, Discriminator, Field, Tag, model_validator +from pydantic import BaseModel, Discriminator, Field, Tag class LoadRequest(BaseModel): """Request to load a model for inference""" model_path: str = Field(..., description = "Model identifier or local path") - native_path_lease: Optional[str] = Field( - None, description = "Frontend-visible signed native path grant" - ) hf_token: Optional[str] = Field( None, description = "HuggingFace token for gated models" ) @@ -55,16 +52,6 @@ class LoadRequest(BaseModel): None, description = "Speculative decoding mode for GGUF models (e.g. 'ngram-simple', 'ngram-mod'). Ignored for non-GGUF and vision models.", ) - llama_extra_args: Optional[List[str]] = Field( - None, - description = ( - "Extra arguments forwarded verbatim to llama-server for GGUF models. " - "One token per list entry, e.g. ['--top-k', '20', '--seed', '42']. " - "Studio-managed flags (model identity, port, context length, GPU placement, " - "auth, --flash-attn, --no-context-shift, --jinja) are rejected. Ignored for " - "non-GGUF models." - ), - ) class UnloadRequest(BaseModel): @@ -82,9 +69,6 @@ class ValidateModelRequest(BaseModel): """ model_path: str = Field(..., description = "Model identifier or local path") - native_path_lease: Optional[str] = Field( - None, description = "Frontend-visible signed native path grant" - ) hf_token: Optional[str] = Field( None, description = "HuggingFace token for gated models" ) @@ -110,10 +94,6 @@ class ValidateModelResponse(BaseModel): is_gguf: bool = Field(False, description = "Whether this is a GGUF model (llama.cpp)") is_lora: bool = Field(False, description = "Whether this is a LoRA adapter") is_vision: bool = Field(False, description = "Whether this is a vision-capable model") - requires_trust_remote_code: bool = Field( - False, - description = "Whether the model defaults require trust_remote_code to be enabled for loading.", - ) class GenerateRequest(BaseModel): @@ -157,10 +137,6 @@ class LoadResponse(BaseModel): inference: dict = Field( ..., description = "Inference parameters (temperature, top_p, top_k, min_p)" ) - requires_trust_remote_code: bool = Field( - False, - description = "Whether the model defaults require trust_remote_code to be enabled for loading.", - ) context_length: Optional[int] = Field( None, description = "Model's native context length (from GGUF metadata)" ) @@ -173,20 +149,12 @@ class LoadResponse(BaseModel): ) supports_reasoning: bool = Field( False, - description = "Whether model supports thinking/reasoning mode (enable_thinking or reasoning_effort)", - ) - reasoning_style: Literal["enable_thinking", "reasoning_effort"] = Field( - "enable_thinking", - description = "Reasoning control style: 'enable_thinking' (boolean) or 'reasoning_effort' (low|medium|high)", + description = "Whether model supports thinking/reasoning mode (enable_thinking)", ) reasoning_always_on: bool = Field( False, description = "Whether reasoning is always on (hardcoded tags, not toggleable)", ) - supports_preserve_thinking: bool = Field( - False, - description = "Whether the template understands the optional preserve_thinking kwarg (Qwen3.6-style)", - ) supports_tools: bool = Field( False, description = "Whether model supports tool calling (web search, etc.)", @@ -212,39 +180,6 @@ class UnloadResponse(BaseModel): model: str = Field(..., description = "Model identifier that was unloaded") -class LoadProgressResponse(BaseModel): - """Progress of the active GGUF load, sampled on demand. - - Used by the UI to show a real progress bar during the - post-download warmup window (mmap + CUDA upload), rather than a - generic "Starting model..." spinner that freezes for minutes on - large MoE models. - """ - - phase: Optional[str] = Field( - None, - description = ( - "Load phase: 'mmap' (weights paging into RAM via mmap), " - "'ready' (llama-server reported healthy), or null when no " - "load is in flight." - ), - ) - bytes_loaded: int = Field( - 0, - description = ( - "Bytes of the model already resident in the llama-server " - "process (VmRSS on Linux)." - ), - ) - bytes_total: int = Field( - 0, - description = "Total bytes across all GGUF shards for the active model.", - ) - fraction: float = Field( - 0.0, description = "bytes_loaded / bytes_total, clamped to 0..1." - ) - - class InferenceStatusResponse(BaseModel): """Current inference backend status""" @@ -278,31 +213,15 @@ class InferenceStatusResponse(BaseModel): inference: Optional[Dict[str, Any]] = Field( None, description = "Recommended inference parameters for the active model" ) - requires_trust_remote_code: bool = Field( - False, - description = "Whether the active model requires trust_remote_code to be enabled for loading.", - ) supports_reasoning: bool = Field( False, description = "Whether the active model supports reasoning/thinking mode" ) - reasoning_style: Literal["enable_thinking", "reasoning_effort"] = Field( - "enable_thinking", - description = "Reasoning control style: 'enable_thinking' (boolean) or 'reasoning_effort' (low|medium|high)", - ) reasoning_always_on: bool = Field( False, description = "Whether reasoning is always on (not toggleable)" ) - supports_preserve_thinking: bool = Field( - False, - description = "Whether the active model's template understands the optional preserve_thinking kwarg", - ) supports_tools: bool = Field( False, description = "Whether the active model supports tool calling" ) - chat_template: Optional[str] = Field( - None, - description = "Jinja2 chat template string for the active model", - ) context_length: Optional[int] = Field( None, description = "Context length of the active model" ) @@ -374,69 +293,14 @@ class ChatMessage(BaseModel): ``content`` may be a plain string (text-only) or a list of content parts for multimodal messages (OpenAI vision format). - Assistant messages that only contain tool calls may set ``content`` - to ``None`` with ``tool_calls`` populated. ``role="tool"`` messages - carry the result of a client-executed tool call and require - ``tool_call_id`` per the OpenAI spec. """ - role: Literal["system", "user", "assistant", "tool"] = Field( + role: Literal["system", "user", "assistant"] = Field( ..., description = "Message role" ) - content: Optional[Union[str, list[ContentPart]]] = Field( - None, description = "Message content (string or multimodal parts)" - ) - tool_call_id: Optional[str] = Field( - None, - description = "OpenAI tool-result messages: id of the tool call this result belongs to.", + content: Union[str, list[ContentPart]] = Field( + ..., description = "Message content (string or multimodal parts)" ) - tool_calls: Optional[list[dict]] = Field( - None, - description = "OpenAI assistant messages: structured tool calls the model decided to make.", - ) - name: Optional[str] = Field( - None, - description = "OpenAI tool-result messages: name of the tool whose result this is.", - ) - - @model_validator(mode = "after") - def _validate_role_shape(self) -> "ChatMessage": - # Enforce the per-role OpenAI spec shape at the request boundary. - # Without this, malformed messages (e.g. user entries with no - # content, tool_calls on a user/system role, role="tool" without - # tool_call_id) would be silently forwarded to llama-server via - # the passthrough path, surfacing as opaque upstream errors or - # broken tool-call reconciliation downstream. - - # Tool-call metadata must appear only on the appropriate role. - if self.tool_calls is not None and self.role != "assistant": - raise ValueError('"tool_calls" is only valid on role="assistant" messages.') - if self.tool_call_id is not None and self.role != "tool": - raise ValueError('"tool_call_id" is only valid on role="tool" messages.') - if self.name is not None and self.role != "tool": - raise ValueError('"name" is only valid on role="tool" messages.') - - # Per-role content requirements. OpenAI-compatible clients may send - # ``content=""`` for image-only turns when the image travels in a - # companion field such as Studio's ``image_base64`` extension, so treat - # empty strings as present content for user/system messages. - if self.role == "tool": - if not self.tool_call_id: - raise ValueError( - 'role="tool" messages require "tool_call_id" per the OpenAI spec.' - ) - if not self.content: - raise ValueError('role="tool" messages require non-empty "content".') - elif self.role == "assistant": - # Assistant messages may omit content when tool_calls is set. - if not self.content and not self.tool_calls: - raise ValueError( - 'role="assistant" messages require either "content" or "tool_calls".' - ) - else: # "user" | "system" - if self.content is None or self.content == []: - raise ValueError(f'role="{self.role}" messages require "content".') - return self class ChatCompletionRequest(BaseModel): @@ -446,49 +310,18 @@ class ChatCompletionRequest(BaseModel): Extensions (non-OpenAI fields) are marked with 'x-unsloth'. """ - # Accept unknown fields defensively so future OpenAI fields (seed, - # response_format, logprobs, frequency_penalty, etc.) don't get - # silently dropped by Pydantic before route code runs. Mirrors - # AnthropicMessagesRequest and ResponsesRequest. - model_config = {"extra": "allow"} - model: str = Field( "default", description = "Model identifier (informational; the active model is used)", ) messages: list[ChatMessage] = Field(..., description = "Conversation messages") - stream: bool = Field( - False, - description = ( - "Whether to stream the response via SSE. Default matches OpenAI's " - "spec (`false`); opt into streaming by sending `stream: true`." - ), - ) + stream: bool = Field(True, description = "Whether to stream the response via SSE") temperature: float = Field(0.6, ge = 0.0, le = 2.0) top_p: float = Field(0.95, ge = 0.0, le = 1.0) max_tokens: Optional[int] = Field( None, ge = 1, description = "Maximum tokens to generate (None = until EOS)" ) presence_penalty: float = Field(0.0, ge = 0.0, le = 2.0, description = "Presence penalty") - stop: Optional[Union[str, list[str]]] = Field( - None, - description = "OpenAI stop sequences: a single string or list of strings at which generation halts.", - ) - tools: Optional[list[dict]] = Field( - None, - description = ( - "OpenAI function-tool definitions. When provided without `enable_tools=true`, " - "Studio forwards the tools to the backend so the model returns structured " - "tool_calls for the client to execute (standard OpenAI function calling)." - ), - ) - tool_choice: Optional[Union[str, dict]] = Field( - None, - description = ( - "OpenAI tool choice: 'auto' | 'required' | 'none' | " - "{'type': 'function', 'function': {'name': ...}}" - ), - ) # ── Unsloth extensions (ignored by standard OpenAI clients) ── top_k: int = Field(20, ge = -1, le = 100, description = "[x-unsloth] Top-k sampling") @@ -518,14 +351,6 @@ class ChatCompletionRequest(BaseModel): None, description = "[x-unsloth] Enable/disable thinking/reasoning mode for supported models", ) - reasoning_effort: Optional[Literal["low", "medium", "high"]] = Field( - None, - description = "[x-unsloth] Reasoning effort level ('low'|'medium'|'high') for Harmony-style reasoning models (e.g. gpt-oss). Overrides enable_thinking when the active model uses reasoning_effort style.", - ) - preserve_thinking: Optional[bool] = Field( - None, - description = "[x-unsloth] When true, keep historical blocks from past assistant turns in the prompt (Qwen3.6 templates). Independent of enable_thinking / reasoning_effort.", - ) enable_tools: Optional[bool] = Field( None, description = "[x-unsloth] Enable tool calling for supported models", @@ -552,10 +377,6 @@ class ChatCompletionRequest(BaseModel): None, description = "[x-unsloth] Session/thread ID for scoping tool execution sandbox.", ) - cancel_id: Optional[str] = Field( - None, - description = "[x-unsloth] Per-request cancellation token. Frontend sends a fresh UUID per run so /inference/cancel matches one specific generation.", - ) # ── Streaming response chunks ──────────────────────────────────── @@ -623,435 +444,3 @@ class ChatCompletion(BaseModel): model: str = "default" choices: list[CompletionChoice] usage: CompletionUsage = Field(default_factory = CompletionUsage) - - -# ===================================================================== -# OpenAI Responses API Models (/v1/responses) -# ===================================================================== - - -# ── Request models ────────────────────────────────────────────── - - -class ResponsesInputTextPart(BaseModel): - """Text content part in a Responses API message (type=input_text).""" - - type: Literal["input_text"] - text: str - - -class ResponsesInputImagePart(BaseModel): - """Image content part in a Responses API message (type=input_image).""" - - type: Literal["input_image"] - image_url: str = Field(..., description = "data:image/png;base64,... or https://...") - detail: Optional[Literal["auto", "low", "high"]] = "auto" - - -class ResponsesOutputTextPart(BaseModel): - """Assistant ``output_text`` content part replayed on subsequent turns. - - When a client (OpenAI Codex CLI, OpenAI Python SDK agents) loops on a - stateless Responses endpoint, prior assistant messages are round-tripped - as ``{"role":"assistant","content":[{"type":"output_text","text":..., - "annotations":[],"logprobs":[]}]}``. We preserve the text and ignore - the annotations/logprobs metadata when flattening into Chat Completions. - """ - - type: Literal["output_text"] - text: str - annotations: Optional[list] = None - logprobs: Optional[list] = None - - model_config = {"extra": "allow"} - - -class ResponsesUnknownContentPart(BaseModel): - """Catch-all for content-part types we don't model explicitly. - - Keeps validation green when a client sends newer part types (e.g. - ``input_audio``, ``input_file``) we haven't mapped; these are silently - skipped during normalisation rather than rejected with a 422. - """ - - type: str - - model_config = {"extra": "allow"} - - -ResponsesContentPart = Union[ - ResponsesInputTextPart, - ResponsesInputImagePart, - ResponsesOutputTextPart, - ResponsesUnknownContentPart, -] - - -class ResponsesInputMessage(BaseModel): - """A single message in the Responses API input array.""" - - type: Optional[Literal["message"]] = None - role: Literal["system", "user", "assistant", "developer"] - content: Union[str, list[ResponsesContentPart]] - - # Codex (gpt-5.3-codex+) attaches a `phase` field ("commentary" | - # "final_answer") to assistant messages and requires clients to preserve - # it on subsequent turns. We accept and round-trip it; llama-server does - # not care about it. - model_config = {"extra": "allow"} - - -class ResponsesFunctionCallInputItem(BaseModel): - """A prior assistant function_call being replayed in a multi-turn Responses input. - - The Responses API represents tool calls as top-level input items (not - nested inside assistant messages), correlated across turns by ``call_id``. - """ - - type: Literal["function_call"] - id: Optional[str] = Field( - None, description = "Item id assigned by the server (e.g. fc_...)" - ) - call_id: str = Field( - ..., - description = "Correlation id matching a function_call_output on the next turn.", - ) - name: str - arguments: str = Field( - ..., description = "JSON string of the arguments the model produced." - ) - status: Optional[Literal["in_progress", "completed", "incomplete"]] = None - - -class ResponsesFunctionCallOutputInputItem(BaseModel): - """A tool result supplied by the client for a prior function_call. - - Replaces Chat Completions' ``role="tool"`` message. Correlated to the - originating call by ``call_id``. - """ - - type: Literal["function_call_output"] - id: Optional[str] = None - call_id: str - output: Union[str, list] = Field( - ..., description = "String or content-array result of the tool call." - ) - status: Optional[Literal["in_progress", "completed", "incomplete"]] = None - - -class ResponsesUnknownInputItem(BaseModel): - """Catch-all for Responses input item types we don't model explicitly. - - Covers ``reasoning`` items (replayed from prior o-series / gpt-5 turns) - and any future item types the client may send. These items are dropped - during normalisation — llama-server-backed GGUFs cannot consume them — - but keeping them in the request-model union stops unrelated turns from - failing validation with a 422. - """ - - type: str - - model_config = {"extra": "allow"} - - -def _responses_input_item_discriminator(v: Any) -> str: - """Route a Responses input item to the correct tagged variant. - - Pydantic's default smart-union matching fails when one variant in the - union is tagged with a strict ``Literal`` (``function_call`` / - ``function_call_output``) and the incoming dict uses a different - ``type`` — the other variants' validation errors are hidden and the - outer ``Union[str, list[...]]`` reports a misleading "Input should be a - valid string" error. An explicit discriminator makes the routing - deterministic and lets us fall through to the catch-all. - """ - if isinstance(v, dict): - t = v.get("type") - r = v.get("role") - else: - t = getattr(v, "type", None) - r = getattr(v, "role", None) - if t == "function_call": - return "function_call" - if t == "function_call_output": - return "function_call_output" - if r is not None or t == "message": - return "message" - return "unknown" - - -ResponsesInputItem = Annotated[ - Union[ - Annotated[ResponsesInputMessage, Tag("message")], - Annotated[ResponsesFunctionCallInputItem, Tag("function_call")], - Annotated[ResponsesFunctionCallOutputInputItem, Tag("function_call_output")], - Annotated[ResponsesUnknownInputItem, Tag("unknown")], - ], - Discriminator(_responses_input_item_discriminator), -] - - -class ResponsesFunctionTool(BaseModel): - """Flat function-tool definition used by the Responses API request. - - Unlike Chat Completions (which nests ``{"name": ..., "parameters": ...}`` - inside a ``"function"`` key), the Responses API uses a flat shape with - ``type``, ``name``, ``description``, ``parameters``, and ``strict`` at the - top level of each tool entry. - """ - - type: Literal["function"] - name: str - description: Optional[str] = None - parameters: Optional[dict] = None - strict: Optional[bool] = None - - -class ResponsesRequest(BaseModel): - """OpenAI Responses API request.""" - - model: str = Field("default", description = "Model identifier") - input: Union[str, list[ResponsesInputItem]] = Field( - default = [], - description = "Input text or list of messages / function_call / function_call_output items", - ) - instructions: Optional[str] = Field( - None, description = "System / developer instructions" - ) - temperature: Optional[float] = Field(None, ge = 0.0, le = 2.0) - top_p: Optional[float] = Field(None, ge = 0.0, le = 1.0) - max_output_tokens: Optional[int] = Field(None, ge = 1) - stream: bool = Field(False, description = "Whether to stream the response via SSE") - - # OpenAI function-calling fields — forwarded to llama-server via the - # Chat Completions pass-through (see routes/inference.py). Typed as a - # plain list so built-in tool shapes (``web_search``, ``file_search``, - # ``mcp``, ...) round-trip without validation errors — the translator - # picks out only ``type=="function"`` entries for forwarding. - tools: Optional[list[dict]] = Field( - None, - description = ( - "Responses-shape function tool definitions. Entries with " - '`type="function"` are translated to the Chat Completions nested ' - "shape before being forwarded to llama-server; other tool types " - "(built-in web_search, file_search, mcp, ...) are accepted for SDK " - "compatibility but ignored on the llama-server passthrough." - ), - ) - tool_choice: Optional[Any] = Field( - None, - description = ( - "'auto' | 'required' | 'none' | {'type': 'function', 'name': ...} — " - "the Responses-shape forcing object is translated to the Chat " - "Completions nested shape internally." - ), - ) - parallel_tool_calls: Optional[bool] = None - - previous_response_id: Optional[str] = None - store: Optional[bool] = None - metadata: Optional[dict] = None - truncation: Optional[Any] = None - user: Optional[str] = None - text: Optional[Any] = None - reasoning: Optional[Any] = None - - model_config = {"extra": "allow"} - - -# ── Response models ───────────────────────────────────────────── - - -class ResponsesOutputTextContent(BaseModel): - """A text content block inside an output message.""" - - type: Literal["output_text"] = "output_text" - text: str - annotations: list = Field(default_factory = list) - - -class ResponsesOutputMessage(BaseModel): - """An output message in the Responses API response.""" - - type: Literal["message"] = "message" - id: str = Field(default_factory = lambda: f"msg_{uuid.uuid4().hex[:12]}") - status: Literal["completed", "in_progress"] = "completed" - role: Literal["assistant"] = "assistant" - content: list[ResponsesOutputTextContent] = Field(default_factory = list) - - -class ResponsesOutputFunctionCall(BaseModel): - """A function-call output item in the Responses API response. - - Unlike Chat Completions (which nests tool calls inside the assistant - message), the Responses API emits each tool call as its own top-level - ``output`` item so clients can correlate results via ``call_id`` on the - next turn. - """ - - type: Literal["function_call"] = "function_call" - id: str = Field(default_factory = lambda: f"fc_{uuid.uuid4().hex[:12]}") - call_id: str - name: str - arguments: str = Field( - ..., description = "JSON string of the arguments the model produced." - ) - status: Literal["completed", "in_progress", "incomplete"] = "completed" - - -ResponsesOutputItem = Union[ResponsesOutputMessage, ResponsesOutputFunctionCall] - - -class ResponsesUsage(BaseModel): - """Token usage for a Responses API response (input_tokens, not prompt_tokens).""" - - input_tokens: int = 0 - output_tokens: int = 0 - total_tokens: int = 0 - - -class ResponsesResponse(BaseModel): - """Top-level Responses API response object.""" - - id: str = Field(default_factory = lambda: f"resp_{uuid.uuid4().hex[:12]}") - object: Literal["response"] = "response" - created_at: int = Field(default_factory = lambda: int(time.time())) - status: Literal["completed", "in_progress", "failed"] = "completed" - model: str = "default" - output: list[ResponsesOutputItem] = Field(default_factory = list) - usage: ResponsesUsage = Field(default_factory = ResponsesUsage) - error: Optional[Any] = None - incomplete_details: Optional[Any] = None - instructions: Optional[str] = None - metadata: dict = Field(default_factory = dict) - temperature: Optional[float] = None - top_p: Optional[float] = None - max_output_tokens: Optional[int] = None - previous_response_id: Optional[str] = None - text: Optional[Any] = None - tool_choice: Optional[Any] = None - tools: list = Field(default_factory = list) - truncation: Optional[Any] = None - - -# ===================================================================== -# Anthropic Messages API Models (/v1/messages) -# ===================================================================== - - -# ── Request models ───────────────────────────────────────────── - - -class AnthropicTextBlock(BaseModel): - type: Literal["text"] - text: str - - -class AnthropicImageSource(BaseModel): - type: Literal["base64", "url"] - media_type: Optional[str] = None - data: Optional[str] = None - url: Optional[str] = None - - -class AnthropicImageBlock(BaseModel): - type: Literal["image"] - source: AnthropicImageSource - - -class AnthropicToolUseBlock(BaseModel): - type: Literal["tool_use"] - id: str - name: str - input: dict - - -class AnthropicToolResultBlock(BaseModel): - type: Literal["tool_result"] - tool_use_id: str - content: Union[str, list] = "" - - -AnthropicContentBlock = Union[ - AnthropicTextBlock, - AnthropicImageBlock, - AnthropicToolUseBlock, - AnthropicToolResultBlock, -] - - -class AnthropicMessage(BaseModel): - role: Literal["user", "assistant"] - content: Union[str, list[AnthropicContentBlock]] - - -class AnthropicTool(BaseModel): - name: str - description: Optional[str] = None - input_schema: dict - - -class AnthropicMessagesRequest(BaseModel): - model: str = "default" - max_tokens: Optional[int] = None - messages: list[AnthropicMessage] - system: Optional[Union[str, list]] = None - tools: Optional[list[AnthropicTool]] = None - tool_choice: Optional[Any] = None - stream: bool = False - temperature: Optional[float] = None - top_p: Optional[float] = None - top_k: Optional[int] = None - stop_sequences: Optional[list[str]] = None - metadata: Optional[dict] = None - # [x-unsloth] extensions — mirror the OpenAI endpoint convenience fields - min_p: Optional[float] = Field( - None, ge = 0.0, le = 1.0, description = "[x-unsloth] Min-p sampling threshold" - ) - repetition_penalty: Optional[float] = Field( - None, ge = 1.0, le = 2.0, description = "[x-unsloth] Repetition penalty" - ) - presence_penalty: Optional[float] = Field( - None, ge = 0.0, le = 2.0, description = "[x-unsloth] Presence penalty" - ) - enable_tools: Optional[bool] = None - enabled_tools: Optional[list[str]] = None - session_id: Optional[str] = None - cancel_id: Optional[str] = None - model_config = {"extra": "allow"} - - -# ── Response models ──────────────────────────────────────────── - - -class AnthropicUsage(BaseModel): - input_tokens: int = 0 - output_tokens: int = 0 - - -class AnthropicResponseTextBlock(BaseModel): - type: Literal["text"] = "text" - text: str - - -class AnthropicResponseToolUseBlock(BaseModel): - type: Literal["tool_use"] = "tool_use" - id: str - name: str - input: dict - - -AnthropicResponseBlock = Union[ - AnthropicResponseTextBlock, AnthropicResponseToolUseBlock -] - - -class AnthropicMessagesResponse(BaseModel): - id: str = Field(default_factory = lambda: f"msg_{uuid.uuid4().hex[:24]}") - type: Literal["message"] = "message" - role: Literal["assistant"] = "assistant" - content: list[AnthropicResponseBlock] = Field(default_factory = list) - model: str = "default" - stop_reason: Optional[str] = None - stop_sequence: Optional[str] = None - usage: AnthropicUsage = Field(default_factory = AnthropicUsage) diff --git a/studio/backend/models/models.py b/studio/backend/models/models.py index 46ca4e3784..f67014a17b 100644 --- a/studio/backend/models/models.py +++ b/studio/backend/models/models.py @@ -213,68 +213,3 @@ class ScanFolderInfo(BaseModel): id: int = Field(..., description = "Database row ID") path: str = Field(..., description = "Normalized absolute path") created_at: str = Field(..., description = "ISO 8601 creation timestamp") - - -class BrowseEntry(BaseModel): - """A directory entry surfaced by the folder browser.""" - - name: str = Field(..., description = "Entry name (basename, not full path)") - has_models: bool = Field( - False, - description = ( - "Hint that the directory likely contains models " - "(*.gguf, *.safetensors, config.json, or HF-style " - "`models--*` subfolders). Used by the UI to highlight " - "promising candidates; the scanner itself is authoritative." - ), - ) - hidden: bool = Field( - False, - description = "Name starts with a dot (e.g. `.cache`)", - ) - - -class BrowseFoldersResponse(BaseModel): - """Response schema for the folder browser endpoint.""" - - current: str = Field(..., description = "Absolute path of the directory just listed") - parent: Optional[str] = Field( - None, - description = ( - "Parent directory of `current`, or null if `current` is the " - "filesystem root. The frontend uses this to render an `Up` row." - ), - ) - entries: List[BrowseEntry] = Field( - default_factory = list, - description = ( - "Subdirectories of `current`. Sorted with model-bearing " - "directories first, then alphabetically case-insensitive; " - "hidden entries come last within each group." - ), - ) - suggestions: List[str] = Field( - default_factory = list, - description = ( - "Handy starting points (home, HF cache, already-registered " - "scan folders). Rendered as quick-pick chips above the list." - ), - ) - truncated: bool = Field( - False, - description = ( - "True when the listing was capped because the directory had " - "more subfolders than the server is willing to enumerate in " - "one request. The UI should show a hint telling the user to " - "narrow their path." - ), - ) - model_files_here: int = Field( - 0, - description = ( - "Count of GGUF/safetensors files immediately inside " - "``current``. Used by the UI to surface a hint on leaf " - "model directories (which otherwise look `empty` because " - "they contain only files, no subdirectories)." - ), - ) diff --git a/studio/backend/models/training.py b/studio/backend/models/training.py index a9f4caa1bb..07a306ca39 100644 --- a/studio/backend/models/training.py +++ b/studio/backend/models/training.py @@ -127,9 +127,6 @@ def _compat_split(cls, values: Any) -> Any: wandb_project: Optional[str] = Field(None, description = "W&B project name") enable_tensorboard: bool = Field(False, description = "Enable TensorBoard logging") tensorboard_dir: Optional[str] = Field(None, description = "TensorBoard directory") - resume_from_checkpoint: Optional[str] = Field( - None, description = "Saved training output directory to resume from" - ) # GPU selection gpu_ids: Optional[List[int]] = Field( @@ -223,8 +220,6 @@ class TrainingRunSummary(BaseModel): duration_seconds: Optional[float] = None error_message: Optional[str] = None loss_sparkline: Optional[List[float]] = None - can_resume: bool = False - resumed_later: bool = False class TrainingRunListResponse(BaseModel): diff --git a/studio/backend/plugins/data-designer-github-repo-seed/README.md b/studio/backend/plugins/data-designer-github-repo-seed/README.md deleted file mode 100644 index 346d94b305..0000000000 --- a/studio/backend/plugins/data-designer-github-repo-seed/README.md +++ /dev/null @@ -1,73 +0,0 @@ -# data-designer-github-repo-seed - -A Data Designer seed-reader plugin for **Unsloth Studio** that scrapes real -GitHub data (issues, pull requests, commits) from one or more repositories -and hands it to the recipe pipeline as a seed dataset. - -Designed to ship with Studio as a default seed source so any user with a -GitHub token can build training datasets straight from live repos. - -## What it does - -Given a list of `owner/name` repos, a GitHub token, and a per-resource -`limit`, the plugin uses GitHub's GraphQL API to fetch issues, pull -requests, and/or commits, with labels, state, authors, and the first N -comments of each item, and materialises a single JSONL with uniform -columns so the rest of the recipe (LLM text / LLM structured / processors) -can treat it like any other seed table. - -| Column | Description | -|---------------|------------------------------------------------| -| `item_type` | `issue` / `pull` / `commit` | -| `repo` | `owner/name` | -| `number` | Issue/PR number, or commit SHA | -| `title` | Title (or commit message headline) | -| `body` | Issue/PR body (or full commit message) | -| `state` | `OPEN` / `CLOSED` / `MERGED` (empty for commit)| -| `author` | GitHub login of the author | -| `created_at` | ISO8601 | -| `closed_at` | ISO8601 (empty for commits) | -| `url` | Permalink | -| `labels` | List of label names | -| `comments` | First N comments concatenated | - -## Usage in a recipe - -```json -{ - "seed_config": { - "source": { - "seed_type": "github_repo", - "repos": ["unslothai/unsloth", "unslothai/unsloth-zoo"], - "token": "", - "item_types": ["issues", "pulls"], - "limit": 100, - "include_comments": true, - "max_comments_per_item": 30 - }, - "sampling_strategy": "shuffle", - "selection_strategy": null - } -} -``` - -Leave `token` empty to fall back to the server's `GH_TOKEN` / `GITHUB_TOKEN` -environment variable, useful when the recipe is published and shouldn't -carry a secret. - -## Auth - -A GitHub personal access token with `public_repo` scope is enough for public -repositories; `repo` scope is required for private ones. GraphQL requests -are rate-limit aware: the client inspects `x-ratelimit-*` headers and -sleeps until reset when the budget drops below a safety threshold. - -## Install - -Shipped as a default Studio plugin. For development: - -```bash -pip install -e . -``` - -Registered automatically via the `data_designer.plugins` entry point. diff --git a/studio/backend/plugins/data-designer-github-repo-seed/pyproject.toml b/studio/backend/plugins/data-designer-github-repo-seed/pyproject.toml deleted file mode 100644 index e232adc60c..0000000000 --- a/studio/backend/plugins/data-designer-github-repo-seed/pyproject.toml +++ /dev/null @@ -1,25 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 - -[build-system] -requires = ["setuptools>=68", "wheel"] -build-backend = "setuptools.build_meta" - -[project] -name = "data-designer-github-repo-seed" -version = "0.1.0" -description = "Unsloth Studio seed plugin that scrapes GitHub issues, PRs, and commits." -requires-python = ">=3.11" -dependencies = [ - "data-designer-engine>=0.5.4,<0.6", - "requests>=2.31", -] - -[project.entry-points."data_designer.plugins"] -github_repo_seed = "data_designer_github_repo_seed.plugin:github_repo_seed_plugin" - -[tool.setuptools] -package-dir = {"" = "src"} - -[tool.setuptools.packages.find] -where = ["src"] diff --git a/studio/backend/plugins/data-designer-github-repo-seed/src/data_designer_github_repo_seed/__init__.py b/studio/backend/plugins/data-designer-github-repo-seed/src/data_designer_github_repo_seed/__init__.py deleted file mode 100644 index f57af4c6c3..0000000000 --- a/studio/backend/plugins/data-designer-github-repo-seed/src/data_designer_github_repo_seed/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 - -# Intentionally empty. Data-designer loads submodules lazily via qualified names -# (impl_qualified_name / config_qualified_name in plugin.py), so importing this -# package must NOT touch modules that depend on data_designer.engine.* during -# Studio's bootstrap (circular import). diff --git a/studio/backend/plugins/data-designer-github-repo-seed/src/data_designer_github_repo_seed/config.py b/studio/backend/plugins/data-designer-github-repo-seed/src/data_designer_github_repo_seed/config.py deleted file mode 100644 index 6b347c4f83..0000000000 --- a/studio/backend/plugins/data-designer-github-repo-seed/src/data_designer_github_repo_seed/config.py +++ /dev/null @@ -1,64 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 - -from __future__ import annotations - -from typing import Literal - -from pydantic import Field, field_validator, model_validator - -from data_designer.config.seed_source import SeedSource - - -class GitHubRepoSeedSource(SeedSource): - seed_type: Literal["github_repo"] = "github_repo" - - repos: list[str] = Field( - default_factory = list, - description = "List of GitHub repositories to scrape, each in `owner/name` form.", - ) - token: str = Field( - default = "", - description = "Personal access token. Leave blank to read GH_TOKEN / GITHUB_TOKEN from env at run time.", - ) - item_types: list[Literal["issues", "pulls", "commits"]] = Field( - default = ["issues", "pulls"], - description = "Which GitHub item types to fetch per repo.", - ) - limit: int = Field( - default = 100, - ge = 1, - le = 5000, - description = "Maximum items per repo per item type (e.g. limit=100 + ['issues','pulls'] => up to 200 items per repo).", - ) - include_comments: bool = Field( - default = True, - description = "Fetch the first N comments of each issue/PR and include them in the `comments` column.", - ) - max_comments_per_item: int = Field(default = 30, ge = 0, le = 200) - - @field_validator("repos") - @classmethod - def _validate_repos(cls, v: list[str]) -> list[str]: - out: list[str] = [] - for r in v or []: - r = r.strip() - if not r: - continue - if r.count("/") != 1 or not all(r.split("/")): - raise ValueError(f"Each repo must be `owner/name`; got {r!r}") - out.append(r) - return out - - @field_validator("item_types") - @classmethod - def _validate_item_types(cls, v: list[str]) -> list[str]: - if not v: - raise ValueError("item_types must not be empty") - return list(dict.fromkeys(v)) - - @model_validator(mode = "after") - def _ensure_repos(self) -> "GitHubRepoSeedSource": - if not self.repos: - raise ValueError("At least one repo is required") - return self diff --git a/studio/backend/plugins/data-designer-github-repo-seed/src/data_designer_github_repo_seed/impl.py b/studio/backend/plugins/data-designer-github-repo-seed/src/data_designer_github_repo_seed/impl.py deleted file mode 100644 index 5a38e26d6b..0000000000 --- a/studio/backend/plugins/data-designer-github-repo-seed/src/data_designer_github_repo_seed/impl.py +++ /dev/null @@ -1,83 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 - -from __future__ import annotations - -import hashlib -import tempfile -import threading -from pathlib import Path -from typing import Optional - -import data_designer.lazy_heavy_imports as lazy -from data_designer.engine.resources.seed_reader import SeedReader - -from .config import GitHubRepoSeedSource -from .scraper import ScrapeConfig, materialize_to_jsonl - - -# In-process cache mapping a stable config signature to the JSONL materialization -# path. A single recipe job invokes the seed reader multiple times (validation, -# preview, per-column sampling), and the default flow re-scrapes the repo on -# every call: for a 2-repo preview that is ~15s of redundant GitHub GraphQL -# traffic before any generation fires. Memoize the materialization so the second -# and third passes reuse the file the first pass wrote. Cache key excludes the -# raw token and uses a short SHA-256 digest so token values never hit memory -# twice and token rotation invalidates cleanly. -_SCRAPE_CACHE: dict[tuple, str] = {} -_SCRAPE_CACHE_LOCK = threading.Lock() - - -def _scrape_cache_key(cfg: ScrapeConfig) -> tuple: - token_digest = hashlib.sha256( - (cfg.token or "").encode("utf-8"), - ).hexdigest()[:16] - return ( - tuple(cfg.repos), - tuple(cfg.item_types), - cfg.limit, - bool(cfg.include_comments), - cfg.max_comments_per_item, - token_digest, - ) - - -def _lookup_cached_scrape(key: tuple) -> Optional[str]: - with _SCRAPE_CACHE_LOCK: - path = _SCRAPE_CACHE.get(key) - if path and Path(path).exists(): - return path - # Stale entry (tmp cleanup, user restarted, ...); drop it so the caller - # materializes a fresh file rather than returning a dangling path. - if path: - with _SCRAPE_CACHE_LOCK: - _SCRAPE_CACHE.pop(key, None) - return None - - -def _store_cached_scrape(key: tuple, path: str) -> None: - with _SCRAPE_CACHE_LOCK: - _SCRAPE_CACHE[key] = path - - -class GitHubRepoSeedReader(SeedReader[GitHubRepoSeedSource]): - def create_duckdb_connection(self): - return lazy.duckdb.connect() - - def get_dataset_uri(self) -> str: - out_dir = Path(tempfile.gettempdir()) / "studio-github-repo-seed" - cfg = ScrapeConfig( - repos = list(self.source.repos), - token = self.source.token, - item_types = list(self.source.item_types), - limit = self.source.limit, - include_comments = self.source.include_comments, - max_comments_per_item = self.source.max_comments_per_item, - ) - cache_key = _scrape_cache_key(cfg) - cached_path = _lookup_cached_scrape(cache_key) - if cached_path is not None: - return cached_path - path = materialize_to_jsonl(cfg, out_dir) - _store_cached_scrape(cache_key, str(path)) - return str(path) diff --git a/studio/backend/plugins/data-designer-github-repo-seed/src/data_designer_github_repo_seed/plugin.py b/studio/backend/plugins/data-designer-github-repo-seed/src/data_designer_github_repo_seed/plugin.py deleted file mode 100644 index f87dbd0507..0000000000 --- a/studio/backend/plugins/data-designer-github-repo-seed/src/data_designer_github_repo_seed/plugin.py +++ /dev/null @@ -1,10 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 - -from data_designer.plugins.plugin import Plugin, PluginType - -github_repo_seed_plugin = Plugin( - impl_qualified_name = "data_designer_github_repo_seed.impl.GitHubRepoSeedReader", - config_qualified_name = "data_designer_github_repo_seed.config.GitHubRepoSeedSource", - plugin_type = PluginType.SEED_READER, -) diff --git a/studio/backend/plugins/data-designer-github-repo-seed/src/data_designer_github_repo_seed/scraper.py b/studio/backend/plugins/data-designer-github-repo-seed/src/data_designer_github_repo_seed/scraper.py deleted file mode 100644 index d768fe37be..0000000000 --- a/studio/backend/plugins/data-designer-github-repo-seed/src/data_designer_github_repo_seed/scraper.py +++ /dev/null @@ -1,236 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 - -"""Multi-repo GitHub scraper for the Studio seed plugin. - -Drives the GraphQL-based scraper in `scraper_impl/` per repo. Each repo is -scraped with a trial_limits cap so we stop at `limit` items per resource. -After scraping, we read the per-resource JSONL shards and flatten them into -a single unified JSONL with stable columns (`item_type`, `repo`, `number`, -`title`, `body`, ...). -""" - -from __future__ import annotations - -import json -import os -import sys -import time -import uuid -from dataclasses import dataclass -from pathlib import Path - -# Defer scraper_impl imports until `scrape()` runs with a resolved token. -_IMPL_DIR = Path(__file__).parent / "scraper_impl" - - -def _ensure_impl_on_path() -> None: - if str(_IMPL_DIR) not in sys.path: - sys.path.insert(0, str(_IMPL_DIR)) - - -def _load_impl(): - _ensure_impl_on_path() - import importlib - - gh_client = importlib.import_module("gh_client") # type: ignore - scraper_mod = importlib.import_module("scraper") # type: ignore - return gh_client.GitHubClient, scraper_mod.RepoScraper - - -@dataclass -class ScrapeConfig: - repos: list[str] - token: str - item_types: list[str] - limit: int - include_comments: bool - max_comments_per_item: int - - -def _resolve_token(token: str) -> str: - tok = token or os.environ.get("GH_TOKEN", "") or os.environ.get("GITHUB_TOKEN", "") - if not tok: - raise ValueError( - "GitHub token is required. Set it in the recipe config or the GH_TOKEN / GITHUB_TOKEN env var." - ) - return tok - - -def _read_jsonl(path: Path, max_rows: int | None = None): - if not path.exists(): - return - with path.open(encoding = "utf-8") as f: - for i, line in enumerate(f): - if not line.strip(): - continue - if max_rows is not None and i >= max_rows: - return - try: - yield json.loads(line) - except json.JSONDecodeError: - continue - - -def _flatten_issue_row(r: dict, repo: str, include_comments: bool, max_c: int) -> dict: - labels = [ - l.get("name") - for l in (r.get("labels", {}) or {}).get("nodes", []) - if l.get("name") - ] - comments_nodes = (r.get("comments") or {}).get("nodes") or [] - comments_text = "" - if include_comments and comments_nodes: - kept = comments_nodes[:max_c] - comments_text = "\n\n".join( - f"[{(c.get('author') or {}).get('login', '?')}]: {c.get('body') or ''}" - for c in kept - ) - return { - "item_type": "issue", - "repo": repo, - "number": r.get("number"), - "title": r.get("title") or "", - "body": r.get("body") or "", - "state": r.get("state") or "", - "author": (r.get("author") or {}).get("login", ""), - "created_at": r.get("createdAt") or "", - "closed_at": r.get("closedAt") or "", - "url": r.get("url") or r.get("permalink") or "", - "labels": labels, - "comments": comments_text, - } - - -def _flatten_pr_row(r: dict, repo: str, include_comments: bool, max_c: int) -> dict: - labels = [ - l.get("name") - for l in (r.get("labels", {}) or {}).get("nodes", []) - if l.get("name") - ] - comments_nodes = (r.get("comments") or {}).get("nodes") or [] - comments_text = "" - if include_comments and comments_nodes: - kept = comments_nodes[:max_c] - comments_text = "\n\n".join( - f"[{(c.get('author') or {}).get('login', '?')}]: {c.get('body') or ''}" - for c in kept - ) - return { - "item_type": "pull", - "repo": repo, - "number": r.get("number"), - "title": r.get("title") or "", - "body": r.get("body") or "", - "state": r.get("state") or "", - "author": (r.get("author") or {}).get("login", ""), - "created_at": r.get("createdAt") or "", - "closed_at": r.get("closedAt") or "", - "url": r.get("url") or r.get("permalink") or "", - "labels": labels, - "comments": comments_text, - } - - -def _flatten_commit_row(r: dict, repo: str) -> dict: - msg = r.get("messageHeadline") or r.get("message") or "" - body = r.get("messageBody") or r.get("message") or msg - author = r.get("author") or {} - return { - "item_type": "commit", - "repo": repo, - "number": r.get("oid") or r.get("sha") or "", - "title": msg, - "body": body, - "state": "", - "author": (author.get("user") or {}).get("login") or author.get("name", ""), - "created_at": (author.get("date") or r.get("committedDate") or ""), - "closed_at": "", - "url": r.get("url") or "", - "labels": [], - "comments": "", - } - - -def scrape(cfg: ScrapeConfig, base_dir: Path): - token = _resolve_token(cfg.token) - GitHubClient, RepoScraper = _load_impl() - client = GitHubClient(token = token) - base_dir.mkdir(parents = True, exist_ok = True) - - # Per-resource trial limits. limit <= 0 means "all": use a very large cap. - effective_limit = cfg.limit if cfg.limit and cfg.limit > 0 else 1_000_000 - trial_limits: dict[str, int] = {} - if "issues" in cfg.item_types: - trial_limits["issues"] = effective_limit - if "pulls" in cfg.item_types: - trial_limits["pull_requests"] = effective_limit - if "commits" in cfg.item_types: - trial_limits["commits"] = effective_limit - - all_rows: list[dict] = [] - for repo in cfg.repos: - owner, name = repo.split("/", 1) - scraper = RepoScraper( - owner = owner, - name = name, - base_dir = base_dir, - client = client, - trial_limits = trial_limits, - light = True, - ) - try: - repo_meta = scraper.scrape_repo_meta() - if "issues" in cfg.item_types: - scraper.scrape_issues() - if "pulls" in cfg.item_types: - scraper.scrape_prs() - if "commits" in cfg.item_types: - default_ref = repo_meta.get("defaultBranchRef") or {} - default_branch = ( - default_ref.get("name") if isinstance(default_ref, dict) else None - ) - branch = ( - f"refs/heads/{default_branch}" - if default_branch - else "refs/heads/main" - ) - scraper.scrape_commits(branch = branch) - finally: - scraper.close() - - read_cap = cfg.limit if cfg.limit and cfg.limit > 0 else None - repo_dir = base_dir / f"{owner}__{name}" - if "issues" in cfg.item_types: - for row in _read_jsonl(repo_dir / "issues.jsonl", read_cap): - all_rows.append( - _flatten_issue_row( - row, repo, cfg.include_comments, cfg.max_comments_per_item - ) - ) - if "pulls" in cfg.item_types: - for row in _read_jsonl(repo_dir / "pull_requests.jsonl", read_cap): - all_rows.append( - _flatten_pr_row( - row, repo, cfg.include_comments, cfg.max_comments_per_item - ) - ) - if "commits" in cfg.item_types: - for row in _read_jsonl(repo_dir / "commits.jsonl", read_cap): - all_rows.append(_flatten_commit_row(row, repo)) - - return all_rows - - -def materialize_to_jsonl(cfg: ScrapeConfig, out_dir: Path) -> Path: - out_dir.mkdir(parents = True, exist_ok = True) - tag = "-".join(r.replace("/", "__") for r in cfg.repos)[:120] - kinds = "-".join(cfg.item_types) - run_id = f"{int(time.time())}-{uuid.uuid4().hex[:12]}" - fname = f"github_{tag}__{kinds}__{cfg.limit}_{run_id}.jsonl" - out = out_dir / fname - rows = scrape(cfg, out_dir / "raw-runs" / run_id) - with out.open("w", encoding = "utf-8") as f: - for r in rows: - f.write(json.dumps(r, ensure_ascii = False) + "\n") - return out diff --git a/studio/backend/plugins/data-designer-github-repo-seed/src/data_designer_github_repo_seed/scraper_impl/__init__.py b/studio/backend/plugins/data-designer-github-repo-seed/src/data_designer_github_repo_seed/scraper_impl/__init__.py deleted file mode 100644 index 32014236c6..0000000000 --- a/studio/backend/plugins/data-designer-github-repo-seed/src/data_designer_github_repo_seed/scraper_impl/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 diff --git a/studio/backend/plugins/data-designer-github-repo-seed/src/data_designer_github_repo_seed/scraper_impl/gh_client.py b/studio/backend/plugins/data-designer-github-repo-seed/src/data_designer_github_repo_seed/scraper_impl/gh_client.py deleted file mode 100644 index dd2de2f5ce..0000000000 --- a/studio/backend/plugins/data-designer-github-repo-seed/src/data_designer_github_repo_seed/scraper_impl/gh_client.py +++ /dev/null @@ -1,248 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 - -"""GitHub API client with rate-limit awareness, retry, and dual REST/GraphQL support.""" - -from __future__ import annotations - -import json -import os -import time -import logging -from typing import Any, Dict, Iterable, Iterator, List, Optional - -import requests - -log = logging.getLogger("gh_client") - -GRAPHQL_URL = "https://api.github.com/graphql" -REST_BASE = "https://api.github.com" - -BASE_HEADERS = { - "Accept": "application/vnd.github+json", - "X-GitHub-Api-Version": "2022-11-28", - "User-Agent": "github-data-gatherer/1.0", -} - - -class RateLimitError(Exception): - pass - - -class GitHubClient: - def __init__( - self, - min_remaining_graphql: int = 100, - min_remaining_rest: int = 100, - token: str | None = None, - ): - token = token or os.environ.get("GH_TOKEN") or os.environ.get("GITHUB_TOKEN") - if not token: - raise RuntimeError("GH_TOKEN not set in environment") - self.session = requests.Session() - self.session.headers.update( - {**BASE_HEADERS, "Authorization": f"Bearer {token}"} - ) - self.min_remaining_graphql = min_remaining_graphql - self.min_remaining_rest = min_remaining_rest - self.graphql_remaining: Optional[int] = None - self.graphql_reset: Optional[int] = None - self.rest_remaining: Optional[int] = None - self.rest_reset: Optional[int] = None - self.calls_graphql = 0 - self.calls_rest = 0 - self.retry_count = 0 - - def _sleep_until(self, reset_ts: int, buffer_s: int = 10) -> None: - now = int(time.time()) - wait = max(0, reset_ts - now) + buffer_s - log.warning("Rate limit hit. Sleeping %ds until reset.", wait) - time.sleep(wait) - - def _check_rate_and_wait(self, kind: str) -> None: - if kind == "graphql": - remaining = self.graphql_remaining - reset = self.graphql_reset - min_remaining = self.min_remaining_graphql - else: - remaining = self.rest_remaining - reset = self.rest_reset - min_remaining = self.min_remaining_rest - if remaining is not None and remaining < min_remaining: - if reset: - self._sleep_until(reset) - # Reset remaining so we don't spin - if kind == "graphql": - self.graphql_remaining = None - else: - self.rest_remaining = None - - def graphql( - self, - query: str, - variables: Optional[Dict[str, Any]] = None, - max_retries: int = 20, - ) -> Dict[str, Any]: - self._check_rate_and_wait("graphql") - backoff = 2 - last_err = None - for attempt in range(max_retries): - try: - r = self.session.post( - GRAPHQL_URL, - json = {"query": query, "variables": variables or {}}, - timeout = 120, - ) - self.calls_graphql += 1 - # Update rate info from response headers - rem = r.headers.get("X-RateLimit-Remaining") - rst = r.headers.get("X-RateLimit-Reset") - if rem is not None: - try: - self.graphql_remaining = int(rem) - except ValueError: - pass - if rst is not None: - try: - self.graphql_reset = int(rst) - except ValueError: - pass - if r.status_code in (502, 503, 504): - log.warning("GraphQL %s transient, retrying", r.status_code) - time.sleep(backoff) - backoff = min(backoff * 2, 60) - continue - if r.status_code == 403 or r.status_code == 429: - # Check for secondary/abuse - retry_after = r.headers.get("Retry-After") - if retry_after: - t = int(retry_after) - log.warning("Secondary rate limit. Sleep %ds.", t) - time.sleep(t + 2) - continue - if self.graphql_reset: - self._sleep_until(self.graphql_reset) - continue - time.sleep(60) - continue - r.raise_for_status() - data = r.json() - if "errors" in data and data["errors"]: - # Surface errors but allow partial data - errs = data["errors"] - # Retry on RATE_LIMITED - for e in errs: - if e.get("type") == "RATE_LIMITED": - self._sleep_until( - (self.graphql_reset or int(time.time()) + 60) - ) - break - else: - # No rate-limit error, log and return partial - log.warning("GraphQL errors: %s", json.dumps(errs)[:400]) - return data - continue - return data - except requests.RequestException as e: - last_err = e - log.warning("GraphQL network error: %s. Retry.", e) - time.sleep(backoff) - backoff = min(backoff * 2, 60) - raise RuntimeError(f"GraphQL failed after {max_retries} retries: {last_err}") - - def rest( - self, - method: str, - path: str, - params: Optional[Dict[str, Any]] = None, - json_body: Optional[Dict[str, Any]] = None, - max_retries: int = 6, - ) -> requests.Response: - self._check_rate_and_wait("rest") - if path.startswith("http"): - url = path - else: - url = REST_BASE + path - backoff = 2 - last_err = None - for attempt in range(max_retries): - try: - r = self.session.request( - method, url, params = params, json = json_body, timeout = 120 - ) - self.calls_rest += 1 - rem = r.headers.get("X-RateLimit-Remaining") - rst = r.headers.get("X-RateLimit-Reset") - if rem is not None: - try: - self.rest_remaining = int(rem) - except ValueError: - pass - if rst is not None: - try: - self.rest_reset = int(rst) - except ValueError: - pass - if r.status_code in (502, 503, 504): - log.warning("REST %s transient, retrying", r.status_code) - time.sleep(backoff) - backoff = min(backoff * 2, 60) - continue - if r.status_code in (403, 429): - retry_after = r.headers.get("Retry-After") - if retry_after: - t = int(retry_after) - log.warning("Secondary rate limit on REST. Sleep %ds.", t) - time.sleep(t + 2) - continue - # Check if primary rate - if self.rest_remaining == 0 and self.rest_reset: - self._sleep_until(self.rest_reset) - continue - log.warning("REST 403/429, sleep 60") - time.sleep(60) - continue - return r - except requests.RequestException as e: - last_err = e - log.warning("REST network error: %s. Retry.", e) - time.sleep(backoff) - backoff = min(backoff * 2, 60) - raise RuntimeError(f"REST failed after {max_retries} retries: {last_err}") - - def rest_paginate( - self, path: str, params: Optional[Dict[str, Any]] = None, per_page: int = 100 - ) -> Iterator[dict]: - params = dict(params or {}) - params.setdefault("per_page", per_page) - url = path - while True: - r = self.rest("GET", url, params = params if url == path else None) - if r.status_code != 200: - log.error( - "REST paginate got %s at %s: %s", r.status_code, url, r.text[:200] - ) - return - items = r.json() - if isinstance(items, dict): - # Some endpoints return dict with list field - items = items.get("items", []) - for it in items: - yield it - # Follow link header - link = r.headers.get("Link", "") - nxt = None - for part in link.split(","): - if 'rel="next"' in part: - nxt = part.split(";")[0].strip().strip("<>") - break - if not nxt: - return - url = nxt - params = None - - def rate_snapshot(self) -> Dict[str, Any]: - r = self.rest("GET", "/rate_limit") - if r.status_code == 200: - return r.json() - return {} diff --git a/studio/backend/plugins/data-designer-github-repo-seed/src/data_designer_github_repo_seed/scraper_impl/queries.py b/studio/backend/plugins/data-designer-github-repo-seed/src/data_designer_github_repo_seed/scraper_impl/queries.py deleted file mode 100644 index 9dc7613db5..0000000000 --- a/studio/backend/plugins/data-designer-github-repo-seed/src/data_designer_github_repo_seed/scraper_impl/queries.py +++ /dev/null @@ -1,685 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 - -"""GraphQL queries for GitHub data scraping. - -GitHub's GraphQL rejects queries that define unused fragments, so each query -only includes the fragments it actually references. -""" - -# ---- Fragments (kept as raw strings, composed per query) ---- -F_ACTOR = """ -fragment ActorFields on Actor { - __typename - login - url - avatarUrl - ... on User { id databaseId name } - ... on Bot { id databaseId } - ... on Organization { id databaseId name } -} -""" - -F_LABEL = """ -fragment LabelFields on Label { - id - name - color - description - createdAt -} -""" - -F_TIMELINE = """ -fragment TimelineItem on IssueTimelineItems { - __typename - ... on Node { id } - ... on AddedToProjectEvent { createdAt actor { ...ActorFields } } - ... on AssignedEvent { createdAt actor { ...ActorFields } assignee { __typename ... on User { login } ... on Bot { login } } } - ... on ClosedEvent { createdAt actor { ...ActorFields } stateReason closer { __typename ... on Commit { oid url } ... on PullRequest { number url } } } - ... on CommentDeletedEvent { createdAt actor { ...ActorFields } } - ... on ConnectedEvent { createdAt actor { ...ActorFields } source { __typename ... on Issue { number url repository { nameWithOwner } } ... on PullRequest { number url repository { nameWithOwner } } } subject { __typename ... on Issue { number url } ... on PullRequest { number url } } } - ... on ConvertedNoteToIssueEvent { createdAt actor { ...ActorFields } } - ... on CrossReferencedEvent { createdAt actor { ...ActorFields } isCrossRepository willCloseTarget source { __typename ... on Issue { number url repository { nameWithOwner } title } ... on PullRequest { number url repository { nameWithOwner } title } } } - ... on DemilestonedEvent { createdAt actor { ...ActorFields } milestoneTitle } - ... on DisconnectedEvent { createdAt actor { ...ActorFields } subject { __typename ... on Issue { number url } ... on PullRequest { number url } } source { __typename ... on Issue { number url } ... on PullRequest { number url } } } - ... on IssueComment { id databaseId createdAt updatedAt author { ...ActorFields } body url reactionGroups { content reactors { totalCount } } } - ... on LabeledEvent { createdAt actor { ...ActorFields } label { name color } } - ... on LockedEvent { createdAt actor { ...ActorFields } lockReason } - ... on MarkedAsDuplicateEvent { createdAt actor { ...ActorFields } canonical { __typename ... on Issue { number url } ... on PullRequest { number url } } } - ... on MentionedEvent { createdAt actor { ...ActorFields } } - ... on MilestonedEvent { createdAt actor { ...ActorFields } milestoneTitle } - ... on MovedColumnsInProjectEvent { createdAt actor { ...ActorFields } } - ... on PinnedEvent { createdAt actor { ...ActorFields } } - ... on ReferencedEvent { createdAt actor { ...ActorFields } commit { oid url } commitRepository { nameWithOwner } } - ... on RemovedFromProjectEvent { createdAt actor { ...ActorFields } } - ... on RenamedTitleEvent { createdAt actor { ...ActorFields } previousTitle currentTitle } - ... on ReopenedEvent { createdAt actor { ...ActorFields } } - ... on SubscribedEvent { createdAt actor { ...ActorFields } } - ... on TransferredEvent { createdAt actor { ...ActorFields } fromRepository { nameWithOwner } } - ... on UnassignedEvent { createdAt actor { ...ActorFields } assignee { __typename ... on User { login } ... on Bot { login } } } - ... on UnlabeledEvent { createdAt actor { ...ActorFields } label { name color } } - ... on UnlockedEvent { createdAt actor { ...ActorFields } } - ... on UnmarkedAsDuplicateEvent { createdAt actor { ...ActorFields } } - ... on UnpinnedEvent { createdAt actor { ...ActorFields } } - ... on UnsubscribedEvent { createdAt actor { ...ActorFields } } - ... on UserBlockedEvent { createdAt actor { ...ActorFields } blockDuration } -} -""" - -F_PR_TIMELINE = """ -fragment PRTimelineItem on PullRequestTimelineItems { - __typename - ... on Node { id } - ... on AssignedEvent { createdAt actor { ...ActorFields } assignee { __typename ... on User { login } ... on Bot { login } } } - ... on AutoMergeDisabledEvent { createdAt actor { ...ActorFields } reason } - ... on AutoMergeEnabledEvent { createdAt actor { ...ActorFields } } - ... on AutoRebaseEnabledEvent { createdAt actor { ...ActorFields } } - ... on AutoSquashEnabledEvent { createdAt actor { ...ActorFields } } - ... on AutomaticBaseChangeFailedEvent { createdAt actor { ...ActorFields } oldBase newBase } - ... on AutomaticBaseChangeSucceededEvent { createdAt actor { ...ActorFields } oldBase newBase } - ... on BaseRefChangedEvent { createdAt actor { ...ActorFields } previousRefName currentRefName } - ... on BaseRefDeletedEvent { createdAt actor { ...ActorFields } baseRefName } - ... on BaseRefForcePushedEvent { createdAt actor { ...ActorFields } beforeCommit { oid } afterCommit { oid } ref { name } } - ... on ClosedEvent { createdAt actor { ...ActorFields } stateReason } - ... on CommentDeletedEvent { createdAt actor { ...ActorFields } } - ... on ConnectedEvent { createdAt actor { ...ActorFields } source { __typename ... on Issue { number url } ... on PullRequest { number url } } subject { __typename ... on Issue { number url } ... on PullRequest { number url } } } - ... on ConvertToDraftEvent { createdAt actor { ...ActorFields } } - ... on CrossReferencedEvent { createdAt actor { ...ActorFields } isCrossRepository willCloseTarget source { __typename ... on Issue { number url repository { nameWithOwner } title } ... on PullRequest { number url repository { nameWithOwner } title } } } - ... on DemilestonedEvent { createdAt actor { ...ActorFields } milestoneTitle } - ... on DeployedEvent { createdAt actor { ...ActorFields } } - ... on DeploymentEnvironmentChangedEvent { createdAt actor { ...ActorFields } } - ... on DisconnectedEvent { createdAt actor { ...ActorFields } subject { __typename ... on Issue { number url } ... on PullRequest { number url } } source { __typename ... on Issue { number url } ... on PullRequest { number url } } } - ... on HeadRefDeletedEvent { createdAt actor { ...ActorFields } headRefName } - ... on HeadRefForcePushedEvent { createdAt actor { ...ActorFields } beforeCommit { oid } afterCommit { oid } ref { name } } - ... on HeadRefRestoredEvent { createdAt actor { ...ActorFields } } - ... on IssueComment { id databaseId createdAt updatedAt author { ...ActorFields } body url reactionGroups { content reactors { totalCount } } } - ... on LabeledEvent { createdAt actor { ...ActorFields } label { name color } } - ... on LockedEvent { createdAt actor { ...ActorFields } lockReason } - ... on MarkedAsDuplicateEvent { createdAt actor { ...ActorFields } canonical { __typename ... on Issue { number url } ... on PullRequest { number url } } } - ... on MentionedEvent { createdAt actor { ...ActorFields } } - ... on MergedEvent { createdAt actor { ...ActorFields } commit { oid url } mergeRefName } - ... on MilestonedEvent { createdAt actor { ...ActorFields } milestoneTitle } - ... on MovedColumnsInProjectEvent { createdAt actor { ...ActorFields } } - ... on PinnedEvent { createdAt actor { ...ActorFields } } - ... on PullRequestCommit { commit { oid url message author { user { login } date } committedDate } } - ... on PullRequestCommitCommentThread { commit { oid } } - ... on PullRequestReview { id databaseId createdAt submittedAt author { ...ActorFields } body state url reactionGroups { content reactors { totalCount } } } - ... on PullRequestReviewThread { id isResolved isOutdated path line diffSide } - ... on PullRequestRevisionMarker { createdAt lastSeenCommit { oid } } - ... on ReadyForReviewEvent { createdAt actor { ...ActorFields } } - ... on ReferencedEvent { createdAt actor { ...ActorFields } commit { oid url } commitRepository { nameWithOwner } } - ... on RenamedTitleEvent { createdAt actor { ...ActorFields } previousTitle currentTitle } - ... on ReopenedEvent { createdAt actor { ...ActorFields } } - ... on ReviewDismissedEvent { createdAt actor { ...ActorFields } dismissalMessage previousReviewState } - ... on ReviewRequestRemovedEvent { createdAt actor { ...ActorFields } requestedReviewer { __typename ... on User { login } ... on Team { name } } } - ... on ReviewRequestedEvent { createdAt actor { ...ActorFields } requestedReviewer { __typename ... on User { login } ... on Team { name } } } - ... on SubscribedEvent { createdAt actor { ...ActorFields } } - ... on TransferredEvent { createdAt actor { ...ActorFields } fromRepository { nameWithOwner } } - ... on UnassignedEvent { createdAt actor { ...ActorFields } assignee { __typename ... on User { login } ... on Bot { login } } } - ... on UnlabeledEvent { createdAt actor { ...ActorFields } label { name color } } - ... on UnlockedEvent { createdAt actor { ...ActorFields } } - ... on UnmarkedAsDuplicateEvent { createdAt actor { ...ActorFields } } - ... on UnpinnedEvent { createdAt actor { ...ActorFields } } - ... on UnsubscribedEvent { createdAt actor { ...ActorFields } } - ... on UserBlockedEvent { createdAt actor { ...ActorFields } blockDuration } -} -""" - - -def _q(parts: list[str], body: str) -> str: - return "\n".join(parts + [body]) - - -ISSUES_PAGE_QUERY = _q( - [F_ACTOR, F_LABEL, F_TIMELINE], - """ -query IssuesPage($owner: String!, $name: String!, $first: Int!, $after: String) { - repository(owner: $owner, name: $name) { - issues(first: $first, after: $after, orderBy: {field: CREATED_AT, direction: ASC}) { - pageInfo { hasNextPage endCursor } - totalCount - nodes { - id databaseId number title body state stateReason - createdAt updatedAt closedAt - url - author { ...ActorFields } - editor { ...ActorFields } - labels(first: 50) { nodes { ...LabelFields } } - assignees(first: 20) { nodes { login id } } - milestone { title number state dueOn } - reactionGroups { content reactors { totalCount } } - comments(first: 100) { - totalCount - pageInfo { hasNextPage endCursor } - nodes { - id databaseId createdAt updatedAt url body - author { ...ActorFields } - editor { ...ActorFields } - reactionGroups { content reactors { totalCount } } - } - } - timelineItems(first: 100) { - totalCount - pageInfo { hasNextPage endCursor } - nodes { ...TimelineItem } - } - trackedInIssues(first: 20) { totalCount nodes { number url repository { nameWithOwner } } } - trackedIssues(first: 20) { totalCount nodes { number url repository { nameWithOwner } } } - } - } - } - rateLimit { cost remaining resetAt } -} -""", -) - -PRS_PAGE_QUERY = _q( - [F_ACTOR, F_LABEL, F_PR_TIMELINE], - """ -query PRsPage($owner: String!, $name: String!, $first: Int!, $after: String) { - repository(owner: $owner, name: $name) { - pullRequests(first: $first, after: $after, orderBy: {field: CREATED_AT, direction: ASC}) { - pageInfo { hasNextPage endCursor } - totalCount - nodes { - id databaseId number title body state isDraft - createdAt updatedAt closedAt mergedAt - url - headRefName headRefOid - baseRefName baseRefOid - additions deletions changedFiles - mergeable merged mergeStateStatus - author { ...ActorFields } - editor { ...ActorFields } - mergedBy { ...ActorFields } - labels(first: 50) { nodes { ...LabelFields } } - assignees(first: 20) { nodes { login id } } - milestone { title number state dueOn } - reactionGroups { content reactors { totalCount } } - closingIssuesReferences(first: 20) { totalCount nodes { number url repository { nameWithOwner } title } } - comments(first: 100) { - totalCount - pageInfo { hasNextPage endCursor } - nodes { - id databaseId createdAt updatedAt url body - author { ...ActorFields } - editor { ...ActorFields } - reactionGroups { content reactors { totalCount } } - } - } - reviewThreads(first: 50) { - totalCount - pageInfo { hasNextPage endCursor } - nodes { - id isResolved isOutdated path line diffSide - comments(first: 50) { - totalCount - pageInfo { hasNextPage endCursor } - nodes { - id databaseId createdAt updatedAt url body path diffHunk - author { ...ActorFields } - editor { ...ActorFields } - position originalPosition line originalLine - commit { oid } - reactionGroups { content reactors { totalCount } } - } - } - } - } - reviews(first: 50) { - totalCount - pageInfo { hasNextPage endCursor } - nodes { - id databaseId state createdAt submittedAt body url - author { ...ActorFields } - reactionGroups { content reactors { totalCount } } - } - } - commits(first: 100) { - totalCount - pageInfo { hasNextPage endCursor } - nodes { - commit { - oid - message - messageHeadline - committedDate - authoredDate - author { name email user { login } date } - committer { name email user { login } date } - additions deletions changedFilesIfAvailable - parents(first: 3) { nodes { oid } } - } - } - } - files(first: 100) { - totalCount - pageInfo { hasNextPage endCursor } - nodes { - path additions deletions changeType - } - } - timelineItems(first: 100) { - totalCount - pageInfo { hasNextPage endCursor } - nodes { ...PRTimelineItem } - } - } - } - } - rateLimit { cost remaining resetAt } -} -""", -) - -PRS_PAGE_QUERY_LIGHT = _q( - [F_ACTOR, F_LABEL], - """ -query PRsPageLight($owner: String!, $name: String!, $first: Int!, $after: String) { - repository(owner: $owner, name: $name) { - pullRequests(first: $first, after: $after, orderBy: {field: CREATED_AT, direction: ASC}) { - pageInfo { hasNextPage endCursor } - totalCount - nodes { - id databaseId number title body state isDraft - createdAt updatedAt closedAt mergedAt - url - author { ...ActorFields } - labels(first: 50) { nodes { ...LabelFields } } - comments(first: 30) { - totalCount - pageInfo { hasNextPage endCursor } - nodes { - id databaseId createdAt updatedAt url body - author { ...ActorFields } - } - } - } - } - } - rateLimit { cost remaining resetAt } -} -""", -) - -ISSUES_PAGE_QUERY_LIGHT = _q( - [F_ACTOR, F_LABEL], - """ -query IssuesPageLight($owner: String!, $name: String!, $first: Int!, $after: String) { - repository(owner: $owner, name: $name) { - issues(first: $first, after: $after, orderBy: {field: CREATED_AT, direction: ASC}) { - pageInfo { hasNextPage endCursor } - totalCount - nodes { - id databaseId number title body state - createdAt updatedAt closedAt - url - author { ...ActorFields } - labels(first: 50) { nodes { ...LabelFields } } - comments(first: 30) { - totalCount - pageInfo { hasNextPage endCursor } - nodes { - id databaseId createdAt updatedAt url body - author { ...ActorFields } - } - } - } - } - } - rateLimit { cost remaining resetAt } -} -""", -) - -ISSUE_COMMENTS_QUERY = _q( - [F_ACTOR], - """ -query IssueComments($owner: String!, $name: String!, $number: Int!, $after: String) { - repository(owner: $owner, name: $name) { - issueOrPullRequest(number: $number) { - __typename - ... on Issue { - comments(first: 100, after: $after) { - pageInfo { hasNextPage endCursor } - nodes { - id databaseId createdAt updatedAt url body - author { ...ActorFields } - editor { ...ActorFields } - reactionGroups { content reactors { totalCount } } - } - } - } - ... on PullRequest { - comments(first: 100, after: $after) { - pageInfo { hasNextPage endCursor } - nodes { - id databaseId createdAt updatedAt url body - author { ...ActorFields } - editor { ...ActorFields } - reactionGroups { content reactors { totalCount } } - } - } - } - } - } - rateLimit { cost remaining resetAt } -} -""", -) - -ISSUE_TIMELINE_QUERY = _q( - [F_ACTOR, F_TIMELINE], - """ -query IssueTimeline($owner: String!, $name: String!, $number: Int!, $after: String) { - repository(owner: $owner, name: $name) { - issue(number: $number) { - timelineItems(first: 100, after: $after) { - pageInfo { hasNextPage endCursor } - nodes { ...TimelineItem } - } - } - } - rateLimit { cost remaining resetAt } -} -""", -) - -PR_TIMELINE_QUERY = _q( - [F_ACTOR, F_PR_TIMELINE], - """ -query PRTimeline($owner: String!, $name: String!, $number: Int!, $after: String) { - repository(owner: $owner, name: $name) { - pullRequest(number: $number) { - timelineItems(first: 100, after: $after) { - pageInfo { hasNextPage endCursor } - nodes { ...PRTimelineItem } - } - } - } - rateLimit { cost remaining resetAt } -} -""", -) - -PR_COMMITS_QUERY = """ -query PRCommits($owner: String!, $name: String!, $number: Int!, $after: String) { - repository(owner: $owner, name: $name) { - pullRequest(number: $number) { - commits(first: 100, after: $after) { - pageInfo { hasNextPage endCursor } - nodes { - commit { - oid message messageHeadline committedDate authoredDate - author { name email user { login } date } - committer { name email user { login } date } - additions deletions changedFilesIfAvailable - parents(first: 3) { nodes { oid } } - } - } - } - } - } - rateLimit { cost remaining resetAt } -} -""" - -PR_FILES_QUERY = """ -query PRFiles($owner: String!, $name: String!, $number: Int!, $after: String) { - repository(owner: $owner, name: $name) { - pullRequest(number: $number) { - files(first: 100, after: $after) { - pageInfo { hasNextPage endCursor } - nodes { path additions deletions changeType } - } - } - } - rateLimit { cost remaining resetAt } -} -""" - -PR_REVIEW_THREADS_QUERY = _q( - [F_ACTOR], - """ -query PRReviewThreads($owner: String!, $name: String!, $number: Int!, $after: String) { - repository(owner: $owner, name: $name) { - pullRequest(number: $number) { - reviewThreads(first: 50, after: $after) { - pageInfo { hasNextPage endCursor } - nodes { - id isResolved isOutdated path line diffSide - comments(first: 50) { - totalCount - nodes { - id databaseId createdAt updatedAt url body path diffHunk - author { ...ActorFields } - editor { ...ActorFields } - position originalPosition line originalLine - commit { oid } - reactionGroups { content reactors { totalCount } } - } - } - } - } - } - } - rateLimit { cost remaining resetAt } -} -""", -) - -DISCUSSIONS_PAGE_QUERY = _q( - [F_ACTOR, F_LABEL], - """ -query DiscussionsPage($owner: String!, $name: String!, $first: Int!, $after: String) { - repository(owner: $owner, name: $name) { - discussions(first: $first, after: $after, orderBy: {field: CREATED_AT, direction: ASC}) { - pageInfo { hasNextPage endCursor } - totalCount - nodes { - id databaseId number title body - createdAt updatedAt url - author { ...ActorFields } - editor { ...ActorFields } - locked - answerChosenAt - closed closedAt - category { id name emoji description isAnswerable } - labels(first: 30) { nodes { ...LabelFields } } - upvoteCount - answer { id databaseId body author { ...ActorFields } createdAt url } - reactionGroups { content reactors { totalCount } } - comments(first: 50) { - totalCount - pageInfo { hasNextPage endCursor } - nodes { - id databaseId body createdAt updatedAt url - author { ...ActorFields } - editor { ...ActorFields } - upvoteCount - isAnswer - reactionGroups { content reactors { totalCount } } - replies(first: 50) { - totalCount - pageInfo { hasNextPage endCursor } - nodes { - id databaseId body createdAt updatedAt url - author { ...ActorFields } - editor { ...ActorFields } - reactionGroups { content reactors { totalCount } } - } - } - } - } - } - } - } - rateLimit { cost remaining resetAt } -} -""", -) - -DISCUSSION_COMMENTS_QUERY = _q( - [F_ACTOR], - """ -query DiscussionComments($owner: String!, $name: String!, $number: Int!, $after: String) { - repository(owner: $owner, name: $name) { - discussion(number: $number) { - comments(first: 50, after: $after) { - pageInfo { hasNextPage endCursor } - nodes { - id databaseId body createdAt updatedAt url - author { ...ActorFields } - editor { ...ActorFields } - upvoteCount - isAnswer - reactionGroups { content reactors { totalCount } } - replies(first: 50) { - totalCount - nodes { - id databaseId body createdAt updatedAt url - author { ...ActorFields } - editor { ...ActorFields } - reactionGroups { content reactors { totalCount } } - } - } - } - } - } - } - rateLimit { cost remaining resetAt } -} -""", -) - -DISCUSSION_REPLIES_QUERY = _q( - [F_ACTOR], - """ -query DiscussionReplies($commentId: ID!, $after: String) { - node(id: $commentId) { - ... on DiscussionComment { - replies(first: 50, after: $after) { - pageInfo { hasNextPage endCursor } - nodes { - id databaseId body createdAt updatedAt url - author { ...ActorFields } - editor { ...ActorFields } - reactionGroups { content reactors { totalCount } } - } - } - } - } - rateLimit { cost remaining resetAt } -} -""", -) - -COMMITS_PAGE_QUERY = """ -query CommitsPage($owner: String!, $name: String!, $first: Int!, $after: String, $branch: String!) { - repository(owner: $owner, name: $name) { - ref(qualifiedName: $branch) { - target { - ... on Commit { - history(first: $first, after: $after) { - pageInfo { hasNextPage endCursor } - totalCount - nodes { - oid - message - messageHeadline - committedDate - authoredDate - url - additions deletions changedFilesIfAvailable - author { name email date user { login id } } - committer { name email date user { login id } } - parents(first: 3) { nodes { oid } } - associatedPullRequests(first: 5) { nodes { number url state } } - } - } - } - } - } - } - rateLimit { cost remaining resetAt } -} -""" - -RELEASES_QUERY = _q( - [F_ACTOR], - """ -query Releases($owner: String!, $name: String!, $first: Int!, $after: String) { - repository(owner: $owner, name: $name) { - releases(first: $first, after: $after, orderBy: {field: CREATED_AT, direction: ASC}) { - pageInfo { hasNextPage endCursor } - nodes { - id databaseId name tagName description - createdAt publishedAt updatedAt - isDraft isPrerelease isLatest - url - author { ...ActorFields } - tagCommit { oid url } - reactionGroups { content reactors { totalCount } } - releaseAssets(first: 50) { - nodes { name contentType size downloadUrl createdAt updatedAt } - } - } - } - } - rateLimit { cost remaining resetAt } -} -""", -) - -LABELS_QUERY = _q( - [F_LABEL], - """ -query LabelsList($owner: String!, $name: String!, $first: Int!, $after: String) { - repository(owner: $owner, name: $name) { - labels(first: $first, after: $after) { - pageInfo { hasNextPage endCursor } - nodes { ...LabelFields } - } - } - rateLimit { cost remaining resetAt } -} -""", -) - -MILESTONES_QUERY = """ -query Milestones($owner: String!, $name: String!, $first: Int!, $after: String) { - repository(owner: $owner, name: $name) { - milestones(first: $first, after: $after) { - pageInfo { hasNextPage endCursor } - nodes { - id number title description state - createdAt updatedAt closedAt dueOn - creator { login } - } - } - } - rateLimit { cost remaining resetAt } -} -""" - -REPO_META_QUERY = """ -query RepoMeta($owner: String!, $name: String!) { - repository(owner: $owner, name: $name) { - id databaseId name nameWithOwner description url - createdAt updatedAt pushedAt - isArchived isDisabled isFork isPrivate - primaryLanguage { name } - languages(first: 20, orderBy: {field: SIZE, direction: DESC}) { - edges { size node { name } } - totalSize - } - stargazerCount forkCount watchers { totalCount } - diskUsage - licenseInfo { key name } - homepageUrl - defaultBranchRef { name } - } - rateLimit { cost remaining resetAt } -} -""" diff --git a/studio/backend/plugins/data-designer-github-repo-seed/src/data_designer_github_repo_seed/scraper_impl/scraper.py b/studio/backend/plugins/data-designer-github-repo-seed/src/data_designer_github_repo_seed/scraper_impl/scraper.py deleted file mode 100644 index 127129e18b..0000000000 --- a/studio/backend/plugins/data-designer-github-repo-seed/src/data_designer_github_repo_seed/scraper_impl/scraper.py +++ /dev/null @@ -1,756 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 - -"""Main scraper orchestration. Collects issues, PRs, discussions, commits, releases, etc. - -Resumable via state file. Writes JSONL shards under data/{repo}/{resource}.jsonl. -""" - -from __future__ import annotations - -import argparse -import json -import logging -import os -import subprocess -import sys -import time -from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional, Tuple - -# Allow running as a module or script -THIS_DIR = Path(__file__).resolve().parent -if str(THIS_DIR) not in sys.path: - sys.path.insert(0, str(THIS_DIR)) - -from gh_client import GitHubClient -from state_store import JsonlWriter, StateStore -import queries as Q - -log = logging.getLogger("scraper") - - -def ts() -> str: - return time.strftime("%Y-%m-%d %H:%M:%S") - - -class RepoScraper: - def __init__( - self, - owner: str, - name: str, - base_dir: Path, - client: GitHubClient, - trial_limits: Optional[Dict[str, int]] = None, - light: bool = False, - ): - self.owner = owner - self.name = name - self.base_dir = base_dir - self.client = client - self.trial_limits = trial_limits or {} - # When light=True, use trimmed GraphQL queries (no reviewThreads, - # reviews, commits, timelineItems, files) so PR pages can be much - # larger without blowing GitHub's node-count ceiling. - self.light = light - self.repo_dir = base_dir / f"{owner}__{name}" - self.repo_dir.mkdir(parents = True, exist_ok = True) - self.state = StateStore(base_dir / "state" / f"{owner}__{name}.json") - - # Writers - self.writers: Dict[str, JsonlWriter] = {} - for key in ( - "issues", - "pull_requests", - "discussions", - "commits", - "releases", - "labels", - "milestones", - "pr_extra_comments", - "pr_extra_timeline", - "pr_extra_reviews", - "issue_extra_comments", - "issue_extra_timeline", - "discussion_extra_comments", - "discussion_extra_replies", - "repo_meta", - ): - self.writers[key] = JsonlWriter(self.repo_dir / f"{key}.jsonl") - - # ----- helpers ----- - def _trial_stop(self, key: str, counter: int) -> bool: - lim = self.trial_limits.get(key) - if lim is None: - return False - return counter >= lim - - def _log_rate(self, where: str, data: Dict[str, Any]) -> None: - rl = ( - data.get("data", {}).get("rateLimit") - if isinstance(data.get("data"), dict) - else None - ) - if rl: - log.debug( - "[%s] rate cost=%s remaining=%s resetAt=%s", - where, - rl.get("cost"), - rl.get("remaining"), - rl.get("resetAt"), - ) - - # ----- repo meta ----- - def scrape_repo_meta(self) -> Dict[str, Any]: - data = self.client.graphql( - Q.REPO_META_QUERY, {"owner": self.owner, "name": self.name} - ) - self._log_rate("repo_meta", data) - repo = data.get("data", {}).get("repository") or {} - repo["_fetchedAt"] = ts() - self.writers["repo_meta"].write(repo) - return repo - - # ----- issues ----- - def scrape_issues(self) -> int: - key = "issues" - cursor = self.state.get(f"{key}_cursor") - done = self.state.get(f"{key}_done", False) - if done: - log.info("%s/%s issues already complete", self.owner, self.name) - return 0 - total_new = 0 - page = 0 - # Light query skips heavy nested fields; safe at 50 per page. - # Clamp by trial_limit so e.g. limit=1 asks GitHub for first:1 - # instead of fetching a full 50-item page and discarding 49. - page_cap = 50 if self.light else 15 - trial_cap = self.trial_limits.get(key) - per_page = min(page_cap, trial_cap) if trial_cap and trial_cap > 0 else page_cap - while True: - page += 1 - vars_ = { - "owner": self.owner, - "name": self.name, - "first": per_page, - "after": cursor, - } - query = Q.ISSUES_PAGE_QUERY_LIGHT if self.light else Q.ISSUES_PAGE_QUERY - data = self.client.graphql(query, vars_) - self._log_rate("issues", data) - repo = (data.get("data") or {}).get("repository") or {} - issues = repo.get("issues") or {} - nodes = issues.get("nodes") or [] - for it in nodes: - it["_owner"] = self.owner - it["_repo"] = self.name - it["_fetchedAt"] = ts() - if not self.light: - if it.get("comments", {}).get("pageInfo", {}).get("hasNextPage"): - self._paginate_issue_comments( - it["number"], it["comments"]["pageInfo"]["endCursor"] - ) - if ( - it.get("timelineItems", {}) - .get("pageInfo", {}) - .get("hasNextPage") - ): - self._paginate_issue_timeline( - it["number"], - it["timelineItems"]["pageInfo"]["endCursor"], - ) - if self.writers[key].write(it): - total_new += 1 - info = issues.get("pageInfo") or {} - cursor = info.get("endCursor") - self.state.set(f"{key}_cursor", cursor) - log.info( - "[%s/%s] issues page %d (+%d) cursor=%s remaining=%s", - self.owner, - self.name, - page, - len(nodes), - str(cursor)[:20], - self.client.graphql_remaining, - ) - if self._trial_stop(key, total_new): - log.info("Trial limit reached for issues (%d)", total_new) - return total_new - if not info.get("hasNextPage"): - self.state.set(f"{key}_done", True) - break - return total_new - - def _paginate_issue_comments(self, number: int, after: str) -> None: - cur = after - while cur: - vars_ = { - "owner": self.owner, - "name": self.name, - "number": number, - "after": cur, - } - data = self.client.graphql(Q.ISSUE_COMMENTS_QUERY, vars_) - item = ((data.get("data") or {}).get("repository") or {}).get( - "issueOrPullRequest" - ) or {} - comments = item.get("comments") or {} - for c in comments.get("nodes") or []: - c["_owner"] = self.owner - c["_repo"] = self.name - c["_issueNumber"] = number - self.writers["issue_extra_comments"].write(c) - info = comments.get("pageInfo") or {} - cur = info.get("endCursor") if info.get("hasNextPage") else None - - def _paginate_issue_timeline(self, number: int, after: str) -> None: - cur = after - while cur: - vars_ = { - "owner": self.owner, - "name": self.name, - "number": number, - "after": cur, - } - data = self.client.graphql(Q.ISSUE_TIMELINE_QUERY, vars_) - item = ((data.get("data") or {}).get("repository") or {}).get("issue") or {} - tl = item.get("timelineItems") or {} - for ev in tl.get("nodes") or []: - ev["_owner"] = self.owner - ev["_repo"] = self.name - ev["_issueNumber"] = number - self.writers["issue_extra_timeline"].write(ev) - info = tl.get("pageInfo") or {} - cur = info.get("endCursor") if info.get("hasNextPage") else None - - # ----- PRs ----- - def scrape_prs(self) -> int: - key = "pull_requests" - cursor = self.state.get(f"{key}_cursor") - done = self.state.get(f"{key}_done", False) - if done: - log.info("%s/%s PRs already complete", self.owner, self.name) - return 0 - total_new = 0 - page = 0 - # Heavy nested PR query is capped at 3 per page (GitHub node-count - # ceiling); light query skips reviewThreads/reviews/commits/etc and - # can safely go to 25 per page. Clamp by trial_limit for small - # previews so limit=1 does not fetch a whole 25-item page. - page_cap = 25 if self.light else 3 - trial_cap = self.trial_limits.get(key) - per_page = min(page_cap, trial_cap) if trial_cap and trial_cap > 0 else page_cap - while True: - page += 1 - vars_ = { - "owner": self.owner, - "name": self.name, - "first": per_page, - "after": cursor, - } - query = Q.PRS_PAGE_QUERY_LIGHT if self.light else Q.PRS_PAGE_QUERY - data = self.client.graphql(query, vars_) - self._log_rate("prs", data) - repo = (data.get("data") or {}).get("repository") or {} - prs = repo.get("pullRequests") or {} - nodes = prs.get("nodes") or [] - for pr in nodes: - pr["_owner"] = self.owner - pr["_repo"] = self.name - pr["_fetchedAt"] = ts() - num = pr["number"] - if not self.light: - if pr.get("comments", {}).get("pageInfo", {}).get("hasNextPage"): - self._paginate_pr_comments( - num, pr["comments"]["pageInfo"]["endCursor"] - ) - if ( - pr.get("timelineItems", {}) - .get("pageInfo", {}) - .get("hasNextPage") - ): - self._paginate_pr_timeline( - num, pr["timelineItems"]["pageInfo"]["endCursor"] - ) - if pr.get("commits", {}).get("pageInfo", {}).get("hasNextPage"): - self._paginate_pr_commits( - num, pr["commits"]["pageInfo"]["endCursor"] - ) - if pr.get("files", {}).get("pageInfo", {}).get("hasNextPage"): - self._paginate_pr_files( - num, pr["files"]["pageInfo"]["endCursor"] - ) - if ( - pr.get("reviewThreads", {}) - .get("pageInfo", {}) - .get("hasNextPage") - ): - self._paginate_pr_review_threads( - num, pr["reviewThreads"]["pageInfo"]["endCursor"] - ) - if self.writers[key].write(pr): - total_new += 1 - info = prs.get("pageInfo") or {} - cursor = info.get("endCursor") - self.state.set(f"{key}_cursor", cursor) - log.info( - "[%s/%s] PRs page %d (+%d) cursor=%s remaining=%s", - self.owner, - self.name, - page, - len(nodes), - str(cursor)[:20], - self.client.graphql_remaining, - ) - if self._trial_stop(key, total_new): - log.info("Trial limit reached for PRs (%d)", total_new) - return total_new - if not info.get("hasNextPage"): - self.state.set(f"{key}_done", True) - break - return total_new - - def _paginate_pr_comments(self, number: int, after: str) -> None: - cur = after - while cur: - vars_ = { - "owner": self.owner, - "name": self.name, - "number": number, - "after": cur, - } - data = self.client.graphql(Q.ISSUE_COMMENTS_QUERY, vars_) - item = ((data.get("data") or {}).get("repository") or {}).get( - "issueOrPullRequest" - ) or {} - comments = item.get("comments") or {} - for c in comments.get("nodes") or []: - c["_owner"] = self.owner - c["_repo"] = self.name - c["_prNumber"] = number - self.writers["pr_extra_comments"].write(c) - info = comments.get("pageInfo") or {} - cur = info.get("endCursor") if info.get("hasNextPage") else None - - def _paginate_pr_timeline(self, number: int, after: str) -> None: - cur = after - while cur: - vars_ = { - "owner": self.owner, - "name": self.name, - "number": number, - "after": cur, - } - data = self.client.graphql(Q.PR_TIMELINE_QUERY, vars_) - item = ((data.get("data") or {}).get("repository") or {}).get( - "pullRequest" - ) or {} - tl = item.get("timelineItems") or {} - for ev in tl.get("nodes") or []: - ev["_owner"] = self.owner - ev["_repo"] = self.name - ev["_prNumber"] = number - self.writers["pr_extra_timeline"].write(ev) - info = tl.get("pageInfo") or {} - cur = info.get("endCursor") if info.get("hasNextPage") else None - - def _paginate_pr_commits(self, number: int, after: str) -> None: - cur = after - out_key = "pr_extra_commits" - if out_key not in self.writers: - self.writers[out_key] = JsonlWriter(self.repo_dir / f"{out_key}.jsonl") - while cur: - vars_ = { - "owner": self.owner, - "name": self.name, - "number": number, - "after": cur, - } - data = self.client.graphql(Q.PR_COMMITS_QUERY, vars_) - item = ((data.get("data") or {}).get("repository") or {}).get( - "pullRequest" - ) or {} - cc = item.get("commits") or {} - for c in cc.get("nodes") or []: - c["_owner"] = self.owner - c["_repo"] = self.name - c["_prNumber"] = number - self.writers[out_key].write(c) - info = cc.get("pageInfo") or {} - cur = info.get("endCursor") if info.get("hasNextPage") else None - - def _paginate_pr_files(self, number: int, after: str) -> None: - cur = after - out_key = "pr_extra_files" - if out_key not in self.writers: - self.writers[out_key] = JsonlWriter(self.repo_dir / f"{out_key}.jsonl") - while cur: - vars_ = { - "owner": self.owner, - "name": self.name, - "number": number, - "after": cur, - } - data = self.client.graphql(Q.PR_FILES_QUERY, vars_) - item = ((data.get("data") or {}).get("repository") or {}).get( - "pullRequest" - ) or {} - ff = item.get("files") or {} - for f in ff.get("nodes") or []: - f["_owner"] = self.owner - f["_repo"] = self.name - f["_prNumber"] = number - # files don't have id, synthesize one - f["_syntheticId"] = f"{self.owner}/{self.name}#{number}:{f.get('path')}" - self.writers[out_key].write(f) - info = ff.get("pageInfo") or {} - cur = info.get("endCursor") if info.get("hasNextPage") else None - - def _paginate_pr_review_threads(self, number: int, after: str) -> None: - cur = after - out_key = "pr_extra_review_threads" - if out_key not in self.writers: - self.writers[out_key] = JsonlWriter(self.repo_dir / f"{out_key}.jsonl") - while cur: - vars_ = { - "owner": self.owner, - "name": self.name, - "number": number, - "after": cur, - } - data = self.client.graphql(Q.PR_REVIEW_THREADS_QUERY, vars_) - item = ((data.get("data") or {}).get("repository") or {}).get( - "pullRequest" - ) or {} - rt = item.get("reviewThreads") or {} - for th in rt.get("nodes") or []: - th["_owner"] = self.owner - th["_repo"] = self.name - th["_prNumber"] = number - self.writers[out_key].write(th) - info = rt.get("pageInfo") or {} - cur = info.get("endCursor") if info.get("hasNextPage") else None - - # ----- Discussions ----- - def scrape_discussions(self) -> int: - key = "discussions" - cursor = self.state.get(f"{key}_cursor") - done = self.state.get(f"{key}_done", False) - if done: - log.info("%s/%s discussions already complete", self.owner, self.name) - return 0 - total_new = 0 - page = 0 - per_page = 15 - while True: - page += 1 - vars_ = { - "owner": self.owner, - "name": self.name, - "first": per_page, - "after": cursor, - } - data = self.client.graphql(Q.DISCUSSIONS_PAGE_QUERY, vars_) - self._log_rate("discussions", data) - repo = (data.get("data") or {}).get("repository") or {} - dd = repo.get("discussions") or {} - nodes = dd.get("nodes") or [] - for d in nodes: - d["_owner"] = self.owner - d["_repo"] = self.name - d["_fetchedAt"] = ts() - num = d["number"] - if d.get("comments", {}).get("pageInfo", {}).get("hasNextPage"): - self._paginate_discussion_comments( - num, d["comments"]["pageInfo"]["endCursor"] - ) - # paginate replies per comment if needed - for c in d.get("comments", {}).get("nodes", []) or []: - if c.get("replies", {}).get("pageInfo", {}).get("hasNextPage"): - self._paginate_discussion_replies( - c["id"], c["replies"]["pageInfo"]["endCursor"], num - ) - if self.writers[key].write(d): - total_new += 1 - info = dd.get("pageInfo") or {} - cursor = info.get("endCursor") - self.state.set(f"{key}_cursor", cursor) - log.info( - "[%s/%s] discussions page %d (+%d) cursor=%s remaining=%s", - self.owner, - self.name, - page, - len(nodes), - str(cursor)[:20], - self.client.graphql_remaining, - ) - if self._trial_stop(key, total_new): - return total_new - if not info.get("hasNextPage"): - self.state.set(f"{key}_done", True) - break - return total_new - - def _paginate_discussion_comments(self, number: int, after: str) -> None: - cur = after - while cur: - vars_ = { - "owner": self.owner, - "name": self.name, - "number": number, - "after": cur, - } - data = self.client.graphql(Q.DISCUSSION_COMMENTS_QUERY, vars_) - disc = ((data.get("data") or {}).get("repository") or {}).get( - "discussion" - ) or {} - cc = disc.get("comments") or {} - for c in cc.get("nodes") or []: - c["_owner"] = self.owner - c["_repo"] = self.name - c["_discussionNumber"] = number - self.writers["discussion_extra_comments"].write(c) - info = cc.get("pageInfo") or {} - cur = info.get("endCursor") if info.get("hasNextPage") else None - - def _paginate_discussion_replies( - self, comment_id: str, after: str, disc_number: int - ) -> None: - cur = after - while cur: - vars_ = { - "owner": self.owner, - "name": self.name, - "commentId": comment_id, - "after": cur, - } - data = self.client.graphql(Q.DISCUSSION_REPLIES_QUERY, vars_) - node = (data.get("data") or {}).get("node") or {} - replies = node.get("replies") or {} - for r in replies.get("nodes") or []: - r["_owner"] = self.owner - r["_repo"] = self.name - r["_discussionNumber"] = disc_number - r["_commentId"] = comment_id - self.writers["discussion_extra_replies"].write(r) - info = replies.get("pageInfo") or {} - cur = info.get("endCursor") if info.get("hasNextPage") else None - - # ----- Commits ----- - def scrape_commits(self, branch: str = "refs/heads/main") -> int: - key = "commits" - cursor = self.state.get(f"{key}_cursor") - done = self.state.get(f"{key}_done", False) - if done: - return 0 - total_new = 0 - page = 0 - page_cap = 100 - trial_cap = self.trial_limits.get(key) - per_page = min(page_cap, trial_cap) if trial_cap and trial_cap > 0 else page_cap - while True: - page += 1 - vars_ = { - "owner": self.owner, - "name": self.name, - "first": per_page, - "after": cursor, - "branch": branch, - } - data = self.client.graphql(Q.COMMITS_PAGE_QUERY, vars_) - self._log_rate("commits", data) - ref = ((data.get("data") or {}).get("repository") or {}).get("ref") or {} - tgt = ref.get("target") or {} - hist = tgt.get("history") or {} - nodes = hist.get("nodes") or [] - for c in nodes: - c["_owner"] = self.owner - c["_repo"] = self.name - c["_fetchedAt"] = ts() - if self.writers[key].write(c): - total_new += 1 - info = hist.get("pageInfo") or {} - cursor = info.get("endCursor") - self.state.set(f"{key}_cursor", cursor) - log.info( - "[%s/%s] commits page %d (+%d) remaining=%s", - self.owner, - self.name, - page, - len(nodes), - self.client.graphql_remaining, - ) - if self._trial_stop(key, total_new): - return total_new - if not info.get("hasNextPage"): - self.state.set(f"{key}_done", True) - break - return total_new - - # ----- Releases/Labels/Milestones ----- - def scrape_releases(self) -> int: - return self._scrape_simple("releases", Q.RELEASES_QUERY, "releases") - - def scrape_labels(self) -> int: - return self._scrape_simple("labels", Q.LABELS_QUERY, "labels") - - def scrape_milestones(self) -> int: - return self._scrape_simple("milestones", Q.MILESTONES_QUERY, "milestones") - - def _scrape_simple(self, key: str, query: str, field: str) -> int: - cursor = self.state.get(f"{key}_cursor") - done = self.state.get(f"{key}_done", False) - if done: - return 0 - total_new = 0 - while True: - vars_ = { - "owner": self.owner, - "name": self.name, - "first": 50, - "after": cursor, - } - data = self.client.graphql(query, vars_) - repo = (data.get("data") or {}).get("repository") or {} - col = repo.get(field) or {} - for it in col.get("nodes") or []: - it["_owner"] = self.owner - it["_repo"] = self.name - it["_fetchedAt"] = ts() - if self.writers[key].write(it): - total_new += 1 - info = col.get("pageInfo") or {} - cursor = info.get("endCursor") - self.state.set(f"{key}_cursor", cursor) - if self._trial_stop(key, total_new): - return total_new - if not info.get("hasNextPage"): - self.state.set(f"{key}_done", True) - break - log.info("[%s/%s] %s done +%d", self.owner, self.name, key, total_new) - return total_new - - def close(self) -> None: - for w in self.writers.values(): - try: - w.close() - except Exception: - pass - - -def setup_logging(log_file: Path) -> None: - log_file.parent.mkdir(parents = True, exist_ok = True) - fmt = "%(asctime)s %(levelname)s [%(name)s] %(message)s" - handlers = [ - logging.StreamHandler(sys.stdout), - logging.FileHandler(log_file, mode = "a", encoding = "utf-8"), - ] - logging.basicConfig(level = logging.INFO, format = fmt, handlers = handlers, force = True) - - -def main(): - ap = argparse.ArgumentParser() - ap.add_argument( - "--base-dir", default = "/mnt/disks/unslothai/ubuntu/workspace_34/github_scraper" - ) - ap.add_argument( - "--repos", nargs = "+", default = ["unslothai/unsloth", "unslothai/unsloth-zoo"] - ) - ap.add_argument("--trial", action = "store_true", help = "Small trial run") - ap.add_argument( - "--only", - nargs = "+", - default = None, - help = "Only run these resource keys: issues,pulls,discussions,commits,releases,labels,milestones,meta", - ) - ap.add_argument( - "--hf-upload-interval", - type = int, - default = 900, - help = "Seconds between HF uploads (0 to disable)", - ) - args = ap.parse_args() - - base = Path(args.base_dir) - data_dir = base / "data" - data_dir.mkdir(parents = True, exist_ok = True) - setup_logging(base / "logs" / f"scraper_{time.strftime('%Y%m%d_%H%M%S')}.log") - log.info("Scraper starting: repos=%s trial=%s", args.repos, args.trial) - - client = GitHubClient(min_remaining_graphql = 80, min_remaining_rest = 80) - rl = client.rate_snapshot() - log.info( - "Rate limit snapshot: %s", - json.dumps(rl.get("resources", {}), default = str)[:400], - ) - - # Start HF uploader in background if requested - uploader = None - if args.hf_upload_interval > 0: - from hf_uploader import HFUploader - - uploader = HFUploader(data_dir, interval_s = args.hf_upload_interval) - uploader.start() - - trial_limits = None - if args.trial: - trial_limits = { - "issues": 5, - "pull_requests": 5, - "discussions": 3, - "commits": 20, - "releases": 3, - "labels": 20, - "milestones": 20, - } - - only = set(args.only or []) - - try: - for repo_spec in args.repos: - owner, name = repo_spec.split("/") - scraper = RepoScraper(owner, name, data_dir, client, trial_limits) - try: - repo_meta: Dict[str, Any] = {} - if not only or "meta" in only or "commits" in only: - repo_meta = scraper.scrape_repo_meta() - if not only or "labels" in only: - scraper.scrape_labels() - if not only or "milestones" in only: - scraper.scrape_milestones() - if not only or "releases" in only: - scraper.scrape_releases() - if not only or "discussions" in only: - scraper.scrape_discussions() - if not only or "issues" in only: - scraper.scrape_issues() - if not only or "pulls" in only: - scraper.scrape_prs() - if not only or "commits" in only: - default_ref = repo_meta.get("defaultBranchRef") or {} - default_branch = ( - default_ref.get("name") - if isinstance(default_ref, dict) - else None - ) - branch = ( - f"refs/heads/{default_branch}" - if default_branch - else "refs/heads/main" - ) - scraper.scrape_commits(branch = branch) - finally: - scraper.close() - finally: - if uploader: - log.info("Stopping uploader and final sync...") - uploader.stop(final_upload = True) - log.info( - "Scraper complete. GraphQL calls=%d REST calls=%d", - client.calls_graphql, - client.calls_rest, - ) - - -if __name__ == "__main__": - main() diff --git a/studio/backend/plugins/data-designer-github-repo-seed/src/data_designer_github_repo_seed/scraper_impl/state_store.py b/studio/backend/plugins/data-designer-github-repo-seed/src/data_designer_github_repo_seed/scraper_impl/state_store.py deleted file mode 100644 index efa663db2f..0000000000 --- a/studio/backend/plugins/data-designer-github-repo-seed/src/data_designer_github_repo_seed/scraper_impl/state_store.py +++ /dev/null @@ -1,105 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 - -"""Checkpoint state management for resumable scraping.""" - -from __future__ import annotations - -import json -import os -import threading -from pathlib import Path -from typing import Any, Dict - - -class StateStore: - def __init__(self, path: str | Path): - self.path = Path(path) - self.path.parent.mkdir(parents = True, exist_ok = True) - self._lock = threading.Lock() - self._data: Dict[str, Any] = {} - if self.path.exists(): - try: - with self.path.open() as f: - self._data = json.load(f) - except Exception: - self._data = {} - - def get(self, key: str, default: Any = None) -> Any: - with self._lock: - return self._data.get(key, default) - - def set(self, key: str, value: Any) -> None: - with self._lock: - self._data[key] = value - self._flush() - - def update(self, key: str, **kwargs) -> None: - with self._lock: - sub = dict(self._data.get(key, {})) - sub.update(kwargs) - self._data[key] = sub - self._flush() - - def all(self) -> Dict[str, Any]: - with self._lock: - return dict(self._data) - - def _flush(self) -> None: - tmp = self.path.with_suffix(self.path.suffix + ".tmp") - with tmp.open("w") as f: - json.dump(self._data, f, indent = 2, default = str) - os.replace(tmp, self.path) - - -class JsonlWriter: - """Append-only JSONL writer, thread-safe, with line buffering.""" - - def __init__(self, path: str | Path): - self.path = Path(path) - self.path.parent.mkdir(parents = True, exist_ok = True) - self._lock = threading.Lock() - self._fh = self.path.open("a", buffering = 1) - self._count_seen_keys: set[str] = set() - # Preload seen keys if file exists (for dedup across resumes) - if self.path.exists() and self.path.stat().st_size > 0: - try: - with self.path.open() as f: - for line in f: - try: - obj = json.loads(line) - k = self._key(obj) - if k is not None: - self._count_seen_keys.add(k) - except Exception: - pass - except Exception: - pass - - def _key(self, obj: dict) -> str | None: - for k in ("id", "node_id", "number", "sha", "url"): - if k in obj: - return f"{k}:{obj[k]}" - return None - - def has(self, key: str) -> bool: - return key in self._count_seen_keys - - def write(self, obj: dict) -> bool: - """Return True if newly written, False if already present.""" - k = self._key(obj) - with self._lock: - if k is not None and k in self._count_seen_keys: - return False - if k is not None: - self._count_seen_keys.add(k) - self._fh.write(json.dumps(obj, default = str, ensure_ascii = False)) - self._fh.write("\n") - self._fh.flush() - return True - - def close(self) -> None: - try: - self._fh.close() - except Exception: - pass diff --git a/studio/backend/requirements/extras-no-deps.txt b/studio/backend/requirements/extras-no-deps.txt index 23c61baa44..9934bacd24 100644 --- a/studio/backend/requirements/extras-no-deps.txt +++ b/studio/backend/requirements/extras-no-deps.txt @@ -2,13 +2,9 @@ descript-audio-codec descript-audiotools julius -torchcodec==0.10.0 +torchcodec snac -# peft 0.19.0 causes export subprocess shutdown issues in Studio; -# installing with --no-deps to avoid pulling in torch>=0.11.0 -peft==0.18.1 - # TRL and related packages trl==0.23.1 git+https://github.com/meta-pytorch/OpenEnv.git @@ -17,4 +13,4 @@ torch-c-dlpack-ext sentence_transformers==5.2.0 transformers==4.57.6 pytorch_tokenizers -kernels==0.12.1 +kernels diff --git a/studio/backend/requirements/single-env/data-designer-deps.txt b/studio/backend/requirements/single-env/data-designer-deps.txt index f63c076621..fc63230922 100644 --- a/studio/backend/requirements/single-env/data-designer-deps.txt +++ b/studio/backend/requirements/single-env/data-designer-deps.txt @@ -19,8 +19,7 @@ ruff<1,>=0.14.10 scipy<2,>=1.11.0 sqlfluff<4,>=3.2.0 tiktoken<1,>=0.8.0 -# Local seed plugin deps (plugins installed with --no-deps) -requests>=2.31 +# Unstructured-seed plugin deps (plugin installed with --no-deps) pymupdf>=1.24.0 pymupdf4llm>=0.0.17 mammoth>=1.8.0 diff --git a/studio/backend/routes/__init__.py b/studio/backend/routes/__init__.py index cf4586281b..e79f6553f9 100644 --- a/studio/backend/routes/__init__.py +++ b/studio/backend/routes/__init__.py @@ -8,7 +8,6 @@ from routes.training import router as training_router from routes.models import router as models_router from routes.inference import router as inference_router -from routes.inference import studio_router as inference_studio_router from routes.datasets import router as datasets_router from routes.auth import router as auth_router from routes.data_recipe import router as data_recipe_router @@ -19,7 +18,6 @@ "training_router", "models_router", "inference_router", - "inference_studio_router", "datasets_router", "auth_router", "data_recipe_router", diff --git a/studio/backend/routes/auth.py b/studio/backend/routes/auth.py index 3deeb6793b..db37ed837d 100644 --- a/studio/backend/routes/auth.py +++ b/studio/backend/routes/auth.py @@ -7,18 +7,11 @@ from fastapi import APIRouter, Depends, HTTPException, status -from datetime import datetime, timedelta, timezone - from models.auth import ( - ApiKeyListResponse, - ApiKeyResponse, AuthLoginRequest, + RefreshTokenRequest, AuthStatusResponse, ChangePasswordRequest, - CreateApiKeyRequest, - CreateApiKeyResponse, - DesktopLoginRequest, - RefreshTokenRequest, ) from models.users import Token from auth import storage, hashing @@ -81,24 +74,6 @@ async def login(payload: AuthLoginRequest) -> Token: ) -@router.post("/desktop-login", response_model = Token) -async def desktop_login(payload: DesktopLoginRequest) -> Token: - """Exchange a local desktop secret for normal admin-subject tokens.""" - username = storage.validate_desktop_secret(payload.secret) - if username is None: - raise HTTPException( - status_code = status.HTTP_401_UNAUTHORIZED, - detail = "Desktop authentication failed", - ) - - return Token( - access_token = create_access_token(subject = username, desktop = True), - refresh_token = create_refresh_token(subject = username, desktop = True), - token_type = "bearer", - must_change_password = False, - ) - - @router.post("/refresh", response_model = Token) async def refresh(payload: RefreshTokenRequest) -> Token: """ @@ -106,7 +81,7 @@ async def refresh(payload: RefreshTokenRequest) -> Token: The refresh token itself is reusable until it expires (7 days). """ - new_access_token, username, is_desktop = refresh_access_token(payload.refresh_token) + new_access_token, username = refresh_access_token(payload.refresh_token) if new_access_token is None or username is None: raise HTTPException( status_code = status.HTTP_401_UNAUTHORIZED, @@ -117,9 +92,7 @@ async def refresh(payload: RefreshTokenRequest) -> Token: access_token = new_access_token, refresh_token = payload.refresh_token, token_type = "bearer", - must_change_password = False - if is_desktop - else storage.requires_password_change(username), + must_change_password = storage.requires_password_change(username), ) @@ -158,68 +131,3 @@ async def change_password( token_type = "bearer", must_change_password = False, ) - - -# --------------------------------------------------------------------------- -# API key management -# --------------------------------------------------------------------------- - - -def _row_to_api_key_response(row: dict) -> ApiKeyResponse: - return ApiKeyResponse( - id = row["id"], - name = row["name"], - key_prefix = row["key_prefix"], - created_at = row["created_at"], - last_used_at = row.get("last_used_at"), - expires_at = row.get("expires_at"), - is_active = bool(row["is_active"]), - ) - - -@router.post("/api-keys", response_model = CreateApiKeyResponse) -async def create_api_key( - payload: CreateApiKeyRequest, - current_subject: str = Depends(get_current_subject), -) -> CreateApiKeyResponse: - """Create a new API key. The raw key is returned once and cannot be retrieved later.""" - expires_at = None - if payload.expires_in_days is not None: - expires_at = ( - datetime.now(timezone.utc) + timedelta(days = payload.expires_in_days) - ).isoformat() - - raw_key, row = storage.create_api_key( - username = current_subject, - name = payload.name, - expires_at = expires_at, - ) - return CreateApiKeyResponse( - key = raw_key, - api_key = _row_to_api_key_response(row), - ) - - -@router.get("/api-keys", response_model = ApiKeyListResponse) -async def list_api_keys( - current_subject: str = Depends(get_current_subject), -) -> ApiKeyListResponse: - """List all API keys for the authenticated user (raw keys are never exposed).""" - rows = storage.list_api_keys(current_subject) - return ApiKeyListResponse( - api_keys = [_row_to_api_key_response(r) for r in rows], - ) - - -@router.delete("/api-keys/{key_id}") -async def revoke_api_key( - key_id: int, - current_subject: str = Depends(get_current_subject), -) -> dict: - """Revoke (soft-delete) an API key.""" - if not storage.revoke_api_key(current_subject, key_id): - raise HTTPException( - status_code = status.HTTP_404_NOT_FOUND, - detail = "API key not found", - ) - return {"detail": "API key revoked"} diff --git a/studio/backend/routes/data_recipe/jobs.py b/studio/backend/routes/data_recipe/jobs.py index da6416e324..00546b47a4 100644 --- a/studio/backend/routes/data_recipe/jobs.py +++ b/studio/backend/routes/data_recipe/jobs.py @@ -5,9 +5,8 @@ from __future__ import annotations -import copy -from datetime import datetime, timedelta, timezone -from typing import Any, Optional +from datetime import timedelta +from typing import Any from urllib.parse import urlparse from fastapi import APIRouter, HTTPException, Query, Request @@ -58,20 +57,6 @@ def _resolve_local_v1_endpoint(request: Request) -> str: return f"http://127.0.0.1:{int(port)}/v1" -def _request_has_desktop_access_token(request: Request) -> bool: - auth_header = request.headers.get("authorization") - if not auth_header: - return False - - parts = auth_header.split(None, 1) - if len(parts) != 2 or parts[0].lower() != "bearer": - return False - - from auth.authentication import is_desktop_access_token - - return is_desktop_access_token(parts[1]) - - def _used_llm_model_aliases(recipe: dict[str, Any]) -> set[str]: """Return the set of model_aliases that are actually referenced by an LLM column. Used to narrow the "Chat model loaded" gate so that orphan @@ -95,111 +80,14 @@ def _used_llm_model_aliases(recipe: dict[str, Any]) -> set[str]: return aliases -def _inject_local_structured_response_format( - recipe: dict[str, Any], local_provider_names: set[str] -) -> None: - """For each llm-structured column that targets a local-provider model_config, - clone the model_config and inject an OpenAI ``response_format`` with the - column's ``output_format`` JSON schema. The column is rewritten to point at - the clone so llm-text / llm-judge columns that share the same alias keep - free-form sampling. - - Without this, data_designer only injects a prompt-level "return JSON in a - ```json fence" instruction. Small GGUF models frequently break format, - wasting the full ``max_tokens`` budget per row and then failing to parse. - Forwarding ``response_format`` lets llama-server apply grammar-constrained - sampling from the JSON schema, which guarantees a parseable response and - terminates early. - """ - columns = recipe.get("columns") - model_configs = recipe.get("model_configs") - if not isinstance(columns, list) or not isinstance(model_configs, list): - return - - # alias -> model_config (only configs referencing a local provider qualify). - alias_to_local_mc: dict[str, dict[str, Any]] = {} - for mc in model_configs: - if not isinstance(mc, dict): - continue - if mc.get("provider") in local_provider_names and isinstance( - mc.get("alias"), str - ): - alias_to_local_mc[mc["alias"]] = mc - - if not alias_to_local_mc: - return - - # Clone per (alias, column) so each llm-structured column gets its own - # schema without leaking response_format onto other columns that share the - # same base alias. - seen_clone_aliases: set[str] = { - mc.get("alias") for mc in model_configs if isinstance(mc.get("alias"), str) - } - new_configs: list[dict[str, Any]] = [] - for column in columns: - if not isinstance(column, dict): - continue - if column.get("column_type") != "llm-structured": - continue - alias = column.get("model_alias") - if not isinstance(alias, str) or alias not in alias_to_local_mc: - continue - output_format = column.get("output_format") - if not isinstance(output_format, dict) or not output_format: - continue - base_mc = alias_to_local_mc[alias] - column_name = column.get("name") or "structured" - clone_alias_base = f"{alias}__{column_name}_structured" - clone_alias = clone_alias_base - counter = 1 - while clone_alias in seen_clone_aliases: - counter += 1 - clone_alias = f"{clone_alias_base}_{counter}" - seen_clone_aliases.add(clone_alias) - - clone = copy.deepcopy(base_mc) - clone["alias"] = clone_alias - params = clone.get("inference_parameters") - if not isinstance(params, dict): - params = {} - clone["inference_parameters"] = params - # data_designer's BaseInferenceParams is a pydantic model with - # extra="forbid", so response_format cannot sit at the top level of - # inference_parameters. It does expose an `extra_body: dict` pass- - # through that the OpenAI client spreads into the request body at the - # top level, which is where llama-server reads response_format from. - # llama.cpp server shape (tools/server/README.md): the schema sits - # directly under response_format, not nested in a json_schema object - # the way OpenAI's Chat Completions API expects. llama-server converts - # the schema to a GBNF grammar and applies it during sampling. - extra_body = params.get("extra_body") - if not isinstance(extra_body, dict): - extra_body = {} - extra_body["response_format"] = { - "type": "json_schema", - "schema": output_format, - } - params["extra_body"] = extra_body - new_configs.append(clone) - column["model_alias"] = clone_alias - - if new_configs: - model_configs.extend(new_configs) - - -def _inject_local_providers(recipe: dict[str, Any], request: Request) -> Optional[int]: +def _inject_local_providers(recipe: dict[str, Any], request: Request) -> None: """ Mutate recipe dict in-place: for any provider with is_local=True, - fill in the endpoint pointing at this server and inject a short-lived - internal sk-unsloth-* API key for workflow auth. - - Returns the row id of the minted internal key (so the caller can - revoke it on job completion) or ``None`` when no local provider is - actually reachable from an LLM column. + generate a JWT and fill in the endpoint pointing at this server. """ providers = recipe.get("model_providers") if not providers: - return None + return # Collect local providers and pop is_local from ALL dicts unconditionally. # Strict `is True` guard so malformed payloads (is_local: 1, @@ -213,7 +101,7 @@ def _inject_local_providers(recipe: dict[str, Any], request: Request) -> Optiona local_indices.append(i) if not local_indices: - return None + return endpoint = _resolve_local_v1_endpoint(request) @@ -236,7 +124,6 @@ def _inject_local_providers(recipe: dict[str, Any], request: Request) -> Optiona } token = "" - internal_key_id: Optional[int] = None if local_names & referenced_providers: # Verify a model is loaded. # NOTE: This is a point-in-time check (TOCTOU). The model could be unloaded @@ -257,21 +144,17 @@ def _inject_local_providers(recipe: dict[str, Any], request: Request) -> Optiona "No model loaded in Chat. Load a model first, then run the recipe." ) - from auth import storage # deferred: avoids circular import - - # Mint an internal sk-unsloth-* key scoped to this workflow run. - # Uses the unified API-key issuance path (one mint/revoke/verify - # surface instead of a second JWT code path). The key is marked - # internal so it is hidden from the user's API-key list, and the - # caller revokes it when the job terminates. - expires_at = (datetime.now(timezone.utc) + timedelta(hours = 24)).isoformat() - token, row = storage.create_api_key( - username = "unsloth", - name = "data-recipe workflow", - expires_at = expires_at, - internal = True, + from auth.authentication import ( + create_access_token, + ) # deferred: avoids circular import + + # Uses the "unsloth" admin subject. If the user changes their password, + # the JWT secret rotates and this token becomes invalid mid-run. + # Acceptable for v1 - recipes typically finish well within one session. + token = create_access_token( + subject = "unsloth", + expires_delta = timedelta(hours = 24), ) - internal_key_id = int(row["id"]) # Defensively strip any stale "external"-only fields the frontend may # have left on the dict (extra_headers/extra_body/api_key_env). The UI @@ -298,37 +181,6 @@ def _inject_local_providers(recipe: dict[str, Any], request: Request) -> Optiona continue if mc.get("provider") in local_names: mc["skip_health_check"] = True - # Disable thinking for data-recipe inference on local providers. - # Reasoning models emit a ... preamble before the - # answer, which roughly doubles generated token count per row and - # pushes the visible answer past data_designer's json-fence - # regex. Forward chat_template_kwargs={enable_thinking: False} - # through the OpenAI SDK's extra_body passthrough so llama-server - # renders the template without the reasoning preamble. Free-form - # llm-text columns benefit from the latency cut, and structured - # columns also stop leaking think tags into the grammar- - # constrained JSON (llama-server's GBNF path still enforces the - # schema either way). - params = mc.get("inference_parameters") - if not isinstance(params, dict): - params = {} - mc["inference_parameters"] = params - extra_body = params.get("extra_body") - if not isinstance(extra_body, dict): - extra_body = {} - tpl_kwargs = extra_body.get("chat_template_kwargs") - if not isinstance(tpl_kwargs, dict): - tpl_kwargs = {} - tpl_kwargs.setdefault("enable_thinking", False) - extra_body["chat_template_kwargs"] = tpl_kwargs - params["extra_body"] = extra_body - - # Forward each llm-structured column's output_format as an OpenAI - # response_format so llama-server uses grammar-constrained sampling and - # small GGUFs stop wasting the full max_tokens budget on broken JSON. - _inject_local_structured_response_format(recipe, local_names) - - return internal_key_id def _normalize_run_name(value: Any) -> str | None: @@ -373,49 +225,21 @@ def create_job(payload: RecipePayload, request: Request): ) from exc try: - internal_api_key_id = _inject_local_providers(recipe, request) + _inject_local_providers(recipe, request) except ValueError as exc: raise HTTPException(status_code = 400, detail = str(exc)) from exc - # Single try block covers get_job_manager() AND mgr.start() so a workflow - # key minted above never outlives the request even when an unexpected - # exception type (TypeError from a stale kwarg, OSError from a queue - # write, etc.) bubbles up. Without the bare except, such exceptions let - # the sk-unsloth-* key live until its 24h TTL. + mgr = get_job_manager() try: - mgr = get_job_manager() - job_id = mgr.start( - recipe = recipe, - run = run, - internal_api_key_id = internal_api_key_id, - ) + job_id = mgr.start(recipe = recipe, run = run) except RuntimeError as exc: - if internal_api_key_id is not None: - _revoke_internal_api_key_safe(internal_api_key_id) raise HTTPException(status_code = 409, detail = str(exc)) from exc except ValueError as exc: - if internal_api_key_id is not None: - _revoke_internal_api_key_safe(internal_api_key_id) raise HTTPException(status_code = 400, detail = str(exc)) from exc - except Exception: - if internal_api_key_id is not None: - _revoke_internal_api_key_safe(internal_api_key_id) - raise return {"job_id": job_id} -def _revoke_internal_api_key_safe(key_id: int) -> None: - """Best-effort revoke of a workflow-minted key; swallow any error so - that revocation failures never mask the caller's own error path.""" - try: - from auth import storage # deferred: avoids circular import - - storage.revoke_internal_api_key(key_id) - except Exception: - pass - - @router.get("/jobs/{job_id}/status") def job_status(job_id: str): mgr = get_job_manager() diff --git a/studio/backend/routes/data_recipe/seed.py b/studio/backend/routes/data_recipe/seed.py index 91cf718e6e..e9cf828610 100644 --- a/studio/backend/routes/data_recipe/seed.py +++ b/studio/backend/routes/data_recipe/seed.py @@ -8,7 +8,6 @@ import base64 import binascii import json -import os import re from itertools import islice from pathlib import Path @@ -628,14 +627,3 @@ def inspect_seed_upload(payload: SeedInspectUploadRequest) -> SeedInspectRespons split = None, subset = None, ) - - -@router.get("/seed/github/env-token") -def get_github_env_token_status() -> dict: - """Report whether the server has a GH_TOKEN / GITHUB_TOKEN env var. - - The value is never returned; the UI uses this to tell the user they - can leave the token field blank. - """ - has_token = bool(os.environ.get("GH_TOKEN") or os.environ.get("GITHUB_TOKEN")) - return {"has_token": has_token} diff --git a/studio/backend/routes/data_recipe/validate.py b/studio/backend/routes/data_recipe/validate.py index e794d68e54..555e3eaa06 100644 --- a/studio/backend/routes/data_recipe/validate.py +++ b/studio/backend/routes/data_recipe/validate.py @@ -14,63 +14,10 @@ create_data_designer, validate_recipe, ) -from loggers import get_logger from models.data_recipe import RecipePayload, ValidateError, ValidateResponse -logger = get_logger(__name__) router = APIRouter() -_GITHUB_VALIDATE_NOTE = "Recipe shape is valid. GitHub access and rate limits are checked when the run starts." -_GITHUB_ITEM_TYPES = {"issues", "pulls", "commits"} - - -def _github_seed_source(recipe: dict[str, Any]) -> dict[str, Any] | None: - seed_config = recipe.get("seed_config") - if not isinstance(seed_config, dict): - return None - source = seed_config.get("source") - if not isinstance(source, dict) or source.get("seed_type") != "github_repo": - return None - return source - - -def _validate_github_seed_static(source: dict[str, Any]) -> list[ValidateError]: - errors: list[ValidateError] = [] - - repos = source.get("repos") - if not isinstance(repos, list) or not repos: - errors.append(ValidateError(message = "GitHub seed requires at least one repo.")) - else: - for repo in repos: - if not isinstance(repo, str) or not repo.strip() or "/" not in repo: - errors.append( - ValidateError(message = "GitHub repos must be owner/name strings.") - ) - break - - item_types = source.get("item_types") - if not isinstance(item_types, list) or not item_types: - errors.append( - ValidateError(message = "GitHub seed requires at least one item type.") - ) - else: - invalid_items = [item for item in item_types if item not in _GITHUB_ITEM_TYPES] - if invalid_items: - errors.append( - ValidateError( - message = "GitHub item types must be issues, pulls, or commits." - ) - ) - - try: - limit = int(source.get("limit")) - except (TypeError, ValueError): - limit = 0 - if limit < 1 or limit > 5000: - errors.append(ValidateError(message = "GitHub limit must be from 1 to 5000.")) - - return errors - def _collect_validation_errors(recipe: dict[str, Any]) -> list[ValidateError]: try: @@ -146,38 +93,6 @@ def validate(payload: RecipePayload) -> ValidateResponse: _patch_local_providers(recipe) - github_source = _github_seed_source(recipe) - if github_source is not None: - static_errors = _validate_github_seed_static(github_source) - if static_errors: - return ValidateResponse(valid = False, errors = static_errors) - try: - build_config_builder(recipe) - except ModuleNotFoundError as exc: - # data_designer is an optional runtime dep. Static validation - # already passed; live access + full config validation are - # deferred to run start (per _GITHUB_VALIDATE_NOTE), so a missing - # optional import at validate time should not block the recipe. - # Restrict the bypass to the data_designer module specifically so - # other ImportErrors (e.g. broken internal imports or missing - # transitive deps after a package upgrade) still surface as - # validation failures instead of being silently swallowed. - if not (exc.name or "").startswith("data_designer"): - raise - logger.debug( - "data_designer not installed; deferring full config " - "validation to run start", - missing_module = exc.name, - ) - except Exception as exc: - detail = str(exc).strip() or "Validation failed." - return ValidateResponse( - valid = False, - errors = [ValidateError(message = detail)], - raw_detail = detail, - ) - return ValidateResponse(valid = True, raw_detail = _GITHUB_VALIDATE_NOTE) - try: validate_recipe(recipe) except RuntimeError as exc: diff --git a/studio/backend/routes/datasets.py b/studio/backend/routes/datasets.py index 206af2a66f..8333009626 100644 --- a/studio/backend/routes/datasets.py +++ b/studio/backend/routes/datasets.py @@ -11,55 +11,10 @@ import sys from pathlib import Path from uuid import uuid4 -from typing import Optional -from fastapi import APIRouter, Depends, HTTPException, Query, UploadFile -import re as _re +from fastapi import APIRouter, Depends, HTTPException, UploadFile import structlog from loggers import get_logger -_VALID_REPO_ID = _re.compile(r"^[A-Za-z0-9._-]+/[A-Za-z0-9._-]+$") - - -def _is_valid_repo_id(repo_id: str) -> bool: - return bool(_VALID_REPO_ID.fullmatch(repo_id)) - - -_dataset_size_cache: dict[str, int] = {} - - -def _get_dataset_size_cached(repo_id: str) -> int: - if repo_id in _dataset_size_cache: - return _dataset_size_cache[repo_id] - try: - from huggingface_hub import dataset_info as hf_dataset_info - - info = hf_dataset_info(repo_id, token = None, files_metadata = True) - total = sum(s.size for s in info.siblings if getattr(s, "size", None)) - _dataset_size_cache[repo_id] = total - return total - except Exception: - return 0 - - -def _resolve_hf_cache_realpath(repo_dir: Path) -> Optional[str]: - """Pick the most useful on-disk path for a HF cache repo dir. - - Mirrors the helper in routes/models.py: prefer the most-recent - snapshot dir, fall back to the cache repo root, return resolved - realpath. Duplicated here to keep routes/datasets.py self-contained. - """ - try: - snapshots_dir = repo_dir / "snapshots" - if snapshots_dir.is_dir(): - snaps = [s for s in snapshots_dir.iterdir() if s.is_dir()] - if snaps: - latest = max(snaps, key = lambda s: s.stat().st_mtime) - return str(latest.resolve()) - return str(repo_dir.resolve()) - except Exception: - return None - - # Add backend directory to path backend_path = Path(__file__).parent.parent.parent if str(backend_path) not in sys.path: @@ -353,89 +308,6 @@ def list_local_datasets( return LocalDatasetsResponse(datasets = _build_local_dataset_items()) -@router.get("/download-progress") -async def get_dataset_download_progress( - repo_id: str = Query( - ..., description = "HuggingFace dataset repo ID, e.g. 'unsloth/LaTeX_OCR'" - ), - current_subject: str = Depends(get_current_subject), -): - """Return download progress for a HuggingFace dataset repo. - - Mirrors ``GET /api/models/download-progress`` but scans the - ``datasets--owner--name`` cache directory under HF_HUB_CACHE. - Modern ``datasets``/``huggingface_hub`` caches both raw model and - raw dataset blobs in HF_HUB_CACHE; the ``datasets`` library writes - its processed Arrow shards elsewhere, but the in-progress *download* - bytes are observable here. Returns ``cache_path`` so the UI can - show users where the dataset blobs landed on disk. - """ - _empty = { - "downloaded_bytes": 0, - "expected_bytes": 0, - "progress": 0, - "cache_path": None, - } - try: - if not _is_valid_repo_id(repo_id): - return _empty - - from huggingface_hub import constants as hf_constants - - cache_dir = Path(hf_constants.HF_HUB_CACHE) - target = f"datasets--{repo_id.replace('/', '--')}".lower() - completed_bytes = 0 - in_progress_bytes = 0 - cache_path: Optional[str] = None - - if cache_dir.is_dir(): - for entry in cache_dir.iterdir(): - if entry.name.lower() != target: - continue - cache_path = _resolve_hf_cache_realpath(entry) - blobs_dir = entry / "blobs" - if not blobs_dir.is_dir(): - break - for f in blobs_dir.iterdir(): - if not f.is_file(): - continue - if f.name.endswith(".incomplete"): - in_progress_bytes += f.stat().st_size - else: - completed_bytes += f.stat().st_size - break - - downloaded_bytes = completed_bytes + in_progress_bytes - if downloaded_bytes == 0: - return {**_empty, "cache_path": cache_path} - - expected_bytes = _get_dataset_size_cached(repo_id) - if expected_bytes <= 0: - return { - "downloaded_bytes": downloaded_bytes, - "expected_bytes": 0, - "progress": 0, - "cache_path": cache_path, - } - - # Same 95% completion threshold as the model endpoint -- HF blob - # dedup makes completed_bytes drift slightly under expected_bytes, - # and inter-file gaps would otherwise look like "done". - if completed_bytes >= expected_bytes * 0.95: - progress = 1.0 - else: - progress = min(downloaded_bytes / expected_bytes, 0.99) - return { - "downloaded_bytes": downloaded_bytes, - "expected_bytes": expected_bytes, - "progress": round(progress, 3), - "cache_path": cache_path, - } - except Exception as e: - logger.warning(f"Error checking dataset download progress for {repo_id}: {e}") - return _empty - - @router.post("/check-format", response_model = CheckFormatResponse) def check_format( request: CheckFormatRequest, diff --git a/studio/backend/routes/export.py b/studio/backend/routes/export.py index 798859fc87..3e60eaaf20 100644 --- a/studio/backend/routes/export.py +++ b/studio/backend/routes/export.py @@ -5,15 +5,9 @@ Export API routes: checkpoint discovery and model export operations. """ -import asyncio -import json import sys -import time from pathlib import Path -from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple - -from fastapi import APIRouter, Depends, HTTPException, Query, Request -from fastapi.responses import StreamingResponse +from fastapi import APIRouter, Depends, HTTPException, Query import structlog from loggers import get_logger @@ -103,11 +97,7 @@ async def load_checkpoint( logger.warning("Could not stop training: %s", e) backend = get_export_backend() - # load_checkpoint spawns and waits on a subprocess and can take - # minutes. Run it in a worker thread so the event loop stays - # free to serve the live log SSE stream concurrently. - success, message = await asyncio.to_thread( - backend.load_checkpoint, + success, message = backend.load_checkpoint( checkpoint_path = request.checkpoint_path, max_seq_length = request.max_seq_length, load_in_4bit = request.load_in_4bit, @@ -139,7 +129,7 @@ async def cleanup_export_memory( """ try: backend = get_export_backend() - success = await asyncio.to_thread(backend.cleanup_memory) + success = backend.cleanup_memory() if not success: raise HTTPException( @@ -183,17 +173,6 @@ async def get_export_status( ) -def _export_details(output_path: Optional[str]) -> Optional[Dict[str, Any]]: - """Wrap the resolved on-disk export path into the details dict the - frontend reads to populate the Export Complete screen. Returns None - when the export had no local component (Hub-only push) so the - Pydantic field stays absent rather than ``{"output_path": null}``. - """ - if not output_path: - return None - return {"output_path": output_path} - - @router.post("/export/merged", response_model = ExportOperationResponse) async def export_merged_model( request: ExportMergedModelRequest, @@ -206,8 +185,7 @@ async def export_merged_model( """ try: backend = get_export_backend() - success, message, output_path = await asyncio.to_thread( - backend.export_merged_model, + success, message = backend.export_merged_model( save_directory = request.save_directory, format_type = request.format_type, push_to_hub = request.push_to_hub, @@ -219,11 +197,7 @@ async def export_merged_model( if not success: raise HTTPException(status_code = 400, detail = message) - return ExportOperationResponse( - success = True, - message = message, - details = _export_details(output_path), - ) + return ExportOperationResponse(success = True, message = message) except HTTPException: raise except Exception as e: @@ -246,8 +220,7 @@ async def export_base_model( """ try: backend = get_export_backend() - success, message, output_path = await asyncio.to_thread( - backend.export_base_model, + success, message = backend.export_base_model( save_directory = request.save_directory, push_to_hub = request.push_to_hub, repo_id = request.repo_id, @@ -259,11 +232,7 @@ async def export_base_model( if not success: raise HTTPException(status_code = 400, detail = message) - return ExportOperationResponse( - success = True, - message = message, - details = _export_details(output_path), - ) + return ExportOperationResponse(success = True, message = message) except HTTPException: raise except Exception as e: @@ -286,8 +255,7 @@ async def export_gguf( """ try: backend = get_export_backend() - success, message, output_path = await asyncio.to_thread( - backend.export_gguf, + success, message = backend.export_gguf( save_directory = request.save_directory, quantization_method = request.quantization_method, push_to_hub = request.push_to_hub, @@ -298,11 +266,7 @@ async def export_gguf( if not success: raise HTTPException(status_code = 400, detail = message) - return ExportOperationResponse( - success = True, - message = message, - details = _export_details(output_path), - ) + return ExportOperationResponse(success = True, message = message) except HTTPException: raise except Exception as e: @@ -325,8 +289,7 @@ async def export_lora_adapter( """ try: backend = get_export_backend() - success, message, output_path = await asyncio.to_thread( - backend.export_lora_adapter, + success, message = backend.export_lora_adapter( save_directory = request.save_directory, push_to_hub = request.push_to_hub, repo_id = request.repo_id, @@ -337,11 +300,7 @@ async def export_lora_adapter( if not success: raise HTTPException(status_code = 400, detail = message) - return ExportOperationResponse( - success = True, - message = message, - details = _export_details(output_path), - ) + return ExportOperationResponse(success = True, message = message) except HTTPException: raise except Exception as e: @@ -350,155 +309,3 @@ async def export_lora_adapter( status_code = 500, detail = f"Failed to export LoRA adapter: {str(e)}", ) - - -# ───────────────────────────────────────────────────────────────────── -# Live export log stream (Server-Sent Events) -# ───────────────────────────────────────────────────────────────────── -# -# The export worker subprocess redirects its stdout/stderr into a pipe -# that a reader thread forwards to the orchestrator as log entries (see -# core/export/worker.py::_setup_log_capture and -# core/export/orchestrator.py::_append_log). This endpoint streams -# those entries to the browser so the export dialog can show a live -# terminal-style output panel while load_checkpoint / export_merged / -# export_gguf / export_lora / export_base run. -# -# Shape follows the training progress SSE endpoint -# (routes/training.py::stream_training_progress): each event carries -# `id`, `event`, and `data` fields, the stream starts with a `retry:` -# directive, and `Last-Event-ID` is honored on reconnect. - - -def _format_sse(data: str, event: str, event_id: Optional[int] = None) -> str: - """Format a single SSE message with id/event/data fields.""" - lines = [] - if event_id is not None: - lines.append(f"id: {event_id}") - lines.append(f"event: {event}") - lines.append(f"data: {data}") - lines.append("") - lines.append("") - return "\n".join(lines) - - -@router.get("/logs/stream") -async def stream_export_logs( - request: Request, - since: Optional[int] = Query( - None, - description = "Return log entries with seq strictly greater than this cursor.", - ), - current_subject: str = Depends(get_current_subject), -): - """ - Stream live stdout/stderr output from the export worker subprocess - as Server-Sent Events. - - Events: - - `log` : a single log line (data: {"stream","line","ts"}) - - `heartbeat`: periodic keepalive when no new lines are available - - `complete` : emitted once the export worker is idle and no new - lines arrived for ~1 second. Clients should close. - - `error` : unrecoverable server-side error - - The `id:` field on each event is the log entry's monotonic seq - number so the browser can resume via `Last-Event-ID` on reconnect. - """ - backend = get_export_backend() - - # Determine starting cursor. Explicit `since` wins, then - # Last-Event-ID header on reconnect, otherwise start from the - # run-start snapshot captured by clear_logs() so the client sees - # every line emitted since the current run began -- even if the - # SSE connection opened after the POST that kicked off the export. - # Using get_current_log_seq() here would lose the early bootstrap - # lines that arrive in the gap between POST and SSE connect. - last_event_id = request.headers.get("last-event-id") - if since is None and last_event_id is not None: - try: - since = int(last_event_id) - except ValueError: - pass - - if since is None: - cursor = backend.get_run_start_seq() - else: - cursor = max(0, int(since)) - - async def event_generator() -> AsyncGenerator[str, None]: - nonlocal cursor - # Tell the browser to reconnect after 3 seconds if the - # connection drops mid-export. - yield "retry: 3000\n\n" - - last_yield = time.monotonic() - idle_since: Optional[float] = None - try: - while True: - if await request.is_disconnected(): - return - - entries, new_cursor = backend.get_logs_since(cursor) - if entries: - for entry in entries: - payload = json.dumps( - { - "stream": entry.get("stream", "stdout"), - "line": entry.get("line", ""), - "ts": entry.get("ts"), - } - ) - yield _format_sse( - payload, - event = "log", - event_id = int(entry.get("seq", 0)), - ) - cursor = new_cursor - last_yield = time.monotonic() - idle_since = None - else: - now = time.monotonic() - if now - last_yield > 10.0: - yield _format_sse("{}", event = "heartbeat") - last_yield = now - if not backend.is_export_active(): - # Give the reader thread a moment to drain any - # trailing lines the worker process printed - # just before signalling done. - if idle_since is None: - idle_since = now - elif now - idle_since > 1.0: - yield _format_sse( - "{}", - event = "complete", - event_id = cursor, - ) - return - else: - idle_since = None - - await asyncio.sleep(0.1) - except asyncio.CancelledError: - # Client disconnected mid-yield. Don't re-raise, just end - # the generator cleanly so StreamingResponse finalizes. - return - except Exception as exc: - logger.error("Export log stream failed: %s", exc, exc_info = True) - try: - yield _format_sse( - json.dumps({"error": str(exc)}), - event = "error", - ) - except Exception: - pass - - return StreamingResponse( - event_generator(), - media_type = "text/event-stream", - headers = { - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", - }, - ) diff --git a/studio/backend/routes/inference.py b/studio/backend/routes/inference.py index a6b00360af..30ff7da49c 100644 --- a/studio/backend/routes/inference.py +++ b/studio/backend/routes/inference.py @@ -11,10 +11,9 @@ import uuid from pathlib import Path from fastapi import APIRouter, Depends, HTTPException, Request, status -from fastapi.responses import StreamingResponse, JSONResponse, Response -from typing import Any, Optional, Union +from fastapi.responses import StreamingResponse, JSONResponse +from typing import Optional import json -import httpx import structlog from loggers import get_logger import asyncio @@ -27,68 +26,8 @@ from utils.models import extract_model_size_b as _extract_model_size_b -def _install_httpcore_asyncgen_silencer() -> None: - """Silence benign httpx/httpcore asyncgen GC noise on Python 3.13. - - When Studio proxies a streaming response from llama-server via httpx, - the innermost ``HTTP11ConnectionByteStream.__aiter__`` async generator - is finalised by Python's asyncgen GC hook on a task different from the - one that opened it. Its ``aclose`` path then calls - ``anyio.Lock.acquire`` → ``cancel_shielded_checkpoint`` which enters a - ``CancelScope`` on the finaliser task — Python 3.13 flags the - cross-task exit as ``"Attempted to exit cancel scope in a different - task"`` and prints ``"async generator ignored GeneratorExit"`` as an - unraisable warning. - - This is a known httpx + httpcore + anyio interaction (see MCP SDK - python-sdk#831, agno #3556, chainlit #2361, langchain-mcp-adapters - #254). It is benign: the response has already been delivered with a - 200. The streaming pass-throughs (``/v1/chat/completions``, - ``/v1/messages``, ``/v1/responses``, ``/v1/completions``) already - manage their httpx lifecycle inside a single task with explicit - ``aclose()`` of the lines iterator, response, and client; the errant - generator is not one we hold a reference to and therefore cannot - close ourselves. - - We install a single process-wide unraisable hook that swallows just - this specific interaction — identified by the tuple of (RuntimeError - mentioning cancel scope / GeneratorExit) + (object repr referencing - HTTP11ConnectionByteStream) — and defers to the default hook for - everything else. The filter is idempotent. - """ - prior_hook = sys.unraisablehook - if getattr(prior_hook, "_unsloth_httpcore_silencer", False): - return - - def _hook(unraisable): - exc_value = getattr(unraisable, "exc_value", None) - obj = getattr(unraisable, "object", None) - obj_repr = repr(obj) if obj is not None else "" - if ( - isinstance(exc_value, RuntimeError) - and "HTTP11ConnectionByteStream" in obj_repr - and ("cancel scope" in str(exc_value) or "GeneratorExit" in str(exc_value)) - ): - return - prior_hook(unraisable) - - _hook._unsloth_httpcore_silencer = True # type: ignore[attr-defined] - sys.unraisablehook = _hook - - -_install_httpcore_asyncgen_silencer() - - def _friendly_error(exc: Exception) -> str: """Extract a user-friendly message from known llama-server errors.""" - # httpx transport-layer failures reaching the managed llama-server — - # raised by the async pass-through helpers that talk to llama-server - # directly. Treat any RequestError subclass (ConnectError, ReadError, - # RemoteProtocolError, WriteError, PoolTimeout, ...) as "the upstream - # subprocess is unreachable", which for Studio always means the - # llama-server subprocess crashed or is still coming up. - if isinstance(exc, httpx.RequestError): - return "Lost connection to the model server. It may have crashed -- try reloading the model." msg = str(exc) m = _re.search( r"request \((\d+) tokens?\) exceeds the available context size \((\d+) tokens?\)", @@ -113,58 +52,30 @@ def _friendly_error(exc: Exception) -> str: # Import backend functions try: from core.inference import get_inference_backend - from core.inference.llama_cpp import ( - LlamaCppBackend, - _DEFAULT_MAX_TOKENS_FLOOR, - _DEFAULT_T_MAX_PREDICT_MS, - detect_reasoning_flags, - ) - from core.inference.llama_server_args import validate_extra_args + from core.inference.llama_cpp import LlamaCppBackend from utils.models import ModelConfig from utils.inference import load_inference_config from utils.models.model_config import load_model_defaults - from utils.native_path_leases import ( - NativePathLeaseError, - display_label_for_native_path, - is_registered_native_path_label, - redact_native_paths, - verify_native_path_lease, - ) except ImportError: parent_backend = backend_path.parent / "backend" if str(parent_backend) not in sys.path: sys.path.insert(0, str(parent_backend)) from core.inference import get_inference_backend - from core.inference.llama_cpp import ( - LlamaCppBackend, - _DEFAULT_MAX_TOKENS_FLOOR, - _DEFAULT_T_MAX_PREDICT_MS, - detect_reasoning_flags, - ) - from core.inference.llama_server_args import validate_extra_args + from core.inference.llama_cpp import LlamaCppBackend from utils.models import ModelConfig from utils.inference import load_inference_config from utils.models.model_config import load_model_defaults - from utils.native_path_leases import ( - NativePathLeaseError, - display_label_for_native_path, - is_registered_native_path_label, - redact_native_paths, - verify_native_path_lease, - ) from models.inference import ( LoadRequest, UnloadRequest, GenerateRequest, LoadResponse, - LoadProgressResponse, UnloadResponse, InferenceStatusResponse, ChatCompletionRequest, ChatCompletionChunk, ChatCompletion, - ChatMessage, ChunkChoice, ChoiceDelta, CompletionChoice, @@ -172,35 +83,6 @@ def _friendly_error(exc: Exception) -> str: CompletionUsage, ValidateModelRequest, ValidateModelResponse, - TextContentPart, - ImageContentPart, - ImageUrl, - ResponsesRequest, - ResponsesInputMessage, - ResponsesInputTextPart, - ResponsesInputImagePart, - ResponsesOutputTextPart, - ResponsesUnknownContentPart, - ResponsesUnknownInputItem, - ResponsesFunctionCallInputItem, - ResponsesFunctionCallOutputInputItem, - ResponsesOutputTextContent, - ResponsesOutputMessage, - ResponsesOutputFunctionCall, - ResponsesUsage, - ResponsesResponse, - AnthropicMessagesRequest, - AnthropicMessagesResponse, - AnthropicResponseTextBlock, - AnthropicResponseToolUseBlock, - AnthropicUsage, -) -from core.inference.anthropic_compat import ( - anthropic_messages_to_openai, - anthropic_tools_to_openai, - anthropic_tool_choice_to_openai, - AnthropicStreamEmitter, - AnthropicPassthroughEmitter, ) from auth.authentication import get_current_subject @@ -211,138 +93,6 @@ def _friendly_error(exc: Exception) -> str: from datetime import date as _date router = APIRouter() -# Studio-only router (not mounted on /v1 OpenAI-compat). -studio_router = APIRouter() - - -def _effective_enable_tools(payload) -> Optional[bool]: - """Resolve `payload.enable_tools` against the process-level tool policy. - - Returns the policy value when set (CLI hard-override from `unsloth run`), - otherwise the per-request value. - """ - from state.tool_policy import get_tool_policy - - policy = get_tool_policy() - return policy if policy is not None else payload.enable_tools - - -# Cancel registry. Proxies (e.g. Colab) can swallow client fetch aborts -# so is_disconnected() never fires. POST /inference/cancel looks up -# in-flight cancel_events here by cancel_id (per-run) or session_id / -# completion_id (fallbacks). -_CANCEL_REGISTRY: dict[str, set[threading.Event]] = {} -_CANCEL_LOCK = threading.Lock() - -# Cancel POSTs that arrive before registration are stashed; the next -# matching __enter__ replays set() within the TTL. -_PENDING_CANCELS: dict[str, float] = {} -_PENDING_CANCEL_TTL_S = 30.0 - - -def _prune_pending(now: float) -> None: - for k in [ - k for k, ts in _PENDING_CANCELS.items() if now - ts > _PENDING_CANCEL_TTL_S - ]: - _PENDING_CANCELS.pop(k, None) - - -class _TrackedCancel: - """Register cancel_event in _CANCEL_REGISTRY for the block's duration.""" - - def __init__(self, event: threading.Event, *keys): - self.event = event - self.keys = tuple(k for k in keys if k) - - def __enter__(self): - # Register + consume-pending must be one critical section to close - # the TOCTOU race against a concurrent cancel POST. - should_cancel = False - with _CANCEL_LOCK: - for k in self.keys: - _CANCEL_REGISTRY.setdefault(k, set()).add(self.event) - now = time.monotonic() - _prune_pending(now) - for k in self.keys: - if k and _PENDING_CANCELS.pop(k, None) is not None: - should_cancel = True - if should_cancel: - self.event.set() - return self.event - - def __exit__(self, *exc): - with _CANCEL_LOCK: - for k in self.keys: - bucket = _CANCEL_REGISTRY.get(k) - if bucket is None: - continue - bucket.discard(self.event) - if not bucket: - _CANCEL_REGISTRY.pop(k, None) - return False - - -def _cancel_by_keys(keys) -> int: - """Set cancel_event for matching registry entries; no stash. - session_id/completion_id are shared across runs on the same thread, - so stashing them would ghost-cancel the user's next request. Only - cancel_id is per-run unique (see _cancel_by_cancel_id_or_stash).""" - if not keys: - return 0 - events: set[threading.Event] = set() - with _CANCEL_LOCK: - _prune_pending(time.monotonic()) - for k in keys: - bucket = _CANCEL_REGISTRY.get(k) - if bucket: - events.update(bucket) - for ev in events: - ev.set() - return len(events) - - -def _cancel_by_cancel_id_or_stash(cancel_id: str) -> int: - """Atomic lookup-or-stash; pairs with _TrackedCancel.__enter__ to - close the TOCTOU race.""" - now = time.monotonic() - events: set[threading.Event] = set() - with _CANCEL_LOCK: - _prune_pending(now) - bucket = _CANCEL_REGISTRY.get(cancel_id) - if bucket: - events.update(bucket) - else: - _PENDING_CANCELS[cancel_id] = now - for ev in events: - ev.set() - return len(events) - - -async def _await_cancel_then_close(cancel_event, resp) -> None: - """Watch a threading.Event from asyncio and close ``resp`` when it fires. - - Used by the passthrough streamers so a /cancel POST can interrupt - while the async iterator is blocked waiting for llama-server prefill. - Without this watcher the in-loop ``cancel_event.is_set()`` check is - unreachable until the first SSE chunk arrives, which is exactly the - proxy/Colab scenario the cancel POST exists to handle. - - Polls a threading.Event because the cancel registry is keyed by - threading.Event so the synchronous /cancel handler can call .set(). - 50ms cadence adds at most that much latency to a prefill cancel; the - common-case streaming cancel path still observes the event in the - iterator's first iteration after the next chunk. - """ - try: - while not cancel_event.is_set(): - await asyncio.sleep(0.05) - try: - await resp.aclose() - except Exception: - pass - except asyncio.CancelledError: - return - # Appended to tool-use nudge to discourage plan-without-action _TOOL_ACTION_NUDGE = ( @@ -360,65 +110,6 @@ async def _await_cancel_then_close(cancel_event, resp) -> None: logger = get_logger(__name__) -def _validate_native_mmproj_companion( - mmproj_path: str | None, gguf_path: str | None -) -> None: - if not mmproj_path or not gguf_path: - return - import stat as _stat_module - - mm = Path(mmproj_path) - gguf = Path(gguf_path) - try: - mm_lstat = os.lstat(mm) - except OSError as exc: - raise HTTPException( - status_code = 400, - detail = "Native vision companion is no longer accessible.", - ) from exc - if _stat_module.S_ISLNK(mm_lstat.st_mode) or not _stat_module.S_ISREG( - mm_lstat.st_mode - ): - raise HTTPException( - status_code = 400, - detail = "Native vision companion must be a regular file.", - ) - try: - if mm.resolve(strict = True).parent != gguf.resolve(strict = True).parent: - raise HTTPException( - status_code = 400, - detail = "Native vision companion must live next to the selected GGUF.", - ) - except OSError as exc: - raise HTTPException( - status_code = 400, - detail = "Native vision companion is no longer accessible.", - ) from exc - - -def _resolve_model_identifier_for_request( - request: LoadRequest | ValidateModelRequest, - *, - operation: str, -) -> tuple[str, str, bool]: - if not request.native_path_lease: - return request.model_path, request.model_path, False - try: - grant = verify_native_path_lease( - request.native_path_lease, - operation = operation, - expected_kind = "model", - expected_path_type = "file", - allowed_suffixes = (".gguf",), - ) - except NativePathLeaseError as exc: - raise HTTPException(status_code = 400, detail = str(exc)) from exc - display_label = ( - grant.display_label or Path(request.model_path).name or "Native model" - ) - return str(grant.canonical_path), display_label, True - - # GGUF inference backend (llama-server) _llama_cpp_backend = LlamaCppBackend() @@ -430,7 +121,6 @@ def get_llama_cpp_backend() -> LlamaCppBackend: @router.post("/load", response_model = LoadResponse) async def load_model( request: LoadRequest, - fastapi_request: Request, current_subject: str = Depends(get_current_subject), ): """ @@ -442,19 +132,7 @@ async def load_model( GGUF models are loaded via llama-server (llama.cpp) instead of Unsloth. """ - native_grant_backed = False - model_log_label = request.model_path try: - # Validate user-supplied llama-server pass-through args up front - # so a managed-flag collision returns 400 before any model work. - try: - extra_llama_args = validate_extra_args(request.llama_extra_args) - except ValueError as exc: - raise HTTPException(status_code = 400, detail = str(exc)) - - model_identifier, model_log_label, native_grant_backed = ( - _resolve_model_identifier_for_request(request, operation = "load-model") - ) # Version switching is handled automatically by the subprocess-based # inference backend — no need for ensure_transformers_version() here. @@ -468,10 +146,10 @@ async def load_model( and llama_backend.hf_variant and llama_backend.hf_variant.lower() == request.gguf_variant.lower() and llama_backend.model_identifier - and llama_backend.model_identifier.lower() == model_identifier.lower() + and llama_backend.model_identifier.lower() == request.model_path.lower() ): logger.info( - f"Model already loaded (GGUF): {model_log_label} variant={request.gguf_variant}, skipping reload" + f"Model already loaded (GGUF): {request.model_path} variant={request.gguf_variant}, skipping reload" ) inference_config = load_inference_config(llama_backend.model_identifier) from utils.models import is_audio_input_type @@ -484,12 +162,8 @@ async def load_model( _gguf_is_audio = getattr(llama_backend, "_is_audio", False) return LoadResponse( status = "already_loaded", - model = model_log_label - if native_grant_backed - else llama_backend.model_identifier, - display_name = model_log_label - if native_grant_backed - else llama_backend.model_identifier, + model = llama_backend.model_identifier, + display_name = llama_backend.model_identifier, is_vision = llama_backend._is_vision, is_lora = False, is_gguf = True, @@ -499,26 +173,21 @@ async def load_model( if _gguf_audio else False, inference = inference_config, - requires_trust_remote_code = bool( - inference_config.get("trust_remote_code", False) - ), context_length = llama_backend.context_length, max_context_length = llama_backend.max_context_length, native_context_length = llama_backend.native_context_length, supports_reasoning = llama_backend.supports_reasoning, - reasoning_style = llama_backend.reasoning_style, reasoning_always_on = llama_backend.reasoning_always_on, - supports_preserve_thinking = llama_backend.supports_preserve_thinking, chat_template = llama_backend.chat_template, speculative_type = llama_backend.speculative_type, ) else: if ( backend.active_model_name - and backend.active_model_name.lower() == model_identifier.lower() + and backend.active_model_name.lower() == request.model_path.lower() ): logger.info( - f"Model already loaded (Unsloth): {model_log_label}, skipping reload" + f"Model already loaded (Unsloth): {request.model_path}, skipping reload" ) inference_config = load_inference_config(backend.active_model_name) _model_info = backend.models.get(backend.active_model_name, {}) @@ -530,29 +199,10 @@ async def load_model( logger.warning( f"Could not retrieve chat template for {backend.active_model_name}: {e}" ) - # Non-GGUF: only advertise reasoning for gpt-oss Harmony, - # which emits reasoning via channels at the tokenizer level. - # Template-level chat_template_kwargs (enable_thinking / - # preserve_thinking / tools) are not yet forwarded through - # the transformers generation path, so avoid advertising - # controls the server cannot honour outside GGUF. - _sf_supports_reasoning = False - _sf_reasoning_style = "enable_thinking" - if hasattr(backend, "_is_gpt_oss_model"): - try: - if backend._is_gpt_oss_model(): - _sf_supports_reasoning = True - _sf_reasoning_style = "reasoning_effort" - except Exception: - pass return LoadResponse( status = "already_loaded", - model = model_log_label - if native_grant_backed - else backend.active_model_name, - display_name = model_log_label - if native_grant_backed - else backend.active_model_name, + model = backend.active_model_name, + display_name = backend.active_model_name, is_vision = _model_info.get("is_vision", False), is_lora = _model_info.get("is_lora", False), is_gguf = False, @@ -560,21 +210,13 @@ async def load_model( audio_type = _model_info.get("audio_type"), has_audio_input = _model_info.get("has_audio_input", False), inference = inference_config, - requires_trust_remote_code = bool( - inference_config.get("trust_remote_code", False) - ), - supports_reasoning = _sf_supports_reasoning, - reasoning_style = _sf_reasoning_style, - reasoning_always_on = False, - supports_preserve_thinking = False, - supports_tools = False, chat_template = _chat_template, ) # Create config using clean factory method # is_lora is auto-detected from adapter_config.json on disk/HF config = ModelConfig.from_identifier( - model_id = model_identifier, + model_id = request.model_path, hf_token = request.hf_token, gguf_variant = request.gguf_variant, ) @@ -582,7 +224,7 @@ async def load_model( if not config: raise HTTPException( status_code = 400, - detail = f"Invalid model identifier: {model_log_label}", + detail = f"Invalid model identifier: {request.model_path}", ) # Normalize gpu_ids: empty list means auto-selection, same as None @@ -610,8 +252,6 @@ async def load_model( # Run in a thread so the event loop stays free for progress # polling and other requests during the (potentially long) # GGUF download + llama-server startup. - _n_parallel = getattr(fastapi_request.app.state, "llama_parallel_slots", 1) - if config.gguf_hf_repo: # HF mode: download via huggingface_hub then start llama-server success = await asyncio.to_thread( @@ -625,15 +265,9 @@ async def load_model( chat_template_override = request.chat_template_override, cache_type_kv = request.cache_type_kv, speculative_type = request.speculative_type, - n_parallel = _n_parallel, - extra_args = extra_llama_args, ) else: # Local mode: llama-server loads via -m - if native_grant_backed and config.gguf_mmproj_file: - _validate_native_mmproj_companion( - config.gguf_mmproj_file, config.gguf_file - ) success = await asyncio.to_thread( llama_backend.load_model, gguf_path = config.gguf_file, @@ -644,19 +278,15 @@ async def load_model( chat_template_override = request.chat_template_override, cache_type_kv = request.cache_type_kv, speculative_type = request.speculative_type, - n_parallel = _n_parallel, - extra_args = extra_llama_args, ) if not success: raise HTTPException( status_code = 500, - detail = f"Failed to load GGUF model: {model_log_label if native_grant_backed else config.display_name}", + detail = f"Failed to load GGUF model: {config.display_name}", ) - logger.info( - f"Loaded GGUF model via llama-server: {model_log_label if native_grant_backed else config.identifier}" - ) + logger.info(f"Loaded GGUF model via llama-server: {config.identifier}") # Detect TTS audio by probing the loaded model's vocabulary from utils.models import is_audio_input_type @@ -665,10 +295,6 @@ async def load_model( _gguf_is_audio = _gguf_audio in ("snac", "bicodec", "dac") llama_backend._is_audio = _gguf_is_audio llama_backend._audio_type = _gguf_audio - llama_backend._native_display_label = ( - model_log_label if native_grant_backed else None - ) - llama_backend._native_grant_backed = bool(native_grant_backed) if _gguf_is_audio: logger.info(f"GGUF model detected as audio: audio_type={_gguf_audio}") await asyncio.to_thread(llama_backend.init_audio_codec, _gguf_audio) @@ -677,10 +303,8 @@ async def load_model( return LoadResponse( status = "loaded", - model = model_log_label if native_grant_backed else config.identifier, - display_name = model_log_label - if native_grant_backed - else config.display_name, + model = config.identifier, + display_name = config.display_name, is_vision = config.is_vision, is_lora = False, is_gguf = True, @@ -688,16 +312,11 @@ async def load_model( audio_type = _gguf_audio, has_audio_input = is_audio_input_type(_gguf_audio), inference = inference_config, - requires_trust_remote_code = bool( - inference_config.get("trust_remote_code", False) - ), context_length = llama_backend.context_length, max_context_length = llama_backend.max_context_length, native_context_length = llama_backend.native_context_length, supports_reasoning = llama_backend.supports_reasoning, - reasoning_style = llama_backend.reasoning_style, reasoning_always_on = llama_backend.reasoning_always_on, - supports_preserve_thinking = llama_backend.supports_preserve_thinking, supports_tools = llama_backend.supports_tools, cache_type_kv = llama_backend.cache_type_kv, chat_template = llama_backend.chat_template, @@ -803,13 +422,10 @@ async def load_model( ), ) raise HTTPException( - status_code = 500, - detail = f"Failed to load model: {model_log_label if native_grant_backed else config.display_name}", + status_code = 500, detail = f"Failed to load model: {config.display_name}" ) - logger.info( - f"Loaded model: {model_log_label if native_grant_backed else config.identifier}" - ) + logger.info(f"Loaded model: {config.identifier}") # Load inference configuration parameters inference_config = load_inference_config(config.identifier) @@ -823,26 +439,10 @@ async def load_model( except Exception: pass - # Non-GGUF: gpt-oss Harmony surfaces reasoning via tokenizer-level - # channels; other safetensors reasoning/tools/preserve-thinking - # knobs are not forwarded to tokenizer.apply_chat_template yet, so - # we only advertise support for the Harmony case here. - _sf_supports_reasoning = False - _sf_reasoning_style = "enable_thinking" - if hasattr(backend, "_is_gpt_oss_model"): - try: - if backend._is_gpt_oss_model(): - _sf_supports_reasoning = True - _sf_reasoning_style = "reasoning_effort" - except Exception: - pass - return LoadResponse( status = "loaded", - model = model_log_label if native_grant_backed else config.identifier, - display_name = model_log_label - if native_grant_backed - else config.display_name, + model = config.identifier, + display_name = config.display_name, is_vision = config.is_vision, is_lora = config.is_lora, is_gguf = False, @@ -850,31 +450,17 @@ async def load_model( audio_type = config.audio_type, has_audio_input = config.has_audio_input, inference = inference_config, - requires_trust_remote_code = bool( - inference_config.get("trust_remote_code", False) - ), - supports_reasoning = _sf_supports_reasoning, - reasoning_style = _sf_reasoning_style, - reasoning_always_on = False, - supports_preserve_thinking = False, - supports_tools = False, chat_template = _chat_template, ) except HTTPException: raise except ValueError as e: - if native_grant_backed: - redacted_msg = redact_native_paths(str(e)) - logger.warning( - "Rejected inference selection for native model %s: %s", - model_log_label, - redacted_msg, - ) - raise HTTPException(status_code = 400, detail = redacted_msg) logger.warning("Rejected inference GPU selection: %s", e) raise HTTPException(status_code = 400, detail = str(e)) except Exception as e: + logger.error(f"Error loading model: {e}", exc_info = True) + msg = str(e) # Surface a friendlier message for models that Unsloth cannot load not_supported_hints = [ "No config file found", @@ -882,22 +468,6 @@ async def load_model( "is not supported", "does not support", ] - if native_grant_backed: - redacted_msg = redact_native_paths(str(e)) - logger.error( - "Error loading native model %s: %s", - model_log_label, - redacted_msg, - ) - msg = redacted_msg - if any(h.lower() in msg.lower() for h in not_supported_hints): - msg = f"This model is not supported yet. Try a different model. (Original error: {msg})" - raise HTTPException( - status_code = 500, - detail = f"Failed to load native model {model_log_label}: {msg}", - ) - logger.error(f"Error loading model: {e}", exc_info = True) - msg = str(e) if any(h.lower() in msg.lower() for h in not_supported_hints): msg = f"This model is not supported yet. Try a different model. (Original error: {msg})" raise HTTPException(status_code = 500, detail = f"Failed to load model: {msg}") @@ -914,14 +484,9 @@ async def validate_model( This checks that ModelConfig.from_identifier() can resolve the given model_path, but it does NOT actually load model weights into GPU memory. """ - native_grant_backed = False - model_log_label = request.model_path try: - model_identifier, model_log_label, native_grant_backed = ( - _resolve_model_identifier_for_request(request, operation = "validate-model") - ) config = ModelConfig.from_identifier( - model_id = model_identifier, + model_id = request.model_path, hf_token = request.hf_token, gguf_variant = request.gguf_variant, ) @@ -929,47 +494,22 @@ async def validate_model( if not config: raise HTTPException( status_code = 400, - detail = f"Invalid model identifier: {model_log_label}", + detail = f"Invalid model identifier: {request.model_path}", ) return ValidateModelResponse( valid = True, message = "Model identifier is valid.", - identifier = model_log_label if native_grant_backed else config.identifier, - display_name = model_log_label - if native_grant_backed - else getattr(config, "display_name", config.identifier), + identifier = config.identifier, + display_name = getattr(config, "display_name", config.identifier), is_gguf = getattr(config, "is_gguf", False), is_lora = getattr(config, "is_lora", False), is_vision = getattr(config, "is_vision", False), - requires_trust_remote_code = bool( - load_inference_config(config.identifier).get("trust_remote_code", False) - ), ) except HTTPException: raise except Exception as e: - not_supported_hints = [ - "No config file found", - "not yet supported", - "is not supported", - "does not support", - ] - if native_grant_backed: - redacted_msg = redact_native_paths(str(e)) - logger.error( - "Error validating native model %s: %s", - model_log_label, - redacted_msg, - ) - msg = redacted_msg - if any(h.lower() in msg.lower() for h in not_supported_hints): - msg = f"This model is not supported yet. Try a different model. (Original error: {msg})" - raise HTTPException( - status_code = 400, - detail = f"Invalid native model {model_log_label}: {msg}", - ) logger.error( f"Error validating model identifier '{request.model_path}': {e}", exc_info = True, @@ -994,9 +534,6 @@ async def unload_model( llama_backend = get_llama_cpp_backend() if llama_backend.is_active and ( llama_backend.model_identifier == request.model_path - or is_registered_native_path_label( - llama_backend.model_identifier, request.model_path - ) or not llama_backend.is_loaded ): llama_backend.unload_model() @@ -1014,48 +551,6 @@ async def unload_model( raise HTTPException(status_code = 500, detail = f"Failed to unload model: {str(e)}") -@studio_router.post("/cancel") -async def cancel_inference( - request: Request, - current_subject: str = Depends(get_current_subject), -): - """Cancel in-flight inference requests. - - Body (JSON, at least one key required): - cancel_id - preferred: per-run UUID, matched exclusively. - session_id - fallback when cancel_id is absent. - completion_id - fallback when cancel_id is absent. - - A cancel_id arriving before its stream registers is stashed briefly - and replayed on registration. Returns {"cancelled": N}. - """ - try: - body = await request.json() - if not isinstance(body, dict): - body = {} - except Exception as e: - logger.debug("Failed to parse cancel request body: %s", e) - body = {} - - cancel_id = body.get("cancel_id") - if isinstance(cancel_id, str) and cancel_id: - return {"cancelled": _cancel_by_cancel_id_or_stash(cancel_id)} - - keys = [] - # `message_id` is the Anthropic passthrough's per-run identifier -- - # included so /v1/messages clients can cancel by their native id. - for k in ("completion_id", "session_id", "message_id"): - v = body.get(k) - if isinstance(v, str) and v: - keys.append(v) - - if not keys: - return {"cancelled": 0} - - n = _cancel_by_keys(keys) - return {"cancelled": n} - - @router.post("/generate/stream") async def generate_stream( request: GenerateRequest, @@ -1144,37 +639,20 @@ async def get_status( # If a GGUF model is loaded via llama-server, report that if llama_backend.is_loaded: _model_id = llama_backend.model_identifier - _native_grant_backed = getattr(llama_backend, "_native_grant_backed", False) - _display_model_id = getattr( - llama_backend, "_native_display_label", None - ) or display_label_for_native_path(_model_id) - if ( - _native_grant_backed - and _model_id - and _display_model_id == _model_id - and os.path.isabs(_model_id) - ): - _display_model_id = os.path.basename(_model_id) _inference_cfg = load_inference_config(_model_id) if _model_id else None return InferenceStatusResponse( - active_model = _display_model_id, + active_model = _model_id, is_vision = llama_backend.is_vision, is_gguf = True, gguf_variant = llama_backend.hf_variant, is_audio = getattr(llama_backend, "_is_audio", False), audio_type = getattr(llama_backend, "_audio_type", None), loading = [], - loaded = [_display_model_id] if _display_model_id else [], + loaded = [_model_id], inference = _inference_cfg, - requires_trust_remote_code = bool( - (_inference_cfg or {}).get("trust_remote_code", False) - ), supports_reasoning = llama_backend.supports_reasoning, - reasoning_style = llama_backend.reasoning_style, reasoning_always_on = llama_backend.reasoning_always_on, - supports_preserve_thinking = llama_backend.supports_preserve_thinking, supports_tools = llama_backend.supports_tools, - chat_template = llama_backend.chat_template, context_length = llama_backend.context_length, max_context_length = llama_backend.max_context_length, native_context_length = llama_backend.native_context_length, @@ -1188,37 +666,17 @@ async def get_status( is_audio = False audio_type = None has_audio_input = False - model_info = {} if backend.active_model_name: model_info = backend.models.get(backend.active_model_name, {}) is_vision = model_info.get("is_vision", False) is_audio = model_info.get("is_audio", False) audio_type = model_info.get("audio_type") has_audio_input = model_info.get("has_audio_input", False) - chat_template_info = model_info.get("chat_template_info", {}) - chat_template = ( - chat_template_info.get("template") - if isinstance(chat_template_info, dict) - else None - ) - # Non-GGUF: only gpt-oss Harmony is wired through the transformers - # generation path. Other template-level reasoning / tool kwargs - # are not yet forwarded, so we do not advertise them here. + # gpt-oss safetensors models support reasoning via harmony channels supports_reasoning = False - reasoning_style = "enable_thinking" if backend.active_model_name and hasattr(backend, "_is_gpt_oss_model"): - try: - if backend._is_gpt_oss_model(): - supports_reasoning = True - reasoning_style = "reasoning_effort" - except Exception: - pass - inference_config = ( - load_inference_config(backend.active_model_name) - if backend.active_model_name - else None - ) + supports_reasoning = backend._is_gpt_oss_model() return InferenceStatusResponse( active_model = backend.active_model_name, @@ -1229,16 +687,7 @@ async def get_status( has_audio_input = has_audio_input, loading = list(getattr(backend, "loading_models", set())), loaded = list(backend.models.keys()), - inference = inference_config, - requires_trust_remote_code = bool( - (inference_config or {}).get("trust_remote_code", False) - ), supports_reasoning = supports_reasoning, - reasoning_style = reasoning_style, - reasoning_always_on = False, - supports_preserve_thinking = False, - supports_tools = False, - chat_template = chat_template, ) except Exception as e: @@ -1246,34 +695,6 @@ async def get_status( raise HTTPException(status_code = 500, detail = f"Failed to get status: {str(e)}") -@router.get("/load-progress", response_model = LoadProgressResponse) -async def get_load_progress( - current_subject: str = Depends(get_current_subject), -): - """ - Return the active GGUF load's mmap/upload progress. - - During the warmup window after a GGUF download -- when llama-server - is paging ~tens-to-hundreds of GB of shards into the page cache - before pushing layers to VRAM -- ``/api/inference/status`` only - shows a generic spinner. This endpoint exposes sampled progress so - the UI can render a real bar plus rate/ETA during that window. - - Returns an empty payload (``phase=null, bytes=0``) when no load is - in flight. The frontend should stop polling once ``phase`` becomes - ``ready``. - """ - try: - llama_backend = get_llama_cpp_backend() - progress = llama_backend.load_progress() - if progress is None: - return LoadProgressResponse() - return LoadProgressResponse(**progress) - except Exception as e: - logger.warning(f"Error sampling load progress: {e}") - return LoadProgressResponse() - - # ===================================================================== # Audio (TTS) Generation (/audio/generate) # ===================================================================== @@ -1484,20 +905,6 @@ async def openai_chat_completions( llama_backend = get_llama_cpp_backend() using_gguf = llama_backend.is_loaded - # OpenAI-SDK clients send ``chat_template_kwargs`` via ``extra_body``, - # which the SDK spreads into the request body at the top level. Studio's - # ChatCompletionRequest has ``extra="allow"`` so pydantic stashes them in - # ``model_extra``, but the typed ``payload.enable_thinking`` path is what - # downstream generators actually consume. Lift ``enable_thinking`` from - # the extra-body chat_template_kwargs onto the typed field so clients - # that only know the OpenAI shape (data_designer recipe runs, etc.) - # can still control the reasoning preamble. - _extra = getattr(payload, "model_extra", None) - if payload.enable_thinking is None and isinstance(_extra, dict): - _tpl_kw = _extra.get("chat_template_kwargs") - if isinstance(_tpl_kw, dict) and "enable_thinking" in _tpl_kw: - payload.enable_thinking = bool(_tpl_kw["enable_thinking"]) - # ── Determine which backend is active ───────────────────── if using_gguf: model_name = llama_backend.model_identifier or payload.model @@ -1553,9 +960,6 @@ def audio_input_generate(): ) if payload.stream: - _cancel_keys = (payload.cancel_id, payload.session_id, completion_id) - _tracker = _TrackedCancel(cancel_event, *_cancel_keys) - _tracker.__enter__() async def audio_input_stream(): try: @@ -1572,17 +976,10 @@ async def audio_input_stream(): ) yield f"data: {first_chunk.model_dump_json(exclude_none = True)}\n\n" - gen = audio_input_generate() - _DONE = object() - while True: - if cancel_event.is_set(): - break + for chunk_text in audio_input_generate(): if await request.is_disconnected(): cancel_event.set() return - chunk_text = await asyncio.to_thread(next, gen, _DONE) - if chunk_text is _DONE: - break if chunk_text: chunk = ChatCompletionChunk( id = completion_id, @@ -1615,8 +1012,6 @@ async def audio_input_stream(): f"Error during audio input streaming: {e}", exc_info = True ) yield f"data: {json.dumps({'error': {'message': _friendly_error(e), 'type': 'server_error'}})}\n\n" - finally: - _tracker.__exit__(None, None, None) return StreamingResponse( audio_input_stream(), @@ -1642,67 +1037,6 @@ async def audio_input_stream(): ) return JSONResponse(content = response.model_dump()) - # ── Standard OpenAI function-calling pass-through (GGUF only) ──── - # When a client (opencode / Claude Code via OpenAI compat / Cursor / - # Continue / ...) sends standard OpenAI `tools` without Studio's - # `enable_tools` shorthand, forward the request to llama-server - # verbatim so structured `tool_calls` flow back to the client. This - # branch runs BEFORE `_extract_content_parts` because that helper is - # unaware of `role="tool"` messages and assistant messages that only - # carry `tool_calls` (content=None) — both of which are valid in - # multi-turn client-side tool loops. - _has_tool_messages = any(m.role == "tool" or m.tool_calls for m in payload.messages) - # Route guided-decoding requests through the verbatim passthrough so - # ``response_format`` (JSON schema) actually reaches llama-server and - # the model's GBNF-constrained output comes back unmodified. The - # non-passthrough GGUF path below calls ``generate_chat_completion`` - # which has no response_format kwarg, so the schema gets silently - # dropped and data_designer falls back to free-form sampling. Guided - # decoding does not require ``supports_tools`` - the grammar machinery - # is independent of tool-call parsing. - _has_response_format = _extract_response_format(payload) is not None - _tools_passthrough = llama_backend.supports_tools and ( - (payload.tools and len(payload.tools) > 0) or _has_tool_messages - ) - if ( - using_gguf - and not _effective_enable_tools(payload) - and (_tools_passthrough or _has_response_format) - ): - # Preserve the vision guard that would otherwise run in the - # non-passthrough path below: text-only tool-capable GGUFs - # should return a clear 400 here rather than forwarding the - # image to llama-server and surfacing an opaque upstream error. - if not llama_backend.is_vision and ( - payload.image_base64 - or any( - isinstance(m.content, list) - and any(isinstance(p, ImageContentPart) for p in m.content) - for m in payload.messages - ) - ): - raise HTTPException( - status_code = 400, - detail = "Image provided but current GGUF model does not support vision.", - ) - - cancel_event = threading.Event() - completion_id = f"chatcmpl-{uuid.uuid4().hex[:12]}" - if payload.stream: - return await _openai_passthrough_stream( - request, - cancel_event, - llama_backend, - payload, - model_name, - completion_id, - ) - return await _openai_passthrough_non_streaming( - llama_backend, - payload, - model_name, - ) - # ── Parse messages (handles multimodal content parts) ───── system_prompt, chat_messages, extracted_image_b64 = _extract_content_parts( payload.messages @@ -1732,11 +1066,9 @@ async def audio_input_stream(): from PIL import Image as _Image raw = _b64.b64decode(image_b64) - # Normalize to RGB so PNG encoding succeeds regardless of - # source mode (RGBA, P, L, CMYK, I, F, ...). Previously - # we only converted RGBA, which left CMYK/I/F to raise at - # img.save(PNG). - img = _Image.open(_BytesIO(raw)).convert("RGB") + img = _Image.open(_BytesIO(raw)) + if img.mode == "RGBA": + img = img.convert("RGB") buf = _BytesIO() img.save(buf, format = "PNG") image_b64 = _b64.b64encode(buf.getvalue()).decode("ascii") @@ -1757,13 +1089,8 @@ async def audio_input_stream(): created = int(time.time()) # ── Tool-calling path (agentic loop) ────────────────── - # `_effective_enable_tools` lets `unsloth run --enable-tools/--disable-tools` - # hard-override the per-request value. Without a CLI override, falls - # back to `payload.enable_tools` (existing behavior). use_tools = ( - _effective_enable_tools(payload) - and llama_backend.supports_tools - and not image_b64 + payload.enable_tools and llama_backend.supports_tools and not image_b64 ) if use_tools: @@ -1862,8 +1189,6 @@ def gguf_generate_with_tools(): presence_penalty = payload.presence_penalty, cancel_event = cancel_event, enable_thinking = payload.enable_thinking, - reasoning_effort = payload.reasoning_effort, - preserve_thinking = payload.preserve_thinking, auto_heal_tool_calls = payload.auto_heal_tool_calls if payload.auto_heal_tool_calls is not None else True, @@ -1878,10 +1203,6 @@ def gguf_generate_with_tools(): _tool_sentinel = object() - _cancel_keys = (payload.cancel_id, payload.session_id, completion_id) - _tracker = _TrackedCancel(cancel_event, *_cancel_keys) - _tracker.__enter__() - async def gguf_tool_stream(): try: first_chunk = ChatCompletionChunk( @@ -1904,8 +1225,6 @@ async def gguf_tool_stream(): _stream_usage = None _stream_timings = None while True: - if cancel_event.is_set(): - break if await request.is_disconnected(): cancel_event.set() return @@ -2013,8 +1332,6 @@ async def gguf_tool_stream(): }, } yield f"data: {json.dumps(error_chunk)}\n\n" - finally: - _tracker.__exit__(None, None, None) return StreamingResponse( gguf_tool_stream(), @@ -2041,16 +1358,11 @@ def gguf_generate(): presence_penalty = payload.presence_penalty, cancel_event = cancel_event, enable_thinking = payload.enable_thinking, - reasoning_effort = payload.reasoning_effort, - preserve_thinking = payload.preserve_thinking, ) _gguf_sentinel = object() if payload.stream: - _cancel_keys = (payload.cancel_id, payload.session_id, completion_id) - _tracker = _TrackedCancel(cancel_event, *_cancel_keys) - _tracker.__enter__() async def gguf_stream_chunks(): try: @@ -2075,8 +1387,6 @@ async def gguf_stream_chunks(): _stream_usage = None _stream_timings = None while True: - if cancel_event.is_set(): - break if await request.is_disconnected(): cancel_event.set() return @@ -2160,8 +1470,6 @@ async def gguf_stream_chunks(): }, } yield f"data: {json.dumps(error_chunk)}\n\n" - finally: - _tracker.__exit__(None, None, None) return StreamingResponse( gguf_stream_chunks(), @@ -2261,9 +1569,6 @@ def generate(): # ── Streaming response ──────────────────────────────────────── if payload.stream: - _cancel_keys = (payload.cancel_id, payload.session_id, completion_id) - _tracker = _TrackedCancel(cancel_event, *_cancel_keys) - _tracker.__enter__() async def stream_chunks(): try: @@ -2291,9 +1596,6 @@ async def stream_chunks(): loop = asyncio.get_event_loop() gen = generate() while True: - if cancel_event.is_set(): - backend.reset_generation_state() - break # next(gen, _DONE) returns _DONE instead of raising # StopIteration — StopIteration cannot propagate # through asyncio futures (Python limitation). @@ -2349,8 +1651,6 @@ async def stream_chunks(): }, } yield f"data: {json.dumps(error_chunk)}\n\n" - finally: - _tracker.__exit__(None, None, None) return StreamingResponse( stream_chunks(), @@ -2516,2006 +1816,3 @@ async def openai_list_models( ) return {"object": "list", "data": models} - - -# ===================================================================== -# OpenAI-Compatible Completions Proxy (/completions → /v1/completions) -# ===================================================================== - - -@router.post("/completions") -async def openai_completions( - request: Request, - current_subject: str = Depends(get_current_subject), -): - """ - OpenAI-compatible text completions endpoint (non-chat). - - Transparently proxies to the running llama-server's ``/v1/completions``. - Only available when a GGUF model is loaded. - """ - llama_backend = get_llama_cpp_backend() - if not llama_backend.is_loaded: - raise HTTPException( - status_code = 503, - detail = "No GGUF model loaded. Load a GGUF model first.", - ) - - body = await request.json() - target_url = f"{llama_backend.base_url}/v1/completions" - is_stream = body.get("stream", False) - - if is_stream: - - async def _stream(): - # Manual httpx client/response lifecycle AND explicit - # aiter_bytes() iterator close — see _anthropic_passthrough_stream - # for the full rationale. Saving `bytes_iter = resp.aiter_bytes()` - # and `await bytes_iter.aclose()` in the finally block is the - # part that matters for avoiding the Python 3.13 + httpcore - # 1.0.x "Exception ignored in: " / anyio - # cancel-scope trace: an anonymous async for leaves the - # iterator unclosed, so Python's asyncgen GC finalizer runs - # cleanup on a later pass in a different asyncio task. - client = httpx.AsyncClient(timeout = 600) - resp = None - bytes_iter = None - try: - req = client.build_request("POST", target_url, json = body) - resp = await client.send(req, stream = True) - bytes_iter = resp.aiter_bytes() - async for chunk in bytes_iter: - yield chunk - except Exception as e: - logger.error("openai_completions stream error: %s", e) - finally: - if bytes_iter is not None: - try: - await bytes_iter.aclose() - except Exception: - pass - if resp is not None: - try: - await resp.aclose() - except Exception: - pass - try: - await client.aclose() - except Exception: - pass - - return StreamingResponse(_stream(), media_type = "text/event-stream") - else: - async with httpx.AsyncClient() as client: - resp = await client.post(target_url, json = body, timeout = 600) - return Response( - content = resp.content, - status_code = resp.status_code, - media_type = "application/json", - ) - - -# ===================================================================== -# OpenAI-Compatible Embeddings Proxy (/embeddings → /v1/embeddings) -# ===================================================================== - - -@router.post("/embeddings") -async def openai_embeddings( - request: Request, - current_subject: str = Depends(get_current_subject), -): - """ - OpenAI-compatible embeddings endpoint. - - Transparently proxies to the running llama-server's ``/v1/embeddings``. - Only available when a GGUF model is loaded. - Note: the loaded model must support pooling; otherwise llama-server - will return an error (expected). - """ - llama_backend = get_llama_cpp_backend() - if not llama_backend.is_loaded: - raise HTTPException( - status_code = 503, - detail = "No GGUF model loaded. Load a GGUF model first.", - ) - - body = await request.json() - target_url = f"{llama_backend.base_url}/v1/embeddings" - - async with httpx.AsyncClient() as client: - resp = await client.post(target_url, json = body, timeout = 600) - return Response( - content = resp.content, - status_code = resp.status_code, - media_type = "application/json", - ) - - -# ===================================================================== -# OpenAI Responses API (/responses → /v1/responses) -# ===================================================================== - - -def _translate_responses_tools_to_chat( - tools: Optional[list[dict]], -) -> Optional[list[dict]]: - """Translate Responses-shape function tools to the Chat Completions nested shape. - - Responses uses a flat shape per tool entry:: - - {"type": "function", "name": "...", "description": "...", - "parameters": {...}, "strict": true} - - The Chat Completions / llama-server passthrough expects the nested shape:: - - {"type": "function", - "function": {"name": "...", "description": "...", - "parameters": {...}, "strict": true}} - - Only ``type=="function"`` entries are forwarded. Built-in Responses tools - (``web_search``, ``file_search``, ``mcp``, ...) are dropped because - llama-server does not implement them server-side; keeping them in the - request would produce an opaque upstream 400. - """ - if not tools: - return None - out: list[dict] = [] - for tool in tools: - if not isinstance(tool, dict): - continue - if tool.get("type") != "function": - continue - fn: dict = {} - if "name" in tool: - fn["name"] = tool["name"] - if tool.get("description") is not None: - fn["description"] = tool["description"] - if tool.get("parameters") is not None: - fn["parameters"] = tool["parameters"] - if tool.get("strict") is not None: - fn["strict"] = tool["strict"] - out.append({"type": "function", "function": fn}) - return out or None - - -def _translate_responses_tool_choice_to_chat(tool_choice: Any) -> Any: - """Translate a Responses-shape ``tool_choice`` to the Chat Completions shape. - - String values (``"auto"``/``"none"``/``"required"``) pass through unchanged. - The Responses forcing object ``{"type": "function", "name": "X"}`` is - converted to Chat Completions' ``{"type": "function", "function": {"name": "X"}}``. - Unknown / built-in tool choices are forwarded as-is; llama-server ignores - what it doesn't recognise. - """ - if tool_choice is None: - return None - if isinstance(tool_choice, str): - return tool_choice - if ( - isinstance(tool_choice, dict) - and tool_choice.get("type") == "function" - and "name" in tool_choice - and "function" not in tool_choice - ): - return {"type": "function", "function": {"name": tool_choice["name"]}} - return tool_choice - - -def _responses_message_text(content: Union[str, list]) -> str: - """Flatten a ResponsesInputMessage ``content`` into a plain text string. - - Used for system/developer message hoisting and for assistant-replay - (``output_text``) messages when images/unknown parts are irrelevant. - Returns an empty string for empty input. - """ - if isinstance(content, str): - return content - parts: list[str] = [] - for part in content or []: - if isinstance(part, (ResponsesInputTextPart, ResponsesOutputTextPart)): - parts.append(part.text) - return "\n".join(parts) - - -def _normalise_responses_input(payload: ResponsesRequest) -> list[ChatMessage]: - """Convert a ResponsesRequest's ``input`` into Chat-format ``ChatMessage`` list. - - Handles the three input item shapes allowed by the Responses API: - - - ``ResponsesInputMessage`` — regular chat messages (text or multimodal). - - ``ResponsesFunctionCallInputItem`` — a prior assistant tool call replayed - on a follow-up turn. Converted into an assistant message carrying a - Chat Completions ``tool_calls`` entry keyed by ``call_id``. - - ``ResponsesFunctionCallOutputInputItem`` — a tool result the client is - returning. Converted into a ``role="tool"`` message with ``tool_call_id`` - set to the originating ``call_id`` so llama-server can reconcile the - call with its result. - - System / developer content is collected from ``instructions`` *and* from - any ``role="system"`` / ``role="developer"`` entries in ``input``, then - merged into a single ``role="system"`` message placed at the top of the - returned list. This satisfies strict chat templates (harmony / gpt-oss, - Qwen3, ...) whose Jinja raises ``"System message must be at the - beginning."`` when more than one system message is present or when a - system message appears after a user turn — the exact pattern the OpenAI - Codex CLI hits, since Codex sets ``instructions`` *and* also sends a - developer message in ``input``. - """ - system_parts: list[str] = [] - messages: list[ChatMessage] = [] - - if payload.instructions: - system_parts.append(payload.instructions) - - # Simple string input - if isinstance(payload.input, str): - if payload.input: - messages.append(ChatMessage(role = "user", content = payload.input)) - if system_parts: - merged = "\n\n".join(p for p in system_parts if p) - return [ChatMessage(role = "system", content = merged), *messages] - return messages - - for item in payload.input: - if isinstance(item, ResponsesFunctionCallInputItem): - messages.append( - ChatMessage( - role = "assistant", - content = None, - tool_calls = [ - { - "id": item.call_id, - "type": "function", - "function": { - "name": item.name, - "arguments": item.arguments, - }, - } - ], - ) - ) - continue - - if isinstance(item, ResponsesFunctionCallOutputInputItem): - # Chat Completions `role="tool"` requires a string content; if a - # Responses client sends a content-array output, serialize it. - output = item.output - if not isinstance(output, str): - output = json.dumps(output) - messages.append( - ChatMessage( - role = "tool", - tool_call_id = item.call_id, - content = output, - ) - ) - continue - - if isinstance(item, ResponsesUnknownInputItem): - # Reasoning items and any other unmodelled top-level Responses - # item types are silently dropped — llama-server-backed GGUFs - # cannot consume them and our lenient validation let them in so - # unrelated turns don't 422. - continue - - # ResponsesInputMessage — hoist system/developer to the top, merge. - if item.role in ("system", "developer"): - hoisted = _responses_message_text(item.content) - if hoisted: - system_parts.append(hoisted) - continue - - if isinstance(item.content, str): - messages.append(ChatMessage(role = item.role, content = item.content)) - continue - - # Assistant-replay turns come back as content = [output_text, ...]. - # Chat Completions' assistant role expects a plain string, not a - # multimodal content array, so flatten output_text (and any stray - # input_text / unknown text) to a single string. - if item.role == "assistant": - text = _responses_message_text(item.content) - if text: - messages.append(ChatMessage(role = "assistant", content = text)) - continue - - # User (and any other remaining roles) — keep multimodal when - # present, drop unknown content parts silently. - parts: list = [] - for part in item.content: - if isinstance(part, (ResponsesInputTextPart, ResponsesOutputTextPart)): - parts.append(TextContentPart(type = "text", text = part.text)) - elif isinstance(part, ResponsesInputImagePart): - parts.append( - ImageContentPart( - type = "image_url", - image_url = ImageUrl(url = part.image_url, detail = part.detail), - ) - ) - # ResponsesUnknownContentPart and anything else: drop. - if parts: - # Collapse single-text-part content to a plain string so roles - # that reject multimodal arrays (e.g. legacy templates) still - # accept the message. - if len(parts) == 1 and isinstance(parts[0], TextContentPart): - messages.append(ChatMessage(role = item.role, content = parts[0].text)) - else: - messages.append(ChatMessage(role = item.role, content = parts)) - - if system_parts: - merged = "\n\n".join(p for p in system_parts if p) - return [ChatMessage(role = "system", content = merged), *messages] - return messages - - -def _build_chat_request( - payload: ResponsesRequest, messages: list[ChatMessage], stream: bool -) -> ChatCompletionRequest: - """Build a ChatCompletionRequest from a ResponsesRequest. - - Tools and ``tool_choice`` are translated from the flat Responses shape to - the nested Chat Completions shape here so the existing #5099 - ``/v1/chat/completions`` client-side pass-through picks them up without - further modification. - """ - chat_kwargs: dict = dict( - model = payload.model, - messages = messages, - stream = stream, - ) - if payload.temperature is not None: - chat_kwargs["temperature"] = payload.temperature - if payload.top_p is not None: - chat_kwargs["top_p"] = payload.top_p - if payload.max_output_tokens is not None: - chat_kwargs["max_tokens"] = payload.max_output_tokens - - chat_tools = _translate_responses_tools_to_chat(payload.tools) - if chat_tools is not None: - chat_kwargs["tools"] = chat_tools - - chat_tool_choice = _translate_responses_tool_choice_to_chat(payload.tool_choice) - if chat_tool_choice is not None: - chat_kwargs["tool_choice"] = chat_tool_choice - - req = ChatCompletionRequest(**chat_kwargs) - # `parallel_tool_calls` is not a first-class field on ChatCompletionRequest, - # but the model allows extras and _build_openai_passthrough_body forwards - # only explicitly-known fields. Llama-server does not currently implement - # parallel_tool_calls semantics, so we accept-and-ignore it on the - # Responses side to avoid breaking SDK clients that always send it. - return req - - -def _chat_tool_calls_to_responses_output(tool_calls: list[dict]) -> list[dict]: - """Map Chat Completions ``tool_calls`` into Responses ``function_call`` output items. - - The Chat Completions id (``call_xxx``) is the shared correlation key across - turns in the OpenAI Responses API — it is stored as ``call_id`` on the - output item and must be echoed back by the client as - ``function_call_output.call_id`` on the next turn. - """ - items: list[dict] = [] - for tc in tool_calls: - if tc.get("type") != "function": - continue - fn = tc.get("function") or {} - items.append( - ResponsesOutputFunctionCall( - call_id = tc.get("id", ""), - name = fn.get("name", ""), - arguments = fn.get("arguments", "") or "", - status = "completed", - ).model_dump() - ) - return items - - -async def _responses_non_streaming( - payload: ResponsesRequest, - messages: list[ChatMessage], - request: Request, -) -> JSONResponse: - """Handle a non-streaming Responses API call.""" - chat_req = _build_chat_request(payload, messages, stream = False) - result = await openai_chat_completions(chat_req, request) - - # openai_chat_completions returns a JSONResponse for non-streaming - if isinstance(result, JSONResponse): - body = json.loads(result.body.decode()) - elif isinstance(result, Response): - body = json.loads(result.body.decode()) - else: - body = result - - choices = body.get("choices", []) - text = "" - tool_calls: list[dict] = [] - if choices: - msg = choices[0].get("message", {}) or {} - text = msg.get("content", "") or "" - tool_calls = msg.get("tool_calls") or [] - - usage_data = body.get("usage", {}) - input_tokens = usage_data.get("prompt_tokens", 0) - output_tokens = usage_data.get("completion_tokens", 0) - - resp_id = f"resp_{uuid.uuid4().hex[:12]}" - - # Responses API emits each tool call as its own top-level output item, - # alongside an optional assistant text message. Emit the text message - # only when the model actually produced content, so clients that expect - # a pure tool-call turn (finish_reason="tool_calls") don't see a spurious - # empty message item. - output_items: list[dict] = [] - if text: - msg_id = f"msg_{uuid.uuid4().hex[:12]}" - output_items.append( - ResponsesOutputMessage( - id = msg_id, - status = "completed", - role = "assistant", - content = [ResponsesOutputTextContent(text = text)], - ).model_dump() - ) - output_items.extend(_chat_tool_calls_to_responses_output(tool_calls)) - - response = ResponsesResponse( - id = resp_id, - created_at = int(time.time()), - status = "completed", - model = body.get("model", payload.model), - output = output_items, - usage = ResponsesUsage( - input_tokens = input_tokens, - output_tokens = output_tokens, - total_tokens = input_tokens + output_tokens, - ), - temperature = payload.temperature, - top_p = payload.top_p, - max_output_tokens = payload.max_output_tokens, - instructions = payload.instructions, - ) - return JSONResponse(content = response.model_dump()) - - -async def _responses_stream( - payload: ResponsesRequest, - messages: list[ChatMessage], - request: Request, -): - """Handle a streaming Responses API call, emitting named SSE events. - - For GGUF models the request goes directly to llama-server's - ``/v1/chat/completions`` endpoint from inside the StreamingResponse - child task — a single httpx lifecycle, a single async generator. - Wrapping the existing ``openai_chat_completions`` pass-through (which - already does its own httpx lifecycle) stacks two generators: Python - 3.13 + httpcore 1.0.x then loses the close-propagation chain on the - innermost ``HTTP11ConnectionByteStream`` at asyncgen finalisation, - tripping "Attempted to exit cancel scope in a different task" / - "async generator ignored GeneratorExit". The direct path avoids that - altogether. Non-GGUF falls back to the wrapper (which doesn't use - httpx, so the issue doesn't apply). - - Text deltas arrive as ``response.output_text.delta`` on a single - ``message`` output item at ``output_index=0``. Each tool call from - ``delta.tool_calls[]`` is promoted to its own top-level ``function_call`` - output item (one per distinct ``tool_calls[].index``), and relayed as - ``response.function_call_arguments.delta`` / ``.done`` events so clients - (Codex, OpenAI Python SDK) can reconstruct the call incrementally and - reply with a ``function_call_output`` item on the next turn. - """ - resp_id = f"resp_{uuid.uuid4().hex[:12]}" - msg_id = f"msg_{uuid.uuid4().hex[:12]}" - created_at = int(time.time()) - - chat_req = _build_chat_request(payload, messages, stream = True) - - llama_backend = get_llama_cpp_backend() - if not llama_backend.is_loaded: - # The direct pass-through is GGUF-only. Non-GGUF /v1/responses - # streaming isn't a Codex-compatible path today and wrapping the - # transformers backend's streaming generator here would re- - # introduce the double-layer asyncgen close pattern that produces - # "Attempted to exit cancel scope in a different task" on Python - # 3.13. Surface a typed 400 so the client sees a useful error - # instead of a dangling stream. - raise HTTPException( - status_code = 400, - detail = ( - "Streaming /v1/responses requires a GGUF model loaded via " - "llama-server. Use non-streaming /v1/responses, " - "/v1/chat/completions, or load a GGUF model." - ), - ) - - body = _build_openai_passthrough_body( - chat_req, backend_ctx = llama_backend.context_length - ) - target_url = f"{llama_backend.base_url}/v1/chat/completions" - - async def event_generator(): - full_text = "" - input_tokens = 0 - output_tokens = 0 - # Per-tool-call state keyed by the Chat Completions `tool_calls[].index` - # which stays stable across chunks for the same call. Values are: - # {output_index, item_id, call_id, name, arguments, opened} - tool_call_state: dict[int, dict] = {} - # Text message lives at output_index 0; tool calls claim 1, 2, ... - next_output_index = 1 - - def _snapshot_output() -> list[dict]: - """Snapshot of all completed output items for response.completed.""" - items: list[dict] = [ - { - "type": "message", - "id": msg_id, - "status": "completed", - "role": "assistant", - "content": [ - { - "type": "output_text", - "text": full_text, - "annotations": [], - } - ], - } - ] - for st in sorted(tool_call_state.values(), key = lambda s: s["output_index"]): - items.append( - { - "type": "function_call", - "id": st["item_id"], - "status": "completed", - "call_id": st["call_id"], - "name": st["name"], - "arguments": st["arguments"], - } - ) - return items - - # ── Preamble events ── - yield f"event: response.created\ndata: {json.dumps({'type': 'response.created', 'response': {'id': resp_id, 'object': 'response', 'created_at': created_at, 'status': 'in_progress', 'model': payload.model, 'output': [], 'usage': {'input_tokens': 0, 'output_tokens': 0, 'total_tokens': 0}}})}\n\n" - - # output_item.added (text message at output_index 0) - output_item = { - "type": "message", - "id": msg_id, - "status": "in_progress", - "role": "assistant", - "content": [], - } - yield f"event: response.output_item.added\ndata: {json.dumps({'type': 'response.output_item.added', 'output_index': 0, 'item': output_item})}\n\n" - - # content_part.added - content_part = {"type": "output_text", "text": "", "annotations": []} - yield f"event: response.content_part.added\ndata: {json.dumps({'type': 'response.content_part.added', 'item_id': msg_id, 'output_index': 0, 'content_index': 0, 'part': content_part})}\n\n" - - # ── Direct httpx lifecycle to llama-server ── - # Full same-task open + close, identical pattern to - # _openai_passthrough_stream and _anthropic_passthrough_stream: - # no `async with`, explicit aclose of lines_iter BEFORE resp / - # client so the innermost httpcore byte stream is finalised in - # this task (not via Python's asyncgen GC in a sibling task). - client = httpx.AsyncClient(timeout = 600) - resp = None - lines_iter = None - try: - req = client.build_request("POST", target_url, json = body) - try: - resp = await client.send(req, stream = True) - except httpx.RequestError as e: - logger.error("responses stream: upstream unreachable: %s", e) - yield f"event: response.failed\ndata: {json.dumps({'type': 'response.failed', 'response': {'id': resp_id, 'object': 'response', 'created_at': created_at, 'status': 'failed', 'model': payload.model, 'output': [], 'error': {'code': 502, 'message': _friendly_error(e)}}})}\n\n" - return - - if resp.status_code != 200: - err_bytes = await resp.aread() - err_text = err_bytes.decode("utf-8", errors = "replace") - logger.error( - "responses stream upstream error: status=%s body=%s", - resp.status_code, - err_text[:500], - ) - yield f"event: response.failed\ndata: {json.dumps({'type': 'response.failed', 'response': {'id': resp_id, 'object': 'response', 'created_at': created_at, 'status': 'failed', 'model': payload.model, 'output': [], 'error': {'code': resp.status_code, 'message': f'llama-server error: {err_text[:500]}'}}})}\n\n" - return - - lines_iter = resp.aiter_lines() - async for raw_line in lines_iter: - if await request.is_disconnected(): - break - if not raw_line: - continue - if not raw_line.startswith("data: "): - continue - data_str = raw_line[6:] - if data_str.strip() == "[DONE]": - break - try: - chunk_data = json.loads(data_str) - except json.JSONDecodeError: - continue - - choices = chunk_data.get("choices", []) - if not choices: - usage = chunk_data.get("usage") - if usage: - input_tokens = usage.get("prompt_tokens", input_tokens) - output_tokens = usage.get("completion_tokens", output_tokens) - continue - - delta = choices[0].get("delta", {}) or {} - content = delta.get("content") - if content: - full_text += content - delta_event = { - "type": "response.output_text.delta", - "item_id": msg_id, - "output_index": 0, - "content_index": 0, - "delta": content, - } - yield f"event: response.output_text.delta\ndata: {json.dumps(delta_event)}\n\n" - - for tc in delta.get("tool_calls") or []: - idx = tc.get("index", 0) - st = tool_call_state.get(idx) - fn = tc.get("function") or {} - if st is None: - # First chunk for this tool call — allocate an - # output_index and emit output_item.added. - st = { - "output_index": next_output_index, - "item_id": f"fc_{uuid.uuid4().hex[:12]}", - "call_id": tc.get("id") or "", - "name": fn.get("name") or "", - "arguments": "", - "opened": False, - } - next_output_index += 1 - tool_call_state[idx] = st - else: - # Later chunks sometimes carry the id/name only - # once; merge when present. - if tc.get("id") and not st["call_id"]: - st["call_id"] = tc["id"] - if fn.get("name") and not st["name"]: - st["name"] = fn["name"] - - if not st["opened"] and st["call_id"] and st["name"]: - item_added = { - "type": "response.output_item.added", - "output_index": st["output_index"], - "item": { - "type": "function_call", - "id": st["item_id"], - "status": "in_progress", - "call_id": st["call_id"], - "name": st["name"], - "arguments": "", - }, - } - yield f"event: response.output_item.added\ndata: {json.dumps(item_added)}\n\n" - st["opened"] = True - - arg_delta = fn.get("arguments") or "" - if arg_delta and st["opened"]: - st["arguments"] += arg_delta - args_delta_event = { - "type": "response.function_call_arguments.delta", - "item_id": st["item_id"], - "output_index": st["output_index"], - "delta": arg_delta, - } - yield f"event: response.function_call_arguments.delta\ndata: {json.dumps(args_delta_event)}\n\n" - elif arg_delta: - # Buffer the args until we can open the item - # (id/name arrive in the same chunk as the first - # arg delta for some models — but if not, stash). - st["arguments"] += arg_delta - - usage = chunk_data.get("usage") - if usage: - input_tokens = usage.get("prompt_tokens", input_tokens) - output_tokens = usage.get("completion_tokens", output_tokens) - except Exception as e: - logger.error("responses stream error: %s", e) - finally: - if lines_iter is not None: - try: - await lines_iter.aclose() - except Exception: - pass - if resp is not None: - try: - await resp.aclose() - except Exception: - pass - try: - await client.aclose() - except Exception: - pass - - # ── Closing events for tool calls ── - for st in sorted(tool_call_state.values(), key = lambda s: s["output_index"]): - # If id/name never arrived (malformed upstream), synthesise so - # the client still sees a coherent frame sequence. - if not st["opened"]: - if not st["call_id"]: - st["call_id"] = f"call_{uuid.uuid4().hex[:12]}" - item_added = { - "type": "response.output_item.added", - "output_index": st["output_index"], - "item": { - "type": "function_call", - "id": st["item_id"], - "status": "in_progress", - "call_id": st["call_id"], - "name": st["name"], - "arguments": "", - }, - } - yield f"event: response.output_item.added\ndata: {json.dumps(item_added)}\n\n" - if st["arguments"]: - yield ( - "event: response.function_call_arguments.delta\n" - "data: " - + json.dumps( - { - "type": "response.function_call_arguments.delta", - "item_id": st["item_id"], - "output_index": st["output_index"], - "delta": st["arguments"], - } - ) - + "\n\n" - ) - st["opened"] = True - - args_done = { - "type": "response.function_call_arguments.done", - "item_id": st["item_id"], - "output_index": st["output_index"], - "name": st["name"], - "arguments": st["arguments"], - } - yield f"event: response.function_call_arguments.done\ndata: {json.dumps(args_done)}\n\n" - - item_done = { - "type": "response.output_item.done", - "output_index": st["output_index"], - "item": { - "type": "function_call", - "id": st["item_id"], - "status": "completed", - "call_id": st["call_id"], - "name": st["name"], - "arguments": st["arguments"], - }, - } - yield f"event: response.output_item.done\ndata: {json.dumps(item_done)}\n\n" - - # ── Closing events for text message ── - yield f"event: response.output_text.done\ndata: {json.dumps({'type': 'response.output_text.done', 'item_id': msg_id, 'output_index': 0, 'content_index': 0, 'text': full_text})}\n\n" - - yield f"event: response.content_part.done\ndata: {json.dumps({'type': 'response.content_part.done', 'item_id': msg_id, 'output_index': 0, 'content_index': 0, 'part': {'type': 'output_text', 'text': full_text, 'annotations': []}})}\n\n" - - yield f"event: response.output_item.done\ndata: {json.dumps({'type': 'response.output_item.done', 'output_index': 0, 'item': {'type': 'message', 'id': msg_id, 'status': 'completed', 'role': 'assistant', 'content': [{'type': 'output_text', 'text': full_text, 'annotations': []}]}})}\n\n" - - # response.completed - total_tokens = input_tokens + output_tokens - completed_response = { - "type": "response.completed", - "response": { - "id": resp_id, - "object": "response", - "created_at": created_at, - "status": "completed", - "model": payload.model, - "output": _snapshot_output(), - "usage": { - "input_tokens": input_tokens, - "output_tokens": output_tokens, - "total_tokens": total_tokens, - }, - }, - } - yield f"event: response.completed\ndata: {json.dumps(completed_response)}\n\n" - - return StreamingResponse( - event_generator(), - media_type = "text/event-stream", - headers = { - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", - }, - ) - - -@router.post("/responses") -async def openai_responses( - payload: ResponsesRequest, - request: Request, - current_subject: str = Depends(get_current_subject), -): - """ - OpenAI Responses API endpoint. - - Accepts the Responses-format request, converts it to a - ChatCompletionRequest internally, and returns a response - matching the OpenAI Responses API schema (output array, - input_tokens/output_tokens, named SSE events for streaming). - """ - messages = _normalise_responses_input(payload) - if not messages: - raise HTTPException(status_code = 400, detail = "No input provided.") - - if payload.stream: - return await _responses_stream(payload, messages, request) - return await _responses_non_streaming(payload, messages, request) - - -# ===================================================================== -# Anthropic-Compatible Messages API (/messages → /v1/messages) -# ===================================================================== - - -def _normalize_anthropic_openai_images( - openai_messages: list[dict], is_vision: bool -) -> bool: - """Enforce the vision guard on translated Anthropic messages and - normalize any ``image_url`` parts with base64 data URLs to PNG. - - llama-server's stb_image only handles a few formats (JPEG/PNG/BMP/…); - Anthropic clients commonly send JPEG or WebP, and Claude Code sends - WebP. Re-encoding everything to PNG mirrors the behavior of - `_openai_messages_for_passthrough` / the GGUF branch of - `/v1/chat/completions` so the two endpoints agree. - - Mutates ``openai_messages`` in place. Returns ``True`` when any - image part was seen (so the caller can skip a second scan). Raises - HTTPException(400) when images are present but the active model is - not a vision model, or when an image cannot be decoded. - """ - from PIL import Image - - has_image = False - for msg in openai_messages: - content = msg.get("content") - if not isinstance(content, list): - continue - for part in content: - if part.get("type") != "image_url": - continue - - has_image = True - if not is_vision: - raise HTTPException( - status_code = 400, - detail = "Image provided but current GGUF model does not support vision.", - ) - - url = (part.get("image_url") or {}).get("url", "") - if not url.startswith("data:"): - # Remote URLs are forwarded as-is; llama-server will - # fetch (or fail) per its own support matrix. - continue - - try: - _, b64data = url.split(",", 1) - raw = base64.b64decode(b64data) - img = Image.open(io.BytesIO(raw)).convert("RGB") - buf = io.BytesIO() - img.save(buf, format = "PNG") - png_b64 = base64.b64encode(buf.getvalue()).decode("ascii") - except Exception as e: - raise HTTPException( - status_code = 400, - detail = f"Failed to process image: {e}", - ) - part["image_url"] = {"url": f"data:image/png;base64,{png_b64}"} - - return has_image - - -@router.post("/messages") -async def anthropic_messages( - payload: AnthropicMessagesRequest, - request: Request, - current_subject: str = Depends(get_current_subject), -): - """ - Anthropic-compatible Messages API endpoint. - - Translates Anthropic message format to internal OpenAI format, runs - through the existing agentic tool loop when tools are provided, and - returns responses in Anthropic Messages API format (streaming SSE or - non-streaming JSON). - """ - llama_backend = get_llama_cpp_backend() - if not llama_backend.is_loaded: - raise HTTPException( - status_code = 503, - detail = "No GGUF model loaded. Load a GGUF model first.", - ) - - model_name = getattr(llama_backend, "model_identifier", None) or payload.model - message_id = f"msg_{uuid.uuid4().hex[:24]}" - - # ── Translate Anthropic → OpenAI ────────────────────────── - openai_messages = anthropic_messages_to_openai( - [m.model_dump() for m in payload.messages], - payload.system, - ) - - # Enforce vision guard + re-encode embedded images to PNG so the - # Anthropic endpoint matches the behavior of /v1/chat/completions. - _has_image = _normalize_anthropic_openai_images( - openai_messages, llama_backend.is_vision - ) - - temperature = payload.temperature if payload.temperature is not None else 0.6 - top_p = payload.top_p if payload.top_p is not None else 0.95 - top_k = payload.top_k if payload.top_k is not None else 20 - min_p = payload.min_p if payload.min_p is not None else 0.01 - repetition_penalty = ( - payload.repetition_penalty if payload.repetition_penalty is not None else 1.0 - ) - presence_penalty = ( - payload.presence_penalty if payload.presence_penalty is not None else 0.0 - ) - stop = payload.stop_sequences or None - - # Translate Anthropic tool_choice to OpenAI format for forwarding to - # llama-server. Falls back to "auto" when unset or unrecognized, which - # matches the prior hardcoded behavior. - openai_tool_choice = anthropic_tool_choice_to_openai(payload.tool_choice) - if openai_tool_choice is None: - openai_tool_choice = "auto" - - cancel_event = threading.Event() - - # ── Tool routing ────────────────────────────────────────── - # Three paths: - # 1. enable_tools=true → server-side execution of built-in tools (Unsloth shorthand) - # 2. tools=[...] only → client-side pass-through (standard Anthropic behavior) - # 3. neither → plain chat - # Server-side agentic loop doesn't support multimodal input — matches - # the `not image_b64` gate in /v1/chat/completions. - server_tools = ( - _effective_enable_tools(payload) - and llama_backend.supports_tools - and not _has_image - ) - client_tools = ( - not server_tools - and payload.tools - and len(payload.tools) > 0 - and llama_backend.supports_tools - ) - - # ── Client-side pass-through path ───────────────────────── - if client_tools: - openai_tools = anthropic_tools_to_openai(payload.tools) - - if payload.stream: - return await _anthropic_passthrough_stream( - request, - cancel_event, - llama_backend, - openai_messages, - openai_tools, - temperature, - top_p, - top_k, - payload.max_tokens, - message_id, - model_name, - stop = stop, - min_p = min_p, - repetition_penalty = repetition_penalty, - presence_penalty = presence_penalty, - tool_choice = openai_tool_choice, - session_id = payload.session_id, - cancel_id = payload.cancel_id, - ) - return await _anthropic_passthrough_non_streaming( - llama_backend, - openai_messages, - openai_tools, - temperature, - top_p, - top_k, - payload.max_tokens, - message_id, - model_name, - stop = stop, - min_p = min_p, - repetition_penalty = repetition_penalty, - presence_penalty = presence_penalty, - tool_choice = openai_tool_choice, - ) - - if server_tools: - from core.inference.tools import ALL_TOOLS - - if payload.enabled_tools is not None: - openai_tools = [ - t for t in ALL_TOOLS if t["function"]["name"] in payload.enabled_tools - ] - else: - openai_tools = ALL_TOOLS - - # Build tool-use system prompt nudge (same logic as /chat/completions) - _tool_names = {t["function"]["name"] for t in openai_tools} - _has_web = "web_search" in _tool_names - _has_code = "python" in _tool_names or "terminal" in _tool_names - - _date_line = f"The current date is {_date.today().isoformat()}." - _model_size_b = _extract_model_size_b(model_name) - _is_small_model = _model_size_b is not None and _model_size_b < 9 - - if _is_small_model: - _web_tips = "Do not repeat the same search query." - else: - _web_tips = ( - "When you search and find a relevant URL in the results, " - "fetch its full content by calling web_search with the url parameter. " - "Do not repeat the same search query. If a search returns " - "no useful results, try rephrasing or fetching a result URL directly." - ) - _code_tips = ( - "Use code execution for math, calculations, data processing, " - "or to parse and analyze information from tool results." - ) - - if _has_web and _has_code: - _nudge = ( - _date_line + " " - "You have access to tools. When appropriate, prefer using " - "tools rather than answering from memory. " - + _web_tips - + " " - + _code_tips - ) - elif _has_code: - _nudge = ( - _date_line + " " - "You have access to tools. When appropriate, prefer using " - "code execution rather than answering from memory. " + _code_tips - ) - elif _has_web: - _nudge = ( - _date_line + " " - "You have access to tools. When appropriate, prefer using " - "web search for up-to-date or uncertain factual " - "information rather than answering from memory. " + _web_tips - ) - else: - _nudge = "" - - if _nudge: - _nudge += _TOOL_ACTION_NUDGE - # Inject into system prompt - if openai_messages and openai_messages[0].get("role") == "system": - openai_messages[0]["content"] = ( - openai_messages[0]["content"].rstrip() + "\n\n" + _nudge - ) - else: - openai_messages.insert(0, {"role": "system", "content": _nudge}) - - # Strip stale tool-call XML from conversation - for _msg in openai_messages: - if _msg.get("role") == "assistant" and isinstance(_msg.get("content"), str): - _msg["content"] = _TOOL_XML_RE.sub("", _msg["content"]).strip() - - def _run_tool_gen(): - return llama_backend.generate_chat_completion_with_tools( - messages = openai_messages, - tools = openai_tools, - temperature = temperature, - top_p = top_p, - top_k = top_k, - min_p = min_p, - repetition_penalty = repetition_penalty, - presence_penalty = presence_penalty, - max_tokens = payload.max_tokens, - stop = stop, - cancel_event = cancel_event, - max_tool_iterations = 25, - auto_heal_tool_calls = True, - tool_call_timeout = 300, - session_id = payload.session_id, - ) - - if payload.stream: - return await _anthropic_tool_stream( - request, - cancel_event, - _run_tool_gen, - message_id, - model_name, - ) - return await _anthropic_tool_non_streaming( - _run_tool_gen, - message_id, - model_name, - ) - - # ── No-tool path ────────────────────────────────────────── - def _run_plain_gen(): - return llama_backend.generate_chat_completion( - messages = openai_messages, - temperature = temperature, - top_p = top_p, - top_k = top_k, - min_p = min_p, - repetition_penalty = repetition_penalty, - presence_penalty = presence_penalty, - max_tokens = payload.max_tokens, - stop = stop, - cancel_event = cancel_event, - ) - - if payload.stream: - return await _anthropic_plain_stream( - request, - cancel_event, - _run_plain_gen, - message_id, - model_name, - ) - return await _anthropic_plain_non_streaming( - _run_plain_gen, - message_id, - model_name, - ) - - -async def _anthropic_tool_stream( - request, - cancel_event, - run_gen, - message_id, - model_name, -): - """Streaming response for the tool-calling path.""" - _sentinel = object() - - async def _stream(): - emitter = AnthropicStreamEmitter() - for line in emitter.start(message_id, model_name): - yield line - - gen = run_gen() - try: - while True: - if await request.is_disconnected(): - cancel_event.set() - return - event = await asyncio.to_thread(next, gen, _sentinel) - if event is _sentinel: - break - # Strip leaked tool-call XML from content events - if event.get("type") == "content": - event = dict(event) - event["text"] = _TOOL_XML_RE.sub("", event["text"]) - for line in emitter.feed(event): - yield line - except Exception as e: - logger.error("anthropic_messages stream error: %s", e) - - for line in emitter.finish("end_turn"): - yield line - - return StreamingResponse( - _stream(), - media_type = "text/event-stream", - headers = { - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", - }, - ) - - -async def _anthropic_plain_stream( - request, - cancel_event, - run_gen, - message_id, - model_name, -): - """Streaming response for the no-tool path.""" - _sentinel = object() - - async def _stream(): - emitter = AnthropicStreamEmitter() - for line in emitter.start(message_id, model_name): - yield line - - gen = run_gen() - try: - while True: - if await request.is_disconnected(): - cancel_event.set() - return - cumulative = await asyncio.to_thread(next, gen, _sentinel) - if cumulative is _sentinel: - break - if isinstance(cumulative, dict): - if cumulative.get("type") == "metadata": - for line in emitter.feed(cumulative): - yield line - continue - # Plain generator yields cumulative text strings - for line in emitter.feed({"type": "content", "text": cumulative}): - yield line - except Exception as e: - logger.error("anthropic_messages stream error: %s", e) - - for line in emitter.finish("end_turn"): - yield line - - return StreamingResponse( - _stream(), - media_type = "text/event-stream", - headers = { - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", - }, - ) - - -async def _anthropic_tool_non_streaming(run_gen, message_id, model_name): - """Non-streaming response for the tool-calling path. - - Builds ``content_blocks`` in generation order (text → tool_use → text → - tool_use → ...), mirroring the streaming emitter's behavior. Deltas - within a single synthesis turn are merged into the trailing text block; - tool_use blocks interrupt the text sequence and open a new text block on - the next content event. - - ``prev_text`` is reset on ``tool_end`` because - ``generate_chat_completion_with_tools`` yields cumulative content *per - turn* — the first content event of turn N+1 must diff against an empty - baseline, not against turn N's final length. - """ - content_blocks: list = [] - usage = {} - prev_text = "" - - for event in run_gen(): - etype = event.get("type", "") - if etype == "content": - # Strip leaked tool-call XML - clean = _TOOL_XML_RE.sub("", event["text"]) - new = clean[len(prev_text) :] - prev_text = clean - if new: - if content_blocks and isinstance( - content_blocks[-1], AnthropicResponseTextBlock - ): - content_blocks[-1].text += new - else: - content_blocks.append(AnthropicResponseTextBlock(text = new)) - elif etype == "tool_start": - content_blocks.append( - AnthropicResponseToolUseBlock( - id = event["tool_call_id"], - name = event["tool_name"], - input = event.get("arguments", {}), - ) - ) - elif etype == "tool_end": - prev_text = "" - elif etype == "metadata": - usage = event.get("usage", {}) - - resp = AnthropicMessagesResponse( - id = message_id, - model = model_name, - content = content_blocks, - stop_reason = "end_turn", - usage = AnthropicUsage( - input_tokens = usage.get("prompt_tokens", 0), - output_tokens = usage.get("completion_tokens", 0), - ), - ) - return JSONResponse(content = resp.model_dump()) - - -async def _anthropic_plain_non_streaming(run_gen, message_id, model_name): - """Non-streaming response for the no-tool path.""" - text_parts = [] - usage = {} - prev_text = "" - - for cumulative in run_gen(): - if isinstance(cumulative, dict): - if cumulative.get("type") == "metadata": - usage = cumulative.get("usage", {}) - continue - new = cumulative[len(prev_text) :] - prev_text = cumulative - if new: - text_parts.append(new) - - full_text = "".join(text_parts) - content_blocks = [] - if full_text: - content_blocks.append(AnthropicResponseTextBlock(text = full_text)) - - resp = AnthropicMessagesResponse( - id = message_id, - model = model_name, - content = content_blocks, - stop_reason = "end_turn", - usage = AnthropicUsage( - input_tokens = usage.get("prompt_tokens", 0), - output_tokens = usage.get("completion_tokens", 0), - ), - ) - return JSONResponse(content = resp.model_dump()) - - -# ===================================================================== -# Client-side tool pass-through (Anthropic-native tools field) -# ===================================================================== - - -def _build_passthrough_payload( - openai_messages, - openai_tools, - temperature, - top_p, - top_k, - max_tokens, - stream, - stop = None, - min_p = None, - repetition_penalty = None, - presence_penalty = None, - tool_choice = "auto", - response_format = None, - chat_template_kwargs = None, - backend_ctx = None, -): - body = { - "messages": openai_messages, - "tools": openai_tools, - "tool_choice": tool_choice, - "temperature": temperature, - "top_p": top_p, - "top_k": top_k, - "stream": stream, - } - if stream: - body["stream_options"] = {"include_usage": True} - body["max_tokens"] = ( - max_tokens - if max_tokens is not None - else (backend_ctx or _DEFAULT_MAX_TOKENS_FLOOR) - ) - body["t_max_predict_ms"] = _DEFAULT_T_MAX_PREDICT_MS - if stop: - body["stop"] = stop - if min_p is not None: - body["min_p"] = min_p - if repetition_penalty is not None: - # llama-server's field is "repeat_penalty", not "repetition_penalty" - body["repeat_penalty"] = repetition_penalty - if presence_penalty is not None: - body["presence_penalty"] = presence_penalty - if response_format is not None: - # llama-server applies a GBNF grammar derived from the JSON schema - # when response_format is present. Field is documented flat at the - # request root (tools/server/README.md), which is also what the - # OpenAI SDK produces by spreading extra_body into the body top. - body["response_format"] = response_format - if chat_template_kwargs is not None: - # Propagate reasoning / template overrides (e.g. enable_thinking) - # so llama-server renders the Jinja template in the mode the caller - # asked for instead of whatever default the model was loaded with. - body["chat_template_kwargs"] = chat_template_kwargs - return body - - -async def _anthropic_passthrough_stream( - request, - cancel_event, - llama_backend, - openai_messages, - openai_tools, - temperature, - top_p, - top_k, - max_tokens, - message_id, - model_name, - stop = None, - min_p = None, - repetition_penalty = None, - presence_penalty = None, - tool_choice = "auto", - session_id = None, - cancel_id = None, -): - """Streaming client-side pass-through: forward tools to llama-server and - translate its streaming response to Anthropic SSE without executing anything.""" - target_url = f"{llama_backend.base_url}/v1/chat/completions" - body = _build_passthrough_payload( - openai_messages, - openai_tools, - temperature, - top_p, - top_k, - max_tokens, - True, - stop = stop, - min_p = min_p, - repetition_penalty = repetition_penalty, - presence_penalty = presence_penalty, - tool_choice = tool_choice, - backend_ctx = llama_backend.context_length, - ) - - # cancel_id mirrors the OpenAI passthrough so a per-run cancel POST - # works without the caller having to know the local message_id. - _tracker = _TrackedCancel(cancel_event, cancel_id, session_id, message_id) - _tracker.__enter__() - - async def _stream(): - emitter = AnthropicPassthroughEmitter() - for line in emitter.start(message_id, model_name): - yield line - - # Manage the httpx client, response, AND the aiter_lines() async - # generator MANUALLY — no `async with`, no anonymous iterator. - # - # On Python 3.13 + httpcore 1.0.x, `async for raw_line in - # resp.aiter_lines():` creates an anonymous async generator. When - # the loop exits via `break` (or the generator is orphaned when a - # client disconnects mid-stream), Python's `async for` protocol - # does NOT auto-close the iterator the way a sync `for` loop - # would. The iterator remains reachable only from the current - # coroutine frame; once `_stream()` returns, the frame is GC'd - # and the iterator becomes unreachable. Python's asyncgen - # finalizer hook then runs its aclose() on a LATER GC pass in a - # DIFFERENT asyncio task, where httpcore's - # `HTTP11ConnectionByteStream.aclose()` enters - # `anyio.CancelScope.__exit__` with a mismatched task and prints - # `RuntimeError: Attempted to exit cancel scope in a different - # task` / `RuntimeError: async generator ignored GeneratorExit` - # as "Exception ignored in:" unraisable warnings. - # - # The fix: save `resp.aiter_lines()` as `lines_iter`, and in the - # finally block explicitly `await lines_iter.aclose()` BEFORE - # `resp.aclose()` / `client.aclose()`. This closes the iterator - # inside our own task's event loop, so the internal httpcore - # byte-stream is cleaned up before Python's asyncgen finalizer - # has anything orphaned to finalize. Each aclose is wrapped in - # `try: ... except Exception: pass` so anyio cleanup noise from - # nested aclose paths can't bubble out. - client = httpx.AsyncClient( - timeout = 600, - limits = httpx.Limits(max_keepalive_connections = 0), - ) - resp = None - lines_iter = None - cancel_watcher = None - try: - req = client.build_request("POST", target_url, json = body) - resp = await client.send(req, stream = True) - - # See _openai_passthrough_stream for rationale: aiter_lines() - # blocks during llama-server prefill, so the in-loop cancel - # check is unreachable until the first SSE chunk arrives. - # The watcher closes `resp` on cancel, raising in aiter_lines. - cancel_watcher = asyncio.create_task( - _await_cancel_then_close(cancel_event, resp) - ) - lines_iter = resp.aiter_lines() - async for raw_line in lines_iter: - if cancel_event.is_set(): - break - if await request.is_disconnected(): - cancel_event.set() - break - if not raw_line or not raw_line.startswith("data: "): - continue - data_str = raw_line[6:] - if data_str.strip() == "[DONE]": - break - try: - chunk = json.loads(data_str) - except json.JSONDecodeError: - continue - for line in emitter.feed_chunk(chunk): - yield line - except (httpx.RemoteProtocolError, httpx.ReadError, httpx.CloseError): - if not cancel_event.is_set(): - raise - except Exception as e: - logger.error("anthropic_messages passthrough stream error: %s", e) - finally: - if cancel_watcher is not None: - cancel_watcher.cancel() - try: - await cancel_watcher - except (asyncio.CancelledError, Exception): - pass - if lines_iter is not None: - try: - await lines_iter.aclose() - except Exception: - pass - if resp is not None: - try: - await resp.aclose() - except Exception: - pass - try: - await client.aclose() - except Exception: - pass - _tracker.__exit__(None, None, None) - - for line in emitter.finish(): - yield line - - return StreamingResponse( - _stream(), - media_type = "text/event-stream", - headers = { - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", - }, - ) - - -async def _anthropic_passthrough_non_streaming( - llama_backend, - openai_messages, - openai_tools, - temperature, - top_p, - top_k, - max_tokens, - message_id, - model_name, - stop = None, - min_p = None, - repetition_penalty = None, - presence_penalty = None, - tool_choice = "auto", -): - """Non-streaming client-side pass-through.""" - target_url = f"{llama_backend.base_url}/v1/chat/completions" - body = _build_passthrough_payload( - openai_messages, - openai_tools, - temperature, - top_p, - top_k, - max_tokens, - False, - stop = stop, - min_p = min_p, - repetition_penalty = repetition_penalty, - presence_penalty = presence_penalty, - tool_choice = tool_choice, - backend_ctx = llama_backend.context_length, - ) - - async with httpx.AsyncClient() as client: - resp = await client.post(target_url, json = body, timeout = 600) - - if resp.status_code != 200: - raise HTTPException( - status_code = resp.status_code, - detail = f"llama-server error: {resp.text[:500]}", - ) - - data = resp.json() - choice = (data.get("choices") or [{}])[0] - message = choice.get("message") or {} - finish_reason = choice.get("finish_reason") - - content_blocks = [] - text = message.get("content") or "" - if text: - text = _TOOL_XML_RE.sub("", text).strip() - if text: - content_blocks.append(AnthropicResponseTextBlock(text = text)) - - tool_calls = message.get("tool_calls") or [] - for tc in tool_calls: - fn = tc.get("function") or {} - try: - args = json.loads(fn.get("arguments", "{}")) - except json.JSONDecodeError: - args = {} - content_blocks.append( - AnthropicResponseToolUseBlock( - id = tc.get("id", ""), - name = fn.get("name", ""), - input = args, - ) - ) - - if tool_calls: - stop_reason = "tool_use" - elif finish_reason == "length": - stop_reason = "max_tokens" - else: - stop_reason = "end_turn" - - usage = data.get("usage") or {} - resp_obj = AnthropicMessagesResponse( - id = message_id, - model = model_name, - content = content_blocks, - stop_reason = stop_reason, - usage = AnthropicUsage( - input_tokens = usage.get("prompt_tokens", 0), - output_tokens = usage.get("completion_tokens", 0), - ), - ) - return JSONResponse(content = resp_obj.model_dump()) - - -# ===================================================================== -# Client-side tool pass-through (OpenAI-native /v1/chat/completions) -# ===================================================================== - - -def _openai_messages_for_passthrough(payload) -> list[dict]: - """Build OpenAI-format message dicts for the /v1/chat/completions - passthrough path. - - Messages from ``payload.messages`` are dumped through Pydantic (dropping - unset optional fields) so they are already in standard OpenAI format - — including ``role="tool"`` tool-result messages and assistant messages - that carry structured ``tool_calls``. Content-parts images already in - the message list are left untouched. - - When a client uses Studio's legacy ``image_base64`` top-level field, the - image is re-encoded to PNG (llama-server's stb_image has limited format - support) and spliced into the last user message as an OpenAI - ``image_url`` content part so vision + function-calling requests work - transparently. - """ - messages = [m.model_dump(exclude_none = True) for m in payload.messages] - - if not payload.image_base64: - return messages - - try: - import base64 as _b64 - from io import BytesIO as _BytesIO - from PIL import Image as _Image - - raw = _b64.b64decode(payload.image_base64) - img = _Image.open(_BytesIO(raw)).convert("RGB") - buf = _BytesIO() - img.save(buf, format = "PNG") - png_b64 = _b64.b64encode(buf.getvalue()).decode("ascii") - except Exception as e: - raise HTTPException( - status_code = 400, - detail = f"Failed to process image: {e}", - ) - - data_url = f"data:image/png;base64,{png_b64}" - image_part = {"type": "image_url", "image_url": {"url": data_url}} - - for msg in reversed(messages): - if msg.get("role") != "user": - continue - existing = msg.get("content") - if isinstance(existing, str): - msg["content"] = [{"type": "text", "text": existing}, image_part] - elif isinstance(existing, list): - existing.append(image_part) - else: - msg["content"] = [image_part] - break - else: - messages.append({"role": "user", "content": [image_part]}) - - return messages - - -def _extract_response_format(payload): - """Return the ``response_format`` field on an incoming ChatCompletionRequest - (or None). The model is declared with ``extra="allow"`` so pydantic stashes - unknown top-level fields in ``model_extra``; OpenAI-SDK clients spread - ``extra_body`` into the request body top level, which is where guided- - decoding recipes park their JSON-schema response_format. - """ - extra = getattr(payload, "model_extra", None) - if not isinstance(extra, dict): - return None - rf = extra.get("response_format") - return rf if isinstance(rf, dict) else None - - -def _build_openai_passthrough_body(payload, backend_ctx = None) -> dict: - """Assemble the llama-server request body from a ChatCompletionRequest. - - Only explicitly-known OpenAI / llama-server fields are forwarded so that - Studio-specific extensions (``enable_tools``, ``enabled_tools``, - ``session_id``, ...) never leak to the backend. - """ - messages = _openai_messages_for_passthrough(payload) - tool_choice = payload.tool_choice if payload.tool_choice is not None else "auto" - # When the caller asked for a specific reasoning mode, forward it to - # llama-server via chat_template_kwargs so the Jinja template renders - # with (or without) the reasoning preamble. - tpl_kwargs = None - if payload.enable_thinking is not None: - tpl_kwargs = {"enable_thinking": bool(payload.enable_thinking)} - return _build_passthrough_payload( - messages, - payload.tools, - payload.temperature, - payload.top_p, - payload.top_k, - payload.max_tokens, - payload.stream, - stop = payload.stop, - min_p = payload.min_p, - repetition_penalty = payload.repetition_penalty, - presence_penalty = payload.presence_penalty, - tool_choice = tool_choice, - response_format = _extract_response_format(payload), - chat_template_kwargs = tpl_kwargs, - backend_ctx = backend_ctx, - ) - - -async def _openai_passthrough_stream( - request, - cancel_event, - llama_backend, - payload, - model_name, - completion_id, -): - """Streaming client-side pass-through for /v1/chat/completions. - - Forwards the client's OpenAI function-calling request to llama-server and - relays the SSE stream back verbatim. This preserves llama-server's - native response ``id``, ``finish_reason`` (including ``"tool_calls"``), - ``delta.tool_calls``, and the trailing ``usage`` chunk so the client - observes a standard OpenAI response. - """ - target_url = f"{llama_backend.base_url}/v1/chat/completions" - body = _build_openai_passthrough_body( - payload, backend_ctx = llama_backend.context_length - ) - - _cancel_keys = (payload.cancel_id, payload.session_id, completion_id) - _tracker = _TrackedCancel(cancel_event, *_cancel_keys) - _tracker.__enter__() - - # Outer guard: asyncio.CancelledError at `await client.send(...)` is - # a BaseException that bypasses `except httpx.RequestError`; without - # this the tracker leaks. The generator's finally only runs once - # iteration starts. - try: - # Dispatch BEFORE returning StreamingResponse so transport errors - # and non-200 upstream statuses surface as real HTTP errors -- - # OpenAI SDKs rely on status codes to raise APIError/BadRequestError. - client = httpx.AsyncClient( - timeout = 600, - limits = httpx.Limits(max_keepalive_connections = 0), - ) - resp = None - try: - req = client.build_request("POST", target_url, json = body) - resp = await client.send(req, stream = True) - except httpx.RequestError as e: - # llama-server subprocess crashed / still starting / unreachable. - logger.error("openai passthrough stream: upstream unreachable: %s", e) - if resp is not None: - try: - await resp.aclose() - except Exception: - pass - try: - await client.aclose() - except Exception: - pass - raise HTTPException( - status_code = 502, - detail = _friendly_error(e), - ) - - if resp.status_code != 200: - err_bytes = await resp.aread() - err_text = err_bytes.decode("utf-8", errors = "replace") - logger.error( - "openai passthrough upstream error: status=%s body=%s", - resp.status_code, - err_text[:500], - ) - upstream_status = resp.status_code - try: - await resp.aclose() - except Exception: - pass - try: - await client.aclose() - except Exception: - pass - raise HTTPException( - status_code = upstream_status, - detail = f"llama-server error: {err_text[:500]}", - ) - - async def _stream(): - # Same httpx lifecycle pattern as _anthropic_passthrough_stream: - # save resp.aiter_lines() so the finally block can aclose() it - # on our task. See that function for full rationale. - lines_iter = None - # During llama-server prefill, `aiter_lines()` blocks until the - # first SSE chunk arrives. The in-loop `cancel_event` check - # cannot fire until then, which is the exact proxy/Colab - # scenario the cancel POST is meant to recover from. Run a - # tiny watcher that closes `resp` as soon as cancel fires, - # unblocking the iterator with a RemoteProtocolError caught - # in the except clause below. - cancel_watcher = asyncio.create_task( - _await_cancel_then_close(cancel_event, resp) - ) - try: - lines_iter = resp.aiter_lines() - async for raw_line in lines_iter: - if cancel_event.is_set(): - break - if await request.is_disconnected(): - cancel_event.set() - break - if not raw_line: - continue - if not raw_line.startswith("data: "): - continue - # Relay verbatim to preserve llama-server's native id, - # finish_reason, delta.tool_calls, and usage chunks. - yield raw_line + "\n\n" - if raw_line[6:].strip() == "[DONE]": - break - except (httpx.RemoteProtocolError, httpx.ReadError, httpx.CloseError): - # Watcher closed resp on cancel. Emit nothing extra; the - # client either initiated the cancel or already disconnected. - if not cancel_event.is_set(): - raise - except Exception as e: - # 200 headers are already flushed; errors must be in the SSE body. - logger.error("openai passthrough stream error: %s", e) - err = { - "error": { - "message": _friendly_error(e), - "type": "server_error", - }, - } - yield f"data: {json.dumps(err)}\n\n" - finally: - cancel_watcher.cancel() - try: - await cancel_watcher - except (asyncio.CancelledError, Exception): - pass - if lines_iter is not None: - try: - await lines_iter.aclose() - except Exception: - pass - try: - await resp.aclose() - except Exception: - pass - try: - await client.aclose() - except Exception: - pass - _tracker.__exit__(None, None, None) - - return StreamingResponse( - _stream(), - media_type = "text/event-stream", - headers = { - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", - }, - ) - except BaseException: - _tracker.__exit__(None, None, None) - raise - - -async def _openai_passthrough_non_streaming( - llama_backend, - payload, - model_name, -): - """Non-streaming client-side pass-through for /v1/chat/completions. - - Returns llama-server's JSON response verbatim (via JSONResponse) so the - client sees the native response ``id``, ``finish_reason`` (including - ``"tool_calls"``), structured ``tool_calls``, and accurate ``usage`` - token counts. - """ - target_url = f"{llama_backend.base_url}/v1/chat/completions" - body = _build_openai_passthrough_body( - payload, backend_ctx = llama_backend.context_length - ) - - try: - async with httpx.AsyncClient() as client: - resp = await client.post(target_url, json = body, timeout = 600) - except httpx.RequestError as e: - # llama-server subprocess crashed / still starting / unreachable. - # Surface the same friendly message the sync chat path emits so - # operators don't see a bare 500 with no diagnostic. - logger.error("openai passthrough non-streaming: upstream unreachable: %s", e) - raise HTTPException( - status_code = 502, - detail = _friendly_error(e), - ) - - if resp.status_code != 200: - raise HTTPException( - status_code = resp.status_code, - detail = f"llama-server error: {resp.text[:500]}", - ) - - # Guided-decoding fence wrap. llama-server returns raw JSON that matches - # the schema (no surrounding markdown) because the GBNF grammar only - # emits the JSON object itself. data_designer's llm-structured parser - # looks for a ```json ... ``` markdown fence and discards unfenced - # output, which collapses a 100%-valid guided-decoding run to 0/N. - # Wrap each choice's content in the expected fence when the caller - # asked for guided decoding, leaving already-fenced content alone. - if _extract_response_format(payload) is not None: - try: - data = resp.json() - changed = False - for choice in data.get("choices", []): - if not isinstance(choice, dict): - continue - msg = choice.get("message") - if not isinstance(msg, dict): - continue - content = msg.get("content") - if not isinstance(content, str): - continue - stripped = content.strip() - if not stripped or stripped.startswith("```"): - continue - msg["content"] = f"```json\n{stripped}\n```" - changed = True - if changed: - return JSONResponse(content = data) - except Exception as exc: - # Wrap is best-effort; fall through to the verbatim body if - # the response is not JSON-shaped or the structure is unusual. - logger.warning( - "response_format fence wrap skipped: %s", - exc, - ) - - # Pass the upstream body through as raw bytes — skips a redundant - # parse+re-serialize round-trip and keeps the response truly - # verbatim (matches the docstring). Status is guaranteed 200 by - # the check above. - return Response(content = resp.content, media_type = "application/json") diff --git a/studio/backend/routes/models.py b/studio/backend/routes/models.py index d01e94b0c9..3f361ca5eb 100644 --- a/studio/backend/routes/models.py +++ b/studio/backend/routes/models.py @@ -5,12 +5,8 @@ Model Management API routes """ -import hashlib -import json import os -import shutil import sys -import uuid from pathlib import Path from fastapi import APIRouter, Body, Depends, HTTPException, Query from typing import List, Optional @@ -36,9 +32,8 @@ def _is_valid_repo_id(repo_id: str) -> bool: # Import backend functions try: from utils.models import ( - scan_trained_models, + scan_trained_loras, scan_exported_models, - get_base_model_from_checkpoint, load_model_defaults, get_base_model_from_lora, is_vision_model, @@ -67,9 +62,8 @@ def _is_valid_repo_id(repo_id: str) -> bool: if str(parent_backend) not in sys.path: sys.path.insert(0, str(parent_backend)) from utils.models import ( - scan_trained_models, + scan_trained_loras, scan_exported_models, - get_base_model_from_checkpoint, load_model_defaults, get_base_model_from_lora, is_vision_model, @@ -105,8 +99,6 @@ def _is_valid_repo_id(repo_id: str) -> bool: ModelListResponse, ) from models.models import ( - BrowseEntry, - BrowseFoldersResponse, GgufVariantDetail, GgufVariantsResponse, ModelType, @@ -415,267 +407,6 @@ def _scan_lmstudio_dir(lm_dir: Path) -> List[LocalModelInfo]: return found -def _ollama_links_dir(ollama_dir: Path) -> Optional[Path]: - """Return a writable directory for Ollama ``.gguf`` symlinks. - - Prefers ``/.studio_links/`` so the links sit next to the - blobs they point at. Falls back to a per-ollama-dir namespace under - Studio's own cache when the models directory is read-only (common - for system installs under ``/usr/share/ollama`` or ``/var/lib/ollama``) - so we still surface Ollama models in those environments. - """ - from utils.paths.storage_roots import cache_root - - primary = ollama_dir / ".studio_links" - try: - primary.mkdir(exist_ok = True) - return primary - except OSError as e: - logger.debug( - "Ollama dir %s not writable for .studio_links (%s); " - "falling back to Studio cache", - ollama_dir, - e, - ) - - # Fallback: namespace by a hash of the ollama_dir so two different - # Ollama roots don't collide. This is a cache path, not a security - # boundary. - try: - digest = hashlib.sha256(str(ollama_dir.resolve()).encode()).hexdigest()[:12] - except OSError: - digest = "default" - fallback = cache_root() / "ollama_links" / digest - try: - fallback.mkdir(parents = True, exist_ok = True) - return fallback - except OSError as e: - logger.warning( - "Could not create Ollama symlink cache at %s: %s", - fallback, - e, - ) - return None - - -def _scan_ollama_dir( - ollama_dir: Path, limit: Optional[int] = None -) -> List[LocalModelInfo]: - """Scan an Ollama models directory for downloaded models. - - Ollama stores models in a content-addressable layout:: - - /manifests//// - /blobs/sha256-... - - The default host is ``registry.ollama.ai`` with namespace - ``library`` (official models), but users can pull from custom - namespaces (``mradermacher/llama3``) or entirely different hosts - (``hf.co/org/repo:tag``). We iterate all manifest files via - ``rglob`` so every layout depth is discovered. - - Each manifest is JSON with a ``layers`` array. The layer with - ``mediaType == "application/vnd.ollama.image.model"`` contains the - GGUF weights. Vision models also have a projector layer - (``application/vnd.ollama.image.projector``). We read the config - layer to extract family/size info. - - Since Ollama blobs lack a ``.gguf`` extension (which the GGUF - loading pipeline requires), we create ``.gguf``-named links - pointing at the blobs so the existing ``detect_gguf_model`` and - ``llama-server -m`` paths work unchanged. Each model gets its - own subdirectory under the links dir (keyed by a short hash of - the manifest path) so that ``detect_mmproj_file`` only sees the - projector for *that* model. Links are created as symlinks when - possible, falling back to hardlinks (Windows without Developer - Mode) as a last resort. The link dir lives under - ``/.studio_links/`` when writable, otherwise under - Studio's own cache directory. - """ - manifests_root = ollama_dir / "manifests" - if not manifests_root.is_dir(): - return [] - - found: List[LocalModelInfo] = [] - blobs_dir = ollama_dir / "blobs" - links_root = _ollama_links_dir(ollama_dir) - if links_root is None: - logger.warning( - "Skipping Ollama scan for %s: no writable location for .gguf links", - ollama_dir, - ) - return [] - - def _make_link(link_dir: Path, link_name: str, target: Path) -> Optional[str]: - """Create a .gguf-named link to an Ollama blob. - - Tries symlink first, then hardlink (works on Windows without - Developer Mode when target is on the same filesystem). Skips - the model if neither works -- a full file copy of a multi-GB - GGUF inside a synchronous API request would block the backend. - - Idempotent: skips recreation when a valid link already exists. - """ - link_dir.mkdir(parents = True, exist_ok = True) - link_path = link_dir / link_name - resolved = target.resolve() - - # Skip if the link already points at the exact same blob. - # Only use samefile -- size-based checks can reuse stale links - # after `ollama pull` updates a tag to a same-sized blob. - try: - if link_path.exists() and os.path.samefile(str(link_path), str(resolved)): - return str(link_path) - except OSError as e: - logger.debug("Error checking existing link %s: %s", link_path, e) - - tmp_path = link_dir / f".{link_name}.tmp-{uuid.uuid4().hex[:8]}" - try: - if tmp_path.is_symlink() or tmp_path.exists(): - tmp_path.unlink() - try: - tmp_path.symlink_to(resolved) - except OSError: - try: - os.link(str(resolved), str(tmp_path)) - except OSError: - logger.warning( - "Could not create link for Ollama blob %s " - "(symlinks and hardlinks both failed). " - "Skipping model to avoid blocking the API.", - target, - ) - return None - os.replace(str(tmp_path), str(link_path)) - return str(link_path) - except OSError as e: - logger.debug("Could not create Ollama link %s: %s", link_path, e) - try: - if tmp_path.is_symlink() or tmp_path.exists(): - tmp_path.unlink() - except OSError as cleanup_err: - logger.debug( - "Could not clean up tmp path %s: %s", tmp_path, cleanup_err - ) - return None - - try: - for tag_file in manifests_root.rglob("*"): - if not tag_file.is_file(): - continue - - rel = tag_file.relative_to(manifests_root) - parts = rel.parts - if len(parts) < 3: - continue - - host = parts[0] - repo_parts = list(parts[1:-1]) - tag = parts[-1] - - if ( - host == "registry.ollama.ai" - and repo_parts - and repo_parts[0] == "library" - ): - repo_name = "/".join(repo_parts[1:]) - elif host == "registry.ollama.ai": - repo_name = "/".join(repo_parts) - else: - repo_name = "/".join([host] + repo_parts) - - if not repo_name: - continue - - display = f"{repo_name}:{tag}" - - manifest_key = rel.as_posix() - stem_hash = hashlib.sha256(manifest_key.encode()).hexdigest()[:10] - - try: - manifest = json.loads(tag_file.read_text()) - except (json.JSONDecodeError, OSError) as e: - logger.debug( - "Skipping unreadable/invalid Ollama manifest %s: %s", - tag_file, - e, - ) - continue - - config_digest = manifest.get("config", {}).get("digest", "") - model_type = "" - file_type = "" - if config_digest and blobs_dir.is_dir(): - config_blob = blobs_dir / config_digest.replace(":", "-") - if config_blob.is_file(): - try: - cfg = json.loads(config_blob.read_text()) - model_type = cfg.get("model_type", "") - file_type = cfg.get("file_type", "") - except (json.JSONDecodeError, OSError) as e: - logger.debug( - "Could not parse Ollama config blob %s: %s", - config_blob, - e, - ) - - model_link_dir = links_root / stem_hash - - gguf_link_path: Optional[str] = None - quant = f"-{file_type}" if file_type else "" - safe_name = repo_name.replace("/", "-") - for layer in manifest.get("layers") or []: - media = layer.get("mediaType", "") - digest = layer.get("digest", "") - if not digest: - continue - - if media == "application/vnd.ollama.image.model": - candidate = blobs_dir / digest.replace(":", "-") - if candidate.is_file(): - link_name = f"{safe_name}-{tag}{quant}.gguf" - gguf_link_path = _make_link( - model_link_dir, link_name, candidate - ) - - elif media == "application/vnd.ollama.image.projector": - candidate = blobs_dir / digest.replace(":", "-") - if candidate.is_file(): - mmproj_name = f"{safe_name}-{tag}-mmproj.gguf" - _make_link(model_link_dir, mmproj_name, candidate) - - if not gguf_link_path: - continue - - suffix = "" - if model_type: - suffix += f" ({model_type}" - if file_type: - suffix += f" {file_type}" - suffix += ")" - - try: - updated_at = tag_file.stat().st_mtime - except OSError: - updated_at = None - - found.append( - LocalModelInfo( - id = gguf_link_path, - model_id = f"ollama/{repo_name}:{tag}", - display_name = display + suffix, - path = gguf_link_path, - source = "custom", - updated_at = updated_at, - ), - ) - if limit is not None and len(found) >= limit: - return found - except OSError as e: - logger.warning("Error scanning Ollama directory %s: %s", ollama_dir, e) - return found - - @router.get("/local", response_model = LocalModelListResponse) async def list_local_models( models_dir: str = Query( @@ -758,27 +489,11 @@ async def list_local_models( for folder in custom_folders: folder_path = Path(folder["path"]) try: - # Ollama scanner creates .studio_links/ with .gguf symlinks. - # Filter those from the generic scanners to avoid duplicates - # and leaking internal paths into the UI. - _generic = [ - m - for m in ( - _scan_models_dir(folder_path, limit = _MAX_MODELS_PER_FOLDER) - + _scan_hf_cache(folder_path) - + _scan_lmstudio_dir(folder_path) - ) - if not any( - p in (".studio_links", "ollama_links") - for p in Path(m.path).parts - ) - ] - custom_models = _generic - if len(custom_models) < _MAX_MODELS_PER_FOLDER: - custom_models += _scan_ollama_dir( - folder_path, - limit = _MAX_MODELS_PER_FOLDER - len(custom_models), - ) + custom_models = ( + _scan_models_dir(folder_path, limit = _MAX_MODELS_PER_FOLDER) + + _scan_hf_cache(folder_path) + + _scan_lmstudio_dir(folder_path) + )[:_MAX_MODELS_PER_FOLDER] except OSError as e: logger.warning("Skipping unreadable scan folder %s: %s", folder_path, e) continue @@ -856,580 +571,6 @@ async def remove_scan_folder_endpoint( return {"ok": True} -@router.get("/recommended-folders") -async def get_recommended_folders( - current_subject: str = Depends(get_current_subject), -): - """Return well-known model directories that exist on this machine. - - Lightweight alternative to ``browse-folders`` for showing quick-pick - chips without the overhead of enumerating a directory tree. Returns - paths that actually exist on disk (HF cache, LM Studio, Ollama, - ``~/models``, etc.) so the frontend can offer them as one-click - "Recommended" shortcuts in the Custom Folders section. - """ - from utils.paths.storage_roots import lmstudio_model_dirs - - folders: list[str] = [] - seen: set[str] = set() - - def _add(p: Optional[Path]) -> None: - if p is None: - return - try: - resolved = str(p.resolve()) - except OSError: - return - if resolved in seen: - return - if Path(resolved).is_dir() and os.access(resolved, os.R_OK | os.X_OK): - seen.add(resolved) - folders.append(resolved) - - # LM Studio model directories - try: - for p in lmstudio_model_dirs(): - _add(p) - except Exception as e: - logger.warning("Failed to scan for LM Studio model directories: %s", e) - - # Ollama model directories - ollama_env = os.environ.get("OLLAMA_MODELS") - if ollama_env: - _add(Path(ollama_env).expanduser()) - for candidate in ( - Path.home() / ".ollama" / "models", - Path("/usr/share/ollama/.ollama/models"), - Path("/var/lib/ollama/.ollama/models"), - ): - _add(candidate) - - return {"folders": folders} - - -# Heuristic ceiling on how many children to stat when checking whether a -# directory "looks like" it contains models. Keeps the browser snappy -# even when a directory has thousands of unrelated entries. -_BROWSE_MODEL_HINT_PROBE = 64 -# Hard cap on how many subdirectory entries we send back. Pointing the -# browser at something like ``/usr/lib`` or ``/proc`` must not stat-storm -# the process or send tens of thousands of rows to the client. -_BROWSE_ENTRY_CAP = 2000 - - -def _count_model_files(directory: Path, cap: int = 200) -> int: - """Count GGUF/safetensors files immediately inside *directory*. - Used to surface a count-hint on the response so the UI can tell - users that a leaf directory (no subdirs, only weights) is a valid - "Use this folder" target. - - Bounded by *visited entries*, not by *match count*: in directories - with many non-model files (or many subdirectories) the scan still - stops after ``cap`` entries so a UI hint never costs more than a - bounded directory walk. - """ - n = 0 - visited = 0 - try: - for f in directory.iterdir(): - visited += 1 - if visited > cap: - break - try: - if f.is_file(): - low = f.name.lower() - if low.endswith((".gguf", ".safetensors")): - n += 1 - except OSError: - continue - except PermissionError as e: - logger.debug("browse-folders: permission denied counting %s: %s", directory, e) - return 0 - except OSError as e: - logger.debug("browse-folders: OS error counting %s: %s", directory, e) - return 0 - return n - - -def _has_direct_model_signal(directory: Path) -> bool: - """Return True if *directory* has an immediate child that signals - it holds a model: a GGUF/safetensors/config.json file, or a - `models--*` subdir (HF hub cache). Bounded by - ``_BROWSE_MODEL_HINT_PROBE`` to stay fast.""" - try: - it = directory.iterdir() - except OSError: - return False - try: - for i, child in enumerate(it): - if i >= _BROWSE_MODEL_HINT_PROBE: - break - try: - name = child.name - if child.is_file(): - low = name.lower() - if low.endswith((".gguf", ".safetensors")): - return True - if low in ("config.json", "adapter_config.json"): - return True - elif child.is_dir() and name.startswith("models--"): - return True - except OSError: - continue - except OSError: - return False - return False - - -def _looks_like_model_dir(directory: Path) -> bool: - """Bounded heuristic used by the folder browser to flag directories - worth exploring. False negatives are fine; the real scanner is - authoritative. - - Three signals, cheapest first: - - 1. Directory name itself: ``models--*`` is the HuggingFace hub cache - layout (``blobs``/``refs``/``snapshots`` children wouldn't match - the file-level probes below). - 2. An immediate child is a weight file or config (handled by - :func:`_has_direct_model_signal`). - 3. A grandchild has a direct signal -- this catches the - ``publisher/model/weights.gguf`` layout used by LM Studio and - Ollama. We probe at most the first - ``_BROWSE_MODEL_HINT_PROBE`` child directories, each of which is - checked with a bounded :func:`_has_direct_model_signal` call, - so the total cost stays O(PROBE^2) worst-case. - """ - if directory.name.startswith("models--"): - return True - if _has_direct_model_signal(directory): - return True - # Grandchild probe: LM Studio / Ollama publisher/model layout. - try: - it = directory.iterdir() - except OSError: - return False - try: - for i, child in enumerate(it): - if i >= _BROWSE_MODEL_HINT_PROBE: - break - try: - if not child.is_dir(): - continue - except OSError: - continue - # Fast name check first - if child.name.startswith("models--"): - return True - if _has_direct_model_signal(child): - return True - except OSError: - return False - return False - - -def _build_browse_allowlist() -> list[Path]: - """Return the list of root directories the folder browser is allowed - to walk. The same list is used to seed the sidebar suggestion chips, - so chip targets are always reachable. - - Roots include the current user's HOME, the resolved HF cache dirs, - Studio's own outputs/exports/studio root, registered scan folders, - and well-known third-party local-LLM dirs (LM Studio, Ollama, - `~/models`). Each is added only if it currently resolves to a real - directory, so we never produce a "dead" sandbox boundary the user - can't navigate into. - """ - from utils.paths import ( - hf_default_cache_dir, - legacy_hf_cache_dir, - well_known_model_dirs, - ) - from storage.studio_db import list_scan_folders - - candidates: list[Path] = [] - - def _add(p: Optional[Path]) -> None: - if p is None: - return - try: - resolved = p.resolve() - except OSError: - return - if resolved.is_dir(): - candidates.append(resolved) - - _add(Path.home()) - _add(_resolve_hf_cache_dir()) - try: - _add(hf_default_cache_dir()) - except Exception: # noqa: BLE001 -- best-effort - pass - try: - _add(legacy_hf_cache_dir()) - except Exception: # noqa: BLE001 -- best-effort - pass - try: - from utils.paths import ( - exports_root, - outputs_root, - studio_root, - ) - - _add(studio_root()) - _add(outputs_root()) - _add(exports_root()) - except Exception as exc: # noqa: BLE001 -- best-effort - logger.debug("browse-folders: studio roots unavailable: %s", exc) - try: - for folder in list_scan_folders(): - p = folder.get("path") - if p: - _add(Path(p)) - except Exception as exc: # noqa: BLE001 -- best-effort - logger.debug("browse-folders: could not load scan folders: %s", exc) - try: - for p in well_known_model_dirs(): - _add(p) - except Exception as exc: # noqa: BLE001 -- best-effort - logger.debug("browse-folders: well-known dirs unavailable: %s", exc) - - # Dedupe while preserving order. - seen: set[str] = set() - deduped: list[Path] = [] - for p in candidates: - key = str(p) - if key in seen: - continue - seen.add(key) - deduped.append(p) - return deduped - - -def _is_path_inside_allowlist(target: Path, allowed_roots: list[Path]) -> bool: - """Return True if *target* equals or is a descendant of any allowed - root. The comparison uses ``os.path.realpath`` so symlinks cannot be - used to escape the sandbox. - """ - try: - target_real = os.path.realpath(str(target)) - except OSError: - return False - for root in allowed_roots: - try: - root_real = os.path.realpath(str(root)) - except OSError: - continue - if target_real == root_real or target_real.startswith(root_real + os.sep): - return True - return False - - -def _normalize_browse_request_path(path: Optional[str]) -> str: - """Normalize the browse request path lexically, without touching the FS.""" - if path is None or not path.strip(): - return os.path.normpath(str(Path.home())) - - expanded = os.path.expanduser(path.strip()) - if not os.path.isabs(expanded): - expanded = os.path.join(str(Path.cwd()), expanded) - return os.path.normpath(expanded) - - -def _browse_relative_parts(requested_path: str, root: Path) -> Optional[list[str]]: - """Return validated relative path components under ``root``.""" - root_text = os.path.normpath(str(root)) - try: - rel_text = os.path.relpath(requested_path, root_text) - except ValueError: - return None - - if rel_text == ".": - return [] - if rel_text == ".." or rel_text.startswith(f"..{os.sep}"): - return None - - parts = [part for part in rel_text.split(os.sep) if part not in ("", ".")] - altsep = os.altsep - for part in parts: - if part == ".." or os.sep in part or (altsep and altsep in part): - return None - return parts - - -def _match_browse_child(current: Path, name: str) -> Optional[Path]: - """Return the immediate child named ``name`` under ``current``.""" - try: - for child in current.iterdir(): - if child.name == name: - return child - except PermissionError: - raise HTTPException( - status_code = 403, - detail = f"Permission denied reading {current}", - ) from None - except OSError as exc: - raise HTTPException( - status_code = 500, - detail = f"Could not read {current}: {exc}", - ) from exc - return None - - -def _resolve_browse_target(path: Optional[str], allowed_roots: list[Path]) -> Path: - """Resolve a requested browse path by walking from trusted allowlist roots.""" - requested_path = _normalize_browse_request_path(path) - resolved_roots: list[Path] = [] - seen_roots: set[str] = set() - for root in sorted(allowed_roots, key = lambda p: len(str(p)), reverse = True): - try: - resolved = root.resolve() - except OSError: - continue - key = str(resolved) - if key in seen_roots: - continue - seen_roots.add(key) - resolved_roots.append(resolved) - - for root in resolved_roots: - parts = _browse_relative_parts(requested_path, root) - if parts is None: - continue - - current = root - for part in parts: - child = _match_browse_child(current, part) - if child is None: - raise HTTPException( - status_code = 404, - detail = f"Path does not exist: {requested_path}", - ) - try: - resolved_child = child.resolve() - except OSError as exc: - raise HTTPException( - status_code = 400, - detail = f"Invalid path: {exc}", - ) from exc - if not _is_path_inside_allowlist(resolved_child, resolved_roots): - raise HTTPException( - status_code = 403, - detail = ( - "Path is not in the browseable allowlist. Register it via " - "POST /api/models/scan-folders first, or pick a directory " - "under your home folder." - ), - ) - current = resolved_child - - if not current.is_dir(): - raise HTTPException( - status_code = 400, - detail = f"Not a directory: {current}", - ) - return current - - raise HTTPException( - status_code = 403, - detail = ( - "Path is not in the browseable allowlist. Register it via " - "POST /api/models/scan-folders first, or pick a directory " - "under your home folder." - ), - ) - - -@router.get("/browse-folders", response_model = BrowseFoldersResponse) -async def browse_folders( - path: Optional[str] = Query( - None, - description = ( - "Directory to list. If omitted, defaults to the current user's " - "home directory. Tilde (`~`) and relative paths are expanded. " - "Must resolve inside the allowlist of browseable roots (HOME, " - "HF cache, Studio dirs, registered scan folders, well-known " - "model dirs)." - ), - ), - show_hidden: bool = Query( - False, - description = "Include entries whose name starts with a dot", - ), - current_subject: str = Depends(get_current_subject), -): - """ - List immediate subdirectories of *path* for the Custom Folders picker. - - The frontend uses this to render a modal folder browser without needing - a native OS dialog (Studio is served over HTTP, so the browser can't - reveal absolute paths on the host). The endpoint is read-only and does - not create, move, or delete anything. It simply enumerates visible - subdirectories so the user can click their way to a folder and hand - the resulting string back to POST `/api/models/scan-folders`. - - Sandbox: requests are bounded to the allowlist returned by - :func:`_build_browse_allowlist` (HOME, HF cache, Studio dirs, - registered scan folders, well-known model dirs). Paths outside the - allowlist return 403 so users cannot probe ``/etc``, ``/proc``, - ``/root`` (when not HOME), or other sensitive system locations - even if the server process can read them. Symlinks are resolved - via ``os.path.realpath`` before the check, so symlink traversal - cannot escape the sandbox either. - - Sorting: directories that look like they hold models come first, then - plain directories, then hidden entries (if `show_hidden=true`). - """ - from utils.paths import hf_default_cache_dir, well_known_model_dirs - from storage.studio_db import list_scan_folders - - # Build the allowlist once -- both the sandbox check below and the - # suggestion chips use the same set, so chips are always navigable. - allowed_roots = _build_browse_allowlist() - - try: - target = _resolve_browse_target(path, allowed_roots) - except HTTPException: - requested_path = _normalize_browse_request_path(path) - if path is not None and path.strip(): - logger.warning( - "browse-folders: rejected path %r (normalized=%s)", - path, - requested_path, - ) - raise - - # Enumerate immediate subdirectories with a bounded cap so a stray - # query against ``/usr/lib`` or ``/proc`` can't stat-storm the process. - entries: list[BrowseEntry] = [] - truncated = False - visited = 0 - try: - it = target.iterdir() - except PermissionError: - raise HTTPException( - status_code = 403, - detail = f"Permission denied reading {target}", - ) - except OSError as exc: - raise HTTPException( - status_code = 500, - detail = f"Could not read {target}: {exc}", - ) - - try: - for child in it: - # Bound by *visited entries*, not by *appended entries*: in - # directories full of files (or hidden subdirs when - # ``show_hidden=False``) the cap on ``len(entries)`` would - # never trigger and we'd still stat every child. Counting - # visits keeps the worst-case work to ``_BROWSE_ENTRY_CAP`` - # iterdir/is_dir calls regardless of how many of them - # survive the filters below. - visited += 1 - if visited > _BROWSE_ENTRY_CAP: - truncated = True - break - try: - if not child.is_dir(): - continue - except OSError: - continue - name = child.name - is_hidden = name.startswith(".") - if is_hidden and not show_hidden: - continue - entries.append( - BrowseEntry( - name = name, - has_models = _looks_like_model_dir(child), - hidden = is_hidden, - ) - ) - except PermissionError as exc: - logger.debug( - "browse-folders: permission denied during enumeration of %s: %s", - target, - exc, - ) - except OSError as exc: - # Rare: iterdir succeeded but reading a specific entry failed. - logger.warning("browse-folders: partial enumeration of %s: %s", target, exc) - - # Model-bearing dirs first, then plain, then hidden; case-insensitive - # alphabetical within each bucket. - def _sort_key(e: BrowseEntry) -> tuple[int, str]: - bucket = 0 if e.has_models else (2 if e.hidden else 1) - return (bucket, e.name.lower()) - - entries.sort(key = _sort_key) - - # Parent is None at the filesystem root (`p.parent == p`) AND when - # the parent would step outside the sandbox -- otherwise the up-row - # would 403 on click. Users can still hop to other allowed roots - # via the suggestion chips below. - parent: Optional[str] - if target.parent == target or not _is_path_inside_allowlist( - target.parent, allowed_roots - ): - parent = None - else: - parent = str(target.parent) - - # Handy starting points for the quick-pick chips. - suggestions: list[str] = [] - seen_sug: set[str] = set() - - def _add_sug(p: Optional[Path]) -> None: - if p is None: - return - try: - resolved = str(p.resolve()) - except OSError: - return - if resolved in seen_sug: - return - if Path(resolved).is_dir(): - seen_sug.add(resolved) - suggestions.append(resolved) - - # Home always comes first -- it's the safe fallback when everything - # else is cold. - _add_sug(Path.home()) - # The HF cache root the process is actually using. - try: - _add_sug(hf_default_cache_dir()) - except Exception: - pass - # Already-registered scan folders (what the user has curated). - try: - for folder in list_scan_folders(): - _add_sug(Path(folder.get("path", ""))) - except Exception as exc: - logger.debug("browse-folders: could not load scan folders: %s", exc) - # Directories commonly used by other local-LLM tools: LM Studio - # (`~/.lmstudio/models` + legacy `~/.cache/lm-studio/models` + - # user-configured downloadsFolder from LM Studio's settings.json), - # Ollama (`~/.ollama/models` + common system paths + OLLAMA_MODELS - # env var), and generic user-choice spots (`~/models`, `~/Models`). - # Each helper only returns paths that currently exist so we never - # show dead chips. - try: - for p in well_known_model_dirs(): - _add_sug(p) - except Exception as exc: - logger.debug("browse-folders: could not load well-known dirs: %s", exc) - - return BrowseFoldersResponse( - current = str(target), - parent = parent, - entries = entries, - suggestions = suggestions, - truncated = truncated, - model_files_here = _count_model_files(target), - ) - - @router.get("/list") async def list_models( current_subject: str = Depends(get_current_subject), @@ -1650,16 +791,15 @@ async def scan_loras( lora_list = [] # Scan training outputs - trained_models = scan_trained_models(outputs_dir = resolved_outputs_dir) - for display_name, model_path, model_type in trained_models: - base_model = get_base_model_from_checkpoint(model_path) + trained_loras = scan_trained_loras(outputs_dir = resolved_outputs_dir) + for display_name, adapter_path in trained_loras: + base_model = get_base_model_from_lora(adapter_path) lora_list.append( LoRAInfo( display_name = display_name, - adapter_path = model_path, + adapter_path = adapter_path, base_model = base_model, source = "training", - export_type = model_type, ) ) @@ -1685,338 +825,6 @@ async def scan_loras( ) -def _is_path_under(path: Path, root: Path) -> bool: - try: - path.resolve().relative_to(root.resolve()) - return True - except ValueError: - return False - - -def _is_path_under_lexically(path: Path, root: Path) -> bool: - """Check containment without resolving the final path's symlink target.""" - try: - absolute_path = Path(os.path.abspath(str(path))) - absolute_root = Path(os.path.abspath(str(root))) - absolute_path.relative_to(absolute_root) - return True - except ValueError: - return False - - -def _loaded_model_matches_deleted_path(active_model: str, deleted_path: Path) -> bool: - try: - active = Path(active_model).expanduser().resolve() - target = deleted_path.resolve() - return active == target or (target.is_dir() and active.is_relative_to(target)) - except (OSError, RuntimeError, ValueError) as e: - logger.debug( - "Could not resolve loaded/deleted model paths; falling back to string comparison: %s", - e, - ) - active_lower = active_model.lower() - target_lower = str(deleted_path).lower() - return active_lower == target_lower or active_lower.startswith( - f"{target_lower}{os.sep}" - ) - - -def _loading_model_matches_deleted_path( - loading_model: object, - deleted_path: Path, -) -> bool: - if not loading_model: - return False - return _loaded_model_matches_deleted_path(str(loading_model), deleted_path) - - -def _prune_empty_parents(start: Path, stop_at: Path) -> None: - """Remove empty ancestor directories of ``start`` up to (but not including) ``stop_at``. - - Used after deleting a model checkpoint so the enclosing run directory does - not linger as an empty entry in scan results. - """ - try: - stop_resolved = stop_at.resolve() - except OSError: - return - parent = start.parent - while True: - try: - parent_resolved = parent.resolve() - except OSError: - return - if parent_resolved == stop_resolved: - return - try: - parent_resolved.relative_to(stop_resolved) - except ValueError: - return - try: - parent.rmdir() - except OSError: - return - parent = parent.parent - - -def _delete_gguf_variant_files(root: Path, variant: str) -> tuple[int, int]: - deleted_count = 0 - deleted_bytes = 0 - for path in root.rglob("*"): - if not path.is_file() or not _is_main_gguf_filename(path.name): - continue - if _extract_quant_label(path.name).lower() != variant.lower(): - continue - try: - deleted_bytes += path.stat().st_size - except OSError: - pass - path.unlink() - deleted_count += 1 - return deleted_count, deleted_bytes - - -@router.delete("/delete-finetuned") -async def delete_finetuned_model( - model_path: str = Body(...), - source: str = Body(...), - export_type: Optional[str] = Body(None), - gguf_variant: Optional[str] = Body(None), - current_subject: str = Depends(get_current_subject), -): - """Delete a Studio-trained or exported model from disk. - - Only paths under Studio's outputs/exports roots are accepted. Exported - GGUF entries can delete one quantization variant at a time. - """ - if source not in {"training", "exported"}: - raise HTTPException( - status_code = 400, - detail = "Only trained or exported Studio models can be deleted", - ) - - if not model_path or not model_path.strip(): - raise HTTPException(status_code = 400, detail = "model_path is required") - - if export_type == "gguf" and not gguf_variant: - raise HTTPException( - status_code = 400, - detail = "gguf_variant is required when export_type is 'gguf'", - ) - - raw_path = Path(model_path).expanduser() - if source == "training": - target_path = raw_path - allowed_root = outputs_root() - else: - allowed_root = exports_root() - target_path = ( - raw_path.parent - if export_type == "gguf" and raw_path.suffix.lower() == ".gguf" - else raw_path - ) - - allowed_root = allowed_root.resolve() - delete_path = Path(os.path.abspath(str(target_path))) - delete_path_is_symlink = delete_path.is_symlink() - - if delete_path_is_symlink: - if not _is_path_under_lexically(delete_path, allowed_root): - raise HTTPException( - status_code = 400, - detail = "Model path is outside Studio storage", - ) - if export_type == "gguf" and gguf_variant: - target_path = delete_path.resolve() - if not _is_path_under(target_path, allowed_root): - raise HTTPException( - status_code = 400, - detail = "Model path is outside Studio storage", - ) - else: - target_path = delete_path - else: - target_path = target_path.resolve() - - should_check_resolved_path = not delete_path_is_symlink or ( - export_type == "gguf" and gguf_variant - ) - if should_check_resolved_path and not _is_path_under(target_path, allowed_root): - raise HTTPException( - status_code = 400, - detail = "Model path is outside Studio storage", - ) - if target_path == allowed_root: - raise HTTPException( - status_code = 400, - detail = "Refusing to delete storage root", - ) - if not target_path.exists() and not target_path.is_symlink(): - raise HTTPException(status_code = 404, detail = "Model not found on disk") - - if source == "training": - try: - from core.training import get_training_backend - - training_backend = get_training_backend() - if training_backend.is_training_active(): - raise HTTPException( - status_code = 409, - detail = "Cannot delete trained models while training is running", - ) - except HTTPException: - raise - except Exception as e: - logger.warning("Could not check training status before delete: %s", e) - raise HTTPException( - status_code = 500, - detail = "Could not verify training status before deleting", - ) from e - - try: - from routes.inference import get_llama_cpp_backend - - llama_backend = get_llama_cpp_backend() - if ( - llama_backend.is_active - and not llama_backend.is_loaded - and llama_backend.model_identifier - and _loaded_model_matches_deleted_path( - llama_backend.model_identifier, - target_path, - ) - and ( - not gguf_variant - or not llama_backend.hf_variant - or llama_backend.hf_variant.lower() == gguf_variant.lower() - ) - ): - raise HTTPException( - status_code = 409, - detail = "Cannot delete a model while it is loading", - ) - if ( - llama_backend.is_loaded - and llama_backend.model_identifier - and _loaded_model_matches_deleted_path( - llama_backend.model_identifier, - target_path, - ) - and ( - not gguf_variant - or not llama_backend.hf_variant - or llama_backend.hf_variant.lower() == gguf_variant.lower() - ) - ): - raise HTTPException( - status_code = 400, - detail = "Unload the model before deleting", - ) - except HTTPException: - raise - except Exception as e: - logger.warning("Could not check llama.cpp loaded model before delete: %s", e) - raise HTTPException( - status_code = 503, - detail = "Could not verify model load status before deleting", - ) from e - - try: - inference_backend = get_inference_backend() - loading_models = getattr(inference_backend, "loading_models", set()) - if any( - _loading_model_matches_deleted_path(loading_model, target_path) - for loading_model in loading_models - ): - raise HTTPException( - status_code = 409, - detail = "Cannot delete a model while it is loading", - ) - if inference_backend.active_model_name: - if _loaded_model_matches_deleted_path( - inference_backend.active_model_name, - target_path, - ): - raise HTTPException( - status_code = 400, - detail = "Unload the model before deleting", - ) - except HTTPException: - raise - except Exception as e: - logger.warning( - "Could not check inference backend loaded model before delete: %s", e - ) - raise HTTPException( - status_code = 503, - detail = "Could not verify model load status before deleting", - ) from e - - try: - if export_type == "gguf" and gguf_variant: - if not target_path.is_dir(): - raise HTTPException( - status_code = 400, - detail = "GGUF variant deletion requires an export directory", - ) - deleted_count, deleted_bytes = _delete_gguf_variant_files( - target_path, - gguf_variant, - ) - if deleted_count == 0: - raise HTTPException( - status_code = 404, - detail = f"Variant {gguf_variant} not found on disk", - ) - try: - if not any(target_path.iterdir()): - target_path.rmdir() - _prune_empty_parents(target_path, allowed_root) - except OSError: - pass - logger.info( - "Deleted %s GGUF file(s) for exported model at %s variant %s (%0.1f MB freed)", - deleted_count, - target_path, - gguf_variant, - deleted_bytes / (1024 * 1024), - ) - return { - "status": "deleted", - "path": str(target_path), - "gguf_variant": gguf_variant, - } - - if target_path.is_symlink() or target_path.is_file(): - target_path.unlink() - else: - shutil.rmtree(target_path) - - if target_path.exists() or target_path.is_symlink(): - raise HTTPException( - status_code = 500, - detail = "Deletion incomplete; some files could not be removed", - ) - - _prune_empty_parents(target_path, allowed_root) - - logger.info("Deleted fine-tuned model at %s", target_path) - return {"status": "deleted", "path": str(target_path)} - except HTTPException: - raise - except Exception as e: - logger.error( - "Error deleting fine-tuned model %s: %s", - target_path, - e, - exc_info = True, - ) - raise HTTPException( - status_code = 500, - detail = f"Failed to delete fine-tuned model: {str(e)}", - ) - - @router.get("/loras/{lora_path:path}/base-model", response_model = LoRABaseModelResponse) async def get_lora_base_model( lora_path: str, @@ -2181,7 +989,7 @@ async def get_gguf_variants( snapshots = entry / "snapshots" if snapshots.is_dir(): for snap in snapshots.iterdir(): - for f in _iter_gguf_paths(snap): + for f in snap.rglob("*.gguf"): q = _extract_quant_label(f.name) cached_bytes_by_quant[q] = ( cached_bytes_by_quant.get(q, 0) + f.stat().st_size @@ -2250,7 +1058,7 @@ async def get_gguf_download_progress( for entry in cache_dir.iterdir(): if entry.name.lower() == target: # Count completed .gguf files matching this variant in snapshots - for f in _iter_gguf_paths(entry): + for f in entry.rglob("*.gguf"): fname = f.name.lower().replace("-", "").replace("_", "") if not variant_lower or variant_lower in fname: downloaded_bytes += f.stat().st_size @@ -2280,25 +1088,6 @@ async def get_gguf_download_progress( return {"downloaded_bytes": 0, "expected_bytes": expected_bytes, "progress": 0} -def _resolve_hf_cache_realpath(repo_dir: Path) -> Optional[str]: - """Pick the most useful on-disk path for a HF cache repo. - - Prefers the most-recent snapshot dir (what `from_pretrained` actually - points at). Falls back to the cache repo root. Returns the resolved - realpath so symlinks under snapshots/ are followed back to blobs/. - """ - try: - snapshots_dir = repo_dir / "snapshots" - if snapshots_dir.is_dir(): - snaps = [s for s in snapshots_dir.iterdir() if s.is_dir()] - if snaps: - latest = max(snaps, key = lambda s: s.stat().st_mtime) - return str(latest.resolve()) - return str(repo_dir.resolve()) - except Exception: - return None - - @router.get("/download-progress") async def get_download_progress( repo_id: str = Query(..., description = "HuggingFace repo ID"), @@ -2309,16 +1098,8 @@ async def get_download_progress( Checks the local HF cache for completed blobs and in-progress (.incomplete) downloads. Uses the HF API to determine the expected total size on the first call, then caches it for subsequent polls. - Also returns ``cache_path``: the realpath of the snapshot directory - (or the cache repo root if no snapshot exists yet) so the UI can - show users where the weights actually live on disk. """ - _empty = { - "downloaded_bytes": 0, - "expected_bytes": 0, - "progress": 0, - "cache_path": None, - } + _empty = {"downloaded_bytes": 0, "expected_bytes": 0, "progress": 0} try: if not _is_valid_repo_id(repo_id): return _empty @@ -2329,12 +1110,10 @@ async def get_download_progress( target = f"models--{repo_id.replace('/', '--')}".lower() completed_bytes = 0 in_progress_bytes = 0 - cache_path: Optional[str] = None for entry in cache_dir.iterdir(): if entry.name.lower() != target: continue - cache_path = _resolve_hf_cache_realpath(entry) blobs_dir = entry / "blobs" if not blobs_dir.is_dir(): break @@ -2349,7 +1128,7 @@ async def get_download_progress( downloaded_bytes = completed_bytes + in_progress_bytes if downloaded_bytes == 0: - return {**_empty, "cache_path": cache_path} + return _empty # Get expected size from HF API (cached per repo_id) expected_bytes = _get_repo_size_cached(repo_id) @@ -2359,7 +1138,6 @@ async def get_download_progress( "downloaded_bytes": downloaded_bytes, "expected_bytes": 0, "progress": 0, - "cache_path": cache_path, } # Use 95% threshold for completion (blob deduplication can make @@ -2375,7 +1153,6 @@ async def get_download_progress( "downloaded_bytes": downloaded_bytes, "expected_bytes": expected_bytes, "progress": round(progress, 3), - "cache_path": cache_path, } except Exception as e: logger.warning(f"Error checking download progress for {repo_id}: {e}") @@ -2426,62 +1203,6 @@ def _all_hf_cache_scans(): return scans -def _is_gguf_filename(name: str) -> bool: - return name.lower().endswith(".gguf") - - -def _is_mmproj_filename(name: str) -> bool: - """Match GGUF vision-adapter (mmproj) files. Kept consistent with - ``utils.models.model_config._is_mmproj``.""" - return "mmproj" in name.lower() - - -def _is_main_gguf_filename(name: str) -> bool: - """A GGUF file that is a primary weight artifact, not an mmproj - vision adapter.""" - return _is_gguf_filename(name) and not _is_mmproj_filename(name) - - -def _iter_gguf_paths(root: Path): - for path in root.rglob("*"): - if path.is_file() and _is_gguf_filename(path.name): - yield path - - -def _repo_gguf_size_bytes(repo_info) -> int: - """Return the total on-disk size of primary GGUF weight files across - all revisions, excluding mmproj vision-adapter files. - - Hugging Face hardlinks blobs shared between revisions, so this - deduplicates by blob path (or, as a fallback, by revision commit - hash + filename) to avoid double-counting the same bytes. Files - with an unknown size (``size_on_disk is None``, e.g. a partial or - interrupted download) are treated as zero bytes. mmproj files are - excluded so that repos whose only ``.gguf`` artifact is a vision - adapter are not classified as GGUF repos: the variant selector - filters mmproj out and would otherwise show zero pickable variants. - """ - unique_blobs: dict[str, int] = {} - for revision in repo_info.revisions: - rev_id = getattr(revision, "commit_hash", None) or str(id(revision)) - for f in revision.files: - if _is_main_gguf_filename(f.file_name): - blob_path = getattr(f, "blob_path", None) - size = f.size_on_disk or 0 - if blob_path: - unique_blobs[str(blob_path)] = size - else: - unique_blobs[f"{rev_id}:{f.file_name}"] = size - return sum(unique_blobs.values()) - - -def _repo_has_gguf_files(repo_info) -> bool: - """Return True when any revision in a cached repo contains a - primary GGUF weight file. Repos whose only ``.gguf`` artifact is - an mmproj vision adapter are not treated as GGUF here.""" - return _repo_gguf_size_bytes(repo_info) > 0 - - @router.get("/cached-gguf") async def list_cached_gguf( current_subject: str = Depends(get_current_subject), @@ -2493,25 +1214,28 @@ async def list_cached_gguf( seen_lower: dict[str, dict] = {} for hf_cache in cache_scans: for repo_info in hf_cache.repos: - try: - if repo_info.repo_type != "model": - continue - repo_id = repo_info.repo_id - total_size = _repo_gguf_size_bytes(repo_info) - if total_size == 0: - continue - key = repo_id.lower() - existing = seen_lower.get(key) - if existing is None or total_size > existing["size_bytes"]: - seen_lower[key] = { - "repo_id": repo_id, - "size_bytes": total_size, - "cache_path": str(repo_info.repo_path), - } - except Exception as e: - repo_label = getattr(repo_info, "repo_id", "") - logger.warning(f"Skipping cached GGUF repo {repo_label}: {e}") + if repo_info.repo_type != "model": + continue + repo_id = repo_info.repo_id + if not repo_id.upper().endswith("-GGUF"): + continue + total_size = 0 + has_gguf = False + for revision in repo_info.revisions: + for f in revision.files: + if f.file_name.endswith(".gguf"): + has_gguf = True + total_size += f.size_on_disk + if not has_gguf: continue + key = repo_id.lower() + existing = seen_lower.get(key) + if existing is None or total_size > existing["size_bytes"]: + seen_lower[key] = { + "repo_id": repo_id, + "size_bytes": total_size, + "cache_path": str(repo_info.repo_path), + } cached = sorted(seen_lower.values(), key = lambda c: c["repo_id"]) return {"cached": cached} except Exception as e: @@ -2532,37 +1256,30 @@ async def list_cached_models( seen_lower: dict[str, dict] = {} for hf_cache in cache_scans: for repo_info in hf_cache.repos: - try: - if repo_info.repo_type != "model": - continue - repo_id = repo_info.repo_id - if _repo_has_gguf_files(repo_info): - continue - total_size = sum( - (f.size_on_disk or 0) - for rev in repo_info.revisions - for f in rev.files - ) - if total_size == 0: - continue - has_weights = any( - f.file_name.endswith(_WEIGHT_EXTENSIONS) - for rev in repo_info.revisions - for f in rev.files - ) - if not has_weights: - continue - key = repo_id.lower() - existing = seen_lower.get(key) - if existing is None or total_size > existing["size_bytes"]: - seen_lower[key] = { - "repo_id": repo_id, - "size_bytes": total_size, - } - except Exception as e: - repo_label = getattr(repo_info, "repo_id", "") - logger.warning(f"Skipping cached model repo {repo_label}: {e}") + if repo_info.repo_type != "model": + continue + repo_id = repo_info.repo_id + if repo_id.upper().endswith("-GGUF"): + continue + total_size = sum( + f.size_on_disk for rev in repo_info.revisions for f in rev.files + ) + if total_size == 0: + continue + has_weights = any( + f.file_name.endswith(_WEIGHT_EXTENSIONS) + for rev in repo_info.revisions + for f in rev.files + ) + if not has_weights: continue + key = repo_id.lower() + existing = seen_lower.get(key) + if existing is None or total_size > existing["size_bytes"]: + seen_lower[key] = { + "repo_id": repo_id, + "size_bytes": total_size, + } cached = sorted(seen_lower.values(), key = lambda c: c["repo_id"]) return {"cached": cached} except Exception as e: @@ -2639,7 +1356,7 @@ async def delete_cached_model( deleted_count = 0 for rev in target_repo.revisions: for f in rev.files: - if not _is_gguf_filename(f.file_name): + if not f.file_name.endswith(".gguf"): continue quant = _extract_quant_label(f.file_name) if quant.lower() != variant.lower(): diff --git a/studio/backend/routes/training.py b/studio/backend/routes/training.py index e5195bb337..e625408bad 100644 --- a/studio/backend/routes/training.py +++ b/studio/backend/routes/training.py @@ -25,12 +25,6 @@ # Import backend functions try: from core.training import get_training_backend - from core.training.resume import ( - can_resume_run, - get_resume_checkpoint_path, - normalize_resume_output_dir, - ) - from storage.studio_db import get_resumable_run_by_output_dir from utils.models.model_config import load_model_defaults from utils.paths import resolve_dataset_path except ImportError: @@ -39,12 +33,6 @@ if str(parent_backend) not in sys.path: sys.path.insert(0, str(parent_backend)) from core.training import get_training_backend - from core.training.resume import ( - can_resume_run, - get_resume_checkpoint_path, - normalize_resume_output_dir, - ) - from storage.studio_db import get_resumable_run_by_output_dir from utils.models.model_config import load_model_defaults from utils.paths import resolve_dataset_path @@ -164,28 +152,6 @@ async def start_training( request.local_eval_datasets = _validate_local_dataset_paths( request.local_eval_datasets, "Local eval dataset" ) - resume_output_dir: Optional[str] = None - if request.resume_from_checkpoint: - try: - resume_output_dir = normalize_resume_output_dir( - request.resume_from_checkpoint - ) - except ValueError as e: - raise HTTPException(status_code = 400, detail = str(e)) - - resume_run = get_resumable_run_by_output_dir(resume_output_dir) - if not resume_run or not can_resume_run(resume_run): - raise HTTPException( - status_code = 400, - detail = "Resume checkpoint must belong to a stopped run with saved trainer state.", - ) - resume_checkpoint = get_resume_checkpoint_path(resume_output_dir) - if not resume_checkpoint: - raise HTTPException( - status_code = 400, - detail = "Resume checkpoint must include saved trainer state.", - ) - request.resume_from_checkpoint = resume_checkpoint # Convert request to kwargs for backend training_kwargs = { @@ -243,8 +209,6 @@ async def start_training( "wandb_project": request.wandb_project or "", "enable_tensorboard": request.enable_tensorboard, "tensorboard_dir": request.tensorboard_dir or "", - "output_dir": resume_output_dir, - "resume_from_checkpoint": request.resume_from_checkpoint, "trust_remote_code": request.trust_remote_code, "gpu_ids": request.gpu_ids, } @@ -473,9 +437,6 @@ async def get_training_status( "loss": getattr(progress, "loss", None), "learning_rate": getattr(progress, "learning_rate", None), } - output_dir = getattr(backend, "_output_dir", None) - if output_dir: - details["output_dir"] = output_dir # Build metric history for chart recovery after SSE reconnection metric_history = None diff --git a/studio/backend/routes/training_history.py b/studio/backend/routes/training_history.py index 6f34321959..597c4424c0 100644 --- a/studio/backend/routes/training_history.py +++ b/studio/backend/routes/training_history.py @@ -11,7 +11,6 @@ from loggers import get_logger from auth.authentication import get_current_subject -from core.training.resume import can_resume_run from models import ( TrainingRunDeleteResponse, TrainingRunDetailResponse, @@ -35,10 +34,7 @@ async def list_training_runs( """List training runs, newest first.""" result = list_runs(limit = limit, offset = offset) return TrainingRunListResponse( - runs = [ - TrainingRunSummary(**{**r, "can_resume": can_resume_run(r)}) - for r in result["runs"] - ], + runs = [TrainingRunSummary(**r) for r in result["runs"]], total = result["total"], ) @@ -62,12 +58,7 @@ async def get_training_run_detail( metrics_data = get_run_metrics(run_id) return TrainingRunDetailResponse( - run = TrainingRunSummary( - **{ - **{k: v for k, v in run.items() if k != "config_json"}, - "can_resume": can_resume_run(run), - } - ), + run = TrainingRunSummary(**{k: v for k, v in run.items() if k != "config_json"}), config = config, metrics = TrainingRunMetrics(**metrics_data), ) diff --git a/studio/backend/run.py b/studio/backend/run.py index c5b103ff70..86c1194661 100644 --- a/studio/backend/run.py +++ b/studio/backend/run.py @@ -244,12 +244,10 @@ def _graceful_shutdown(server = None): def run_server( - host: str = "127.0.0.1", + host: str = "0.0.0.0", port: int = 8888, frontend_path: Path = Path(__file__).resolve().parent.parent / "frontend" / "dist", silent: bool = False, - api_only: bool = False, - llama_parallel_slots: int = 1, ): """ Start the FastAPI server. @@ -259,8 +257,6 @@ def run_server( port: Port to bind to (auto-increments if in use) frontend_path: Path to frontend build directory (optional) silent: Suppress startup messages - api_only: Run API server only, no frontend serving (for Tauri desktop app) - llama_parallel_slots: Number of parallel slots for llama-server Note: Signal handlers are NOT registered here so that embedders @@ -277,10 +273,6 @@ def run_server( except Exception: pass - # Set env var BEFORE importing main so CORS middleware picks it up - if api_only: - os.environ["UNSLOTH_API_ONLY"] = "1" - import nest_asyncio nest_asyncio.apply() @@ -316,12 +308,8 @@ def run_server( print("=" * 50) print("") - # Output port for Tauri to parse when in api-only mode - if api_only: - print(f"TAURI_PORT={port}", flush = True) - - # Setup frontend if path provided (skip in api-only mode) - if frontend_path and not api_only: + # Setup frontend if path provided + if frontend_path: if setup_frontend(app, frontend_path): if not silent: print(f"[OK] Frontend loaded from {frontend_path}") @@ -343,7 +331,6 @@ def run_server( # binds (port==0) leave it unset and let request handlers fall back # to the ASGI request scope or request.base_url. app.state.server_port = port if port and port > 0 else None - app.state.llama_parallel_slots = llama_parallel_slots # Run server in a daemon thread def _run(): @@ -392,11 +379,7 @@ def _trigger_shutdown(): pass parser = argparse.ArgumentParser(description = "Run Unsloth UI Backend server") - parser.add_argument( - "--host", - default = "127.0.0.1", - help = "Host to bind to (default: 127.0.0.1; use 0.0.0.0 for network/cloud access)", - ) + parser.add_argument("--host", default = "0.0.0.0", help = "Host to bind to") parser.add_argument("--port", type = int, default = 8888, help = "Port to bind to") parser.add_argument( "--frontend", @@ -405,17 +388,10 @@ def _trigger_shutdown(): help = "Path to frontend build", ) parser.add_argument("--silent", action = "store_true", help = "Suppress output") - parser.add_argument( - "--api-only", - action = "store_true", - help = "API server only, no frontend (for Tauri)", - ) args = parser.parse_args() - kwargs = dict( - host = args.host, port = args.port, silent = args.silent, api_only = args.api_only - ) + kwargs = dict(host = args.host, port = args.port, silent = args.silent) if args.frontend is not None: kwargs["frontend_path"] = Path(args.frontend) diff --git a/studio/backend/state/tool_policy.py b/studio/backend/state/tool_policy.py deleted file mode 100644 index 9343a39806..0000000000 --- a/studio/backend/state/tool_policy.py +++ /dev/null @@ -1,33 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. - -"""Process-level server-side tool policy. - -Set by `unsloth run` at startup; consulted by the inference route gates. - - None -> no CLI override (default). Per-request `enable_tools` is honored. - True -> CLI forced tools on for every request. - False -> CLI forced tools off for every request. -""" - -from typing import Optional - -_tool_policy: Optional[bool] = None - - -def get_tool_policy() -> Optional[bool]: - return _tool_policy - - -def set_tool_policy(value: Optional[bool]) -> None: - if value is not None and not isinstance(value, bool): - raise TypeError( - f"tool_policy must be Optional[bool], got {type(value).__name__}" - ) - global _tool_policy - _tool_policy = value - - -def reset_tool_policy() -> None: - global _tool_policy - _tool_policy = None diff --git a/studio/backend/storage/studio_db.py b/studio/backend/storage/studio_db.py index 29e787c196..89f75632ef 100644 --- a/studio/backend/storage/studio_db.py +++ b/studio/backend/storage/studio_db.py @@ -267,23 +267,10 @@ def list_runs(limit: int = 50, offset: int = 0) -> dict: total = conn.execute("SELECT COUNT(*) FROM training_runs").fetchone()[0] rows = conn.execute( """ - SELECT r.id, r.status, r.model_name, r.dataset_name, r.started_at, - r.ended_at, r.total_steps, r.final_step, r.final_loss, - r.output_dir, r.duration_seconds, r.error_message, - r.loss_sparkline, - CASE - WHEN r.status = 'stopped' - AND r.output_dir IS NOT NULL - AND EXISTS ( - SELECT 1 - FROM training_runs newer - WHERE newer.output_dir = r.output_dir - AND newer.status IN ('stopped', 'completed') - AND newer.started_at > r.started_at - ) - THEN 1 ELSE 0 - END AS resumed_later - FROM training_runs r + SELECT id, status, model_name, dataset_name, started_at, ended_at, + total_steps, final_step, final_loss, output_dir, + duration_seconds, error_message, loss_sparkline + FROM training_runs ORDER BY started_at DESC LIMIT ? OFFSET ? """, @@ -310,26 +297,7 @@ def list_runs(limit: int = 50, offset: int = 0) -> dict: def get_run(id: str) -> Optional[dict]: conn = get_connection() try: - row = conn.execute( - """ - SELECT r.*, - CASE - WHEN r.status = 'stopped' - AND r.output_dir IS NOT NULL - AND EXISTS ( - SELECT 1 - FROM training_runs newer - WHERE newer.output_dir = r.output_dir - AND newer.status IN ('stopped', 'completed') - AND newer.started_at > r.started_at - ) - THEN 1 ELSE 0 - END AS resumed_later - FROM training_runs r - WHERE r.id = ? - """, - (id,), - ).fetchone() + row = conn.execute("SELECT * FROM training_runs WHERE id = ?", (id,)).fetchone() if row is None: return None run = dict(row) @@ -345,45 +313,6 @@ def get_run(id: str) -> Optional[dict]: conn.close() -def get_resumable_run_by_output_dir(output_dir: str) -> Optional[dict]: - conn = get_connection() - try: - row = conn.execute( - """ - SELECT r.*, - 0 AS resumed_later - FROM training_runs r - WHERE r.output_dir = ? - AND r.status = 'stopped' - AND NOT EXISTS ( - SELECT 1 - FROM training_runs newer - WHERE newer.output_dir = r.output_dir - AND newer.status IN ('stopped', 'completed') - AND newer.started_at > r.started_at - ) - ORDER BY r.started_at DESC - LIMIT 1 - """, - (output_dir,), - ).fetchone() - if row is None: - return None - run = dict(row) - sparkline = run.get("loss_sparkline") - if sparkline: - try: - run["loss_sparkline"] = json.loads(sparkline) - except (json.JSONDecodeError, TypeError): - logger.debug( - "Failed to parse loss_sparkline for output_dir %s", output_dir - ) - run["loss_sparkline"] = None - return run - finally: - conn.close() - - def get_run_metrics(id: str) -> dict: """Return metric arrays for a run, using paired step arrays per metric.""" conn = get_connection() diff --git a/studio/backend/tests/conftest.py b/studio/backend/tests/conftest.py index 6aa6d314c1..053e9b85d9 100644 --- a/studio/backend/tests/conftest.py +++ b/studio/backend/tests/conftest.py @@ -3,136 +3,14 @@ """ Shared pytest configuration for the backend test suite. - -Responsibilities: -1. Put the backend root on sys.path so `from models.inference import ...` - (and similar flat imports) resolve in test modules — mirrors how the - app itself is launched. -2. Provide a hybrid ``studio_server`` session fixture for end-to-end tests - (see ``test_studio_api.py``). The fixture supports two invocation modes: - - a. **External server.** If ``UNSLOTH_E2E_BASE_URL`` is set, tests point - at an already-running Studio instance. ``UNSLOTH_E2E_API_KEY`` must - also be set. This is the fast-iteration mode: start the server once - with ``unsloth studio run ...``, then run pytest against it many - times with no per-run GGUF load cost. - - b. **Fixture-managed server.** Otherwise, the fixture launches a fresh - server via ``_start_server`` and tears it down at session end. This - is the one-shot mode for CI or a clean-slate verification run. - - The model / variant for mode (b) come from ``--unsloth-model`` / - ``--unsloth-gguf-variant`` pytest options, then ``UNSLOTH_E2E_MODEL`` / - ``UNSLOTH_E2E_VARIANT`` env vars, then the defaults in - ``test_studio_api.py``. +Ensures that the backend root is on sys.path so that +`import utils.utils` (and similar flat imports) resolve correctly. """ -import os import sys from pathlib import Path -import pytest - # Add backend root to sys.path (mirrors how the app itself is launched) _backend_root = Path(__file__).resolve().parent.parent if str(_backend_root) not in sys.path: sys.path.insert(0, str(_backend_root)) - - -# ── Pytest CLI options ─────────────────────────────────────────────── - - -def pytest_addoption(parser): - group = parser.getgroup( - "unsloth-e2e", - "Unsloth Studio end-to-end test options", - ) - group.addoption( - "--unsloth-model", - action = "store", - default = None, - help = ( - "GGUF model id used when starting a server for e2e tests. " - "Ignored if UNSLOTH_E2E_BASE_URL is set. Overrides " - "UNSLOTH_E2E_MODEL env var. Defaults to test_studio_api.py's " - "DEFAULT_MODEL." - ), - ) - group.addoption( - "--unsloth-gguf-variant", - action = "store", - default = None, - help = ( - "GGUF variant used when starting a server for e2e tests. " - "Ignored if UNSLOTH_E2E_BASE_URL is set. Overrides " - "UNSLOTH_E2E_VARIANT env var. Defaults to test_studio_api.py's " - "DEFAULT_VARIANT." - ), - ) - - -# ── E2E server fixtures ────────────────────────────────────────────── - - -@pytest.fixture(scope = "session") -def studio_server(request): - """Yield ``(base_url, api_key)`` for e2e tests. - - Resolution order: - - 1. If ``UNSLOTH_E2E_BASE_URL`` is set → point at that server, - require ``UNSLOTH_E2E_API_KEY`` alongside (skip if missing). - 2. Otherwise → start a fresh ``unsloth studio run`` subprocess via - the existing ``_start_server`` helper in ``test_studio_api.py`` - and tear it down on session teardown. - - Session-scoped so the expensive GGUF load happens at most once per - pytest invocation. Lazily instantiated — tests that don't request - the fixture (e.g. the unit tests in ``test_anthropic_messages.py`` - or ``test_help_output``) do not trigger server startup. - """ - external_url = os.environ.get("UNSLOTH_E2E_BASE_URL") - if external_url: - api_key = os.environ.get("UNSLOTH_E2E_API_KEY") - if not api_key: - pytest.skip( - "UNSLOTH_E2E_BASE_URL is set but UNSLOTH_E2E_API_KEY is " - "missing — tests that require auth cannot run against an " - "external server without it.", - ) - yield external_url, api_key - return - - # Lazy import: pytest has already loaded test_studio_api into - # sys.modules by the time any test requests this fixture, so this - # is a cache hit, not a re-execution. - import test_studio_api as _e2e - - model = ( - request.config.getoption("--unsloth-model") - or os.environ.get("UNSLOTH_E2E_MODEL") - or _e2e.DEFAULT_MODEL - ) - variant = ( - request.config.getoption("--unsloth-gguf-variant") - or os.environ.get("UNSLOTH_E2E_VARIANT") - or _e2e.DEFAULT_VARIANT - ) - - proc, api_key = _e2e._start_server(model, variant) - try: - yield f"http://{_e2e.HOST}:{_e2e.PORT}", api_key - finally: - _e2e._kill_server(proc) - - -@pytest.fixture -def base_url(studio_server): - """Base URL for the e2e Studio server (from ``studio_server``).""" - return studio_server[0] - - -@pytest.fixture -def api_key(studio_server): - """API key for the e2e Studio server (from ``studio_server``).""" - return studio_server[1] diff --git a/studio/backend/tests/test_anthropic_messages.py b/studio/backend/tests/test_anthropic_messages.py deleted file mode 100644 index 0825ef9337..0000000000 --- a/studio/backend/tests/test_anthropic_messages.py +++ /dev/null @@ -1,1013 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. - -""" -Tests for the Anthropic Messages API schemas and translation layer. -No running server or GPU required. -""" - -import sys -import os -import json - -import pytest - -_backend = os.path.join(os.path.dirname(__file__), "..") -sys.path.insert(0, _backend) - -from models.inference import ( - AnthropicMessagesRequest, - AnthropicMessagesResponse, - AnthropicMessage, - AnthropicTextBlock, - AnthropicToolUseBlock, - AnthropicToolResultBlock, - AnthropicTool, - AnthropicUsage, - AnthropicResponseTextBlock, - AnthropicResponseToolUseBlock, -) -from core.inference.anthropic_compat import ( - anthropic_messages_to_openai, - anthropic_tools_to_openai, - build_anthropic_sse_event, - AnthropicStreamEmitter, - AnthropicPassthroughEmitter, -) -from routes.inference import _normalize_anthropic_openai_images -from fastapi import HTTPException -import base64 as _b64 -from io import BytesIO as _BytesIO - - -# ===================================================================== -# Pydantic model tests -# ===================================================================== - - -class TestAnthropicModels: - def test_minimal_request(self): - req = AnthropicMessagesRequest( - messages = [{"role": "user", "content": "Hi"}], - ) - assert req.max_tokens is None - assert req.model == "default" - assert req.stream is False - - def test_max_tokens_optional(self): - req = AnthropicMessagesRequest( - max_tokens = 100, - messages = [{"role": "user", "content": "Hi"}], - ) - assert req.max_tokens == 100 - - def test_system_as_string(self): - req = AnthropicMessagesRequest( - max_tokens = 50, - messages = [{"role": "user", "content": "Hi"}], - system = "You are helpful.", - ) - assert req.system == "You are helpful." - - def test_tools_field_parses(self): - req = AnthropicMessagesRequest( - max_tokens = 100, - messages = [{"role": "user", "content": "Hi"}], - tools = [{"name": "web_search", "input_schema": {"type": "object"}}], - ) - assert len(req.tools) == 1 - assert req.tools[0].name == "web_search" - - def test_extra_fields_accepted(self): - req = AnthropicMessagesRequest( - max_tokens = 100, - messages = [{"role": "user", "content": "Hi"}], - some_future_field = "hello", - ) - assert req.max_tokens == 100 - - def test_stream_defaults_false(self): - req = AnthropicMessagesRequest( - max_tokens = 100, - messages = [{"role": "user", "content": "Hi"}], - ) - assert req.stream is False - - def test_enable_tools_shorthand(self): - req = AnthropicMessagesRequest( - messages = [{"role": "user", "content": "Hi"}], - enable_tools = True, - enabled_tools = ["web_search", "python"], - session_id = "my-session", - ) - assert req.enable_tools is True - assert req.enabled_tools == ["web_search", "python"] - assert req.session_id == "my-session" - - def test_extension_fields_default_none(self): - req = AnthropicMessagesRequest( - messages = [{"role": "user", "content": "Hi"}], - ) - assert req.enable_tools is None - assert req.enabled_tools is None - assert req.session_id is None - - def test_response_model_defaults(self): - resp = AnthropicMessagesResponse() - assert resp.type == "message" - assert resp.role == "assistant" - assert resp.id.startswith("msg_") - assert resp.content == [] - assert resp.usage.input_tokens == 0 - - -# ===================================================================== -# Message translation tests -# ===================================================================== - - -class TestAnthropicMessagesToOpenAI: - def test_simple_user_message(self): - msgs = [{"role": "user", "content": "Hello"}] - result = anthropic_messages_to_openai(msgs) - assert result == [{"role": "user", "content": "Hello"}] - - def test_system_string_prepended(self): - msgs = [{"role": "user", "content": "Hello"}] - result = anthropic_messages_to_openai(msgs, system = "Be brief.") - assert result[0] == {"role": "system", "content": "Be brief."} - assert result[1] == {"role": "user", "content": "Hello"} - - def test_system_as_block_list(self): - system = [ - {"type": "text", "text": "Be brief."}, - {"type": "text", "text": "Be accurate."}, - ] - msgs = [{"role": "user", "content": "Hello"}] - result = anthropic_messages_to_openai(msgs, system = system) - assert result[0]["role"] == "system" - assert "Be brief." in result[0]["content"] - assert "Be accurate." in result[0]["content"] - - def test_multi_turn_conversation(self): - msgs = [ - {"role": "user", "content": "Hi"}, - {"role": "assistant", "content": "Hello!"}, - {"role": "user", "content": "How are you?"}, - ] - result = anthropic_messages_to_openai(msgs) - assert len(result) == 3 - assert result[0]["role"] == "user" - assert result[1]["role"] == "assistant" - assert result[2]["role"] == "user" - - def test_assistant_tool_use_maps_to_tool_calls(self): - msgs = [ - { - "role": "assistant", - "content": [ - {"type": "text", "text": "Let me search."}, - { - "type": "tool_use", - "id": "tu_1", - "name": "web_search", - "input": {"query": "test"}, - }, - ], - } - ] - result = anthropic_messages_to_openai(msgs) - assert len(result) == 1 - m = result[0] - assert m["role"] == "assistant" - assert m["content"] == "Let me search." - assert len(m["tool_calls"]) == 1 - tc = m["tool_calls"][0] - assert tc["id"] == "tu_1" - assert tc["function"]["name"] == "web_search" - assert json.loads(tc["function"]["arguments"]) == {"query": "test"} - - def test_tool_result_maps_to_tool_role(self): - msgs = [ - { - "role": "user", - "content": [ - { - "type": "tool_result", - "tool_use_id": "tu_1", - "content": "Result text", - }, - ], - } - ] - result = anthropic_messages_to_openai(msgs) - assert len(result) == 1 - assert result[0]["role"] == "tool" - assert result[0]["tool_call_id"] == "tu_1" - assert result[0]["content"] == "Result text" - - def test_mixed_text_and_tool_use_blocks(self): - msgs = [ - { - "role": "assistant", - "content": [ - {"type": "text", "text": "Thinking..."}, - { - "type": "tool_use", - "id": "tu_1", - "name": "python", - "input": {"code": "1+1"}, - }, - { - "type": "tool_use", - "id": "tu_2", - "name": "terminal", - "input": {"command": "ls"}, - }, - ], - } - ] - result = anthropic_messages_to_openai(msgs) - assert len(result) == 1 - m = result[0] - assert m["content"] == "Thinking..." - assert len(m["tool_calls"]) == 2 - - def test_tool_result_with_list_content(self): - msgs = [ - { - "role": "user", - "content": [ - { - "type": "tool_result", - "tool_use_id": "tu_1", - "content": [ - {"type": "text", "text": "Line 1"}, - {"type": "text", "text": "Line 2"}, - ], - }, - ], - } - ] - result = anthropic_messages_to_openai(msgs) - assert result[0]["content"] == "Line 1 Line 2" - - def test_image_base64_block_becomes_multimodal_part(self): - msgs = [ - { - "role": "user", - "content": [ - {"type": "text", "text": "What is this?"}, - { - "type": "image", - "source": { - "type": "base64", - "media_type": "image/jpeg", - "data": "AAAA", - }, - }, - ], - } - ] - result = anthropic_messages_to_openai(msgs) - assert len(result) == 1 - assert result[0]["role"] == "user" - parts = result[0]["content"] - assert isinstance(parts, list) - assert parts[0] == {"type": "text", "text": "What is this?"} - assert parts[1]["type"] == "image_url" - assert parts[1]["image_url"]["url"] == "data:image/jpeg;base64,AAAA" - - def test_image_url_block_forwarded_as_url(self): - msgs = [ - { - "role": "user", - "content": [ - {"type": "text", "text": "Describe it"}, - { - "type": "image", - "source": {"type": "url", "url": "https://x/y.png"}, - }, - ], - } - ] - result = anthropic_messages_to_openai(msgs) - parts = result[0]["content"] - assert parts[1] == { - "type": "image_url", - "image_url": {"url": "https://x/y.png"}, - } - - def test_image_only_user_message_emits_no_text_part(self): - msgs = [ - { - "role": "user", - "content": [ - { - "type": "image", - "source": { - "type": "base64", - "media_type": "image/png", - "data": "ZZ", - }, - }, - ], - } - ] - result = anthropic_messages_to_openai(msgs) - parts = result[0]["content"] - assert len(parts) == 1 - assert parts[0]["type"] == "image_url" - - def test_image_default_media_type_when_missing(self): - msgs = [ - { - "role": "user", - "content": [ - { - "type": "image", - "source": {"type": "base64", "data": "BB"}, - }, - ], - } - ] - result = anthropic_messages_to_openai(msgs) - parts = result[0]["content"] - assert parts[0]["image_url"]["url"].startswith("data:image/jpeg;base64,") - - def test_image_text_order_preserved(self): - # [text1, image1, text2, image2] must not collapse to - # [text1+text2, image1, image2]. - msgs = [ - { - "role": "user", - "content": [ - {"type": "text", "text": "before"}, - { - "type": "image", - "source": { - "type": "base64", - "media_type": "image/png", - "data": "AA", - }, - }, - {"type": "text", "text": "after"}, - { - "type": "image", - "source": {"type": "url", "url": "https://x/y.png"}, - }, - ], - } - ] - result = anthropic_messages_to_openai(msgs) - parts = result[0]["content"] - assert [p["type"] for p in parts] == [ - "text", - "image_url", - "text", - "image_url", - ] - assert parts[0]["text"] == "before" - assert parts[2]["text"] == "after" - assert parts[1]["image_url"]["url"] == "data:image/png;base64,AA" - assert parts[3]["image_url"]["url"] == "https://x/y.png" - - def test_malformed_image_block_is_skipped(self): - msgs = [ - { - "role": "user", - "content": [ - {"type": "text", "text": "Hi"}, - {"type": "image", "source": {"type": "base64"}}, - {"type": "image", "source": {"type": "url"}}, - ], - } - ] - result = anthropic_messages_to_openai(msgs) - # No image parts emitted; message falls back to plain text. - assert result[0] == {"role": "user", "content": "Hi"} - - -# ===================================================================== -# Tool translation tests -# ===================================================================== - - -class TestAnthropicToolsToOpenAI: - def test_single_tool(self): - tools = [ - { - "name": "web_search", - "description": "Search", - "input_schema": { - "type": "object", - "properties": {"query": {"type": "string"}}, - }, - } - ] - result = anthropic_tools_to_openai(tools) - assert len(result) == 1 - assert result[0]["type"] == "function" - assert result[0]["function"]["name"] == "web_search" - assert result[0]["function"]["parameters"]["type"] == "object" - - def test_multiple_tools(self): - tools = [ - {"name": "a", "description": "Tool A", "input_schema": {}}, - {"name": "b", "description": "Tool B", "input_schema": {}}, - ] - result = anthropic_tools_to_openai(tools) - assert len(result) == 2 - assert result[0]["function"]["name"] == "a" - assert result[1]["function"]["name"] == "b" - - def test_empty_list(self): - assert anthropic_tools_to_openai([]) == [] - - def test_pydantic_model_input(self): - tool = AnthropicTool( - name = "test", description = "desc", input_schema = {"type": "object"} - ) - result = anthropic_tools_to_openai([tool]) - assert result[0]["function"]["name"] == "test" - - -# ===================================================================== -# SSE event helper tests -# ===================================================================== - - -class TestBuildAnthropicSSEEvent: - def test_basic_event(self): - result = build_anthropic_sse_event("message_start", {"type": "message_start"}) - assert result.startswith("event: message_start\n") - assert "data: " in result - assert result.endswith("\n\n") - - def test_data_is_valid_json(self): - result = build_anthropic_sse_event("test", {"key": "value"}) - data_line = result.split("\n")[1] - payload = json.loads(data_line.removeprefix("data: ")) - assert payload == {"key": "value"} - - -# ===================================================================== -# Stream emitter tests -# ===================================================================== - - -class TestAnthropicStreamEmitter: - def test_start_emits_message_start_and_content_block_start(self): - e = AnthropicStreamEmitter() - events = e.start("msg_123", "test-model") - assert len(events) == 2 - assert "message_start" in events[0] - assert "content_block_start" in events[1] - assert '"type": "text"' in events[1] - - def test_content_delta_emits_text_delta(self): - e = AnthropicStreamEmitter() - e.start("msg_1", "m") - events = e.feed({"type": "content", "text": "Hello"}) - assert len(events) == 1 - parsed = json.loads(events[0].split("data: ")[1]) - assert parsed["delta"]["type"] == "text_delta" - assert parsed["delta"]["text"] == "Hello" - - def test_cumulative_content_diffs_correctly(self): - e = AnthropicStreamEmitter() - e.start("msg_1", "m") - e.feed({"type": "content", "text": "Hel"}) - events = e.feed({"type": "content", "text": "Hello"}) - parsed = json.loads(events[0].split("data: ")[1]) - assert parsed["delta"]["text"] == "lo" - - def test_empty_content_diff_no_event(self): - e = AnthropicStreamEmitter() - e.start("msg_1", "m") - e.feed({"type": "content", "text": "Hi"}) - events = e.feed({"type": "content", "text": "Hi"}) - assert events == [] - - def test_tool_start_closes_text_opens_tool_block(self): - e = AnthropicStreamEmitter() - e.start("msg_1", "m") - e.feed({"type": "content", "text": "Thinking"}) - events = e.feed( - { - "type": "tool_start", - "tool_name": "web_search", - "tool_call_id": "tc_1", - "arguments": {"query": "test"}, - } - ) - # content_block_stop + content_block_start(tool_use) + content_block_delta(input_json) - assert len(events) == 3 - assert "content_block_stop" in events[0] - assert "tool_use" in events[1] - assert "input_json_delta" in events[2] - - def test_tool_end_closes_tool_opens_new_text_block(self): - e = AnthropicStreamEmitter() - e.start("msg_1", "m") - e.feed( - { - "type": "tool_start", - "tool_name": "t", - "tool_call_id": "tc_1", - "arguments": {}, - } - ) - events = e.feed( - { - "type": "tool_end", - "tool_name": "t", - "tool_call_id": "tc_1", - "result": "done", - } - ) - # content_block_stop (tool) + tool_result + content_block_start (new text) - assert len(events) == 3 - assert "content_block_stop" in events[0] - assert "tool_result" in events[1] - parsed = json.loads(events[1].split("data: ")[1]) - assert parsed["content"] == "done" - assert parsed["tool_use_id"] == "tc_1" - assert "content_block_start" in events[2] - assert '"type": "text"' in events[2] - - def test_finish_emits_stop_events(self): - e = AnthropicStreamEmitter() - e.start("msg_1", "m") - events = e.finish("end_turn") - # content_block_stop + message_delta + message_stop - assert len(events) == 3 - assert "content_block_stop" in events[0] - assert "message_delta" in events[1] - assert "end_turn" in events[1] - assert "message_stop" in events[2] - - def test_metadata_captured_in_finish_usage(self): - e = AnthropicStreamEmitter() - e.start("msg_1", "m") - e.feed( - { - "type": "metadata", - "usage": {"prompt_tokens": 10, "completion_tokens": 20}, - } - ) - events = e.finish("end_turn") - delta_event = [ev for ev in events if "message_delta" in ev][0] - parsed = json.loads(delta_event.split("data: ")[1]) - assert parsed["usage"]["output_tokens"] == 20 - - def test_status_events_ignored(self): - e = AnthropicStreamEmitter() - e.start("msg_1", "m") - events = e.feed({"type": "status", "text": "Searching..."}) - assert events == [] - - def test_no_tool_calls_simple_text_flow(self): - e = AnthropicStreamEmitter() - start_events = e.start("msg_1", "m") - content_events = e.feed({"type": "content", "text": "Hello world"}) - meta_events = e.feed( - {"type": "metadata", "usage": {"prompt_tokens": 5, "completion_tokens": 2}} - ) - end_events = e.finish("end_turn") - - assert len(start_events) == 2 - assert len(content_events) == 1 - assert meta_events == [] - assert len(end_events) == 3 - - def test_block_index_increments(self): - e = AnthropicStreamEmitter() - e.start("msg_1", "m") - assert e.block_index == 0 - e.feed( - { - "type": "tool_start", - "tool_name": "t", - "tool_call_id": "tc_1", - "arguments": {}, - } - ) - assert e.block_index == 1 - e.feed( - { - "type": "tool_end", - "tool_name": "t", - "tool_call_id": "tc_1", - "result": "ok", - } - ) - assert e.block_index == 2 - - def test_text_after_tool_resets_prev_text(self): - e = AnthropicStreamEmitter() - e.start("msg_1", "m") - e.feed({"type": "content", "text": "Before tool"}) - e.feed( - { - "type": "tool_start", - "tool_name": "t", - "tool_call_id": "tc_1", - "arguments": {}, - } - ) - e.feed( - { - "type": "tool_end", - "tool_name": "t", - "tool_call_id": "tc_1", - "result": "ok", - } - ) - # After tool_end, prev_text should be reset - events = e.feed({"type": "content", "text": "After tool"}) - parsed = json.loads(events[0].split("data: ")[1]) - assert parsed["delta"]["text"] == "After tool" - - -# ===================================================================== -# Pass-through emitter tests (client-side tool execution path) -# ===================================================================== - - -class TestAnthropicPassthroughEmitter: - def _parse(self, event_str): - return json.loads(event_str.split("data: ")[1]) - - def test_start_emits_message_start_only(self): - e = AnthropicPassthroughEmitter() - events = e.start("msg_1", "test-model") - assert len(events) == 1 - assert "message_start" in events[0] - parsed = self._parse(events[0]) - assert parsed["message"]["id"] == "msg_1" - assert parsed["message"]["model"] == "test-model" - - def test_text_chunk_opens_text_block_and_emits_delta(self): - e = AnthropicPassthroughEmitter() - e.start("msg_1", "m") - chunk = {"choices": [{"delta": {"content": "Hello"}}]} - events = e.feed_chunk(chunk) - # content_block_start + content_block_delta - assert len(events) == 2 - assert "content_block_start" in events[0] - assert '"type": "text"' in events[0] - delta = self._parse(events[1]) - assert delta["delta"]["type"] == "text_delta" - assert delta["delta"]["text"] == "Hello" - - def test_sequential_text_chunks_single_block(self): - e = AnthropicPassthroughEmitter() - e.start("msg_1", "m") - events1 = e.feed_chunk({"choices": [{"delta": {"content": "Hello"}}]}) - events2 = e.feed_chunk({"choices": [{"delta": {"content": " world"}}]}) - # First chunk opens the block, second only emits delta - assert len(events1) == 2 - assert len(events2) == 1 - assert self._parse(events2[0])["delta"]["text"] == " world" - - def test_tool_call_opens_tool_use_block(self): - e = AnthropicPassthroughEmitter() - e.start("msg_1", "m") - chunk = { - "choices": [ - { - "delta": { - "tool_calls": [ - { - "index": 0, - "id": "call_1", - "type": "function", - "function": {"name": "Bash", "arguments": ""}, - } - ] - } - } - ] - } - events = e.feed_chunk(chunk) - assert len(events) == 1 - parsed = self._parse(events[0]) - assert parsed["type"] == "content_block_start" - assert parsed["content_block"]["type"] == "tool_use" - assert parsed["content_block"]["id"] == "call_1" - assert parsed["content_block"]["name"] == "Bash" - - def test_tool_call_arguments_streamed_as_input_json_delta(self): - e = AnthropicPassthroughEmitter() - e.start("msg_1", "m") - # Open the tool call - e.feed_chunk( - { - "choices": [ - { - "delta": { - "tool_calls": [ - { - "index": 0, - "id": "c1", - "type": "function", - "function": {"name": "Bash", "arguments": ""}, - } - ] - } - } - ] - } - ) - # Stream argument fragments - events1 = e.feed_chunk( - { - "choices": [ - { - "delta": { - "tool_calls": [ - {"index": 0, "function": {"arguments": '{"cmd'}} - ] - } - } - ] - } - ) - events2 = e.feed_chunk( - { - "choices": [ - { - "delta": { - "tool_calls": [ - {"index": 0, "function": {"arguments": '": "ls"}'}} - ] - } - } - ] - } - ) - parsed1 = self._parse(events1[0]) - parsed2 = self._parse(events2[0]) - assert parsed1["delta"]["type"] == "input_json_delta" - assert parsed1["delta"]["partial_json"] == '{"cmd' - assert parsed2["delta"]["partial_json"] == '": "ls"}' - - def test_text_then_tool_closes_text_block(self): - e = AnthropicPassthroughEmitter() - e.start("msg_1", "m") - e.feed_chunk({"choices": [{"delta": {"content": "Let me check."}}]}) - events = e.feed_chunk( - { - "choices": [ - { - "delta": { - "tool_calls": [ - { - "index": 0, - "id": "c1", - "type": "function", - "function": {"name": "Bash", "arguments": ""}, - } - ] - } - } - ] - } - ) - # Should close text block and open tool_use block - assert "content_block_stop" in events[0] - assert "content_block_start" in events[1] - assert '"type": "tool_use"' in events[1] - - def test_finish_reason_tool_calls_sets_tool_use_stop(self): - e = AnthropicPassthroughEmitter() - e.start("msg_1", "m") - e.feed_chunk( - { - "choices": [ - { - "delta": { - "tool_calls": [ - { - "index": 0, - "id": "c1", - "type": "function", - "function": {"name": "Bash", "arguments": "{}"}, - } - ] - } - } - ] - } - ) - e.feed_chunk({"choices": [{"delta": {}, "finish_reason": "tool_calls"}]}) - events = e.finish() - delta_event = [ev for ev in events if "message_delta" in ev][0] - parsed = self._parse(delta_event) - assert parsed["delta"]["stop_reason"] == "tool_use" - - def test_finish_reason_stop_sets_end_turn(self): - e = AnthropicPassthroughEmitter() - e.start("msg_1", "m") - e.feed_chunk({"choices": [{"delta": {"content": "Hi"}}]}) - e.feed_chunk({"choices": [{"delta": {}, "finish_reason": "stop"}]}) - events = e.finish() - delta_event = [ev for ev in events if "message_delta" in ev][0] - parsed = self._parse(delta_event) - assert parsed["delta"]["stop_reason"] == "end_turn" - - def test_finish_reason_length_sets_max_tokens(self): - e = AnthropicPassthroughEmitter() - e.start("msg_1", "m") - e.feed_chunk({"choices": [{"delta": {"content": "Hi"}}]}) - e.feed_chunk({"choices": [{"delta": {}, "finish_reason": "length"}]}) - events = e.finish() - delta_event = [ev for ev in events if "message_delta" in ev][0] - parsed = self._parse(delta_event) - assert parsed["delta"]["stop_reason"] == "max_tokens" - - def test_finish_closes_current_block(self): - e = AnthropicPassthroughEmitter() - e.start("msg_1", "m") - e.feed_chunk({"choices": [{"delta": {"content": "Hi"}}]}) - events = e.finish() - assert "content_block_stop" in events[0] - assert "message_delta" in events[1] - assert "message_stop" in events[2] - - def test_usage_chunk_captured(self): - e = AnthropicPassthroughEmitter() - e.start("msg_1", "m") - e.feed_chunk({"choices": [{"delta": {"content": "Hi"}}]}) - e.feed_chunk( - { - "choices": [], - "usage": {"prompt_tokens": 10, "completion_tokens": 5}, - } - ) - events = e.finish() - delta_event = [ev for ev in events if "message_delta" in ev][0] - parsed = self._parse(delta_event) - assert parsed["usage"]["output_tokens"] == 5 - - def test_empty_chunk_returns_no_events(self): - e = AnthropicPassthroughEmitter() - e.start("msg_1", "m") - events = e.feed_chunk({"choices": []}) - assert events == [] - - def test_no_blocks_at_all_still_produces_valid_finish(self): - e = AnthropicPassthroughEmitter() - e.start("msg_1", "m") - events = e.finish() - # No content_block_stop because no block was opened - assert not any("content_block_stop" in ev for ev in events) - assert any("message_delta" in ev for ev in events) - assert any("message_stop" in ev for ev in events) - - def test_multiple_tool_calls_distinct_blocks(self): - e = AnthropicPassthroughEmitter() - e.start("msg_1", "m") - # First tool call - e.feed_chunk( - { - "choices": [ - { - "delta": { - "tool_calls": [ - { - "index": 0, - "id": "c1", - "type": "function", - "function": {"name": "Bash", "arguments": "{}"}, - } - ] - } - } - ] - } - ) - # Second tool call (different index) - events = e.feed_chunk( - { - "choices": [ - { - "delta": { - "tool_calls": [ - { - "index": 1, - "id": "c2", - "type": "function", - "function": {"name": "Read", "arguments": "{}"}, - } - ] - } - } - ] - } - ) - # Should close block 0, open block 1 - assert "content_block_stop" in events[0] - assert "content_block_start" in events[1] - parsed = self._parse(events[1]) - assert parsed["content_block"]["name"] == "Read" - assert parsed["content_block"]["id"] == "c2" - - -# ===================================================================== -# Vision guard + PNG normalization (/v1/messages) -# ===================================================================== - - -def _jpeg_data_url() -> str: - from PIL import Image - - img = Image.new("RGB", (2, 2), (255, 0, 0)) - buf = _BytesIO() - img.save(buf, format = "JPEG") - b64 = _b64.b64encode(buf.getvalue()).decode("ascii") - return f"data:image/jpeg;base64,{b64}" - - -class TestNormalizeAnthropicOpenAIImages: - def test_noop_when_no_images(self): - msgs = [{"role": "user", "content": "hi"}] - has_image = _normalize_anthropic_openai_images(msgs, is_vision = False) - assert has_image is False - assert msgs == [{"role": "user", "content": "hi"}] - - def test_returns_true_when_image_present(self): - msgs = [ - { - "role": "user", - "content": [ - {"type": "image_url", "image_url": {"url": _jpeg_data_url()}}, - ], - } - ] - assert _normalize_anthropic_openai_images(msgs, is_vision = True) is True - - def test_rejects_image_when_model_not_vision(self): - msgs = [ - { - "role": "user", - "content": [ - {"type": "text", "text": "?"}, - { - "type": "image_url", - "image_url": {"url": _jpeg_data_url()}, - }, - ], - } - ] - with pytest.raises(HTTPException) as exc: - _normalize_anthropic_openai_images(msgs, is_vision = False) - assert exc.value.status_code == 400 - - def test_reencodes_jpeg_data_url_to_png(self): - original_url = _jpeg_data_url() - msgs = [ - { - "role": "user", - "content": [ - {"type": "text", "text": "?"}, - {"type": "image_url", "image_url": {"url": original_url}}, - ], - } - ] - _normalize_anthropic_openai_images(msgs, is_vision = True) - new_url = msgs[0]["content"][1]["image_url"]["url"] - assert new_url.startswith("data:image/png;base64,") - assert new_url != original_url - - def test_remote_url_left_unchanged(self): - msgs = [ - { - "role": "user", - "content": [ - { - "type": "image_url", - "image_url": {"url": "https://x.example/y.png"}, - }, - ], - } - ] - _normalize_anthropic_openai_images(msgs, is_vision = True) - assert msgs[0]["content"][0]["image_url"]["url"] == "https://x.example/y.png" - - def test_bad_base64_raises_400(self): - msgs = [ - { - "role": "user", - "content": [ - { - "type": "image_url", - "image_url": {"url": "data:image/jpeg;base64,!!!not-b64!!!"}, - }, - ], - } - ] - with pytest.raises(HTTPException) as exc: - _normalize_anthropic_openai_images(msgs, is_vision = True) - assert exc.value.status_code == 400 diff --git a/studio/backend/tests/test_browse_folders_route.py b/studio/backend/tests/test_browse_folders_route.py deleted file mode 100644 index 19a83987d3..0000000000 --- a/studio/backend/tests/test_browse_folders_route.py +++ /dev/null @@ -1,86 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 - -import os -import sys -import types -from pathlib import Path - -import pytest -from fastapi import HTTPException - -# Keep this test runnable in lightweight environments where optional logging -# deps are not installed. -if "structlog" not in sys.modules: - - class _DummyLogger: - def __getattr__(self, _name): - return lambda *args, **kwargs: None - - sys.modules["structlog"] = types.SimpleNamespace( - BoundLogger = _DummyLogger, - get_logger = lambda *args, **kwargs: _DummyLogger(), - ) - -import routes.models as models_route - - -def test_resolve_browse_target_returns_allowed_directory(tmp_path): - allowed = tmp_path / "allowed" - target = allowed / "models" / "nested" - target.mkdir(parents = True) - - resolved = models_route._resolve_browse_target(str(target), [allowed]) - - assert resolved == target.resolve() - - -def test_resolve_browse_target_rejects_outside_allowlist(tmp_path): - allowed = tmp_path / "allowed" - disallowed = tmp_path / "disallowed" - allowed.mkdir() - disallowed.mkdir() - - with pytest.raises(HTTPException) as exc_info: - models_route._resolve_browse_target(str(disallowed), [allowed]) - - assert exc_info.value.status_code == 403 - - -def test_resolve_browse_target_rejects_file_path(tmp_path): - allowed = tmp_path / "allowed" - allowed.mkdir() - model_file = allowed / "model.gguf" - model_file.write_text("gguf") - - with pytest.raises(HTTPException) as exc_info: - models_route._resolve_browse_target(str(model_file), [allowed]) - - assert exc_info.value.status_code == 400 - - -def test_resolve_browse_target_allows_symlink_into_other_allowed_root(tmp_path): - home_root = tmp_path / "home" - scan_root = tmp_path / "scan" - target = scan_root / "nested" - home_root.mkdir() - target.mkdir(parents = True) - (home_root / "scan-link").symlink_to(scan_root, target_is_directory = True) - - resolved = models_route._resolve_browse_target( - str(home_root / "scan-link" / "nested"), - [home_root, scan_root], - ) - - assert resolved == target.resolve() - - -@pytest.mark.skipif(os.altsep is not None, reason = "POSIX-only path semantics") -def test_resolve_browse_target_allows_backslash_in_posix_segment(tmp_path): - allowed = tmp_path / "allowed" - target = allowed / r"dir\name" - target.mkdir(parents = True) - - resolved = models_route._resolve_browse_target(str(target), [allowed]) - - assert resolved == target.resolve() diff --git a/studio/backend/tests/test_cached_gguf_routes.py b/studio/backend/tests/test_cached_gguf_routes.py deleted file mode 100644 index 05aae8fb75..0000000000 --- a/studio/backend/tests/test_cached_gguf_routes.py +++ /dev/null @@ -1,398 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 - -import asyncio -import sys -import types -from pathlib import Path -from types import SimpleNamespace - -# Keep this test runnable in lightweight environments where optional logging -# deps are not installed. -if "structlog" not in sys.modules: - - class _DummyLogger: - def __getattr__(self, _name): - return lambda *args, **kwargs: None - - sys.modules["structlog"] = types.SimpleNamespace( - BoundLogger = _DummyLogger, - get_logger = lambda *args, **kwargs: _DummyLogger(), - ) - -import routes.models as models_route - - -def _repo( - repo_id: str, - files: list[SimpleNamespace], - repo_path: Path, - *, - revisions: list[SimpleNamespace] | None = None, -) -> SimpleNamespace: - return SimpleNamespace( - repo_id = repo_id, - repo_type = "model", - repo_path = repo_path, - revisions = revisions or [SimpleNamespace(files = files)], - ) - - -def _file( - name: str, - size_on_disk: int, - *, - blob_path: str | None = None, -) -> SimpleNamespace: - return SimpleNamespace( - file_name = name, - size_on_disk = size_on_disk, - blob_path = blob_path, - ) - - -def test_iter_gguf_paths_matches_extension_case_insensitively(tmp_path): - nested = tmp_path / "snapshots" / "rev" - nested.mkdir(parents = True) - lower = nested / "Q4_K_M.gguf" - upper = nested / "Q8_0.GGUF" - other = nested / "README.md" - lower.write_text("a") - upper.write_text("b") - other.write_text("c") - - result = sorted(path.name for path in models_route._iter_gguf_paths(tmp_path)) - - assert result == ["Q4_K_M.gguf", "Q8_0.GGUF"] - - -def test_list_cached_gguf_includes_non_suffix_repo_when_cache_contains_gguf( - monkeypatch, tmp_path -): - repo = _repo( - "HauhauCS/Gemma-4-E4B-Uncensored-HauhauCS-Aggressive", - [_file("Q4_K_M.gguf", 5_000), _file("README.md", 10)], - tmp_path / "models--HauhauCS--Gemma", - ) - scan = SimpleNamespace(repos = [repo]) - - monkeypatch.setattr(models_route, "_all_hf_cache_scans", lambda: [scan]) - - result = asyncio.run(models_route.list_cached_gguf(current_subject = "test-user")) - - assert result["cached"] == [ - { - "repo_id": "HauhauCS/Gemma-4-E4B-Uncensored-HauhauCS-Aggressive", - "size_bytes": 5_000, - "cache_path": str(repo.repo_path), - } - ] - - -def test_list_cached_gguf_matches_extension_case_insensitively(monkeypatch, tmp_path): - repo = _repo( - "Org/Model-Without-Suffix", - [_file("Q8_0.GGUF", 7_000)], - tmp_path / "models--Org--Model-Without-Suffix", - ) - scan = SimpleNamespace(repos = [repo]) - - monkeypatch.setattr(models_route, "_all_hf_cache_scans", lambda: [scan]) - - result = asyncio.run(models_route.list_cached_gguf(current_subject = "test-user")) - - assert result["cached"] == [ - { - "repo_id": "Org/Model-Without-Suffix", - "size_bytes": 7_000, - "cache_path": str(repo.repo_path), - } - ] - - -def test_list_cached_gguf_skips_repos_without_positive_gguf_size(monkeypatch, tmp_path): - missing = _repo( - "Org/ReadmeOnly", - [_file("README.md", 10)], - tmp_path / "models--Org--ReadmeOnly", - ) - zero = _repo( - "Org/ZeroSize", - [_file("Q4_K_M.gguf", 0)], - tmp_path / "models--Org--ZeroSize", - ) - scan = SimpleNamespace(repos = [missing, zero]) - - monkeypatch.setattr(models_route, "_all_hf_cache_scans", lambda: [scan]) - - result = asyncio.run(models_route.list_cached_gguf(current_subject = "test-user")) - - assert result["cached"] == [] - - -def test_list_cached_gguf_keeps_largest_duplicate_repo_across_scans( - monkeypatch, tmp_path -): - smaller = _repo( - "Org/Dupe", - [_file("Q4_K_M.gguf", 2_000)], - tmp_path / "models--Org--Dupe-a", - ) - larger = _repo( - "org/dupe", - [_file("Q4_K_M.gguf", 5_000), _file("Q6_K.gguf", 1_000)], - tmp_path / "models--Org--Dupe-b", - ) - - monkeypatch.setattr( - models_route, - "_all_hf_cache_scans", - lambda: [ - SimpleNamespace(repos = [smaller]), - SimpleNamespace(repos = [larger]), - ], - ) - - result = asyncio.run(models_route.list_cached_gguf(current_subject = "test-user")) - - assert result["cached"] == [ - { - "repo_id": "org/dupe", - "size_bytes": 6_000, - "cache_path": str(larger.repo_path), - } - ] - - -def test_list_cached_gguf_dedupes_shared_blobs_across_revisions(monkeypatch, tmp_path): - shared = "blobs/shared-q4" - repo = _repo( - "Org/SharedBlobRepo", - [], - tmp_path / "models--Org--SharedBlobRepo", - revisions = [ - SimpleNamespace(files = [_file("Q4_K_M.gguf", 5_000, blob_path = shared)]), - SimpleNamespace(files = [_file("Q4_K_M.gguf", 5_000, blob_path = shared)]), - ], - ) - - monkeypatch.setattr( - models_route, - "_all_hf_cache_scans", - lambda: [SimpleNamespace(repos = [repo])], - ) - - result = asyncio.run(models_route.list_cached_gguf(current_subject = "test-user")) - - assert result["cached"] == [ - { - "repo_id": "Org/SharedBlobRepo", - "size_bytes": 5_000, - "cache_path": str(repo.repo_path), - } - ] - - -def test_list_cached_models_skips_non_suffix_repo_when_gguf_files_exist( - monkeypatch, tmp_path -): - mixed = _repo( - "Org/MixedRepo", - [ - _file("Q4_K_M.gguf", 5_000), - _file("model.safetensors", 10_000), - ], - tmp_path / "models--Org--MixedRepo", - ) - - monkeypatch.setattr( - models_route, - "_all_hf_cache_scans", - lambda: [SimpleNamespace(repos = [mixed])], - ) - - result = asyncio.run(models_route.list_cached_models(current_subject = "test-user")) - - assert result["cached"] == [] - - -def test_list_cached_gguf_includes_mixed_repo_with_gguf_and_safetensors( - monkeypatch, tmp_path -): - """Mirror of the _skips_ test: the mixed repo should still surface in - cached-gguf so the picker can show it as a GGUF download.""" - mixed = _repo( - "Org/MixedRepo", - [ - _file("Q4_K_M.gguf", 5_000), - _file("model.safetensors", 10_000), - ], - tmp_path / "models--Org--MixedRepo", - ) - - monkeypatch.setattr( - models_route, - "_all_hf_cache_scans", - lambda: [SimpleNamespace(repos = [mixed])], - ) - - result = asyncio.run(models_route.list_cached_gguf(current_subject = "test-user")) - - assert result["cached"] == [ - { - "repo_id": "Org/MixedRepo", - "size_bytes": 5_000, - "cache_path": str(mixed.repo_path), - } - ] - - -def test_list_cached_gguf_handles_none_size_on_disk(monkeypatch, tmp_path): - """A partial/interrupted GGUF download has ``size_on_disk = None``. The - route must treat the unknown bytes as zero instead of raising TypeError - out of ``sum()`` and wiping the entire response.""" - partial = _repo( - "Org/PartialDownload", - [_file("Q4_K_M.gguf", None), _file("Q6_K.gguf", 5_000)], - tmp_path / "models--Org--PartialDownload", - ) - - monkeypatch.setattr( - models_route, - "_all_hf_cache_scans", - lambda: [SimpleNamespace(repos = [partial])], - ) - - result = asyncio.run(models_route.list_cached_gguf(current_subject = "test-user")) - - assert result["cached"] == [ - { - "repo_id": "Org/PartialDownload", - "size_bytes": 5_000, - "cache_path": str(partial.repo_path), - } - ] - - -def test_list_cached_gguf_skips_malformed_repo_without_wiping_response( - monkeypatch, tmp_path -): - """One repo raising during classification must not poison the response - for every other repo in the scan.""" - - class _ExplodingRepo: - repo_id = "Org/Broken" - repo_type = "model" - repo_path = tmp_path / "models--Org--Broken" - - @property - def revisions(self): - raise RuntimeError("boom") - - healthy = _repo( - "Org/Healthy", - [_file("Q4_K_M.gguf", 5_000)], - tmp_path / "models--Org--Healthy", - ) - - monkeypatch.setattr( - models_route, - "_all_hf_cache_scans", - lambda: [SimpleNamespace(repos = [_ExplodingRepo(), healthy])], - ) - - result = asyncio.run(models_route.list_cached_gguf(current_subject = "test-user")) - - assert result["cached"] == [ - { - "repo_id": "Org/Healthy", - "size_bytes": 5_000, - "cache_path": str(healthy.repo_path), - } - ] - - -def test_list_cached_gguf_skips_repo_with_only_mmproj_gguf(monkeypatch, tmp_path): - """A repo whose only ``.gguf`` artifact is an mmproj vision adapter - must not be classified as a GGUF repo: the variant selector filters - mmproj out and the picker would otherwise show zero variants.""" - mmproj_only = _repo( - "Org/MmprojOnly", - [ - _file("mmproj-Q8_0.gguf", 5_000), - _file("model.safetensors", 10_000), - ], - tmp_path / "models--Org--MmprojOnly", - ) - - monkeypatch.setattr( - models_route, - "_all_hf_cache_scans", - lambda: [SimpleNamespace(repos = [mmproj_only])], - ) - - result = asyncio.run(models_route.list_cached_gguf(current_subject = "test-user")) - - assert result["cached"] == [] - - -def test_list_cached_models_includes_repo_with_only_mmproj_gguf(monkeypatch, tmp_path): - """Mirror of the cached-gguf skip: a safetensors repo with an - auxiliary mmproj vision adapter must still surface in cached-models - so the user can load it as a normal model.""" - mmproj_aux = _repo( - "Org/MmprojAux", - [ - _file("mmproj-Q8_0.gguf", 5_000), - _file("model.safetensors", 10_000), - ], - tmp_path / "models--Org--MmprojAux", - ) - - monkeypatch.setattr( - models_route, - "_all_hf_cache_scans", - lambda: [SimpleNamespace(repos = [mmproj_aux])], - ) - - result = asyncio.run(models_route.list_cached_models(current_subject = "test-user")) - - assert result["cached"] == [ - { - "repo_id": "Org/MmprojAux", - "size_bytes": 15_000, - } - ] - - -def test_list_cached_gguf_includes_vision_repo_with_main_gguf_and_mmproj( - monkeypatch, tmp_path -): - """A vision-capable GGUF repo (main weight + mmproj adapter) is still - a GGUF repo. The reported size is the main weight size; mmproj is - excluded from the GGUF-size accounting because it is filtered out at - classification time.""" - vision_repo = _repo( - "Org/VisionGguf", - [ - _file("Q4_K_M.gguf", 5_000), - _file("mmproj-Q8_0.gguf", 1_000), - ], - tmp_path / "models--Org--VisionGguf", - ) - - monkeypatch.setattr( - models_route, - "_all_hf_cache_scans", - lambda: [SimpleNamespace(repos = [vision_repo])], - ) - - result = asyncio.run(models_route.list_cached_gguf(current_subject = "test-user")) - - assert result["cached"] == [ - { - "repo_id": "Org/VisionGguf", - "size_bytes": 5_000, - "cache_path": str(vision_repo.repo_path), - } - ] diff --git a/studio/backend/tests/test_data_recipe_github_progress.py b/studio/backend/tests/test_data_recipe_github_progress.py deleted file mode 100644 index 8e8c3995f4..0000000000 --- a/studio/backend/tests/test_data_recipe_github_progress.py +++ /dev/null @@ -1,91 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 - -from core.data_recipe.jobs.parse import apply_update, parse_log_message -from core.data_recipe.jobs.types import Job -from routes.data_recipe.validate import _GITHUB_VALIDATE_NOTE, validate -from models.data_recipe import RecipePayload - - -def test_github_page_log_updates_source_progress_without_cursor(): - job = Job(job_id = "job-1") - job.source_progress_estimated_total = 200 - - update = parse_log_message( - "[unslothai/unsloth] issues page 2 (+15) cursor=abc123 remaining=2960" - ) - - assert update is not None - apply_update(job, update) - - progress = job.source_progress - assert progress is not None - assert progress.source == "github" - assert progress.status == "fetching" - assert progress.repo == "unslothai/unsloth" - assert progress.resource == "issues" - assert progress.page == 2 - assert progress.page_items == 15 - assert progress.fetched_items == 15 - assert progress.estimated_total == 200 - assert progress.rate_remaining == 2960 - assert progress.message is not None - assert "cursor" not in progress.message - assert "abc123" not in progress.message - - -def test_github_rate_limit_log_updates_source_progress(): - job = Job(job_id = "job-1") - - update = parse_log_message("Rate limit hit. Sleeping 123s until reset.") - - assert update is not None - apply_update(job, update) - - progress = job.source_progress - assert progress is not None - assert progress.status == "rate_limited" - assert progress.retry_after_sec == 123 - assert "resume automatically" in (progress.message or "") - - -def test_github_real_sample_prs_and_trial_limit_are_parsed(): - job = Job(job_id = "job-1") - - for message in ( - "[unslothai/unsloth] PRs page 4 (+25) cursor=abc123 remaining=4983", - "Trial limit reached for PRs (100)", - ): - update = parse_log_message(message) - assert update is not None - apply_update(job, update) - - progress = job.source_progress - assert progress is not None - assert progress.repo == "unslothai/unsloth" - assert progress.resource == "pulls" - assert progress.page == 4 - assert progress.fetched_items == 25 - assert progress.rate_remaining == 4983 - assert progress.message == "GitHub pulls trial limit reached (100)." - - -def test_github_validate_skips_live_access_with_honest_note(): - response = validate( - RecipePayload( - recipe = { - "seed_config": { - "source": { - "seed_type": "github_repo", - "repos": ["unslothai/unsloth"], - "item_types": ["issues"], - "limit": 1, - } - }, - "columns": [{"column_type": "expression", "name": "x", "expr": "1"}], - } - ) - ) - - assert response.valid is True - assert response.raw_detail == _GITHUB_VALIDATE_NOTE diff --git a/studio/backend/tests/test_desktop_auth.py b/studio/backend/tests/test_desktop_auth.py deleted file mode 100644 index a5508c1c8b..0000000000 --- a/studio/backend/tests/test_desktop_auth.py +++ /dev/null @@ -1,598 +0,0 @@ -import importlib.util -import asyncio -import hashlib -import json -import os -import platform -import secrets -import sqlite3 -import subprocess -import sys -from pathlib import Path -from types import SimpleNamespace - -import jwt -import pytest -from fastapi import APIRouter, FastAPI -from fastapi.security import HTTPAuthorizationCredentials -from fastapi.testclient import TestClient - -from auth import storage - - -@pytest.fixture(autouse = True) -def isolated_auth_db(tmp_path, monkeypatch): - monkeypatch.setattr(storage, "DB_PATH", tmp_path / "auth.db") - monkeypatch.setattr(storage, "_BOOTSTRAP_PW_PATH", tmp_path / ".bootstrap_password") - monkeypatch.setattr(storage, "_bootstrap_password", None) - monkeypatch.setattr(storage, "_api_key_pbkdf2_salt_cache", None) - yield - - -def seed_user(*, must_change_password = False): - storage.create_initial_user( - username = storage.DEFAULT_ADMIN_USERNAME, - password = "human-password-123", - jwt_secret = secrets.token_urlsafe(64), - must_change_password = must_change_password, - ) - - -def auth_client(): - route_path = Path(__file__).resolve().parents[1] / "routes" / "auth.py" - spec = importlib.util.spec_from_file_location("_desktop_auth_route", route_path) - auth_route = importlib.util.module_from_spec(spec) - assert spec.loader is not None - spec.loader.exec_module(auth_route) - - app = FastAPI() - app.include_router(auth_route.router, prefix = "/api/auth") - return TestClient(app) - - -def data_recipe_jobs_module(): - route_path = ( - Path(__file__).resolve().parents[1] / "routes" / "data_recipe" / "jobs.py" - ) - spec = importlib.util.spec_from_file_location( - "_desktop_data_recipe_jobs", route_path - ) - jobs_route = importlib.util.module_from_spec(spec) - assert spec.loader is not None - spec.loader.exec_module(jobs_route) - return jobs_route - - -def local_recipe(): - return { - "model_providers": [{"name": "local", "is_local": True}], - "model_configs": [{"alias": "local-model", "provider": "local"}], - "columns": [{"column_type": "llm-text", "model_alias": "local-model"}], - } - - -def local_recipe_request(token): - return SimpleNamespace( - headers = {"authorization": f"Bearer {token}"}, - app = SimpleNamespace(state = SimpleNamespace(server_port = 8888)), - scope = {}, - base_url = "http://testserver/", - ) - - -@pytest.fixture -def loaded_local_model(monkeypatch): - inference_module = SimpleNamespace( - get_llama_cpp_backend = lambda: SimpleNamespace(is_loaded = True), - ) - monkeypatch.setitem(sys.modules, "routes.inference", inference_module) - - -def test_desktop_secret_round_trip_uses_real_admin_subject(): - seed_user() - raw = storage.create_desktop_secret() - - assert raw.startswith("desktop-") - assert storage.validate_desktop_secret(raw) == storage.DEFAULT_ADMIN_USERNAME - assert storage.validate_desktop_secret(raw + "x") is None - - -def test_create_desktop_secret_rotates_old_secret(): - seed_user() - old = storage.create_desktop_secret() - new = storage.create_desktop_secret() - - assert old != new - assert storage.validate_desktop_secret(old) is None - assert storage.validate_desktop_secret(new) == storage.DEFAULT_ADMIN_USERNAME - - -def test_clear_desktop_secret_invalidates_secret(): - seed_user() - raw = storage.create_desktop_secret() - - storage.clear_desktop_secret() - - assert storage.validate_desktop_secret(raw) is None - - -def test_ensure_default_admin_does_not_recreate_bootstrap_for_existing_admin(): - seed_user() - - created = storage.ensure_default_admin() - - assert created is False - assert not storage._BOOTSTRAP_PW_PATH.exists() - - -def test_ensure_default_admin_loads_existing_bootstrap_after_restart(monkeypatch): - created = storage.ensure_default_admin() - bootstrap_pw = storage._BOOTSTRAP_PW_PATH.read_text().strip() - - monkeypatch.setattr(storage, "_bootstrap_password", None) - created_again = storage.ensure_default_admin() - - assert created is True - assert storage._BOOTSTRAP_PW_PATH.exists() - assert created_again is False - assert storage.get_bootstrap_password() == bootstrap_pw - - -def test_ensure_default_admin_does_not_generate_for_empty_existing_bootstrap(): - seed_user() - storage._BOOTSTRAP_PW_PATH.write_text(" \n") - - created = storage.ensure_default_admin() - - assert created is False - assert storage._BOOTSTRAP_PW_PATH.read_text() == " \n" - assert storage.get_bootstrap_password() is None - - -def test_web_login_token_has_no_desktop_marker_and_keeps_password_gate(): - seed_user(must_change_password = True) - client = auth_client() - - response = client.post( - "/api/auth/login", - json = { - "username": storage.DEFAULT_ADMIN_USERNAME, - "password": "human-password-123", - }, - ) - - assert response.status_code == 200 - body = response.json() - assert body["must_change_password"] is True - payload = jwt.decode( - body["access_token"], - storage.get_jwt_secret(storage.DEFAULT_ADMIN_USERNAME), - algorithms = ["HS256"], - ) - assert payload["sub"] == storage.DEFAULT_ADMIN_USERNAME - assert "desktop" not in payload - - gated = client.post( - "/api/auth/api-keys", - headers = {"Authorization": f"Bearer {body['access_token']}"}, - json = {"name": "web"}, - ) - assert gated.status_code == 403 - - -def test_desktop_login_mints_admin_token_without_clearing_web_password_change(): - seed_user(must_change_password = True) - raw = storage.create_desktop_secret() - client = auth_client() - - response = client.post("/api/auth/desktop-login", json = {"secret": raw}) - - assert response.status_code == 200 - body = response.json() - assert body["access_token"] - assert body["refresh_token"] - assert body["token_type"] == "bearer" - assert body["must_change_password"] is False - assert storage.requires_password_change(storage.DEFAULT_ADMIN_USERNAME) is True - - payload = jwt.decode( - body["access_token"], - storage.get_jwt_secret(storage.DEFAULT_ADMIN_USERNAME), - algorithms = ["HS256"], - ) - assert payload["sub"] == storage.DEFAULT_ADMIN_USERNAME - assert payload["desktop"] is True - - -def test_desktop_refresh_preserves_desktop_marker(): - seed_user(must_change_password = True) - raw = storage.create_desktop_secret() - client = auth_client() - login_body = client.post("/api/auth/desktop-login", json = {"secret": raw}).json() - - response = client.post( - "/api/auth/refresh", - json = {"refresh_token": login_body["refresh_token"]}, - ) - - assert response.status_code == 200 - body = response.json() - assert body["must_change_password"] is False - payload = jwt.decode( - body["access_token"], - storage.get_jwt_secret(storage.DEFAULT_ADMIN_USERNAME), - algorithms = ["HS256"], - ) - assert payload["sub"] == storage.DEFAULT_ADMIN_USERNAME - assert payload["desktop"] is True - - -def test_desktop_session_uses_real_admin_identity_for_api_keys(): - seed_user(must_change_password = True) - raw = storage.create_desktop_secret() - client = auth_client() - token = client.post("/api/auth/desktop-login", json = {"secret": raw}).json()[ - "access_token" - ] - - response = client.post( - "/api/auth/api-keys", - headers = {"Authorization": f"Bearer {token}"}, - json = {"name": "desktop"}, - ) - - assert response.status_code == 200 - rows = storage.list_api_keys(storage.DEFAULT_ADMIN_USERNAME) - assert [row["name"] for row in rows] == ["desktop"] - - -def test_local_recipe_token_authenticates_as_admin_for_desktop_user(loaded_local_model): - # _inject_local_providers mints an internal sk-unsloth-* API key (not a - # forwarded JWT). The unified API-key path validates as the real admin - # user regardless of whether the incoming session was desktop or web. - from auth.authentication import create_access_token, get_current_subject - - seed_user(must_change_password = True) - jobs_route = data_recipe_jobs_module() - incoming_token = create_access_token( - subject = storage.DEFAULT_ADMIN_USERNAME, - desktop = True, - ) - recipe = local_recipe() - - jobs_route._inject_local_providers(recipe, local_recipe_request(incoming_token)) - - local_token = recipe["model_providers"][0]["api_key"] - assert local_token.startswith(storage.API_KEY_PREFIX) - credentials = HTTPAuthorizationCredentials( - scheme = "Bearer", - credentials = local_token, - ) - assert ( - asyncio.run(get_current_subject(credentials)) == storage.DEFAULT_ADMIN_USERNAME - ) - - -def test_local_recipe_token_authenticates_as_admin_for_web_user(loaded_local_model): - # Mirror of the desktop variant: API-key issuance is identical for web - # and desktop incoming tokens; auth via get_current_subject works the same. - from auth.authentication import create_access_token, get_current_subject - - seed_user(must_change_password = False) - jobs_route = data_recipe_jobs_module() - incoming_token = create_access_token(subject = storage.DEFAULT_ADMIN_USERNAME) - recipe = local_recipe() - - jobs_route._inject_local_providers(recipe, local_recipe_request(incoming_token)) - - local_token = recipe["model_providers"][0]["api_key"] - assert local_token.startswith(storage.API_KEY_PREFIX) - credentials = HTTPAuthorizationCredentials( - scheme = "Bearer", - credentials = local_token, - ) - assert ( - asyncio.run(get_current_subject(credentials)) == storage.DEFAULT_ADMIN_USERNAME - ) - - -def test_desktop_login_rejects_invalid_secret(): - seed_user(must_change_password = False) - client = auth_client() - - response = client.post( - "/api/auth/desktop-login", - json = {"secret": "desktop-invalid"}, - ) - - assert response.status_code == 401 - - -def test_write_desktop_secret_file_is_0600_on_unix(tmp_path): - from unsloth_cli.commands import studio as studio_cli - - path = tmp_path / ".desktop_secret" - if platform.system() != "Windows": - path.write_text("old-secret") - os.chmod(path, 0o644) - - studio_cli._write_auth_secret(path, "desktop-secret") - - assert path.read_text() == "desktop-secret" - if platform.system() != "Windows": - assert oct(path.stat().st_mode & 0o777) == "0o600" - - -def test_reset_password_removes_desktop_secret_files(tmp_path, monkeypatch): - from typer.testing import CliRunner - from unsloth_cli.commands import studio as studio_cli - - auth_dir = tmp_path / "auth" - auth_dir.mkdir() - (auth_dir / "auth.db").write_text("db") - (auth_dir / ".bootstrap_password").write_text("boot") - (auth_dir / ".desktop_secret").write_text("new") - monkeypatch.setattr(studio_cli, "STUDIO_HOME", tmp_path) - - result = CliRunner().invoke(studio_cli.studio_app, ["reset-password"]) - - assert result.exit_code == 0 - assert not (auth_dir / "auth.db").exists() - assert not (auth_dir / ".bootstrap_password").exists() - assert not (auth_dir / ".desktop_secret").exists() - - -def test_reset_password_removes_desktop_secret_files_without_db(tmp_path, monkeypatch): - from typer.testing import CliRunner - from unsloth_cli.commands import studio as studio_cli - - auth_dir = tmp_path / "auth" - auth_dir.mkdir() - (auth_dir / ".desktop_secret").write_text("new") - monkeypatch.setattr(studio_cli, "STUDIO_HOME", tmp_path) - - result = CliRunner().invoke(studio_cli.studio_app, ["reset-password"]) - - assert result.exit_code == 0 - assert not (auth_dir / ".desktop_secret").exists() - - -def test_desktop_capabilities_json_reports_rollout_safe_flags(): - from typer.testing import CliRunner - import unsloth_cli.commands.studio as studio_cli - - result = CliRunner().invoke( - studio_cli.studio_app, - ["desktop-capabilities", "--json"], - ) - - assert result.exit_code == 0 - body = json.loads(result.output) - assert body["desktop_protocol_version"] == 1 - assert body["supports_provision_desktop_auth"] is True - assert body["supports_api_only"] is True - assert isinstance(body["version"], str) - - -def test_health_response_reports_desktop_capability_fields(monkeypatch): - router_stub = SimpleNamespace( - auth_router = APIRouter(), - data_recipe_router = APIRouter(), - datasets_router = APIRouter(), - export_router = APIRouter(), - inference_router = APIRouter(), - inference_studio_router = APIRouter(), - models_router = APIRouter(), - training_history_router = APIRouter(), - training_router = APIRouter(), - ) - monkeypatch.setitem(sys.modules, "routes", router_stub) - - import studio.backend.main as backend_main - - monkeypatch.setattr(backend_main._hw_module, "CHAT_ONLY", False) - - body = asyncio.run(backend_main.health_check()) - - assert body["desktop_protocol_version"] == 1 - assert body["supports_desktop_auth"] is True - - -def test_provision_desktop_auth_writes_secret_and_creates_db_without_backend_deps( - tmp_path, - monkeypatch, -): - auth_dir = tmp_path / "auth" - auth_dir.mkdir() - - code = """ -import builtins -import sys -from pathlib import Path -from typer.testing import CliRunner - -studio_home = Path(sys.argv[1]) -real_import = builtins.__import__ - -def guarded_import(name, *args, **kwargs): - blocked = ("auth", "fastapi", "structlog", "utils") - if name in blocked or name.startswith(("auth.", "utils.")): - raise ModuleNotFoundError(name) - return real_import(name, *args, **kwargs) - -builtins.__import__ = guarded_import -from unsloth_cli.commands import studio as studio_cli - -studio_cli.STUDIO_HOME = studio_home -result = CliRunner().invoke(studio_cli.studio_app, ["provision-desktop-auth"]) -if result.exit_code != 0: - print(result.output) - if result.exception is not None: - raise result.exception - raise SystemExit(result.exit_code) -""" - result = subprocess.run( - [sys.executable, "-c", code, str(tmp_path)], - cwd = Path(__file__).resolve().parents[3], - env = {**os.environ, "PYTHONPATH": "."}, - text = True, - capture_output = True, - ) - assert result.returncode == 0, result.stderr + result.stdout - secret = (auth_dir / ".desktop_secret").read_text() - assert secret.startswith("desktop-") - - conn = sqlite3.connect(auth_dir / "auth.db") - conn.row_factory = sqlite3.Row - try: - user = conn.execute( - """ - SELECT username, password_salt, password_hash, must_change_password - FROM auth_user - """ - ).fetchone() - app_secrets = { - row["key"]: row["value"] - for row in conn.execute("SELECT key, value FROM app_secrets") - } - refresh_columns = { - row["name"] for row in conn.execute("PRAGMA table_info(refresh_tokens)") - } - finally: - conn.close() - - bootstrap_password = (auth_dir / ".bootstrap_password").read_text().strip() - bootstrap_hash = hashlib.pbkdf2_hmac( - "sha256", - bootstrap_password.encode("utf-8"), - user["password_salt"].encode("utf-8"), - 100_000, - ).hex() - - assert bootstrap_password - assert user["username"] == "unsloth" - assert user["must_change_password"] == 1 - assert bootstrap_hash == user["password_hash"] - assert len(app_secrets["api_key_pbkdf2_salt"]) == 64 - assert len(app_secrets["desktop_secret_hash"]) == 64 - assert app_secrets["desktop_secret_created_at"] - assert "is_desktop" in refresh_columns - - monkeypatch.setattr(storage, "DB_PATH", auth_dir / "auth.db") - monkeypatch.setattr(storage, "_api_key_pbkdf2_salt_cache", None) - assert storage.validate_desktop_secret(secret) == storage.DEFAULT_ADMIN_USERNAME - assert storage.requires_password_change(storage.DEFAULT_ADMIN_USERNAME) is True - - -def test_provision_desktop_auth_keeps_existing_admin_password(tmp_path, monkeypatch): - from typer.testing import CliRunner - from unsloth_cli.commands import studio as studio_cli - - auth_dir = tmp_path / "auth" - auth_dir.mkdir() - monkeypatch.setattr(studio_cli, "STUDIO_HOME", tmp_path) - - conn = sqlite3.connect(auth_dir / "auth.db") - try: - conn.execute( - """ - CREATE TABLE auth_user ( - id INTEGER PRIMARY KEY, - username TEXT UNIQUE NOT NULL, - password_salt TEXT NOT NULL, - password_hash TEXT NOT NULL, - jwt_secret TEXT NOT NULL, - must_change_password INTEGER NOT NULL DEFAULT 0 - ) - """ - ) - conn.execute( - """ - INSERT INTO auth_user ( - username, password_salt, password_hash, jwt_secret, must_change_password - ) - VALUES (?, ?, ?, ?, ?) - """, - ("unsloth", "existing-salt", "existing-hash", "existing-jwt", 0), - ) - conn.commit() - finally: - conn.close() - - result = CliRunner().invoke(studio_cli.studio_app, ["provision-desktop-auth"]) - - assert result.exit_code == 0 - assert not (auth_dir / ".bootstrap_password").exists() - conn = sqlite3.connect(auth_dir / "auth.db") - conn.row_factory = sqlite3.Row - try: - user = conn.execute( - """ - SELECT password_salt, password_hash, jwt_secret, must_change_password - FROM auth_user WHERE username = ? - """, - ("unsloth",), - ).fetchone() - finally: - conn.close() - - assert dict(user) == { - "password_salt": "existing-salt", - "password_hash": "existing-hash", - "jwt_secret": "existing-jwt", - "must_change_password": 0, - } - - -def test_update_password_clears_desktop_secret(): - seed_user() - raw = storage.create_desktop_secret() - assert storage.validate_desktop_secret(raw) == storage.DEFAULT_ADMIN_USERNAME - - changed = storage.update_password( - storage.DEFAULT_ADMIN_USERNAME, "new-admin-password" - ) - assert changed is True - assert storage.validate_desktop_secret(raw) is None - - -def test_update_password_on_unknown_user_leaves_desktop_secret_intact(): - seed_user() - raw = storage.create_desktop_secret() - - changed = storage.update_password("not-a-user", "irrelevant") - assert changed is False - assert storage.validate_desktop_secret(raw) == storage.DEFAULT_ADMIN_USERNAME - - -def test_desktop_auth_provision_has_bounded_timeout(): - rs_path = ( - Path(__file__).resolve().parents[3] - / "studio" - / "src-tauri" - / "src" - / "desktop_auth.rs" - ) - src = rs_path.read_text() - start = src.index("async fn provision_desktop_auth(") - depth = 0 - body_start = src.index("{", start) - body_end = None - for i in range(body_start, len(src)): - c = src[i] - if c == "{": - depth += 1 - elif c == "}": - depth -= 1 - if depth == 0: - body_end = i + 1 - break - assert body_end is not None - body = src[start:body_end] - assert "tokio::time::timeout" in body - import re - - m = re.search(r"Duration::from_secs\(\s*(\d+)\s*\)", body) - assert m is not None - seconds = int(m.group(1)) - assert 5 <= seconds <= 120 diff --git a/studio/backend/tests/test_export_log_cursor.py b/studio/backend/tests/test_export_log_cursor.py deleted file mode 100644 index 734ca522c9..0000000000 --- a/studio/backend/tests/test_export_log_cursor.py +++ /dev/null @@ -1,179 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 - -""" -Regression tests for the export log ring-buffer cursor semantics. - -Context: the live export log SSE stream has a race where the frontend -opens the SSE connection AFTER the POST that starts the export. Any -lines the worker subprocess emits during the gap between POST and SSE -connect get buffered with seqs 1..k, and then the SSE default cursor -`get_current_log_seq()` returns k -- so lines 1..k are forever -unreachable to that client. - -Fix: `clear_logs()` snapshots the pre-run seq into `_run_start_seq` -(exposed via `get_run_start_seq()`), and `routes/export.py` defaults -the SSE cursor to that snapshot instead of the current seq. Every line -appended during the current run has seq strictly greater than the -snapshot, so the client sees the full run regardless of when it -connects. - -These tests exercise the orchestrator-side contract only (no -subprocess, no FastAPI, no frontend). The routes-level integration -with get_run_start_seq() is a one-line edit covered by manual testing -and the frontend build. -""" - -from __future__ import annotations - -import sys -import types -from pathlib import Path - -import pytest - - -# Backend root on sys.path so `from core.export.orchestrator import ...` -# and friends resolve without the studio app bootstrap. -_BACKEND_DIR = Path(__file__).resolve().parent.parent -if str(_BACKEND_DIR) not in sys.path: - sys.path.insert(0, str(_BACKEND_DIR)) - -# ExportOrchestrator imports structlog and a few heavy modules at the -# top of orchestrator.py. Stub the ones we don't need in these unit -# tests so the import succeeds on machines without the full studio -# venv. -_loggers_stub = types.ModuleType("loggers") -_loggers_stub.get_logger = lambda name: __import__("logging").getLogger(name) -sys.modules.setdefault("loggers", _loggers_stub) - -# structlog is only used for a module-level import; a bare stub is -# enough because we never call into it in these tests. -sys.modules.setdefault("structlog", types.ModuleType("structlog")) - -# utils.paths.outputs_root is only called inside scan_checkpoints which -# we don't hit in these tests. Provide a stub module so the top-level -# import in orchestrator.py resolves. -_utils_pkg = types.ModuleType("utils") -_utils_pkg.__path__ = [] # mark as package -_utils_paths_stub = types.ModuleType("utils.paths") -_utils_paths_stub.outputs_root = lambda: Path("/tmp") -sys.modules.setdefault("utils", _utils_pkg) -sys.modules.setdefault("utils.paths", _utils_paths_stub) - - -@pytest.fixture -def orchestrator(): - """Fresh ExportOrchestrator with only the log-buffer state exercised.""" - from core.export.orchestrator import ExportOrchestrator - - return ExportOrchestrator() - - -def _append(orch, line: str, stream: str = "stdout") -> None: - """Shortcut for simulating a worker log message.""" - orch._append_log({"type": "log", "stream": stream, "line": line, "ts": 0.0}) - - -# --------------------------------------------------------------------------- -# clear_logs() semantics -# --------------------------------------------------------------------------- - - -def test_run_start_seq_is_zero_before_any_logs(orchestrator) -> None: - """A brand-new orchestrator must report run_start_seq == 0 so a - first SSE connection picks up every line from seq 1 onward.""" - assert orchestrator.get_run_start_seq() == 0 - - -def test_clear_logs_snapshots_current_seq(orchestrator) -> None: - """clear_logs() must capture _log_seq BEFORE clearing the buffer, - so subsequent runs can anchor their SSE cursor at the snapshot.""" - _append(orchestrator, "old run line 1") - _append(orchestrator, "old run line 2") - _append(orchestrator, "old run line 3") - assert orchestrator.get_current_log_seq() == 3 - - orchestrator.clear_logs() - - assert orchestrator.get_run_start_seq() == 3 - assert orchestrator.get_current_log_seq() == 3 # seq counter preserved - - -# --------------------------------------------------------------------------- -# Race regression: SSE connects AFTER lines have been emitted -# --------------------------------------------------------------------------- - - -def test_sse_default_cursor_catches_all_current_run_lines(orchestrator) -> None: - """Simulate the POST-then-SSE race: worker starts emitting lines - immediately after clear_logs(), SSE connects several lines later. - Using get_run_start_seq() as the default cursor MUST return every - line emitted since clear_logs() ran. - - Pre-fix, the SSE defaulted to get_current_log_seq() at connect - time, which would return the last-seen seq and miss lines N+1..M. - """ - # Previous run leaves some buffered lines. - _append(orchestrator, "previous run line A") - _append(orchestrator, "previous run line B") - - # New run starts: orchestrator clears the buffer and snapshots seq. - orchestrator.clear_logs() - run_start = orchestrator.get_run_start_seq() - - # Worker emits early lines BEFORE the SSE connects. - _append(orchestrator, "Importing Unsloth...") - _append(orchestrator, "Loading checkpoint: /foo/bar") - _append(orchestrator, "Starting export...") - - # SSE connects now and asks "give me everything after the run - # start cursor". - entries, new_cursor = orchestrator.get_logs_since(run_start) - - # All three early lines must be present. Pre-fix this was []. - lines = [e["line"] for e in entries] - assert lines == [ - "Importing Unsloth...", - "Loading checkpoint: /foo/bar", - "Starting export...", - ] - assert new_cursor == entries[-1]["seq"] - - -def test_sse_default_cursor_excludes_previous_run(orchestrator) -> None: - """After clear_logs(), lines from the PREVIOUS run must not leak - into the new run's SSE stream. Pre-fix this worked correctly - (clear_logs cleared the deque); the fix must preserve it. - """ - _append(orchestrator, "previous run line 1") - _append(orchestrator, "previous run line 2") - _append(orchestrator, "previous run line 3") - assert orchestrator.get_current_log_seq() == 3 - - orchestrator.clear_logs() - run_start = orchestrator.get_run_start_seq() - - _append(orchestrator, "new run line") - - entries, _ = orchestrator.get_logs_since(run_start) - assert [e["line"] for e in entries] == ["new run line"] - - -def test_clear_logs_twice_advances_run_start(orchestrator) -> None: - """Back-to-back clear_logs() calls (e.g. cleanup -> load -> - export in the same dialog session) must each re-anchor run_start - at the current seq, so successive runs each start with a fresh - low-water mark.""" - _append(orchestrator, "run 1 line a") - _append(orchestrator, "run 1 line b") - - orchestrator.clear_logs() - assert orchestrator.get_run_start_seq() == 2 - - _append(orchestrator, "run 2 line a") - _append(orchestrator, "run 2 line b") - _append(orchestrator, "run 2 line c") - - orchestrator.clear_logs() - assert orchestrator.get_run_start_seq() == 5 diff --git a/studio/backend/tests/test_gpu_selection.py b/studio/backend/tests/test_gpu_selection.py index a1fe5653ef..c6f26037af 100644 --- a/studio/backend/tests/test_gpu_selection.py +++ b/studio/backend/tests/test_gpu_selection.py @@ -746,15 +746,7 @@ def test_inference_route_rejects_gpu_ids_for_gguf(self): ): with self.assertRaises(HTTPException) as exc_info: asyncio.run( - inference_route.load_model( - request, - SimpleNamespace( - app = SimpleNamespace( - state = SimpleNamespace(llama_parallel_slots = 1), - ), - ), - current_subject = "test-user", - ) + inference_route.load_model(request, current_subject = "test-user") ) self.assertEqual(exc_info.exception.status_code, 400) @@ -894,15 +886,7 @@ def load_model(self, **kwargs): ): with self.assertRaises(HTTPException) as exc_info: asyncio.run( - inference_route.load_model( - request, - SimpleNamespace( - app = SimpleNamespace( - state = SimpleNamespace(llama_parallel_slots = 1), - ), - ), - current_subject = "test-user", - ) + inference_route.load_model(request, current_subject = "test-user") ) self.assertEqual(exc_info.exception.status_code, 400) @@ -958,15 +942,7 @@ def load_model(self, **kwargs): ): with self.assertRaises(HTTPException) as exc_info: asyncio.run( - inference_route.load_model( - request, - SimpleNamespace( - app = SimpleNamespace( - state = SimpleNamespace(llama_parallel_slots = 1), - ), - ), - current_subject = "test-user", - ) + inference_route.load_model(request, current_subject = "test-user") ) self.assertEqual(exc_info.exception.status_code, 400) @@ -1049,182 +1025,6 @@ def test_total_equals_min_gpu_vram_1(self): class TestPerGpuFitGuardAllCounts(unittest.TestCase): - def test_training_estimate_resolves_attention_without_raising(self): - with ( - patch("utils.hardware.hardware.get_device", return_value = DeviceType.CUDA), - patch( - "utils.hardware.hardware.estimate_fp16_model_size_bytes", - return_value = (8 * (1024**3), "config"), - ), - patch( - "utils.hardware.hardware._resolve_model_identifier_for_gpu_estimate", - return_value = "unsloth/test", - ), - patch( - "utils.hardware.hardware._load_config_for_gpu_estimate", - return_value = SimpleNamespace( - hidden_size = 4096, - num_hidden_layers = 32, - num_attention_heads = 32, - num_key_value_heads = 8, - intermediate_size = 14336, - vocab_size = 128256, - tie_word_embeddings = False, - ), - ), - patch( - "utils.hardware.hardware._determine_attention_impl_for_gpu_estimate", - return_value = "eager", - ), - patch("utils.hardware.hardware.get_visible_gpu_count", return_value = 1), - ): - _, metadata = estimate_required_model_memory_gb( - "unsloth/test", - training_type = "LoRA/QLoRA", - load_in_4bit = True, - ) - - self.assertEqual(metadata.get("estimation_mode"), "detailed") - self.assertEqual(metadata.get("attention_implementation"), "eager") - - def test_training_estimate_falls_back_when_attention_resolution_fails(self): - with ( - patch("utils.hardware.hardware.get_device", return_value = DeviceType.CUDA), - patch( - "utils.hardware.hardware.estimate_fp16_model_size_bytes", - return_value = (8 * (1024**3), "config"), - ), - patch( - "utils.hardware.hardware._resolve_model_identifier_for_gpu_estimate", - return_value = "unsloth/test", - ), - patch( - "utils.hardware.hardware._load_config_for_gpu_estimate", - return_value = SimpleNamespace( - hidden_size = 4096, - num_hidden_layers = 32, - num_attention_heads = 32, - num_key_value_heads = 8, - intermediate_size = 14336, - vocab_size = 128256, - tie_word_embeddings = False, - ), - ), - patch( - "utils.hardware.hardware._determine_attention_impl_for_gpu_estimate", - side_effect = RuntimeError("attention unavailable"), - ), - patch("utils.hardware.hardware.get_visible_gpu_count", return_value = 1), - ): - _, metadata = estimate_required_model_memory_gb( - "unsloth/test", - training_type = "LoRA/QLoRA", - load_in_4bit = True, - ) - - self.assertEqual(metadata.get("estimation_mode"), "detailed") - self.assertEqual( - metadata.get("attention_implementation"), - "eager", - ) - - def test_attention_resolver_does_not_mutate_loaded_config(self): - from utils.hardware import hardware as hardware_module - - config = SimpleNamespace( - hidden_size = 1024, - num_hidden_layers = 2, - num_attention_heads = 8, - num_key_value_heads = 8, - intermediate_size = 2048, - vocab_size = 1024, - tie_word_embeddings = True, - ) - - def _stub_resolver(model_class, cfg): - cfg._attn_implementation = "eager" - return "eager" - - with patch( - "unsloth.models._utils.resolve_attention_implementation", - side_effect = _stub_resolver, - ): - hardware_module._determine_attention_impl_for_gpu_estimate(config) - - self.assertFalse(hasattr(config, "_attn_implementation")) - - def test_attention_resolver_handles_missing_model_mapping(self): - from utils.hardware import hardware as hardware_module - - config = SimpleNamespace( - hidden_size = 1024, - num_hidden_layers = 2, - num_attention_heads = 8, - num_key_value_heads = 8, - intermediate_size = 2048, - vocab_size = 1024, - tie_word_embeddings = True, - ) - captured = {} - - def _stub_resolver(model_class, cfg): - captured["model_class"] = model_class - return "eager" - - from transformers import AutoModel, AutoModelForCausalLM - - with ( - patch.object(AutoModelForCausalLM, "_model_mapping", new = None), - patch.object(AutoModel, "_model_mapping", new = None), - patch( - "unsloth.models._utils.resolve_attention_implementation", - side_effect = _stub_resolver, - ), - ): - result = hardware_module._determine_attention_impl_for_gpu_estimate(config) - - self.assertEqual(result, "eager") - self.assertIsNone(captured["model_class"]) - - def test_attention_resolver_does_not_mutate_nested_text_config(self): - from utils.hardware import hardware as hardware_module - - text_config = SimpleNamespace( - hidden_size = 1024, - num_hidden_layers = 2, - num_attention_heads = 8, - num_key_value_heads = 8, - intermediate_size = 2048, - vocab_size = 1024, - tie_word_embeddings = True, - ) - config = SimpleNamespace( - hidden_size = 1024, - num_hidden_layers = 2, - num_attention_heads = 8, - num_key_value_heads = 8, - intermediate_size = 2048, - vocab_size = 1024, - tie_word_embeddings = True, - text_config = text_config, - ) - - def _stub_resolver(model_class, cfg): - cfg._attn_implementation = "eager" - inner = getattr(cfg, "text_config", None) - if inner is not None: - inner._attn_implementation = "eager" - return "eager" - - with patch( - "unsloth.models._utils.resolve_attention_implementation", - side_effect = _stub_resolver, - ): - hardware_module._determine_attention_impl_for_gpu_estimate(config) - - self.assertFalse(hasattr(config, "_attn_implementation")) - self.assertFalse(hasattr(text_config, "_attn_implementation")) - def test_min_per_gpu_generated_for_all_visible_counts(self): with ( patch("utils.hardware.hardware.get_device", return_value = DeviceType.CUDA), @@ -1301,123 +1101,3 @@ def test_prepare_gpu_selection_rejects_explicit_ids_on_xpu(self): with patch("utils.hardware.hardware.get_device", return_value = DeviceType.XPU): with self.assertRaisesRegex(ValueError, "only supported on CUDA"): prepare_gpu_selection([0], model_name = "unsloth/test") - - -class TestEstimateFp16ModelSizeBytesPrefersLocalWeights(unittest.TestCase): - def _run( - self, - model_path, - *, - config_bytes, - local_bytes, - safetensors_params = None, - config = object(), - ): - from utils.hardware import hardware as hardware_module - - with ( - patch.object( - hardware_module, - "_resolve_model_identifier_for_gpu_estimate", - return_value = model_path, - ), - patch.object( - hardware_module, - "_get_hf_safetensors_total_params", - return_value = safetensors_params, - ), - patch.object( - hardware_module, - "_load_config_for_gpu_estimate", - return_value = config, - ), - patch.object( - hardware_module, - "_estimate_fp16_model_size_bytes_from_config", - return_value = config_bytes, - ), - patch.object( - hardware_module, - "_get_local_weight_size_bytes", - return_value = local_bytes, - ), - ): - return hardware_module.estimate_fp16_model_size_bytes(model_path) - - def test_local_weight_bytes_preferred_when_larger_than_config(self): - bytes_, src = self._run( - "/local/vlm", - config_bytes = 2 * (1 << 30), - local_bytes = 20 * (1 << 30), - ) - self.assertEqual(bytes_, 20 * (1 << 30)) - self.assertEqual(src, "weight_bytes") - - def test_config_bytes_preferred_when_larger_than_local(self): - bytes_, src = self._run( - "/local/text-only", - config_bytes = 20 * (1 << 30), - local_bytes = 2 * (1 << 30), - ) - self.assertEqual(bytes_, 20 * (1 << 30)) - self.assertEqual(src, "config") - - def test_config_bytes_returned_when_no_local_weights(self): - bytes_, src = self._run( - "/local/no-weights", - config_bytes = 5 * (1 << 30), - local_bytes = None, - ) - self.assertEqual(bytes_, 5 * (1 << 30)) - self.assertEqual(src, "config") - - def test_local_bytes_returned_when_config_resolution_fails(self): - bytes_, src = self._run( - "/local/no-config", - config_bytes = None, - local_bytes = 7 * (1 << 30), - config = None, - ) - self.assertEqual(bytes_, 7 * (1 << 30)) - self.assertEqual(src, "weight_bytes") - - def test_equal_local_and_config_keeps_config_label(self): - # why: tie-breaker is "local must be strictly larger" so an exact - # match keeps the config-derived path. - same = 8 * (1 << 30) - bytes_, src = self._run( - "/local/equal", - config_bytes = same, - local_bytes = same, - ) - self.assertEqual(bytes_, same) - self.assertEqual(src, "config") - - def test_remote_safetensors_path_unaffected_by_local_weights(self): - from utils.hardware import hardware as hardware_module - - with ( - patch.object( - hardware_module, - "_resolve_model_identifier_for_gpu_estimate", - return_value = "owner/repo", - ), - patch.object( - hardware_module, - "_get_hf_safetensors_total_params", - return_value = 1_000_000_000, - ), - patch.object( - hardware_module, - "_load_config_for_gpu_estimate", - ) as mock_load, - patch.object( - hardware_module, - "_get_local_weight_size_bytes", - ) as mock_local, - ): - bytes_, src = hardware_module.estimate_fp16_model_size_bytes("owner/repo") - self.assertEqual(bytes_, 2 * 1_000_000_000) - self.assertEqual(src, "safetensors") - mock_load.assert_not_called() - mock_local.assert_not_called() diff --git a/studio/backend/tests/test_host_defaults.py b/studio/backend/tests/test_host_defaults.py deleted file mode 100644 index 8b81474e92..0000000000 --- a/studio/backend/tests/test_host_defaults.py +++ /dev/null @@ -1,98 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 - -"""Tests that Unsloth Studio defaults to 127.0.0.1 (loopback) not 0.0.0.0. - -Uses AST parsing to inspect source-level defaults without requiring the -full studio venv (run.py has heavy dependencies like structlog/uvicorn). -""" - -import ast -from pathlib import Path - -_RUN_PY = Path(__file__).resolve().parent.parent / "run.py" - - -def _parse_function_param_defaults(source: str, func_name: str) -> dict: - """Return {param_name: default_value} for a named function in *source*. - - Only handles ast.Constant defaults (strings, ints, bools). - """ - tree = ast.parse(source) - for node in ast.walk(tree): - if ( - isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) - and node.name == func_name - ): - result = {} - all_args = node.args.args - defaults = node.args.defaults - # Defaults are right-aligned against the args list - offset = len(all_args) - len(defaults) - for i, default in enumerate(defaults): - arg_name = all_args[offset + i].arg - if isinstance(default, ast.Constant): - result[arg_name] = default.value - return result - return {} - - -def _parse_argparse_add_argument_default(source: str, option_name: str): - """Return the 'default' kwarg value for add_argument(option_name, ...) in *source*. - - Walks the entire module so the call can live in __main__ or in a helper - function — only handles ast.Constant defaults. - """ - tree = ast.parse(source) - for node in ast.walk(tree): - if not isinstance(node, ast.Call): - continue - func = node.func - if not (isinstance(func, ast.Attribute) and func.attr == "add_argument"): - continue - if not node.args: - continue - first_arg = node.args[0] - if not (isinstance(first_arg, ast.Constant) and first_arg.value == option_name): - continue - for kw in node.keywords: - if kw.arg == "default" and isinstance(kw.value, ast.Constant): - return kw.value.value - return None - - -def test_run_server_default_host_is_loopback(): - """run_server() parameter default for 'host' must be 127.0.0.1, not 0.0.0.0. - - Binding to 0.0.0.0 by default exposes the service on all network - interfaces, contradicting the documented "privacy first / 100% local" - guarantee. Loopback (127.0.0.1) is the least-permissive default; - users who need network access can pass -H 0.0.0.0 explicitly. - """ - source = _RUN_PY.read_text() - defaults = _parse_function_param_defaults(source, "run_server") - assert ( - "host" in defaults - ), "run_server() must have a 'host' parameter with a default" - host_default = defaults["host"] - assert host_default == "127.0.0.1", ( - f"run_server() host default must be '127.0.0.1' (loopback) " - f"but got '{host_default}'. Binding to '{host_default}' by default " - f"exposes the service beyond localhost." - ) - - -def test_argparse_default_host_is_loopback(): - """argparse --host add_argument default must be 127.0.0.1. - - When run.py is invoked directly (python run.py), the argparse default - should match the function default so direct execution is equally safe. - """ - source = _RUN_PY.read_text() - host_default = _parse_argparse_add_argument_default(source, "--host") - assert ( - host_default is not None - ), "Could not find add_argument('--host', ...) in run.py" - assert ( - host_default == "127.0.0.1" - ), f"run.py argparse --host default must be '127.0.0.1', got '{host_default}'" diff --git a/studio/backend/tests/test_kv_cache_estimation.py b/studio/backend/tests/test_kv_cache_estimation.py index 29d87804ff..2640ded90d 100644 --- a/studio/backend/tests/test_kv_cache_estimation.py +++ b/studio/backend/tests/test_kv_cache_estimation.py @@ -12,7 +12,6 @@ """ import io -import json import struct import sys import types as _types @@ -38,43 +37,35 @@ _structlog_stub = _types.ModuleType("structlog") sys.modules.setdefault("structlog", _structlog_stub) -# httpx -- only stub when the real library isn't installed. Stubbing -# unconditionally would shadow ``HTTPError`` / ``Response`` etc. that -# ``huggingface_hub.errors`` imports at module load time, which causes -# the transformers introspection tier to silently return None inside -# the test process. -try: - import httpx as _httpx_real # noqa: F401 -except ImportError: - _httpx_stub = _types.ModuleType("httpx") - for _exc_name in ( - "ConnectError", - "TimeoutException", - "ReadTimeout", - "ReadError", - "RemoteProtocolError", - "CloseError", - "HTTPError", - "RequestError", - ): - setattr(_httpx_stub, _exc_name, type(_exc_name, (Exception,), {})) - - class _FakeTimeout: - def __init__(self, *a, **kw): - pass - - _httpx_stub.Timeout = _FakeTimeout - _httpx_stub.Response = type("Response", (), {}) - _httpx_stub.Client = type( - "Client", - (), - { - "__init__": lambda self, **kw: None, - "__enter__": lambda self: self, - "__exit__": lambda self, *a: None, - }, - ) - sys.modules["httpx"] = _httpx_stub +# httpx +_httpx_stub = _types.ModuleType("httpx") +for _exc_name in ( + "ConnectError", + "TimeoutException", + "ReadTimeout", + "ReadError", + "RemoteProtocolError", + "CloseError", +): + setattr(_httpx_stub, _exc_name, type(_exc_name, (Exception,), {})) + + +class _FakeTimeout: + def __init__(self, *a, **kw): + pass + + +_httpx_stub.Timeout = _FakeTimeout +_httpx_stub.Client = type( + "Client", + (), + { + "__init__": lambda self, **kw: None, + "__enter__": lambda self: self, + "__exit__": lambda self, *a: None, + }, +) +sys.modules.setdefault("httpx", _httpx_stub) from core.inference.llama_cpp import LlamaCppBackend @@ -86,7 +77,8 @@ def __init__(self, *a, **kw): def _make_gguf_bytes(arch: str, kv_pairs: dict) -> bytes: """Build a minimal GGUF v3 binary blob with the given KV metadata. - Supports the scalar and simple array metadata used by the parser. + Only supports UINT32 (type 4), UINT64 (type 10), and STRING (type 8) + values, which is all the metadata parser reads. """ buf = io.BytesIO() # Header: magic, version, tensor_count, kv_count @@ -104,17 +96,6 @@ def _make_gguf_bytes(arch: str, kv_pairs: dict) -> bytes: val_bytes = val.encode("utf-8") buf.write(struct.pack(" bytes: return buf.getvalue() -def _backend_from_gguf( - arch: str, fields: dict, general: dict | None = None -) -> LlamaCppBackend: - """Create a LlamaCppBackend with parsed GGUF metadata from given fields. - - `general` lets a test inject extra `general.*` metadata (used to - verify the dynamic SWA resolver picks up source-repo hints from - GGUFs that ship them). - """ +def _backend_from_gguf(arch: str, fields: dict) -> LlamaCppBackend: + """Create a LlamaCppBackend with parsed GGUF metadata from given fields.""" kv = {"general.architecture": arch} - for k, v in (general or {}).items(): - kv[k] = v for k, v in fields.items(): kv[f"{arch}.{k}"] = v import tempfile, os @@ -161,7 +133,7 @@ def _backend_from_gguf( class TestGGUFParserNewFields: - """Verify that architecture-aware fields are correctly parsed.""" + """Verify that the 8 new architecture-aware fields are correctly parsed.""" @pytest.mark.parametrize( "field,gguf_key,value", @@ -186,189 +158,15 @@ def test_missing_fields_are_none(self): "_kv_key_length", "_kv_value_length", "_sliding_window", - "_sliding_window_pattern", "_full_attention_interval", "_kv_lora_rank", "_key_length_mla", - "_kv_key_length_swa", - "_kv_value_length_swa", "_ssm_inner_size", "_ssm_state_size", ]: assert getattr(b, attr) is None - def test_array_fields_parsed(self): - b = _backend_from_gguf( - "gemma4", - { - "block_count": 6, - "attention.head_count_kv": [8, 8, 8, 8, 8, 2], - "attention.sliding_window_pattern": [ - True, - True, - True, - True, - True, - False, - ], - }, - ) - # Per-layer KV head count is preserved exactly... - assert b._n_kv_heads_by_layer == [8, 8, 8, 8, 8, 2] - # ...and mirrored into the scalar field as a conservative max so - # non-SWA estimator paths and any caller using - # `n_kv = self._n_kv_heads or ...` get a safe upper bound. - assert b._n_kv_heads == 8 - assert b._sliding_window_pattern == [True, True, True, True, True, False] - - -class TestArchSwaPatternDefaults: - """Bootstrap arch table fires when GGUF reports `sliding_window` but - no per-layer pattern (true for every Gemma 2/3/3n/gpt-oss GGUF today).""" - - @pytest.mark.parametrize( - "arch,n_layers,expected_period", - [ - ("gemma2", 26, 2), - ("gemma3", 18, 6), - ("gemma3n", 35, 5), - ("gpt_oss", 24, 2), - ("cohere2", 32, 4), - ], - ) - def test_arch_default_pattern_applied(self, arch, n_layers, expected_period): - b = _backend_from_gguf( - arch, - { - "block_count": n_layers, - "attention.head_count": 4, - "attention.head_count_kv": 1, - "attention.key_length": 256, - "attention.value_length": 256, - "attention.sliding_window": 512, - }, - ) - expected_pattern = [(i + 1) % expected_period != 0 for i in range(n_layers)] - assert ( - b._sliding_window_pattern == expected_pattern - ), f"{arch} should expand to period={expected_period}" - - def test_unknown_arch_no_default(self): - b = _backend_from_gguf( - "totallymadeupv7", - { - "block_count": 24, - "attention.head_count": 4, - "attention.head_count_kv": 1, - "attention.key_length": 128, - "attention.value_length": 128, - "attention.sliding_window": 1024, - }, - ) - assert b._sliding_window_pattern is None - - def test_explicit_pattern_overrides_arch_default(self): - # Period=6 is the gemma3 default; the explicit array must win. - b = _backend_from_gguf( - "gemma3", - { - "block_count": 6, - "attention.head_count": 4, - "attention.head_count_kv": 1, - "attention.key_length": 256, - "attention.value_length": 256, - "attention.sliding_window": 512, - "attention.sliding_window_pattern": [ - True, - False, - True, - False, - True, - False, - ], - }, - ) - assert b._sliding_window_pattern == [True, False, True, False, True, False] - - def test_no_sliding_window_no_pattern(self): - b = _backend_from_gguf( - "gemma3", - { - "block_count": 18, - "attention.head_count": 4, - "attention.head_count_kv": 1, - "attention.key_length": 256, - "attention.value_length": 256, - # no sliding_window key - }, - ) - assert b._sliding_window_pattern is None - - @pytest.mark.parametrize( - "arch", ["llama", "qwen2", "qwen3", "mistral", "mistral3", "glm4", "llama4"] - ) - def test_non_swa_arch_uses_full_attention_path(self, arch): - # Pure-GQA arches: GGUF has no sliding_window, no synthetic - # pattern, estimator hits Path 4. - b = _backend_from_gguf( - arch, - { - "block_count": 32, - "attention.head_count": 32, - "attention.head_count_kv": 8, - "attention.key_length": 128, - "attention.value_length": 128, - "embedding_length": 4096, - }, - ) - assert b._sliding_window_pattern is None - assert b._sliding_window is None - kv = b._estimate_kv_cache_bytes(8192, "f16") - gqa_expected = 32 * 8192 * 8 * (128 + 128) * 2 - assert kv == gqa_expected - - def test_arch_default_reduces_kv_estimate_vs_legacy(self): - common = { - "block_count": 62, - "attention.head_count": 32, - "attention.head_count_kv": 16, - "attention.key_length": 128, - "attention.value_length": 128, - "attention.sliding_window": 1024, - "embedding_length": 5376, - } - with_default = _backend_from_gguf("gemma3", common) - # Arch not in the table -> legacy 1/4 path. - without_default = _backend_from_gguf("totallymadeupv7", common) - - kv_default = with_default._estimate_kv_cache_bytes(131072, "f16") - kv_legacy = without_default._estimate_kv_cache_bytes(131072, "f16") - assert kv_default > 0 - assert kv_legacy > 0 - assert kv_default < kv_legacy, ( - f"arch fallback should under-shoot legacy estimate: " - f"{kv_default} >= {kv_legacy}" - ) - - def test_scalar_sliding_window_pattern_expanded(self): - block_count = 8 - b = _backend_from_gguf( - "gemma3", - { - "attention.sliding_window_pattern": 4, - "block_count": block_count, - "attention.head_count_kv": 4, - "attention.key_length": 256, - "attention.value_length": 256, - "attention.sliding_window": 1024, - }, - ) - expected = [(i + 1) % 4 != 0 for i in range(block_count)] - assert isinstance(b._sliding_window_pattern, list) - assert b._sliding_window_pattern == expected - assert b._estimate_kv_cache_bytes(4096, "f16") > 0 - - def test_all_fields_parsed_together(self): + def test_all_13_fields_parsed_together(self): fields = { "context_length": 131072, "block_count": 62, @@ -378,12 +176,9 @@ def test_all_fields_parsed_together(self): "attention.key_length": 128, "attention.value_length": 128, "attention.sliding_window": 1024, - "attention.sliding_window_pattern": [True, False], "full_attention_interval": 6, "attention.kv_lora_rank": 512, "attention.key_length_mla": 256, - "attention.key_length_swa": 64, - "attention.value_length_swa": 64, "ssm.inner_size": 4096, "ssm.state_size": 128, } @@ -396,294 +191,13 @@ def test_all_fields_parsed_together(self): assert b._kv_key_length == 128 assert b._kv_value_length == 128 assert b._sliding_window == 1024 - assert b._sliding_window_pattern == [True, False] assert b._full_attention_interval == 6 assert b._kv_lora_rank == 512 assert b._key_length_mla == 256 - assert b._kv_key_length_swa == 64 - assert b._kv_value_length_swa == 64 assert b._ssm_inner_size == 4096 assert b._ssm_state_size == 128 -_SWA_FIELDS = { - "block_count": 12, - "attention.head_count": 4, - "attention.head_count_kv": 1, - "attention.key_length": 256, - "attention.value_length": 256, - "attention.sliding_window": 512, -} - - -class TestDynamicSwaResolver: - """4-tier resolver: GGUF metadata, on-disk cache, bootstrap, HF fetch.""" - - def _isolate_cache(self, monkeypatch, tmp_path): - from core.inference import llama_cpp as lc - - monkeypatch.setenv("UNSLOTH_STUDIO_HOME", str(tmp_path)) - monkeypatch.setattr(lc, "_SWA_CACHE", None) - return tmp_path - - def test_period_from_layer_types_finds_smallest_period(self): - from core.inference.llama_cpp import _period_from_layer_types - - # gemma3 (1 global per 6), gpt-oss (alternating), gemma3n (1 per 5). - assert ( - _period_from_layer_types( - (["sliding_attention"] * 5 + ["full_attention"]) * 4 - ) - == 6 - ) - assert ( - _period_from_layer_types(["sliding_attention", "full_attention"] * 12) == 2 - ) - assert ( - _period_from_layer_types( - (["sliding_attention"] * 4 + ["full_attention"]) * 7 - ) - == 5 - ) - - def test_period_from_layer_types_returns_none_for_aperiodic(self): - from core.inference.llama_cpp import _period_from_layer_types - - lt = [ - "sliding_attention", - "full_attention", - "sliding_attention", - "sliding_attention", - "full_attention", - "sliding_attention", - "sliding_attention", - "sliding_attention", - ] - assert _period_from_layer_types(lt) is None - - def test_hf_repo_from_url(self): - from core.inference.llama_cpp import _hf_repo_from_url - - assert ( - _hf_repo_from_url("https://huggingface.co/google/gemma-3-1b-it") - == "google/gemma-3-1b-it" - ) - assert ( - _hf_repo_from_url( - "https://huggingface.co/google/gemma-3-1b-it/blob/main/config.json" - ) - == "google/gemma-3-1b-it" - ) - for bad in [ - "https://huggingface.co/google", - "https://example.com/foo/bar", - None, - "", - ]: - assert _hf_repo_from_url(bad) is None - - def test_bootstrap_tier_used_when_no_cache(self, monkeypatch, tmp_path): - self._isolate_cache(monkeypatch, tmp_path) - from core.inference import llama_cpp as lc - - def boom(*a, **kw): - raise AssertionError("HF fetch must not run when bootstrap covers the arch") - - monkeypatch.setattr(lc, "_fetch_swa_entry_from_hf", boom) - b = _backend_from_gguf("gemma3", dict(_SWA_FIELDS, block_count = 18)) - assert b._sliding_window_pattern == [(i + 1) % 6 != 0 for i in range(18)] - - def test_disk_cache_takes_precedence_over_bootstrap(self, monkeypatch, tmp_path): - self._isolate_cache(monkeypatch, tmp_path) - # Override bootstrap=6 with a cached period=3. - with open(tmp_path / "swa_cache.json", "w") as f: - json.dump({"gemma3": 3}, f) - b = _backend_from_gguf("gemma3", dict(_SWA_FIELDS, block_count = 18)) - assert b._sliding_window_pattern == [(i + 1) % 3 != 0 for i in range(18)] - - def test_disk_cache_supports_array_entries(self, monkeypatch, tmp_path): - # Aperiodic mask gets tiled across n_layers. - self._isolate_cache(monkeypatch, tmp_path) - mask = [True, False, True, True, False, True, False, False] - with open(tmp_path / "swa_cache.json", "w") as f: - json.dump({"customarch": mask}, f) - b = _backend_from_gguf("customarch", dict(_SWA_FIELDS, block_count = 16)) - assert b._sliding_window_pattern == [bool(mask[i % 8]) for i in range(16)] - - def test_hf_fetch_populates_cache(self, monkeypatch, tmp_path): - self._isolate_cache(monkeypatch, tmp_path) - from core.inference import llama_cpp as lc - - calls = [] - - def fake_fetch(repo_id): - calls.append(repo_id) - return 4 if repo_id == "vendor/newmodel-1b-instruct" else None - - monkeypatch.setattr(lc, "_fetch_swa_entry_from_hf", fake_fetch) - b = _backend_from_gguf( - "newmodel", - _SWA_FIELDS, - general = { - "general.source.huggingface.repository": "vendor/newmodel-1b-instruct" - }, - ) - assert b._sliding_window_pattern == [(i + 1) % 4 != 0 for i in range(12)] - assert calls == ["vendor/newmodel-1b-instruct"] - with open(tmp_path / "swa_cache.json") as f: - assert json.load(f) == {"newmodel": 4} - - def test_hf_fetch_falls_back_to_other_candidates(self, monkeypatch, tmp_path): - self._isolate_cache(monkeypatch, tmp_path) - from core.inference import llama_cpp as lc - - monkeypatch.setattr( - lc, - "_fetch_swa_entry_from_hf", - lambda r: 6 if r == "vendor/newmodel-base" else None, - ) - b = _backend_from_gguf( - "newmodel", - _SWA_FIELDS, - general = { - "general.base_model.0.repo_url": "https://huggingface.co/vendor/newmodel-base" - }, - ) - assert b._sliding_window_pattern == [(i + 1) % 6 != 0 for i in range(12)] - - def test_offline_env_skips_network(self, monkeypatch, tmp_path): - self._isolate_cache(monkeypatch, tmp_path) - monkeypatch.setenv("UNSLOTH_STUDIO_OFFLINE", "1") - from core.inference import llama_cpp as lc - - def boom(*a, **kw): - raise AssertionError("HF fetch must not run when offline=1") - - monkeypatch.setattr(lc, "_fetch_swa_entry_from_hf", boom) - b = _backend_from_gguf( - "newmodel", - _SWA_FIELDS, - general = {"general.source.huggingface.repository": "vendor/newmodel"}, - ) - assert b._sliding_window_pattern is None - - def test_hf_fetch_failure_falls_through_silently(self, monkeypatch, tmp_path): - self._isolate_cache(monkeypatch, tmp_path) - from core.inference import llama_cpp as lc - - monkeypatch.setattr(lc, "_fetch_swa_entry_from_hf", lambda repo_id: None) - # Force the failure into the Tier 3 path; bypass Tier 2.5. - monkeypatch.setattr( - lc, "_resolve_swa_entry_from_transformers", lambda arch: None - ) - b = _backend_from_gguf( - "newmodel", - _SWA_FIELDS, - general = {"general.source.huggingface.repository": "vendor/does-not-exist"}, - ) - assert b._sliding_window_pattern is None - assert not (tmp_path / "swa_cache.json").exists() - - -class TestTransformersIntrospection: - """Tier 2.5: default-init the matching Config; on failure, parse via inspect.""" - - def _isolate_cache(self, monkeypatch, tmp_path): - from core.inference import llama_cpp as lc - - monkeypatch.setenv("UNSLOTH_STUDIO_HOME", str(tmp_path)) - monkeypatch.setattr(lc, "_SWA_CACHE", None) - return tmp_path - - def test_arch_aliases_normalises_hyphen_underscore(self): - from core.inference.llama_cpp import _arch_aliases - - aliases = _arch_aliases("falcon-h1") - assert aliases[0] == "falcon-h1" and "falcon_h1" in aliases - assert _arch_aliases("gemma3") == ("gemma3",) - assert _arch_aliases("") == () - - def test_resolves_real_transformers_arches(self): - from core.inference.llama_cpp import _resolve_swa_entry_from_transformers - - assert _resolve_swa_entry_from_transformers("gemma3") == 6 - assert _resolve_swa_entry_from_transformers("gemma2") == 2 - assert _resolve_swa_entry_from_transformers("cohere2") == 4 - - def test_falls_back_to_inspect_when_default_init_raises(self, monkeypatch): - from core.inference import llama_cpp as lc - - class _FakeBrokenConfig: - """Class with sliding_window_pattern: int = 7 in its docstring.""" - - def __init__(self, required_arg): - raise TypeError("requires an argument") - - class _FakeLazyMapping(dict): - def __getitem__(self, k): - return ( - _FakeBrokenConfig if k == "brokenarch" else super().__getitem__(k) - ) - - import sys, types as _types - - fake_auto = _types.ModuleType("transformers.models.auto.configuration_auto") - fake_auto.CONFIG_MAPPING_NAMES = {"brokenarch": "FakeBroken"} - fake_auto.CONFIG_MAPPING = _FakeLazyMapping({"brokenarch": "FakeBroken"}) - monkeypatch.setitem( - sys.modules, "transformers.models.auto.configuration_auto", fake_auto - ) - assert lc._resolve_swa_entry_from_transformers("brokenarch") == 7 - - def test_returns_none_when_transformers_unavailable(self, monkeypatch): - from core.inference import llama_cpp as lc - import sys - - orig_import = ( - __builtins__["__import__"] - if isinstance(__builtins__, dict) - else __builtins__.__import__ - ) - - def fake_import(name, *a, **kw): - if name.startswith("transformers"): - raise ImportError("transformers not installed") - return orig_import(name, *a, **kw) - - monkeypatch.setattr("builtins.__import__", fake_import) - for k in list(sys.modules): - if k.startswith("transformers"): - monkeypatch.delitem(sys.modules, k, raising = False) - assert lc._resolve_swa_entry_from_transformers("gemma3") is None - - def test_returns_none_for_arch_unknown_to_transformers(self): - from core.inference.llama_cpp import _resolve_swa_entry_from_transformers - - assert _resolve_swa_entry_from_transformers("totally-fake-arch-xyz") is None - - def test_full_resolver_uses_transformers_before_hf_fetch( - self, monkeypatch, tmp_path - ): - # With bootstrap empty, Tier 2.5 must answer before Tier 3 fires. - self._isolate_cache(monkeypatch, tmp_path) - from core.inference import llama_cpp as lc - - monkeypatch.setattr(lc, "_BOOTSTRAP_SWA_DEFAULTS", {}) - - def boom(repo_id): - raise AssertionError("Tier 3 must not run when Tier 2.5 has the answer") - - monkeypatch.setattr(lc, "_fetch_swa_entry_from_hf", boom) - b = _backend_from_gguf( - "gemma3", - dict(_SWA_FIELDS, block_count = 18), - general = {"general.source.huggingface.repository": "google/gemma-3-1b-it"}, - ) - assert b._sliding_window_pattern == [(i + 1) % 6 != 0 for i in range(18)] - with open(tmp_path / "swa_cache.json") as f: - assert json.load(f) == {"gemma3": 6} - - class TestGGUFParserReset: """Verify that fields are properly reset between parses.""" @@ -695,19 +209,11 @@ def test_reset_between_parses(self): "block_count": 32, "attention.key_length": 128, "attention.kv_lora_rank": 512, - "attention.head_count_kv": [8, 2], - "attention.sliding_window_pattern": [True, False], - "attention.key_length_swa": 64, - "attention.value_length_swa": 64, "ssm.inner_size": 4096, }, ) assert b._kv_key_length == 128 assert b._kv_lora_rank == 512 - assert b._n_kv_heads_by_layer == [8, 2] - assert b._sliding_window_pattern == [True, False] - assert b._kv_key_length_swa == 64 - assert b._kv_value_length_swa == 64 assert b._ssm_inner_size == 4096 # Second parse without those fields -- they should be None @@ -724,10 +230,6 @@ def test_reset_between_parses(self): os.unlink(path) assert b._kv_key_length is None assert b._kv_lora_rank is None - assert b._n_kv_heads_by_layer is None - assert b._sliding_window_pattern is None - assert b._kv_key_length_swa is None - assert b._kv_value_length_swa is None assert b._ssm_inner_size is None assert b._n_layers == 64 @@ -953,9 +455,7 @@ def test_gemma3(self): n_global = max(1, 62 // 4) # 15 n_swa = 62 - n_global # 47 kv_per = 16 * (128 + 128) * 2 - # SWA cache is double-buffered: 2 * sliding_window cells, capped at n_ctx. - swa_cells = min(131072, 2 * 1024) - expected = int(n_global * 131072 * kv_per + n_swa * swa_cells * kv_per) + expected = int(n_global * 131072 * kv_per + n_swa * min(131072, 1024) * kv_per) assert b._estimate_kv_cache_bytes(131072, "f16") == expected def test_gpt_oss(self): @@ -972,52 +472,27 @@ def test_gpt_oss(self): n_global = max(1, 24 // 4) # 6 n_swa = 24 - n_global # 18 kv_per = 8 * (64 + 64) * 2 - swa_cells = min(131072, 2 * 128) - expected = int(n_global * 131072 * kv_per + n_swa * swa_cells * kv_per) + expected = int(n_global * 131072 * kv_per + n_swa * min(131072, 128) * kv_per) assert b._estimate_kv_cache_bytes(131072, "f16") == expected - def test_gemma4_per_layer_swa_metadata(self): - b = self._swa_backend( - _n_layers = 30, - _n_kv_heads = None, - _n_kv_heads_by_layer = [8, 8, 8, 8, 8, 2] * 5, - _n_heads = 16, - _embedding_length = 2816, - _kv_key_length = 512, - _kv_value_length = 512, - _sliding_window = 1024, - _sliding_window_pattern = [True, True, True, True, True, False] * 5, - _kv_key_length_swa = 256, - _kv_value_length_swa = 256, - ) - - full_layers = 5 - sliding_layers = 25 - - def expected(ctx): - full = full_layers * ctx * 2 * (512 + 512) * 2 - sliding = sliding_layers * min(ctx, 2 * 1024) * 8 * (256 + 256) * 2 - return int(full + sliding) - - for ctx in (4096, 46500, 262144): - assert b._estimate_kv_cache_bytes(ctx, "f16") == expected(ctx) - def test_ctx_smaller_than_window(self): - """When context < 2 * sliding_window, SWA cache caps at ctx.""" + """When context < sliding_window, SWA layers use full context anyway.""" b = self._swa_backend(_sliding_window = 8192) n_global = max(1, 62 // 4) # 15 n_swa = 62 - n_global # 47 kv_per = 16 * (128 + 128) * 2 ctx = 4096 - expected = int(n_global * ctx * kv_per + n_swa * min(ctx, 2 * 8192) * kv_per) + expected = int(n_global * ctx * kv_per + n_swa * min(ctx, 8192) * kv_per) + # min(4096, 8192) = 4096, so both pools use full ctx assert b._estimate_kv_cache_bytes(ctx, "f16") == expected def test_odd_layer_count(self): + """Odd layer count: n_global = max(1, n//4), n_swa = n - n_global.""" b = self._swa_backend(_n_layers = 63) n_global = max(1, 63 // 4) # 15 n_swa = 63 - n_global # 48 kv_per = 16 * (128 + 128) * 2 - expected = int(n_global * 1000 * kv_per + n_swa * min(1000, 2 * 1024) * kv_per) + expected = int(n_global * 1000 * kv_per + n_swa * min(1000, 1024) * kv_per) assert b._estimate_kv_cache_bytes(1000, "f16") == expected @@ -1310,686 +785,6 @@ def test_both_heads_none_falls_to_one(self): assert result == expected -# --------------------------------------------------------------------------- -# J2. Server-flag knobs (--swa-full, --kv-unified/--parallel, -# --ctx-checkpoints, --kv-offload) -# --------------------------------------------------------------------------- - - -class TestServerFlags: - """Estimator should mirror llama-server CLI flags that change KV size.""" - - def _swa_backend(self, **overrides): - defaults = { - "_n_layers": 26, - "_n_kv_heads": 4, - "_n_heads": 8, - "_embedding_length": 1152, - "_kv_key_length": 256, - "_kv_value_length": 256, - "_sliding_window": 512, - "_sliding_window_pattern": [True, True, True, True, True, False] * 4 - + [True, True], - } - defaults.update(overrides) - b = LlamaCppBackend() - for k, v in defaults.items(): - setattr(b, k, v) - return b - - def _gqa_backend(self, **overrides): - defaults = { - "_n_layers": 28, - "_n_kv_heads": 8, - "_n_heads": 16, - "_embedding_length": 1024, - "_kv_key_length": 128, - "_kv_value_length": 128, - } - defaults.update(overrides) - b = LlamaCppBackend() - for k, v in defaults.items(): - setattr(b, k, v) - return b - - # ── --swa-full ────────────────────────────────────────────────── - - def test_swa_full_collapses_pattern_path_to_full_ctx(self): - b = self._swa_backend() - ctx = 32_768 - flagged = b._estimate_kv_cache_bytes(ctx, "f16", swa_full = True) - # With swa_full, every layer caches n_ctx -- equals path 4 sizing. - kv_per_token = 4 * (256 + 256) * 2 # n_kv_heads * (k+v) * f16 - expected = 26 * ctx * kv_per_token - assert flagged == expected - assert flagged > b._estimate_kv_cache_bytes(ctx, "f16") - - def test_swa_full_collapses_legacy_path_to_full_ctx(self): - # No per-layer pattern -> 1/4-global heuristic; swa_full overrides. - b = self._swa_backend(_sliding_window_pattern = None) - ctx = 16_384 - flagged = b._estimate_kv_cache_bytes(ctx, "f16", swa_full = True) - n_global = max(1, 26 // 4) - n_swa = 26 - n_global - kv_per = 4 * (256 + 256) * 2 - # swa_cells == n_ctx when swa_full=True - expected = n_global * ctx * kv_per + n_swa * ctx * kv_per - assert flagged == expected - - def test_swa_full_no_op_for_non_swa_model(self): - b = self._gqa_backend() - baseline = b._estimate_kv_cache_bytes(8192, "f16") - flagged = b._estimate_kv_cache_bytes(8192, "f16", swa_full = True) - assert flagged == baseline - - def test_swa_full_suppresses_checkpoint_term(self): - b = self._swa_backend() - with_cp = b._estimate_kv_cache_bytes(8192, "f16", ctx_checkpoints = 8) - with_cp_full = b._estimate_kv_cache_bytes( - 8192, "f16", ctx_checkpoints = 8, swa_full = True - ) - no_cp_full = b._estimate_kv_cache_bytes(8192, "f16", swa_full = True) - # Checkpoints only matter when SWA layers don't already keep n_ctx. - assert with_cp_full == no_cp_full - assert with_cp > b._estimate_kv_cache_bytes(8192, "f16") - - # ── --parallel + --kv-unified ────────────────────────────────── - # Empirically verified against llama-server: non-SWA caches partition - # n_ctx across slots (total memory constant); SWA layers are the only - # portion that scales with --parallel. --kv-unified is currently a - # no-op for memory math (kept for API forward-compat). - - def test_gqa_kv_constant_across_parallel(self): - b = self._gqa_backend() - baseline = b._estimate_kv_cache_bytes(4096, "f16") - for slots in (1, 2, 4, 8): - for unified in (True, False): - assert ( - b._estimate_kv_cache_bytes( - 4096, "f16", n_parallel = slots, kv_unified = unified - ) - == baseline - ) - - def test_zero_parallel_floors_at_one(self): - b = self._gqa_backend() - baseline = b._estimate_kv_cache_bytes(4096, "f16") - for unified in (True, False): - assert ( - b._estimate_kv_cache_bytes( - 4096, "f16", n_parallel = 0, kv_unified = unified - ) - == baseline - ) - - def test_swa_path_scales_only_swa_portion(self): - b = self._swa_backend() - ctx = 8192 - baseline = b._estimate_kv_cache_bytes(ctx, "f16") - # Decompose baseline by walking the same loop the estimator does. - swa = b._sliding_window - per_token_global = 4 * (256 + 256) * 2 # n_kv * (k+v) * f16 - per_token_swa = 4 * (256 + 256) * 2 # k_swa/val_swa fall back - per_slot_swa_cells = min(ctx, 2 * swa) # not clamped at parallel=1 - global_bytes = sum( - ctx * per_token_global - for f in b._sliding_window_pattern[: b._n_layers] - if not f - ) - swa_bytes_per_slot = sum( - per_slot_swa_cells * per_token_swa - for f in b._sliding_window_pattern[: b._n_layers] - if f - ) - # Sanity: parallel=1 reproduces baseline exactly - assert global_bytes + swa_bytes_per_slot == baseline - # Only SWA portion scales by parallel - for slots in (1, 2, 3, 4): - scaled = b._estimate_kv_cache_bytes( - ctx, "f16", n_parallel = slots, kv_unified = False - ) - # SWA cells get clamped to per_slot_ctx when ctx/slots < 2*swa - per_slot_ctx = max(1, ctx // slots) - cells = min(ctx, 2 * swa, per_slot_ctx) - swa_bps = sum( - cells * per_token_swa - for f in b._sliding_window_pattern[: b._n_layers] - if f - ) - assert scaled == global_bytes + slots * swa_bps - - def test_mla_kv_constant_across_parallel(self): - b = LlamaCppBackend() - b._n_layers = 60 - b._n_kv_heads = 1 - b._kv_lora_rank = 512 - b._key_length_mla = 64 - b._kv_key_length = 576 - baseline = b._estimate_kv_cache_bytes(8192, "f16") - for slots in (1, 2, 4, 8): - for unified in (True, False): - assert ( - b._estimate_kv_cache_bytes( - 8192, "f16", n_parallel = slots, kv_unified = unified - ) - == baseline - ) - - # ── --ctx-checkpoints ────────────────────────────────────────── - - def test_ctx_checkpoints_zero_is_no_op(self): - b = self._swa_backend() - baseline = b._estimate_kv_cache_bytes(8192, "f16") - assert b._estimate_kv_cache_bytes(8192, "f16", ctx_checkpoints = 0) == baseline - - def test_ctx_checkpoints_no_op_for_non_swa(self): - b = self._gqa_backend() - baseline = b._estimate_kv_cache_bytes(8192, "f16") - assert b._estimate_kv_cache_bytes(8192, "f16", ctx_checkpoints = 32) == baseline - - def test_ctx_checkpoints_pattern_path_adds_known_bytes(self): - b = self._swa_backend() - ctx = 8192 - baseline = b._estimate_kv_cache_bytes(ctx, "f16") - flagged = b._estimate_kv_cache_bytes(ctx, "f16", ctx_checkpoints = 4) - # 22 SWA layers * 4 checkpoints * 512 cells * 4 heads * (256+256) * 2 bytes - n_swa_layers = sum( - 1 for f in [True, True, True, True, True, False] * 4 + [True, True] if f - ) - per_layer = 4 * 512 * 4 * (256 + 256) * 2 - assert flagged == baseline + n_swa_layers * per_layer - - def test_ctx_checkpoints_legacy_path_adds_known_bytes(self): - b = self._swa_backend(_sliding_window_pattern = None) - ctx = 8192 - baseline = b._estimate_kv_cache_bytes(ctx, "f16") - flagged = b._estimate_kv_cache_bytes(ctx, "f16", ctx_checkpoints = 4) - n_global = max(1, 26 // 4) - n_swa = 26 - n_global - kv_per = 4 * (256 + 256) * 2 - extra = 4 * n_swa * 512 * kv_per # ctx_checkpoints * n_swa * sliding * kv_per - assert flagged == baseline + extra - - def test_ctx_checkpoints_compose_with_n_parallel(self): - # Only the SWA + checkpoint portion scales by n_parallel; the - # global-layer portion stays constant. - b = self._swa_backend() - ctx = 8192 - swa = b._sliding_window - per_token = 4 * (256 + 256) * 2 - global_bytes = sum( - ctx * per_token for f in b._sliding_window_pattern[: b._n_layers] if not f - ) - n_swa_layers = sum(1 for f in b._sliding_window_pattern[: b._n_layers] if f) - slots = 3 - per_slot_ctx = max(1, ctx // slots) - swa_cells = min(ctx, 2 * swa, per_slot_ctx) - swa_bytes_per_slot = n_swa_layers * swa_cells * per_token - cp_extra_per_slot = n_swa_layers * 4 * swa * per_token # 4 checkpoints - flagged = b._estimate_kv_cache_bytes( - ctx, "f16", ctx_checkpoints = 4, n_parallel = slots, kv_unified = False - ) - assert flagged == global_bytes + slots * ( - swa_bytes_per_slot + cp_extra_per_slot - ) - - # ── --kv-offload (kv_on_gpu) ─────────────────────────────────── - - def test_fit_returns_requested_when_kv_off_gpu(self): - b = self._gqa_backend() - # Tiny VRAM budget -- normally would force a reduction. - fitted = b._fit_context_to_vram( - requested_ctx = 32_768, - available_mib = 1, - model_size_bytes = 100, - cache_type_kv = "f16", - kv_on_gpu = False, - ) - assert fitted == 32_768 - - def test_fit_reduces_when_kv_on_gpu(self): - b = self._gqa_backend() - fitted = b._fit_context_to_vram( - requested_ctx = 32_768, - available_mib = 64, - model_size_bytes = 1024 * 1024, # 1 MiB - cache_type_kv = "f16", - kv_on_gpu = True, - ) - assert fitted < 32_768 - - def test_fit_threads_swa_full_through_estimator(self): - # SWA model, generous budget; both should fit but cache size differs. - b = self._swa_backend() - ctx = 8192 - kv_default = b._estimate_kv_cache_bytes(ctx, "f16") - kv_full = b._estimate_kv_cache_bytes(ctx, "f16", swa_full = True) - assert kv_full > kv_default - # Budget = model + kv_default (rounded up) -- swa_full should not fit. - budget_mib = (1024 * 1024 + kv_default) / (1024 * 1024) / 0.90 + 1 - fitted_default = b._fit_context_to_vram( - requested_ctx = ctx, - available_mib = int(budget_mib), - model_size_bytes = 1024 * 1024, - cache_type_kv = "f16", - ) - fitted_full = b._fit_context_to_vram( - requested_ctx = ctx, - available_mib = int(budget_mib), - model_size_bytes = 1024 * 1024, - cache_type_kv = "f16", - swa_full = True, - ) - assert fitted_default == ctx - assert fitted_full < ctx - - -# --------------------------------------------------------------------------- -# J2.5. --parallel N memory accounting (per-layer-type scaling rule) -# --------------------------------------------------------------------------- - - -class TestParallelSWAScaling: - """Verifies the per-layer-type scaling rule against the closed form - measured from llama-server. Empirical formula on Gemma-3 270m at - ctx=8192: total_kv = 24 + parallel * 15 (MiB). - - Rule (verified vs ``llama-server`` log on real GGUFs): - * non-SWA layers: total cells = n_ctx, partitioned across slots, - memory CONSTANT in n_parallel. - * SWA layers: per-slot cells = 2 * sliding_window (clamped at - n_ctx and at per_slot_ctx); memory LINEAR in n_parallel. - * --kv-unified is a no-op for memory math; both modes yield the - same total in measured cases. - """ - - def _gqa_backend(self, **overrides): - defaults = { - "_n_layers": 28, - "_n_kv_heads": 8, - "_n_heads": 16, - "_embedding_length": 1024, - "_kv_key_length": 128, - "_kv_value_length": 128, - } - defaults.update(overrides) - b = LlamaCppBackend() - for k, v in defaults.items(): - setattr(b, k, v) - return b - - def _swa_backend(self, **overrides): - defaults = { - "_n_layers": 18, - "_n_kv_heads": 1, - "_n_heads": 4, - "_embedding_length": 1024, - "_kv_key_length": 256, - "_kv_value_length": 256, - "_sliding_window": 512, - # 15 SWA + 3 global, mirrors gemma-3-270m - "_sliding_window_pattern": [ - t == "swa" for t in (["swa"] * 5 + ["global"]) * 3 - ], - } - defaults.update(overrides) - b = LlamaCppBackend() - for k, v in defaults.items(): - setattr(b, k, v) - return b - - # ── non-SWA paths: constant ──────────────────────────────────── - - def test_pure_gqa_constant_across_parallel(self): - b = self._gqa_backend() - baseline = b._estimate_kv_cache_bytes(8192, "f16") - for slots in (1, 2, 4, 8): - for unified in (True, False): - assert ( - b._estimate_kv_cache_bytes( - 8192, "f16", n_parallel = slots, kv_unified = unified - ) - == baseline - ) - - def test_mla_constant_across_parallel(self): - b = LlamaCppBackend() - b._n_layers = 60 - b._n_kv_heads = 1 - b._kv_lora_rank = 512 - b._key_length_mla = 64 - b._kv_key_length = 576 - baseline = b._estimate_kv_cache_bytes(8192, "f16") - for slots in (1, 2, 4, 8): - assert b._estimate_kv_cache_bytes(8192, "f16", n_parallel = slots) == baseline - - def test_hybrid_constant_across_parallel(self): - b = LlamaCppBackend() - b._n_layers = 64 - b._n_kv_heads = 16 - b._n_heads = 32 - b._embedding_length = 4096 - b._kv_key_length = 128 - b._kv_value_length = 128 - b._ssm_inner_size = 4096 - b._full_attention_interval = 4 - baseline = b._estimate_kv_cache_bytes(8192, "f16") - for slots in (1, 2, 4, 8): - assert b._estimate_kv_cache_bytes(8192, "f16", n_parallel = slots) == baseline - - def test_legacy_constant_across_parallel(self): - b = LlamaCppBackend() - b._n_layers = 32 - b._n_kv_heads = 8 - b._n_heads = 8 - b._embedding_length = 4096 - baseline = b._estimate_kv_cache_bytes(8192, "f16") - for slots in (1, 2, 4, 8): - assert b._estimate_kv_cache_bytes(8192, "f16", n_parallel = slots) == baseline - - # ── SWA paths: scale only the SWA portion ────────────────────── - - def test_swa_pattern_scales_only_swa_portion(self): - b = self._swa_backend() - ctx = 8192 - swa = b._sliding_window - per_token = 1 * (256 + 256) * 2 # n_kv * (k+v) * f16 - n_global = sum(1 for f in b._sliding_window_pattern if not f) - n_swa = sum(1 for f in b._sliding_window_pattern if f) - global_bytes = n_global * ctx * per_token - for slots in (1, 2, 4, 8): - per_slot_ctx = max(1, ctx // slots) - cells = min(ctx, 2 * swa, per_slot_ctx) - swa_bps = n_swa * cells * per_token - for unified in (True, False): - got = b._estimate_kv_cache_bytes( - ctx, "f16", n_parallel = slots, kv_unified = unified - ) - assert got == global_bytes + slots * swa_bps - - def test_swa_fallback_scales_only_swa_portion(self): - # No per-layer pattern -> 1/4-global heuristic. - b = self._swa_backend(_sliding_window_pattern = None) - ctx = 8192 - swa = b._sliding_window - n_layers = 18 - n_global = max(1, n_layers // 4) - n_swa = n_layers - n_global - per_token = 1 * (256 + 256) * 2 - global_bytes = n_global * ctx * per_token - for slots in (1, 2, 4, 8): - per_slot_ctx = max(1, ctx // slots) - cells = min(ctx, 2 * swa, per_slot_ctx) - swa_bps = n_swa * cells * per_token - got = b._estimate_kv_cache_bytes(ctx, "f16", n_parallel = slots) - assert got == global_bytes + slots * swa_bps - - def test_swa_per_slot_clamped_when_ctx_lt_slots_x_2window(self): - # ctx=4096 / slots=8 -> per_slot_ctx=512, but 2*sliding=1024. - # SWA cells should clamp at per_slot_ctx (512), not 2*sliding. - b = self._swa_backend() - ctx = 4096 - per_slot_ctx_at_8 = ctx // 8 - assert per_slot_ctx_at_8 < 2 * b._sliding_window - # Build expected with the clamped formula - n_swa = sum(1 for f in b._sliding_window_pattern if f) - n_global = sum(1 for f in b._sliding_window_pattern if not f) - per_token = 1 * (256 + 256) * 2 - global_bytes = n_global * ctx * per_token - cells = min(ctx, 2 * b._sliding_window, per_slot_ctx_at_8) - assert cells == per_slot_ctx_at_8 - expected = global_bytes + 8 * (n_swa * cells * per_token) - assert b._estimate_kv_cache_bytes(ctx, "f16", n_parallel = 8) == expected - - def test_swa_full_does_not_scale_under_parallel(self): - # swa_full forces every layer to n_ctx; result is the all-global - # GQA-style total, which is constant in parallel. - b = self._swa_backend() - ctx = 8192 - baseline = b._estimate_kv_cache_bytes(ctx, "f16", swa_full = True) - for slots in (1, 2, 4, 8): - assert ( - b._estimate_kv_cache_bytes(ctx, "f16", swa_full = True, n_parallel = slots) - == baseline - ) - - # ── kv_unified: no-op for memory math ────────────────────────── - - def test_kv_unified_is_no_op_for_memory_math(self): - # Both unified=True and unified=False must produce the same - # total bytes for every backend type and every parallel value. - backends = [ - ("gqa", self._gqa_backend()), - ("swa", self._swa_backend()), - ] - for label, b in backends: - for slots in (1, 2, 4, 8): - u = b._estimate_kv_cache_bytes( - 8192, "f16", n_parallel = slots, kv_unified = True - ) - nu = b._estimate_kv_cache_bytes( - 8192, "f16", n_parallel = slots, kv_unified = False - ) - assert u == nu, f"{label} parallel={slots} unified-mismatch" - - # ── Empirical Gemma-3 270m formula ───────────────────────────── - - def test_matches_empirical_gemma3_270m_formula(self): - """Exact match against the formula measured from llama-server: - total_kv = 24 + parallel * 15 (MiB) at ctx=8192. - - Geometry: 18 layers (3 global + 15 SWA), n_kv=1, head_dim=256, - sliding=512, f16. - """ - b = LlamaCppBackend() - b._n_layers = 18 - b._n_kv_heads = 1 - b._n_heads = 4 - b._embedding_length = 1024 - b._kv_key_length = 256 - b._kv_value_length = 256 - b._sliding_window = 512 - # 5-period [swa,swa,swa,swa,full] * 3 + [swa,swa,swa]: mirrors the - # bootstrap-resolved pattern for gemma3 (period 6) on an 18-layer - # model (15 SWA, 3 global). - b._sliding_window_pattern = [(i + 1) % 6 != 0 for i in range(18)] - n_global = 3 - n_swa = 15 - # Confirm pattern shape - assert sum(b._sliding_window_pattern) == n_swa - for slots, expected_mib in [(1, 39), (2, 54), (4, 84)]: - got_bytes = b._estimate_kv_cache_bytes(8192, "f16", n_parallel = slots) - got_mib = got_bytes / (1024 * 1024) - assert ( - got_mib == expected_mib - ), f"slots={slots}: got {got_mib} MiB, expected {expected_mib} MiB" - - -# --------------------------------------------------------------------------- -# J3. shared_kv_layers (Gemma 3n / Gemma 4) -# --------------------------------------------------------------------------- - - -class TestSharedKVLayers: - """``.attention.shared_kv_layers`` reduces the layer count that - actually allocates KV. The trailing ``shared_kv_layers`` blocks reuse - earlier caches (Gemma 3n: 35 layers, 15 shared -> 20 allocate; Gemma 4 - same field). Unset on every other arch -> no behavioural change.""" - - def _gemma3n_backend(self, **overrides): - # Mirrors google/gemma-3n-E4B-it: 35 layers, 15 shared, - # SWA window 1024, period 5 (4 sliding + 1 full repeating). - defaults = { - "_n_layers": 35, - "_n_kv_heads": 4, - "_n_heads": 8, - "_embedding_length": 2048, - "_kv_key_length": 256, - "_kv_value_length": 256, - "_sliding_window": 1024, - "_sliding_window_pattern": [ - t == "sliding_attention" - for t in (["sliding_attention"] * 4 + ["full_attention"]) * 7 - ], - "_shared_kv_layers": 15, - } - defaults.update(overrides) - b = LlamaCppBackend() - for k, v in defaults.items(): - setattr(b, k, v) - return b - - def _gqa_backend(self, **overrides): - defaults = { - "_n_layers": 28, - "_n_kv_heads": 8, - "_n_heads": 16, - "_embedding_length": 1024, - "_kv_key_length": 128, - "_kv_value_length": 128, - } - defaults.update(overrides) - b = LlamaCppBackend() - for k, v in defaults.items(): - setattr(b, k, v) - return b - - def test_field_initialises_to_none(self): - b = LlamaCppBackend() - assert b._shared_kv_layers is None - - def test_unset_field_is_noop(self): - b = self._gqa_backend() - baseline = b._estimate_kv_cache_bytes(8192, "f16") - b._shared_kv_layers = None - assert b._estimate_kv_cache_bytes(8192, "f16") == baseline - b._shared_kv_layers = 0 - assert b._estimate_kv_cache_bytes(8192, "f16") == baseline - - def test_path4_drops_shared_layers(self): - b = self._gqa_backend(_shared_kv_layers = 4) - ctx = 4096 - kv_per = 8 * (128 + 128) * 2 - # 28 - 4 = 24 layers actually allocate - assert b._estimate_kv_cache_bytes(ctx, "f16") == 24 * ctx * kv_per - - def test_path5_drops_shared_layers(self): - b = LlamaCppBackend() - b._n_layers = 32 - b._n_kv_heads = 8 - b._n_heads = 8 - b._embedding_length = 4096 - b._shared_kv_layers = 8 - ctx = 4096 - head_dim = 4096 // 8 # 512 - # 32 - 8 = 24 layers - expected = 2 * 8 * head_dim * 24 * ctx * 2 - assert b._estimate_kv_cache_bytes(ctx, "f16") == expected - - def test_path1_mla_drops_shared_layers(self): - b = LlamaCppBackend() - b._n_layers = 60 - b._n_kv_heads = 1 - b._kv_lora_rank = 512 - b._key_length_mla = 64 - b._kv_key_length = 576 - b._shared_kv_layers = 10 - ctx = 8192 - # 60 - 10 = 50 - assert b._estimate_kv_cache_bytes(ctx, "f16") == 50 * ctx * 1 * 576 * 2 - - def test_path3_pattern_loops_only_unshared_layers(self): - b = self._gemma3n_backend() - ctx = 8192 - # First 20 layers contribute; layers 20..34 are skipped. - # Pattern: [s,s,s,s,F] repeated. In layers 0..19: - # sliding: 16, full: 4 - sliding_in_unshared = sum(b._sliding_window_pattern[:20]) - full_in_unshared = 20 - sliding_in_unshared - assert sliding_in_unshared == 16 - assert full_in_unshared == 4 - kv_per = 4 * (256 + 256) * 2 - swa_cells = min(ctx, 2 * 1024) - expected = ( - full_in_unshared * ctx * kv_per + sliding_in_unshared * swa_cells * kv_per - ) - assert b._estimate_kv_cache_bytes(ctx, "f16") == expected - - def test_shared_layers_reduces_estimate(self): - b = self._gemma3n_backend() - with_shared = b._estimate_kv_cache_bytes(8192, "f16") - b._shared_kv_layers = 0 - without_shared = b._estimate_kv_cache_bytes(8192, "f16") - # 20/35 = 0.571 of the work; expect ~43% reduction. - ratio = with_shared / without_shared - assert 0.5 < ratio < 0.65 - - def test_path3_pattern_with_swa_full_and_shared(self): - b = self._gemma3n_backend() - ctx = 8192 - flagged = b._estimate_kv_cache_bytes(ctx, "f16", swa_full = True) - # Every unshared layer caches n_ctx; equals path-4-style sizing - # over only the 20 unshared layers. - kv_per = 4 * (256 + 256) * 2 - assert flagged == 20 * ctx * kv_per - - def test_path3_fallback_uses_unshared_count(self): - # No per-layer pattern -> 1/4-global heuristic over n_layers_kv, - # not n_layers. - b = self._gemma3n_backend(_sliding_window_pattern = None) - ctx = 8192 - n_layers_kv = 35 - 15 # 20 - n_global = max(1, n_layers_kv // 4) # 5 - n_swa = n_layers_kv - n_global # 15 - kv_per = 4 * (256 + 256) * 2 - swa_cells = min(ctx, 2 * 1024) - expected = n_global * ctx * kv_per + n_swa * swa_cells * kv_per - assert b._estimate_kv_cache_bytes(ctx, "f16") == expected - - def test_shared_floors_at_one_layer(self): - # Pathological: shared >= n_layers should not zero out the cache. - b = self._gqa_backend(_shared_kv_layers = 99) - ctx = 4096 - kv_per = 8 * (128 + 128) * 2 - assert b._estimate_kv_cache_bytes(ctx, "f16") == 1 * ctx * kv_per - - def test_composes_with_n_parallel(self): - # Only the SWA portion of the unshared layers scales by n_parallel; - # the global portion stays constant. - b = self._gemma3n_backend() - ctx = 8192 - swa = b._sliding_window - per_token = 4 * (256 + 256) * 2 - unshared_pattern = b._sliding_window_pattern[:20] # 35 - 15 shared - sliding_in_unshared = sum(unshared_pattern) - global_in_unshared = len(unshared_pattern) - sliding_in_unshared - global_bytes = global_in_unshared * ctx * per_token - slots = 3 - per_slot_ctx = max(1, ctx // slots) - swa_cells = min(ctx, 2 * swa, per_slot_ctx) - swa_bytes_per_slot = sliding_in_unshared * swa_cells * per_token - flagged = b._estimate_kv_cache_bytes( - ctx, "f16", n_parallel = slots, kv_unified = False - ) - assert flagged == global_bytes + slots * swa_bytes_per_slot - - def test_composes_with_ctx_checkpoints(self): - b = self._gemma3n_backend() - ctx = 8192 - baseline = b._estimate_kv_cache_bytes(ctx, "f16") - with_cp = b._estimate_kv_cache_bytes(ctx, "f16", ctx_checkpoints = 4) - # Checkpoints only count over UNSHARED SWA layers (16 of them). - sliding_in_unshared = sum(b._sliding_window_pattern[:20]) - per_cp_layer = 4 * 1024 * 4 * (256 + 256) * 2 # cps * swa * heads * (k+v) * bpe - assert with_cp == baseline + sliding_in_unshared * per_cp_layer - - def test_unload_resets_shared_kv_layers(self): - b = LlamaCppBackend() - b._shared_kv_layers = 12 - b.unload_model() - assert b._shared_kv_layers is None - - # --------------------------------------------------------------------------- # K. Lifecycle Tests # --------------------------------------------------------------------------- @@ -2004,18 +799,13 @@ def test_init_fields_none(self): "_kv_key_length", "_kv_value_length", "_sliding_window", - "_sliding_window_pattern", "_full_attention_interval", "_kv_lora_rank", "_key_length_mla", - "_kv_key_length_swa", - "_kv_value_length_swa", "_ssm_inner_size", "_ssm_state_size", - "_shared_kv_layers", ]: assert getattr(b, attr) is None - assert b._n_kv_heads_by_layer is None def test_unload_resets_fields(self): b = LlamaCppBackend() @@ -2023,30 +813,20 @@ def test_unload_resets_fields(self): b._kv_key_length = 128 b._kv_lora_rank = 512 b._sliding_window = 1024 - b._sliding_window_pattern = [True, False] - b._n_kv_heads_by_layer = [8, 2] - b._kv_key_length_swa = 64 - b._kv_value_length_swa = 64 b._ssm_inner_size = 4096 b._full_attention_interval = 4 - b._shared_kv_layers = 8 b.unload_model() for attr in [ "_kv_key_length", "_kv_value_length", "_sliding_window", - "_sliding_window_pattern", "_full_attention_interval", "_kv_lora_rank", "_key_length_mla", - "_kv_key_length_swa", - "_kv_value_length_swa", "_ssm_inner_size", "_ssm_state_size", - "_shared_kv_layers", ]: assert getattr(b, attr) is None - assert b._n_kv_heads_by_layer is None def test_end_to_end_synthetic_mla(self): """Full round-trip: write GGUF -> parse -> estimate.""" @@ -2107,46 +887,12 @@ def test_end_to_end_synthetic_swa(self): ) assert b._can_estimate_kv() result = b._estimate_kv_cache_bytes(131072, "f16") - # gemma3 -> period 6 from the bootstrap table, SWA cache - # double-buffered to 2 * sliding_window cells. - period = 6 + n_global = max(1, 62 // 4) # 15 + n_swa = 62 - n_global # 47 kv_per = 16 * 256 * 2 - expected = 0 - for i in range(62): - is_swa = (i + 1) % period != 0 - layer_ctx = min(131072, 2 * 1024) if is_swa else 131072 - expected += layer_ctx * kv_per + expected = int(n_global * 131072 * kv_per + n_swa * 1024 * kv_per) assert result == expected - def test_end_to_end_synthetic_shared_kv_round_trip(self): - # Mirrors gemma3n_text: 35 layers, 15 shared, sliding_window=1024. - b = _backend_from_gguf( - "gemma3n_text", - { - "context_length": 32768, - "block_count": 35, - "attention.head_count_kv": 4, - "attention.head_count": 8, - "embedding_length": 2048, - "attention.key_length": 256, - "attention.value_length": 256, - "attention.sliding_window": 1024, - "attention.shared_kv_layers": 15, - }, - ) - assert b._can_estimate_kv() - assert b._shared_kv_layers == 15 - # Bootstrap table for gemma3n_text -> period 5; the resolver - # synthesises a 35-entry bool array. The first 20 entries - # (n_layers - shared) are the only ones that allocate KV. - result = b._estimate_kv_cache_bytes(8192, "f16") - assert result > 0 - # Sanity: setting shared back to 0 must produce a strictly larger - # estimate (more layers allocate). - b._shared_kv_layers = 0 - unshared = b._estimate_kv_cache_bytes(8192, "f16") - assert unshared > result - def test_end_to_end_synthetic_gqa(self): b = _backend_from_gguf( "qwen3", diff --git a/studio/backend/tests/test_llama_cpp_cache_aware_disk_check.py b/studio/backend/tests/test_llama_cpp_cache_aware_disk_check.py deleted file mode 100644 index 255c04a956..0000000000 --- a/studio/backend/tests/test_llama_cpp_cache_aware_disk_check.py +++ /dev/null @@ -1,243 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 - -"""Tests for the cache-aware disk-space preflight in -``LlamaCppBackend.load_model``. - -The preflight used to compare the repo's total GGUF download size against -free disk without accounting for bytes already present in the Hugging -Face cache. That made re-loading a cached large model (e.g. -``unsloth/MiniMax-M2.7-GGUF`` at 131 GB) fail cold whenever free disk was -below the full weight footprint, even though nothing needed -downloading. - -These tests exercise the preflight arithmetic in isolation by driving -``get_paths_info`` and ``try_to_load_from_cache`` through ``mock.patch``. -No network, GPU, or subprocess use. - -Cross-platform: Linux, macOS, Windows, WSL. -""" - -from __future__ import annotations - -import sys -import tempfile -import types as _types -from pathlib import Path -from unittest.mock import patch - -import pytest - -# --------------------------------------------------------------------------- -# Stub heavy / unavailable external dependencies before importing the -# module under test. Same pattern as test_kv_cache_estimation.py. -# --------------------------------------------------------------------------- - -_BACKEND_DIR = str(Path(__file__).resolve().parent.parent) -if _BACKEND_DIR not in sys.path: - sys.path.insert(0, _BACKEND_DIR) - -# loggers -_loggers_stub = _types.ModuleType("loggers") -_loggers_stub.get_logger = lambda name: __import__("logging").getLogger(name) -sys.modules.setdefault("loggers", _loggers_stub) - -# structlog -_structlog_stub = _types.ModuleType("structlog") -sys.modules.setdefault("structlog", _structlog_stub) - -# httpx -_httpx_stub = _types.ModuleType("httpx") -for _exc_name in ( - "ConnectError", - "TimeoutException", - "ReadTimeout", - "ReadError", - "RemoteProtocolError", - "CloseError", -): - setattr(_httpx_stub, _exc_name, type(_exc_name, (Exception,), {})) - - -class _FakeTimeout: - def __init__(self, *a, **kw): - pass - - -_httpx_stub.Timeout = _FakeTimeout -_httpx_stub.Client = type( - "Client", - (), - { - "__init__": lambda self, **kw: None, - "__enter__": lambda self: self, - "__exit__": lambda self, *a: None, - }, -) -sys.modules.setdefault("httpx", _httpx_stub) - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - -GIB = 1024**3 - - -class _FakePathInfo: - """Mimics huggingface_hub's RepoFile-ish return type from get_paths_info.""" - - def __init__(self, path: str, size: int): - self.path = path - self.size = size - - -def _preflight( - repo_files, - cached_files, - free_bytes, - hf_repo = "unsloth/Example-GGUF", - hf_token = None, -): - """Run the preflight arithmetic as written in llama_cpp.py and return - the decision outcome as a dict. - - ``repo_files``: list of (filename, remote_bytes). - ``cached_files``: dict {filename: on_disk_bytes} for files already in cache. - ``free_bytes``: value returned by shutil.disk_usage(cache_dir).free. - """ - import os - import shutil - - path_infos = [_FakePathInfo(name, size) for name, size in repo_files] - - with tempfile.TemporaryDirectory() as tmp: - # Create SPARSE files for the cached ones so os.path.exists / - # os.path.getsize pass without actually allocating bytes on disk. - # This is critical when simulating multi-GB models. - cache_paths = {} - for name, sz in cached_files.items(): - p = Path(tmp) / name.replace("/", "_") - with open(p, "wb") as fh: - if sz > 0: - fh.truncate(sz) # sparse allocation: no data blocks written - cache_paths[name] = str(p) - - def fake_try_to_load_from_cache(repo_id, filename): - return cache_paths.get(filename) - - # Mirror the same variable names and control flow as the real code - # so behavioral drift is caught immediately. - total_bytes = sum((p.size or 0) for p in path_infos) - already_cached_bytes = 0 - for p in path_infos: - if not p.size: - continue - cached_path = fake_try_to_load_from_cache(hf_repo, p.path) - if isinstance(cached_path, str) and os.path.exists(cached_path): - try: - on_disk = os.path.getsize(cached_path) - except OSError: - on_disk = 0 - if on_disk >= p.size: - already_cached_bytes += p.size - - total_download_bytes = max(0, total_bytes - already_cached_bytes) - needed_download = total_download_bytes > free_bytes - return { - "total_bytes": total_bytes, - "already_cached_bytes": already_cached_bytes, - "total_download_bytes": total_download_bytes, - "would_raise_disk_error": (needed_download and total_download_bytes > 0), - } - - -# --------------------------------------------------------------------------- -# Tests -# --------------------------------------------------------------------------- - - -class TestCacheAwarePreflight: - def test_fully_cached_model_does_not_require_disk(self): - """The MiniMax case: 131 GB weights cached, only 36 GB free. - Preflight must not raise.""" - shards = [(f"UD-Q4_K_XL/shard-{i}.gguf", 35 * GIB) for i in range(4)] - cached = {name: size for name, size in shards} - out = _preflight( - repo_files = shards, - cached_files = cached, - free_bytes = 36 * GIB, - ) - assert out["total_download_bytes"] == 0 - assert out["already_cached_bytes"] == 140 * GIB - assert out["would_raise_disk_error"] is False - - def test_partial_cache_only_counts_remaining_bytes(self): - """Two of four shards cached: preflight against remaining 70 GB.""" - shards = [(f"UD-Q4_K_XL/shard-{i}.gguf", 35 * GIB) for i in range(4)] - cached = { - shards[0][0]: shards[0][1], - shards[1][0]: shards[1][1], - } - out = _preflight( - repo_files = shards, - cached_files = cached, - free_bytes = 80 * GIB, - ) - assert out["already_cached_bytes"] == 70 * GIB - assert out["total_download_bytes"] == 70 * GIB - assert out["would_raise_disk_error"] is False - - def test_partial_cache_insufficient_disk_for_rest_still_raises(self): - """Two of four shards cached; remaining 70 GB still bigger than - free disk -> preflight correctly wants to raise.""" - shards = [(f"UD-Q4_K_XL/shard-{i}.gguf", 35 * GIB) for i in range(4)] - cached = { - shards[0][0]: shards[0][1], - shards[1][0]: shards[1][1], - } - out = _preflight( - repo_files = shards, - cached_files = cached, - free_bytes = 50 * GIB, - ) - assert out["total_download_bytes"] == 70 * GIB - assert out["would_raise_disk_error"] is True - - def test_nothing_cached_preserves_existing_behavior(self): - """Cold-cache path still compares full download vs free disk.""" - shards = [("UD-Q4_K_XL/shard-0.gguf", 40 * GIB)] - out = _preflight( - repo_files = shards, - cached_files = {}, - free_bytes = 50 * GIB, - ) - assert out["already_cached_bytes"] == 0 - assert out["total_download_bytes"] == 40 * GIB - assert out["would_raise_disk_error"] is False - - def test_incomplete_cached_blob_is_not_credited(self): - """A partial file on disk (e.g. interrupted download) is not - counted as cached -- we still require bytes for it.""" - shards = [("UD-Q4_K_XL/shard-0.gguf", 40 * GIB)] - partial = {"UD-Q4_K_XL/shard-0.gguf": 10 * GIB} - out = _preflight( - repo_files = shards, - cached_files = partial, - free_bytes = 50 * GIB, - ) - assert out["already_cached_bytes"] == 0 - assert out["total_download_bytes"] == 40 * GIB - assert out["would_raise_disk_error"] is False - - def test_zero_size_path_infos_do_not_crash(self): - """A path_info with size=0 should not be credited or break the - arithmetic.""" - shards = [("mmproj.gguf", 0), ("UD-Q4_K_XL/shard-0.gguf", 40 * GIB)] - out = _preflight( - repo_files = shards, - cached_files = {}, - free_bytes = 50 * GIB, - ) - assert out["already_cached_bytes"] == 0 - assert out["total_bytes"] == 40 * GIB diff --git a/studio/backend/tests/test_llama_cpp_context_fit.py b/studio/backend/tests/test_llama_cpp_context_fit.py deleted file mode 100644 index caa6397901..0000000000 --- a/studio/backend/tests/test_llama_cpp_context_fit.py +++ /dev/null @@ -1,393 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 - -"""Tests for the GGUF load-time context auto-fit decision. - -Guards two regressions in ``LlamaCppBackend.load_model``: - -1. **Auto mode on weights-exceed-VRAM** (``n_ctx == 0``): when the model - weights alone exceed 90% of every GPU subset's free memory, the - auto-pick loop used to exit without matching, leaving - ``effective_ctx`` at the model's native context (e.g. 196608 for - MiniMax-M2.7). The intended default per Studio's UI spec is 4096 so - the slider lands on a usable value; the user can still drag higher - and trigger ``--fit on`` with a warning. - -2. **Explicit ctx silently shrunk when KV overflows**: with fittable - weights but a requested ctx whose KV cache pushes total memory over - 90% of VRAM, the old code binary-searched a smaller ctx and emitted - ``-c -ngl -1`` without informing the caller. The UI had - already surfaced its "might be slower" warning and expects the user's - explicit ctx to be honored with ``--fit on`` flexing ``-ngl`` instead. - -Tests avoid GPU probing, subprocess spawning, and GGUF I/O by driving the -post-metadata decision block directly against a stubbed instance. - -Requires no GPU, network, or external libraries beyond pytest. -Cross-platform: Linux, macOS, Windows, WSL. -""" - -from __future__ import annotations - -import sys -import types as _types -from pathlib import Path - -import pytest - -# --------------------------------------------------------------------------- -# Stub heavy / unavailable external dependencies before importing the -# module under test. Same pattern as test_kv_cache_estimation.py. -# --------------------------------------------------------------------------- - -_BACKEND_DIR = str(Path(__file__).resolve().parent.parent) -if _BACKEND_DIR not in sys.path: - sys.path.insert(0, _BACKEND_DIR) - -# loggers -_loggers_stub = _types.ModuleType("loggers") -_loggers_stub.get_logger = lambda name: __import__("logging").getLogger(name) -sys.modules.setdefault("loggers", _loggers_stub) - -# structlog -_structlog_stub = _types.ModuleType("structlog") -sys.modules.setdefault("structlog", _structlog_stub) - -# httpx -_httpx_stub = _types.ModuleType("httpx") -for _exc_name in ( - "ConnectError", - "TimeoutException", - "ReadTimeout", - "ReadError", - "RemoteProtocolError", - "CloseError", -): - setattr(_httpx_stub, _exc_name, type(_exc_name, (Exception,), {})) - - -class _FakeTimeout: - def __init__(self, *a, **kw): - pass - - -_httpx_stub.Timeout = _FakeTimeout -_httpx_stub.Client = type( - "Client", - (), - { - "__init__": lambda self, **kw: None, - "__enter__": lambda self: self, - "__exit__": lambda self, *a: None, - }, -) -sys.modules.setdefault("httpx", _httpx_stub) - -from core.inference.llama_cpp import LlamaCppBackend - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - -GIB = 1024**3 -FALLBACK_CTX = 4096 - - -def _make_backend( - native_ctx = 131072, - n_layers = 80, - n_kv_heads = 8, - n_heads = 64, - kv_key_length = 128, - kv_value_length = 128, -): - """Create a LlamaCppBackend instance with GGUF metadata fields set and - the helpers used by the decision block stubbed out.""" - inst = LlamaCppBackend.__new__(LlamaCppBackend) - inst._context_length = native_ctx - inst._n_layers = n_layers - inst._n_kv_heads = n_kv_heads - inst._n_heads = n_heads - inst._embedding_length = 8192 - inst._kv_key_length = kv_key_length - inst._kv_value_length = kv_value_length - inst._kv_lora_rank = None - inst._sliding_window = None - inst._sliding_window_pattern = None - inst._ssm_inner_size = None - inst._full_attention_interval = None - inst._key_length_mla = None - inst._n_kv_heads_by_layer = None - inst._kv_key_length_swa = None - inst._kv_value_length_swa = None - return inst - - -def _drive( - n_ctx, - model_gib, - gpus, - native_ctx = 131072, - kv_per_token_bytes = 325_000, - can_estimate_kv = True, -): - """Drive the post-metadata portion of load_model with stubbed inputs. - - Mirrors the decision block at llama_cpp.py:1137-1296 so we can assert - the command that would be built, without subprocesses or GPU probes. - """ - inst = _make_backend(native_ctx = native_ctx) - model_size = int(model_gib * GIB) - cache_type_kv = None - - def fake_estimate(n_ctx_, _type = None, **_kwargs): - return 0 if n_ctx_ <= 0 else n_ctx_ * kv_per_token_bytes - - inst._estimate_kv_cache_bytes = fake_estimate - inst._can_estimate_kv = lambda: can_estimate_kv - - context_length = inst._context_length - - effective_ctx = n_ctx if n_ctx > 0 else (context_length or 0) - max_available_ctx = context_length or effective_ctx - if n_ctx > 0: - effective_ctx = n_ctx - elif context_length is not None: - effective_ctx = context_length - else: - effective_ctx = 0 - original_ctx = effective_ctx - max_available_ctx = context_length or effective_ctx - - gpu_indices, use_fit = None, True - explicit_ctx = n_ctx > 0 - - if gpus and inst._can_estimate_kv() and effective_ctx > 0: - native_ctx_for_cap = context_length or effective_ctx - if native_ctx_for_cap > 0: - ranked_for_cap = sorted(gpus, key = lambda g: g[1], reverse = True) - best_cap = 0 - for n_gpus in range(1, len(ranked_for_cap) + 1): - subset = ranked_for_cap[:n_gpus] - pool_mib = sum(free for _, free in subset) - capped = inst._fit_context_to_vram( - native_ctx_for_cap, - pool_mib, - model_size, - cache_type_kv, - ) - kv = inst._estimate_kv_cache_bytes(capped, cache_type_kv) - total_mib = (model_size + kv) / (1024 * 1024) - if total_mib <= pool_mib * 0.90: - best_cap = max(best_cap, capped) - if best_cap > 0: - max_available_ctx = best_cap - - if explicit_ctx: - requested_total = model_size + inst._estimate_kv_cache_bytes( - effective_ctx, cache_type_kv - ) - gpu_indices, use_fit = inst._select_gpus(requested_total, gpus) - else: - ranked = sorted(gpus, key = lambda g: g[1], reverse = True) - matched = False - for n_gpus in range(1, len(ranked) + 1): - subset = ranked[:n_gpus] - pool_mib = sum(free for _, free in subset) - capped = inst._fit_context_to_vram( - effective_ctx, - pool_mib, - model_size, - cache_type_kv, - ) - kv = inst._estimate_kv_cache_bytes(capped, cache_type_kv) - total_mib = (model_size + kv) / (1024 * 1024) - if total_mib <= pool_mib * 0.90: - effective_ctx = capped - gpu_indices = sorted(idx for idx, _ in subset) - use_fit = False - matched = True - break - if not matched: - effective_ctx = min(FALLBACK_CTX, effective_ctx) - elif gpus: - gpu_indices, use_fit = inst._select_gpus(model_size, gpus) - if use_fit and not explicit_ctx: - effective_ctx = ( - min(FALLBACK_CTX, effective_ctx) if effective_ctx > 0 else FALLBACK_CTX - ) - - return { - "c_arg": effective_ctx if effective_ctx > 0 else 0, - "use_fit": use_fit, - "gpu_indices": gpu_indices, - "max_available_ctx": max_available_ctx, - "original_ctx": original_ctx, - } - - -# --------------------------------------------------------------------------- -# Auto mode, model weights exceed VRAM (Bug A guard) -# --------------------------------------------------------------------------- - - -class TestAutoModeWeightsExceedVRAM: - """``n_ctx == 0`` on a model whose weights don't fit anywhere.""" - - def test_minimax_like_single_gpu(self): - plan = _drive( - n_ctx = 0, - model_gib = 131, - gpus = [(0, 97_000)], - native_ctx = 196608, - ) - assert plan["c_arg"] == FALLBACK_CTX - assert plan["use_fit"] is True - assert plan["gpu_indices"] is None - # UI slider ceiling stays at native: user can still drag higher - # and get the "might be slower" path. - assert plan["max_available_ctx"] == 196608 - - def test_multi_gpu_all_subsets_fail(self): - plan = _drive( - n_ctx = 0, - model_gib = 400, - gpus = [(0, 80_000), (1, 80_000), (2, 80_000), (3, 80_000)], - native_ctx = 131072, - ) - assert plan["c_arg"] == FALLBACK_CTX - assert plan["use_fit"] is True - assert plan["gpu_indices"] is None - - def test_no_kv_metadata_auto(self): - """File-size-only fallback path also defaults to 4096.""" - plan = _drive( - n_ctx = 0, - model_gib = 131, - gpus = [(0, 97_000)], - native_ctx = 196608, - can_estimate_kv = False, - ) - assert plan["c_arg"] == FALLBACK_CTX - assert plan["use_fit"] is True - - -# --------------------------------------------------------------------------- -# Explicit ctx, KV overflows fittable weights (Bug B guard) -# --------------------------------------------------------------------------- - - -class TestExplicitCtxRespectsUser: - """``n_ctx > 0`` must never be silently shrunk.""" - - def test_fittable_weights_oversized_kv(self): - # 8 GB weights + 131k ctx KV on 24 GB VRAM. - # Budget = 21.6 GB, KV at 131k >> 13.6 GB remaining, so - # _select_gpus flips use_fit=True. - plan = _drive( - n_ctx = 131072, - model_gib = 8, - gpus = [(0, 24_000)], - native_ctx = 131072, - ) - assert plan["c_arg"] == 131072 - assert plan["use_fit"] is True - assert plan["gpu_indices"] is None - - def test_explicit_that_fits_uses_ngl(self): - plan = _drive( - n_ctx = 8192, - model_gib = 8, - gpus = [(0, 24_000)], - native_ctx = 131072, - ) - assert plan["c_arg"] == 8192 - assert plan["use_fit"] is False - assert plan["gpu_indices"] == [0] - - def test_explicit_on_weights_exceed_vram(self): - # User drags the slider to 32k on a too-big model: honored. - plan = _drive( - n_ctx = 32768, - model_gib = 131, - gpus = [(0, 97_000)], - native_ctx = 196608, - ) - assert plan["c_arg"] == 32768 - assert plan["use_fit"] is True - - def test_explicit_at_fallback_on_too_big(self): - plan = _drive( - n_ctx = FALLBACK_CTX, - model_gib = 131, - gpus = [(0, 97_000)], - native_ctx = 196608, - ) - assert plan["c_arg"] == FALLBACK_CTX - assert plan["use_fit"] is True - - def test_explicit_below_floor_honored(self): - # 2048 is below --fit-ctx default; still honored since user set it. - plan = _drive( - n_ctx = 2048, - model_gib = 8, - gpus = [(0, 24_000)], - ) - assert plan["c_arg"] == 2048 - - -# --------------------------------------------------------------------------- -# Non-regression: fittable + auto still auto-picks largest fitting ctx -# --------------------------------------------------------------------------- - - -class TestFittableAutoPickRegressions: - def test_small_model_one_gpu(self): - plan = _drive( - n_ctx = 0, - model_gib = 8, - gpus = [(0, 24_000)], - native_ctx = 131072, - kv_per_token_bytes = 8192, - ) - assert plan["use_fit"] is False - assert plan["gpu_indices"] == [0] - assert plan["c_arg"] > FALLBACK_CTX - - def test_medium_model_needs_multi_gpu(self): - plan = _drive( - n_ctx = 0, - model_gib = 60, - gpus = [(0, 40_000), (1, 40_000)], - native_ctx = 131072, - kv_per_token_bytes = 8192, - ) - assert plan["use_fit"] is False - assert plan["gpu_indices"] == [0, 1] - - def test_no_kv_metadata_fittable_auto(self): - plan = _drive( - n_ctx = 0, - model_gib = 8, - gpus = [(0, 24_000)], - native_ctx = 131072, - can_estimate_kv = False, - ) - assert plan["use_fit"] is False - assert plan["gpu_indices"] == [0] - - -# --------------------------------------------------------------------------- -# Platform-agnostic input shape -# --------------------------------------------------------------------------- - - -@pytest.mark.parametrize("platform_tag", ["linux", "windows", "mac", "rocm"]) -def test_identical_decision_across_platforms(platform_tag): - """The decision function takes ``[(gpu_idx, free_mib), ...]`` regardless - of how upstream (nvidia-smi / nvidia-smi.exe / Metal / rocm-smi) produced - it. Identical inputs must yield identical plans.""" - plan_a = _drive(n_ctx = 0, model_gib = 8, gpus = [(0, 24_000)]) - plan_b = _drive(n_ctx = 0, model_gib = 8, gpus = [(0, 24_000)]) - assert plan_a == plan_b, platform_tag diff --git a/studio/backend/tests/test_llama_cpp_load_progress.py b/studio/backend/tests/test_llama_cpp_load_progress.py deleted file mode 100644 index f46751b798..0000000000 --- a/studio/backend/tests/test_llama_cpp_load_progress.py +++ /dev/null @@ -1,258 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 - -"""Tests for ``LlamaCppBackend.load_progress()``. - -The chat settings flow and the training overlay both show a generic -"Starting model..." spinner during the window after a GGUF download -finishes and before llama-server reports healthy. For small models -that window is a second or two and nobody notices. For large MoE GGUFs -(MiniMax-M2.7, Qwen3.5-397B-A17B, etc.) the llama-server process spends -minutes in kernel state D, paging tens or hundreds of GB of shards -into the page cache. The UI has no way to show a real progress bar, -rate, or ETA during that window. - -``load_progress()`` samples ``/proc//status VmRSS`` (what the -kernel has actually paged in) against the total shard file size on -disk, so the frontend can render a real bar plus rate/ETA. This -module pins that contract: - - * returns ``None`` when no load is in flight - * returns ``{"phase": "mmap", ...}`` while the subprocess is alive - but ``_healthy`` is False - * returns ``{"phase": "ready", ...}`` once ``_healthy`` flips - * ``bytes_total`` is derived from the resolved on-disk path - (which the paired fix assigns to ``self._gguf_path`` on both the - local-GGUF and HF-download code paths) - * ``bytes_loaded`` is VmRSS in bytes, capped by total, rounded - * ``fraction`` is clamped to 0..1 and rounded to 4 decimal places - -Linux-only via ``/proc``. On platforms without ``/proc`` the method -returns ``None`` instead of raising. -Cross-platform test: skips cleanly on macOS / Windows if ``/proc`` is -not available. -""" - -from __future__ import annotations - -import os -import sys -import tempfile -import types as _types -from pathlib import Path -from unittest.mock import patch - -import pytest - -# --------------------------------------------------------------------------- -# Stub heavy / unavailable external dependencies before importing the -# module under test. Same pattern as test_kv_cache_estimation.py. -# --------------------------------------------------------------------------- - -_BACKEND_DIR = str(Path(__file__).resolve().parent.parent) -if _BACKEND_DIR not in sys.path: - sys.path.insert(0, _BACKEND_DIR) - -_loggers_stub = _types.ModuleType("loggers") -_loggers_stub.get_logger = lambda name: __import__("logging").getLogger(name) -sys.modules.setdefault("loggers", _loggers_stub) - -_structlog_stub = _types.ModuleType("structlog") -sys.modules.setdefault("structlog", _structlog_stub) - -_httpx_stub = _types.ModuleType("httpx") -for _exc_name in ( - "ConnectError", - "TimeoutException", - "ReadTimeout", - "ReadError", - "RemoteProtocolError", - "CloseError", -): - setattr(_httpx_stub, _exc_name, type(_exc_name, (Exception,), {})) - - -class _FakeTimeout: - def __init__(self, *a, **kw): - pass - - -_httpx_stub.Timeout = _FakeTimeout -_httpx_stub.Client = type( - "Client", - (), - { - "__init__": lambda self, **kw: None, - "__enter__": lambda self: self, - "__exit__": lambda self, *a: None, - }, -) -sys.modules.setdefault("httpx", _httpx_stub) - -from core.inference.llama_cpp import LlamaCppBackend - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def _make_instance(): - inst = LlamaCppBackend.__new__(LlamaCppBackend) - inst._process = None - inst._gguf_path = None - inst._healthy = False - return inst - - -class _FakeProc: - """Minimal stand-in for subprocess.Popen that just carries a pid.""" - - def __init__(self, pid: int): - self.pid = pid - - -def _write_sparse_file(path: Path, size_bytes: int) -> None: - """Create a sparse file of the given size without allocating blocks.""" - with open(path, "wb") as fh: - if size_bytes > 0: - fh.truncate(size_bytes) - - -# --------------------------------------------------------------------------- -# Tests -# --------------------------------------------------------------------------- - - -class TestLoadProgressEmptyStates: - def test_returns_none_when_no_process(self): - inst = _make_instance() - assert inst.load_progress() is None - - def test_returns_none_when_process_has_no_pid(self): - inst = _make_instance() - inst._process = _FakeProc(pid = None) # type: ignore[arg-type] - assert inst.load_progress() is None - - -class TestLoadProgressSingleShard: - def test_mmap_phase_for_alive_but_unhealthy(self, tmp_path): - """VmRSS below total -> phase='mmap', fraction reflects progress.""" - gguf = tmp_path / "model.gguf" - _write_sparse_file(gguf, 40 * 1024**3) # 40 GB - - inst = _make_instance() - inst._process = _FakeProc(pid = os.getpid()) # use our own pid - inst._gguf_path = str(gguf) - inst._healthy = False - - # Patch /proc read to claim 10 GB RSS. - def fake_open(path, *args, **kwargs): - if str(path).startswith("/proc/"): - import io - - return io.StringIO(f"Name:\ttest\nVmRSS:\t{10 * 1024 ** 2}\tkB\n") - return open(path, *args, **kwargs) # fall through - - with patch("builtins.open", side_effect = fake_open): - out = inst.load_progress() - - assert out is not None - assert out["phase"] == "mmap" - assert out["bytes_total"] == 40 * 1024**3 - assert out["bytes_loaded"] == 10 * 1024**3 - assert 0.24 < out["fraction"] < 0.26 # ~25% - - def test_ready_phase_when_healthy(self, tmp_path): - gguf = tmp_path / "model.gguf" - _write_sparse_file(gguf, 8 * 1024**3) - - inst = _make_instance() - inst._process = _FakeProc(pid = os.getpid()) - inst._gguf_path = str(gguf) - inst._healthy = True - - def fake_open(path, *args, **kwargs): - if str(path).startswith("/proc/"): - import io - - return io.StringIO(f"VmRSS:\t{8 * 1024 ** 2}\tkB\n") - return open(path, *args, **kwargs) - - with patch("builtins.open", side_effect = fake_open): - out = inst.load_progress() - - assert out is not None - assert out["phase"] == "ready" - assert out["bytes_total"] == 8 * 1024**3 - assert out["bytes_loaded"] == 8 * 1024**3 - assert out["fraction"] == 1.0 - - -class TestLoadProgressMultiShard: - """Shard-aware total: for ``*-00001-of-00004.gguf`` primaries the - method sums sibling files with the same prefix.""" - - def test_sharded_total_aggregates_siblings(self, tmp_path): - for i in range(1, 5): - _write_sparse_file( - tmp_path / f"model-{i:05d}-of-00004.gguf", - size_bytes = 20 * 1024**3, - ) - # Drop an unrelated .gguf in the same folder -- must not be counted. - _write_sparse_file(tmp_path / "mmproj-BF16.gguf", 2 * 1024**3) - - inst = _make_instance() - inst._process = _FakeProc(pid = os.getpid()) - inst._gguf_path = str(tmp_path / "model-00001-of-00004.gguf") - inst._healthy = False - - def fake_open(path, *args, **kwargs): - if str(path).startswith("/proc/"): - import io - - return io.StringIO("VmRSS:\t0\tkB\n") - return open(path, *args, **kwargs) - - with patch("builtins.open", side_effect = fake_open): - out = inst.load_progress() - - assert out is not None - assert out["bytes_total"] == 80 * 1024**3 # 4 x 20 GB, no mmproj - - -class TestLoadProgressDegradation: - """Broken / unusual inputs never raise; they produce best-effort output.""" - - def test_missing_gguf_path_still_reports_rss(self, tmp_path): - inst = _make_instance() - inst._process = _FakeProc(pid = os.getpid()) - inst._gguf_path = None - inst._healthy = False - - def fake_open(path, *args, **kwargs): - if str(path).startswith("/proc/"): - import io - - return io.StringIO("VmRSS:\t1024\tkB\n") - return open(path, *args, **kwargs) - - with patch("builtins.open", side_effect = fake_open): - out = inst.load_progress() - - assert out is not None - assert out["phase"] == "mmap" - assert out["bytes_total"] == 0 - assert out["bytes_loaded"] == 1024 * 1024 - assert out["fraction"] == 0.0 - - def test_unreadable_proc_returns_none(self, tmp_path): - inst = _make_instance() - # Pid that doesn't exist -> /proc read fails. - inst._process = _FakeProc(pid = 999_999_999) - inst._gguf_path = str(tmp_path / "model.gguf") # doesn't need to exist - inst._healthy = False - - out = inst.load_progress() - # FileNotFoundError on /proc path -> load_progress returns None. - assert out is None diff --git a/studio/backend/tests/test_llama_cpp_load_progress_live.py b/studio/backend/tests/test_llama_cpp_load_progress_live.py deleted file mode 100644 index beed8713c1..0000000000 --- a/studio/backend/tests/test_llama_cpp_load_progress_live.py +++ /dev/null @@ -1,202 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 - -"""Live, no-mock integration test for ``LlamaCppBackend.load_progress()``. - -The companion files (``test_llama_cpp_load_progress.py`` and -``test_llama_cpp_load_progress_matrix.py``) patch ``builtins.open`` to -feed synthetic VmRSS values. This file is the opposite: it uses **real** -subprocesses, **real** file sizes, and the **real** ``/proc`` -interface. It is the sanity check that the contract we keep in the -mocked tests still maps to what the kernel actually returns on a live -Linux system. - -Why both: the mocked tests can be fooled by a buggy implementation that -parses ``/proc`` output in a format the kernel no longer uses, or that -makes assumptions about ``Path.stat()`` vs ``os.path.getsize``. This -file hits the real APIs so any format drift gets caught. - -Skipped cleanly on non-Linux (no ``/proc``). -""" - -from __future__ import annotations - -import os -import subprocess -import sys -import time -import types as _types -from pathlib import Path - -import pytest - -# --------------------------------------------------------------------------- -# Same stubs as the matrix file (keep self-contained so the file can be -# run standalone as well as via the full suite). -# --------------------------------------------------------------------------- - -_BACKEND_DIR = str(Path(__file__).resolve().parent.parent) -if _BACKEND_DIR not in sys.path: - sys.path.insert(0, _BACKEND_DIR) - -_loggers_stub = _types.ModuleType("loggers") -_loggers_stub.get_logger = lambda name: __import__("logging").getLogger(name) -sys.modules.setdefault("loggers", _loggers_stub) -_structlog_stub = _types.ModuleType("structlog") -sys.modules.setdefault("structlog", _structlog_stub) -_httpx_stub = _types.ModuleType("httpx") -for _exc in ( - "ConnectError", - "TimeoutException", - "ReadTimeout", - "ReadError", - "RemoteProtocolError", - "CloseError", -): - setattr(_httpx_stub, _exc, type(_exc, (Exception,), {})) -_httpx_stub.Timeout = type("Timeout", (), {"__init__": lambda self, *a, **k: None}) -_httpx_stub.Client = type( - "Client", - (), - { - "__init__": lambda self, **kw: None, - "__enter__": lambda self: self, - "__exit__": lambda self, *a: None, - }, -) -sys.modules.setdefault("httpx", _httpx_stub) - -from core.inference.llama_cpp import LlamaCppBackend - - -pytestmark = pytest.mark.skipif( - not Path("/proc").exists(), - reason = "live /proc test is Linux-only", -) - - -def _make_backend(pid: int, gguf_path: str, healthy: bool = False): - inst = LlamaCppBackend.__new__(LlamaCppBackend) - inst._process = type("P", (), {"pid": pid})() - inst._gguf_path = gguf_path - inst._healthy = healthy - return inst - - -def test_live_rss_matches_kernel_vmrss(tmp_path): - """Spawn a real child, let it allocate real bytes, confirm - ``bytes_loaded`` tracks the kernel's VmRSS within a sane tolerance.""" - # Child that allocates ~100 MB of zero'd bytes and then idles. - script = tmp_path / "burn.py" - script.write_text( - "import time, sys\n" - "buf = bytearray(100 * 1024 * 1024)\n" # 100 MB - "# touch every page so RSS actually grows\n" - "for i in range(0, len(buf), 4096):\n" - " buf[i] = 1\n" - "sys.stdout.write('ready\\n')\n" - "sys.stdout.flush()\n" - "time.sleep(10)\n" - ) - proc = subprocess.Popen( - [sys.executable, str(script)], - stdout = subprocess.PIPE, - stderr = subprocess.PIPE, - ) - try: - # Wait for the child to finish touching pages. - ready = proc.stdout.readline() - assert ready.strip() == b"ready" - - # Create a fake 200 MB sparse gguf so bytes_total is concrete. - gguf = tmp_path / "model.gguf" - with open(gguf, "wb") as f: - f.truncate(200 * 1024 * 1024) - - inst = _make_backend(proc.pid, str(gguf), healthy = False) - out = inst.load_progress() - - assert out is not None, "load_progress returned None for live pid" - assert out["phase"] == "mmap" - assert out["bytes_total"] == 200 * 1024 * 1024 - # VmRSS for the Python child includes the interpreter + the 100MB - # buffer, so a realistic floor is 50 MB and ceiling is 200 MB. - assert ( - out["bytes_loaded"] >= 50 * 1024 * 1024 - ), f"bytes_loaded unexpectedly low: {out['bytes_loaded']}" - assert out["bytes_loaded"] <= 200 * 1024 * 1024 - assert 0.0 < out["fraction"] <= 1.0 - finally: - proc.terminate() - try: - proc.wait(timeout = 5) - except subprocess.TimeoutExpired: - proc.kill() - - -def test_live_ready_phase_when_healthy(tmp_path): - gguf = tmp_path / "m.gguf" - with open(gguf, "wb") as f: - f.truncate(1 * 1024 * 1024) - - inst = _make_backend(os.getpid(), str(gguf), healthy = True) - out = inst.load_progress() - assert out is not None - assert out["phase"] == "ready" - assert out["bytes_total"] == 1 * 1024 * 1024 - # Self-pid RSS is well above 1 MiB for CPython; fraction caps at 1. - assert out["fraction"] == 1.0 - - -def test_live_dead_pid_returns_none(tmp_path): - """A recently-dead pid may linger in /proc for ms; use a clearly - invalid id so the read reliably fails.""" - gguf = tmp_path / "m.gguf" - gguf.touch() - - inst = _make_backend(9_999_999_999, str(gguf), healthy = False) - out = inst.load_progress() - assert out is None - - -def test_live_shard_aggregation_counts_real_files(tmp_path): - """With 4 real sibling shards on disk, ``bytes_total`` equals their - summed size to the byte.""" - shard_size = 7 * 1024 * 1024 # 7 MB each - for i in range(1, 5): - f = tmp_path / f"model-{i:05d}-of-00004.gguf" - with open(f, "wb") as fh: - fh.truncate(shard_size) - # Unrelated file in same dir -- must not be counted. - with open(tmp_path / "config.json", "wb") as fh: - fh.truncate(123) - - inst = _make_backend( - os.getpid(), - str(tmp_path / "model-00001-of-00004.gguf"), - healthy = False, - ) - out = inst.load_progress() - assert out is not None - assert out["bytes_total"] == 4 * shard_size - - -def test_live_repeated_polling_stays_sane(tmp_path): - """Sampling the same backend 20 times should not raise or produce - non-numeric output, even under normal kernel RSS jitter.""" - gguf = tmp_path / "m.gguf" - with open(gguf, "wb") as f: - f.truncate(500 * 1024 * 1024) - - inst = _make_backend(os.getpid(), str(gguf), healthy = False) - seen = [] - for _ in range(20): - out = inst.load_progress() - assert out is not None - assert isinstance(out["bytes_loaded"], int) - assert isinstance(out["bytes_total"], int) - assert 0.0 <= out["fraction"] <= 1.0 - seen.append(out["bytes_loaded"]) - time.sleep(0.01) - # RSS of a healthy Python process doesn't go below ~5 MB. - assert min(seen) > 1 * 1024 * 1024 diff --git a/studio/backend/tests/test_llama_cpp_load_progress_matrix.py b/studio/backend/tests/test_llama_cpp_load_progress_matrix.py deleted file mode 100644 index a88450ec0b..0000000000 --- a/studio/backend/tests/test_llama_cpp_load_progress_matrix.py +++ /dev/null @@ -1,473 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 - -"""Extended test matrix for ``LlamaCppBackend.load_progress()``. - -Companion to ``test_llama_cpp_load_progress.py`` (which pins the basic -contract). This file widens coverage to the edge cases that bit users -or were hypothesized to bite them on cross-platform installs: - - * Platform matrix — macOS/Windows simulation via ``/proc`` absence. - * ``VmRSS`` parsing — tab vs space delimiter, missing line, malformed - integer. - * Filesystem edges — HF-cache symlinks, broken symlinks, nonexistent - paths, relative paths. - * Shard aggregation — partial multi-shard downloads where some shards - are still ``.incomplete``, two shard series in the same dir, - ``mmproj-*.gguf`` sibling exclusion for non-sharded primaries, - single-file models. - * Lifecycle races — process set before ``_gguf_path`` is assigned, - process dead mid-sample, ``_healthy`` flipped to True. - * Concurrent sampling — 10 threads × 50 iterations against a single - backend, hitting real ``/proc`` (no mocks — see the note in - ``TestConcurrentSampling`` for why). - * Fraction bounds — capped at 1.0 when RSS exceeds total; 0.0 when - total is zero. - -All tests are Linux-only in practice (we stub ``/proc`` where needed). -The stable subset runs in well under a second. -""" - -from __future__ import annotations - -import io -import os -import sys -import threading -import types as _types -from pathlib import Path -from unittest.mock import patch - -import pytest - -# --------------------------------------------------------------------------- -# Stub heavy / unavailable external dependencies before importing the -# module under test. Same pattern as test_llama_cpp_load_progress.py. -# --------------------------------------------------------------------------- - -_BACKEND_DIR = str(Path(__file__).resolve().parent.parent) -if _BACKEND_DIR not in sys.path: - sys.path.insert(0, _BACKEND_DIR) - -_loggers_stub = _types.ModuleType("loggers") -_loggers_stub.get_logger = lambda name: __import__("logging").getLogger(name) -sys.modules.setdefault("loggers", _loggers_stub) - -_structlog_stub = _types.ModuleType("structlog") -sys.modules.setdefault("structlog", _structlog_stub) - -_httpx_stub = _types.ModuleType("httpx") -for _exc_name in ( - "ConnectError", - "TimeoutException", - "ReadTimeout", - "ReadError", - "RemoteProtocolError", - "CloseError", -): - setattr(_httpx_stub, _exc_name, type(_exc_name, (Exception,), {})) - - -class _FakeTimeout: - def __init__(self, *a, **kw): - pass - - -_httpx_stub.Timeout = _FakeTimeout -_httpx_stub.Client = type( - "Client", - (), - { - "__init__": lambda self, **kw: None, - "__enter__": lambda self: self, - "__exit__": lambda self, *a: None, - }, -) -sys.modules.setdefault("httpx", _httpx_stub) - -from core.inference.llama_cpp import LlamaCppBackend - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def _make(): - inst = LlamaCppBackend.__new__(LlamaCppBackend) - inst._process = None - inst._gguf_path = None - inst._healthy = False - return inst - - -class _Proc: - def __init__(self, pid): - self.pid = pid - - -def _sparse(path, size): - with open(path, "wb") as f: - if size > 0: - f.truncate(size) - - -def _fake_proc_reader(rss_kb): - """Return an ``open()`` replacement that fakes /proc reads with a VmRSS line.""" - - def fake_open(path, *args, **kwargs): - if str(path).startswith("/proc/"): - return io.StringIO(f"VmRSS:\t{rss_kb}\tkB\n") - return open(path, *args, **kwargs) - - return fake_open - - -# --------------------------------------------------------------------------- -# A. Platform matrix -# --------------------------------------------------------------------------- - - -class TestPlatformMatrix: - """The method is Linux-first via /proc. On macOS/Windows it must - degrade to None rather than crash.""" - - def test_linux_live_proc_is_self_pid(self, tmp_path): - """Self-pid /proc read uses the real kernel interface.""" - gguf = tmp_path / "m.gguf" - _sparse(gguf, 1 * 1024**3) - inst = _make() - inst._process = _Proc(os.getpid()) - inst._gguf_path = str(gguf) - inst._healthy = False - out = inst.load_progress() - assert out is not None - assert out["phase"] == "mmap" - assert out["bytes_total"] == 1 * 1024**3 - # Our Python process has some RSS -- just sanity-check positive. - assert out["bytes_loaded"] > 0 - - def test_macos_no_proc_returns_none(self, tmp_path): - """Simulate macOS: /proc open fails with FileNotFoundError.""" - gguf = tmp_path / "m.gguf" - _sparse(gguf, 1 * 1024**3) - inst = _make() - inst._process = _Proc(pid = 12345) - inst._gguf_path = str(gguf) - - def fake_open(path, *args, **kwargs): - if str(path).startswith("/proc/"): - raise FileNotFoundError(f"No such file: {path}") - return open(path, *args, **kwargs) - - with patch("builtins.open", side_effect = fake_open): - out = inst.load_progress() - assert out is None - - def test_windows_no_proc_returns_none(self, tmp_path): - """Simulate Windows: opening /proc raises PermissionError or OSError.""" - gguf = tmp_path / "m.gguf" - _sparse(gguf, 1 * 1024**3) - inst = _make() - inst._process = _Proc(pid = 4567) - inst._gguf_path = str(gguf) - - def fake_open(path, *args, **kwargs): - if str(path).startswith("/proc/"): - raise PermissionError("access denied") - return open(path, *args, **kwargs) - - with patch("builtins.open", side_effect = fake_open): - out = inst.load_progress() - assert out is None - - -# --------------------------------------------------------------------------- -# B. VmRSS parsing edge cases -# --------------------------------------------------------------------------- - - -class TestVmRSSParsing: - def test_standard_tab_delimited(self, tmp_path): - gguf = tmp_path / "m.gguf" - _sparse(gguf, 4 * 1024**3) - inst = _make() - inst._process = _Proc(os.getpid()) - inst._gguf_path = str(gguf) - with patch("builtins.open", side_effect = _fake_proc_reader(2 * 1024**2)): - out = inst.load_progress() - assert out["bytes_loaded"] == 2 * 1024**3 - - def test_space_separated_fallback(self, tmp_path): - """Some kernels emit single-space rather than tab.""" - gguf = tmp_path / "m.gguf" - _sparse(gguf, 4 * 1024**3) - inst = _make() - inst._process = _Proc(os.getpid()) - inst._gguf_path = str(gguf) - - def fake_open(path, *a, **kw): - if str(path).startswith("/proc/"): - return io.StringIO("VmRSS: 4194304 kB\n") - return open(path, *a, **kw) - - with patch("builtins.open", side_effect = fake_open): - out = inst.load_progress() - assert out["bytes_loaded"] == 4 * 1024**3 - - def test_missing_vmrss_line(self, tmp_path): - """Kernel with VmRSS stripped (zombie / kthread) -> 0.""" - gguf = tmp_path / "m.gguf" - _sparse(gguf, 1 * 1024**3) - inst = _make() - inst._process = _Proc(os.getpid()) - inst._gguf_path = str(gguf) - - def fake_open(path, *a, **kw): - if str(path).startswith("/proc/"): - return io.StringIO("Name:\ttest\nState:\tZ (zombie)\n") - return open(path, *a, **kw) - - with patch("builtins.open", side_effect = fake_open): - out = inst.load_progress() - assert out is not None - assert out["bytes_loaded"] == 0 - assert out["fraction"] == 0.0 - - def test_malformed_vmrss_value(self, tmp_path): - """Non-integer VmRSS value should be treated as if the line were - absent (early ValueError caught).""" - gguf = tmp_path / "m.gguf" - _sparse(gguf, 1 * 1024**3) - inst = _make() - inst._process = _Proc(os.getpid()) - inst._gguf_path = str(gguf) - - def fake_open(path, *a, **kw): - if str(path).startswith("/proc/"): - return io.StringIO("VmRSS:\tXXXX\tkB\n") - return open(path, *a, **kw) - - with patch("builtins.open", side_effect = fake_open): - out = inst.load_progress() - # The implementation catches ValueError on int() and returns None. - assert out is None - - -# --------------------------------------------------------------------------- -# C. Filesystem edge cases -# --------------------------------------------------------------------------- - - -class TestFilesystemEdges: - def test_symlink_primary_follows_to_blob(self, tmp_path): - """HF cache stores blobs under blobs/ and symlinks them from - snapshots/. The method must follow the symlink.""" - blob = tmp_path / "blob" - _sparse(blob, 12 * 1024**3) - snap = tmp_path / "snap" - snap.mkdir() - link = snap / "m.gguf" - link.symlink_to(blob) - - inst = _make() - inst._process = _Proc(os.getpid()) - inst._gguf_path = str(link) - with patch("builtins.open", side_effect = _fake_proc_reader(6 * 1024**2)): - out = inst.load_progress() - assert out["bytes_total"] == 12 * 1024**3 - - def test_broken_symlink_skipped(self, tmp_path): - snap = tmp_path / "snap" - snap.mkdir() - link = snap / "m.gguf" - link.symlink_to(tmp_path / "missing-blob") - inst = _make() - inst._process = _Proc(os.getpid()) - inst._gguf_path = str(link) - with patch("builtins.open", side_effect = _fake_proc_reader(1024)): - out = inst.load_progress() - assert out["bytes_total"] == 0 - assert out["bytes_loaded"] == 1024 * 1024 - - def test_nonexistent_path_skipped(self, tmp_path): - inst = _make() - inst._process = _Proc(os.getpid()) - inst._gguf_path = str(tmp_path / "ghost.gguf") - with patch("builtins.open", side_effect = _fake_proc_reader(1024)): - out = inst.load_progress() - assert out["bytes_total"] == 0 - - def test_relative_gguf_path(self, tmp_path): - """Relative paths shouldn't crash; behaviour depends on CWD but - the method must not raise.""" - cwd = os.getcwd() - try: - os.chdir(tmp_path) - _sparse(Path("rel.gguf"), 8 * 1024**3) - inst = _make() - inst._process = _Proc(os.getpid()) - inst._gguf_path = "rel.gguf" - with patch("builtins.open", side_effect = _fake_proc_reader(0)): - out = inst.load_progress() - assert out is not None - assert out["bytes_total"] == 8 * 1024**3 - finally: - os.chdir(cwd) - - -# --------------------------------------------------------------------------- -# D. Shard aggregation -# --------------------------------------------------------------------------- - - -class TestShardAggregation: - def test_partial_multi_shard_download(self, tmp_path): - """Primary present but shards 2..N still downloading as - ``.incomplete``. Sums only the fully-arrived ``.gguf`` files.""" - _sparse(tmp_path / "m-00001-of-00004.gguf", 30 * 1024**3) - _sparse(tmp_path / "m-00002-of-00004.gguf", 30 * 1024**3) - # 3 and 4 still downloading as .incomplete - _sparse(tmp_path / "m-00003-of-00004.gguf.incomplete", 5 * 1024**3) - inst = _make() - inst._process = _Proc(os.getpid()) - inst._gguf_path = str(tmp_path / "m-00001-of-00004.gguf") - with patch("builtins.open", side_effect = _fake_proc_reader(0)): - out = inst.load_progress() - assert out["bytes_total"] == 60 * 1024**3 # only the .gguf siblings - - def test_two_shard_series_in_same_dir(self, tmp_path): - """Defensive: if two quant series share a dir, prefix filter - only sums siblings of the chosen primary.""" - for i in range(1, 3): - _sparse(tmp_path / f"m_q4-{i:05d}-of-00002.gguf", 10 * 1024**3) - _sparse(tmp_path / f"m_q8-{i:05d}-of-00002.gguf", 20 * 1024**3) - inst = _make() - inst._process = _Proc(os.getpid()) - inst._gguf_path = str(tmp_path / "m_q8-00001-of-00002.gguf") - with patch("builtins.open", side_effect = _fake_proc_reader(0)): - out = inst.load_progress() - assert out["bytes_total"] == 40 * 1024**3 # just q8 series - - def test_mmproj_sibling_not_counted(self, tmp_path): - """Vision models drop an ``mmproj-*.gguf`` alongside. For a - single-file (non-sharded) primary we only count the primary.""" - _sparse(tmp_path / "m.gguf", 8 * 1024**3) - _sparse(tmp_path / "mmproj-BF16.gguf", 2 * 1024**3) - inst = _make() - inst._process = _Proc(os.getpid()) - inst._gguf_path = str(tmp_path / "m.gguf") - with patch("builtins.open", side_effect = _fake_proc_reader(0)): - out = inst.load_progress() - # Non-sharded primary: only the primary is counted. - assert out["bytes_total"] == 8 * 1024**3 - - def test_single_file_model(self, tmp_path): - """Non-sharded model: primary only.""" - _sparse(tmp_path / "small.gguf", 4 * 1024**3) - inst = _make() - inst._process = _Proc(os.getpid()) - inst._gguf_path = str(tmp_path / "small.gguf") - with patch("builtins.open", side_effect = _fake_proc_reader(2 * 1024**2)): - out = inst.load_progress() - assert out["bytes_total"] == 4 * 1024**3 - assert out["bytes_loaded"] == 2 * 1024**3 - - -# --------------------------------------------------------------------------- -# E. Lifecycle races -# --------------------------------------------------------------------------- - - -class TestLifecycleRaces: - def test_process_set_but_gguf_path_not_yet(self, tmp_path): - """Moment between Popen and self._gguf_path=model_path.""" - inst = _make() - inst._process = _Proc(os.getpid()) - inst._gguf_path = None - with patch("builtins.open", side_effect = _fake_proc_reader(1024)): - out = inst.load_progress() - assert out is not None - assert out["phase"] == "mmap" - assert out["bytes_total"] == 0 - assert out["bytes_loaded"] == 1024 * 1024 - - def test_process_died_mid_sample(self, tmp_path): - """/proc/ disappears -> None.""" - _sparse(tmp_path / "m.gguf", 1 * 1024**3) - inst = _make() - inst._process = _Proc(pid = 999_999_999) - inst._gguf_path = str(tmp_path / "m.gguf") - assert inst.load_progress() is None - - def test_healthy_true_ready_phase(self, tmp_path): - _sparse(tmp_path / "m.gguf", 1 * 1024**3) - inst = _make() - inst._process = _Proc(os.getpid()) - inst._gguf_path = str(tmp_path / "m.gguf") - inst._healthy = True - with patch("builtins.open", side_effect = _fake_proc_reader(1024)): - out = inst.load_progress() - assert out["phase"] == "ready" - - -# --------------------------------------------------------------------------- -# F. Concurrent sampling (simulates multiple browser tabs polling) -# --------------------------------------------------------------------------- - - -class TestConcurrentSampling: - def test_parallel_invocations_never_raise(self, tmp_path): - """Many concurrent samplers hitting the same backend must not raise. - - We intentionally do NOT patch ``builtins.open`` here because - ``unittest.mock.patch`` is not thread-safe: interleaved - enter/exit across threads can leak a Mock into ``builtins.open`` - and poison every subsequent test in the session. Instead, we - let each thread hit the real ``/proc/self/status`` of the test - process, which is exactly the code path that matters in prod. - """ - _sparse(tmp_path / "m.gguf", 1 * 1024**3) - inst = _make() - inst._process = _Proc(os.getpid()) - inst._gguf_path = str(tmp_path / "m.gguf") - errors = [] - - def run(): - try: - for _ in range(50): - inst.load_progress() - except Exception as e: # pragma: no cover - errors.append(e) - - threads = [threading.Thread(target = run) for _ in range(10)] - for t in threads: - t.start() - for t in threads: - t.join() - assert not errors, errors - - -# --------------------------------------------------------------------------- -# G. Fraction bounds -# --------------------------------------------------------------------------- - - -class TestFractionBounds: - def test_fraction_capped_at_one(self, tmp_path): - _sparse(tmp_path / "m.gguf", 1 * 1024**3) - inst = _make() - inst._process = _Proc(os.getpid()) - inst._gguf_path = str(tmp_path / "m.gguf") - # RSS > total (post-paged-in + extra structures) - with patch("builtins.open", side_effect = _fake_proc_reader(2 * 1024**2)): - out = inst.load_progress() - assert 0.0 <= out["fraction"] <= 1.0 - - def test_fraction_zero_when_total_zero(self): - inst = _make() - inst._process = _Proc(os.getpid()) - inst._gguf_path = None - with patch("builtins.open", side_effect = _fake_proc_reader(1024**2)): - out = inst.load_progress() - assert out["fraction"] == 0.0 diff --git a/studio/backend/tests/test_llama_cpp_max_context_threshold.py b/studio/backend/tests/test_llama_cpp_max_context_threshold.py deleted file mode 100644 index 22e4cda7d1..0000000000 --- a/studio/backend/tests/test_llama_cpp_max_context_threshold.py +++ /dev/null @@ -1,248 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 - -"""Tests for the ``max_context_length`` warning-threshold semantics. - -``/api/inference/status.max_context_length`` is what the ctx slider in -the chat settings sheet reads to decide when to render the "Exceeds -estimated VRAM capacity. The model may use system RAM." warning: - - ctxDisplayValue > ggufMaxContextLength → show warning - -For models whose weights fit on some GPU subset, the warning threshold -is the largest ctx that fits fully in VRAM (the binary-search cap from -``_fit_context_to_vram``). For models whose weights exceed 90% of every -GPU subset's free memory, the warning must fire as soon as the user -drags above the 4096 spec default (otherwise a user loading e.g. -MiniMax-M2.7 on a 97 GB GPU sees a slider up to 196608 with no -indication that any value above 4096 will trigger ``--fit on`` and -degrade performance). - -These tests pin both cases. No GPU probing, no subprocess, no GGUF I/O. -Cross-platform: Linux, macOS, Windows, WSL. -""" - -from __future__ import annotations - -import sys -import types as _types -from pathlib import Path - -import pytest - -# --------------------------------------------------------------------------- -# Stub heavy / unavailable external dependencies before importing the -# module under test. Same pattern as test_kv_cache_estimation.py. -# --------------------------------------------------------------------------- - -_BACKEND_DIR = str(Path(__file__).resolve().parent.parent) -if _BACKEND_DIR not in sys.path: - sys.path.insert(0, _BACKEND_DIR) - -# loggers -_loggers_stub = _types.ModuleType("loggers") -_loggers_stub.get_logger = lambda name: __import__("logging").getLogger(name) -sys.modules.setdefault("loggers", _loggers_stub) - -# structlog -_structlog_stub = _types.ModuleType("structlog") -sys.modules.setdefault("structlog", _structlog_stub) - -# httpx -_httpx_stub = _types.ModuleType("httpx") -for _exc_name in ( - "ConnectError", - "TimeoutException", - "ReadTimeout", - "ReadError", - "RemoteProtocolError", - "CloseError", -): - setattr(_httpx_stub, _exc_name, type(_exc_name, (Exception,), {})) - - -class _FakeTimeout: - def __init__(self, *a, **kw): - pass - - -_httpx_stub.Timeout = _FakeTimeout -_httpx_stub.Client = type( - "Client", - (), - { - "__init__": lambda self, **kw: None, - "__enter__": lambda self: self, - "__exit__": lambda self, *a: None, - }, -) -sys.modules.setdefault("httpx", _httpx_stub) - -from core.inference.llama_cpp import LlamaCppBackend - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - -GIB = 1024**3 - - -def _make_backend(native_ctx = 131072): - inst = LlamaCppBackend.__new__(LlamaCppBackend) - inst._context_length = native_ctx - inst._n_layers = 80 - inst._n_kv_heads = 8 - inst._n_heads = 64 - inst._embedding_length = 8192 - inst._kv_key_length = 128 - inst._kv_value_length = 128 - inst._kv_lora_rank = None - inst._sliding_window = None - inst._sliding_window_pattern = None - inst._ssm_inner_size = None - inst._full_attention_interval = None - inst._key_length_mla = None - inst._n_kv_heads_by_layer = None - inst._kv_key_length_swa = None - inst._kv_value_length_swa = None - return inst - - -def _compute_max_available_ctx(native_ctx, model_gib, gpus, kv_per_token_bytes = 325_000): - """Run the ceiling-probe block from load_model and return the final - ``max_available_ctx`` value the backend would assign to - ``_max_context_length``. - """ - inst = _make_backend(native_ctx = native_ctx) - model_size = int(model_gib * GIB) - - inst._estimate_kv_cache_bytes = ( - lambda n, _t = None, **_kw: 0 if n <= 0 else n * kv_per_token_bytes - ) - inst._can_estimate_kv = lambda: True - - context_length = inst._context_length - effective_ctx = context_length - max_available_ctx = context_length - - cache_type_kv = None - native_ctx_for_cap = context_length - - ranked_for_cap = sorted(gpus, key = lambda g: g[1], reverse = True) - best_cap = 0 - for n_gpus in range(1, len(ranked_for_cap) + 1): - subset = ranked_for_cap[:n_gpus] - pool_mib = sum(free for _, free in subset) - capped = inst._fit_context_to_vram( - native_ctx_for_cap, - pool_mib, - model_size, - cache_type_kv, - ) - kv = inst._estimate_kv_cache_bytes(capped, cache_type_kv) - total_mib = (model_size + kv) / (1024 * 1024) - if total_mib <= pool_mib * 0.90: - best_cap = max(best_cap, capped) - if best_cap > 0: - max_available_ctx = best_cap - else: - max_available_ctx = min(4096, native_ctx_for_cap) - - return max_available_ctx - - -# --------------------------------------------------------------------------- -# Weights exceed every GPU subset's VRAM (MiniMax-M2.7-like) -# --------------------------------------------------------------------------- - - -class TestMaxContextLengthForWeightsExceedVRAM: - """The UI ``max_context_length`` threshold must fall back to 4096 so - the warning fires as soon as the user drags above the spec default. - """ - - def test_minimax_like(self): - """131 GB weights, single 97 GB GPU, native ctx 196608.""" - got = _compute_max_available_ctx( - native_ctx = 196608, - model_gib = 131, - gpus = [(0, 97_000)], - ) - assert got == 4096 - - def test_multi_gpu_all_subsets_fail(self): - """400 GB weights across a 4x80 GB pool (320 GB total, still too small).""" - got = _compute_max_available_ctx( - native_ctx = 131072, - model_gib = 400, - gpus = [(0, 80_000), (1, 80_000), (2, 80_000), (3, 80_000)], - ) - assert got == 4096 - - def test_native_below_fallback_is_preserved(self): - """If the model's native ctx is itself smaller than 4096, do not - advertise a larger value than the model supports.""" - got = _compute_max_available_ctx( - native_ctx = 2048, - model_gib = 200, - gpus = [(0, 80_000)], - ) - assert got == 2048 - - -# --------------------------------------------------------------------------- -# Fittable models (regression guard) -# --------------------------------------------------------------------------- - - -class TestMaxContextLengthForFittableModels: - """The existing best-cap behaviour must be unchanged.""" - - def test_small_model_fits_easily(self): - """8 GB model on 24 GB GPU: should auto-pick a large ctx.""" - got = _compute_max_available_ctx( - native_ctx = 131072, - model_gib = 8, - gpus = [(0, 24_000)], - kv_per_token_bytes = 8192, - ) - assert got > 4096 - assert got <= 131072 - - def test_medium_model_multi_gpu(self): - """60 GB model split across 2 GPUs: picks a fitting ctx.""" - got = _compute_max_available_ctx( - native_ctx = 131072, - model_gib = 60, - gpus = [(0, 40_000), (1, 40_000)], - kv_per_token_bytes = 8192, - ) - assert got > 4096 - - def test_tiny_model_on_huge_gpu_near_native(self): - """2 GB model, 80 GB GPU, negligible KV: should approach native.""" - got = _compute_max_available_ctx( - native_ctx = 131072, - model_gib = 2, - gpus = [(0, 80_000)], - kv_per_token_bytes = 64, - ) - assert got >= 131072 - 256 # rounded to 256 boundary - - -# --------------------------------------------------------------------------- -# Property plumbing -# --------------------------------------------------------------------------- - - -class TestMaxContextLengthProperty: - def test_falls_back_to_native_when_unset(self): - inst = _make_backend(native_ctx = 131072) - inst._max_context_length = None - assert inst.max_context_length == 131072 - - def test_returns_stored_value_when_set(self): - inst = _make_backend(native_ctx = 131072) - inst._max_context_length = 4096 - assert inst.max_context_length == 4096 diff --git a/studio/backend/tests/test_llama_cpp_no_context_shift.py b/studio/backend/tests/test_llama_cpp_no_context_shift.py deleted file mode 100644 index b9f25faf88..0000000000 --- a/studio/backend/tests/test_llama_cpp_no_context_shift.py +++ /dev/null @@ -1,137 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 - -"""``--no-context-shift`` launch-flag contract. - -When llama-server runs with its default context-shift behavior, the UI -has no way to tell the user that the KV cache has been rotated -- -earlier turns silently vanish from the conversation. The Studio -backend always passes ``--no-context-shift`` so the server returns a -clean error instead, and the chat adapter can point the user at the -``Context Length`` input in the settings panel. - -This file is a static read of the launch command: we ask -``LlamaCppBackend`` to assemble its ``cmd`` list and assert the flag -is always present. Testing via the real subprocess would require an -actual GGUF on disk, which is out of scope for the fast test suite. -""" - -from __future__ import annotations - -import inspect -import sys -import types as _types -from pathlib import Path - -import pytest - -# --------------------------------------------------------------------------- -# Same external-dep stubs as the other llama_cpp tests. -# --------------------------------------------------------------------------- - -_BACKEND_DIR = str(Path(__file__).resolve().parent.parent) -if _BACKEND_DIR not in sys.path: - sys.path.insert(0, _BACKEND_DIR) - -_loggers_stub = _types.ModuleType("loggers") -_loggers_stub.get_logger = lambda name: __import__("logging").getLogger(name) -sys.modules.setdefault("loggers", _loggers_stub) - -_structlog_stub = _types.ModuleType("structlog") -sys.modules.setdefault("structlog", _structlog_stub) - -_httpx_stub = _types.ModuleType("httpx") -for _exc in ( - "ConnectError", - "TimeoutException", - "ReadTimeout", - "ReadError", - "RemoteProtocolError", - "CloseError", -): - setattr(_httpx_stub, _exc, type(_exc, (Exception,), {})) -_httpx_stub.Timeout = type("T", (), {"__init__": lambda s, *a, **k: None}) -_httpx_stub.Client = type( - "C", - (), - { - "__init__": lambda s, **kw: None, - "__enter__": lambda s: s, - "__exit__": lambda s, *a: None, - }, -) -sys.modules.setdefault("httpx", _httpx_stub) - -from core.inference import llama_cpp as llama_cpp_module - - -def _load_model_source() -> str: - """Return the source of ``LlamaCppBackend.load_model``. - - Using ``inspect.getsource`` instead of reading the file directly - scopes the assertions to the function that actually launches - llama-server, so neither the presence check nor the location check - can be fooled by a stray occurrence of ``"--no-context-shift"`` - elsewhere in the module. - """ - return inspect.getsource(llama_cpp_module.LlamaCppBackend.load_model) - - -def test_no_context_shift_is_in_load_model(): - """The flag is part of the static launch-command template. - - We check the source of ``load_model`` rather than mocking the whole - call chain (GPU probing, GGUF stat, etc.): the flag is written as - a literal in one place and any regression has to delete it, which - a text search will catch. - """ - assert '"--no-context-shift"' in _load_model_source(), ( - "llama-server must be launched with --no-context-shift so the " - "UI can surface a clean 'context full' error instead of silently " - "losing old turns to a KV-cache rotation." - ) - - -def test_flag_sits_inside_the_base_cmd_list(): - """Pin the flag's location so a future refactor can't accidentally - move it into a branch that only fires on some code paths. - - We slice from ``cmd = [`` to the first ``]`` at the same indent. - Using ``inspect.getsource`` means the function lives in its own - string and there are no siblings to worry about, so a plain - bracket search would also work -- anchoring on the trailing indent - just keeps the slice from wandering into a later expression if the - opening literal ever grows an in-line comment trailing it. - """ - source = _load_model_source() - start = source.find("cmd = [") - assert start >= 0, "could not find the base cmd = [...] block" - # Find the first line containing only ``]`` (possibly indented). - # Works for any indentation style the formatter picks. - rest = source[start:] - end_rel = -1 - for line_start, line in _iter_lines_with_offset(rest): - if line_start == 0: - # Skip the opening ``cmd = [`` line itself. - continue - if line.strip() == "]": - end_rel = line_start - break - assert end_rel > 0, "could not find end of cmd = [...] block" - block = rest[:end_rel] - assert '"--no-context-shift"' in block, ( - "--no-context-shift must be in the base cmd list, not in a " - "conditional branch -- otherwise some code paths would still " - "run with silent context shift enabled." - ) - # Also pin that it is next to -c / --ctx so the grouping makes sense. - assert '"-c"' in block - assert '"--flash-attn"' in block - - -def _iter_lines_with_offset(text: str): - """Yield (offset, line) pairs over ``text`` without losing offsets.""" - offset = 0 - for line in text.splitlines(keepends = True): - yield offset, line - offset += len(line) diff --git a/studio/backend/tests/test_llama_server_args.py b/studio/backend/tests/test_llama_server_args.py deleted file mode 100644 index 351fbd014d..0000000000 --- a/studio/backend/tests/test_llama_server_args.py +++ /dev/null @@ -1,189 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 - -"""Unit tests for the llama-server pass-through args validator. - -The validator is the security boundary between user-supplied CLI / HTTP -input and the llama-server subprocess command. These tests pin the -denylist behavior so the boundary doesn't quietly regress when new -managed flags are added. -""" - -from __future__ import annotations - -import pytest - -from core.inference.llama_server_args import ( - is_managed_flag, - validate_extra_args, -) - - -# ── Pass-through (allowed) ─────────────────────────────────────────── - - -@pytest.mark.parametrize( - "args", - [ - # Sampling - ["--top-k", "20"], - ["--top-p", "0.9", "--min-p", "0.05"], - ["--seed", "-1"], # negative value, not a flag - ["--temp", "0.0"], - ["--repeat-penalty", "1.05"], - ["--mirostat", "2", "--mirostat-lr", "0.1"], - ["--xtc-probability", "0.05", "--xtc-threshold", "0.1"], - ["--dry-multiplier", "0.5"], - # Tier-2 knobs that map to LoadRequest fields - ["--cache-type-k", "q8_0"], - ["--cache-type-v", "q8_0"], - ["--chat-template-file", "/tmp/tpl.jinja"], - ["--chat-template-kwargs", '{"reasoning_effort":"high"}'], - ["--spec-type", "ngram-mod"], - ["--spec-default"], - # Reasoning controls - ["--reasoning-format", "deepseek"], - ["-rea", "auto"], - # Soft-managed flags the user may want to override on the CLI; - # llama.cpp's last-wins parsing means these win over Studio's - # auto-set version. - ["-c", "131072"], - ["--ctx-size", "8192"], - ["--parallel", "1"], - ["-np", "8"], - ["--flash-attn", "off"], - ["-fa", "on"], - ["--no-context-shift"], - ["--context-shift"], - ["--jinja"], - ["--no-jinja"], - ["-ngl", "-1"], - ["--gpu-layers", "32"], - ["-t", "16"], - ["--threads", "32"], - ["-fit", "off"], - ["--fit", "on"], - ["--fit-ctx", "8192"], - ], -) -def test_pass_through_allowed(args): - assert validate_extra_args(args) == args - - -def test_none_returns_empty_list(): - assert validate_extra_args(None) == [] - - -def test_empty_list_returns_empty_list(): - assert validate_extra_args([]) == [] - - -def test_value_with_equals_form_passes_through(): - assert validate_extra_args(["--top-k=20"]) == ["--top-k=20"] - - -def test_non_flag_token_passes_through(): - # A bare positional value (not preceded by a flag) is preserved - # verbatim. llama-server may reject it, but that's not our job. - assert validate_extra_args(["foo"]) == ["foo"] - - -# ── Denylist (rejected) ────────────────────────────────────────────── - - -@pytest.mark.parametrize( - "denied", - [ - # Model identity - "-m", - "--model", - "-hf", - "-hfr", - "--hf-repo", - "-hff", - "--hf-file", - "-hft", - "--hf-token", - "-mm", - "--mmproj", - "--mmproj-url", - # Networking (Studio binds + proxies) - "--host", - "--port", - "--path", - "--api-prefix", - "--reuse-port", - # Auth / TLS - "--api-key", - "--api-key-file", - "--ssl-key-file", - "--ssl-cert-file", - # Single-model server - "--webui", - "--no-webui", - "--models-dir", - "--models-max", - ], -) -def test_denylist_rejects_all_aliases(denied): - with pytest.raises(ValueError, match = denied): - validate_extra_args([denied, "value"]) - - -def test_denylist_rejects_equals_form(): - with pytest.raises(ValueError, match = "--port"): - validate_extra_args(["--port=9000"]) - - -def test_denylist_rejects_short_form_when_long_is_denied(): - # -m is the short form of the hard-denied --model; rejecting only - # the long form would leave a trivial bypass. - with pytest.raises(ValueError, match = "-m"): - validate_extra_args(["-m", "/some/other/path.gguf"]) - - -def test_denylist_message_names_offending_flag(): - with pytest.raises(ValueError) as excinfo: - validate_extra_args(["--top-k", "20", "--api-key", "secret"]) - assert "--api-key" in str(excinfo.value) - - -def test_first_denied_flag_short_circuits(): - # Validation stops at the first denied flag; later denied flags - # in the same call don't matter for behaviour, but the message - # should name the first one we hit. - with pytest.raises(ValueError, match = "--port"): - validate_extra_args(["--port", "1", "--host", "x"]) - - -# ── Numeric values that look flag-ish ───────────────────────────────── - - -@pytest.mark.parametrize("value", ["-1", "-0.5", "-42", "-.5"]) -def test_negative_number_value_is_not_flag(value): - # ``--seed -1`` is a value, not a flag. Validator must not try - # to look up "-1" in the denylist. - assert validate_extra_args(["--seed", value]) == ["--seed", value] - - -# ── is_managed_flag helper ─────────────────────────────────────────── - - -def test_is_managed_flag_true_for_denied(): - assert is_managed_flag("--port") is True - assert is_managed_flag("--api-key") is True - assert is_managed_flag("-m") is True - assert is_managed_flag("--model") is True - - -def test_is_managed_flag_false_for_pass_through(): - assert is_managed_flag("--top-k") is False - assert is_managed_flag("--cache-type-k") is False - assert is_managed_flag("--chat-template-file") is False - # Soft-managed flags pass through (last-wins override) - assert is_managed_flag("-c") is False - assert is_managed_flag("--ctx-size") is False - assert is_managed_flag("--parallel") is False - assert is_managed_flag("--flash-attn") is False - assert is_managed_flag("-ngl") is False - assert is_managed_flag("--threads") is False diff --git a/studio/backend/tests/test_native_context_length.py b/studio/backend/tests/test_native_context_length.py index 60622c776d..7c69e56f89 100644 --- a/studio/backend/tests/test_native_context_length.py +++ b/studio/backend/tests/test_native_context_length.py @@ -320,23 +320,11 @@ def test_status_response_has_field(self): """Field exists in InferenceStatusResponse.model_fields.""" assert "native_context_length" in InferenceStatusResponse.model_fields - def test_status_response_has_chat_template_field(self): - """Status includes chat_template so the UI can rehydrate after refresh.""" - assert "chat_template" in InferenceStatusResponse.model_fields - def test_status_response_defaults_none(self): """Omitting native_context_length defaults to None.""" resp = InferenceStatusResponse() assert resp.native_context_length is None - def test_status_response_chat_template_roundtrip(self): - """chat_template serializes and validates as part of status.""" - resp = InferenceStatusResponse(chat_template = "{{ messages }}") - roundtripped = InferenceStatusResponse.model_validate_json( - resp.model_dump_json() - ) - assert roundtripped.chat_template == "{{ messages }}" - def test_roundtrip_preserves_value(self): """model_validate_json(model_dump_json()) round-trips.""" resp = LoadResponse( diff --git a/studio/backend/tests/test_openai_tool_passthrough.py b/studio/backend/tests/test_openai_tool_passthrough.py deleted file mode 100644 index cdb7f5d270..0000000000 --- a/studio/backend/tests/test_openai_tool_passthrough.py +++ /dev/null @@ -1,474 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. - -""" -Tests for the OpenAI /v1/chat/completions client-side tool pass-through. - -Covers: -- ChatCompletionRequest accepts standard OpenAI `tools` / `tool_choice` / `stop`. -- ChatMessage accepts role="tool" with `tool_call_id` and role="assistant" - with `content: None` + `tool_calls`. -- ChatCompletionRequest carries unknown fields via `extra="allow"`. -- anthropic_tool_choice_to_openai() covers all four Anthropic shapes. -- _build_passthrough_payload() honors a caller-supplied tool_choice and - defaults to "auto" when unset. -- _friendly_error() maps httpx transport errors to a "Lost connection" - message so passthrough failures are legible instead of bare 500s. - -No running server or GPU required. -""" - -import os -import sys - -_backend = os.path.join(os.path.dirname(__file__), "..") -sys.path.insert(0, _backend) - -import httpx -import pytest -from pydantic import ValidationError - -from models.inference import ( - ChatCompletionRequest, - ChatMessage, -) -from core.inference.anthropic_compat import ( - anthropic_tool_choice_to_openai, -) -from routes.inference import _build_passthrough_payload, _friendly_error - - -# ===================================================================== -# ChatMessage — tool role, tool_calls, optional content -# ===================================================================== - - -class TestChatMessageToolRoles: - def test_tool_role_with_tool_call_id(self): - msg = ChatMessage( - role = "tool", - tool_call_id = "call_abc123", - content = '{"temperature": 72}', - ) - assert msg.role == "tool" - assert msg.tool_call_id == "call_abc123" - assert msg.content == '{"temperature": 72}' - - def test_tool_role_with_name(self): - msg = ChatMessage( - role = "tool", - tool_call_id = "call_abc123", - name = "get_weather", - content = '{"temperature": 72}', - ) - assert msg.name == "get_weather" - - def test_assistant_with_tool_calls_no_content(self): - msg = ChatMessage( - role = "assistant", - content = None, - tool_calls = [ - { - "id": "call_1", - "type": "function", - "function": { - "name": "get_weather", - "arguments": '{"city": "Paris"}', - }, - } - ], - ) - assert msg.role == "assistant" - assert msg.content is None - assert msg.tool_calls is not None - assert len(msg.tool_calls) == 1 - assert msg.tool_calls[0]["function"]["name"] == "get_weather" - - def test_assistant_with_content_and_tool_calls(self): - msg = ChatMessage( - role = "assistant", - content = "Let me check the weather.", - tool_calls = [ - { - "id": "call_1", - "type": "function", - "function": {"name": "get_weather", "arguments": "{}"}, - } - ], - ) - assert msg.content == "Let me check the weather." - assert msg.tool_calls[0]["id"] == "call_1" - - def test_plain_user_message_still_works(self): - msg = ChatMessage(role = "user", content = "Hello") - assert msg.role == "user" - assert msg.tool_call_id is None - assert msg.tool_calls is None - assert msg.name is None - - def test_invalid_role_rejected(self): - with pytest.raises(ValidationError): - ChatMessage(role = "function", content = "x") - - def test_content_absent_on_assistant_tool_call_defaults_to_none(self): - # Assistant messages that carry only tool_calls are the one - # documented case where `content=None` is permitted. - msg = ChatMessage( - role = "assistant", - tool_calls = [ - { - "id": "call_1", - "type": "function", - "function": {"name": "f", "arguments": "{}"}, - } - ], - ) - assert msg.content is None - - def test_tool_role_missing_tool_call_id_rejected(self): - # Per OpenAI spec, role="tool" messages must carry tool_call_id so - # upstream backends can associate the result with its prior call. - # Pin the boundary-level rejection so a malformed tool-result - # message never reaches the passthrough path. - with pytest.raises(ValidationError) as exc_info: - ChatMessage(role = "tool", content = '{"temperature": 72}') - assert "tool_call_id" in str(exc_info.value) - - def test_tool_role_empty_tool_call_id_rejected(self): - with pytest.raises(ValidationError): - ChatMessage( - role = "tool", - tool_call_id = "", - content = '{"temperature": 72}', - ) - - # ── Role-aware content requirements ──────────────────────────── - - @pytest.mark.parametrize("role", ["user", "system"]) - def test_empty_string_content_allowed(self, role): - msg = ChatMessage(role = role, content = "") - assert msg.content == "" - - def test_user_missing_content_rejected(self): - with pytest.raises(ValidationError): - ChatMessage(role = "user") - - def test_user_empty_list_content_rejected(self): - with pytest.raises(ValidationError): - ChatMessage(role = "user", content = []) - - def test_tool_empty_content_rejected(self): - with pytest.raises(ValidationError) as exc_info: - ChatMessage(role = "tool", tool_call_id = "call_1", content = "") - assert "content" in str(exc_info.value) - - def test_assistant_without_content_or_tool_calls_rejected(self): - with pytest.raises(ValidationError) as exc_info: - ChatMessage(role = "assistant") - assert "content" in str(exc_info.value) or "tool_calls" in str(exc_info.value) - - # ── Role-constrained tool-call metadata ──────────────────────── - - def test_tool_calls_on_user_rejected(self): - with pytest.raises(ValidationError) as exc_info: - ChatMessage( - role = "user", - content = "Hi", - tool_calls = [ - { - "id": "c1", - "type": "function", - "function": {"name": "f", "arguments": "{}"}, - } - ], - ) - assert "tool_calls" in str(exc_info.value) - - def test_tool_call_id_on_user_rejected(self): - with pytest.raises(ValidationError) as exc_info: - ChatMessage(role = "user", content = "Hi", tool_call_id = "call_1") - assert "tool_call_id" in str(exc_info.value) - - def test_name_on_user_rejected(self): - with pytest.raises(ValidationError) as exc_info: - ChatMessage(role = "user", content = "Hi", name = "get_weather") - assert "name" in str(exc_info.value) - - -# ===================================================================== -# ChatCompletionRequest — standard OpenAI tool fields -# ===================================================================== - - -class TestChatCompletionRequestToolFields: - def _make(self, **kwargs): - base = {"messages": [{"role": "user", "content": "Hi"}]} - base.update(kwargs) - return ChatCompletionRequest(**base) - - def test_tools_parses(self): - req = self._make( - tools = [ - { - "type": "function", - "function": { - "name": "get_weather", - "description": "Return the weather in a city", - "parameters": { - "type": "object", - "properties": {"city": {"type": "string"}}, - "required": ["city"], - }, - }, - } - ], - ) - assert req.tools is not None - assert len(req.tools) == 1 - assert req.tools[0]["function"]["name"] == "get_weather" - - def test_image_base64_allows_empty_user_text(self): - req = ChatCompletionRequest( - messages = [{"role": "user", "content": ""}], - image_base64 = "aW1hZ2U=", - ) - assert req.messages[0].content == "" - assert req.image_base64 == "aW1hZ2U=" - - def test_tool_choice_string_auto(self): - assert self._make(tool_choice = "auto").tool_choice == "auto" - - def test_tool_choice_string_required(self): - assert self._make(tool_choice = "required").tool_choice == "required" - - def test_tool_choice_string_none(self): - assert self._make(tool_choice = "none").tool_choice == "none" - - def test_tool_choice_named_function(self): - tc = {"type": "function", "function": {"name": "get_weather"}} - assert self._make(tool_choice = tc).tool_choice == tc - - def test_stop_string(self): - assert self._make(stop = "\nUser:").stop == "\nUser:" - - def test_stop_list(self): - assert self._make(stop = ["\nUser:", "\nAssistant:"]).stop == [ - "\nUser:", - "\nAssistant:", - ] - - def test_tools_default_none(self): - req = self._make() - assert req.tools is None - assert req.tool_choice is None - assert req.stop is None - - def test_extra_fields_accepted(self): - # `frequency_penalty`, `seed`, `response_format` are not yet - # explicitly declared but must survive Pydantic parsing now that - # extra="allow" is set. - req = self._make( - frequency_penalty = 0.5, - seed = 42, - response_format = {"type": "json_object"}, - ) - # Extras land in model_extra - assert req.model_extra is not None - assert req.model_extra.get("frequency_penalty") == 0.5 - assert req.model_extra.get("seed") == 42 - assert req.model_extra.get("response_format") == {"type": "json_object"} - - def test_unsloth_extensions_still_work(self): - req = self._make( - enable_tools = True, - enabled_tools = ["web_search", "python"], - session_id = "abc", - ) - assert req.enable_tools is True - assert req.enabled_tools == ["web_search", "python"] - assert req.session_id == "abc" - - def test_stream_defaults_false_matching_openai_spec(self): - # OpenAI's /v1/chat/completions spec defaults `stream` to false. - # Studio previously defaulted to true, which broke naive curl - # clients that omit `stream` (they expect a JSON blob, got SSE). - # Pin the corrected default so it can't silently regress. - req = self._make() - assert req.stream is False - - def test_multiturn_tool_loop_messages(self): - req = ChatCompletionRequest( - messages = [ - {"role": "user", "content": "What's the weather in Paris?"}, - { - "role": "assistant", - "content": None, - "tool_calls": [ - { - "id": "call_1", - "type": "function", - "function": { - "name": "get_weather", - "arguments": '{"city": "Paris"}', - }, - } - ], - }, - { - "role": "tool", - "tool_call_id": "call_1", - "content": '{"temperature": 14, "unit": "celsius"}', - }, - ], - tools = [ - { - "type": "function", - "function": { - "name": "get_weather", - "parameters": {"type": "object"}, - }, - } - ], - ) - assert len(req.messages) == 3 - assert req.messages[1].role == "assistant" - assert req.messages[1].content is None - assert req.messages[1].tool_calls[0]["id"] == "call_1" - assert req.messages[2].role == "tool" - assert req.messages[2].tool_call_id == "call_1" - - -# ===================================================================== -# anthropic_tool_choice_to_openai — pure translation helper -# ===================================================================== - - -class TestAnthropicToolChoiceToOpenAI: - def test_auto(self): - assert anthropic_tool_choice_to_openai({"type": "auto"}) == "auto" - - def test_any_becomes_required(self): - assert anthropic_tool_choice_to_openai({"type": "any"}) == "required" - - def test_none(self): - assert anthropic_tool_choice_to_openai({"type": "none"}) == "none" - - def test_tool_named(self): - result = anthropic_tool_choice_to_openai( - {"type": "tool", "name": "get_weather"} - ) - assert result == { - "type": "function", - "function": {"name": "get_weather"}, - } - - def test_tool_missing_name_returns_none(self): - assert anthropic_tool_choice_to_openai({"type": "tool"}) is None - - def test_none_input_returns_none(self): - assert anthropic_tool_choice_to_openai(None) is None - - def test_unrecognized_shape_returns_none(self): - assert anthropic_tool_choice_to_openai({"type": "wibble"}) is None - assert anthropic_tool_choice_to_openai("auto") is None - assert anthropic_tool_choice_to_openai(42) is None - - -# ===================================================================== -# _build_passthrough_payload — tool_choice propagation -# ===================================================================== - - -class TestBuildPassthroughPayloadToolChoice: - def _args(self): - return dict( - openai_messages = [{"role": "user", "content": "Hi"}], - openai_tools = [ - { - "type": "function", - "function": {"name": "f", "parameters": {"type": "object"}}, - } - ], - temperature = 0.6, - top_p = 0.95, - top_k = 20, - max_tokens = 128, - stream = False, - ) - - def test_default_tool_choice_is_auto(self): - body = _build_passthrough_payload(**self._args()) - assert body["tool_choice"] == "auto" - - def test_override_tool_choice_required(self): - body = _build_passthrough_payload(**self._args(), tool_choice = "required") - assert body["tool_choice"] == "required" - - def test_override_tool_choice_none(self): - body = _build_passthrough_payload(**self._args(), tool_choice = "none") - assert body["tool_choice"] == "none" - - def test_override_tool_choice_named_function(self): - tc = {"type": "function", "function": {"name": "f"}} - body = _build_passthrough_payload(**self._args(), tool_choice = tc) - assert body["tool_choice"] == tc - - def test_stream_adds_include_usage(self): - args = self._args() - args["stream"] = True - body = _build_passthrough_payload(**args) - assert body.get("stream_options") == {"include_usage": True} - - def test_repetition_penalty_renamed(self): - body = _build_passthrough_payload(**self._args(), repetition_penalty = 1.1) - assert body.get("repeat_penalty") == 1.1 - assert "repetition_penalty" not in body - - -# ===================================================================== -# _friendly_error — httpx transport failures -# ===================================================================== - - -class TestFriendlyErrorHttpx: - """The async pass-through helpers talk to llama-server via httpx. - When the subprocess is down, httpx raises RequestError subclasses - whose string form (``"All connection attempts failed"``, ``"[Errno 111] - Connection refused"``, ...) does NOT contain the substring - ``"Lost connection to llama-server"`` the sync path uses, so the - previous substring-only `_friendly_error` returned a useless generic - message. These tests pin the new isinstance-based mapping. - """ - - def _req(self): - return httpx.Request("POST", "http://127.0.0.1:65535/v1/chat/completions") - - def test_connect_error_mapped(self): - exc = httpx.ConnectError("All connection attempts failed", request = self._req()) - assert "Lost connection" in _friendly_error(exc) - - def test_read_error_mapped(self): - exc = httpx.ReadError("EOF", request = self._req()) - assert "Lost connection" in _friendly_error(exc) - - def test_remote_protocol_error_mapped(self): - exc = httpx.RemoteProtocolError("peer closed", request = self._req()) - assert "Lost connection" in _friendly_error(exc) - - def test_read_timeout_mapped(self): - exc = httpx.ReadTimeout("timed out", request = self._req()) - assert "Lost connection" in _friendly_error(exc) - - def test_non_httpx_unchanged(self): - # Non-httpx exceptions still fall through to the existing substring - # heuristics — a context-size message must still produce the - # "Message too long" path. - ctx_msg = ( - "request (4096 tokens) exceeds the available context size (2048 tokens)" - ) - assert "Message too long" in _friendly_error(ValueError(ctx_msg)) - - def test_generic_exception_returns_generic_message(self): - assert ( - _friendly_error(RuntimeError("unrelated")) == "An internal error occurred" - ) diff --git a/studio/backend/tests/test_pytorch_mirror.py b/studio/backend/tests/test_pytorch_mirror.py deleted file mode 100644 index 5844f209b6..0000000000 --- a/studio/backend/tests/test_pytorch_mirror.py +++ /dev/null @@ -1,55 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 - -"""Tests for UNSLOTH_PYTORCH_MIRROR env var in install_python_stack.py.""" - -from __future__ import annotations - -import importlib -import os -import sys -from pathlib import Path - -import pytest - -# install_python_stack.py lives at repo_root/studio/install_python_stack.py -_INSTALL_SCRIPT = Path(__file__).resolve().parents[2] / "install_python_stack.py" - -OFFICIAL_URL = "https://download.pytorch.org/whl" - - -def _reload_whl_base(monkeypatch, mirror_value = None): - """(Re-)import install_python_stack with a controlled env and return _PYTORCH_WHL_BASE.""" - # Remove cached module so the module-level assignment re-executes - sys.modules.pop("install_python_stack", None) - - if mirror_value is None: - monkeypatch.delenv("UNSLOTH_PYTORCH_MIRROR", raising = False) - else: - monkeypatch.setenv("UNSLOTH_PYTORCH_MIRROR", mirror_value) - - # Temporarily add the script's directory to sys.path for import - script_dir = str(_INSTALL_SCRIPT.parent) - monkeypatch.syspath_prepend(script_dir) - - import install_python_stack - - return install_python_stack._PYTORCH_WHL_BASE - - -class TestPyTorchMirrorEnvVar: - """UNSLOTH_PYTORCH_MIRROR controls _PYTORCH_WHL_BASE in install_python_stack.""" - - def test_unset_uses_official_url(self, monkeypatch): - assert _reload_whl_base(monkeypatch) == OFFICIAL_URL - - def test_empty_string_falls_back_to_official(self, monkeypatch): - assert _reload_whl_base(monkeypatch, "") == OFFICIAL_URL - - def test_custom_mirror_is_used(self, monkeypatch): - mirror = "https://mirrors.nju.edu.cn/pytorch/whl" - assert _reload_whl_base(monkeypatch, mirror) == mirror - - def test_trailing_slash_stripped(self, monkeypatch): - result = _reload_whl_base(monkeypatch, "https://example.com/whl/") - assert result == "https://example.com/whl" diff --git a/studio/backend/tests/test_responses_api.py b/studio/backend/tests/test_responses_api.py deleted file mode 100644 index 5b55f87259..0000000000 --- a/studio/backend/tests/test_responses_api.py +++ /dev/null @@ -1,328 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. - -""" -Tests for the OpenAI Responses API schemas and input normalisation. -These tests do NOT require a running server or GPU -- they validate -the Pydantic models and the _normalise_responses_input helper. -""" - -import sys -import os -import json -import re - -# Ensure backend is on path -_backend = os.path.join(os.path.dirname(__file__), "..") -sys.path.insert(0, _backend) - -from models.inference import ( - ResponsesRequest, - ResponsesInputMessage, - ResponsesInputTextPart, - ResponsesInputImagePart, - ResponsesOutputTextContent, - ResponsesOutputMessage, - ResponsesUsage, - ResponsesResponse, - ChatMessage, - TextContentPart, - ImageContentPart, - ImageUrl, - ChatCompletionRequest, -) - - -# ── _normalise_responses_input: copied from routes/inference.py ── -# We cannot import routes.inference directly because routes/__init__.py -# pulls in heavy dependencies (structlog/twisted/torch). This is a -# direct copy of the function for testing purposes. - - -def _normalise_responses_input(payload: ResponsesRequest) -> list: - """Convert a ResponsesRequest into a list of ChatMessage for the completions backend.""" - messages = [] - - # System / developer instructions - if payload.instructions: - messages.append(ChatMessage(role = "system", content = payload.instructions)) - - # Simple string input - if isinstance(payload.input, str): - if payload.input: - messages.append(ChatMessage(role = "user", content = payload.input)) - return messages - - # List of ResponsesInputMessage - for msg in payload.input: - role = "system" if msg.role == "developer" else msg.role - - if isinstance(msg.content, str): - messages.append(ChatMessage(role = role, content = msg.content)) - else: - # Convert Responses content parts -> Chat content parts - parts = [] - for part in msg.content: - if isinstance(part, ResponsesInputTextPart): - parts.append(TextContentPart(type = "text", text = part.text)) - elif isinstance(part, ResponsesInputImagePart): - parts.append( - ImageContentPart( - type = "image_url", - image_url = ImageUrl(url = part.image_url, detail = part.detail), - ) - ) - messages.append(ChatMessage(role = role, content = parts if parts else "")) - - return messages - - -# ===================================================================== -# Schema validation tests -# ===================================================================== - - -class TestResponsesRequest: - """Validate ResponsesRequest accepts the shapes the OpenAI SDK sends.""" - - def test_minimal_string_input(self): - req = ResponsesRequest(input = "Hello") - assert req.input == "Hello" - assert req.stream is False - assert req.model == "default" - - def test_message_list_input(self): - req = ResponsesRequest( - input = [ - {"role": "user", "content": "Hi"}, - {"role": "assistant", "content": "Hello!"}, - ], - ) - assert len(req.input) == 2 - assert req.input[0].role == "user" - assert req.input[0].content == "Hi" - - def test_multimodal_input(self): - req = ResponsesRequest( - input = [ - { - "role": "user", - "content": [ - {"type": "input_text", "text": "What is in this image?"}, - { - "type": "input_image", - "image_url": "https://example.com/img.png", - }, - ], - }, - ], - ) - parts = req.input[0].content - assert len(parts) == 2 - assert isinstance(parts[0], ResponsesInputTextPart) - assert isinstance(parts[1], ResponsesInputImagePart) - - def test_instructions_field(self): - req = ResponsesRequest( - input = "test", - instructions = "You are a helpful assistant.", - ) - assert req.instructions == "You are a helpful assistant." - - def test_extra_fields_accepted(self): - """OpenAI SDK may send fields we don't model -- extra='allow' should pass.""" - req = ResponsesRequest( - input = "test", - tools = [{"type": "web_search_preview"}], - store = True, - metadata = {"key": "value"}, - previous_response_id = "resp_abc123", - ) - assert req.tools == [{"type": "web_search_preview"}] - assert req.store is True - - def test_stream_flag(self): - req = ResponsesRequest(input = "test", stream = True) - assert req.stream is True - - def test_temperature_and_top_p(self): - req = ResponsesRequest(input = "test", temperature = 0.8, top_p = 0.9) - assert req.temperature == 0.8 - assert req.top_p == 0.9 - - def test_max_output_tokens(self): - req = ResponsesRequest(input = "test", max_output_tokens = 512) - assert req.max_output_tokens == 512 - - def test_developer_role(self): - req = ResponsesRequest( - input = [{"role": "developer", "content": "System instructions"}], - ) - assert req.input[0].role == "developer" - - -# ===================================================================== -# Response model tests -# ===================================================================== - - -class TestResponsesResponse: - """Validate response models serialise correctly.""" - - def test_basic_response(self): - resp = ResponsesResponse( - model = "test-model", - output = [ - ResponsesOutputMessage( - content = [ResponsesOutputTextContent(text = "Hello!")] - ), - ], - usage = ResponsesUsage(input_tokens = 10, output_tokens = 5, total_tokens = 15), - ) - d = resp.model_dump() - assert d["object"] == "response" - assert d["status"] == "completed" - assert d["output"][0]["type"] == "message" - assert d["output"][0]["content"][0]["type"] == "output_text" - assert d["output"][0]["content"][0]["text"] == "Hello!" - assert d["usage"]["input_tokens"] == 10 - assert d["usage"]["output_tokens"] == 5 - assert d["usage"]["total_tokens"] == 15 - # Must NOT have prompt_tokens / completion_tokens - assert "prompt_tokens" not in d["usage"] - assert "completion_tokens" not in d["usage"] - - def test_id_format(self): - resp = ResponsesResponse() - assert resp.id.startswith("resp_") - - def test_output_message_id_format(self): - msg = ResponsesOutputMessage() - assert msg.id.startswith("msg_") - - def test_annotations_default_empty(self): - part = ResponsesOutputTextContent(text = "hi") - assert part.annotations == [] - - def test_response_json_roundtrip(self): - resp = ResponsesResponse( - model = "gpt-4", - output = [ - ResponsesOutputMessage( - content = [ResponsesOutputTextContent(text = "ok")], - ), - ], - usage = ResponsesUsage(input_tokens = 1, output_tokens = 1, total_tokens = 2), - ) - j = json.loads(resp.model_dump_json()) - assert j["object"] == "response" - assert j["output"][0]["role"] == "assistant" - assert j["output"][0]["status"] == "completed" - - -# ===================================================================== -# Input normalisation tests -# ===================================================================== - - -class TestNormaliseResponsesInput: - """Test _normalise_responses_input converts Responses input to ChatMessages.""" - - def test_string_input(self): - payload = ResponsesRequest(input = "Hello world") - msgs = _normalise_responses_input(payload) - assert len(msgs) == 1 - assert msgs[0].role == "user" - assert msgs[0].content == "Hello world" - - def test_instructions_become_system_message(self): - payload = ResponsesRequest( - input = "Hi", - instructions = "Be concise.", - ) - msgs = _normalise_responses_input(payload) - assert len(msgs) == 2 - assert msgs[0].role == "system" - assert msgs[0].content == "Be concise." - assert msgs[1].role == "user" - assert msgs[1].content == "Hi" - - def test_message_list(self): - payload = ResponsesRequest( - input = [ - {"role": "user", "content": "First"}, - {"role": "assistant", "content": "Response"}, - {"role": "user", "content": "Second"}, - ], - ) - msgs = _normalise_responses_input(payload) - assert len(msgs) == 3 - assert msgs[0].role == "user" - assert msgs[1].role == "assistant" - assert msgs[2].role == "user" - - def test_developer_role_maps_to_system(self): - payload = ResponsesRequest( - input = [{"role": "developer", "content": "Instructions"}], - ) - msgs = _normalise_responses_input(payload) - assert msgs[0].role == "system" - assert msgs[0].content == "Instructions" - - def test_multimodal_parts(self): - payload = ResponsesRequest( - input = [ - { - "role": "user", - "content": [ - {"type": "input_text", "text": "Describe this:"}, - { - "type": "input_image", - "image_url": "data:image/png;base64,abc", - }, - ], - }, - ], - ) - msgs = _normalise_responses_input(payload) - assert len(msgs) == 1 - content = msgs[0].content - assert isinstance(content, list) - assert len(content) == 2 - assert isinstance(content[0], TextContentPart) - assert content[0].text == "Describe this:" - assert isinstance(content[1], ImageContentPart) - assert content[1].image_url.url == "data:image/png;base64,abc" - - def test_empty_string_input(self): - payload = ResponsesRequest(input = "") - msgs = _normalise_responses_input(payload) - assert len(msgs) == 0 - - def test_empty_list_input(self): - payload = ResponsesRequest(input = []) - msgs = _normalise_responses_input(payload) - assert len(msgs) == 0 - - def test_instructions_only(self): - payload = ResponsesRequest(input = "", instructions = "System msg") - msgs = _normalise_responses_input(payload) - assert len(msgs) == 1 - assert msgs[0].role == "system" - - def test_instructions_plus_message_list(self): - payload = ResponsesRequest( - input = [{"role": "user", "content": "Hello"}], - instructions = "Be brief.", - ) - msgs = _normalise_responses_input(payload) - assert len(msgs) == 2 - assert msgs[0].role == "system" - assert msgs[0].content == "Be brief." - assert msgs[1].role == "user" - - -if __name__ == "__main__": - import pytest - - pytest.main([__file__, "-v"]) diff --git a/studio/backend/tests/test_responses_tool_passthrough.py b/studio/backend/tests/test_responses_tool_passthrough.py deleted file mode 100644 index 2f1161c329..0000000000 --- a/studio/backend/tests/test_responses_tool_passthrough.py +++ /dev/null @@ -1,667 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. - -""" -Tests for the OpenAI /v1/responses client-side function-calling pass-through. - -Covers: -- ResponsesRequest accepts Responses-shape `tools`, `tool_choice`, - `parallel_tool_calls`, and the `function_call` / `function_call_output` - input items used for multi-turn tool loops. -- _translate_responses_tools_to_chat() converts the flat Responses tool - shape to the nested Chat Completions shape, drops non-function built-in - tools, and returns None for empty lists. -- _translate_responses_tool_choice_to_chat() passes string choices through - and converts {type:function,name:X} to Chat Completions' nested shape. -- _normalise_responses_input() maps function_call_output items to - role="tool" ChatMessages with tool_call_id, and function_call items to - assistant messages with tool_calls. -- _chat_tool_calls_to_responses_output() preserves call_id and drops - non-function tool calls. -- ResponsesOutputFunctionCall and ResponsesResponse round-trip tool-call - outputs without losing fields. - -No running server or GPU required. -""" - -import os -import sys - -_backend = os.path.join(os.path.dirname(__file__), "..") -sys.path.insert(0, _backend) - -import json - -import pytest -from pydantic import ValidationError - -from models.inference import ( - ChatMessage, - ResponsesFunctionCallInputItem, - ResponsesFunctionCallOutputInputItem, - ResponsesFunctionTool, - ResponsesInputMessage, - ResponsesOutputFunctionCall, - ResponsesOutputMessage, - ResponsesOutputTextContent, - ResponsesOutputTextPart, - ResponsesRequest, - ResponsesResponse, - ResponsesUnknownContentPart, - ResponsesUnknownInputItem, - ResponsesUsage, -) -from routes.inference import ( - _chat_tool_calls_to_responses_output, - _normalise_responses_input, - _translate_responses_tool_choice_to_chat, - _translate_responses_tools_to_chat, -) - - -# ===================================================================== -# Request model — tools / tool_choice / parallel_tool_calls -# ===================================================================== - - -class TestResponsesRequestTools: - def test_flat_function_tool_accepted(self): - req = ResponsesRequest( - input = "hi", - tools = [ - { - "type": "function", - "name": "get_weather", - "description": "Get the weather for a city.", - "parameters": { - "type": "object", - "properties": {"city": {"type": "string"}}, - "required": ["city"], - }, - "strict": True, - } - ], - ) - assert req.tools is not None - assert req.tools[0]["name"] == "get_weather" - assert req.tools[0]["type"] == "function" - assert req.tools[0]["strict"] is True - - def test_tool_choice_string_values(self): - for choice in ("auto", "required", "none"): - req = ResponsesRequest(input = "hi", tool_choice = choice) - assert req.tool_choice == choice - - def test_tool_choice_forcing_object(self): - req = ResponsesRequest( - input = "hi", - tool_choice = {"type": "function", "name": "get_weather"}, - ) - assert req.tool_choice == {"type": "function", "name": "get_weather"} - - def test_parallel_tool_calls(self): - req = ResponsesRequest(input = "hi", parallel_tool_calls = True) - assert req.parallel_tool_calls is True - - def test_builtin_tool_type_passes_validation(self): - """Non-function built-in tools (web_search, file_search, mcp, ...) must - not raise at request validation so SDKs that default to them don't - fail on Studio; they are filtered out during translation.""" - req = ResponsesRequest( - input = "hi", - tools = [{"type": "web_search_preview"}], - ) - assert req.tools == [{"type": "web_search_preview"}] - - def test_function_tool_model_direct(self): - tool = ResponsesFunctionTool( - type = "function", - name = "send_email", - parameters = {"type": "object", "properties": {}}, - ) - assert tool.name == "send_email" - assert tool.description is None - - def test_function_tool_rejects_other_type(self): - with pytest.raises(ValidationError): - ResponsesFunctionTool(type = "web_search", name = "x") - - -# ===================================================================== -# Request model — function_call / function_call_output input items -# ===================================================================== - - -class TestResponsesMultiTurnInput: - def test_function_call_input_item(self): - req = ResponsesRequest( - input = [ - {"role": "user", "content": "Weather in Paris?"}, - { - "type": "function_call", - "id": "fc_abc", - "call_id": "call_abc", - "name": "get_weather", - "arguments": '{"city": "Paris"}', - }, - { - "type": "function_call_output", - "call_id": "call_abc", - "output": '{"temp": 12}', - }, - ], - ) - assert len(req.input) == 3 - assert isinstance(req.input[1], ResponsesFunctionCallInputItem) - assert req.input[1].call_id == "call_abc" - assert isinstance(req.input[2], ResponsesFunctionCallOutputInputItem) - assert req.input[2].call_id == "call_abc" - assert req.input[2].output == '{"temp": 12}' - - def test_function_call_output_missing_call_id_rejected(self): - with pytest.raises(ValidationError): - ResponsesFunctionCallOutputInputItem( - type = "function_call_output", output = "x" - ) - - def test_function_call_output_accepts_content_array(self): - item = ResponsesFunctionCallOutputInputItem( - type = "function_call_output", - call_id = "call_1", - output = [{"type": "output_text", "text": "done"}], - ) - assert isinstance(item.output, list) - - -# ===================================================================== -# Translators — tools, tool_choice -# ===================================================================== - - -class TestToolsTranslation: - def test_flat_to_nested(self): - tools = [ - { - "type": "function", - "name": "get_weather", - "description": "Returns weather.", - "parameters": {"type": "object"}, - "strict": True, - } - ] - out = _translate_responses_tools_to_chat(tools) - assert out == [ - { - "type": "function", - "function": { - "name": "get_weather", - "description": "Returns weather.", - "parameters": {"type": "object"}, - "strict": True, - }, - } - ] - - def test_builtin_tools_dropped(self): - out = _translate_responses_tools_to_chat( - [ - {"type": "web_search_preview"}, - {"type": "file_search"}, - { - "type": "function", - "name": "search", - "parameters": {"type": "object"}, - }, - ] - ) - assert len(out) == 1 - assert out[0]["function"]["name"] == "search" - - def test_empty_returns_none(self): - assert _translate_responses_tools_to_chat(None) is None - assert _translate_responses_tools_to_chat([]) is None - - def test_only_builtin_tools_returns_none(self): - assert ( - _translate_responses_tools_to_chat([{"type": "web_search_preview"}]) is None - ) - - def test_description_optional(self): - out = _translate_responses_tools_to_chat( - [ - { - "type": "function", - "name": "noop", - "parameters": {"type": "object"}, - } - ] - ) - assert "description" not in out[0]["function"] - - -class TestToolChoiceTranslation: - def test_string_passthrough(self): - for v in ("auto", "required", "none"): - assert _translate_responses_tool_choice_to_chat(v) == v - - def test_none_passthrough(self): - assert _translate_responses_tool_choice_to_chat(None) is None - - def test_forcing_object_converted(self): - assert _translate_responses_tool_choice_to_chat( - {"type": "function", "name": "get_weather"} - ) == {"type": "function", "function": {"name": "get_weather"}} - - def test_already_chat_nested_shape_passes_through(self): - """If a client happens to send the Chat Completions nested shape, - we don't double-wrap it.""" - already_nested = {"type": "function", "function": {"name": "get_weather"}} - assert ( - _translate_responses_tool_choice_to_chat(already_nested) == already_nested - ) - - def test_unknown_shape_passes_through(self): - obj = {"type": "allowed_tools", "tools": [{"type": "function", "name": "x"}]} - assert _translate_responses_tool_choice_to_chat(obj) == obj - - -# ===================================================================== -# _normalise_responses_input — multi-turn tool mapping -# ===================================================================== - - -class TestNormaliseResponsesInputWithTools: - def test_function_call_output_maps_to_tool_role(self): - payload = ResponsesRequest( - input = [ - {"role": "user", "content": "Weather?"}, - { - "type": "function_call", - "call_id": "call_1", - "name": "get_weather", - "arguments": "{}", - }, - { - "type": "function_call_output", - "call_id": "call_1", - "output": '{"temp": 20}', - }, - ], - ) - msgs = _normalise_responses_input(payload) - assert len(msgs) == 3 - assert msgs[0].role == "user" - - assert msgs[1].role == "assistant" - assert msgs[1].tool_calls is not None - assert msgs[1].tool_calls[0]["id"] == "call_1" - assert msgs[1].tool_calls[0]["function"]["name"] == "get_weather" - - assert msgs[2].role == "tool" - assert msgs[2].tool_call_id == "call_1" - assert msgs[2].content == '{"temp": 20}' - - def test_instructions_plus_developer_message_are_merged(self): - """Codex CLI sends `instructions` (system prompt) AND a developer - message in `input`. Strict chat templates (harmony / gpt-oss, Qwen3, - ...) raise "System message must be at the beginning" when two - separate system-role messages appear, so we must emit exactly one - merged system message at the top. - """ - payload = ResponsesRequest( - instructions = "Base instructions.", - input = [ - {"role": "developer", "content": "Developer override."}, - {"role": "user", "content": "Hi"}, - ], - ) - msgs = _normalise_responses_input(payload) - system_roles = [m for m in msgs if m.role == "system"] - assert len(system_roles) == 1 - assert "Base instructions." in system_roles[0].content - assert "Developer override." in system_roles[0].content - # System must be the very first message for strict templates. - assert msgs[0].role == "system" - assert msgs[1].role == "user" - - def test_developer_message_after_user_is_still_hoisted(self): - """Multi-turn conversations where a developer message appears after - user turns must still produce a single leading system message, not - a mid-conversation system that strict templates reject.""" - payload = ResponsesRequest( - input = [ - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi!"}, - {"role": "developer", "content": "Updated rules."}, - {"role": "user", "content": "Continue"}, - ], - ) - msgs = _normalise_responses_input(payload) - assert msgs[0].role == "system" - assert "Updated rules." in msgs[0].content - for m in msgs[1:]: - assert m.role != "system", "no trailing system message permitted" - - def test_no_system_output_when_no_system_input(self): - payload = ResponsesRequest(input = "Hi") - msgs = _normalise_responses_input(payload) - assert all(m.role != "system" for m in msgs) - - def test_multiple_system_messages_in_input_are_merged(self): - payload = ResponsesRequest( - input = [ - {"role": "system", "content": "A"}, - {"role": "system", "content": "B"}, - {"role": "user", "content": "Hi"}, - ], - ) - msgs = _normalise_responses_input(payload) - assert sum(1 for m in msgs if m.role == "system") == 1 - assert "A" in msgs[0].content and "B" in msgs[0].content - - def test_content_array_output_serialised_to_json_string(self): - payload = ResponsesRequest( - input = [ - { - "type": "function_call_output", - "call_id": "call_1", - "output": [{"type": "output_text", "text": "ok"}], - } - ], - ) - msgs = _normalise_responses_input(payload) - assert msgs[0].role == "tool" - # Content is serialised so llama-server sees a string. - assert json.loads(msgs[0].content) == [{"type": "output_text", "text": "ok"}] - - -# ===================================================================== -# Response mapping — tool_calls → function_call output items -# ===================================================================== - - -class TestChatToolCallsToResponsesOutput: - def test_basic_mapping(self): - items = _chat_tool_calls_to_responses_output( - [ - { - "id": "call_abc", - "type": "function", - "function": { - "name": "get_weather", - "arguments": '{"city":"Paris"}', - }, - } - ] - ) - assert len(items) == 1 - assert items[0]["type"] == "function_call" - assert items[0]["call_id"] == "call_abc" - assert items[0]["name"] == "get_weather" - assert items[0]["arguments"] == '{"city":"Paris"}' - assert items[0]["status"] == "completed" - assert items[0]["id"].startswith("fc_") - - def test_multiple_tool_calls_preserved(self): - items = _chat_tool_calls_to_responses_output( - [ - { - "id": "call_1", - "type": "function", - "function": {"name": "a", "arguments": "{}"}, - }, - { - "id": "call_2", - "type": "function", - "function": {"name": "b", "arguments": "{}"}, - }, - ] - ) - assert [it["call_id"] for it in items] == ["call_1", "call_2"] - - def test_non_function_tool_call_dropped(self): - items = _chat_tool_calls_to_responses_output([{"id": "x", "type": "retrieval"}]) - assert items == [] - - def test_missing_arguments_coerced_to_empty_string(self): - items = _chat_tool_calls_to_responses_output( - [{"id": "call_1", "type": "function", "function": {"name": "x"}}] - ) - assert items[0]["arguments"] == "" - - -# ===================================================================== -# Response model — ResponsesOutputFunctionCall / mixed output -# ===================================================================== - - -class TestResponsesOutputFunctionCall: - def test_direct_construction(self): - fc = ResponsesOutputFunctionCall( - call_id = "call_1", - name = "get_weather", - arguments = '{"city":"Paris"}', - ) - d = fc.model_dump() - assert d["type"] == "function_call" - assert d["call_id"] == "call_1" - assert d["status"] == "completed" - assert d["id"].startswith("fc_") - - def test_response_with_tool_call_output(self): - resp = ResponsesResponse( - model = "test", - output = [ - ResponsesOutputFunctionCall( - call_id = "call_1", - name = "get_weather", - arguments = "{}", - ) - ], - usage = ResponsesUsage(input_tokens = 1, output_tokens = 1, total_tokens = 2), - ) - d = json.loads(resp.model_dump_json()) - assert d["output"][0]["type"] == "function_call" - assert d["output"][0]["call_id"] == "call_1" - - def test_response_with_mixed_output(self): - resp = ResponsesResponse( - model = "test", - output = [ - ResponsesOutputMessage( - content = [ResponsesOutputTextContent(text = "Calling...")], - ), - ResponsesOutputFunctionCall( - call_id = "call_1", - name = "get_weather", - arguments = '{"city":"Paris"}', - ), - ], - ) - d = resp.model_dump() - assert d["output"][0]["type"] == "message" - assert d["output"][1]["type"] == "function_call" - - -# ===================================================================== -# Regression: ChatMessage validator still accepts mapped tool messages -# ===================================================================== - - -class TestCodexStyleRequestShapes: - """Regression tests for the request shapes OpenAI Codex CLI sends.""" - - def test_assistant_replay_output_text_accepted(self): - """Codex replays prior assistant turns with `output_text` content. - Before, this triggered a 422 on every turn after the first.""" - req = ResponsesRequest( - input = [ - {"role": "user", "content": "Hi"}, - { - "type": "message", - "role": "assistant", - "content": [ - { - "type": "output_text", - "text": "Hello!", - "annotations": [], - "logprobs": [], - } - ], - }, - {"role": "user", "content": "Continue"}, - ], - ) - assert len(req.input) == 3 - parts = req.input[1].content - assert isinstance(parts, list) - assert isinstance(parts[0], ResponsesOutputTextPart) - assert parts[0].text == "Hello!" - - def test_reasoning_item_accepted_as_unknown(self): - """`reasoning` items replayed from prior o-series turns must not - fail validation — Codex preserves them in multi-turn.""" - req = ResponsesRequest( - input = [ - {"role": "user", "content": "Hi"}, - { - "type": "reasoning", - "id": "rs_1", - "summary": [], - "encrypted_content": "opaque", - }, - {"role": "assistant", "content": "Hello!"}, - ], - ) - assert len(req.input) == 3 - assert isinstance(req.input[1], ResponsesUnknownInputItem) - - def test_unknown_content_part_type_accepted(self): - """Unknown content-part types (e.g. future input_audio) validate as - ResponsesUnknownContentPart so the whole request doesn't 422.""" - req = ResponsesRequest( - input = [ - { - "role": "user", - "content": [ - {"type": "input_text", "text": "See:"}, - {"type": "input_audio", "audio": {"data": "..."}}, - ], - } - ], - ) - parts = req.input[0].content - assert isinstance(parts[1], ResponsesUnknownContentPart) - assert parts[1].type == "input_audio" - - def test_codex_full_shape_roundtrip(self): - """End-to-end: developer + user + assistant(output_text) + - function_call + function_call_output + reasoning in one request.""" - payload = ResponsesRequest( - instructions = "Base instructions.", - input = [ - { - "type": "message", - "role": "developer", - "content": [{"type": "input_text", "text": "Dev override."}], - }, - { - "type": "message", - "role": "user", - "content": [{"type": "input_text", "text": "Weather?"}], - }, - { - "type": "reasoning", - "id": "rs_1", - "summary": [], - }, - { - "type": "function_call", - "call_id": "call_1", - "name": "get_weather", - "arguments": "{}", - }, - { - "type": "function_call_output", - "call_id": "call_1", - "output": '{"temp":20}', - }, - { - "type": "message", - "role": "assistant", - "content": [ - { - "type": "output_text", - "text": "It's 20°C.", - "annotations": [], - "logprobs": [], - } - ], - }, - {"role": "user", "content": "And tomorrow?"}, - ], - ) - msgs = _normalise_responses_input(payload) - # Single leading merged system; no mid-conversation system. - assert msgs[0].role == "system" - assert sum(1 for m in msgs if m.role == "system") == 1 - assert "Base instructions." in msgs[0].content - assert "Dev override." in msgs[0].content - - roles = [m.role for m in msgs[1:]] - # Reasoning item is dropped. Order: user, assistant(tool_calls), - # tool, assistant(text), user. - assert roles == ["user", "assistant", "tool", "assistant", "user"] - assert msgs[2].tool_calls is not None - assert msgs[3].role == "tool" - assert msgs[3].tool_call_id == "call_1" - assert msgs[4].content == "It's 20°C." - - def test_single_output_text_part_flattens_to_string(self): - """ChatMessage assistant role prefers plain string content — tests - confirm we don't forward a single-part array that would otherwise - force legacy chat templates into multimodal handling.""" - payload = ResponsesRequest( - input = [ - { - "role": "assistant", - "content": [ - {"type": "output_text", "text": "ok", "annotations": []} - ], - }, - {"role": "user", "content": "next"}, - ], - ) - msgs = _normalise_responses_input(payload) - assert msgs[0].role == "assistant" - assert msgs[0].content == "ok" - - -class TestTranslatedMessagesValidate: - """Verify that the messages produced by _normalise_responses_input - satisfy ChatMessage's role-shape validator so the downstream /v1/chat/ - completions pass-through does not reject them.""" - - def test_round_trip_multi_turn(self): - payload = ResponsesRequest( - input = [ - {"role": "user", "content": "Weather in Paris?"}, - { - "type": "function_call", - "call_id": "call_1", - "name": "get_weather", - "arguments": '{"city": "Paris"}', - }, - { - "type": "function_call_output", - "call_id": "call_1", - "output": '{"temp": 20}', - }, - {"role": "user", "content": "Thanks!"}, - ], - ) - msgs = _normalise_responses_input(payload) - for m in msgs: - # Constructing a fresh ChatMessage from the dump round-trips the - # role-shape validator — the key invariant for the passthrough. - ChatMessage(**m.model_dump(exclude_none = True)) diff --git a/studio/backend/tests/test_studio_api.py b/studio/backend/tests/test_studio_api.py deleted file mode 100644 index 521c99e126..0000000000 --- a/studio/backend/tests/test_studio_api.py +++ /dev/null @@ -1,974 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 - -""" -End-to-end tests for Unsloth Studio's HTTP API surface. - -Covers the OpenAI-compatible and Anthropic-compatible endpoints exposed -by the server that ``unsloth studio run`` boots, plus API key -authentication and the CLI's ``--help`` output: - - 1. curl -- basic chat completions (non-streaming) - 2. curl -- streaming chat completions - 3. Python OpenAI SDK -- streaming completions - 4. curl -- Studio server-side tools (enable_tools=true) - 5. curl -- Standard OpenAI function calling (non-streaming) - 6. curl -- Standard OpenAI function calling (streaming) - 7. curl -- Standard OpenAI function calling (multi-turn tool loop) - 8. OpenAI Python SDK -- Standard function calling - 9. Anthropic Messages API -- basic non-streaming - 10. Anthropic Messages API -- streaming SSE - 11. Anthropic Python SDK -- non-streaming - 12. Anthropic Messages API -- streaming with tools - 13. Anthropic Messages API -- tool_choice={"type":"any"} honored - -Training, export, fine-tuning, and chat-UI concerns are out of scope — -see the unit suites elsewhere under ``studio/backend/tests/`` for those. - -Usage: - - # Script mode — launches its own server via ``unsloth studio run``. - python tests/test_studio_api.py - python tests/test_studio_api.py --model unsloth/... --gguf-variant ... - - # Pytest mode, external server — start a Studio server yourself, - # then point pytest at it. Fastest iteration loop. - unsloth studio run --model unsloth/Qwen3-1.7B-GGUF --gguf-variant UD-Q4_K_XL & - export UNSLOTH_E2E_BASE_URL=http://127.0.0.1:8080 - export UNSLOTH_E2E_API_KEY=sk-unsloth-... # from the server banner - pytest tests/test_studio_api.py -v - - # Pytest mode, fixture-managed server — pytest launches and tears - # down the server itself. One-shot verification, CI-friendly. - pytest tests/test_studio_api.py -v \\ - --unsloth-model unsloth/Qwen3-1.7B-GGUF \\ - --unsloth-gguf-variant UD-Q4_K_XL - -The ``base_url`` / ``api_key`` parameters on the test functions resolve -via the ``studio_server`` session fixture in ``conftest.py``. - -Requires a GPU and ~2 GB of disk for the GGUF download. -""" - -from __future__ import annotations - -import argparse -import json -import os -import re -import signal -import subprocess -import sys -import time -import urllib.error -import urllib.request -from pathlib import Path - - -# ── Configuration ──────────────────────────────────────────────────── - -DEFAULT_MODEL = "unsloth/Qwen3-1.7B-GGUF" -DEFAULT_VARIANT = "UD-Q4_K_XL" -PORT = 18222 # high port unlikely to collide -HOST = "127.0.0.1" -STARTUP_TIMEOUT = 120 # seconds to wait for banner -LOG_FILE = ( - Path(__file__).resolve().parent.parent.parent.parent - / "temp" - / "test_studio_api.log" -) - - -# ── Helpers ────────────────────────────────────────────────────────── - - -def _http( - method: str, - url: str, - *, - body: dict | None = None, - headers: dict | None = None, - timeout: int = 60, -) -> tuple[int, str]: - """Minimal stdlib HTTP helper. Returns (status_code, body_text).""" - data = json.dumps(body).encode() if body else None - req = urllib.request.Request(url, data = data, headers = headers or {}, method = method) - if body: - req.add_header("Content-Type", "application/json") - try: - with urllib.request.urlopen(req, timeout = timeout) as resp: - return resp.status, resp.read().decode() - except urllib.error.HTTPError as exc: - return exc.code, exc.read().decode(errors = "replace") - - -def _stream_http( - url: str, - *, - body: dict, - headers: dict, - timeout: int = 60, -) -> tuple[int, list[dict]]: - """POST a streaming request and collect SSE chunks.""" - data = json.dumps(body).encode() - req = urllib.request.Request(url, data = data, headers = headers, method = "POST") - req.add_header("Content-Type", "application/json") - chunks: list[dict] = [] - try: - with urllib.request.urlopen(req, timeout = timeout) as resp: - status = resp.status - for raw_line in resp: - line = raw_line.decode().strip() - if line.startswith("data: ") and line != "data: [DONE]": - try: - chunks.append(json.loads(line[6:])) - except json.JSONDecodeError: - pass - return status, chunks - except urllib.error.HTTPError as exc: - return exc.code, [] - - -# ── Test functions ─────────────────────────────────────────────────── - - -def test_help_output(): - """``unsloth studio run --help`` should show all documented options.""" - result = subprocess.run( - ["unsloth", "studio", "run", "--help"], - capture_output = True, - text = True, - timeout = 15, - ) - out = result.stdout - assert result.returncode == 0, f"--help exited with {result.returncode}" - - for flag in [ - "--model", - "--gguf-variant", - "--max-seq-length", - "--load-in-4bit", - "--api-key-name", - "--port", - "--host", - "--frontend", - "--silent", - ]: - assert flag in out, f"Missing flag {flag!r} in --help output" - print(" PASS --help shows all flags") - - -def test_curl_basic(base_url: str, api_key: str): - """Example 1: basic non-streaming chat completion via HTTP.""" - status, text = _http( - "POST", - f"{base_url}/v1/chat/completions", - body = { - "messages": [{"role": "user", "content": "Say just the word hello"}], - "stream": False, - }, - headers = {"Authorization": f"Bearer {api_key}"}, - ) - assert status == 200, f"Expected 200, got {status}: {text[:300]}" - data = json.loads(text) - assert "choices" in data, f"Missing 'choices' in response: {text[:300]}" - content = data["choices"][0]["message"]["content"] - assert len(content) > 0, "Empty assistant content" - print(f" PASS curl basic: {content[:80]!r}") - - -def _collect_streamed_content(chunks: list[dict]) -> str: - """Extract text from SSE chunks, skipping role-only and usage chunks.""" - parts = [] - for c in chunks: - choices = c.get("choices", []) - if not choices: - continue - delta = choices[0].get("delta", {}) - part = delta.get("content") - if part: - parts.append(part) - return "".join(parts) - - -def test_curl_streaming(base_url: str, api_key: str): - """Example 2: streaming chat completion via HTTP SSE.""" - status, chunks = _stream_http( - f"{base_url}/v1/chat/completions", - body = { - "messages": [{"role": "user", "content": "Count from 1 to 3"}], - "stream": True, - }, - headers = {"Authorization": f"Bearer {api_key}"}, - ) - assert status == 200, f"Expected 200, got {status}" - assert len(chunks) > 0, "No SSE chunks received" - full = _collect_streamed_content(chunks) - assert len(full) > 0, "Streamed content is empty" - print(f" PASS curl streaming: got {len(chunks)} chunks, {len(full)} chars") - - -def test_openai_sdk(base_url: str, api_key: str): - """Example 3: OpenAI Python SDK streaming completion.""" - try: - from openai import OpenAI - except ImportError: - print(" SKIP openai SDK not installed") - return - - client = OpenAI(base_url = f"{base_url}/v1", api_key = api_key) - response = client.chat.completions.create( - model = "current", - messages = [ - {"role": "user", "content": "What is 2+2? Answer with just the number."} - ], - stream = True, - ) - content_parts = [] - for chunk in response: - if not chunk.choices: - continue - delta_content = chunk.choices[0].delta.content - if delta_content: - content_parts.append(delta_content) - full = "".join(content_parts) - assert len(full) > 0, "OpenAI SDK returned empty content" - print(f" PASS OpenAI SDK streaming: {full.strip()[:80]!r}") - - -def test_curl_with_tools(base_url: str, api_key: str): - """Example 4: chat completion with tool calling enabled. - - Note: when ``enable_tools`` is set the server always returns SSE - streaming regardless of the ``stream`` flag, so we parse SSE chunks. - The model may or may not produce visible content -- tool orchestration - can intercept the response -- so we only assert the endpoint succeeds. - """ - status, chunks = _stream_http( - f"{base_url}/v1/chat/completions", - body = { - "messages": [ - { - "role": "user", - "content": "What is 123 * 456? Use code to compute it.", - } - ], - "stream": True, - "enable_tools": True, - "enabled_tools": ["python"], - "session_id": "test-session", - }, - headers = {"Authorization": f"Bearer {api_key}"}, - timeout = 120, - ) - assert status == 200, f"Expected 200, got {status}" - assert len(chunks) > 0, "No SSE chunks received for tools request" - - # Check that at least one chunk has the expected shape - has_valid_chunk = any("choices" in c or "type" in c for c in chunks) - assert has_valid_chunk, "No valid chunks in tools response" - full = _collect_streamed_content(chunks) - print(f" PASS curl with tools: {len(chunks)} chunks, {len(full)} chars content") - - -# ── Standard OpenAI function-calling pass-through tests ───────────── -# -# Regression coverage for unslothai/unsloth#4999: Studio's -# /v1/chat/completions used to silently strip standard OpenAI `tools` -# and `tool_choice` fields, so clients (opencode, Claude Code, Cursor, -# Continue, ...) could never get structured tool_calls back. These -# tests exercise the client-side pass-through path that forwards those -# fields to llama-server verbatim. -# -# They require a tool-capable GGUF (``supports_tools=True`` — e.g. -# Qwen3, Qwen2.5-Coder, Llama-3.1-Instruct). The default test model -# ``unsloth/Qwen3-1.7B-GGUF`` advertises tool support via its chat -# template metadata. - -_WEATHER_TOOL = { - "type": "function", - "function": { - "name": "get_weather", - "description": "Look up the current weather for a given city.", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": "string", - "description": "The name of the city, e.g. 'Paris'.", - }, - }, - "required": ["city"], - }, - }, -} - - -def _collect_streamed_tool_calls(chunks: list[dict]) -> list[dict]: - """Reassemble OpenAI streaming delta.tool_calls into full tool calls. - - OpenAI streams partial tool calls across chunks — the first chunk for - a given index carries ``id`` + ``function.name``, and subsequent - chunks append fragments to ``function.arguments``. - """ - by_index: dict[int, dict] = {} - for c in chunks: - choices = c.get("choices") or [] - if not choices: - continue - delta = choices[0].get("delta") or {} - tool_calls = delta.get("tool_calls") or [] - for tc in tool_calls: - idx = tc.get("index", 0) - slot = by_index.setdefault( - idx, - { - "id": None, - "type": "function", - "function": {"name": None, "arguments": ""}, - }, - ) - if tc.get("id"): - slot["id"] = tc["id"] - fn = tc.get("function") or {} - if fn.get("name"): - slot["function"]["name"] = fn["name"] - if fn.get("arguments"): - slot["function"]["arguments"] += fn["arguments"] - return [by_index[i] for i in sorted(by_index)] - - -def _final_finish_reason(chunks: list[dict]) -> str | None: - for c in reversed(chunks): - choices = c.get("choices") or [] - if not choices: - continue - fr = choices[0].get("finish_reason") - if fr is not None: - return fr - return None - - -def test_openai_tools_nonstream(base_url: str, api_key: str): - """Standard OpenAI function calling, non-streaming, tool_choice='required'. - - Regression: before the fix, Studio silently stripped `tools` and the - model returned plain text with finish_reason='stop'. After the fix, - llama-server's response is forwarded verbatim so the client sees - finish_reason='tool_calls' with a structured tool_calls array and - non-zero usage.prompt_tokens. - """ - status, text = _http( - "POST", - f"{base_url}/v1/chat/completions", - body = { - "messages": [{"role": "user", "content": "What is the weather in Paris?"}], - "tools": [_WEATHER_TOOL], - "tool_choice": "required", - "stream": False, - }, - headers = {"Authorization": f"Bearer {api_key}"}, - timeout = 120, - ) - assert status == 200, f"Expected 200, got {status}: {text[:500]}" - data = json.loads(text) - assert "choices" in data, f"Missing 'choices': {text[:300]}" - choice = data["choices"][0] - assert ( - choice["finish_reason"] == "tool_calls" - ), f"Expected finish_reason='tool_calls', got {choice['finish_reason']!r}" - msg = choice["message"] - tool_calls = msg.get("tool_calls") or [] - assert len(tool_calls) >= 1, f"No tool_calls in response: {msg}" - first = tool_calls[0] - assert first["type"] == "function" - assert ( - first["function"]["name"] == "get_weather" - ), f"Wrong tool name: {first['function']['name']!r}" - # arguments must be valid JSON - parsed = json.loads(first["function"]["arguments"]) - assert "city" in parsed, f"Tool call missing required 'city' arg: {parsed}" - # Usage must be non-zero (was 0 before the fix) - usage = data.get("usage") or {} - assert ( - usage.get("prompt_tokens", 0) > 0 - ), f"Expected non-zero prompt_tokens; got {usage}" - assert data.get("id"), "Missing response id" - print( - f" PASS openai tools non-stream: " - f"tool={first['function']['name']}, args={parsed}, " - f"prompt_tokens={usage['prompt_tokens']}" - ) - - -def test_openai_tools_stream(base_url: str, api_key: str): - """Standard OpenAI function calling, streaming, tool_choice='required'.""" - status, chunks = _stream_http( - f"{base_url}/v1/chat/completions", - body = { - "messages": [{"role": "user", "content": "What is the weather in Tokyo?"}], - "tools": [_WEATHER_TOOL], - "tool_choice": "required", - "stream": True, - }, - headers = {"Authorization": f"Bearer {api_key}"}, - timeout = 120, - ) - assert status == 200, f"Expected 200, got {status}" - assert len(chunks) > 0, "No SSE chunks received" - assert _final_finish_reason(chunks) == "tool_calls", ( - f"Expected final finish_reason='tool_calls', got " - f"{_final_finish_reason(chunks)!r}" - ) - assembled = _collect_streamed_tool_calls(chunks) - assert len(assembled) >= 1, "No tool_calls reassembled from stream" - first = assembled[0] - assert first["function"]["name"] == "get_weather" - parsed = json.loads(first["function"]["arguments"]) - assert "city" in parsed - print( - f" PASS openai tools stream: {len(chunks)} chunks, " - f"tool={first['function']['name']}, args={parsed}" - ) - - -def test_openai_tools_multiturn(base_url: str, api_key: str): - """Multi-turn client-side tool loop: validates that role='tool' result - messages and assistant messages carrying tool_calls are accepted. - - Regression: before the fix, ChatMessage.role was restricted to - {system,user,assistant} and rejected role='tool' at the Pydantic - validation stage. This test sends a full round trip so the model - receives the simulated tool result and responds with final text. - """ - status, text = _http( - "POST", - f"{base_url}/v1/chat/completions", - body = { - "messages": [ - {"role": "user", "content": "What is the weather in Paris?"}, - { - "role": "assistant", - "content": None, - "tool_calls": [ - { - "id": "call_test_1", - "type": "function", - "function": { - "name": "get_weather", - "arguments": '{"city": "Paris"}', - }, - } - ], - }, - { - "role": "tool", - "tool_call_id": "call_test_1", - "content": '{"temperature_c": 14, "condition": "cloudy"}', - }, - ], - "tools": [_WEATHER_TOOL], - "stream": False, - }, - headers = {"Authorization": f"Bearer {api_key}"}, - timeout = 120, - ) - assert status == 200, f"Expected 200, got {status}: {text[:500]}" - data = json.loads(text) - msg = data["choices"][0]["message"] - # The model should respond with text now that it has the tool result - content = msg.get("content") or "" - assert len(content) > 0 or msg.get( - "tool_calls" - ), f"Expected text or follow-up tool call, got empty message: {msg}" - print(f" PASS openai tools multiturn: {content[:80]!r}") - - -def test_openai_sdk_tool_calling(base_url: str, api_key: str): - """OpenAI Python SDK round trip — the real client shape opencode et al. use.""" - try: - from openai import OpenAI - except ImportError: - print(" SKIP openai SDK not installed") - return - - client = OpenAI(base_url = f"{base_url}/v1", api_key = api_key) - resp = client.chat.completions.create( - model = "current", - messages = [{"role": "user", "content": "What's the weather in Berlin?"}], - tools = [_WEATHER_TOOL], - tool_choice = "required", - stream = False, - ) - assert resp.choices[0].finish_reason == "tool_calls", ( - f"Expected finish_reason='tool_calls', got " - f"{resp.choices[0].finish_reason!r}" - ) - tool_calls = resp.choices[0].message.tool_calls - assert tool_calls and len(tool_calls) >= 1, "No tool_calls from SDK" - tc = tool_calls[0] - assert tc.function.name == "get_weather" - parsed = json.loads(tc.function.arguments) - assert "city" in parsed - print( - f" PASS openai SDK tool calling: " f"tool={tc.function.name}, args={parsed}" - ) - - -def test_invalid_key_rejected(base_url: str): - """Requests with a bad API key should be rejected.""" - status, _text = _http( - "POST", - f"{base_url}/v1/chat/completions", - body = { - "messages": [{"role": "user", "content": "Hello"}], - "stream": False, - }, - headers = {"Authorization": "Bearer sk-unsloth-boguskey123"}, - ) - assert status == 401, f"Expected 401 for invalid key, got {status}" - print(" PASS invalid API key rejected (401)") - - -def test_no_key_rejected(base_url: str): - """Requests without any auth header should be rejected.""" - status, _text = _http( - "POST", - f"{base_url}/v1/chat/completions", - body = { - "messages": [{"role": "user", "content": "Hello"}], - "stream": False, - }, - ) - assert status == 401 or status == 403, f"Expected 401/403 for no key, got {status}" - print(f" PASS no API key rejected ({status})") - - -# ── Anthropic SSE helper ───────────────────────────────────────────── - - -def _stream_anthropic_http( - url: str, - *, - body: dict, - headers: dict, - timeout: int = 60, -) -> tuple[int, list[tuple[str, dict]]]: - """POST a streaming request and collect Anthropic SSE events. - - Returns (status, [(event_type, data_dict), ...]). - """ - data = json.dumps(body).encode() - req = urllib.request.Request(url, data = data, headers = headers, method = "POST") - req.add_header("Content-Type", "application/json") - events: list[tuple[str, dict]] = [] - try: - with urllib.request.urlopen(req, timeout = timeout) as resp: - status = resp.status - current_event = None - for raw_line in resp: - line = raw_line.decode().strip() - if line.startswith("event: "): - current_event = line[7:] - elif line.startswith("data: ") and current_event: - try: - events.append((current_event, json.loads(line[6:]))) - except json.JSONDecodeError: - pass - current_event = None - return status, events - except urllib.error.HTTPError as exc: - return exc.code, [] - - -def _collect_anthropic_text(events: list[tuple[str, dict]]) -> str: - """Extract text content from Anthropic SSE events.""" - parts = [] - for etype, data in events: - if etype == "content_block_delta": - delta = data.get("delta", {}) - if delta.get("type") == "text_delta": - parts.append(delta.get("text", "")) - return "".join(parts) - - -# ── Anthropic /v1/messages test functions ──────────────────────────── - - -def test_anthropic_basic(base_url: str, api_key: str): - """Anthropic Messages API: non-streaming.""" - status, text = _http( - "POST", - f"{base_url}/v1/messages", - body = { - "model": "default", - "max_tokens": 100, - "messages": [{"role": "user", "content": "Say just the word hello"}], - }, - headers = {"Authorization": f"Bearer {api_key}"}, - ) - assert status == 200, f"Expected 200, got {status}: {text[:300]}" - data = json.loads(text) - assert data.get("type") == "message", f"Expected type 'message': {text[:300]}" - assert data.get("role") == "assistant" - content = data.get("content", []) - assert len(content) > 0, "Empty content array" - text_block = content[-1] - assert text_block.get("type") == "text", f"Expected text block: {text_block}" - assert len(text_block.get("text", "")) > 0, "Empty text in response" - print(f" PASS anthropic basic: {text_block['text'][:80]!r}") - - -def test_anthropic_streaming(base_url: str, api_key: str): - """Anthropic Messages API: streaming SSE.""" - status, events = _stream_anthropic_http( - f"{base_url}/v1/messages", - body = { - "model": "default", - "max_tokens": 100, - "messages": [{"role": "user", "content": "Count from 1 to 3"}], - "stream": True, - }, - headers = {"Authorization": f"Bearer {api_key}"}, - ) - assert status == 200, f"Expected 200, got {status}" - assert len(events) > 0, "No SSE events received" - - event_types = [e[0] for e in events] - assert "message_start" in event_types, "Missing message_start event" - assert "message_stop" in event_types, "Missing message_stop event" - - full = _collect_anthropic_text(events) - assert len(full) > 0, "Streamed text content is empty" - print(f" PASS anthropic streaming: {len(events)} events, {len(full)} chars") - - -def test_anthropic_sdk(base_url: str, api_key: str): - """Anthropic Python SDK: non-streaming.""" - try: - from anthropic import Anthropic - except ImportError: - print(" SKIP anthropic SDK not installed") - return - - client = Anthropic(base_url = f"{base_url}/v1", api_key = api_key) - message = client.messages.create( - model = "default", - max_tokens = 100, - messages = [ - {"role": "user", "content": "What is 2+2? Answer with just the number."} - ], - ) - assert message.role == "assistant" - assert len(message.content) > 0, "Empty content" - text = message.content[0].text - assert len(text) > 0, "Empty text" - print(f" PASS Anthropic SDK: {text.strip()[:80]!r}") - - -def test_anthropic_with_tools(base_url: str, api_key: str): - """Anthropic Messages API: streaming with tools.""" - status, events = _stream_anthropic_http( - f"{base_url}/v1/messages", - body = { - "model": "default", - "max_tokens": 1024, - "messages": [ - { - "role": "user", - "content": "What is 123 * 456? Use code to compute it.", - } - ], - "tools": [ - { - "name": "python", - "description": "Execute Python code in a sandbox and return stdout/stderr.", - "input_schema": { - "type": "object", - "properties": { - "code": { - "type": "string", - "description": "The Python code to run", - }, - }, - "required": ["code"], - }, - } - ], - "stream": True, - }, - headers = {"Authorization": f"Bearer {api_key}"}, - timeout = 120, - ) - assert status == 200, f"Expected 200, got {status}" - assert len(events) > 0, "No SSE events received for tools request" - - event_types = [e[0] for e in events] - assert "message_start" in event_types, "Missing message_start" - assert "message_stop" in event_types, "Missing message_stop" - - full = _collect_anthropic_text(events) - print( - f" PASS anthropic with tools: {len(events)} events, {len(full)} chars content" - ) - - -def test_anthropic_tool_choice_any(base_url: str, api_key: str): - """Anthropic Messages API: ``tool_choice: {"type": "any"}`` must be - honored (forwarded as OpenAI ``tool_choice: "required"`` to - llama-server). Regression for the secondary fix bundled with #4999 — - previously this field was accepted on the request model but silently - dropped with a warning log, so the model was free to answer from - memory instead of using the tool. - """ - status, events = _stream_anthropic_http( - f"{base_url}/v1/messages", - body = { - "model": "default", - "max_tokens": 256, - "messages": [ - # A question the model could easily answer from memory if - # tool_choice were not enforced. - { - "role": "user", - "content": "What is the weather in London right now?", - } - ], - "tools": [ - { - "name": "get_weather", - "description": "Look up current weather for a city.", - "input_schema": { - "type": "object", - "properties": { - "city": {"type": "string"}, - }, - "required": ["city"], - }, - } - ], - "tool_choice": {"type": "any"}, - "stream": True, - }, - headers = {"Authorization": f"Bearer {api_key}"}, - timeout = 120, - ) - assert status == 200, f"Expected 200, got {status}" - assert len(events) > 0, "No SSE events received" - - # With tool_choice=any, stop_reason must be tool_use (not end_turn) - stop_reason = None - for etype, data in events: - if etype == "message_delta": - stop_reason = data.get("delta", {}).get("stop_reason") or stop_reason - assert stop_reason == "tool_use", ( - f"Expected stop_reason='tool_use' with tool_choice=any, got " - f"{stop_reason!r} — tool_choice may not be forwarded to llama-server." - ) - - # And at least one tool_use content block must be emitted - tool_use_starts = [ - e - for e in events - if e[0] == "content_block_start" - and e[1].get("content_block", {}).get("type") == "tool_use" - ] - assert len(tool_use_starts) >= 1, "No tool_use content block emitted" - print( - f" PASS anthropic tool_choice=any honored: " - f"{len(tool_use_starts)} tool_use blocks, stop_reason={stop_reason}" - ) - - -# ── Server lifecycle ───────────────────────────────────────────────── - - -def _start_server(model: str, variant: str | None) -> tuple[subprocess.Popen, str]: - """Launch ``unsloth studio run`` and parse the API key from its banner. - - Returns (process, api_key). - """ - cmd = [ - "unsloth", - "studio", - "run", - "--model", - model, - "--port", - str(PORT), - "--host", - HOST, - "--api-key-name", - "test", - ] - if variant: - cmd.extend(["--gguf-variant", variant]) - - LOG_FILE.parent.mkdir(parents = True, exist_ok = True) - log_fh = open(LOG_FILE, "w") - proc = subprocess.Popen( - cmd, - stdout = log_fh, - stderr = subprocess.STDOUT, - preexec_fn = os.setsid, - ) - - # Wait for the banner containing the API key - api_key = None - deadline = time.monotonic() + STARTUP_TIMEOUT - while time.monotonic() < deadline: - time.sleep(2) - if proc.poll() is not None: - log_fh.flush() - log_text = LOG_FILE.read_text() - raise RuntimeError( - f"Server exited early (code {proc.returncode}):\n{log_text[-2000:]}" - ) - log_text = LOG_FILE.read_text() - m = re.search(r"API Key:\s+(sk-unsloth-[a-f0-9]+)", log_text) - if m: - api_key = m.group(1) - break - - if not api_key: - log_text = LOG_FILE.read_text() - _kill_server(proc) - raise RuntimeError( - f"Timed out waiting for API key in server output:\n{log_text[-2000:]}" - ) - - # Wait a moment for the model to be fully loaded - time.sleep(2) - return proc, api_key - - -def _kill_server(proc: subprocess.Popen): - """Send SIGTERM to the process group and wait for cleanup.""" - try: - os.killpg(os.getpgid(proc.pid), signal.SIGTERM) - except (ProcessLookupError, PermissionError): - pass - try: - proc.wait(timeout = 10) - except subprocess.TimeoutExpired: - try: - os.killpg(os.getpgid(proc.pid), signal.SIGKILL) - except (ProcessLookupError, PermissionError): - pass - proc.wait(timeout = 5) - - -# ── Main ───────────────────────────────────────────────────────────── - - -def main(): - parser = argparse.ArgumentParser( - description = "End-to-end tests for unsloth studio run" - ) - parser.add_argument( - "--model", - default = DEFAULT_MODEL, - help = f"Model to test with (default: {DEFAULT_MODEL})", - ) - parser.add_argument( - "--gguf-variant", - default = DEFAULT_VARIANT, - help = f"GGUF variant (default: {DEFAULT_VARIANT})", - ) - args = parser.parse_args() - - passed = 0 - failed = 0 - skipped = 0 - - def run_test(fn, *a, **kw): - nonlocal passed, failed, skipped - try: - fn(*a, **kw) - passed += 1 - except AssertionError as exc: - failed += 1 - print(f" FAIL {fn.__name__}: {exc}") - except Exception as exc: - failed += 1 - print(f" ERROR {fn.__name__}: {type(exc).__name__}: {exc}") - - # ── 1. Test --help (no server needed) ──────────────────────────── - print("\n[1/16] Testing --help output") - run_test(test_help_output) - - # ── 2-16. Start server and run API tests ───────────────────────── - print( - f"\nStarting server: {args.model} (variant={args.gguf_variant}) on port {PORT}..." - ) - proc = None - try: - proc, api_key = _start_server(args.model, args.gguf_variant) - base_url = f"http://{HOST}:{PORT}" - print(f"Server ready. API Key: {api_key[:20]}...\n") - - print("[2/16] Testing curl basic (non-streaming)") - run_test(test_curl_basic, base_url, api_key) - - print("[3/16] Testing curl streaming") - run_test(test_curl_streaming, base_url, api_key) - - print("[4/16] Testing OpenAI Python SDK (streaming)") - run_test(test_openai_sdk, base_url, api_key) - - print("[5/16] Testing curl with tools (server-side enable_tools)") - run_test(test_curl_with_tools, base_url, api_key) - - print("[6/16] Testing OpenAI standard tools (non-streaming)") - run_test(test_openai_tools_nonstream, base_url, api_key) - - print("[7/16] Testing OpenAI standard tools (streaming)") - run_test(test_openai_tools_stream, base_url, api_key) - - print("[8/16] Testing OpenAI standard tools (multi-turn)") - run_test(test_openai_tools_multiturn, base_url, api_key) - - print("[9/16] Testing OpenAI SDK tool calling") - run_test(test_openai_sdk_tool_calling, base_url, api_key) - - print("[10/16] Testing invalid API key rejection") - run_test(test_invalid_key_rejected, base_url) - - print("[11/16] Testing no API key rejection") - run_test(test_no_key_rejected, base_url) - - print("[12/16] Testing Anthropic basic (non-streaming)") - run_test(test_anthropic_basic, base_url, api_key) - - print("[13/16] Testing Anthropic streaming") - run_test(test_anthropic_streaming, base_url, api_key) - - print("[14/16] Testing Anthropic Python SDK") - run_test(test_anthropic_sdk, base_url, api_key) - - print("[15/16] Testing Anthropic with tools") - run_test(test_anthropic_with_tools, base_url, api_key) - - print("[16/16] Testing Anthropic tool_choice=any honored") - run_test(test_anthropic_tool_choice_any, base_url, api_key) - - except RuntimeError as exc: - print(f"\nFATAL: Server failed to start: {exc}") - failed += 16 # count remaining tests as failed - finally: - if proc: - print("\nStopping server...") - _kill_server(proc) - print("Server stopped.") - - # ── Summary ────────────────────────────────────────────────────── - total = passed + failed - print(f"\n{'=' * 40}") - print(f"Results: {passed}/{total} passed, {failed} failed") - print(f"Log: {LOG_FILE}") - print(f"{'=' * 40}") - sys.exit(1 if failed else 0) - - -if __name__ == "__main__": - main() diff --git a/studio/backend/tests/test_tool_policy_gates.py b/studio/backend/tests/test_tool_policy_gates.py deleted file mode 100644 index 01f6bbbc3f..0000000000 --- a/studio/backend/tests/test_tool_policy_gates.py +++ /dev/null @@ -1,56 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. - -""" -Tests for `_effective_enable_tools` -- the helper that folds the -process-level `tool_policy` over a request's `enable_tools` field. - -Truth table (policy x payload.enable_tools -> effective): - policy=None + payload=None -> None - policy=None + payload=True -> True - policy=None + payload=False -> False - policy=True + payload=* -> True - policy=False + payload=* -> False -""" - -import os -import sys -from types import SimpleNamespace - -_backend = os.path.join(os.path.dirname(__file__), "..") -sys.path.insert(0, _backend) - -import pytest - -from routes.inference import _effective_enable_tools -from state.tool_policy import reset_tool_policy, set_tool_policy - - -@pytest.fixture(autouse = True) -def _reset(): - reset_tool_policy() - yield - reset_tool_policy() - - -def _payload(value): - return SimpleNamespace(enable_tools = value) - - -class TestEffectiveEnableTools: - @pytest.mark.parametrize( - "payload_value,expected", - [(None, None), (True, True), (False, False)], - ) - def test_no_policy_falls_through_to_payload(self, payload_value, expected): - assert _effective_enable_tools(_payload(payload_value)) == expected - - @pytest.mark.parametrize("payload_value", [None, True, False]) - def test_policy_true_overrides_any_payload(self, payload_value): - set_tool_policy(True) - assert _effective_enable_tools(_payload(payload_value)) is True - - @pytest.mark.parametrize("payload_value", [None, True, False]) - def test_policy_false_overrides_any_payload(self, payload_value): - set_tool_policy(False) - assert _effective_enable_tools(_payload(payload_value)) is False diff --git a/studio/backend/tests/test_tool_policy_state.py b/studio/backend/tests/test_tool_policy_state.py deleted file mode 100644 index 5f6b228281..0000000000 --- a/studio/backend/tests/test_tool_policy_state.py +++ /dev/null @@ -1,59 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. - -""" -Tests for the process-level server-side tool policy used by `unsloth run`. - -The policy has three states: - None -> no CLI override (default; honor per-request enable_tools) - True -> CLI forced tools on - False -> CLI forced tools off -""" - -import os -import sys - -_backend = os.path.join(os.path.dirname(__file__), "..") -sys.path.insert(0, _backend) - -import pytest - -from state.tool_policy import ( - get_tool_policy, - reset_tool_policy, - set_tool_policy, -) - - -@pytest.fixture(autouse = True) -def _reset(): - reset_tool_policy() - yield - reset_tool_policy() - - -class TestToolPolicy: - def test_default_is_none(self): - assert get_tool_policy() is None - - def test_set_true_then_get(self): - set_tool_policy(True) - assert get_tool_policy() is True - - def test_set_false_then_get(self): - set_tool_policy(False) - assert get_tool_policy() is False - - def test_set_none_clears(self): - set_tool_policy(True) - set_tool_policy(None) - assert get_tool_policy() is None - - def test_reset_clears(self): - set_tool_policy(False) - reset_tool_policy() - assert get_tool_policy() is None - - def test_rejects_non_optional_bool(self): - with pytest.raises(TypeError): - set_tool_policy("true") # type: ignore[arg-type] diff --git a/studio/backend/tests/test_trained_model_scan.py b/studio/backend/tests/test_trained_model_scan.py deleted file mode 100644 index 84be681fca..0000000000 --- a/studio/backend/tests/test_trained_model_scan.py +++ /dev/null @@ -1,101 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 - -"""Tests for Studio trained-model discovery used by Chat.""" - -import json -from pathlib import Path -import sys -import types as _types -import importlib - - -_BACKEND_DIR = str(Path(__file__).resolve().parent.parent) -if _BACKEND_DIR not in sys.path: - sys.path.insert(0, _BACKEND_DIR) - -_loggers_stub = _types.ModuleType("loggers") -_loggers_stub.get_logger = lambda name: __import__("logging").getLogger(name) -sys.modules.setdefault("loggers", _loggers_stub) - -from unittest.mock import patch - -from utils.models.model_config import ( - ModelConfig, - get_base_model_from_checkpoint, - get_base_model_from_lora, - scan_trained_models, -) - - -def test_scan_trained_models_includes_lora_and_full_finetune_outputs(tmp_path: Path): - lora_dir = tmp_path / "unsloth_SmolLM-135M_1775412608" - lora_dir.mkdir() - (lora_dir / "adapter_config.json").write_text( - json.dumps({"base_model_name_or_path": "HuggingFaceTB/SmolLM-135M"}) - ) - (lora_dir / "adapter_model.safetensors").write_bytes(b"") - - full_dir = tmp_path / "unsloth_SmolLM-135M_full_1775412609" - full_dir.mkdir() - (full_dir / "config.json").write_text( - json.dumps({"_name_or_path": "HuggingFaceTB/SmolLM-135M"}) - ) - (full_dir / "model.safetensors").write_bytes(b"") - - found = { - name: (path, model_type) - for name, path, model_type in scan_trained_models(str(tmp_path)) - } - - assert found[lora_dir.name] == (str(lora_dir), "lora") - assert found[full_dir.name] == (str(full_dir), "merged") - - -def test_get_base_model_from_checkpoint_falls_back_to_full_finetune_config( - tmp_path: Path, -): - (tmp_path / "config.json").write_text( - json.dumps({"_name_or_path": "HuggingFaceTB/SmolLM-135M"}) - ) - (tmp_path / "model.safetensors").write_bytes(b"") - - assert get_base_model_from_checkpoint(str(tmp_path)) == "HuggingFaceTB/SmolLM-135M" - - -def test_get_base_model_from_lora_rejects_full_finetune_dirs(tmp_path: Path): - (tmp_path / "config.json").write_text( - json.dumps({"_name_or_path": "HuggingFaceTB/SmolLM-135M"}) - ) - (tmp_path / "model.safetensors").write_bytes(b"") - - assert get_base_model_from_lora(str(tmp_path)) is None - - -@patch("utils.models.model_config.is_audio_input_type", return_value = False) -@patch("utils.models.model_config.detect_audio_type", return_value = None) -@patch("utils.models.model_config.is_vision_model", return_value = False) -def test_model_config_full_finetune_local_path_is_not_lora( - _mock_vision, - _mock_audio_type, - _mock_audio_input, - tmp_path: Path, -): - (tmp_path / "config.json").write_text( - json.dumps({"_name_or_path": "unsloth/Qwen3-4B"}) - ) - (tmp_path / "model.safetensors").write_bytes(b"") - - config = ModelConfig.from_identifier(str(tmp_path)) - - assert config is not None - assert config.is_lora is False - assert config.base_model is None - - -def test_scan_trained_loras_aliases_scan_trained_models(): - utils_models = importlib.import_module("utils.models") - core_module = importlib.import_module("core") - - assert utils_models.scan_trained_loras is utils_models.scan_trained_models - assert core_module.scan_trained_loras is core_module.scan_trained_models diff --git a/studio/backend/tests/test_training_worker_flash_attn.py b/studio/backend/tests/test_training_worker_flash_attn.py deleted file mode 100644 index 41a7c87df1..0000000000 --- a/studio/backend/tests/test_training_worker_flash_attn.py +++ /dev/null @@ -1,170 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 - -from __future__ import annotations - -import builtins -import subprocess -import sys -from unittest import mock - -from core.training import worker - - -def _missing_flash_attn_import(): - real_import = builtins.__import__ - - def fake_import(name, globals = None, locals = None, fromlist = (), level = 0): - if name == "flash_attn": - raise ImportError - return real_import(name, globals, locals, fromlist, level) - - return fake_import - - -def test_should_try_runtime_flash_attn_install_threshold_and_skip(monkeypatch): - monkeypatch.delenv(worker._FLASH_ATTN_SKIP_ENV, raising = False) - assert worker._should_try_runtime_flash_attn_install(32767) is False - assert worker._should_try_runtime_flash_attn_install( - 32768 - ) is sys.platform.startswith("linux") - - monkeypatch.setenv(worker._FLASH_ATTN_SKIP_ENV, "1") - assert worker._should_try_runtime_flash_attn_install(32768) is False - - -def test_runtime_flash_attn_prefers_prebuilt_wheel(monkeypatch): - statuses: list[str] = [] - - monkeypatch.delenv(worker._FLASH_ATTN_SKIP_ENV, raising = False) - monkeypatch.setattr(builtins, "__import__", _missing_flash_attn_import()) - monkeypatch.setattr( - worker, - "flash_attn_wheel_url", - lambda env: "https://example.com/fa.whl", - ) - monkeypatch.setattr(worker, "url_exists", lambda url: True) - monkeypatch.setattr( - worker, - "_send_status", - lambda queue, message: statuses.append(message), - ) - monkeypatch.setattr( - worker, - "install_wheel", - lambda *args, **kwargs: [("pip", subprocess.CompletedProcess(["pip"], 0, ""))], - ) - - worker._ensure_flash_attn_for_long_context(event_queue = [], max_seq_length = 32768) - - assert statuses == ["Installing prebuilt flash-attn wheel..."] - - -def test_runtime_flash_attn_falls_back_to_pypi(monkeypatch): - calls: list[list[str]] = [] - statuses: list[str] = [] - - monkeypatch.delenv(worker._FLASH_ATTN_SKIP_ENV, raising = False) - monkeypatch.setattr(builtins, "__import__", _missing_flash_attn_import()) - monkeypatch.setattr( - worker, - "probe_torch_wheel_env", - lambda timeout = 30: { - "python_tag": "cp313", - "torch_mm": "2.10", - "cuda_major": "13", - "cxx11abi": "TRUE", - "platform_tag": "linux_x86_64", - }, - ) - monkeypatch.setattr( - worker, - "flash_attn_wheel_url", - lambda env: "https://example.com/fa.whl", - ) - monkeypatch.setattr(worker, "url_exists", lambda url: False) - monkeypatch.setattr(worker.shutil, "which", lambda name: None) - monkeypatch.setattr( - worker, - "_send_status", - lambda queue, message: statuses.append(message), - ) - monkeypatch.setattr(worker, "install_wheel", mock.Mock()) - - def fake_run(cmd, stdout = None, stderr = None, text = None): - calls.append(list(cmd)) - return subprocess.CompletedProcess(cmd, 0, "") - - monkeypatch.setattr(worker._sp, "run", fake_run) - - worker._ensure_flash_attn_for_long_context(event_queue = [], max_seq_length = 32768) - - assert statuses == ["Installing flash-attn from PyPI for long-context training..."] - assert calls == [[sys.executable, "-m", "pip", "install", "flash-attn"]] - - -def test_runtime_flash_attn_skip_env_avoids_all_install_work(monkeypatch): - monkeypatch.setenv(worker._FLASH_ATTN_SKIP_ENV, "1") - monkeypatch.setattr(worker._sp, "run", mock.Mock()) - - worker._ensure_flash_attn_for_long_context(event_queue = [], max_seq_length = 32768) - - worker._sp.run.assert_not_called() - - -def test_causal_conv1d_fast_path_preserves_wheel_first_install_args(monkeypatch): - install_mock = mock.Mock(return_value = True) - monkeypatch.setattr(worker, "_install_package_wheel_first", install_mock) - - worker._ensure_causal_conv1d_fast_path( - event_queue = [], - model_name = "tiiuae/Falcon-H1-0.5B-Instruct", - ) - - install_mock.assert_called_once_with( - event_queue = [], - import_name = "causal_conv1d", - display_name = "causal-conv1d", - pypi_name = "causal-conv1d", - pypi_version = worker._CAUSAL_CONV1D_PACKAGE_VERSION, - filename_prefix = "causal_conv1d", - release_tag = worker._CAUSAL_CONV1D_RELEASE_TAG, - release_base_url = "https://github.com/Dao-AILab/causal-conv1d/releases/download", - ) - - -def test_causal_conv1d_fast_path_includes_qwen3_6_variants(monkeypatch): - install_mock = mock.Mock(return_value = True) - monkeypatch.setattr(worker, "_install_package_wheel_first", install_mock) - - worker._ensure_causal_conv1d_fast_path( - event_queue = [], - model_name = "unsloth/Qwen3.6-4B", - ) - worker._ensure_causal_conv1d_fast_path( - event_queue = [], - model_name = "unsloth/Qwen3_6-4B", - ) - - assert install_mock.call_count == 2 - - -def test_mamba_ssm_path_preserves_wheel_first_install_args(monkeypatch): - install_mock = mock.Mock(return_value = True) - monkeypatch.setattr(worker, "_install_package_wheel_first", install_mock) - - worker._ensure_mamba_ssm( - event_queue = [], - model_name = "tiiuae/Falcon-H1-0.5B-Instruct", - ) - - install_mock.assert_called_once_with( - event_queue = [], - import_name = "mamba_ssm", - display_name = "mamba-ssm", - pypi_name = "mamba-ssm", - pypi_version = worker._MAMBA_SSM_PACKAGE_VERSION, - filename_prefix = "mamba_ssm", - release_tag = worker._MAMBA_SSM_RELEASE_TAG, - release_base_url = "https://github.com/state-spaces/mamba/releases/download", - ) diff --git a/studio/backend/tests/test_utils.py b/studio/backend/tests/test_utils.py index 64c9907119..50557c6718 100644 --- a/studio/backend/tests/test_utils.py +++ b/studio/backend/tests/test_utils.py @@ -191,14 +191,8 @@ def test_has_backend_key(self): assert "backend" in get_gpu_memory_info() def test_backend_matches_device(self): - # The backend field uses _backend_label, which swaps "cuda" for - # "rocm" when running on an AMD host (IS_ROCM=True) so the UI - # can render the correct label. On CUDA / XPU / MLX / CPU hosts - # it is equivalent to `get_device().value`. - from utils.hardware.hardware import _backend_label - result = get_gpu_memory_info() - assert result["backend"] == _backend_label(get_device()) + assert result["backend"] == get_device().value # --- When a GPU IS available --- diff --git a/studio/backend/tests/test_vision_cache.py b/studio/backend/tests/test_vision_cache.py index 9e7bbdd1fb..fae1e95311 100644 --- a/studio/backend/tests/test_vision_cache.py +++ b/studio/backend/tests/test_vision_cache.py @@ -124,50 +124,23 @@ def test_subprocess_called_once_with_cache(self, mock_needs_t5, mock_subprocess) class TestVisionCacheOnException: - """When detection raises an exception, _is_vision_model_uncached - distinguishes permanent failures (cached as False) from transient - failures (returned as None, not cached so the next call can retry). - Verify both contracts.""" - - @patch( - "utils.models.model_config.load_model_config", - side_effect = ValueError("bad config"), - ) - @patch("utils.transformers_version.needs_transformers_5", return_value = False) - def test_permanent_exception_result_cached(self, mock_needs_t5, mock_load_config): - """A permanent failure (ValueError / RepositoryNotFoundError / - GatedRepoError / JSONDecodeError) should be caught, return False, - and that False should be cached so subsequent calls don't retry. - - ValueError is used here because it's the simplest of the - code-path's cacheable exception types and does not require an - import of huggingface_hub errors (whose module path varies - across versions).""" - # First call: load_model_config raises -> except branch -> False. - assert is_vision_model("broken/model") is False - # Second call: cache hit, load_model_config not called again. - assert is_vision_model("broken/model") is False - mock_load_config.assert_called_once() + """When detection raises an exception, _is_vision_model_uncached catches + it and returns False. That False must be cached so subsequent calls don't + retry and fail again.""" @patch( "utils.models.model_config.load_model_config", side_effect = OSError("network down"), ) @patch("utils.transformers_version.needs_transformers_5", return_value = False) - def test_transient_exception_not_cached(self, mock_needs_t5, mock_load_config): - """A transient failure (OSError, timeouts) should return None from - _is_vision_model_uncached, surface as False to the caller, and - NOT be cached, so the next call retries detection. This matches - the documented behaviour on _vision_detection_cache: - 'transient failures (network errors, timeouts) are NOT cached so - they can be retried.'""" - # First call: load_model_config raises OSError -> uncached None - # -> caller returns False without caching. + def test_exception_result_cached(self, mock_needs_t5, mock_load_config): + """A real exception inside _is_vision_model_uncached should be caught, + return False, and that False should be cached for subsequent calls.""" + # First call: load_model_config raises → except branch → False assert is_vision_model("broken/model") is False - # Second call: cache miss again, load_model_config called a - # second time. + # Second call: cache hit, load_model_config not called again assert is_vision_model("broken/model") is False - assert mock_load_config.call_count == 2 + mock_load_config.assert_called_once() # --------------------------------------------------------------------------- diff --git a/studio/backend/tests/test_vram_estimation.py b/studio/backend/tests/test_vram_estimation.py index e54ae6dcf8..0be067310d 100644 --- a/studio/backend/tests/test_vram_estimation.py +++ b/studio/backend/tests/test_vram_estimation.py @@ -2,9 +2,7 @@ # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. import unittest -from dataclasses import replace from types import SimpleNamespace -from unittest.mock import patch from utils.hardware.vram_estimation import ( ModelArchConfig, @@ -118,55 +116,6 @@ def _gb(b: int) -> float: num_dense_layers = 0, ) -STRUCTURED_MIXED = ModelArchConfig( - hidden_size = 256, - num_hidden_layers = 6, - num_attention_heads = 4, - num_key_value_heads = 2, - intermediate_size = 512, - vocab_size = 1024, - tie_word_embeddings = True, - head_dim = 80, - global_head_dim = 96, - num_global_key_value_heads = 1, - attention_k_eq_v = True, - layer_types = [ - "sliding_attention", - "full_attention", - "sliding_attention", - "full_attention", - "sliding_attention", - "full_attention", - ], -) - -STRUCTURED_SHARED = ModelArchConfig( - hidden_size = 192, - num_hidden_layers = 4, - num_attention_heads = 6, - num_key_value_heads = 2, - intermediate_size = 384, - vocab_size = 512, - tie_word_embeddings = True, - head_dim = 32, - num_kv_shared_layers = 2, - use_double_wide_mlp = True, - vocab_size_per_layer_input = 128, - hidden_size_per_layer_input = 48, - quant_4bit_factor = 3.6, -) - -QUANT_SKIP_STRUCTURED = replace( - STRUCTURED_SHARED, - quantization_skip_modules = [ - "model.layers.0.self_attn.q_proj", - "language_model.model.layers.1.mlp", - "layers.2", - "vision_tower", - "embed_tokens", - ], -) - class TestExtractArchConfig(unittest.TestCase): def test_basic_config(self): @@ -233,42 +182,6 @@ def test_intermediate_size_list(self): arch = extract_arch_config(hf_config) self.assertEqual(arch.intermediate_size, 8192) - def test_structural_and_quantization_fields_are_config_derived(self): - hf_config = SimpleNamespace( - hidden_size = 256, - num_hidden_layers = 2, - num_attention_heads = 4, - num_key_value_heads = 2, - intermediate_size = 512, - vocab_size = 1024, - tie_word_embeddings = True, - head_dim = 80, - global_head_dim = 96, - num_global_key_value_heads = 1, - attention_k_eq_v = True, - layer_types = ["sliding_attention", "full_attention"], - num_kv_shared_layers = 1, - use_double_wide_mlp = True, - vocab_size_per_layer_input = 128, - hidden_size_per_layer_input = 48, - quantization_config = { - "bnb_4bit_use_double_quant": True, - "llm_int8_skip_modules": ["model.layers.0.self_attn"], - }, - ) - arch = extract_arch_config(hf_config) - self.assertEqual(arch.head_dim, 80) - self.assertEqual(arch.global_head_dim, 96) - self.assertEqual(arch.num_global_key_value_heads, 1) - self.assertTrue(arch.attention_k_eq_v) - self.assertEqual(arch.layer_types, ["sliding_attention", "full_attention"]) - self.assertEqual(arch.num_kv_shared_layers, 1) - self.assertTrue(arch.use_double_wide_mlp) - self.assertEqual(arch.vocab_size_per_layer_input, 128) - self.assertEqual(arch.hidden_size_per_layer_input, 48) - self.assertEqual(arch.quantization_skip_modules, ["model.layers.0.self_attn"]) - self.assertEqual(arch.quant_4bit_factor, 3.6) - class TestModelWeightsBytes(unittest.TestCase): def test_llama_8b_fp16(self): @@ -325,18 +238,6 @@ def test_moe_mlp_modules_scale_with_experts(self): ratio = moe_lora / dense_lora self.assertAlmostEqual(ratio, 8.0, delta = 0.5) - def test_structured_moe_mlp_modules_scale_with_experts(self): - structured_moe = replace(QWEN3_MOE_30B, head_dim = 128) - dense_like = replace( - structured_moe, - num_experts = None, - moe_intermediate_size = None, - ) - target_modules = ["gate_proj", "up_proj", "down_proj"] - dense_lora = compute_lora_params(dense_like, 16, target_modules) - moe_lora = compute_lora_params(structured_moe, 16, target_modules) - self.assertGreater(moe_lora, dense_lora * 20) - def test_attention_modules_same_for_moe(self): dense_attn = compute_lora_params( LLAMA_8B, 16, ["q_proj", "k_proj", "v_proj", "o_proj"] @@ -346,41 +247,6 @@ def test_attention_modules_same_for_moe(self): ) self.assertEqual(dense_attn, moe_attn) - def test_all_linear_uses_default_text_modules(self): - text_only = compute_lora_params(STRUCTURED_MIXED, 16, DEFAULT_TARGET_MODULES) - all_linear = compute_lora_params(STRUCTURED_MIXED, 16, ["all-linear"]) - self.assertEqual(all_linear, text_only) - - def test_structural_layer_shapes_are_config_driven(self): - unstructured_arch = replace( - STRUCTURED_MIXED, - head_dim = None, - global_head_dim = None, - num_global_key_value_heads = None, - attention_k_eq_v = False, - layer_types = None, - ) - self.assertNotEqual( - compute_lora_params(unstructured_arch, 16, ["all-linear"]), - compute_lora_params(STRUCTURED_MIXED, 16, ["all-linear"]), - ) - self.assertNotEqual( - compute_model_weights_bytes(unstructured_arch, "qlora", True), - compute_model_weights_bytes(STRUCTURED_MIXED, "qlora", True), - ) - - def test_shared_kv_and_per_layer_inputs_change_weight_count(self): - unstructured_arch = replace( - STRUCTURED_SHARED, - head_dim = None, - num_kv_shared_layers = 0, - use_double_wide_mlp = False, - ) - self.assertNotEqual( - compute_model_weights_bytes(unstructured_arch, "qlora", True), - compute_model_weights_bytes(STRUCTURED_SHARED, "qlora", True), - ) - class TestOptimizerBytes(unittest.TestCase): def test_adamw_8bit(self): @@ -427,163 +293,6 @@ def test_scales_with_seq_len(self): act_4k = compute_activation_bytes(LLAMA_8B, 2, 4096, "unsloth") self.assertAlmostEqual(act_4k / act_2k, 2.0, delta = 0.1) - def test_flash_attention_uses_linear_path(self): - flash = compute_activation_bytes( - STRUCTURED_MIXED, - 1, - 4096, - "unsloth", - is_lora = True, - attention_implementation = "flash_attention_2", - ) - default = compute_activation_bytes( - STRUCTURED_MIXED, - 1, - 4096, - "unsloth", - is_lora = True, - ) - self.assertEqual(flash, default) - - def test_sdpa_attention_uses_linear_path(self): - flash = compute_activation_bytes( - STRUCTURED_MIXED, - 1, - 4096, - "unsloth", - is_lora = True, - attention_implementation = "flash_attention_2", - ) - sdpa = compute_activation_bytes( - STRUCTURED_MIXED, - 1, - 4096, - "unsloth", - is_lora = True, - attention_implementation = "sdpa", - ) - self.assertEqual(sdpa, flash) - - def test_non_flash_attention_uses_quadratic_path(self): - seq_len = 4096 - expected_quadratic = ( - 1 * STRUCTURED_MIXED.num_attention_heads * seq_len * seq_len * 2 * 12.0 - ) - for attention_implementation in ("eager", "unknown_impl", None): - with self.subTest(attention_implementation = attention_implementation): - non_flash = compute_activation_bytes( - STRUCTURED_MIXED, - 1, - seq_len, - "unsloth", - is_lora = True, - attention_implementation = attention_implementation, - ) - self.assertEqual(non_flash, int(expected_quadratic)) - - def test_non_flash_attention_without_gc_scales_quadratic_path_by_layers(self): - seq_len = 4096 - one_layer = ( - 1 * STRUCTURED_MIXED.num_attention_heads * seq_len * seq_len * 2 * 12.0 - ) - non_flash = compute_activation_bytes( - STRUCTURED_MIXED, - 1, - seq_len, - "none", - is_lora = True, - attention_implementation = "eager", - ) - self.assertEqual(non_flash, int(one_layer * STRUCTURED_MIXED.num_hidden_layers)) - self.assertGreater(non_flash, int(one_layer)) - - -class TestQuantizationSkips(unittest.TestCase): - def test_skipped_language_layers_stay_fp16(self): - no_skips = replace(QUANT_SKIP_STRUCTURED, quantization_skip_modules = []) - skipped = compute_model_weights_bytes(QUANT_SKIP_STRUCTURED, "qlora", True) - quantized = compute_model_weights_bytes(no_skips, "qlora", True) - self.assertGreater(skipped, quantized) - - def test_non_language_skips_do_not_double_count_text_weights(self): - arch = replace( - QUANT_SKIP_STRUCTURED, - quantization_skip_modules = ["vision_tower", "embed_tokens"], - ) - no_skips = replace(QUANT_SKIP_STRUCTURED, quantization_skip_modules = []) - self.assertEqual( - compute_model_weights_bytes(arch, "qlora", True), - compute_model_weights_bytes(no_skips, "qlora", True), - ) - - def test_double_quant_factor_reduces_quantized_weight_storage(self): - default_quant = replace(STRUCTURED_MIXED, quant_4bit_factor = 16 / 5) - double_quant = replace(STRUCTURED_MIXED, quant_4bit_factor = 3.6) - self.assertLess( - compute_model_weights_bytes(double_quant, "qlora", True), - compute_model_weights_bytes(default_quant, "qlora", True), - ) - - def test_prefixed_parent_and_child_skips_do_not_double_count(self): - parent_only = replace( - QUANT_SKIP_STRUCTURED, - quantization_skip_modules = ["language_model.model.layers.1.mlp"], - ) - parent_and_child = replace( - QUANT_SKIP_STRUCTURED, - quantization_skip_modules = [ - "language_model.model.layers.1.mlp", - "language_model.model.layers.1.mlp.gate_proj", - "model.layers.1.mlp.up_proj", - ], - ) - self.assertEqual( - compute_model_weights_bytes(parent_and_child, "qlora", True), - compute_model_weights_bytes(parent_only, "qlora", True), - ) - - def test_vlm_prefix_skip_module_does_not_match_text_alias(self): - # vision_tower-prefixed skips must not shadow text aliases sharing the - # same suffix. - baseline = replace(QUANT_SKIP_STRUCTURED, quantization_skip_modules = []) - vlm_skip = replace( - QUANT_SKIP_STRUCTURED, - quantization_skip_modules = [ - "vision_tower.model.layers.0.self_attn.q_proj", - "vision_tower.model.layers.1.mlp", - ], - ) - self.assertEqual( - compute_model_weights_bytes(vlm_skip, "qlora", True), - compute_model_weights_bytes(baseline, "qlora", True), - ) - - def test_mla_skip_module_uses_authoritative_attn_total(self): - from utils.hardware.vram_estimation import ( - _build_text_module_elements, - _compute_attn_elements, - ) - - mla = ModelArchConfig( - hidden_size = 2048, - num_hidden_layers = 4, - num_attention_heads = 16, - num_key_value_heads = 16, - intermediate_size = 8192, - vocab_size = 32000, - tie_word_embeddings = False, - q_lora_rank = 512, - kv_lora_rank = 128, - qk_nope_head_dim = 64, - qk_rope_head_dim = 32, - v_head_dim = 64, - ) - elements, _ = _build_text_module_elements(mla) - self.assertEqual( - elements["text.layers.0.self_attn"], - _compute_attn_elements(mla), - ) - class TestEstimateTrainingVram(unittest.TestCase): def test_llama_8b_qlora_reasonable_total(self): @@ -721,90 +430,6 @@ def test_adamw_fp32_uses_more_optimizer_memory(self): v32.optimizer_states / v8.optimizer_states, 1.5, delta = 0.1 ) - def test_min_gpu_vram_treats_activations_as_per_gpu_fixed(self): - config = TrainingVramConfig(training_method = "qlora", load_in_4bit = True) - breakdown = estimate_training_vram(LLAMA_8B, config) - shardable = ( - breakdown.model_weights - + breakdown.lora_adapters - + breakdown.optimizer_states - + breakdown.gradients - ) - per_gpu_fixed = breakdown.activations + breakdown.cuda_overhead - for n_gpus in (1, 2, 4): - self.assertEqual( - breakdown.min_gpu_vram(n_gpus), - shardable // n_gpus + per_gpu_fixed, - ) - - def test_qlora_gradient_floor_is_capped_by_trainable_scale(self): - config = TrainingVramConfig( - training_method = "qlora", - batch_size = 1, - max_seq_length = 512, - lora_rank = 16, - target_modules = ["all-linear"], - gradient_checkpointing = "unsloth", - optimizer = "adamw_8bit", - load_in_4bit = True, - ) - breakdown = estimate_training_vram(LLAMA_8B, config) - lora_params = compute_lora_params(LLAMA_8B, 16, DEFAULT_TARGET_MODULES) - optimizer_bytes = compute_optimizer_bytes(lora_params, "adamw_8bit") - weight_floor = int(breakdown.model_weights * 0.15) - - self.assertEqual( - breakdown.gradients, - max(breakdown.activations_computed, optimizer_bytes), - ) - self.assertLess(breakdown.gradients, weight_floor) - self.assertEqual(breakdown.activations, breakdown.activations_computed) - - def test_full_finetuning_gradient_floor_remains_uncapped(self): - config = TrainingVramConfig( - training_method = "full", - batch_size = 1, - max_seq_length = 512, - gradient_checkpointing = "unsloth", - optimizer = "adamw_8bit", - load_in_4bit = False, - ) - expected_floor = int( - compute_model_weights_bytes(LLAMA_8B, "full", False) * 0.15 - ) - with patch( - "utils.hardware.vram_estimation.compute_gradient_bytes", - return_value = 1, - ): - breakdown = estimate_training_vram(LLAMA_8B, config) - self.assertEqual(breakdown.gradients, expected_floor) - - def test_non_flash_attention_flows_into_training_estimate(self): - config = TrainingVramConfig( - training_method = "qlora", - batch_size = 1, - max_seq_length = 4096, - lora_rank = 16, - target_modules = ["all-linear"], - gradient_checkpointing = "unsloth", - optimizer = "adamw_8bit", - load_in_4bit = True, - attention_implementation = "eager", - ) - breakdown = estimate_training_vram(STRUCTURED_MIXED, config) - self.assertEqual(breakdown.activations, breakdown.activations_computed) - self.assertGreater( - breakdown.activations, - compute_activation_bytes( - STRUCTURED_MIXED, - 1, - 4096, - "unsloth", - is_lora = True, - attention_implementation = "flash_attention_2", - ), - ) - class TestExtractArchConfigMoE(unittest.TestCase): def test_deepseek_v3_shared_experts(self): @@ -846,16 +471,11 @@ def test_qwen3_moe_decoder_sparse_step(self): moe_intermediate_size = 768, decoder_sparse_step = 1, mlp_only_layers = [], - head_dim = 128, ) arch = extract_arch_config(hf_config) self.assertEqual(arch.num_experts, 128) self.assertEqual(arch.num_dense_layers, 0) - self.assertEqual(arch.head_dim, 128) self.assertIsNone(arch.q_lora_rank) - total_b = compute_total_params(arch) / 1e9 - self.assertGreater(total_b, 20) - self.assertLess(total_b, 50) def test_qwen3_moe_with_mlp_only_layers(self): hf_config = SimpleNamespace( @@ -922,343 +542,6 @@ def test_backward_compat_no_new_fields(self): self.assertEqual(arch.n_shared_experts, 0) self.assertEqual(arch.num_dense_layers, 0) self.assertIsNone(arch.q_lora_rank) - self.assertFalse(arch.moe_has_dense_mlp) - - def test_enable_moe_block_extracted_as_moe_has_dense_mlp(self): - hf_config = SimpleNamespace( - hidden_size = 2048, - num_hidden_layers = 8, - num_attention_heads = 16, - num_key_value_heads = 4, - intermediate_size = 4096, - vocab_size = 32000, - tie_word_embeddings = True, - num_experts = 8, - moe_intermediate_size = 1024, - head_dim = 128, - layer_types = ["full_attention"] * 8, - enable_moe_block = True, - ) - arch = extract_arch_config(hf_config) - self.assertTrue(arch.moe_has_dense_mlp) - - -class TestParallelDenseMoE(unittest.TestCase): - def _arch(self, **overrides): - base = ModelArchConfig( - hidden_size = 512, - num_hidden_layers = 4, - num_attention_heads = 8, - num_key_value_heads = 2, - intermediate_size = 1024, - vocab_size = 1024, - tie_word_embeddings = True, - num_experts = 8, - moe_intermediate_size = 512, - num_dense_layers = 0, - head_dim = 64, - layer_types = ["full_attention"] * 4, - ) - return replace(base, **overrides) - - def test_total_params_includes_parallel_dense_when_enable_moe_block(self): - without_parallel = self._arch(moe_has_dense_mlp = False) - with_parallel = self._arch(moe_has_dense_mlp = True) - self.assertGreater( - compute_total_params(with_parallel), - compute_total_params(without_parallel), - ) - - def test_lora_params_includes_parallel_dense_when_enable_moe_block(self): - without_parallel = self._arch(moe_has_dense_mlp = False) - with_parallel = self._arch(moe_has_dense_mlp = True) - target = ["gate_proj", "up_proj", "down_proj"] - self.assertGreater( - compute_lora_params(with_parallel, 16, target), - compute_lora_params(without_parallel, 16, target), - ) - - def test_activation_bytes_includes_parallel_dense_when_enable_moe_block(self): - without_parallel = self._arch(moe_has_dense_mlp = False) - with_parallel = self._arch(moe_has_dense_mlp = True) - self.assertGreater( - compute_activation_bytes( - with_parallel, - 1, - 2048, - "unsloth", - is_lora = True, - ), - compute_activation_bytes( - without_parallel, - 1, - 2048, - "unsloth", - is_lora = True, - ), - ) - - def test_layer_aggregates_split_dense_mlp_from_experts(self): - from utils.hardware.vram_estimation import _build_text_module_elements - - with_parallel = self._arch(moe_has_dense_mlp = True) - elements, _ = _build_text_module_elements(with_parallel) - moe_only = ( - with_parallel.hidden_size - * with_parallel.moe_intermediate_size - * 3 - * with_parallel.num_experts - + with_parallel.num_experts * with_parallel.hidden_size - ) - dense_only = with_parallel.hidden_size * with_parallel.intermediate_size * 3 - # why: under gemma4 enable_moe_block, the layer's `self.experts` is a - # sibling of `self.mlp`; the `text.layers..mlp` aggregate must - # cover the dense path only, with experts in their own aggregate. - self.assertEqual(elements["text.layers.0.mlp"], dense_only) - self.assertEqual(elements["text.layers.0.experts"], moe_only) - - -class TestDenseLayerIndices(unittest.TestCase): - def test_non_prefix_mlp_only_layers_preserve_position(self): - hf_config = SimpleNamespace( - hidden_size = 1024, - num_hidden_layers = 8, - num_attention_heads = 16, - num_key_value_heads = 4, - intermediate_size = 2048, - vocab_size = 32000, - tie_word_embeddings = True, - num_local_experts = 4, - moe_intermediate_size = 512, - decoder_sparse_step = 1, - mlp_only_layers = [3, 5], - ) - arch = extract_arch_config(hf_config) - self.assertEqual(arch.num_dense_layers, 2) - self.assertIn(3, arch.dense_layer_indices) - self.assertIn(5, arch.dense_layer_indices) - self.assertNotIn(0, arch.dense_layer_indices) - - def test_first_k_dense_replace_indices_are_prefix(self): - hf_config = SimpleNamespace( - hidden_size = 1024, - num_hidden_layers = 6, - num_attention_heads = 16, - num_key_value_heads = 4, - intermediate_size = 2048, - vocab_size = 32000, - tie_word_embeddings = False, - n_routed_experts = 8, - moe_intermediate_size = 512, - first_k_dense_replace = 2, - ) - arch = extract_arch_config(hf_config) - self.assertEqual(tuple(arch.dense_layer_indices), (0, 1)) - - -class TestKvSharedLayer(unittest.TestCase): - def test_fully_shared_kv_returns_false_matching_upstream(self): - from utils.hardware.vram_estimation import _is_kv_shared_layer - - arch = ModelArchConfig( - hidden_size = 512, - num_hidden_layers = 4, - num_attention_heads = 8, - num_key_value_heads = 2, - intermediate_size = 1024, - vocab_size = 1024, - num_kv_shared_layers = 4, - ) - for i in range(arch.num_hidden_layers): - self.assertFalse(_is_kv_shared_layer(arch, i)) - - def test_partial_share_returns_true_for_tail_layers(self): - from utils.hardware.vram_estimation import _is_kv_shared_layer - - arch = ModelArchConfig( - hidden_size = 512, - num_hidden_layers = 4, - num_attention_heads = 8, - num_key_value_heads = 2, - intermediate_size = 1024, - vocab_size = 1024, - num_kv_shared_layers = 2, - ) - self.assertFalse(_is_kv_shared_layer(arch, 0)) - self.assertFalse(_is_kv_shared_layer(arch, 1)) - self.assertTrue(_is_kv_shared_layer(arch, 2)) - self.assertTrue(_is_kv_shared_layer(arch, 3)) - - -class TestFlexAttentionLinear(unittest.TestCase): - def test_flex_attention_treated_as_linear(self): - flash = compute_activation_bytes( - STRUCTURED_MIXED, - 1, - 4096, - "unsloth", - is_lora = True, - attention_implementation = "flash_attention_2", - ) - flex = compute_activation_bytes( - STRUCTURED_MIXED, - 1, - 4096, - "unsloth", - is_lora = True, - attention_implementation = "flex_attention", - ) - self.assertEqual(flex, flash) - - -class TestNonStructuredParallelDense(unittest.TestCase): - def _arch(self, **overrides): - base = ModelArchConfig( - hidden_size = 1024, - num_hidden_layers = 4, - num_attention_heads = 16, - num_key_value_heads = 4, - intermediate_size = 4096, - vocab_size = 32000, - tie_word_embeddings = False, - num_experts = 8, - moe_intermediate_size = 768, - num_dense_layers = 0, - moe_has_dense_mlp = True, - ) - return replace(base, **overrides) - - def test_skip_module_uses_intermediate_size_for_parallel_dense(self): - from utils.hardware.vram_estimation import _build_text_module_elements - - arch = self._arch() - elements, _ = _build_text_module_elements(arch) - gate_proj = elements["text.layers.0.mlp.gate_proj"] - self.assertEqual(gate_proj, arch.hidden_size * arch.intermediate_size) - - -class TestPerLayerInputAccounting(unittest.TestCase): - def _arch(self, **overrides): - base = ModelArchConfig( - hidden_size = 1024, - num_hidden_layers = 4, - num_attention_heads = 16, - num_key_value_heads = 4, - intermediate_size = 2048, - vocab_size = 32000, - tie_word_embeddings = False, - head_dim = 64, - layer_types = ["full_attention"] * 4, - vocab_size_per_layer_input = 256, - hidden_size_per_layer_input = 96, - ) - return replace(base, **overrides) - - def test_per_layer_input_increases_total_params(self): - with_ple = self._arch() - without_ple = replace(with_ple, hidden_size_per_layer_input = 0) - self.assertGreater( - compute_total_params(with_ple), - compute_total_params(without_ple), - ) - - def test_per_layer_input_modules_count_quantizable_block(self): - with_ple = self._arch() - without_ple = replace(with_ple, hidden_size_per_layer_input = 0) - # The PLE block adds: model_projection (hd*nl*pli), per_layer_input_gate - # (hd*pli per layer) + per_layer_projection (pli*hd per layer) as - # quantizable text linears. - n_layers = with_ple.num_hidden_layers - hd = with_ple.hidden_size - pli = with_ple.hidden_size_per_layer_input - expected_quantizable_extra = ( - hd * (n_layers * pli) + (hd * pli) * n_layers + (pli * hd) * n_layers - ) - delta = compute_total_params(with_ple) - compute_total_params(without_ple) - self.assertGreaterEqual(delta, expected_quantizable_extra) - - def test_all_linear_lora_excludes_per_layer_input_modules(self): - # why: Unsloth's get_peft_regex requires module names to contain a - # component tag (mlp/attn/...); PLE module names (per_layer_input_gate, - # per_layer_projection, per_layer_model_projection) lack any tag, so - # all-linear training does NOT attach LoRA to them. - arch = self._arch() - without_ple = replace(arch, hidden_size_per_layer_input = 0) - self.assertEqual( - compute_lora_params(arch, 16, ["all-linear"]), - compute_lora_params(without_ple, 16, ["all-linear"]), - ) - - def test_explicit_target_modules_does_not_add_per_layer_input(self): - arch = self._arch() - without_ple = replace(arch, hidden_size_per_layer_input = 0) - self.assertEqual( - compute_lora_params(arch, 16, ["q_proj", "v_proj"]), - compute_lora_params(without_ple, 16, ["q_proj", "v_proj"]), - ) - - -class TestDenseMlpLayerFallback(unittest.TestCase): - def test_falls_back_to_count_when_indices_empty(self): - from utils.hardware.vram_estimation import _is_dense_mlp_layer - - arch = ModelArchConfig( - hidden_size = 512, - num_hidden_layers = 4, - num_attention_heads = 8, - num_key_value_heads = 2, - intermediate_size = 1024, - vocab_size = 1024, - num_experts = 4, - moe_intermediate_size = 256, - num_dense_layers = 2, - ) - self.assertTrue(_is_dense_mlp_layer(arch, 0)) - self.assertTrue(_is_dense_mlp_layer(arch, 1)) - self.assertFalse(_is_dense_mlp_layer(arch, 2)) - self.assertFalse(_is_dense_mlp_layer(arch, 3)) - - -class TestExpertsSkipGranularity(unittest.TestCase): - def _arch(self): - return ModelArchConfig( - hidden_size = 512, - num_hidden_layers = 4, - num_attention_heads = 8, - num_key_value_heads = 2, - intermediate_size = 1024, - vocab_size = 1024, - tie_word_embeddings = True, - num_experts = 8, - moe_intermediate_size = 512, - num_dense_layers = 0, - head_dim = 64, - layer_types = ["full_attention"] * 4, - moe_has_dense_mlp = True, - ) - - def test_experts_skip_excludes_parallel_dense_projections(self): - no_skip = self._arch() - skip_experts = replace( - no_skip, - quantization_skip_modules = ["model.layers.0.mlp.experts"], - ) - skip_full_mlp = replace( - no_skip, - quantization_skip_modules = ["model.layers.0.mlp"], - ) - bytes_no_skip = compute_model_weights_bytes(no_skip, "qlora", True) - bytes_skip_experts = compute_model_weights_bytes(skip_experts, "qlora", True) - bytes_skip_mlp = compute_model_weights_bytes(skip_full_mlp, "qlora", True) - # why: under gemma4 enable_moe_block, `self.experts` is a sibling of - # `self.mlp`; skipping `model.layers.0.mlp` should cover only the - # dense MLP, while `model.layers.0.mlp.experts` covers the routed - # experts. Routed experts have far more params than the dense MLP, - # so skipping experts must add more bytes than skipping the dense - # path. - self.assertGreater(bytes_skip_experts, bytes_no_skip) - self.assertGreater(bytes_skip_mlp, bytes_no_skip) - self.assertGreater(bytes_skip_experts, bytes_skip_mlp) class TestSharedExperts(unittest.TestCase): @@ -1325,16 +608,6 @@ def test_mla_lora_produces_values(self): lora_p = compute_lora_params(DEEPSEEK_V3, 16, ["q_proj", "v_proj", "o_proj"]) self.assertGreater(lora_p, 0) - def test_mla_with_head_dim_does_not_route_through_structured(self): - from utils.hardware.vram_estimation import _uses_structured_layer_shapes - - mla_with_head_dim = replace(DEEPSEEK_V3, head_dim = 128) - self.assertFalse(_uses_structured_layer_shapes(mla_with_head_dim)) - self.assertEqual( - compute_lora_params(DEEPSEEK_V3, 16, ["q_proj", "v_proj", "o_proj"]), - compute_lora_params(mla_with_head_dim, 16, ["q_proj", "v_proj", "o_proj"]), - ) - class TestDenseMoEMix(unittest.TestCase): def test_dense_layers_change_total(self): @@ -1418,952 +691,5 @@ def test_lora_dense_vs_moe_layers_differ(self): self.assertNotEqual(lora_all, lora_mix) -class TestMlpLayerTypesDispatch(unittest.TestCase): - def _hf(self, **fields): - text_config = SimpleNamespace( - hidden_size = 64, - num_hidden_layers = 4, - num_attention_heads = 4, - num_key_value_heads = 4, - intermediate_size = 128, - vocab_size = 1000, - tie_word_embeddings = True, - num_local_experts = 4, - moe_intermediate_size = 32, - **fields, - ) - return SimpleNamespace(text_config = text_config, quantization_config = {}) - - def test_mlp_layer_types_drives_dense_indices(self): - hf = self._hf(mlp_layer_types = ["sparse", "dense", "sparse", "dense"]) - arch = extract_arch_config(hf) - self.assertIsNotNone(arch) - self.assertEqual(arch.dense_layer_indices, (1, 3)) - self.assertEqual(arch.num_dense_layers, 2) - - def test_mlp_layer_types_takes_priority_over_first_k_dense_replace(self): - hf = self._hf( - mlp_layer_types = ["dense", "sparse", "dense", "sparse"], - first_k_dense_replace = 3, - ) - arch = extract_arch_config(hf) - self.assertEqual(arch.dense_layer_indices, (0, 2)) - - def test_mlp_layer_types_ignores_unknown_entries(self): - hf = self._hf(mlp_layer_types = ["dense", "moe", "dense", "linear"]) - arch = extract_arch_config(hf) - self.assertEqual(arch.dense_layer_indices, (0, 2)) - - def test_mlp_layer_types_shorter_than_layers_only_marks_present(self): - hf = self._hf(mlp_layer_types = ["dense", "sparse"]) - arch = extract_arch_config(hf) - self.assertEqual(arch.dense_layer_indices, (0,)) - - def test_empty_mlp_layer_types_falls_through_to_first_k(self): - hf = self._hf(mlp_layer_types = [], first_k_dense_replace = 2) - arch = extract_arch_config(hf) - self.assertEqual(arch.dense_layer_indices, (0, 1)) - - -class TestPerLayerInputSkipAlias(unittest.TestCase): - def _hf(self, skip): - text_config = SimpleNamespace( - hidden_size = 64, - num_hidden_layers = 2, - num_attention_heads = 4, - num_key_value_heads = 4, - intermediate_size = 128, - vocab_size = 1000, - tie_word_embeddings = True, - hidden_size_per_layer_input = 8, - vocab_size_per_layer_input = 256, - ) - return SimpleNamespace( - text_config = text_config, - quantization_config = {"llm_int8_skip_modules": list(skip)}, - ) - - def test_per_layer_input_gate_skip_pulls_nonzero_delta(self): - from utils.hardware.vram_estimation import _compute_skipped_quantizable_elements - - arch = extract_arch_config(self._hf(["model.layers.0.per_layer_input_gate"])) - delta = _compute_skipped_quantizable_elements(arch) - self.assertEqual(delta, arch.hidden_size * arch.hidden_size_per_layer_input) - - def test_per_layer_model_projection_skip_pulls_global_delta(self): - from utils.hardware.vram_estimation import _compute_skipped_quantizable_elements - - arch = extract_arch_config(self._hf(["model.per_layer_model_projection"])) - delta = _compute_skipped_quantizable_elements(arch) - self.assertEqual( - delta, - arch.hidden_size - * arch.num_hidden_layers - * arch.hidden_size_per_layer_input, - ) - - def test_layer_aggregate_skip_includes_per_layer_input_modules(self): - from utils.hardware.vram_estimation import ( - _compute_skipped_quantizable_elements, - ) - - arch_with = extract_arch_config(self._hf(["model.layers.0"])) - # The text.layers.0 aggregate must include the PLE per-layer modules, - # so the same skip on a config without PLE produces a smaller value. - arch_without = extract_arch_config( - SimpleNamespace( - text_config = SimpleNamespace( - hidden_size = 64, - num_hidden_layers = 2, - num_attention_heads = 4, - num_key_value_heads = 4, - intermediate_size = 128, - vocab_size = 1000, - tie_word_embeddings = True, - hidden_size_per_layer_input = 0, - vocab_size_per_layer_input = 0, - ), - quantization_config = {"llm_int8_skip_modules": ["model.layers.0"]}, - ) - ) - self.assertGreater( - _compute_skipped_quantizable_elements(arch_with), - _compute_skipped_quantizable_elements(arch_without), - ) - - -class TestAllLinearStringHandling(unittest.TestCase): - def test_compute_lora_params_accepts_bare_all_linear_string(self): - list_form = compute_lora_params(LLAMA_8B, 16, ["all-linear"]) - str_form = compute_lora_params(LLAMA_8B, 16, "all-linear") - self.assertEqual(list_form, str_form) - self.assertGreater(list_form, 0) - - def test_compute_lora_params_string_with_underscores_normalized(self): - list_form = compute_lora_params(LLAMA_8B, 16, ["all_linear"]) - str_form = compute_lora_params(LLAMA_8B, 16, "all_linear") - self.assertEqual(list_form, str_form) - self.assertGreater(str_form, 0) - - -class TestSharedExpertVariants(unittest.TestCase): - def _hf(self, **fields): - text_config = SimpleNamespace( - hidden_size = 256, - num_hidden_layers = 4, - num_attention_heads = 8, - num_key_value_heads = 4, - intermediate_size = 1024, - vocab_size = 1000, - tie_word_embeddings = False, - num_local_experts = 8, - moe_intermediate_size = 128, - **fields, - ) - return SimpleNamespace(text_config = text_config, quantization_config = {}) - - def test_shared_expert_intermediate_size_extracted_and_infers_count(self): - arch = extract_arch_config(self._hf(shared_expert_intermediate_size = 64)) - self.assertEqual(arch.shared_expert_intermediate_size, 64) - self.assertEqual(arch.n_shared_experts, 1) - - def test_num_shared_experts_alias_extracted(self): - arch = extract_arch_config(self._hf(num_shared_experts = 2)) - self.assertEqual(arch.n_shared_experts, 2) - - def test_n_shared_experts_takes_priority_over_alias(self): - arch = extract_arch_config(self._hf(n_shared_experts = 3, num_shared_experts = 99)) - self.assertEqual(arch.n_shared_experts, 3) - - def test_shared_expert_size_separate_from_routed_changes_weight_count(self): - from utils.hardware.vram_estimation import _compute_moe_mlp_elements - - arch_separate = extract_arch_config( - self._hf(shared_expert_intermediate_size = 64) - ) - arch_implicit = extract_arch_config(self._hf(n_shared_experts = 1)) - # Different shared sizes (64 vs default moe_intermediate_size=128) must - # produce different MoE element counts. - self.assertNotEqual( - _compute_moe_mlp_elements(arch_separate), - _compute_moe_mlp_elements(arch_implicit), - ) - - def test_shared_expert_gate_counted_only_for_qwen_style(self): - from utils.hardware.vram_estimation import _compute_moe_mlp_elements - - # Qwen-style: shared_expert_intermediate_size set -> shared_expert_gate counted. - qwen_arch = extract_arch_config(self._hf(shared_expert_intermediate_size = 64)) - hd = qwen_arch.hidden_size - ms = qwen_arch.moe_intermediate_size - ne = qwen_arch.num_experts - ss = qwen_arch.shared_expert_intermediate_size - expected = hd * ms * 3 * ne + ne * hd + hd * ss * 3 * 1 + 1 * hd - self.assertEqual(_compute_moe_mlp_elements(qwen_arch), expected) - - # Non-Qwen shared experts (e.g. Exaone-MoE) -> no shared_expert_gate. - plain_arch = extract_arch_config(self._hf(n_shared_experts = 1)) - hd = plain_arch.hidden_size - ms = plain_arch.moe_intermediate_size - ne = plain_arch.num_experts - expected_plain = hd * ms * 3 * ne + ne * hd + hd * ms * 3 * 1 - self.assertEqual(_compute_moe_mlp_elements(plain_arch), expected_plain) - - -class TestSharedExpertActivation(unittest.TestCase): - def _make(self, **fields): - text_config = SimpleNamespace( - hidden_size = 512, - num_hidden_layers = 4, - num_attention_heads = 8, - num_key_value_heads = 4, - intermediate_size = 1024, - vocab_size = 1000, - tie_word_embeddings = False, - num_local_experts = 4, - moe_intermediate_size = 64, - **fields, - ) - return extract_arch_config( - SimpleNamespace(text_config = text_config, quantization_config = {}) - ) - - def test_shared_expert_increases_activation_bytes(self): - with_shared = self._make(shared_expert_intermediate_size = 64) - without = self._make() - self.assertGreater( - compute_activation_bytes( - with_shared, - 2, - 1024, - "none", - is_lora = True, - attention_implementation = "flash_attention_2", - ), - compute_activation_bytes( - without, - 2, - 1024, - "none", - is_lora = True, - attention_implementation = "flash_attention_2", - ), - ) - - def test_shared_expert_plus_dense_block_compose(self): - # gemma4 enable_moe_block with hypothetical shared expert: dense + routed - # + shared all live per layer; mlp_size should sum all three terms. - from utils.hardware.vram_estimation import _layer_qkv_mlp_sizes - - arch = self._make( - enable_moe_block = True, - shared_expert_intermediate_size = 32, - head_dim = 64, - layer_types = ["full_attention"] * 4, - ) - _, mlp_size = _layer_qkv_mlp_sizes(arch, 0) - # routed (64) + shared (32) + parallel dense intermediate (1024) - self.assertEqual(mlp_size, 64 + 32 + 1024) - - -class TestPerLayerInputActivation(unittest.TestCase): - def _make(self, **fields): - text_config = SimpleNamespace( - hidden_size = 512, - num_hidden_layers = 4, - num_attention_heads = 8, - num_key_value_heads = 4, - intermediate_size = 1024, - vocab_size = 1000, - tie_word_embeddings = False, - **fields, - ) - return extract_arch_config( - SimpleNamespace(text_config = text_config, quantization_config = {}) - ) - - def test_ple_increases_activation_bytes(self): - with_ple = self._make( - hidden_size_per_layer_input = 64, - vocab_size_per_layer_input = 256, - ) - without = self._make() - self.assertGreater( - compute_activation_bytes( - with_ple, - 2, - 1024, - "none", - is_lora = True, - attention_implementation = "flash_attention_2", - ), - compute_activation_bytes( - without, - 2, - 1024, - "none", - is_lora = True, - attention_implementation = "flash_attention_2", - ), - ) - - def test_ple_zero_does_not_inflate_activations(self): - without = self._make(hidden_size_per_layer_input = 0) - baseline = self._make() - self.assertEqual( - compute_activation_bytes( - without, - 2, - 512, - "none", - is_lora = True, - attention_implementation = "flash_attention_2", - ), - compute_activation_bytes( - baseline, - 2, - 512, - "none", - is_lora = True, - attention_implementation = "flash_attention_2", - ), - ) - - -class TestKvSharedActivation(unittest.TestCase): - def _make(self, kv_shared): - text_config = SimpleNamespace( - hidden_size = 512, - num_hidden_layers = 4, - num_attention_heads = 8, - num_key_value_heads = 4, - intermediate_size = 1024, - vocab_size = 1000, - tie_word_embeddings = False, - head_dim = 64, - num_kv_shared_layers = kv_shared, - layer_types = ["full_attention"] * 4, - ) - return extract_arch_config( - SimpleNamespace(text_config = text_config, quantization_config = {}) - ) - - def test_kv_shared_layers_keep_activation_bytes(self): - shared = self._make(kv_shared = 2) - full = self._make(kv_shared = 0) - self.assertEqual( - compute_activation_bytes( - shared, - 2, - 1024, - "none", - is_lora = True, - attention_implementation = "flash_attention_2", - ), - compute_activation_bytes( - full, - 2, - 1024, - "none", - is_lora = True, - attention_implementation = "flash_attention_2", - ), - ) - - -class TestSparseMoeSkipAliases(unittest.TestCase): - def _hf(self, skip, **fields): - text_config = SimpleNamespace( - hidden_size = 128, - num_hidden_layers = 2, - num_attention_heads = 4, - num_key_value_heads = 4, - intermediate_size = 256, - vocab_size = 1000, - tie_word_embeddings = False, - num_local_experts = 4, - moe_intermediate_size = 64, - **fields, - ) - return SimpleNamespace( - text_config = text_config, - quantization_config = {"llm_int8_skip_modules": list(skip)}, - ) - - def test_gemma4_layers_experts_alias_pulls_routed(self): - from utils.hardware.vram_estimation import _compute_skipped_quantizable_elements - - arch = extract_arch_config( - self._hf(["model.layers.0.experts"], enable_moe_block = True) - ) - self.assertGreater(_compute_skipped_quantizable_elements(arch), 0) - - def test_qwen_shared_expert_skip_pulls_only_shared(self): - from utils.hardware.vram_estimation import _compute_skipped_quantizable_elements - - arch = extract_arch_config( - self._hf( - ["model.layers.0.mlp.shared_expert"], - shared_expert_intermediate_size = 32, - ) - ) - # shared_expert delta only -- routed mlp.experts is NOT skipped. - delta = _compute_skipped_quantizable_elements(arch) - self.assertGreater(delta, 0) - full_layer = extract_arch_config( - self._hf( - ["model.layers.0.mlp"], - shared_expert_intermediate_size = 32, - ) - ) - self.assertGreater( - _compute_skipped_quantizable_elements(full_layer), - delta, - ) - - def test_exaone_shared_experts_plural_alias(self): - from utils.hardware.vram_estimation import _compute_skipped_quantizable_elements - - arch = extract_arch_config( - self._hf( - ["model.layers.0.mlp.shared_experts"], - num_shared_experts = 1, - ) - ) - self.assertGreater(_compute_skipped_quantizable_elements(arch), 0) - - -class TestAllLinearMoELoraExclusion(unittest.TestCase): - def _arch(self, **fields): - text_config = SimpleNamespace( - hidden_size = 256, - num_hidden_layers = 2, - num_attention_heads = 4, - num_key_value_heads = 4, - intermediate_size = 512, - vocab_size = 1000, - tie_word_embeddings = False, - num_local_experts = 8, - moe_intermediate_size = 64, - **fields, - ) - return extract_arch_config( - SimpleNamespace(text_config = text_config, quantization_config = {}) - ) - - def test_all_linear_drops_routed_moe_expert_lora(self): - arch = self._arch() - all_linear = compute_lora_params(arch, 8, "all-linear") - explicit = compute_lora_params(arch, 8, ["gate_proj", "up_proj", "down_proj"]) - self.assertLess(all_linear, explicit) - - def test_all_linear_drops_shared_expert_lora(self): - arch = self._arch(shared_expert_intermediate_size = 32) - all_linear = compute_lora_params(arch, 8, "all-linear") - explicit = compute_lora_params(arch, 8, ["gate_proj", "up_proj", "down_proj"]) - # explicit includes routed + shared MoE; all-linear includes neither. - self.assertLess(all_linear, explicit) - - def test_all_linear_includes_attention_lora(self): - arch = self._arch() - all_linear = compute_lora_params(arch, 8, "all-linear") - attn_only = compute_lora_params( - arch, 8, ["q_proj", "k_proj", "v_proj", "o_proj"] - ) - # all-linear still attaches to attention nn.Linear modules. - self.assertGreaterEqual(all_linear, attn_only) - - -class TestExplicitPerLayerInputLora(unittest.TestCase): - def _arch(self): - text_config = SimpleNamespace( - hidden_size = 256, - num_hidden_layers = 3, - num_attention_heads = 4, - num_key_value_heads = 4, - intermediate_size = 512, - vocab_size = 1000, - tie_word_embeddings = False, - hidden_size_per_layer_input = 32, - vocab_size_per_layer_input = 128, - ) - return extract_arch_config( - SimpleNamespace(text_config = text_config, quantization_config = {}) - ) - - def test_explicit_per_layer_input_gate_returns_nonzero(self): - arch = self._arch() - result = compute_lora_params(arch, 16, ["per_layer_input_gate"]) - self.assertGreater(result, 0) - - def test_explicit_per_layer_projection_returns_nonzero(self): - arch = self._arch() - result = compute_lora_params(arch, 16, ["per_layer_projection"]) - self.assertGreater(result, 0) - - def test_explicit_per_layer_model_projection_returns_nonzero(self): - arch = self._arch() - result = compute_lora_params(arch, 16, ["per_layer_model_projection"]) - self.assertGreater(result, 0) - - def test_explicit_ple_string_target_handled(self): - # Bare-string target with a PLE name should not be iterated char-by-char. - arch = self._arch() - list_form = compute_lora_params(arch, 16, ["per_layer_input_gate"]) - str_form = compute_lora_params(arch, 16, "per_layer_input_gate") - self.assertEqual(list_form, str_form) - - -class TestTopKExpertActivation(unittest.TestCase): - def _make(self, **fields): - text_config = SimpleNamespace( - hidden_size = 512, - num_hidden_layers = 4, - num_attention_heads = 8, - num_key_value_heads = 4, - intermediate_size = 1024, - vocab_size = 1000, - tie_word_embeddings = False, - num_local_experts = 8, - moe_intermediate_size = 64, - **fields, - ) - return extract_arch_config( - SimpleNamespace(text_config = text_config, quantization_config = {}) - ) - - def test_num_experts_per_tok_extracted(self): - arch = self._make(num_experts_per_tok = 4) - self.assertEqual(arch.num_experts_per_tok, 4) - - def test_top_k_experts_alias_extracted(self): - arch = self._make(top_k_experts = 8) - self.assertEqual(arch.num_experts_per_tok, 8) - - def test_default_top_k_one_unchanged(self): - arch = self._make() - self.assertEqual(arch.num_experts_per_tok, 1) - - def test_top_k_scales_moe_activation(self): - single = self._make() - multi = self._make(num_experts_per_tok = 8) - single_act = compute_activation_bytes( - single, - 2, - 512, - "none", - is_lora = True, - attention_implementation = "flash_attention_2", - ) - multi_act = compute_activation_bytes( - multi, - 2, - 512, - "none", - is_lora = True, - attention_implementation = "flash_attention_2", - ) - self.assertGreater(multi_act, single_act) - - -class TestErnieMoEListConfig(unittest.TestCase): - def _hf(self, **fields): - text_config = SimpleNamespace( - hidden_size = 256, - num_hidden_layers = 4, - num_attention_heads = 4, - num_key_value_heads = 4, - intermediate_size = 1024, - vocab_size = 1000, - tie_word_embeddings = False, - **fields, - ) - return SimpleNamespace(text_config = text_config, quantization_config = {}) - - def test_list_moe_intermediate_size_scalarized(self): - arch = extract_arch_config( - self._hf( - moe_num_experts = 32, - moe_intermediate_size = [1536, 512], - ) - ) - # why: ERNIE 4.5 VL MoE encodes [text_routed, vision_routed]; the - # second element is the vision-routed expert width, not the shared - # expert width. Shared experts are sized from the text-routed width - # (= moe_intermediate_size[0]) when moe_num_shared_experts is set. - self.assertEqual(arch.moe_intermediate_size, 1536) - self.assertIsNone(arch.shared_expert_intermediate_size) - self.assertEqual(arch.n_shared_experts, 0) - - def test_moe_num_experts_alias_extracted(self): - arch = extract_arch_config( - self._hf( - moe_num_experts = 64, - moe_intermediate_size = 1024, - ) - ) - self.assertEqual(arch.num_experts, 64) - - def test_moe_num_shared_experts_alias_extracted(self): - arch = extract_arch_config( - self._hf( - moe_num_experts = 16, - moe_num_shared_experts = 2, - moe_intermediate_size = 1024, - ) - ) - self.assertEqual(arch.n_shared_experts, 2) - - def test_explicit_shared_size_overrides_list_second_element(self): - arch = extract_arch_config( - self._hf( - moe_num_experts = 8, - moe_intermediate_size = [1536, 512], - shared_expert_intermediate_size = 256, - ) - ) - # Explicit shared size wins over moe_intermediate_size[1]. - self.assertEqual(arch.shared_expert_intermediate_size, 256) - - -class TestSuffixSkipModuleMatch(unittest.TestCase): - def _hf(self, skip): - text_config = SimpleNamespace( - hidden_size = 128, - num_hidden_layers = 2, - num_attention_heads = 4, - num_key_value_heads = 4, - intermediate_size = 256, - vocab_size = 1000, - tie_word_embeddings = False, - ) - return SimpleNamespace( - text_config = text_config, - quantization_config = {"llm_int8_skip_modules": list(skip)}, - ) - - def test_q_proj_suffix_skip_matches_all_layers(self): - from utils.hardware.vram_estimation import _compute_skipped_quantizable_elements - - arch = extract_arch_config(self._hf(["q_proj"])) - delta = _compute_skipped_quantizable_elements(arch) - # 2 layers * hd * hd of q_proj weight elements. - self.assertEqual(delta, 2 * arch.hidden_size * arch.hidden_size) - - def test_self_attn_aggregate_skip_matches_aggregate(self): - from utils.hardware.vram_estimation import _compute_skipped_quantizable_elements - - arch = extract_arch_config(self._hf(["self_attn"])) - # The aggregate text.layers..self_attn matches; total covers both layers. - delta = _compute_skipped_quantizable_elements(arch) - self.assertGreater(delta, 0) - - def test_vision_prefix_skip_does_not_match_text_alias(self): - from utils.hardware.vram_estimation import _module_path_matches - - # vision_tower-prefixed full path must NOT match text-tower aliases. - self.assertFalse( - _module_path_matches( - "vision_tower.model.layers.0.self_attn.q_proj", - "model.layers.0.self_attn.q_proj", - ) - ) - - -class TestMultimodalFullModelBytes(unittest.TestCase): - def test_extra_bytes_added_when_safetensors_exceeds_text_arch(self): - from utils.hardware import hardware as hardware_module - - config = SimpleNamespace( - hidden_size = 1024, - num_hidden_layers = 4, - num_attention_heads = 8, - num_key_value_heads = 4, - intermediate_size = 2048, - vocab_size = 32000, - tie_word_embeddings = False, - ) - # Force safetensors size >>> arch text-only bytes. - big_safetensors = 20 * 1024**3 - with ( - patch.object( - hardware_module, - "_load_config_for_gpu_estimate", - return_value = config, - ), - patch.object( - hardware_module, - "estimate_fp16_model_size_bytes", - return_value = (big_safetensors, "safetensors"), - ), - patch.object( - hardware_module, - "_determine_attention_impl_for_gpu_estimate", - return_value = "flash_attention_2", - ), - patch.object( - hardware_module, - "get_visible_gpu_count", - return_value = 1, - ), - ): - _, metadata = hardware_module.estimate_required_model_memory_gb( - "fake/model", - training_type = "LoRA/QLoRA", - load_in_4bit = True, - ) - self.assertEqual(metadata.get("estimation_mode"), "detailed") - # model_weights_gb must reflect the extra non-text bytes (>5 GB - # since text-only arch_fp16 is small for these dims). - self.assertGreater(metadata["vram_breakdown"]["model_weights_gb"], 5.0) - - def test_no_extra_when_safetensors_smaller_than_text_arch(self): - from utils.hardware import hardware as hardware_module - - config = SimpleNamespace( - hidden_size = 4096, - num_hidden_layers = 32, - num_attention_heads = 32, - num_key_value_heads = 8, - intermediate_size = 11008, - vocab_size = 32000, - tie_word_embeddings = False, - ) - tiny_safetensors = 100 # bytes, deliberately absurdly small - with ( - patch.object( - hardware_module, - "_load_config_for_gpu_estimate", - return_value = config, - ), - patch.object( - hardware_module, - "estimate_fp16_model_size_bytes", - return_value = (tiny_safetensors, "safetensors"), - ), - patch.object( - hardware_module, - "_determine_attention_impl_for_gpu_estimate", - return_value = "flash_attention_2", - ), - patch.object( - hardware_module, - "get_visible_gpu_count", - return_value = 1, - ), - ): - required, metadata = hardware_module.estimate_required_model_memory_gb( - "fake/model", - training_type = "LoRA/QLoRA", - load_in_4bit = True, - ) - # No negative extra; required_gb stays a positive finite number. - self.assertGreater(required, 0) - - -class TestLlama4ArchExtraction(unittest.TestCase): - def _llama4_text_config(self, **fields): - base = dict( - hidden_size = 2048, - num_hidden_layers = 4, - num_attention_heads = 16, - num_key_value_heads = 4, - intermediate_size = 8192, - intermediate_size_mlp = 16384, - vocab_size = 32000, - tie_word_embeddings = True, - num_local_experts = 4, - num_experts_per_tok = 2, - ) - base.update(fields) - return SimpleNamespace(**base) - - def test_llama4_moe_layers_dispatch_uses_explicit_indices(self): - from utils.hardware.vram_estimation import _compute_dense_layer_indices - - cfg = SimpleNamespace(num_hidden_layers = 4, moe_layers = [1, 3]) - self.assertEqual(_compute_dense_layer_indices(cfg, 4), (0, 2)) - - def test_llama4_moe_layers_takes_priority_over_first_k_dense_replace(self): - from utils.hardware.vram_estimation import _compute_dense_layer_indices - - cfg = SimpleNamespace( - num_hidden_layers = 6, - moe_layers = [2, 4], - first_k_dense_replace = 4, - ) - self.assertEqual(_compute_dense_layer_indices(cfg, 6), (0, 1, 3, 5)) - - def test_dense_intermediate_size_picks_up_intermediate_size_mlp(self): - from utils.hardware.vram_estimation import _dense_mlp_size - - arch = extract_arch_config(self._llama4_text_config(moe_layers = [1, 3])) - self.assertIsNotNone(arch) - self.assertEqual(arch.intermediate_size, 8192) - self.assertEqual(arch.dense_intermediate_size, 16384) - self.assertEqual(_dense_mlp_size(arch), 16384) - - def test_auto_attaches_one_shared_expert_at_routed_width(self): - from utils.hardware.vram_estimation import _shared_expert_size - - arch = extract_arch_config(self._llama4_text_config(moe_layers = [1, 3])) - self.assertIsNotNone(arch) - self.assertEqual(arch.n_shared_experts, 1) - self.assertIsNone(arch.shared_expert_intermediate_size) - self.assertEqual(_shared_expert_size(arch), arch.intermediate_size) - - def test_non_llama4_config_leaves_dense_intermediate_size_none(self): - from utils.hardware.vram_estimation import _dense_mlp_size - - cfg = SimpleNamespace( - hidden_size = 1024, - num_hidden_layers = 4, - num_attention_heads = 8, - num_key_value_heads = 2, - intermediate_size = 4096, - vocab_size = 32000, - tie_word_embeddings = True, - ) - arch = extract_arch_config(cfg) - self.assertIsNotNone(arch) - self.assertIsNone(arch.dense_intermediate_size) - self.assertEqual(_dense_mlp_size(arch), 4096) - - def test_intermediate_size_mlp_without_moe_does_not_force_shared_expert(self): - cfg = SimpleNamespace( - hidden_size = 2048, - num_hidden_layers = 4, - num_attention_heads = 16, - num_key_value_heads = 4, - intermediate_size = 8192, - intermediate_size_mlp = 16384, - vocab_size = 32000, - tie_word_embeddings = True, - ) - arch = extract_arch_config(cfg) - self.assertIsNotNone(arch) - self.assertEqual(arch.dense_intermediate_size, 16384) - self.assertEqual(arch.n_shared_experts, 0) - - -class TestDbrxFfnConfigExtraction(unittest.TestCase): - def test_extracts_moe_fields_from_ffn_subconfig(self): - ffn = SimpleNamespace(moe_num_experts = 4, moe_top_k = 2, ffn_hidden_size = 1024) - cfg = SimpleNamespace( - hidden_size = 2048, - num_hidden_layers = 4, - num_attention_heads = 16, - num_key_value_heads = 4, - intermediate_size = 2048, - vocab_size = 32000, - tie_word_embeddings = False, - ffn_config = ffn, - ) - arch = extract_arch_config(cfg) - self.assertIsNotNone(arch) - self.assertEqual(arch.num_experts, 4) - self.assertEqual(arch.num_experts_per_tok, 2) - self.assertEqual(arch.moe_intermediate_size, 1024) - - def test_top_level_attrs_take_precedence_over_ffn_config(self): - ffn = SimpleNamespace(moe_num_experts = 4, moe_top_k = 2, ffn_hidden_size = 1024) - cfg = SimpleNamespace( - hidden_size = 2048, - num_hidden_layers = 4, - num_attention_heads = 16, - num_key_value_heads = 4, - intermediate_size = 2048, - vocab_size = 32000, - tie_word_embeddings = False, - ffn_config = ffn, - num_local_experts = 16, - num_experts_per_tok = 8, - ) - arch = extract_arch_config(cfg) - self.assertIsNotNone(arch) - self.assertEqual(arch.num_experts, 16) - self.assertEqual(arch.num_experts_per_tok, 8) - - -class TestErniePhaseModuloDispatch(unittest.TestCase): - def test_phase_modulo_with_interval_two_matches_decoder(self): - from utils.hardware.vram_estimation import _compute_dense_layer_indices - - cfg = SimpleNamespace( - num_hidden_layers = 10, - moe_layer_start_index = 2, - moe_layer_end_index = 8, - moe_layer_interval = 2, - ) - # Decoder gates by ((i + 1) % 2 == 0) AND 2 <= i <= 8 -> MoE = {3, 5, 7}. - self.assertEqual(_compute_dense_layer_indices(cfg, 10), (0, 1, 2, 4, 6, 8, 9)) - - def test_phase_modulo_with_interval_three(self): - from utils.hardware.vram_estimation import _compute_dense_layer_indices - - cfg = SimpleNamespace( - num_hidden_layers = 9, - moe_layer_start_index = 0, - moe_layer_end_index = -1, - moe_layer_interval = 3, - ) - self.assertEqual(_compute_dense_layer_indices(cfg, 9), (0, 1, 3, 4, 6, 7)) - - -class TestErnieVlSharedExpertWidth(unittest.TestCase): - def test_shared_expert_width_uses_text_routed_not_vision(self): - from utils.hardware.vram_estimation import ( - _compute_shared_moe_elements, - _shared_expert_size, - ) - - cfg = SimpleNamespace( - text_config = SimpleNamespace( - hidden_size = 1024, - num_hidden_layers = 4, - num_attention_heads = 8, - num_key_value_heads = 4, - intermediate_size = 2048, - vocab_size = 32000, - tie_word_embeddings = False, - moe_num_experts = 8, - moe_num_shared_experts = 2, - moe_intermediate_size = [1536, 512], - ), - quantization_config = {}, - ) - arch = extract_arch_config(cfg) - self.assertIsNotNone(arch) - self.assertIsNone(arch.shared_expert_intermediate_size) - self.assertEqual(arch.moe_intermediate_size, 1536) - self.assertEqual(arch.n_shared_experts, 2) - self.assertEqual(_shared_expert_size(arch), 1536) - self.assertEqual(_compute_shared_moe_elements(arch), 1024 * 1536 * 3 * 2) - - def test_qwen_style_explicit_shared_expert_size_still_adds_gate(self): - from utils.hardware.vram_estimation import _compute_shared_moe_elements - - cfg = SimpleNamespace( - hidden_size = 1024, - num_hidden_layers = 4, - num_attention_heads = 8, - num_key_value_heads = 4, - intermediate_size = 2048, - vocab_size = 32000, - tie_word_embeddings = False, - num_local_experts = 8, - moe_intermediate_size = 256, - shared_expert_intermediate_size = 768, - ) - arch = extract_arch_config(cfg) - self.assertIsNotNone(arch) - self.assertEqual(arch.shared_expert_intermediate_size, 768) - self.assertEqual(arch.n_shared_experts, 1) - self.assertEqual( - _compute_shared_moe_elements(arch), - 1024 * 768 * 3 + 1 * 1024, - ) - - if __name__ == "__main__": unittest.main() diff --git a/studio/backend/utils/datasets/llm_assist.py b/studio/backend/utils/datasets/llm_assist.py index 4c66d2ebf6..fdc4f374ab 100644 --- a/studio/backend/utils/datasets/llm_assist.py +++ b/studio/backend/utils/datasets/llm_assist.py @@ -26,7 +26,7 @@ logger = get_logger(__name__) -DEFAULT_HELPER_MODEL_REPO = "unsloth/gemma-4-E2B-it-GGUF" +DEFAULT_HELPER_MODEL_REPO = "unsloth/Qwen3.5-4B-GGUF" DEFAULT_HELPER_MODEL_VARIANT = "UD-Q4_K_XL" README_MAX_CHARS = 1500 diff --git a/studio/backend/utils/datasets/model_mappings.py b/studio/backend/utils/datasets/model_mappings.py index 21e8566ac5..95b4791574 100644 --- a/studio/backend/utils/datasets/model_mappings.py +++ b/studio/backend/utils/datasets/model_mappings.py @@ -215,21 +215,6 @@ "google/gemma-3n-E2B-it", "unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit", ), - "gemma-4": ( - "unsloth/gemma-4-E2B-it", - "google/gemma-4-E2B-it", - "unsloth/gemma-4-E4B-it", - "google/gemma-4-E4B-it", - "unsloth/gemma-4-E2B-it-unsloth-bnb-4bit", - "unsloth/gemma-4-E4B-it-unsloth-bnb-4bit", - ), - "gemma-4-thinking": ( - "unsloth/gemma-4-26B-A4B-it", - "google/gemma-4-26B-A4B-it", - "unsloth/gemma-4-31B-it", - "unsloth/gemma-4-31B-it-unsloth-bnb-4bit", - "google/gemma-4-31B-it", - ), "qwen2.5": ( "unsloth/Qwen2.5-0.5B-Instruct-unsloth-bnb-4bit", "unsloth/Qwen2.5-0.5B-Instruct", @@ -364,16 +349,11 @@ "unsloth/Qwen3-4B-Thinking-2507-bnb-4bit", "unsloth/Qwen3-30B-A3B-Thinking-2507", "Qwen/Qwen3-30B-A3B-Thinking-2507", - "Qwen/Qwen3.6-35B-A3B", - "unsloth/Qwen3.6-35B-A3B", - "Qwen/Qwen3.6-27B", - "unsloth/Qwen3.6-27B", ), "qwen3.5": ( "unsloth/Qwen3.5-0.8B", "unsloth/Qwen3.5-2B", "unsloth/Qwen3.5-4B", - "unsloth/Qwen3.5-9B", "unsloth/Qwen3.5-27B", "unsloth/Qwen3.5-35B-A3B", ), @@ -419,15 +399,6 @@ "THUDM/GLM-4.7-Flash", "unsloth/GLM-4.7-Flash-bnb-4bit", ), - "lfm-2": ( - "unsloth/LFM2-1.2B", - "LiquidAI/LFM2-1.2B", - "unsloth/LFM2-1.2B-unsloth-bnb-4bit", - ), - "lfm-2.5": ( - "unsloth/LFM2.5-1.2B-Instruct", - "LiquidAI/LFM2.5-1.2B-Instruct", - ), } MODEL_TO_TEMPLATE_MAPPER = {} @@ -443,14 +414,6 @@ TEMPLATE_TO_RESPONSES_MAPPER = { - "gemma-4-thinking": { - "instruction": "<|turn>user\n", - "response": "<|turn>model\n", - }, - "gemma-4": { - "instruction": "<|turn>user\n", - "response": "<|turn>model\n", - }, "gemma-3": { "instruction": "user\n", "response": "model\n", @@ -551,10 +514,6 @@ "instruction": "<|im_start|>user\n", "response": "<|im_start|>assistant\n", }, - "lfm-2.5": { - "instruction": "<|im_start|>user\n", - "response": "<|im_start|>assistant\n", - }, "starling": { "instruction": "GPT4 Correct User: ", "response": "GPT4 Correct Assistant: ", diff --git a/studio/backend/utils/hardware/VRAM_ESTIMATION.md b/studio/backend/utils/hardware/VRAM_ESTIMATION.md index a6b4de29d2..26072b208f 100644 --- a/studio/backend/utils/hardware/VRAM_ESTIMATION.md +++ b/studio/backend/utils/hardware/VRAM_ESTIMATION.md @@ -33,13 +33,7 @@ Non-quantizable = 2*H*L + V*H + (V*H if not tie_embeddings else 0) | QLoRA 4-bit | `Quantizable * 2 / 3.2 + Non-quantizable * 2` | | LoRA / Full fp16 | `(Quantizable + Non-quantizable) * 2` | -The 3.2 factor (`16/5`) accounts for BNB NF4 blockwise scales. Repos whose -quantization config enables `bnb_4bit_use_double_quant` use a tighter, still -conservative 3.6 factor for the quantized portion of the weights. -When a 4-bit config has `llm_int8_skip_modules` entries that point to language -model layers or submodules, those quantizable weights are charged at fp16 -instead of NF4. Generic embedding and multimodal skip names are already covered -by non-quantizable terms or excluded from text training weights. +The 3.2 factor (`16/5`) accounts for BNB NF4 blockwise scales. ## 2. LoRA Adapters @@ -59,18 +53,6 @@ MLP modules multiply by `E` for MoE. LoRA_bytes = sum(A + B per selected module) * L * 2 ``` -`all-linear` is treated as all known text linear modules in the table above. -The estimator deliberately does not infer multimodal or vision-tower LoRA -modules from config shapes; those modules vary too much across VLM families for -a generic config formula. - -Some decoder configs expose layer-shape fields such as `layer_types`, -`head_dim`, `global_head_dim`, `num_global_key_value_heads`, `attention_k_eq_v`, -`num_kv_shared_layers`, `use_double_wide_mlp`, `vocab_size_per_layer_input`, and -`hidden_size_per_layer_input`. When those fields are present, the estimator -derives text weight and LoRA counts from the per-layer shapes instead of -assuming every layer has the same seven projection modules. - ## 3. Optimizer States (calibrated) | Optimizer | Bytes/param | Notes | @@ -95,21 +77,6 @@ Per-layer (from `unsloth_zoo/vllm_utils.py`): Per_layer = (S*B*(H+K+K) + S*B*2 + S*B*(M+M)) * 2 * 1.25 ``` -When the resolved attention implementation is none of `flash_attention_2`, -`sdpa`, or `flex_attention` (PyTorch SDPA dispatches to flash or -memory-efficient kernels and FlexAttention is also a memory-efficient -kernel, all of which are O(n) in memory), activation memory also includes -a quadratic attention-score/workspace estimate: - -``` -Non_flash_attention = B * num_attention_heads * S^2 * 2 * 12.0 * effective_layers -Activations = max(Per_layer_with_gc, Non_flash_attention) -``` - -Studio resolves the attention implementation with Unsloth's -`resolve_attention_implementation` helper and uses that result directly. The -estimator does not duplicate model-family attention policy. - | GC Mode | Full FT | LoRA/QLoRA | |---------|---------|------------| | none | `L` layers | `L` layers | @@ -118,33 +85,13 @@ estimator does not duplicate model-family attention policy. ## 6. Floors -Activations use the computed formula directly: +Gradients and activations have minimum floors at **15% of model weight memory** to account for autograd overhead, attention score matrices, NCCL buffers, mixed-precision scaling, and PyTorch fragmentation. ``` -activation_bytes = computed_activation_bytes +gradient_bytes = max(computed, weights * 0.15) +activation_bytes = max(computed, weights * 0.15 * B/2) ``` -Full fine-tuning keeps the gradient floor at **15% of model weight memory** to -account for autograd overhead, NCCL buffers, mixed-precision scaling, and -PyTorch fragmentation: - -``` -gradient_bytes = max(computed_gradient_bytes, weights * 0.15) -``` - -For LoRA/QLoRA, the base model is frozen, so the weight-derived gradient floor -is capped by trainable-state and live-activation scale: - -``` -raw_gradient_bytes = trainable_params * 2 -gradient_floor = min(weights * 0.15, max(computed_activation_bytes, optimizer_bytes)) -gradient_bytes = max(raw_gradient_bytes, gradient_floor) -``` - -This prevents frozen quantized model size from dominating gradient/state -overhead when the measured runtime footprint is governed by LoRA optimizer -states and live activations. - ## 7. CUDA Overhead **1.4 GB** fixed — CUDA driver + PyTorch runtime, calibrated on RTX 5070 Ti. @@ -159,6 +106,34 @@ usable_gb = free[gpu_0] + sum(free[gpu_i] * 0.85 for i in 1..N) --- +## Reference Table (bsz=2, seq=2048, rank=16, GC=unsloth, adamw_8bit) + +| Model | Weights | LoRA | Optim | Grad | Act | CUDA | Total | +|-------|---------|------|-------|------|-----|------|-------| +| 0.5B QLoRA | 0.5 | 0.0 | 0.0 | 0.1 | 0.1 | 1.4 | **2.1** | +| 1B QLoRA | 1.1 | 0.0 | 0.0 | 0.2 | 0.2 | 1.4 | **2.9** | +| 3B QLoRA | 2.4 | 0.0 | 0.1 | 0.5 | 0.5 | 1.4 | **4.9** | +| 8B QLoRA | 6.0 | 0.1 | 0.2 | 1.2 | 1.2 | 1.4 | **10.1** | +| 8B LoRA fp16 | 15.0 | 0.1 | 0.2 | 3.0 | 3.0 | 1.4 | **22.6** | +| 8B Full FT | 15.0 | — | 29.9 | 15.0 | 3.0 | 1.4 | **64.2** | +| 32B LoRA fp16 | 61.0 | 0.2 | 0.5 | 12.2 | 12.2 | 1.4 | **87.6** | +| 72B QLoRA | 45.5 | 0.4 | 0.8 | 9.1 | 9.1 | 1.4 | **66.3** | + +## E2E Validation (Llama-3.2-1B, B200 emulating 24GB) + +| Config | Estimated | Actual (nvsmi) | Error | +|--------|----------|----------------|-------| +| QLoRA bsz=2 seq=512 | 2.55 GB | 2.65 GB | -3.7% | +| QLoRA bsz=2 seq=2048 | 2.60 GB | 2.65 GB | -1.8% | +| QLoRA bsz=4 seq=2048 | 2.65 GB | 2.65 GB | +0.0% | +| LoRA fp16 bsz=2 | 3.84 GB | 3.88 GB | -1.0% | +| Full FT adamw_8bit | 10.89 GB | 10.80 GB | +0.8% | +| Full FT adamw_torch | 13.19 GB | 12.93 GB | +2.0% | + +*Note: e2e numbers predate the 15% floors, which add safety margin on top.* + +--- + ## Parameter Flow ``` diff --git a/studio/backend/utils/hardware/__init__.py b/studio/backend/utils/hardware/__init__.py index 400b5dd066..aaa0452406 100644 --- a/studio/backend/utils/hardware/__init__.py +++ b/studio/backend/utils/hardware/__init__.py @@ -5,7 +5,6 @@ Hardware detection and GPU utilities """ -from . import hardware as _hardware from .hardware import ( DeviceType, DEVICE, @@ -50,7 +49,6 @@ "DeviceType", "DEVICE", "CHAT_ONLY", - "IS_ROCM", "detect_hardware", "get_device", "is_apple_silicon", @@ -83,11 +81,3 @@ "extract_arch_config", "estimate_training_vram", ] - - -def __getattr__(name: str): - """Resolve IS_ROCM at access time so callers always see the live value - after detect_hardware() runs (it flips the flag in hardware.py).""" - if name == "IS_ROCM": - return getattr(_hardware, "IS_ROCM") - raise AttributeError(name) diff --git a/studio/backend/utils/hardware/amd.py b/studio/backend/utils/hardware/amd.py deleted file mode 100644 index fdb1ab4520..0000000000 --- a/studio/backend/utils/hardware/amd.py +++ /dev/null @@ -1,384 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 - -"""AMD GPU monitoring via amd-smi. - -Mirrors the nvidia.py module structure so hardware.py can swap backends -based on IS_ROCM. All functions return the same dict shapes as their -nvidia.py counterparts. -""" - -import json -import math -import os -import re -import subprocess -from typing import Any, Optional - -from loggers import get_logger -from utils.native_path_leases import child_env_without_native_path_secret - -logger = get_logger(__name__) - - -def _run_amd_smi(*args: str, timeout: int = 5) -> Optional[Any]: - """Run amd-smi with the given arguments and return parsed JSON, or None.""" - try: - result = subprocess.run( - ["amd-smi", *args, "--json"], - capture_output = True, - text = True, - timeout = timeout, - env = child_env_without_native_path_secret(), - ) - except (OSError, subprocess.TimeoutExpired) as e: - logger.warning("amd-smi query failed: %s", e) - return None - if result.returncode != 0 or not result.stdout.strip(): - logger.warning("amd-smi returned code %d", result.returncode) - return None - try: - return json.loads(result.stdout) - except json.JSONDecodeError: - logger.warning("Failed to parse amd-smi JSON output") - return None - - -def _parse_numeric(value: Any) -> Optional[float]: - """Extract a numeric value from amd-smi output (may be str, int, float, or dict).""" - if value is None: - return None - # Newer amd-smi versions emit {"value": 10, "unit": "W"} - if isinstance(value, dict): - return _parse_numeric(value.get("value")) - if isinstance(value, (int, float)): - f = float(value) - return f if math.isfinite(f) else None - if isinstance(value, str): - # Strip units like "W", "C", "%", "MB", "MiB", "GB", "GiB" etc. - cleaned = re.sub(r"\s*[A-Za-z/%]+$", "", value.strip()) - if not cleaned or cleaned.lower() in ("n/a", "none", "unknown"): - return None - try: - return float(cleaned) - except (ValueError, TypeError): - return None - return None - - -def _parse_memory_mb(value: Any) -> Optional[float]: - """Parse a memory value from amd-smi output and return MB. - - Handles bare numbers (assumed MB -- the amd-smi convention on every - version we have seen), dict-shaped values with explicit units - (``{"value": 192, "unit": "GiB"}`` on newer releases), and plain - strings like ``"8192 MiB"``. - """ - unit = "" - raw_value = value - - if isinstance(value, dict): - unit = str(value.get("unit", "")).strip().lower() - raw_value = value.get("value") - elif isinstance(value, str): - # Extract unit suffix from strings like "192 GiB" or "8192 MB" - m = re.match(r"^\s*([\d.]+)\s*([A-Za-z]+)\s*$", value.strip()) - if m: - unit = m.group(2).lower() - - num = _parse_numeric(raw_value if isinstance(value, dict) else value) - if num is None: - return None - - # Unit conversion -- GPU tools (including amd-smi) use binary units even - # when labeling them "GB" or "MB", so treat GB/GiB and MB/MiB the same. - if "gib" in unit or "gb" in unit: - return num * 1024 - if "mib" in unit or "mb" in unit: - return num - if "kib" in unit or "kb" in unit: - return num / 1024 - if unit in ("b", "byte", "bytes"): - # Plain bytes - return num / (1024 * 1024) - - # No explicit unit -- default to MB, which is the amd-smi convention - # for bare numeric values. A previous heuristic assumed values above - # ~10M were bytes, but that misclassifies small VRAM allocations - # (e.g. 5 MB = 5,242,880 reported without a unit) as ~5 TB. Modern - # amd-smi always ships explicit units, so the heuristic branch only - # fired for legacy output where MB was already the convention. - return num - - -def _extract_gpu_metrics(gpu_data: dict) -> dict[str, Any]: - """Extract standardized metrics from a single GPU's amd-smi data.""" - # amd-smi metric output structure varies by version; try common paths - usage = gpu_data.get("usage", gpu_data.get("gpu_activity", {})) - if isinstance(usage, dict): - gpu_util = _parse_numeric( - usage.get("gfx_activity", usage.get("gpu_use_percent")) - ) - else: - gpu_util = _parse_numeric(usage) - - # Temperature -- try multiple keys in priority order. - # dict.get() returns "N/A" strings rather than falling through, - # so we must try each key and check if it parses to a real number. - temp_data = gpu_data.get("temperature", {}) - temp = None - if isinstance(temp_data, dict): - for temp_key in ("edge", "temperature_edge", "hotspot", "temperature_hotspot"): - temp = _parse_numeric(temp_data.get(temp_key)) - if temp is not None: - break - else: - temp = _parse_numeric(temp_data) - - # Power - power_data = gpu_data.get("power", {}) - if isinstance(power_data, dict): - power_draw = _parse_numeric( - power_data.get( - "current_socket_power", - power_data.get("average_socket_power", power_data.get("socket_power")), - ) - ) - power_limit = _parse_numeric( - power_data.get("power_cap", power_data.get("max_power_limit")) - ) - else: - power_draw = None - power_limit = None - - # VRAM -- unit-aware parsing to handle varying amd-smi output formats. - # Newer amd-smi versions may return {"value": 192, "unit": "GiB"}. - # Newer amd-smi uses "mem_usage" with "total_vram" / "used_vram" keys; - # older versions use "vram" or "fb_memory_usage" with "used" / "total". - vram_data = gpu_data.get( - "mem_usage", - gpu_data.get("vram", gpu_data.get("fb_memory_usage", {})), - ) - if isinstance(vram_data, dict): - vram_used_mb = _parse_memory_mb( - vram_data.get( - "used_vram", vram_data.get("vram_used", vram_data.get("used")) - ) - ) - vram_total_mb = _parse_memory_mb( - vram_data.get( - "total_vram", vram_data.get("vram_total", vram_data.get("total")) - ) - ) - else: - vram_used_mb = None - vram_total_mb = None - - # Build the standardized dict (same shape as nvidia._build_gpu_metrics) - vram_used_gb = round(vram_used_mb / 1024, 2) if vram_used_mb is not None else None - vram_total_gb = ( - round(vram_total_mb / 1024, 2) if vram_total_mb is not None else None - ) - vram_util = ( - round((vram_used_mb / vram_total_mb) * 100, 1) - if vram_used_mb is not None and vram_total_mb is not None and vram_total_mb > 0 - else None - ) - power_util = ( - round((power_draw / power_limit) * 100, 1) - if power_draw is not None and power_limit is not None and power_limit > 0 - else None - ) - - return { - "gpu_utilization_pct": gpu_util, - "temperature_c": temp, - "vram_used_gb": vram_used_gb, - "vram_total_gb": vram_total_gb, - "vram_utilization_pct": vram_util, - "power_draw_w": power_draw, - "power_limit_w": power_limit, - "power_utilization_pct": power_util, - } - - -def _has_real_metrics(metrics: dict[str, Any]) -> bool: - """Return True when ``metrics`` contains at least one non-None value. - - ``amd-smi`` can return a zero-exit JSON envelope that is missing every - expected field (error response, unsupported card, hipless container). - In that case ``_extract_gpu_metrics`` produces a dict where every value - is ``None`` -- callers must surface this as ``available: False`` rather - than ``available: True`` with empty data. - """ - return any(value is not None for value in metrics.values()) - - -def get_physical_gpu_count() -> Optional[int]: - """Return physical AMD GPU count via amd-smi, or None on failure.""" - data = _run_amd_smi("list") - if data is None: - return None - if isinstance(data, list): - return len(data) - # Some versions return a dict with a "gpu" / "gpus" key. Guard the - # .get() access with an isinstance check so a malformed scalar / - # string response from amd-smi cannot raise AttributeError. - if not isinstance(data, dict): - return None - gpus = data.get("gpu", data.get("gpus", [])) - if isinstance(gpus, list): - return len(gpus) - return None - - -def _first_visible_amd_gpu_id() -> Optional[str]: - """Return the physical AMD GPU id that should be treated as 'primary'. - - Honours HIP_VISIBLE_DEVICES / ROCR_VISIBLE_DEVICES / CUDA_VISIBLE_DEVICES - in that order (HIP respects all three). Returns ``"0"`` when none are - set, and ``None`` when the env var explicitly narrows to zero GPUs - ("" or "-1"), so callers can short-circuit to "available: False". - """ - for env_name in ( - "HIP_VISIBLE_DEVICES", - "ROCR_VISIBLE_DEVICES", - "CUDA_VISIBLE_DEVICES", - ): - raw = os.environ.get(env_name) - if raw is None: - continue - raw = raw.strip() - if raw == "" or raw == "-1": - return None - # Filter out empty tokens after splitting. This tolerates minor - # typos like ``HIP_VISIBLE_DEVICES=",1"`` (leading comma, user - # clearly meant to narrow to device 1) while still falling - # through to the next env var when every token is empty - # (e.g. ``,,,``). - tokens = [t.strip() for t in raw.split(",") if t.strip()] - if tokens: - return tokens[0] - return "0" - - -def get_primary_gpu_utilization() -> dict[str, Any]: - """Return utilization metrics for the primary visible AMD GPU.""" - gpu_idx = _first_visible_amd_gpu_id() - if gpu_idx is None: - return {"available": False} - data = _run_amd_smi("metric", "-g", gpu_idx) - if data is None: - return {"available": False} - - # amd-smi may return: - # - a list of GPU dicts (older versions) - # - a dict with a "gpu_data" key wrapping a list (newer versions) - # - a single GPU dict (rare) - if isinstance(data, dict) and "gpu_data" in data: - data = data["gpu_data"] - if isinstance(data, list): - if len(data) == 0: - return {"available": False} - gpu_data = data[0] - else: - gpu_data = data - - metrics = _extract_gpu_metrics(gpu_data) - if not _has_real_metrics(metrics): - # amd-smi returned a JSON envelope with no usable fields (error - # response or unsupported card). Surface as unavailable rather - # than available-with-empty-data so the UI does not render a - # ghost device. - return {"available": False} - metrics["available"] = True - return metrics - - -def get_visible_gpu_utilization( - parent_visible_ids: Optional[list[int]], - parent_cuda_visible_devices: Optional[str] = None, -) -> dict[str, Any]: - """Return utilization metrics for visible AMD GPUs.""" - if parent_visible_ids is None: - return { - "available": False, - "backend_cuda_visible_devices": parent_cuda_visible_devices, - "parent_visible_gpu_ids": [], - "devices": [], - "index_kind": "unresolved", - } - - data = _run_amd_smi("metric") - if data is None: - return { - "available": False, - "backend_cuda_visible_devices": parent_cuda_visible_devices, - "parent_visible_gpu_ids": parent_visible_ids or [], - "devices": [], - "index_kind": "physical", - } - - # Extract a device list from amd-smi's envelope. Newer versions return - # a JSON array directly, older versions return a dict with a "gpus" / - # "gpu" key wrapping the list. Guard non-dict / non-list envelopes - # (scalar / string fallbacks from malformed output) so the .get() - # access cannot raise AttributeError on an unexpected shape. - if isinstance(data, list): - gpu_list = data - elif isinstance(data, dict): - # Newer amd-smi wraps output in {"gpu_data": [...]} - gpu_list = data.get("gpu_data", data.get("gpus", data.get("gpu", [data]))) - else: - gpu_list = [data] - visible_set = set(parent_visible_ids) - ordinal_map = {gpu_id: ordinal for ordinal, gpu_id in enumerate(parent_visible_ids)} - - devices = [] - for fallback_idx, gpu_data in enumerate(gpu_list): - # Skip non-dict entries defensively: if amd-smi ever ships a - # scalar inside its "gpus" array (observed on some malformed - # output), _extract_gpu_metrics would raise AttributeError on - # the first .get() call. - if not isinstance(gpu_data, dict): - continue - # Use AMD-reported GPU ID when available, fall back to enumeration - # index. Newer amd-smi versions wrap scalars as ``{"value": 0, - # "unit": "none"}``, so route raw_id through ``_parse_numeric`` - # which already handles bare ints, floats, strings, and that - # dict shape uniformly. - raw_id = gpu_data.get( - "gpu", gpu_data.get("gpu_id", gpu_data.get("id", fallback_idx)) - ) - parsed_id = _parse_numeric(raw_id) - if parsed_id is None: - logger.debug( - "amd-smi GPU id %r could not be parsed; falling back to " - "enumeration index %d", - raw_id, - fallback_idx, - ) - idx = fallback_idx - else: - idx = int(parsed_id) - if idx not in visible_set: - continue - metrics = _extract_gpu_metrics(gpu_data) - if not _has_real_metrics(metrics): - # Skip ghost entries: an amd-smi response that decodes to a - # dict but contains no usable fields (error envelope, etc.) - # would otherwise show up as a device row with all-None - # numbers in the UI. - continue - metrics["index"] = idx - metrics["index_kind"] = "physical" - metrics["visible_ordinal"] = ordinal_map.get(idx, len(devices)) - devices.append(metrics) - - return { - "available": len(devices) > 0, - "backend_cuda_visible_devices": parent_cuda_visible_devices, - "parent_visible_gpu_ids": parent_visible_ids or [], - "devices": devices, - "index_kind": "physical", - } diff --git a/studio/backend/utils/hardware/hardware.py b/studio/backend/utils/hardware/hardware.py index c218b7b4b9..b6d3faf6d7 100644 --- a/studio/backend/utils/hardware/hardware.py +++ b/studio/backend/utils/hardware/hardware.py @@ -43,26 +43,6 @@ class DeviceType(str, Enum): DEVICE: Optional[DeviceType] = None CHAT_ONLY: bool = True # No CUDA GPU -> GGUF chat only (Mac, CPU-only, etc.) -IS_ROCM: bool = ( - False # True when running on AMD ROCm (HIP) -- routes GPU monitoring to amd.py -) - - -def _backend_label(device: DeviceType) -> str: - """Return the user-facing backend name for API responses. - - Internally we still represent ROCm hosts as ``DeviceType.CUDA`` because - ROCm torch sets ``torch.cuda.is_available() = True`` and reuses the whole - ``torch.cuda.*`` API surface, so branching on ``DeviceType`` stays - consistent with the rest of the codebase. For the JSON responses served - to the Studio frontend and other clients, however, "cuda" is misleading - on an AMD machine. This helper swaps the label to ``"rocm"`` when the - module-level ``IS_ROCM`` flag is set so the UI can render the correct - backend name without every caller having to duplicate the check. - """ - if IS_ROCM and device == DeviceType.CUDA: - return "rocm" - return device.value # ========== Detection ========== @@ -105,11 +85,10 @@ def detect_hardware() -> DeviceType: 2. MLX (Apple Silicon via MLX framework) 3. CPU (fallback) """ - global DEVICE, CHAT_ONLY, IS_ROCM - CHAT_ONLY = True # reset -- only CUDA/ROCm sets it to False - IS_ROCM = False + global DEVICE, CHAT_ONLY + CHAT_ONLY = True # reset -- only CUDA sets it to False - # --- CUDA / ROCm: try PyTorch --- + # --- CUDA: try PyTorch --- if _has_torch(): import torch @@ -117,16 +96,7 @@ def detect_hardware() -> DeviceType: DEVICE = DeviceType.CUDA CHAT_ONLY = False device_name = torch.cuda.get_device_properties(0).name - - # Distinguish AMD ROCm (HIP) from NVIDIA CUDA for display purposes. - # DeviceType stays CUDA since torch.cuda.* works on ROCm via HIP. - if getattr(torch.version, "hip", None) is not None: - IS_ROCM = True - print( - f"Hardware detected: ROCm (HIP {torch.version.hip}) -- {device_name}" - ) - else: - print(f"Hardware detected: CUDA -- {device_name}") + print(f"Hardware detected: CUDA — {device_name}") return DEVICE # --- XPU: Intel GPU --- @@ -216,7 +186,7 @@ def get_gpu_memory_info() -> Dict[str, Any]: return { "available": True, - "backend": _backend_label(device), + "backend": device.value, "device": idx, "device_name": props.name, "total_gb": total / (1024**3), @@ -227,11 +197,7 @@ def get_gpu_memory_info() -> Dict[str, Any]: } except Exception as e: logger.error(f"Error getting CUDA GPU info: {e}") - return { - "available": False, - "backend": _backend_label(device), - "error": str(e), - } + return {"available": False, "backend": device.value, "error": str(e)} # ---- XPU path (Intel GPU) ---- if device == DeviceType.XPU: @@ -247,7 +213,7 @@ def get_gpu_memory_info() -> Dict[str, Any]: return { "available": True, - "backend": _backend_label(device), + "backend": device.value, "device": idx, "device_name": props.name, "total_gb": total / (1024**3), @@ -258,11 +224,7 @@ def get_gpu_memory_info() -> Dict[str, Any]: } except Exception as e: logger.error("Error getting XPU GPU info: %s", e) - return { - "available": False, - "backend": _backend_label(device), - "error": str(e), - } + return {"available": False, "backend": device.value, "error": str(e)} # ---- MLX path (Apple Silicon) ---- if device == DeviceType.MLX: @@ -277,7 +239,7 @@ def get_gpu_memory_info() -> Dict[str, Any]: return { "available": True, - "backend": _backend_label(device), + "backend": device.value, "device": 0, "device_name": f"Apple Silicon ({platform.processor() or platform.machine()})", "total_gb": total / (1024**3), @@ -288,11 +250,7 @@ def get_gpu_memory_info() -> Dict[str, Any]: } except Exception as e: logger.error(f"Error getting MLX GPU info: {e}") - return { - "available": False, - "backend": _backend_label(device), - "error": str(e), - } + return {"available": False, "backend": device.value, "error": str(e)} # ---- CPU-only ---- return {"available": False, "backend": "cpu"} @@ -357,15 +315,13 @@ def get_package_versions() -> Dict[str, Optional[str]]: except PackageNotFoundError: versions[name] = None - # GPU runtime version bundled with torch + # CUDA toolkit version bundled with torch try: import torch versions["cuda"] = getattr(torch.version, "cuda", None) - versions["rocm"] = getattr(torch.version, "hip", None) except Exception: versions["cuda"] = None - versions["rocm"] = None return versions @@ -431,50 +387,26 @@ def _torch_get_per_device_info(device_indices: list[int]) -> list[Dict[str, Any] # ========== Live GPU Utilization ========== -def _smi_query(func_name: str, *args, **kwargs) -> Optional[Dict[str, Any]]: - """Run a query against the appropriate SMI backend (amd-smi or nvidia-smi). - - Returns the result dict if available, or None on failure/unavailability. - """ - if IS_ROCM: - backend_name = "amd-smi" - try: - from . import amd as _backend - except Exception as e: - logger.warning("%s import failed: %s", backend_name, e) - return None - else: - backend_name = "nvidia-smi" - try: - from . import nvidia as _backend - except Exception as e: - logger.warning("%s import failed: %s", backend_name, e) - return None - try: - func = getattr(_backend, func_name) - result = func(*args, **kwargs) - if result.get("available"): - return result - except Exception as e: - logger.warning("%s %s query failed: %s", backend_name, func_name, e) - return None - - def get_gpu_utilization() -> Dict[str, Any]: """Return a live snapshot of device utilization information.""" device = get_device() if device == DeviceType.CUDA: - result = _smi_query("get_primary_gpu_utilization") - if result is not None: - result["backend"] = _backend_label(device) - return result + try: + from . import nvidia + + result = nvidia.get_primary_gpu_utilization() + if result.get("available"): + result["backend"] = device.value + return result + except Exception as e: + logger.warning("nvidia-smi utilization query failed: %s", e) mem = get_gpu_memory_info() if device != DeviceType.CPU and mem.get("available"): return { "available": True, - "backend": _backend_label(device), + "backend": device.value, "gpu_utilization_pct": None, "temperature_c": None, "vram_used_gb": round(mem.get("allocated_gb", 0), 2), @@ -485,7 +417,7 @@ def get_gpu_utilization() -> Dict[str, Any]: "power_utilization_pct": None, } - return {"available": False, "backend": _backend_label(device)} + return {"available": False, "backend": device.value} def get_visible_gpu_utilization() -> Dict[str, Any]: @@ -493,14 +425,18 @@ def get_visible_gpu_utilization() -> Dict[str, Any]: if device == DeviceType.CUDA: parent_visible_spec = _get_parent_visible_gpu_spec() - result = _smi_query( - "get_visible_gpu_utilization", - parent_visible_spec["numeric_ids"], - parent_cuda_visible_devices = parent_visible_spec["raw"], - ) - if result is not None: - result["backend"] = _backend_label(device) - return result + try: + from . import nvidia + + result = nvidia.get_visible_gpu_utilization( + parent_visible_spec["numeric_ids"], + parent_cuda_visible_devices = parent_visible_spec["raw"], + ) + if result.get("available"): + result["backend"] = device.value + return result + except Exception as e: + logger.warning("nvidia-smi visible GPU utilization query failed: %s", e) # Torch-based fallback for CUDA (nvidia-smi unavailable, AMD ROCm) and XPU (Intel) if device in (DeviceType.CUDA, DeviceType.XPU): @@ -539,7 +475,7 @@ def get_visible_gpu_utilization() -> Dict[str, Any]: ) return { "available": True, - "backend": _backend_label(device), + "backend": device.value, "parent_visible_gpu_ids": parent_ids, "devices": devices, "index_kind": index_kind, @@ -550,14 +486,14 @@ def get_visible_gpu_utilization() -> Dict[str, Any]: if not mem.get("available"): return { "available": False, - "backend": _backend_label(device), + "backend": device.value, "parent_visible_gpu_ids": [], "devices": [], "index_kind": "relative", } return { "available": True, - "backend": _backend_label(device), + "backend": device.value, "parent_visible_gpu_ids": [0], "devices": [ { @@ -579,7 +515,7 @@ def get_visible_gpu_utilization() -> Dict[str, Any]: return { "available": False, - "backend": _backend_label(device), + "backend": device.value, "parent_visible_gpu_ids": [], "devices": [], "index_kind": "relative", @@ -593,21 +529,7 @@ def get_visible_gpu_utilization() -> Dict[str, Any]: def _get_parent_visible_gpu_spec() -> Dict[str, Any]: - # ROCm uses HIP_VISIBLE_DEVICES / ROCR_VISIBLE_DEVICES in addition to - # CUDA_VISIBLE_DEVICES (which HIP also respects). Check ROCm-specific - # env vars first so multi-GPU AMD setups are handled correctly. - # Use explicit None checks (not `or`) so empty string "" is honoured - # as "no visible GPUs" rather than falling through to CUDA_VISIBLE_DEVICES. - cuda_visible = None - if IS_ROCM: - hip_vis = os.environ.get("HIP_VISIBLE_DEVICES") - rocr_vis = os.environ.get("ROCR_VISIBLE_DEVICES") - if hip_vis is not None: - cuda_visible = hip_vis - elif rocr_vis is not None: - cuda_visible = rocr_vis - if cuda_visible is None: - cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES") + cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES") if cuda_visible is None: return { @@ -774,34 +696,6 @@ def _load_config_for_gpu_estimate(model_name: str, hf_token: Optional[str] = Non return None -def _determine_attention_impl_for_gpu_estimate(config) -> str: - import copy as _copy - - from unsloth.models._utils import resolve_attention_implementation - from transformers import AutoModel, AutoModelForCausalLM - - # why: resolve_attention_implementation calls _set_attn_impl which writes - # _attn_implementation onto the config; PreTrainedConfig's setter walks - # `sub_configs` and propagates to nested text_config / sub-configs, so a - # shallow copy still mutates those shared inner objects on the cached - # config returned by _load_config_for_gpu_estimate. Deepcopy isolates them. - config_copy = _copy.deepcopy(config) - - model_class = None - for auto_model in (AutoModelForCausalLM, AutoModel): - mapping = getattr(auto_model, "_model_mapping", None) - if mapping is None: - continue - try: - if config_copy.__class__ in mapping: - model_class = mapping[config_copy.__class__] - break - except Exception: - continue - - return resolve_attention_implementation(model_class, config_copy) - - def _estimate_fp16_model_size_bytes_from_config(config) -> Optional[int]: from .vram_estimation import extract_arch_config, compute_total_params @@ -872,21 +766,12 @@ def estimate_fp16_model_size_bytes( return int(total_params * 2), "safetensors" config = _load_config_for_gpu_estimate(estimate_model, hf_token = hf_token) - config_bytes: Optional[int] = None if config is not None: config_bytes = _estimate_fp16_model_size_bytes_from_config(config) + if config_bytes is not None: + return config_bytes, "config" local_bytes = _get_local_weight_size_bytes(estimate_model) - - # why: config-derived bytes cover only the text tower; local safetensors - # include vision/audio towers. Take the larger so the multimodal - # extra_bytes correction can fire. - if config_bytes is not None and local_bytes is not None: - if local_bytes > config_bytes: - return local_bytes, "weight_bytes" - return config_bytes, "config" - if config_bytes is not None: - return config_bytes, "config" if local_bytes is not None: return local_bytes, "weight_bytes" @@ -914,9 +799,6 @@ def estimate_required_model_memory_gb( TrainingVramConfig, extract_arch_config, estimate_training_vram, - compute_total_params, - compute_optimizer_bytes, - compute_gradient_bytes, CUDA_OVERHEAD_BYTES, QUANT_4BIT_FACTOR, DEFAULT_TARGET_MODULES, @@ -966,44 +848,13 @@ def estimate_required_model_memory_gb( model_name, hf_token = hf_token ) config = _load_config_for_gpu_estimate(estimate_model, hf_token = hf_token) - if config is not None: - try: - vram_config.attention_implementation = ( - _determine_attention_impl_for_gpu_estimate(config) - ) - except Exception as e: - logger.warning( - "Could not resolve attention implementation for '%s': %s", - estimate_model, - e, - ) - # why: if we cannot prove flash attention is usable, charge the - # quadratic non-flash activation path so GPU selection stays - # conservative. - vram_config.attention_implementation = "eager" arch = extract_arch_config(config) if config is not None else None if arch is not None: breakdown = estimate_training_vram(arch, vram_config) - # why: extract_arch_config only sees text_config; safetensors include - # vision/audio tower bytes that the text-arch fp16 total misses. - arch_fp16_bytes = compute_total_params(arch) * 2 - extra_bytes = max(0, int(model_size_bytes) - arch_fp16_bytes) - if extra_bytes > 0: - breakdown.model_weights += extra_bytes - if training_method == "full": - # why: full fine-tuning makes the extra (vision/audio) params - # trainable; optimizer + gradient bytes scale with them too. - extra_params = extra_bytes // 2 - breakdown.optimizer_states += compute_optimizer_bytes( - extra_params, - vram_config.optimizer, - ) - breakdown.gradients += compute_gradient_bytes(extra_params) required_gb = breakdown.total / (1024**3) metadata["required_gb"] = round(required_gb, 3) metadata["estimation_mode"] = "detailed" - metadata["attention_implementation"] = vram_config.attention_implementation metadata["vram_breakdown"] = breakdown.to_gb_dict() max_gpus = max(1, get_visible_gpu_count()) for n_gpus in range(1, max_gpus + 1): @@ -1258,17 +1109,15 @@ def get_physical_gpu_count() -> int: if device == DeviceType.CUDA: try: - if IS_ROCM: - from . import amd as _smi_mod - else: - from . import nvidia as _smi_mod - count = _smi_mod.get_physical_gpu_count() + from . import nvidia + + count = nvidia.get_physical_gpu_count() if count is not None: _physical_gpu_count = count return _physical_gpu_count except Exception: pass - # SMI tool unavailable or failed -- fall back to torch + # nvidia-smi unavailable or failed — fall back to torch count = _torch_get_physical_gpu_count() _physical_gpu_count = count if count is not None else 1 return _physical_gpu_count @@ -1287,25 +1136,12 @@ def get_physical_gpu_count() -> int: return _physical_gpu_count -def _backend_visible_devices_env() -> Optional[str]: - """Return the raw visibility env string that applies to this backend. - - On ROCm, HIP_VISIBLE_DEVICES / ROCR_VISIBLE_DEVICES take precedence - over CUDA_VISIBLE_DEVICES; the helper mirrors the resolution logic in - ``_get_parent_visible_gpu_spec`` so ``backend_cuda_visible_devices`` - reports the value that is actually narrowing the visible device set. - """ - if IS_ROCM: - return _get_parent_visible_gpu_spec().get("raw") - return os.environ.get("CUDA_VISIBLE_DEVICES") - - def get_backend_visible_gpu_info() -> Dict[str, Any]: device = get_device() if device in (DeviceType.CUDA, DeviceType.XPU): parent_visible_ids = get_parent_visible_gpu_ids() - # Try native SMI tool first (nvidia-smi for NVIDIA, skipped for ROCm) - if device == DeviceType.CUDA and not IS_ROCM: + # Try nvidia-smi first (NVIDIA only) + if device == DeviceType.CUDA: try: from . import nvidia @@ -1315,7 +1151,7 @@ def get_backend_visible_gpu_info() -> Dict[str, Any]: parent_visible_spec["raw"], ) if result.get("available"): - result["backend"] = _backend_label(device) + result["backend"] = device.value return result except Exception as e: logger.warning("Backend GPU visibility query failed: %s", e) @@ -1344,8 +1180,8 @@ def get_backend_visible_gpu_info() -> Dict[str, Any]: ] return { "available": True, - "backend": _backend_label(device), - "backend_cuda_visible_devices": _backend_visible_devices_env(), + "backend": device.value, + "backend_cuda_visible_devices": os.environ.get("CUDA_VISIBLE_DEVICES"), "parent_visible_gpu_ids": parent_visible_ids, "devices": devices, "index_kind": index_kind, @@ -1353,8 +1189,8 @@ def get_backend_visible_gpu_info() -> Dict[str, Any]: return { "available": False, - "backend": _backend_label(device), - "backend_cuda_visible_devices": _backend_visible_devices_env(), + "backend": device.value, + "backend_cuda_visible_devices": os.environ.get("CUDA_VISIBLE_DEVICES"), "parent_visible_gpu_ids": parent_visible_ids, "devices": [], "index_kind": "physical", @@ -1365,7 +1201,7 @@ def get_backend_visible_gpu_info() -> Dict[str, Any]: if not mem.get("available"): return { "available": False, - "backend": _backend_label(device), + "backend": device.value, "backend_cuda_visible_devices": os.environ.get("CUDA_VISIBLE_DEVICES"), "parent_visible_gpu_ids": [], "devices": [], @@ -1373,7 +1209,7 @@ def get_backend_visible_gpu_info() -> Dict[str, Any]: } return { "available": True, - "backend": _backend_label(device), + "backend": device.value, "backend_cuda_visible_devices": os.environ.get("CUDA_VISIBLE_DEVICES"), "parent_visible_gpu_ids": [0], "devices": [ @@ -1390,7 +1226,7 @@ def get_backend_visible_gpu_info() -> Dict[str, Any]: return { "available": False, - "backend": _backend_label(device), + "backend": device.value, "backend_cuda_visible_devices": os.environ.get("CUDA_VISIBLE_DEVICES"), "parent_visible_gpu_ids": [], "devices": [], @@ -1410,20 +1246,17 @@ def get_visible_gpu_count() -> int: if _visible_gpu_count is not None: return _visible_gpu_count - # Use _get_parent_visible_gpu_spec() which already handles - # HIP_VISIBLE_DEVICES / ROCR_VISIBLE_DEVICES on ROCm. - visible_spec = _get_parent_visible_gpu_spec() - if visible_spec["raw"] is not None: - raw = visible_spec["raw"].strip() - if raw == "" or raw == "-1": + cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES") + if cuda_visible is not None: + # "" means zero GPUs, "0" means 1, "0,1,2" means 3 + cuda_visible = cuda_visible.strip() + if cuda_visible == "" or cuda_visible == "-1": _visible_gpu_count = 0 - elif visible_spec["numeric_ids"] is not None: - _visible_gpu_count = len(visible_spec["numeric_ids"]) else: - _visible_gpu_count = len([x for x in raw.split(",") if x.strip()]) + _visible_gpu_count = len([x for x in cuda_visible.split(",") if x.strip()]) return _visible_gpu_count - # No visibility env var set -- try torch, fall back to physical count + # CUDA_VISIBLE_DEVICES not set -- try torch, fall back to physical count try: import torch @@ -1455,24 +1288,8 @@ def apply_gpu_ids(gpu_ids) -> None: value = str(gpu_ids) os.environ["CUDA_VISIBLE_DEVICES"] = value - # Keep ROCm visibility env vars in sync so _get_parent_visible_gpu_spec() - # picks up the narrowed set on AMD systems. Workers can call - # apply_gpu_ids() before detect_hardware() runs (so IS_ROCM is still - # its default False), so also mirror the selection whenever the - # parent process already set a ROCm visibility variable -- that - # way a downstream ROCm process inherits the narrowed mask even - # before Studio's hardware detection has classified the host. - _inherits_rocm_visibility = ( - "HIP_VISIBLE_DEVICES" in os.environ or "ROCR_VISIBLE_DEVICES" in os.environ - ) - if IS_ROCM or _inherits_rocm_visibility: - os.environ["HIP_VISIBLE_DEVICES"] = value - os.environ["ROCR_VISIBLE_DEVICES"] = value _visible_gpu_count = None - if IS_ROCM or _inherits_rocm_visibility: - logger.info("Applied gpu_ids: CUDA_VISIBLE_DEVICES='%s' (rocm)", value) - else: - logger.info("Applied gpu_ids: CUDA_VISIBLE_DEVICES='%s'", value) + logger.info("Applied gpu_ids: CUDA_VISIBLE_DEVICES='%s'", value) def get_device_map( diff --git a/studio/backend/utils/hardware/nvidia.py b/studio/backend/utils/hardware/nvidia.py index 099c5fa3a5..dc5295c302 100644 --- a/studio/backend/utils/hardware/nvidia.py +++ b/studio/backend/utils/hardware/nvidia.py @@ -6,11 +6,6 @@ from loggers import get_logger -from utils.native_path_leases import child_env_without_native_path_secret -from utils.subprocess_compat import ( - windows_hidden_subprocess_kwargs as _windows_hidden_subprocess_kwargs, -) - logger = get_logger(__name__) @@ -66,8 +61,6 @@ def get_physical_gpu_count() -> Optional[int]: capture_output = True, text = True, timeout = 5, - env = child_env_without_native_path_secret(), - **_windows_hidden_subprocess_kwargs(), ) if result.returncode == 0 and result.stdout.strip(): return len(result.stdout.strip().splitlines()) @@ -92,8 +85,6 @@ def get_primary_gpu_utilization() -> dict[str, Any]: capture_output = True, text = True, timeout = 5, - env = child_env_without_native_path_secret(), - **_windows_hidden_subprocess_kwargs(), ) except (OSError, subprocess.TimeoutExpired) as e: logger.warning("nvidia-smi query failed in get_primary_gpu_utilization: %s", e) @@ -144,8 +135,6 @@ def get_visible_gpu_utilization( capture_output = True, text = True, timeout = 5, - env = child_env_without_native_path_secret(), - **_windows_hidden_subprocess_kwargs(), ) except (OSError, subprocess.TimeoutExpired) as e: logger.warning("nvidia-smi query failed in get_visible_gpu_utilization: %s", e) @@ -231,8 +220,6 @@ def get_backend_visible_gpu_info( capture_output = True, text = True, timeout = 10, - env = child_env_without_native_path_secret(), - **_windows_hidden_subprocess_kwargs(), ) except (OSError, subprocess.TimeoutExpired) as e: logger.warning("nvidia-smi query failed in get_backend_visible_gpu_info: %s", e) diff --git a/studio/backend/utils/hardware/vram_estimation.py b/studio/backend/utils/hardware/vram_estimation.py index ba1b1dfe61..e03665374d 100644 --- a/studio/backend/utils/hardware/vram_estimation.py +++ b/studio/backend/utils/hardware/vram_estimation.py @@ -16,26 +16,7 @@ from typing import Dict, Optional QUANT_4BIT_FACTOR = 16 / 5 -DOUBLE_QUANT_4BIT_FACTOR = ( - 3.6 # bnb_4bit_use_double_quant; see VRAM_ESTIMATION.md section 1 -) CUDA_OVERHEAD_BYTES = int(1.4 * 1024**3) # calibrated on RTX 5070 Ti -NON_FLASH_ATTENTION_FACTOR = ( - 12.0 # eager attention score+workspace overhead; see VRAM_ESTIMATION.md section 5 -) - -LINEAR_ATTENTION_IMPLS = frozenset({"flash_attention_2", "sdpa", "flex_attention"}) - -_SKIP_MODULE_TEXT_PREFIXES = frozenset( - { - "model", - "model.model", - "language_model", - "language_model.model", - "model.language_model", - "model.language_model.model", - } -) DEFAULT_TARGET_MODULES = [ "q_proj", @@ -46,8 +27,6 @@ "up_proj", "down_proj", ] -ATTENTION_TARGET_MODULES = {"q_proj", "k_proj", "v_proj", "o_proj"} -MLP_TARGET_MODULES = {"gate_proj", "up_proj", "down_proj"} # Empirically calibrated bytes/param — see VRAM_ESTIMATION.md for rationale. OPTIMIZER_BYTES_PER_PARAM: Dict[str, int] = { @@ -82,28 +61,12 @@ class ModelArchConfig: num_experts: Optional[int] = None moe_intermediate_size: Optional[int] = None n_shared_experts: int = 0 - shared_expert_intermediate_size: Optional[int] = None - num_experts_per_tok: int = 1 num_dense_layers: int = 0 q_lora_rank: Optional[int] = None kv_lora_rank: Optional[int] = None qk_nope_head_dim: Optional[int] = None qk_rope_head_dim: Optional[int] = None v_head_dim: Optional[int] = None - head_dim: Optional[int] = None - global_head_dim: Optional[int] = None - num_global_key_value_heads: Optional[int] = None - attention_k_eq_v: bool = False - layer_types: Optional[list] = None - num_kv_shared_layers: int = 0 - use_double_wide_mlp: bool = False - vocab_size_per_layer_input: int = 0 - hidden_size_per_layer_input: int = 0 - quantization_skip_modules: list = field(default_factory = list) - quant_4bit_factor: float = QUANT_4BIT_FACTOR - moe_has_dense_mlp: bool = False - dense_layer_indices: tuple = () - dense_intermediate_size: Optional[int] = None @dataclass @@ -116,7 +79,6 @@ class TrainingVramConfig: gradient_checkpointing: str = "unsloth" optimizer: str = "adamw_8bit" load_in_4bit: bool = True - attention_implementation: str = "flash_attention_2" @dataclass @@ -127,8 +89,8 @@ class VramBreakdown: gradients: int activations: int cuda_overhead: int - # Equals `activations`; retained for backward compatibility with - # consumers that read this field. + # The computed (formula-based) activation cost before floors. + # This is the true per-layer cost that doesn't shard across GPUs. activations_computed: int = 0 @property @@ -146,15 +108,17 @@ def min_gpu_vram(self, n_gpus: int) -> int: """Minimum VRAM a single GPU needs: its shard + non-shardable costs. Weights/LoRA/optimizer/gradients shard across GPUs. - Activations do NOT shard (the GPU running a layer holds them). + The computed activation cost does NOT shard (one GPU runs the layer). + The floor portion (activations - computed) is overhead that shards. """ shardable = ( self.model_weights + self.lora_adapters + self.optimizer_states + self.gradients + + (self.activations - self.activations_computed) # floor overhead shards ) - per_gpu_fixed = self.activations + self.cuda_overhead + per_gpu_fixed = self.activations_computed + self.cuda_overhead return shardable // max(n_gpus, 1) + per_gpu_fixed def to_gb_dict(self) -> Dict[str, float]: @@ -169,88 +133,28 @@ def to_gb_dict(self) -> Dict[str, float]: } -def _first_scalar(value): - # why: ERNIE MoE configs ship moe_intermediate_size / moe_num_experts as - # [routed, shared] lists; downstream arithmetic needs the routed scalar. - if isinstance(value, (list, tuple)): - return value[0] if value else None - return value - - -def _max_scalar(value): - # why: Hunyuan-V1-MoE moe_topk can be a per-layer list; activation - # accounting uses the max top-k as a conservative upper bound. - if isinstance(value, (list, tuple)): - items = [v for v in value if v is not None] - return max(items) if items else None - return value - - -def _compute_dense_layer_indices(text_config, total_layers: int) -> tuple: - """Layer indices that use dense MLP instead of MoE. Position matters.""" - # why: transformers Exaone-MoE / Laguna / Hy_v3 / GLM-MoE-DSA / GLM4-MoE-Lite / - # Ernie4_5_VL_MoE prefer per-position `mlp_layer_types` over the prefix-style - # `first_k_dense_replace` and may omit `decoder_sparse_step` entirely. - layer_types = getattr(text_config, "mlp_layer_types", None) - if layer_types: - return tuple( - i - for i, t in enumerate(layer_types[:total_layers]) - if str(t).lower() == "dense" - ) - - # why: Llama4TextConfig.__init__ auto-populates self.moe_layers from - # interleave_moe_layer_step; Llama4TextDecoderLayer dispatches via - # `layer_idx in config.moe_layers` (modeling_llama4.py). - llama4_moe_layers = getattr(text_config, "moe_layers", None) - if llama4_moe_layers is not None: - moe_indices = {int(i) for i in llama4_moe_layers} - return tuple(i for i in range(total_layers) if i not in moe_indices) - - # why: transformers ERNIE 4.5 MoE / ERNIE 4.5 VL MoE declare MoE layers - # via moe_layer_start_index / moe_layer_end_index / moe_layer_interval; - # the model's per-layer guard is `(layer_idx + 1) % interval == 0` with - # start <= layer_idx <= end (modeling_ernie4_5_moe.py). - moe_start = getattr(text_config, "moe_layer_start_index", None) - moe_interval = getattr(text_config, "moe_layer_interval", None) - if moe_start is not None and moe_interval is not None and int(moe_interval) > 0: - moe_end_raw = getattr(text_config, "moe_layer_end_index", None) - end = ( - total_layers - if moe_end_raw is None or int(moe_end_raw) == -1 - else min(int(moe_end_raw) + 1, total_layers) - ) - start = max(0, int(moe_start)) - interval = int(moe_interval) - moe_indices = {i for i in range(start, end) if (i + 1) % interval == 0} - return tuple(i for i in range(total_layers) if i not in moe_indices) - +def _compute_num_dense_layers(text_config, total_layers: int) -> int: + """Count how many layers use dense MLP instead of MoE.""" first_k = getattr(text_config, "first_k_dense_replace", None) if first_k is not None: - return tuple(range(min(int(first_k), total_layers))) + return min(int(first_k), total_layers) sparse_step = getattr(text_config, "decoder_sparse_step", None) mlp_only = getattr(text_config, "mlp_only_layers", None) or [] if sparse_step is not None and sparse_step > 0: - mlp_only_set = {int(i) for i in mlp_only} - return tuple( - i + mlp_only_set = set(mlp_only) + moe_count = sum( + 1 for i in range(total_layers) - if i in mlp_only_set or (i + 1) % sparse_step != 0 + if i not in mlp_only_set and (i + 1) % sparse_step == 0 ) - return () + return total_layers - moe_count + + return 0 def extract_arch_config(hf_config) -> Optional[ModelArchConfig]: text_config = getattr(hf_config, "text_config", None) or hf_config - quantization_config = getattr(hf_config, "quantization_config", None) or {} - if not isinstance(quantization_config, dict): - quantization_config = getattr(quantization_config, "to_dict", lambda: {})() - quant_4bit_factor = ( - DOUBLE_QUANT_4BIT_FACTOR - if quantization_config.get("bnb_4bit_use_double_quant", False) - else QUANT_4BIT_FACTOR - ) hidden_size = getattr(text_config, "hidden_size", None) num_layers = getattr(text_config, "num_hidden_layers", None) @@ -273,75 +177,18 @@ def extract_arch_config(hf_config) -> Optional[ModelArchConfig]: num_kv_heads = getattr(text_config, "num_key_value_heads", num_heads) - # why: DBRX places its MoE attrs on the DbrxFFNConfig sub-config; probe - # ffn_config as a secondary source so DBRX is not misclassified as dense. - ffn_config = getattr(text_config, "ffn_config", None) - - def _moe_attr(name): - value = getattr(text_config, name, None) - if value is None and ffn_config is not None: - value = getattr(ffn_config, name, None) - return value - num_experts = None - for attr in ( - "num_local_experts", - "num_experts", - "n_routed_experts", - "moe_num_experts", - ): - num_experts = _first_scalar(_moe_attr(attr)) + for attr in ("num_local_experts", "num_experts", "n_routed_experts"): + num_experts = getattr(text_config, attr, None) if num_experts is not None: break - moe_intermediate_raw = _moe_attr("moe_intermediate_size") - if moe_intermediate_raw is None: - moe_intermediate_raw = _moe_attr("ffn_hidden_size") - moe_intermediate = _first_scalar(moe_intermediate_raw) - # why: Exaone-MoE / ERNIE families alias num_shared_experts / - # moe_num_shared_experts to the canonical n_shared_experts. - n_shared_experts = ( - _first_scalar(_moe_attr("n_shared_experts")) - or _first_scalar(_moe_attr("num_shared_experts")) - or _first_scalar(_moe_attr("moe_num_shared_experts")) - or 0 - ) - shared_expert_intermediate_size = _moe_attr("shared_expert_intermediate_size") - if shared_expert_intermediate_size and n_shared_experts == 0: - n_shared_experts = 1 - # why: DBRX exposes moe_top_k, Hunyuan-V1-MoE exposes moe_topk (which can - # be a per-layer list); _max_scalar normalizes list values to the worst - # case so int(...) below cannot crash on the canonical attribute_map path. - num_experts_per_tok = ( - _max_scalar(_moe_attr("num_experts_per_tok")) - or _max_scalar(_moe_attr("top_k_experts")) - or _max_scalar(_moe_attr("moe_top_k")) - or _max_scalar(_moe_attr("moe_topk")) - or 1 - ) + moe_intermediate = getattr(text_config, "moe_intermediate_size", None) + n_shared_experts = getattr(text_config, "n_shared_experts", None) or 0 - dense_layer_indices: tuple = () + num_dense_layers = 0 if num_experts is not None and num_experts > 1: - dense_layer_indices = _compute_dense_layer_indices(text_config, num_layers) - num_dense_layers = len(dense_layer_indices) - - # why: Llama4 dense layers use intermediate_size_mlp; routed and shared - # experts use intermediate_size. Llama4TextMoe builds one shared_expert - # per MoE layer (modeling_llama4.py). - intermediate_size_mlp_raw = _first_scalar(_moe_attr("intermediate_size_mlp")) - dense_intermediate_size = ( - int(intermediate_size_mlp_raw) - if intermediate_size_mlp_raw is not None - else None - ) - if ( - intermediate_size_mlp_raw is not None - and num_experts is not None - and num_experts > 1 - and shared_expert_intermediate_size is None - and n_shared_experts == 0 - ): - n_shared_experts = 1 + num_dense_layers = _compute_num_dense_layers(text_config, num_layers) q_lora_rank = getattr(text_config, "q_lora_rank", None) kv_lora_rank = getattr(text_config, "kv_lora_rank", None) @@ -360,416 +207,13 @@ def _moe_attr(name): num_experts = num_experts, moe_intermediate_size = moe_intermediate, n_shared_experts = n_shared_experts, - shared_expert_intermediate_size = shared_expert_intermediate_size, - num_experts_per_tok = int(num_experts_per_tok), num_dense_layers = num_dense_layers, q_lora_rank = q_lora_rank, kv_lora_rank = kv_lora_rank, qk_nope_head_dim = qk_nope_head_dim, qk_rope_head_dim = qk_rope_head_dim, v_head_dim = v_head_dim, - head_dim = getattr(text_config, "head_dim", None), - global_head_dim = getattr(text_config, "global_head_dim", None), - num_global_key_value_heads = getattr( - text_config, - "num_global_key_value_heads", - None, - ), - attention_k_eq_v = bool(getattr(text_config, "attention_k_eq_v", False)), - layer_types = getattr(text_config, "layer_types", None), - num_kv_shared_layers = getattr(text_config, "num_kv_shared_layers", None) or 0, - use_double_wide_mlp = bool(getattr(text_config, "use_double_wide_mlp", False)), - vocab_size_per_layer_input = getattr( - text_config, - "vocab_size_per_layer_input", - None, - ) - or 0, - hidden_size_per_layer_input = getattr( - text_config, - "hidden_size_per_layer_input", - None, - ) - or 0, - quantization_skip_modules = list( - quantization_config.get("llm_int8_skip_modules", []) or [] - ), - quant_4bit_factor = quant_4bit_factor, - moe_has_dense_mlp = bool(getattr(text_config, "enable_moe_block", False)), - dense_layer_indices = dense_layer_indices, - dense_intermediate_size = dense_intermediate_size, - ) - - -def _targets_all_linear(target_modules) -> bool: - # why: peft LoraConfig accepts target_modules="all-linear" as a bare - # string; iterating a string yields chars and never matches the set. - if isinstance(target_modules, str): - target_modules = [target_modules] - normalized = {str(module).lower().replace("_", "-") for module in target_modules} - return normalized == {"all-linear"} - - -def _head_dim(arch: ModelArchConfig) -> int: - return arch.head_dim or arch.hidden_size // arch.num_attention_heads - - -def _layer_types(arch: ModelArchConfig) -> list: - if arch.layer_types and len(arch.layer_types) == arch.num_hidden_layers: - return arch.layer_types - return ["full_attention"] * arch.num_hidden_layers - - -def _uses_structured_layer_shapes(arch: ModelArchConfig) -> bool: - # MLA configs have their own q/kv low-rank projection shape formulas in - # _compute_attn_elements / _lora_attn_elements; do not let head_dim or - # other structured fields override that path. - if arch.q_lora_rank is not None: - return False - return bool( - arch.layer_types - or arch.head_dim is not None - or arch.global_head_dim is not None - or arch.num_global_key_value_heads is not None - or arch.attention_k_eq_v - or arch.num_kv_shared_layers > 0 - or arch.use_double_wide_mlp - ) - - -def _is_kv_shared_layer(arch: ModelArchConfig, layer_idx: int) -> bool: - if arch.num_kv_shared_layers <= 0: - return False - first_shared = arch.num_hidden_layers - arch.num_kv_shared_layers - # why: transformers Gemma4 (modeling_gemma4.py:1031, modular_gemma4.py:863) - # uses the same `> 0` guard so a fully-shared config raises during model - # construction; matching upstream avoids producing a detailed estimate - # for a shape the actual model code rejects. - return layer_idx >= first_shared > 0 - - -def _is_dense_mlp_layer(arch: ModelArchConfig, layer_idx: int) -> bool: - if arch.dense_layer_indices: - return layer_idx in arch.dense_layer_indices - return layer_idx < arch.num_dense_layers - - -def _per_layer_input_quantizable(arch: ModelArchConfig) -> int: - # why: Gemma4 PLE block adds per_layer_model_projection (single Linear), - # per_layer_input_gate (per layer), and per_layer_projection (per layer); - # see transformers gemma4/modular_gemma4.py:1077-1083 and :1247-1253. - pli = arch.hidden_size_per_layer_input - if pli <= 0: - return 0 - n_layers = arch.num_hidden_layers - hd = arch.hidden_size - return hd * (n_layers * pli) + (hd * pli) * n_layers + (pli * hd) * n_layers - - -def _per_layer_input_norm_elements(arch: ModelArchConfig) -> int: - pli = arch.hidden_size_per_layer_input - if pli <= 0: - return 0 - n_layers = arch.num_hidden_layers - hd = arch.hidden_size - return hd * n_layers + pli - - -def _per_layer_input_lora_params( - arch: ModelArchConfig, - r: int, - target_modules, -) -> int: - # why: Unsloth's get_peft_regex (unsloth_zoo/peft_utils.py) requires module - # names to contain a component tag (mlp/attn/...); PLE module names lack - # any tag, so all-linear training does NOT attach LoRA to them. Only count - # PLE LoRA when the user explicitly names PLE modules. - pli = arch.hidden_size_per_layer_input - if pli <= 0: - return 0 - targets = ( - {target_modules} - if isinstance(target_modules, str) - else set(target_modules or []) - ) - n_layers = arch.num_hidden_layers - hd = arch.hidden_size - total = 0 - if "per_layer_model_projection" in targets: - total += hd * r + r * (n_layers * pli) - if "per_layer_input_gate" in targets: - total += (hd * r + r * pli) * n_layers - if "per_layer_projection" in targets: - total += (pli * r + r * hd) * n_layers - return total - - -def _layer_attention_dims(arch: ModelArchConfig, layer_idx: int) -> tuple: - layer_types = _layer_types(arch) - layer_type = layer_types[layer_idx] - is_sliding = layer_type == "sliding_attention" - head_dim = ( - arch.global_head_dim - if not is_sliding and arch.global_head_dim - else _head_dim(arch) - ) - use_alt_attention = arch.attention_k_eq_v and not is_sliding - num_kv_heads = ( - arch.num_global_key_value_heads - if use_alt_attention and arch.num_global_key_value_heads - else arch.num_key_value_heads - ) - q_size = arch.num_attention_heads * head_dim - kv_size = num_kv_heads * head_dim - has_k = not _is_kv_shared_layer(arch, layer_idx) - has_v = has_k and not use_alt_attention - return q_size, kv_size, has_k, has_v - - -def _layer_mlp_size(arch: ModelArchConfig, layer_idx: int) -> int: - if arch.use_double_wide_mlp and _is_kv_shared_layer(arch, layer_idx): - return _dense_mlp_size(arch) * 2 - return _dense_mlp_size(arch) - - -def _text_linear_dims( - arch: ModelArchConfig, - layer_idx: int, -) -> Dict[str, tuple[int, int]]: - hd = arch.hidden_size - if _uses_structured_layer_shapes(arch): - q_size, kv_size, has_k, has_v = _layer_attention_dims(arch, layer_idx) - mlp_size = _layer_mlp_size(arch, layer_idx) - else: - q_size = hd - kv_size = _get_kv_size(arch) - has_k = True - has_v = True - mlp_size = _get_mlp_size(arch) - - dims = { - "q_proj": (hd, q_size), - "o_proj": (q_size, hd), - } - if has_k: - dims["k_proj"] = (hd, kv_size) - if has_v: - dims["v_proj"] = (hd, kv_size) - - dims.update( - { - "gate_proj": (hd, mlp_size), - "up_proj": (hd, mlp_size), - "down_proj": (mlp_size, hd), - } ) - return dims - - -def _module_path_matches(skip_module: str, alias: str) -> bool: - skip_parts = [part for part in skip_module.split(".") if part] - alias_parts = [part for part in alias.split(".") if part] - if not skip_parts or not alias_parts: - return False - if alias_parts[0] == "layers": - return skip_parts == alias_parts - if len(skip_parts) <= len(alias_parts): - # why: transformers BNB quantizer suffix-matches short skip entries - # like ["q_proj"] / ["lm_head"] against full module paths, so a skip - # shorter than the alias is a tail match. - return alias_parts[-len(skip_parts) :] == skip_parts - if skip_parts[-len(alias_parts) :] != alias_parts: - return False - prefix_parts = skip_parts[: len(skip_parts) - len(alias_parts)] - if not prefix_parts: - return True - # why: bound the prefix to known text-tower roots so VLM skip names like - # vision_tower.model.layers..self_attn.q_proj do not shadow the text - # alias model.layers..self_attn.q_proj. - return ".".join(prefix_parts) in _SKIP_MODULE_TEXT_PREFIXES - - -def _add_module_aliases( - aliases: Dict[str, str], - canonical: str, - suffix: str, -) -> None: - for prefix in ( - "", - "model", - "model.model", - "language_model", - "language_model.model", - "model.language_model", - "model.language_model.model", - ): - alias = f"{prefix}.{suffix}" if prefix else suffix - aliases[alias] = canonical - - -def _build_text_module_elements( - arch: ModelArchConfig, -) -> tuple[Dict[str, int], Dict[str, str]]: - elements: Dict[str, int] = {} - aliases: Dict[str, str] = {} - - is_mla = arch.q_lora_rank is not None and not _uses_structured_layer_shapes(arch) - pli = arch.hidden_size_per_layer_input - hd_global = arch.hidden_size - - for layer_idx in range(arch.num_hidden_layers): - layer_modules: Dict[str, int] = {} - dims = _text_linear_dims(arch, layer_idx) - attn_dims = { - name: dim for name, dim in dims.items() if name in ATTENTION_TARGET_MODULES - } - mlp_dims = { - name: dim for name, dim in dims.items() if name in MLP_TARGET_MODULES - } - - if is_mla: - # why: _text_linear_dims uses (hd, hd) for q/o; MLA actually splits - # into q_a/q_b/kv_a/kv_b, so emit a single self_attn aggregate at - # the authoritative MLA per-layer total. - layer_modules["self_attn"] = _compute_attn_elements(arch) - else: - for name, (in_dim, out_dim) in attn_dims.items(): - layer_modules[f"self_attn.{name}"] = in_dim * out_dim - - if arch.num_experts and arch.num_experts > 1: - if _is_dense_mlp_layer(arch, layer_idx): - layer_modules.update( - { - f"mlp.{name}": in_dim * out_dim - for name, (in_dim, out_dim) in mlp_dims.items() - } - ) - else: - layer_modules["mlp.experts"] = _compute_routed_moe_elements(arch) - shared_moe = _compute_shared_moe_elements(arch) - if shared_moe: - # why: Qwen3.5-MoE exposes shared expert as - # mlp.shared_expert; Exaone-MoE/Laguna/GLM-style configs use - # mlp.shared_experts. Register both names so child-path - # llm_int8_skip_modules entries match the right shared block. - layer_modules["mlp.shared_expert"] = shared_moe - if arch.moe_has_dense_mlp: - # why: enable_moe_block runs the dense MLP and the MoE - # experts in parallel; register both for skip matching. - # Non-structured _text_linear_dims returns mlp_size from - # _get_mlp_size which prefers moe_intermediate_size, so - # rebuild dense dims from arch.intermediate_size directly. - if _uses_structured_layer_shapes(arch): - dense_dims = mlp_dims - else: - hd = arch.hidden_size - inter = arch.intermediate_size - dense_dims = { - "gate_proj": (hd, inter), - "up_proj": (hd, inter), - "down_proj": (inter, hd), - } - layer_modules.update( - { - f"mlp.{name}": in_dim * out_dim - for name, (in_dim, out_dim) in dense_dims.items() - } - ) - else: - layer_modules.update( - { - f"mlp.{name}": in_dim * out_dim - for name, (in_dim, out_dim) in mlp_dims.items() - } - ) - - if pli > 0: - # why: register PLE per-layer linears so llm_int8_skip_modules - # entries like model.layers.0.per_layer_input_gate match. - layer_modules["per_layer_input_gate"] = hd_global * pli - layer_modules["per_layer_projection"] = pli * hd_global - - attn_total = sum( - value - for name, value in layer_modules.items() - if name == "self_attn" or name.startswith("self_attn.") - ) - # why: gemma4 enable_moe_block puts routed experts at the sibling - # layers..experts attribute, not under self.mlp; the layer's "mlp" - # aggregate must reflect only the dense MLP path so a skip module - # `model.layers.0.mlp` does not over-skip into the experts block. - is_sibling_experts = bool(arch.moe_has_dense_mlp) - mlp_total = sum( - value - for name, value in layer_modules.items() - if ( - name == "mlp" - or ( - name.startswith("mlp.") - and not (is_sibling_experts and name == "mlp.experts") - ) - ) - ) - experts_total = layer_modules.get("mlp.experts", 0) if is_sibling_experts else 0 - layer_total = sum(layer_modules.values()) - - aggregate_modules = { - f"text.layers.{layer_idx}": layer_total, - f"text.layers.{layer_idx}.self_attn": attn_total, - f"text.layers.{layer_idx}.mlp": mlp_total, - } - if experts_total: - aggregate_modules[f"text.layers.{layer_idx}.experts"] = experts_total - elements.update(aggregate_modules) - for canonical in aggregate_modules: - suffix = canonical.removeprefix("text.") - _add_module_aliases(aliases, canonical, suffix) - - for name, value in layer_modules.items(): - canonical = f"text.layers.{layer_idx}.{name}" - elements[canonical] = value - _add_module_aliases(aliases, canonical, canonical.removeprefix("text.")) - if name == "mlp.experts" and arch.moe_has_dense_mlp: - # why: gemma4 enable_moe_block exposes routed experts at - # layers..experts (sibling of self.mlp), not under mlp. - _add_module_aliases(aliases, canonical, f"layers.{layer_idx}.experts") - elif name == "mlp.shared_expert": - # why: Exaone-MoE / Laguna / GLM-style configs use the plural - # `shared_experts` attribute name; register both spellings. - _add_module_aliases( - aliases, - canonical, - f"layers.{layer_idx}.mlp.shared_experts", - ) - - if pli > 0: - canonical = "text.per_layer_model_projection" - elements[canonical] = hd_global * (arch.num_hidden_layers * pli) - _add_module_aliases(aliases, canonical, canonical.removeprefix("text.")) - - return elements, aliases - - -def _compute_skipped_quantizable_elements(arch: ModelArchConfig) -> int: - if not arch.quantization_skip_modules: - return 0 - - module_elements, aliases = _build_text_module_elements(arch) - matched = set() - for skip_module in arch.quantization_skip_modules: - for alias, canonical in aliases.items(): - if _module_path_matches(skip_module, alias): - matched.add(canonical) - - pruned = { - canonical - for canonical in matched - if not any( - canonical != parent and canonical.startswith(f"{parent}.") - for parent in matched - ) - } - return sum(module_elements[canonical] for canonical in pruned) def _get_kv_size(arch: ModelArchConfig) -> int: @@ -782,12 +226,6 @@ def _get_mlp_size(arch: ModelArchConfig) -> int: return arch.intermediate_size -def _dense_mlp_size(arch: ModelArchConfig) -> int: - # why: Llama4 dense layers use intermediate_size_mlp; routed/shared - # experts use intermediate_size. Other configs leave the field None. - return arch.dense_intermediate_size or arch.intermediate_size - - def _get_num_experts(arch: ModelArchConfig) -> int: return arch.num_experts if arch.num_experts and arch.num_experts > 1 else 1 @@ -810,39 +248,14 @@ def _compute_attn_elements(arch: ModelArchConfig) -> int: def _compute_dense_mlp_elements(arch: ModelArchConfig) -> int: - return arch.hidden_size * _dense_mlp_size(arch) * 3 - + return arch.hidden_size * arch.intermediate_size * 3 -def _shared_expert_size(arch: ModelArchConfig) -> int: - # why: Qwen3.5-MoE shared expert has its own intermediate_size (default 512) - # distinct from moe_intermediate_size; fall back to routed mlp_size for - # families that share it (deepseek-style configs). - return arch.shared_expert_intermediate_size or _get_mlp_size(arch) - -def _compute_routed_moe_elements(arch: ModelArchConfig) -> int: +def _compute_moe_mlp_elements(arch: ModelArchConfig) -> int: hd = arch.hidden_size + mlp_size = _get_mlp_size(arch) n_experts = _get_num_experts(arch) - return hd * _get_mlp_size(arch) * 3 * n_experts + n_experts * hd - - -def _compute_shared_moe_elements(arch: ModelArchConfig) -> int: - if not arch.n_shared_experts: - return 0 - hd = arch.hidden_size - shared_size = _shared_expert_size(arch) - total = hd * shared_size * 3 * arch.n_shared_experts - # why: only Qwen2-MoE / Qwen3.5-MoE define a shared_expert_gate Linear - # (hidden_size→1); other families (Exaone-MoE, HY-V3, GLM4-MoE-Lite, Laguna) - # have shared_experts without a gate. shared_expert_intermediate_size is the - # Qwen-style discriminator. - if arch.shared_expert_intermediate_size: - total += arch.n_shared_experts * hd - return total - - -def _compute_moe_mlp_elements(arch: ModelArchConfig) -> int: - return _compute_routed_moe_elements(arch) + _compute_shared_moe_elements(arch) + return hd * mlp_size * 3 * (n_experts + arch.n_shared_experts) + n_experts * hd def _compute_layer_elements(arch: ModelArchConfig): @@ -854,60 +267,22 @@ def _compute_layer_elements(arch: ModelArchConfig): n_layers = arch.num_hidden_layers n_experts = _get_num_experts(arch) - if _uses_structured_layer_shapes(arch): - attn_total = 0 - per_layer_dense_mlp = [] - for layer_idx in range(n_layers): - layer_dense_mlp = 0 - for name, (in_dim, out_dim) in _text_linear_dims( - arch, - layer_idx, - ).items(): - elements = in_dim * out_dim - if name in ATTENTION_TARGET_MODULES: - attn_total += elements - elif name in MLP_TARGET_MODULES: - layer_dense_mlp += elements - per_layer_dense_mlp.append(layer_dense_mlp) - if n_experts > 1: - n_dense = arch.num_dense_layers - n_moe = n_layers - n_dense - moe_mlp_total = _compute_moe_mlp_elements(arch) * n_moe - if arch.moe_has_dense_mlp: - # why: enable_moe_block runs dense MLP and MoE experts in - # parallel; count dense for every layer alongside MoE. - mlp_total = sum(per_layer_dense_mlp) + moe_mlp_total - else: - dense_only_total = sum( - value - for i, value in enumerate(per_layer_dense_mlp) - if _is_dense_mlp_layer(arch, i) - ) - mlp_total = moe_mlp_total + dense_only_total - else: - mlp_total = sum(per_layer_dense_mlp) - elif n_experts > 1: - attn_total = _compute_attn_elements(arch) * n_layers + attn_total = _compute_attn_elements(arch) * n_layers + + if n_experts > 1: n_dense = arch.num_dense_layers n_moe = n_layers - n_dense - moe_mlp_total = _compute_moe_mlp_elements(arch) * n_moe - if arch.moe_has_dense_mlp: - mlp_total = _compute_dense_mlp_elements(arch) * n_layers + moe_mlp_total - else: - mlp_total = moe_mlp_total + _compute_dense_mlp_elements(arch) * n_dense + mlp_total = ( + _compute_moe_mlp_elements(arch) * n_moe + + _compute_dense_mlp_elements(arch) * n_dense + ) else: - attn_total = _compute_attn_elements(arch) * n_layers mlp_total = _compute_dense_mlp_elements(arch) * n_layers layernorms = 2 * hd - per_layer_embed = ( - arch.vocab_size_per_layer_input * arch.hidden_size_per_layer_input * n_layers - ) - ple_text_linear = _per_layer_input_quantizable(arch) - ple_norms = _per_layer_input_norm_elements(arch) - embed_tokens = arch.vocab_size * hd + per_layer_embed + ple_norms + embed_tokens = arch.vocab_size * hd lm_head = 0 if arch.tie_word_embeddings else arch.vocab_size * hd - return attn_total + mlp_total + ple_text_linear, layernorms, embed_tokens, lm_head + return attn_total + mlp_total, layernorms, embed_tokens, lm_head def compute_model_weights_bytes( @@ -920,16 +295,7 @@ def compute_model_weights_bytes( non_quantizable = layernorms * n_layers + embed_tokens + lm_head if training_method == "qlora" and load_in_4bit: - skipped_quantizable = min( - _compute_skipped_quantizable_elements(arch), - total_quantizable, - ) - quantized = total_quantizable - skipped_quantizable - return int( - quantized * 2 / arch.quant_4bit_factor - + skipped_quantizable * 2 - + non_quantizable * 2 - ) + return int(total_quantizable * 2 / QUANT_4BIT_FACTOR + non_quantizable * 2) return int((total_quantizable + non_quantizable) * 2) @@ -997,130 +363,46 @@ def compute_lora_params( lora_rank: int, target_modules: list, ) -> int: - all_linear = _targets_all_linear(target_modules) - selected_modules = list(DEFAULT_TARGET_MODULES) if all_linear else target_modules hd = arch.hidden_size r = lora_rank n_layers = arch.num_hidden_layers n_experts = _get_num_experts(arch) - use_structured_shapes = _uses_structured_layer_shapes(arch) - if use_structured_shapes: - attn_total = 0 - structured_dense_mlp = 0 - per_layer_dense_mlp = [] - for layer_idx in range(n_layers): - layer_dense = 0 - for name, (in_dim, out_dim) in _text_linear_dims( - arch, - layer_idx, - ).items(): - if name not in selected_modules: - continue - if name in ATTENTION_TARGET_MODULES: - attn_total += in_dim * r + r * out_dim - elif name in MLP_TARGET_MODULES: - layer_dense += in_dim * r + r * out_dim - per_layer_dense_mlp.append(layer_dense) - structured_dense_mlp += layer_dense - if n_experts > 1: - n_dense = arch.num_dense_layers - n_moe = n_layers - n_dense - # why: peft "all-linear" attaches LoRA to nn.Linear only; - # routed experts are nn.Parameter and need explicit - # gate_proj/up_proj/down_proj naming via Unsloth's - # get_moe_target_parameters. Shared experts are nn.Linear and - # are picked up by get_peft_regex. - routed_moe = ( - 0 - if all_linear - else _lora_mlp_elements( - hd, - _get_mlp_size(arch), - r, - selected_modules, - n_experts, - ) - ) - shared_moe = _lora_mlp_elements( - hd, - _shared_expert_size(arch), - r, - selected_modules, - arch.n_shared_experts, - ) - moe_mlp = routed_moe + shared_moe - if arch.moe_has_dense_mlp: - # why: parallel dense MLP coexists with MoE on every layer. - mlp_total = structured_dense_mlp + moe_mlp * n_moe - else: - dense_only = sum( - value - for i, value in enumerate(per_layer_dense_mlp) - if _is_dense_mlp_layer(arch, i) - ) - mlp_total = moe_mlp * n_moe + dense_only - else: - mlp_total = structured_dense_mlp - return ( - attn_total - + mlp_total - + _per_layer_input_lora_params(arch, r, target_modules) - ) - elif n_experts > 1: - attn_total = _lora_attn_elements(arch, r, selected_modules) * n_layers + attn_total = _lora_attn_elements(arch, r, target_modules) * n_layers + + if n_experts > 1: n_dense = arch.num_dense_layers n_moe = n_layers - n_dense - # why: routed and shared experts may use different intermediate sizes - # (Qwen3.5-MoE: routed mlp_size != shared_expert_intermediate_size). - # See structured branch for the all-linear exclusion rationale; only - # routed (nn.Parameter) experts are excluded under all-linear. - routed_moe = ( - 0 - if all_linear - else _lora_mlp_elements( - hd, - _get_mlp_size(arch), - r, - selected_modules, - n_experts, - ) - ) - shared_moe = _lora_mlp_elements( + # Include shared experts alongside routed experts + moe_expert_mult = n_experts + arch.n_shared_experts + moe_mlp = _lora_mlp_elements( hd, - _shared_expert_size(arch), + _get_mlp_size(arch), r, - selected_modules, - arch.n_shared_experts, + target_modules, + moe_expert_mult, ) - moe_mlp = routed_moe + shared_moe dense_mlp = _lora_mlp_elements( hd, - _dense_mlp_size(arch), + arch.intermediate_size, r, - selected_modules, + target_modules, 1, ) - if arch.moe_has_dense_mlp: - mlp_total = moe_mlp * n_moe + dense_mlp * n_layers - else: - mlp_total = moe_mlp * n_moe + dense_mlp * n_dense + mlp_total = moe_mlp * n_moe + dense_mlp * n_dense else: - attn_total = _lora_attn_elements(arch, r, selected_modules) * n_layers mlp_total = ( _lora_mlp_elements( hd, - _dense_mlp_size(arch), + arch.intermediate_size, r, - selected_modules, + target_modules, 1, ) * n_layers ) - return ( - attn_total + mlp_total + _per_layer_input_lora_params(arch, r, target_modules) - ) + return attn_total + mlp_total def compute_lora_adapter_bytes(lora_params: int) -> int: @@ -1137,88 +419,26 @@ def compute_gradient_bytes(trainable_params: int) -> int: return trainable_params * 2 -def _is_linear_attention(attention_implementation: Optional[str]) -> bool: - # why: PyTorch SDPA dispatches to flash/memory-efficient O(n) backends; only - # eager (and other non-flash impls) need the quadratic correction. - return attention_implementation in LINEAR_ATTENTION_IMPLS - - -def _compute_non_flash_attention_bytes( - arch: ModelArchConfig, - batch_size: int, - seq_len: int, - effective_layers: float, -) -> int: - score_elements = batch_size * arch.num_attention_heads * seq_len * seq_len - return int(score_elements * 2 * NON_FLASH_ATTENTION_FACTOR * effective_layers) - - -def _layer_qkv_mlp_sizes(arch: ModelArchConfig, layer_idx: int) -> tuple: - n_experts = _get_num_experts(arch) - is_moe_layer = n_experts > 1 and not _is_dense_mlp_layer(arch, layer_idx) - if _uses_structured_layer_shapes(arch): - q_size, kv_size, _has_k, _has_v = _layer_attention_dims(arch, layer_idx) - # why: KV-shared layers (Gemma4/Gemma3n) drop k_proj/v_proj WEIGHTS but - # the donor layer's K/V tensors stay alive across the shared range, so - # activation memory still pays for kv_size; only the weight path uses - # has_k/has_v. - layer_type = _layer_types(arch)[layer_idx] - use_alt_attention = arch.attention_k_eq_v and layer_type != "sliding_attention" - kv_count = 1 if use_alt_attention else 2 - qkv_size = q_size + kv_size * kv_count - if is_moe_layer: - # why: each token routes through `num_experts_per_tok` experts; their - # gate/up/down intermediates are all live during MLP forward. - mlp_size = _get_mlp_size(arch) * arch.num_experts_per_tok - if arch.n_shared_experts: - mlp_size += _shared_expert_size(arch) * arch.n_shared_experts - if arch.moe_has_dense_mlp: - mlp_size += _layer_mlp_size(arch, layer_idx) - else: - mlp_size = _layer_mlp_size(arch, layer_idx) - return qkv_size, mlp_size - kv_size = _get_kv_size(arch) - if is_moe_layer: - mlp_size = _get_mlp_size(arch) * arch.num_experts_per_tok - if arch.n_shared_experts: - mlp_size += _shared_expert_size(arch) * arch.n_shared_experts - if arch.moe_has_dense_mlp: - mlp_size += arch.intermediate_size - else: - mlp_size = _get_mlp_size(arch) - return arch.hidden_size + kv_size + kv_size, mlp_size - - -def _per_layer_activation_bytes( - arch: ModelArchConfig, - layer_idx: int, - batch_size: int, - seq_len: int, -) -> int: - qkv_size, mlp_size = _layer_qkv_mlp_sizes(arch, layer_idx) - activation_qkv = seq_len * batch_size * qkv_size - residual_memory = (seq_len * batch_size) * 2 - activation_mlp = seq_len * batch_size * (mlp_size + mlp_size) - # why: per_layer_input_gate (hd-sized) and per_layer_projection (pli-sized) - # outputs materialize once per decoder layer when hidden_size_per_layer_input - # is set; see gemma4/modular_gemma4.py:1141-1145. - pli = arch.hidden_size_per_layer_input - activation_ple = seq_len * batch_size * (arch.hidden_size + pli) if pli > 0 else 0 - return int( - (activation_qkv + residual_memory + activation_mlp + activation_ple) * 2 * 1.25 - ) - - def compute_activation_bytes( arch: ModelArchConfig, batch_size: int, seq_len: int, gradient_checkpointing: str, is_lora: bool = False, - attention_implementation: Optional[str] = "flash_attention_2", ) -> int: + hd = arch.hidden_size + kv_size = _get_kv_size(arch) + mlp_size = _get_mlp_size(arch) + bsz = batch_size n_layers = arch.num_hidden_layers + activation_qkv = seq_len * bsz * (hd + kv_size + kv_size) + residual_memory = (seq_len * bsz) * 2 + activation_mlp = seq_len * bsz * (mlp_size + mlp_size) + + per_layer_bytes = (activation_qkv + residual_memory + activation_mlp) * 2 + per_layer_bytes = int(per_layer_bytes * 1.25) + gc_key = gradient_checkpointing.lower() gc_entry = GC_LAYER_MULTIPLIERS.get(gc_key, (None, None)) full_ft_mult, lora_mult = gc_entry @@ -1226,35 +446,10 @@ def compute_activation_bytes( if gc_multiplier is None: effective_layers = n_layers - linear_bytes = sum( - _per_layer_activation_bytes(arch, i, batch_size, seq_len) - for i in range(n_layers) - ) else: effective_layers = gc_multiplier - max_layer_bytes = max( - _per_layer_activation_bytes(arch, i, batch_size, seq_len) - for i in range(n_layers) - ) - linear_bytes = int(max_layer_bytes * effective_layers) - - # why: gemma4 per_layer_model_projection runs once outside the per-decoder - # loop and materializes a [B, S, L, PLI] tensor; see modular_gemma4.py:1247. - pli = arch.hidden_size_per_layer_input - if pli > 0: - linear_bytes += int(seq_len * batch_size * n_layers * pli * 2 * 1.25) - - if _is_linear_attention(attention_implementation): - return linear_bytes - return max( - linear_bytes, - _compute_non_flash_attention_bytes( - arch, - batch_size, - seq_len, - effective_layers, - ), - ) + + return int(per_layer_bytes * effective_layers) def estimate_training_vram( @@ -1279,23 +474,21 @@ def estimate_training_vram( trainable_params = lora_params if is_lora else compute_total_params(arch) optimizer_bytes = compute_optimizer_bytes(trainable_params, config.optimizer) + gradient_bytes = max( + compute_gradient_bytes(trainable_params), + int(model_weights * 0.15), + ) activations_computed = compute_activation_bytes( arch, config.batch_size, config.max_seq_length, config.gradient_checkpointing, is_lora = is_lora, - attention_implementation = config.attention_implementation, ) - raw_gradient_bytes = compute_gradient_bytes(trainable_params) - gradient_floor = int(model_weights * 0.15) - if is_lora: - gradient_floor = min( - gradient_floor, - max(activations_computed, optimizer_bytes), - ) - gradient_bytes = max(raw_gradient_bytes, gradient_floor) - activation_bytes = activations_computed + activation_bytes = max( + activations_computed, + int(model_weights * 0.15 * (config.batch_size / 2)), + ) return VramBreakdown( model_weights = model_weights, diff --git a/studio/backend/utils/models/__init__.py b/studio/backend/utils/models/__init__.py index 808e2b012e..a81682d8b7 100644 --- a/studio/backend/utils/models/__init__.py +++ b/studio/backend/utils/models/__init__.py @@ -13,9 +13,8 @@ detect_audio_type, is_audio_input_type, VALID_AUDIO_TYPES, - scan_trained_models, + scan_trained_loras, scan_exported_models, - get_base_model_from_checkpoint, load_model_defaults, get_base_model_from_lora, load_model_config, @@ -26,8 +25,6 @@ ) from .checkpoints import scan_checkpoints -scan_trained_loras = scan_trained_models - __all__ = [ "ModelConfig", "GgufVariantInfo", @@ -36,10 +33,8 @@ "detect_audio_type", "is_audio_input_type", "VALID_AUDIO_TYPES", - "scan_trained_models", "scan_trained_loras", "scan_exported_models", - "get_base_model_from_checkpoint", "load_model_defaults", "get_base_model_from_lora", "load_model_config", diff --git a/studio/backend/utils/models/model_config.py b/studio/backend/utils/models/model_config.py index 16f6d21edb..5be0b183d5 100644 --- a/studio/backend/utils/models/model_config.py +++ b/studio/backend/utils/models/model_config.py @@ -32,11 +32,6 @@ import yaml -from utils.native_path_leases import child_env_without_native_path_secret -from utils.subprocess_compat import ( - windows_hidden_subprocess_kwargs as _windows_hidden_subprocess_kwargs, -) - logger = get_logger(__name__) # ── Model size extraction ──────────────────────────────────── @@ -584,8 +579,6 @@ def _is_vision_model_subprocess( capture_output = True, text = True, timeout = 60, - env = child_env_without_native_path_secret(), - **_windows_hidden_subprocess_kwargs(), ) if result.returncode != 0: @@ -911,99 +904,24 @@ def _is_mmproj(filename: str) -> bool: return "mmproj" in filename.lower() -def _is_gguf_filename(filename: str) -> bool: - return filename.lower().endswith(".gguf") - - -def _iter_gguf_files(directory: Path, recursive: bool = False): - if not directory.is_dir(): - return - iterator = directory.rglob("*") if recursive else directory.iterdir() - for f in iterator: - if f.is_file() and _is_gguf_filename(f.name): - yield f - - -def detect_mmproj_file(path: str, search_root: Optional[str] = None) -> Optional[str]: +def detect_mmproj_file(path: str) -> Optional[str]: """ - Find the mmproj (vision projection) GGUF file for a given model. + Find the mmproj (vision projection) GGUF file in a directory. Args: - path: Directory to search — or a .gguf file (uses its parent dir - as the starting point). - search_root: Optional outer directory that should also be scanned - (and any directory between it and ``path``). This handles - local layouts where the model weights live in a quant-named - subdir (``snapshot/BF16/foo.gguf``) but the mmproj sits at - the snapshot root (``snapshot/mmproj-BF16.gguf``). When - ``None``, only the immediate parent dir is scanned, matching - the historical behavior. + path: Directory to search — or a .gguf file (uses its parent dir). Returns: Full path to the mmproj .gguf file, or None if not found. """ p = Path(path) - start_dir = p.parent if p.is_file() else p - if not start_dir.is_dir(): + search_dir = p.parent if p.is_file() else p + if not search_dir.is_dir(): return None - # Build the list of dirs to scan: immediate dir first, then walk up - # to (and including) ``search_root`` if it is an ancestor. We walk - # incrementally rather than recursing into ``search_root`` so we - # don't accidentally pick up an mmproj from a sibling subdir - # belonging to a different model variant. - seen: set[Path] = set() - scan_order: list[Path] = [] - - def _add(d: Path) -> None: - try: - resolved = d.resolve() - except OSError: - return - if resolved in seen or not resolved.is_dir(): - return - seen.add(resolved) - scan_order.append(resolved) - - _add(start_dir) - - # When ``path`` is a symlink (e.g. Ollama's ``.studio_links/...gguf`` - # -> ``blobs/sha256-...``), the symlink's parent directory rarely - # contains the mmproj sibling; the real mmproj file lives next to - # the symlink target. Add the target's parent to the scan so vision - # GGUFs that are surfaced via symlinks are still recognised as - # vision models. - try: - if p.is_symlink() and p.is_file(): - target_parent = p.resolve().parent - if target_parent.is_dir(): - _add(target_parent) - except OSError: - pass - if search_root is not None: - try: - root_resolved = Path(search_root).resolve() - start_resolved = start_dir.resolve() - # Only walk if start_dir is inside (or equal to) search_root. - if root_resolved == start_resolved or ( - start_resolved.is_relative_to(root_resolved) - if hasattr(start_resolved, "is_relative_to") - else str(start_resolved).startswith(str(root_resolved) + "/") - ): - cur = start_resolved - # Walk up from start_dir to (and including) root_resolved. - while cur != root_resolved and cur.parent != cur: - cur = cur.parent - _add(cur) - if cur == root_resolved: - break - except OSError: - pass - - for d in scan_order: - for f in _iter_gguf_files(d): - if _is_mmproj(f.name): - return str(f.resolve()) + for f in search_dir.glob("*.gguf"): + if _is_mmproj(f.name): + return str(f.resolve()) return None @@ -1024,18 +942,15 @@ def detect_gguf_model(path: str) -> Optional[str]: p = Path(path) # Case 1: direct .gguf file - if p.suffix.lower() == ".gguf" and p.is_file(): + if p.suffix == ".gguf" and p.is_file(): if _is_mmproj(p.name): return None - # Use absolute (not resolve) to preserve symlink names -- e.g. - # Ollama .studio_links/model.gguf -> blobs/sha256-... should - # keep the readable symlink name, not the opaque blob hash. - return str(p.absolute()) + return str(p.resolve()) # Case 2: directory containing .gguf files (skip mmproj) if p.is_dir(): gguf_files = sorted( - (f for f in _iter_gguf_files(p) if not _is_mmproj(f.name)), + (f for f in p.glob("*.gguf") if not _is_mmproj(f.name)), key = lambda f: f.stat().st_size, reverse = True, ) @@ -1100,7 +1015,7 @@ def _pick_best_gguf(filenames: list[str]) -> Optional[str]: Prefers quantization levels in _GGUF_QUANT_PREFERENCE order. Falls back to the first .gguf file found. """ - gguf_files = [f for f in filenames if f.lower().endswith(".gguf")] + gguf_files = [f for f in filenames if f.endswith(".gguf")] if not gguf_files: return None @@ -1185,7 +1100,7 @@ def list_gguf_variants( for sibling in info.siblings: fname = sibling.rfilename - if not fname.lower().endswith(".gguf"): + if not fname.endswith(".gguf"): continue size = sibling.size or 0 @@ -1229,11 +1144,9 @@ def _resolve_gguf_dir(p: Path) -> Optional[Path]: return p if p.is_file() and p.suffix.lower() == ".gguf": parent = p.parent - if ( - (parent / "config.json").exists() - or (parent / "adapter_config.json").exists() - or (parent / "export_metadata.json").exists() - ): + if (parent / "config.json").exists() or ( + parent / "adapter_config.json" + ).exists(): return parent return None @@ -1258,11 +1171,7 @@ def list_local_gguf_variants( quant_first_file: dict[str, str] = {} has_vision = False - # Recurse so variant-specific subdirectories (e.g. ``BF16/...gguf`` - # used by some HF GGUF repos for the largest quants) are picked up. - # Filenames in the result preserve the relative subpath so that - # ``_find_local_gguf_by_variant`` can locate the file again. - for f in sorted(_iter_gguf_files(p, recursive = True)): + for f in sorted(p.glob("*.gguf")): if _is_mmproj(f.name): has_vision = True continue @@ -1272,14 +1181,8 @@ def list_local_gguf_variants( size = 0 quant = _extract_quant_label(f.name) quant_totals[quant] = quant_totals.get(quant, 0) + size - # Only compute the (potentially expensive) relative path when this - # is the first file we've seen for this quant -- after that we'd - # discard the result anyway. Use posix-style separators so the - # filename matches what ``list_gguf_variants`` (the remote HF - # API path) returns on every platform; otherwise Windows would - # emit ``BF16\foo.gguf`` here. if quant not in quant_first_file: - quant_first_file[quant] = f.relative_to(p).as_posix() + quant_first_file[quant] = f.name variants = [ GgufVariantInfo( @@ -1305,11 +1208,9 @@ def _find_local_gguf_by_variant(directory: str, variant: str) -> Optional[str]: if p is None: return None - # Recurse into subdirectories so variants stored under a quant-named - # subdir (e.g. ``BF16/foo-BF16-00001-of-00002.gguf``) are found. matches = sorted( f - for f in _iter_gguf_files(p, recursive = True) + for f in p.glob("*.gguf") if not _is_mmproj(f.name) and _extract_quant_label(f.name) == variant ) if matches: @@ -1421,89 +1322,46 @@ def is_embedding_model(model_name: str, hf_token: Optional[str] = None) -> bool: return False -def _has_model_weight_files(model_dir: Path) -> bool: - """Return True when a directory contains loadable model weights.""" - for item in model_dir.iterdir(): - if not item.is_file(): - continue - - suffix = item.suffix.lower() - if suffix == ".safetensors": - return True - if suffix == ".gguf": - return "mmproj" not in item.name.lower() - if suffix == ".bin": - name = item.name.lower() - if ( - name.startswith("pytorch_model") - or name.startswith("model") - or name.startswith("adapter_model") - or name.startswith("consolidated") - ): - return True - return False - - -def _detect_training_output_type(model_dir: Path) -> Optional[str]: - """Classify a Studio training output as LoRA or full finetune.""" - adapter_config = model_dir / "adapter_config.json" - adapter_model = model_dir / "adapter_model.safetensors" - if adapter_config.exists() or adapter_model.exists(): - return "lora" - - config_file = model_dir / "config.json" - if config_file.exists() and _has_model_weight_files(model_dir): - return "merged" - - return None - - -def _looks_like_lora_adapter(model_dir: Path) -> bool: - return model_dir.is_dir() and ( - (model_dir / "adapter_config.json").exists() - or any(model_dir.glob("adapter_model*.safetensors")) - or any(model_dir.glob("adapter_model*.bin")) - ) - - -def scan_trained_models( - outputs_dir: str = str(outputs_root()), -) -> List[Tuple[str, str, str]]: +def scan_trained_loras(outputs_dir: str = str(outputs_root())) -> List[Tuple[str, str]]: """ - Scan outputs folder for trained Studio models. + Scan outputs folder for trained LoRA adapters. Returns: - List of tuples: [(display_name, model_path, model_type), ...] - model_type is "lora" for adapter runs and "merged" for full finetunes. + List of tuples: [(display_name, adapter_path), ...] + + Example: + [ + ("unsloth_Meta-Llama-3.1_...", "./outputs/unsloth_Meta-Llama-3.1_.../"), + ("my_finetuned_model", "./outputs/my_finetuned_model/"), + ] """ - trained_models = [] + trained_loras = [] outputs_path = resolve_output_dir(outputs_dir) if not outputs_path.exists(): logger.warning(f"Outputs directory not found: {outputs_dir}") - return trained_models + return trained_loras try: for item in outputs_path.iterdir(): if item.is_dir(): - model_type = _detect_training_output_type(item) - if model_type is None: - continue + # Check if this directory contains a LoRA adapter + adapter_config = item / "adapter_config.json" + adapter_model = item / "adapter_model.safetensors" - display_name = item.name - model_path = str(item) - trained_models.append((display_name, model_path, model_type)) - logger.debug("Found trained model: %s (%s)", display_name, model_type) + if adapter_config.exists() or adapter_model.exists(): + display_name = item.name + adapter_path = str(item) + trained_loras.append((display_name, adapter_path)) + logger.debug(f"Found trained LoRA: {display_name}") # Sort by modification time (newest first) - trained_models.sort(key = lambda x: Path(x[1]).stat().st_mtime, reverse = True) + trained_loras.sort(key = lambda x: Path(x[1]).stat().st_mtime, reverse = True) logger.info( - "Found %s trained models in %s", - len(trained_models), - outputs_dir, + f"Found {len(trained_loras)} trained LoRA adapters in {outputs_dir}" ) - return trained_models + return trained_loras except Exception as e: logger.error(f"Error scanning outputs folder: {e}") @@ -1537,9 +1395,7 @@ def scan_exported_models( # Check for flat GGUF export (e.g. exports/gemma-3-4b-it-finetune-gguf/) # Filter out mmproj (vision projection) files — they aren't loadable as main models - gguf_files = [ - f for f in _iter_gguf_files(run_dir) if not _is_mmproj(f.name) - ] + gguf_files = [f for f in run_dir.glob("*.gguf") if not _is_mmproj(f.name)] if gguf_files: base_model = None export_meta = run_dir / "export_metadata.json" @@ -1566,7 +1422,7 @@ def scan_exported_models( has_weights = any(checkpoint_dir.glob("*.safetensors")) or any( checkpoint_dir.glob("*.bin") ) - has_gguf = any(_iter_gguf_files(checkpoint_dir)) + has_gguf = any(checkpoint_dir.glob("*.gguf")) base_model = None export_type = None @@ -1589,7 +1445,7 @@ def scan_exported_models( pass elif has_gguf: export_type = "gguf" - gguf_list = list(_iter_gguf_files(checkpoint_dir)) + gguf_list = list(checkpoint_dir.glob("*.gguf")) # Check checkpoint_dir first, then fall back to parent run_dir # (export.py writes metadata to the top-level export directory) for meta_dir in (checkpoint_dir, run_dir): @@ -1638,68 +1494,6 @@ def scan_exported_models( return [] -def get_base_model_from_checkpoint(checkpoint_path: str) -> Optional[str]: - """Read the base model name from a local training or checkpoint directory.""" - try: - checkpoint_path_obj = Path(checkpoint_path) - - adapter_config_path = checkpoint_path_obj / "adapter_config.json" - if adapter_config_path.exists(): - with open(adapter_config_path, "r") as f: - config = json.load(f) - base_model = config.get("base_model_name_or_path") - if base_model: - logger.info( - "Detected base model from adapter_config.json: %s", base_model - ) - return base_model - - config_path = checkpoint_path_obj / "config.json" - if config_path.exists(): - with open(config_path, "r") as f: - config = json.load(f) - for key in ("model_name", "_name_or_path"): - base_model = config.get(key) - if base_model and str(base_model) != str(checkpoint_path_obj): - logger.info( - "Detected base model from config.json (%s): %s", - key, - base_model, - ) - return base_model - - training_args_path = checkpoint_path_obj / "training_args.bin" - if training_args_path.exists(): - try: - import torch - - training_args = torch.load(training_args_path) - if hasattr(training_args, "model_name_or_path"): - base_model = training_args.model_name_or_path - logger.info( - "Detected base model from training_args.bin: %s", base_model - ) - return base_model - except Exception as e: - logger.warning(f"Could not load training_args.bin: {e}") - - dir_name = checkpoint_path_obj.name - if dir_name.startswith("unsloth_"): - parts = dir_name.split("_") - if len(parts) >= 2: - model_parts = parts[1:-1] - base_model = "unsloth/" + "_".join(model_parts) - logger.info("Detected base model from directory name: %s", base_model) - return base_model - - logger.warning(f"Could not detect base model for checkpoint: {checkpoint_path}") - return None - - except Exception as e: - logger.error(f"Error reading base model from checkpoint config: {e}") - return None - - def get_base_model_from_lora(lora_path: str) -> Optional[str]: """ Read the base model name from a LoRA adapter's config. @@ -1708,14 +1502,16 @@ def get_base_model_from_lora(lora_path: str) -> Optional[str]: lora_path: Path to the LoRA adapter directory Returns: - Base model identifier or None if not found + Base model identifier (e.g., "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit") + or None if not found + + Example: + >>> get_base_model_from_lora("./outputs/unsloth_Meta-Llama-3.1_.../") + "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit" """ try: lora_path_obj = Path(lora_path) - if not _looks_like_lora_adapter(lora_path_obj): - return None - # Try adapter_config.json first adapter_config_path = lora_path_obj / "adapter_config.json" if adapter_config_path.exists(): @@ -2019,16 +1815,8 @@ def from_identifier( except Exception as e: logger.debug(f"Could not read export metadata: {e}") - # If vision (or mmproj happens to exist), find the mmproj - # file. The recursive variant scan in - # ``_find_local_gguf_by_variant`` may have returned a - # weight file inside a quant-named subdir (e.g. - # ``.../BF16/foo.gguf``) while ``mmproj-*.gguf`` lives - # at the snapshot root. Pass ``search_root=path`` so - # ``detect_mmproj_file`` walks up to the snapshot root - # instead of seeing only the weight file's immediate - # parent. - mmproj_file = detect_mmproj_file(gguf_file, search_root = path) + # If vision (or mmproj happens to exist), find the mmproj file + mmproj_file = detect_mmproj_file(gguf_file) if mmproj_file: gguf_is_vision = True logger.info(f"Detected mmproj for vision: {mmproj_file}") @@ -2096,11 +1884,7 @@ def from_identifier( # Auto-detect LoRA for local paths (check adapter_config.json on disk) if not is_lora and is_local: - detected_base = ( - get_base_model_from_lora(path) - if _looks_like_lora_adapter(Path(path)) - else None - ) + detected_base = get_base_model_from_lora(path) if detected_base: is_lora = True logger.info( diff --git a/studio/backend/utils/native_path_leases.py b/studio/backend/utils/native_path_leases.py deleted file mode 100644 index a69dfab532..0000000000 --- a/studio/backend/utils/native_path_leases.py +++ /dev/null @@ -1,406 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 - -"""Verification for Tauri native path signed grants. - -Rust signs compact ``base64url(payload_json).base64url(hmac)`` grants. The -frontend can see and forward the grant, but cannot change it without breaking -the HMAC. The backend verifies the original payload segment bytes, then -re-stats the path before any native read. -""" - -from __future__ import annotations - -import base64 -import binascii -import hashlib -import hmac -import json -import os -import stat as _stat_module -import threading -import time -from contextlib import contextmanager -from dataclasses import dataclass -from pathlib import Path -from typing import Any, Callable, Iterable, Iterator, Mapping - -LEASE_SECRET_ENV = "UNSLOTH_STUDIO_NATIVE_PATH_LEASE_SECRET" -_MAX_NATIVE_PATH_REDACTIONS = 100 -_MAX_NATIVE_PATH_LABELS = 10_000 -_MIN_LEASE_SECRET_BYTES = 32 - -_REPLAY_LOCK = threading.Lock() -_USED_NONCES: dict[str, int] = {} -_REDACTION_LOCK = threading.Lock() -_NATIVE_PATH_REDACTIONS: list[str] = [] -_NATIVE_PATH_LABELS: dict[str, str] = {} -_NATIVE_PATH_ENV_LOCK = threading.Lock() -_SECRET_INIT_LOCK = threading.Lock() -_CACHED_LEASE_SECRET: bytes | None = None -_SCRUB_REFCOUNT = 0 -_SCRUB_SAVED_SECRET: str | None = None - - -class NativePathLeaseError(ValueError): - """Raised when a native path grant is missing, invalid, or unsafe.""" - - -@dataclass(frozen = True) -class NativePathGrant: - operation: str - canonical_path: Path - path_kind: str - path_type: str - source_kind: str - token_id_hash: str - display_label: str - expires_at_ms: int - size_bytes: int | None - modified_ms: int | None - - -def native_path_leases_supported() -> bool: - try: - _decode_secret() - except NativePathLeaseError: - return False - return True - - -def child_env_without_native_path_secret( - env: Mapping[str, str] | None = None, -) -> dict[str, str]: - """Return a child-process env with the native path lease secret removed.""" - - if env is None: - with _NATIVE_PATH_ENV_LOCK: - cleaned = dict(os.environ) - else: - cleaned = dict(env) - cleaned.pop(LEASE_SECRET_ENV, None) - return cleaned - - -def run_without_native_path_secret( - target: Callable[..., Any], - *args: Any, - **kwargs: Any, -) -> Any: - """Run a multiprocessing child target without the native path lease secret.""" - - global _CACHED_LEASE_SECRET, _SCRUB_SAVED_SECRET - os.environ.pop(LEASE_SECRET_ENV, None) - _CACHED_LEASE_SECRET = None - _SCRUB_SAVED_SECRET = None - return target(*args, **kwargs) - - -@contextmanager -def native_path_secret_removed_for_child_start() -> Iterator[None]: - global _SCRUB_REFCOUNT, _SCRUB_SAVED_SECRET, _CACHED_LEASE_SECRET - with _NATIVE_PATH_ENV_LOCK: - if _SCRUB_REFCOUNT == 0: - _SCRUB_SAVED_SECRET = os.environ.pop(LEASE_SECRET_ENV, None) - _CACHED_LEASE_SECRET = None - _SCRUB_REFCOUNT += 1 - try: - yield - finally: - with _NATIVE_PATH_ENV_LOCK: - _SCRUB_REFCOUNT -= 1 - if _SCRUB_REFCOUNT == 0 and _SCRUB_SAVED_SECRET is not None: - os.environ[LEASE_SECRET_ENV] = _SCRUB_SAVED_SECRET - _SCRUB_SAVED_SECRET = None - - -def verify_native_path_lease( - lease: str | None, - *, - operation: str, - expected_kind: str | None = None, - expected_path_type: str | None = None, - allowed_suffixes: Iterable[str] | None = None, -) -> NativePathGrant: - if not lease: - raise NativePathLeaseError("Native path grant is required.") - - secret = _decode_secret() - payload_b64, signature_b64 = _split_lease(lease) - expected_signature = hmac.new( - secret, - payload_b64.encode("ascii"), - hashlib.sha256, - ).digest() - supplied_signature = _b64decode(signature_b64) - if not hmac.compare_digest(expected_signature, supplied_signature): - raise NativePathLeaseError("Native path grant signature is invalid.") - - payload = _decode_payload(payload_b64) - _validate_payload(payload, operation = operation, expected_kind = expected_kind) - - path = Path(str(payload["canonical_path"])) - _reject_network_or_device_path(path) - try: - signed_lstat = os.lstat(path) - except OSError as exc: - raise NativePathLeaseError("Native path is no longer accessible.") from exc - if _stat_module.S_ISLNK(signed_lstat.st_mode): - raise NativePathLeaseError("Native path is no longer a regular file.") - try: - resolved = path.resolve(strict = True) - except OSError as exc: - raise NativePathLeaseError("Native path is no longer accessible.") from exc - _reject_network_or_device_path(resolved) - if not _same_native_path(resolved, path): - raise NativePathLeaseError( - "Native path grant no longer resolves to the selected path." - ) - - grant = NativePathGrant( - operation = str(payload["operation"]), - canonical_path = resolved, - path_kind = str(payload["path_kind"]), - path_type = str(payload["path_type"]), - source_kind = str(payload["source_kind"]), - token_id_hash = str(payload["token_id_hash"]), - display_label = str(payload.get("display_label") or resolved.name), - expires_at_ms = _required_int(payload, "expires_at_ms"), - size_bytes = _optional_int(payload.get("size_bytes")), - modified_ms = _optional_int(payload.get("modified_ms")), - ) - - if expected_path_type and grant.path_type != expected_path_type: - raise NativePathLeaseError("Native path grant has the wrong path type.") - suffixes = tuple(s.lower() for s in (allowed_suffixes or ())) - if suffixes and resolved.suffix.lower() not in suffixes: - raise NativePathLeaseError("Native path grant has an unsupported file type.") - - _validate_current_stat(grant) - _consume_nonce(str(payload["nonce"]), grant.expires_at_ms) - _remember_native_path_for_redaction(str(resolved), grant.display_label) - return grant - - -def display_label_for_native_path(value: str | None) -> str | None: - if not value: - return value - with _REDACTION_LOCK: - return _NATIVE_PATH_LABELS.get(value, value) - - -def is_registered_native_path_label(path_value: str | None, label: str | None) -> bool: - if not path_value or not label: - return False - with _REDACTION_LOCK: - return _NATIVE_PATH_LABELS.get(path_value) == label - - -def redact_native_paths(value: str) -> str: - with _REDACTION_LOCK: - paths = sorted(_NATIVE_PATH_REDACTIONS, key = len, reverse = True) - redacted = value - for path in paths: - for variant in {path, path.replace("/", "\\"), path.replace("\\", "/")}: - if variant: - redacted = redacted.replace(variant, "") - return redacted - - -def _decode_secret() -> bytes: - global _CACHED_LEASE_SECRET - if _CACHED_LEASE_SECRET is not None: - return _CACHED_LEASE_SECRET - with _SECRET_INIT_LOCK: - if _CACHED_LEASE_SECRET is not None: - return _CACHED_LEASE_SECRET - with _NATIVE_PATH_ENV_LOCK: - encoded = os.environ.get(LEASE_SECRET_ENV) - if encoded is None and _SCRUB_SAVED_SECRET is not None: - encoded = _SCRUB_SAVED_SECRET - if not encoded: - raise NativePathLeaseError( - "Native path grants require the managed desktop backend." - ) - try: - secret = _b64decode(encoded) - except Exception as exc: - raise NativePathLeaseError("Native path grant secret is invalid.") from exc - if len(secret) < _MIN_LEASE_SECRET_BYTES: - raise NativePathLeaseError("Native path grant secret is invalid.") - _CACHED_LEASE_SECRET = secret - return secret - - -def _split_lease(lease: str) -> tuple[str, str]: - if not isinstance(lease, str): - raise NativePathLeaseError("Native path grant has an invalid format.") - try: - lease.encode("ascii") - except UnicodeEncodeError as exc: - raise NativePathLeaseError("Native path grant has an invalid format.") from exc - parts = lease.split(".") - if len(parts) != 2 or not parts[0] or not parts[1]: - raise NativePathLeaseError("Native path grant has an invalid format.") - return parts[0], parts[1] - - -def _decode_payload(payload_b64: str) -> dict[str, Any]: - try: - payload = json.loads(_b64decode(payload_b64).decode("utf-8")) - except Exception as exc: - raise NativePathLeaseError("Native path grant payload is invalid.") from exc - if not isinstance(payload, dict): - raise NativePathLeaseError("Native path grant payload is invalid.") - return payload - - -def _validate_payload( - payload: dict[str, Any], *, operation: str, expected_kind: str | None -) -> None: - required = ( - "version", - "operation", - "canonical_path", - "path_kind", - "path_type", - "source_kind", - "token_id_hash", - "issued_at_ms", - "expires_at_ms", - "nonce", - ) - missing = [key for key in required if key not in payload] - if missing: - raise NativePathLeaseError( - "Native path grant payload is missing required fields." - ) - if _required_int(payload, "version") != 1: - raise NativePathLeaseError("Native path grant version is unsupported.") - if payload["operation"] != operation: - raise NativePathLeaseError("Native path grant operation is invalid.") - if expected_kind and payload["path_kind"] != expected_kind: - raise NativePathLeaseError("Native path grant kind is invalid.") - now_ms = int(time.time() * 1000) - issued_at_ms = _required_int(payload, "issued_at_ms") - expires_at_ms = _required_int(payload, "expires_at_ms") - if issued_at_ms >= expires_at_ms: - raise NativePathLeaseError("Native path grant timestamps are inconsistent.") - if expires_at_ms <= now_ms: - raise NativePathLeaseError("Native path grant has expired.") - if issued_at_ms > now_ms + 30_000: - raise NativePathLeaseError("Native path grant issue time is invalid.") - for key in ("canonical_path", "nonce", "token_id_hash", "display_label"): - raw = payload.get(key) - if raw is None: - continue - if "\x00" in str(raw): - raise NativePathLeaseError("Native path grant contains invalid characters.") - - -def _validate_current_stat(grant: NativePathGrant) -> None: - try: - st = os.lstat(grant.canonical_path) - except OSError as exc: - raise NativePathLeaseError("Native path is no longer accessible.") from exc - if _stat_module.S_ISLNK(st.st_mode): - raise NativePathLeaseError("Native path is no longer a regular file.") - if grant.path_type == "file": - if not _stat_module.S_ISREG(st.st_mode): - raise NativePathLeaseError("Native path is no longer a regular file.") - elif grant.path_type == "directory": - if not _stat_module.S_ISDIR(st.st_mode): - raise NativePathLeaseError("Native path is no longer a directory.") - else: - raise NativePathLeaseError("Native path grant has an unsupported path type.") - - if grant.size_bytes is not None and st.st_size != grant.size_bytes: - raise NativePathLeaseError("Native path changed after it was selected.") - current_modified_ms = int(st.st_mtime_ns // 1_000_000) - if grant.modified_ms is not None and current_modified_ms != grant.modified_ms: - raise NativePathLeaseError("Native path changed after it was selected.") - - -def _consume_nonce(nonce: str, expires_at_ms: int) -> None: - now_ms = int(time.time() * 1000) - with _REPLAY_LOCK: - for key, expiry in list(_USED_NONCES.items()): - if expiry <= now_ms: - _USED_NONCES.pop(key, None) - if nonce in _USED_NONCES: - raise NativePathLeaseError("Native path grant was already used.") - _USED_NONCES[nonce] = expires_at_ms - - -def _remember_native_path_for_redaction(path: str, display_label: str) -> None: - with _REDACTION_LOCK: - _NATIVE_PATH_LABELS[path] = display_label - if len(_NATIVE_PATH_LABELS) > _MAX_NATIVE_PATH_LABELS: - excess = len(_NATIVE_PATH_LABELS) - _MAX_NATIVE_PATH_LABELS - for stale_path in list(_NATIVE_PATH_LABELS.keys())[:excess]: - _NATIVE_PATH_LABELS.pop(stale_path, None) - if path in _NATIVE_PATH_REDACTIONS: - return - _NATIVE_PATH_REDACTIONS.append(path) - del _NATIVE_PATH_REDACTIONS[:-_MAX_NATIVE_PATH_REDACTIONS] - - -def _reject_network_or_device_path(path: Path) -> None: - text = str(path) - if os.name == "nt": - normalized = text.replace("/", "\\").lower() - if normalized.startswith("\\\\?\\"): - rest = normalized[4:] - is_local_drive = len(rest) >= 3 and rest[0].isalpha() and rest[1:3] == ":\\" - if not is_local_drive: - raise NativePathLeaseError( - "Network paths are not supported for native grants." - ) - elif normalized.startswith("\\\\"): - raise NativePathLeaseError( - "Network paths are not supported for native grants." - ) - if os.name != "nt": - for root in ("/dev", "/proc", "/sys"): - if path.is_relative_to(root): - raise NativePathLeaseError( - "Device and virtual filesystem paths are not supported." - ) - if "\x00" in text: - raise NativePathLeaseError("Native path contains invalid characters.") - - -def _b64decode(value: str) -> bytes: - try: - padding = "=" * (-len(value) % 4) - return base64.urlsafe_b64decode((value + padding).encode("ascii")) - except (UnicodeEncodeError, binascii.Error, ValueError) as exc: - raise NativePathLeaseError("Native path grant has an invalid format.") from exc - - -def _same_native_path(resolved: Path, signed: Path) -> bool: - try: - return resolved.samefile(signed) - except OSError: - return os.path.normcase(str(resolved)) == os.path.normcase(str(signed)) - - -def _optional_int(value: Any) -> int | None: - if value is None: - return None - try: - return int(value) - except (TypeError, ValueError) as exc: - raise NativePathLeaseError("Native path grant payload is invalid.") from exc - - -def _required_int(payload: dict[str, Any], key: str) -> int: - raw = payload.get(key) - if raw is None: - raise NativePathLeaseError( - "Native path grant payload is missing required fields." - ) - try: - return int(raw) - except (TypeError, ValueError) as exc: - raise NativePathLeaseError("Native path grant payload is invalid.") from exc diff --git a/studio/backend/utils/paths/__init__.py b/studio/backend/utils/paths/__init__.py index 92191dccdd..11709ae56e 100644 --- a/studio/backend/utils/paths/__init__.py +++ b/studio/backend/utils/paths/__init__.py @@ -34,7 +34,6 @@ legacy_hf_cache_dir, hf_default_cache_dir, lmstudio_model_dirs, - well_known_model_dirs, ensure_dir, ensure_studio_directories, resolve_under_root, @@ -71,7 +70,6 @@ "legacy_hf_cache_dir", "hf_default_cache_dir", "lmstudio_model_dirs", - "well_known_model_dirs", "ensure_dir", "ensure_studio_directories", "resolve_under_root", diff --git a/studio/backend/utils/paths/storage_roots.py b/studio/backend/utils/paths/storage_roots.py index b52609b06b..4841c5d0a3 100644 --- a/studio/backend/utils/paths/storage_roots.py +++ b/studio/backend/utils/paths/storage_roots.py @@ -130,51 +130,6 @@ def _add(p: Path) -> None: return dirs -def well_known_model_dirs() -> list[Path]: - """Return directories commonly used by other local LLM tools. - - Used by the folder browser to offer quick-pick chips. Returns only - paths that exist on disk, so the UI never shows dead chips. Order - reflects a rough "likelihood the user has models here" -- LM Studio - and Ollama first, then the generic fallbacks. - """ - candidates: list[Path] = [] - - # LM Studio (reuses the logic above, including settings.json override) - candidates.extend(lmstudio_model_dirs()) - - # Ollama -- both the user-level and common system-wide install paths - # (https://github.com/ollama/ollama/issues/733). - ollama_env = os.environ.get("OLLAMA_MODELS") - if ollama_env: - candidates.append(Path(ollama_env).expanduser()) - candidates.append(Path.home() / ".ollama" / "models") - candidates.append(Path("/usr/share/ollama/.ollama/models")) - candidates.append(Path("/var/lib/ollama/.ollama/models")) - - # HF hub cache root (separate from the explicit HF cache chip) - candidates.append(Path.home() / ".cache" / "huggingface" / "hub") - - # Generic "my models" spots users tend to drop things into - for name in ("models", "Models"): - candidates.append(Path.home() / name) - - # Deduplicate while preserving order; keep only extant dirs - out: list[Path] = [] - seen: set[str] = set() - for p in candidates: - try: - resolved = str(p.resolve()) - except OSError: - continue - if resolved in seen: - continue - if Path(resolved).is_dir(): - seen.add(resolved) - out.append(Path(resolved)) - return out - - def _setup_cache_env() -> None: """Set cache environment variables for HuggingFace, uv, and vLLM. diff --git a/studio/backend/utils/subprocess_compat.py b/studio/backend/utils/subprocess_compat.py deleted file mode 100644 index bedf8cf2e6..0000000000 --- a/studio/backend/utils/subprocess_compat.py +++ /dev/null @@ -1,34 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 - -"""Cross-platform subprocess helpers for the Unsloth Studio backend.""" - -import subprocess -import sys - - -def windows_hidden_subprocess_kwargs() -> dict[str, object]: - """Return Windows-only subprocess kwargs that suppress console windows. - - On non-Windows platforms returns an empty dict so callers can always - unpack the result into ``subprocess.run`` / ``subprocess.Popen`` via - ``**windows_hidden_subprocess_kwargs()``. - """ - if sys.platform != "win32": - return {} - - kwargs: dict[str, object] = {} - create_no_window = getattr(subprocess, "CREATE_NO_WINDOW", 0) - if create_no_window: - kwargs["creationflags"] = create_no_window - - startupinfo_factory = getattr(subprocess, "STARTUPINFO", None) - startf_use_showwindow = getattr(subprocess, "STARTF_USESHOWWINDOW", 0) - sw_hide = getattr(subprocess, "SW_HIDE", 0) - if startupinfo_factory is not None and startf_use_showwindow: - startupinfo = startupinfo_factory() - startupinfo.dwFlags |= startf_use_showwindow - startupinfo.wShowWindow = sw_hide - kwargs["startupinfo"] = startupinfo - - return kwargs diff --git a/studio/backend/utils/transformers_version.py b/studio/backend/utils/transformers_version.py index 17af40f663..0c13b5455b 100644 --- a/studio/backend/utils/transformers_version.py +++ b/studio/backend/utils/transformers_version.py @@ -36,11 +36,6 @@ import sys from pathlib import Path -from utils.native_path_leases import child_env_without_native_path_secret -from utils.subprocess_compat import ( - windows_hidden_subprocess_kwargs as _windows_hidden_subprocess_kwargs, -) - logger = get_logger(__name__) @@ -57,14 +52,12 @@ "qwen3.5", # Qwen3.5 family (35B-A3B, etc.) "qwen3-next", # Qwen3-Next and variants "tiny_qwen3_moe", # imdatta0/tiny_qwen3_moe_2.8B_0.7B - "lfm2.5-vl-450m", # LiquidAI/LFM2.5-VL-450M ) # Lowercase substrings for models that require transformers 5.5.0 (checked first). TRANSFORMERS_550_MODEL_SUBSTRINGS: tuple[str, ...] = ( "gemma-4", # Gemma-4 (E2B-it, E4B-it, 31B-it, 26B-A4B-it) "gemma4", # Gemma-4 alternate naming - "qwen3.6", ) # Architecture classes / model_type values that require transformers 5.5.0. @@ -505,8 +498,6 @@ def _install_to_dir(pkg: str, target_dir: str) -> bool: stdout = subprocess.PIPE, stderr = subprocess.STDOUT, text = True, - env = child_env_without_native_path_secret(), - **_windows_hidden_subprocess_kwargs(), ) if result.returncode == 0: return True @@ -528,8 +519,6 @@ def _install_to_dir(pkg: str, target_dir: str) -> bool: stdout = subprocess.PIPE, stderr = subprocess.STDOUT, text = True, - env = child_env_without_native_path_secret(), - **_windows_hidden_subprocess_kwargs(), ) if result.returncode != 0: logger.error("install failed:\n%s", result.stdout) diff --git a/studio/backend/utils/wheel_utils.py b/studio/backend/utils/wheel_utils.py deleted file mode 100644 index 3ed9bda827..0000000000 --- a/studio/backend/utils/wheel_utils.py +++ /dev/null @@ -1,175 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 - -from __future__ import annotations - -import json -import logging -import platform -import shutil -import subprocess -import sys -import urllib.error -import urllib.request -from typing import Callable - -from utils.native_path_leases import child_env_without_native_path_secret - -_logger = logging.getLogger(__name__) - -FLASH_ATTN_RELEASE_BASE_URL = ( - "https://github.com/Dao-AILab/flash-attention/releases/download" -) - - -def linux_wheel_platform_tag() -> str | None: - machine = platform.machine().lower() - if sys.platform.startswith("linux"): - if machine in {"x86_64", "amd64"}: - return "linux_x86_64" - if machine in {"aarch64", "arm64"}: - return "linux_aarch64" - # No prebuilt wheels published for macOS or Windows - return None - - -def probe_torch_wheel_env(*, timeout: int | None = None) -> dict[str, str] | None: - platform_tag = linux_wheel_platform_tag() - if platform_tag is None: - return None - - try: - probe = subprocess.run( - [ - sys.executable, - "-c", - ( - "import json, sys, re, torch; " - "parts = torch.__version__.split('+', 1)[0].split('.')[:2]; " - "minor = re.sub(r'[^0-9].*', '', parts[1]) if len(parts) > 1 else '0'; " - "torch_mm = parts[0] + '.' + minor; " - "print(json.dumps({" - "'python_tag': f'cp{sys.version_info.major}{sys.version_info.minor}', " - "'torch_mm': torch_mm, " - "'cuda_major': str(int(str(torch.version.cuda).split('.', 1)[0])) if torch.version.cuda else '', " - "'hip_version': str(torch.version.hip) if getattr(torch.version, 'hip', None) else '', " - "'cxx11abi': str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()" - "}))" - ), - ], - stdout = subprocess.PIPE, - stderr = subprocess.PIPE, - text = True, - timeout = timeout, - env = child_env_without_native_path_secret(), - ) - except subprocess.TimeoutExpired: - return None - - if probe.returncode != 0: - return None - - try: - env = json.loads(probe.stdout.strip()) - except json.JSONDecodeError: - return None - env["platform_tag"] = platform_tag - return env - - -def direct_wheel_url( - *, - filename_prefix: str, - package_version: str, - release_tag: str, - release_base_url: str, - env: dict[str, str] | None, -) -> str | None: - if env is None or not env.get("cuda_major"): - return None - - filename = ( - f"{filename_prefix}-{package_version}" - f"+cu{env['cuda_major']}torch{env['torch_mm']}" - f"cxx11abi{env['cxx11abi']}-{env['python_tag']}-{env['python_tag']}" - f"-{env['platform_tag']}.whl" - ) - return f"{release_base_url}/{release_tag}/{filename}" - - -def flash_attn_package_version(torch_mm: str) -> str | None: - if torch_mm == "2.10": - return "2.8.1" - try: - major, minor = (int(part) for part in torch_mm.split(".", 1)) - except ValueError: - return None - if major == 2 and 4 <= minor <= 9: - return "2.8.3" - return None - - -def flash_attn_wheel_url(env: dict[str, str] | None) -> str | None: - if env is None: - return None - package_version = flash_attn_package_version(env["torch_mm"]) - if package_version is None: - return None - return direct_wheel_url( - filename_prefix = "flash_attn", - package_version = package_version, - release_tag = f"v{package_version}", - release_base_url = FLASH_ATTN_RELEASE_BASE_URL, - env = env, - ) - - -def install_wheel( - wheel_url: str, - *, - python_executable: str, - use_uv: bool, - uv_needs_system: bool = False, - run: Callable[..., subprocess.CompletedProcess[str]] = subprocess.run, -) -> list[tuple[str, subprocess.CompletedProcess[str]]]: - attempts: list[tuple[str, subprocess.CompletedProcess[str]]] = [] - - # Try uv first if available, then fall back to pip - if use_uv and shutil.which("uv"): - uv_cmd = ["uv", "pip", "install"] - if uv_needs_system: - uv_cmd.append("--system") - uv_cmd.extend(["--python", python_executable, "--no-deps", wheel_url]) - result = run( - uv_cmd, - stdout = subprocess.PIPE, - stderr = subprocess.STDOUT, - text = True, - env = child_env_without_native_path_secret(), - ) - attempts.append(("uv", result)) - if result.returncode == 0: - return attempts - - pip_cmd = [python_executable, "-m", "pip", "install", "--no-deps", wheel_url] - result = run( - pip_cmd, - stdout = subprocess.PIPE, - stderr = subprocess.STDOUT, - text = True, - env = child_env_without_native_path_secret(), - ) - attempts.append(("pip", result)) - return attempts - - -def url_exists(url: str) -> bool: - try: - request = urllib.request.Request(url, method = "HEAD") - with urllib.request.urlopen(request, timeout = 10): - return True - except urllib.error.HTTPError as exc: - _logger.debug("url_exists(%s): HTTP %s", url, exc.code) - except (urllib.error.URLError, TimeoutError) as exc: - _logger.debug("url_exists(%s): %s", url, exc) - return False diff --git a/studio/frontend/package.json b/studio/frontend/package.json index 6b24440964..ffb3c65719 100644 --- a/studio/frontend/package.json +++ b/studio/frontend/package.json @@ -16,7 +16,6 @@ "biome:fix": "biome check . --write" }, "dependencies": { - "@assistant-ui/core": "0.1.17", "@assistant-ui/react": "^0.12.19", "@assistant-ui/react-markdown": "^0.12.3", "@assistant-ui/react-streamdown": "^0.1.2", @@ -42,12 +41,6 @@ "@tailwindcss/vite": "^4.2.2", "@tanstack/react-router": "^1.159.10", "@tanstack/react-table": "^8.21.3", - "@tauri-apps/api": "^2.10.1", - "@tauri-apps/plugin-clipboard-manager": "^2.3.2", - "@tauri-apps/plugin-notification": "^2.3.3", - "@tauri-apps/plugin-opener": "^2.5.3", - "@tauri-apps/plugin-process": "^2.3.1", - "@tauri-apps/plugin-updater": "^2.10.1", "@toolwind/corner-shape": "^0.0.8-3", "@types/canvas-confetti": "^1.9.0", "@xyflow/react": "^12.10.0", @@ -94,7 +87,6 @@ "eslint-plugin-react-hooks": "^7.0.1", "eslint-plugin-react-refresh": "^0.5.2", "globals": "^17.4.0", - "playwright": "^1.59.1", "typescript": "~5.9.3", "typescript-eslint": "^8.55.0", "vite": "^8.0.1" diff --git a/studio/frontend/public/blacklogo-c.png b/studio/frontend/public/blacklogo-c.png deleted file mode 100644 index 7ab9959536d21dc48ed49f7782d579033014b9f9..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 141545 zcmZ6y1yq#V_dh%{bc=K&TuG@b-AXBlbjZ*hGjw-|q7tHjN`seXC}DdvG;kU{Zx&Nn2{I+0+Bs_q^t`95x9at5HUi0;0Y^o z!VdV4-R6<5CI}SB1p&OD%z_->=QwClC z{*&EQ@E&+VcJq9y4<5+% zO9g&B=&|wxegBD#DZgOmr5gF2mEN+zF9J}{?WWmy5@ehA6OpFBZr#=;#C>=EqeYQDg7S{^d=0SFm=dl2_c)ecRBjL1ue9VSKcIqZekesUFyT+Vkkk z{q5e3@4BUGE6h6T5iy?K<1TFb5I#2I0QDl-h{acrSaAb5-chhqD9XnkzX9umbcwU| zm5D)%IK_mDcfudWP(LUFzCd{NyF_xjNkB{BS)#HMaJcR*(7bm@Ws5Zrj}2u7yrN)7{6XTpRB$!lBa-aYxXoe%UQvFVEPNh=f^dA0fOavN`f0K*T%DJV z`a*p$-@8e3^*e#tvD0Y)LP+4`Mds$%`~10SMG^Yp ze}nScP@aU}!MwNG*;GML;s;JUmzM5tAPbtr>^$L=m=~=4qW#MkelwdI57Q&yu?tD3 z%+lfL!8F}sGDVVX6PiItwkB90_ZD#g6U`uDA$m4cQ@sd3Pn)6;KTjXW+n~U^$$Yg~ z=a|45OFlMQ*1V{O*p)rz_{plzA8Pd%$NxUL=qvq`0Fndi<6dJ)-e%r9z#U}jLI!6y@fqf3=zXokbEkk!B~-{4 z=d#)m!OV{I;=B2N*waiJYyK8Q|oR^8$+w} zCUXhA6}+d)y+;`x;V%25;wL7@Y_1_I3Sk{W&z1jtk_<>)UrDz+JaRs>&6n+=yFFmV&0PuPM$jR~4XxGy^-?MGP!HX` zxnBoX?ixv44y7U0go+GFe5#7cj(7oitmP=#!O01zY^F115&ocBb}{&XCb9XJKF&LW zPM$PPZP^_7UIe02O%|T`nE*MLpx?=bm3Hmsr@1R>sBJw}NX?it5~{eVM|@5*5(*0X zx<0keD>%9B<|;+)=;|~Wa;1u^uasSb5E&f#S#kA_4W%3;{xCxVM>5Omu_SY1^ZHl^7bxe_V)!d+%rs6smfTAZP4rokCJ;#zjyI`8Z z1RgUMozR&!#gc99{r<~yiS$dQ?$@SBBd$M-$n_Vgz3b_|HlencokUmvT}e@hMsuAz zm6T4CiR-Nu$KTYcq?Fkvu?*FzAVzj`5aTU!7-GmQ5C_^SC0n!kei_;GiU65^$-(-h zNx-wz_eZ~rGyxK+25DQH1c6v0Rr4z7)I)N@;h^S`y{0U@iQp^=H*847Qfs1pQ{5$?Yd6>1; zS#-4t#1f9P4Nd(O_hf_cFnRnnP|@+wUo&po81F-t(1H~;Wzl^n`uQ6x@o^({Ey7B# z#CNQSGBmB! zI3F~qmPreUnYxyysj7XJqS>%Wnz~R$F`Tt*-4oqaOn%Z;k)eJCUVSsxXhxr)rG*#B ziot5QUkCT~P)1L4eD|Lr zgms90xlL}#x$ZEvj8Txum7*VI4dxscz+UBhgesDw?BHk&(}FZ=U{ zm;)+JCfd?c7J;{IAUvVMT-w`ZA82g3EAXoBccWJu>twpo^<_3t+hZxDM>O&k zf~u?sZx&~;G)cNJX`(EtnI`op5!L9O_7mqkSwVZj4JI{NyZR+ntND5z-Lo>r+hZ4_ z@u#qj917Rw39dg+rbvd+ZBXT6EXg!*PfuNi7wqo3+*9L&SUX zdJ(NhXMXbIW3{LI%VuFFl!Pn{q;6k*oh<*e6ocPOSvV1+Sb3w{lru=6$OguLl^;i) zD3aPk7()sk)O)g9`bg3>uC!`%%?q|rRN*6|HUYuT66un``qoX_jcDp}JH-APq8%6G zLE3*kTK8z4!BHPm&`}Q|ZgViPY7<{v+SHN12G8WA;>2yc$z7akZB_Pm(opv88lK=~ z+&G+JYj=%|>Gn|yr`$m%VjD3{E+yOC{BMMy5Zs@7%7G9bF{6<^~)S+d}#4(9kq}S&a$z_Z*QrmaZb$Z!T9u%6VT@VhWP=qCg zcL=|16%p$I0iuB%=Hh?;J~aj2a7l1=CNu>#=ZHAtsuZA+r)a65`d=zj`uq)Kke;+_ zpAst%=}34)SC}24D7-q~>NfiVF3%UFw0Mn3d|@G8EV5^=Rs5phu&IBBR*KUL7RC0@ zPg<$wR&>1aEAu8*Fc$?!5@Q+-!hCd7-W!?g;)_kIqW#Y>ITof}x!8omQTR-ci{JSxUHUHXj}j z?IR*kf>;cVp5?lHH?3tdydKMeJsvQHF8VKH`0S?MJ%TO{UIF{mHAE{lRBX((5&F`datlAZZmU8&THc$oD|Rpn%X~$P?m~q6Ie>OF;XzTL#%gaf|N3r zCeez~HXMY*&V$!J=aTCM>o`YlGwiaiU{V#ncQRv@(VMkWNY@C&C1h%Xqx*ihc(a#u zbf@5P5jFM5onIKTa^w#^^D~CVs>d~F(@s|>KPVJhHkQu@%i5%Ba}2+edc69RIi93Z z(ClkfgYE00Rknua!KIt|jt$E+YKP+SH0qouSaMg+q zR?dw;Y{GbZD>aJ+Jeeclj6lV$lYbdv(0QeJ&Rqq$P%qLP9L4`O0A(w5nVdWu?>CF+ zN=Y?Ss6*w1_$>$}D4RlXZV-~(dq4(1Ls^bS;fS?C6qn!h;8HTME4oD_y)1E9Ab}++ zLld(BpRdMoh3o0#Ntyw1kbC$KVdJjk0Z7OfQXQ7vXPa4WcIQr0rrfK%u{nn`YRg zof1QB5@5JUWOptu@+gD~_J-Nd(q-zMJJxF7uweH7Ui#x|mH2IWfJuH|O7=5+E`uN<#{%8Y9-|#^njdw_=}bSAkIy0=go+#rkBnp(^C_s=OZmRZ6G7&6ua;n$ed&k z><{O1d7x#IVm^+f`mpW0m?t1%V|G?ACL1W!w!gAwGcgg_(?;E5pgHCq7mQ4 zr@UZTRGw8l!p`Nj0n!zo$eG2xo~kqP=w%I`89buBTQ}*NvN4kaqP8|&7I1P3_t9$1 z4?j9Z&3h9sh+e2Zym3Z+_^2fDeM1FpcIiM8TT)SGAr2B1+a!=2-aKA=1WYnk%xJ2< zuaXQ-B+o_3cE2L7J^>DW4No+#t?Zi>uhstRX&F zM#q9<1R^LJacDkOoQYU50u2_2cMOi&-}N^Th1Y~#&%xnYThgAhT?C9^jqo2kDSt`~ zPajZcO?siHt-_R_cw4E5Y#U>qP+C3;TDQS;V1RJZh(i#H4oi*GN=-}vl=|p$p|mRQ zKH<|bW~@J-`4mKd%*T4t%|FZk;XLlv{Rm`FB_9lH^jc!kplls4d!*nNEK=bPG#E&N6-40ClL>eKH3lDhs6Wg zy&sezQz3Tr$jH(qKufB8i9Q8i>n#9(nBU|@U+l=s(Tw?Sg^vlh;Zjjev4na}$)H|v$$UvARaBADwW{1w8PcCoL znP`w*DbKMS{2zQK!rJx4B1aLAM`_?kUNsrZ(}ie0Yi7CBQyenzQ0?PiAls27AMW|RbO4nzk%D&;`m zXM20XjE(6-?{{Ho$Io7#uRN&c!#Z}qxca$z7Nl()9eHth7Ln=)ojzOqk}AXIh95qA z0reuWLu%bXyVn4;)TN@L0~Tx$9(&r1Ft}Ud{wRT^j<|xIl}aa zG_#m270F^g@Z(mgW5D4inVRdBhJYoBo6^gQkiY*B(L@}4lAaS7VHGEj0RG_?S@jm8I{0&g!q3?P>M1W9Q z7QVl=-&H6Z2ZTCMw(5_b3{=JiXHHIVa8qF>%%eo{kjRy%puCch!BMELVdK-9YQ%0x z{eF~^(d7{cB-z_ycPQ{@T{^e`J9V92BY2%Htj^9I)ANbrqo86yEr)m=vM1eU-S^2w z1Y${UDfsf6{ABz9s8lE_bCC$&ej8lbTS?YcKzZMJA=8Fw5+&1Lxk(JHuYdkG zQ8}Bh4&9?Er(Bmz7xMgS{QnBVZ&870NBfq0;36yegE=$7V-;fGh@mfoVvG^s`iV%{ z?FJK_w)HEPb0dk%%^57y^-&-@{j@X4TQYb;xb^P;@HcnsbsJnl0yxAdJWuNI(^M{Q z@JF7vjf2ckvZoR7Bq`+62*gK-*>5>KPXcc^2MT(z*#Z%44miJmaqljG{!;1vFhC!- zq5~toRjc@t9u5!EN*%3~bQruSm=Nd4p4Lsh&}hu`i7+FOm-Rmt@&m)Np8Nql>Q}%T ze0tXiWOLn<@TFKdNR)+5!tbc6z?AV2#yCb~j#|c_Yo4l}Bb6sO{DACxdfrKeF*$xy z@LgQf3zlg4)~PO~69?+`kwMGNX0e)oB&2R<@i#UR$J9$^!v;429?lQ;`} zg;aTJlZ;sCC*Dyi8UQy6olb}0`LYDptbR+zUO;{EgW19u&ZgcDQUW}G?nYSoaFe<= z)7Wq-y>XMOp5pi2A@ zlaYTpw1QE^A4aZw5yCS18P-ox&%b`R3tGI=0gkVXs~_2OBqIjb(z0X<@Gc?#wcW8E zz;HS(R*c+T$Zu(^ZF2g`*HUzbzNo%`^YC1wfsOkc;7)kC!kV~mlx@M`9*|Yee8VnP zPYz^Cce;&Bn7)df;Atzm*_UgA!@(Y~QnXOCU!Sgfk1#_E*1bpo8DIA@6F9Yap^7-! z7VQ4^QQh*&uT2pP&@xcQ3r zZ7y0qO&q|$7`2BK zC^nbr?|_!|vyQgiH5{Kn@*f+p@XHw_u@6_}@3YkXMxwMAeAwp@=#MKv#Z8*yxa7Hx ztjO{D z)(Ru@d_~xQJFzWX%C^Q3BaMf2FKmw^g(;o;&PXA<)QC_jhJx$0ELjK4&|`K?U4}{u+TCzRET5T0?gA?pRBvJJKx z1fLSN>J?WmWeE(Fomb37IHWz9Lp+phtx1o+NyB{bfE@jZSHi{!DjZf>ycf|+`@oLJ zUhIYH#0rKHj4)Q3oF&iji5HkqB+R3!+X0P&h7r8MQ^SO6dMLHC)lv-bU!K1EeN zq&7CyB{fY2xi4*zOFt)R$1ffQw-xxPb!U&-v2$apZz`IGCqj2y&mC<6q>f{>jn| zswnJJ*j4IFQ7~6b`)aDCh_XFcwb+UR1~46yJ8Ee*Z~BL6rGOHLQTi%dW&`*s#gaZy z2{SKP+-0u*cvoe7)wUx=DAsL!b5yfgz;p47os`4K;0S9<7tjtkX1s6gQwVJS;pZ2K z-5Qnn%hS~12>3^kc44$`$=t3i8>V?XIsdK@9RBRooL%zV@YYq5|EfnqWna3Rp)eksZc6E}dK zZM>Urr_+ZoEcEdkHci#RpbJm?k~}9 zaq4_kq>AqZ!6w!krP99y6bfQM^S>j532fW(Rp8xr(SnUDn9(3T(v76tBIjdCL`e#86YC#=a-S+Z*%e&a1BBp$JO z#thgOUed09!SOamRM@dN+IoC8KQd@TA2ZhW1k^~Rd6nbLWC1g;JbT~OeLZ@QgB8U3 zoZn4g4zX*x%f$MJNG=KDqcS=xOW4+gS}Qp`Q(~AoTiMTP+hnj*xYxU;srIf{c!kCtnmTI8L?ofAbPvi zDfe3dXEm9?wBMumJc5nx2}@>2%LT~(EF zw>KL*nXg&isk}f|LY}OgYc>jaa?l`$AR~njJd)V;zShp{$XAv`87Q|G=58>V=zo>` zKe8>csQGt=PY`(q>lau8;{}(l)w*|bh9QY}4LR#(!06h!(QwoN-1+Q$O`P%@g}f`d z8`drPCTm+_#79Z3)rwKj5eaMHejr>V1)sf~;7C(4DnRxi`6iFx%kC1vNr!SSmTCRO zQ1<3R9_|uY@(-X}+f8Me%Ic2D_o+{nj2zer-7-KH)u{M7PNk;+6x0AIIG`5GhwTZ9 zHM3hc)l9GB!(zITEp(g}KF#Xu{bs@yag75Bcl;ScuV>PXN<&dj<0XB?MWZAOKhhvj z?mry{op>Afv?$u6tIOlHh_jVivzSYEVSUWK-FcP1`rRKEzue2 zx(#>sUi-^wAHX-9&=#0E&9`c(}=KP(}|y(*NT4Kw_J! zN6&8tix&u&gcuy(xBaJG{NzD(MPF4|>M|zvSkpTm0kw1?zrP^w23X0+70d*`HBEMB zfK2CU&FUaA^mt`*_j5HKva)|<=_GrH7&nY;L#C?U?rU*I-?i~DF+0Ugsl@X#;7bUY znyd95Krdp_@3anXF=Oi4tU>@X+f{VeeKC}Ly9jgVnQb!Mlm>~cp~|vxnGy0P(t!Zv z>XwY6j~RE@<%{E2W}3`ecPj&$1}RlZNgO%ab_^n$+?B$pBKZmUe5am-{GS#W zb%eThufGn)H|;5XT{c zHTeB(*qt_^3R3q!wgwN>*dq-GILZ&6;CQ;|eX+#%cYfh~oZx7{{cSwtKPmqNy`*8?)DjNCWuY6CD?|#H6$Qp z!A3|T=FtQ9gr4tE!MAH{iEmVdNn*_Gux0?$qQcZ?bO^MPhTIZ>ec5JnNQS3BaF^YT{M}7D zX#~WSBNTv68vB}8J>hsF9Ammw>b6T9?yj>1=4(Gz_9t-jZ*#IU~O z`3uLSR`ii+m$V+rFh98Yy~_R~UD!hPyVqH!XSup-Z(|nQ9Md*Vcm+AyZ`z^KxA|Sa zEOv@%ym;WA1G0E&2?I)d1`c$8Hv?#OO>mF5qW3aAi54gO-`fD= zb5H+%grAGK#(e_h=IAo9_H7TYm6d7FtnbZJho=VU%^ck)G^QE^+5A(xROro`_~6T2 zgK_C=Ye)%f_{js?n#BM%C+PLie%!Ibm+tFR)`(q_Cpn)`rHwZPL#dT}8Xk9Zn?L50tanZ+-Yaf{8#J$^#pF)XgzQnxRt^bFJD){dOIBs&1PFpu5tNinv5wSJDoxi8Z#?pZz}QlSYm zTdEe@B8INh()h5f>;v7JM8EP#?pOl@W}sQs`%fGpBOWo|{0G9M*B8M)6rocPLl#BF z)xqRQBV2W`&z0dXZJxu$&mx2P<@q^1@9AVwoirF}Ilsh#>d1$RbtrvgF!kr};)Vd1 z_(}?+;s7O`(Wg3m3a#B*1+?{v&JPkM6RYQ_aJg2z)|RtR zKuIcCridt~X@IW9y)~d{EeROnY`KM(N_U!C&?Hv1#NbJ0vY925Bf&zfzKZS!e0K@n zLx#rb0t3AtGpSb>y=i>|2p62UY{8<`(~z4LsxS%=y&!IcnJ)3HF<6kM$B9E1Z0j2T z8(n(h022MqJsIKotb3E00kFypf4i9(QFVX?%)%dK2FTm`$JHw283TI(#e>tc#=T=M z=6d5}!sR6dSz#?@D1tAV5#mXKBp}e&Oj&)JBt-(@%2JBuBgW^979`(g@jjaB#NF1C zZ-$kryyB+t3M7ASv_Jd#42Ry#Zhl&Y73SCf_ByU&4%GPQ$T-@vq8c6BiT|i1zKWzmG-QbxzjI#ZT6iD_K12m0&J@G zwp$ebgXZN2MG|bB>VdgN=9KxzzY+qP@Hs8^v+tFIlwR}AQj#sT2}{tNLzyZIYFM|& zHoIHMgX0GnDHbeUIGY`dRZ7Z`s|aPCs@Pgd0z#kO&VU2~huF85#7B#r^ycUm{sQJqi&-HCMDJ z-Yg*`>LQ3-Hq~T9Yu#rw_;kmL+ZP;9?v+2x^dB6GKQY;PFeP6qUCoju#?qBvAL3u& zOK|}g2=9)ny10GN%TKCs9p4uCksPmu7?tQt215wlnLmd%ESI>r zD$2MT@CAp{9eqOmheUAFPull}eBPgV(p9TctDISDPucxIQT(&Z*t{jjKJ$MMGn+qQqs+_mAx-f!YBYk9LD!4lpE z$4{?1N_-6Jr%Bw$`}h$9IGg~HsCP?-kNoaP&KJEPn^N`!99zRyBBYAAQIjfuW6l7c z*A6$}9UC1=Px{FUWZTI4Sx?Z8A7QR^iA*4&q9;(DZ8w!KTWWPEkRIRJJm!1O?@ScQ zOa5pwM3FI?gXEG$?uSh7?3NU|xGK*4RHt@igb0pea~ysUx&e824+lHR34%C_$xcw6 zWhx1wFP+C7t`eGYW4}We{21D=@g}Z|cbRzqXsW#UHaKHO)e$9FScUyV%xo~IqZFQq zpm{%@X+zRUN3&#WSz?Ozkdcp;LW%&W&w%06tED?-d|i_{>OqETMWdai$RY%ok0jij zC{m&}-}{M<=Ag$SL!<5JO-v)MqU}BoyWFxrW`6!mjv%|_G`9#}VdCO;=E24$?-!Z2^MFun58Y+M|65+aM*AS64)85;Xpyf7Z8RwfHFw_^w<48N&?@+hOQqLMySpBbs+DchDHQY4CO0JZadV!w>ZFBhnK&5+GU zQ&(OWk9KsucT3~Pg1-rK+5iIs<^>9yrCHF@op@Z#>^i|Hszw2m%gsU0L^cVK@f_3U zs+r&6N;aBHDo(7m-qH6qVnegF)-k2bzvpZau2oYw;pB4zgZT7)TzW5Q&O4G^Vq2=Q z*-k$WqH456FRs2(V8w%}!)x{oRT@QZ@|S-^YBh(7)>&!5w2E`^1lLJ)i~l>i z1Mp^Sz~A(k)Q(t4zga zE=F3MU!%{k94m`}HYcbP?-sAtd_P-)$VJRCVIe`rLkmTa_8xaVvt zWh>2PbbvR$>EhD)mM#njJlK#e%8@R1MH+pUlk#o#BIqfDb)zG8ZYWdt<@yJ)$-|K# zgd1$HgV2p89mf=OlN;~=kn!GoqFzld*H%9Khe`(nJjtj9 zCRDZFu1X!(lHfDsKsgOjsEkjWuPbSOKi-OA!VMExZ> zURIeBRfN{qP{q5$Wi+!BH>FY~(#-)pbkcbw1*)VzEj7k067j^Ll!zr#F@ zKRYNsE*U%H_=nK&h~uM(OscmkWkLJO?um-DB!kluhJP^++uHypx9O%cmb(w0NL5}T zw&`#t40xiO=Unne;(x!R8G_LupB_^KHR_bUKne-0Zw_fca*mzV1MAZz<+?o}+7qKF z-qR4Ns-{~$W{jM>D>uZvlQX#Rb!R?P#6oOM5)lZBGCyG(IfV{2WK)A*4!{hje4Vk` z-`=meUs92Uh1%<_yKJ9;{M&A_YTeG;jkxzrjz1d34jx%Yq(^8wVjnMKAX_hhcbk8f^xytAfSuf%e?d#1jAEsqLwtgOC@F7pDQYr)*9Gshvu;uFGT8GI z$N^G_M5zFs^nXXC0o`OG`I;Iqp4l+KWBlw_z%=>aUNEJL!J%0*L|$bH2uGA9t*8ew;D-v@^k0ndd5WMx7!P z`m|yV`Cf00Gv-2RRQ&48ZiK^(TfnIRTkLE~@=bk3Z8+IAcW{{S3gnfu$j>236}e{n zzZq&SVBP>}_SO8pXBo`4mH(E1r@7raJ+2y2j1SMqnTgJV%Y!#tY5*k%lG4)JiN^lm zSHgG`zLmR<6h4hX_`H#x9!|>)h1po8&$=OOrkN$Qg$-CXrpATN+1YvaFmUepT*t`h zWl9PyH8pj1Ma8Yt)6>GjLdD0AiPY8Ag}f$xF|*YnD)Fu_cFyZ zdaS2iJmN=+C=fVoA3h;_R)N_>b$s- zu<=MV+VCV;eq*xk{#?!e9&qL0p_SDgd@{!8ca&aO9&LhW{E zJ_%aj9iO2^YtU#ND+W?zFwOX7Udq$<`+DEALl7Duv(3?VW#=c!nXRvr z-X~2!G*j?=S%a&tj?@%(!@kH#I0zu1$Jv&s@?4Sd-Ap#;3= z?#V*J0>&mm(&!rJ^AG*k->n@E;M`)XVTG>Ac%$5Qp zpezPYk+)Zq9ThZ`o-jo_v~rDeOwb`kJJWcOajC(h3KDcyUAn0xn#AsiZwk zB~W_9FWlmJt69XK5vaQqui@SCEXK{jQklVfk4vE_LF|Y3;=ukZYiN)VlaNr!BMm6B zRVZ2)bTe3-#}Kb@-#6~eJfOLV0lPYU29fRGnO1BQp@KqWj;>2q6F)EkHzYTsM+5yG zaBh11*6vKz{vEmz?ZK4k6g{w30J4k-VL&?&=x~t3SjIeL)Bta^xm@Gi<|EMJRgsKt zng7{k(&n$3tXsH`vUbzpBo?KxKW2>KfAOR(@_P=?$jKNdt5EX^M`q}mlk6@qq0S2;THd zFUN#dxQ6`LUx$JhKOK!wmNk!pZ=&%O%ljgB@UqgPCx{!6^s!RNHVR#!@XK=y{X4Ei zY}f};HFg2~0+$&2vFd_6ux6+NK5g=X8fQWqKs15Pe$?ww)8OLb(!re-9O;l1LMzmA z@wa)%;V8^PS2Z~$CAYNnUXjgA@i6WAy+hCccG?nlvfi|IS3$Wc} zumwa#T#G=}ZY6%ib8}Ie>9N}w$RI9e>}R-i7*kjBQQH`SR@-BPjv?-rzm~sDJQEWY zZCyXi1>lu{g0-_hL*@<#z#JYbDd8M#VhOSaUbeSC^p|B76%)$=kPLue9$wxDp`r4r zX=xrivkx|ATG+~)&-e_>Okd^a=fh@9UP`W{du2PB=ysRMH?rK)`bqO;(?;)3*FDqc z4^08jI-rY0pV&-6Z{1ACD6p>%e&{EA)ExGs9~2va%4rgntbY1uFfn(l084;!a5EOC z1ZI>ZNfdbp5C3oENRBlgB~TTzj5v(_ZY2khO^D)Ore-A*V^!uY_z_yCmE|>*V$HO% zysL5BWNFdcC%T=l+03|xM>n@U2RbWbNi6Vcd!j4;lPVW7eB{i~medzc%I-;o=rM0z_PVt1*5s?>nJ~7) z!$VAipU`h;W0v*5MMg%3sj?&ajUDsz)+3{%cIWS8nsyhu7nVL)r~**;URpled-D8i z+6i`xV)v{}vqz{{F`@CHzp|2aKKezo{Hiu#1w;Zx1TqT&_1vDxAGjXcNo#SX*pSI> z>zT(9ZpUh47YCXC6S_iE;L?awZYAP^_BQ4~A7 z-E0dziJboXeG-)c`}U&Z;`3hy^A1L2^=UczlWbVQ^Y2~c-^1Qcu3Yb+ooDf9i0C9oKRuY zP2PBQwr<=KD(AaB9habRC5inoibzgMnw4ld)%_=Xx=KHaN+?0?8L+32=jd z>5SmqoxVsaamRyYr7xuB><}kiJsSRe9hm2|%I^c(qwJlcBzIE+3xfVrwj(4C?@H5A z_TssN*}Li4^0n0T$L(zaIveHPTN%d+kXs$f?$j(c%~)bb4XCZ1LX?^)Gt8y)LthPG z)d*i$o-p^8Gc)gh+HeFiz5|Vb?zWQ%%`PM;be$cpKhsQ8idVSuzu3OM65`TK9l5iT zHuPB`WuTD|9td(FckA{#m;YLhDpS_fVBvccl&!e0sP%V&9!Qi>l=38M6!4C{VzLP% zPqim<+P^Gs9lK4{Rc`iaK|r#?_S0*G;*8fCd+*vNH6* zqd@)MJzFnd_)W!uBzUulMSZTR5yO5^;Y!V)TTXN9(h55phFK7sE1ck6Ekv#Cn)f{g zMS|}(W%J9$ms9jV410Ou^zpNXlYMW)sPZ+d6@>7`+u7)PLU%T(S4( zQb!^Gx_H}W9n*(HyhRgmGfuXjGvR{0kb+8#SE7#gp z#k+_92)H35rj${Ht^-rGjojg0xF>hQiHgcT{GR#ZJZwc*XRUU4_f^(xGX=%Nfs19@ zP@@FpjRkGHb;<``tuM8hC~Wkd?q0ccnx+rCnf5yTJ^Z_f(wX~%ON`39X0EUNHn1om z><49$-V7N!extJI-3!3FCHHt)jGE3x4L42+U#5|NZmi}0-iY^jphawX%J;JTFe9EY zyCjUa8*%wRszFC$NFWlsFD|IS)*1%>n7n3^0putU@xZt}a6egeB!OCsrk1;2A_JCnl;d-iNn87 z)M(GECa}hl2-bqAxSw{`dMpTAwd49X)k*9H`Dw_7UY|-|ti2Dz2?UMiCGX%~%Xeyk zaI5MnL2nKHO8z^x}7ppOc4 zO4|V{9aa`pM@Ugit8J@VU(CJUEivKWFSFfpEuR^3^sLml5&}@#-U%P=Gin+dtNG4X zALW9?m;2IQIt}Mro$SmFiV9#qptnY&fIdf>d20hfi=vek1K<9eg|cQlVCpD#2l+sD zBBYt3#lLVXMGg6jz@>?>$y;#W_5W#UA2W_U;qV_{fVZH}Pdekk(g9cD>Y`@`V%`ns zMg2Y!Hj2+bkE}tDlN|tI)S*L!+qV1+sS34KydU}cD1>tb^P<>81sBUm(ZSDZD&fr9 zEN07CmDAkJD%5mv`?p3*<~N#ejlg+UN4`dngHPI8?KJGZQQ1BJ(I&p_IT4k%VG7gz zq_E3X#*5>8I*Wm6mGA`l#)$fBR+oU#sSux*;9#ck`Pt84_vwFK zN*z=?n>Q7uj#2BjgYvi4sNTK?{2@D8#whzkx*Q%|bCxb;%NFvrp zbfpe1fFWc`MQdDSxXvr=tlL+II5Elk70{&Al`_zwia&+rV}Do9gJG?OL^am7hV5qT z+)BX)le9NazE)QJLc9$5Z|LQEn*MQ9qdmwcXt#SnVs;=4ftRff6n;R7kAaLpCsGUI z#|$il^3QesB+89za(4#?nDO5dENF19O#M4y;=_6IK!TFe2k#ynNai8?)w(9*T@~kP7 zCP==W1!_3oR3GAaSYxf|bk{o%dL>T;Vj&JS#QM@oi^A`Qp98a5i(Udis~qx{!ll9X zMRQ+XCl;8J14$sh%BHapa5Kux81f-Nw1Qd`QQb1xD zQaT0c5~V~M3F#a{7)nBrMg|yCq@}z6XZ+sxPd`4;wOB5lIoEa0-uqSO8lDcW7;dOc z>7(jyQA6Mw7H7}ylb@i&9J+Z*9v|06zvhc{?Myh2bG}`2WfpvRT|dbC?n1}4vH3ft zP7dEK@1}w=Lu=CS^89E?<}$3Kqk~-Po(uo^>0u{NUP^%GQz;rSlL&e4Eyq^FJ?6Ju z4fdyn-Tf}}mJJ8#@RJFl!!wn#@<~s`zguz45ae6DZdAIMJUYEO1NX8lB%km3y!Ov` z_@8$uyk-44w>MQl1Y|FIDX9by663rWz=o#6_>_!tFmEllUOC+}I$JBZY#?|Ao>pz-!YIorN=h+oGR)n~fV=}|`h&L>e($C0@`pe=Aj8@kPbPII%&Sp)q8Or$A3 z2jjIzKO^q=U7Qsew52z>Q|X`%@6N&%we)Bj*q<^Z?m?sk3{8K;kF|!o)VWIiAW)Fq4e!coId(aTo^CZJT8czqJdli1^USR7iR7U3e z_E;P_>SESg9zV;BIq-=;XAtbHVI3^1$$^GUB$wxVQroe(NuJOpw0mn0&#l+Ybg*3> zDc#nqz>`idmFs&v)#(q#b=diKbN*+9K*!^osUy?;4cXXNN@f~Rn$s8LI}yCkv~-D^ zSOtT#c-zO2!h+HkI0M^NHZ~q*D|2+CK)8v#;!pSgqs8kb|7UqYJ;l@-1h$4P)FRms%=SJigZuBmjD5vxp7utBa4S2v4Z5G0927 z{%bBA!xf@tX?U{J+d(*rK>1ut=K0~HHv5YcBMiVqq3$N$pUHx1G6%FNSi1kO@y&Oh z|7Xdq#!>!vKM;~X$5sa9_D1)rv_3mn6^1o)FhcGqH%)kPC{2fL$12_ZZBx?OwNm~Zhe_KkIn?nZtT&2f99B+zkZqhY`@e`!N8$H9Nz-5V z{#Co{=22iDO9LCVTCm9t|FqZ`ORO*@ly!_C`)Lj^i2H+Z2LrOT$yPCalQmt7!dA;j z)-m4Ui;hlKN~;!L*nFKcxviX5oig;}DtU&EFC}Y|6RjPaAI=GF9M_rWO`C7;IFp1e z$?a%q9WMQY3sfzSv>?h9 zcFM2b#u1d**>+zV%cULbu+m|!Y!y9=t+p9O#c0;rpu&JZ*}@POU+km0tG?*3L-m>~ zd5}?@eykw${vG2|*@0ZbvYhR!eXI8w`};3=)e0CX`+GjVtC{vBu}pQL| zcTf0Rezr$!|E}zJ?ed7(5Ql8CTuOO+?cw|`oC;a=mpO7{+%lLW%SYr12>=Y|+X^A!m3YXcC$KyIt~(F~sY=wmlC z%t8-S_oIR{JetTBO0VT{m7^@mDk3t}B_2m|pscAHYAlML)uB@B=ryCuU2ut6-1NJ= zA1!@iHR*f0ZUu+;936*h>`%Z;=gxNfJJF~Eks_l;iTR~n*{vm>i0vr#U58Ju_sFY7 zA5ifnk>MCoTD5M8F3~q#5r))D+_7KHrwVWKwGn|BZw+>ZV-6sRTK|j$sZa9A?4aKX z9r*pBQ#ir^61XOTR6yc4_=l>85#lfN1Cu_Q%x*6t$u{^wHlNd1Z8AmpuV?$k!%{pj-Qu!3 zUpBzu%x?m%nVwJQ*V78F1yIjAe~(u>ZZ;mi+>@05;0nOK$ydXwNg{CCjk=|@$n?$R zxytRu*3B2*F&AT6evoAP8S;mIhmE=8HF3C+rR9vTdvVaei8sFirI2XsD~>w( zQThUdL@?Z^P**zr_5P{D{tun`ogRVC`5=sX@xBln)-x z-odYrv*?HzU{YpDyk|WYhM(8>8$Sboj`mwDl)v!*T8TfBVu@Hax=CJ>XraZ41H{ab z_0kxQ-H)=&kD+eS#f`n?QbeL`a74Ck32(wPbNGi+c%%N>E|?k3GRQW6thWIY&k_A; zcxXrei^paE-X!a7Ur&<(=coj+j0sDM1SVXa`*;@!fyvCH0$wp98R1LORJT z4umR-FBN!MoJDK_kBr!T$x-s{k1LtQUv@&~#_tCe&a-PYEqw_RzY&P9WZ5z8*~5LR z%Oa{@)1G>|KbOs zP6l~b{!HZe{GFjGanbtB{=Zxj*EPpf0*V*(Yi-=vG?K)x*HFE$WWVqko zYZlX){?af&)bEz7C~Ntf8p!Nx!G>P^4!U(4qQ1Ro)~d4ETAmz;{@f(*hyOVxm7oAEuBzV++J|hOe-M(k@Z1WX8{B1}CM#w1|2>RT8rfKL0 z&9Fr~_zphiDEWq<@9V%hL;?>aORK=&x5912UoR-)Wg^5X(50^!G5xQtF>5kz_z%pr zAoscMsvc1W+uQqKDKRgq1dxWXkniQyrF5Rg&&8^=C~xV)cwsRV)k-B@x0*GD1zG`D z8`hZNSPu|VNAR@G&?Sc0k6Gxx`Z-S!K6?AMVf|f>H3}d+Ilbx2*15+L*tmb18Q_?H zOeao1y$H!sY|dtZy}BrWzsnWF{bu`AOwIZK4^RA2L@t!2klSSnLf=+Gzpq4s03@J+ zJZu)N;H&_e!jSX8`_}9Ujgw#xdC0fX^(qYrH(l-TJ>dn4mnR@))pa*iKUmKCe0q_VHjpLPS8j|B{xLEagpOh@0$iq&-un>l&$jWBoYyiCdmI&NDEI0#Us28+u}(J!feb*30CE8UT;4{( zISt}olEXlJfpI|jJ0-livC$Q82U0|hDCRC;;K*!gD*fAd~r84UZ-Bwr>e#&?@b;e_nINxx$k><|| zOMKi4VlcO4x1N7V7v{M4f1YjpYJDH>-%>ucYy1sallCPVZ!tF+XMVd8|0WA#=9XCbkf# zDMTwq=|a~pxl8)8KMe@6=Gm-`zdYJ~ab7;t{y31!?CF#dUIy4U5YTd&k|#+K`16Ik=dE?628`CXT3CBsJwrTjjNewJETGIR3nmSJ@o*6 z(kXLs2eln#pHWvQz7D5yQNngJIf+K$0DP86^}Lxn^gFH~ zca_vMaEv0m`jjv;B#l(!oxnbdd5&G+$Z6W~8)XV&Wg?~7JYjl+aB>80eRRwk#m~sq zJMJ$!f|lcr<9%}M&n|Lrk6<}#* zjGuGMTf>iNX^$zyLby*U?4Ce7-47eUuVM8)TVJW3ege>9Mg3BmmqQCaT}DX>|2kaY zgD~iOia&Yar09pUxkY(!W#rKpqap)ddU77fepGt!+|mTRv2+#wF{Z8DwQ40 zY*282GP#+pp8cfH$cnX}2RKC|0DdRV0lh|9hpBbsz|PPOC#!|l#$Mm5oE?TTzKwaA z6Uu**x^kD--qbo@dye872v^Y?<|&Jk>J;5!{0_|R z{RiCI(4SY6bTNV<0?oG=ks>UwxdN2(Qa>Ij*WFBW$BCfjYXv>)hf$X>Zs|;$LyMk- z9G`W+J`{E3NmN@}@{~VcgUAHe6R$6>-D26l>geP@mh2a%7&uIh5JRB47*3_WFaMad zZuf-E9{eWj*~uJvNGZzK%$e$})xRVx<(VB$Wc~q3?{WeSFxww;mX{&jsaw{``2YUOtY#X(p!B zd*Rj*XbIy-T}VVIMp?W#`&fl=y=)F2sx%!@+I&p8iK9wC?aEC?^Wkn~ z(gur6_55ozQDz1cyKH#0Ex*452ZqG89zZ1w$vT7Z-<3oq%fW zttxpe|4mbQu70hgt>5>cvZ}Q4#V9BJXts>&J?5(C>mjIU57YCYYu(rJAsTq9FRdDT zV)!K;J=s7LWjQ&a|IZn#?rJWD<*bFslvRVk_K)w-4w0qiH!bR6#&mH2th~6hP1}@Y zI2?)Wy%S)9(GCeWI3$K8HjmCVB!5CRDYd}c?!?i+7J z2=g=HZtw_Z6dn*y9JoA$x#S-lzpMAq>hX)b3?oX$_4OWq9_V#5^|f%JQr^Br<8}v= zu*RU0{K~G(<*rOE*j&T)i=QfZ&pa0Ktq@|Na>Y?wHfaZ)MezXR5bm>Tc}p1dSiirb zj?-A&x>j5722m6n82JO~C;qxcEgowJVk+<5z6Eu-jn@e+Xyq%z+PHs$O;0fn$dc$# zAoB1~1)s1lM!sYXWNY$NX}4fSDOZOJ61*r3Q?F;dvMs>GY?7sh*^_@r7)9^vG@a1c z)h#4om-K_U`Zk8x&du+-SfLoVW|NcSV>31b#-iS<%*n@VyuAcMlM;5~B1-vp(Lvsr6uZ zqKT#1qmNJSpS^UuOW(okAUH>gW~7`r$z(dM3qZK z&{k&%yto52Seh_%%~{5U1_>1z>zE`nrn+3hfPO}DGi3$iGt;50Dpc;h8E;_>2bGHk z(2mHIG#=C^A&qMkO}$DId? zVlwx1F!}*7?f)Zh{&c4<2=-NHbjWXP-{BKQ~=-@KJt%>j11^68?m_K=Zo><_@=?xgYJ6*OKK& z%30lKp#~lc@z#dBgVsw~^$irRBCG*Qzh|5z#RY8|dq{kQ2h3CAsQGW}idn?PQLhU6 zg1lobC+;E(tG|Nak(P=bs$C$q^UN4|1?hxzJ+%WGM^f8)n+TBehjpccJkcj6h3toJ z))CAMXYAb)ne}3PN4qWyuY6j(@cp%={1K`mN9$iT1?xWvAxf7T52qsl{@G~_iwUh* zx;WkJ6hRxjNA8D9dAR>ve&P=_^2e~-yLNJ%=^g;9nvaw|)*3KR7_C8b_h;}zXG69x zahY$~LH}nVKapQrAO9qKg0HMfU|xWu41!MK;|Y&UNUDHvxitNZ83yzY!+NPW7>6Ll z9*EI05=+CY1j!F)BLdY7AMDLb&vbF)dDjvBX%6+;4s=9&S z`Mx#G-nX>ns6DPd-(C!_3GhI{H7pNxaB~FxG`8_O$%lEB@9J7kRagl9KR1>-t;PJ! z1zc)};A<-Pq*c755&7Ms6U?yAW>JuHnIN)$qo__)N3Voa=&guxfX`yl@$XpRz3OKo zcgOb+1*O#a)S}etsCMppbSY!>$4-6~3`yOKc4F^&rkO_=Px-j}*e?I<1l^>D@w^vp zB4k0r0GL-%YQX&Nx2I(5UYm1P4 z;~SKtD|&ub)ff#QEo!Mj?eiUZqekdBvQqYDF3)FbK~pq@oSq-ivyTYdKHTN{{@6zgiP z^OeqbviU{s_T8rlpC%Cy2FQg&4qa_8kgdsuzZMc2V=G&qMfXHO3jtxhb(!R7|1WS*b~ zp5Y4pQELD)#Una<%e4}uKx$#MZ71pO>YSz=g&o|UXH}?DCu4wq%hc})73&9A{*HtL zwVt(FpHtV{Mh)~pgYyLYDKXyCY^c+z&7ak1Bk%*N?vvTNx?OHA6Qf=a!xK3He4R%_{ zMZ#1N))6D3GoA|&2B)b=z)o33*rGPOcwx?MR`?cpEbDV+yN$$my`Q^5&@ss&`0;W> zJB*@t@6LwbDrrvOgAG-%Qo4`>9sg0c4k`dhRB|mEhYJ*cj()SuEV+Q-P{wB9{K1m` z1B3?ameo!7PSUiG_e`yBMC#S3z2?Jk{vF?3-bRzyxscLz*@DbP5N!CI3{L#mOskEN z?UOx7z<`-=O?X&ZrvRE2;2bbYE%Y>}Mjjle3TgHyu0f z0mJ3zqYrCr{W`?L8P+5;g3V(#C7cSUxmLR``A%`R`^y!ulC=qTrIa7ERa|#EN9*{U z#AaO#2|MZ}*Y}Z&%9C6BkN{7=TQ=u>x)SF&q_=sW_P~$8AT|pabPxSe_N>13am(~} z3csVas-hoFWu#1WMr5T{Vgo}E)uN^(5G_?^;j~CUBNlA4(0-FTMG#SB>E1K|luHaX zMfHDs^6rdC!W!kzC}lz1W@cF$1%k{2pfIwcFvrlT)bcsG_gaBj9dS`4jt9WZ-_+|M zhd0XLW(7MM(9%|hEFI^b0Wvqy92-Bp`6*4F%s)8u{_dDGoo%dd)C5pN+?n$Yx8Vv} zN*A&wuw5KLx07R09NR9ny%k9^FyW4e@NCu?9hcsuwf&GnN$mefTr$l}%Ewcsi|O7c zBH8OlR@@CPB>yCTbN5QH*Rt>il=~OpFi()e4x@p$m#++@Xo`=63=#C&1TBMO@i6Mg zD45A-wG@+XvkzP!^{P%G8NfTrsv{urqUiLCO8J<8F8TcV3fjnvkDN_9C1F1_f9= zsFt^kM2YuTC~1(f+m6o7F&v2AqlF`?;hhPd90!2F%>#}cdoqsf8vytfh-2;bM?K9a z>gWfZA>`E~n1>Ss4WNa28Gu8LryR8txhzw@TvQiNPQ`=UNoolct3~XbdsaX}jT`s?Q~-yZYe@v48G< zAX*q!SVX^(Dp&-WvzJr*f&8Cs;!jFYvrO2c%e}W5Dvq>sw&uXaxMSEkUKHyBMQwrY zy0a~)CCKTp0ND%>eADrC1fOt%QT_1GqxA#j4`Q+ulk1^Q%zgj7usca)| z5#}7L2rmgsu{;cF1+g7zOP;Uc8L?o<;pL z$DtKq$xPc>TiiW%_{N!Na=0svBd9t@syVXjfuB8>f{v?xs8GBR%jb`GGgfQ@x z2ypQ|^l5tlQ@sh}vP{wl`WYMRVMa*NpDtYw3UHrR&0@?y7EM*5Ym=)IiacGfAlZN0rCv=n} z9$vEDurc``ly&csm%|sUDYGC5S41#_KO=p0>!0?Jf(rCFrI8~1E&8}c?AmT0!Bmn4 znnJIAyP{LFg-%704^FjWK30P+JqH8mwl0R+ zIId(-<%YCx$JvBr-1`;a&ApnQIe>|Xx=MB!kLo7tx!nUkPfJ7qvIe*t z6$zSi;wPH!<)lEGUpp7fq+t|N0VJeGN#f32V8?5uIJ90_Mwg}|()ZWUX2ZFv*9vCM5E3PL(G+kDiwh22r1^`^$JJL6*3@g9D+uDqvRyA1M^b*#KX59X1V+G7 zS6`6mtR(%C?3r0A+FK3n!rW(xni)usU8?M}izKQH{80<+2Y2clh-c{UinQo-j>qd5 zeG@FDy1(-2y~FoAowJk9{q3I~zt7q@J;2y!qfpYAL={6_A0H{yezE-=_}dSRjEt1w z=bmSelj@h`fcg1ZzFw0d!>%1ZD-?&+Ui0O;Po&Ma-RSS>LKRPYBADQD{cHvGA|c2G z8yMy5@U!55vH-6@l?PxAU07CT!6kqGxt11I3n&H* zA9nzC`~BlRfenA#^@-|nl=mtDmXa_&`$>moa5eQU!5HoB{MA5m>)RUrfSG^JDy<-j z-!s5iSneA;XU9bf4!Ki9;awcW=A8)8X!(g zZ)zLL(R^2GcJx9FBrERv+F^sqkI?Fr_Z7(!<^dJzp+u}Gf-tXzOHTt%)K<*D-KX<~ zZ|RNCjk*x5lo~sO@|4Pb(9+cnU0t`RTvj+P)!P8f^qPflhm(yQdc z)W}VJWpSH)4Nd^6%9p*$)~WY9Q$s-(U7%bs*|OF%oTnJZ;sVUL>=1^>MvX?pR~g18T49uw|IV$P0#hyPk0?Lj?_a)aLOMRzF;1UsK% zjd%K^S$o@_jY-lz5*<-YA;3603rPQtGSfC@%TcZuK+1Up$p16OmmrELSMC zCTh;FcOkKd{u_6e^B(<=mVI$w|~qjF8l@79_6)g5I7I_yb2ZrHqyOC`sQa;1br@Ioe# zKvt<1;bdj~@3YF`Mzj=boVbTZ;yjy1@f}%$rn;^36F?DH&Op{9Q|Ix|D`76&?$`Z= ze6MXBc|QQbBb3+(26!OX&12#q#CQe%<-YM{{Eh@QEQVF;8F=pk0&pT(?$ZK`l_D^a zDMr%2{vd3b)zH8UTrUb*T3U*w*}z5y>#J#8&WhL?PQRcA+WD_oykxy+CBd$Kg$p|; zQgNv~vQFjH>~Bd*^b46la>Xssi2LU3g}x+l;TgXUgUVK0rQN|N$s@DJoVCBa98G;a z>$Ip>!{{%I+>HHy!(sldM^7BQx|yC7(3nSuM&duC2e*A-Z{QHsjxUBqyy)Cjd_FEw z^6??eeAmx~+X`|wblqV6lt-4z1MNoB@8)YD_Yeo?UHtdfJJeNIe3iQ>L>OU}rM)$`!@hxoOr1`lCioPM{S zS1flvODNNPd9itk0|AZ`yk;v3BV*(A^(SlEXTCBe^{c!^26aL?E{bt`UrRv>!R6$U zPTT>r3A~%h=0nk#&^Y_b(t|74^WSjIO9d8k(ekI%=b5hc`3R=HL_Ybe8|wCdr8RjF z*rft7qmJAeOi2PkIYkw!z&QhGSP;eD>+Bje9>c%J3y4h1s1TqtCPK-0Yj+Eb0BcM% z_GRNK9i9Shg4y2kVA35kf`Y=rJHBV`Dk)-=tTBY_0pNv|9j+RA8g~P@3s$w)0{~Xt zb%;8YB4xZwgk|aUVX#V{u&M3>twQYBGKZPDa8+JAb2zez4#3|C@Qw5=Q-iQx&%AVBi8c(m zYa7|U+3&*rDju%8z-QTB?IjaTedzCG6>%HVkkvAy3ggP-DHA~?NN-tjR832HK7No^ zqVJQLDrdil{Pu4{to)$6W&NA|?Qc5DQP&w@n7HGms<_8r?~$N6FXVXR@c{0 zUVoyjr>aT@%0kL|&+=7p#J2+E>QZ59=?#@yb9Rh$cn|*j8Sx?lYORHgvaac{o&F*J58652hFrZu0j66{?lkolas^6-W^Vm z3}XWY!AB;Tsfdq%Ti7A4oavIAr9Se*@zP`g+0@VQnu)j!AHpm=NKlsJGdrQ5;~lLG z;bn{y_eZkCMkUr^sw7{6kjgxrZ*@PQtAoCFZ+myYYRN>fFFFG$u}EoCO*Pcni>#H; zEavHPknc)P;#Q&YIq!l?eO5(}Z56 zP6fYfFX5-kMAUj~L+htW@IB@u`V9rUhiek)<`gMQD(yJannM4+sOGU6?Zp8`x7nSq z3*izg*v02{I~LYOVOfv&HW5X$C^ais9J0&K**Udj>iFb49G-9*BRPF5z$YqXtGH!Y z!y_yJUFfh@T+_OFnAf1KBjmni;=K+xZaZE|kNNigrsb)V&`Y3z8v}*_(+05DEy-J^zsqTHWaSV1Q&-pNLFIt>7X5E`D{T{P?daDrHK^v5J zI=c@ku^^CEi3Soo{c~I88#t3{DFTqr5_T;g_UX7#M35tC@6q zl1Pq#n)D>_5#iyCqaq`dC0uw*8g{zh#C-V~e#_x-%FXh0=Sw_jpyyGi8(=7ehx7+D z)seIJ&2&7}s$q@gPrRNS@-27mQP@F#2EE6GE)O5$=XhEg2vKw61k&fqzg3ZwD`!?4 zeMdGdlBjr|IYqxsGtadnGiHJZd&R)MAC)K#5wLCcvtUM%Yw|s> z@OjMg6Ku)Vx4XCXtBBAB&Sc=Yy?@4P56a=*>W^EW&42~lX3_<12&~(yTWcF7*u??X zY?&Mpl(;)D0Wbu&B-H2M$d#@C4jEO!L&`>qv^mgYRihkT_uv16@F~Tg zNYqL>_8Jpsj?hE2VsDr~0yMlFw4nH!=nEEky_?KH9Im1Vzse+47L^}RtwnfWXJ3DF zm~|`|$FkQcEmTn4{p>Z8ONG0z^&Y{~$H(|65PRbIt1C7Rz4Bci4@{IIqkyqbzu3!r zOz52&+rRWjndKP`q=rA!yf*&W6=lE!2Th%4$n-qW@d4f0&$!Xrg9$j+ec`x>AUp;J z28VvB!%tPf07Y?W|CPB)1P$d6O3h4bG`ykk{Wmj4jP?1|s*{snY$Dzc(?>Vjfg9N> z6QSci@%I>#cs>Mr92-2m_tek(4-o5@E;g3G!56`>w^doXbHcYsbh8J$>MBN&UAbB+ zTBUn*`% zPG}PZjLBWy*69x?ohg8a`@<0942uwM{K5~h0SyY(Nf;!Xs^-bY7B4KXt##)6iB%(S z2Q6Rnl+I4W+{WeHJ-!{Qf2)4Du|~^!nEjc8=Ck>XZ-Wre@xrHb&sg9T`cN_LF?J!@ zAP*L}zYRcym;;P=>#*`B+~@FICJQ_M?@JN3?xdR63o{7iV7euI)T@lAB>qyx`}7(U zOz0T>w>uGDz;tc4Io$|40|u?<&u7t17a~Sp>ms)eY6U^wgb4U0<}fR(IEn(Batrk2 zW{clHtYZKK5;8F{LH?x)*z|G?5O7x@kDd*pGm-`RnESt3>kF}YGTlIVc`n?n?i~mZ2n*oWv(1??nx?i<IHMfqO>(p1Ru8=Hx>z zQWS61mCYsMcJzNL!v|i%3SLB*ZXbo)o1{Makvdiw%b_a zD*dT<2(QVxwaM}&1(N;9 zK6Ta)9;X4CiI@%E&$HS*d(>oYna_ii1)Uy*M!{LYU1{a1N~G{$!? z+tygQPwdt5)t#N4+1;F_84NI&wBk}o`oC9`__n>hxVT78Vt4D-Ev)>~(tO=4TAToD zzU-5^^v3$uS)631um=)wO^1P08i5=aUm+XmlF3@;E&<0WrR$1IfwnWRs6Ql|eVnE$ zXwi=uYMGupC&BR|7s2TYY+M!H%Y|ecR%nEqf~(wsh=>O&$k0p=;ebYxU$|Y3WRuOS z_r2wB2yw+3BQs(0e+uy?ou$G$V$*R~^#f$QnW(&(h)o3jZ$Joe@`JYn>l*H+R}wTE zLA)Tp1GnogIQ{zr?bCU3^efLNud46Jf4bK?b1S6n=tKz8fljy}y3CMaS+WV<-Z$6T zT7LpbhaOQk=P7??8TXYA4V9px{0tE!>p*X-3!)NY0=7jY7fNT?o#;C*FELk42qR0! z(;-Jh%}ZyTP2V4j)k^v9x3S`jTw_0>efBk>pxeeIcs^ z35ff0KUvA!?lHc&Rx{)46*8&mR`qH{;|x1CeGQ{}Z#2KRI=Dq0m!*D~ty}2_qk_C^ zbIC~MC==l}q@c~V98)6pw%PK>R zruhhZ=#aQJPBzO20T=dNF$vUIbwH-yTxFIwhxo6w)O?0;wGwGcoE%>D!kg((tQ$=g z^fjFeoW*cK*m!do?;q)haTDKJtciNP?P9^@|5y-(eC{%9L4D)ou+Q^NQ*;CFyw@@z zDM4r`gmTd2?8U01;6<-6mExn1j~fH*)nsb-Bf69zqisiS{U0IITFS?j=LRnLR|sn) z_FhABA$}>GF6b2r}YaiewZi@IgB9{tW+}C1v}@eg#$sQN>rH-_7vy!N;?9dEeJLY_!}Bh{ z>0$;%`L$;|J>@`4NPgemZM%yjW)@)K9RPiOXY?npX!rh%QYEV_N;d-Tw0+6)Dwh{V zk^_k$jjzbcZU`)ti3>OjEU50q^{OsDE(%|N=ln4=g<_ttB(?pC@@I-lAIhny&@&}u z>5T_5JUJ3CedXTC|1yWnV{P&7f*5DRK-2X)2*3Zf+_H-HNbdQL=~?iZIpp$6hN`>~ zK^ac-u=3~VW0|2iD)~=D&5s3ruevctCUhEx5zMb)+#x!6H6meMy!0(SYmjO4w@>1X z*oCN@*&%s2xoax1oT!CD&ZAA(8smB zV>@3ePir(*Q*Dj$#7Tmym!P#*JJ2YDv@_bzc+D#<_aI|D?0$q>ZuIGHXP3T?%G|YD z5I6hJ>zJA@07xW4*C}KF!@Vy#S zV0;0AhX3t2jg0_TXUUfhW$UQEhr@TVqa^l8RQm9U_pa^#1m4x=iE22Nv@6IWs0Fpf z5zM0yo+WW8VGtsJ>vi*)@J z#}Qz7t&hNgGt#(W1^XZ4nx1`MVT6PznDh9D#B12UJZ*I{e0aI6Rc~}V;KHBrr5z2k zo&}toSMEn%X0fsYL0vbeY_;|{Fs-g*;)v7wU-=Giuiq$_18D}er~#ZOjgM!CK(53(I6-^`)g1u$&KT4__-5M$~C(8z9s2}7tVYFG!g zLqp0UaW~V9|I2S7EdG``KnrW6X)3W%{XIjrP7E^)~-B(rZp z!0(s@X*ZQT#8?xpM>HsDrP;^gIXoaDX&sxBzapXW?*8o`9|x}sDHj&y51s5?eS_in z2;;(($PNwGfQ#(w$tW4P*lTddptJT6HAVCl1c$X?fdCtBnU4GhO>7uX$z$<{@-M@? z8X*VGM+PXODS;{%4PuUF6jO5|`wlAfOhf+V?&ZB}3OCtt6@E!jdQfyoju~foy+NuG zKx=6gpY)}&@4zTOq^(D=vxh5f%}BaI#;%x%`_B)$Jj6U|gm_alD^ecI%KKhtw-`s| z8^r566h4>Xq|f(ff07jD5ITXfKlN*)$)5{Jl^Z-L&Yc{kJ!_?&P&8sIdSG4jc8qq6 zSNDZB%%spn4>wmzZi&RuZ8@C*h;TAb_d*G#&~HHYZqD9A-W&(uJIt+8w8X}Gf)&07 z)X^_iBV4F{{PZmLYka zq+7(1u%@%bGlPho1462iOWUJ(9mLS_nZhHRM=O5BRY_=%wX$Aq^X5+kt9^nr#=i!x zjM^0Tj>!y(6ip#!{O{*QE1IqTDP>DyuSn96B|ZOg`h4iRMHS3IS$2P*)0}13JshmjS zOGB1NFg3H3t8UBpJMDr6cD;$*{A@rKB~Gk8E=4jcMUVI-jOoa67NcwyhD_7NRx->k zN>i@zvDdT`X>Ej1Dr)>BfF9k5+*%N(JPrzdRW05u??@TTOIZi)K*!%S{ocNK7&yiS zttDyZr}O|9qgtMV$QkyFUgs-E_GwP^O%A)CLnUvsb5!~~sAJ!l)z&E>4gj;Epucw& zAy#ddUCkZ5&;-;v2I3f0DWGl!!^b%Pic|$LA zI6Alt?^vH#lkIQ#wN!pjLs&!Wsse*HP*bTMKntc z*$nS$QZT2yG*p&niMm`2`-VH~1YwKlN>})=`k=~C7N_`K4l(b23`91^%Ux;WV1E*`dML1Z>qiKH3kmOpzNw?8qS~l1qV}EO#n%Qf{s;Regj&@0J(hn+&dR? z6cXC-AN%KO+pMaqvB0HrMbX5IMeK{{dpt#&6CrxUz8wuEf!@eMXxDSfVx zFS7JAi5_HDkN(C+JNvrfqfyTQcgoSTgW9mTu zLqKRoc`+qbw+6-cgxh-WqWgnV?s^|NNCNG|@=Oy*+zk3D2%sB~hbU~e#^!gnJ&sb) z!Oqp#&>Y}fE)$Z(Pxn3HNdpW++^AU7jF^G6H#xZ|a|LG*amPJrvd|;t^ak>DGZOO~ zWiDcwTyp=-+}BuB*XG6d=ynR-CuSJCMyY=PjyFkjCj>=;Bz)5GSa9OzrICDhrO%xy z^{w5bj}U46EoZ^9b*o0x=c4v>v7GEeA{aXPI9^HywxSMCre=@Xrfi`VZpag;?X_4E z-Pcd#S33N5KNS&mTy|bOlCQUzWzlgceAP(w><6U4zw0V(312wNQnz9iohxwq=h(_T zqu6A!v-th|M?}}|2T^&XLa=qQ1Mv2U_hwd|*B7t7DMU=+Vg@=dX32o!GG&P52t*Sw zfq2pv%J5$W7w4z*SQ9-^WPpy3IxkZ$g#9X!x%oN}VVVJ7_c*FO3EMvaaObiac~e8aGF)f(gO*(>>R7d!hn&%Lpx;t~MV22C0p;IX1h4r$`_7wCP>9-J zeLC`WDt`XKO5Yft`aZbC%toD57uQJa%?A*VliD`Ce52fFs;j^#L{=OGM;C#+@75^~ zdHjK@Z90wtFjD=%qL=dey`$;n6`n?a3hATkAURW0pF?K~L=}Nt0d1r};I&WIA6D)P z`<%Q2)ZraW4P9aw)y<#S6e%n^9JfT^rnH7G?S$Dk#t$*>dXXu{R7_hDRgJTfxIKDg zWb+zf35*Mta@L0Nk3Ae+roWnUmigX~Bt?LUo{8k54^YyY={{kOe5kWr!LoIdt3oAZ|t7}T5W|0C(TqpAM?zY$8ZvPx3O zs1W*4X6jm5g_~@$3fIcDcSL3=5!v%vWn6niRzeXM*CsNqnZ16``~CgloKEM~Io|j6 ze!ia1$MZ1`-L%EkN!Ln4rTMt>kJ92pYXpyz_%eG!L{7G2?M#%F-jqzUh~7Ixc_WVz zWHyh9r$|A=3bv5Q(wTPKQj{Asc)}-s|>goc3 z_v$`<{j0V9=cKLrQUy8*1k$$}s^35K9iI%$&)m1E|1R@+Zdzj#wbyc*E5n}m&Z`Xy zk`wZtg|4hFGaVZB-tAx)1ujvFCPflCN`tXKHIMuc+!)LbgubhAIbV+0PLybRPNh;k zb(>`MO^3tXF#k}-uYQ{@p~fVa>5iLiK5%L?idD&2{W-GydF!G4Ir2w~z!18a(Z)JN z<_N5SyAl$f5Zrwi>;J&CAeqcy2b3o`J0mggqc4HWKRLaUmK-%sobs+K6Er0-M= z{vGSw_|xSC)6}~`HXWVyli}LcUF`Rz=$TaRF1*w7vOLSxdr!{=+-e@r4C1}#|5-nc zNuN6SSpT6w`8`#GhXR=t`XgG3!D$cv(- zj?GoiuKX78IXKb(HOOY}d4b_BO#$36|=XFC&ro8kKL)_+73iV{6kl)o>lCm<3Ur zDiYP8z#>I~A$yNPl=?O9?(Zm*o8zKhcfJ`jzp0lssC5;Lmf3!K-`pZYC63n#F#Xyn zGigSKhpKUYU%S98?FXQ2MsjsQM7k^yFirNDODxW`2+mghumKKme7D<%GhHc6E~{S* zbN4H>7Aq2TX2M)~H`9x6(|vOLdt>+K($moyn-Tc6Pp+cFZ04{pIt*Py&ez)rX3#t$ z8Aa^cxFof{Cjq4LGWl_6cXG#*y(WbwRxanLddmNbLqjzYLR@uG1cu_xgajJW%k(jP zh;L!bZ3&^vGSOmE^g`nrFFAOSp7!rA^^sn0_(k2)EP?pAKJp?Sm4#pM;J|MShSeqw z;4Md22EU*NquZ%^3RppRShSn?Nqmm;&yC)CkGQiJoQSrWfr(}=fWxx z0egOuNf|lw?U&c*;s#8Y&y;UNV06?N(}=2q!R!x#A>Lc<*ctx7UWvfNLo*NZ{rKJQ zVN$)H3B{;Vk+y_OJX{o9&QgxtB0}BfTGJM$8n$@x^SOgJb|*-4YI>f5!_{RYo8Y0hvLkEVxn^La=5Li#CpHH1)CYxri(QL4>`>Ir?Y zZ`uR-t_)^7{LX`wQ5KI)tN{N|smFo6{(0MWM!T$vmzA?CcM2+xX6(EWgj`a$F-gy_ zSaq~2w*m8xnO5tq-ih*a?BZ+Syvut$)_Y6tWUm)k^FD^{x6X{9D#tpyR98whsW@ET zE`_|m?cw2(>^ky48}APQa|mlzm@5-CFeEaHTl@TY_=V_@+g& zt-q84Pp;&w%fWr*gRf+71DgD4i~NSQk_n@X)i0&Y(^OSww+|{E!qp_qUHD^MZQhJt z``|j9xg_%6cY+aiLR7nIiBPk5(`6oQC&fVqnvBO|k6QEbE1DCB8BZIPjD94Fg}}~p zVDmWfIk-|9q6d&?BGWOUr@@EwrU5p>^p^51HVZ(?3-v2(He1O-RFn>2LTYcN2a7d` z*f9?+Ew7PN(OOJ41^ipWMZ*BFzq6_dK4=g1N#--a!CDNaW?5*~vLYm0m!ED=`1#@J zZY%)l=qqtZfS8P`<9`>#4U!(9o z4M|Lc`A=K0c&~oLHZ`j+u3ZV#p>*k~B_e_|k zP)$XZ+0{l2zv$|rz;_?GR#ljzpVl5GX|xU%ms!CJ%*~XG@8#L9U?D3{3Pj;?1s`oL z;NAEe+5MCaC$f4#9~cm{-HRY@mOFcUSDkA2_dORE7f(~Z@`jxPkUZbDwJE}!_7w;r z4}l$%{wNd>hce=N!B-9@U%R2tS69@sv$KK36biGP*t2h^+afNsL`ki5fH~(7(5D5i zZy0;`T7qw(C6rNR{V7)n1xU~0ajQL}o+uj3;W`um9ty#JM@uHao~n&{_n0`D5V`Z7 z#~w*Kcau53hpg$1=~J>G8g`rc?svRAy0vLNl#^kBHn zd>iOv>H`h}9>a&O4QKC7Ylv z&U-}pi#pF>Wmroh@(B&9aS$~xhGno?#L5LKUd(?`6yKrnlu)kT3t!C6_M>E$d*$X= zB}UxSe%oVaB=DT~IsMEcCQ^#Q;`BR+8VDXIYX4x&O$B;sSqlm}JnI1r^3*GN*^|SU zvWIIAE?&I23!rs^oS!VHtjf#FKR;^?d-n%ZQeqAG4_GEyY@*bR=5hY~d%h4^qsup1c6c8J+Y+4nom{w=ff5*ose2^F0zej%Z`P4(&>bG*~P z$#2y9V%9xbU?Pz&BB}x`OT1m#IcO;7fx0&b6ejbQ*XLl>@Dbco9zh0AQ;vG#`f+|k zLkbWT#QcuE;XacEx*L%iOkBr>aJaJ7E|-Qcg66Yk^Mr5p1JDzo%{>+iMx8?ou{?)@G+YxNZWo!|eT`Ten z=6RW&l$JmB#>VSE=a%I1Olujs*e4FNVgv;#xhqt$!`L>qIC4bRDR)Y#fEYn)|8|Ah zB#dM`9BVTL1M+0rWqz*q{KXH&EeXrqT){#3!qh*K;VK130dD$!ZbQze+GhEorl7Y; zQl1-K=e5Q2Dj8K?%hHef{651V{ojJ50lMoZ9*gaeVMZZqS**FcOUI^xTBN^X53$x(t%!^A7LZ7uW;X1`Bq)FXNhR}W+1b~{{g zYXp_*E~S{&k4aeJhyidZ3=EY^z~{FD=Kclk^2n?&Ut)m%7?656%XbO=`>`Pye$)Zq zr`Y`NPNai3cRE@Hf63*id<3W5_*RQb0nnM(LgWqus4lx-5$H7&atAYD#(f#bzO_$7 z)<5qIqDI)9zFXtQ&l*T4GcA0(-LKjSr2E-x!qEKCOC( zfX1vF&0~!S0?Q(yDJNRlT~n}6Xr8gignj>;55}Hv%6dw7d{sXPSq>KJor!cZJ$X$3 zJZzu+slRj#fB(JORI{Wr6b7-z#gW{;?veQWa2!?5zLJ<&dYI*!B};R|4X3rT)5L_k z8o~D`26!yW^@xy$IPW}E z(`EGOsfljAJ+Hf-Q%@GcI~;?3Ij9RQuoHmhS^T4*ct!li_7&Ptl&uqx&Y#Sv3l8cVcPB_^r-Z^M2mckYKu zVbVl8dAon(TYp8N?$>_0ZbCk;Ce}Y90#O^P*p!*GM*?d*J6nPRHn>=>zJ2VIrvKU! zZt`*U2c%Z@m@)Q6JFdI>FI^D|?O2?YqpBk0#l4KC;u1U}Q$kIh3 z|L0ZbwOTJxX5H(`!AAh+>5i)*m!>^Vj<)f=`V^HPMJ=g`PK@wYzT{|4gnJ_j8d>*o zKkwmt*TeQ7aoR=IM3OXgg9*@&Qnc1*^>Nbk(ZbiJceZxke~5qRYkVO0J|?UZRT)>L z+i<~L&EXoC6xoUMrxu>pr_HFiv*#}hgok#xHa)nPvUHPo#&Ay3zmt}0Lw!;&t%Ud& z$@1YWRf z;FoWCp z$*}XpI3u_>cIr-=nzc0WHr9*mnZ5pWC1`e^3jAd}SFXAZxOd6Ravg0_>geq5zn>?) zTEl{>HMuziv#R3x7_Lc%wBKl!xZ4Lm!>v-nlm*uKr(A!vmG4Agw#T=k)PCsWM17Li zK2EhEoSD|@LZ6yVB;2%E^>FS(am6CLE+SH#Zs-SSQf9u}GW?7Z)K+{EKDy29lAMSc z`2xOorucqItVqQj`_!BY%xFW^#R}5`2ImSHnWP&%7t>A85>zWmnY?8+E~gCrAhvXp zZzf#h?AgWbCavWTF3hL;$wJ=T8@hyR8ZuitFq!05?cmKqIEWN!qxN29RwEaNUfF}2 zz~!r7;(xINsuJGv6WJrhz}Fe0?8VnylhOF~aeG(QIE&sd>dy|>&iZ?&iI$YQ+?&!c z4@>XrXA=yB0dCpp+|L8!pxFj7h1r|9t9`TLrb@W0YKr^SOC-FT6)dRLHml{N6avK$ zmKbW{iEx4;zq*|cb`!4*soo&JH5Y}^Xj|lEWPiabtav)SyHGUN{n@gsQLCB=m3PHZ zV+pjCmIw+R$z}G1b#lAs_7>g~>B~md@8=!*z_6(syqcyx{NnBsNui)ExqpS;k>eT) z>^w%QcJ8j$ZFY2(?a2SS`03w_F}Ij)+l;-uS%rE9KUYD(2ke3sb#;< zu^n3H@Zcr67t&rKp!eW}j<^nhW1KjDq;8c6X#nuPX9AB1{jQw&Emtp>4t=1Oglf1q zzSA%->Drcb$BB#hoonZ)Il(iMJ>E43HYViZYGT;hqI^Ogjg8*hK|lDaWCeI1JO>_K zHHxgQ%>V4OcaK^Qxv9?m;FK!#UrYVI9a0fkB+_P4lJsR_s=AQgT^aG6@v;E&rD4#V>LN|BzFVLRI_&Oy(Dr{z*L-a&hOO~q>*+ zvi6f@QYqTb%O-quZ+jQ62>yq*?HQ={$CAApC<7|i(-pTcLkAOcrnfpkC8WvEv?2Yv zQ~Ct@Y_L;*>8XyNwJolGz3Q{-CDW531LI;fR6~KAH(faH-e6*3*?=+Ld4Hkae46N- z=1{l(dj4qhTl?cz*sDM51Qx}o(Dq_x@u=6Vb}s~nT9uV&ocZiZm1k~m-5jjP3aH9Y+Y$SRC zeSs-K3uH*4k92ep5F|;(>3s|4G&Z*h9|@m>5rxoo2$P%xA|_D=0dI*Q(MusnKXP@* znm*|MF4ylfQfg;VjXkLNN(Bv-ui}0FQ#l=fV&~bpt_2@?`>383pSbAJhMGUQ2t!$0 z{U&0TA9poY1wl(5_3~4TMb?$I8}RPWF~<1#Q~7)2^wb8EXr#FSqk#85=ffuX2}{f~ zoV+LXZ~)~4h1vOY4_t29EnFGRP8*&o(5OTtnBjf#%3^l-i9P_bQuNtubN!7*ZoA!e zoo)YV_U-h`Ao1R;?3XdvO`6%yn0Q_+K`fM3p!H(sK~67 zonAv_bDDm`=loNi3he8>VL*GBD(k~0=PNza9^JlLi7Pf{W@aY#7V97u5OWxp#yM#% zLOhfeDC=lAWk>susc6}{{Ws+8>|SIK>_ZRDC{7kr-KEZ%4OH^TlA6LTO5 zk5VO^;BV{CSxd{P9`vL$olM}zYROLtEw~8F%%zIW@#5j)L+g=yzuv!{EMP}{TdcGy**kT|{yMbg$@4zyu0&P#9uXXAS568b zr*ei?!4YyFU&3Q41@9DtU*g2TQDWyU&3d~`0S40hMyC}&=s5nHbo%|@pljpzBA&*# z7mZ%CqOJ{Q5atM7-&RY~ZcS(&C^B=*40~3>@2xg>QE?ZcD6}gybzZE^WsuB}04