Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions csrc/xpu_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ inline float dDequantizeNF4(unsigned char val) {
}

template <typename T, int TILE_SIZE, int NUM_PER_TH, int DATA_TYPE>
SYCL_EXTERNAL void kDequantizeBlockwise<T, TILE_SIZE, NUM_PER_TH, DATA_TYPE>::operator()(sycl::and_item<1> item) const {
SYCL_EXTERNAL void kDequantizeBlockwise<T, TILE_SIZE, NUM_PER_TH, DATA_TYPE>::operator()(sycl::nd_item<1> item) const {
const int base_idx = item.get_group(0) * TILE_SIZE;
size_t local_idx = item.get_local_id(0) * NUM_PER_TH;
float local_abs_max = -FLT_MAX;
Expand Down Expand Up @@ -172,7 +172,7 @@ SYCL_EXTERNAL void kDequantizeBlockwise<T, TILE_SIZE, NUM_PER_TH, DATA_TYPE>::op

template <typename T, size_t GROUP_SIZE, size_t NUM_PER_THREAD, size_t SUBG_SIZE, int BITS>
SYCL_EXTERNAL void
kgemv_4bit_inference<T, GROUP_SIZE, NUM_PER_THREAD, SUBG_SIZE, BITS>::operator()(sycl::and_item<1> item) const {
kgemv_4bit_inference<T, GROUP_SIZE, NUM_PER_THREAD, SUBG_SIZE, BITS>::operator()(sycl::nd_item<1> item) const {
size_t idx = item.get_local_id();
const int sg_idx = idx / SUBG_SIZE;
const int sg_lane = idx % SUBG_SIZE;
Expand Down
4 changes: 2 additions & 2 deletions csrc/xpu_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

template <typename T, int TILE_SIZE, int NUM_PER_TH, int DATA_TYPE> class kDequantizeBlockwise {
public:
SYCL_EXTERNAL void operator()(sycl::and_item<1> item) const;
SYCL_EXTERNAL void operator()(sycl::nd_item<1> item) const;

kDequantizeBlockwise(float* code_, uint8_t* A_, float* absmax_, T* out_, const int blocksize_, const int n_)
: code(code_), A(A_), absmax(absmax_), out(out_), blocksize(blocksize_), n(n_) {}
Expand All @@ -22,7 +22,7 @@ template <typename T, int TILE_SIZE, int NUM_PER_TH, int DATA_TYPE> class kDequa

template <typename T, size_t GROUP_SIZE, size_t NUM_PER_THREAD, size_t SUBG_SIZE, int BITS> class kgemv_4bit_inference {
public:
SYCL_EXTERNAL void operator()(sycl::and_item<1> item) const;
SYCL_EXTERNAL void operator()(sycl::nd_item<1> item) const;

kgemv_4bit_inference(
int M_, int N_, int K_, T* A_, unsigned char* B_, float* absmax_, const float* datatype_, T* out_, int lda_,
Expand Down
6 changes: 3 additions & 3 deletions csrc/xpu_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ void dequantizeBlockwise(
sycl::range<1> global_range{(size_t)workgroup_num * (size_t)workgroup_size};
kDequantizeBlockwise<T, tile_size, num_per_th, DATA_TYPE> kfn(code, A, absmax, out, blocksize / 2, n);
sycl_kernel_submit<decltype(kfn), 1, 32>(
sycl::and_range<1>(sycl::range<1>(global_range), sycl::range<1>(local_range)), queue, kfn
sycl::nd_range<1>(sycl::range<1>(global_range), sycl::range<1>(local_range)), queue, kfn
);
} else {
const int workgroup_num = (n + tile_size - 1) / tile_size;
sycl::range<1> local_range{(size_t)workgroup_size};
sycl::range<1> global_range{(size_t)workgroup_num * (size_t)workgroup_size};
kDequantizeBlockwise<T, tile_size, num_per_th, DATA_TYPE> kfn(code, A, absmax, out, blocksize, n);
sycl_kernel_submit<decltype(kfn), 1, 32>(
sycl::and_range<1>(sycl::range<1>(global_range), sycl::range<1>(local_range)), queue, kfn
sycl::nd_range<1>(sycl::range<1>(global_range), sycl::range<1>(local_range)), queue, kfn
);
}
}
Expand All @@ -47,7 +47,7 @@ void gemv_4bit_inference(
);

sycl_comp_kernel_submit<decltype(kfn), 1, SUBG_SIZE>(
sycl::and_range<1>(sycl::range<1>(GROUP_SIZE * workgroup_num), sycl::range<1>(GROUP_SIZE)), queue, kfn
sycl::nd_range<1>(sycl::range<1>(GROUP_SIZE * workgroup_num), sycl::range<1>(GROUP_SIZE)), queue, kfn
);
}

Expand Down
4 changes: 2 additions & 2 deletions csrc/xpu_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
#include <sycl/sycl.hpp>

template <typename ker_t, int dim, int subgroup_size>
static inline void sycl_kernel_submit(sycl::and_range<dim> range, sycl::queue q, ker_t ker) {
static inline void sycl_kernel_submit(sycl::nd_range<dim> range, sycl::queue q, ker_t ker) {
auto cgf = [&](::sycl::handler& cgh)
[[sycl::reqd_sub_group_size(subgroup_size)]] { cgh.parallel_for<ker_t>(range, ker); };
q.submit(cgf);
}

template <typename ker_t, int dim, int subgroup_size>
static inline void sycl_comp_kernel_submit(sycl::and_range<dim> range, sycl::queue q, ker_t ker) {
static inline void sycl_comp_kernel_submit(sycl::nd_range<dim> range, sycl::queue q, ker_t ker) {
auto cgf = [&](::sycl::handler& cgh) [[sycl::reqd_sub_group_size(subgroup_size)]] {
ker.sycl_ker_local_memory_creation(cgh);
cgh.parallel_for<ker_t>(range, ker);
Expand Down
4 changes: 2 additions & 2 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,8 +1238,8 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double
max_errs3 = []

# Large number of iterations is excessive and slow on CPU.
# Keep for CUDA for now.
iters = 100 if device == "cuda" else 10
# Keep for CUDA/XPU for now.
iters = 10 if device == "cpu" else 100

for i in range(iters):
if kind == "fc1":
Expand Down
Loading