diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 4dcf0613bd91..158538f34609 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -344,6 +344,7 @@ class Envs: # TODO(mmangkad): Remove this once the FlashInfer unified allreduce-fusion # transport issue on GB200/GB300 platforms is fixed and verified resolved. SGLANG_FLASHINFER_FORCE_POSIX_FD_TRANSPORT = EnvBool(None) + SGLANG_ENABLE_FLASHINFER_ROPE_FUSION = EnvBool(False) # Triton SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS = EnvBool(False) diff --git a/python/sglang/srt/layers/attention/trtllm_mha_backend.py b/python/sglang/srt/layers/attention/trtllm_mha_backend.py index 09f3f409a1fe..4b04d360af78 100644 --- a/python/sglang/srt/layers/attention/trtllm_mha_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mha_backend.py @@ -5,6 +5,7 @@ The kernel supports sm100 only, with sliding window and attention sink features. """ +import bisect import logging from dataclasses import dataclass from typing import TYPE_CHECKING, Optional @@ -19,7 +20,10 @@ from sglang.srt.layers.attention.triton_ops.trtllm_fp8_kv_kernel import ( fused_fp8_set_kv_buffer, ) -from sglang.srt.layers.attention.utils import canonicalize_stride +from sglang.srt.layers.attention.utils import ( + canonicalize_stride, + create_flashinfer_kv_indices_triton, +) from sglang.srt.mem_cache.swa_memory_pool import SWAKVPool, SWATokenToKVPoolAllocator from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.utils import is_flashinfer_available @@ -61,6 +65,18 @@ class TRTLLMMHAMetadata: # Page table for SWA layers (translated from full pool indices to SWA pool indices) swa_page_table: torch.Tensor = None + # The following fields are used for Flashinfer KV update + # kv_indices: page indices for all requests (full-attention) + kv_indices: torch.Tensor = None + # kv_indices for SWA layers + swa_kv_indices: torch.Tensor = None + # kv_indptr: cumulative page count per request + kv_indptr: torch.Tensor = None + # batch_indices: which request each token belongs to + batch_indices: torch.Tensor = None + # positions: position of each token within its sequence + positions: torch.Tensor = None + class TRTLLMHAAttnBackend(FlashInferAttnBackend): """TRTLLM MHA attention kernel from flashinfer.""" @@ -149,6 +165,19 @@ def __init__( # KV fp8: q_type = fp8, out_type=model_runner.dtype self.is_xqa_impl = is_sm90_supported() or is_sm120_supported() + self.enable_rope_fusion = envs.SGLANG_ENABLE_FLASHINFER_ROPE_FUSION.get() + if self.enable_rope_fusion: + logger.info("Flashinfer RoPE+Quant+cache update fusion is enabled") + assert ( + self.data_type == torch.float8_e4m3fn + ), "RoPE+Quant+cache update fusion is only supported for FP8 KV cache dtype" + + self.piecewise_cuda_graph_tokens = ( + model_runner.server_args.piecewise_cuda_graph_tokens + if not model_runner.server_args.disable_piecewise_cuda_graph + else None + ) + def _maybe_translate_swa( self, token_indices: torch.Tensor ) -> Optional[torch.Tensor]: @@ -211,6 +240,15 @@ def _get_layer_page_table( return swa_pt return self.forward_metadata.page_table + def _get_layer_kv_indices(self, layer: RadixAttention) -> torch.Tensor: + """Return the correct kv_indices for the RoPE fusion kernel (SWA or full).""" + swa_kv_indices = self.forward_metadata.swa_kv_indices + if swa_kv_indices is not None: + _, is_swa = self._swa_kv_pool.layers_mapping[layer.layer_id] + if is_swa: + return swa_kv_indices + return self.forward_metadata.kv_indices + def init_cuda_graph_state( self, max_bs: int, @@ -233,6 +271,14 @@ def init_cuda_graph_state( ), } + if self.enable_rope_fusion: + self.decode_cuda_graph_metadata["kv_indices"] = torch.zeros( + max_bs * max_num_pages, dtype=torch.int32, device=self.device + ) + if self.use_sliding_window_kv_pool: + self.decode_cuda_graph_metadata["swa_kv_indices"] = torch.zeros( + max_bs * max_num_pages, dtype=torch.int32, device=self.device + ) if ( self.speculative_num_draft_tokens is not None and self.speculative_num_draft_tokens > 0 @@ -303,6 +349,21 @@ def init_cuda_graph_state( ), } + if self.enable_rope_fusion: + self.target_verify_metadata["kv_indices"] = torch.zeros( + max_bs * max_num_pages, dtype=torch.int32, device=self.device + ) + self.draft_extend_metadata["kv_indices"] = torch.zeros( + max_bs * max_num_pages, dtype=torch.int32, device=self.device + ) + if self.use_sliding_window_kv_pool: + self.target_verify_metadata["swa_kv_indices"] = torch.zeros( + max_bs * max_num_pages, dtype=torch.int32, device=self.device + ) + self.draft_extend_metadata["swa_kv_indices"] = torch.zeros( + max_bs * max_num_pages, dtype=torch.int32, device=self.device + ) + def init_forward_metadata_capture_cuda_graph( self, bs: int, @@ -438,6 +499,25 @@ def init_forward_metadata_capture_cuda_graph( ) self.draft_extend_metadata[bs] = metadata + + if self.enable_rope_fusion: + if forward_mode.is_decode_or_idle(): + metadata.kv_indices = self.decode_cuda_graph_metadata["kv_indices"] + metadata.swa_kv_indices = self.decode_cuda_graph_metadata.get( + "swa_kv_indices" + ) + elif forward_mode.is_target_verify(): + metadata.kv_indices = self.target_verify_metadata["kv_indices"] + metadata.swa_kv_indices = self.target_verify_metadata.get( + "swa_kv_indices" + ) + else: + metadata.kv_indices = self.draft_extend_metadata["kv_indices"] + metadata.swa_kv_indices = self.draft_extend_metadata.get( + "swa_kv_indices" + ) + self._init_forward_metadata_for_rope_fusion(metadata, bs, num_tokens) + self.forward_metadata = metadata def init_forward_metadata_replay_cuda_graph( @@ -543,42 +623,39 @@ def init_forward_metadata_replay_cuda_graph( ] metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size) self._copy_swa_page_table(metadata, page_indices, max_seq_pages) + + if self.enable_rope_fusion: + # Compute total number of tokens + if forward_mode.is_decode_or_idle(): + if spec_info is not None: + total_num_tokens = bs * self.topk + else: + total_num_tokens = bs + metadata.kv_indices = self.decode_cuda_graph_metadata["kv_indices"] + metadata.swa_kv_indices = self.decode_cuda_graph_metadata.get( + "swa_kv_indices" + ) + elif forward_mode.is_target_verify(): + total_num_tokens = bs * self.speculative_num_draft_tokens + metadata.kv_indices = self.target_verify_metadata["kv_indices"] + metadata.swa_kv_indices = self.target_verify_metadata.get( + "swa_kv_indices" + ) + else: # draft_extend + total_num_tokens = metadata.cu_seqlens_q[-1].item() + metadata.kv_indices = self.draft_extend_metadata["kv_indices"] + metadata.swa_kv_indices = self.draft_extend_metadata.get( + "swa_kv_indices" + ) + + self._init_forward_metadata_for_rope_fusion(metadata, bs, total_num_tokens) + self.forward_metadata = metadata def get_cuda_graph_seq_len_fill_value(self) -> int: """Get the fill value for sequence lengths in CUDA graph.""" return 1 - def _should_use_fused_fp8_path(self, save_kv_cache: bool, k: torch.Tensor) -> bool: - """Check if we should use the fused FP8 KV cache write path.""" - return save_kv_cache and k is not None and self.data_type == torch.float8_e4m3fn - - def _fused_fp8_set_kv_buffer( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - layer: RadixAttention, - forward_batch: ForwardBatch, - **kwargs, - ): - """Fused FP8 quantization and KV cache write.""" - cache_loc = self._get_layer_cache_loc(layer, forward_batch.out_cache_loc) - - # Get K/V cache buffers from token_to_kv_pool - k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) - - fused_fp8_set_kv_buffer( - k=k, - v=v, - k_cache=k_cache, - v_cache=v_cache, - cache_loc=cache_loc, - k_scale=layer.k_scale, # May be None - v_scale=layer.v_scale, # May be None - page_size=self.page_size, - ) - def init_forward_metadata(self, forward_batch: ForwardBatch): """Initialize the metadata for a forward pass.""" @@ -646,7 +723,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ forward_batch.req_pool_indices, : metadata.max_seq_len_k ] - else: metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() @@ -690,8 +766,114 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): metadata.swa_page_table[:, self.strided_indices] // self.page_size ) + if self.enable_rope_fusion: + # Compute total number of tokens + if forward_batch.forward_mode.is_decode_or_idle(): + if forward_batch.spec_info is not None: + total_num_tokens = batch_size * self.topk + else: + total_num_tokens = batch_size + elif forward_batch.forward_mode.is_target_verify(): + total_num_tokens = batch_size * self.speculative_num_draft_tokens + else: + total_num_tokens = forward_batch.extend_num_tokens + # For piecewise CUDA graph, we need to find padded token count + # to initialize the rope fusion metadata + if self.piecewise_cuda_graph_tokens is not None: + idx = bisect.bisect_left( + self.piecewise_cuda_graph_tokens, total_num_tokens + ) + if idx < len(self.piecewise_cuda_graph_tokens): + total_num_tokens = self.piecewise_cuda_graph_tokens[idx] + + self._init_forward_metadata_for_rope_fusion( + metadata, batch_size, total_num_tokens + ) + self.forward_metadata = metadata + def _init_forward_metadata_for_rope_fusion( + self, + metadata: TRTLLMMHAMetadata, + batch_size: int, + total_num_tokens: int, + ): + """ + Initialize the following metadata for RoPE + FP8 quantization + KV cache update kernel. + - kv_indptr: cumulative page count per request + - batch_indices: which request each token belongs to + - positions: position of each token within its sequence + - kv_indices: page indices for all requests (full pool) + - swa_kv_indices: page indices for SWA layers + """ + + # Compute number of pages per request + num_pages_per_req = ( + metadata.cache_seqlens_int32 + self.page_size - 1 + ) // self.page_size + + # kv_indptr + if metadata.kv_indptr is None: + metadata.kv_indptr = torch.nn.functional.pad( + torch.cumsum(num_pages_per_req, dim=0, dtype=torch.int32), (1, 0) + ) + else: + metadata.kv_indptr[1:].copy_( + torch.cumsum(num_pages_per_req, dim=0, dtype=torch.int32) + ) + + # batch_indices and positions + metadata.batch_indices, metadata.positions = ( + flashinfer.page.get_batch_indices_positions( + metadata.cu_seqlens_q, + metadata.cache_seqlens_int32, + total_num_tokens, + metadata.batch_indices, + metadata.positions, + ) + ) + + # kv_indices (full-attention) + device = metadata.kv_indptr.device + batch_idx = torch.arange(batch_size, dtype=torch.int32, device=device) + if metadata.kv_indices is None: + # The max number of pages is kv_indptr[-1] + # Use upper bound to avoid D2H transfer from accessing kv_indptr[-1] + max_pages_per_req = ( + metadata.max_seq_len_k + self.page_size - 1 + ) // self.page_size + metadata.kv_indices = torch.zeros( + batch_size * max_pages_per_req, dtype=torch.int32, device=device + ) + create_flashinfer_kv_indices_triton[(batch_size,)]( + metadata.page_table, + batch_idx, + num_pages_per_req, + metadata.kv_indptr, + None, + metadata.kv_indices, + metadata.page_table.stride(0), + ) + + # swa_kv_indices (SWA) + if metadata.swa_page_table is not None: + if metadata.swa_kv_indices is None: + max_pages_per_req = ( + metadata.max_seq_len_k + self.page_size - 1 + ) // self.page_size + metadata.swa_kv_indices = torch.zeros( + batch_size * max_pages_per_req, dtype=torch.int32, device=device + ) + create_flashinfer_kv_indices_triton[(batch_size,)]( + metadata.swa_page_table, + batch_idx, + num_pages_per_req, + metadata.kv_indptr, + None, + metadata.swa_kv_indices, + metadata.swa_page_table.stride(0), + ) + def forward_decode( self, q: torch.Tensor, @@ -703,58 +885,107 @@ def forward_decode( **kwargs, ) -> torch.Tensor: """Run forward for decode using TRTLLM MHA kernel.""" - cache_loc = forward_batch.out_cache_loc + k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) - use_fused_fp8_path = self._should_use_fused_fp8_path(save_kv_cache, k) + q_scale_float = 1.0 + k_scale_float = ( + layer.k_scale_float + if getattr(layer, "k_scale_float", None) is not None + else 1.0 + ) + v_scale_float = ( + layer.v_scale_float + if getattr(layer, "v_scale_float", None) is not None + else 1.0 + ) - if use_fused_fp8_path: - # Use fused FP8 quantization + KV cache write path - self._fused_fp8_set_kv_buffer( - q=q, - k=k, - v=v, - layer=layer, - forward_batch=forward_batch, - ) - k = None - v = None - else: - # Use original set_kv_buffer path - if save_kv_cache and k is not None: + if save_kv_cache and k is not None: + if ( + self.enable_rope_fusion + and kwargs.get("cos_sin_cache") is not None + and kwargs.get("is_neox_style") is not None + and (not self.is_xqa_impl) + ): + # Use fused RoPE + FP8 quantization + KV cache update path + cos_sin_cache = kwargs.get("cos_sin_cache") + is_neox_style = kwargs.get("is_neox_style") + + q = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim) + k = k.view(-1, layer.tp_k_head_num, layer.qk_head_dim) + v = v.view(-1, layer.tp_v_head_num, layer.v_head_dim) + + k_cache = k_cache.view( + -1, self.page_size, layer.tp_k_head_num, layer.qk_head_dim + ) + v_cache = v_cache.view( + -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim + ) + + q = flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache( + q_rope=q, + k_rope=k, + q_nope=None, + k_nope=None, + v=v, + cos_sin_cache=cos_sin_cache, + pos_ids=forward_batch.positions, + paged_kv_cache=(k_cache, v_cache), + kv_indices=self._get_layer_kv_indices(layer), + kv_indptr=self.forward_metadata.kv_indptr, + batch_indices=self.forward_metadata.batch_indices, + positions=self.forward_metadata.positions, + is_neox=is_neox_style, + quantize_dtype=torch.float8_e4m3fn, + quant_scale_q=q_scale_float, + quant_scale_kv=k_scale_float, + page_size=self.page_size, + kv_layout="NHD", + )[0] + elif self.data_type == torch.float8_e4m3fn: + # Use fused KV FP8 quantization + KV cache update path + cache_loc = self._get_layer_cache_loc( + layer, forward_batch.out_cache_loc + ) + fused_fp8_set_kv_buffer( + k=k, + v=v, + k_cache=k_cache, + v_cache=v_cache, + cache_loc=cache_loc, + k_scale=layer.k_scale, + v_scale=layer.v_scale, + page_size=self.page_size, + ) + else: + # Use original set_kv_buffer path forward_batch.token_to_kv_pool.set_kv_buffer( - layer, cache_loc, k, v, layer.k_scale, layer.v_scale + layer, + forward_batch.out_cache_loc, + k, + v, + layer.k_scale, + layer.v_scale, ) # For XQA, q_dtype should be bf16 if self.data_type == torch.float8_e4m3fn and (not self.is_xqa_impl): q = q.to(torch.float8_e4m3fn) - q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) - k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) - # shape conversion: - # [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim] + q = q.contiguous().view(-1, layer.tp_q_head_num, layer.qk_head_dim) + k_cache = k_cache.view( - -1, self.page_size, layer.tp_k_head_num, layer.head_dim - ).permute(0, 2, 1, 3) + -1, self.page_size, layer.tp_k_head_num, layer.qk_head_dim + ) v_cache = v_cache.view( - -1, self.page_size, layer.tp_v_head_num, layer.head_dim - ).permute(0, 2, 1, 3) - + -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim + ) if layer.tp_k_head_num == 1: k_cache = canonicalize_stride(k_cache) if layer.tp_v_head_num == 1: v_cache = canonicalize_stride(v_cache) - kv_cache = (k_cache, v_cache) - # TODO: add support for quantization - q_scale = 1.0 - k_scale = ( - layer.k_scale_float - if getattr(layer, "k_scale_float", None) is not None - else 1.0 - ) - bmm1_scale = q_scale * k_scale * layer.scaling - bmm2_scale = 1.0 + bmm1_scale = q_scale_float * k_scale_float * layer.scaling + bmm2_scale = v_scale_float # sink: additional value per head in the denominator of the softmax. attention_sink = kwargs.get("sinks", None) @@ -774,9 +1005,10 @@ def forward_decode( window_left=layer.sliding_window_size, sinks=attention_sink, out_dtype=self.q_data_type, # model_runner.dtype + kv_layout="NHD", ) - return o.view(-1, layer.tp_q_head_num * layer.head_dim) + return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) def forward_extend( self, @@ -785,61 +1017,111 @@ def forward_extend( v: torch.Tensor, layer: RadixAttention, forward_batch: ForwardBatch, - save_kv_cache=True, + save_kv_cache: bool = True, **kwargs, - ): - cache_loc = forward_batch.out_cache_loc + ) -> torch.Tensor: + """Run forward for extend using TRTLLM MHA kernel.""" + k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) - use_fused_fp8_path = self._should_use_fused_fp8_path(save_kv_cache, k) + q_scale_float = 1.0 + k_scale_float = ( + layer.k_scale_float + if getattr(layer, "k_scale_float", None) is not None + else 1.0 + ) + v_scale_float = ( + layer.v_scale_float + if getattr(layer, "v_scale_float", None) is not None + else 1.0 + ) - if use_fused_fp8_path: - # Use fused FP8 quantization + KV cache write path - self._fused_fp8_set_kv_buffer( - q=q, - k=k, - v=v, - layer=layer, - forward_batch=forward_batch, - ) - k = None - v = None - else: - # Use original set_kv_buffer path - if save_kv_cache and k is not None: + if save_kv_cache and k is not None: + if ( + self.enable_rope_fusion + and kwargs.get("cos_sin_cache") is not None + and kwargs.get("is_neox_style") is not None + ): + # Use fused RoPE + FP8 quantization + KV cache update path + cos_sin_cache = kwargs.get("cos_sin_cache") + is_neox_style = kwargs.get("is_neox_style") + + q = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim) + k = k.view(-1, layer.tp_k_head_num, layer.qk_head_dim) + v = v.view(-1, layer.tp_v_head_num, layer.v_head_dim) + + k_cache = k_cache.view( + -1, self.page_size, layer.tp_k_head_num, layer.qk_head_dim + ) + v_cache = v_cache.view( + -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim + ) + + q = flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache( + q_rope=q, + k_rope=k, + q_nope=None, + k_nope=None, + v=v, + cos_sin_cache=cos_sin_cache, + pos_ids=forward_batch.positions, + paged_kv_cache=(k_cache, v_cache), + kv_indices=self._get_layer_kv_indices(layer), + kv_indptr=self.forward_metadata.kv_indptr, + batch_indices=self.forward_metadata.batch_indices, + positions=self.forward_metadata.positions, + is_neox=is_neox_style, + quantize_dtype=torch.float8_e4m3fn, + quant_scale_q=q_scale_float, + quant_scale_kv=k_scale_float, + page_size=self.page_size, + kv_layout="NHD", + )[0] + elif self.data_type == torch.float8_e4m3fn: + # Use fused KV FP8 quantization + KV cache update path + cache_loc = self._get_layer_cache_loc( + layer, forward_batch.out_cache_loc + ) + fused_fp8_set_kv_buffer( + k=k, + v=v, + k_cache=k_cache, + v_cache=v_cache, + cache_loc=cache_loc, + k_scale=layer.k_scale, + v_scale=layer.v_scale, + page_size=self.page_size, + ) + else: + # Use original set_kv_buffer path forward_batch.token_to_kv_pool.set_kv_buffer( - layer, cache_loc, k, v, layer.k_scale, layer.v_scale + layer, + forward_batch.out_cache_loc, + k, + v, + layer.k_scale, + layer.v_scale, ) if self.data_type == torch.float8_e4m3fn: q = q.to(torch.float8_e4m3fn) - q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) - # [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim] - k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + q = q.contiguous().view(-1, layer.tp_q_head_num, layer.qk_head_dim) + k_cache = k_cache.view( - -1, self.page_size, layer.tp_k_head_num, layer.head_dim - ).permute(0, 2, 1, 3) + -1, self.page_size, layer.tp_k_head_num, layer.qk_head_dim + ) v_cache = v_cache.view( - -1, self.page_size, layer.tp_v_head_num, layer.head_dim - ).permute(0, 2, 1, 3) - + -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim + ) if layer.tp_k_head_num == 1: k_cache = canonicalize_stride(k_cache) if layer.tp_v_head_num == 1: v_cache = canonicalize_stride(v_cache) - kv_cache = (k_cache, v_cache) + bmm1_scale = q_scale_float * k_scale_float * layer.scaling + bmm2_scale = v_scale_float # sink: additional value per head in the denominator of the softmax. attention_sink = kwargs.get("sinks", None) - # TODO: add support for quantization - q_scale = 1.0 - k_scale = ( - layer.k_scale_float - if getattr(layer, "k_scale_float", None) is not None - else 1.0 - ) - bmm1_scale = q_scale * k_scale * layer.scaling - bmm2_scale = 1.0 page_table = self._get_layer_page_table(layer, forward_batch) @@ -857,6 +1139,7 @@ def forward_extend( sinks=attention_sink, out_dtype=self.q_data_type, # model_runner.dtype q_len_per_req=self.forward_metadata.max_seq_len_q, + kv_layout="NHD", ) else: o = flashinfer.prefill.trtllm_batch_context_with_kv_cache( @@ -875,9 +1158,10 @@ def forward_extend( window_left=layer.sliding_window_size, sinks=attention_sink, out_dtype=self.q_data_type, # model_runner.dtype + kv_layout="NHD", ) - return o.view(-1, layer.tp_q_head_num * layer.head_dim) + return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend): diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 449bc867067e..89a3044d7f15 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -148,6 +148,8 @@ def unified_attention_with_output( q_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None, sinks: Optional[torch.Tensor] = None, + cos_sin_cache: Optional[torch.Tensor] = None, + is_neox_style: Optional[bool] = None, ) -> None: context = get_forward_context() forward_batch = context.forward_batch @@ -161,6 +163,10 @@ def unified_attention_with_output( kwargs["k_rope"] = k_rope if sinks is not None: kwargs["sinks"] = sinks + if cos_sin_cache is not None: + kwargs["cos_sin_cache"] = cos_sin_cache + if is_neox_style is not None: + kwargs["is_neox_style"] = is_neox_style ret = forward_batch.attn_backend.forward( query, key, value, attention_layer, forward_batch, save_kv_cache, **kwargs diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py index 593ef4b9f932..451cf8f27748 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -39,6 +39,7 @@ get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) +from sglang.srt.environ import envs from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes @@ -395,7 +396,9 @@ def forward_prepare( else None ), } - q, k = self.rotary_emb(positions, q, k, **extra_args) + # Defer RoPE to attn backend if RoPE fusion is enabled + if not envs.SGLANG_ENABLE_FLASHINFER_ROPE_FUSION.get(): + q, k = self.rotary_emb(positions, q, k, **extra_args) inner_state = q, k, v, forward_batch return None, forward_batch, inner_state @@ -403,10 +406,20 @@ def forward_core(self, intermediate_state): hidden_states, forward_batch, inner_state = intermediate_state if inner_state is None: return hidden_states + + # Pass required RoPE args into attn backend if RoPE fusion is enabled + extra_args = {} + if envs.SGLANG_ENABLE_FLASHINFER_ROPE_FUSION.get(): + extra_args = { + "cos_sin_cache": self.rotary_emb.cos_sin_cache, + "is_neox_style": self.rotary_emb.is_neox_style, + } + attn_output = self.attn( *inner_state, sinks=self.sinks, save_kv_cache=not enable_fused_set_kv_buffer(forward_batch), + **extra_args, ) output, _ = self.o_proj(attn_output) return output diff --git a/python/sglang/test/attention/test_trtllm_mha_backend.py b/python/sglang/test/attention/test_trtllm_mha_backend.py new file mode 100644 index 000000000000..61bf552afa39 --- /dev/null +++ b/python/sglang/test/attention/test_trtllm_mha_backend.py @@ -0,0 +1,572 @@ +""" +Tests for TRTLLM MHA attention backend. + +- test_decode_output_match / test_extend_output_match: + Compare TRTLLM MHA vs FlashInfer reference across kv_cache_dtype variants. +- test_rope_fusion_decode_output_match / test_rope_fusion_extend_output_match: + Compare fused (RoPE+FP8 quant+KV cache) vs unfused path. +""" + +import os +import unittest + +import torch + +from sglang.srt.layers import dp_attention as _dp_attn + +_dp_attn.get_attention_tp_size = lambda: 1 + +from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend +from sglang.srt.layers.attention.trtllm_mha_backend import TRTLLMHAAttnBackend +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler +from sglang.srt.utils import is_flashinfer_available +from sglang.test.test_utils import CustomTestCase + +DEFAULT_CONFIG = { + "device": "cuda", + "dtype": torch.bfloat16, + "kv_cache_dtype": torch.bfloat16, + "context_len": 4096, + "max_bs": 64, + "page_size": 64, + "num_attention_heads": 64, + "num_kv_heads": 8, + "head_dim": 64, + "hidden_size": 2880, + "layer_num": 2, + "layer_id": 0, + "rope_theta": 150000, + "architectures": ["GptOssForCausalLM"], + "seed": 42, + "rtol": 0.01, + "atol": 0.01, +} + +TEST_CASES = { + "basic_decode": [ + { + "name": "bf16_kv_cache", + "kv_cache_dtype": torch.bfloat16, + "batch_size": 4, + "max_seq_len": 128, + }, + { + "name": "fp8_kv_cache", + "kv_cache_dtype": torch.float8_e4m3fn, + "batch_size": 4, + "max_seq_len": 128, + "atol": 0.05, + }, + ], + "basic_extend": [ + { + "name": "bf16_kv_cache", + "kv_cache_dtype": torch.bfloat16, + "seq_lens_list": [64, 100, 80], + }, + { + "name": "fp8_kv_cache", + "kv_cache_dtype": torch.float8_e4m3fn, + "seq_lens_list": [64, 100, 80], + "atol": 0.25, + }, + ], + "rope_fusion_decode": [ + { + "name": "rope_fusion", + "kv_cache_dtype": torch.float8_e4m3fn, + "batch_size": 4, + "max_seq_len": 128, + "atol": 0.05, + }, + ], + "rope_fusion_extend": [ + { + "name": "rope_fusion", + "kv_cache_dtype": torch.float8_e4m3fn, + "seq_lens_list": [64, 100, 80], + "atol": 0.15, + }, + ], +} + + +class MockModelRunner: + """Minimal ModelRunner for testing MHA backends.""" + + def __init__(self, config, enable_rope_fusion=False): + self.device = config["device"] + self.dtype = config["dtype"] + self.kv_cache_dtype = config["kv_cache_dtype"] + self.page_size = config["page_size"] + self.sliding_window_size = None + + server_args = ServerArgs(model_path="dummy") + server_args.enable_dp_attention = False + server_args.disable_piecewise_cuda_graph = False + server_args.piecewise_cuda_graph_tokens = [4, 8, 16, 32, 64, 128, 256, 512] + set_global_server_args_for_scheduler(server_args) + self.server_args = server_args + + if enable_rope_fusion: + os.environ["SGLANG_ENABLE_FLASHINFER_ROPE_FUSION"] = "1" + else: + os.environ["SGLANG_ENABLE_FLASHINFER_ROPE_FUSION"] = "0" + + hf_config = type("HFConfig", (), {"architectures": config["architectures"]}) + self.model_config = type( + "ModelConfig", + (), + { + "context_len": config["context_len"], + "num_attention_heads": config["num_attention_heads"], + "hidden_size": config["hidden_size"], + "head_dim": config["head_dim"], + "get_num_kv_heads": staticmethod(lambda _: config["num_kv_heads"]), + "is_multimodal": False, + "is_encoder_decoder": False, + "hf_config": hf_config, + }, + ) + + max_bs = config["max_bs"] + max_ctx = config["context_len"] + req_to_token = torch.arange( + max_bs * max_ctx, dtype=torch.int32, device=self.device + ).reshape(max_bs, max_ctx) + self.req_to_token_pool = type( + "TokenPool", (), {"size": max_bs, "req_to_token": req_to_token} + ) + + self.token_to_kv_pool = MHATokenToKVPool( + size=max_bs * max_ctx, + page_size=config["page_size"], + dtype=config["kv_cache_dtype"], + head_num=config["num_kv_heads"], + head_dim=config["head_dim"], + layer_num=config["layer_num"], + device=config["device"], + enable_memory_saver=False, + ) + + self.token_to_kv_pool_allocator = type( + "MockAllocator", + (), + {"get_kvcache": lambda self_: self.token_to_kv_pool}, + )() + + +def _create_layer(config): + return RadixAttention( + num_heads=config["num_attention_heads"], + head_dim=config["head_dim"], + scaling=1.0 / (config["head_dim"] ** 0.5), + num_kv_heads=config["num_kv_heads"], + layer_id=config["layer_id"], + v_head_dim=config["head_dim"], + prefix="test_attn", + ) + + +def _create_rotary_emb(config): + from sglang.srt.layers.rotary_embedding import get_rope_wrapper + + rotary = get_rope_wrapper( + head_size=config["head_dim"], + rotary_dim=config["head_dim"], + max_position=config["context_len"], + base=config.get("rope_theta", 10000), + is_neox_style=True, + device=config["device"], + ) + rotary.cos_sin_cache = rotary.cos_sin_cache.to(config["device"]) + return rotary + + +def _populate_kv_cache(batch_size, seq_lens, model_runner, layer, config): + torch.manual_seed(config["seed"]) + for b in range(batch_size): + sl = int(seq_lens[b].item()) + for t in range(sl - 1): + cache_k = torch.randn( + 1, + config["num_kv_heads"], + config["head_dim"], + dtype=config["dtype"], + device=config["device"], + ) + cache_v = torch.randn( + 1, + config["num_kv_heads"], + config["head_dim"], + dtype=config["dtype"], + device=config["device"], + ) + loc = model_runner.req_to_token_pool.req_to_token[b, t].unsqueeze(0).long() + model_runner.token_to_kv_pool.set_kv_buffer(layer, loc, cache_k, cache_v) + + +def _create_decode_forward_batch(batch_size, seq_lens, backend, model_runner, config): + out_cache_loc = torch.tensor( + [ + model_runner.req_to_token_pool.req_to_token[b, int(seq_lens[b].item()) - 1] + for b in range(batch_size) + ], + dtype=torch.int64, + device=config["device"], + ) + fb = ForwardBatch( + batch_size=batch_size, + input_ids=torch.zeros(batch_size, dtype=torch.int64, device=config["device"]), + out_cache_loc=out_cache_loc, + seq_lens_sum=int(seq_lens.sum().item()), + forward_mode=ForwardMode.DECODE, + req_pool_indices=torch.arange(batch_size, device=config["device"]), + seq_lens=seq_lens, + seq_lens_cpu=seq_lens.cpu(), + attn_backend=backend, + ) + fb.req_to_token_pool = model_runner.req_to_token_pool + fb.token_to_kv_pool = model_runner.token_to_kv_pool + fb.positions = (seq_lens - 1).to(torch.int64) + return fb + + +def _create_extend_forward_batch(batch_size, seq_lens, backend, model_runner, config): + total_tokens = int(seq_lens.sum().item()) + out_cache_loc = torch.cat( + [ + model_runner.req_to_token_pool.req_to_token[b, : int(seq_lens[b].item())] + for b in range(batch_size) + ] + ).to(torch.int64) + + fb = ForwardBatch( + batch_size=batch_size, + input_ids=torch.zeros(total_tokens, dtype=torch.int64, device=config["device"]), + out_cache_loc=out_cache_loc, + seq_lens_sum=int(seq_lens.sum().item()), + forward_mode=ForwardMode.EXTEND, + req_pool_indices=torch.arange(batch_size, device=config["device"]), + seq_lens=seq_lens, + seq_lens_cpu=seq_lens.cpu(), + extend_num_tokens=total_tokens, + extend_seq_lens=seq_lens.clone(), + extend_seq_lens_cpu=seq_lens.cpu().tolist(), + extend_prefix_lens=torch.zeros( + batch_size, dtype=torch.int32, device=config["device"] + ), + extend_prefix_lens_cpu=[0] * batch_size, + attn_backend=backend, + ) + fb.req_to_token_pool = model_runner.req_to_token_pool + fb.token_to_kv_pool = model_runner.token_to_kv_pool + fb.positions = torch.cat( + [torch.arange(s, dtype=torch.int64, device=config["device"]) for s in seq_lens] + ) + return fb + + +def _compare_outputs(test_case, out_a, out_b, rtol, atol, label=""): + test_case.assertEqual(out_a.shape, out_b.shape) + test_case.assertFalse(torch.isnan(out_a).any(), f"{label} output A has NaN") + test_case.assertFalse(torch.isnan(out_b).any(), f"{label} output B has NaN") + + diff = (out_a.float() - out_b.float()).abs() + max_diff = diff.max().item() + mean_diff = diff.mean().item() + + test_case.assertTrue( + torch.allclose(out_a.float(), out_b.float(), rtol=rtol, atol=atol), + f"{label} outputs differ: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}", + ) + + +@unittest.skipIf( + not torch.cuda.is_available() or not is_flashinfer_available(), + "CUDA + flashinfer required", +) +class TestTRTLLMMHA(CustomTestCase): + """Test suite for TRTLLM MHA backend.""" + + def _merge_config(self, overrides): + config = DEFAULT_CONFIG.copy() + config.update(overrides) + return config + + def _build_trtllm_backend(self, config, enable_rope_fusion=False): + model_runner = MockModelRunner(config, enable_rope_fusion=enable_rope_fusion) + backend = TRTLLMHAAttnBackend(model_runner, skip_prefill=True) + return backend, model_runner + + def _build_reference_backend(self, config): + model_runner = MockModelRunner(config) + backend = FlashInferAttnBackend(model_runner) + return backend, model_runner + + # ------------------------------------------------------------------ # + # Fundamental: TRTLLM MHA vs FlashInfer reference # + # ------------------------------------------------------------------ # + + def test_basic_decode_output_match(self): + """TRTLLM MHA decode should match FlashInfer decode output.""" + for tc in TEST_CASES["basic_decode"]: + with self.subTest(name=tc["name"]): + config = self._merge_config(tc) + bs = config["batch_size"] + max_seq_len = config["max_seq_len"] + num_q, num_kv, hdim = ( + config["num_attention_heads"], + config["num_kv_heads"], + config["head_dim"], + ) + + torch.manual_seed(config["seed"]) + seq_lens = torch.randint( + max_seq_len // 2, max_seq_len + 1, (bs,), device=config["device"] + ) + q = torch.randn( + bs, num_q * hdim, dtype=config["dtype"], device=config["device"] + ) + k = torch.randn( + bs, num_kv * hdim, dtype=config["dtype"], device=config["device"] + ) + v = torch.randn( + bs, num_kv * hdim, dtype=config["dtype"], device=config["device"] + ) + + def run(build_fn): + backend, model_runner = build_fn(config) + layer = _create_layer(config) + _populate_kv_cache(bs, seq_lens, model_runner, layer, config) + forward_batch = _create_decode_forward_batch( + bs, seq_lens, backend, model_runner, config + ) + backend.init_forward_metadata(forward_batch) + return backend.forward_decode( + q.clone(), k.clone(), v.clone(), layer, forward_batch + ) + + out_trtllm = run(self._build_trtllm_backend) + out_ref = run(self._build_reference_backend) + _compare_outputs( + self, + out_trtllm, + out_ref, + rtol=config["rtol"], + atol=config["atol"], + label=f"[basic_decode/{config['name']}]", + ) + + def test_basic_extend_output_match(self): + """TRTLLM MHA extend should match FlashInfer extend output.""" + for tc in TEST_CASES["basic_extend"]: + with self.subTest(name=tc["name"]): + config = self._merge_config(tc) + seq_lens_list = config["seq_lens_list"] + bs = len(seq_lens_list) + total_num_tokens = sum(seq_lens_list) + num_q, num_kv, hdim = ( + config["num_attention_heads"], + config["num_kv_heads"], + config["head_dim"], + ) + + torch.manual_seed(config["seed"]) + seq_lens = torch.tensor( + seq_lens_list, dtype=torch.int32, device=config["device"] + ) + q = torch.randn( + total_num_tokens, + num_q * hdim, + dtype=config["dtype"], + device=config["device"], + ) + k = torch.randn( + total_num_tokens, + num_kv * hdim, + dtype=config["dtype"], + device=config["device"], + ) + v = torch.randn( + total_num_tokens, + num_kv * hdim, + dtype=config["dtype"], + device=config["device"], + ) + + def run(build_fn): + backend, model_runner = build_fn(config) + layer = _create_layer(config) + forward_batch = _create_extend_forward_batch( + bs, seq_lens, backend, model_runner, config + ) + backend.init_forward_metadata(forward_batch) + return backend.forward_extend( + q.clone(), k.clone(), v.clone(), layer, forward_batch + ) + + out_trtllm = run(self._build_trtllm_backend) + out_ref = run(self._build_reference_backend) + _compare_outputs( + self, + out_trtllm, + out_ref, + rtol=config["rtol"], + atol=config["atol"], + label=f"[basic_extend/{config['name']}]", + ) + + # ------------------------------------------------------------------ # + # Rope fusion: fused vs unfused path # + # ------------------------------------------------------------------ # + + def test_rope_fusion_decode_output_match(self): + """Fused vs unfused decode should produce the same attention output.""" + for tc in TEST_CASES["rope_fusion_decode"]: + with self.subTest(name=tc["name"]): + config = self._merge_config(tc) + bs = config["batch_size"] + max_seq_len = config["max_seq_len"] + num_q, num_kv, hdim = ( + config["num_attention_heads"], + config["num_kv_heads"], + config["head_dim"], + ) + + torch.manual_seed(config["seed"]) + seq_lens = torch.randint( + max_seq_len // 2, max_seq_len + 1, (bs,), device=config["device"] + ) + q = torch.randn( + bs, num_q * hdim, dtype=config["dtype"], device=config["device"] + ) + k = torch.randn( + bs, num_kv * hdim, dtype=config["dtype"], device=config["device"] + ) + v = torch.randn( + bs, num_kv * hdim, dtype=config["dtype"], device=config["device"] + ) + + def run(enable_rope_fusion): + backend, model_runner = self._build_trtllm_backend( + config, enable_rope_fusion=enable_rope_fusion + ) + layer = _create_layer(config) + rotary = _create_rotary_emb(config) + _populate_kv_cache(bs, seq_lens, model_runner, layer, config) + forward_batch = _create_decode_forward_batch( + bs, seq_lens, backend, model_runner, config + ) + backend.init_forward_metadata(forward_batch) + if enable_rope_fusion: + return backend.forward_decode( + q.clone(), + k.clone(), + v.clone(), + layer, + forward_batch, + cos_sin_cache=rotary.cos_sin_cache, + is_neox_style=rotary.is_neox_style, + ) + else: + q_rope, k_rope = rotary( + forward_batch.positions, q.clone(), k.clone() + ) + return backend.forward_decode( + q_rope, k_rope, v.clone(), layer, forward_batch + ) + + out_fused = run(enable_rope_fusion=True) + out_unfused = run(enable_rope_fusion=False) + _compare_outputs( + self, + out_fused, + out_unfused, + rtol=config["rtol"], + atol=config["atol"], + label=f"[rope_fusion_decode/{config['name']}]", + ) + + def test_rope_fusion_extend_output_match(self): + """Fused vs unfused extend should produce the same attention output.""" + for tc in TEST_CASES["rope_fusion_extend"]: + with self.subTest(name=tc["name"]): + config = self._merge_config(tc) + seq_lens_list = config["seq_lens_list"] + bs = len(seq_lens_list) + total_num_tokens = sum(seq_lens_list) + num_q, num_kv, hdim = ( + config["num_attention_heads"], + config["num_kv_heads"], + config["head_dim"], + ) + + torch.manual_seed(config["seed"]) + seq_lens = torch.tensor( + seq_lens_list, dtype=torch.int32, device=config["device"] + ) + q = torch.randn( + total_num_tokens, + num_q * hdim, + dtype=config["dtype"], + device=config["device"], + ) + k = torch.randn( + total_num_tokens, + num_kv * hdim, + dtype=config["dtype"], + device=config["device"], + ) + v = torch.randn( + total_num_tokens, + num_kv * hdim, + dtype=config["dtype"], + device=config["device"], + ) + + def run(enable_rope_fusion): + backend, model_runner = self._build_trtllm_backend( + config, enable_rope_fusion=enable_rope_fusion + ) + layer = _create_layer(config) + rotary = _create_rotary_emb(config) + forward_batch = _create_extend_forward_batch( + bs, seq_lens, backend, model_runner, config + ) + backend.init_forward_metadata(forward_batch) + if enable_rope_fusion: + return backend.forward_extend( + q.clone(), + k.clone(), + v.clone(), + layer, + forward_batch, + cos_sin_cache=rotary.cos_sin_cache, + is_neox_style=rotary.is_neox_style, + ) + else: + q_rope, k_rope = rotary( + forward_batch.positions, q.clone(), k.clone() + ) + return backend.forward_extend( + q_rope, k_rope, v.clone(), layer, forward_batch + ) + + out_fused = run(enable_rope_fusion=True) + out_unfused = run(enable_rope_fusion=False) + _compare_outputs( + self, + out_fused, + out_unfused, + rtol=config["rtol"], + atol=config["atol"], + label=f"[rope_fusion_extend/{config['name']}]", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/models/test_gpt_oss_models_rope_fusion.py b/test/registered/models/test_gpt_oss_models_rope_fusion.py new file mode 100644 index 000000000000..79509e7be0c5 --- /dev/null +++ b/test/registered/models/test_gpt_oss_models_rope_fusion.py @@ -0,0 +1,121 @@ +""" +GPT-OSS RoPE+FP8 Quant+KV Cache fusion tests. + +Tests the fused FlashInfer kernel (rope_quantize_fp8_append_paged_kv_cache) in the +TRTLLM MHA backend with piecewise CUDA graph and optionally EAGLE speculative decoding. +""" + +import os +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + +register_cuda_ci( + est_time=600, + suite="stage-b-test-2-gpu-large", +) + +GPT_OSS_MODEL = "openai/gpt-oss-120b" +GPT_OSS_EAGLE3_DRAFT_MODEL = "nvidia/gpt-oss-120b-Eagle3" + +ACC_THRESHOLDS = { + "gsm8k": 0.81, +} + +BASE_ARGS = [ + "--tp", + "2", + "--trust-remote-code", + "--reasoning-parser", + "gpt-oss", + "--kv-cache-dtype", + "fp8_e4m3", +] + + +def _run_gsm8k(base_url): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(base_url.split(":")[-1]), + ) + return run_eval(args) + + +class TestGptOssRopeFusion(CustomTestCase): + """Test GPT-OSS accuracy with FlashInfer RoPE fusion enabled.""" + + def _launch_and_eval(self, extra_args=None, extra_env=None): + env = {"SGLANG_ENABLE_FLASHINFER_ROPE_FUSION": "1"} + if extra_env: + env.update(extra_env) + for k, v in env.items(): + os.environ[k] = v + + server_args = BASE_ARGS + (extra_args or []) + process = popen_launch_server( + GPT_OSS_MODEL, + DEFAULT_URL_FOR_TEST, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=server_args, + ) + try: + metrics = _run_gsm8k(DEFAULT_URL_FOR_TEST) + print(f"{metrics=}") + self.assertGreaterEqual( + metrics["accuracy"], + ACC_THRESHOLDS["gsm8k"], + ) + finally: + kill_process_tree(process.pid) + for k in env: + os.environ.pop(k, None) + + def test_rope_fusion_pcg(self): + """RoPE fusion with piecewise CUDA graph (default).""" + self._launch_and_eval() + + def test_rope_fusion_no_pcg(self): + """RoPE fusion without piecewise CUDA graph.""" + self._launch_and_eval(extra_args=["--disable-piecewise-cuda-graph"]) + + def test_rope_fusion_eagle(self): + """RoPE fusion with EAGLE3 speculative decoding.""" + eagle_args = [ + "--speculative-algorithm", + "EAGLE3", + "--speculative-draft-model-path", + GPT_OSS_EAGLE3_DRAFT_MODEL, + "--speculative-num-steps", + "3", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "4", + "--cuda-graph-max-bs", + "100", + "--mem-fraction-static", + "0.85", + ] + eagle_env = { + "SGLANG_ENABLE_SPEC_V2": "1", + "SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN": "1", + } + self._launch_and_eval(extra_args=eagle_args, extra_env=eagle_env) + + +if __name__ == "__main__": + unittest.main()