-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
tried making batching work, it's too difficult so I'm reverting back to
going one sample at the time.
- Loading branch information
ajoino
committed
Apr 7, 2021
1 parent
d3e052b
commit c79a8d2
Showing
5 changed files
with
462 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
from typing import Union | ||
from pathlib import Path | ||
import json | ||
from itertools import groupby | ||
from collections import defaultdict | ||
import string | ||
from numbers import Number | ||
|
||
from more_itertools import sort_together | ||
|
||
import torch | ||
from torch.utils.data import DataLoader, Dataset | ||
import pytorch_lightning as pl | ||
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence | ||
|
||
from json_parser import JSONParseTree | ||
|
||
ALL_CHARACTERS = chr(3) + string.printable | ||
NUM_CHARACTERS = len(ALL_CHARACTERS) | ||
JSON_TYPES = ('___stop___', | ||
'___object___', '___array___', | ||
'___string___', '___number___', | ||
'___bool___', '___null___') | ||
|
||
JSON_PRIMITIVES = JSON_TYPES[3:] | ||
|
||
NUM_JSON_TYPES = len(JSON_TYPES) | ||
|
||
|
||
def collate_tree_batches(batch): | ||
max_sequence_lengths = {identifier: max([len(sample_data['leaf_data'])]) | ||
for sample in batch | ||
for identifier, sample_data in sample.items()} | ||
tensor_dict = defaultdict(list) | ||
tree_dict = defaultdict(list) | ||
type_dict = defaultdict(list) | ||
tree_index_dict = defaultdict(list) | ||
|
||
unique_trees = list({value['parse_tree'] for data in batch for identifier, value in data.items()}) | ||
|
||
for sample in batch: | ||
for identifier, sample_data in sample.items(): | ||
tensor_dict[identifier].append(sample_data['leaf_data']) | ||
tree_dict[identifier].append(unique_trees) | ||
type_dict[identifier].append(sample_data['type']) | ||
tree_index_dict[identifier].append(unique_trees.index(sample_data['parse_tree'])) | ||
|
||
collated_samples = {} | ||
for identifier in tensor_dict.keys(): | ||
trees = tree_dict[identifier] | ||
tree_index = tree_index_dict[identifier] | ||
# TODO: isinstance check is not enough, to perfectly separate cases we need type information from the parse tree | ||
types = type_dict[identifier] | ||
tensors = tensor_dict[identifier] # pad_sequence(tensors, padding_value=0) if isinstance(tensors[0], torch.LongTensor) else torch.cat(tensors, dim=0) | ||
type_masks = {type: torch.BoolTensor([tp == type for tp in types]) for type in types} | ||
masked_tensors = {type: cat_or_pad([ | ||
tensor for tensor, m in zip(tensors, mask) if m | ||
], type=type) | ||
for type, mask in type_masks.items() | ||
} | ||
collated_samples[identifier] = { | ||
'type': types, | ||
'leaf_data': masked_tensors, | ||
'parse_trees': trees, | ||
'tree_index': tree_index, | ||
} | ||
|
||
return collated_samples | ||
|
||
|
||
def cat_or_pad(tensors, type): | ||
if type == '___string___': | ||
return pad_sequence(tensors, padding_value=0) | ||
|
||
return torch.cat(tensors, dim=0) | ||
|
||
|
||
class LeafDataToTensor: | ||
"""Convert leaf data to tensors""" | ||
all_characters = string.printable | ||
|
||
def __call__(self, sample): | ||
tensor_sample = {identifier: { | ||
'type': data['type'], | ||
'leaf_data': self._transform(data['leaf_data']), | ||
'parse_tree': data['parse_tree'] | ||
} | ||
for identifier, data in sample.items()} | ||
|
||
return tensor_sample | ||
|
||
def _transform(self, value): | ||
|
||
if isinstance(value, Number) or isinstance(value, bool): | ||
data = torch.Tensor([[value]]) | ||
elif isinstance(value, str): | ||
data = torch.LongTensor( | ||
[ALL_CHARACTERS.index(char) for char in value] | ||
) | ||
else: | ||
data = torch.zeros(1, 1) | ||
|
||
return data | ||
|
||
|
||
class SimpleDataset(Dataset): | ||
"""Simple dataset for json data""" | ||
|
||
def __init__(self, data_file, transform=None): | ||
|
||
with open(data_file, 'r') as json_file: | ||
self.jsons = json.load(json_file) | ||
self.transform = transform | ||
|
||
def __len__(self): | ||
return len(self.jsons) | ||
|
||
def __getitem__(self, idx): | ||
if not isinstance(idx, int): | ||
raise ValueError('Index has to be a single integer') | ||
|
||
tree = JSONParseTree.parse_start('___root___', self.jsons[idx]) | ||
|
||
sample = {identifier: {'type': datum.type, 'leaf_data': datum.data, 'parse_tree': tree} | ||
for identifier, datum in tree.leaf_data()} | ||
|
||
if self.transform: | ||
sample = self.transform(sample) | ||
|
||
return sample | ||
|
||
|
||
class JSONDataModule(pl.LightningDataModule): | ||
|
||
def __init__(self, data_dir: Union[str, Path] = './'): | ||
super().__init__() | ||
|
||
self.data_dir = Path(data_dir) | ||
|
||
def setup(self, stage=None): | ||
self.json_data = SimpleDataset(self.data_dir, transform=LeafDataToTensor()) | ||
|
||
def train_dataloader(self): | ||
return DataLoader(self.json_data, collate_fn=collate_tree_batches, batch_size=4) | ||
|
||
|
||
if __name__ == '__main__': | ||
from pprint import pprint | ||
|
||
jsons = JSONDataModule('../../some_json.json') | ||
jsons.setup() | ||
|
||
for batch in jsons.train_dataloader(): | ||
pprint(batch) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,233 @@ | ||
from typing import Callable, Dict, Any | ||
from collections import defaultdict, namedtuple | ||
import heapq | ||
|
||
import torch | ||
from torch import nn | ||
from more_itertools import sort_together | ||
|
||
from json_data_module import NUM_CHARACTERS, JSONDataModule, JSON_TYPES, JSON_PRIMITIVES | ||
from json_parser import JSONParseTree | ||
|
||
|
||
NodeEmbedding = namedtuple('NodeEmbedding', ['memory', 'hidden']) | ||
|
||
|
||
def first_true(iterable, default=False, pred=None): | ||
"""Returns the first true value in the iterable. | ||
If no true value is found, returns *default* | ||
If *pred* is not None, returns the first item | ||
for which pred(item) is true. | ||
""" | ||
# first_true([a,b,c], x) --> a or b or c or x | ||
# first_true([a,b], x, f) --> a if f(a) else b if f(b) else x | ||
return next(filter(pred, iterable), default) | ||
|
||
|
||
class DefaultModuleDict(nn.ModuleDict): | ||
def __init__(self, default_factory: Callable, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.default_factory = default_factory | ||
|
||
def __getitem__(self, item): | ||
try: | ||
return super(DefaultModuleDict, self).__getitem__(item) | ||
except (NameError, KeyError): | ||
return self.__missing__(item) | ||
|
||
def __missing__(self, key): | ||
# Taken from json2vec.py | ||
if self.default_factory is None: | ||
raise RuntimeError('default_factory is not set') | ||
else: | ||
ret = self[key] = self.default_factory() | ||
return ret | ||
|
||
class TypeModule: | ||
def __init__(self, default_factory: Callable): | ||
self.default_factory = default_factory | ||
|
||
def __set_name__(self, owner, name): | ||
self.public_name = name | ||
self.private_name = '_' + name | ||
|
||
def __get__(self, obj, objtype=None): | ||
if obj not in getattr(obj, self.private_name): | ||
setattr(obj, self.private_name, nn.ModuleDict()) | ||
|
||
|
||
class ChildSumTreeLSTM(nn.Module): | ||
def __init__(self, mem_dim: int): | ||
super().__init__() | ||
self.mem_dim = mem_dim | ||
|
||
self.childsum_forget = nn.Linear(mem_dim, mem_dim) | ||
self.childsum_iou = nn.Linear(mem_dim, 3 * mem_dim) | ||
|
||
def forward(self, children_memory: torch.Tensor, children_hidden: torch.Tensor): | ||
""" | ||
Tensor shape: object_size X sample_indices X embedding_size | ||
""" | ||
hidden_sum = torch.sum(children_hidden, dim=0) | ||
|
||
forget_gates = torch.sigmoid(self.childsum_forget(children_hidden)) | ||
sigmoid_gates, tanh_gate = torch.split( | ||
self.childsum_iou(hidden_sum), | ||
(2 * self.mem_dim, self.mem_dim), | ||
dim=1 | ||
) | ||
input_gate, output_gate = torch.split( | ||
torch.sigmoid(sigmoid_gates), | ||
(self.mem_dim, self.mem_dim), | ||
dim=1 | ||
) | ||
memory_gate = torch.tanh(tanh_gate) | ||
node_memory = input_gate * memory_gate + torch.sum(forget_gates * children_memory, dim=0) | ||
node_hidden = output_gate * torch.tanh(node_memory) | ||
|
||
return NodeEmbedding(node_memory, node_hidden) | ||
|
||
|
||
class JSONLSTMEncoder(nn.Module): | ||
def __init__(self, mem_dim): | ||
super().__init__() | ||
|
||
self.mem_dim = mem_dim | ||
|
||
self.__build_model() | ||
|
||
def __build_model(self): | ||
padding_index = 0 | ||
self.string_embedding = nn.Embedding( | ||
num_embeddings=NUM_CHARACTERS, | ||
embedding_dim = self.mem_dim, | ||
padding_idx=padding_index, | ||
) | ||
self.object_lstm = ChildSumTreeLSTM(self.mem_dim) | ||
|
||
self._string_modules = DefaultModuleDict(lambda: nn.LSTM(self.mem_dim, self.mem_dim)) | ||
self._number_modules = DefaultModuleDict(lambda: nn.Linear(1, self.mem_dim)) | ||
self._bool_modules = DefaultModuleDict(lambda: nn.Linear(1, self.mem_dim)) | ||
self._null_modules = DefaultModuleDict(lambda: nn.Linear(1, self.mem_dim)) | ||
self._array_modules = DefaultModuleDict(lambda: nn.LSTM(2 * self.mem_dim, self.mem_dim)) | ||
self._object_modules = DefaultModuleDict(lambda: nn.Linear(self.mem_dim, self.mem_dim)) | ||
|
||
def _empty_tensor(self, batch_size): | ||
return torch.zeros(batch_size, self.mem_dim) | ||
|
||
def forward(self, batch: Dict[str, Any]): | ||
unique_trees = {tree for keys, data in batch.items() for tree in data['parse_tree']} | ||
# tree_indices = [i for i, tree in zip(data['sample_index'], data['parse_tree'])] | ||
if len(unique_trees) == -1: | ||
"""Traverse the tree""" | ||
tree: JSONParseTree = unique_trees.pop() | ||
root = tree[tree.root] | ||
root_type = root.data | ||
return self.embed_node(root, root_type, batch) | ||
|
||
root = ('___root___',) | ||
return self.embed_node(root, batch) | ||
|
||
def embed_node(self, node, batch): | ||
# Find batch node type corresponding to each sample index | ||
for child_name, child_data in batch.items(): | ||
batch_size = len(child_data['parse_tree']) | ||
|
||
accumulated_memory = self._empty_tensor(batch_size) | ||
accumulated_hidden = self._empty_tensor(batch_size) | ||
for type, data in child_data['leaf_data'].items(): | ||
index_tensor = torch.LongTensor([[i for i, tp in enumerate(child_data['type']) if tp == type]]*self.mem_dim).t() | ||
if type in JSON_PRIMITIVES: | ||
result = self.embed_leaf(child_name, type, data) | ||
accumulated_memory.scatter_(dim=0, index=index_tensor, src=result.memory) | ||
accumulated_hidden.scatter_(dim=0, index=index_tensor, src=result.hidden) | ||
a = 1 + 2 | ||
""" | ||
type_indices = { | ||
node_type: torch.arange(len(data['type']))[ | ||
torch.BoolTensor([True if type == node_type else False for type in data['type']])] | ||
for node_type in JSON_TYPES[1:] | ||
} | ||
batch_size = len(data['type']) | ||
accumulated_memory = self._empty_tensor(batch_size) | ||
accumulated_hidden = self._empty_tensor((batch_size)) | ||
for child_type, child_index in type_indices.items(): | ||
if len(child_index) == 0: | ||
continue | ||
temp_value_tensor = data['leaf_data'].index_select(1, child_index) | ||
temp_batch = {'type': child_type, 'leaf_data': temp_value_tensor, 'parse_tree': data['parse_tree']} | ||
if child_type in JSON_PRIMITIVES: | ||
child_embeddings = self.embed_leaf(identifier, temp_batch) | ||
accumulation_mask = torch.IntTensor([[1] if i in child_index else [0] for i in range(batch_size)]) | ||
accumulated_memory += accumulation_mask * child_embeddings.memory | ||
accumulated_hidden += accumulation_mask * child_embeddings.hidden | ||
node_embeddings = NodeEmbedding(accumulated_memory, accumulated_hidden) | ||
a = 1 + 2 | ||
""" | ||
|
||
def embed_leaf(self, identifier, node_type, tensors): | ||
if node_type == '___string___': | ||
node_embedding = self.embed_string(tensors, identifier) | ||
elif node_type == '___number___': | ||
node_embedding = self.embed_number(tensors, identifier) | ||
elif node_type == '___bool___': | ||
node_embedding = self.embed_number(tensors, identifier) | ||
elif node_type == '___null___': | ||
node_embedding = self.embed_number(tensors, identifier) | ||
else: | ||
raise ValueError(f'node is of unknown type {node_type}') | ||
|
||
return node_embedding | ||
|
||
def embed_object(self, identifier, node_embeddings: NodeEmbedding) -> NodeEmbedding: | ||
memory, hidden = node_embeddings | ||
memory, hidden = self.object_lstm(memory, hidden) | ||
hidden = self._object_modules[str(identifier)](hidden) | ||
|
||
return NodeEmbedding(memory, hidden) | ||
|
||
def embed_array(self, identifier, node_embeddings: NodeEmbedding): | ||
memory, hidden = node_embeddings | ||
|
||
|
||
def embed_string(self, string_batch, key): | ||
batch_size = string_batch.shape[1] | ||
string_embeddings = self.string_embedding(string_batch) | ||
_, (memory, hidden) = self._string_modules[str(key)](string_embeddings) | ||
return NodeEmbedding(self._empty_tensor(batch_size), hidden.view(batch_size, -1)) | ||
|
||
def embed_number(self, number_batch: torch.Tensor, key): | ||
batch_size = len(number_batch) | ||
if len(number_batch) > 1: | ||
# TODO: This is unstable and should be fixed vvvvvvvvvvvvvvvvvvvvv | ||
number_batch = (number_batch - torch.mean(number_batch, dim=0)) / (torch.Tensor((1e-21,)) + torch.std(number_batch, dim=0)) | ||
|
||
return NodeEmbedding(self._empty_tensor(batch_size), self._number_modules[str(key)](number_batch)) | ||
|
||
def embed_bool(self, bool_batch, key): | ||
batch_size = len(bool_batch) | ||
return NodeEmbedding(self._empty_tensor(batch_size), self._bool_modules[str(key)](bool_batch)) | ||
|
||
def embed_null(self, null_batch, key): | ||
batch_size = len(null_batch) | ||
return NodeEmbedding(self._empty_tensor(batch_size), self._empty_tensor(batch_size)) | ||
|
||
|
||
if __name__ == '__main__': | ||
data_module = JSONDataModule('../../some_json.json') | ||
|
||
test = JSONLSTMEncoder(128) | ||
|
||
data_module.setup() | ||
for batch in data_module.train_dataloader(): | ||
print('#### NEW BATCH ####') | ||
print(test(batch).memory.shape) | ||
|
||
|
||
#print([module for module in test.named_modules()]) | ||
|
||
|
Oops, something went wrong.