diff --git a/python/sglang/srt/weight_sync/tensor_bucket.py b/python/sglang/srt/weight_sync/tensor_bucket.py index 44273713fb88..c1d592ddbb0d 100644 --- a/python/sglang/srt/weight_sync/tensor_bucket.py +++ b/python/sglang/srt/weight_sync/tensor_bucket.py @@ -22,6 +22,9 @@ class FlattenedTensorBucket: while preserving all metadata needed for reconstruction. """ + # This field is solely for users of to check whether the class supports this feature + supports_multi_dtypes = True + def __init__( self, named_tensors: List[Tuple[str, torch.Tensor]] = None, @@ -48,7 +51,7 @@ def __init__( flattened_tensors: List[torch.Tensor] = [None] * len(named_tensors) for i, (name, tensor) in enumerate(named_tensors): - flattened = tensor.flatten() + flattened = tensor.flatten().view(torch.uint8) flattened_tensors[i] = flattened # Store metadata @@ -93,14 +96,12 @@ def reconstruct_tensors(self) -> List[Tuple[str, torch.Tensor]]: reconstructed = [None] * len(self.metadata) for i, meta in enumerate(self.metadata): - tensor = self.flattened_tensor[meta.start_idx : meta.end_idx].reshape( - meta.shape + tensor = ( + self.flattened_tensor[meta.start_idx : meta.end_idx] + .view(meta.dtype) + .reshape(meta.shape) ) - # batch dtype conversion (if needed) - if tensor.dtype != meta.dtype: - tensor = tensor.to(meta.dtype) - reconstructed[i] = (meta.name, tensor) return reconstructed