You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I tried to print the output from a quantized EBC layer in a model. However, when I call .to_dict() on the layer output I get the error: RuntimeError: split_with_sizes expects split_sizes to sum exactly to 812 (input tensor's size at dimension 1), but got split_sizes=[804]. This bug happens when I set the output_dtype to torch.qint8 or torch.quint8, but not torch.float32.
Below is the reproduction code and the log for the error. The code creates a model (with a dense layer, a sparse layer, a weighted sparse layer, and an over layer), quantizes the model, and runs a forward pass of the model on random inputs. My environment is Python 3.10.14, torch 2.3.0+cu121, torchrec 0.7.0.
I guess the additional dimension might be some parameters added to support integer quantization. I wonder if this is the case. If so, then I wonder if the to_dict() function can be fixed to handle the additional dimension and produce the correct output dictionary. Thanks!
Reproduction code:
import traceback
from typing import Protocol, cast, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torchrec.distributed.embedding_types import EmbeddingTableConfig
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.distributed.planner import (
EmbeddingShardingPlanner,
Topology,
)
from torchrec.distributed.test_utils.test_model import (
ModelInput,
)
from torchrec.distributed.types import (
ModuleSharder,
ShardingEnv,
ShardingPlan,
)
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.sparse.jagged_tensor import KeyedTensor
from torchrec.distributed.test_utils.infer_utils import TestQuantEBCSharder
from torchrec.inference.modules import quantize_embeddings
from torchrec.modules.embedding_modules import EmbeddingBagCollection
class TestModel(nn.Module):
def __init__(self):
super().__init__()
table_params = [
[777, 912],
]
weighted_table_params = [
[941, 804],
]
self.tables = [
EmbeddingBagConfig(
num_embeddings=table_params[i][0],
embedding_dim=table_params[i][1],
name="table_" + str(i),
feature_names=["feature_" + str(i)],
data_type=torch.int64,
)
for i in range(len(table_params))
]
self.weighted_tables = [
EmbeddingBagConfig(
num_embeddings=weighted_table_params[i][0],
embedding_dim=weighted_table_params[i][1],
name="weighted_table_" + str(i),
feature_names=["weighted_feature_" + str(i)],
)
for i in range(len(weighted_table_params))
]
self.dense = nn.Linear(in_features=914, out_features=930, bias=True)
self.sparse = EmbeddingBagCollection(
tables=self.tables,
is_weighted=False,
)
self.sparse_weighted = EmbeddingBagCollection(
tables=self.weighted_tables,
is_weighted=True,
)
in_features_concat = (
self.dense.out_features
+ sum(
[
table.embedding_dim * len(table.feature_names)
for table in self.tables
]
)
+ sum(
[
table.embedding_dim * len(table.feature_names)
for table in self.weighted_tables
]
)
)
self.over = nn.Linear(in_features=in_features_concat, out_features=224, bias=True)
def forward(
self,
input,
):
dense_r = self.dense(input.float_features)
sparse_r = self.sparse(input.idlist_features)
sparse_weighted_r = self.sparse_weighted(input.idscore_features)
result = KeyedTensor(
keys=sparse_r.keys() + sparse_weighted_r.keys(),
length_per_key=sparse_r.length_per_key()
+ sparse_weighted_r.length_per_key(),
values=torch.cat([sparse_r.values(), sparse_weighted_r.values()], dim=1),
)
_features = [
feature for table in self.tables for feature in table.feature_names
]
_weighted_features = [
feature for table in self.weighted_tables for feature in table.feature_names
]
ret_list = []
ret_list.append(dense_r)
for feature_name in _features:
ret_list.append(result[feature_name])
for feature_name in _weighted_features:
ret_list.append(result[feature_name])
ret_concat = torch.cat(ret_list, dim=1)
over_r = self.over(ret_concat)
pred = torch.sigmoid(torch.mean(over_r, dim=1))
if self.training:
return (
torch.nn.functional.binary_cross_entropy_with_logits(pred, input.label),
pred, (dense_r, sparse_r, sparse_weighted_r, over_r),
)
else:
return pred, (dense_r, sparse_r, sparse_weighted_r, over_r)
def sharding_single_rank_test(
world_size: int,
model,
inputs,
sharders: List[ModuleSharder[nn.Module]],
quant_dtype = None,
quant_output_dtype = None,
) -> None:
device = torch.device("cuda:0")
model = model.to(device)
model = quantize_embeddings(model, dtype=quant_dtype, inplace=True, output_dtype=quant_output_dtype)
global_input_train = inputs[0][0].to(device)
local_model = model
planner = EmbeddingShardingPlanner(
topology=Topology(
world_size, device.type
)
)
plan: ShardingPlan = planner.plan(local_model, sharders)
local_model = DistributedModelParallel(
local_model,
env=ShardingEnv.from_local(world_size=world_size, rank=0),
plan=plan,
sharders=sharders,
device=device,
init_data_parallel=False,
)
local_pred, (local_dense_r, local_sparse_r, local_sparse_weighted_r, local_over_r) = local_model(global_input_train)
print("local_sparse_r: ")
print(local_sparse_r.values().shape)
print(local_sparse_r.keys())
print(local_sparse_r.length_per_key())
try:
print(local_sparse_r.to_dict())
except Exception as e:
print(e)
traceback.print_exc()
print(local_sparse_weighted_r.values().shape)
print(local_sparse_weighted_r.keys())
print(local_sparse_weighted_r.length_per_key())
try:
print(local_sparse_weighted_r.to_dict())
except Exception as e:
print(e)
traceback.print_exc()
class ModelInputCallable(Protocol):
def __call__(
self,
batch_size: int,
world_size: int,
num_float_features: int,
tables: Union[List[EmbeddingTableConfig], List[EmbeddingBagConfig]],
weighted_tables: Union[List[EmbeddingTableConfig], List[EmbeddingBagConfig]],
pooling_avg: int = 10,
dedup_tables: Optional[
Union[List[EmbeddingTableConfig], List[EmbeddingBagConfig]]
] = None,
variable_batch_size: bool = False,
long_indices: bool = True,
) -> Tuple["ModelInput", List["ModelInput"]]: ...
def main_test(
sharders: List[ModuleSharder[nn.Module]],
world_size: int = 2,
quant_dtype = torch.qint8,
quant_output_dtype = torch.float32,
) -> None:
model = TestModel()
batch_size=2400
world_size=1
num_float_features=model.dense.in_features
tables = model.tables
weighted_tables=model.weighted_tables
inputs = [
(
cast(ModelInputCallable, ModelInput.generate)(
world_size=world_size,
tables=tables,
weighted_tables=weighted_tables or [],
num_float_features=num_float_features,
batch_size=batch_size,
)
)
]
sharding_single_rank_test(
world_size=world_size,
model=model,
inputs=inputs,
sharders=sharders,
quant_dtype = quant_dtype,
quant_output_dtype = quant_output_dtype,
)
def main():
backend = "nccl"
world_size = 2
dtype = torch.qint8
output_dtype = torch.qint8
sharding_type = "table_wise"
kernel_type = "quant"
sharders = [TestQuantEBCSharder(sharding_type, kernel_type)]
main_test(
sharders = sharders,
world_size = world_size,
quant_dtype=dtype,
quant_output_dtype=output_dtype,
)
if __name__ == "__main__":
main()
Logs:
local_sparse_r:
torch.Size([2400, 920])
['feature_0']
[912]
split_with_sizes expects split_sizes to sum exactly to 920 (input tensor's size at dimension 1), but got split_sizes=[912]
Traceback (most recent call last):
File "/mnt/util/reproduce_quant_unit_32.py", line 168, in sharding_single_rank_test
print(local_sparse_r.to_dict())
File "/root/miniconda3/envs/pttrlatest/lib/python3.10/site-packages/torchrec/sparse/jagged_tensor.py", line 2313, in to_dict
split_values = self._values.split(lengths, dim=self._key_dim)
File "/root/miniconda3/envs/pttrlatest/lib/python3.10/site-packages/torch/_tensor.py", line 921, in split
return torch._VF.split_with_sizes(self, split_size, dim)
RuntimeError: split_with_sizes expects split_sizes to sum exactly to 920 (input tensor's size at dimension 1), but got split_sizes=[912]
torch.Size([2400, 812])
['weighted_feature_0']
[804]
split_with_sizes expects split_sizes to sum exactly to 812 (input tensor's size at dimension 1), but got split_sizes=[804]
Traceback (most recent call last):
File "/mnt/util/reproduce_quant_unit_32.py", line 177, in sharding_single_rank_test
print(local_sparse_weighted_r.to_dict())
File "/root/miniconda3/envs/pttrlatest/lib/python3.10/site-packages/torchrec/sparse/jagged_tensor.py", line 2313, in to_dict
split_values = self._values.split(lengths, dim=self._key_dim)
File "/root/miniconda3/envs/pttrlatest/lib/python3.10/site-packages/torch/_tensor.py", line 921, in split
return torch._VF.split_with_sizes(self, split_size, dim)
RuntimeError: split_with_sizes expects split_sizes to sum exactly to 812 (input tensor's size at dimension 1), but got split_sizes=[804]
The text was updated successfully, but these errors were encountered:
I tried to print the output from a quantized EBC layer in a model. However, when I call
.to_dict()
on the layer output I get the error:RuntimeError: split_with_sizes expects split_sizes to sum exactly to 812 (input tensor's size at dimension 1), but got split_sizes=[804]
. This bug happens when I set theoutput_dtype
totorch.qint8
ortorch.quint8
, but nottorch.float32
.Below is the reproduction code and the log for the error. The code creates a model (with a dense layer, a sparse layer, a weighted sparse layer, and an over layer), quantizes the model, and runs a forward pass of the model on random inputs. My environment is
Python 3.10.14, torch 2.3.0+cu121, torchrec 0.7.0
.I guess the additional dimension might be some parameters added to support integer quantization. I wonder if this is the case. If so, then I wonder if the
to_dict()
function can be fixed to handle the additional dimension and produce the correct output dictionary. Thanks!Reproduction code:
Logs:
The text was updated successfully, but these errors were encountered: