Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Quant-LLM compatible with BF16 #998

Closed
gau-nernst opened this issue Oct 3, 2024 · 3 comments · Fixed by #1147
Closed

Make Quant-LLM compatible with BF16 #998

gau-nernst opened this issue Oct 3, 2024 · 3 comments · Fixed by #1147
Labels
enhancement New feature or request good first issue Good for newcomers inference

Comments

@gau-nernst
Copy link
Collaborator

Quant-LLM code: https://github.com/pytorch/ao/tree/main/torchao/csrc/cuda/fp6_llm

Currently Quant-LLM kernel (backing FPx in torchao) only works with FP16. This creates a small divergence from other quantization methods, which all work with BF16. Since all recent models are trained and released with BF16, having BF16 support potentially improve accuracy for FPx models.

Might be over-simplifying, but I think it's just the matter of modifying dequant logic and MMA instructions (as well as update dtype in function signature appropriately)

template<int EXPONENT, int MANTISSA>
__device__ __forceinline__ void FPx_FP16_Cast_4Way(uint32_t *In, uint32_t *Out1, uint32_t *Out2) {
//
constexpr int RIGHT_SHIFT = 5 - EXPONENT;
constexpr int MASK1 = 0x80000000;
constexpr int MASK2 = MASK1 >> EXPONENT + MANTISSA;
constexpr int MASK3 = MASK2 & 0x7fffffff;
constexpr int MASK = MASK3 | MASK3 >> 16;
//
*Out1 = *In & 0x80008000;
*Out1 |= ( (*In) & MASK ) >> RIGHT_SHIFT;
//
*In = (*In) << 8;
*Out2 = *In & 0x80008000;
*Out2 |= ( (*In) & MASK ) >> RIGHT_SHIFT;
}

asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{ %0, %1, %2, %3},"
"{ %4, %5, %6, %7 },"
"{ %8, %9 },"
"{ %10, %11, %12, %13 };"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]),
"r"(b[0]), "r"(b[1]),
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));

cc @msaroufim @HDCharles

I might try to do it myself, but I think it would be an interesting good first issue task too. @tobiasvanderwerff Would you be interested?

@gau-nernst gau-nernst added enhancement New feature or request good first issue Good for newcomers inference labels Oct 3, 2024
@tobiasvanderwerff
Copy link
Contributor

I'd love to work on this @gau-nernst :)

So if you're ok with me doing this I'll probably get started on it in the next few days.

@DevyRuxpin
Copy link

Please assign, I may have a solution for this.

@tobiasvanderwerff
Copy link
Contributor

Hi @DevyRuxpin, I can understand that you want to work on this, but I have already made most of the necessary changes for this feature (see the diff here). Since your solution is not complete yet (it removes FP16 support), it's probably best to avoid doing further duplicate work.

yanbing-j pushed a commit to yanbing-j/ao that referenced this issue Dec 9, 2024
Summary: Read max_seq_length from model config instead of hard-code it
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request good first issue Good for newcomers inference
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants