Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EmbeddingBagCollection output raises error when calls to_dict() #2043

Open
jiannanWang opened this issue May 26, 2024 · 0 comments
Open

EmbeddingBagCollection output raises error when calls to_dict() #2043

jiannanWang opened this issue May 26, 2024 · 0 comments

Comments

@jiannanWang
Copy link

I tried to print the output from a quantized EBC layer in a model. However, when I call .to_dict() on the layer output I get the error: RuntimeError: split_with_sizes expects split_sizes to sum exactly to 812 (input tensor's size at dimension 1), but got split_sizes=[804]. This bug happens when I set the output_dtype to torch.qint8 or torch.quint8, but not torch.float32.

Below is the reproduction code and the log for the error. The code creates a model (with a dense layer, a sparse layer, a weighted sparse layer, and an over layer), quantizes the model, and runs a forward pass of the model on random inputs. My environment is Python 3.10.14, torch 2.3.0+cu121, torchrec 0.7.0.

I guess the additional dimension might be some parameters added to support integer quantization. I wonder if this is the case. If so, then I wonder if the to_dict() function can be fixed to handle the additional dimension and produce the correct output dictionary. Thanks!

Reproduction code:

import traceback
from typing import Protocol, cast, List, Optional, Tuple, Union

import torch
import torch.nn as nn
from torchrec.distributed.embedding_types import EmbeddingTableConfig
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.distributed.planner import (
    EmbeddingShardingPlanner,
    Topology,
)
from torchrec.distributed.test_utils.test_model import (
    ModelInput,
)
from torchrec.distributed.types import (
    ModuleSharder,
    ShardingEnv,
    ShardingPlan,
)
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.sparse.jagged_tensor import KeyedTensor

from torchrec.distributed.test_utils.infer_utils import TestQuantEBCSharder
from torchrec.inference.modules import quantize_embeddings
from torchrec.modules.embedding_modules import EmbeddingBagCollection

class TestModel(nn.Module):
    def __init__(self):
        super().__init__()

        table_params = [
            [777, 912],
        ]

        weighted_table_params = [
            [941, 804],
        ]

        self.tables = [
            EmbeddingBagConfig(
                num_embeddings=table_params[i][0],
                embedding_dim=table_params[i][1],
                name="table_" + str(i),
                feature_names=["feature_" + str(i)],
                data_type=torch.int64,
            )
            for i in range(len(table_params))
        ]
        self.weighted_tables = [
            EmbeddingBagConfig(
                num_embeddings=weighted_table_params[i][0],
                embedding_dim=weighted_table_params[i][1],
                name="weighted_table_" + str(i),
                feature_names=["weighted_feature_" + str(i)],
            )
            for i in range(len(weighted_table_params))
        ]

        self.dense = nn.Linear(in_features=914, out_features=930, bias=True)

        self.sparse = EmbeddingBagCollection(
            tables=self.tables,
            is_weighted=False,
        )
        self.sparse_weighted = EmbeddingBagCollection(
            tables=self.weighted_tables, 
            is_weighted=True,
        )

        in_features_concat = (
            self.dense.out_features
            + sum(
                [
                    table.embedding_dim * len(table.feature_names)
                    for table in self.tables
                ]
            )
            + sum(
                [
                    table.embedding_dim * len(table.feature_names)
                    for table in self.weighted_tables
                ]
            )
        )

        self.over = nn.Linear(in_features=in_features_concat, out_features=224, bias=True)

    def forward(
        self,
        input,
    ):
        dense_r = self.dense(input.float_features)
        sparse_r = self.sparse(input.idlist_features)
        sparse_weighted_r = self.sparse_weighted(input.idscore_features)
        result = KeyedTensor(
            keys=sparse_r.keys() + sparse_weighted_r.keys(),
            length_per_key=sparse_r.length_per_key()
            + sparse_weighted_r.length_per_key(),
            values=torch.cat([sparse_r.values(), sparse_weighted_r.values()], dim=1),
        )

        _features = [
            feature for table in self.tables for feature in table.feature_names
        ]
        _weighted_features = [
            feature for table in self.weighted_tables for feature in table.feature_names
        ]

        ret_list = []
        ret_list.append(dense_r)
        for feature_name in _features:
            ret_list.append(result[feature_name])
        for feature_name in _weighted_features:
            ret_list.append(result[feature_name])
        ret_concat = torch.cat(ret_list, dim=1)

        over_r = self.over(ret_concat)
        pred = torch.sigmoid(torch.mean(over_r, dim=1))
        if self.training:
            return (
                torch.nn.functional.binary_cross_entropy_with_logits(pred, input.label),
                pred, (dense_r, sparse_r, sparse_weighted_r, over_r),
            )
        else:
            return pred, (dense_r, sparse_r, sparse_weighted_r, over_r)


def sharding_single_rank_test(
    world_size: int,
    model,
    inputs,
    sharders: List[ModuleSharder[nn.Module]],
    quant_dtype = None,
    quant_output_dtype = None,
) -> None:
    device = torch.device("cuda:0")
    model = model.to(device)
    
    model = quantize_embeddings(model, dtype=quant_dtype, inplace=True, output_dtype=quant_output_dtype)

    global_input_train = inputs[0][0].to(device)

    local_model = model

    planner = EmbeddingShardingPlanner(
        topology=Topology(
            world_size, device.type
        )
    )
    plan: ShardingPlan = planner.plan(local_model, sharders)

    local_model = DistributedModelParallel(
        local_model,
        env=ShardingEnv.from_local(world_size=world_size, rank=0),
        plan=plan,
        sharders=sharders,
        device=device,
        init_data_parallel=False,
    )

    local_pred, (local_dense_r, local_sparse_r, local_sparse_weighted_r, local_over_r) = local_model(global_input_train)

    print("local_sparse_r: ")
    print(local_sparse_r.values().shape)
    print(local_sparse_r.keys())
    print(local_sparse_r.length_per_key())
    try:
        print(local_sparse_r.to_dict())
    except Exception as e:
        print(e)
        traceback.print_exc()
    
    print(local_sparse_weighted_r.values().shape)
    print(local_sparse_weighted_r.keys())
    print(local_sparse_weighted_r.length_per_key())
    try:
        print(local_sparse_weighted_r.to_dict())
    except Exception as e:
        print(e)
        traceback.print_exc()

class ModelInputCallable(Protocol):
    def __call__(
        self,
        batch_size: int,
        world_size: int,
        num_float_features: int,
        tables: Union[List[EmbeddingTableConfig], List[EmbeddingBagConfig]],
        weighted_tables: Union[List[EmbeddingTableConfig], List[EmbeddingBagConfig]],
        pooling_avg: int = 10,
        dedup_tables: Optional[
            Union[List[EmbeddingTableConfig], List[EmbeddingBagConfig]]
        ] = None,
        variable_batch_size: bool = False,
        long_indices: bool = True,
    ) -> Tuple["ModelInput", List["ModelInput"]]: ...

def main_test(
    sharders: List[ModuleSharder[nn.Module]],
    world_size: int = 2,
    quant_dtype = torch.qint8,
    quant_output_dtype = torch.float32,
) -> None:
    model = TestModel()
    batch_size=2400
    world_size=1
    num_float_features=model.dense.in_features
    tables = model.tables
    weighted_tables=model.weighted_tables

    inputs = [
        (
            cast(ModelInputCallable, ModelInput.generate)(
                world_size=world_size,
                tables=tables,
                weighted_tables=weighted_tables or [],
                num_float_features=num_float_features,
                batch_size=batch_size,
            )
        )
    ]

    sharding_single_rank_test(
        world_size=world_size,
        model=model,
        inputs=inputs,
        sharders=sharders,
        quant_dtype = quant_dtype,
        quant_output_dtype = quant_output_dtype,
    )


def main():
    backend = "nccl"
    world_size = 2
    
    dtype = torch.qint8
    output_dtype = torch.qint8

    sharding_type = "table_wise"
    kernel_type = "quant"
    sharders = [TestQuantEBCSharder(sharding_type, kernel_type)]

    main_test(
        sharders = sharders,
        world_size = world_size,
        quant_dtype=dtype,
        quant_output_dtype=output_dtype,
    )


if __name__ == "__main__":
    main()

Logs:

local_sparse_r:
torch.Size([2400, 920])
['feature_0']
[912]
split_with_sizes expects split_sizes to sum exactly to 920 (input tensor's size at dimension 1), but got split_sizes=[912]
Traceback (most recent call last):
  File "/mnt/util/reproduce_quant_unit_32.py", line 168, in sharding_single_rank_test
    print(local_sparse_r.to_dict())
  File "/root/miniconda3/envs/pttrlatest/lib/python3.10/site-packages/torchrec/sparse/jagged_tensor.py", line 2313, in to_dict
    split_values = self._values.split(lengths, dim=self._key_dim)
  File "/root/miniconda3/envs/pttrlatest/lib/python3.10/site-packages/torch/_tensor.py", line 921, in split
    return torch._VF.split_with_sizes(self, split_size, dim)
RuntimeError: split_with_sizes expects split_sizes to sum exactly to 920 (input tensor's size at dimension 1), but got split_sizes=[912]
torch.Size([2400, 812])
['weighted_feature_0']
[804]
split_with_sizes expects split_sizes to sum exactly to 812 (input tensor's size at dimension 1), but got split_sizes=[804]
Traceback (most recent call last):
  File "/mnt/util/reproduce_quant_unit_32.py", line 177, in sharding_single_rank_test
    print(local_sparse_weighted_r.to_dict())
  File "/root/miniconda3/envs/pttrlatest/lib/python3.10/site-packages/torchrec/sparse/jagged_tensor.py", line 2313, in to_dict
    split_values = self._values.split(lengths, dim=self._key_dim)
  File "/root/miniconda3/envs/pttrlatest/lib/python3.10/site-packages/torch/_tensor.py", line 921, in split
    return torch._VF.split_with_sizes(self, split_size, dim)
RuntimeError: split_with_sizes expects split_sizes to sum exactly to 812 (input tensor's size at dimension 1), but got split_sizes=[804]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant