Skip to content
31 changes: 31 additions & 0 deletions sgl-kernel/csrc/cpu/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,29 @@ namespace {
} \
}()

#define AT_DISPATCH_BOOL2(BOOL_V1, BOOL_NAME1, BOOL_V2, BOOL_NAME2, ...) \
[&] { \
if (BOOL_V1) { \
constexpr bool BOOL_NAME1 = true; \
if (BOOL_V2) { \
constexpr bool BOOL_NAME2 = true; \
return __VA_ARGS__(); \
} else { \
constexpr bool BOOL_NAME2 = false; \
return __VA_ARGS__(); \
} \
} else { \
constexpr bool BOOL_NAME1 = false; \
if (BOOL_V2) { \
constexpr bool BOOL_NAME2 = true; \
return __VA_ARGS__(); \
} else { \
constexpr bool BOOL_NAME2 = false; \
return __VA_ARGS__(); \
} \
} \
}()

// dispatch: bfloat16, float16, int8_t, fp8_e4m3
#define CPU_DISPATCH_PACKED_TYPES(TYPE, ...) \
[&] { \
Expand Down Expand Up @@ -105,6 +128,8 @@ namespace {

#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)

#define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b)

// [NB] Parallel Routines
//
// * at::parallel_for - applies for most of generic use cases, this will be compiled
Expand Down Expand Up @@ -321,4 +346,10 @@ struct Unroll<1> {
}
};

// conditional data ptr for optional tensor
template <typename T>
inline T* conditional_data_ptr(const std::optional<at::Tensor>& opt) {
return opt.has_value() ? opt.value().data_ptr<T>() : nullptr;
}

} // anonymous namespace
Loading
Loading