Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
61c32ba
[Metax] add keyword filter in CI CMakeLists.txt
StareAtYou Sep 24, 2025
642eb37
Merge branch 'metax666:develop' into develop
StareAtYou Sep 25, 2025
b2ddc81
[Metax] add ignore case list
StareAtYou Sep 25, 2025
041e585
Merge branch 'metax666:develop' into develop
StareAtYou Sep 25, 2025
087a9c1
[Metax] fix phi::backends::gpu::DnnVersion() symbol not found
StareAtYou Sep 26, 2025
73710c5
Revert "[Metax] fix phi::backends::gpu::DnnVersion() symbol not found"
StareAtYou Sep 26, 2025
404ff3d
[Metax] fix index_elementwise_get kernel
StareAtYou Sep 26, 2025
739c5c7
Merge branch 'metax666:develop' into develop
StareAtYou Sep 28, 2025
35a4e49
Merge branch 'metax666:develop' into develop
StareAtYou Sep 29, 2025
8f91b94
Merge branch 'metax666:develop' into develop
StareAtYou Oct 9, 2025
b533149
Merge branch 'metax666:develop' into develop
StareAtYou Oct 11, 2025
3c6bcd2
Merge branch 'metax666:develop' into develop
StareAtYou Oct 15, 2025
a786d0a
Merge branch 'metax666:develop' into develop
StareAtYou Oct 17, 2025
eb32ae3
Merge branch 'metax666:develop' into develop
StareAtYou Oct 22, 2025
342ff81
[Metax] fix weight_quant & weight_only_linear bug
StareAtYou Oct 23, 2025
9bc5cd4
Merge branch 'metax666:develop' into develop
StareAtYou Oct 23, 2025
e9d0d72
Merge branch 'metax666:develop' into develop
StareAtYou Oct 25, 2025
f507479
[Metax] fix 'WeightQuantizeKernel' wint4 branch
StareAtYou Oct 28, 2025
2c0d6f4
Merge branch 'metax666:develop' into develop
StareAtYou Oct 28, 2025
b3c816b
[Metax] add quanted weight layout transformation using CPU programming
StareAtYou Oct 29, 2025
181772d
[Metax] adjust quanted weight layout transformation
StareAtYou Oct 29, 2025
1c42f4a
Merge branch 'metax666:develop' into develop
StareAtYou Oct 29, 2025
6e0d1eb
[Metax] add quanted weight layout transformation using GPU programming
StareAtYou Oct 29, 2025
dc95578
Merge branch 'metax666:develop' into develop
StareAtYou Oct 30, 2025
165e524
[Metax] optimize wint4 quantization implementation
StareAtYou Oct 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,26 @@
#include "paddle/phi/kernels/funcs/math_function.h"

namespace phi {

template <typename DataType>
void show_2d_cpu_tensor(const DenseTensor& tensor,
const int64_t row_num = 3,
const int64_t col_num = 3) {
const int64_t rows = tensor.dims()[0];
const int64_t cols = tensor.dims()[1];
printf("\nTensor shape = [%d, %d]\n", rows, cols);

const int8_t* cpu_ptr = tensor.data<int8_t>();
const DataType* cpu_ptr = tensor.data<DataType>();

for (int r = 0; r < row_num; r++) {
for (int c = 0; c < col_num; c++) {
int8_t val = *(cpu_ptr + r * cols + c);
printf("%d ", val);
DataType val = *(cpu_ptr + r * cols + c);
printf("%#x ", val);
}
printf("\n");
}
printf("\n\n");
}

template <typename DataType>
void show_2d_gpu_tensor(const CustomContext& dev_ctx,
const DenseTensor& tensor,
const int64_t row_num = 3,
Expand All @@ -58,18 +58,39 @@ void show_2d_gpu_tensor(const CustomContext& dev_ctx,
const int64_t cols = cpu_tensor.dims()[1];
printf("\nTensor shape = [%d, %d]\n", rows, cols);

const int8_t* cpu_ptr = cpu_tensor.data<int8_t>();
const DataType* cpu_ptr = cpu_tensor.data<DataType>();

for (int r = 0; r < row_num; r++) {
for (int c = 0; c < col_num; c++) {
int8_t val = *(cpu_ptr + r * cols + c);
printf("%d ", val);
DataType val = *(cpu_ptr + r * cols + c);
printf("%#x ", val);
}
printf("\n");
}
printf("\n\n");
}

template <typename DataType>
void show_1d_gpu_tensor(const CustomContext& dev_ctx,
const DenseTensor& tensor,
const int64_t num = 3) {
phi::CPUPlace cpu_place;

DenseTensor cpu_tensor;
phi::Copy(dev_ctx, tensor, cpu_place, true, &cpu_tensor);

const int64_t nums = cpu_tensor.numel();
printf("\nTensor shape = [%d]\n", nums);

const DataType* cpu_ptr = cpu_tensor.data<DataType>();

for (int n = 0; n < num; n++) {
DataType val = *(cpu_ptr + n);
printf("%#x ", val);
}
printf("\n\n");
}

void cpu_2d_tensor_transpose(const DenseTensor& input_data,
DenseTensor* transposed_data) {
const int64_t input_data_rows = input_data.dims()[0];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ void WeightQuantizeKernel(const Context& dev_ctx,
dev_ctx.template Alloc<T>(scale);
weight_quant_gpu<T, Context>(dev_ctx,
x.data<T>(),
out->data<int8_t>(),
quanted_x.data<int8_t>(),
scale->data<T>(),
weight_shape,
arch,
Expand All @@ -141,7 +141,13 @@ void WeightQuantizeKernel(const Context& dev_ctx,
// arch,
// algo);
#endif
MetaxQuantizedWeightLayoutTrans<Context>(dev_ctx, algo, weight_shape, out);
quanted_x.Resize({m / 2, n});

std::vector<int> axis = {1, 0};
funcs::Transpose<Context, int8_t, 2> trans;
trans(dev_ctx, quanted_x, out, axis);

out->Resize({n / 2, m});
} else if (algo == "w4a8") {
weight_permute_gpu_w4a8<Context>(dev_ctx,
x.data<int8_t>(),
Expand Down
Loading