diff --git a/flashinfer/green_ctx.py b/flashinfer/green_ctx.py index 0555cec212..09962fd467 100644 --- a/flashinfer/green_ctx.py +++ b/flashinfer/green_ctx.py @@ -170,12 +170,27 @@ def split_device_green_ctx( RuntimeError: when requested SM allocation exceeds device capacity: ``num_groups * rounded_min_count > total_device_sms`` """ - cu_dev = get_cudevice(dev) - resource = get_device_resource(cu_dev) - results, remaining = split_resource(resource, num_groups, min_count) - resources = results + [remaining] - streams = create_green_ctx_streams(cu_dev, resources) - return streams, resources + try: + cu_dev = get_cudevice(dev) + resource = get_device_resource(cu_dev) + results, remaining = split_resource(resource, num_groups, min_count) + resources = results + [remaining] + streams = create_green_ctx_streams(cu_dev, resources) + return streams, resources + except RuntimeError as e: + if ( + "CUDA error code=914" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e) + or "CUDA error code=915" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e) + ): + raise RuntimeError( + f"{e}\n" + f"Failed to split device into {num_groups} groups with min_count={min_count}. " + f"This is likely due to insufficient number of SMs available on the device. " + f"Please reduce the number of groups or the minimum SM count per group." + ) from e + raise def split_device_green_ctx_by_sm_count( @@ -241,21 +256,40 @@ def split_device_green_ctx_by_sm_count( See `CUDA Green Contexts `_ for more details. """ - cu_dev = get_cudevice(dev) - resource = get_device_resource(cu_dev) + try: + cu_dev = get_cudevice(dev) + resource = get_device_resource(cu_dev) + + # Round sm counts to meet the alignment and granularity requirements + rounded_sm_counts = [] + for sm_count in sm_counts: + min_sm_count, sm_alignment = get_sm_count_constraint( + *get_compute_capability(dev) + ) + if sm_count <= 0: + raise ValueError(f"SM count must be positive, got {sm_count}") + rounded_sm_counts.append( + round_up(max(sm_count, min_sm_count), sm_alignment) + ) - # Round sm counts to meet the alignment and granularity requirements - rounded_sm_counts = [] - for sm_count in sm_counts: - min_sm_count, sm_alignment = get_sm_count_constraint( - *get_compute_capability(dev) + # Split the device into multiple green contexts + results, remaining = split_resource_by_sm_count( + cu_dev, resource, rounded_sm_counts ) - if sm_count <= 0: - raise ValueError(f"SM count must be positive, got {sm_count}") - rounded_sm_counts.append(round_up(max(sm_count, min_sm_count), sm_alignment)) - - # Split the device into multiple green contexts - results, remaining = split_resource_by_sm_count(cu_dev, resource, rounded_sm_counts) - resources = results + [remaining] - streams = create_green_ctx_streams(cu_dev, resources) - return streams, resources + resources = results + [remaining] + streams = create_green_ctx_streams(cu_dev, resources) + return streams, resources + except RuntimeError as e: + if ( + "CUDA error code=914" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e) + or "CUDA error code=915" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e) + ): + raise RuntimeError( + f"{e}\n" + f"Failed to split device with SM counts {sm_counts} (rounded to {rounded_sm_counts}). " + f"This is likely due to insufficient number of SMs available on the device. " + f"Please reduce the requested SM counts or use fewer partitions." + ) from e + raise diff --git a/tests/utils/test_green_ctx.py b/tests/utils/test_green_ctx.py index 4863dd5c51..99d6dc97bc 100644 --- a/tests/utils/test_green_ctx.py +++ b/tests/utils/test_green_ctx.py @@ -12,14 +12,30 @@ def test_green_ctx_creation( num_groups: int, min_count: int, ): - streams, resources = green_ctx.split_device_green_ctx( - torch.device(device), num_groups, min_count - ) + try: + streams, resources = green_ctx.split_device_green_ctx( + torch.device(device), num_groups, min_count + ) - assert len(resources) == num_groups + 1 - for resource in resources[:-1]: - sm_count = resource.sm.smCount - assert sm_count >= min_count + assert len(resources) == num_groups + 1 + for resource in resources[:-1]: + sm_count = resource.sm.smCount + assert sm_count >= min_count + except RuntimeError as e: + if ( + "CUDA error code=914" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e) + or "CUDA error code=915" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e) + ): + # Get total SM count on the device + cu_dev = green_ctx.get_cudevice(torch.device(device)) + device_resource = green_ctx.get_device_resource(cu_dev) + total_sms = device_resource.sm.smCount + pytest.skip( + f"Insufficient SMs on device. Total SMs available: {total_sms}, requested: num_groups={num_groups}, min_count={min_count}" + ) + raise @pytest.mark.parametrize("device", ["cuda:0"]) @@ -30,19 +46,35 @@ def test_green_ctx_kernel_execution( num_groups: int, min_count: int, ): - streams, resources = green_ctx.split_device_green_ctx( - torch.device(device), num_groups, min_count - ) - num_partitions = num_groups + 1 - assert len(streams) == num_partitions - assert len(resources) == num_partitions - - for stream in streams: - with torch.cuda.stream(stream): - x = torch.randn(8192, 8192, device=device, dtype=torch.bfloat16) - y = torch.randn(8192, 8192, device=device, dtype=torch.bfloat16) - z = x @ y - print(z.shape) + try: + streams, resources = green_ctx.split_device_green_ctx( + torch.device(device), num_groups, min_count + ) + num_partitions = num_groups + 1 + assert len(streams) == num_partitions + assert len(resources) == num_partitions + + for stream in streams: + with torch.cuda.stream(stream): + x = torch.randn(8192, 8192, device=device, dtype=torch.bfloat16) + y = torch.randn(8192, 8192, device=device, dtype=torch.bfloat16) + z = x @ y + print(z.shape) + except RuntimeError as e: + if ( + "CUDA error code=914" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e) + or "CUDA error code=915" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e) + ): + # Get total SM count on the device + cu_dev = green_ctx.get_cudevice(torch.device(device)) + device_resource = green_ctx.get_device_resource(cu_dev) + total_sms = device_resource.sm.smCount + pytest.skip( + f"Insufficient SMs on device. Total SMs available: {total_sms}, requested: num_groups={num_groups}, min_count={min_count}" + ) + raise @pytest.mark.parametrize("device", ["cuda:0"]) @@ -59,17 +91,33 @@ def test_split_device_green_ctx_by_sm_count_creation( device: str, sm_counts: list, ): - streams, resources = green_ctx.split_device_green_ctx_by_sm_count( - torch.device(device), sm_counts - ) - num_partitions = len(sm_counts) + 1 - assert len(resources) == num_partitions - assert len(streams) == num_partitions - - # Check that each partition has the expected SM count - for i, expected_sm_count in enumerate(sm_counts): - actual_sm_count = resources[i].sm.smCount - assert actual_sm_count >= expected_sm_count + try: + streams, resources = green_ctx.split_device_green_ctx_by_sm_count( + torch.device(device), sm_counts + ) + num_partitions = len(sm_counts) + 1 + assert len(resources) == num_partitions + assert len(streams) == num_partitions + + # Check that each partition has the expected SM count + for i, expected_sm_count in enumerate(sm_counts): + actual_sm_count = resources[i].sm.smCount + assert actual_sm_count >= expected_sm_count + except RuntimeError as e: + if ( + "CUDA error code=914" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e) + or "CUDA error code=915" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e) + ): + # Get total SM count on the device + cu_dev = green_ctx.get_cudevice(torch.device(device)) + device_resource = green_ctx.get_device_resource(cu_dev) + total_sms = device_resource.sm.smCount + pytest.skip( + f"Insufficient SMs on device. Total SMs available: {total_sms}, requested SM counts: {sm_counts}" + ) + raise @pytest.mark.parametrize("device", ["cuda:0"]) @@ -85,19 +133,35 @@ def test_split_device_green_ctx_by_sm_count_kernel_execution( device: str, sm_counts: list, ): - streams, resources = green_ctx.split_device_green_ctx_by_sm_count( - torch.device(device), sm_counts - ) - num_partitions = len(sm_counts) + 1 - assert len(streams) == num_partitions - assert len(resources) == num_partitions - - for i, stream in enumerate(streams): - with torch.cuda.stream(stream): - x = torch.randn(4096, 4096, device=device, dtype=torch.bfloat16) - y = torch.randn(4096, 4096, device=device, dtype=torch.bfloat16) - z = x @ y - print(f"Partition {i}: {z.shape}") + try: + streams, resources = green_ctx.split_device_green_ctx_by_sm_count( + torch.device(device), sm_counts + ) + num_partitions = len(sm_counts) + 1 + assert len(streams) == num_partitions + assert len(resources) == num_partitions + + for i, stream in enumerate(streams): + with torch.cuda.stream(stream): + x = torch.randn(4096, 4096, device=device, dtype=torch.bfloat16) + y = torch.randn(4096, 4096, device=device, dtype=torch.bfloat16) + z = x @ y + print(f"Partition {i}: {z.shape}") + except RuntimeError as e: + if ( + "CUDA error code=914" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e) + or "CUDA error code=915" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e) + ): + # Get total SM count on the device + cu_dev = green_ctx.get_cudevice(torch.device(device)) + device_resource = green_ctx.get_device_resource(cu_dev) + total_sms = device_resource.sm.smCount + pytest.skip( + f"Insufficient SMs on device. Total SMs available: {total_sms}, requested SM counts: {sm_counts}" + ) + raise @pytest.mark.parametrize("device", ["cuda:0"]) @@ -113,16 +177,32 @@ def test_split_device_green_ctx_by_sm_count_alignment( device: str, sm_counts: list, ): - _, resources = green_ctx.split_device_green_ctx_by_sm_count( - torch.device(device), sm_counts - ) - - for resource in resources[:-1]: # Exclude remaining SMs - sm_count = resource.sm.smCount - assert sm_count > 0 - - min_sm_count, sm_alignment = green_ctx.get_sm_count_constraint( - *green_ctx.get_compute_capability(torch.device(device)) + try: + _, resources = green_ctx.split_device_green_ctx_by_sm_count( + torch.device(device), sm_counts ) - assert sm_count >= min_sm_count - assert sm_count % sm_alignment == 0 + + for resource in resources[:-1]: # Exclude remaining SMs + sm_count = resource.sm.smCount + assert sm_count > 0 + + min_sm_count, sm_alignment = green_ctx.get_sm_count_constraint( + *green_ctx.get_compute_capability(torch.device(device)) + ) + assert sm_count >= min_sm_count + assert sm_count % sm_alignment == 0 + except RuntimeError as e: + if ( + "CUDA error code=914" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e) + or "CUDA error code=915" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e) + ): + # Get total SM count on the device + cu_dev = green_ctx.get_cudevice(torch.device(device)) + device_resource = green_ctx.get_device_resource(cu_dev) + total_sms = device_resource.sm.smCount + pytest.skip( + f"Insufficient SMs on device. Total SMs available: {total_sms}, requested SM counts: {sm_counts}" + ) + raise diff --git a/tests/utils/test_jit_example.py b/tests/utils/test_jit_example.py index fb169f1a7f..959f303914 100644 --- a/tests/utils/test_jit_example.py +++ b/tests/utils/test_jit_example.py @@ -11,7 +11,7 @@ gen_customize_single_prefill_module, ) from flashinfer.prefill import single_prefill_with_kv_cache_with_jit_module -from flashinfer.utils import MaskMode, is_sm90a_supported +from flashinfer.utils import MaskMode, is_sm90a_supported, get_compute_capability def test_single_decode_mask(): @@ -166,6 +166,10 @@ def test_flash_sigmoid(): torch.testing.assert_close(o, o_ref, rtol=2e-2, atol=2e-2) +@pytest.mark.xfail( + get_compute_capability(torch.device("cuda:0")) == (12, 1), + reason="Numerical accuracy issue on SM 121 (Spark)", +) def test_dump_logits(): torch.manual_seed(42) variant_decl = r"""