From 9a65c0ed885649be624889f7a7b58a8304585ff4 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 20 Oct 2025 02:22:40 -0400 Subject: [PATCH 1/6] upd --- tests/utils/test_green_ctx.py | 55 +++++++++++++++++++++++++++++++++ tests/utils/test_jit_example.py | 6 +++- tests/utils/test_sampling.py | 2 +- 3 files changed, 61 insertions(+), 2 deletions(-) diff --git a/tests/utils/test_green_ctx.py b/tests/utils/test_green_ctx.py index 4863dd5c51..1701fd8646 100644 --- a/tests/utils/test_green_ctx.py +++ b/tests/utils/test_green_ctx.py @@ -2,6 +2,26 @@ import torch import flashinfer.green_ctx as green_ctx +from flashinfer.utils import get_compute_capability, get_device_sm_count, round_up + + +def calculate_required_sms(num_groups: int, min_count: int, device: str) -> int: + """Calculate total SM count required for the test.""" + dev = torch.device(device) + min_sm, alignment = green_ctx.get_sm_count_constraint(*get_compute_capability(dev)) + rounded_min = round_up(max(min_count, min_sm), alignment) + return num_groups * rounded_min + + +def calculate_required_sms_by_counts(sm_counts: list, device: str) -> int: + """Calculate total SM count required for the test with specific SM counts.""" + dev = torch.device(device) + min_sm, alignment = green_ctx.get_sm_count_constraint(*get_compute_capability(dev)) + total = 0 + for sm_count in sm_counts: + rounded = round_up(max(sm_count, min_sm), alignment) + total += rounded + return total @pytest.mark.parametrize("device", ["cuda:0"]) @@ -12,6 +32,13 @@ def test_green_ctx_creation( num_groups: int, min_count: int, ): + required_sms = calculate_required_sms(num_groups, min_count, device) + available_sms = get_device_sm_count(torch.device(device)) + if required_sms > available_sms: + pytest.skip( + f"Test requires {required_sms} SMs but device only has {available_sms} SMs" + ) + streams, resources = green_ctx.split_device_green_ctx( torch.device(device), num_groups, min_count ) @@ -30,6 +57,13 @@ def test_green_ctx_kernel_execution( num_groups: int, min_count: int, ): + required_sms = calculate_required_sms(num_groups, min_count, device) + available_sms = get_device_sm_count(torch.device(device)) + if required_sms > available_sms: + pytest.skip( + f"Test requires {required_sms} SMs but device only has {available_sms} SMs" + ) + streams, resources = green_ctx.split_device_green_ctx( torch.device(device), num_groups, min_count ) @@ -59,6 +93,13 @@ def test_split_device_green_ctx_by_sm_count_creation( device: str, sm_counts: list, ): + required_sms = calculate_required_sms_by_counts(sm_counts, device) + available_sms = get_device_sm_count(torch.device(device)) + if required_sms > available_sms: + pytest.skip( + f"Test requires {required_sms} SMs but device only has {available_sms} SMs" + ) + streams, resources = green_ctx.split_device_green_ctx_by_sm_count( torch.device(device), sm_counts ) @@ -85,6 +126,13 @@ def test_split_device_green_ctx_by_sm_count_kernel_execution( device: str, sm_counts: list, ): + required_sms = calculate_required_sms_by_counts(sm_counts, device) + available_sms = get_device_sm_count(torch.device(device)) + if required_sms > available_sms: + pytest.skip( + f"Test requires {required_sms} SMs but device only has {available_sms} SMs" + ) + streams, resources = green_ctx.split_device_green_ctx_by_sm_count( torch.device(device), sm_counts ) @@ -113,6 +161,13 @@ def test_split_device_green_ctx_by_sm_count_alignment( device: str, sm_counts: list, ): + required_sms = calculate_required_sms_by_counts(sm_counts, device) + available_sms = get_device_sm_count(torch.device(device)) + if required_sms > available_sms: + pytest.skip( + f"Test requires {required_sms} SMs but device only has {available_sms} SMs" + ) + _, resources = green_ctx.split_device_green_ctx_by_sm_count( torch.device(device), sm_counts ) diff --git a/tests/utils/test_jit_example.py b/tests/utils/test_jit_example.py index fb169f1a7f..959f303914 100644 --- a/tests/utils/test_jit_example.py +++ b/tests/utils/test_jit_example.py @@ -11,7 +11,7 @@ gen_customize_single_prefill_module, ) from flashinfer.prefill import single_prefill_with_kv_cache_with_jit_module -from flashinfer.utils import MaskMode, is_sm90a_supported +from flashinfer.utils import MaskMode, is_sm90a_supported, get_compute_capability def test_single_decode_mask(): @@ -166,6 +166,10 @@ def test_flash_sigmoid(): torch.testing.assert_close(o, o_ref, rtol=2e-2, atol=2e-2) +@pytest.mark.xfail( + get_compute_capability(torch.device("cuda:0")) == (12, 1), + reason="Numerical accuracy issue on SM 121 (Spark)", +) def test_dump_logits(): torch.manual_seed(42) variant_decl = r""" diff --git a/tests/utils/test_sampling.py b/tests/utils/test_sampling.py index 20df72b55d..89cc997a11 100644 --- a/tests/utils/test_sampling.py +++ b/tests/utils/test_sampling.py @@ -72,7 +72,7 @@ def test_softmax( probs_ref = torch.softmax(logits_scaled, dim=-1) - assert torch.allclose(probs, probs_ref, atol=1e-5) + assert torch.allclose(probs, probs_ref, rtol=1e-5, atol=1e-5) @pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) From db585e5b412e8a4747e22b6aaab2ff9a3c551a8e Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Thu, 23 Oct 2025 14:19:53 -0700 Subject: [PATCH 2/6] Update tests/utils/test_green_ctx.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- tests/utils/test_green_ctx.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/utils/test_green_ctx.py b/tests/utils/test_green_ctx.py index 1701fd8646..f85098df93 100644 --- a/tests/utils/test_green_ctx.py +++ b/tests/utils/test_green_ctx.py @@ -32,15 +32,16 @@ def test_green_ctx_creation( num_groups: int, min_count: int, ): + dev = torch.device(device) required_sms = calculate_required_sms(num_groups, min_count, device) - available_sms = get_device_sm_count(torch.device(device)) + available_sms = get_device_sm_count(dev) if required_sms > available_sms: pytest.skip( f"Test requires {required_sms} SMs but device only has {available_sms} SMs" ) streams, resources = green_ctx.split_device_green_ctx( - torch.device(device), num_groups, min_count + dev, num_groups, min_count ) assert len(resources) == num_groups + 1 From 89eac51f00ec559e884e7f56eea006d238144945 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Fri, 24 Oct 2025 07:10:25 -0400 Subject: [PATCH 3/6] upd --- flashinfer/green_ctx.py | 28 ++++++++-- tests/utils/test_green_ctx.py | 100 ++++++++++++---------------------- 2 files changed, 59 insertions(+), 69 deletions(-) diff --git a/flashinfer/green_ctx.py b/flashinfer/green_ctx.py index 0555cec212..ae48bbc740 100644 --- a/flashinfer/green_ctx.py +++ b/flashinfer/green_ctx.py @@ -28,7 +28,7 @@ ) from e from .cuda_utils import checkCudaErrors -from .utils import get_compute_capability, round_up +from .utils import get_compute_capability, get_device_sm_count, round_up def get_sm_count_constraint(major: int, minor: int) -> Tuple[int, int]: @@ -170,6 +170,18 @@ def split_device_green_ctx( RuntimeError: when requested SM allocation exceeds device capacity: ``num_groups * rounded_min_count > total_device_sms`` """ + # Check if device has enough SMs + min_sm, alignment = get_sm_count_constraint(*get_compute_capability(dev)) + rounded_min = round_up(max(min_count, min_sm), alignment) + required_sms = num_groups * rounded_min + available_sms = get_device_sm_count(dev) + + if required_sms > available_sms: + raise RuntimeError( + f"Insufficient SMs: requested {num_groups} groups with {rounded_min} SMs each " + f"(total: {required_sms} SMs), but device only has {available_sms} SMs available" + ) + cu_dev = get_cudevice(dev) resource = get_device_resource(cu_dev) results, remaining = split_resource(resource, num_groups, min_count) @@ -246,14 +258,22 @@ def split_device_green_ctx_by_sm_count( # Round sm counts to meet the alignment and granularity requirements rounded_sm_counts = [] + min_sm_count, sm_alignment = get_sm_count_constraint(*get_compute_capability(dev)) for sm_count in sm_counts: - min_sm_count, sm_alignment = get_sm_count_constraint( - *get_compute_capability(dev) - ) if sm_count <= 0: raise ValueError(f"SM count must be positive, got {sm_count}") rounded_sm_counts.append(round_up(max(sm_count, min_sm_count), sm_alignment)) + # Check if device has enough SMs + required_sms = sum(rounded_sm_counts) + available_sms = get_device_sm_count(dev) + + if required_sms > available_sms: + raise RuntimeError( + f"Insufficient SMs: requested {rounded_sm_counts} SMs " + f"(total: {required_sms} SMs), but device only has {available_sms} SMs available" + ) + # Split the device into multiple green contexts results, remaining = split_resource_by_sm_count(cu_dev, resource, rounded_sm_counts) resources = results + [remaining] diff --git a/tests/utils/test_green_ctx.py b/tests/utils/test_green_ctx.py index f85098df93..16ce162577 100644 --- a/tests/utils/test_green_ctx.py +++ b/tests/utils/test_green_ctx.py @@ -2,26 +2,6 @@ import torch import flashinfer.green_ctx as green_ctx -from flashinfer.utils import get_compute_capability, get_device_sm_count, round_up - - -def calculate_required_sms(num_groups: int, min_count: int, device: str) -> int: - """Calculate total SM count required for the test.""" - dev = torch.device(device) - min_sm, alignment = green_ctx.get_sm_count_constraint(*get_compute_capability(dev)) - rounded_min = round_up(max(min_count, min_sm), alignment) - return num_groups * rounded_min - - -def calculate_required_sms_by_counts(sm_counts: list, device: str) -> int: - """Calculate total SM count required for the test with specific SM counts.""" - dev = torch.device(device) - min_sm, alignment = green_ctx.get_sm_count_constraint(*get_compute_capability(dev)) - total = 0 - for sm_count in sm_counts: - rounded = round_up(max(sm_count, min_sm), alignment) - total += rounded - return total @pytest.mark.parametrize("device", ["cuda:0"]) @@ -33,16 +13,14 @@ def test_green_ctx_creation( min_count: int, ): dev = torch.device(device) - required_sms = calculate_required_sms(num_groups, min_count, device) - available_sms = get_device_sm_count(dev) - if required_sms > available_sms: - pytest.skip( - f"Test requires {required_sms} SMs but device only has {available_sms} SMs" + try: + streams, resources = green_ctx.split_device_green_ctx( + dev, num_groups, min_count ) - - streams, resources = green_ctx.split_device_green_ctx( - dev, num_groups, min_count - ) + except RuntimeError as e: + if "Insufficient SMs" in str(e): + pytest.skip(str(e)) + raise assert len(resources) == num_groups + 1 for resource in resources[:-1]: @@ -58,16 +36,14 @@ def test_green_ctx_kernel_execution( num_groups: int, min_count: int, ): - required_sms = calculate_required_sms(num_groups, min_count, device) - available_sms = get_device_sm_count(torch.device(device)) - if required_sms > available_sms: - pytest.skip( - f"Test requires {required_sms} SMs but device only has {available_sms} SMs" + try: + streams, resources = green_ctx.split_device_green_ctx( + torch.device(device), num_groups, min_count ) - - streams, resources = green_ctx.split_device_green_ctx( - torch.device(device), num_groups, min_count - ) + except RuntimeError as e: + if "Insufficient SMs" in str(e): + pytest.skip(str(e)) + raise num_partitions = num_groups + 1 assert len(streams) == num_partitions assert len(resources) == num_partitions @@ -94,16 +70,14 @@ def test_split_device_green_ctx_by_sm_count_creation( device: str, sm_counts: list, ): - required_sms = calculate_required_sms_by_counts(sm_counts, device) - available_sms = get_device_sm_count(torch.device(device)) - if required_sms > available_sms: - pytest.skip( - f"Test requires {required_sms} SMs but device only has {available_sms} SMs" + try: + streams, resources = green_ctx.split_device_green_ctx_by_sm_count( + torch.device(device), sm_counts ) - - streams, resources = green_ctx.split_device_green_ctx_by_sm_count( - torch.device(device), sm_counts - ) + except RuntimeError as e: + if "Insufficient SMs" in str(e): + pytest.skip(str(e)) + raise num_partitions = len(sm_counts) + 1 assert len(resources) == num_partitions assert len(streams) == num_partitions @@ -127,16 +101,14 @@ def test_split_device_green_ctx_by_sm_count_kernel_execution( device: str, sm_counts: list, ): - required_sms = calculate_required_sms_by_counts(sm_counts, device) - available_sms = get_device_sm_count(torch.device(device)) - if required_sms > available_sms: - pytest.skip( - f"Test requires {required_sms} SMs but device only has {available_sms} SMs" + try: + streams, resources = green_ctx.split_device_green_ctx_by_sm_count( + torch.device(device), sm_counts ) - - streams, resources = green_ctx.split_device_green_ctx_by_sm_count( - torch.device(device), sm_counts - ) + except RuntimeError as e: + if "Insufficient SMs" in str(e): + pytest.skip(str(e)) + raise num_partitions = len(sm_counts) + 1 assert len(streams) == num_partitions assert len(resources) == num_partitions @@ -162,16 +134,14 @@ def test_split_device_green_ctx_by_sm_count_alignment( device: str, sm_counts: list, ): - required_sms = calculate_required_sms_by_counts(sm_counts, device) - available_sms = get_device_sm_count(torch.device(device)) - if required_sms > available_sms: - pytest.skip( - f"Test requires {required_sms} SMs but device only has {available_sms} SMs" + try: + _, resources = green_ctx.split_device_green_ctx_by_sm_count( + torch.device(device), sm_counts ) - - _, resources = green_ctx.split_device_green_ctx_by_sm_count( - torch.device(device), sm_counts - ) + except RuntimeError as e: + if "Insufficient SMs" in str(e): + pytest.skip(str(e)) + raise for resource in resources[:-1]: # Exclude remaining SMs sm_count = resource.sm.smCount From 9424eef3bdc95f1799173fb7ca2153f4f65a826b Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Tue, 28 Oct 2025 06:21:33 +0000 Subject: [PATCH 4/6] upd --- flashinfer/green_ctx.py | 92 +++++++++++------------ tests/utils/test_green_ctx.py | 134 ++++++++++++++++++++-------------- 2 files changed, 126 insertions(+), 100 deletions(-) diff --git a/flashinfer/green_ctx.py b/flashinfer/green_ctx.py index ae48bbc740..fa69ba4c15 100644 --- a/flashinfer/green_ctx.py +++ b/flashinfer/green_ctx.py @@ -28,7 +28,7 @@ ) from e from .cuda_utils import checkCudaErrors -from .utils import get_compute_capability, get_device_sm_count, round_up +from .utils import get_compute_capability, round_up def get_sm_count_constraint(major: int, minor: int) -> Tuple[int, int]: @@ -170,24 +170,23 @@ def split_device_green_ctx( RuntimeError: when requested SM allocation exceeds device capacity: ``num_groups * rounded_min_count > total_device_sms`` """ - # Check if device has enough SMs - min_sm, alignment = get_sm_count_constraint(*get_compute_capability(dev)) - rounded_min = round_up(max(min_count, min_sm), alignment) - required_sms = num_groups * rounded_min - available_sms = get_device_sm_count(dev) - - if required_sms > available_sms: - raise RuntimeError( - f"Insufficient SMs: requested {num_groups} groups with {rounded_min} SMs each " - f"(total: {required_sms} SMs), but device only has {available_sms} SMs available" - ) - - cu_dev = get_cudevice(dev) - resource = get_device_resource(cu_dev) - results, remaining = split_resource(resource, num_groups, min_count) - resources = results + [remaining] - streams = create_green_ctx_streams(cu_dev, resources) - return streams, resources + try: + cu_dev = get_cudevice(dev) + resource = get_device_resource(cu_dev) + results, remaining = split_resource(resource, num_groups, min_count) + resources = results + [remaining] + streams = create_green_ctx_streams(cu_dev, resources) + return streams, resources + except RuntimeError as e: + if "CUDA error code=914" in str(e) or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e) or \ + "CUDA error code=915" in str(e) or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e): + raise RuntimeError( + f"{e}\n" + f"Failed to split device into {num_groups} groups with min_count={min_count}. " + f"This is likely due to insufficient number of SMs available on the device. " + f"Please reduce the number of groups or the minimum SM count per group." + ) from e + raise def split_device_green_ctx_by_sm_count( @@ -253,29 +252,32 @@ def split_device_green_ctx_by_sm_count( See `CUDA Green Contexts `_ for more details. """ - cu_dev = get_cudevice(dev) - resource = get_device_resource(cu_dev) - - # Round sm counts to meet the alignment and granularity requirements - rounded_sm_counts = [] - min_sm_count, sm_alignment = get_sm_count_constraint(*get_compute_capability(dev)) - for sm_count in sm_counts: - if sm_count <= 0: - raise ValueError(f"SM count must be positive, got {sm_count}") - rounded_sm_counts.append(round_up(max(sm_count, min_sm_count), sm_alignment)) - - # Check if device has enough SMs - required_sms = sum(rounded_sm_counts) - available_sms = get_device_sm_count(dev) - - if required_sms > available_sms: - raise RuntimeError( - f"Insufficient SMs: requested {rounded_sm_counts} SMs " - f"(total: {required_sms} SMs), but device only has {available_sms} SMs available" - ) - - # Split the device into multiple green contexts - results, remaining = split_resource_by_sm_count(cu_dev, resource, rounded_sm_counts) - resources = results + [remaining] - streams = create_green_ctx_streams(cu_dev, resources) - return streams, resources + try: + cu_dev = get_cudevice(dev) + resource = get_device_resource(cu_dev) + + # Round sm counts to meet the alignment and granularity requirements + rounded_sm_counts = [] + for sm_count in sm_counts: + min_sm_count, sm_alignment = get_sm_count_constraint( + *get_compute_capability(dev) + ) + if sm_count <= 0: + raise ValueError(f"SM count must be positive, got {sm_count}") + rounded_sm_counts.append(round_up(max(sm_count, min_sm_count), sm_alignment)) + + # Split the device into multiple green contexts + results, remaining = split_resource_by_sm_count(cu_dev, resource, rounded_sm_counts) + resources = results + [remaining] + streams = create_green_ctx_streams(cu_dev, resources) + return streams, resources + except RuntimeError as e: + if "CUDA error code=914" in str(e) or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e) or \ + "CUDA error code=915" in str(e) or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e): + raise RuntimeError( + f"{e}\n" + f"Failed to split device with SM counts {sm_counts} (rounded to {rounded_sm_counts}). " + f"This is likely due to insufficient number of SMs available on the device. " + f"Please reduce the requested SM counts or use fewer partitions." + ) from e + raise diff --git a/tests/utils/test_green_ctx.py b/tests/utils/test_green_ctx.py index 16ce162577..81b72eb252 100644 --- a/tests/utils/test_green_ctx.py +++ b/tests/utils/test_green_ctx.py @@ -12,21 +12,25 @@ def test_green_ctx_creation( num_groups: int, min_count: int, ): - dev = torch.device(device) try: streams, resources = green_ctx.split_device_green_ctx( - dev, num_groups, min_count + torch.device(device), num_groups, min_count ) + + assert len(resources) == num_groups + 1 + for resource in resources[:-1]: + sm_count = resource.sm.smCount + assert sm_count >= min_count except RuntimeError as e: - if "Insufficient SMs" in str(e): - pytest.skip(str(e)) + if "CUDA error code=914" in str(e) or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e) or \ + "CUDA error code=915" in str(e) or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e): + # Get total SM count on the device + cu_dev = green_ctx.get_cudevice(torch.device(device)) + device_resource = green_ctx.get_device_resource(cu_dev) + total_sms = device_resource.sm.smCount + pytest.skip(f"Insufficient SMs on device. Total SMs available: {total_sms}, requested: num_groups={num_groups}, min_count={min_count}") raise - assert len(resources) == num_groups + 1 - for resource in resources[:-1]: - sm_count = resource.sm.smCount - assert sm_count >= min_count - @pytest.mark.parametrize("device", ["cuda:0"]) @pytest.mark.parametrize("num_groups", [1, 2, 3]) @@ -40,20 +44,25 @@ def test_green_ctx_kernel_execution( streams, resources = green_ctx.split_device_green_ctx( torch.device(device), num_groups, min_count ) + num_partitions = num_groups + 1 + assert len(streams) == num_partitions + assert len(resources) == num_partitions + + for stream in streams: + with torch.cuda.stream(stream): + x = torch.randn(8192, 8192, device=device, dtype=torch.bfloat16) + y = torch.randn(8192, 8192, device=device, dtype=torch.bfloat16) + z = x @ y + print(z.shape) except RuntimeError as e: - if "Insufficient SMs" in str(e): - pytest.skip(str(e)) + if "CUDA error code=914" in str(e) or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e) or \ + "CUDA error code=915" in str(e) or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e): + # Get total SM count on the device + cu_dev = green_ctx.get_cudevice(torch.device(device)) + device_resource = green_ctx.get_device_resource(cu_dev) + total_sms = device_resource.sm.smCount + pytest.skip(f"Insufficient SMs on device. Total SMs available: {total_sms}, requested: num_groups={num_groups}, min_count={min_count}") raise - num_partitions = num_groups + 1 - assert len(streams) == num_partitions - assert len(resources) == num_partitions - - for stream in streams: - with torch.cuda.stream(stream): - x = torch.randn(8192, 8192, device=device, dtype=torch.bfloat16) - y = torch.randn(8192, 8192, device=device, dtype=torch.bfloat16) - z = x @ y - print(z.shape) @pytest.mark.parametrize("device", ["cuda:0"]) @@ -74,18 +83,23 @@ def test_split_device_green_ctx_by_sm_count_creation( streams, resources = green_ctx.split_device_green_ctx_by_sm_count( torch.device(device), sm_counts ) + num_partitions = len(sm_counts) + 1 + assert len(resources) == num_partitions + assert len(streams) == num_partitions + + # Check that each partition has the expected SM count + for i, expected_sm_count in enumerate(sm_counts): + actual_sm_count = resources[i].sm.smCount + assert actual_sm_count >= expected_sm_count except RuntimeError as e: - if "Insufficient SMs" in str(e): - pytest.skip(str(e)) + if "CUDA error code=914" in str(e) or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e) or \ + "CUDA error code=915" in str(e) or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e): + # Get total SM count on the device + cu_dev = green_ctx.get_cudevice(torch.device(device)) + device_resource = green_ctx.get_device_resource(cu_dev) + total_sms = device_resource.sm.smCount + pytest.skip(f"Insufficient SMs on device. Total SMs available: {total_sms}, requested SM counts: {sm_counts}") raise - num_partitions = len(sm_counts) + 1 - assert len(resources) == num_partitions - assert len(streams) == num_partitions - - # Check that each partition has the expected SM count - for i, expected_sm_count in enumerate(sm_counts): - actual_sm_count = resources[i].sm.smCount - assert actual_sm_count >= expected_sm_count @pytest.mark.parametrize("device", ["cuda:0"]) @@ -105,20 +119,25 @@ def test_split_device_green_ctx_by_sm_count_kernel_execution( streams, resources = green_ctx.split_device_green_ctx_by_sm_count( torch.device(device), sm_counts ) + num_partitions = len(sm_counts) + 1 + assert len(streams) == num_partitions + assert len(resources) == num_partitions + + for i, stream in enumerate(streams): + with torch.cuda.stream(stream): + x = torch.randn(4096, 4096, device=device, dtype=torch.bfloat16) + y = torch.randn(4096, 4096, device=device, dtype=torch.bfloat16) + z = x @ y + print(f"Partition {i}: {z.shape}") except RuntimeError as e: - if "Insufficient SMs" in str(e): - pytest.skip(str(e)) + if "CUDA error code=914" in str(e) or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e) or \ + "CUDA error code=915" in str(e) or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e): + # Get total SM count on the device + cu_dev = green_ctx.get_cudevice(torch.device(device)) + device_resource = green_ctx.get_device_resource(cu_dev) + total_sms = device_resource.sm.smCount + pytest.skip(f"Insufficient SMs on device. Total SMs available: {total_sms}, requested SM counts: {sm_counts}") raise - num_partitions = len(sm_counts) + 1 - assert len(streams) == num_partitions - assert len(resources) == num_partitions - - for i, stream in enumerate(streams): - with torch.cuda.stream(stream): - x = torch.randn(4096, 4096, device=device, dtype=torch.bfloat16) - y = torch.randn(4096, 4096, device=device, dtype=torch.bfloat16) - z = x @ y - print(f"Partition {i}: {z.shape}") @pytest.mark.parametrize("device", ["cuda:0"]) @@ -138,17 +157,22 @@ def test_split_device_green_ctx_by_sm_count_alignment( _, resources = green_ctx.split_device_green_ctx_by_sm_count( torch.device(device), sm_counts ) - except RuntimeError as e: - if "Insufficient SMs" in str(e): - pytest.skip(str(e)) - raise - for resource in resources[:-1]: # Exclude remaining SMs - sm_count = resource.sm.smCount - assert sm_count > 0 + for resource in resources[:-1]: # Exclude remaining SMs + sm_count = resource.sm.smCount + assert sm_count > 0 - min_sm_count, sm_alignment = green_ctx.get_sm_count_constraint( - *green_ctx.get_compute_capability(torch.device(device)) - ) - assert sm_count >= min_sm_count - assert sm_count % sm_alignment == 0 + min_sm_count, sm_alignment = green_ctx.get_sm_count_constraint( + *green_ctx.get_compute_capability(torch.device(device)) + ) + assert sm_count >= min_sm_count + assert sm_count % sm_alignment == 0 + except RuntimeError as e: + if "CUDA error code=914" in str(e) or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e) or \ + "CUDA error code=915" in str(e) or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e): + # Get total SM count on the device + cu_dev = green_ctx.get_cudevice(torch.device(device)) + device_resource = green_ctx.get_device_resource(cu_dev) + total_sms = device_resource.sm.smCount + pytest.skip(f"Insufficient SMs on device. Total SMs available: {total_sms}, requested SM counts: {sm_counts}") + raise From a6ec87fd9fc15f632fe23638e98580ecce37c8e0 Mon Sep 17 00:00:00 2001 From: yzh119 Date: Tue, 28 Oct 2025 17:41:58 -0700 Subject: [PATCH 5/6] pre-commit --- flashinfer/green_ctx.py | 24 ++++++++++---- tests/utils/test_green_ctx.py | 60 ++++++++++++++++++++++++++--------- 2 files changed, 63 insertions(+), 21 deletions(-) diff --git a/flashinfer/green_ctx.py b/flashinfer/green_ctx.py index fa69ba4c15..09962fd467 100644 --- a/flashinfer/green_ctx.py +++ b/flashinfer/green_ctx.py @@ -178,8 +178,12 @@ def split_device_green_ctx( streams = create_green_ctx_streams(cu_dev, resources) return streams, resources except RuntimeError as e: - if "CUDA error code=914" in str(e) or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e) or \ - "CUDA error code=915" in str(e) or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e): + if ( + "CUDA error code=914" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e) + or "CUDA error code=915" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e) + ): raise RuntimeError( f"{e}\n" f"Failed to split device into {num_groups} groups with min_count={min_count}. " @@ -264,16 +268,24 @@ def split_device_green_ctx_by_sm_count( ) if sm_count <= 0: raise ValueError(f"SM count must be positive, got {sm_count}") - rounded_sm_counts.append(round_up(max(sm_count, min_sm_count), sm_alignment)) + rounded_sm_counts.append( + round_up(max(sm_count, min_sm_count), sm_alignment) + ) # Split the device into multiple green contexts - results, remaining = split_resource_by_sm_count(cu_dev, resource, rounded_sm_counts) + results, remaining = split_resource_by_sm_count( + cu_dev, resource, rounded_sm_counts + ) resources = results + [remaining] streams = create_green_ctx_streams(cu_dev, resources) return streams, resources except RuntimeError as e: - if "CUDA error code=914" in str(e) or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e) or \ - "CUDA error code=915" in str(e) or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e): + if ( + "CUDA error code=914" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e) + or "CUDA error code=915" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e) + ): raise RuntimeError( f"{e}\n" f"Failed to split device with SM counts {sm_counts} (rounded to {rounded_sm_counts}). " diff --git a/tests/utils/test_green_ctx.py b/tests/utils/test_green_ctx.py index 81b72eb252..99d6dc97bc 100644 --- a/tests/utils/test_green_ctx.py +++ b/tests/utils/test_green_ctx.py @@ -22,13 +22,19 @@ def test_green_ctx_creation( sm_count = resource.sm.smCount assert sm_count >= min_count except RuntimeError as e: - if "CUDA error code=914" in str(e) or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e) or \ - "CUDA error code=915" in str(e) or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e): + if ( + "CUDA error code=914" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e) + or "CUDA error code=915" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e) + ): # Get total SM count on the device cu_dev = green_ctx.get_cudevice(torch.device(device)) device_resource = green_ctx.get_device_resource(cu_dev) total_sms = device_resource.sm.smCount - pytest.skip(f"Insufficient SMs on device. Total SMs available: {total_sms}, requested: num_groups={num_groups}, min_count={min_count}") + pytest.skip( + f"Insufficient SMs on device. Total SMs available: {total_sms}, requested: num_groups={num_groups}, min_count={min_count}" + ) raise @@ -55,13 +61,19 @@ def test_green_ctx_kernel_execution( z = x @ y print(z.shape) except RuntimeError as e: - if "CUDA error code=914" in str(e) or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e) or \ - "CUDA error code=915" in str(e) or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e): + if ( + "CUDA error code=914" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e) + or "CUDA error code=915" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e) + ): # Get total SM count on the device cu_dev = green_ctx.get_cudevice(torch.device(device)) device_resource = green_ctx.get_device_resource(cu_dev) total_sms = device_resource.sm.smCount - pytest.skip(f"Insufficient SMs on device. Total SMs available: {total_sms}, requested: num_groups={num_groups}, min_count={min_count}") + pytest.skip( + f"Insufficient SMs on device. Total SMs available: {total_sms}, requested: num_groups={num_groups}, min_count={min_count}" + ) raise @@ -92,13 +104,19 @@ def test_split_device_green_ctx_by_sm_count_creation( actual_sm_count = resources[i].sm.smCount assert actual_sm_count >= expected_sm_count except RuntimeError as e: - if "CUDA error code=914" in str(e) or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e) or \ - "CUDA error code=915" in str(e) or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e): + if ( + "CUDA error code=914" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e) + or "CUDA error code=915" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e) + ): # Get total SM count on the device cu_dev = green_ctx.get_cudevice(torch.device(device)) device_resource = green_ctx.get_device_resource(cu_dev) total_sms = device_resource.sm.smCount - pytest.skip(f"Insufficient SMs on device. Total SMs available: {total_sms}, requested SM counts: {sm_counts}") + pytest.skip( + f"Insufficient SMs on device. Total SMs available: {total_sms}, requested SM counts: {sm_counts}" + ) raise @@ -130,13 +148,19 @@ def test_split_device_green_ctx_by_sm_count_kernel_execution( z = x @ y print(f"Partition {i}: {z.shape}") except RuntimeError as e: - if "CUDA error code=914" in str(e) or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e) or \ - "CUDA error code=915" in str(e) or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e): + if ( + "CUDA error code=914" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e) + or "CUDA error code=915" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e) + ): # Get total SM count on the device cu_dev = green_ctx.get_cudevice(torch.device(device)) device_resource = green_ctx.get_device_resource(cu_dev) total_sms = device_resource.sm.smCount - pytest.skip(f"Insufficient SMs on device. Total SMs available: {total_sms}, requested SM counts: {sm_counts}") + pytest.skip( + f"Insufficient SMs on device. Total SMs available: {total_sms}, requested SM counts: {sm_counts}" + ) raise @@ -168,11 +192,17 @@ def test_split_device_green_ctx_by_sm_count_alignment( assert sm_count >= min_sm_count assert sm_count % sm_alignment == 0 except RuntimeError as e: - if "CUDA error code=914" in str(e) or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e) or \ - "CUDA error code=915" in str(e) or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e): + if ( + "CUDA error code=914" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e) + or "CUDA error code=915" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e) + ): # Get total SM count on the device cu_dev = green_ctx.get_cudevice(torch.device(device)) device_resource = green_ctx.get_device_resource(cu_dev) total_sms = device_resource.sm.smCount - pytest.skip(f"Insufficient SMs on device. Total SMs available: {total_sms}, requested SM counts: {sm_counts}") + pytest.skip( + f"Insufficient SMs on device. Total SMs available: {total_sms}, requested SM counts: {sm_counts}" + ) raise From 60ffe137ac793e450c67bb6308205298432ce821 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Tue, 4 Nov 2025 20:08:56 -0500 Subject: [PATCH 6/6] revert changes to test_sampling.py --- tests/utils/test_sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils/test_sampling.py b/tests/utils/test_sampling.py index 89cc997a11..20df72b55d 100644 --- a/tests/utils/test_sampling.py +++ b/tests/utils/test_sampling.py @@ -72,7 +72,7 @@ def test_softmax( probs_ref = torch.softmax(logits_scaled, dim=-1) - assert torch.allclose(probs, probs_ref, rtol=1e-5, atol=1e-5) + assert torch.allclose(probs, probs_ref, atol=1e-5) @pytest.mark.parametrize("vocab_size", [111, 32000, 128256])