diff --git a/chatglm_test.cpp b/chatglm_test.cpp index 3900eb5..1d800d7 100644 --- a/chatglm_test.cpp +++ b/chatglm_test.cpp @@ -6,22 +6,19 @@ namespace chatglm { namespace fs = std::filesystem; -void expect_all_close(float *a, float *b, size_t n, float atol = 2e-4) { - for (size_t i = 0; i < n; i++) { - EXPECT_LT(std::abs(a[i] - b[i]), atol); +static inline void expect_all_close(ggml_tensor *a, ggml_tensor *b, float atol = 1e-5) { + ASSERT_EQ(a->type, b->type); + ASSERT_EQ(a->type, GGML_TYPE_F32); + ASSERT_EQ(ggml_nelements(a), ggml_nelements(b)); + int64_t numel = ggml_nelements(a); + for (int64_t i = 0; i < numel; i++) { + EXPECT_LT(std::abs(((float *)a->data)[i] - ((float *)b->data)[i]), atol); } } -bool has_shape(ggml_tensor *tensor, const std::vector &shape) { - if (tensor->n_dims != (int)shape.size()) { - return false; - } - for (int i = 0; i < tensor->n_dims; i++) { - if (tensor->ne[tensor->n_dims - 1 - i] != shape[i]) { - return false; - } - } - return true; +static inline char *map_tensor_data(char *ptr, ggml_tensor *tensor) { + tensor->data = ptr; + return ptr + ggml_nbytes(tensor); } class ChatGLMTest : public ::testing::Test { @@ -38,6 +35,11 @@ class ChatGLMTest : public ::testing::Test { ctx.gctx = GGMLContext({1024 * 1024, nullptr, false}); ctx.scratch = {0, scratch_buf.size(), scratch_buf.data()}; + + reset_cgraph(); + } + + void reset_cgraph() { ctx.gf = {}; ctx.gf.n_threads = 1; } @@ -54,616 +56,263 @@ TEST_F(ChatGLMTest, Embedding) { memcpy(x->data, x_data, sizeof(x_data)); Embedding model(&ictx, 4, 3); memcpy(model.weight->data, w_data, sizeof(w_data)); - ggml_tensor *y = model.forward(&ctx, x); + ggml_tensor *ref = ggml_new_tensor_2d(ctx.gctx.get(), GGML_TYPE_F32, 3, 5); + ref->data = y_data; - ggml_build_forward_expand(&ctx.gf, y); + ggml_tensor *out = model.forward(&ctx, x); + + ggml_build_forward_expand(&ctx.gf, out); ggml_graph_compute(ctx.gctx.get(), &ctx.gf); - expect_all_close((float *)y->data, y_data, ggml_nelements(y)); + expect_all_close(ref, out); } TEST_F(ChatGLMTest, Linear) { - float w_data[]{-1.1258, -1.1524, -0.2506, -0.4339, 0.8487, 0.6920, -0.3160, -2.1152, - 0.4681, -0.1577, 1.4437, 0.2660, 0.1665, 0.8744, -0.1435, -0.1116, - 0.9318, 1.2590, 2.0050, 0.0537, 0.6181, -0.4128, -0.8411, -2.3160}; - float b_data[]{0.3704, 1.4565, 0.9398, 0.7748, 0.1919, 1.2638, -1.2904, -0.7911}; - float x_data[]{-0.0209, -0.7185, 0.5186, -1.3125, 0.1920, 0.5428}; - float y_data[]{1.0919, 1.2147, 2.7089, -0.1211, -0.5142, 1.2496, -1.0504, -1.3794, - 1.4908, 2.5645, 1.2025, 1.4034, 0.0634, 2.2725, -3.5762, -1.6679}; - - ggml_tensor *x = ggml_new_tensor_2d(ctx.gctx.get(), GGML_TYPE_F32, 3, 2); - memcpy(x->data, x_data, sizeof(x_data)); + fs::path test_path = fs::path(__FILE__).parent_path() / "tests/data/linear.data"; + MappedFile mapped_file(test_path.string()); + char *ptr = mapped_file.data; + + ggml_tensor *w = ggml_new_tensor_2d(ctx.gctx.get(), GGML_TYPE_F32, 32, 16); + ptr = map_tensor_data(ptr, w); + ggml_tensor *b = ggml_new_tensor_1d(ctx.gctx.get(), GGML_TYPE_F32, 16); + ptr = map_tensor_data(ptr, b); + ggml_tensor *x = ggml_new_tensor_2d(ctx.gctx.get(), GGML_TYPE_F32, 32, 2); + ptr = map_tensor_data(ptr, x); + ggml_tensor *ref = ggml_new_tensor_2d(ctx.gctx.get(), GGML_TYPE_F32, 16, 2); + ptr = map_tensor_data(ptr, ref); + ASSERT_EQ(ptr, mapped_file.data + mapped_file.size); // fp32 { ictx.dtype = GGML_TYPE_F32; - Linear model(&ictx, 3, 8); - memcpy(model.weight->data, w_data, sizeof(w_data)); - memcpy(model.bias->data, b_data, sizeof(b_data)); + Linear model(&ictx, 32, 16); + model.weight->data = w->data; + model.bias->data = b->data; - ggml_tensor *y = model.forward(&ctx, x); + ggml_tensor *out = model.forward(&ctx, x); - ggml_build_forward_expand(&ctx.gf, y); + ggml_build_forward_expand(&ctx.gf, out); ggml_graph_compute(ctx.gctx.get(), &ctx.gf); - expect_all_close((float *)y->data, y_data, ggml_nelements(y)); + expect_all_close(ref, out); } // fp16 { + reset_cgraph(); + ictx.dtype = GGML_TYPE_F16; - Linear model(&ictx, 3, 8); - ggml_fp32_to_fp16_row(w_data, (ggml_fp16_t *)model.weight->data, ggml_nelements(model.weight)); - memcpy(model.bias->data, b_data, sizeof(b_data)); + Linear model(&ictx, 32, 16); + ggml_fp32_to_fp16_row((float *)w->data, (ggml_fp16_t *)model.weight->data, ggml_nelements(model.weight)); + model.bias->data = b->data; - ggml_tensor *y = model.forward(&ctx, x); + ggml_tensor *out = model.forward(&ctx, x); - ggml_build_forward_expand(&ctx.gf, y); + ggml_build_forward_expand(&ctx.gf, out); ggml_graph_compute(ctx.gctx.get(), &ctx.gf); - EXPECT_EQ(y->type, GGML_TYPE_F32); - expect_all_close((float *)y->data, y_data, ggml_nelements(y), 5e-3); + EXPECT_EQ(out->type, GGML_TYPE_F32); + expect_all_close(ref, out, 5e-3); } } TEST_F(ChatGLMTest, LayerNorm) { - float w_data[]{1.5410, -0.2934, -2.1788, 0.5684, -1.0845, -1.3986, 0.4033, 0.8380, -0.7193}; - float b_data[]{-0.4033, -0.5966, 0.1820, -0.8567, 1.1006, -1.0712, 0.1227, -0.5663, 0.3731}; - float x_data[]{0.4397, 0.1124, 0.5433, -0.3952, 0.2055, -0.4503, -0.5731, -0.5554, 0.5943, - 1.5419, 1.8197, -0.5515, -1.3253, 0.1886, -0.0691, -0.4949, -1.4959, -0.1938}; - float y_data[]{1.1039, -0.6742, -2.4414, -1.3358, 0.5938, 0.2759, -0.3738, -1.5655, -0.5730, - 1.9137, -1.1141, 1.1753, -1.5275, 0.8437, -1.0652, -0.0398, -1.6891, 0.4602}; - - ggml_tensor *x = ggml_new_tensor_2d(ctx.gctx.get(), GGML_TYPE_F32, 9, 2); - memcpy(x->data, x_data, sizeof(x_data)); - LayerNorm model(&ictx, 9); - memcpy(model.weight->data, w_data, sizeof(w_data)); - memcpy(model.bias->data, b_data, sizeof(b_data)); - ggml_tensor *y = model.forward(&ctx, x); + fs::path test_path = fs::path(__FILE__).parent_path() / "tests/data/layer_norm.data"; + MappedFile mapped_file(test_path.string()); + char *ptr = mapped_file.data; - ggml_build_forward_expand(&ctx.gf, y); + LayerNorm model(&ictx, 64); + ptr = map_tensor_data(ptr, model.weight); + ptr = map_tensor_data(ptr, model.bias); + + ggml_tensor *x = ggml_new_tensor_2d(ctx.gctx.get(), GGML_TYPE_F32, 64, 3); + ptr = map_tensor_data(ptr, x); + x = ggml_dup(ctx.gctx.get(), x); + + ggml_tensor *ref = ggml_new_tensor_2d(ctx.gctx.get(), GGML_TYPE_F32, 64, 3); + ptr = map_tensor_data(ptr, ref); + + ASSERT_EQ(ptr, mapped_file.data + mapped_file.size); + + ggml_tensor *out = model.forward(&ctx, x); + ggml_build_forward_expand(&ctx.gf, out); ggml_graph_compute(ctx.gctx.get(), &ctx.gf); - expect_all_close((float *)y->data, y_data, ggml_nelements(y)); + expect_all_close(ref, out); } TEST_F(ChatGLMTest, RMSNorm) { - float w_data[]{0.4963, 0.7682, 0.0885, 0.1320, 0.3074, 0.6341, 0.4901}; - float x_data[]{0.1642, 0.3058, 0.2100, 0.9056, 0.6035, 0.8110, -0.0451, - 0.8797, 1.0482, -0.0445, -0.7229, 2.8663, -0.5655, 0.1604}; - float y_data[]{0.1521, 0.4385, 0.0347, 0.2232, 0.3464, 0.9599, -0.0412, - 0.3489, 0.6436, -0.0031, -0.0763, 0.7043, -0.2866, 0.0628}; + fs::path test_path = fs::path(__FILE__).parent_path() / "tests/data/rms_norm.data"; + MappedFile mapped_file(test_path.string()); + char *ptr = mapped_file.data; - RMSNorm model(&ictx, 7); - memcpy(model.weight->data, w_data, sizeof(w_data)); + RMSNorm model(&ictx, 64); + ptr = map_tensor_data(ptr, model.weight); - ggml_tensor *x = ggml_new_tensor_2d(ctx.gctx.get(), GGML_TYPE_F32, 7, 2); - memcpy(x->data, x_data, sizeof(x_data)); - ggml_tensor *y = model.forward(&ctx, x); + ggml_tensor *x = ggml_new_tensor_2d(ctx.gctx.get(), GGML_TYPE_F32, 64, 3); + ptr = map_tensor_data(ptr, x); + x = ggml_dup(ctx.gctx.get(), x); + + ggml_tensor *ref = ggml_new_tensor_2d(ctx.gctx.get(), GGML_TYPE_F32, 64, 3); + ptr = map_tensor_data(ptr, ref); - ggml_build_forward_expand(&ctx.gf, y); + ASSERT_EQ(ptr, mapped_file.data + mapped_file.size); + + ggml_tensor *out = model.forward(&ctx, x); + + ggml_build_forward_expand(&ctx.gf, out); ggml_graph_compute(ctx.gctx.get(), &ctx.gf); - expect_all_close((float *)y->data, y_data, ggml_nelements(y)); + expect_all_close(ref, out); } -TEST_F(ChatGLMTest, GLMSelfAttention) { - float query_key_value_weight_data[]{ - -1.8717e-03, 1.3411e-01, -2.0576e-01, -1.8398e-01, -9.6289e-02, 6.7039e-02, -4.9533e-03, 1.9822e-01, - -2.2186e-02, 6.6153e-02, -7.5553e-02, -4.9141e-02, -2.3884e-01, -1.6557e-01, -1.0306e-01, 9.2609e-03, - 9.8834e-02, 1.5001e-01, -1.6949e-01, -1.0887e-01, 9.0804e-02, 2.0760e-01, -5.1450e-02, 1.8708e-01, - -4.0296e-02, 2.6454e-02, 2.2637e-01, -2.3192e-01, -1.5738e-01, -6.3291e-02, -9.7450e-02, 2.1600e-01, - -1.6204e-01, -1.1508e-01, -1.7466e-01, -2.3414e-01, -1.4594e-01, 2.1490e-01, 1.1155e-01, 1.2117e-01, - 1.3148e-02, -1.2817e-01, 4.2296e-02, -2.3342e-01, -1.8064e-01, -1.2888e-01, 1.5773e-01, 1.4658e-01, - -1.1087e-01, -9.0206e-03, 1.5989e-01, 2.4853e-01, 9.9221e-02, 3.3773e-02, 1.6762e-01, -1.4720e-01, - 4.6586e-02, -1.9383e-01, -1.7327e-01, -1.2915e-01, 1.1312e-01, 1.0054e-01, -1.4809e-01, 7.5527e-02, - 1.3724e-01, -3.1554e-02, 9.5454e-03, 5.7926e-02, 1.5509e-01, 2.4005e-01, -1.9266e-01, -9.1617e-02, - 9.8252e-02, 2.0714e-01, 2.1755e-01, 2.2059e-01, 4.9754e-02, -2.1740e-01, 2.2998e-02, -1.5640e-01, - -2.3299e-01, 2.2212e-01, 1.9009e-01, -2.4938e-01, 4.6793e-02, -4.2115e-02, -4.1140e-02, -1.1444e-01, - 9.6139e-02, -1.4808e-01, 9.1648e-02, 1.2643e-01, 1.7897e-01, 9.3478e-02, -2.4743e-01, -1.6217e-01, - 1.2483e-01, 5.2325e-02, -1.9502e-01, -1.4395e-01, 2.3519e-01, 1.6845e-01, -1.0901e-01, -6.2921e-02, - -2.3815e-01, -4.4935e-03, -1.8826e-01, -1.9284e-01, -1.3775e-02, 3.7536e-02, -1.0238e-01, 1.4834e-01, - -1.5213e-01, 2.2684e-01, 1.7132e-01, -2.1082e-01, -6.2221e-02, 1.1281e-02, 3.6475e-02, 5.9294e-02, - 9.8107e-02, 1.4975e-02, -1.2198e-01, 1.1830e-01, -2.3981e-01, -1.4818e-01, -6.2582e-02, -1.2178e-01, - -8.7458e-02, -2.0491e-01, -5.3179e-02, 5.3439e-02, -1.6287e-01, -1.2830e-02, 1.7896e-01, -2.5700e-02, - 6.9481e-03, -2.1567e-02, 5.0595e-02, 1.5896e-01, 2.3681e-01, 1.5876e-01, 2.3735e-01, -1.8080e-02, - -2.2458e-01, -1.1852e-01, 1.7023e-01, -1.6206e-03, -1.2426e-01, -1.9158e-01, -2.3396e-01, -2.1100e-01, - -5.0709e-02, 1.3710e-01, 1.3516e-01, -2.4111e-01, 1.5595e-01, -1.9563e-01, -5.2853e-02, -1.0137e-01, - -4.8154e-02, -4.9086e-02, -2.2434e-01, -2.1586e-01, -3.9120e-02, 3.2330e-03, -1.1357e-01, 9.4175e-02, - -2.2501e-01, -1.6872e-02, 2.1985e-01, -1.0197e-01, 2.2575e-01, 9.0538e-02, -2.2562e-01, 1.5817e-01, - -2.8849e-02, -1.1160e-01, 1.9991e-01, -2.0202e-01, 2.6826e-02, -5.2342e-02, 1.7853e-01, 6.9786e-02, - 1.2013e-01, 8.8290e-02, -6.0119e-02, -5.2576e-02, -2.0602e-01, 1.3546e-01, 1.9849e-01, 1.7106e-01, - -1.7634e-01, 1.1150e-02, -1.7623e-01, -1.3762e-01, -1.4568e-01, 8.5436e-02, -1.4898e-01, -5.4543e-03, - 1.0517e-02, 1.6116e-01, -1.8898e-01, -1.7163e-01, -1.4517e-01, 1.7498e-01, -8.9866e-02, 2.1087e-01, - 9.0402e-02, 3.1657e-02, -1.8610e-03, -4.9420e-02, 3.1367e-02, -5.7086e-02, -1.7566e-03, 3.1898e-02, - -1.9555e-01, -1.3103e-01, 2.0187e-01, -2.0289e-01, -1.7952e-02, 2.4731e-01, 9.0309e-02, 7.0783e-03, - -2.1665e-01, 1.2384e-01, -1.7807e-01, -7.0966e-02, -8.3879e-02, -3.7022e-02, 2.7346e-03, 2.0620e-01, - 3.1210e-02, 2.2392e-01, 1.5293e-01, -1.5805e-01, 1.1213e-01, -1.7672e-01, -1.0596e-01, 7.3531e-02, - 8.2548e-02, 1.8756e-01, -8.0479e-02, 4.0022e-04, 1.2871e-01, -2.4177e-01, 1.8075e-01, -2.0673e-01, - 3.4456e-03, -4.2504e-02, -1.3167e-01, 3.3043e-02, 2.0673e-01, -7.3080e-02, -1.4842e-01, -9.2460e-02, - -2.4779e-01, 1.1285e-01, -1.2007e-01, -1.6684e-01, -1.4403e-01, 1.4374e-01, 1.3239e-01, 1.9188e-01, - 9.0681e-02, -8.3490e-02, -6.9862e-02, 7.3857e-02, 2.0551e-01, 6.7973e-02, -1.1829e-01, -1.1752e-01, - -2.3637e-01, 5.4012e-02, -1.4030e-01, -2.2289e-01, 2.1922e-01, -1.6235e-01, -2.8440e-02, 7.1623e-02, - 7.9646e-03, -1.6822e-01, -2.0208e-01, 1.9927e-01, 4.0710e-02, 2.0741e-01, -8.3810e-02, 7.3639e-02, - -5.7169e-02, -1.1118e-02, -1.5226e-01, 8.4550e-02, 7.9041e-02, -5.1571e-03, -5.6226e-02, -1.5411e-01, - 1.7289e-01, -1.8611e-01, 1.0242e-01, -8.4063e-02, -1.2062e-01, 4.4911e-02, -1.2986e-01, 5.7601e-02, - 4.9095e-02, -1.8562e-01, 4.1625e-02, 1.0648e-01, 9.8958e-02, -3.1470e-02, -2.0495e-01, -3.8539e-02, - 8.6827e-02, -9.1219e-02, 9.4895e-02, 1.6649e-01, -1.3054e-01, 2.4655e-03, 1.0337e-01, 1.9594e-02, - 2.0883e-02, 3.1216e-02, -1.9654e-01, 1.9647e-02, 1.7312e-01, 2.2528e-01, 1.4694e-01, 3.3501e-02, - 1.1675e-01, -1.2162e-01, -2.0717e-01, -2.1498e-01, 2.4940e-01, 1.5870e-01, -1.7281e-01, 9.7811e-02, - 1.8879e-01, 2.4990e-01, 2.1860e-01, 1.9369e-01, -5.7315e-02, -8.7735e-02, 2.0526e-01, 1.4009e-01, - -1.5044e-01, 2.2475e-01, 1.2079e-01, 1.3628e-01, -1.5669e-01, 7.1725e-02, -8.7641e-02, 1.9534e-01, - -4.4985e-02, 9.7328e-02, 4.4405e-02, 1.0637e-01, -8.4957e-02, 1.2189e-01, -1.7462e-01, 5.6464e-02, - -1.6915e-01, -2.4663e-01, -2.0076e-01, 1.9737e-01, 1.3526e-01, 2.3455e-01, 2.0028e-01, -2.2326e-01, - -1.7061e-01, -4.0396e-02, -1.6236e-01, 1.7360e-01, -1.8900e-01, -1.2198e-01, -2.4152e-01, -1.4194e-01, - 2.0562e-01, 2.0469e-01, 1.7896e-01, 1.9303e-01, 2.2230e-01, -6.4016e-02, 1.1000e-01, 2.2273e-01, - 8.2705e-02, 2.4992e-01, 1.2967e-01, 1.5541e-01, -8.7498e-02, 1.1996e-01, 2.8727e-02, -5.9704e-02, - -1.4093e-01, -1.4028e-01, -1.9237e-01, 1.6783e-01, 1.7773e-01, -2.8453e-02, -1.4467e-01, 1.9323e-01, - 1.5987e-01, 1.8584e-02, -1.1803e-01, 2.2976e-01, 1.0224e-01, -1.8979e-01, 2.3927e-01, 1.8984e-01, - -9.1123e-02, 1.4054e-01, -1.4205e-01, -3.9179e-02, 2.1228e-01, 1.0330e-02, -1.7680e-01, -8.3558e-02, - -6.7861e-02, -4.8232e-02, 2.3925e-02, 2.3121e-01, 1.3385e-02, -1.5436e-01, 1.2814e-02, 1.1987e-01, - 1.2401e-01, -2.2848e-01, -4.4736e-02, -1.8579e-01, -1.0667e-01, 9.0073e-02, -1.7753e-01, 9.2932e-02, - 2.1219e-01, 1.6397e-02, -1.6662e-01, -8.9571e-02, 5.4591e-02, -1.9058e-01, 1.2420e-01, -2.2697e-01, - -2.4032e-01, -2.4292e-01, -5.0716e-02, 1.6811e-01, -2.3662e-01, 2.0780e-01, -1.0001e-01, 7.3221e-02, - 1.1401e-02, -2.2543e-01, 2.0733e-01, 1.3461e-01, 2.4850e-01, 1.2630e-01, -1.6502e-01, 2.0865e-01, - 1.3436e-02, 1.1855e-01, -2.0046e-01, -7.1907e-02, -2.4547e-01, -9.7373e-02, 5.3933e-02, -1.9629e-01, - 7.9691e-02, 1.3420e-01, 3.4827e-02, -1.6727e-01, -1.9383e-01, -7.7129e-02, 1.0974e-01, 2.4660e-01, - 1.4376e-01, -2.8152e-02, 8.7654e-02, -2.4527e-01, -2.1353e-01, 1.1665e-01, -1.4160e-01, 1.2027e-01, - -1.7648e-01, -1.2383e-01, -2.0592e-01, 1.3046e-01, -2.5474e-02, 1.9240e-01, 1.5472e-01, 1.3834e-01, - 8.0390e-03, -7.7294e-02, -5.4358e-02, 3.3225e-02, 1.2393e-01, -1.7515e-01, 2.0982e-01, -2.7183e-02, - -2.0949e-01, -1.3526e-01, 2.2120e-01, 2.2863e-01, -2.3157e-01, 1.7632e-01, 1.2529e-01, 1.4798e-01, - 2.1163e-01, -1.3474e-01, 7.8944e-02, 1.0231e-01, -7.3873e-02, 8.3663e-02, -7.1928e-02, 1.5457e-01, - -6.9364e-02, -9.3199e-02, 6.2937e-02, 8.8673e-02, -1.2214e-01, 2.2100e-02, 1.4488e-01, -2.4874e-02, - 7.6083e-02, -6.0297e-02, 8.7625e-02, -1.8110e-01, -1.4701e-01, -1.2690e-01, 2.2975e-01, -6.7271e-02, - -6.8259e-04, -1.2112e-01, 2.4957e-01, 2.4417e-01, -1.8855e-01, -2.0267e-01, -1.8950e-01, -1.2057e-03, - -6.3726e-02, -1.6364e-01, -8.9668e-02, 4.7233e-02, -1.3062e-01, 5.5396e-02, -5.7329e-02, -1.2114e-01, - 3.4346e-02, 2.0556e-01, -1.6902e-01, 1.1609e-02, -9.2192e-02, 2.4533e-01, -2.3719e-01, -2.3967e-01, - 2.4635e-01, -1.5817e-01, 4.7931e-02, -2.1577e-02, -5.2664e-02, -5.5837e-02, 1.5886e-01, 1.1949e-02, - -2.4340e-01, -1.4759e-01, -8.5225e-02, 1.2580e-01, -1.6178e-01, 2.3573e-01, -5.5682e-02, -4.4890e-02, - 1.9589e-01, 1.2565e-01, 2.1203e-01, 1.4461e-01, -7.5838e-02, -1.6587e-01, -1.8601e-02, 2.0691e-01, - -8.3905e-02, -2.3184e-01, 1.0248e-01, 2.4337e-01, -7.1171e-02, -2.0701e-01, -2.2676e-01, 6.2650e-02, - -1.8930e-02, -1.2625e-01, 5.0533e-02, 9.4940e-02, 1.9883e-01, 1.9410e-01, -3.7421e-02, -2.2044e-01, - -2.2591e-01, 2.3342e-01, 1.1051e-01, 1.0898e-01, -2.1631e-01, 2.3150e-01, 2.3684e-01, 2.2572e-01, - -2.1090e-01, -9.4331e-02, -1.7194e-01, 2.3674e-01, -1.0742e-01, -1.1414e-01, 1.3098e-01, -1.1565e-01, - -1.2313e-01, -2.1868e-02, -2.4030e-02, -1.9474e-01, 2.0842e-01, -1.1028e-01, 8.8679e-02, 2.1745e-01, - 1.2608e-01, 3.5395e-02, 2.1271e-01, 3.3607e-02, -1.1566e-01, 2.3650e-01, 5.9170e-02, -2.4392e-01, - -7.1173e-02, -1.7029e-01, 2.1919e-01, -4.1286e-02, -2.2788e-01, -1.5716e-02, 1.5700e-01, 6.4950e-02, - 7.9051e-02, 2.3218e-02, 9.3205e-02, -6.0913e-02, -9.9450e-02, -2.3368e-01, -1.8833e-01, 1.0835e-01, - -1.4803e-01, 3.5867e-02, 7.9798e-02, 1.7701e-02, -1.6209e-01, 2.3908e-01, -1.4537e-01, 2.0563e-01, - -1.9888e-01, -6.0137e-02, 1.3599e-01, -1.0215e-01, 2.1001e-01, -1.7204e-01, -2.0994e-01, -1.1273e-01, - 4.0413e-02, 2.3019e-01, -1.1935e-01, 8.9407e-02, -6.2677e-02, -5.4216e-02, 1.8383e-01, -1.9375e-01, - 2.6553e-02, 2.3510e-01, -3.4353e-02, 1.9411e-01, -7.6989e-02, 2.0124e-01, -2.4183e-01, -3.6031e-02, - -4.3912e-02, 8.1021e-02, 9.8082e-02, 1.9195e-01, -3.7246e-02, -9.9138e-03, 1.7120e-01, -6.7643e-02, - 2.1916e-01, -1.6645e-01, -2.7052e-02, -1.3428e-02, 1.1155e-01, 1.7092e-01, -3.9626e-02, -2.0712e-01, - 1.2387e-01, 7.4783e-02, 1.0042e-01, -1.5422e-01, 1.6090e-01, 2.3679e-01, 2.1693e-02, -2.3351e-01, - 1.7550e-01, -1.8536e-01, 5.7468e-02, 3.6329e-02, -1.1701e-01, 8.7049e-02, -2.2361e-01, 5.6940e-02, - -1.5849e-01, -2.7033e-02, 3.2157e-02, 2.1299e-01, -1.1928e-01, 1.6016e-01, -3.1758e-02, -1.1873e-01, - -2.1770e-01, -2.2938e-01, 2.4415e-01, -6.2350e-02, 1.2487e-02, 6.7778e-02, 1.6993e-01, 2.1337e-01, - 2.0275e-01, -1.8522e-01, -4.0054e-02, -1.4793e-01, -1.4284e-01, 5.9302e-02, 2.3466e-01, -2.0028e-01, - 1.5130e-01, -1.2962e-01, -4.8694e-02, 1.9844e-01, -5.6543e-02, 2.2764e-02, -1.7476e-01, 2.1281e-01, - -3.2317e-02, -1.8285e-01, 7.3203e-02, -1.7775e-01, -1.9838e-01, 1.5230e-02, 1.9821e-01, -7.0746e-02, - 1.1767e-01, 2.1483e-01, 1.6582e-01, -1.3114e-01, -2.7405e-02, -7.8671e-02, -2.0103e-01, 1.0329e-04, - 1.8811e-01, 2.1063e-01, 2.3325e-02, 5.6781e-02, -1.0823e-01, 1.8871e-01, -1.0403e-01, -1.7366e-01, - 3.8498e-02, 1.4985e-01, -2.2540e-01, 2.2599e-01, 8.9933e-02, -1.7515e-01, -5.3852e-02, 2.1689e-01}; - float query_key_value_bias_data[]{ - -1.9181e-01, -7.3069e-02, 8.1995e-02, -2.1902e-01, 1.3705e-01, 1.3014e-01, 1.5505e-01, -1.5939e-01, - 2.4900e-01, -1.4819e-01, 2.4959e-01, -2.3992e-01, -2.2274e-01, 1.5355e-01, 2.6128e-02, 1.4422e-02, - -1.3844e-01, -1.0487e-01, -7.3097e-02, -2.4354e-01, 1.2992e-02, 4.4215e-02, -2.1166e-04, 8.0733e-02, - 2.3721e-01, 6.6471e-02, -9.1518e-02, -1.0289e-01, -1.5995e-01, -1.7330e-01, -4.0262e-02, -4.4216e-02, - 1.1122e-01, -1.0686e-01, 1.9930e-01, -1.7542e-01, 7.1055e-04, 2.2473e-01, 2.4860e-01, -1.4482e-01, - 4.4532e-02, 2.9531e-02, -1.1721e-01, -8.6374e-02, 6.7713e-02, -1.7384e-01, 4.1249e-02, 1.0818e-01}; - float dense_weight_data[]{ - -0.0985, 0.2077, -0.0165, 0.1134, 0.2476, -0.0764, 0.1359, -0.0715, -0.0365, -0.0424, -0.0016, -0.0944, - 0.0586, 0.0094, 0.1585, -0.0506, 0.0251, -0.0930, -0.2094, 0.1012, 0.0320, -0.1001, -0.0845, 0.0654, - -0.0452, 0.1634, 0.0142, 0.0944, 0.1089, -0.0613, 0.1082, 0.1845, 0.0115, 0.0489, 0.0091, 0.1731, - -0.1055, -0.1329, 0.1089, -0.2176, 0.0040, -0.1150, 0.1650, -0.2328, 0.1516, 0.2473, 0.0506, -0.0167, - 0.2461, -0.1068, -0.0200, -0.1095, -0.0345, 0.0543, 0.0334, -0.0472, -0.2500, 0.0128, -0.0074, 0.0376, - 0.1916, 0.2430, -0.1483, -0.0156, -0.1002, -0.2304, -0.1819, 0.2346, 0.1458, 0.1343, 0.1389, -0.1986, - 0.0263, 0.2327, -0.1395, 0.2224, 0.0657, 0.1759, -0.1071, 0.1153, -0.2216, -0.0159, 0.0834, 0.0750, - 0.2092, 0.2457, 0.2274, 0.1678, 0.1758, -0.0323, -0.1940, -0.0928, -0.0281, -0.1351, 0.1278, 0.0878, - 0.0403, 0.0604, 0.2196, 0.0911, -0.2192, -0.1815, 0.1102, 0.0341, 0.1219, -0.2497, -0.2307, 0.1533, - 0.1599, -0.2262, 0.0949, -0.1952, 0.1893, 0.0784, 0.2472, -0.2465, -0.2407, 0.1690, 0.2220, 0.1532, - 0.1662, -0.2106, 0.1810, -0.2360, 0.0559, -0.1638, -0.0993, 0.0443, -0.0983, -0.0396, -0.2370, -0.2357, - 0.0968, -0.1058, -0.1521, -0.0986, -0.2044, -0.1999, 0.1496, 0.1271, 0.0333, -0.1366, 0.0832, 0.0112, - -0.1787, 0.0538, 0.2276, -0.2459, -0.2486, 0.0320, 0.1883, -0.1521, 0.0550, 0.1757, 0.0771, 0.0541, - 0.2494, 0.1629, -0.1406, 0.0216, 0.1277, 0.2205, -0.2487, 0.1354, -0.0726, -0.2448, 0.0555, -0.1361, - 0.0354, -0.1623, -0.1881, -0.0212, -0.0840, 0.1462, 0.0216, 0.1951, 0.0469, -0.0804, 0.1693, 0.0137, - 0.0469, -0.0804, 0.0009, -0.0481, -0.0213, -0.1950, -0.0015, -0.2075, -0.0538, 0.1449, -0.1738, -0.1685, - -0.0610, -0.0685, 0.0423, 0.0415, 0.1268, -0.1722, -0.0176, 0.1398, 0.2162, -0.0182, -0.1447, 0.0719, - 0.1424, -0.1562, -0.1451, 0.1105, -0.0175, -0.2361, -0.1441, 0.1014, -0.0848, 0.1726, 0.1976, 0.0364, - -0.0198, -0.0794, -0.0126, 0.0455, -0.1910, -0.0597, -0.2080, 0.1534, -0.1592, 0.2284, -0.0644, -0.1432, - 0.1201, 0.0373, 0.1731, 0.1044, -0.2408, 0.1581, -0.0471, -0.1105, 0.1588, 0.1823, -0.2197, -0.0226, - 0.2053, 0.0968, 0.2106, -0.0857, -0.1379, 0.2150, 0.1042, 0.2400, -0.1044, -0.1605, -0.0293, -0.2354, - 0.0980, 0.1844, 0.0600, -0.0247}; - float dense_bias_data[]{0.1240, -0.1587, 0.2446, -0.2486, -0.2395, -0.0591, 0.2042, 0.0250, - 0.0960, -0.1833, 0.0912, -0.0279, 0.1002, 0.1766, 0.1087, -0.0213}; - - GLMSelfAttention model(&ictx, 16, 2, 16); - memcpy(model.query_key_value.weight->data, query_key_value_weight_data, sizeof(query_key_value_weight_data)); - memcpy(model.query_key_value.bias->data, query_key_value_bias_data, sizeof(query_key_value_bias_data)); - memcpy(model.dense.weight->data, dense_weight_data, sizeof(dense_weight_data)); - memcpy(model.dense.bias->data, dense_bias_data, sizeof(dense_bias_data)); +TEST_F(ChatGLMTest, GLMBlock) { + fs::path test_path = fs::path(__FILE__).parent_path() / "tests/data/glm_block.data"; + MappedFile mapped_file(test_path.string()); + char *ptr = mapped_file.data; + + constexpr int hidden_size = 32; + constexpr int num_attention_heads = 8; + constexpr int num_hidden_layers = 28; + constexpr int max_length = 16; + constexpr int seq_len = 4; + GLMBlock model(&ictx, hidden_size, num_attention_heads, num_hidden_layers, max_length); + + ptr = map_tensor_data(ptr, model.input_layernorm.weight); + ptr = map_tensor_data(ptr, model.input_layernorm.bias); + ptr = map_tensor_data(ptr, model.attention.query_key_value.weight); + ptr = map_tensor_data(ptr, model.attention.query_key_value.bias); + ptr = map_tensor_data(ptr, model.attention.dense.weight); + ptr = map_tensor_data(ptr, model.attention.dense.bias); + ptr = map_tensor_data(ptr, model.post_attention_layernorm.weight); + ptr = map_tensor_data(ptr, model.post_attention_layernorm.bias); + ptr = map_tensor_data(ptr, model.mlp.dense_h_to_4h.weight); + ptr = map_tensor_data(ptr, model.mlp.dense_h_to_4h.bias); + ptr = map_tensor_data(ptr, model.mlp.dense_4h_to_h.weight); + ptr = map_tensor_data(ptr, model.mlp.dense_4h_to_h.bias); + + ggml_tensor *x1 = ggml_new_tensor_2d(ictx.gctx.get(), GGML_TYPE_F32, hidden_size, seq_len); + ptr = map_tensor_data(ptr, x1); + x1 = ggml_dup(ictx.gctx.get(), x1); + + ggml_tensor *ref_y1 = ggml_new_tensor_2d(ictx.gctx.get(), GGML_TYPE_F32, hidden_size, seq_len); + ptr = map_tensor_data(ptr, ref_y1); + + ggml_tensor *x2 = ggml_new_tensor_1d(ictx.gctx.get(), GGML_TYPE_F32, hidden_size); + ptr = map_tensor_data(ptr, x2); + x2 = ggml_dup(ictx.gctx.get(), x2); + + ggml_tensor *ref_y2 = ggml_new_tensor_1d(ictx.gctx.get(), GGML_TYPE_F32, hidden_size); + ptr = map_tensor_data(ptr, ref_y2); + + ggml_tensor *x3 = ggml_new_tensor_1d(ictx.gctx.get(), GGML_TYPE_F32, hidden_size); + ptr = map_tensor_data(ptr, x3); + x3 = ggml_dup(ictx.gctx.get(), x3); + + ggml_tensor *ref_y3 = ggml_new_tensor_1d(ictx.gctx.get(), GGML_TYPE_F32, hidden_size); + ptr = map_tensor_data(ptr, ref_y3); + + ASSERT_EQ(ptr, mapped_file.data + mapped_file.size); // self attention { - float x_data[]{-1.1230, 0.6210, -0.8764, -1.7740, -0.8104, -1.4760, 0.6670, 0.6688, -0.0752, 0.1643, - 0.0275, 0.5813, 0.2350, 1.0717, -0.5229, -0.7213, -0.2058, -0.8088, 0.4647, 0.3915, - -1.3406, 0.7596, 0.7288, -0.0667, 0.5559, -1.4076, 0.1106, 0.0352, -0.5646, 0.9159, - 0.2990, 0.1771, -0.3439, 0.2291, 0.4366, 0.2702, -0.2002, 1.0184, -0.4169, -0.0308, - -0.9846, 0.2024, 0.4156, -1.2037, -0.1333, -0.2934, 0.4649, 0.5762, -0.0578, 1.1715, - -1.5185, 0.0475, -0.0103, 1.4353, 0.2518, 0.7408, -1.0982, 0.1548, -1.3989, 0.0408, - -1.1174, 0.7029, -0.6675, 0.1254}; - float y_data[]{ - 1.6484e-01, -1.9052e-01, 2.8198e-01, -1.7810e-01, -4.1363e-01, -3.2469e-02, 3.9480e-01, 3.2204e-01, - 1.3954e-01, -4.9466e-02, 5.7194e-02, -5.4678e-02, 1.5006e-01, -2.2709e-04, 1.1909e-01, 1.2991e-01, - 1.4828e-01, -1.7710e-01, 3.0013e-01, -1.6764e-01, -3.2587e-01, -2.7857e-02, 3.3097e-01, 2.3543e-01, - 1.5086e-01, -3.0025e-02, -1.9906e-02, -3.6032e-02, 1.2470e-01, 4.0979e-02, 1.1201e-01, 8.7082e-02, - 1.2724e-01, -1.7805e-01, 2.5121e-01, -1.9400e-01, -3.5495e-01, -3.3908e-02, 2.7882e-01, 2.1870e-01, - 1.6772e-01, -6.0175e-02, -3.4353e-02, -5.1714e-02, 9.2530e-02, 4.2405e-02, 1.1574e-01, 7.0774e-02, - 1.5804e-01, -1.7552e-01, 2.2854e-01, -1.2723e-01, -2.5867e-01, -3.2931e-02, 2.1397e-01, 2.2805e-01, - 1.5255e-01, 1.1970e-02, 4.6676e-02, -9.0893e-02, 1.3114e-01, 5.0670e-02, 2.1247e-01, 4.2616e-02}; - - ggml_tensor *x = ggml_new_tensor_2d(ctx.gctx.get(), GGML_TYPE_F32, 16, 4); - memcpy(x->data, x_data, sizeof(x_data)); - ggml_tensor *y = model.forward(&ctx, x, 0, 4); - - ggml_build_forward_expand(&ctx.gf, y); + ggml_tensor *out_y1 = model.forward(&ctx, x1, 0, seq_len); + + ggml_build_forward_expand(&ctx.gf, out_y1); ggml_graph_compute(ctx.gctx.get(), &ctx.gf); - expect_all_close((float *)y->data, y_data, ggml_nelements(y)); + expect_all_close(ref_y1, out_y1, 5e-4); } // cross attention { - ctx.gf = {}; - ctx.gf.n_threads = 1; - - float x_data[]{0.4047, -0.6549, 0.0521, 0.3401, -0.2124, 1.5629, -0.9072, -1.5662, - 0.0485, 0.9935, 2.1180, -0.1827, -0.7768, 1.7764, -0.5033, 0.0566}; - float y_data[]{0.1011, -0.1058, 0.1434, -0.1980, -0.1304, 0.0632, 0.2109, 0.0619, - 0.0893, 0.0146, -0.1291, -0.1430, 0.1655, 0.0916, 0.1488, -0.0282}; - ggml_tensor *x = ggml_new_tensor_2d(ctx.gctx.get(), GGML_TYPE_F32, 16, 1); - memcpy(x->data, x_data, sizeof(x_data)); + ggml_tensor *out_y2 = model.forward(&ctx, x2, seq_len, seq_len); - ggml_tensor *y = model.forward(&ctx, x, 4, 4); - - ggml_build_forward_expand(&ctx.gf, y); + ggml_build_forward_expand(&ctx.gf, out_y2); ggml_graph_compute(ctx.gctx.get(), &ctx.gf); - expect_all_close((float *)y->data, y_data, ggml_nelements(y)); + expect_all_close(ref_y2, out_y2, 5e-4); } { - ctx.gf = {}; - ctx.gf.n_threads = 1; - - float x_data[]{-0.5481, -0.6917, -0.6559, 0.6949, -0.4671, -0.1680, 0.7585, 1.1881, - -0.6305, -0.0654, 0.6188, 2.0020, 0.2952, 0.5314, -0.5227, 0.0995}; - float y_data[]{0.1479, -0.1224, 0.2174, -0.1259, -0.1662, 0.0400, 0.2235, 0.1599, - 0.0633, 0.0484, -0.0541, -0.1475, 0.1395, 0.0448, 0.2216, 0.0838}; - ggml_tensor *x = ggml_new_tensor_2d(ctx.gctx.get(), GGML_TYPE_F32, 16, 1); - memcpy(x->data, x_data, sizeof(x_data)); + ggml_tensor *out_y3 = model.forward(&ctx, x3, seq_len + 1, seq_len); - ggml_tensor *y = model.forward(&ctx, x, 5, 4); - - ggml_build_forward_expand(&ctx.gf, y); + ggml_build_forward_expand(&ctx.gf, out_y3); ggml_graph_compute(ctx.gctx.get(), &ctx.gf); - expect_all_close((float *)y->data, y_data, ggml_nelements(y)); + expect_all_close(ref_y3, out_y3, 5e-4); } } -TEST_F(ChatGLMTest, GLMBlock) { - float x_data[]{-1.0159, -1.5296, -1.2180, 0.8683, -0.3850, 0.0964, -0.3182, 1.8511, -0.1479, -1.0800, -0.1396, - -1.2574, 1.5435, 0.9686, -1.5552, 0.7989, 1.1896, -0.1795, -1.8232, -0.0514, -1.6361, -2.5106, - -1.2383, -0.9296, -0.0585, 0.7729, 0.4689, 0.2599, 0.7576, -2.0418, -0.5524, 0.6270}; - float query_key_value_weight_data[]{ - -0.0026, 0.1897, -0.2910, -0.2602, -0.1362, 0.0948, -0.0070, 0.2803, -0.0314, 0.0936, -0.1068, -0.0695, - -0.3378, -0.2342, -0.1457, 0.0131, 0.1398, 0.2121, -0.2397, -0.1540, 0.1284, 0.2936, -0.0728, 0.2646, - -0.0570, 0.0374, 0.3201, -0.3280, -0.2226, -0.0895, -0.1378, 0.3055, -0.2292, -0.1628, -0.2470, -0.3311, - -0.2064, 0.3039, 0.1578, 0.1714, 0.0186, -0.1813, 0.0598, -0.3301, -0.2555, -0.1823, 0.2231, 0.2073, - -0.1568, -0.0128, 0.2261, 0.3515, 0.1403, 0.0478, 0.2371, -0.2082, 0.0659, -0.2741, -0.2450, -0.1826, - 0.1600, 0.1422, -0.2094, 0.1068, 0.1941, -0.0446, 0.0135, 0.0819, 0.2193, 0.3395, -0.2725, -0.1296, - 0.1389, 0.2929, 0.3077, 0.3120, 0.0704, -0.3074, 0.0325, -0.2212, -0.3295, 0.3141, 0.2688, -0.3527, - 0.0662, -0.0596, -0.0582, -0.1618, 0.1360, -0.2094, 0.1296, 0.1788, 0.2531, 0.1322, -0.3499, -0.2293, - 0.1765, 0.0740, -0.2758, -0.2036, 0.3326, 0.2382, -0.1542, -0.0890, -0.3368, -0.0064, -0.2662, -0.2727, - -0.0195, 0.0531, -0.1448, 0.2098, -0.2152, 0.3208, 0.2423, -0.2981, -0.0880, 0.0160, 0.0516, 0.0839, - 0.1387, 0.0212, -0.1725, 0.1673, -0.3391, -0.2096, -0.0885, -0.1722, -0.1237, -0.2898, -0.0752, 0.0756, - -0.2303, -0.0181, 0.2531, -0.0363, 0.0098, -0.0305, 0.0716, 0.2248, 0.3349, 0.2245, 0.3357, -0.0256, - -0.3176, -0.1676, 0.2407, -0.0023, -0.1757, -0.2709, -0.3309, -0.2984, -0.0717, 0.1939, 0.1911, -0.3410, - 0.2205, -0.2767, -0.0747, -0.1434, -0.0681, -0.0694, -0.3173, -0.3053, -0.0553, 0.0046, -0.1606, 0.1332, - -0.3182, -0.0239, 0.3109, -0.1442, 0.3193, 0.1280, -0.3191, 0.2237, -0.0408, -0.1578, 0.2827, -0.2857, - 0.0379, -0.0740, 0.2525, 0.0987, 0.1699, 0.1249, -0.0850, -0.0744, -0.2914, 0.1916, 0.2807, 0.2419}; - float query_key_value_bias_data[]{-0.2494, 0.0158, -0.2492, -0.1946, -0.2060, 0.1208, -0.2107, -0.0077, - 0.0149, 0.2279, -0.2673, -0.2427, -0.2053, 0.2475, -0.1271, 0.2982, - 0.1278, 0.0448, -0.0026, -0.0699, 0.0444, -0.0807, -0.0025, 0.0451}; - float dense_weight_data[]{-0.2766, -0.1853, 0.2855, -0.2869, -0.0254, 0.3497, 0.1277, 0.0100, -0.3064, 0.1751, - -0.2518, -0.1004, -0.1186, -0.0524, 0.0039, 0.2916, 0.0441, 0.3167, 0.2163, -0.2235, - 0.1586, -0.2499, -0.1498, 0.1040, 0.1167, 0.2652, -0.1138, 0.0006, 0.1820, -0.3419, - 0.2556, -0.2924, 0.0049, -0.0601, -0.1862, 0.0467, 0.2924, -0.1034, -0.2099, -0.1308, - -0.3504, 0.1596, -0.1698, -0.2359, -0.2037, 0.2033, 0.1872, 0.2714, 0.1282, -0.1181, - -0.0988, 0.1045, 0.2906, 0.0961, -0.1673, -0.1662, -0.3343, 0.0764, -0.1984, -0.3152, - 0.3100, -0.2296, -0.0402, 0.1013}; - float dense_bias_data[]{0.0113, -0.2379, -0.2858, 0.2818, 0.0576, 0.2933, -0.1185, 0.1041}; - - float dense_h_to_4h_weight_data[]{ - -0.0808, -0.0157, -0.2153, 0.1196, 0.1118, -0.0073, -0.0795, -0.2179, 0.2445, -0.2632, 0.1448, -0.1189, - -0.1706, 0.0635, -0.1837, 0.0815, 0.0694, -0.2625, 0.0589, 0.1506, 0.1399, -0.0445, -0.2898, -0.0545, - 0.1228, -0.1290, 0.1342, 0.2355, -0.1846, 0.0035, 0.1462, 0.0277, 0.0295, 0.0441, -0.2779, 0.0278, - 0.2448, 0.3186, 0.2078, 0.0474, 0.1651, -0.1720, -0.2930, -0.3040, 0.3527, 0.2244, -0.2444, 0.1383, - 0.2670, 0.3534, 0.3091, 0.2739, -0.0811, -0.1241, 0.2903, 0.1981, -0.2128, 0.3179, 0.1708, 0.1927, - -0.2216, 0.1014, -0.1239, 0.2763, -0.0636, 0.1376, 0.0628, 0.1504, -0.1201, 0.1724, -0.2469, 0.0799, - -0.2392, -0.3488, -0.2839, 0.2791, 0.1913, 0.3317, 0.2832, -0.3157, -0.2413, -0.0571, -0.2296, 0.2455, - -0.2673, -0.1725, -0.3416, -0.2007, 0.2908, 0.2895, 0.2531, 0.2730, 0.3144, -0.0905, 0.1556, 0.3150, - 0.1170, 0.3534, 0.1834, 0.2198, -0.1237, 0.1696, 0.0406, -0.0844, -0.1993, -0.1984, -0.2721, 0.2374, - 0.2514, -0.0402, -0.2046, 0.2733, 0.2261, 0.0263, -0.1669, 0.3249, 0.1446, -0.2684, 0.3384, 0.2685, - -0.1289, 0.1988, -0.2009, -0.0554, 0.3002, 0.0146, -0.2500, -0.1182, -0.0960, -0.0682, 0.0338, 0.3270, - 0.0189, -0.2183, 0.0181, 0.1695, 0.1754, -0.3231, -0.0633, -0.2627, -0.1509, 0.1274, -0.2511, 0.1314, - 0.3001, 0.0232, -0.2356, -0.1267, 0.0772, -0.2695, 0.1756, -0.3210, -0.3399, -0.3435, -0.0717, 0.2377, - -0.3346, 0.2939, -0.1414, 0.1036, 0.0161, -0.3188, 0.2932, 0.1904, 0.3514, 0.1786, -0.2334, 0.2951, - 0.0190, 0.1677, -0.2835, -0.1017, -0.3471, -0.1377, 0.0763, -0.2776, 0.1127, 0.1898, 0.0493, -0.2366, - -0.2741, -0.1091, 0.1552, 0.3487, 0.2033, -0.0398, 0.1240, -0.3469, -0.3020, 0.1650, -0.2003, 0.1701, - -0.2496, -0.1751, -0.2912, 0.1845, -0.0360, 0.2721, 0.2188, 0.1956, 0.0114, -0.1093, -0.0769, 0.0470, - 0.1753, -0.2477, 0.2967, -0.0384, -0.2963, -0.1913, 0.3128, 0.3233, -0.3275, 0.2494, 0.1772, 0.2093, - 0.2993, -0.1905, 0.1116, 0.1447, -0.1045, 0.1183, -0.1017, 0.2186, -0.0981, -0.1318, 0.0890, 0.1254, - -0.1727, 0.0313, 0.2049, -0.0352, 0.1076, -0.0853, 0.1239, -0.2561, -0.2079, -0.1795, 0.3249, -0.0951, - -0.0010, -0.1713, 0.3529, 0.3453, -0.2666, -0.2866, -0.2680, -0.0017, -0.0901, -0.2314, -0.1268, 0.0668, - -0.1847, 0.0783, -0.0811, -0.1713}; - float dense_h_to_4h_bias_data[]{0.0486, 0.2907, -0.2390, 0.0164, -0.1304, 0.3469, -0.3354, -0.3389, - 0.3484, -0.2237, 0.0678, -0.0305, -0.0745, -0.0790, 0.2247, 0.0169, - -0.3442, -0.2087, -0.1205, 0.1779, -0.2288, 0.3334, -0.0787, -0.0635, - 0.2770, 0.1777, 0.2999, 0.2045, -0.1073, -0.2346, -0.0263, 0.2926}; - float dense_4h_to_h_weight_data[]{ - -5.9330e-02, -1.6394e-01, 7.2466e-02, 1.7209e-01, -5.0325e-02, -1.4638e-01, -1.6035e-01, 4.4300e-02, - -1.3385e-02, -8.9270e-02, 3.5732e-02, 6.7133e-02, 1.4059e-01, 1.3725e-01, -2.6461e-02, -1.5588e-01, - -1.5974e-01, 1.6505e-01, 7.8145e-02, 7.7058e-02, -1.5295e-01, 1.6370e-01, 1.6747e-01, 1.5960e-01, - -1.4913e-01, -6.6702e-02, -1.2158e-01, 1.6740e-01, -7.5957e-02, -8.0708e-02, 9.2616e-02, -8.1776e-02, - -8.7066e-02, -1.5463e-02, -1.6992e-02, -1.3771e-01, 1.4737e-01, -7.7980e-02, 6.2705e-02, 1.5376e-01, - 8.9151e-02, 2.5028e-02, 1.5041e-01, 2.3763e-02, -8.1787e-02, 1.6723e-01, 4.1840e-02, -1.7248e-01, - -5.0327e-02, -1.2041e-01, 1.5499e-01, -2.9194e-02, -1.6114e-01, -1.1113e-02, 1.1102e-01, 4.5927e-02, - 5.5898e-02, 1.6418e-02, 6.5906e-02, -4.3072e-02, -7.0322e-02, -1.6523e-01, -1.3317e-01, 7.6615e-02, - -1.0467e-01, 2.5362e-02, 5.6426e-02, 1.2516e-02, -1.1461e-01, 1.6906e-01, -1.0280e-01, 1.4540e-01, - -1.4063e-01, -4.2523e-02, 9.6161e-02, -7.2228e-02, 1.4850e-01, -1.2165e-01, -1.4845e-01, -7.9712e-02, - 2.8576e-02, 1.6277e-01, -8.4393e-02, 6.3220e-02, -4.4319e-02, -3.8336e-02, 1.2998e-01, -1.3700e-01, - 1.8776e-02, 1.6624e-01, -2.4291e-02, 1.3726e-01, -5.4440e-02, 1.4230e-01, -1.7100e-01, -2.5478e-02, - -3.1050e-02, 5.7291e-02, 6.9354e-02, 1.3573e-01, -2.6337e-02, -7.0101e-03, 1.2106e-01, -4.7831e-02, - 1.5497e-01, -1.1770e-01, -1.9129e-02, -9.4948e-03, 7.8880e-02, 1.2086e-01, -2.8020e-02, -1.4646e-01, - 8.7591e-02, 5.2880e-02, 7.1011e-02, -1.0905e-01, 1.1377e-01, 1.6743e-01, 1.5339e-02, -1.6512e-01, - 1.2410e-01, -1.3107e-01, 4.0636e-02, 2.5689e-02, -8.2738e-02, 6.1553e-02, -1.5811e-01, 4.0263e-02, - -1.1207e-01, -1.9116e-02, 2.2739e-02, 1.5060e-01, -8.4344e-02, 1.1325e-01, -2.2457e-02, -8.3956e-02, - -1.5394e-01, -1.6220e-01, 1.7264e-01, -4.4088e-02, 8.8298e-03, 4.7926e-02, 1.2016e-01, 1.5087e-01, - 1.4337e-01, -1.3097e-01, -2.8323e-02, -1.0461e-01, -1.0100e-01, 4.1933e-02, 1.6593e-01, -1.4162e-01, - 1.0699e-01, -9.1653e-02, -3.4432e-02, 1.4032e-01, -3.9982e-02, 1.6097e-02, -1.2357e-01, 1.5048e-01, - -2.2852e-02, -1.2929e-01, 5.1762e-02, -1.2569e-01, -1.4027e-01, 1.0769e-02, 1.4016e-01, -5.0025e-02, - 8.3203e-02, 1.5191e-01, 1.1725e-01, -9.2730e-02, -1.9378e-02, -5.5629e-02, -1.4215e-01, 7.3040e-05, - 1.3301e-01, 1.4894e-01, 1.6493e-02, 4.0150e-02, -7.6530e-02, 1.3344e-01, -7.3558e-02, -1.2280e-01, - 2.7222e-02, 1.0596e-01, -1.5938e-01, 1.5980e-01, 6.3592e-02, -1.2385e-01, -3.8079e-02, 1.5337e-01, - -1.3563e-01, -5.1668e-02, 5.7979e-02, -1.5487e-01, 9.6909e-02, 9.2021e-02, 1.0964e-01, -1.1270e-01, - 1.7607e-01, -1.0479e-01, 1.7648e-01, -1.6965e-01, -1.5750e-01, 1.0858e-01, 1.8475e-02, 1.0198e-02, - -9.7892e-02, -7.4153e-02, -5.1687e-02, -1.7221e-01, 9.1869e-03, 3.1265e-02, -1.4966e-04, 5.7087e-02, - 1.6773e-01, 4.7002e-02, -6.4713e-02, -7.2752e-02, -1.1310e-01, -1.2254e-01, -2.8470e-02, -3.1266e-02, - 7.8641e-02, -7.5561e-02, 1.4093e-01, -1.2404e-01, 5.0243e-04, 1.5891e-01, 1.7578e-01, -1.0240e-01, - 3.1489e-02, 2.0881e-02, -8.2882e-02, -6.1075e-02, 4.7881e-02, -1.2292e-01, 2.9167e-02, 7.6496e-02, - -6.9662e-02, 1.4684e-01, -1.1634e-02, 8.0206e-02, 1.7506e-01, -5.4035e-02, 9.6062e-02, -5.0563e-02, - -2.5822e-02, -2.9959e-02, -1.0982e-03, -6.6781e-02, 4.1434e-02, 6.6610e-03, 1.1206e-01, -3.5781e-02}; - float dense_4h_to_h_bias_data[]{0.0177, -0.0658, -0.1480, 0.0715, 0.0226, -0.0708, -0.0598, 0.0462}; - float y_data[]{-6.1799, -9.0367, -7.6170, 7.9863, -0.7231, 2.1669, -0.6415, 14.4362, -1.0757, -7.2829, -0.6062, - -7.4960, 11.8125, 7.5786, -9.6781, 7.1442, 14.2660, 4.8352, -7.2181, 6.3933, -4.4885, -10.8612, - -2.3747, 0.0401, -0.9771, 6.1194, 2.9300, 2.2652, 6.4303, -17.4357, -5.0949, 5.3586}; - - ggml_tensor *x = ggml_new_tensor_2d(ctx.gctx.get(), GGML_TYPE_F32, 8, 4); - memcpy(x->data, x_data, sizeof(x_data)); - - GLMBlock model(&ictx, 8, 2, 28, 16); - ggml_set_f32(model.input_layernorm.weight, 1); - ggml_set_f32(model.input_layernorm.bias, 0); - memcpy(model.attention.query_key_value.weight->data, query_key_value_weight_data, - sizeof(query_key_value_weight_data)); - memcpy(model.attention.query_key_value.bias->data, query_key_value_bias_data, sizeof(query_key_value_bias_data)); - memcpy(model.attention.dense.weight->data, dense_weight_data, sizeof(dense_weight_data)); - memcpy(model.attention.dense.bias->data, dense_bias_data, sizeof(dense_bias_data)); - ggml_set_f32(model.post_attention_layernorm.weight, 1); - ggml_set_f32(model.post_attention_layernorm.bias, 0); - memcpy(model.mlp.dense_h_to_4h.weight->data, dense_h_to_4h_weight_data, sizeof(dense_h_to_4h_weight_data)); - memcpy(model.mlp.dense_h_to_4h.bias->data, dense_h_to_4h_bias_data, sizeof(dense_h_to_4h_bias_data)); - memcpy(model.mlp.dense_4h_to_h.weight->data, dense_4h_to_h_weight_data, sizeof(dense_4h_to_h_weight_data)); - memcpy(model.mlp.dense_4h_to_h.bias->data, dense_4h_to_h_bias_data, sizeof(dense_4h_to_h_bias_data)); - - ggml_tensor *y = model.forward(&ctx, x, 0, 4); - - ggml_build_forward_expand(&ctx.gf, y); - ggml_graph_compute(ctx.gctx.get(), &ctx.gf); - - expect_all_close((float *)y->data, y_data, ggml_nelements(y), 5e-4); -} - TEST_F(ChatGLMTest, GLM2Block) { - float input_layernorm_weight[]{0.4692, 0.1864, 0.3191, 0.8249, 0.2995, 0.8105, 0.3017, 0.3836, - 0.5106, 0.0412, 0.4950, 0.4496, 0.4551, 0.4000, 0.8942, 0.8690}; - float attn_qkv_weight[]{ - -0.0019, 0.1341, -0.2058, -0.1840, -0.0963, 0.0670, -0.0050, 0.1982, -0.0222, 0.0662, -0.0756, -0.0491, - -0.2388, -0.1656, -0.1031, 0.0093, 0.0988, 0.1500, -0.1695, -0.1089, 0.0908, 0.2076, -0.0515, 0.1871, - -0.0403, 0.0265, 0.2264, -0.2319, -0.1574, -0.0633, -0.0974, 0.2160, -0.1620, -0.1151, -0.1747, -0.2341, - -0.1459, 0.2149, 0.1116, 0.1212, 0.0131, -0.1282, 0.0423, -0.2334, -0.1806, -0.1289, 0.1577, 0.1466, - -0.1109, -0.0090, 0.1599, 0.2485, 0.0992, 0.0338, 0.1676, -0.1472, 0.0466, -0.1938, -0.1733, -0.1291, - 0.1131, 0.1005, -0.1481, 0.0755, 0.1372, -0.0316, 0.0095, 0.0579, 0.1551, 0.2400, -0.1927, -0.0916, - 0.0983, 0.2071, 0.2176, 0.2206, 0.0498, -0.2174, 0.0230, -0.1564, -0.2330, 0.2221, 0.1901, -0.2494, - 0.0468, -0.0421, -0.0411, -0.1144, 0.0961, -0.1481, 0.0916, 0.1264, 0.1790, 0.0935, -0.2474, -0.1622, - 0.1248, 0.0523, -0.1950, -0.1440, 0.2352, 0.1685, -0.1090, -0.0629, -0.2381, -0.0045, -0.1883, -0.1928, - -0.0138, 0.0375, -0.1024, 0.1483, -0.1521, 0.2268, 0.1713, -0.2108, -0.0622, 0.0113, 0.0365, 0.0593, - 0.0981, 0.0150, -0.1220, 0.1183, -0.2398, -0.1482, -0.0626, -0.1218, -0.0875, -0.2049, -0.0532, 0.0534, - -0.1629, -0.0128, 0.1790, -0.0257, 0.0069, -0.0216, 0.0506, 0.1590, 0.2368, 0.1588, 0.2374, -0.0181, - -0.2246, -0.1185, 0.1702, -0.0016, -0.1243, -0.1916, -0.2340, -0.2110, -0.0507, 0.1371, 0.1352, -0.2411, - 0.1559, -0.1956, -0.0529, -0.1014, -0.0482, -0.0491, -0.2243, -0.2159, -0.0391, 0.0032, -0.1136, 0.0942, - -0.2250, -0.0169, 0.2199, -0.1020, 0.2258, 0.0905, -0.2256, 0.1582, -0.0288, -0.1116, 0.1999, -0.2020, - 0.0268, -0.0523, 0.1785, 0.0698, 0.1201, 0.0883, -0.0601, -0.0526, -0.2060, 0.1355, 0.1985, 0.1711, - -0.1763, 0.0111, -0.1762, -0.1376, -0.1457, 0.0854, -0.1490, -0.0055, 0.0105, 0.1612, -0.1890, -0.1716, - -0.1452, 0.1750, -0.0899, 0.2109, 0.0904, 0.0317, -0.0019, -0.0494, 0.0314, -0.0571, -0.0018, 0.0319, - -0.1956, -0.1310, 0.2019, -0.2029, -0.0180, 0.2473, 0.0903, 0.0071, -0.2167, 0.1238, -0.1781, -0.0710, - -0.0839, -0.0370, 0.0027, 0.2062, 0.0312, 0.2239, 0.1529, -0.1581, 0.1121, -0.1767, -0.1060, 0.0735, - 0.0825, 0.1876, -0.0805, 0.0004, 0.1287, -0.2418, 0.1807, -0.2067, 0.0034, -0.0425, -0.1317, 0.0330, - 0.2067, -0.0731, -0.1484, -0.0925, -0.2478, 0.1128, -0.1201, -0.1668, -0.1440, 0.1437, 0.1324, 0.1919, - 0.0907, -0.0835, -0.0699, 0.0739, 0.2055, 0.0680, -0.1183, -0.1175, -0.2364, 0.0540, -0.1403, -0.2229, - 0.2192, -0.1624, -0.0284, 0.0716, 0.0080, -0.1682, -0.2021, 0.1993, 0.0407, 0.2074, -0.0838, 0.0736, - -0.0572, -0.0111, -0.1523, 0.0846, 0.0790, -0.0052, -0.0562, -0.1541, 0.1729, -0.1861, 0.1024, -0.0841, - -0.1206, 0.0449, -0.1299, 0.0576, 0.0491, -0.1856, 0.0416, 0.1065, 0.0990, -0.0315, -0.2049, -0.0385, - 0.0868, -0.0912, 0.0949, 0.1665, -0.1305, 0.0025, 0.1034, 0.0196, 0.0209, 0.0312, -0.1965, 0.0196, - 0.1731, 0.2253, 0.1469, 0.0335, 0.1168, -0.1216, -0.2072, -0.2150, 0.2494, 0.1587, -0.1728, 0.0978, - 0.1888, 0.2499, 0.2186, 0.1937, -0.0573, -0.0877, 0.2053, 0.1401, -0.1504, 0.2248, 0.1208, 0.1363, - -0.1567, 0.0717, -0.0876, 0.1953, -0.0450, 0.0973, 0.0444, 0.1064, -0.0850, 0.1219, -0.1746, 0.0565, - -0.1691, -0.2466, -0.2008, 0.1974, 0.1353, 0.2346, 0.2003, -0.2233, -0.1706, -0.0404, -0.1624, 0.1736, - -0.1890, -0.1220, -0.2415, -0.1419, 0.2056, 0.2047, 0.1790, 0.1930, 0.2223, -0.0640, 0.1100, 0.2227, - 0.0827, 0.2499, 0.1297, 0.1554, -0.0875, 0.1200, 0.0287, -0.0597, -0.1409, -0.1403, -0.1924, 0.1678, - 0.1777, -0.0285, -0.1447, 0.1932, 0.1599, 0.0186, -0.1180, 0.2298, 0.1022, -0.1898, 0.2393, 0.1898, - -0.0911, 0.1405, -0.1420, -0.0392, 0.2123, 0.0103, -0.1768, -0.0836, -0.0679, -0.0482, 0.0239, 0.2312, - 0.0134, -0.1544, 0.0128, 0.1199, 0.1240, -0.2285, -0.0447, -0.1858, -0.1067, 0.0901, -0.1775, 0.0929, - 0.2122, 0.0164, -0.1666, -0.0896, 0.0546, -0.1906, 0.1242, -0.2270, -0.2403, -0.2429, -0.0507, 0.1681, - -0.2366, 0.2078, -0.1000, 0.0732, 0.0114, -0.2254, 0.2073, 0.1346, 0.2485, 0.1263, -0.1650, 0.2086, - 0.0134, 0.1186, -0.2005, -0.0719, -0.2455, -0.0974, 0.0539, -0.1963, 0.0797, 0.1342, 0.0348, -0.1673, - -0.1938, -0.0771, 0.1097, 0.2466, 0.1438, -0.0282, 0.0877, -0.2453, -0.2135, 0.1167, -0.1416, 0.1203, - -0.1765, -0.1238, -0.2059, 0.1305, -0.0255, 0.1924, 0.1547, 0.1383, 0.0080, -0.0773, -0.0544, 0.0332, - 0.1239, -0.1751, 0.2098, -0.0272, -0.2095, -0.1353, 0.2212, 0.2286, -0.2316, 0.1763, 0.1253, 0.1480, - 0.2116, -0.1347, 0.0789, 0.1023, -0.0739, 0.0837, -0.0719, 0.1546}; - float attn_qkv_bias[]{-0.0694, -0.0932, 0.0629, 0.0887, -0.1221, 0.0221, 0.1449, -0.0249, - 0.0761, -0.0603, 0.0876, -0.1811, -0.1470, -0.1269, 0.2298, -0.0673, - -0.0007, -0.1211, 0.2496, 0.2442, -0.1885, -0.2027, -0.1895, -0.0012, - -0.0637, -0.1636, -0.0897, 0.0472, -0.1306, 0.0554, -0.0573, -0.1211}; - float attn_dense_weight[]{ - 3.4346e-02, 2.0556e-01, -1.6902e-01, 1.1609e-02, -9.2192e-02, 2.4533e-01, -2.3719e-01, -2.3967e-01, - 2.4635e-01, -1.5817e-01, 4.7931e-02, -2.1577e-02, -5.2664e-02, -5.5837e-02, 1.5886e-01, 1.1949e-02, - -2.4340e-01, -1.4759e-01, -8.5225e-02, 1.2580e-01, -1.6178e-01, 2.3573e-01, -5.5682e-02, -4.4890e-02, - 1.9589e-01, 1.2565e-01, 2.1203e-01, 1.4461e-01, -7.5838e-02, -1.6587e-01, -1.8601e-02, 2.0691e-01, - -8.3905e-02, -2.3184e-01, 1.0248e-01, 2.4337e-01, -7.1171e-02, -2.0701e-01, -2.2676e-01, 6.2650e-02, - -1.8930e-02, -1.2625e-01, 5.0533e-02, 9.4940e-02, 1.9883e-01, 1.9410e-01, -3.7421e-02, -2.2044e-01, - -2.2591e-01, 2.3342e-01, 1.1051e-01, 1.0898e-01, -2.1631e-01, 2.3150e-01, 2.3684e-01, 2.2572e-01, - -2.1090e-01, -9.4331e-02, -1.7194e-01, 2.3674e-01, -1.0742e-01, -1.1414e-01, 1.3098e-01, -1.1565e-01, - -1.2313e-01, -2.1868e-02, -2.4030e-02, -1.9474e-01, 2.0842e-01, -1.1028e-01, 8.8679e-02, 2.1745e-01, - 1.2608e-01, 3.5395e-02, 2.1271e-01, 3.3607e-02, -1.1566e-01, 2.3650e-01, 5.9170e-02, -2.4392e-01, - -7.1173e-02, -1.7029e-01, 2.1919e-01, -4.1286e-02, -2.2788e-01, -1.5716e-02, 1.5700e-01, 6.4950e-02, - 7.9051e-02, 2.3218e-02, 9.3205e-02, -6.0913e-02, -9.9450e-02, -2.3368e-01, -1.8833e-01, 1.0835e-01, - -1.4803e-01, 3.5867e-02, 7.9798e-02, 1.7701e-02, -1.6209e-01, 2.3908e-01, -1.4537e-01, 2.0563e-01, - -1.9888e-01, -6.0137e-02, 1.3599e-01, -1.0215e-01, 2.1001e-01, -1.7204e-01, -2.0994e-01, -1.1273e-01, - 4.0413e-02, 2.3019e-01, -1.1935e-01, 8.9407e-02, -6.2677e-02, -5.4216e-02, 1.8383e-01, -1.9375e-01, - 2.6553e-02, 2.3510e-01, -3.4353e-02, 1.9411e-01, -7.6989e-02, 2.0124e-01, -2.4183e-01, -3.6031e-02, - -4.3912e-02, 8.1021e-02, 9.8082e-02, 1.9195e-01, -3.7246e-02, -9.9138e-03, 1.7120e-01, -6.7643e-02, - 2.1916e-01, -1.6645e-01, -2.7052e-02, -1.3428e-02, 1.1155e-01, 1.7092e-01, -3.9626e-02, -2.0712e-01, - 1.2387e-01, 7.4783e-02, 1.0042e-01, -1.5422e-01, 1.6090e-01, 2.3679e-01, 2.1693e-02, -2.3351e-01, - 1.7550e-01, -1.8536e-01, 5.7468e-02, 3.6329e-02, -1.1701e-01, 8.7049e-02, -2.2361e-01, 5.6940e-02, - -1.5849e-01, -2.7033e-02, 3.2157e-02, 2.1299e-01, -1.1928e-01, 1.6016e-01, -3.1758e-02, -1.1873e-01, - -2.1770e-01, -2.2938e-01, 2.4415e-01, -6.2350e-02, 1.2487e-02, 6.7778e-02, 1.6993e-01, 2.1337e-01, - 2.0275e-01, -1.8522e-01, -4.0054e-02, -1.4793e-01, -1.4284e-01, 5.9302e-02, 2.3466e-01, -2.0028e-01, - 1.5130e-01, -1.2962e-01, -4.8694e-02, 1.9844e-01, -5.6543e-02, 2.2764e-02, -1.7476e-01, 2.1281e-01, - -3.2317e-02, -1.8285e-01, 7.3203e-02, -1.7775e-01, -1.9838e-01, 1.5230e-02, 1.9821e-01, -7.0746e-02, - 1.1767e-01, 2.1483e-01, 1.6582e-01, -1.3114e-01, -2.7405e-02, -7.8671e-02, -2.0103e-01, 1.0329e-04, - 1.8811e-01, 2.1063e-01, 2.3325e-02, 5.6781e-02, -1.0823e-01, 1.8871e-01, -1.0403e-01, -1.7366e-01, - 3.8498e-02, 1.4985e-01, -2.2540e-01, 2.2599e-01, 8.9933e-02, -1.7515e-01, -5.3852e-02, 2.1689e-01, - -1.9181e-01, -7.3069e-02, 8.1995e-02, -2.1902e-01, 1.3705e-01, 1.3014e-01, 1.5505e-01, -1.5939e-01, - 2.4900e-01, -1.4819e-01, 2.4959e-01, -2.3992e-01, -2.2274e-01, 1.5355e-01, 2.6128e-02, 1.4422e-02, - -1.3844e-01, -1.0487e-01, -7.3097e-02, -2.4354e-01, 1.2992e-02, 4.4215e-02, -2.1166e-04, 8.0733e-02, - 2.3721e-01, 6.6471e-02, -9.1518e-02, -1.0289e-01, -1.5995e-01, -1.7330e-01, -4.0262e-02, -4.4216e-02}; - float post_attention_layernorm_weight[]{0.1611, 0.7323, 0.1078, 0.0743, 0.6528, 0.5073, 0.2667, 0.0177, - 0.3064, 0.6670, 0.0372, 0.0143, 0.5634, 0.1398, 0.0620, 0.3074}; - float mlp_dense_h_to_4h_weight[]{ - 0.1112, -0.1069, 0.1993, -0.1754, 0.0007, 0.2247, 0.2486, -0.1448, 0.0445, 0.0295, -0.1172, -0.0864, - 0.0677, -0.1738, 0.0412, 0.1082, -0.0985, 0.2077, -0.0165, 0.1134, 0.2476, -0.0764, 0.1359, -0.0715, - -0.0365, -0.0424, -0.0016, -0.0944, 0.0586, 0.0094, 0.1585, -0.0506, 0.0251, -0.0930, -0.2094, 0.1012, - 0.0320, -0.1001, -0.0845, 0.0654, -0.0452, 0.1634, 0.0142, 0.0944, 0.1089, -0.0613, 0.1082, 0.1845, - 0.0115, 0.0489, 0.0091, 0.1731, -0.1055, -0.1329, 0.1089, -0.2176, 0.0040, -0.1150, 0.1650, -0.2328, - 0.1516, 0.2473, 0.0506, -0.0167, 0.2461, -0.1068, -0.0200, -0.1095, -0.0345, 0.0543, 0.0334, -0.0472, - -0.2500, 0.0128, -0.0074, 0.0376, 0.1916, 0.2430, -0.1483, -0.0156, -0.1002, -0.2304, -0.1819, 0.2346, - 0.1458, 0.1343, 0.1389, -0.1986, 0.0263, 0.2327, -0.1395, 0.2224, 0.0657, 0.1759, -0.1071, 0.1153, - -0.2216, -0.0159, 0.0834, 0.0750, 0.2092, 0.2457, 0.2274, 0.1678, 0.1758, -0.0323, -0.1940, -0.0928, - -0.0281, -0.1351, 0.1278, 0.0878, 0.0403, 0.0604, 0.2196, 0.0911, -0.2192, -0.1815, 0.1102, 0.0341, - 0.1219, -0.2497, -0.2307, 0.1533, 0.1599, -0.2262, 0.0949, -0.1952, 0.1893, 0.0784, 0.2472, -0.2465, - -0.2407, 0.1690, 0.2220, 0.1532, 0.1662, -0.2106, 0.1810, -0.2360, 0.0559, -0.1638, -0.0993, 0.0443, - -0.0983, -0.0396, -0.2370, -0.2357, 0.0968, -0.1058, -0.1521, -0.0986, -0.2044, -0.1999, 0.1496, 0.1271, - 0.0333, -0.1366, 0.0832, 0.0112, -0.1787, 0.0538, 0.2276, -0.2459, -0.2486, 0.0320, 0.1883, -0.1521, - 0.0550, 0.1757, 0.0771, 0.0541, 0.2494, 0.1629, -0.1406, 0.0216, 0.1277, 0.2205, -0.2487, 0.1354, - -0.0726, -0.2448, 0.0555, -0.1361, 0.0354, -0.1623, -0.1881, -0.0212, -0.0840, 0.1462, 0.0216, 0.1951}; - float mlp_dense_4h_to_h_weight[]{ - 0.0765, -0.1313, 0.2765, 0.0223, 0.0766, -0.1313, 0.0015, -0.0785, -0.0347, -0.3184, -0.0025, -0.3389, - -0.0878, 0.2366, -0.2839, -0.2752, -0.0997, -0.1119, 0.0691, 0.0678, 0.2070, -0.2812, -0.0288, 0.2283, - 0.3531, -0.0298, -0.2362, 0.1174, 0.2325, -0.2551, -0.2369, 0.1804, -0.0286, -0.3856, -0.2354, 0.1656, - -0.1385, 0.2818, 0.3227, 0.0594, -0.0323, -0.1296, -0.0206, 0.0743, -0.3118, -0.0976, -0.3396, 0.2506, - -0.2600, 0.3730, -0.1052, -0.2339, 0.1961, 0.0609, 0.2827, 0.1704, -0.3933, 0.2582, -0.0769, -0.1804, - 0.2593, 0.2977, -0.3588, -0.0369, 0.3353, 0.1581, 0.3439, -0.1399, -0.2252, 0.3511, 0.1702, 0.3919, - -0.1705, -0.2621, -0.0479, -0.3844, 0.1600, 0.3011, 0.0980, -0.0403, 0.2024, -0.2591, 0.3994, -0.4059, - -0.3911, -0.0965, 0.3335, 0.0409, 0.1568, -0.2992, 0.1489, -0.0456, 0.1636, 0.2883, 0.1775, -0.0347}; + fs::path test_path = fs::path(__FILE__).parent_path() / "tests/data/glm2_block.data"; + MappedFile mapped_file(test_path.string()); + char *ptr = mapped_file.data; constexpr int seq_len = 3; - constexpr int hidden_size = 16; - constexpr int num_attention_heads = 4; + constexpr int hidden_size = 32; + constexpr int num_attention_heads = 8; constexpr int num_kv_heads = 2; constexpr int ffn_hidden_size = 6; constexpr int max_length = 8; GLM2Block model(&ictx, hidden_size, num_attention_heads, num_kv_heads, ffn_hidden_size, max_length); - memcpy(model.input_layernorm.weight->data, input_layernorm_weight, sizeof(input_layernorm_weight)); - memcpy(model.attention.query_key_value.weight->data, attn_qkv_weight, sizeof(attn_qkv_weight)); - memcpy(model.attention.query_key_value.bias->data, attn_qkv_bias, sizeof(attn_qkv_bias)); - memcpy(model.attention.dense.weight->data, attn_dense_weight, sizeof(attn_dense_weight)); - memcpy(model.post_attention_layernorm.weight->data, post_attention_layernorm_weight, - sizeof(post_attention_layernorm_weight)); - memcpy(model.mlp.dense_h_to_4h.weight->data, mlp_dense_h_to_4h_weight, sizeof(mlp_dense_h_to_4h_weight)); - memcpy(model.mlp.dense_4h_to_h.weight->data, mlp_dense_4h_to_h_weight, sizeof(mlp_dense_4h_to_h_weight)); + ptr = map_tensor_data(ptr, model.input_layernorm.weight); + ptr = map_tensor_data(ptr, model.attention.query_key_value.weight); + ptr = map_tensor_data(ptr, model.attention.query_key_value.bias); + ptr = map_tensor_data(ptr, model.attention.dense.weight); + ptr = map_tensor_data(ptr, model.post_attention_layernorm.weight); + ptr = map_tensor_data(ptr, model.mlp.dense_h_to_4h.weight); + ptr = map_tensor_data(ptr, model.mlp.dense_4h_to_h.weight); + + ggml_tensor *x1 = ggml_new_tensor_2d(ictx.gctx.get(), GGML_TYPE_F32, hidden_size, seq_len); + ptr = map_tensor_data(ptr, x1); + x1 = ggml_dup(ictx.gctx.get(), x1); + + ggml_tensor *ref_y1 = ggml_new_tensor_2d(ictx.gctx.get(), GGML_TYPE_F32, hidden_size, seq_len); + ptr = map_tensor_data(ptr, ref_y1); + + ggml_tensor *x2 = ggml_new_tensor_1d(ictx.gctx.get(), GGML_TYPE_F32, hidden_size); + ptr = map_tensor_data(ptr, x2); + x2 = ggml_dup(ictx.gctx.get(), x2); + + ggml_tensor *ref_y2 = ggml_new_tensor_1d(ictx.gctx.get(), GGML_TYPE_F32, hidden_size); + ptr = map_tensor_data(ptr, ref_y2); + + ggml_tensor *x3 = ggml_new_tensor_1d(ictx.gctx.get(), GGML_TYPE_F32, hidden_size); + ptr = map_tensor_data(ptr, x3); + x3 = ggml_dup(ictx.gctx.get(), x3); + + ggml_tensor *ref_y3 = ggml_new_tensor_1d(ictx.gctx.get(), GGML_TYPE_F32, hidden_size); + ptr = map_tensor_data(ptr, ref_y3); + + ASSERT_EQ(ptr, mapped_file.data + mapped_file.size); // self attention { - float x_data[]{-0.3439, 0.2291, 0.4366, 0.2702, -0.2002, 1.0184, -0.4169, -0.0308, -0.9846, 0.2024, - 0.4156, -1.2037, -0.1333, -0.2934, 0.4649, 0.5762, -0.0578, 1.1715, -1.5185, 0.0475, - -0.0103, 1.4353, 0.2518, 0.7408, -1.0982, 0.1548, -1.3989, 0.0408, -1.1174, 0.7029, - -0.6675, 0.1254, 0.4047, -0.6549, 0.0521, 0.3401, -0.2124, 1.5629, -0.9072, -1.5662, - 0.0485, 0.9935, 2.1180, -0.1827, -0.7768, 1.7764, -0.5033, 0.0566}; - float y_data[]{-0.3964, 0.2281, 0.6048, -0.3200, -0.0622, 0.9805, -0.7734, -0.1910, -1.1877, 0.1297, - 0.4010, -1.0275, -0.0449, -0.5022, 0.5744, 0.6155, -0.1103, 1.1213, -1.3659, -0.3552, - 0.1080, 1.3184, 0.0149, 0.6601, -1.1811, 0.0690, -1.4537, 0.0565, -1.1463, 0.5818, - -0.6814, 0.0777, 0.2416, -0.6671, 0.2555, -0.0109, -0.1029, 1.5266, -1.0859, -1.6493, - -0.0626, 0.8874, 2.0539, -0.1481, -0.8140, 1.6529, -0.5594, 0.0603}; - - ggml_tensor *x = ggml_new_tensor_2d(ctx.gctx.get(), GGML_TYPE_F32, hidden_size, seq_len); - memcpy(x->data, x_data, sizeof(x_data)); - - ggml_tensor *y = model.forward(&ctx, x, 0); - ggml_build_forward_expand(&ctx.gf, y); + ggml_tensor *out_y1 = model.forward(&ctx, x1, 0); + ggml_build_forward_expand(&ctx.gf, out_y1); ggml_graph_compute(ctx.gctx.get(), &ctx.gf); - expect_all_close((float *)y->data, y_data, ggml_nelements(y)); + expect_all_close(ref_y1, out_y1, 5e-5); } // cross attention { - ctx.gf = {}; - ctx.gf.n_threads = 1; + reset_cgraph(); - float x2_data[]{-0.5481, -0.6917, -0.6559, 0.6949, -0.4671, -0.1680, 0.7585, 1.1881, - -0.6305, -0.0654, 0.6188, 2.0020, 0.2952, 0.5314, -0.5227, 0.0995}; - float y2_data[]{-0.6584, -0.6923, -0.5611, 0.5060, -0.4226, -0.2116, 0.5962, 1.1427, - -0.7516, -0.1508, 0.6161, 2.0421, 0.2349, 0.4791, -0.6062, 0.0516}; - - ggml_tensor *x2 = ggml_new_tensor_2d(ctx.gctx.get(), GGML_TYPE_F32, hidden_size, 1); - memcpy(x2->data, x2_data, sizeof(x2_data)); - - ggml_tensor *y2 = model.forward(&ctx, x2, seq_len); - ggml_build_forward_expand(&ctx.gf, y2); + ggml_tensor *out_y2 = model.forward(&ctx, x2, seq_len); + ggml_build_forward_expand(&ctx.gf, out_y2); ggml_graph_compute(ctx.gctx.get(), &ctx.gf); - expect_all_close((float *)y2->data, y2_data, ggml_nelements(y2)); + expect_all_close(ref_y2, out_y2, 5e-5); } { - ctx.gf = {}; - ctx.gf.n_threads = 1; - - float x3_data[]{-0.1169, 0.3728, -1.0211, 1.0795, 1.1469, -0.1733, 0.0637, -1.2699, - -0.6212, -0.2381, 0.0892, 1.8008, -2.0627, 0.3222, -1.1390, 1.2418}; - float y3_data[]{-0.2959, 0.3463, -0.9365, 0.9941, 1.1817, -0.1974, -0.0675, -1.3237, - -0.7128, -0.3565, 0.0678, 1.8228, -2.1314, 0.3017, -1.2696, 1.1742}; - - ggml_tensor *x3 = ggml_new_tensor_2d(ctx.gctx.get(), GGML_TYPE_F32, hidden_size, 1); - memcpy(x3->data, x3_data, sizeof(x3_data)); + reset_cgraph(); - ggml_tensor *y3 = model.forward(&ctx, x3, seq_len + 1); - ggml_build_forward_expand(&ctx.gf, y3); + ggml_tensor *out_y3 = model.forward(&ctx, x3, seq_len + 1); + ggml_build_forward_expand(&ctx.gf, out_y3); ggml_graph_compute(ctx.gctx.get(), &ctx.gf); - expect_all_close((float *)y3->data, y3_data, ggml_nelements(y3)); + expect_all_close(ref_y3, out_y3, 5e-5); } } diff --git a/tests/data/glm2_block.data b/tests/data/glm2_block.data new file mode 100644 index 0000000..0c18db5 Binary files /dev/null and b/tests/data/glm2_block.data differ diff --git a/tests/data/glm_block.data b/tests/data/glm_block.data new file mode 100644 index 0000000..a7d3e9b Binary files /dev/null and b/tests/data/glm_block.data differ diff --git a/tests/data/layer_norm.data b/tests/data/layer_norm.data new file mode 100644 index 0000000..b72b7d5 Binary files /dev/null and b/tests/data/layer_norm.data differ diff --git a/tests/data/linear.data b/tests/data/linear.data new file mode 100644 index 0000000..febc77d Binary files /dev/null and b/tests/data/linear.data differ diff --git a/tests/data/rms_norm.data b/tests/data/rms_norm.data new file mode 100644 index 0000000..7822cde Binary files /dev/null and b/tests/data/rms_norm.data differ diff --git a/tests/test_convert.py b/tests/test_convert.py index 39abb68..e8be3ad 100644 --- a/tests/test_convert.py +++ b/tests/test_convert.py @@ -5,6 +5,8 @@ import torch.nn.functional as F from convert import quantize_q4_0, quantize_q4_1, quantize_q5_0, quantize_q5_1, quantize_q8_0 +HERE = Path(__file__).resolve().parent + # generated by: # torch.manual_seed(0) # weight = torch.randn(2, 128) @@ -167,11 +169,11 @@ def test_quantize_q5_1(): CHATGLM_MODEL_PATH = Path( - "~/.cache/huggingface/hub/models--THUDM--chatglm-6b/snapshots/a70fe6b0a3cf1675b3aec07e3b7bb7a8ce73c6ae" + "~/.cache/huggingface/hub/models--THUDM--chatglm-6b/snapshots/294cb13118a1e08ad8449ca542624a5c6aecc401" ).expanduser() CHATGLM2_MODEL_PATH = Path( - "~/.cache/huggingface/hub/models--THUDM--chatglm2-6b/snapshots/fc442f7e7cf3ac073433cef0f301b4744c25edb6" + "~/.cache/huggingface/hub/models--THUDM--chatglm2-6b/snapshots/0ecfe0b857efd00836a4851b3dd2ed04bd4b197f" ).expanduser() @@ -185,47 +187,58 @@ def make_data_embedding(): def make_data_linear(): - w = torch.randn(8, 3) - b = torch.randn(8) - x = torch.randn(2, 3) + w = torch.randn(16, 32) + b = torch.randn(16) + x = torch.randn(2, 32) y = F.linear(x, w, b) - print("w", w.flatten()) - print("b", b.flatten()) - print("x", x.flatten()) - print("y", y.flatten()) + with open(HERE / "data/linear.data", "wb") as f: + w.numpy().tofile(f) + b.numpy().tofile(f) + x.numpy().tofile(f) + y.numpy().tofile(f) def make_data_layernorm(): - w = torch.randn(9) - b = torch.randn(9) - x = torch.randn(2, 9) - y = F.layer_norm(x, [9], w, b) - print("w", w.flatten()) - print("b", b.flatten()) - print("x", x.flatten()) - print("y", y.flatten()) + w = torch.randn(64) + b = torch.randn(64) + x = torch.randn(3, 64) + y = F.layer_norm(x, [64], w, b) + + with open(HERE / "data/layer_norm.data", "wb") as f: + w.numpy().tofile(f) + b.numpy().tofile(f) + x.numpy().tofile(f) + y.numpy().tofile(f) def make_data_rms_norm(): from modeling_chatglm import RMSNorm - m = RMSNorm(7, eps=1e-6).eval() + m = RMSNorm(64, eps=1e-6).eval() m.weight.data.uniform_() - x = torch.randn(2, 7) + x = torch.randn(3, 64) with torch.no_grad(): y = m(x) - print("weight", m.weight.data.flatten()) - print("x", x.flatten()) - print("y", y.flatten()) + + with open(HERE / "data/rms_norm.data", "wb") as f: + m.weight.data.numpy().tofile(f) + x.numpy().tofile(f) + y.numpy().tofile(f) -def make_data_glm_self_attention(): - from modeling_chatglm import SelfAttention +def make_data_glm_block(): + from modeling_chatglm import GLMBlock - m = SelfAttention(16, 2, layer_id=3, empty_init=False).float().eval() - x = torch.randn(4, 1, 16) # [seqlen, bs, hidden] + m = ( + GLMBlock( + hidden_size=32, num_attention_heads=8, layernorm_epsilon=1e-5, layer_id=3, num_layers=28, empty_init=False + ) + .float() + .eval() + ) + x1 = torch.randn(4, 1, 32) # [seqlen, bs, hidden] position_ids = torch.tensor([[[0, 1, 2, 2], [0, 0, 0, 1]]]) attention_mask = torch.tensor( [ @@ -237,23 +250,15 @@ def make_data_glm_self_attention(): [0, 0, 0, 0], ] ] - ] - ).bool() - y, layer_past = m( - x, - position_ids=position_ids, - attention_mask=attention_mask, - layer_id=m.layer_id, - use_cache=True, + ], + dtype=torch.bool, + ) + y1, layer_past = m( + x1, position_ids=position_ids, attention_mask=attention_mask, layer_id=m.layer_id, use_cache=True ) - print("x", x.flatten()) - print("query_key_value.weight", m.query_key_value.weight.flatten()) - print("query_key_value.bias", m.query_key_value.bias.flatten()) - print("dense.weight", m.dense.weight.flatten()) - print("dense.bias", m.dense.bias.flatten()) - print("y", y.flatten()) - x2 = torch.randn(1, 1, 16) + # cross attention + x2 = torch.randn(1, 1, 32) position_ids = torch.tensor([[[2], [2]]]) attention_mask = torch.zeros(1, 1, dtype=torch.bool) y2, layer_past = m( @@ -264,10 +269,8 @@ def make_data_glm_self_attention(): layer_past=layer_past, use_cache=True, ) - print("x2", x2.flatten()) - print("y2", y2.flatten()) - x3 = torch.randn(1, 1, 16) + x3 = torch.randn(1, 1, 32) position_ids = torch.tensor([[[2], [3]]]) attention_mask = torch.zeros(1, 1, dtype=torch.bool) y3, layer_past = m( @@ -278,47 +281,29 @@ def make_data_glm_self_attention(): layer_past=layer_past, use_cache=True, ) - print("x3", x3.flatten()) - print("y3", y3.flatten()) + print(m) -def make_data_glm_block(): - from modeling_chatglm import GLMBlock - - m = ( - GLMBlock(hidden_size=8, num_attention_heads=2, layernorm_epsilon=1e-5, layer_id=3, empty_init=False) - .float() - .eval() - ) - x = torch.randn(4, 1, 8) # [seqlen, bs, hidden] - position_ids = torch.tensor([[[0, 1, 2, 2], [0, 0, 0, 1]]]) - attention_mask = torch.tensor( - [ - [ - [ - [0, 0, 0, 1], - [0, 0, 0, 1], - [0, 0, 0, 1], - [0, 0, 0, 0], - ] - ] - ] - ).bool() - (y,) = m(x, position_ids=position_ids, attention_mask=attention_mask, layer_id=m.layer_id) - print("x", x.flatten()) - print("input_layernorm.weight", m.input_layernorm.weight.data.flatten()) - print("input_layernorm.bias", m.input_layernorm.bias.data.flatten()) - print("query_key_value.weight", m.attention.query_key_value.weight.data.flatten()) - print("query_key_value.bias", m.attention.query_key_value.bias.data.flatten()) - print("dense.weight", m.attention.dense.weight.data.flatten()) - print("dense.bias", m.attention.dense.bias.data.flatten()) - print("post_attention_layernorm.weight", m.post_attention_layernorm.weight.data.flatten()) - print("post_attention_layernorm.bias", m.post_attention_layernorm.bias.data.flatten()) - print("dense_h_to_4h.weight", m.mlp.dense_h_to_4h.weight.data.flatten()) - print("dense_h_to_4h.bias", m.mlp.dense_h_to_4h.bias.data.flatten()) - print("dense_4h_to_h.weight", m.mlp.dense_4h_to_h.weight.data.flatten()) - print("dense_4h_to_h.bias", m.mlp.dense_4h_to_h.bias.data.flatten()) - print("y", y.flatten()) + with open(HERE / "data/glm_block.data", "wb") as f: + m.input_layernorm.weight.data.numpy().tofile(f) + m.input_layernorm.bias.data.numpy().tofile(f) + m.attention.query_key_value.weight.data.numpy().tofile(f) + m.attention.query_key_value.bias.data.numpy().tofile(f) + m.attention.dense.weight.data.numpy().tofile(f) + m.attention.dense.bias.data.numpy().tofile(f) + m.post_attention_layernorm.weight.data.numpy().tofile(f) + m.post_attention_layernorm.bias.data.numpy().tofile(f) + m.mlp.dense_h_to_4h.weight.data.numpy().tofile(f) + m.mlp.dense_h_to_4h.bias.data.numpy().tofile(f) + m.mlp.dense_4h_to_h.weight.data.numpy().tofile(f) + m.mlp.dense_4h_to_h.bias.data.numpy().tofile(f) + + x1.numpy().tofile(f) + y1.data.numpy().tofile(f) + x2.numpy().tofile(f) + y2.data.numpy().tofile(f) + x3.numpy().tofile(f) + y3.data.numpy().tofile(f) def make_data_glm2_block(): @@ -327,8 +312,8 @@ def make_data_glm2_block(): config = AutoConfig.from_pretrained(CHATGLM2_MODEL_PATH, trust_remote_code=True) config.layernorm_epsilon = 1e-6 - config.hidden_size = 16 - config.num_attention_heads = 4 + config.hidden_size = 32 + config.num_attention_heads = 8 config.multi_query_group_num = 2 config.ffn_hidden_size = 6 config.kv_channels = config.hidden_size // config.num_attention_heads @@ -343,45 +328,52 @@ def make_data_glm2_block(): rotary_pos_emb = rotary_pos_emb_module(8)[None, :seq_length].transpose(0, 1).contiguous() # self attention - x = torch.randn(seq_length, 1, config.hidden_size) + x1 = torch.randn(seq_length, 1, config.hidden_size) with torch.no_grad(): - y, kv_cache = m(x, attention_mask=None, rotary_pos_emb=rotary_pos_emb) - - print(m) - - print("input_layernorm.weight", m.input_layernorm.weight.data.flatten()) - print("attn.qkv.weight", m.self_attention.query_key_value.weight.data.flatten()) - print("attn.qkv.bias", m.self_attention.query_key_value.bias.data.flatten()) - print("attn.dense.weight", m.self_attention.dense.weight.data.flatten()) - print("post_attention_layernorm.weight", m.post_attention_layernorm.weight.data.flatten()) - print("mlp.dense_h_to_4h.weight", m.mlp.dense_h_to_4h.weight.data.flatten()) - print("mlp.dense_4h_to_h.weight", m.mlp.dense_4h_to_h.weight.data.flatten()) - - print("x", x.flatten()) - print("y", y.flatten()) + y1, kv_cache = m(x1, attention_mask=None, rotary_pos_emb=rotary_pos_emb) # cross attention position_ids = torch.tensor([[seq_length]]) rotary_pos_emb = rotary_pos_emb_module(8)[position_ids].transpose(0, 1).contiguous() - x = torch.randn(1, 1, config.hidden_size) + x2 = torch.randn(1, 1, config.hidden_size) with torch.no_grad(): - y, kv_cache = m(x, attention_mask=None, rotary_pos_emb=rotary_pos_emb, kv_cache=kv_cache) - print("x2", x.flatten()) - print("y2", y.flatten()) + y2, kv_cache = m(x2, attention_mask=None, rotary_pos_emb=rotary_pos_emb, kv_cache=kv_cache) + # cross attention position_ids = torch.tensor([[seq_length + 1]]) rotary_pos_emb = rotary_pos_emb_module(8)[position_ids].transpose(0, 1).contiguous() - x = torch.randn(1, 1, config.hidden_size) + x3 = torch.randn(1, 1, config.hidden_size) with torch.no_grad(): - y, kv_cache = m(x, attention_mask=None, rotary_pos_emb=rotary_pos_emb, kv_cache=kv_cache) - print("x3", x.flatten()) - print("y3", y.flatten()) + y3, kv_cache = m(x3, attention_mask=None, rotary_pos_emb=rotary_pos_emb, kv_cache=kv_cache) + + print(m) + + with open(HERE / "data/glm2_block.data", "wb") as f: + m.input_layernorm.weight.data.numpy().tofile(f) + m.self_attention.query_key_value.weight.data.numpy().tofile(f) + m.self_attention.query_key_value.bias.data.numpy().tofile(f) + m.self_attention.dense.weight.data.numpy().tofile(f) + m.post_attention_layernorm.weight.data.numpy().tofile(f) + m.mlp.dense_h_to_4h.weight.data.numpy().tofile(f) + m.mlp.dense_4h_to_h.weight.data.numpy().tofile(f) + + x1.numpy().tofile(f) + y1.numpy().tofile(f) + x2.numpy().tofile(f) + y2.numpy().tofile(f) + x3.numpy().tofile(f) + y3.numpy().tofile(f) def main(): - sys.path.append(str(CHATGLM_MODEL_PATH)) + sys.path.append(str(CHATGLM2_MODEL_PATH)) torch.manual_seed(0) - make_data_glm2_block() + (HERE / "data").mkdir(parents=True, exist_ok=True) + # make_data_linear() + make_data_layernorm() + # make_data_rms_norm() + # make_data_glm_block() + # make_data_glm2_block() if __name__ == "__main__":