Skip to content
Merged
38 changes: 36 additions & 2 deletions python/sglang/srt/hardware_backend/npu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,28 @@ def init_npu_backend():
torch_npu.npu.set_compile_mode(jit_compile=False)


def _is_nz_aligned(tensor: torch.Tensor) -> bool:
"""Check whether the last two dims satisfy FRACTAL_NZ alignment rules.

Ascend FRACTAL_NZ requires:
BF16 / FP16 : both dims divisible by 16
INT8 : k % 16 == 0 and n % 32 == 0
INT4 : k % 16 == 0 and n % 64 == 0
FP4 : both dims divisible by 64
"""
if tensor.dim() < 2:
return False
k, n = tensor.shape[-2], tensor.shape[-1]
if tensor.dtype in (torch.bfloat16, torch.float16):
return k % 16 == 0 and n % 16 == 0
if tensor.dtype == torch.int8:
return k % 16 == 0 and n % 32 == 0
if tensor.dtype in (torch.uint8, torch.int32):
# INT4 is typically packed into uint8/int32; be conservative
return k % 16 == 0 and n % 64 == 0
return True


def npu_format_cast(
tensor: torch.Tensor,
acl_format: NPUACLFormat = NPUACLFormat.ACL_FORMAT_FRACTAL_NZ,
Expand Down Expand Up @@ -135,8 +157,20 @@ def npu_format_cast(
"significantly reduced."
)
return tensor
else:
return torch.ops.npu.npu_format_cast(tensor, acl_format.value)

if acl_format == NPUACLFormat.ACL_FORMAT_FRACTAL_NZ and not _is_nz_aligned(tensor):
k, n = tensor.shape[-2], tensor.shape[-1]
logger.warning_once(
"Skipping FRACTAL_NZ format cast: tensor shape (%d, %d) dtype %s "
"is not aligned to NZ requirements. Falling back to 'ND' format, "
"which may reduce NPU performance.",
k,
n,
tensor.dtype,
)
return tensor

return torch.ops.npu.npu_format_cast(tensor, acl_format.value)


def get_indexer_weight_stream():
Expand Down
4 changes: 1 addition & 3 deletions python/sglang/srt/layers/quantization/unquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,9 +317,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
for weight_name in ["w13_weight", "w2_weight"]:
weight = getattr(layer, weight_name)
weight.data = weight.data.transpose(1, 2)
weight.data = npu_format_cast(
weight.data,
)
weight.data = npu_format_cast(weight.data)

return

Expand Down
Loading