Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions python/sglang/srt/layers/attention/wave_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
import triton
import triton.language as tl
import logging

from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
Expand All @@ -18,6 +19,7 @@
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput

logger = logging.getLogger(__name__)

@triton.jit
def get_num_kv_splits_triton(
Expand Down Expand Up @@ -102,6 +104,13 @@ def __init__(

super().__init__()

# Set unique cache dir for each process to avoid cache write races
import iree.turbine.kernel.wave.cache as wave_cache
base_cache_dir = wave_cache.CACHE_BASE_DIR
new_dir = base_cache_dir / f"worker_{model_runner.tp_rank}"
logger.info(f"Setting Wave cache dir: {new_dir}")
wave_cache.CACHE_BASE_DIR = new_dir

self.decode_attention_fwd = decode_attention_fwd
self.extend_attention_fwd = extend_attention_wave

Expand Down
Loading