Skip to content

[Do not merge] Hacks for the ROCm port#1314

Closed
pcmoritz wants to merge 17 commits into
vllm-project:mainfrom
pcmoritz:port-to-rocm-hacks
Closed

[Do not merge] Hacks for the ROCm port#1314
pcmoritz wants to merge 17 commits into
vllm-project:mainfrom
pcmoritz:port-to-rocm-hacks

Conversation

@pcmoritz
Copy link
Copy Markdown
Collaborator

@pcmoritz pcmoritz commented Oct 10, 2023

This is getting tests working on top of #1313

Currently the following tests are working:

  • kernels/test_activation.py
  • kernels/test_cache.py
  • kernels/test_layernorm.py
  • kernels/test_pos_encoding.py

Currently, the following test is failing:

  • kernels/test_attention.py with
_________________________________________________________________________________ test_single_query_cached_kv_attention[0-dtype0-8-False-64-num_heads0-7] __________________________________________________________________________________

kv_cache_factory = <function create_kv_caches at 0x7f359bcb6700>, num_seqs = 7, num_heads = (40, 40), head_size = 64, use_alibi = False, block_size = 8, dtype = torch.float16, seed = 0

    @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
    @pytest.mark.parametrize("num_heads", NUM_HEADS)
    @pytest.mark.parametrize("head_size", HEAD_SIZES)
    @pytest.mark.parametrize("use_alibi", USE_ALIBI)
    @pytest.mark.parametrize("block_size", BLOCK_SIZES)
    @pytest.mark.parametrize("dtype", DTYPES)
    @pytest.mark.parametrize("seed", SEEDS)
    @torch.inference_mode()
    def test_single_query_cached_kv_attention(
        kv_cache_factory,
        num_seqs: int,
        num_heads: Tuple[int, int],
        head_size: int,
        use_alibi: bool,
        block_size: int,
        dtype: torch.dtype,
        seed: int,
    ) -> None:
        random.seed(seed)
        torch.random.manual_seed(seed)
        torch.cuda.manual_seed(seed)
    
        scale = float(1.0 / (head_size**0.5))
        num_query_heads, num_kv_heads = num_heads
        query = torch.empty(num_seqs,
                            num_query_heads,
                            head_size,
                            dtype=dtype,
                            device="cuda")
        query.uniform_(-scale, scale)
    
        assert num_query_heads % num_kv_heads == 0
        num_queries_per_kv = num_query_heads // num_kv_heads
        head_mapping = torch.repeat_interleave(
            torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
            num_queries_per_kv)
        alibi_slopes = None
        if use_alibi:
            alibi_slopes = torch.randn(num_query_heads,
                                       dtype=torch.float,
                                       device="cuda")
    
        context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
        context_lens[-1] = MAX_SEQ_LEN
        max_context_len = max(context_lens)
        context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
    
        # Create the block tables.
        max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
        block_tables = []
        for _ in range(num_seqs):
            block_table = [
                random.randint(0, NUM_BLOCKS - 1)
                for _ in range(max_num_blocks_per_seq)
            ]
            block_tables.append(block_table)
        block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
    
        # Create the KV caches.
        key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
                                                    num_kv_heads, head_size, dtype,
                                                    seed)
        key_cache, value_cache = key_caches[0], value_caches[0]
    
        # Call the paged attention kernel.
        output = torch.empty_like(query)
        attention_ops.single_query_cached_kv_attention(
            output,
            query,
            key_cache,
            value_cache,
            head_mapping,
            scale,
            block_tables,
            context_lens,
            block_size,
            max_context_len,
            alibi_slopes,
        )
    
        # Run the reference implementation.
        ref_output = torch.empty_like(query)
        ref_single_query_cached_kv_attention(
            ref_output,
            query,
            num_queries_per_kv,
            key_cache,
            value_cache,
            block_tables,
            context_lens,
            scale,
            alibi_slopes,
        )
    
        # NOTE(woosuk): Due to the kernel-level differences in the two
        # implementations, there is a small numerical difference in the two
        # outputs. Thus, we use a relaxed tolerance for the test.
>       assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
E       AssertionError: assert False
E        +  where False = <built-in method allclose of type object at 0x7f37adc8c980>(tensor([[[inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf],\n         ...,\n         [inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf]],\n\n        [[inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf],\n         ...,\n         [inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf]],\n\n        [[inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf],\n         ...,\n         [inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf]],\n\n        ...,\n\n        [[inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf],\n         ...,\n         [inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf]],\n\n        [[inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf],\n         ...,\n         [inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf]],\n\n        [[inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf],\n         ...,\n         [inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf]]], device='cuda:0',\n       dtype=torch.float16), tensor([[[-2.9540e-04, -9.1505e-04, -5.2643e-03,  ..., -4.1389e-03,\n           3.7746e-03, -1.1806e-03],\n         [-2.3403e-03, -6.4545e-03,  1.9197e-03,  ..., -1.6556e-03,\n          -5.9319e-03, -4.0741e-03],\n         [-5.5981e-04,  2.6398e-03, -1.4648e-03,  ..., -7.0047e-04,\n           4.5547e-03, -1.5097e-03],\n         ...,\n         [-3.6411e-03, -2.8057e-03,  1.7796e-03,  ...,  1.0653e-03,\n          -1.0586e-03, -1.0653e-03],\n         [ 3.5915e-03,  2.6798e-03,  1.4706e-03,  ..., -2.3212e-03,\n           5.5008e-03,  3.4657e-03],\n         [-9.9659e-04,  8.4543e-04,  5.0163e-03,  ...,  1.5497e-03,\n          -4.0169e-03,  1.8406e-03]],\n\n        [[-1.2970e-04, -3.5572e-04, -2.5806e-03,  ..., -5.5542e-03,\n           3.1090e-03, -1.2016e-03],\n         [-2.1477e-03, -5.9204e-03,  1.9283e-03,  ..., -2.2907e-03,\n          -4.3106e-03, -1.1225e-03],\n         [-1.8883e-03,  1.8559e-03, -8.9836e-04,  ..., -1.3151e-03,\n           4.8218e-03,  3.4571e-04],\n         ...,\n         [-1.7996e-03, -5.6410e-04,  6.3360e-05,  ..., -1.0090e-03,\n          -4.7636e-04, -1.4601e-03],\n         [-1.9312e-04,  1.5726e-03, -6.9189e-04,  ..., -1.0176e-03,\n           4.1847e-03,  3.1605e-03],\n         [-2.1610e-03, -1.2455e-03,  5.3596e-03,  ...,  5.1069e-04,\n          -2.6073e-03, -1.6487e-04]],\n\n        [[-2.5043e-03, -4.9896e-03, -2.0828e-03,  ..., -5.8708e-03,\n           3.5019e-03, -3.8357e-03],\n         [ 2.5215e-03, -7.4043e-03,  1.6594e-04,  ..., -1.8339e-03,\n          -1.7347e-03, -2.3880e-03],\n         [ 4.2419e-03, -1.0729e-03, -4.2648e-03,  ..., -5.4512e-03,\n           1.0338e-02, -4.6959e-03],\n         ...,\n         [-6.6471e-04, -2.9659e-03, -5.2452e-04,  ...,  7.2908e-04,\n           5.1613e-03,  1.3485e-03],\n         [-1.4277e-03, -3.2883e-03,  3.8509e-03,  ..., -1.8845e-03,\n           5.3673e-03,  2.2583e-03],\n         [-2.8706e-03, -5.9938e-04,  8.8654e-03,  ...,  3.7251e-03,\n          -7.4425e-03, -5.9700e-03]],\n\n        ...,\n\n        [[ 7.9060e-04, -6.0892e-04, -4.1847e-03,  ..., -4.3831e-03,\n           3.0994e-03, -3.0003e-03],\n         [-2.2011e-03, -4.2000e-03,  6.2370e-04,  ..., -1.1024e-03,\n          -4.9896e-03, -2.7027e-03],\n         [-1.0033e-03,  3.2692e-03, -2.0065e-03,  ...,  1.0080e-03,\n           3.9978e-03, -1.6823e-03],\n         ...,\n         [-8.6689e-04, -1.4143e-03, -7.7438e-04,  ..., -2.0850e-04,\n           3.3438e-05, -4.0321e-03],\n         [ 1.0900e-03,  2.3079e-03,  1.1129e-03,  ...,  3.2353e-04,\n           3.8033e-03,  1.0681e-03],\n         [-1.4133e-03,  3.9697e-04,  4.8561e-03,  ...,  3.7289e-04,\n          -4.6120e-03,  1.6661e-03]],\n\n        [[-4.1604e-04, -1.8415e-03, -6.8817e-03,  ..., -4.1161e-03,\n           4.0970e-03, -4.7302e-03],\n         [-1.5364e-03, -6.6071e-03,  9.3699e-04,  ..., -2.3117e-03,\n          -5.2147e-03, -3.1834e-03],\n         [-1.6203e-03,  2.2907e-03, -2.0943e-03,  ...,  4.9496e-04,\n           3.9597e-03, -7.8249e-04],\n         ...,\n         [-1.4238e-03, -9.6035e-04,  2.8515e-04,  ..., -1.2674e-03,\n          -1.0672e-03, -1.9522e-03],\n         [ 1.9321e-03,  1.5774e-03,  8.1062e-06,  ..., -1.0405e-03,\n           5.6992e-03,  2.0828e-03],\n         [-2.5749e-03,  2.3186e-04,  5.7220e-03,  ...,  2.5730e-03,\n          -3.7003e-03,  2.2049e-03]],\n\n        [[-7.6294e-05,  4.0293e-05, -3.2806e-03,  ..., -4.9210e-03,\n           2.8820e-03, -2.0123e-03],\n         [-1.8377e-03, -6.0577e-03,  3.4370e-03,  ..., -1.4210e-03,\n          -5.4855e-03, -2.2049e-03],\n         [-1.9875e-03,  3.8471e-03, -2.9125e-03,  ..., -5.3453e-04,\n           6.4545e-03,  2.1338e-04],\n         ...,\n         [-1.2665e-03, -6.2466e-04,  1.8396e-03,  ..., -3.7932e-04,\n           6.3181e-04, -2.3403e-03],\n         [ 1.8797e-03,  1.2455e-03,  2.7514e-04,  ..., -9.2411e-04,\n           2.8858e-03,  2.9793e-03],\n         [-4.8923e-04,  5.1618e-05,  4.1428e-03,  ...,  1.1559e-03,\n          -2.8362e-03,  3.9363e-04]]], device='cuda:0', dtype=torch.float16), atol=0.001, rtol=1e-05)
E        +    where <built-in method allclose of type object at 0x7f37adc8c980> = torch.allclose

@iAmir97
Copy link
Copy Markdown
Contributor

iAmir97 commented Oct 18, 2023

@pcmoritz I'd like to submit a pull request to replace the usage of HIP API for type conversions for Float16 with ASM volatile instructions.

Additionally, I've modified the setup.py file to check for the availability of ROCMHOME, and if it is set, the flags are updated to include -DUSE_ROCM.

Would appreciate it if you could review the pull request and let me know if there are any issues or concerns. Thanks!

@WoosukKwon WoosukKwon added the rocm Related to AMD ROCm label Dec 1, 2023
@WoosukKwon
Copy link
Copy Markdown
Collaborator

WoosukKwon commented Dec 10, 2023

Closed as we merged #1836 which is a superset of this PR. @pcmoritz Thanks for the amazing work!

@WoosukKwon WoosukKwon closed this Dec 10, 2023
WeNeedMoreCode pushed a commit to WeNeedMoreCode/vllm that referenced this pull request Dec 15, 2025
Add codespell check test for doc only PR

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

rocm Related to AMD ROCm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants