Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,104 changes: 823 additions & 281 deletions ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp

Large diffs are not rendered by default.

810 changes: 306 additions & 504 deletions ggml/src/ggml-webgpu/ggml-webgpu.cpp

Large diffs are not rendered by default.

137 changes: 65 additions & 72 deletions ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl
Original file line number Diff line number Diff line change
@@ -1,82 +1,79 @@
#decl(BYTE_HELPERS)

#ifdef BYTE_HELPERS
fn get_byte(value: u32, index: u32) -> u32 {
return (value >> (index * 8)) & 0xFF;
}

fn get_byte_i32(value: u32, index: u32) -> i32 {
return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24;
}
#endif

#enddecl(BYTE_HELPERS)

#decl(Q4_0_T)
#ifdef Q4_0_T
struct q4_0 {
d: f16,
qs: array<f16, 8>
};
#enddecl(Q4_0_T)
#endif

#decl(Q4_1_T)
#ifdef Q4_1_T
struct q4_1 {
d: f16,
m: f16,
qs: array<u32, 4>
};
#enddecl(Q4_1_T)
#endif

#decl(Q5_0_T)
#ifdef Q5_0_T
struct q5_0 {
d: f16,
qh: array<f16, 2>,
qs: array<f16, 8>
};
#enddecl(Q5_0_T)
#endif

#decl(Q5_1_T)
#ifdef Q5_1_T
struct q5_1 {
d: f16,
m: f16,
qh: u32,
qs: array<u32, 4>
};
#enddecl(Q5_1_T)
#endif

#decl(Q8_0_T)
#ifdef Q8_0_T
struct q8_0 {
d: f16,
qs: array<f16, 16>
};
#enddecl(Q8_0_T)
#endif

#decl(Q8_1_T)
#ifdef Q8_1_T
struct q8_1 {
d: f16,
m: f16,
qs: array<u32, 8>
};
#enddecl(Q8_1_T)
#endif

#decl(Q2_K_T)
struct q2_k {
#ifdef Q2_K_T
struct q2_K {
scales: array<u32, 4>,
qs: array<u32, 16>,
d: f16,
dmin: f16
};
#enddecl(Q2_K_T)
#endif

#decl(Q3_K_T)
struct q3_k {
#ifdef Q3_K_T
struct q3_K {
hmask: array<f16, 16>,
qs: array<f16, 32>,
scales: array<f16, 6>,
d: f16
};
#enddecl(Q3_K_T)

#decl(Q45_K_SCALE_MIN)
#endif

#if defined(Q4_K_SCALE_MIN) || defined(Q5_K_SCALE_MIN)
fn get_scale_min(is: u32, scales: array<u32, 3>) -> vec2<f32> {
if (is < 4) {
let sc_byte = get_byte(scales[is / 4], is % 4);
Expand All @@ -91,111 +88,109 @@ fn get_scale_min(is: u32, scales: array<u32, 3>) -> vec2<f32> {
return vec2(f32(sc), f32(m));
}
}

#enddecl(Q45_K_SCALE_MIN)

#decl(Q4_K_T)
struct q4_k {
#endif
#ifdef Q4_K_T
struct q4_K {
d: f16,
dmin: f16,
scales: array<u32, 3>,
qs: array<u32, 32>
};
#enddecl(Q4_K_T)
#endif

#decl(Q5_K_T)
struct q5_k {
#ifdef Q5_K_T
struct q5_K {
d: f16,
dmin: f16,
scales: array<u32, 3>,
qh: array<u32, 8>,
qs: array<u32, 32>
};
#enddecl(Q5_K_T)
#endif

#decl(Q6_K_T)
struct q6_k {
#ifdef Q6_K_T
struct q6_K {
ql: array<f16, 64>,
qh: array<f16, 32>,
scales: array<f16, 8>,
d: f16
};
#enddecl(Q6_K_T)
#endif

#decl(IQ2_XXS_T)
#ifdef IQ2_XXS_T
struct iq2_xxs {
d: f16,
qs: array<f16, 32>
};
#enddecl(IQ2_XXS_T)
#endif

#decl(IQ2_XS_T)
#ifdef IQ2_XS_T
struct iq2_xs {
d: f16,
qs: array<f16, 32>,
scales: array<f16, 4>
};
#enddecl(IQ2_XS_T)
#endif

#decl(IQ2_S_T)
#ifdef IQ2_S_T
struct iq2_s {
d: f16,
qs: array<f16, 32>,
qh: array<f16, 4>,
scales: array<f16, 4>
};
#enddecl(IQ2_S_T)
#endif

#decl(IQ3_XSS_T)
#ifdef IQ3_XXS_T
struct iq3_xxs {
d: f16,
qs: array<f16, 48>
};
#enddecl(IQ3_XSS_T)
#endif

#decl(IQ3_S_T)
#ifdef IQ3_S_T
struct iq3_s {
d: f16,
qs: array<f16, 32>,
qh: array<f16, 4>,
signs: array<f16, 16>,
scales: array<f16, 2>
};
#enddecl(IQ3_S_T)
#endif

#decl(IQ1_S_T)
#ifdef IQ1_S_T
struct iq1_s {
d: f16,
qs: array<f16, 16>,
qh: array<f16, 8>
};
#enddecl(IQ1_S_T)
#endif

#decl(IQ1_M_T)
#ifdef IQ1_M_T
struct iq1_m {
qs: array<u32, 8>,
qh: array<u32, 4>,
scales: array<u32, 2>
};
#enddecl(IQ1_M_T)
#endif

#decl(IQ4_NL_T)
#ifdef IQ4_NL_T
struct iq4_nl {
d: f16,
qs: array<f16, 8>,
};
#enddecl(IQ4_NL_T)
#endif

#decl(IQ4_XS_T)
#ifdef IQ4_XS_T
struct iq4_xs {
d: f16,
scales_h: f16,
scales_l: u32,
qs: array<u32, 32>
};
#enddecl(IQ4_XS_T)
#endif

#decl(IQ23_TABLES)
#if defined(IQ2_XXS_TABLES) || defined(IQ2_XS_TABLES) || defined(IQ2_S_TABLES) || defined(IQ3_XXS_TABLES) || defined(IQ3_S_TABLES)
const kmask_iq2xs : array<u32, 2> = array<u32, 2>(
0x08040201u, // 1, 2, 4, 8
0x80402010u // 16, 32, 64, 128
Expand All @@ -211,9 +206,9 @@ const ksigns_iq2xs: array<u32, 32> = array<u32, 32>(
0x63e2e160,0xe76665e4,0xeb6a69e8,0x6feeed6c,
0xf37271f0,0x77f6f574,0x7bfaf978,0xff7e7dfc
);
#enddecl(IQ23_TABLES)
#endif

#decl(IQ2_XXS_GRID)
#ifdef IQ2_XXS_GRID
const iq2xxs_grid = array<u32, 512>(
0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808,
0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x082b0808, 0x08080808,
Expand Down Expand Up @@ -280,9 +275,9 @@ const iq2xxs_grid = array<u32, 512>(
0x0808082b, 0x2b2b0808, 0x19190808, 0x2b2b0808, 0x2b081919, 0x2b2b0808, 0x08082b19, 0x2b2b0819,
0x08080808, 0x2b2b082b, 0x08192b08, 0x2b2b1908, 0x19190808, 0x2b2b2b08, 0x08081908, 0x2b2b2b19
);
#enddecl(IQ2_XXS_GRID)
#endif

#decl(IQ2_XS_GRID)
#ifdef IQ2_XS_GRID
const iq2xs_grid = array<u32, 1024>(
0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808,
0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808,
Expand Down Expand Up @@ -413,9 +408,9 @@ const iq2xs_grid = array<u32, 1024>(
0x2b2b2b08, 0x2b2b2b08, 0x08081908, 0x2b2b2b19, 0x2b081908, 0x2b2b2b19, 0x2b08192b, 0x2b2b2b19,
0x082b2b08, 0x2b2b2b2b, 0x082b2b2b, 0x2b2b2b2b, 0x2b190819, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b
);
#enddecl(IQ2_XS_GRID)
#endif

#decl(IQ2_S_GRID)
#ifdef IQ2_S_GRID
const iq2s_grid = array<u32, 2048>(
0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808,
0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808,
Expand Down Expand Up @@ -674,10 +669,9 @@ const iq2s_grid = array<u32, 2048>(
0x2b08192b, 0x2b2b2b19, 0x08082b08, 0x2b2b2b2b, 0x08082b2b, 0x2b2b2b2b, 0x082b0808, 0x2b2b2b2b,
0x082b082b, 0x2b2b2b2b, 0x082b2b08, 0x2b2b2b2b, 0x2b082b08, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b
);
#enddecl(IQ2_S_GRID)

#decl(IQ3_XSS_GRID)
#endif

#ifdef IQ3_XXS_GRID
const iq3xxs_grid = array<u32, 256>(
0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
Expand Down Expand Up @@ -712,10 +706,9 @@ const iq3xxs_grid = array<u32, 256>(
0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04
);
#enddecl(IQ3_XSS_GRID)

#decl(IQ3_S_GRID)
#endif

#ifdef IQ3_S_GRID
const iq3s_grid = array<u32, 512>(
0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305,
0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905,
Expand Down Expand Up @@ -782,9 +775,9 @@ const iq3s_grid = array<u32, 512>(
0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b,
0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101
);
#enddecl(IQ3_S_GRID)
#endif

#decl(IQ1_GRID)
#if defined(IQ1_S_GRID) || defined(IQ1_M_GRID)

const IQ1_DELTA: f32 = 0.125;

Expand Down Expand Up @@ -919,12 +912,12 @@ const iq1_grid = array<u32, 1024>(
0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557
);

#enddecl(IQ1_GRID)
#endif

#decl(IQ4_GRID)
#if defined(IQ4_NL_GRID) || defined(IQ4_XS_GRID)

const kvalues_iq4nl = array<i32, 16>(
-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113
);

#enddecl(IQ4_GRID)
#endif
8 changes: 5 additions & 3 deletions ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def replacer(match):
return include_pattern.sub(replacer, shader)


def write_shader(shader_name, shader_code, output_dir, outfile):
def write_shader(shader_name, shader_code, output_dir, outfile, input_dir):
shader_code = expand_includes(shader_code, input_dir)

if output_dir:
wgsl_filename = os.path.join(output_dir, f"{shader_name}.wgsl")
with open(wgsl_filename, "w", encoding="utf-8") as f_out:
Expand All @@ -74,7 +76,7 @@ def generate_variants(fname, input_dir, output_dir, outfile):
try:
variants = ast.literal_eval(extract_block(text, "VARIANTS"))
except ValueError:
write_shader(shader_base_name, text, output_dir, outfile)
write_shader(shader_base_name, text, output_dir, outfile, input_dir)
else:
try:
decls_map = parse_decls(extract_block(text, "DECLS"))
Expand Down Expand Up @@ -123,7 +125,7 @@ def generate_variants(fname, input_dir, output_dir, outfile):
output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"]
else:
output_name = shader_base_name
write_shader(output_name, final_shader, output_dir, outfile)
write_shader(output_name, final_shader, output_dir, outfile, input_dir)


def main():
Expand Down
Loading