Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
ForwardBatch,
ForwardMode,
)
from sglang.srt.patch_torch import monkey_patch_torch_compile
from sglang.srt.utils import get_available_gpu_memory, is_hip

_is_hip = is_hip()
Expand Down Expand Up @@ -108,6 +109,8 @@ def set_torch_compile_config():
if hasattr(torch._dynamo.config, "cache_size_limit"):
torch._dynamo.config.cache_size_limit = 1024

monkey_patch_torch_compile()


def get_batch_sizes_to_capture(model_runner: ModelRunner):
server_args = model_runner.server_args
Expand Down
7 changes: 1 addition & 6 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,7 @@
)
from sglang.srt.model_loader.utils import set_default_torch_dtype
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.patch_torch import (
monkey_patch_torch_compile,
monkey_patch_torch_reductions,
)
from sglang.srt.patch_torch import monkey_patch_torch_reductions
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
Expand All @@ -92,8 +89,6 @@
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300

monkey_patch_torch_compile()


class ModelRunner:
"""ModelRunner runs the forward passes of the models."""
Expand Down
Loading