Skip to content

Commit e12cd1a

Browse files
committed
implements minimum 1 bucket for oov
1 parent b3d95f5 commit e12cd1a

File tree

6 files changed

+53
-40
lines changed

6 files changed

+53
-40
lines changed

.autoenv.zsh

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
if [[ $autoenv_event == 'enter' ]]; then
2+
typeset -ax PREVIOUS_LEFT_PROMPT_ELEMENTS
3+
PREVIOUS_LEFT_PROMPT_ELEMENTS=(${(v)POWERLEVEL9K_LEFT_PROMPT_ELEMENTS})
4+
POWERLEVEL9K_LEFT_PROMPT_ELEMENTS=(time root_indicator virtualenv dir vcs vi_mode)
5+
else
6+
typeset -ax POWERLEVEL9K_LEFT_PROMPT_ELEMENTS
7+
POWERLEVEL9K_LEFT_PROMPT_ELEMENTS=(${(v)PREVIOUS_LEFT_PROMPT_ELEMENTS})
8+
fi

.vscode/launch.json

+9-6
Original file line numberDiff line numberDiff line change
@@ -64,19 +64,22 @@
6464
"console": "integratedTerminal"
6565
},
6666
{
67-
"name": "Embedding-filter-fn-test",
67+
"name": "min-1-oov-bucket-fix",
6868
"type": "python",
6969
"request": "launch",
7070
"module": "tsaplay.task",
7171
"args": [
7272
"single",
73-
"-em='twitter-50[corpus,only_adjectives]'",
73+
"-em='wiki-50[corpus]'",
7474
"-ds='dong'",
75-
"-m=lstm",
76-
"-b=25",
77-
"-s=200",
75+
"-m=lcrrot",
76+
"-b=5",
77+
"-s=100",
7878
"-mp",
79-
"num_oov_buckets=100"
79+
"oov=true",
80+
"hidden_units=5",
81+
"-aux",
82+
"attn_heatmaps=false"
8083
],
8184
"console": "integratedTerminal"
8285
},

tsaplay/features.py

+29-31
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ def _init_vocab(self):
175175
self._vocab = read_vocab_file(vocab_file_path)
176176
else:
177177
self._vocab = self._embedding.vocab
178+
#! if 0 buckets and oov = true, the train vocab is added, each will be assigned a vector
178179
if self._oov_fn and not self._num_oov_buckets:
179180
train_vocab = set(
180181
corpora_vocab(
@@ -210,19 +211,13 @@ def _init_token_data(self):
210211
data_dict = getattr(self, data_dict_attr)
211212
to_tokenize[mode] = data_dict
212213
if to_tokenize:
213-
include = set(self._vocab) | (
214-
set(
215-
corpora_vocab(
216-
self._train_corpus,
217-
self._test_corpus,
218-
case_insensitive=self._embedding.case_insensitive,
219-
)
214+
#! Regardless of buckets, all vocab must be tokenized, otherwise risk experiment failing with empty target
215+
include = set(self._vocab) | set(
216+
corpora_vocab(
217+
self._train_corpus,
218+
self._test_corpus,
219+
case_insensitive=self._embedding.case_insensitive,
220220
)
221-
if self._num_oov_buckets
222-
#! an OOV target that appears only in the test dataset oov_buckets = 0
223-
#! will break this system as it will not be included in the tokens,
224-
#! resulting in an empty target
225-
else set()
226221
)
227222
include_tokens_path = join(self._gen_dir, "_incl_tokens.pkl")
228223
pickle_file(path=include_tokens_path, data=include)
@@ -268,8 +263,9 @@ def _init_tfrecords(self):
268263
write_vocab_file(
269264
filtered_vocab_path, filtered_vocab, indices
270265
)
266+
#! There has to be at least 1 bucket for any test-time oov tokens (possibly targets)
271267
lookup_table = ids_lookup_table(
272-
filtered_vocab_path, self._num_oov_buckets
268+
filtered_vocab_path, max(self._num_oov_buckets, 1)
273269
)
274270
fetch_dict = fetch_lookup_ops(lookup_table, **tokens_lists)
275271
fetch_results = run_lookups(
@@ -289,18 +285,17 @@ def _init_tfrecords(self):
289285
tfrecord_folder = "_{mode}".format(mode=mode)
290286
tfrecord_path = join(self._gen_dir, tfrecord_folder)
291287
write_tfrecords(tfrecord_path, tfexamples)
292-
if self._num_oov_buckets:
293-
buckets = [
294-
BUCKET_TOKEN.format(num=n + 1)
295-
for n in range(self._num_oov_buckets)
296-
]
297-
oov_buckets[mode] = tokens_by_assigned_id(
298-
string_features,
299-
int_features,
300-
start=len(self._vocab),
301-
keys=buckets,
302-
)
303-
if oov_buckets:
288+
#! There has to be at least 1 bucket for any test-time oov tokens (possibly targets)
289+
buckets = [
290+
BUCKET_TOKEN.format(num=n + 1)
291+
for n in range(max(self._num_oov_buckets, 1))
292+
]
293+
oov_buckets[mode] = tokens_by_assigned_id(
294+
string_features,
295+
int_features,
296+
start=len(self._vocab),
297+
keys=buckets,
298+
)
304299
accum_oov_buckets = accumulate_dicts(
305300
**oov_buckets,
306301
accum_fn=lambda prev, curr: list(set(prev) | set(curr)),
@@ -316,17 +311,19 @@ def _init_embedding_params(self):
316311
np.random.seed(RANDOM_SEED)
317312
dim_size = self._embedding.dim_size
318313
vectors = self._embedding.vectors
319-
num_oov_vectors = len(self._vocab) - self._embedding.vocab_size
320-
num_oov_vectors += self._num_oov_buckets
321-
if num_oov_vectors:
322-
oov_fn = self._oov_fn or DEFAULT_OOV_FN
323-
oov_vectors = oov_fn(size=(num_oov_vectors, dim_size))
324-
vectors = np.concatenate([vectors, oov_vectors], axis=0)
314+
#! There has to be at least 1 bucket for any test-time oov tokens (possibly targets)
315+
num_oov_vectors = (self._num_oov_buckets or 1) + (
316+
len(self._vocab) - self._embedding.vocab_size
317+
)
318+
oov_fn = self._oov_fn or DEFAULT_OOV_FN
319+
oov_vectors = oov_fn(size=(num_oov_vectors, dim_size))
320+
vectors = np.concatenate([vectors, oov_vectors], axis=0)
325321
vocab_size = len(vectors)
326322
num_shards = partitioner_num_shards(vocab_size)
327323
init_fn = embedding_initializer_fn(vectors, num_shards)
328324
self._embedding_params = {
329325
"_vocab_size": vocab_size,
326+
"_num_oov_buckets": max(self._num_oov_buckets, 1),
330327
"_vocab_file": self._vocab_file,
331328
"_embedding_dim": dim_size,
332329
"_embedding_init": init_fn,
@@ -345,6 +342,7 @@ def _write_info_file(self):
345342
"embedding": {
346343
"uid": self.embedding.uid,
347344
"name": self.embedding.name,
345+
"params": {k:stringify(v) for k,v in self._embedding_params.items()}
348346
},
349347
"oov_policy": {
350348
"oov": stringify(self._oov_fn),

tsaplay/models/tsa_model.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,11 @@ def _serving_input_receiver_fn(self):
184184
}
185185
parsed_example = tf.parse_example(inputs_serialized, feature_spec)
186186

187-
ids_table = ids_lookup_table(self.params["_vocab_file"])
187+
# TODO: Why does this function call not have a value for number of OOV buckets?
188+
ids_table = ids_lookup_table(
189+
self.params["_vocab_file"],
190+
oov_buckets=self.params["_num_oov_buckets"],
191+
)
188192
features = {
189193
"left": parsed_example["left"],
190194
"target": parsed_example["target"],

tsaplay/task.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def run_next_experiment(batch_file_path, job_dir=None, defaults=None):
238238
try:
239239
task_args = task_parser.parse_args(tasks[task_index])
240240
cprnt("RUNNING TASK {0}: {1}".format(task_index, task_args))
241-
# run_experiment(task_args, experiment_index=task_index)
241+
run_experiment(task_args, experiment_index=task_index)
242242
except Exception: # pylint: disable=W0703
243243
traceback.print_exc()
244244
environ["TSATASK"] = str(task_index + 1)

tsaplay/utils/tf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,7 @@ def image_to_summary(name, image):
471471
return summary
472472

473473

474-
def ids_lookup_table(vocab_file_path, oov_buckets=0):
474+
def ids_lookup_table(vocab_file_path, oov_buckets=1):
475475
return tf.contrib.lookup.index_table_from_file(
476476
vocabulary_file=vocab_file_path,
477477
key_column_index=0,

0 commit comments

Comments
 (0)