Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Pre-trained embeddings fails using variable input sequence lengths #797

Open
mvidela31 opened this issue Jan 10, 2025 · 0 comments
Open
Labels
bug Something isn't working status/needs-triage

Comments

@mvidela31
Copy link

Bug description

I was trying to use the pre-trained embeddings feature following the test_trainer_with_pretrained_embeddings() unit test. Despite that test works, I realized that the length of the sequences used in the input dataset (tr.data.music_streaming_testing_data) is always the same (max_sequence_length=20). Therefore, when I tried to replicate that example with a dataset that has a variable input sequence length (e.g., the dataset used in the example notebooks), the model training fails returning a matrix multiplication mismatch error.

Steps/Code to reproduce bug

Use the dataset generated in the 01-ETL-with-NVTabular.ipynb official example to feed a model with pre-trained embeddings. Run the following code:

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import glob

import cudf
import numpy as np
import pandas as pd

import nvtabular as nvt
from nvtabular.ops import *
from merlin.schema.tags import Tags

from transformers4rec.utils.data_utils import save_time_based_splits

import torch 
from transformers4rec import torch as tr
from transformers4rec.torch.ranking_metric import NDCGAt, AvgPrecisionAt, RecallAt
from transformers4rec.torch.utils.examples_utils import wipe_memory

from merlin.schema import Schema
from merlin.io import Dataset

from transformers4rec.config.trainer import T4RecTrainingArguments
from transformers4rec.torch import Trainer


INPUT_DATA_DIR = os.environ.get("INPUT_DATA_DIR", "./data/")
NUM_ROWS = os.environ.get("NUM_ROWS", 100000)
long_tailed_item_distribution = np.clip(np.random.lognormal(3., 1., int(NUM_ROWS)).astype(np.int32), 1, 50000)
# generate random item interaction features 
df = pd.DataFrame(np.random.randint(70000, 90000, int(NUM_ROWS)), columns=['session_id'])
df['item_id'] = long_tailed_item_distribution

# generate category mapping for each item-id
df['category'] = pd.cut(df['item_id'], bins=334, labels=np.arange(1, 335)).astype(np.int32)
df['age_days'] = np.random.uniform(0, 1, int(NUM_ROWS)).astype(np.float32)
df['weekday_sin']= np.random.uniform(0, 1, int(NUM_ROWS)).astype(np.float32)

# generate day mapping for each session 
map_day = dict(zip(df.session_id.unique(), np.random.randint(1, 10, size=(df.session_id.nunique()))))
df['day'] =  df.session_id.map(map_day)
SESSIONS_MAX_LENGTH =20

# Categorify categorical features
categ_feats = ['item_id', 'category'] >> nvt.ops.Categorify()

# Define Groupby Workflow
groupby_feats = categ_feats + ['session_id', 'day', 'age_days', 'weekday_sin']

# Group interaction features by session
groupby_features = groupby_feats >> nvt.ops.Groupby(
    groupby_cols=["session_id"], 
    aggs={
        "item_id": ["list", "count"],
        "category": ["list"],     
        "day": ["first"],
        "age_days": ["list"],
        'weekday_sin': ["list"],
        },
    name_sep="-")

# Select and truncate the sequential features
sequence_features_truncated = (
    groupby_features['category-list']
    >> nvt.ops.ListSlice(-SESSIONS_MAX_LENGTH) 
)

sequence_features_truncated_item = (
    groupby_features['item_id-list']
    >> nvt.ops.ListSlice(-SESSIONS_MAX_LENGTH) 
    >> TagAsItemID()
)  
sequence_features_truncated_cont = (
    groupby_features['age_days-list', 'weekday_sin-list'] 
    >> nvt.ops.ListSlice(-SESSIONS_MAX_LENGTH) 
    >> nvt.ops.AddMetadata(tags=[Tags.CONTINUOUS])
)

# Filter out sessions with length 1 (not valid for next-item prediction training and evaluation)
MINIMUM_SESSION_LENGTH = 2
selected_features = (
    groupby_features['item_id-count', 'day-first', 'session_id'] + 
    sequence_features_truncated_item +
    sequence_features_truncated + 
    sequence_features_truncated_cont
)

filtered_sessions = selected_features >> nvt.ops.Filter(f=lambda df: df["item_id-count"] >= MINIMUM_SESSION_LENGTH)
seq_feats_list = filtered_sessions['item_id-list', 'category-list', 'age_days-list', 'weekday_sin-list'] >>  nvt.ops.ValueCount()
workflow = nvt.Workflow(filtered_sessions['session_id', 'day-first'] + seq_feats_list)
dataset = nvt.Dataset(df)

# Generate statistics for the features and export parquet files
# this step will generate the schema file
workflow.fit_transform(dataset).to_parquet(os.path.join(INPUT_DATA_DIR, "processed_nvt"))
workflow.output_schema
workflow.save(os.path.join(INPUT_DATA_DIR, "workflow_etl"))
OUTPUT_DIR = os.environ.get("OUTPUT_DIR",os.path.join(INPUT_DATA_DIR, "sessions_by_day"))
# Read in the processed parquet file
sessions_gdf = cudf.read_parquet(os.path.join(INPUT_DATA_DIR, "processed_nvt/part_0.parquet"))
save_time_based_splits(
    data=nvt.Dataset(sessions_gdf),
    output_dir= OUTPUT_DIR,
    partition_col='day-first',
    timestamp_col='session_id', 
)
INPUT_DATA_DIR = os.environ.get("INPUT_DATA_DIR", "./data")
OUTPUT_DIR = os.environ.get("OUTPUT_DIR", f"{INPUT_DATA_DIR}/sessions_by_day")
train = Dataset(os.path.join(INPUT_DATA_DIR, "processed_nvt/part_0.parquet"))
schema = train.schema
schema = schema.select_by_name(['item_id-list', 
                            'category-list', 
                            'weekday_sin-list',
                            'age_days-list'])
pretrained_dim = 200
item_cardinality = schema["item_id-list"].int_domain.max + 1
np_emb_item_id = np.random.rand(item_cardinality, pretrained_dim)

embeddings_op = EmbeddingOperator(
    np_emb_item_id, lookup_key="item_id-list", embedding_name="pretrained_item_id_embeddings"
)
# set dataloader with pre-trained embeddings
data_loader = MerlinDataLoader.from_schema(
    schema,
    Dataset(os.path.join(INPUT_DATA_DIR, "processed_nvt/part_0.parquet"), schema=schema),
    max_sequence_length=20,
    batch_size=128,
    transforms=[embeddings_op],
    shuffle=False,
)

# set the model schema from data-loader
model_schema = data_loader.output_schema
inputs = tr.TabularSequenceFeatures.from_schema(
    model_schema,
    max_sequence_length=20,
    continuous_projection=64,
    pretrained_output_dims=8,
    normalizer="layer-norm",
    d_output=100,
    masking="mlm",
)
# Define XLNetConfig class and set default parameters for HF XLNet config  
transformer_config = tr.XLNetConfig.build(
    d_model=64, n_head=4, n_layer=2, total_seq_length=20
)
# Define the model block including: inputs, masking, projection and transformer block.
body = tr.SequentialBlock(
    inputs, tr.MLPBlock([64]), tr.TransformerBlock(transformer_config, masking=inputs.masking)
)

# Define the evaluation top-N metrics and the cut-offs
metrics = [NDCGAt(top_ks=[20, 40], labels_onehot=True),  
           RecallAt(top_ks=[20, 40], labels_onehot=True)]

# Define a head related to next item prediction task 
head = tr.Head(
    body,
    tr.NextItemPredictionTask(weight_tying=True, 
                              metrics=metrics),
    inputs=inputs,
)

# Get the end-to-end Model class 
model = tr.Model(head)

per_device_train_batch_size = int(os.environ.get(
    "per_device_train_batch_size", 
    '128'
))

per_device_eval_batch_size = int(os.environ.get(
    "per_device_eval_batch_size", 
    '32'
))

train_args = T4RecTrainingArguments(
    data_loader_engine='merlin', 
    dataloader_drop_last = True,
    gradient_accumulation_steps = 1,
    per_device_train_batch_size = per_device_train_batch_size, 
    per_device_eval_batch_size = per_device_eval_batch_size,
    output_dir = "./tmp", 
    learning_rate=0.0005,
    lr_scheduler_type='cosine', 
    learning_rate_num_cosine_cycles_by_epoch=1.5,
    num_train_epochs=5,
    max_sequence_length=20, 
    report_to = [],
    logging_steps=50,
)
# Explicitly pass the merlin dataloader with pre-trained embeddings
trainer = Trainer(
    model=model,
    args=train_args,
    schema=schema,
    train_dataloader=data_loader,
    eval_dataloader=data_loader,
    compute_metrics=True,
)
trainer.train()
eval_metrics = trainer.evaluate(eval_dataset=os.path.join(INPUT_DATA_DIR, "processed_nvt/part_0.parquet"), metric_key_prefix="eval")
File "/opt/ml/code/train.py", line 259, in demo
  trainer.train()
File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1633, in train
  return inner_training_loop(
File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1902, in _inner_training_loop
  tr_loss_step = self.training_step(model, inputs)
File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2645, in training_step
  loss = self.compute_loss(model, inputs)
File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/trainer.py", line 323, in compute_loss
  outputs = model(inputs, targets=targets, training=True)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
  return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/model/base.py", line 560, in forward
  head_output = head(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
  return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/model/base.py", line 382, in forward
  body_outputs = self.body(body_outputs, training=training, testing=testing, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers4rec/config/schema.py", line 50, in __call__
  return super().__call__(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
  return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/block/base.py", line 256, in forward
  input = module(input, training=training, testing=testing)
File "/usr/local/lib/python3.10/dist-packages/transformers4rec/config/schema.py", line 50, in __call__
  return super().__call__(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/tabular/base.py", line 392, in __call__
  outputs = super().__call__(inputs, *args, **kwargs)  # noqa
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
  return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/features/sequence.py", line 253, in forward
  outputs = super(TabularSequenceFeatures, self).forward(inputs)
File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/tabular/base.py", line 604, in forward
  outputs.update(layer(inputs))
File "/usr/local/lib/python3.10/dist-packages/transformers4rec/config/schema.py", line 50, in __call__
  return super().__call__(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/tabular/base.py", line 392, in __call__
  outputs = super().__call__(inputs, *args, **kwargs)  # noqa
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
  return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/features/embedding.py", line 700, in forward
  output = {key: self.projection[key](val) for key, val in output.items()}
File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/features/embedding.py", line 700, in <dictcomp>
  output = {key: self.projection[key](val) for key, val in output.items()}
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
  return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py", line 114, in forward
  return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (1536x208 and 200x8)

Expected behavior

Successful training and evaluation of a model using pre-trained embeddings.

Environment details

  • Transformers4Rec version: nvcr.io/nvidia/merlin/merlin-pytorch:23.12
  • Platform:
  • Python version:
  • Huggingface Transformers version:
  • PyTorch version (GPU?):
  • Tensorflow version (GPU?):

Additional context

Note that if the dataset sequences are pre-padded (e.g., using nvt.ops.ListSlice(-SESSIONS_MAX_LENGTH, pad=True, pad_value=0) in the nvt workflow), the model training with pre-trained embeddings works as expected. However, that solution is not ideal since the padding is done in the right side of the sequence, and all the sequence are padded to the SESSIONS_MAX_LENGTH value instead of padding to the max length of each batch.

@mvidela31 mvidela31 added bug Something isn't working status/needs-triage labels Jan 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working status/needs-triage
Projects
None yet
Development

No branches or pull requests

1 participant