diff --git a/src/tl_templates/cuda/common.h b/src/tl_templates/cuda/common.h index dfbc062cf..a42aa1bd0 100644 --- a/src/tl_templates/cuda/common.h +++ b/src/tl_templates/cuda/common.h @@ -93,12 +93,22 @@ TL_DEVICE unsigned __pack_nv_bfloat162(const bfloat16_t x, const bfloat16_t y) { return (v1 << 16) | v0; } -// Pack four char values +// Pack four char values. TL_DEVICE int make_int(signed char x0, signed char x1, signed char x2, signed char x3) { return (x3 << 24) | (x2 << 16) | (x1 << 8) | x0; } +// Pack eight char values. +TL_DEVICE int2 make_int2(signed char x0, signed char x1, signed char x2, + signed char x3, signed char y0, signed char y1, + signed char y2, signed char y3) { + int2 result; + result.x = make_int(x0, x1, x2, x3); + result.y = make_int(y0, y1, y2, y3); + return result; +} + // Pack sixteen char values. TL_DEVICE int4_t make_int4(signed char x0, signed char x1, signed char x2, signed char x3, signed char y0, signed char y1, @@ -114,6 +124,17 @@ TL_DEVICE int4_t make_int4(signed char x0, signed char x1, signed char x2, return result; } +// Pack eight int values. +TL_DEVICE longlong4 make_longlong4(int x0, int x1, int y0, int y1, int z0, + int z1, int w0, int w1) { + longlong4 result; + *((int2 *)&result.x) = make_int2(x0, x1); + *((int2 *)&result.y) = make_int2(y0, y1); + *((int2 *)&result.z) = make_int2(z0, z1); + *((int2 *)&result.w) = make_int2(w0, w1); + return result; +} + // Helper to cast SMEM pointer to unsigned TL_DEVICE uint32_t smem_ptr_to_uint(void const *const ptr) { return static_cast(__cvta_generic_to_shared(ptr));