Skip to content

Commit 6a035bc

Browse files
authored
fix: Fix process_weights_after_loading for fp8 dense (#1432)
Signed-off-by: Guyue Huang <[email protected]>
1 parent 3350ba2 commit 6a035bc

File tree

3 files changed

+41
-57
lines changed

3 files changed

+41
-57
lines changed

nemo_rl/algorithms/grpo.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1321,6 +1321,7 @@ def grpo_train(
13211321
print("\n📊 Training Results:")
13221322

13231323
print(f" • Loss: {metrics['loss']:.4f}")
1324+
print(f" • Generation KL Error: {metrics['gen_kl_error']:.4f}")
13241325
if master_config["grpo"]["use_dynamic_sampling"]:
13251326
print(f" • Avg Filtered Reward: {np.mean(rewards.numpy()):.4f}")
13261327
print(
@@ -2184,6 +2185,7 @@ def async_grpo_train(
21842185

21852186
print("\n📊 Training Results:")
21862187
print(f" • Loss: {metrics['loss']:.4f}")
2188+
print(f" • Generation KL Error: {metrics['gen_kl_error']:.4f}")
21872189
print(f" • Avg Reward: {np.mean(rewards.numpy()):.4f}")
21882190
print(f" • Buffer Size: {buffer_size_current}")
21892191
print(f" • Avg Trajectory Age: {avg_trajectory_age:.2f} steps")

nemo_rl/models/generation/fp8.py

Lines changed: 38 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -302,18 +302,8 @@ def load_weights(weights, model_runner):
302302
param_scale = torch.squeeze(param_scale, dim=-1)
303303
weights_quantized.append([k, param_lp])
304304
weights_quantized.append([k + "_scale_inv", param_scale])
305-
# Monkey patch the param class to their subclass, as certain models
306-
# will check the param type to call the proper weightloader
307-
for name, param in model.named_parameters():
308-
if hasattr(param, "subclass_type"):
309-
param.orig_type = param.__class__
310-
param.__class__ = param.subclass_type
311305
# Finally load the weights into vllm
312306
model.load_weights(weights_quantized)
313-
# Undo the type change above to the original type
314-
for name, param in model.named_parameters():
315-
if hasattr(param, "subclass_type"):
316-
param.__class__ = param.orig_type
317307

318308

319309
def cast_tensor_to_fp8_blockwise(
@@ -324,12 +314,25 @@ def cast_tensor_to_fp8_blockwise(
324314

325315
block_size1 = weight_block_size[1]
326316
block_size0 = weight_block_size[0]
327-
assert data_hp.shape[1] % block_size1 == 0, (
328-
f"data_hp.shape[1] {data_hp.shape[1]} must be a multiple of block_size1: {block_size1}."
329-
)
330-
assert data_hp.shape[0] % block_size0 == 0, (
331-
f"data_hp.shape[0] {data_hp.shape[0]} must be a multiple of block_size0: {block_size0}."
332-
)
317+
shape_before_padding = data_hp.shape
318+
# pad data_hp to make its shape a multiple of weight_block_size with the last element of data_hp
319+
if data_hp.shape[1] % block_size1 != 0 or data_hp.shape[0] % block_size0 != 0:
320+
pad1 = (
321+
0
322+
if data_hp.shape[1] % block_size1 == 0
323+
else block_size1 - data_hp.shape[1] % block_size1
324+
)
325+
pad0 = (
326+
0
327+
if data_hp.shape[0] % block_size0 == 0
328+
else block_size0 - data_hp.shape[0] % block_size0
329+
)
330+
print(
331+
f"Padding data_hp from {data_hp.shape} to {(data_hp.shape[0] + pad0, data_hp.shape[1] + pad1)}"
332+
)
333+
data_hp = torch.nn.functional.pad(
334+
data_hp, (0, pad1, 0, pad0), mode="constant", value=data_hp[-1, -1]
335+
)
333336

334337
# FP8
335338
max_dtype = torch.finfo(torch.float8_e4m3fn).max
@@ -385,57 +388,35 @@ def cast_tensor_to_fp8_blockwise(
385388
.reshape(original_shape)
386389
)
387390

391+
# remove the padding
392+
if data_hp.shape != shape_before_padding:
393+
fp_data = fp_data[: shape_before_padding[0], : shape_before_padding[1]]
394+
388395
# Convert to target format, but still in original precision container
389396
return fp_data, descale_fp
390397

391398

392399
def process_weights_after_loading(self, layer) -> None:
393-
from torch.nn import Parameter
394-
from vllm.model_executor.parameter import (
395-
BlockQuantScaleParameter,
396-
ModelWeightParameter,
400+
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
401+
maybe_post_process_fp8_weight_block,
402+
process_fp8_weight_block_strategy,
397403
)
398404

399405
assert self.block_quant and self.quant_config.is_checkpoint_fp8_serialized
400406
assert self.quant_config.activation_scheme == "dynamic"
401407

402-
def _create_param_from_subclass_attributes(custom_param):
403-
param = Parameter(custom_param.data, requires_grad=False)
404-
base_param_dir = dir(torch.nn.Parameter)
405-
custom_param_dir = dir(custom_param)
406-
# Find the attributes that are unique to the custom parameter
407-
custom_attributes = [
408-
attr
409-
for attr in custom_param_dir
410-
if attr not in base_param_dir and not attr.startswith("__")
411-
]
412-
# Set the custom attributes into the base parameter object
413-
for attr in custom_attributes:
414-
setattr(param, attr, getattr(custom_param, attr))
415-
416-
param.subclass_type = type(custom_param)
417-
return param
418-
419-
weight = layer.weight.data
420-
weight_scale_inv = layer.weight_scale_inv.data
421-
weight = self._maybe_pad_weight(weight)
422-
423-
layer.weight = _create_param_from_subclass_attributes(
424-
ModelWeightParameter(
425-
data=weight,
426-
output_dim=0,
427-
input_dim=1,
428-
weight_loader=layer.weight.weight_loader,
429-
)
430-
)
431-
layer.weight_scale_inv = _create_param_from_subclass_attributes(
432-
BlockQuantScaleParameter(
433-
data=weight_scale_inv,
434-
output_dim=0,
435-
input_dim=1,
436-
weight_loader=layer.weight_scale_inv.weight_loader,
437-
)
438-
)
408+
weight_scale = layer.weight_scale_inv
409+
weight, weight_scale = process_fp8_weight_block_strategy(layer.weight, weight_scale)
410+
layer.weight.data = weight.data
411+
if hasattr(layer, "weight_scale"):
412+
# Not the first time to call this function, just need to update the data
413+
layer.weight_scale.data = weight_scale.data
414+
else:
415+
# The first time to call this function, create a new parameter and update the tp status
416+
layer.weight_scale = torch.nn.Parameter(weight_scale.data, requires_grad=False)
417+
layer.update_param_tp_status()
418+
419+
maybe_post_process_fp8_weight_block(layer, self.cutlass_block_fp8_supported)
439420

440421

441422
@triton.jit

tests/unit/algorithms/test_grpo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,7 @@ def mock_grpo_components():
826826
"token_mult_prob_error": [
827827
1.0
828828
], # Must be <= 1.05 to avoid logging extra plots
829+
"gen_kl_error": [0.0001],
829830
},
830831
}
831832
policy.generate.return_value = {

0 commit comments

Comments
 (0)