diff --git a/python/sglang/srt/layers/attention/wave_backend.py b/python/sglang/srt/layers/attention/wave_backend.py index 13aaafe77ffd..38e1224cbb12 100644 --- a/python/sglang/srt/layers/attention/wave_backend.py +++ b/python/sglang/srt/layers/attention/wave_backend.py @@ -6,6 +6,7 @@ import torch import triton import triton.language as tl +import logging from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton @@ -18,6 +19,7 @@ from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput +logger = logging.getLogger(__name__) @triton.jit def get_num_kv_splits_triton( @@ -102,6 +104,13 @@ def __init__( super().__init__() + # Set unique cache dir for each process to avoid cache write races + import iree.turbine.kernel.wave.cache as wave_cache + base_cache_dir = wave_cache.CACHE_BASE_DIR + new_dir = base_cache_dir / f"worker_{model_runner.tp_rank}" + logger.info(f"Setting Wave cache dir: {new_dir}") + wave_cache.CACHE_BASE_DIR = new_dir + self.decode_attention_fwd = decode_attention_fwd self.extend_attention_fwd = extend_attention_wave