2121_DEFAULT_READER_SCHEMA = ""
2222# From https://github.com/tensorflow/tensorflow/blob/v2.0.0/tensorflow/python/data/ops/readers.py
2323
24+
25+ def _require (condition : bool , err_msg : str = None ) -> None :
26+ """Checks if the specified condition is true else raises exception
27+
28+ Args:
29+ condition: The condition to test
30+ err_msg: If specified, it's the error message to use if condition is not true.
31+
32+ Raises:
33+ ValueError: Raised when the condition is false
34+
35+ Returns:
36+ None
37+ """
38+ if not condition :
39+ raise ValueError (err_msg )
40+
41+
2442# copied from https://github.com/tensorflow/tensorflow/blob/
2543# 3095681b8649d9a828afb0a14538ace7a998504d/tensorflow/python/data/ops/readers.py#L36
2644def _create_or_validate_filenames_dataset (filenames ):
@@ -52,21 +70,62 @@ def _create_or_validate_filenames_dataset(filenames):
5270
5371# copied from https://github.com/tensorflow/tensorflow/blob/
5472# 3095681b8649d9a828afb0a14538ace7a998504d/tensorflow/python/data/ops/readers.py#L67
55- def _create_dataset_reader (dataset_creator , filenames , num_parallel_reads = None ):
56- """create_dataset_reader"""
57-
58- def read_one_file (filename ):
59- filename = tf .convert_to_tensor (filename , tf .string , name = "filename" )
60- return dataset_creator (filename )
61-
62- if num_parallel_reads is None :
63- return filenames .flat_map (read_one_file )
64- if num_parallel_reads == tf .data .experimental .AUTOTUNE :
65- return filenames .interleave (
66- read_one_file , num_parallel_calls = num_parallel_reads
67- )
73+ def _create_dataset_reader (
74+ dataset_creator ,
75+ filenames ,
76+ cycle_length = None ,
77+ num_parallel_calls = None ,
78+ deterministic = None ,
79+ block_length = 1 ,
80+ ):
81+ """
82+ This creates a dataset reader which reads records from multiple files and interleaves them together
83+ ```
84+ dataset = Dataset.range(1, 6) # ==> [ 1, 2, 3, 4, 5 ]
85+ # NOTE: New lines indicate "block" boundaries.
86+ dataset = dataset.interleave(
87+ lambda x: Dataset.from_tensors(x).repeat(6),
88+ cycle_length=2, block_length=4)
89+ list(dataset.as_numpy_iterator())
90+ ```
91+ Results in the following output:
92+ [1,1,1,1,
93+ 2,2,2,2,
94+ 1,1,
95+ 2,2,
96+ 3,3,3,3,
97+ 4,4,4,4,
98+ 3,4,
99+ 5,5,5,5,
100+ 5,5,
101+ ]
102+ Args:
103+ dataset_creator: Initializer for AvroDatasetRecord
104+ filenames: A `tf.data.Dataset` iterator of filenames to read
105+ cycle_length: The number of files to be processed in parallel. This is used by `Dataset.Interleave`.
106+ We set this equal to `block_length`, so that each time n number of records are returned for each of the n
107+ files.
108+ num_parallel_calls: Number of threads spawned by the interleave call.
109+ deterministic: Sets whether the interleaved records are written in deterministic order. in tf.interleave this is default true
110+ block_length: Sets the number of output on the output tensor. Defaults to 1
111+ Returns:
112+ A dataset iterator with an interleaved list of parsed avro records.
113+
114+ """
115+
116+ def read_many_files (filenames ):
117+ filenames = tf .convert_to_tensor (filenames , tf .string , name = "filename" )
118+ return dataset_creator (filenames )
119+
120+ if cycle_length is None :
121+ return filenames .flat_map (read_many_files )
122+
68123 return filenames .interleave (
69- read_one_file , cycle_length = num_parallel_reads , block_length = 1
124+ read_many_files ,
125+ cycle_length = cycle_length ,
126+ num_parallel_calls = num_parallel_calls ,
127+ block_length = block_length ,
128+ deterministic = deterministic ,
70129 )
71130
72131
@@ -128,10 +187,16 @@ class AvroRecordDataset(tf.data.Dataset):
128187 """A `Dataset` comprising records from one or more AvroRecord files."""
129188
130189 def __init__ (
131- self , filenames , buffer_size = None , num_parallel_reads = None , reader_schema = None
190+ self ,
191+ filenames ,
192+ buffer_size = None ,
193+ num_parallel_reads = None ,
194+ num_parallel_calls = None ,
195+ reader_schema = None ,
196+ deterministic = True ,
197+ block_length = 1 ,
132198 ):
133199 """Creates a `AvroRecordDataset` to read one or more AvroRecord files.
134-
135200 Args:
136201 filenames: A `tf.string` tensor or `tf.data.Dataset` containing one or
137202 more filenames.
@@ -144,25 +209,61 @@ def __init__(
144209 files read in parallel are outputted in an interleaved order. If your
145210 input pipeline is I/O bottlenecked, consider setting this parameter to a
146211 value greater than one to parallelize the I/O. If `None`, files will be
147- read sequentially.
212+ read sequentially. This must be set to equal or greater than `num_parallel_calls`.
213+ This constraint exists because `num_parallel_reads` becomes `cycle_length` in the
214+ underlying call to `tf.Dataset.Interleave`, and the `cycle_length` is required to be
215+ equal or higher than the number of threads(`num_parallel_calls`).
216+ `cycle_length` in tf.Dataset.Interleave will dictate how many items it will pick up to process
217+ num_parallel_calls: (Optional.) number of thread to spawn. This must be set to `None`
218+ or greater than 0. Also this must be less than or equal to `num_parallel_reads`. This defines
219+ the degree of parallelism in the underlying Dataset.interleave call.
148220 reader_schema: (Optional.) A `tf.string` scalar representing the reader
149221 schema or None.
150-
222+ deterministic: (Optional.) A boolean controlling whether determinism should be traded for performance by
223+ allowing elements to be produced out of order. Defaults to `True`
224+ block_length: Sets the number of output on the output tensor. Defaults to 1
151225 Raises:
152226 TypeError: If any argument does not have the expected type.
153227 ValueError: If any argument does not have the expected shape.
154228 """
229+ _require (
230+ num_parallel_calls is None
231+ or num_parallel_calls == tf .data .experimental .AUTOTUNE
232+ or num_parallel_calls > 0 ,
233+ f"num_parallel_calls: { num_parallel_calls } must be set to None, "
234+ f"tf.data.experimental.AUTOTUNE, or greater than 0" ,
235+ )
236+ if num_parallel_calls is not None :
237+ _require (
238+ num_parallel_reads is not None
239+ and (
240+ num_parallel_reads >= num_parallel_calls
241+ or num_parallel_reads == tf .data .experimental .AUTOTUNE
242+ ),
243+ f"num_parallel_reads: { num_parallel_reads } must be greater than or equal to "
244+ f"num_parallel_calls: { num_parallel_calls } or set to tf.data.experimental.AUTOTUNE" ,
245+ )
246+
155247 filenames = _create_or_validate_filenames_dataset (filenames )
156248
157249 self ._filenames = filenames
158250 self ._buffer_size = buffer_size
159251 self ._num_parallel_reads = num_parallel_reads
252+ self ._num_parallel_calls = num_parallel_calls
160253 self ._reader_schema = reader_schema
254+ self ._block_length = block_length
161255
162- def creator_fn ( filename ):
163- return _AvroRecordDataset (filename , buffer_size , reader_schema )
256+ def read_multiple_files ( filenames ):
257+ return _AvroRecordDataset (filenames , buffer_size , reader_schema )
164258
165- self ._impl = _create_dataset_reader (creator_fn , filenames , num_parallel_reads )
259+ self ._impl = _create_dataset_reader (
260+ read_multiple_files ,
261+ filenames ,
262+ cycle_length = num_parallel_reads ,
263+ num_parallel_calls = num_parallel_calls ,
264+ deterministic = deterministic ,
265+ block_length = block_length ,
266+ )
166267 variant_tensor = self ._impl ._variant_tensor # pylint: disable=protected-access
167268 super ().__init__ (variant_tensor )
168269
@@ -171,13 +272,17 @@ def _clone(
171272 filenames = None ,
172273 buffer_size = None ,
173274 num_parallel_reads = None ,
275+ num_parallel_calls = None ,
174276 reader_schema = None ,
277+ block_length = None ,
175278 ):
176279 return AvroRecordDataset (
177280 filenames or self ._filenames ,
178281 buffer_size or self ._buffer_size ,
179282 num_parallel_reads or self ._num_parallel_reads ,
283+ num_parallel_calls or self ._num_parallel_calls ,
180284 reader_schema or self ._reader_schema ,
285+ block_length or self ._block_length ,
181286 )
182287
183288 def _inputs (self ):
0 commit comments