diff --git a/.github/workflows/build.wheel.sh b/.github/workflows/build.wheel.sh index 502984e72..afb382cbb 100755 --- a/.github/workflows/build.wheel.sh +++ b/.github/workflows/build.wheel.sh @@ -7,10 +7,10 @@ run_test() { CPYTHON_VERSION=$($entry -c 'import sys; print(str(sys.version_info[0])+str(sys.version_info[1]))') (cd wheelhouse && $entry -m pip install tensorflow_io-*-cp${CPYTHON_VERSION}-*.whl) $entry -m pip install -q pytest pytest-benchmark boto3 fastavro avro-python3 scikit-image pandas pyarrow==3.0.0 google-cloud-pubsub==2.1.0 google-cloud-bigtable==1.6.0 google-cloud-bigquery-storage==1.1.0 google-cloud-bigquery==2.3.1 google-cloud-storage==1.32.0 - (cd tests && $entry -m pytest --benchmark-disable -v --import-mode=append $(find . -type f \( -iname "test_*.py" ! \( -iname "test_*_eager.py" \) \))) - (cd tests && $entry -m pytest --benchmark-disable -v --import-mode=append $(find . -type f \( -iname "test_*_eager.py" ! \( -iname "test_bigquery_eager.py" \) \))) + (cd tests && $entry -m pytest --benchmark-disable -v --import-mode=append $(find . -type f \( -iname "test_*_v1.py" \))) + (cd tests && $entry -m pytest --benchmark-disable -v --import-mode=append $(find . -type f \( -iname "test_*.py" ! \( -iname "test_*_v1.py" -o -iname "test_bigquery.py" \) \))) # GRPC and test_bigquery_eager tests have to be executed separately because of https://github.com/grpc/grpc/issues/20034 - (cd tests && $entry -m pytest --benchmark-disable -v --import-mode=append $(find . -type f \( -iname "test_bigquery_eager.py" \))) + (cd tests && $entry -m pytest --benchmark-disable -v --import-mode=append $(find . -type f \( -iname "test_bigquery.py" \))) } PYTHON_VERSION=python diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0df1e111d..74128276b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -434,14 +434,13 @@ jobs: python --version python -m pip install -U pytest-benchmark rm -rf tensorflow_io - (cd tests && python -m pytest -s -v test_lmdb_eager.py) - (python -m pytest -s -v test_image_eager.py -k "webp or ppm or bmp or bounding or exif or hdr or openexr or tiff or avif") - (python -m pytest -s -v test_serialization_eager.py) - (python -m pytest -s -v test_io_dataset_eager.py -k "numpy or hdf5 or audio or to_file") - (python -m pytest -s -v test_http_eager.py) + (cd tests && python -m pytest -s -v test_lmdb.py) + (python -m pytest -s -v test_image.py -k "webp or ppm or bmp or bounding or exif or hdr or openexr or tiff or avif") + (python -m pytest -s -v test_serialization.py) + (python -m pytest -s -v test_io_dataset.py -k "numpy or hdf5 or audio or to_file") + (python -m pytest -s -v test_http.py) python -m pip install google-cloud-bigquery-storage==0.7.0 google-cloud-bigquery==1.22.0 fastavro - (python -m pytest -s -v test_bigquery_eager.py) - (python -m pytest -s -v test_dicom_eager.py) + (python -m pytest -s -v test_bigquery.py) (python -m pytest -s -v test_dicom.py) release: diff --git a/docs/development.md b/docs/development.md index c23c7d66a..5c43a509a 100644 --- a/docs/development.md +++ b/docs/development.md @@ -88,7 +88,7 @@ bazel build -s --verbose_failures $BAZEL_OPTIMIZATION //tensorflow_io/... # `bazel-bin/tensorflow_io/core/python/ops/` and it is possible # to run tests with `pytest`, e.g.: sudo python3 -m pip install pytest -TFIO_DATAPATH=bazel-bin python3 -m pytest -s -v tests/test_serialization_eager.py +TFIO_DATAPATH=bazel-bin python3 -m pytest -s -v tests/test_serialization.py ``` NOTE: When running pytest, `TFIO_DATAPATH=bazel-bin` has to be passed so that python can utilize the generated shared libraries after the build process. @@ -147,7 +147,7 @@ bazel build -s --verbose_failures $BAZEL_OPTIMIZATION //tensorflow_io/... # `bazel-bin/tensorflow_io/core/python/ops/` and it is possible # to run tests with `pytest`, e.g.: sudo python3 -m pip install pytest -TFIO_DATAPATH=bazel-bin python3 -m pytest -s -v tests/test_serialization_eager.py +TFIO_DATAPATH=bazel-bin python3 -m pytest -s -v tests/test_serialization.py ``` ##### CentOS 8 @@ -207,7 +207,7 @@ scl enable rh-python36 devtoolset-9 \ TFIO_DATAPATH=bazel-bin \ scl enable rh-python36 devtoolset-9 \ - 'python3 -m pytest -s -v tests/test_serialization_eager.py' + 'python3 -m pytest -s -v tests/test_serialization.py' ``` #### Python Wheels @@ -295,7 +295,7 @@ use: $ bash -x -e tests/test_kafka/kafka_test.sh # Run the tests -$ TFIO_DATAPATH=bazel-bin pytest -s -vv tests/test_kafka_eager.py +$ TFIO_DATAPATH=bazel-bin pytest -s -vv tests/test_kafka.py ``` Testing `Datasets` associated with tools such as `Elasticsearch` or `MongoDB` @@ -307,7 +307,7 @@ require docker to be available on the system. In such scenarios, use: $ bash tests/test_elasticsearch/elasticsearch_test.sh start # Run the tests -$ TFIO_DATAPATH=bazel-bin pytest -s -vv tests/test_elasticsearch_eager.py +$ TFIO_DATAPATH=bazel-bin pytest -s -vv tests/test_elasticsearch.py # Stop and remove the container $ bash tests/test_elasticsearch/elasticsearch_test.sh stop @@ -319,7 +319,7 @@ For example, to run tests related to `parquet` dataset's, use: ```sh # Just run the test -$ TFIO_DATAPATH=bazel-bin pytest -s -vv tests/test_parquet_eager.py +$ TFIO_DATAPATH=bazel-bin pytest -s -vv tests/test_parquet.py ``` diff --git a/tests/test_archive_eager.py b/tests/test_archive.py similarity index 100% rename from tests/test_archive_eager.py rename to tests/test_archive.py diff --git a/tests/test_arrow_eager.py b/tests/test_arrow.py similarity index 100% rename from tests/test_arrow_eager.py rename to tests/test_arrow.py diff --git a/tests/test_audio_eager.py b/tests/test_audio.py similarity index 100% rename from tests/test_audio_eager.py rename to tests/test_audio.py diff --git a/tests/test_audio_ops_eager.py b/tests/test_audio_ops.py similarity index 100% rename from tests/test_audio_ops_eager.py rename to tests/test_audio_ops.py diff --git a/tests/test_avro_eager.py b/tests/test_avro.py similarity index 100% rename from tests/test_avro_eager.py rename to tests/test_avro.py diff --git a/tests/test_bigquery_eager.py b/tests/test_bigquery.py similarity index 100% rename from tests/test_bigquery_eager.py rename to tests/test_bigquery.py diff --git a/tests/test_bigtable_eager.py b/tests/test_bigtable.py similarity index 100% rename from tests/test_bigtable_eager.py rename to tests/test_bigtable.py diff --git a/tests/test_color_eager.py b/tests/test_color.py similarity index 100% rename from tests/test_color_eager.py rename to tests/test_color.py diff --git a/tests/test_csv_eager.py b/tests/test_csv.py similarity index 100% rename from tests/test_csv_eager.py rename to tests/test_csv.py diff --git a/tests/test_dicom.py b/tests/test_dicom.py index c4605e61a..e867465cb 100644 --- a/tests/test_dicom.py +++ b/tests/test_dicom.py @@ -16,6 +16,7 @@ import os +import numpy as np import pytest import tensorflow as tf @@ -35,8 +36,7 @@ def test_dicom_input(): - """test_dicom_input - """ + """test_dicom_input""" _ = tfio.image.decode_dicom_data _ = tfio.image.decode_dicom_image _ = tfio.image.dicom_tags @@ -66,32 +66,26 @@ def test_dicom_input(): ("MR-MONO2-12-shoulder.dcm", (1, 1024, 1024, 1)), ("OT-MONO2-8-a7.dcm", (1, 512, 512, 1)), ("US-PAL-8-10x-echo.dcm", (10, 430, 600, 3)), + ("TOSHIBA_J2K_OpenJPEGv2Regression.dcm", (1, 512, 512, 1)), ], ) def test_decode_dicom_image(fname, exp_shape): - """test_decode_dicom_image - """ + """test_decode_dicom_image""" dcm_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), "test_dicom", fname ) - g1 = tf.compat.v1.Graph() + file_contents = tf.io.read_file(filename=dcm_path) - with g1.as_default(): - file_contents = tf.io.read_file(filename=dcm_path) - dcm_image = tfio.image.decode_dicom_image( - contents=file_contents, - dtype=tf.float32, - on_error="strict", - scale="auto", - color_dim=True, - ) - - sess = tf.compat.v1.Session(graph=g1) - dcm_image_np = sess.run(dcm_image) - - assert dcm_image_np.shape == exp_shape + dcm_image = tfio.image.decode_dicom_image( + contents=file_contents, + dtype=tf.float32, + on_error="strict", + scale="auto", + color_dim=True, + ) + assert dcm_image.numpy().shape == exp_shape @pytest.mark.parametrize( @@ -121,23 +115,108 @@ def test_decode_dicom_image(fname, exp_shape): ], ) def test_decode_dicom_data(fname, tag, exp_value): - """test_decode_dicom_data - """ + """test_decode_dicom_data""" dcm_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), "test_dicom", fname ) - g1 = tf.compat.v1.Graph() + file_contents = tf.io.read_file(filename=dcm_path) + + dcm_data = tfio.image.decode_dicom_data(contents=file_contents, tags=tag) + + assert dcm_data.numpy() == exp_value + + +def test_dicom_image_shape(): + """test_decode_dicom_image""" + + dcm_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_dicom", + "US-PAL-8-10x-echo.dcm", + ) + + dataset = tf.data.Dataset.from_tensor_slices([dcm_path]) + dataset = dataset.map(tf.io.read_file) + dataset = dataset.map(lambda e: tfio.image.decode_dicom_image(e, dtype=tf.uint16)) + dataset = dataset.map(lambda e: tf.image.resize(e, (224, 224))) + + +def test_dicom_image_concurrency(): + """test_decode_dicom_image_currency""" - with g1.as_default(): - file_contents = tf.io.read_file(filename=dcm_path) - dcm_data = tfio.image.decode_dicom_data(contents=file_contents, tags=tag) + @tf.function + def preprocess(dcm_content): + tags = tfio.image.decode_dicom_data( + dcm_content, tags=[tfio.image.dicom_tags.PatientsName] + ) + tf.print(tags) + image = tfio.image.decode_dicom_image(dcm_content, dtype=tf.float32) + return image + + dcm_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_dicom", + "TOSHIBA_J2K_OpenJPEGv2Regression.dcm", + ) + + dataset = ( + tf.data.Dataset.from_tensor_slices([dcm_path]) + .repeat() + .map(tf.io.read_file) + .map(preprocess, num_parallel_calls=8) + .take(200) + ) + for i, item in enumerate(dataset): + print(tf.shape(item), i) + assert np.array_equal(tf.shape(item), [1, 512, 512, 1]) + + dcm_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_dicom", + "US-PAL-8-10x-echo.dcm", + ) + + dataset = ( + tf.data.Dataset.from_tensor_slices([dcm_path]) + .repeat() + .map(tf.io.read_file) + .map(preprocess, num_parallel_calls=8) + .take(200) + ) + for i, item in enumerate(dataset): + print(tf.shape(item), i) + assert np.array_equal(tf.shape(item), [10, 430, 600, 3]) + + +def test_dicom_sequence(): + """test_decode_dicom_sequence""" + + dcm_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_dicom", + "2.25.304589190180579357564631626197663875025.dcm", + ) + dcm_content = tf.io.read_file(filename=dcm_path) + + tags = tfio.image.decode_dicom_data( + dcm_content, tags=["[0x0008,0x1115][0][0x0008,0x1140][0][0x0008,0x1155]"] + ) + assert np.array_equal(tags, [b"2.25.211904290918469145111906856660599393535"]) + + dcm_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_dicom", + "US-PAL-8-10x-echo.dcm", + ) + dcm_content = tf.io.read_file(filename=dcm_path) - sess = tf.compat.v1.Session(graph=g1) - dcm_data_np = sess.run(dcm_data) + tags = tfio.image.decode_dicom_data(dcm_content, tags=["[0x0020,0x000E]"]) + assert np.array_equal(tags, [b"999.999.94827453"]) - assert dcm_data_np == exp_value + tags = tfio.image.decode_dicom_data(dcm_content, tags=["0x0020,0x000e"]) + assert np.array_equal(tags, [b"999.999.94827453"]) if __name__ == "__main__": diff --git a/tests/test_dicom_eager.py b/tests/test_dicom_eager.py deleted file mode 100644 index e867465cb..000000000 --- a/tests/test_dicom_eager.py +++ /dev/null @@ -1,223 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may not -# use this file except in compliance with the License. You may obtain a copy of -# the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations under -# the License. -# ============================================================================== -"""Tests for DICOM.""" - - -import os -import numpy as np -import pytest - -import tensorflow as tf -import tensorflow_io as tfio - -# The DICOM sample files must be downloaded befor running the tests -# -# To download the DICOM samples: -# $ bash dicom_samples.sh download -# $ bash dicom_samples.sh extract -# -# To remopve the DICOM samples: -# $ bash dicom_samples.sh clean_dcm -# -# To remopve all the downloaded files: -# $ bash dicom_samples.sh clean_all - - -def test_dicom_input(): - """test_dicom_input""" - _ = tfio.image.decode_dicom_data - _ = tfio.image.decode_dicom_image - _ = tfio.image.dicom_tags - - -@pytest.mark.parametrize( - "fname, exp_shape", - [ - ("OT-MONO2-8-colon.dcm", (1, 512, 512, 1)), - ("CR-MONO1-10-chest.dcm", (1, 440, 440, 1)), - ("CT-MONO2-16-ort.dcm", (1, 512, 512, 1)), - ("MR-MONO2-16-head.dcm", (1, 256, 256, 1)), - ("US-RGB-8-epicard.dcm", (1, 480, 640, 3)), - ("CT-MONO2-8-abdo.dcm", (1, 512, 512, 1)), - ("MR-MONO2-16-knee.dcm", (1, 256, 256, 1)), - ("OT-MONO2-8-hip.dcm", (1, 512, 512, 1)), - ("US-RGB-8-esopecho.dcm", (1, 120, 256, 3)), - ("CT-MONO2-16-ankle.dcm", (1, 512, 512, 1)), - ("MR-MONO2-12-an2.dcm", (1, 256, 256, 1)), - ("MR-MONO2-8-16x-heart.dcm", (16, 256, 256, 1)), - ("OT-PAL-8-face.dcm", (1, 480, 640, 3)), - ("XA-MONO2-8-12x-catheter.dcm", (12, 512, 512, 1)), - ("CT-MONO2-16-brain.dcm", (1, 512, 512, 1)), - ("NM-MONO2-16-13x-heart.dcm", (13, 64, 64, 1)), - ("US-MONO2-8-8x-execho.dcm", (8, 120, 128, 1)), - ("CT-MONO2-16-chest.dcm", (1, 400, 512, 1)), - ("MR-MONO2-12-shoulder.dcm", (1, 1024, 1024, 1)), - ("OT-MONO2-8-a7.dcm", (1, 512, 512, 1)), - ("US-PAL-8-10x-echo.dcm", (10, 430, 600, 3)), - ("TOSHIBA_J2K_OpenJPEGv2Regression.dcm", (1, 512, 512, 1)), - ], -) -def test_decode_dicom_image(fname, exp_shape): - """test_decode_dicom_image""" - - dcm_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "test_dicom", fname - ) - - file_contents = tf.io.read_file(filename=dcm_path) - - dcm_image = tfio.image.decode_dicom_image( - contents=file_contents, - dtype=tf.float32, - on_error="strict", - scale="auto", - color_dim=True, - ) - assert dcm_image.numpy().shape == exp_shape - - -@pytest.mark.parametrize( - "fname, tag, exp_value", - [ - ( - "OT-MONO2-8-colon.dcm", - tfio.image.dicom_tags.StudyInstanceUID, - b"1.3.46.670589.17.1.7.1.1.16", - ), - ("OT-MONO2-8-colon.dcm", tfio.image.dicom_tags.Rows, b"512"), - ("OT-MONO2-8-colon.dcm", tfio.image.dicom_tags.Columns, b"512"), - ("OT-MONO2-8-colon.dcm", tfio.image.dicom_tags.SamplesperPixel, b"1"), - ( - "US-PAL-8-10x-echo.dcm", - tfio.image.dicom_tags.StudyInstanceUID, - b"999.999.3859744", - ), - ( - "US-PAL-8-10x-echo.dcm", - tfio.image.dicom_tags.SeriesInstanceUID, - b"999.999.94827453", - ), - ("US-PAL-8-10x-echo.dcm", tfio.image.dicom_tags.NumberofFrames, b"10"), - ("US-PAL-8-10x-echo.dcm", tfio.image.dicom_tags.Rows, b"430"), - ("US-PAL-8-10x-echo.dcm", tfio.image.dicom_tags.Columns, b"600"), - ], -) -def test_decode_dicom_data(fname, tag, exp_value): - """test_decode_dicom_data""" - - dcm_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "test_dicom", fname - ) - - file_contents = tf.io.read_file(filename=dcm_path) - - dcm_data = tfio.image.decode_dicom_data(contents=file_contents, tags=tag) - - assert dcm_data.numpy() == exp_value - - -def test_dicom_image_shape(): - """test_decode_dicom_image""" - - dcm_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "test_dicom", - "US-PAL-8-10x-echo.dcm", - ) - - dataset = tf.data.Dataset.from_tensor_slices([dcm_path]) - dataset = dataset.map(tf.io.read_file) - dataset = dataset.map(lambda e: tfio.image.decode_dicom_image(e, dtype=tf.uint16)) - dataset = dataset.map(lambda e: tf.image.resize(e, (224, 224))) - - -def test_dicom_image_concurrency(): - """test_decode_dicom_image_currency""" - - @tf.function - def preprocess(dcm_content): - tags = tfio.image.decode_dicom_data( - dcm_content, tags=[tfio.image.dicom_tags.PatientsName] - ) - tf.print(tags) - image = tfio.image.decode_dicom_image(dcm_content, dtype=tf.float32) - return image - - dcm_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "test_dicom", - "TOSHIBA_J2K_OpenJPEGv2Regression.dcm", - ) - - dataset = ( - tf.data.Dataset.from_tensor_slices([dcm_path]) - .repeat() - .map(tf.io.read_file) - .map(preprocess, num_parallel_calls=8) - .take(200) - ) - for i, item in enumerate(dataset): - print(tf.shape(item), i) - assert np.array_equal(tf.shape(item), [1, 512, 512, 1]) - - dcm_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "test_dicom", - "US-PAL-8-10x-echo.dcm", - ) - - dataset = ( - tf.data.Dataset.from_tensor_slices([dcm_path]) - .repeat() - .map(tf.io.read_file) - .map(preprocess, num_parallel_calls=8) - .take(200) - ) - for i, item in enumerate(dataset): - print(tf.shape(item), i) - assert np.array_equal(tf.shape(item), [10, 430, 600, 3]) - - -def test_dicom_sequence(): - """test_decode_dicom_sequence""" - - dcm_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "test_dicom", - "2.25.304589190180579357564631626197663875025.dcm", - ) - dcm_content = tf.io.read_file(filename=dcm_path) - - tags = tfio.image.decode_dicom_data( - dcm_content, tags=["[0x0008,0x1115][0][0x0008,0x1140][0][0x0008,0x1155]"] - ) - assert np.array_equal(tags, [b"2.25.211904290918469145111906856660599393535"]) - - dcm_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "test_dicom", - "US-PAL-8-10x-echo.dcm", - ) - dcm_content = tf.io.read_file(filename=dcm_path) - - tags = tfio.image.decode_dicom_data(dcm_content, tags=["[0x0020,0x000E]"]) - assert np.array_equal(tags, [b"999.999.94827453"]) - - tags = tfio.image.decode_dicom_data(dcm_content, tags=["0x0020,0x000e"]) - assert np.array_equal(tags, [b"999.999.94827453"]) - - -if __name__ == "__main__": - test.main() diff --git a/tests/test_documentation_eager.py b/tests/test_documentation.py similarity index 100% rename from tests/test_documentation_eager.py rename to tests/test_documentation.py diff --git a/tests/test_elasticsearch_eager.py b/tests/test_elasticsearch.py similarity index 100% rename from tests/test_elasticsearch_eager.py rename to tests/test_elasticsearch.py diff --git a/tests/test_feather_eager.py b/tests/test_feather.py similarity index 100% rename from tests/test_feather_eager.py rename to tests/test_feather.py diff --git a/tests/test_ffmpeg_eager.py b/tests/test_ffmpeg.py similarity index 100% rename from tests/test_ffmpeg_eager.py rename to tests/test_ffmpeg.py diff --git a/tests/test_filter_eager.py b/tests/test_filter.py similarity index 100% rename from tests/test_filter_eager.py rename to tests/test_filter.py diff --git a/tests/test_gcs_eager.py b/tests/test_gcs.py similarity index 100% rename from tests/test_gcs_eager.py rename to tests/test_gcs.py diff --git a/tests/test_genome.py b/tests/test_genome.py index 48b2642b3..1798b8914 100644 --- a/tests/test_genome.py +++ b/tests/test_genome.py @@ -19,8 +19,6 @@ import numpy as np import tensorflow as tf - -tf.compat.v1.disable_eager_execution() import tensorflow_io as tfio # pylint: disable=wrong-import-position fastq_path = os.path.join( @@ -30,13 +28,8 @@ def test_genome_fastq_reader(): """test_genome_fastq_reader""" - g1 = tf.compat.v1.Graph() - - with g1.as_default(): - data = tfio.genome.read_fastq(filename=fastq_path) - sess = tf.compat.v1.Session(graph=g1) - data_np = sess.run(data) + data = tfio.genome.read_fastq(filename=fastq_path) data_expected = [ b"GATTACA", @@ -52,8 +45,8 @@ def test_genome_fastq_reader(): b"FAD", ] - assert np.all(data_np.sequences == data_expected) - assert np.all(data_np.raw_quality == quality_expected) + assert np.all(data.sequences == data_expected) + assert np.all(data.raw_quality == quality_expected) def test_genome_sequences_to_onehot(): @@ -189,12 +182,10 @@ def test_genome_sequences_to_onehot(): [[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0]], ] - with tf.compat.v1.Session() as sess: - raw_data = tfio.genome.read_fastq(filename=fastq_path) - data = tfio.genome.sequences_to_onehot(sequences=raw_data.sequences) - out = sess.run(data) + raw_data = tfio.genome.read_fastq(filename=fastq_path) + data = tfio.genome.sequences_to_onehot(sequences=raw_data.sequences) - assert np.all(out.to_list() == expected) + assert np.all(data.to_list() == expected) def test_genome_phred_sequences_to_probability(): @@ -210,28 +201,21 @@ def test_genome_phred_sequences_to_probability(): 0.00019952621369156986, ] - with tf.compat.v1.Session() as sess: - example_quality = tf.constant(example_quality_list) - converted_phred = tfio.genome.phred_sequences_to_probability(example_quality) - out = sess.run(converted_phred) + example_quality = tf.constant(example_quality_list) + converted_phred = tfio.genome.phred_sequences_to_probability(example_quality) # Compare flat values - assert np.allclose(out.flat_values.flatten(), expected_probabilities) + assert np.allclose( + converted_phred.flat_values.numpy().flatten(), expected_probabilities + ) # Ensure nested array lengths are correct assert np.all( - [len(a) == len(b) for a, b in zip(out.to_list(), example_quality_list)] + [ + len(a) == len(b) + for a, b in zip(converted_phred.to_list(), example_quality_list) + ] ) -def test_genome_phred_sequences_to_probability_with_other_genome_ops(): - """Test quality op in graph with read_fastq op, ensure no errors""" - with tf.compat.v1.Session() as sess: - raw_data = tfio.genome.read_fastq(filename=fastq_path) - data = tfio.genome.phred_sequences_to_probability( - phred_qualities=raw_data.raw_quality - ) - sess.run(data) - - if __name__ == "__main__": test.main() diff --git a/tests/test_genome_eager.py b/tests/test_genome_v1.py similarity index 79% rename from tests/test_genome_eager.py rename to tests/test_genome_v1.py index 1798b8914..48b2642b3 100644 --- a/tests/test_genome_eager.py +++ b/tests/test_genome_v1.py @@ -19,6 +19,8 @@ import numpy as np import tensorflow as tf + +tf.compat.v1.disable_eager_execution() import tensorflow_io as tfio # pylint: disable=wrong-import-position fastq_path = os.path.join( @@ -28,8 +30,13 @@ def test_genome_fastq_reader(): """test_genome_fastq_reader""" + g1 = tf.compat.v1.Graph() + + with g1.as_default(): + data = tfio.genome.read_fastq(filename=fastq_path) - data = tfio.genome.read_fastq(filename=fastq_path) + sess = tf.compat.v1.Session(graph=g1) + data_np = sess.run(data) data_expected = [ b"GATTACA", @@ -45,8 +52,8 @@ def test_genome_fastq_reader(): b"FAD", ] - assert np.all(data.sequences == data_expected) - assert np.all(data.raw_quality == quality_expected) + assert np.all(data_np.sequences == data_expected) + assert np.all(data_np.raw_quality == quality_expected) def test_genome_sequences_to_onehot(): @@ -182,10 +189,12 @@ def test_genome_sequences_to_onehot(): [[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0]], ] - raw_data = tfio.genome.read_fastq(filename=fastq_path) - data = tfio.genome.sequences_to_onehot(sequences=raw_data.sequences) + with tf.compat.v1.Session() as sess: + raw_data = tfio.genome.read_fastq(filename=fastq_path) + data = tfio.genome.sequences_to_onehot(sequences=raw_data.sequences) + out = sess.run(data) - assert np.all(data.to_list() == expected) + assert np.all(out.to_list() == expected) def test_genome_phred_sequences_to_probability(): @@ -201,21 +210,28 @@ def test_genome_phred_sequences_to_probability(): 0.00019952621369156986, ] - example_quality = tf.constant(example_quality_list) - converted_phred = tfio.genome.phred_sequences_to_probability(example_quality) + with tf.compat.v1.Session() as sess: + example_quality = tf.constant(example_quality_list) + converted_phred = tfio.genome.phred_sequences_to_probability(example_quality) + out = sess.run(converted_phred) # Compare flat values - assert np.allclose( - converted_phred.flat_values.numpy().flatten(), expected_probabilities - ) + assert np.allclose(out.flat_values.flatten(), expected_probabilities) # Ensure nested array lengths are correct assert np.all( - [ - len(a) == len(b) - for a, b in zip(converted_phred.to_list(), example_quality_list) - ] + [len(a) == len(b) for a, b in zip(out.to_list(), example_quality_list)] ) +def test_genome_phred_sequences_to_probability_with_other_genome_ops(): + """Test quality op in graph with read_fastq op, ensure no errors""" + with tf.compat.v1.Session() as sess: + raw_data = tfio.genome.read_fastq(filename=fastq_path) + data = tfio.genome.phred_sequences_to_probability( + phred_qualities=raw_data.raw_quality + ) + sess.run(data) + + if __name__ == "__main__": test.main() diff --git a/tests/test_hdf5_eager.py b/tests/test_hdf5.py similarity index 100% rename from tests/test_hdf5_eager.py rename to tests/test_hdf5.py diff --git a/tests/test_hdfs_eager.py b/tests/test_hdfs.py similarity index 100% rename from tests/test_hdfs_eager.py rename to tests/test_hdfs.py diff --git a/tests/test_http_eager.py b/tests/test_http.py similarity index 100% rename from tests/test_http_eager.py rename to tests/test_http.py diff --git a/tests/test_ignite.py b/tests/test_ignite_v1.py similarity index 100% rename from tests/test_ignite.py rename to tests/test_ignite_v1.py diff --git a/tests/test_image_eager.py b/tests/test_image.py similarity index 100% rename from tests/test_image_eager.py rename to tests/test_image.py diff --git a/tests/test_io_dataset_eager.py b/tests/test_io_dataset.py similarity index 100% rename from tests/test_io_dataset_eager.py rename to tests/test_io_dataset.py diff --git a/tests/test_io_layer_eager.py b/tests/test_io_layer.py similarity index 100% rename from tests/test_io_layer_eager.py rename to tests/test_io_layer.py diff --git a/tests/test_io_tensor_eager.py b/tests/test_io_tensor.py similarity index 100% rename from tests/test_io_tensor_eager.py rename to tests/test_io_tensor.py diff --git a/tests/test_json_eager.py b/tests/test_json.py similarity index 100% rename from tests/test_json_eager.py rename to tests/test_json.py diff --git a/tests/test_kafka.py b/tests/test_kafka.py index 8c539473c..06e82b12a 100644 --- a/tests/test_kafka.py +++ b/tests/test_kafka.py @@ -1,4 +1,4 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not # use this file except in compliance with the License. You may obtain a copy of @@ -12,500 +12,504 @@ # License for the specific language governing permissions and limitations under # the License. # ============================================================================== -"""Tests for KafkaDataset.""" +"""Tests for Kafka Output Sequence.""" import time import pytest +import numpy as np +import threading import tensorflow as tf - -tf.compat.v1.disable_eager_execution() - -from tensorflow import dtypes # pylint: disable=wrong-import-position -from tensorflow import errors # pylint: disable=wrong-import-position -from tensorflow import test # pylint: disable=wrong-import-position -from tensorflow.compat.v1 import data # pylint: disable=wrong-import-position - +import tensorflow_io as tfio +from tensorflow_io.kafka.python.ops import ( + kafka_ops, +) # pylint: disable=wrong-import-position import tensorflow_io.kafka as kafka_io # pylint: disable=wrong-import-position -class KafkaDatasetTest(test.TestCase): - """Tests for KafkaDataset.""" - - # The Kafka server has to be setup before the test - # and tear down after the test manually. - # The docker engine has to be installed. - # - # To setup the Kafka server: - # $ bash kafka_test.sh start kafka - # - # To tear down the Kafka server: - # $ bash kafka_test.sh stop kafka - - def test_kafka_dataset(self): - """Tests for KafkaDataset when reading non-keyed messages - from a single-partitioned topic""" - topics = tf.compat.v1.placeholder(dtypes.string, shape=[None]) - num_epochs = tf.compat.v1.placeholder(dtypes.int64, shape=[]) - batch_size = tf.compat.v1.placeholder(dtypes.int64, shape=[]) - - repeat_dataset = kafka_io.KafkaDataset(topics, group="test", eof=True).repeat( - num_epochs +def test_kafka_io_tensor(): + kafka = tfio.IOTensor.from_kafka("test") + assert kafka.dtype == tf.string + assert kafka.shape.as_list() == [None] + assert np.all( + kafka.to_tensor().numpy() == [("D" + str(i)).encode() for i in range(10)] + ) + assert len(kafka.to_tensor()) == 10 + + +@pytest.mark.skip(reason="TODO") +def test_kafka_output_sequence(): + """Test case based on fashion mnist tutorial""" + fashion_mnist = tf.keras.datasets.fashion_mnist + ((train_images, train_labels), (test_images, _)) = fashion_mnist.load_data() + + class_names = [ + "T-shirt/top", + "Trouser", + "Pullover", + "Dress", + "Coat", + "Sandal", + "Shirt", + "Sneaker", + "Bag", + "Ankle boot", + ] + + train_images = train_images / 255.0 + test_images = test_images / 255.0 + + model = tf.keras.Sequential( + [ + tf.keras.layers.Flatten(input_shape=(28, 28)), + tf.keras.layers.Dense(128, activation=tf.nn.relu), + tf.keras.layers.Dense(10, activation=tf.nn.softmax), + ] + ) + + model.compile( + optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"] + ) + + model.fit(train_images, train_labels, epochs=5) + + class OutputCallback(tf.keras.callbacks.Callback): + """KafkaOutputCallback""" + + def __init__( + self, batch_size, topic, servers + ): # pylint: disable=super-init-not-called + self._sequence = kafka_ops.KafkaOutputSequence(topic=topic, servers=servers) + self._batch_size = batch_size + + def on_predict_batch_end(self, batch, logs=None): + index = batch * self._batch_size + for outputs in logs["outputs"]: + for output in outputs: + self._sequence.setitem(index, class_names[np.argmax(output)]) + index += 1 + + def flush(self): + self._sequence.flush() + + channel = "e{}e".format(time.time()) + topic = "test_" + channel + + # By default batch size is 32 + output = OutputCallback(32, topic, "localhost") + predictions = model.predict(test_images, callbacks=[output]) + output.flush() + + predictions = [class_names[v] for v in np.argmax(predictions, axis=1)] + + # Reading from `test_e(time)e` we should get the same result + dataset = tfio.kafka.KafkaDataset(topics=[topic], group="test", eof=True) + for entry, prediction in zip(dataset, predictions): + assert entry.numpy() == prediction.encode() + + +def test_avro_kafka_dataset(): + """test_avro_kafka_dataset""" + schema = ( + '{"type":"record","name":"myrecord","fields":[' + '{"name":"f1","type":"string"},' + '{"name":"f2","type":"long"},' + '{"name":"f3","type":["null","string"],"default":null}' + "]}" + ) + dataset = kafka_io.KafkaDataset(["avro-test:0"], group="avro-test", eof=True) + # remove kafka framing + dataset = dataset.map(lambda e: tf.strings.substr(e, 5, -1)) + # deserialize avro + dataset = dataset.map( + lambda e: tfio.experimental.serialization.decode_avro(e, schema=schema) + ) + entries = [(e["f1"], e["f2"], e["f3"]) for e in dataset] + np.all(entries == [("value1", 1, ""), ("value2", 2, ""), ("value3", 3, "")]) + + +def test_avro_kafka_dataset_with_resource(): + """test_avro_kafka_dataset_with_resource""" + schema = ( + '{"type":"record","name":"myrecord","fields":[' + '{"name":"f1","type":"string"},' + '{"name":"f2","type":"long"},' + '{"name":"f3","type":["null","string"],"default":null}' + ']}"' + ) + schema_resource = kafka_io.decode_avro_init(schema) + dataset = kafka_io.KafkaDataset(["avro-test:0"], group="avro-test", eof=True) + # remove kafka framing + dataset = dataset.map(lambda e: tf.strings.substr(e, 5, -1)) + # deserialize avro + dataset = dataset.map( + lambda e: kafka_io.decode_avro( + e, schema=schema_resource, dtype=[tf.string, tf.int64, tf.string] ) - batch_dataset = repeat_dataset.batch(batch_size) - - iterator = data.Iterator.from_structure(batch_dataset.output_types) - init_op = iterator.make_initializer(repeat_dataset) - init_batch_op = iterator.make_initializer(batch_dataset) - get_next = iterator.get_next() - - with self.cached_session() as sess: - # Basic test: read a limited number of messages from the topic. - sess.run(init_op, feed_dict={topics: ["test:0:0:4"], num_epochs: 1}) - for i in range(5): - self.assertEqual(("D" + str(i)).encode(), sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Basic test: read all the messages from the topic from offset 5. - sess.run(init_op, feed_dict={topics: ["test:0:5:-1"], num_epochs: 1}) - for i in range(5): - self.assertEqual(("D" + str(i + 5)).encode(), sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Basic test: read from different subscriptions of the same topic. - sess.run( - init_op, - feed_dict={topics: ["test:0:0:4", "test:0:5:-1"], num_epochs: 1}, - ) - for j in range(2): - for i in range(5): - self.assertEqual( - ("D" + str(i + j * 5)).encode(), sess.run(get_next) - ) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test repeated iteration through both subscriptions. - sess.run( - init_op, - feed_dict={topics: ["test:0:0:4", "test:0:5:-1"], num_epochs: 10}, - ) - for _ in range(10): - for j in range(2): - for i in range(5): - self.assertEqual( - ("D" + str(i + j * 5)).encode(), sess.run(get_next) - ) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test batched and repeated iteration through both subscriptions. - sess.run( - init_batch_op, - feed_dict={ - topics: ["test:0:0:4", "test:0:5:-1"], - num_epochs: 10, - batch_size: 5, - }, - ) - for _ in range(10): - self.assertAllEqual( - [("D" + str(i)).encode() for i in range(5)], sess.run(get_next) - ) - self.assertAllEqual( - [("D" + str(i + 5)).encode() for i in range(5)], sess.run(get_next) - ) - - @pytest.mark.skip(reason="TODO") - def test_kafka_dataset_save_and_restore(self): - """Tests for KafkaDataset save and restore.""" - g = tf.Graph() - with g.as_default(): - topics = tf.compat.v1.placeholder(dtypes.string, shape=[None]) - num_epochs = tf.compat.v1.placeholder(dtypes.int64, shape=[]) - - repeat_dataset = kafka_io.KafkaDataset( - topics, group="test", eof=True - ).repeat(num_epochs) - iterator = repeat_dataset.make_initializable_iterator() - get_next = iterator.get_next() - - it = tf.data.experimental.make_saveable_from_iterator(iterator) - g.add_to_collection(tf.compat.v1.GraphKeys.SAVEABLE_OBJECTS, it) - saver = tf.compat.v1.train.Saver() - - model_file = "/tmp/test-kafka-model" - with self.cached_session() as sess: - sess.run( - iterator.initializer, - feed_dict={topics: ["test:0:0:4"], num_epochs: 1}, - ) - for i in range(3): - self.assertEqual(("D" + str(i)).encode(), sess.run(get_next)) - # Save current offset which is 2 - saver.save(sess, model_file, global_step=3) - - checkpoint_file = "/tmp/test-kafka-model-3" - with self.cached_session() as sess: - saver.restore(sess, checkpoint_file) - # Restore current offset to 2 - for i in [2, 3]: - self.assertEqual(("D" + str(i)).encode(), sess.run(get_next)) - - def test_kafka_topic_configuration(self): - """Tests for KafkaDataset topic configuration properties.""" - topics = tf.compat.v1.placeholder(dtypes.string, shape=[None]) - num_epochs = tf.compat.v1.placeholder(dtypes.int64, shape=[]) - cfg_list = ["auto.offset.reset=earliest"] - - repeat_dataset = kafka_io.KafkaDataset( - topics, group="test", eof=True, config_topic=cfg_list - ).repeat(num_epochs) - - iterator = data.Iterator.from_structure(repeat_dataset.output_types) - init_op = iterator.make_initializer(repeat_dataset) - get_next = iterator.get_next() - - with self.cached_session() as sess: - # Use a wrong offset 100 here to make sure - # configuration 'auto.offset.reset=earliest' works. - sess.run(init_op, feed_dict={topics: ["test:0:100:-1"], num_epochs: 1}) - for i in range(5): - self.assertEqual(("D" + str(i)).encode(), sess.run(get_next)) - - def test_kafka_global_configuration(self): - """Tests for KafkaDataset global configuration properties.""" - topics = tf.compat.v1.placeholder(dtypes.string, shape=[None]) - num_epochs = tf.compat.v1.placeholder(dtypes.int64, shape=[]) - cfg_list = ["debug=generic", "enable.auto.commit=false"] - - repeat_dataset = kafka_io.KafkaDataset( - topics, group="test", eof=True, config_global=cfg_list - ).repeat(num_epochs) - - iterator = data.Iterator.from_structure(repeat_dataset.output_types) - init_op = iterator.make_initializer(repeat_dataset) - get_next = iterator.get_next() - - with self.cached_session() as sess: - sess.run(init_op, feed_dict={topics: ["test:0:0:4"], num_epochs: 1}) - for i in range(5): - self.assertEqual(("D" + str(i)).encode(), sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def test_kafka_wrong_global_configuration_failed(self): - """Tests for KafkaDataset worng global configuration properties.""" - topics = tf.compat.v1.placeholder(dtypes.string, shape=[None]) - num_epochs = tf.compat.v1.placeholder(dtypes.int64, shape=[]) - - # Add wrong configuration - wrong_cfg = ["debug=al"] - repeat_dataset = kafka_io.KafkaDataset( - topics, group="test", eof=True, config_global=wrong_cfg - ).repeat(num_epochs) - - iterator = data.Iterator.from_structure(repeat_dataset.output_types) - init_op = iterator.make_initializer(repeat_dataset) - get_next = iterator.get_next() - - with self.cached_session() as sess: - sess.run(init_op, feed_dict={topics: ["test:0:0:4"], num_epochs: 1}) - with self.assertRaises(errors.InternalError): - sess.run(get_next) - - def test_kafka_wrong_topic_configuration_failed(self): - """Tests for KafkaDataset wrong topic configuration properties.""" - topics = tf.compat.v1.placeholder(dtypes.string, shape=[None]) - num_epochs = tf.compat.v1.placeholder(dtypes.int64, shape=[]) - - # Add wrong configuration - wrong_cfg = ["auto.offset.reset=arliest"] - repeat_dataset = kafka_io.KafkaDataset( - topics, group="test", eof=True, config_topic=wrong_cfg - ).repeat(num_epochs) - - iterator = data.Iterator.from_structure(repeat_dataset.output_types) - init_op = iterator.make_initializer(repeat_dataset) - get_next = iterator.get_next() - - with self.cached_session() as sess: - sess.run(init_op, feed_dict={topics: ["test:0:0:4"], num_epochs: 1}) - with self.assertRaises(errors.InternalError): - sess.run(get_next) - - def test_write_kafka(self): - """test_write_kafka""" - channel = "e{}e".format(time.time()) - - # Start with reading test topic, replace `D` with `e(time)e`, - # and write to test_e(time)e` topic. - dataset = kafka_io.KafkaDataset(topics=["test:0:0:4"], group="test", eof=True) - dataset = dataset.map( - lambda x: kafka_io.write_kafka( - tf.strings.regex_replace(x, "D", channel), topic="test_" + channel - ) + ) + entries = [(f1.numpy(), f2.numpy(), f3.numpy()) for (f1, f2, f3) in dataset] + np.all(entries == [("value1", 1), ("value2", 2), ("value3", 3)]) + + +def test_kafka_stream_dataset(): + dataset = tfio.IODataset.stream().from_kafka("test").batch(2) + assert np.all( + [k.numpy().tolist() for (k, _) in dataset] + == np.asarray([("D" + str(i)).encode() for i in range(10)]).reshape((5, 2)) + ) + + +def test_kafka_io_dataset(): + dataset = tfio.IODataset.from_kafka( + "test", configuration=["fetch.min.bytes=2"] + ).batch(2) + # repeat multiple times will result in the same result + for _ in range(5): + assert np.all( + [k.numpy().tolist() for (k, _) in dataset] + == np.asarray([("D" + str(i)).encode() for i in range(10)]).reshape((5, 2)) ) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.cached_session() as sess: - # Basic test: read from topic 0. - sess.run(init_op) - for i in range(5): - self.assertEqual((channel + str(i)).encode(), sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Reading from `test_e(time)e` we should get the same result - dataset = kafka_io.KafkaDataset( - topics=["test_" + channel], group="test", eof=True + + +def test_avro_encode_decode(): + """test_avro_encode_decode""" + schema = ( + '{"type":"record","name":"myrecord","fields":' + '[{"name":"f1","type":"string"},{"name":"f2","type":"long"}]}' + ) + value = [("value1", 1), ("value2", 2), ("value3", 3)] + f1 = tf.cast([v[0] for v in value], tf.string) + f2 = tf.cast([v[1] for v in value], tf.int64) + message = tfio.experimental.serialization.encode_avro([f1, f2], schema=schema) + entries = tfio.experimental.serialization.decode_avro(message, schema=schema) + assert np.all(entries["f1"].numpy() == f1.numpy()) + assert np.all(entries["f2"].numpy() == f2.numpy()) + + +def test_kafka_group_io_dataset_primary_cg(): + """Test the functionality of the KafkaGroupIODataset when the consumer group + is being newly created. + + NOTE: After the kafka cluster is setup during the testing phase, 10 messages + are written to the 'key-partition-test' topic with 5 in each partition + (topic created with 2 partitions, the messages are split based on the keys). + And the same 10 messages are written into the 'key-test' topic (topic created + with 1 partition, so no splitting of the messages based on the keys). + + K0:D0, K1:D1, K0:D2, K1:D3, K0:D4, K1:D5, K0:D6, K1:D7, K0:D8, K1:D9. + + Here, messages D0, D2, D4, D6 and D8 are written into partition 0 and the rest are written + into partition 1. + + Also, since the messages are read from different partitions, the order of retrieval may not be + the same as storage. Thus, we sort and compare. + """ + dataset = tfio.experimental.streaming.KafkaGroupIODataset( + topics=["key-partition-test"], + group_id="cgtestprimary", + servers="localhost:9092", + configuration=[ + "session.timeout.ms=7000", + "max.poll.interval.ms=8000", + "auto.offset.reset=earliest", + ], + ) + assert np.all( + sorted([k.numpy() for (k, _) in dataset]) + == sorted([("D" + str(i)).encode() for i in range(10)]) + ) + + +def test_kafka_group_io_dataset_primary_cg_no_lag(): + """Test the functionality of the KafkaGroupIODataset when the + consumer group has read all the messages and committed the offsets. + """ + dataset = tfio.experimental.streaming.KafkaGroupIODataset( + topics=["key-partition-test"], + group_id="cgtestprimary", + servers="localhost:9092", + configuration=["session.timeout.ms=7000", "max.poll.interval.ms=8000"], + ) + assert np.all(sorted([k.numpy() for (k, _) in dataset]) == []) + + +def test_kafka_group_io_dataset_primary_cg_new_topic(): + """Test the functionality of the KafkaGroupIODataset when the existing + consumer group reads data from a new topic. + """ + dataset = tfio.experimental.streaming.KafkaGroupIODataset( + topics=["key-test"], + group_id="cgtestprimary", + servers="localhost:9092", + configuration=[ + "session.timeout.ms=7000", + "max.poll.interval.ms=8000", + "auto.offset.reset=earliest", + ], + ) + assert np.all( + sorted([k.numpy() for (k, _) in dataset]) + == sorted([("D" + str(i)).encode() for i in range(10)]) + ) + + +def test_kafka_group_io_dataset_resume_primary_cg(): + """Test the functionality of the KafkaGroupIODataset when the + consumer group is yet to catch up with the newly added messages only + (Instead of reading from the beginning). + """ + + # Write new messages to the topic + for i in range(10, 100): + message = "D{}".format(i) + kafka_io.write_kafka(message=message, topic="key-partition-test") + # Read only the newly sent 90 messages + dataset = tfio.experimental.streaming.KafkaGroupIODataset( + topics=["key-partition-test"], + group_id="cgtestprimary", + servers="localhost:9092", + configuration=["session.timeout.ms=7000", "max.poll.interval.ms=8000"], + ) + assert np.all( + sorted([k.numpy() for (k, _) in dataset]) + == sorted([("D" + str(i)).encode() for i in range(10, 100)]) + ) + + +def test_kafka_group_io_dataset_resume_primary_cg_new_topic(): + """Test the functionality of the KafkaGroupIODataset when the + consumer group is yet to catch up with the newly added messages only + (Instead of reading from the beginning) from the new topic. + """ + + # Write new messages to the topic + for i in range(10, 100): + message = "D{}".format(i) + kafka_io.write_kafka(message=message, topic="key-test") + # Read only the newly sent 90 messages + dataset = tfio.experimental.streaming.KafkaGroupIODataset( + topics=["key-test"], + group_id="cgtestprimary", + servers="localhost:9092", + configuration=["session.timeout.ms=7000", "max.poll.interval.ms=8000"], + ) + assert np.all( + sorted([k.numpy() for (k, _) in dataset]) + == sorted([("D" + str(i)).encode() for i in range(10, 100)]) + ) + + +def test_kafka_group_io_dataset_secondary_cg(): + """Test the functionality of the KafkaGroupIODataset when a + secondary consumer group is created and is yet to catch up all the messages, + from the beginning. + """ + + dataset = tfio.experimental.streaming.KafkaGroupIODataset( + topics=["key-partition-test"], + group_id="cgtestsecondary", + servers="localhost:9092", + configuration=[ + "session.timeout.ms=7000", + "max.poll.interval.ms=8000", + "auto.offset.reset=earliest", + ], + ) + assert np.all( + sorted([k.numpy() for (k, _) in dataset]) + == sorted([("D" + str(i)).encode() for i in range(100)]) + ) + + +def test_kafka_group_io_dataset_tertiary_cg_multiple_topics(): + """Test the functionality of the KafkaGroupIODataset when a new + consumer group reads data from multiple topics from the beginning. + """ + + dataset = tfio.experimental.streaming.KafkaGroupIODataset( + topics=["key-partition-test", "key-test"], + group_id="cgtesttertiary", + servers="localhost:9092", + configuration=[ + "session.timeout.ms=7000", + "max.poll.interval.ms=8000", + "auto.offset.reset=earliest", + ], + ) + assert np.all( + sorted([k.numpy() for (k, _) in dataset]) + == sorted([("D" + str(i)).encode() for i in range(100)] * 2) + ) + + +def test_kafka_group_io_dataset_auto_offset_reset(): + """Test the functionality of the `auto.offset.reset` configuration + at global and topic level""" + + dataset = tfio.experimental.streaming.KafkaGroupIODataset( + topics=["key-partition-test"], + group_id="cgglobaloffsetearliest", + servers="localhost:9092", + configuration=[ + "session.timeout.ms=7000", + "max.poll.interval.ms=8000", + "auto.offset.reset=earliest", + ], + ) + assert np.all( + sorted([k.numpy() for (k, _) in dataset]) + == sorted([("D" + str(i)).encode() for i in range(100)]) + ) + + dataset = tfio.experimental.streaming.KafkaGroupIODataset( + topics=["key-partition-test"], + group_id="cgglobaloffsetlatest", + servers="localhost:9092", + configuration=[ + "session.timeout.ms=7000", + "max.poll.interval.ms=8000", + "auto.offset.reset=latest", + ], + ) + assert np.all(sorted([k.numpy() for (k, _) in dataset]) == []) + + dataset = tfio.experimental.streaming.KafkaGroupIODataset( + topics=["key-partition-test"], + group_id="cgtopicoffsetearliest", + servers="localhost:9092", + configuration=[ + "session.timeout.ms=7000", + "max.poll.interval.ms=8000", + "conf.topic.auto.offset.reset=earliest", + ], + ) + assert np.all( + sorted([k.numpy() for (k, _) in dataset]) + == sorted([("D" + str(i)).encode() for i in range(100)]) + ) + + dataset = tfio.experimental.streaming.KafkaGroupIODataset( + topics=["key-partition-test"], + group_id="cgtopicoffsetlatest", + servers="localhost:9092", + configuration=[ + "session.timeout.ms=7000", + "max.poll.interval.ms=8000", + "conf.topic.auto.offset.reset=latest", + ], + ) + assert np.all(sorted([k.numpy() for (k, _) in dataset]) == []) + + +def test_kafka_group_io_dataset_invalid_stream_timeout(): + """Test the functionality of the KafkaGroupIODataset when the + consumer is configured to have an invalid stream_timeout value which is + less than the message_timeout value. + NOTE: The default value for message_timeout=5000 + """ + + STREAM_TIMEOUT = -20 + try: + tfio.experimental.streaming.KafkaGroupIODataset( + topics=["key-partition-test", "key-test"], + group_id="cgteststreaminvalid", + servers="localhost:9092", + stream_timeout=STREAM_TIMEOUT, + configuration=["session.timeout.ms=7000", "max.poll.interval.ms=8000"], ) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.cached_session() as sess: - sess.run(init_op) - for i in range(5): - self.assertEqual((channel + str(i)).encode(), sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def test_kafka_dataset_with_key(self): - """Tests for KafkaDataset when reading keyed-messages - from a single-partitioned topic""" - topics = tf.compat.v1.placeholder(dtypes.string, shape=[None]) - num_epochs = tf.compat.v1.placeholder(dtypes.int64, shape=[]) - batch_size = tf.compat.v1.placeholder(dtypes.int64, shape=[]) - - repeat_dataset = kafka_io.KafkaDataset( - topics, group="test", eof=True, message_key=True - ).repeat(num_epochs) - batch_dataset = repeat_dataset.batch(batch_size) - - iterator = data.Iterator.from_structure(batch_dataset.output_types) - init_op = iterator.make_initializer(repeat_dataset) - init_batch_op = iterator.make_initializer(batch_dataset) - get_next = iterator.get_next() - - with self.cached_session() as sess: - # Basic test: read a limited number of keyed messages from the topic. - sess.run(init_op, feed_dict={topics: ["key-test:0:0:4"], num_epochs: 1}) - for i in range(5): - self.assertEqual( - (("D" + str(i)).encode(), ("K" + str(i % 2)).encode()), - sess.run(get_next), - ) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Basic test: read all the keyed messages from the topic from offset 5. - sess.run(init_op, feed_dict={topics: ["key-test:0:5:-1"], num_epochs: 1}) - for i in range(5): - self.assertEqual( - (("D" + str(i + 5)).encode(), ("K" + str((i + 5) % 2)).encode()), - sess.run(get_next), - ) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Basic test: read from different subscriptions of the same topic. - sess.run( - init_op, - feed_dict={ - topics: ["key-test:0:0:4", "key-test:0:5:-1"], - num_epochs: 1, - }, - ) - for j in range(2): - for i in range(5): - self.assertEqual( - ( - ("D" + str(i + j * 5)).encode(), - ("K" + str((i + j * 5) % 2)).encode(), - ), - sess.run(get_next), - ) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test repeated iteration through both subscriptions. - sess.run( - init_op, - feed_dict={ - topics: ["key-test:0:0:4", "key-test:0:5:-1"], - num_epochs: 10, - }, - ) - for _ in range(10): - for j in range(2): - for i in range(5): - self.assertEqual( - ( - ("D" + str(i + j * 5)).encode(), - ("K" + str((i + j * 5) % 2)).encode(), - ), - sess.run(get_next), - ) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test batched and repeated iteration through both subscriptions. - sess.run( - init_batch_op, - feed_dict={ - topics: ["key-test:0:0:4", "key-test:0:5:-1"], - num_epochs: 10, - batch_size: 5, - }, - ) - for _ in range(10): - self.assertAllEqual( - [ - [("D" + str(i)).encode() for i in range(5)], - [("K" + str(i % 2)).encode() for i in range(5)], - ], - sess.run(get_next), - ) - self.assertAllEqual( - [ - [("D" + str(i + 5)).encode() for i in range(5)], - [("K" + str((i + 5) % 2)).encode() for i in range(5)], - ], - sess.run(get_next), - ) - - def test_kafka_dataset_with_partitioned_key(self): - """Tests for KafkaDataset when reading keyed-messages - from a multi-partitioned topic""" - topics = tf.compat.v1.placeholder(dtypes.string, shape=[None]) - num_epochs = tf.compat.v1.placeholder(dtypes.int64, shape=[]) - batch_size = tf.compat.v1.placeholder(dtypes.int64, shape=[]) - - repeat_dataset = kafka_io.KafkaDataset( - topics, group="test", eof=True, message_key=True - ).repeat(num_epochs) - batch_dataset = repeat_dataset.batch(batch_size) - - iterator = data.Iterator.from_structure(batch_dataset.output_types) - init_op = iterator.make_initializer(repeat_dataset) - init_batch_op = iterator.make_initializer(batch_dataset) - get_next = iterator.get_next() - - with self.cached_session() as sess: - # Basic test: read first 5 messages from the first partition of the topic. - # NOTE: The key-partition mapping occurs based on the order in which the data - # is being stored in kafka. Please check kafka_test.sh for the sample data. - - sess.run( - init_op, - feed_dict={topics: ["key-partition-test:0:0:5"], num_epochs: 1}, - ) - for i in range(5): - self.assertEqual( - (("D" + str(i * 2)).encode(), (b"K0")), sess.run(get_next), - ) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Basic test: read first 5 messages from the second partition of the topic. - sess.run( - init_op, - feed_dict={topics: ["key-partition-test:1:0:5"], num_epochs: 1}, - ) - for i in range(5): - self.assertEqual( - (("D" + str(i * 2 + 1)).encode(), (b"K1")), sess.run(get_next), - ) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Basic test: read from different subscriptions to the same topic. - sess.run( - init_op, - feed_dict={ - topics: ["key-partition-test:0:0:5", "key-partition-test:1:0:5"], - num_epochs: 1, - }, - ) - for j in range(2): - for i in range(5): - self.assertEqual( - (("D" + str(i * 2 + j)).encode(), ("K" + str(j)).encode()), - sess.run(get_next), - ) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test repeated iteration through both subscriptions. - sess.run( - init_op, - feed_dict={ - topics: ["key-partition-test:0:0:5", "key-partition-test:1:0:5"], - num_epochs: 10, - }, - ) - for _ in range(10): - for j in range(2): - for i in range(5): - self.assertEqual( - (("D" + str(i * 2 + j)).encode(), ("K" + str(j)).encode()), - sess.run(get_next), - ) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test batched and repeated iteration through both subscriptions. - sess.run( - init_batch_op, - feed_dict={ - topics: ["key-partition-test:0:0:5", "key-partition-test:1:0:5"], - num_epochs: 10, - batch_size: 5, - }, + except ValueError as e: + assert str( + e + ) == "Invalid stream_timeout value: {} ,set it to -1 to block indefinitely.".format( + STREAM_TIMEOUT + ) + + +def test_kafka_group_io_dataset_stream_timeout_check(): + """Test the functionality of the KafkaGroupIODataset when the + consumer is configured to have a valid stream_timeout value and thus waits + for the new messages from kafka. + NOTE: The default value for message_timeout=5000 + """ + + def write_messages_background(): + # Write new messages to the topic in a background thread + time.sleep(6) + for i in range(100, 200): + message = "D{}".format(i) + kafka_io.write_kafka(message=message, topic="key-partition-test") + + dataset = tfio.experimental.streaming.KafkaGroupIODataset( + topics=["key-partition-test"], + group_id="cgteststreamvalid", + servers="localhost:9092", + stream_timeout=20000, + configuration=[ + "session.timeout.ms=7000", + "max.poll.interval.ms=8000", + "auto.offset.reset=earliest", + ], + ) + + # start writing the new messages to kafka using the background job. + # the job sleeps for some time (< stream_timeout) and then writes the + # messages into the topic. + thread = threading.Thread(target=write_messages_background, args=()) + thread.daemon = True + thread.start() + + # At the end, after the timeout has occurred, we must have the old 100 messages + # along with the new 100 messages + assert np.all( + sorted([k.numpy() for (k, _) in dataset]) + == sorted([("D" + str(i)).encode() for i in range(200)]) + ) + + +def test_kafka_batch_io_dataset(): + """Test the functionality of the KafkaBatchIODataset by training a model + directly on the incoming kafka message batch(of type tf.data.Dataset), in an + online-training fashion. + + NOTE: This kind of dataset is suitable in scenarios where the 'keys' of 'messages' + act as labels. If not, additional transformations are required. + """ + + dataset = tfio.experimental.streaming.KafkaBatchIODataset( + topics=["mini-batch-test"], + group_id="cgminibatch", + servers=None, + stream_timeout=5000, + configuration=[ + "session.timeout.ms=7000", + "max.poll.interval.ms=8000", + "auto.offset.reset=earliest", + ], + ) + + NUM_COLUMNS = 1 + model = tf.keras.Sequential( + [ + tf.keras.layers.Input(shape=(NUM_COLUMNS,)), + tf.keras.layers.Dense(4, activation="relu"), + tf.keras.layers.Dropout(0.1), + tf.keras.layers.Dense(1, activation="sigmoid"), + ] + ) + model.compile( + optimizer="adam", + loss=tf.keras.losses.BinaryCrossentropy(from_logits=True), + metrics=["accuracy"], + ) + assert issubclass(type(dataset), tf.data.Dataset) + for mini_d in dataset: + mini_d = mini_d.map( + lambda m, k: ( + tf.strings.to_number(m, out_type=tf.float32), + tf.strings.to_number(k, out_type=tf.float32), ) - for _ in range(10): - for j in range(2): - self.assertAllEqual( - [ - [("D" + str(i * 2 + j)).encode() for i in range(5)], - [("K" + str(j)).encode() for i in range(5)], - ], - sess.run(get_next), - ) - - def test_kafka_dataset_with_offset(self): - """Tests for KafkaDataset when reading non-keyed messages - from a single-partitioned topic""" - topics = tf.compat.v1.placeholder(dtypes.string, shape=[None]) - num_epochs = tf.compat.v1.placeholder(dtypes.int64, shape=[]) - batch_size = tf.compat.v1.placeholder(dtypes.int64, shape=[]) - - repeat_dataset = kafka_io.KafkaDataset( - topics, group="test", eof=True, message_offset=True - ).repeat(num_epochs) - batch_dataset = repeat_dataset.batch(batch_size) - - iterator = data.Iterator.from_structure(batch_dataset.output_types) - init_op = iterator.make_initializer(repeat_dataset) - get_next = iterator.get_next() - - with self.cached_session() as sess: - # Basic offset test: read a limited number of messages from the topic. - sess.run(init_op, feed_dict={topics: ["offset-test:0:0:4"], num_epochs: 1}) - for i in range(5): - self.assertEqual( - (("D" + str(i)).encode(), ("0:" + str(i)).encode()), - sess.run(get_next), - ) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - -if __name__ == "__main__": - test.main() + ).batch(2) + assert issubclass(type(mini_d), tf.data.Dataset) + # Fits the model as long as the data keeps on streaming + model.fit(mini_d, epochs=5) diff --git a/tests/test_kafka_eager.py b/tests/test_kafka_eager.py deleted file mode 100644 index 06e82b12a..000000000 --- a/tests/test_kafka_eager.py +++ /dev/null @@ -1,515 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may not -# use this file except in compliance with the License. You may obtain a copy of -# the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations under -# the License. -# ============================================================================== -"""Tests for Kafka Output Sequence.""" - - -import time -import pytest -import numpy as np -import threading - -import tensorflow as tf -import tensorflow_io as tfio -from tensorflow_io.kafka.python.ops import ( - kafka_ops, -) # pylint: disable=wrong-import-position -import tensorflow_io.kafka as kafka_io # pylint: disable=wrong-import-position - - -def test_kafka_io_tensor(): - kafka = tfio.IOTensor.from_kafka("test") - assert kafka.dtype == tf.string - assert kafka.shape.as_list() == [None] - assert np.all( - kafka.to_tensor().numpy() == [("D" + str(i)).encode() for i in range(10)] - ) - assert len(kafka.to_tensor()) == 10 - - -@pytest.mark.skip(reason="TODO") -def test_kafka_output_sequence(): - """Test case based on fashion mnist tutorial""" - fashion_mnist = tf.keras.datasets.fashion_mnist - ((train_images, train_labels), (test_images, _)) = fashion_mnist.load_data() - - class_names = [ - "T-shirt/top", - "Trouser", - "Pullover", - "Dress", - "Coat", - "Sandal", - "Shirt", - "Sneaker", - "Bag", - "Ankle boot", - ] - - train_images = train_images / 255.0 - test_images = test_images / 255.0 - - model = tf.keras.Sequential( - [ - tf.keras.layers.Flatten(input_shape=(28, 28)), - tf.keras.layers.Dense(128, activation=tf.nn.relu), - tf.keras.layers.Dense(10, activation=tf.nn.softmax), - ] - ) - - model.compile( - optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"] - ) - - model.fit(train_images, train_labels, epochs=5) - - class OutputCallback(tf.keras.callbacks.Callback): - """KafkaOutputCallback""" - - def __init__( - self, batch_size, topic, servers - ): # pylint: disable=super-init-not-called - self._sequence = kafka_ops.KafkaOutputSequence(topic=topic, servers=servers) - self._batch_size = batch_size - - def on_predict_batch_end(self, batch, logs=None): - index = batch * self._batch_size - for outputs in logs["outputs"]: - for output in outputs: - self._sequence.setitem(index, class_names[np.argmax(output)]) - index += 1 - - def flush(self): - self._sequence.flush() - - channel = "e{}e".format(time.time()) - topic = "test_" + channel - - # By default batch size is 32 - output = OutputCallback(32, topic, "localhost") - predictions = model.predict(test_images, callbacks=[output]) - output.flush() - - predictions = [class_names[v] for v in np.argmax(predictions, axis=1)] - - # Reading from `test_e(time)e` we should get the same result - dataset = tfio.kafka.KafkaDataset(topics=[topic], group="test", eof=True) - for entry, prediction in zip(dataset, predictions): - assert entry.numpy() == prediction.encode() - - -def test_avro_kafka_dataset(): - """test_avro_kafka_dataset""" - schema = ( - '{"type":"record","name":"myrecord","fields":[' - '{"name":"f1","type":"string"},' - '{"name":"f2","type":"long"},' - '{"name":"f3","type":["null","string"],"default":null}' - "]}" - ) - dataset = kafka_io.KafkaDataset(["avro-test:0"], group="avro-test", eof=True) - # remove kafka framing - dataset = dataset.map(lambda e: tf.strings.substr(e, 5, -1)) - # deserialize avro - dataset = dataset.map( - lambda e: tfio.experimental.serialization.decode_avro(e, schema=schema) - ) - entries = [(e["f1"], e["f2"], e["f3"]) for e in dataset] - np.all(entries == [("value1", 1, ""), ("value2", 2, ""), ("value3", 3, "")]) - - -def test_avro_kafka_dataset_with_resource(): - """test_avro_kafka_dataset_with_resource""" - schema = ( - '{"type":"record","name":"myrecord","fields":[' - '{"name":"f1","type":"string"},' - '{"name":"f2","type":"long"},' - '{"name":"f3","type":["null","string"],"default":null}' - ']}"' - ) - schema_resource = kafka_io.decode_avro_init(schema) - dataset = kafka_io.KafkaDataset(["avro-test:0"], group="avro-test", eof=True) - # remove kafka framing - dataset = dataset.map(lambda e: tf.strings.substr(e, 5, -1)) - # deserialize avro - dataset = dataset.map( - lambda e: kafka_io.decode_avro( - e, schema=schema_resource, dtype=[tf.string, tf.int64, tf.string] - ) - ) - entries = [(f1.numpy(), f2.numpy(), f3.numpy()) for (f1, f2, f3) in dataset] - np.all(entries == [("value1", 1), ("value2", 2), ("value3", 3)]) - - -def test_kafka_stream_dataset(): - dataset = tfio.IODataset.stream().from_kafka("test").batch(2) - assert np.all( - [k.numpy().tolist() for (k, _) in dataset] - == np.asarray([("D" + str(i)).encode() for i in range(10)]).reshape((5, 2)) - ) - - -def test_kafka_io_dataset(): - dataset = tfio.IODataset.from_kafka( - "test", configuration=["fetch.min.bytes=2"] - ).batch(2) - # repeat multiple times will result in the same result - for _ in range(5): - assert np.all( - [k.numpy().tolist() for (k, _) in dataset] - == np.asarray([("D" + str(i)).encode() for i in range(10)]).reshape((5, 2)) - ) - - -def test_avro_encode_decode(): - """test_avro_encode_decode""" - schema = ( - '{"type":"record","name":"myrecord","fields":' - '[{"name":"f1","type":"string"},{"name":"f2","type":"long"}]}' - ) - value = [("value1", 1), ("value2", 2), ("value3", 3)] - f1 = tf.cast([v[0] for v in value], tf.string) - f2 = tf.cast([v[1] for v in value], tf.int64) - message = tfio.experimental.serialization.encode_avro([f1, f2], schema=schema) - entries = tfio.experimental.serialization.decode_avro(message, schema=schema) - assert np.all(entries["f1"].numpy() == f1.numpy()) - assert np.all(entries["f2"].numpy() == f2.numpy()) - - -def test_kafka_group_io_dataset_primary_cg(): - """Test the functionality of the KafkaGroupIODataset when the consumer group - is being newly created. - - NOTE: After the kafka cluster is setup during the testing phase, 10 messages - are written to the 'key-partition-test' topic with 5 in each partition - (topic created with 2 partitions, the messages are split based on the keys). - And the same 10 messages are written into the 'key-test' topic (topic created - with 1 partition, so no splitting of the messages based on the keys). - - K0:D0, K1:D1, K0:D2, K1:D3, K0:D4, K1:D5, K0:D6, K1:D7, K0:D8, K1:D9. - - Here, messages D0, D2, D4, D6 and D8 are written into partition 0 and the rest are written - into partition 1. - - Also, since the messages are read from different partitions, the order of retrieval may not be - the same as storage. Thus, we sort and compare. - """ - dataset = tfio.experimental.streaming.KafkaGroupIODataset( - topics=["key-partition-test"], - group_id="cgtestprimary", - servers="localhost:9092", - configuration=[ - "session.timeout.ms=7000", - "max.poll.interval.ms=8000", - "auto.offset.reset=earliest", - ], - ) - assert np.all( - sorted([k.numpy() for (k, _) in dataset]) - == sorted([("D" + str(i)).encode() for i in range(10)]) - ) - - -def test_kafka_group_io_dataset_primary_cg_no_lag(): - """Test the functionality of the KafkaGroupIODataset when the - consumer group has read all the messages and committed the offsets. - """ - dataset = tfio.experimental.streaming.KafkaGroupIODataset( - topics=["key-partition-test"], - group_id="cgtestprimary", - servers="localhost:9092", - configuration=["session.timeout.ms=7000", "max.poll.interval.ms=8000"], - ) - assert np.all(sorted([k.numpy() for (k, _) in dataset]) == []) - - -def test_kafka_group_io_dataset_primary_cg_new_topic(): - """Test the functionality of the KafkaGroupIODataset when the existing - consumer group reads data from a new topic. - """ - dataset = tfio.experimental.streaming.KafkaGroupIODataset( - topics=["key-test"], - group_id="cgtestprimary", - servers="localhost:9092", - configuration=[ - "session.timeout.ms=7000", - "max.poll.interval.ms=8000", - "auto.offset.reset=earliest", - ], - ) - assert np.all( - sorted([k.numpy() for (k, _) in dataset]) - == sorted([("D" + str(i)).encode() for i in range(10)]) - ) - - -def test_kafka_group_io_dataset_resume_primary_cg(): - """Test the functionality of the KafkaGroupIODataset when the - consumer group is yet to catch up with the newly added messages only - (Instead of reading from the beginning). - """ - - # Write new messages to the topic - for i in range(10, 100): - message = "D{}".format(i) - kafka_io.write_kafka(message=message, topic="key-partition-test") - # Read only the newly sent 90 messages - dataset = tfio.experimental.streaming.KafkaGroupIODataset( - topics=["key-partition-test"], - group_id="cgtestprimary", - servers="localhost:9092", - configuration=["session.timeout.ms=7000", "max.poll.interval.ms=8000"], - ) - assert np.all( - sorted([k.numpy() for (k, _) in dataset]) - == sorted([("D" + str(i)).encode() for i in range(10, 100)]) - ) - - -def test_kafka_group_io_dataset_resume_primary_cg_new_topic(): - """Test the functionality of the KafkaGroupIODataset when the - consumer group is yet to catch up with the newly added messages only - (Instead of reading from the beginning) from the new topic. - """ - - # Write new messages to the topic - for i in range(10, 100): - message = "D{}".format(i) - kafka_io.write_kafka(message=message, topic="key-test") - # Read only the newly sent 90 messages - dataset = tfio.experimental.streaming.KafkaGroupIODataset( - topics=["key-test"], - group_id="cgtestprimary", - servers="localhost:9092", - configuration=["session.timeout.ms=7000", "max.poll.interval.ms=8000"], - ) - assert np.all( - sorted([k.numpy() for (k, _) in dataset]) - == sorted([("D" + str(i)).encode() for i in range(10, 100)]) - ) - - -def test_kafka_group_io_dataset_secondary_cg(): - """Test the functionality of the KafkaGroupIODataset when a - secondary consumer group is created and is yet to catch up all the messages, - from the beginning. - """ - - dataset = tfio.experimental.streaming.KafkaGroupIODataset( - topics=["key-partition-test"], - group_id="cgtestsecondary", - servers="localhost:9092", - configuration=[ - "session.timeout.ms=7000", - "max.poll.interval.ms=8000", - "auto.offset.reset=earliest", - ], - ) - assert np.all( - sorted([k.numpy() for (k, _) in dataset]) - == sorted([("D" + str(i)).encode() for i in range(100)]) - ) - - -def test_kafka_group_io_dataset_tertiary_cg_multiple_topics(): - """Test the functionality of the KafkaGroupIODataset when a new - consumer group reads data from multiple topics from the beginning. - """ - - dataset = tfio.experimental.streaming.KafkaGroupIODataset( - topics=["key-partition-test", "key-test"], - group_id="cgtesttertiary", - servers="localhost:9092", - configuration=[ - "session.timeout.ms=7000", - "max.poll.interval.ms=8000", - "auto.offset.reset=earliest", - ], - ) - assert np.all( - sorted([k.numpy() for (k, _) in dataset]) - == sorted([("D" + str(i)).encode() for i in range(100)] * 2) - ) - - -def test_kafka_group_io_dataset_auto_offset_reset(): - """Test the functionality of the `auto.offset.reset` configuration - at global and topic level""" - - dataset = tfio.experimental.streaming.KafkaGroupIODataset( - topics=["key-partition-test"], - group_id="cgglobaloffsetearliest", - servers="localhost:9092", - configuration=[ - "session.timeout.ms=7000", - "max.poll.interval.ms=8000", - "auto.offset.reset=earliest", - ], - ) - assert np.all( - sorted([k.numpy() for (k, _) in dataset]) - == sorted([("D" + str(i)).encode() for i in range(100)]) - ) - - dataset = tfio.experimental.streaming.KafkaGroupIODataset( - topics=["key-partition-test"], - group_id="cgglobaloffsetlatest", - servers="localhost:9092", - configuration=[ - "session.timeout.ms=7000", - "max.poll.interval.ms=8000", - "auto.offset.reset=latest", - ], - ) - assert np.all(sorted([k.numpy() for (k, _) in dataset]) == []) - - dataset = tfio.experimental.streaming.KafkaGroupIODataset( - topics=["key-partition-test"], - group_id="cgtopicoffsetearliest", - servers="localhost:9092", - configuration=[ - "session.timeout.ms=7000", - "max.poll.interval.ms=8000", - "conf.topic.auto.offset.reset=earliest", - ], - ) - assert np.all( - sorted([k.numpy() for (k, _) in dataset]) - == sorted([("D" + str(i)).encode() for i in range(100)]) - ) - - dataset = tfio.experimental.streaming.KafkaGroupIODataset( - topics=["key-partition-test"], - group_id="cgtopicoffsetlatest", - servers="localhost:9092", - configuration=[ - "session.timeout.ms=7000", - "max.poll.interval.ms=8000", - "conf.topic.auto.offset.reset=latest", - ], - ) - assert np.all(sorted([k.numpy() for (k, _) in dataset]) == []) - - -def test_kafka_group_io_dataset_invalid_stream_timeout(): - """Test the functionality of the KafkaGroupIODataset when the - consumer is configured to have an invalid stream_timeout value which is - less than the message_timeout value. - NOTE: The default value for message_timeout=5000 - """ - - STREAM_TIMEOUT = -20 - try: - tfio.experimental.streaming.KafkaGroupIODataset( - topics=["key-partition-test", "key-test"], - group_id="cgteststreaminvalid", - servers="localhost:9092", - stream_timeout=STREAM_TIMEOUT, - configuration=["session.timeout.ms=7000", "max.poll.interval.ms=8000"], - ) - except ValueError as e: - assert str( - e - ) == "Invalid stream_timeout value: {} ,set it to -1 to block indefinitely.".format( - STREAM_TIMEOUT - ) - - -def test_kafka_group_io_dataset_stream_timeout_check(): - """Test the functionality of the KafkaGroupIODataset when the - consumer is configured to have a valid stream_timeout value and thus waits - for the new messages from kafka. - NOTE: The default value for message_timeout=5000 - """ - - def write_messages_background(): - # Write new messages to the topic in a background thread - time.sleep(6) - for i in range(100, 200): - message = "D{}".format(i) - kafka_io.write_kafka(message=message, topic="key-partition-test") - - dataset = tfio.experimental.streaming.KafkaGroupIODataset( - topics=["key-partition-test"], - group_id="cgteststreamvalid", - servers="localhost:9092", - stream_timeout=20000, - configuration=[ - "session.timeout.ms=7000", - "max.poll.interval.ms=8000", - "auto.offset.reset=earliest", - ], - ) - - # start writing the new messages to kafka using the background job. - # the job sleeps for some time (< stream_timeout) and then writes the - # messages into the topic. - thread = threading.Thread(target=write_messages_background, args=()) - thread.daemon = True - thread.start() - - # At the end, after the timeout has occurred, we must have the old 100 messages - # along with the new 100 messages - assert np.all( - sorted([k.numpy() for (k, _) in dataset]) - == sorted([("D" + str(i)).encode() for i in range(200)]) - ) - - -def test_kafka_batch_io_dataset(): - """Test the functionality of the KafkaBatchIODataset by training a model - directly on the incoming kafka message batch(of type tf.data.Dataset), in an - online-training fashion. - - NOTE: This kind of dataset is suitable in scenarios where the 'keys' of 'messages' - act as labels. If not, additional transformations are required. - """ - - dataset = tfio.experimental.streaming.KafkaBatchIODataset( - topics=["mini-batch-test"], - group_id="cgminibatch", - servers=None, - stream_timeout=5000, - configuration=[ - "session.timeout.ms=7000", - "max.poll.interval.ms=8000", - "auto.offset.reset=earliest", - ], - ) - - NUM_COLUMNS = 1 - model = tf.keras.Sequential( - [ - tf.keras.layers.Input(shape=(NUM_COLUMNS,)), - tf.keras.layers.Dense(4, activation="relu"), - tf.keras.layers.Dropout(0.1), - tf.keras.layers.Dense(1, activation="sigmoid"), - ] - ) - model.compile( - optimizer="adam", - loss=tf.keras.losses.BinaryCrossentropy(from_logits=True), - metrics=["accuracy"], - ) - assert issubclass(type(dataset), tf.data.Dataset) - for mini_d in dataset: - mini_d = mini_d.map( - lambda m, k: ( - tf.strings.to_number(m, out_type=tf.float32), - tf.strings.to_number(k, out_type=tf.float32), - ) - ).batch(2) - assert issubclass(type(mini_d), tf.data.Dataset) - # Fits the model as long as the data keeps on streaming - model.fit(mini_d, epochs=5) diff --git a/tests/test_kafka_v1.py b/tests/test_kafka_v1.py new file mode 100644 index 000000000..8c539473c --- /dev/null +++ b/tests/test_kafka_v1.py @@ -0,0 +1,511 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# ============================================================================== +"""Tests for KafkaDataset.""" + + +import time +import pytest + +import tensorflow as tf + +tf.compat.v1.disable_eager_execution() + +from tensorflow import dtypes # pylint: disable=wrong-import-position +from tensorflow import errors # pylint: disable=wrong-import-position +from tensorflow import test # pylint: disable=wrong-import-position +from tensorflow.compat.v1 import data # pylint: disable=wrong-import-position + +import tensorflow_io.kafka as kafka_io # pylint: disable=wrong-import-position + + +class KafkaDatasetTest(test.TestCase): + """Tests for KafkaDataset.""" + + # The Kafka server has to be setup before the test + # and tear down after the test manually. + # The docker engine has to be installed. + # + # To setup the Kafka server: + # $ bash kafka_test.sh start kafka + # + # To tear down the Kafka server: + # $ bash kafka_test.sh stop kafka + + def test_kafka_dataset(self): + """Tests for KafkaDataset when reading non-keyed messages + from a single-partitioned topic""" + topics = tf.compat.v1.placeholder(dtypes.string, shape=[None]) + num_epochs = tf.compat.v1.placeholder(dtypes.int64, shape=[]) + batch_size = tf.compat.v1.placeholder(dtypes.int64, shape=[]) + + repeat_dataset = kafka_io.KafkaDataset(topics, group="test", eof=True).repeat( + num_epochs + ) + batch_dataset = repeat_dataset.batch(batch_size) + + iterator = data.Iterator.from_structure(batch_dataset.output_types) + init_op = iterator.make_initializer(repeat_dataset) + init_batch_op = iterator.make_initializer(batch_dataset) + get_next = iterator.get_next() + + with self.cached_session() as sess: + # Basic test: read a limited number of messages from the topic. + sess.run(init_op, feed_dict={topics: ["test:0:0:4"], num_epochs: 1}) + for i in range(5): + self.assertEqual(("D" + str(i)).encode(), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Basic test: read all the messages from the topic from offset 5. + sess.run(init_op, feed_dict={topics: ["test:0:5:-1"], num_epochs: 1}) + for i in range(5): + self.assertEqual(("D" + str(i + 5)).encode(), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Basic test: read from different subscriptions of the same topic. + sess.run( + init_op, + feed_dict={topics: ["test:0:0:4", "test:0:5:-1"], num_epochs: 1}, + ) + for j in range(2): + for i in range(5): + self.assertEqual( + ("D" + str(i + j * 5)).encode(), sess.run(get_next) + ) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test repeated iteration through both subscriptions. + sess.run( + init_op, + feed_dict={topics: ["test:0:0:4", "test:0:5:-1"], num_epochs: 10}, + ) + for _ in range(10): + for j in range(2): + for i in range(5): + self.assertEqual( + ("D" + str(i + j * 5)).encode(), sess.run(get_next) + ) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test batched and repeated iteration through both subscriptions. + sess.run( + init_batch_op, + feed_dict={ + topics: ["test:0:0:4", "test:0:5:-1"], + num_epochs: 10, + batch_size: 5, + }, + ) + for _ in range(10): + self.assertAllEqual( + [("D" + str(i)).encode() for i in range(5)], sess.run(get_next) + ) + self.assertAllEqual( + [("D" + str(i + 5)).encode() for i in range(5)], sess.run(get_next) + ) + + @pytest.mark.skip(reason="TODO") + def test_kafka_dataset_save_and_restore(self): + """Tests for KafkaDataset save and restore.""" + g = tf.Graph() + with g.as_default(): + topics = tf.compat.v1.placeholder(dtypes.string, shape=[None]) + num_epochs = tf.compat.v1.placeholder(dtypes.int64, shape=[]) + + repeat_dataset = kafka_io.KafkaDataset( + topics, group="test", eof=True + ).repeat(num_epochs) + iterator = repeat_dataset.make_initializable_iterator() + get_next = iterator.get_next() + + it = tf.data.experimental.make_saveable_from_iterator(iterator) + g.add_to_collection(tf.compat.v1.GraphKeys.SAVEABLE_OBJECTS, it) + saver = tf.compat.v1.train.Saver() + + model_file = "/tmp/test-kafka-model" + with self.cached_session() as sess: + sess.run( + iterator.initializer, + feed_dict={topics: ["test:0:0:4"], num_epochs: 1}, + ) + for i in range(3): + self.assertEqual(("D" + str(i)).encode(), sess.run(get_next)) + # Save current offset which is 2 + saver.save(sess, model_file, global_step=3) + + checkpoint_file = "/tmp/test-kafka-model-3" + with self.cached_session() as sess: + saver.restore(sess, checkpoint_file) + # Restore current offset to 2 + for i in [2, 3]: + self.assertEqual(("D" + str(i)).encode(), sess.run(get_next)) + + def test_kafka_topic_configuration(self): + """Tests for KafkaDataset topic configuration properties.""" + topics = tf.compat.v1.placeholder(dtypes.string, shape=[None]) + num_epochs = tf.compat.v1.placeholder(dtypes.int64, shape=[]) + cfg_list = ["auto.offset.reset=earliest"] + + repeat_dataset = kafka_io.KafkaDataset( + topics, group="test", eof=True, config_topic=cfg_list + ).repeat(num_epochs) + + iterator = data.Iterator.from_structure(repeat_dataset.output_types) + init_op = iterator.make_initializer(repeat_dataset) + get_next = iterator.get_next() + + with self.cached_session() as sess: + # Use a wrong offset 100 here to make sure + # configuration 'auto.offset.reset=earliest' works. + sess.run(init_op, feed_dict={topics: ["test:0:100:-1"], num_epochs: 1}) + for i in range(5): + self.assertEqual(("D" + str(i)).encode(), sess.run(get_next)) + + def test_kafka_global_configuration(self): + """Tests for KafkaDataset global configuration properties.""" + topics = tf.compat.v1.placeholder(dtypes.string, shape=[None]) + num_epochs = tf.compat.v1.placeholder(dtypes.int64, shape=[]) + cfg_list = ["debug=generic", "enable.auto.commit=false"] + + repeat_dataset = kafka_io.KafkaDataset( + topics, group="test", eof=True, config_global=cfg_list + ).repeat(num_epochs) + + iterator = data.Iterator.from_structure(repeat_dataset.output_types) + init_op = iterator.make_initializer(repeat_dataset) + get_next = iterator.get_next() + + with self.cached_session() as sess: + sess.run(init_op, feed_dict={topics: ["test:0:0:4"], num_epochs: 1}) + for i in range(5): + self.assertEqual(("D" + str(i)).encode(), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def test_kafka_wrong_global_configuration_failed(self): + """Tests for KafkaDataset worng global configuration properties.""" + topics = tf.compat.v1.placeholder(dtypes.string, shape=[None]) + num_epochs = tf.compat.v1.placeholder(dtypes.int64, shape=[]) + + # Add wrong configuration + wrong_cfg = ["debug=al"] + repeat_dataset = kafka_io.KafkaDataset( + topics, group="test", eof=True, config_global=wrong_cfg + ).repeat(num_epochs) + + iterator = data.Iterator.from_structure(repeat_dataset.output_types) + init_op = iterator.make_initializer(repeat_dataset) + get_next = iterator.get_next() + + with self.cached_session() as sess: + sess.run(init_op, feed_dict={topics: ["test:0:0:4"], num_epochs: 1}) + with self.assertRaises(errors.InternalError): + sess.run(get_next) + + def test_kafka_wrong_topic_configuration_failed(self): + """Tests for KafkaDataset wrong topic configuration properties.""" + topics = tf.compat.v1.placeholder(dtypes.string, shape=[None]) + num_epochs = tf.compat.v1.placeholder(dtypes.int64, shape=[]) + + # Add wrong configuration + wrong_cfg = ["auto.offset.reset=arliest"] + repeat_dataset = kafka_io.KafkaDataset( + topics, group="test", eof=True, config_topic=wrong_cfg + ).repeat(num_epochs) + + iterator = data.Iterator.from_structure(repeat_dataset.output_types) + init_op = iterator.make_initializer(repeat_dataset) + get_next = iterator.get_next() + + with self.cached_session() as sess: + sess.run(init_op, feed_dict={topics: ["test:0:0:4"], num_epochs: 1}) + with self.assertRaises(errors.InternalError): + sess.run(get_next) + + def test_write_kafka(self): + """test_write_kafka""" + channel = "e{}e".format(time.time()) + + # Start with reading test topic, replace `D` with `e(time)e`, + # and write to test_e(time)e` topic. + dataset = kafka_io.KafkaDataset(topics=["test:0:0:4"], group="test", eof=True) + dataset = dataset.map( + lambda x: kafka_io.write_kafka( + tf.strings.regex_replace(x, "D", channel), topic="test_" + channel + ) + ) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.cached_session() as sess: + # Basic test: read from topic 0. + sess.run(init_op) + for i in range(5): + self.assertEqual((channel + str(i)).encode(), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Reading from `test_e(time)e` we should get the same result + dataset = kafka_io.KafkaDataset( + topics=["test_" + channel], group="test", eof=True + ) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.cached_session() as sess: + sess.run(init_op) + for i in range(5): + self.assertEqual((channel + str(i)).encode(), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def test_kafka_dataset_with_key(self): + """Tests for KafkaDataset when reading keyed-messages + from a single-partitioned topic""" + topics = tf.compat.v1.placeholder(dtypes.string, shape=[None]) + num_epochs = tf.compat.v1.placeholder(dtypes.int64, shape=[]) + batch_size = tf.compat.v1.placeholder(dtypes.int64, shape=[]) + + repeat_dataset = kafka_io.KafkaDataset( + topics, group="test", eof=True, message_key=True + ).repeat(num_epochs) + batch_dataset = repeat_dataset.batch(batch_size) + + iterator = data.Iterator.from_structure(batch_dataset.output_types) + init_op = iterator.make_initializer(repeat_dataset) + init_batch_op = iterator.make_initializer(batch_dataset) + get_next = iterator.get_next() + + with self.cached_session() as sess: + # Basic test: read a limited number of keyed messages from the topic. + sess.run(init_op, feed_dict={topics: ["key-test:0:0:4"], num_epochs: 1}) + for i in range(5): + self.assertEqual( + (("D" + str(i)).encode(), ("K" + str(i % 2)).encode()), + sess.run(get_next), + ) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Basic test: read all the keyed messages from the topic from offset 5. + sess.run(init_op, feed_dict={topics: ["key-test:0:5:-1"], num_epochs: 1}) + for i in range(5): + self.assertEqual( + (("D" + str(i + 5)).encode(), ("K" + str((i + 5) % 2)).encode()), + sess.run(get_next), + ) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Basic test: read from different subscriptions of the same topic. + sess.run( + init_op, + feed_dict={ + topics: ["key-test:0:0:4", "key-test:0:5:-1"], + num_epochs: 1, + }, + ) + for j in range(2): + for i in range(5): + self.assertEqual( + ( + ("D" + str(i + j * 5)).encode(), + ("K" + str((i + j * 5) % 2)).encode(), + ), + sess.run(get_next), + ) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test repeated iteration through both subscriptions. + sess.run( + init_op, + feed_dict={ + topics: ["key-test:0:0:4", "key-test:0:5:-1"], + num_epochs: 10, + }, + ) + for _ in range(10): + for j in range(2): + for i in range(5): + self.assertEqual( + ( + ("D" + str(i + j * 5)).encode(), + ("K" + str((i + j * 5) % 2)).encode(), + ), + sess.run(get_next), + ) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test batched and repeated iteration through both subscriptions. + sess.run( + init_batch_op, + feed_dict={ + topics: ["key-test:0:0:4", "key-test:0:5:-1"], + num_epochs: 10, + batch_size: 5, + }, + ) + for _ in range(10): + self.assertAllEqual( + [ + [("D" + str(i)).encode() for i in range(5)], + [("K" + str(i % 2)).encode() for i in range(5)], + ], + sess.run(get_next), + ) + self.assertAllEqual( + [ + [("D" + str(i + 5)).encode() for i in range(5)], + [("K" + str((i + 5) % 2)).encode() for i in range(5)], + ], + sess.run(get_next), + ) + + def test_kafka_dataset_with_partitioned_key(self): + """Tests for KafkaDataset when reading keyed-messages + from a multi-partitioned topic""" + topics = tf.compat.v1.placeholder(dtypes.string, shape=[None]) + num_epochs = tf.compat.v1.placeholder(dtypes.int64, shape=[]) + batch_size = tf.compat.v1.placeholder(dtypes.int64, shape=[]) + + repeat_dataset = kafka_io.KafkaDataset( + topics, group="test", eof=True, message_key=True + ).repeat(num_epochs) + batch_dataset = repeat_dataset.batch(batch_size) + + iterator = data.Iterator.from_structure(batch_dataset.output_types) + init_op = iterator.make_initializer(repeat_dataset) + init_batch_op = iterator.make_initializer(batch_dataset) + get_next = iterator.get_next() + + with self.cached_session() as sess: + # Basic test: read first 5 messages from the first partition of the topic. + # NOTE: The key-partition mapping occurs based on the order in which the data + # is being stored in kafka. Please check kafka_test.sh for the sample data. + + sess.run( + init_op, + feed_dict={topics: ["key-partition-test:0:0:5"], num_epochs: 1}, + ) + for i in range(5): + self.assertEqual( + (("D" + str(i * 2)).encode(), (b"K0")), sess.run(get_next), + ) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Basic test: read first 5 messages from the second partition of the topic. + sess.run( + init_op, + feed_dict={topics: ["key-partition-test:1:0:5"], num_epochs: 1}, + ) + for i in range(5): + self.assertEqual( + (("D" + str(i * 2 + 1)).encode(), (b"K1")), sess.run(get_next), + ) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Basic test: read from different subscriptions to the same topic. + sess.run( + init_op, + feed_dict={ + topics: ["key-partition-test:0:0:5", "key-partition-test:1:0:5"], + num_epochs: 1, + }, + ) + for j in range(2): + for i in range(5): + self.assertEqual( + (("D" + str(i * 2 + j)).encode(), ("K" + str(j)).encode()), + sess.run(get_next), + ) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test repeated iteration through both subscriptions. + sess.run( + init_op, + feed_dict={ + topics: ["key-partition-test:0:0:5", "key-partition-test:1:0:5"], + num_epochs: 10, + }, + ) + for _ in range(10): + for j in range(2): + for i in range(5): + self.assertEqual( + (("D" + str(i * 2 + j)).encode(), ("K" + str(j)).encode()), + sess.run(get_next), + ) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test batched and repeated iteration through both subscriptions. + sess.run( + init_batch_op, + feed_dict={ + topics: ["key-partition-test:0:0:5", "key-partition-test:1:0:5"], + num_epochs: 10, + batch_size: 5, + }, + ) + for _ in range(10): + for j in range(2): + self.assertAllEqual( + [ + [("D" + str(i * 2 + j)).encode() for i in range(5)], + [("K" + str(j)).encode() for i in range(5)], + ], + sess.run(get_next), + ) + + def test_kafka_dataset_with_offset(self): + """Tests for KafkaDataset when reading non-keyed messages + from a single-partitioned topic""" + topics = tf.compat.v1.placeholder(dtypes.string, shape=[None]) + num_epochs = tf.compat.v1.placeholder(dtypes.int64, shape=[]) + batch_size = tf.compat.v1.placeholder(dtypes.int64, shape=[]) + + repeat_dataset = kafka_io.KafkaDataset( + topics, group="test", eof=True, message_offset=True + ).repeat(num_epochs) + batch_dataset = repeat_dataset.batch(batch_size) + + iterator = data.Iterator.from_structure(batch_dataset.output_types) + init_op = iterator.make_initializer(repeat_dataset) + get_next = iterator.get_next() + + with self.cached_session() as sess: + # Basic offset test: read a limited number of messages from the topic. + sess.run(init_op, feed_dict={topics: ["offset-test:0:0:4"], num_epochs: 1}) + for i in range(5): + self.assertEqual( + (("D" + str(i)).encode(), ("0:" + str(i)).encode()), + sess.run(get_next), + ) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + +if __name__ == "__main__": + test.main() diff --git a/tests/test_libsvm_eager.py b/tests/test_libsvm.py similarity index 100% rename from tests/test_libsvm_eager.py rename to tests/test_libsvm.py diff --git a/tests/test_lmdb_eager.py b/tests/test_lmdb.py similarity index 100% rename from tests/test_lmdb_eager.py rename to tests/test_lmdb.py diff --git a/tests/test_mongodb_eager.py b/tests/test_mongodb.py similarity index 100% rename from tests/test_mongodb_eager.py rename to tests/test_mongodb.py diff --git a/tests/test_parquet_eager.py b/tests/test_parquet.py similarity index 100% rename from tests/test_parquet_eager.py rename to tests/test_parquet.py diff --git a/tests/test_parse_avro_eager.py b/tests/test_parse_avro.py similarity index 100% rename from tests/test_parse_avro_eager.py rename to tests/test_parse_avro.py diff --git a/tests/test_pcap_eager.py b/tests/test_pcap.py similarity index 100% rename from tests/test_pcap_eager.py rename to tests/test_pcap.py diff --git a/tests/test_pulsar_eager.py b/tests/test_pulsar.py similarity index 100% rename from tests/test_pulsar_eager.py rename to tests/test_pulsar.py diff --git a/tests/test_s3_eager.py b/tests/test_s3.py similarity index 100% rename from tests/test_s3_eager.py rename to tests/test_s3.py diff --git a/tests/test_serialization_eager.py b/tests/test_serialization.py similarity index 100% rename from tests/test_serialization_eager.py rename to tests/test_serialization.py diff --git a/tests/test_text_eager.py b/tests/test_text.py similarity index 100% rename from tests/test_text_eager.py rename to tests/test_text.py diff --git a/tests/test_version_eager.py b/tests/test_version.py similarity index 100% rename from tests/test_version_eager.py rename to tests/test_version.py diff --git a/tests/test_video_eager.py b/tests/test_video.py similarity index 100% rename from tests/test_video_eager.py rename to tests/test_video.py