Skip to content

Commit fd4fbba

Browse files
committed
moves function to bundle datasets into utils file under data
1 parent adaac91 commit fd4fbba

File tree

4 files changed

+54
-30
lines changed

4 files changed

+54
-30
lines changed

docker/Dockerfile

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
FROM tensorflow/serving
2+
COPY export/ /models/
23
RUN echo '#!/bin/bash \n\n\
34
tensorflow_model_server --port=8500 --rest_api_port=8501 \
45
--model_config_file=/models/tfserve.conf \

main.py

+22-12
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
InteractiveAttentionNetwork
1616
)
1717
from tsaplay.models.Tang2016b.MemNet import MemNet
18+
from tsaplay.utils._data import bundle_datasets
1819

1920
tf.logging.set_verbosity(tf.logging.INFO)
2021

@@ -32,19 +33,28 @@
3233
"initializer": tf.initializers.random_uniform(minval=-0.1, maxval=0.1),
3334
}
3435

35-
experiment = Experiment(
36-
dataset=Dataset(
37-
path=DATASETS.DEBUG_PATH,
38-
parser=DATASETS.DEBUG_PARSER,
39-
embedding=Embedding(path=EMBEDDINGS.DEBUG),
40-
),
41-
model=LcrRot(),
42-
contd_tag="debug",
43-
# run_config=tf.estimator.RunConfig(tf_random_seed=1234),
36+
restaurants = Dataset(
37+
path=DATASETS.XUE2018_RESTAURANTS_PATH, parser=DATASETS.XUE2018_PARSER
4438
)
45-
experiment.run(job="train+eval", steps=1)
46-
# experiment.export_model(overwrite=True)
47-
experiment.export_model(overwrite=True, restart_tfserve=True)
39+
laptops = Dataset(
40+
path=DATASETS.XUE2018_LAPTOPS_PATH, parser=DATASETS.XUE2018_PARSER
41+
)
42+
43+
rest_lapt = bundle_datasets(restaurants, laptops)
44+
45+
# experiment = Experiment(
46+
# dataset=Dataset(
47+
# path=DATASETS.DEBUG_PATH,
48+
# parser=DATASETS.DEBUG_PARSER,
49+
# embedding=Embedding(path=EMBEDDINGS.DEBUG),
50+
# ),
51+
# model=LcrRot(),
52+
# contd_tag="debug",
53+
# # run_config=tf.estimator.RunConfig(tf_random_seed=1234),
54+
# )
55+
# experiment.run(job="train+eval", steps=1)
56+
# # experiment.export_model(overwrite=True)
57+
# experiment.export_model(overwrite=True, restart_tfserve=True)
4858
# experiment = Experiment(
4959
# dataset=Dataset(
5060
# path=DATASETS.XUE2018_LAPTOPS_PATH,

tsaplay/datasets/Dataset.py

-18
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
unpickle_file as _unpickle,
2121
pickle_file as _pickle,
2222
)
23-
from tsaplay.utils._data import concat_dicts_lists
2423
import tsaplay.datasets._constants as DATASETS
2524

2625

@@ -221,23 +220,6 @@ def get_features_and_labels(self, mode, distribution=None):
221220

222221
return features, labels, stats
223222

224-
def __add__(self, other):
225-
if isinstance(other, Dataset):
226-
gen_name = "_".join([self.name, other.name])
227-
gen_path = join(DATASETS.PARENT_DIR, "_generated", gen_name)
228-
229-
joined_train_dict = concat_dicts_lists(
230-
self.train_dict, other.train_dict
231-
)
232-
joined_test_dict = concat_dicts_lists(
233-
self.test_dict, other.test_dict
234-
)
235-
236-
_pickle(joined_train_dict, join(gen_path, "train_dict.pkl"))
237-
_pickle(joined_test_dict, join(gen_path, "test_dict.pkl"))
238-
239-
return Dataset(path=gen_path, parser=None)
240-
241223
def _reset(self, path):
242224
self.__path = path
243225
self.__all_docs = None

tsaplay/utils/_data.py

+31
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
import tensorflow as tf
22
from itertools import chain
3+
from os import makedirs
4+
from os.path import join, exists
35
from collections import defaultdict
46
from tensorflow.python.keras.preprocessing import ( # pylint: disable=E0611
57
sequence
68
)
9+
from tsaplay.datasets.Dataset import Dataset, DATASETS
10+
from tsaplay.utils._io import pickle_file
711

812

913
def zip_str_join(first, second):
@@ -24,6 +28,33 @@ def concat_dicts_lists(first, second):
2428
return dict(new_dict)
2529

2630

31+
def bundle_datasets(*datasets, rebuild=False):
32+
dataset_names = []
33+
train_dict = {}
34+
test_dict = {}
35+
for dataset in datasets:
36+
if isinstance(dataset, Dataset) and dataset.name not in dataset_names:
37+
dataset_names.append(dataset.name)
38+
train_dict = concat_dicts_lists(dataset.train_dict, train_dict)
39+
test_dict = concat_dicts_lists(dataset.test_dict, test_dict)
40+
41+
dataset_name = "_".join(dataset_names)
42+
gen_path = join(DATASETS.PARENT_DIR, "_generated", dataset_name)
43+
44+
makedirs(gen_path, exist_ok=True)
45+
46+
train_dict_path = join(gen_path, "train_dict.pkl")
47+
test_dict_path = join(gen_path, "test_dict.pkl")
48+
49+
if not exists(train_dict_path) and not rebuild:
50+
pickle_file(train_dict_path, train_dict)
51+
52+
if not exists(test_dict_path) and not rebuild:
53+
pickle_file(test_dict_path, test_dict)
54+
55+
return Dataset(path=gen_path, parser=None)
56+
57+
2758
def make_labels_dataset_from_list(labels):
2859
low_bound = min(labels)
2960
if low_bound < 0:

0 commit comments

Comments
 (0)