-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Started working on my own implementation of the json2vec model, which…
… will feature batching!
- Loading branch information
ajoino
committed
Apr 4, 2021
1 parent
1120d3c
commit d3e052b
Showing
7 changed files
with
396 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -140,3 +140,6 @@ cython_debug/ | |
|
||
# Pycharm files | ||
.idea/ | ||
|
||
# Pytorch stuff | ||
lightning_logs/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
from typing import Union | ||
from pathlib import Path | ||
import json | ||
from itertools import groupby | ||
from collections import defaultdict | ||
import string | ||
from numbers import Number | ||
|
||
from more_itertools import sort_together | ||
|
||
import torch | ||
from torch.utils.data import DataLoader, Dataset | ||
import pytorch_lightning as pl | ||
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence | ||
|
||
from json_parser import JSONParseTree | ||
|
||
ALL_CHARACTERS = chr(3) + string.printable | ||
NUM_CHARACTERS = len(ALL_CHARACTERS) | ||
|
||
|
||
def collate_tree_batches(batch): | ||
max_sequence_lengths = {identifier: max([len(sample_data['leaf_data'])]) | ||
for sample in batch | ||
for identifier, sample_data in sample.items()} | ||
tensor_dict = defaultdict(list) | ||
index_dict = defaultdict(list) | ||
|
||
for sample in batch: | ||
for identifier, sample_data in sample.items(): | ||
tensor_dict[identifier].append(sample_data['leaf_data']) | ||
index_dict[identifier].append(sample_data['sample_index']) | ||
|
||
collated_samples = {} | ||
for (identifier, index) in index_dict.items(): | ||
tensors = tensor_dict[identifier] | ||
# TODO: isinstance check is not enough, to perfectly separate cases we need type information from the parse tree | ||
tensors = pad_sequence(tensors, padding_value=0) if isinstance(tensors[0], torch.LongTensor) else torch.cat(tensors, dim=0) | ||
index = torch.cat(index, dim=0) | ||
collated_samples[identifier] = {'sample_index': index, 'leaf_data': tensors} | ||
|
||
|
||
return collated_samples | ||
|
||
|
||
class LeafDataToTensor: | ||
"""Convert leaf data to tensors""" | ||
all_characters = string.printable | ||
|
||
def __call__(self, sample): | ||
tensor_sample = {identifier: {'sample_index': torch.LongTensor([[data['sample_index']]]), 'leaf_data': self._transform(data['leaf_data'])} | ||
for identifier, data in sample.items()} | ||
|
||
return tensor_sample | ||
|
||
def _transform(self, value): | ||
|
||
if isinstance(value, Number) or isinstance(value, bool): | ||
data = torch.Tensor([[value]]) | ||
elif isinstance(value, str): | ||
data = torch.LongTensor( | ||
[ALL_CHARACTERS.index(char) for char in value] | ||
) | ||
else: | ||
data = torch.zeros(1, 1) | ||
|
||
return data | ||
|
||
class SimpleDataset(Dataset): | ||
"""Simple dataset for json data""" | ||
|
||
def __init__(self, data_file, transform=None): | ||
|
||
with open(data_file, 'r') as json_file: | ||
self.jsons = json.load(json_file) | ||
self.transform = transform | ||
|
||
def __len__(self): | ||
return len(self.jsons) | ||
|
||
def __getitem__(self, idx): | ||
if not isinstance(idx, int): | ||
raise ValueError('Index has to be a single integer') | ||
|
||
tree = JSONParseTree.parse_start('___root___', self.jsons[idx]) | ||
|
||
sample = {identifier: {'sample_index': idx, 'leaf_data': datum} for identifier, datum in tree.leaf_data()} | ||
|
||
if self.transform: | ||
sample = self.transform(sample) | ||
|
||
return sample | ||
|
||
|
||
class JSONDataModule(pl.LightningDataModule): | ||
|
||
def __init__(self, data_dir: Union[str, Path] = './'): | ||
super().__init__() | ||
|
||
self.data_dir = Path(data_dir) | ||
|
||
def setup(self, stage=None): | ||
self.json_data = SimpleDataset(self.data_dir, transform=LeafDataToTensor()) | ||
|
||
def train_dataloader(self): | ||
return DataLoader(self.json_data, collate_fn=collate_tree_batches, batch_size=4) | ||
|
||
|
||
if __name__ == '__main__': | ||
from pprint import pprint | ||
jsons = JSONDataModule('some_json.json') | ||
jsons.setup() | ||
|
||
for batch in jsons.train_dataloader(): | ||
pprint(batch) | ||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
from __future__ import annotations | ||
import string | ||
from typing import Union, Sequence, Mapping, Any, Tuple | ||
from treelib import Tree, Node | ||
from numbers import Number | ||
from more_itertools import sort_together, split_when, bucket | ||
from itertools import groupby | ||
|
||
import torch | ||
|
||
|
||
def is_primitive(node): | ||
return ( | ||
isinstance(node, Number) | ||
or isinstance(node, bool) | ||
or isinstance(node, str) | ||
or node is None | ||
) | ||
|
||
|
||
def is_array(node): | ||
return ( | ||
isinstance(node, list) | ||
or isinstance(node, tuple) | ||
) | ||
|
||
|
||
def is_object(node): | ||
return isinstance(node, dict) | ||
|
||
|
||
class JSONParseTree(Tree): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
|
||
self.all_characters = string.printable | ||
self.num_characters = len(self.all_characters) | ||
|
||
@classmethod | ||
def parse_start(cls, root_name: str, node: Any) -> JSONParseTree: | ||
tree = cls() | ||
tree.create_node(tag=root_name, identifier=(root_name,), data=node) | ||
if is_array(node): | ||
for i, child in enumerate(node): | ||
tree.parse_object((root_name,), str(i), child) | ||
elif is_object(node): | ||
for child_name, child in node.items(): | ||
tree.parse_object((root_name,), child_name, child) | ||
return tree | ||
|
||
def parse_object(self, parent_path: tuple, name: str, node: Any): | ||
if is_primitive(node): | ||
self.create_node(tag=name, identifier=parent_path + (name,), parent=parent_path, data=node) | ||
elif is_array(node): | ||
self.create_node( | ||
tag=name, | ||
identifier=(new_path := parent_path + (name,)), | ||
parent=parent_path, data='___array___' | ||
) | ||
for i, child in enumerate(node): | ||
self.parse_object(new_path, str(i), child) | ||
elif is_object(node): | ||
self.create_node( | ||
tag=name, | ||
identifier=(new_path := parent_path + (name,)), | ||
parent=parent_path, data='___dict___') | ||
for child_name, child in node.items(): | ||
self.parse_object(new_path, child_name, child) | ||
|
||
def leaf_data(self) -> Tuple[Tuple, Any]: | ||
for leaf in self.leaves(): | ||
yield leaf.identifier, leaf.data | ||
|
||
def leaf_tensors(self) -> Tuple[Tuple, Any]: | ||
for leaf in self.leaves(): | ||
if isinstance(leaf.data, Number) or isinstance(leaf.data, bool): | ||
data = torch.Tensor([[leaf.data]]) | ||
elif isinstance(leaf.data, str): | ||
data = torch.LongTensor( | ||
[self.all_characters.index(char) for char in leaf.data] | ||
) | ||
else: | ||
data = torch.zeros(1, 1) | ||
yield leaf.identifier, data | ||
|
||
|
||
if __name__ == '__main__': | ||
from pprint import pprint | ||
import json | ||
|
||
array = [{"test": {"iest": ['stest', [1, 2, 3]]}, "other": [None, 1, True], "empty": [], "empty_2": {}}] * 3 | ||
some_json = json.loads(r""" | ||
[{"n": "OO_temp_sensor", "t": 0, "u": "K", "v": 290.02483570765054}, | ||
{"n": "CC_temp_sensor", "t": 0, "u": "K", "v": 290.032384426905}, | ||
{"n": "NW_temp_sensor", "t": 0, "u": "K", "v": 289.98829233126384}, | ||
{"n": "NW_Heater", "t": 0, "u": "W", "v": 185.8732269977827}, | ||
{"n": "NN_temp_sensor", "t": 0, "u": "K", "v": 290.0789606407754}, | ||
{"n": "NN_Heater", "t": 0, "u": "W", "v": 171.3662974759336}, | ||
{"n": "NE_temp_sensor", "t": 0, "u": "K", "v": 289.97652628070324} | ||
] | ||
""") | ||
|
||
trees = [JSONParseTree.parse_start('___root___', arr) for arr in some_json] | ||
|
||
sample_identifiers, sample_index, sample_data = sort_together([*zip(*[ | ||
(leaf_identifier, i, leaf_data) for i, tree in enumerate(trees) | ||
for leaf_identifier, leaf_data in tree.leaf_tensors() | ||
])]) | ||
|
||
from prettytable import PrettyTable | ||
|
||
|
||
def sample_table(identifiers, index, data): | ||
table = PrettyTable(('Identifier', 'Batch index', 'Data')) | ||
for id, idx, dat in zip(identifiers, index, data): | ||
table.add_row([id, idx, dat]) | ||
print(table) | ||
|
||
|
||
sample_table(sample_identifiers, sample_index, sample_data) | ||
|
||
# pprint(list(split_when(zip(sample_identifiers, sample_index, sample_data), lambda x, y: x[0] != y[0]))) | ||
import torch | ||
buck = {k: [sorted_elem | ||
if not isinstance(sorted_elem, str) and sorted_elem is not None else sorted_elem | ||
for sorted_elem in | ||
zip(*sorted(elem[1:] for elem in g))] | ||
for k, g in groupby( | ||
zip( | ||
sample_identifiers, | ||
sample_index, | ||
sample_data | ||
), lambda x: x[0] | ||
)} | ||
|
||
pprint(buck) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from json2vec import JSONTreeLSTM | ||
import torch | ||
from datasets import load_seismic_dataset | ||
from pytorch_lightning_test import JsonTreeSystem | ||
|
||
jsons, vectors, labels = load_seismic_dataset() | ||
|
||
num_classes = 1 | ||
embedder = JsonTreeSystem(mem_dim=64) | ||
output_layer = torch.nn.Linear(128, num_classes) | ||
model = torch.nn.Sequential(embedder) | ||
|
||
some_json = r""" | ||
[{"n": "OO_temp_sensor", "t": 0, "u": "K", "v": 290.02483570765054}, | ||
{"n": "CC_temp_sensor", "t": 0, "u": "K", "v": 290.032384426905}, | ||
{"n": "NW_temp_sensor", "t": 0, "u": "K", "v": 289.98829233126384}, | ||
{"n": "NW_Heater", "t": 0, "u": "W", "v": 185.8732269977827}, | ||
{"n": "NN_temp_sensor", "t": 0, "u": "K", "v": 290.0789606407754}, | ||
{"n": "NN_Heater", "t": 0, "u": "W", "v": 171.3662974759336}, | ||
{"n": "NE_temp_sensor", "t": {"1": 1}, "u": "K", "v": 289.97652628070324} | ||
] | ||
""" | ||
same_json = """ | ||
{"n": "OO_temp_sensor", "u": "K", "v": 290.02483570765054, "t": 0} | ||
""" | ||
""" | ||
from tqdm import tqdm | ||
for some_json in tqdm(jsons): | ||
output_3 = model(some_json) | ||
""" | ||
model("""[1, [2, {"num": {"as_text": "four", "as_int": 4}}]]""") | ||
from prettytable import PrettyTable | ||
|
||
def count_parameters(model): | ||
table = PrettyTable(["Modules", "Parameters"]) | ||
total_params = 0 | ||
for name, parameter in model.named_parameters(): | ||
if not parameter.requires_grad: continue | ||
param = parameter.numel() | ||
table.add_row([name, param]) | ||
total_params+=param | ||
print(table) | ||
print(f"Total Trainable Params: {total_params}") | ||
return total_params | ||
|
||
count_parameters(embedder) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
from torch import nn | ||
from torch.utils.data import Dataset, DataLoader | ||
import pytorch_lightning as pl | ||
import sklearn | ||
pl.LightningDataModule() | ||
|
||
from json2vec import JSONTreeLSTM | ||
from datasets import load_seismic_dataset | ||
|
||
|
||
class SeismicDataset(Dataset): | ||
"""Seismic dataset""" | ||
|
||
def __init__(self): | ||
jsons, vectors, labels = load_seismic_dataset() | ||
labels = torch.LongTensor([int(label) for label in labels]) | ||
self.jsons, self.vectors, self.labels = sklearn.utils.shuffle(jsons, vectors, labels) | ||
|
||
def __len__(self): | ||
return len(self.jsons) | ||
|
||
def __getitem__(self, idx): | ||
return self.jsons[idx], self.labels[idx] | ||
|
||
|
||
class JsonTreeSystem(pl.LightningModule): | ||
|
||
def __init__(self, mem_dim=128): | ||
super().__init__() | ||
|
||
self.json_tree_lstm = JSONTreeLSTM(mem_dim=mem_dim) | ||
self.output = nn.Linear(2 * mem_dim, 1) | ||
|
||
def forward(self, *args, **kwargs): | ||
return self.json_tree_lstm(*args, **kwargs) | ||
|
||
def training_step(self, batch, batch_idx): | ||
jsons, labels = batch | ||
labels = torch.LongTensor([int(label) for label in labels]) | ||
output = self.json_tree_lstm(*jsons) | ||
output = torch.sigmoid(self.output(output).view(1)) | ||
|
||
loss = F.binary_cross_entropy(output, labels.float()) | ||
|
||
self.log('train_loss', loss) | ||
|
||
return loss | ||
|
||
def configure_optimizers(self): | ||
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) | ||
return optimizer | ||
|
||
if __name__ == '__main__': | ||
from pprint import pprint | ||
seismic_dataset = SeismicDataset() | ||
train_loader = DataLoader(seismic_dataset) | ||
json_tree = JsonTreeSystem(128) | ||
trainer = pl.Trainer(overfit_batches=1) | ||
trainer.fit(json_tree, train_loader) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
[{"n": "OO_temp_sensor", "t": 0, "u": "K", "v": 290.02483570765054}, {"n": "CC_temp_sensor", "t": 0, "u": "K", "v": 290.032384426905}, {"n": "NW_temp_sensor", "t": 0, "u": "K", "v": 289.98829233126384}, {"n": "NW_Heater", "t": 0, "u": "W", "v": 185.8732269977827}, {"n": "NN_temp_sensor", "t": 0, "u": "K", "v": 290.0789606407754}, {"n": "NN_Heater", "t": 0, "u": "W", "v": 171.3662974759336}, {"n": "NE_temp_sensor", "t": 0, "u": "K", "v": 289.97652628070324}] |
Oops, something went wrong.