-
Notifications
You must be signed in to change notification settings - Fork 194
/
Copy pathprefetcher.py
46 lines (39 loc) · 1.84 KB
/
prefetcher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
"""Provides functions to prefetch tensors to feed into models."""
import tensorflow as tf
def prefetch(tensor_dict, capacity):
"""Creates a prefetch queue for tensors.
Creates a FIFO queue to asynchronously enqueue tensor_dicts and returns a
dequeue op that evaluates to a tensor_dict. This function is useful in
prefetching preprocessed tensors so that the data is readily available for
consumers.
Example input pipeline when you don't need batching:
----------------------------------------------------
key, string_tensor = slim.parallel_reader.parallel_read(...)
tensor_dict = decoder.decode(string_tensor)
tensor_dict = preprocessor.preprocess(tensor_dict, ...)
prefetch_queue = prefetcher.prefetch(tensor_dict, capacity=20)
tensor_dict = prefetch_queue.dequeue()
outputs = Model(tensor_dict)
...
----------------------------------------------------
For input pipelines with batching, refer to core/batcher.py
Args:
tensor_dict: a dictionary of tensors to prefetch.
capacity: the size of the prefetch queue.
Returns:
a FIFO prefetcher queue
"""
names = list(tensor_dict.keys())
dtypes = [t.dtype for t in tensor_dict.values()]
shapes = [t.get_shape() for t in tensor_dict.values()]
prefetch_queue = tf.PaddingFIFOQueue(capacity, dtypes=dtypes,
shapes=shapes,
names=names,
name='prefetch_queue')
enqueue_op = prefetch_queue.enqueue(tensor_dict)
tf.train.queue_runner.add_queue_runner(tf.train.queue_runner.QueueRunner(
prefetch_queue, [enqueue_op]))
tf.summary.scalar('queue/%s/fraction_of_%d_full' % (prefetch_queue.name,
capacity),
tf.to_float(prefetch_queue.size()) * (1. / capacity))
return prefetch_queue