@@ -479,5 +479,49 @@ inline int64_t multiply_integers(Iter begin, Iter end) {
479479 begin, end, static_cast <int64_t >(1 ), std::multiplies<>());
480480}
481481
482+ class WorkgroupSize final {
483+ uint32_t val;
484+
485+ public:
486+ explicit WorkgroupSize () : val(0 ) {}
487+ explicit WorkgroupSize (const uint32_t x, const uint32_t y, const uint32_t z) {
488+ // shift numbers by multiple of 11 bits, since each local workgroup axis can
489+ // be 1024 at most and which is 0x400. only z axis can't store 1024, because
490+ // it would overflow uint32_t storage.
491+ if (z == 1024 ) {
492+ throw std::runtime_error (
493+ " Workgroup size in z axis cannot be 1024 because it would overflow uint32_t storage" );
494+ }
495+ val = x | (y << 11 ) | (z << 22 );
496+ }
497+
498+ explicit WorkgroupSize (const uvec3& vec) {
499+ // shift numbers by multiple of 11 bits, since each local workgroup axis can
500+ // be 1024 at most and which is 0x400. only z axis can't store 1024, because
501+ // it would overflow uint32_t storage.
502+ if (vec[2u ] == 1024 ) {
503+ throw std::runtime_error (
504+ " Workgroup size in z axis cannot be 1024 because it would overflow uint32_t storage" );
505+ }
506+ val = vec[0u ] | (vec[1u ] << 11 ) | (vec[2u ] << 22 );
507+ }
508+
509+ explicit inline operator uvec3 () const {
510+ return {
511+ val & 0x7ffu ,
512+ (val >> 11 ) & 0x7ffu ,
513+ (val >> 22 ),
514+ };
515+ }
516+
517+ explicit inline operator uint32_t () const {
518+ return val;
519+ }
520+
521+ inline constexpr uint32_t operator [](const int idx) const {
522+ return (val >> (11 * idx)) & 0x7ffu ;
523+ }
524+ };
525+
482526} // namespace utils
483527} // namespace vkcompute
0 commit comments