You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I was trying to use the pre-trained embeddings feature following the test_trainer_with_pretrained_embeddings() unit test. Despite that test works, I realized that the length of the sequences used in the input dataset (tr.data.music_streaming_testing_data) is always the same (max_sequence_length=20). Therefore, when I tried to replicate that example with a dataset that has a variable input sequence length (e.g., the dataset used in the example notebooks), the model training fails returning a matrix multiplication mismatch error.
Steps/Code to reproduce bug
Use the dataset generated in the 01-ETL-with-NVTabular.ipynb official example to feed a model with pre-trained embeddings. Run the following code:
importosos.environ["CUDA_VISIBLE_DEVICES"] ="0"importglobimportcudfimportnumpyasnpimportpandasaspdimportnvtabularasnvtfromnvtabular.opsimport*frommerlin.schema.tagsimportTagsfromtransformers4rec.utils.data_utilsimportsave_time_based_splitsimporttorchfromtransformers4recimporttorchastrfromtransformers4rec.torch.ranking_metricimportNDCGAt, AvgPrecisionAt, RecallAtfromtransformers4rec.torch.utils.examples_utilsimportwipe_memoryfrommerlin.schemaimportSchemafrommerlin.ioimportDatasetfromtransformers4rec.config.trainerimportT4RecTrainingArgumentsfromtransformers4rec.torchimportTrainerINPUT_DATA_DIR=os.environ.get("INPUT_DATA_DIR", "./data/")
NUM_ROWS=os.environ.get("NUM_ROWS", 100000)
long_tailed_item_distribution=np.clip(np.random.lognormal(3., 1., int(NUM_ROWS)).astype(np.int32), 1, 50000)
# generate random item interaction features df=pd.DataFrame(np.random.randint(70000, 90000, int(NUM_ROWS)), columns=['session_id'])
df['item_id'] =long_tailed_item_distribution# generate category mapping for each item-iddf['category'] =pd.cut(df['item_id'], bins=334, labels=np.arange(1, 335)).astype(np.int32)
df['age_days'] =np.random.uniform(0, 1, int(NUM_ROWS)).astype(np.float32)
df['weekday_sin']=np.random.uniform(0, 1, int(NUM_ROWS)).astype(np.float32)
# generate day mapping for each session map_day=dict(zip(df.session_id.unique(), np.random.randint(1, 10, size=(df.session_id.nunique()))))
df['day'] =df.session_id.map(map_day)
SESSIONS_MAX_LENGTH=20# Categorify categorical featurescateg_feats= ['item_id', 'category'] >>nvt.ops.Categorify()
# Define Groupby Workflowgroupby_feats=categ_feats+ ['session_id', 'day', 'age_days', 'weekday_sin']
# Group interaction features by sessiongroupby_features=groupby_feats>>nvt.ops.Groupby(
groupby_cols=["session_id"],
aggs={
"item_id": ["list", "count"],
"category": ["list"],
"day": ["first"],
"age_days": ["list"],
'weekday_sin': ["list"],
},
name_sep="-")
# Select and truncate the sequential featuressequence_features_truncated= (
groupby_features['category-list']
>>nvt.ops.ListSlice(-SESSIONS_MAX_LENGTH)
)
sequence_features_truncated_item= (
groupby_features['item_id-list']
>>nvt.ops.ListSlice(-SESSIONS_MAX_LENGTH)
>>TagAsItemID()
)
sequence_features_truncated_cont= (
groupby_features['age_days-list', 'weekday_sin-list']
>>nvt.ops.ListSlice(-SESSIONS_MAX_LENGTH)
>>nvt.ops.AddMetadata(tags=[Tags.CONTINUOUS])
)
# Filter out sessions with length 1 (not valid for next-item prediction training and evaluation)MINIMUM_SESSION_LENGTH=2selected_features= (
groupby_features['item_id-count', 'day-first', 'session_id'] +sequence_features_truncated_item+sequence_features_truncated+sequence_features_truncated_cont
)
filtered_sessions=selected_features>>nvt.ops.Filter(f=lambdadf: df["item_id-count"] >=MINIMUM_SESSION_LENGTH)
seq_feats_list=filtered_sessions['item_id-list', 'category-list', 'age_days-list', 'weekday_sin-list'] >>nvt.ops.ValueCount()
workflow=nvt.Workflow(filtered_sessions['session_id', 'day-first'] +seq_feats_list)
dataset=nvt.Dataset(df)
# Generate statistics for the features and export parquet files# this step will generate the schema fileworkflow.fit_transform(dataset).to_parquet(os.path.join(INPUT_DATA_DIR, "processed_nvt"))
workflow.output_schemaworkflow.save(os.path.join(INPUT_DATA_DIR, "workflow_etl"))
OUTPUT_DIR=os.environ.get("OUTPUT_DIR",os.path.join(INPUT_DATA_DIR, "sessions_by_day"))
# Read in the processed parquet filesessions_gdf=cudf.read_parquet(os.path.join(INPUT_DATA_DIR, "processed_nvt/part_0.parquet"))
save_time_based_splits(
data=nvt.Dataset(sessions_gdf),
output_dir=OUTPUT_DIR,
partition_col='day-first',
timestamp_col='session_id',
)
INPUT_DATA_DIR=os.environ.get("INPUT_DATA_DIR", "./data")
OUTPUT_DIR=os.environ.get("OUTPUT_DIR", f"{INPUT_DATA_DIR}/sessions_by_day")
train=Dataset(os.path.join(INPUT_DATA_DIR, "processed_nvt/part_0.parquet"))
schema=train.schemaschema=schema.select_by_name(['item_id-list',
'category-list',
'weekday_sin-list',
'age_days-list'])
pretrained_dim=200item_cardinality=schema["item_id-list"].int_domain.max+1np_emb_item_id=np.random.rand(item_cardinality, pretrained_dim)
embeddings_op=EmbeddingOperator(
np_emb_item_id, lookup_key="item_id-list", embedding_name="pretrained_item_id_embeddings"
)
# set dataloader with pre-trained embeddingsdata_loader=MerlinDataLoader.from_schema(
schema,
Dataset(os.path.join(INPUT_DATA_DIR, "processed_nvt/part_0.parquet"), schema=schema),
max_sequence_length=20,
batch_size=128,
transforms=[embeddings_op],
shuffle=False,
)
# set the model schema from data-loadermodel_schema=data_loader.output_schemainputs=tr.TabularSequenceFeatures.from_schema(
model_schema,
max_sequence_length=20,
continuous_projection=64,
pretrained_output_dims=8,
normalizer="layer-norm",
d_output=100,
masking="mlm",
)
# Define XLNetConfig class and set default parameters for HF XLNet config transformer_config=tr.XLNetConfig.build(
d_model=64, n_head=4, n_layer=2, total_seq_length=20
)
# Define the model block including: inputs, masking, projection and transformer block.body=tr.SequentialBlock(
inputs, tr.MLPBlock([64]), tr.TransformerBlock(transformer_config, masking=inputs.masking)
)
# Define the evaluation top-N metrics and the cut-offsmetrics= [NDCGAt(top_ks=[20, 40], labels_onehot=True),
RecallAt(top_ks=[20, 40], labels_onehot=True)]
# Define a head related to next item prediction task head=tr.Head(
body,
tr.NextItemPredictionTask(weight_tying=True,
metrics=metrics),
inputs=inputs,
)
# Get the end-to-end Model class model=tr.Model(head)
per_device_train_batch_size=int(os.environ.get(
"per_device_train_batch_size",
'128'
))
per_device_eval_batch_size=int(os.environ.get(
"per_device_eval_batch_size",
'32'
))
train_args=T4RecTrainingArguments(
data_loader_engine='merlin',
dataloader_drop_last=True,
gradient_accumulation_steps=1,
per_device_train_batch_size=per_device_train_batch_size,
per_device_eval_batch_size=per_device_eval_batch_size,
output_dir="./tmp",
learning_rate=0.0005,
lr_scheduler_type='cosine',
learning_rate_num_cosine_cycles_by_epoch=1.5,
num_train_epochs=5,
max_sequence_length=20,
report_to= [],
logging_steps=50,
)
# Explicitly pass the merlin dataloader with pre-trained embeddingstrainer=Trainer(
model=model,
args=train_args,
schema=schema,
train_dataloader=data_loader,
eval_dataloader=data_loader,
compute_metrics=True,
)
trainer.train()
eval_metrics=trainer.evaluate(eval_dataset=os.path.join(INPUT_DATA_DIR, "processed_nvt/part_0.parquet"), metric_key_prefix="eval")
File "/opt/ml/code/train.py", line 259, in demo
trainer.train()
File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1633, in train
return inner_training_loop(
File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1902, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2645, in training_step
loss = self.compute_loss(model, inputs)
File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/trainer.py", line 323, in compute_loss
outputs = model(inputs, targets=targets, training=True)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/model/base.py", line 560, in forward
head_output = head(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/model/base.py", line 382, in forward
body_outputs = self.body(body_outputs, training=training, testing=testing, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers4rec/config/schema.py", line 50, in __call__
returnsuper().__call__(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/block/base.py", line 256, in forward
input = module(input, training=training, testing=testing)
File "/usr/local/lib/python3.10/dist-packages/transformers4rec/config/schema.py", line 50, in __call__
returnsuper().__call__(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/tabular/base.py", line 392, in __call__
outputs = super().__call__(inputs, *args, **kwargs) # noqa
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/features/sequence.py", line 253, in forward
outputs = super(TabularSequenceFeatures, self).forward(inputs)
File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/tabular/base.py", line 604, in forward
outputs.update(layer(inputs))
File "/usr/local/lib/python3.10/dist-packages/transformers4rec/config/schema.py", line 50, in __call__
returnsuper().__call__(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/tabular/base.py", line 392, in __call__
outputs = super().__call__(inputs, *args, **kwargs) # noqa
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/features/embedding.py", line 700, in forward
output = {key: self.projection[key](val) forkey, valinoutput.items()}
File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/features/embedding.py", line 700, in<dictcomp>
output = {key: self.projection[key](val) forkey, valinoutput.items()}
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (1536x208 and 200x8)
Expected behavior
Successful training and evaluation of a model using pre-trained embeddings.
Note that if the dataset sequences are pre-padded (e.g., using nvt.ops.ListSlice(-SESSIONS_MAX_LENGTH, pad=True, pad_value=0) in the nvt workflow), the model training with pre-trained embeddings works as expected. However, that solution is not ideal since the padding is done in the right side of the sequence, and all the sequence are padded to the SESSIONS_MAX_LENGTH value instead of padding to the max length of each batch.
The text was updated successfully, but these errors were encountered:
Bug description
I was trying to use the pre-trained embeddings feature following the test_trainer_with_pretrained_embeddings() unit test. Despite that test works, I realized that the length of the sequences used in the input dataset (
tr.data.music_streaming_testing_data
) is always the same (max_sequence_length=20
). Therefore, when I tried to replicate that example with a dataset that has a variable input sequence length (e.g., the dataset used in the example notebooks), the model training fails returning a matrix multiplication mismatch error.Steps/Code to reproduce bug
Use the dataset generated in the 01-ETL-with-NVTabular.ipynb official example to feed a model with pre-trained embeddings. Run the following code:
Expected behavior
Successful training and evaluation of a model using pre-trained embeddings.
Environment details
nvcr.io/nvidia/merlin/merlin-pytorch:23.12
Additional context
Note that if the dataset sequences are pre-padded (e.g., using
nvt.ops.ListSlice(-SESSIONS_MAX_LENGTH, pad=True, pad_value=0)
in the nvt workflow), the model training with pre-trained embeddings works as expected. However, that solution is not ideal since the padding is done in the right side of the sequence, and all the sequence are padded to theSESSIONS_MAX_LENGTH
value instead of padding to the max length of each batch.The text was updated successfully, but these errors were encountered: