Skip to content

Commit

Permalink
update to latest Cadence decompression code.
Browse files Browse the repository at this point in the history
  • Loading branch information
ddavis-2015 committed Oct 22, 2024
1 parent b43c16c commit 7dc34a9
Showing 1 changed file with 250 additions and 45 deletions.
295 changes: 250 additions & 45 deletions tensorflow/lite/micro/kernels/xtensa/decompress.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,15 @@ struct DecompressionStateXtensa : DecompressionState {
: DecompressionState(other) {}

void DecompressToBufferWidth4_Xtensa(int8_t* buffer);
void DecompressToBufferWidth4_Xtensa_Old(int8_t* buffer);
void DecompressToBufferWidth3_Xtensa(int8_t* buffer);
void DecompressToBufferWidth2_Xtensa(int8_t* buffer);

void DecompressToBufferWidthAnyInt8_Xtensa(int8_t* buffer);
void DecompressToBufferWidthAnyInt16_Xtensa(int16_t* buffer);
void DecompressToBufferWidthAnyInt32_Xtensa(int32_t* buffer);
void DecompressToBufferWidthAnyInt64_Xtensa(int64_t* buffer);
};

// TODO(ddavis-2015): unaligned/stride code has error, method not currently
// used.
void DecompressionStateXtensa::DecompressToBufferWidth4_Xtensa(int8_t* buffer) {
ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_);

Expand All @@ -76,6 +75,8 @@ void DecompressionStateXtensa::DecompressToBufferWidth4_Xtensa(int8_t* buffer) {

const uint8_t* __restrict value_table_t = value_table;

ae_valignx2 align_store = AE_ZALIGN128();

for (size_t i = 0; i < num_channels_; i++) {
value_table_t = value_table;
ae_valignx2 align_vtab = AE_LA128_PP(value_table_t);
Expand All @@ -84,7 +85,6 @@ void DecompressionStateXtensa::DecompressToBufferWidth4_Xtensa(int8_t* buffer) {
AE_DSEL8X8(d_value_0, d_value_1, d_value_0_t, d_value_1_t,
d_shuffle_value_t);

ae_valignx2 align_store = AE_ZALIGN128();
ae_valign align_load = AE_LA64_PP(pIn_tmp);

for (j = 0; j < elements_per_channel_t_by_4; j++) {
Expand All @@ -95,57 +95,257 @@ void DecompressionStateXtensa::DecompressToBufferWidth4_Xtensa(int8_t* buffer) {
}

value_table += stride;

ae_valignx2 align_index = AE_LA128_PP(pIn_tmp);
AE_LAV8X8X2_XP(d_index, d_dummy, align_index, (ae_int8x16*)pIn_tmp,
(elements_per_channel_t_rem >>
1)); /* Loading 48 bits for decoding 16 weight values */
AE_DSEL8X8(d_out1, d_out2, d_value_0, d_value_1, d_index);
AE_DSEL8X8(d_out1, d_out2, d_out1, d_out2, d_shuffle_t);
AE_SAV8X8X2_XP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp,
elements_per_channel_t_rem);
AE_SA128POS_FP(align_store, (ae_int8x16*)p_out_tmp);
if (elements_per_channel_t_rem) {
ae_valignx2 align_index = AE_LA128_PP(pIn_tmp);
AE_LAV8X8X2_XP(d_index, d_dummy, align_index, (ae_int8x16*)pIn_tmp,
(elements_per_channel_t_rem >>
1)); /* Loading 48 bits for decoding 16 weight values */
AE_DSEL8X8(d_out1, d_out2, d_value_0, d_value_1, d_index);
AE_DSEL8X8(d_out1, d_out2, d_out1, d_out2, d_shuffle_t);
AE_SAV8X8X2_XP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp,
elements_per_channel_t_rem);
}
}
AE_SA128POS_FP(align_store, (ae_int8x16*)p_out_tmp);
}

void DecompressionStateXtensa::DecompressToBufferWidth4_Xtensa_Old(
int8_t* buffer) {
void DecompressionStateXtensa::DecompressToBufferWidth3_Xtensa(int8_t* buffer) {
ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_);

char shuffle_pattern_1[8] = {0x08, 0x19, 0x2A, 0x3B, 0x4C, 0x5D, 0x6E, 0x7F};
ae_int8x8 d_shuffle_t = *(ae_int8x8*)&shuffle_pattern_1[0];
int i, j;
ae_int8* __restrict p_out_tmp = (ae_int8*)buffer;
ae_int8x8* pIn_tmp = (ae_int8x8*)compressed_indices_;
const uint8_t* __restrict value_table =
static_cast<const uint8_t*>(comp_data_.data.lut_data->value_table);

const uint8_t* __restrict value_table_t = value_table;

char shuffle_pattern_2[8] = {0xFB, 0x73, 0xEA, 0x62, 0xD9, 0x51, 0xC8, 0x40};
ae_int8x8 d_d_shuffle_t2 = *(ae_int8x8*)&shuffle_pattern_2[0];
int num_channels_t = num_channels_;
const size_t stride = comp_data_.data.lut_data->value_table_channel_stride;

int elements_per_channel_t_by_4 = elements_per_channel_ >> 4;
int elements_per_channel_t_rem = elements_per_channel_ & 0xF;

ae_int8x8 d_index, d_dummy;
ae_int8x8 d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11;
ae_int8x8 d_out1, d_out2;
ae_int8x8 d_value_0, d_value_1;
ae_int8x8 d_index;

int elements_per_channel_t = elements_per_channel_;
int num_channels_t = num_channels_;
ae_int8x8* __restrict pIn_tmp = (ae_int8x8*)compressed_indices_;
ae_int8* __restrict p_out_tmp = (ae_int8*)buffer;
ae_valignx2 align_index = AE_LA128_PP(pIn_tmp);

const size_t stride = comp_data_.data.lut_data->value_table_channel_stride;
ae_int8x8 d_shuffle_value_t = AE_MOVINT8X8_FROMINT64(0x08192A3B4C5D6E7FLL);
ae_int8x8 d_shuffle_t1 = AE_MOVINT8X8_FROMINT64(0x0F00050C00020000LL);
ae_int8x8 d_shuffle_t2 = AE_MOVINT8X8_FROMINT64(0x000E00040B000100LL);
ae_int8x8 d_shuffle_t3 = AE_MOVINT8X8_FROMINT64(0x0F060D040C030A01LL);
ae_int8x8 d_shuffle_t = AE_MOVINT8X8_FROMINT64(0xFB73EA62D951C840LL);

ae_valignx2 align_store = AE_ZALIGN128();

for (i = 0; i < num_channels_t; i++) {
ae_int8x8 d_value_0 = AE_MOVINT8X8_FROMINT64(AE_ZERO());
ae_int8x8 d_value_1 = AE_MOVINT8X8_FROMINT64(AE_ZERO());

value_table_t = value_table;

ae_valign align_vtab = AE_LA64_PP(value_table_t);
AE_LA8X8_IP(d_value_0, align_vtab, (ae_int8x8*)value_table_t);
AE_DSEL8X8(d_value_0, d_value_1, d_value_0, d_value_1, d_shuffle_value_t);

for (j = 0; j < elements_per_channel_t_by_4; j++) {
AE_LAV8X8X2_XP(d_index, d_dummy, align_index, (ae_int8x16*)pIn_tmp,
6); /* Loading 48 bits for decoding 16 weight values */

d1 =
AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 1));
d2 =
AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 2));
d3 =
AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 3));
d4 =
AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 4));

d1 = AE_MOVINT8X8_FROMINT64(
AE_AND64(AE_MOVINT64_FROMINT8X8(d1), 0x7007007007000000LL));
d2 = AE_MOVINT8X8_FROMINT64(
AE_AND64(AE_MOVINT64_FROMINT8X8(d2), 0x0700700700700000LL));
d3 = AE_MOVINT8X8_FROMINT64(
AE_AND64(AE_MOVINT64_FROMINT8X8(d3), 0x0070070070070000LL));
d4 = AE_MOVINT8X8_FROMINT64(
AE_AND64(AE_MOVINT64_FROMINT8X8(d4), 0x0007007007007000LL));

d5 = d1 | d2;
d6 = d3 | d4;

d7 = AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d5), 4));
d8 = AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d6), 4));

d9 = AE_SEL8X8(d5, d7, d_shuffle_t1);
d10 = AE_SEL8X8(d6, d8, d_shuffle_t2);
d11 = AE_SEL8X8(d9, d10, d_shuffle_t3);

AE_DSEL8X8(d_out1, d_out2, d_value_0, d_value_1, d11);
AE_DSEL8X8(d_out1, d_out2, d_out1, d_out2, d_shuffle_t);

AE_SA8X8X2_IP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp);
}
if (elements_per_channel_t_rem) {
AE_LAV8X8X2_XP(d_index, d_dummy, align_index, (ae_int8x16*)pIn_tmp,
3); /* Loading 48 bits for decoding 16 weight values */

d1 =
AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 1));
d2 =
AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 2));
d3 =
AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 3));
d4 =
AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 4));

d1 = AE_MOVINT8X8_FROMINT64(
AE_AND64(AE_MOVINT64_FROMINT8X8(d1), 0x7007007007000000LL));
d2 = AE_MOVINT8X8_FROMINT64(
AE_AND64(AE_MOVINT64_FROMINT8X8(d2), 0x0700700700700000LL));
d3 = AE_MOVINT8X8_FROMINT64(
AE_AND64(AE_MOVINT64_FROMINT8X8(d3), 0x0070070070070000LL));
d4 = AE_MOVINT8X8_FROMINT64(
AE_AND64(AE_MOVINT64_FROMINT8X8(d4), 0x0007007007007000LL));

d5 = d1 | d2;
d6 = d3 | d4;

d7 = AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d5), 4));
d8 = AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d6), 4));

d9 = AE_SEL8X8(d5, d7, d_shuffle_t1);
d10 = AE_SEL8X8(d6, d8, d_shuffle_t2);
d11 = AE_SEL8X8(d9, d10, d_shuffle_t3);

AE_DSEL8X8(d_out1, d_out2, d_value_0, d_value_1, d11);
AE_DSEL8X8(d_out1, d_out2, d_out1, d_out2, d_shuffle_t);

AE_SAV8X8X2_XP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp,
elements_per_channel_t_rem);
}

value_table = value_table + stride;
}
AE_SA128POS_FP(align_store, (ae_int8x16*)p_out_tmp);
}

void DecompressionStateXtensa::DecompressToBufferWidth2_Xtensa(int8_t* buffer) {
ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_);

int i, j;
ae_int8* __restrict p_out_tmp = (ae_int8*)buffer;
ae_int8x8* pIn_tmp = (ae_int8x8*)compressed_indices_;
const uint8_t* __restrict value_table =
static_cast<const uint8_t*>(comp_data_.data.lut_data->value_table);

for (int i = 0; i < num_channels_t; i++) {
ae_int8x8 d_value_0_t = *(ae_int8x8*)&value_table[0];
ae_int8x8 d_value_1_t = *(ae_int8x8*)&value_table[8];
const uint8_t* __restrict value_table_t = value_table;

AE_DSEL8X8(d_value_0, d_value_1, d_value_0_t, d_value_1_t, d_shuffle_t);
int num_channels_t = num_channels_;
const size_t stride = comp_data_.data.lut_data->value_table_channel_stride;

for (int j = 0; j < elements_per_channel_t; j += 16) {
AE_L8X8_IP(d_index, pIn_tmp, 8);
AE_DSEL8X8(d_out1, d_out2, d_value_0, d_value_1, d_index);
AE_DSEL8X8(d_out1, d_out2, d_out1, d_out2, d_d_shuffle_t2);
AE_S8X8X2_IP(d_out1, d_out2, (ae_int8x16*)p_out_tmp, 16);
int elements_per_channel_t_by_5 = elements_per_channel_ >> 5;
int elements_per_channel_t_rem = elements_per_channel_ & 0x1F;
int elements_per_channel_t_rem_minus_16 = 0;
if (elements_per_channel_t_rem > 16) {
elements_per_channel_t_rem_minus_16 = elements_per_channel_t_rem - 16;
}

ae_int8x8 d_index, d_dummy;
ae_int8x8 d0, d1, d2, d3, d4, d5;
ae_int8x8 q0, q1, q2, q3;
ae_int8x8 d_out1, d_out2;

ae_valignx2 align_index = AE_LA128_PP(pIn_tmp);

ae_int8x8 d_shuffle_value_t = AE_MOVINT8X8_FROMINT64(0x08192A3B4C5D6E7FLL);
ae_int8x8 d_shuffle_t1 = AE_MOVINT8X8_FROMINT64(0xFB73EA62D951C840LL);
ae_int8x8 d_shuffle_t2 = AE_MOVINT8X8_FROMINT64(0xFBEA7362D9C85140LL);

ae_valignx2 align_store = AE_ZALIGN128();

for (i = 0; i < num_channels_t; i++) {
ae_int8x8 d_value_0 = AE_MOVINT8X8_FROMINT64(AE_ZERO());
ae_int8x8 d_value_1 = AE_MOVINT8X8_FROMINT64(AE_ZERO());

value_table_t = value_table;

ae_valign align_vtab = AE_LA64_PP(value_table_t);
AE_LA8X8_IP(d_value_0, align_vtab, (ae_int8x8*)value_table_t);
AE_DSEL8X8(d_value_0, d_value_1, d_value_0, d_value_1, d_shuffle_value_t);

for (j = 0; j < elements_per_channel_t_by_5; j++) {
// AE_LA8X8_IP( d_index, align_index, pIn_tmp ); /* Loading 64 bits
// for decoding 32 weight values */

AE_LAV8X8X2_XP(d_index, d_dummy, align_index, (ae_int8x16*)pIn_tmp,
8); /* Loading 64 bits for decoding 32 weight values */
d0 = d_index;
d1 =
AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 2));

d2 = AE_MOVINT8X8_FROMINT64(
AE_AND64(AE_MOVINT64_FROMINT8X8(d0),
0x3333333333333333LL)); // i1,i3,i5, ....
d3 = AE_MOVINT8X8_FROMINT64(
AE_AND64(AE_MOVINT64_FROMINT8X8(d1),
0x3333333333333333LL)); // i0,i2,i4, ....

AE_DSEL8X8(d4, d5, d3, d2,
d_shuffle_t1); // d4 = i0,i2,i1,i3,i4,i6,... d5 =
// i16,i18, i17,i19, ....

AE_DSEL8X8(q0, q1, d_value_0, d_value_1,
d4); // q0 = 0,1,4,5,8,9,12,13 q1 = 2,3,6,7,10,11,14,15
AE_DSEL8X8(
q2, q3, d_value_0, d_value_1,
d5); // q2 = 16,17,20,21,24,25,28,29 q3 = 18,19,22,23,26,27,30,31

AE_DSEL8X8(d_out1, d_out2, q0, q1, d_shuffle_t2);
AE_SA8X8X2_IP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp);

AE_DSEL8X8(d_out1, d_out2, q2, q3, d_shuffle_t2);
AE_SA8X8X2_IP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp);
}
if (elements_per_channel_t_rem) {
AE_LAV8X8X2_XP(d_index, d_dummy, align_index, (ae_int8x16*)pIn_tmp,
(elements_per_channel_t_rem >>
2)); /* Loading 48 bits for decoding 16 weight values */
d0 = d_index;
d1 =
AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 2));
d2 = AE_MOVINT8X8_FROMINT64(
AE_AND64(AE_MOVINT64_FROMINT8X8(d0),
0x3333333333333333LL)); // i1,i3,i5, ....
d3 = AE_MOVINT8X8_FROMINT64(
AE_AND64(AE_MOVINT64_FROMINT8X8(d1),
0x3333333333333333LL)); // i0,i2,i4, ....

AE_DSEL8X8(d4, d5, d3, d2,
d_shuffle_t1); // d4 = i0,i2,i1,i3,i4,i6,... d5 =
// i16,i18, i17,i19, ....

AE_DSEL8X8(q0, q1, d_value_0, d_value_1,
d4); // q0 = 0,1,4,5,8,9,12,13 q1 = 2,3,6,7,10,11,14,15
AE_DSEL8X8(
q2, q3, d_value_0, d_value_1,
d5); // q2 = 16,17,20,21,24,25,28,29 q3 = 18,19,22,23,26,27,30,31

AE_DSEL8X8(d_out1, d_out2, q0, q1, d_shuffle_t2);

AE_SAV8X8X2_XP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp,
elements_per_channel_t_rem);

AE_DSEL8X8(d_out1, d_out2, q2, q3, d_shuffle_t2);

AE_SAV8X8X2_XP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp,
elements_per_channel_t_rem_minus_16);
}

value_table += stride;
value_table = value_table + stride;
}
AE_SA128POS_FP(align_store, (ae_int8x16*)p_out_tmp);
}

void DecompressionStateXtensa::DecompressToBufferWidthAnyInt8_Xtensa(
Expand Down Expand Up @@ -407,20 +607,25 @@ int8_t* DecompressionState::DecompressToBuffer<int8_t>(void* buffer) {

if (comp_data_.data.lut_data->compressed_bit_width == 4 &&
!comp_data_.data.lut_data->use_alternate_axis) {
if (!(elements_per_channel_ & 0x0F) &&
comp_data_.data.lut_data->value_table_channel_stride == 16) {
dsx.DecompressToBufferWidth4_Xtensa_Old(static_cast<int8_t*>(buffer));
if (!(elements_per_channel_ & 0x01)) {
dsx.DecompressToBufferWidth4_Xtensa(static_cast<int8_t*>(buffer));
} else {
dsx.DecompressToBufferWidth4_16(static_cast<int8_t*>(buffer));
dsx.DecompressToBufferWidthAnyInt8_Xtensa(static_cast<int8_t*>(buffer));
}
} else if (comp_data_.data.lut_data->compressed_bit_width == 3 &&
!comp_data_.data.lut_data->use_alternate_axis) {
// TODO(ddavis-2015): placeholder
dsx.DecompressToBufferWidthAnyInt8_Xtensa(static_cast<int8_t*>(buffer));
if (!(elements_per_channel_ & 0x07)) {
dsx.DecompressToBufferWidth3_Xtensa(static_cast<int8_t*>(buffer));
} else {
dsx.DecompressToBufferWidthAnyInt8_Xtensa(static_cast<int8_t*>(buffer));
}
} else if (comp_data_.data.lut_data->compressed_bit_width == 2 &&
!comp_data_.data.lut_data->use_alternate_axis) {
// TODO(ddavis-2015): placeholder
dsx.DecompressToBufferWidthAnyInt8_Xtensa(static_cast<int8_t*>(buffer));
if (!(elements_per_channel_ & 0x03)) {
dsx.DecompressToBufferWidth2_Xtensa(static_cast<int8_t*>(buffer));
} else {
dsx.DecompressToBufferWidthAnyInt8_Xtensa(static_cast<int8_t*>(buffer));
}
} else {
dsx.DecompressToBufferWidthAnyInt8_Xtensa(static_cast<int8_t*>(buffer));
}
Expand Down

0 comments on commit 7dc34a9

Please sign in to comment.