From 607f6ab6bab24b3d08fbcaf13fe9097f623be30f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 14 Apr 2025 15:19:28 +0800 Subject: [PATCH 1/4] more --- python/sglang/srt/layers/dp_attention.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index bf60641195b..c6eeb3edbe4 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -191,10 +191,11 @@ def _dp_gather( assert global_tokens.is_contiguous() if local_tokens.shape[0] > 0 and (is_partial or get_attention_tp_rank() == 0): - assert ( - global_tokens.untyped_storage().data_ptr() - != local_tokens.untyped_storage().data_ptr() - ), "aliasing between global_tokens and local_tokens not allowed" + if not torch.compiler.is_compiling(): + assert ( + global_tokens.untyped_storage().data_ptr() + != local_tokens.untyped_storage().data_ptr() + ), "aliasing between global_tokens and local_tokens not allowed" memcpy_triton( global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False ) @@ -242,10 +243,11 @@ def dp_scatter( assert local_tokens.is_contiguous() assert global_tokens.is_contiguous() if local_tokens.shape[0] > 0: - assert ( - local_tokens.untyped_storage().data_ptr() - != global_tokens.untyped_storage().data_ptr() - ), "aliasing between local_tokens and global_tokens not allowed" + if not torch.compiler.is_compiling(): + assert ( + local_tokens.untyped_storage().data_ptr() + != global_tokens.untyped_storage().data_ptr() + ), "aliasing between local_tokens and global_tokens not allowed" memcpy_triton( local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True ) From d50c6cbf4e8cdf6e51477b99f25ddcafab7f2ccb Mon Sep 17 00:00:00 2001 From: ispobock Date: Mon, 14 Apr 2025 07:29:40 +0000 Subject: [PATCH 2/4] fix and add ut --- python/sglang/srt/layers/dp_attention.py | 6 ++---- test/srt/test_dp_attention.py | 3 +++ 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index c6eeb3edbe4..059ebc371c4 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -193,8 +193,7 @@ def _dp_gather( if local_tokens.shape[0] > 0 and (is_partial or get_attention_tp_rank() == 0): if not torch.compiler.is_compiling(): assert ( - global_tokens.untyped_storage().data_ptr() - != local_tokens.untyped_storage().data_ptr() + local_tokens.untyped_storage() is not global_tokens.untyped_storage() ), "aliasing between global_tokens and local_tokens not allowed" memcpy_triton( global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False @@ -245,8 +244,7 @@ def dp_scatter( if local_tokens.shape[0] > 0: if not torch.compiler.is_compiling(): assert ( - local_tokens.untyped_storage().data_ptr() - != global_tokens.untyped_storage().data_ptr() + local_tokens.untyped_storage() is not global_tokens.untyped_storage() ), "aliasing between local_tokens and global_tokens not allowed" memcpy_triton( local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True diff --git a/test/srt/test_dp_attention.py b/test/srt/test_dp_attention.py index 5fb5e223b72..b47fe2c460e 100644 --- a/test/srt/test_dp_attention.py +++ b/test/srt/test_dp_attention.py @@ -28,6 +28,9 @@ def setUpClass(cls): "--enable-dp-attention", "--dp", "2", + "--enable-torch-compile", + "--torch-compile-max-bs", + "2", ], ) From d5f50c31f94fec64ce21e2308baaef0b0fb4774b Mon Sep 17 00:00:00 2001 From: ispobock Date: Mon, 14 Apr 2025 07:30:48 +0000 Subject: [PATCH 3/4] fix --- python/sglang/srt/layers/dp_attention.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index 059ebc371c4..c1b9e05ecd7 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -191,10 +191,9 @@ def _dp_gather( assert global_tokens.is_contiguous() if local_tokens.shape[0] > 0 and (is_partial or get_attention_tp_rank() == 0): - if not torch.compiler.is_compiling(): - assert ( - local_tokens.untyped_storage() is not global_tokens.untyped_storage() - ), "aliasing between global_tokens and local_tokens not allowed" + assert ( + local_tokens.untyped_storage() is not global_tokens.untyped_storage() + ), "aliasing between global_tokens and local_tokens not allowed" memcpy_triton( global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False ) @@ -242,10 +241,9 @@ def dp_scatter( assert local_tokens.is_contiguous() assert global_tokens.is_contiguous() if local_tokens.shape[0] > 0: - if not torch.compiler.is_compiling(): - assert ( - local_tokens.untyped_storage() is not global_tokens.untyped_storage() - ), "aliasing between local_tokens and global_tokens not allowed" + assert ( + local_tokens.untyped_storage() is not global_tokens.untyped_storage() + ), "aliasing between local_tokens and global_tokens not allowed" memcpy_triton( local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True ) From 7282abedf4521f11fae436b519d9a956a3821d80 Mon Sep 17 00:00:00 2001 From: ispobock Date: Mon, 14 Apr 2025 07:33:54 +0000 Subject: [PATCH 4/4] lint --- test/srt/parse_results.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/srt/parse_results.py b/test/srt/parse_results.py index 8389a4b9c2e..de1d5cf2740 100644 --- a/test/srt/parse_results.py +++ b/test/srt/parse_results.py @@ -1,7 +1,8 @@ -import json -import pandas as pd import argparse +import json import os + +import pandas as pd from tabulate import tabulate # Parse command-line arguments