diff --git a/flash_attn/flash_attn_triton_amd/README.md b/flash_attn/flash_attn_triton_amd/README.md index f3a5db67fc5..87213c1883c 100644 --- a/flash_attn/flash_attn_triton_amd/README.md +++ b/flash_attn/flash_attn_triton_amd/README.md @@ -81,10 +81,10 @@ docker run -it --network=host --user root --group-add video --cap-add=SYS_PTRACE ``` ###### FP8 -In our fork We have created the following api functions that use fp8 to compute their values. These functions are `flash_attn_fp8_func`, `flash_attn_varlen_fp8_func`, `flash_attn_qkvpacked_fp8_func` and `flash_attn_varlen_qkvpacked_fp8_func`. To use these functions just call them with like the other api functions, the casting will be handled internally. For example +In our fork We have created the following api functions that use fp8 internally to compute their values. These functions are `flash_attn_fp8_func`, `flash_attn_varlen_fp8_func`, `flash_attn_qkvpacked_fp8_func` and `flash_attn_varlen_qkvpacked_fp8_func`. Here is a usage example ``` -from flash_attn import flash_attn_qkvpacked_fp8_func +from flash_attn.flash_attn_triton_amd.fp8 import flash_attn_qkvpacked_fp8_func # forward pass out, lse, S_dmask = flash_attn_qkvpacked_fp8_func(