diff --git a/.gitignore b/.gitignore index 0fca5cb..bf72e12 100644 --- a/.gitignore +++ b/.gitignore @@ -140,3 +140,6 @@ cython_debug/ # Pycharm files .idea/ + +# Pytorch stuff +lightning_logs/ \ No newline at end of file diff --git a/json_data_module.py b/json_data_module.py new file mode 100644 index 0000000..f912fa9 --- /dev/null +++ b/json_data_module.py @@ -0,0 +1,119 @@ +from typing import Union +from pathlib import Path +import json +from itertools import groupby +from collections import defaultdict +import string +from numbers import Number + +from more_itertools import sort_together + +import torch +from torch.utils.data import DataLoader, Dataset +import pytorch_lightning as pl +from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence + +from json_parser import JSONParseTree + +ALL_CHARACTERS = chr(3) + string.printable +NUM_CHARACTERS = len(ALL_CHARACTERS) + + +def collate_tree_batches(batch): + max_sequence_lengths = {identifier: max([len(sample_data['leaf_data'])]) + for sample in batch + for identifier, sample_data in sample.items()} + tensor_dict = defaultdict(list) + index_dict = defaultdict(list) + + for sample in batch: + for identifier, sample_data in sample.items(): + tensor_dict[identifier].append(sample_data['leaf_data']) + index_dict[identifier].append(sample_data['sample_index']) + + collated_samples = {} + for (identifier, index) in index_dict.items(): + tensors = tensor_dict[identifier] + # TODO: isinstance check is not enough, to perfectly separate cases we need type information from the parse tree + tensors = pad_sequence(tensors, padding_value=0) if isinstance(tensors[0], torch.LongTensor) else torch.cat(tensors, dim=0) + index = torch.cat(index, dim=0) + collated_samples[identifier] = {'sample_index': index, 'leaf_data': tensors} + + + return collated_samples + + +class LeafDataToTensor: + """Convert leaf data to tensors""" + all_characters = string.printable + + def __call__(self, sample): + tensor_sample = {identifier: {'sample_index': torch.LongTensor([[data['sample_index']]]), 'leaf_data': self._transform(data['leaf_data'])} + for identifier, data in sample.items()} + + return tensor_sample + + def _transform(self, value): + + if isinstance(value, Number) or isinstance(value, bool): + data = torch.Tensor([[value]]) + elif isinstance(value, str): + data = torch.LongTensor( + [ALL_CHARACTERS.index(char) for char in value] + ) + else: + data = torch.zeros(1, 1) + + return data + +class SimpleDataset(Dataset): + """Simple dataset for json data""" + + def __init__(self, data_file, transform=None): + + with open(data_file, 'r') as json_file: + self.jsons = json.load(json_file) + self.transform = transform + + def __len__(self): + return len(self.jsons) + + def __getitem__(self, idx): + if not isinstance(idx, int): + raise ValueError('Index has to be a single integer') + + tree = JSONParseTree.parse_start('___root___', self.jsons[idx]) + + sample = {identifier: {'sample_index': idx, 'leaf_data': datum} for identifier, datum in tree.leaf_data()} + + if self.transform: + sample = self.transform(sample) + + return sample + + +class JSONDataModule(pl.LightningDataModule): + + def __init__(self, data_dir: Union[str, Path] = './'): + super().__init__() + + self.data_dir = Path(data_dir) + + def setup(self, stage=None): + self.json_data = SimpleDataset(self.data_dir, transform=LeafDataToTensor()) + + def train_dataloader(self): + return DataLoader(self.json_data, collate_fn=collate_tree_batches, batch_size=4) + + +if __name__ == '__main__': + from pprint import pprint + jsons = JSONDataModule('some_json.json') + jsons.setup() + + for batch in jsons.train_dataloader(): + pprint(batch) + + + + diff --git a/json_parser.py b/json_parser.py new file mode 100644 index 0000000..56147d7 --- /dev/null +++ b/json_parser.py @@ -0,0 +1,136 @@ +from __future__ import annotations +import string +from typing import Union, Sequence, Mapping, Any, Tuple +from treelib import Tree, Node +from numbers import Number +from more_itertools import sort_together, split_when, bucket +from itertools import groupby + +import torch + + +def is_primitive(node): + return ( + isinstance(node, Number) + or isinstance(node, bool) + or isinstance(node, str) + or node is None + ) + + +def is_array(node): + return ( + isinstance(node, list) + or isinstance(node, tuple) + ) + + +def is_object(node): + return isinstance(node, dict) + + +class JSONParseTree(Tree): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.all_characters = string.printable + self.num_characters = len(self.all_characters) + + @classmethod + def parse_start(cls, root_name: str, node: Any) -> JSONParseTree: + tree = cls() + tree.create_node(tag=root_name, identifier=(root_name,), data=node) + if is_array(node): + for i, child in enumerate(node): + tree.parse_object((root_name,), str(i), child) + elif is_object(node): + for child_name, child in node.items(): + tree.parse_object((root_name,), child_name, child) + return tree + + def parse_object(self, parent_path: tuple, name: str, node: Any): + if is_primitive(node): + self.create_node(tag=name, identifier=parent_path + (name,), parent=parent_path, data=node) + elif is_array(node): + self.create_node( + tag=name, + identifier=(new_path := parent_path + (name,)), + parent=parent_path, data='___array___' + ) + for i, child in enumerate(node): + self.parse_object(new_path, str(i), child) + elif is_object(node): + self.create_node( + tag=name, + identifier=(new_path := parent_path + (name,)), + parent=parent_path, data='___dict___') + for child_name, child in node.items(): + self.parse_object(new_path, child_name, child) + + def leaf_data(self) -> Tuple[Tuple, Any]: + for leaf in self.leaves(): + yield leaf.identifier, leaf.data + + def leaf_tensors(self) -> Tuple[Tuple, Any]: + for leaf in self.leaves(): + if isinstance(leaf.data, Number) or isinstance(leaf.data, bool): + data = torch.Tensor([[leaf.data]]) + elif isinstance(leaf.data, str): + data = torch.LongTensor( + [self.all_characters.index(char) for char in leaf.data] + ) + else: + data = torch.zeros(1, 1) + yield leaf.identifier, data + + +if __name__ == '__main__': + from pprint import pprint + import json + + array = [{"test": {"iest": ['stest', [1, 2, 3]]}, "other": [None, 1, True], "empty": [], "empty_2": {}}] * 3 + some_json = json.loads(r""" + [{"n": "OO_temp_sensor", "t": 0, "u": "K", "v": 290.02483570765054}, + {"n": "CC_temp_sensor", "t": 0, "u": "K", "v": 290.032384426905}, + {"n": "NW_temp_sensor", "t": 0, "u": "K", "v": 289.98829233126384}, + {"n": "NW_Heater", "t": 0, "u": "W", "v": 185.8732269977827}, + {"n": "NN_temp_sensor", "t": 0, "u": "K", "v": 290.0789606407754}, + {"n": "NN_Heater", "t": 0, "u": "W", "v": 171.3662974759336}, + {"n": "NE_temp_sensor", "t": 0, "u": "K", "v": 289.97652628070324} + ] + """) + + trees = [JSONParseTree.parse_start('___root___', arr) for arr in some_json] + + sample_identifiers, sample_index, sample_data = sort_together([*zip(*[ + (leaf_identifier, i, leaf_data) for i, tree in enumerate(trees) + for leaf_identifier, leaf_data in tree.leaf_tensors() + ])]) + + from prettytable import PrettyTable + + + def sample_table(identifiers, index, data): + table = PrettyTable(('Identifier', 'Batch index', 'Data')) + for id, idx, dat in zip(identifiers, index, data): + table.add_row([id, idx, dat]) + print(table) + + + sample_table(sample_identifiers, sample_index, sample_data) + + # pprint(list(split_when(zip(sample_identifiers, sample_index, sample_data), lambda x, y: x[0] != y[0]))) + import torch + buck = {k: [sorted_elem + if not isinstance(sorted_elem, str) and sorted_elem is not None else sorted_elem + for sorted_elem in + zip(*sorted(elem[1:] for elem in g))] + for k, g in groupby( + zip( + sample_identifiers, + sample_index, + sample_data + ), lambda x: x[0] + )} + + pprint(buck) diff --git a/minimum_working_example.py b/minimum_working_example.py new file mode 100644 index 0000000..4ea6eac --- /dev/null +++ b/minimum_working_example.py @@ -0,0 +1,46 @@ +from json2vec import JSONTreeLSTM +import torch +from datasets import load_seismic_dataset +from pytorch_lightning_test import JsonTreeSystem + +jsons, vectors, labels = load_seismic_dataset() + +num_classes = 1 +embedder = JsonTreeSystem(mem_dim=64) +output_layer = torch.nn.Linear(128, num_classes) +model = torch.nn.Sequential(embedder) + +some_json = r""" +[{"n": "OO_temp_sensor", "t": 0, "u": "K", "v": 290.02483570765054}, +{"n": "CC_temp_sensor", "t": 0, "u": "K", "v": 290.032384426905}, +{"n": "NW_temp_sensor", "t": 0, "u": "K", "v": 289.98829233126384}, +{"n": "NW_Heater", "t": 0, "u": "W", "v": 185.8732269977827}, +{"n": "NN_temp_sensor", "t": 0, "u": "K", "v": 290.0789606407754}, +{"n": "NN_Heater", "t": 0, "u": "W", "v": 171.3662974759336}, +{"n": "NE_temp_sensor", "t": {"1": 1}, "u": "K", "v": 289.97652628070324} +] +""" +same_json = """ +{"n": "OO_temp_sensor", "u": "K", "v": 290.02483570765054, "t": 0} +""" +""" +from tqdm import tqdm +for some_json in tqdm(jsons): + output_3 = model(some_json) +""" +model("""[1, [2, {"num": {"as_text": "four", "as_int": 4}}]]""") +from prettytable import PrettyTable + +def count_parameters(model): + table = PrettyTable(["Modules", "Parameters"]) + total_params = 0 + for name, parameter in model.named_parameters(): + if not parameter.requires_grad: continue + param = parameter.numel() + table.add_row([name, param]) + total_params+=param + print(table) + print(f"Total Trainable Params: {total_params}") + return total_params + +count_parameters(embedder) \ No newline at end of file diff --git a/pytorch_lightning_test.py b/pytorch_lightning_test.py new file mode 100644 index 0000000..22c0ece --- /dev/null +++ b/pytorch_lightning_test.py @@ -0,0 +1,61 @@ +import torch +import torch.nn.functional as F +from torch import nn +from torch.utils.data import Dataset, DataLoader +import pytorch_lightning as pl +import sklearn +pl.LightningDataModule() + +from json2vec import JSONTreeLSTM +from datasets import load_seismic_dataset + + +class SeismicDataset(Dataset): + """Seismic dataset""" + + def __init__(self): + jsons, vectors, labels = load_seismic_dataset() + labels = torch.LongTensor([int(label) for label in labels]) + self.jsons, self.vectors, self.labels = sklearn.utils.shuffle(jsons, vectors, labels) + + def __len__(self): + return len(self.jsons) + + def __getitem__(self, idx): + return self.jsons[idx], self.labels[idx] + + +class JsonTreeSystem(pl.LightningModule): + + def __init__(self, mem_dim=128): + super().__init__() + + self.json_tree_lstm = JSONTreeLSTM(mem_dim=mem_dim) + self.output = nn.Linear(2 * mem_dim, 1) + + def forward(self, *args, **kwargs): + return self.json_tree_lstm(*args, **kwargs) + + def training_step(self, batch, batch_idx): + jsons, labels = batch + labels = torch.LongTensor([int(label) for label in labels]) + output = self.json_tree_lstm(*jsons) + output = torch.sigmoid(self.output(output).view(1)) + + loss = F.binary_cross_entropy(output, labels.float()) + + self.log('train_loss', loss) + + return loss + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) + return optimizer + +if __name__ == '__main__': + from pprint import pprint + seismic_dataset = SeismicDataset() + train_loader = DataLoader(seismic_dataset) + json_tree = JsonTreeSystem(128) + trainer = pl.Trainer(overfit_batches=1) + trainer.fit(json_tree, train_loader) \ No newline at end of file diff --git a/some_json.json b/some_json.json new file mode 100644 index 0000000..1d49de1 --- /dev/null +++ b/some_json.json @@ -0,0 +1 @@ +[{"n": "OO_temp_sensor", "t": 0, "u": "K", "v": 290.02483570765054}, {"n": "CC_temp_sensor", "t": 0, "u": "K", "v": 290.032384426905}, {"n": "NW_temp_sensor", "t": 0, "u": "K", "v": 289.98829233126384}, {"n": "NW_Heater", "t": 0, "u": "W", "v": 185.8732269977827}, {"n": "NN_temp_sensor", "t": 0, "u": "K", "v": 290.0789606407754}, {"n": "NN_Heater", "t": 0, "u": "W", "v": 171.3662974759336}, {"n": "NE_temp_sensor", "t": 0, "u": "K", "v": 289.97652628070324}] \ No newline at end of file diff --git a/tree.txt b/tree.txt new file mode 100644 index 0000000..2b44c7c --- /dev/null +++ b/tree.txt @@ -0,0 +1,30 @@ +digraph tree { + "('___root___',)" [label="___root___", shape=circle] + "('___root___', 'empty')" [label="empty", shape=circle] + "('___root___', 'empty_2')" [label="empty_2", shape=circle] + "('___root___', 'other')" [label="other", shape=circle] + "('___root___', 'test')" [label="test", shape=circle] + "('___root___', 'other', '0')" [label="0", shape=circle] + "('___root___', 'other', '1')" [label="1", shape=circle] + "('___root___', 'other', '2')" [label="2", shape=circle] + "('___root___', 'test', 'iest')" [label="iest", shape=circle] + "('___root___', 'test', 'iest', '0')" [label="0", shape=circle] + "('___root___', 'test', 'iest', '1')" [label="1", shape=circle] + "('___root___', 'test', 'iest', '1', '0')" [label="0", shape=circle] + "('___root___', 'test', 'iest', '1', '1')" [label="1", shape=circle] + "('___root___', 'test', 'iest', '1', '2')" [label="2", shape=circle] + + "('___root___',)" -> "('___root___', 'test')" + "('___root___',)" -> "('___root___', 'other')" + "('___root___',)" -> "('___root___', 'empty')" + "('___root___',)" -> "('___root___', 'empty_2')" + "('___root___', 'other')" -> "('___root___', 'other', '0')" + "('___root___', 'other')" -> "('___root___', 'other', '1')" + "('___root___', 'other')" -> "('___root___', 'other', '2')" + "('___root___', 'test')" -> "('___root___', 'test', 'iest')" + "('___root___', 'test', 'iest')" -> "('___root___', 'test', 'iest', '0')" + "('___root___', 'test', 'iest')" -> "('___root___', 'test', 'iest', '1')" + "('___root___', 'test', 'iest', '1')" -> "('___root___', 'test', 'iest', '1', '0')" + "('___root___', 'test', 'iest', '1')" -> "('___root___', 'test', 'iest', '1', '1')" + "('___root___', 'test', 'iest', '1')" -> "('___root___', 'test', 'iest', '1', '2')" +} \ No newline at end of file