@@ -109,55 +109,74 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
109109 self .intermediate_size = intermediate_size_per_partition_after_pad
110110 self .hidden_size = hidden_size
111111 # Fused gate_up_proj (column parallel)
112- w13_weight = torch .nn .Parameter (torch .zeros (
113- num_experts ,
114- 2 * intermediate_size_per_partition_after_pad ,
115- hidden_size // 2 ,
116- dtype = weight_dtype ),
117- requires_grad = False )
112+ w13_weight = torch .nn .Parameter (
113+ torch .zeros (
114+ num_experts ,
115+ 2 * intermediate_size_per_partition_after_pad ,
116+ hidden_size // 2 ,
117+ dtype = weight_dtype ,
118+ ),
119+ requires_grad = False ,
120+ )
118121 layer .register_parameter ("w13_weight" , w13_weight )
119122 set_weight_attrs (w13_weight , extra_weight_attrs )
120123
121- w13_weight_scale = torch .nn .Parameter (torch .zeros (
122- num_experts ,
123- 2 * intermediate_size_per_partition_after_pad ,
124- hidden_size // mxfp4_block ,
125- dtype = scale_dtype ),
126- requires_grad = False )
124+ w13_weight_scale = torch .nn .Parameter (
125+ torch .zeros (
126+ num_experts ,
127+ 2 * intermediate_size_per_partition_after_pad ,
128+ hidden_size // mxfp4_block ,
129+ dtype = scale_dtype ,
130+ ),
131+ requires_grad = False ,
132+ )
127133 layer .register_parameter ("w13_weight_scale" , w13_weight_scale )
128134 set_weight_attrs (w13_weight_scale , extra_weight_attrs )
129135
130- w13_bias = torch .nn .Parameter (torch .zeros (
131- num_experts ,
132- 2 * intermediate_size_per_partition_after_pad ,
133- dtype = torch .bfloat16 ),
134- requires_grad = False )
136+ w13_bias = torch .nn .Parameter (
137+ torch .zeros (
138+ num_experts ,
139+ 2 * intermediate_size_per_partition_after_pad ,
140+ dtype = torch .bfloat16 ,
141+ ),
142+ requires_grad = False ,
143+ )
135144 layer .register_parameter ("w13_bias" , w13_bias )
136145 set_weight_attrs (w13_bias , extra_weight_attrs )
137146
138147 # down_proj (row parallel)
139- w2_weight = torch .nn .Parameter (torch .zeros (
140- num_experts ,
141- hidden_size ,
142- intermediate_size_per_partition_after_pad // 2 ,
143- dtype = weight_dtype ),
144- requires_grad = False )
148+ w2_weight = torch .nn .Parameter (
149+ torch .zeros (
150+ num_experts ,
151+ hidden_size ,
152+ intermediate_size_per_partition_after_pad // 2 ,
153+ dtype = weight_dtype ,
154+ ),
155+ requires_grad = False ,
156+ )
145157 layer .register_parameter ("w2_weight" , w2_weight )
146158 set_weight_attrs (w2_weight , extra_weight_attrs )
147159
148- w2_weight_scale = torch .nn .Parameter (torch .zeros (
149- num_experts ,
150- hidden_size ,
151- intermediate_size_per_partition_after_pad // mxfp4_block ,
152- dtype = scale_dtype ),
153- requires_grad = False )
160+ w2_weight_scale = torch .nn .Parameter (
161+ torch .zeros (
162+ num_experts ,
163+ hidden_size ,
164+ intermediate_size_per_partition_after_pad // mxfp4_block ,
165+ dtype = scale_dtype ,
166+ ),
167+ requires_grad = False ,
168+ )
154169 layer .register_parameter ("w2_weight_scale" , w2_weight_scale )
155170 set_weight_attrs (w2_weight_scale , extra_weight_attrs )
156171
157- w2_bias = torch .nn .Parameter (torch .zeros (num_experts ,
158- hidden_size ,
159- dtype = torch .bfloat16 ),
160- requires_grad = False )
172+ w2_bias = torch .nn .Parameter (
173+ torch .zeros (
174+ num_experts ,
175+ hidden_size ,
176+ dtype = torch .bfloat16 ,
177+ ),
178+ requires_grad = False ,
179+ )
161180 layer .register_parameter ("w2_bias" , w2_bias )
162181 set_weight_attrs (w2_bias , extra_weight_attrs )
163182
0 commit comments