@@ -175,6 +175,7 @@ def _init_vocab(self):
175
175
self ._vocab = read_vocab_file (vocab_file_path )
176
176
else :
177
177
self ._vocab = self ._embedding .vocab
178
+ #! if 0 buckets and oov = true, the train vocab is added, each will be assigned a vector
178
179
if self ._oov_fn and not self ._num_oov_buckets :
179
180
train_vocab = set (
180
181
corpora_vocab (
@@ -210,19 +211,13 @@ def _init_token_data(self):
210
211
data_dict = getattr (self , data_dict_attr )
211
212
to_tokenize [mode ] = data_dict
212
213
if to_tokenize :
213
- include = set (self ._vocab ) | (
214
- set (
215
- corpora_vocab (
216
- self ._train_corpus ,
217
- self ._test_corpus ,
218
- case_insensitive = self ._embedding .case_insensitive ,
219
- )
214
+ #! Regardless of buckets, all vocab must be tokenized, otherwise risk experiment failing with empty target
215
+ include = set (self ._vocab ) | set (
216
+ corpora_vocab (
217
+ self ._train_corpus ,
218
+ self ._test_corpus ,
219
+ case_insensitive = self ._embedding .case_insensitive ,
220
220
)
221
- if self ._num_oov_buckets
222
- #! an OOV target that appears only in the test dataset oov_buckets = 0
223
- #! will break this system as it will not be included in the tokens,
224
- #! resulting in an empty target
225
- else set ()
226
221
)
227
222
include_tokens_path = join (self ._gen_dir , "_incl_tokens.pkl" )
228
223
pickle_file (path = include_tokens_path , data = include )
@@ -268,8 +263,9 @@ def _init_tfrecords(self):
268
263
write_vocab_file (
269
264
filtered_vocab_path , filtered_vocab , indices
270
265
)
266
+ #! There has to be at least 1 bucket for any test-time oov tokens (possibly targets)
271
267
lookup_table = ids_lookup_table (
272
- filtered_vocab_path , self ._num_oov_buckets
268
+ filtered_vocab_path , max ( self ._num_oov_buckets , 1 )
273
269
)
274
270
fetch_dict = fetch_lookup_ops (lookup_table , ** tokens_lists )
275
271
fetch_results = run_lookups (
@@ -289,18 +285,17 @@ def _init_tfrecords(self):
289
285
tfrecord_folder = "_{mode}" .format (mode = mode )
290
286
tfrecord_path = join (self ._gen_dir , tfrecord_folder )
291
287
write_tfrecords (tfrecord_path , tfexamples )
292
- if self ._num_oov_buckets :
293
- buckets = [
294
- BUCKET_TOKEN .format (num = n + 1 )
295
- for n in range (self ._num_oov_buckets )
296
- ]
297
- oov_buckets [mode ] = tokens_by_assigned_id (
298
- string_features ,
299
- int_features ,
300
- start = len (self ._vocab ),
301
- keys = buckets ,
302
- )
303
- if oov_buckets :
288
+ #! There has to be at least 1 bucket for any test-time oov tokens (possibly targets)
289
+ buckets = [
290
+ BUCKET_TOKEN .format (num = n + 1 )
291
+ for n in range (max (self ._num_oov_buckets , 1 ))
292
+ ]
293
+ oov_buckets [mode ] = tokens_by_assigned_id (
294
+ string_features ,
295
+ int_features ,
296
+ start = len (self ._vocab ),
297
+ keys = buckets ,
298
+ )
304
299
accum_oov_buckets = accumulate_dicts (
305
300
** oov_buckets ,
306
301
accum_fn = lambda prev , curr : list (set (prev ) | set (curr )),
@@ -316,17 +311,19 @@ def _init_embedding_params(self):
316
311
np .random .seed (RANDOM_SEED )
317
312
dim_size = self ._embedding .dim_size
318
313
vectors = self ._embedding .vectors
319
- num_oov_vectors = len (self ._vocab ) - self ._embedding .vocab_size
320
- num_oov_vectors += self ._num_oov_buckets
321
- if num_oov_vectors :
322
- oov_fn = self ._oov_fn or DEFAULT_OOV_FN
323
- oov_vectors = oov_fn (size = (num_oov_vectors , dim_size ))
324
- vectors = np .concatenate ([vectors , oov_vectors ], axis = 0 )
314
+ #! There has to be at least 1 bucket for any test-time oov tokens (possibly targets)
315
+ num_oov_vectors = (self ._num_oov_buckets or 1 ) + (
316
+ len (self ._vocab ) - self ._embedding .vocab_size
317
+ )
318
+ oov_fn = self ._oov_fn or DEFAULT_OOV_FN
319
+ oov_vectors = oov_fn (size = (num_oov_vectors , dim_size ))
320
+ vectors = np .concatenate ([vectors , oov_vectors ], axis = 0 )
325
321
vocab_size = len (vectors )
326
322
num_shards = partitioner_num_shards (vocab_size )
327
323
init_fn = embedding_initializer_fn (vectors , num_shards )
328
324
self ._embedding_params = {
329
325
"_vocab_size" : vocab_size ,
326
+ "_num_oov_buckets" : max (self ._num_oov_buckets , 1 ),
330
327
"_vocab_file" : self ._vocab_file ,
331
328
"_embedding_dim" : dim_size ,
332
329
"_embedding_init" : init_fn ,
@@ -345,6 +342,7 @@ def _write_info_file(self):
345
342
"embedding" : {
346
343
"uid" : self .embedding .uid ,
347
344
"name" : self .embedding .name ,
345
+ "params" : {k :stringify (v ) for k ,v in self ._embedding_params .items ()}
348
346
},
349
347
"oov_policy" : {
350
348
"oov" : stringify (self ._oov_fn ),
0 commit comments