diff --git a/src/include/migraphx/raw_data.hpp b/src/include/migraphx/raw_data.hpp index 63f512a4948..ebd72acaab6 100644 --- a/src/include/migraphx/raw_data.hpp +++ b/src/include/migraphx/raw_data.hpp @@ -53,15 +53,15 @@ struct raw_data : raw_data_base friend Stream& operator<<(Stream& os, const Derived& d) { if(not d.empty()) - d.visit([&](auto x) { os << x; }, - [&](auto&& xs) { - for(auto&& x : xs) - { - os << "{ "; - os << x; - os << " }, "; - } - }); + d.fallback_visit([&](auto x) { os << x; }, + [&](auto&& xs) { + for(auto&& x : xs) + { + os << "{ "; + os << x; + os << " }, "; + } + }); return os; } @@ -123,9 +123,13 @@ struct raw_data : raw_data_base } else { - auto&& buffer = static_cast(*this).data(); + auto* buffer = static_cast(*this).data(); shape view_shape = {shape::uint8_type, {s.bytes()}}; - v(make_view(view_shape, reinterpret_cast(buffer))); + using byte_type = + std::conditional_t>{}, + const byte*, + byte*>; + v(make_view(view_shape, reinterpret_cast(buffer))); } } diff --git a/src/propagate_constant.cpp b/src/propagate_constant.cpp index 50e065520a9..8b44064fa2f 100644 --- a/src/propagate_constant.cpp +++ b/src/propagate_constant.cpp @@ -40,7 +40,7 @@ static bool skip_propagate(instruction_ref ins) { if(contains({"contiguous", "dequantizelinear", "reshape"}, ins->name())) return skip_propagate(ins->inputs().front()); - if(ins->name() == "unpack_int4") + if(contains({"unpack_int4", "unpack_fp4"}, ins->name())) return true; auto&& s = ins->get_shape(); if(s.broadcasted() and s.element_space() < s.elements()) diff --git a/src/simplify_qdq.cpp b/src/simplify_qdq.cpp index f4569fe7a36..4169964d9d0 100644 --- a/src/simplify_qdq.cpp +++ b/src/simplify_qdq.cpp @@ -49,32 +49,36 @@ std::unordered_set get_quantizable_op_names() return s; } -// Helper function to insert quantized versions of any broadcasts and transpose ops that -// occur between dequantizelinear and the quantized op -auto propagate_quantized_ins(module& m, - const instruction_ref dqins, - const instruction_ref qop_arg, - bool is_fp16_model = false) +std::vector get_between_ins(const instruction_ref dqins, + const instruction_ref qop_arg) { auto prev_ins = qop_arg; std::vector ins_between; - // matcher skips continguous, multi/broadcasts and transposes, collect all those - // instructions while(prev_ins != dqins) { ins_between.push_back(prev_ins); prev_ins = prev_ins->inputs().front(); } - auto qinp = dqins->inputs().front(); + return ins_between; +} + +// Helper function to insert quantized versions of any broadcasts and transpose ops that +// occur between dequantizelinear and the quantized op +auto propagate_quantized_ins(module& m, + const instruction_ref dqins, + instruction_ref input_ins, + std::vector ins_between, + bool is_fp16_model = false) +{ for(auto ins : reverse_iterator_for(ins_between)) { if((*ins)->name() == "convert" and is_fp16_model) { continue; } - qinp = m.insert_instruction(dqins, (*ins)->get_operator(), {qinp}); + input_ins = m.insert_instruction(dqins, (*ins)->get_operator(), {input_ins}); } - return qinp; + return input_ins; } struct match_find_quantizable_ops @@ -140,8 +144,13 @@ struct match_find_quantizable_ops assert(dq1->get_shape().type() == migraphx::shape::float_type); is_fp16_model = true; } - qop_args.at(0) = propagate_quantized_ins(m, dq1, qop_args[0], is_fp16_model); - qop_args.at(1) = propagate_quantized_ins(m, dq2, qop_args[1], is_fp16_model); + + auto qop_between_arg0 = get_between_ins(dq1, qop_args[0]); + auto qop_between_arg1 = get_between_ins(dq2, qop_args[1]); + qop_args.at(0) = + propagate_quantized_ins(m, dq1, dq1->inputs().front(), qop_between_arg0, is_fp16_model); + qop_args.at(1) = + propagate_quantized_ins(m, dq2, dq2->inputs().front(), qop_between_arg1, is_fp16_model); auto arg1_lens = qop_args[0]->get_shape().lens(); auto arg2_lens = qop_args[1]->get_shape().lens(); @@ -280,15 +289,13 @@ struct match_find_quantizable_ops } }; -// Note: scales are not constant b/c of dynamic quantization. // Checks for block quantized scales by checking scales are not scalar or 1D. -inline auto dynamic_block_dq(const std::string& scale) +inline auto block_dq(const std::string& scale) { // clang-format off return match::name("dequantizelinear")( match::nargs(2), match::arg(1)(match::skip_broadcasts(match::none_of( - match::is_constant(), match::scalar_shape, match::ndim(1) ).bind(scale)))); @@ -305,8 +312,8 @@ struct match_find_mx_quantizable_ops { auto matcher() const { - auto dq1 = match::arg(0)(skip_post_dq_ops(dynamic_block_dq("scale1").bind("dq1"))); - auto dq2 = match::arg(1)(skip_post_dq_ops(dynamic_block_dq("scale2").bind("dq2"))); + auto dq1 = match::arg(0)(skip_post_dq_ops(block_dq("scale1").bind("dq1"))); + auto dq2 = match::arg(1)(skip_post_dq_ops(block_dq("scale2").bind("dq2"))); return match::name("dot")(dq1, dq2); } @@ -328,10 +335,16 @@ struct match_find_mx_quantizable_ops assert(dq1->get_shape().type() == migraphx::shape::float_type); is_fp16_model = true; } - qop_args.at(0) = propagate_quantized_ins(m, dq1, qop_args[0], is_fp16_model); - qop_args.at(1) = propagate_quantized_ins(m, dq2, qop_args[1], is_fp16_model); - qop_args.push_back(scale1); - qop_args.push_back(scale2); + auto qop_between_arg0 = get_between_ins(dq1, qop_args[0]); + qop_args.at(0) = + propagate_quantized_ins(m, dq1, dq1->inputs().front(), qop_between_arg0, is_fp16_model); + auto qop_between_arg1 = get_between_ins(dq2, qop_args[1]); + qop_args.at(1) = + propagate_quantized_ins(m, dq2, dq2->inputs().front(), qop_between_arg1, is_fp16_model); + qop_args.push_back( + propagate_quantized_ins(m, dq1, scale1, qop_between_arg0, is_fp16_model)); + qop_args.push_back( + propagate_quantized_ins(m, dq2, scale2, qop_between_arg1, is_fp16_model)); if(qop->name() == "convolution") { diff --git a/test/propagate_constant_test.cpp b/test/propagate_constant_test.cpp index da7511e3482..72011fd8d83 100644 --- a/test/propagate_constant_test.cpp +++ b/test/propagate_constant_test.cpp @@ -535,4 +535,31 @@ TEST_CASE(block_dequantize_int4) EXPECT(m1.sort() == m2.sort()); } +TEST_CASE(pack_unpack_fp4) +{ + migraphx::shape s1{migraphx::shape::float_type, {4}}; + migraphx::shape s2{migraphx::shape::fp4x2_type, {2}}; + migraphx::module m1; + { + const std::vector vec = {1.f, 0.f, 2.f, 0.f}; + auto l = m1.add_literal(migraphx::literal(s1, vec)); + auto pack = m1.add_instruction(migraphx::make_op("pack_fp4"), l); + auto unpack = m1.add_instruction(migraphx::make_op("unpack_fp4"), pack); + m1.add_return({unpack}); + } + + run_pass(m1); + + migraphx::module m2; + { + using migraphx::shape; + const std::vector vec = {0x2, 0x4}; + auto l = m2.add_literal(migraphx::literal(s2, vec.data())); + auto unpack = m2.add_instruction(migraphx::make_op("unpack_fp4"), l); + m2.add_return({unpack}); + } + + EXPECT(m1 == m2); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/simplify_qdq_test.cpp b/test/simplify_qdq_test.cpp index bfc3d61386b..8d548224762 100644 --- a/test/simplify_qdq_test.cpp +++ b/test/simplify_qdq_test.cpp @@ -1815,6 +1815,104 @@ TEST_CASE(fp4x2_quant_dot_even) EXPECT(m1 == m2); } +TEST_CASE(fp4x2_quant_dot_trans_b) +{ + migraphx::shape shape_packed_a{migraphx::shape::fp4x2_type, {1, 3, 6, 12}}; + migraphx::shape shape_packed_b{migraphx::shape::fp4x2_type, {1, 3, 8, 12}}; + migraphx::shape shape_scales_a{migraphx::shape::float_type, {1, 3, 6, 24}}; + migraphx::shape shape_scales_b{migraphx::shape::float_type, {1, 3, 8, 24}}; + + migraphx::module m1; + { + auto packed_a = m1.add_parameter("input", shape_packed_a); + auto packed_b = m1.add_parameter("weights", shape_packed_b); + auto scale_a = m1.add_parameter("scale_a", shape_scales_a); + auto scale_b = m1.add_parameter("scale_b", shape_scales_b); + + auto unpack_a = + m1.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_a); + auto unpack_b = + m1.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_b); + auto dq_a = m1.add_instruction(migraphx::make_op("dequantizelinear"), unpack_a, scale_a); + auto dq_b = m1.add_instruction(migraphx::make_op("dequantizelinear"), unpack_b, scale_b); + auto trans_b = m1.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), dq_b); + auto dot = m1.add_instruction(migraphx::make_op("dot"), dq_a, trans_b); + m1.add_return({dot}); + } + + migraphx::module m2; + { + auto packed_a = m2.add_parameter("input", shape_packed_a); + auto packed_b = m2.add_parameter("weights", shape_packed_b); + auto scale_a = m2.add_parameter("scale_a", shape_scales_a); + auto scale_b = m2.add_parameter("scale_b", shape_scales_b); + + auto unpack_a = + m2.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_a); + auto unpack_b = + m2.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_b); + auto trans_b = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), unpack_b); + auto trans_scale_b = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), scale_b); + auto quant_dot = m2.add_instruction( + migraphx::make_op("quant_dot"), unpack_a, trans_b, scale_a, trans_scale_b); + m2.add_return({quant_dot}); + } + + run_pass(m1); + EXPECT(m1 == m2); +} + +TEST_CASE(fp4x2_quant_dot_const_b) +{ + migraphx::shape shape_packed_a{migraphx::shape::fp4x2_type, {1, 3, 6, 12}}; + migraphx::shape shape_packed_b{migraphx::shape::fp4x2_type, {1, 3, 24, 4}}; + migraphx::shape shape_packed_b_gen{migraphx::shape::uint8_type, {1, 3, 24, 4}}; + migraphx::shape shape_scales_a{migraphx::shape::float_type, {1, 3, 6, 24}}; + migraphx::shape shape_scales_b{migraphx::shape::float_type, {1, 3, 24, 8}}; + unsigned long seed = 826; + migraphx::literal b_lit = generate_literal(shape_packed_b_gen, seed); + migraphx::literal scale_b_lit = generate_literal(shape_scales_b, seed); + migraphx::module m1; + { + auto packed_a = m1.add_parameter("input", shape_packed_a); + // avoiding visit fp4x2_type + auto packed_b = m1.add_literal(shape_packed_b, b_lit.data()); + auto scale_a = m1.add_parameter("scale_a", shape_scales_a); + auto scale_b = m1.add_literal(scale_b_lit); + + auto unpack_a = + m1.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_a); + auto unpack_b = + m1.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_b); + auto dq_a = m1.add_instruction(migraphx::make_op("dequantizelinear"), unpack_a, scale_a); + auto dq_b = m1.add_instruction(migraphx::make_op("dequantizelinear"), unpack_b, scale_b); + auto dot = m1.add_instruction(migraphx::make_op("dot"), dq_a, dq_b); + m1.add_return({dot}); + } + + migraphx::module m2; + { + auto packed_a = m2.add_parameter("input", shape_packed_a); + auto packed_b = m2.add_literal(shape_packed_b, b_lit.data()); + auto scale_a = m2.add_parameter("scale_a", shape_scales_a); + auto scale_b = m2.add_literal(scale_b_lit); + + auto unpack_a = + m2.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_a); + auto unpack_b = + m2.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_b); + auto quant_dot = m2.add_instruction( + migraphx::make_op("quant_dot"), unpack_a, unpack_b, scale_a, scale_b); + m2.add_return({quant_dot}); + } + + run_pass(m1); + EXPECT(m1 == m2); +} + // Test that unused qdq with pack_fp4, unpack_fp4 are removed TEST_CASE(fp4x2_even_remove_qdq) {