Skip to content

Fixes the incorrect argument in the prefix-prefill test cases#3246

Merged
simon-mo merged 1 commit intovllm-project:mainfrom
sighingnow:ht/fixes-prefix-prefill-tests
Mar 16, 2024
Merged

Fixes the incorrect argument in the prefix-prefill test cases#3246
simon-mo merged 1 commit intovllm-project:mainfrom
sighingnow:ht/fixes-prefix-prefill-tests

Conversation

@sighingnow
Copy link
Copy Markdown
Collaborator

See also comment in #3007

Signed-off-by: Tao He <sighingnow@gmail.com>
Copy link
Copy Markdown
Member

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the fix!

Comment on lines +39 to +43
# Need this, otherwise when we capture the graph the process for GPU 1 would run on both
# GPU0 and GPU1 and things would hang
#
# see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523
torch.cuda.set_device(device)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confused why do we need this? Can you give a more detailed example?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There would be an error if we run the test case in environments with 2 GPU card, the test case test_contexted_kv_attention[cuda:0-dtype0-128-64-64] passed, but when run
test_contexted_kv_attention[cuda:1-dtype0-128-64-64] (note now it uses cuda:1), it failed and complains:

        bin = self.cache[device][key]
        if not warmup:
>           bin.c_wrapper(
                grid_0,
                grid_1,
                grid_2,
                bin.num_warps,
                bin.num_ctas,
                bin.clusterDims[0],
                bin.clusterDims[1],
                bin.clusterDims[2],
                bin.shared,
                stream,
                bin.cu_function,
                CompiledKernel.launch_enter_hook,
                CompiledKernel.launch_exit_hook,
                bin,
                *bin.assemble_tensormap_to_arg(non_constexpr_arg_values),
            )
E           ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)

/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py:550: ValueError

I'm not very clear about the root causes, but I found the same issue report in flash-attention and the fix from here: Dao-AILab/flash-attention#523 (comment), and confirmed it works.

@sighingnow
Copy link
Copy Markdown
Collaborator Author

Hi @zhuohan123 any further comments on this patch?

Thanks!

@sighingnow
Copy link
Copy Markdown
Collaborator Author

Hi @zhuohan123 @simon-mo, could you please take another look at this PR?

Thanks!

@simon-mo simon-mo merged commit 3123f15 into vllm-project:main Mar 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants