Skip to content

Commit

Permalink
cleanup models dependencies 1/n (#2948)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhyncs authored Jan 17, 2025
1 parent d06c1ab commit 033c715
Show file tree
Hide file tree
Showing 10 changed files with 36 additions and 46 deletions.
2 changes: 1 addition & 1 deletion python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from torch.nn import Module
from vllm import _custom_ops as ops
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod

from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
Expand All @@ -25,6 +24,7 @@
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
from sglang.srt.utils import is_hip, set_weight_attrs

logger = logging.getLogger(__name__)
Expand Down
10 changes: 1 addition & 9 deletions python/sglang/srt/lora/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,18 @@
# https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py


import json
import os
import re
from typing import Any, Dict, List, Optional, Tuple

import safetensors.torch
import torch
from torch import nn
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding

from sglang.srt.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_loader.loader import DefaultModelLoader


Expand Down
10 changes: 5 additions & 5 deletions python/sglang/srt/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,6 @@
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.rotary_embedding import get_rope

from sglang.srt.distributed import (
Expand All @@ -37,6 +32,11 @@
)
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
Expand Down
3 changes: 1 addition & 2 deletions python/sglang/srt/models/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,9 @@
import torch
from torch import nn
from transformers import GPT2Config
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding

from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import get_act_fn

# from sglang.srt.layers.activation import get_act_fn
from sglang.srt.layers.linear import (
Expand Down
12 changes: 6 additions & 6 deletions python/sglang/srt/models/minicpm3.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,17 @@
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.rotary_embedding import get_rope

from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/olmo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from torch import nn
from transformers import PretrainedConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader

from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
Expand All @@ -45,6 +44,7 @@
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import make_layers


Expand Down
11 changes: 5 additions & 6 deletions python/sglang/srt/models/olmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,6 @@
import torch.nn.functional as F
from torch import nn
from transformers import PretrainedConfig
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.rotary_embedding import get_rope

from sglang.srt.distributed import (
Expand All @@ -37,6 +31,11 @@
)
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
import logging
from functools import lru_cache, partial
from typing import Iterable, List, Optional, Tuple, Type, TypedDict

Expand All @@ -30,7 +31,6 @@
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import QuickGELU

from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
Expand All @@ -50,7 +50,7 @@
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2Model

logger = init_logger(__name__)
logger = logging.getLogger(__name__)

# === Vision Inputs === #

Expand Down
12 changes: 6 additions & 6 deletions python/sglang/srt/models/xverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@
import torch
from torch import nn
from transformers import LlamaConfig
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.rotary_embedding import get_rope

from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.rotary_embedding import get_rope

from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
Expand Down
16 changes: 8 additions & 8 deletions python/sglang/srt/models/xverse_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,21 @@
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.rotary_embedding import get_rope

from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
from sglang.srt.layers.quantization.base_config import QuantizationConfig
Expand Down

0 comments on commit 033c715

Please sign in to comment.