Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ebpf/PSA: Checksum support for fields wider than 64 bits #3801

Merged
merged 3 commits into from
Dec 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 168 additions & 92 deletions backends/ebpf/psa/externs/ebpfPsaHashAlgorithm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,84 +51,124 @@ void EBPFHashAlgorithmPSA::emitAddData(CodeBuilder* builder, int dataPos,

void CRCChecksumAlgorithm::emitUpdateMethod(CodeBuilder* builder, int crcWidth) {
// Note that this update method is optimized for our CRC16 and CRC32, custom
// version may require other method of update. To deal with byte order data
// is read from the end of buffer.
// version may require other method of update. When data_size <= 64 bits,
// applies host byte order for input data, otherwise network byte order is expected.
if (crcWidth == 16) {
// This function calculates CRC16 by definition, it is bit by bit. If input data has more
// than 64 bit, the outer loop process bytes in network byte order - data pointer is
// incremented. For data shorter than or equal 64 bits, bytes are processed in little endian
// byte order - data pointer is decremented by outer loop in this case.
// There is no need for lookup table.
cstring code =
"static __always_inline\n"
"void crc%w%_update(u%w% * reg, const u8 * data, u16 data_size, const u%w% poly) {\n"
" for (u16 i = data_size; i > 0; i--) {\n"
" bpf_trace_message(\"CRC%w%: data byte: %x\\n\", data[i-1]);\n"
" *reg ^= (u16) data[i-1];\n"
"void crc16_update(u16 * reg, const u8 * data, "
"u16 data_size, const u16 poly) {\n"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confused by this code, some documentation would help.

" if (data_size <= 8)\n"
" data += data_size - 1;\n"
" #pragma clang loop unroll(full)\n"
" for (u16 i = 0; i < data_size; i++) {\n"
" bpf_trace_message(\"CRC16: data byte: %x\\n\", *data);\n"
" *reg ^= *data;\n"
" for (u8 bit = 0; bit < 8; bit++) {\n"
" *reg = (*reg) & 1 ? ((*reg) >> 1) ^ poly : (*reg) >> 1;\n"
" }\n"
" if (data_size <= 8)\n"
" data--;\n"
" else\n"
" data++;\n"
" }\n"
"}";
code = code.replace("%w%", Util::printf_format("%d", crcWidth));
builder->appendLine(code);
} else if (crcWidth == 32) {
// This function calculates CRC32 using two optimisations: slice-by-8 and Standard
// Implementation. Both algorithms have to properly handle byte order depending on data
// length. There are four cases which must be handled:
// 1. Data size below 8 bytes - calculated using Standard Implementation in little endian
// byte order.
// 2. Data size equal to 8 bytes - calculated using slice-by-8 in little endian byte order.
// 3. Data size more than 8 bytes and multiply of 8 bytes - calculated using slice-by-8 in
// big endian byte order.
// 4. Data size more than 8 bytes and not multiply of 8 bytes - calculated using slice-by-8
// and Standard Implementation both in big endian byte order.
// Lookup table is necessary for both algorithms.
cstring code =
"static __always_inline\n"
"void crc%w%_update(u%w% * reg, const u8 * data, u16 data_size, "
"const u%w% poly) {\n"
" data += data_size - 4;\n"
" u32* current = (u32*) data;\n"
" struct lookup_tbl_val* lookup_table;\n"
" u32 index = 0;\n"
" lookup_table = BPF_MAP_LOOKUP_ELEM(crc_lookup_tbl, &index);\n"
" u32 lookup_key = 0;\n"
" u32 lookup_value = 0;\n"
" u32 lookup_value1 = 0;\n"
" u32 lookup_value2 = 0;\n"
" u32 lookup_value3 = 0;\n"
" u32 lookup_value4 = 0;\n"
" u32 lookup_value5 = 0;\n"
" u32 lookup_value6 = 0;\n"
" u32 lookup_value7 = 0;\n"
" u32 lookup_value8 = 0;\n"
" u16 tmp = 0;\n"
" if (lookup_table != NULL) {\n"
" for (u16 i = data_size; i >= 8; i -= 8) {\n"
" bpf_trace_message(\"CRC32: data byte: %x\", *current);\n"
" u32 one = __builtin_bswap32(*current--) ^ *reg;\n"
" bpf_trace_message(\"CRC32: data byte: %x\", *current);\n"
" u32 two = __builtin_bswap32(*current--);\n"
" lookup_key = (one & 0x000000FF);\n"
" lookup_value8 = lookup_table->table[(u16)(1792 + (u8)lookup_key)];\n"
" lookup_key = (one >> 8) & 0x000000FF;\n"
" lookup_value7 = lookup_table->table[(u16)(1536 + (u8)lookup_key)];\n"
" lookup_key = (one >> 16) & 0x000000FF;\n"
" lookup_value6 = lookup_table->table[(u16)(1280 + (u8)lookup_key)];\n"
" lookup_key = one >> 24;\n"
" lookup_value5 = lookup_table->table[(u16)(1024 + (u8)(lookup_key))];\n"
" lookup_key = (two & 0x000000FF);\n"
" lookup_value4 = lookup_table->table[(u16)(768 + (u8)lookup_key)];\n"
" lookup_key = (two >> 8) & 0x000000FF;\n"
" lookup_value3 = lookup_table->table[(u16)(512 + (u8)lookup_key)];\n"
" lookup_key = (two >> 16) & 0x000000FF;\n"
" lookup_value2 = lookup_table->table[(u16)(256 + (u8)lookup_key)];\n"
" lookup_key = two >> 24;\n"
" lookup_value1 = lookup_table->table[(u8)(lookup_key)];\n"
" *reg = lookup_value8 ^ lookup_value7 ^ lookup_value6 ^ lookup_value5 ^\n"
" lookup_value4 ^ lookup_value3 ^ lookup_value2 ^ lookup_value1;\n"
" tmp += 8;\n"
" }\n"
" unsigned char *currentChar = (unsigned char *) current;\n"
" currentChar+= 3;\n"
" volatile int std_algo_lookup_key = 0;\n"
" for (u16 i = tmp; i < data_size; i++) {\n"
" bpf_trace_message(\"CRC32: data byte: %x\\n\", *current);\n"
" std_algo_lookup_key = (u32)(((*reg) & 0xFF) ^ *currentChar--);\n"
" if (std_algo_lookup_key >= 0) {\n"
" lookup_value = "
"void crc32_update(u32 * reg, const u8 * data, u16 data_size, const u32 poly) {\n"
" u32* current = (u32*) data;\n"
" struct lookup_tbl_val* lookup_table;\n"
" u32 index = 0;\n"
" lookup_table = BPF_MAP_LOOKUP_ELEM(crc_lookup_tbl, &index);\n"
" u32 lookup_key = 0;\n"
" u32 lookup_value = 0;\n"
" u32 lookup_value1 = 0;\n"
" u32 lookup_value2 = 0;\n"
" u32 lookup_value3 = 0;\n"
" u32 lookup_value4 = 0;\n"
" u32 lookup_value5 = 0;\n"
" u32 lookup_value6 = 0;\n"
" u32 lookup_value7 = 0;\n"
" u32 lookup_value8 = 0;\n"
" u16 tmp = 0;\n"
" if (lookup_table != NULL) {\n"
" for (u16 i = data_size; i >= 8; i -= 8) {\n"
" /* Vars one and two will have swapped byte order if data_size == 8 */\n"
" if (data_size == 8) current = data + 4;\n"
" bpf_trace_message(\"CRC32: data dword: %x\\n\", *current);\n"
" u32 one = (data_size == 8 ? __builtin_bswap32(*current--) : *current++) ^ "
"*reg;\n"
" bpf_trace_message(\"CRC32: data dword: %x\\n\", *current);\n"
" u32 two = (data_size == 8 ? __builtin_bswap32(*current--) : *current++);\n"
" lookup_key = (one & 0x000000FF);\n"
" lookup_value8 = lookup_table->table[(u16)(1792 + (u8)lookup_key)];\n"
" lookup_key = (one >> 8) & 0x000000FF;\n"
" lookup_value7 = lookup_table->table[(u16)(1536 + (u8)lookup_key)];\n"
" lookup_key = (one >> 16) & 0x000000FF;\n"
" lookup_value6 = lookup_table->table[(u16)(1280 + (u8)lookup_key)];\n"
" lookup_key = one >> 24;\n"
" lookup_value5 = lookup_table->table[(u16)(1024 + (u8)(lookup_key))];\n"
" lookup_key = (two & 0x000000FF);\n"
" lookup_value4 = lookup_table->table[(u16)(768 + (u8)lookup_key)];\n"
" lookup_key = (two >> 8) & 0x000000FF;\n"
" lookup_value3 = lookup_table->table[(u16)(512 + (u8)lookup_key)];\n"
" lookup_key = (two >> 16) & 0x000000FF;\n"
" lookup_value2 = lookup_table->table[(u16)(256 + (u8)lookup_key)];\n"
" lookup_key = two >> 24;\n"
" lookup_value1 = lookup_table->table[(u8)(lookup_key)];\n"
" *reg = lookup_value8 ^ lookup_value7 ^ lookup_value6 ^ lookup_value5 ^\n"
" lookup_value4 ^ lookup_value3 ^ lookup_value2 ^ lookup_value1;\n"
" tmp += 8;\n"
" }\n"
" volatile int std_algo_lookup_key = 0;\n"
" if (data_size < 8) {\n"
// Standard Implementation for little endian byte order
" unsigned char *currentChar = (unsigned char *) current;\n"
" currentChar += data_size - 1;\n"
" for (u16 i = tmp; i < data_size; i++) {\n"
" bpf_trace_message(\"CRC32: data byte: %x\\n\", *currentChar);\n"
" std_algo_lookup_key = (u32)(((*reg) & 0xFF) ^ *currentChar--);\n"
" if (std_algo_lookup_key >= 0) {\n"
" lookup_value = "
"lookup_table->table[(u8)(std_algo_lookup_key & 255)];\n"
" }\n"
" *reg = ((*reg) >> 8) ^ lookup_value;\n"
" }\n"
" } else {\n"
// Standard Implementation for big endian byte order
" /* Consume data not processed by slice-by-8 algorithm above, "
"these data are in network byte order */\n"
" unsigned char *currentChar = (unsigned char *) current;\n"
" for (u16 i = tmp; i < data_size; i++) {\n"
" bpf_trace_message(\"CRC32: data byte: %x\\n\", *currentChar);\n"
" std_algo_lookup_key = (u32)(((*reg) & 0xFF) ^ *currentChar++);\n"
" if (std_algo_lookup_key >= 0) {\n"
" lookup_value = "
"lookup_table->table[(u8)(std_algo_lookup_key & 255)];\n"
" }\n"
" *reg = ((*reg) >> 8) ^ lookup_value;\n"
" }\n"
" *reg = ((*reg) >> 8) ^ lookup_value;\n"
" }\n"
" }\n"
" }\n"
" }\n"
"}";
code = code.replace("%w%", Util::printf_format("%d", crcWidth));
builder->appendLine(code);
}
}
Expand Down Expand Up @@ -340,38 +380,28 @@ void InternetChecksumAlgorithm::updateChecksum(CodeBuilder* builder, const Argum
bitsToRead = width;

if (width > 64) {
BUG("Fields wider than 64 bits are not supported yet", field);
}

while (bitsToRead > 0) {
if (remainingBits == 16) {
builder->emitIndent();
builder->appendFormat("%s = ", tmpVar.c_str());
} else {
builder->append(" | ");
if (remainingBits != 16) {
::error(ErrorType::ERR_UNSUPPORTED,
"%1%: field wider than 64 bits must be aligned to 16 bits in input data",
field);
continue;
}

// TODO: add masks for fields, however they should not exceed declared width
if (bitsToRead < remainingBits) {
remainingBits -= bitsToRead;
builder->append("(");
visitor->visit(field);
builder->appendFormat(" << %d)", remainingBits);
bitsToRead = 0;
} else if (bitsToRead == remainingBits) {
remainingBits = 0;
visitor->visit(field);
bitsToRead = 0;
} else if (bitsToRead > remainingBits) {
bitsToRead -= remainingBits;
remainingBits = 0;
builder->append("(");
visitor->visit(field);
builder->appendFormat(" >> %d)", bitsToRead);
if (width % 16 != 0) {
::error(ErrorType::ERR_UNSUPPORTED,
"%1%: field wider than 64 bits must have size in bits multiply of 16 bits",
field);
continue;
}

if (remainingBits == 0) {
remainingBits = 16;
// Let's convert internal array into an array of u16 and calc csum for such entries.
// Byte order conversion is required, because csum is calculated in host byte order
// but data is preserved in network byte order
const unsigned arrayEntries = width / 16;
for (unsigned i = 0; i < arrayEntries; ++i) {
builder->emitIndent();
builder->appendFormat("%s = htons(((u16 *)(", tmpVar.c_str());
visitor->visit(field);
builder->appendFormat("))[%u])", i);
builder->endOfStatement(true);

// update checksum
Expand All @@ -387,6 +417,52 @@ void InternetChecksumAlgorithm::updateChecksum(CodeBuilder* builder, const Argum
}
builder->endOfStatement(true);
}
} else { // fields smaller or equal than 64 bits
while (bitsToRead > 0) {
if (remainingBits == 16) {
builder->emitIndent();
builder->appendFormat("%s = ", tmpVar.c_str());
} else {
builder->append(" | ");
}

// TODO: add masks for fields, however they should not exceed declared width
if (bitsToRead < remainingBits) {
remainingBits -= bitsToRead;
builder->append("(");
visitor->visit(field);
builder->appendFormat(" << %d)", remainingBits);
bitsToRead = 0;
} else if (bitsToRead == remainingBits) {
remainingBits = 0;
visitor->visit(field);
bitsToRead = 0;
} else if (bitsToRead > remainingBits) {
bitsToRead -= remainingBits;
remainingBits = 0;
builder->append("(");
visitor->visit(field);
builder->appendFormat(" >> %d)", bitsToRead);
}

if (remainingBits == 0) {
remainingBits = 16;
builder->endOfStatement(true);

// update checksum
builder->target->emitTraceMessage(builder, "InternetChecksum: word=0x%llx", 1,
tmpVar.c_str());
builder->emitIndent();
if (addData) {
builder->appendFormat("%s = csum16_add(%s, %s)", stateVar.c_str(),
stateVar.c_str(), tmpVar.c_str());
} else {
builder->appendFormat("%s = csum16_sub(%s, %s)", stateVar.c_str(),
stateVar.c_str(), tmpVar.c_str());
}
builder->endOfStatement(true);
}
}
}
}

Expand Down
Loading