diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index a3ce87096e790..65b902cf3c4d5 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -2219,6 +2219,20 @@ def semanticHash(self): """ return self._jdf.semanticHash() + @since(3.1) + def inputFiles(self): + """ + Returns a best-effort snapshot of the files that compose this :class:`DataFrame`. + This method simply asks each constituent BaseRelation for its respective files and + takes the union of all results. Depending on the source relations, this may not find + all input files. Duplicates are removed. + + >>> df = spark.read.load("examples/src/main/resources/people.json", format="json") + >>> len(df.inputFiles()) + 1 + """ + return list(self._jdf.inputFiles()) + where = copy_func( filter, sinceversion=1.3, diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 9861178158f85..062e61663a332 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -17,6 +17,8 @@ import os import pydoc +import shutil +import tempfile import time import unittest @@ -820,6 +822,22 @@ def test_same_semantics_error(self): with self.assertRaisesRegexp(ValueError, "should be of DataFrame.*int"): self.spark.range(10).sameSemantics(1) + def test_input_files(self): + tpath = tempfile.mkdtemp() + shutil.rmtree(tpath) + try: + self.spark.range(1, 100, 1, 10).write.parquet(tpath) + # read parquet file and get the input files list + input_files_list = self.spark.read.parquet(tpath).inputFiles() + + # input files list should contain 10 entries + self.assertEquals(len(input_files_list), 10) + # all file paths in list must contain tpath + for file_path in input_files_list: + self.assertTrue(tpath in file_path) + finally: + shutil.rmtree(tpath) + class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils): # These tests are separate because it uses 'spark.sql.queryExecutionListeners' which is