Skip to content

Commit

Permalink
Fix int4 conversion (openvinotoolkit#4983)
Browse files Browse the repository at this point in the history
* Fix int4 conversion

* Simplify logic
  • Loading branch information
ilyachur authored Mar 29, 2021
1 parent 007078d commit de50ecf
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ struct EnumClassHash {
* int4 [1011] -> int8 [00001011]
* * For signed types we copy all bits (except sign bit) to destination type and after
* that for negative values we set to 1 all higher bits:
* int4 [1011] -> int8 [11110011]
* int4 [1011] -> int8 [11111011]
*
* @param src source value !!! the type must be unsigned !!!
* @param dst destination value !!! the type must be unsigned !!!
Expand All @@ -427,42 +427,29 @@ void convert_lp_value(const SRC& src, DST& dst, size_t src_offset, size_t src_si
// dst [10001111 00000100] offset 5 size 9
// new_val [00000000 00000000]
DST new_val = 0;
// If source type is signed
if (is_signed) {
// Get the sign of value
// sign [00000000]
// invert value in order to use XOR
SRC sign = (~(val >> (src_size - 1))) & 0b1;
// Calculate diff in order to clean bits which don't exist in the source value
// diff 5
size_t diff = sizeof(SRC)*8 - src_size + 1;
// Clean unnecessary bits
// val [10100000]
val = val << diff;
// val [00000101]
val = (val >> diff);

// Negative number
if (!sign) {
// val [11110101]
val |= (src_max << (diff - 1));
// new_val [00000001 11111111]
new_val = (sign << (dst_size - 1)) ^ (dst_max >> (sizeof(DST) * 8 - dst_size));
// new_val [00000001 11110101]
new_val &= (dst_max << sizeof(SRC)*8) | val;
} else {
// new_val [00000000 00000101]
new_val = val;
}

// Calculate diff in order to clean bits which don't exist in the source value
// diff 4
size_t diff = sizeof(SRC)*8 - src_size;
// Clean unnecessary bits
// val [11010000]
val = val << diff;
// val [00001101]
val = val >> diff;

// Get the sign of value
// sign [00000001]
SRC sign = (val >> (src_size - 1)) & 0b1;

// If source type is signed and negative
if (is_signed && sign) {
// val [11111101]
val |= src_max << diff;
// new_val [00000001 11111111]
new_val = dst_max >> (sizeof(DST) * 8 - dst_size);
// new_val [00000001 11111101]
new_val &= (dst_max << sizeof(SRC)*8) | val;
} else {
// Calculate diff in order to clean bits which don't exist in the source value
// diff 4
size_t diff = sizeof(SRC)*8 - src_size;
// Clean unnecessary bits
// val [11010000]
val = val << diff;
// val [00001101]
val = val >> diff;
// new_val [00000000 00001101]
new_val = val;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -656,11 +656,11 @@ TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToI8) {
}

TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToU8_neg) {
constant_convert_test(element::Type_t::i4, element::Type_t::u8, 171, 242);
constant_convert_test(element::Type_t::i4, element::Type_t::u8, 171, 250);
}

TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToI8_neg) {
constant_convert_test(element::Type_t::i4, element::Type_t::i8, 171, -14);
constant_convert_test(element::Type_t::i4, element::Type_t::i8, 171, -6);
}

TEST(TransformationTests, ConvertPrecision_ConstantConversion_U4ToI32) {
Expand All @@ -680,11 +680,11 @@ TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToI32) {
}

TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToU32_neg) {
constant_convert_test(element::Type_t::i4, element::Type_t::u32, 171, -14);
constant_convert_test(element::Type_t::i4, element::Type_t::u32, 171, -6);
}

TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToI32_neg) {
constant_convert_test(element::Type_t::i4, element::Type_t::i32, 171, -14);
constant_convert_test(element::Type_t::i4, element::Type_t::i32, 171, -6);
}

TEST(TransformationTests, ConvertPrecision_ConstantConversion_U4ToI16) {
Expand All @@ -704,11 +704,11 @@ TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToI16) {
}

TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToU16_neg) {
constant_convert_test(element::Type_t::i4, element::Type_t::u16, 171, 65522);
constant_convert_test(element::Type_t::i4, element::Type_t::u16, 171, 65530);
}

TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToI16_neg) {
constant_convert_test(element::Type_t::i4, element::Type_t::i16, 171, -14);
constant_convert_test(element::Type_t::i4, element::Type_t::i16, 171, -6);
}

TEST(TransformationTests, ConvertPrecision_ConstantConversion_U4ToI64) {
Expand All @@ -728,11 +728,11 @@ TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToI64) {
}

TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToU64_neg) {
constant_convert_test(element::Type_t::i4, element::Type_t::u64, 171, -14);
constant_convert_test(element::Type_t::i4, element::Type_t::u64, 171, -6);
}

TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToI64_neg) {
constant_convert_test(element::Type_t::i4, element::Type_t::i64, 171, -14);
constant_convert_test(element::Type_t::i4, element::Type_t::i64, 171, -6);
}

TEST(TransformationTests, ConvertPrecision_ConstantConversion_U1ToU8) {
Expand Down

0 comments on commit de50ecf

Please sign in to comment.