Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/kernels/benchmark_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@

import ray
import torch
import triton
from ray.experimental.tqdm_ray import tqdm
from transformers import AutoConfig

from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.utils import FlexibleArgumentParser

FP8_DTYPE = current_platform.fp8_dtype()
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/kernels/benchmark_rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from typing import Optional, Union

import torch
import triton
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
from torch import nn

from vllm import _custom_ops as vllm_ops
from vllm.triton_utils import triton


class HuggingFaceRMSNorm(nn.Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
# Import DeepGEMM functions
import deep_gemm
import torch
import triton
from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor

# Import vLLM functions
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
from vllm.triton_utils import triton


# Copied from
Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/attention/test_flashmla.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

import pytest
import torch
import triton

from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
get_mla_metadata,
is_flashmla_supported)
from vllm.triton_utils import triton


def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# SPDX-License-Identifier: Apache-2.0

import torch
import triton
import triton.language as tl

from vllm.triton_utils import tl, triton


def blocksparse_flash_attn_varlen_fwd(
Expand Down
3 changes: 2 additions & 1 deletion vllm/attention/ops/blocksparse_attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

import numpy as np
import torch
import triton

from vllm.triton_utils import triton


class csr_matrix:
Expand Down
3 changes: 1 addition & 2 deletions vllm/attention/ops/chunked_prefill_paged_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
# - Thomas Parnell <[email protected]>

import torch
import triton
import triton.language as tl

from vllm import _custom_ops as ops
from vllm.platforms.rocm import use_rocm_custom_paged_attention
from vllm.triton_utils import tl, triton

from .prefix_prefill import context_attention_fwd

Expand Down
3 changes: 1 addition & 2 deletions vllm/attention/ops/prefix_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py

import torch
import triton
import triton.language as tl

from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton

# Static kernels parameters
BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64
Expand Down
4 changes: 1 addition & 3 deletions vllm/attention/ops/triton_decode_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,8 @@

import logging

import triton
import triton.language as tl

from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton

is_hip_ = current_platform.is_rocm()

Expand Down
3 changes: 1 addition & 2 deletions vllm/attention/ops/triton_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@
from typing import Optional

import torch
import triton
import triton.language as tl

from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton

SUPPORTED_LAYOUTS = ['thd', 'bhsd', 'bshd']

Expand Down
4 changes: 2 additions & 2 deletions vllm/attention/ops/triton_merge_attn_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from typing import Optional

import torch
import triton
import triton.language as tl

from vllm.triton_utils import tl, triton


# Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
Expand Down
3 changes: 1 addition & 2 deletions vllm/lora/ops/triton_ops/kernel_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
"""
Utilities for Punica kernel construction.
"""
import triton
import triton.language as tl
from vllm.triton_utils import tl, triton


@triton.jit
Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch
import triton
import triton.language as tl

import vllm.envs as envs
from vllm import _custom_ops as ops
Expand All @@ -21,6 +19,7 @@
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_group_quant_int8, per_token_quant_int8)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op

from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
from typing import Optional, Tuple

import torch
import triton
import triton.language as tl

import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.triton_utils import tl, triton
from vllm.utils import round_up


Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/lightning_attn.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
import torch
import triton
import triton.language as tl
from einops import rearrange

from vllm.triton_utils import tl, triton


@triton.jit
def _fwd_diag_kernel(Q, K, V, Out, S, b: tl.constexpr, h: tl.constexpr, n,
Expand Down
4 changes: 1 addition & 3 deletions vllm/model_executor/layers/mamba/ops/mamba_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py

import torch
import triton
import triton.language as tl
from packaging import version

from vllm import _custom_ops as ops
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.triton_utils import HAS_TRITON
from vllm.triton_utils import HAS_TRITON, tl, triton

TRITON3 = HAS_TRITON and (version.parse(triton.__version__)
>= version.parse("3.0.0"))
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/mamba/ops/ssd_bmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import math

import torch
import triton
import triton.language as tl

from vllm.triton_utils import tl, triton


@triton.autotune(
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
# ruff: noqa: E501,SIM102

import torch
import triton
import triton.language as tl
from packaging import version

from vllm.triton_utils import tl, triton

TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')


Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import math

import torch
import triton
import triton.language as tl

from vllm.triton_utils import tl, triton

from .mamba_ssm import softplus

Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/mamba/ops/ssd_combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
# ruff: noqa: E501

import torch
import triton
from einops import rearrange
from packaging import version

from vllm.triton_utils import triton

from .ssd_bmm import _bmm_chunk_fwd
from .ssd_chunk_scan import _chunk_scan_fwd
from .ssd_chunk_state import (_chunk_cumsum_fwd, _chunk_state_fwd,
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/mamba/ops/ssd_state_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
# ruff: noqa: E501

import torch
import triton
import triton.language as tl

from vllm.triton_utils import tl, triton


@triton.autotune(
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/quantization/awq_triton.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# SPDX-License-Identifier: Apache-2.0

import torch
import triton
import triton.language as tl

from vllm.triton_utils import tl, triton

AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from typing import Optional, Type

import torch
import triton
import triton.language as tl

from vllm.triton_utils import tl, triton


def is_weak_contiguous(x: torch.Tensor):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import triton
import triton.language as tl

from vllm import _custom_ops as ops
from vllm.logger import init_logger
Expand All @@ -17,6 +15,7 @@
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op

logger = init_logger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
from typing import Any, Dict, List, Optional, Tuple

import torch
import triton
import triton.language as tl

from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton

logger = logging.getLogger(__name__)

Expand Down
12 changes: 10 additions & 2 deletions vllm/triton_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# SPDX-License-Identifier: Apache-2.0

from vllm.triton_utils.importing import HAS_TRITON
from vllm.triton_utils.importing import (HAS_TRITON, TritonLanguagePlaceholder,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add some unittest for this placeholder logic?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll add it soon

TritonPlaceholder)

__all__ = ["HAS_TRITON"]
if HAS_TRITON:
import triton
import triton.language as tl
else:
triton = TritonPlaceholder()
tl = TritonLanguagePlaceholder()

__all__ = ["HAS_TRITON", "triton", "tl"]
60 changes: 31 additions & 29 deletions vllm/triton_utils/importing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,34 @@
logger.info("Triton not installed or not compatible; certain GPU-related"
" functions will not be available.")

class TritonPlaceholder(types.ModuleType):

def __init__(self):
super().__init__("triton")
self.jit = self._dummy_decorator("jit")
self.autotune = self._dummy_decorator("autotune")
self.heuristics = self._dummy_decorator("heuristics")
self.language = TritonLanguagePlaceholder()
logger.warning_once(
"Triton is not installed. Using dummy decorators. "
"Install it via `pip install triton` to enable kernel"
"compilation.")

def _dummy_decorator(self, name):

def decorator(func=None, **kwargs):
if func is None:
return lambda f: f
return func

return decorator

class TritonLanguagePlaceholder(types.ModuleType):

def __init__(self):
super().__init__("triton.language")
self.constexpr = None
self.dtype = None
self.int64 = None

class TritonPlaceholder(types.ModuleType):

def __init__(self):
super().__init__("triton")
self.jit = self._dummy_decorator("jit")
self.autotune = self._dummy_decorator("autotune")
self.heuristics = self._dummy_decorator("heuristics")
self.language = TritonLanguagePlaceholder()
logger.warning_once(
"Triton is not installed. Using dummy decorators. "
"Install it via `pip install triton` to enable kernel"
" compilation.")

def _dummy_decorator(self, name):

def decorator(func=None, **kwargs):
if func is None:
return lambda f: f
return func

return decorator


class TritonLanguagePlaceholder(types.ModuleType):

def __init__(self):
super().__init__("triton.language")
self.constexpr = None
self.dtype = None
self.int64 = None
Loading