|
21 | 21 |
|
22 | 22 | PRESET_MAP = { |
23 | 23 | "qwen3_moe_30b_a3b_en": "Qwen/Qwen3-30B-A3B", |
| 24 | + "qwen3_moe_235b_a22b_en": "Qwen/Qwen3-235B-A22B", |
24 | 25 | } |
25 | 26 |
|
26 | 27 | FLAGS = flags.FLAGS |
@@ -85,21 +86,11 @@ def test_tokenizer(keras_hub_tokenizer, hf_tokenizer): |
85 | 86 | np.testing.assert_equal(keras_hub_output, hf_output) |
86 | 87 |
|
87 | 88 |
|
88 | | -def validate_output( |
89 | | - keras_hub_model, keras_hub_tokenizer, hf_model, hf_tokenizer |
90 | | -): |
| 89 | +def validate_output(qwen3_moe_lm, hf_model, hf_tokenizer): |
91 | 90 | input_str = "What is Keras?" |
92 | 91 | length = 32 |
93 | 92 |
|
94 | | - # KerasHub |
95 | | - preprocessor = keras_hub.models.Qwen3MoeCausalLMPreprocessor( |
96 | | - keras_hub_tokenizer |
97 | | - ) |
98 | | - qwen_moe_lm = keras_hub.models.Qwen3MoeCausalLM( |
99 | | - backbone=keras_hub_model, preprocessor=preprocessor, sampler="greedy" |
100 | | - ) |
101 | | - |
102 | | - keras_output = qwen_moe_lm.generate([input_str], max_length=length) |
| 93 | + keras_output = qwen3_moe_lm.generate([input_str], max_length=length) |
103 | 94 | keras_output = keras_output[0] |
104 | 95 | print("🔶 KerasHub output:", keras_output) |
105 | 96 |
|
@@ -150,11 +141,16 @@ def main(_): |
150 | 141 | test_tokenizer(keras_hub_tokenizer, hf_tokenizer) |
151 | 142 | test_model(keras_hub_model, keras_hub_tokenizer, hf_model, hf_tokenizer) |
152 | 143 |
|
153 | | - # == Validate model.generate output == |
154 | | - validate_output( |
155 | | - keras_hub_model, keras_hub_tokenizer, hf_model, hf_tokenizer |
| 144 | + preprocessor = keras_hub.models.Qwen3MoeCausalLMPreprocessor( |
| 145 | + keras_hub_tokenizer |
156 | 146 | ) |
| 147 | + qwen3_moe_lm = keras_hub.models.Qwen3MoeCausalLM( |
| 148 | + backbone=keras_hub_model, preprocessor=preprocessor, sampler="greedy" |
| 149 | + ) |
| 150 | + # == Validate model.generate output == |
| 151 | + validate_output(qwen3_moe_lm, hf_model, hf_tokenizer) |
157 | 152 | print("\n-> Tests passed!") |
| 153 | + qwen3_moe_lm.save_to_preset(f"./{preset}") |
158 | 154 |
|
159 | 155 |
|
160 | 156 | if __name__ == "__main__": |
|
0 commit comments