Skip to content
4 changes: 4 additions & 0 deletions csrc/cpu/cpu_types_arm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,10 @@ struct FP32Vec16 : public VectorizedRegWrapper<FP32Vec16, 4, float> {

explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {};

// FP8 stub: dead code on ARM (fp8 KV cache is x86-only), needed for
// load_b_pair_vec template to compile on all platforms.
explicit FP32Vec16(const BF16Vec32&, int) : Base() {}

explicit FP32Vec16(const FP16Vec16& v) {
reg.val[0] = Vectorized<float>(vcvt_f32_f16(vget_low_f16(v.reg.val[0])));
reg.val[1] = Vectorized<float>(vcvt_f32_f16(vget_high_f16(v.reg.val[0])));
Expand Down
10 changes: 10 additions & 0 deletions csrc/cpu/cpu_types_scalar.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

namespace vec_op {

struct fp8_e4m3_tag {};
struct fp8_e5m2_tag {};

#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
Expand Down Expand Up @@ -145,6 +148,9 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
}

void save(void* ptr) const { *reinterpret_cast<f16x32_t*>(ptr) = reg; }

explicit BF16Vec32(const uint8_t*, fp8_e4m3_tag) : reg{} {}
explicit BF16Vec32(const uint8_t*, fp8_e5m2_tag) : reg{} {}
};

struct FP32Vec4 : public Vec<FP32Vec4> {
Expand Down Expand Up @@ -302,6 +308,10 @@ struct FP32Vec16 : public Vec<FP32Vec16> {

FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {};

// FP8 stub: dead code on scalar path (fp8 KV cache is x86-only), needed for
// load_b_pair_vec template to compile on all platforms.
explicit FP32Vec16(const BF16Vec32&, int) : reg{} {}

FP32Vec16 operator*(const FP32Vec16& b) const {
f32x16_t ret;
unroll_loop<int, VEC_ELEM_NUM>(
Expand Down
7 changes: 7 additions & 0 deletions csrc/cpu/cpu_types_vsx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
: reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {}

void save(void* ptr) const { *reinterpret_cast<ss16x8x4_t*>(ptr) = reg; }

explicit BF16Vec32(const uint8_t*, fp8_e4m3_tag) : reg{} {}
explicit BF16Vec32(const uint8_t*, fp8_e5m2_tag) : reg{} {}
};

struct FP32Vec4 : public Vec<FP32Vec4> {
Expand Down Expand Up @@ -408,6 +411,10 @@ struct FP32Vec16 : public Vec<FP32Vec16> {

explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}

// FP8 stub: dead code on PowerPC (fp8 KV cache is x86-only), needed for
// load_b_pair_vec template to compile on all platforms.
explicit FP32Vec16(const BF16Vec32&, int) : reg{} {}

explicit FP32Vec16(const INT32Vec16& v) {
reg.val[0] = vec_ctf(v.reg.val[0], 0);
reg.val[1] = vec_ctf(v.reg.val[1], 0);
Expand Down
4 changes: 4 additions & 0 deletions csrc/cpu/cpu_types_vxe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,10 @@ struct FP32Vec16 : public Vec<FP32Vec16> {

explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}

// FP8 stub: dead code on s390x (fp8 KV cache is x86-only), needed for
// load_b_pair_vec template to compile on all platforms.
explicit FP32Vec16(const BF16Vec32&, int) : reg{} {}

FP32Vec16 operator*(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({vec_mul(reg.val[0], b.reg.val[0]),
vec_mul(reg.val[1], b.reg.val[1]),
Expand Down
Loading