Skip to content

Commit

Permalink
revert to original Cadence bit width 4 code
Browse files Browse the repository at this point in the history
  • Loading branch information
ddavis-2015 committed Oct 11, 2024
1 parent 2388549 commit 99c6e35
Showing 1 changed file with 51 additions and 1 deletion.
52 changes: 51 additions & 1 deletion tensorflow/lite/micro/kernels/xtensa/decompress.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,14 @@ struct DecompressionStateXtensa : DecompressionState {
: DecompressionState(other) {}

void DecompressToBufferWidth4_Xtensa(int8_t* buffer);
void DecompressToBufferWidth4_Xtensa_Old(int8_t* buffer);

template <size_t N>
void DecompressToBufferWidthAny_Xtensa(int8_t* buffer);
};

// TODO(ddavis-2015): unaligned/stride code has error, method not currently
// used.
void DecompressionStateXtensa::DecompressToBufferWidth4_Xtensa(int8_t* buffer) {
MicroProfiler* profiler =
static_cast<MicroProfiler*>(micro_context_->external_context());
Expand Down Expand Up @@ -105,6 +108,48 @@ void DecompressionStateXtensa::DecompressToBufferWidth4_Xtensa(int8_t* buffer) {
}
}

void DecompressionStateXtensa::DecompressToBufferWidth4_Xtensa_Old(
int8_t* buffer) {
MicroProfiler* profiler =
static_cast<MicroProfiler*>(micro_context_->external_context());
ScopedMicroProfiler scoped_profiler(__func__, profiler);

char shuffle_pattern_1[8] = {0x08, 0x19, 0x2A, 0x3B, 0x4C, 0x5D, 0x6E, 0x7F};
ae_int8x8 d_shuffle_t = *(ae_int8x8*)&shuffle_pattern_1[0];

char shuffle_pattern_2[8] = {0xFB, 0x73, 0xEA, 0x62, 0xD9, 0x51, 0xC8, 0x40};
ae_int8x8 d_d_shuffle_t2 = *(ae_int8x8*)&shuffle_pattern_2[0];

ae_int8x8 d_out1, d_out2;
ae_int8x8 d_value_0, d_value_1;
ae_int8x8 d_index;

int elements_per_channel_t = elements_per_channel_;
int num_channels_t = num_channels_;
ae_int8x8* __restrict pIn_tmp = (ae_int8x8*)compressed_indices_;
ae_int8* __restrict p_out_tmp = (ae_int8*)buffer;

const size_t stride = comp_data_.data.lut_data->value_table_channel_stride;
const uint8_t* __restrict value_table =
static_cast<const uint8_t*>(comp_data_.data.lut_data->value_table);

for (int i = 0; i < num_channels_t; i++) {
ae_int8x8 d_value_0_t = *(ae_int8x8*)&value_table[0];
ae_int8x8 d_value_1_t = *(ae_int8x8*)&value_table[8];

AE_DSEL8X8(d_value_0, d_value_1, d_value_0_t, d_value_1_t, d_shuffle_t);

for (int j = 0; j < elements_per_channel_t; j += 16) {
AE_L8X8_IP(d_index, pIn_tmp, 8);
AE_DSEL8X8(d_out1, d_out2, d_value_0, d_value_1, d_index);
AE_DSEL8X8(d_out1, d_out2, d_out1, d_out2, d_d_shuffle_t2);
AE_S8X8X2_IP(d_out1, d_out2, (ae_int8x16*)p_out_tmp, 16);
}

value_table += stride;
}
}

template <size_t N>
void DecompressionStateXtensa::DecompressToBufferWidthAny_Xtensa(
int8_t* buffer) {
Expand Down Expand Up @@ -163,7 +208,12 @@ T* DecompressionState::DecompressToBuffer(void* buffer) {
if (std::is_same<T, int8_t>::value &&
comp_data_.data.lut_data->compressed_bit_width == 4 &&
!comp_data_.data.lut_data->use_alternate_axis) {
dsx.DecompressToBufferWidth4_Xtensa(static_cast<int8_t*>(buffer));
if (!(elements_per_channel_ & 0x0F) &&
comp_data_.data.lut_data->value_table_channel_stride == 16) {
dsx.DecompressToBufferWidth4_Xtensa_Old(static_cast<int8_t*>(buffer));
} else {
DecompressToBufferWidth4_16(static_cast<int8_t*>(buffer));
}
} else if (std::is_same<T, int8_t>::value &&
comp_data_.data.lut_data->compressed_bit_width == 3 &&
!comp_data_.data.lut_data->use_alternate_axis) {
Expand Down

0 comments on commit 99c6e35

Please sign in to comment.