diff --git a/abandoned_source/minibatching_try/json_data_module.py b/abandoned_source/minibatching_try/json_data_module.py new file mode 100644 index 0000000..9ed75dc --- /dev/null +++ b/abandoned_source/minibatching_try/json_data_module.py @@ -0,0 +1,154 @@ +from typing import Union +from pathlib import Path +import json +from itertools import groupby +from collections import defaultdict +import string +from numbers import Number + +from more_itertools import sort_together + +import torch +from torch.utils.data import DataLoader, Dataset +import pytorch_lightning as pl +from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence + +from json_parser import JSONParseTree + +ALL_CHARACTERS = chr(3) + string.printable +NUM_CHARACTERS = len(ALL_CHARACTERS) +JSON_TYPES = ('___stop___', + '___object___', '___array___', + '___string___', '___number___', + '___bool___', '___null___') + +JSON_PRIMITIVES = JSON_TYPES[3:] + +NUM_JSON_TYPES = len(JSON_TYPES) + + +def collate_tree_batches(batch): + max_sequence_lengths = {identifier: max([len(sample_data['leaf_data'])]) + for sample in batch + for identifier, sample_data in sample.items()} + tensor_dict = defaultdict(list) + tree_dict = defaultdict(list) + type_dict = defaultdict(list) + tree_index_dict = defaultdict(list) + + unique_trees = list({value['parse_tree'] for data in batch for identifier, value in data.items()}) + + for sample in batch: + for identifier, sample_data in sample.items(): + tensor_dict[identifier].append(sample_data['leaf_data']) + tree_dict[identifier].append(unique_trees) + type_dict[identifier].append(sample_data['type']) + tree_index_dict[identifier].append(unique_trees.index(sample_data['parse_tree'])) + + collated_samples = {} + for identifier in tensor_dict.keys(): + trees = tree_dict[identifier] + tree_index = tree_index_dict[identifier] + # TODO: isinstance check is not enough, to perfectly separate cases we need type information from the parse tree + types = type_dict[identifier] + tensors = tensor_dict[identifier] # pad_sequence(tensors, padding_value=0) if isinstance(tensors[0], torch.LongTensor) else torch.cat(tensors, dim=0) + type_masks = {type: torch.BoolTensor([tp == type for tp in types]) for type in types} + masked_tensors = {type: cat_or_pad([ + tensor for tensor, m in zip(tensors, mask) if m + ], type=type) + for type, mask in type_masks.items() + } + collated_samples[identifier] = { + 'type': types, + 'leaf_data': masked_tensors, + 'parse_trees': trees, + 'tree_index': tree_index, + } + + return collated_samples + + +def cat_or_pad(tensors, type): + if type == '___string___': + return pad_sequence(tensors, padding_value=0) + + return torch.cat(tensors, dim=0) + + +class LeafDataToTensor: + """Convert leaf data to tensors""" + all_characters = string.printable + + def __call__(self, sample): + tensor_sample = {identifier: { + 'type': data['type'], + 'leaf_data': self._transform(data['leaf_data']), + 'parse_tree': data['parse_tree'] + } + for identifier, data in sample.items()} + + return tensor_sample + + def _transform(self, value): + + if isinstance(value, Number) or isinstance(value, bool): + data = torch.Tensor([[value]]) + elif isinstance(value, str): + data = torch.LongTensor( + [ALL_CHARACTERS.index(char) for char in value] + ) + else: + data = torch.zeros(1, 1) + + return data + + +class SimpleDataset(Dataset): + """Simple dataset for json data""" + + def __init__(self, data_file, transform=None): + + with open(data_file, 'r') as json_file: + self.jsons = json.load(json_file) + self.transform = transform + + def __len__(self): + return len(self.jsons) + + def __getitem__(self, idx): + if not isinstance(idx, int): + raise ValueError('Index has to be a single integer') + + tree = JSONParseTree.parse_start('___root___', self.jsons[idx]) + + sample = {identifier: {'type': datum.type, 'leaf_data': datum.data, 'parse_tree': tree} + for identifier, datum in tree.leaf_data()} + + if self.transform: + sample = self.transform(sample) + + return sample + + +class JSONDataModule(pl.LightningDataModule): + + def __init__(self, data_dir: Union[str, Path] = './'): + super().__init__() + + self.data_dir = Path(data_dir) + + def setup(self, stage=None): + self.json_data = SimpleDataset(self.data_dir, transform=LeafDataToTensor()) + + def train_dataloader(self): + return DataLoader(self.json_data, collate_fn=collate_tree_batches, batch_size=4) + + +if __name__ == '__main__': + from pprint import pprint + + jsons = JSONDataModule('../../some_json.json') + jsons.setup() + + for batch in jsons.train_dataloader(): + pprint(batch) diff --git a/abandoned_source/minibatching_try/json_lstm_encoder.py b/abandoned_source/minibatching_try/json_lstm_encoder.py new file mode 100644 index 0000000..65363b2 --- /dev/null +++ b/abandoned_source/minibatching_try/json_lstm_encoder.py @@ -0,0 +1,233 @@ +from typing import Callable, Dict, Any +from collections import defaultdict, namedtuple +import heapq + +import torch +from torch import nn +from more_itertools import sort_together + +from json_data_module import NUM_CHARACTERS, JSONDataModule, JSON_TYPES, JSON_PRIMITIVES +from json_parser import JSONParseTree + + +NodeEmbedding = namedtuple('NodeEmbedding', ['memory', 'hidden']) + + +def first_true(iterable, default=False, pred=None): + """Returns the first true value in the iterable. + + If no true value is found, returns *default* + + If *pred* is not None, returns the first item + for which pred(item) is true. + + """ + # first_true([a,b,c], x) --> a or b or c or x + # first_true([a,b], x, f) --> a if f(a) else b if f(b) else x + return next(filter(pred, iterable), default) + + +class DefaultModuleDict(nn.ModuleDict): + def __init__(self, default_factory: Callable, *args, **kwargs): + super().__init__(*args, **kwargs) + self.default_factory = default_factory + + def __getitem__(self, item): + try: + return super(DefaultModuleDict, self).__getitem__(item) + except (NameError, KeyError): + return self.__missing__(item) + + def __missing__(self, key): + # Taken from json2vec.py + if self.default_factory is None: + raise RuntimeError('default_factory is not set') + else: + ret = self[key] = self.default_factory() + return ret + +class TypeModule: + def __init__(self, default_factory: Callable): + self.default_factory = default_factory + + def __set_name__(self, owner, name): + self.public_name = name + self.private_name = '_' + name + + def __get__(self, obj, objtype=None): + if obj not in getattr(obj, self.private_name): + setattr(obj, self.private_name, nn.ModuleDict()) + + +class ChildSumTreeLSTM(nn.Module): + def __init__(self, mem_dim: int): + super().__init__() + self.mem_dim = mem_dim + + self.childsum_forget = nn.Linear(mem_dim, mem_dim) + self.childsum_iou = nn.Linear(mem_dim, 3 * mem_dim) + + def forward(self, children_memory: torch.Tensor, children_hidden: torch.Tensor): + """ + Tensor shape: object_size X sample_indices X embedding_size + """ + hidden_sum = torch.sum(children_hidden, dim=0) + + forget_gates = torch.sigmoid(self.childsum_forget(children_hidden)) + sigmoid_gates, tanh_gate = torch.split( + self.childsum_iou(hidden_sum), + (2 * self.mem_dim, self.mem_dim), + dim=1 + ) + input_gate, output_gate = torch.split( + torch.sigmoid(sigmoid_gates), + (self.mem_dim, self.mem_dim), + dim=1 + ) + memory_gate = torch.tanh(tanh_gate) + node_memory = input_gate * memory_gate + torch.sum(forget_gates * children_memory, dim=0) + node_hidden = output_gate * torch.tanh(node_memory) + + return NodeEmbedding(node_memory, node_hidden) + + +class JSONLSTMEncoder(nn.Module): + def __init__(self, mem_dim): + super().__init__() + + self.mem_dim = mem_dim + + self.__build_model() + + def __build_model(self): + padding_index = 0 + self.string_embedding = nn.Embedding( + num_embeddings=NUM_CHARACTERS, + embedding_dim = self.mem_dim, + padding_idx=padding_index, + ) + self.object_lstm = ChildSumTreeLSTM(self.mem_dim) + + self._string_modules = DefaultModuleDict(lambda: nn.LSTM(self.mem_dim, self.mem_dim)) + self._number_modules = DefaultModuleDict(lambda: nn.Linear(1, self.mem_dim)) + self._bool_modules = DefaultModuleDict(lambda: nn.Linear(1, self.mem_dim)) + self._null_modules = DefaultModuleDict(lambda: nn.Linear(1, self.mem_dim)) + self._array_modules = DefaultModuleDict(lambda: nn.LSTM(2 * self.mem_dim, self.mem_dim)) + self._object_modules = DefaultModuleDict(lambda: nn.Linear(self.mem_dim, self.mem_dim)) + + def _empty_tensor(self, batch_size): + return torch.zeros(batch_size, self.mem_dim) + + def forward(self, batch: Dict[str, Any]): + unique_trees = {tree for keys, data in batch.items() for tree in data['parse_tree']} + # tree_indices = [i for i, tree in zip(data['sample_index'], data['parse_tree'])] + if len(unique_trees) == -1: + """Traverse the tree""" + tree: JSONParseTree = unique_trees.pop() + root = tree[tree.root] + root_type = root.data + return self.embed_node(root, root_type, batch) + + root = ('___root___',) + return self.embed_node(root, batch) + + def embed_node(self, node, batch): + # Find batch node type corresponding to each sample index + for child_name, child_data in batch.items(): + batch_size = len(child_data['parse_tree']) + + accumulated_memory = self._empty_tensor(batch_size) + accumulated_hidden = self._empty_tensor(batch_size) + for type, data in child_data['leaf_data'].items(): + index_tensor = torch.LongTensor([[i for i, tp in enumerate(child_data['type']) if tp == type]]*self.mem_dim).t() + if type in JSON_PRIMITIVES: + result = self.embed_leaf(child_name, type, data) + accumulated_memory.scatter_(dim=0, index=index_tensor, src=result.memory) + accumulated_hidden.scatter_(dim=0, index=index_tensor, src=result.hidden) + a = 1 + 2 + """ + type_indices = { + node_type: torch.arange(len(data['type']))[ + torch.BoolTensor([True if type == node_type else False for type in data['type']])] + for node_type in JSON_TYPES[1:] + } + batch_size = len(data['type']) + accumulated_memory = self._empty_tensor(batch_size) + accumulated_hidden = self._empty_tensor((batch_size)) + for child_type, child_index in type_indices.items(): + if len(child_index) == 0: + continue + temp_value_tensor = data['leaf_data'].index_select(1, child_index) + temp_batch = {'type': child_type, 'leaf_data': temp_value_tensor, 'parse_tree': data['parse_tree']} + if child_type in JSON_PRIMITIVES: + child_embeddings = self.embed_leaf(identifier, temp_batch) + accumulation_mask = torch.IntTensor([[1] if i in child_index else [0] for i in range(batch_size)]) + accumulated_memory += accumulation_mask * child_embeddings.memory + accumulated_hidden += accumulation_mask * child_embeddings.hidden + + node_embeddings = NodeEmbedding(accumulated_memory, accumulated_hidden) + a = 1 + 2 + """ + + def embed_leaf(self, identifier, node_type, tensors): + if node_type == '___string___': + node_embedding = self.embed_string(tensors, identifier) + elif node_type == '___number___': + node_embedding = self.embed_number(tensors, identifier) + elif node_type == '___bool___': + node_embedding = self.embed_number(tensors, identifier) + elif node_type == '___null___': + node_embedding = self.embed_number(tensors, identifier) + else: + raise ValueError(f'node is of unknown type {node_type}') + + return node_embedding + + def embed_object(self, identifier, node_embeddings: NodeEmbedding) -> NodeEmbedding: + memory, hidden = node_embeddings + memory, hidden = self.object_lstm(memory, hidden) + hidden = self._object_modules[str(identifier)](hidden) + + return NodeEmbedding(memory, hidden) + + def embed_array(self, identifier, node_embeddings: NodeEmbedding): + memory, hidden = node_embeddings + + + def embed_string(self, string_batch, key): + batch_size = string_batch.shape[1] + string_embeddings = self.string_embedding(string_batch) + _, (memory, hidden) = self._string_modules[str(key)](string_embeddings) + return NodeEmbedding(self._empty_tensor(batch_size), hidden.view(batch_size, -1)) + + def embed_number(self, number_batch: torch.Tensor, key): + batch_size = len(number_batch) + if len(number_batch) > 1: + # TODO: This is unstable and should be fixed vvvvvvvvvvvvvvvvvvvvv + number_batch = (number_batch - torch.mean(number_batch, dim=0)) / (torch.Tensor((1e-21,)) + torch.std(number_batch, dim=0)) + + return NodeEmbedding(self._empty_tensor(batch_size), self._number_modules[str(key)](number_batch)) + + def embed_bool(self, bool_batch, key): + batch_size = len(bool_batch) + return NodeEmbedding(self._empty_tensor(batch_size), self._bool_modules[str(key)](bool_batch)) + + def embed_null(self, null_batch, key): + batch_size = len(null_batch) + return NodeEmbedding(self._empty_tensor(batch_size), self._empty_tensor(batch_size)) + + +if __name__ == '__main__': + data_module = JSONDataModule('../../some_json.json') + + test = JSONLSTMEncoder(128) + + data_module.setup() + for batch in data_module.train_dataloader(): + print('#### NEW BATCH ####') + print(test(batch).memory.shape) + + + #print([module for module in test.named_modules()]) + + diff --git a/json_data_module.py b/json_data_module.py index f912fa9..2ecf91f 100644 --- a/json_data_module.py +++ b/json_data_module.py @@ -15,32 +15,17 @@ from json_parser import JSONParseTree + ALL_CHARACTERS = chr(3) + string.printable NUM_CHARACTERS = len(ALL_CHARACTERS) +JSON_TYPES = ('___stop___', + '___object___', '___array___', + '___string___', '___number___', + '___bool___', '___null___') +JSON_PRIMITIVES = JSON_TYPES[3:] -def collate_tree_batches(batch): - max_sequence_lengths = {identifier: max([len(sample_data['leaf_data'])]) - for sample in batch - for identifier, sample_data in sample.items()} - tensor_dict = defaultdict(list) - index_dict = defaultdict(list) - - for sample in batch: - for identifier, sample_data in sample.items(): - tensor_dict[identifier].append(sample_data['leaf_data']) - index_dict[identifier].append(sample_data['sample_index']) - - collated_samples = {} - for (identifier, index) in index_dict.items(): - tensors = tensor_dict[identifier] - # TODO: isinstance check is not enough, to perfectly separate cases we need type information from the parse tree - tensors = pad_sequence(tensors, padding_value=0) if isinstance(tensors[0], torch.LongTensor) else torch.cat(tensors, dim=0) - index = torch.cat(index, dim=0) - collated_samples[identifier] = {'sample_index': index, 'leaf_data': tensors} - - - return collated_samples +NUM_JSON_TYPES = len(JSON_TYPES) class LeafDataToTensor: @@ -48,8 +33,12 @@ class LeafDataToTensor: all_characters = string.printable def __call__(self, sample): - tensor_sample = {identifier: {'sample_index': torch.LongTensor([[data['sample_index']]]), 'leaf_data': self._transform(data['leaf_data'])} - for identifier, data in sample.items()} + tensor_sample = {identifier: { + 'type': data['type'], + 'leaf_data': self._transform(data['leaf_data']), + 'parse_tree': data['parse_tree'] + } + for identifier, data in sample.items()} return tensor_sample @@ -66,6 +55,7 @@ def _transform(self, value): return data + class SimpleDataset(Dataset): """Simple dataset for json data""" @@ -74,7 +64,7 @@ def __init__(self, data_file, transform=None): with open(data_file, 'r') as json_file: self.jsons = json.load(json_file) self.transform = transform - + def __len__(self): return len(self.jsons) @@ -84,7 +74,8 @@ def __getitem__(self, idx): tree = JSONParseTree.parse_start('___root___', self.jsons[idx]) - sample = {identifier: {'sample_index': idx, 'leaf_data': datum} for identifier, datum in tree.leaf_data()} + sample = {identifier: {'type': datum.type, 'leaf_data': datum.data, 'parse_tree': tree} + for identifier, datum in tree.leaf_data()} if self.transform: sample = self.transform(sample) @@ -103,17 +94,14 @@ def setup(self, stage=None): self.json_data = SimpleDataset(self.data_dir, transform=LeafDataToTensor()) def train_dataloader(self): - return DataLoader(self.json_data, collate_fn=collate_tree_batches, batch_size=4) + return DataLoader(self.json_data, collate_fn=lambda batch: batch, batch_size=4) if __name__ == '__main__': - from pprint import pprint - jsons = JSONDataModule('some_json.json') - jsons.setup() - - for batch in jsons.train_dataloader(): - pprint(batch) - - - - + data = SimpleDataset('some_json.json') + data_module = JSONDataModule('some_json.json'); data_module.setup() + print(data[0]) + for i, batch in enumerate(data_module.train_dataloader()): + if i != 0: + continue + print(len(batch)) diff --git a/json_parser.py b/json_parser.py index 56147d7..4734738 100644 --- a/json_parser.py +++ b/json_parser.py @@ -5,10 +5,14 @@ from numbers import Number from more_itertools import sort_together, split_when, bucket from itertools import groupby +from collections import namedtuple import torch +NodeData = namedtuple('NodeData', ['type', 'data']) + + def is_primitive(node): return ( isinstance(node, Number) @@ -17,6 +21,22 @@ def is_primitive(node): or node is None ) +def node_type(node): + if isinstance(node, Number): + return '___number___' + elif isinstance(node, bool): + return '___bool___' + elif isinstance(node, str): + return '___string___' + elif isinstance(node, dict): + return '___object___' + elif isinstance(node, list): + return '___array___' + elif node is None: + return '___null___' + else: + raise ValueError('node must be of type numbers.Number, bool, str, or NoneType') + def is_array(node): return ( @@ -39,7 +59,7 @@ def __init__(self, *args, **kwargs): @classmethod def parse_start(cls, root_name: str, node: Any) -> JSONParseTree: tree = cls() - tree.create_node(tag=root_name, identifier=(root_name,), data=node) + tree.create_node(tag=root_name, identifier=(root_name,), data=NodeData(node_type(node), None)) if is_array(node): for i, child in enumerate(node): tree.parse_object((root_name,), str(i), child) @@ -50,12 +70,18 @@ def parse_start(cls, root_name: str, node: Any) -> JSONParseTree: def parse_object(self, parent_path: tuple, name: str, node: Any): if is_primitive(node): - self.create_node(tag=name, identifier=parent_path + (name,), parent=parent_path, data=node) + self.create_node( + tag=name, + identifier=parent_path + (name,), + parent=parent_path, + data=NodeData(node_type(node), node) + ) elif is_array(node): self.create_node( tag=name, identifier=(new_path := parent_path + (name,)), - parent=parent_path, data='___array___' + parent=parent_path, + data=NodeData('___array___', None) ) for i, child in enumerate(node): self.parse_object(new_path, str(i), child) @@ -63,11 +89,12 @@ def parse_object(self, parent_path: tuple, name: str, node: Any): self.create_node( tag=name, identifier=(new_path := parent_path + (name,)), - parent=parent_path, data='___dict___') + parent=parent_path, + data=NodeData('___dict___', None)) for child_name, child in node.items(): self.parse_object(new_path, child_name, child) - def leaf_data(self) -> Tuple[Tuple, Any]: + def leaf_data(self) -> Tuple[Tuple, NodeData]: for leaf in self.leaves(): yield leaf.identifier, leaf.data @@ -83,6 +110,17 @@ def leaf_tensors(self) -> Tuple[Tuple, Any]: data = torch.zeros(1, 1) yield leaf.identifier, data + def __eq__(self, other): + if not isinstance(other, JSONParseTree): + return False + self_nodes = sorted(self.nodes) + other_nodes = sorted(other.nodes) + + return self_nodes == other_nodes + + def __hash__(self): + return hash(tuple(sorted(self.nodes))) + if __name__ == '__main__': from pprint import pprint diff --git a/some_json.json b/some_json.json index 1d49de1..775e583 100644 --- a/some_json.json +++ b/some_json.json @@ -1 +1,7 @@ -[{"n": "OO_temp_sensor", "t": 0, "u": "K", "v": 290.02483570765054}, {"n": "CC_temp_sensor", "t": 0, "u": "K", "v": 290.032384426905}, {"n": "NW_temp_sensor", "t": 0, "u": "K", "v": 289.98829233126384}, {"n": "NW_Heater", "t": 0, "u": "W", "v": 185.8732269977827}, {"n": "NN_temp_sensor", "t": 0, "u": "K", "v": 290.0789606407754}, {"n": "NN_Heater", "t": 0, "u": "W", "v": 171.3662974759336}, {"n": "NE_temp_sensor", "t": 0, "u": "K", "v": 289.97652628070324}] \ No newline at end of file +[{"n": "OO_temp_sensor", "t": 0, "u": 1, "v": 290.02483570765054}, + {"n": "CC_temp_sensor", "t": [0, 0, 0], "u": "K", "v": 290.032384426905}, + {"n": "NW_temp_sensor", "t": 0, "u": "K", "v": 289.98829233126384}, + {"n": "NW_Heater", "t": 0, "u": "W", "v": 185.8732269977827}, + {"n": "NN_temp_sensor", "t": 0, "u": "K", "v": 290.0789606407754}, + {"n": "NN_Heater", "t": 0, "u": "W", "v": 171.3662974759336}, + {"n": "NE_temp_sensor", "t": 0, "u": ["K", "K"], "v": 289.97652628070324}] \ No newline at end of file