diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index aa5bf635d187..842cd951fe5e 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -757,7 +757,7 @@ def save(self, path=None, format=None, mode=None, partitionBy=None, **options): self._jwrite.save(path) @since(1.4) - def insertInto(self, tableName, overwrite=False): + def insertInto(self, tableName, overwrite=None): """Inserts the content of the :class:`DataFrame` to the specified table. It requires that the schema of the class:`DataFrame` is the same as the @@ -765,7 +765,9 @@ def insertInto(self, tableName, overwrite=False): Optionally overwriting any existing data. """ - self._jwrite.mode("overwrite" if overwrite else "append").insertInto(tableName) + if overwrite is not None: + self.mode("overwrite" if overwrite else "append") + self._jwrite.insertInto(tableName) @since(1.4) def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options): diff --git a/python/pyspark/sql/tests/test_readwriter.py b/python/pyspark/sql/tests/test_readwriter.py index a70807248960..2530cc2ebf22 100644 --- a/python/pyspark/sql/tests/test_readwriter.py +++ b/python/pyspark/sql/tests/test_readwriter.py @@ -141,6 +141,27 @@ def count_bucketed_cols(names, table="pyspark_bucket"): .mode("overwrite").saveAsTable("pyspark_bucket")) self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) + def test_insert_into(self): + df = self.spark.createDataFrame([("a", 1), ("b", 2)], ["C1", "C2"]) + with self.table("test_table"): + df.write.saveAsTable("test_table") + self.assertEqual(2, self.spark.sql("select * from test_table").count()) + + df.write.insertInto("test_table") + self.assertEqual(4, self.spark.sql("select * from test_table").count()) + + df.write.mode("overwrite").insertInto("test_table") + self.assertEqual(2, self.spark.sql("select * from test_table").count()) + + df.write.insertInto("test_table", True) + self.assertEqual(2, self.spark.sql("select * from test_table").count()) + + df.write.insertInto("test_table", False) + self.assertEqual(4, self.spark.sql("select * from test_table").count()) + + df.write.mode("overwrite").insertInto("test_table", False) + self.assertEqual(6, self.spark.sql("select * from test_table").count()) + if __name__ == "__main__": import unittest