diff --git a/csrc/cpu/cpu_types_arm.hpp b/csrc/cpu/cpu_types_arm.hpp index c25713052725..b408731f40d1 100644 --- a/csrc/cpu/cpu_types_arm.hpp +++ b/csrc/cpu/cpu_types_arm.hpp @@ -486,6 +486,10 @@ struct FP32Vec16 : public VectorizedRegWrapper { explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}; + // FP8 stub: dead code on ARM (fp8 KV cache is x86-only), needed for + // load_b_pair_vec template to compile on all platforms. + explicit FP32Vec16(const BF16Vec32&, int) : Base() {} + explicit FP32Vec16(const FP16Vec16& v) { reg.val[0] = Vectorized(vcvt_f32_f16(vget_low_f16(v.reg.val[0]))); reg.val[1] = Vectorized(vcvt_f32_f16(vget_high_f16(v.reg.val[0]))); diff --git a/csrc/cpu/cpu_types_scalar.hpp b/csrc/cpu/cpu_types_scalar.hpp index f9da78283da5..d1c2fc85933a 100644 --- a/csrc/cpu/cpu_types_scalar.hpp +++ b/csrc/cpu/cpu_types_scalar.hpp @@ -6,6 +6,9 @@ namespace vec_op { +struct fp8_e4m3_tag {}; +struct fp8_e5m2_tag {}; + #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ @@ -145,6 +148,9 @@ struct BF16Vec32 : public Vec { } void save(void* ptr) const { *reinterpret_cast(ptr) = reg; } + + explicit BF16Vec32(const uint8_t*, fp8_e4m3_tag) : reg{} {} + explicit BF16Vec32(const uint8_t*, fp8_e5m2_tag) : reg{} {} }; struct FP32Vec4 : public Vec { @@ -302,6 +308,10 @@ struct FP32Vec16 : public Vec { FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}; + // FP8 stub: dead code on scalar path (fp8 KV cache is x86-only), needed for + // load_b_pair_vec template to compile on all platforms. + explicit FP32Vec16(const BF16Vec32&, int) : reg{} {} + FP32Vec16 operator*(const FP32Vec16& b) const { f32x16_t ret; unroll_loop( diff --git a/csrc/cpu/cpu_types_vsx.hpp b/csrc/cpu/cpu_types_vsx.hpp index fbd0767df896..87c7a9dd51f4 100644 --- a/csrc/cpu/cpu_types_vsx.hpp +++ b/csrc/cpu/cpu_types_vsx.hpp @@ -146,6 +146,9 @@ struct BF16Vec32 : public Vec { : reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {} void save(void* ptr) const { *reinterpret_cast(ptr) = reg; } + + explicit BF16Vec32(const uint8_t*, fp8_e4m3_tag) : reg{} {} + explicit BF16Vec32(const uint8_t*, fp8_e5m2_tag) : reg{} {} }; struct FP32Vec4 : public Vec { @@ -408,6 +411,10 @@ struct FP32Vec16 : public Vec { explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {} + // FP8 stub: dead code on PowerPC (fp8 KV cache is x86-only), needed for + // load_b_pair_vec template to compile on all platforms. + explicit FP32Vec16(const BF16Vec32&, int) : reg{} {} + explicit FP32Vec16(const INT32Vec16& v) { reg.val[0] = vec_ctf(v.reg.val[0], 0); reg.val[1] = vec_ctf(v.reg.val[1], 0); diff --git a/csrc/cpu/cpu_types_vxe.hpp b/csrc/cpu/cpu_types_vxe.hpp index 90a2dd918bd9..2e0af466b649 100644 --- a/csrc/cpu/cpu_types_vxe.hpp +++ b/csrc/cpu/cpu_types_vxe.hpp @@ -688,6 +688,10 @@ struct FP32Vec16 : public Vec { explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {} + // FP8 stub: dead code on s390x (fp8 KV cache is x86-only), needed for + // load_b_pair_vec template to compile on all platforms. + explicit FP32Vec16(const BF16Vec32&, int) : reg{} {} + FP32Vec16 operator*(const FP32Vec16& b) const { return FP32Vec16(f32x4x4_t({vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1]),