Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support more stypes in LinearModelEncoder #325

Merged
merged 15 commits into from
Jan 9, 2024
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## [Unreleased]

### Added
- Support more stypes in `LinearModelEncoder` ([#325](https://github.com/pyg-team/pytorch-frame/pull/325))
- Added `stype_encoder_dict` to some models ([#319](https://github.com/pyg-team/pytorch-frame/pull/319))

- Added `HuggingFaceDatasetDict` ([#287](https://github.com/pyg-team/pytorch-frame/pull/287))
Expand Down
141 changes: 140 additions & 1 deletion test/nn/encoder/test_stype_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@

import pytest
import torch
from torch.nn import ReLU
from torch import Tensor
from torch.nn import Linear, ReLU, Sequential

import torch_frame
from torch_frame import NAStrategy, stype
from torch_frame.config import ModelConfig
from torch_frame.config.text_embedder import TextEmbedderConfig
from torch_frame.config.text_tokenizer import TextTokenizerConfig
from torch_frame.data.dataset import Dataset
from torch_frame.data.mapper import TimestampTensorMapper
from torch_frame.data.multi_embedding_tensor import MultiEmbeddingTensor
from torch_frame.data.multi_nested_tensor import MultiNestedTensor
from torch_frame.data.stats import StatType
from torch_frame.datasets import FakeDataset
from torch_frame.nn import (
Expand All @@ -23,6 +28,8 @@
StackEncoder,
TimestampEncoder,
)
from torch_frame.nn.encoding import CyclicEncoding
from torch_frame.testing.text_embedder import HashTextEmbedder
from torch_frame.testing.text_tokenizer import (
RandomTextModel,
WhiteSpaceHashTokenizer,
Expand Down Expand Up @@ -426,3 +433,135 @@
assert torch.allclose(
feat_text[key].offset,
tensor_frame.feat_dict[stype.text_tokenized][key].offset)


def test_linear_model_encoder():
num_rows = 20
out_channels = 8
data_stypes = [
torch_frame.numerical,
torch_frame.text_embedded,
zechengz marked this conversation as resolved.
Show resolved Hide resolved
torch_frame.timestamp,
torch_frame.categorical,
torch_frame.multicategorical,
torch_frame.embedding,
]
dataset = FakeDataset(
num_rows=num_rows,
stypes=data_stypes,
col_to_text_embedder_cfg=TextEmbedderConfig(
text_embedder=HashTextEmbedder(out_channels=out_channels),
batch_size=None,
),
)
dataset.materialize()
tensor_frame = dataset.tensor_frame
stats_list = []
col_to_model_cfg = {}
encoder_dict = {}
for data_stype in data_stypes:
data_stype = data_stype.parent
stats_list.extend(
dataset.col_stats[col_name]
for col_name in tensor_frame.col_names_dict[data_stype])
for col_name in tensor_frame.col_names_dict[data_stype]:
if data_stype == torch_frame.embedding:
in_channels = dataset.col_stats[col_name][StatType.EMB_DIM]
model = EmbeddingModel(in_channels, out_channels)
elif data_stype == torch_frame.numerical:
model = NumericalModel(out_channels)
elif data_stype == torch_frame.timestamp:
model = TimestampModel(out_channels)
elif data_stype == torch_frame.categorical:
count_index, _ = dataset.col_stats[col_name][StatType.COUNT]
model = CategoricalModel(len(count_index), out_channels)
elif data_stype == torch_frame.multicategorical:
count_index, _ = dataset.col_stats[col_name][
StatType.MULTI_COUNT]
model = MultiCategoricalModel(len(count_index), out_channels)
else:
raise ValueError(f"Stype {data_stype} not supported")

Check warning on line 483 in test/nn/encoder/test_stype_encoder.py

View check run for this annotation

Codecov / codecov/patch

test/nn/encoder/test_stype_encoder.py#L483

Added line #L483 was not covered by tests
col_to_model_cfg[col_name] = ModelConfig(model=model,
out_channels=out_channels)

encoder_dict[data_stype] = LinearModelEncoder(
out_channels=out_channels,
stats_list=stats_list,
stype=data_stype,
col_to_model_cfg=col_to_model_cfg,
)

for data_stype in data_stypes:
data_stype = data_stype.parent
col_names = tensor_frame.col_names_dict[data_stype]
x = encoder_dict[data_stype](tensor_frame.feat_dict[data_stype],
col_names)
assert x.shape == (
num_rows,
len(tensor_frame.col_names_dict[data_stype]),
out_channels,
)


class EmbeddingModel(torch.nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.mlp = Sequential(Linear(in_channels, out_channels), ReLU(),
Linear(out_channels, out_channels))

def forward(self, x: MultiEmbeddingTensor) -> Tensor:
# [batch_size, 1, embedding_size]
return self.mlp(x.values.unsqueeze(dim=1))


class NumericalModel(torch.nn.Module):
def __init__(self, out_channels: int):
super().__init__()
self.mlp = Sequential(Linear(1, out_channels), ReLU(),
Linear(out_channels, out_channels))

def forward(self, x: Tensor) -> Tensor:
# [batch_size, 1, 1] -> [batch_size, 1, out_channels]
return self.mlp(x)


class TimestampModel(torch.nn.Module):
def __init__(self, out_channels: int):
super().__init__()
self.weight = torch.nn.Parameter(
torch.empty(
len(TimestampTensorMapper.TIME_TO_INDEX),
out_channels,
out_channels,
))
self.cyclic_encoding = CyclicEncoding(out_size=out_channels)

def forward(self, x: Tensor) -> Tensor:
# [batch_size, 1, num_time_feats]
x = x.to(torch.float32)
# [batch_size, 1, num_time_feats, out_channels]
x_cyclic = self.cyclic_encoding(x / x.max())
# [batch_size, 1, out_channels]
return torch.einsum('ijk,jkl->il', x_cyclic.squeeze(1),
self.weight).unsqueeze(dim=1)


class CategoricalModel(torch.nn.Module):
def __init__(self, num_categories: int, out_channels: int):
super().__init__()
self.emb = torch.nn.Embedding(num_categories, out_channels)

def forward(self, x: Tensor) -> Tensor:
# [batch_size, 1, 1] -> [batch_size, 1]
x = x.squeeze(dim=1)
# [batch_size, 1] -> [batch_size, 1, out_channels]
return self.emb(x)


class MultiCategoricalModel(torch.nn.Module):
def __init__(self, num_categories: int, out_channels: int):
super().__init__()
self.emb = torch.nn.EmbeddingBag(num_categories, out_channels)

def forward(self, x: MultiNestedTensor) -> Tensor:
return self.emb(x.values, x.offset[:-1]).unsqueeze(dim=1)
34 changes: 26 additions & 8 deletions torch_frame/nn/encoder/stype_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,12 +752,15 @@ class LinearModelEncoder(StypeEncoder):
:obj:`[batch_size, 1, model_out_channels]`.
"""

# NOTE: We currently support text embeddings but in principle, this encoder
# can support any model outputs embeddings, including image/audio/graph
# embeddings.
# NOTE: This can in principle support any stypes and allow MLP-based
# non-linear modeling for each column.
supported_stypes = {stype.text_tokenized}
supported_stypes = {
stype.text_embedded,
stype.text_tokenized,
stype.numerical,
zechengz marked this conversation as resolved.
Show resolved Hide resolved
stype.embedding,
stype.timestamp,
stype.categorical,
stype.multicategorical,
}

def __init__(
self,
Expand Down Expand Up @@ -812,14 +815,29 @@ def encode_forward(
) -> Tensor:
xs = []
for i, col_name in enumerate(col_names):
# [batch_size, 1, in_channels]
if self.stype.use_dict_multi_nested_tensor:
# [batch_size, 1, in_channels]
x = self.model_dict[col_name]({
key: feat[key][:, i]
for key in feat
})
else:
x = self.model_dict[col_name](feat[:, i])
input_feat = feat[:, i]

# Numerical and categorical cases etc.:
if input_feat.ndim == 1:
input_feat = input_feat.view(-1, 1, 1)
elif input_feat.ndim == 2:
input_feat = input_feat.unsqueeze(dim=1)
zechengz marked this conversation as resolved.
Show resolved Hide resolved

assert input_feat.ndim == 3
if isinstance(input_feat, Tensor):
batch_size = input_feat.size(0)
else:
batch_size = input_feat.num_rows
assert input_feat.shape[:2] == (batch_size, 1)
zechengz marked this conversation as resolved.
Show resolved Hide resolved

x = self.model_dict[col_name](input_feat)
weihua916 marked this conversation as resolved.
Show resolved Hide resolved
# [batch_size, 1, out_channels]
x_lin = x @ self.weight_dict[col_name] + self.bias_dict[col_name]
xs.append(x_lin)
Expand Down
Loading