Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion python/pyspark/sql/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,9 @@ def insertInto(self, tableName, overwrite=False):

Optionally overwriting any existing data.
"""
self._jwrite.mode("overwrite" if overwrite else "append").insertInto(tableName)
if (overwrite):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (overwrite) -> if overwrite

self._jwrite.mode("overwrite")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we call self.mode("overwrite") instead?

self._jwrite.insertInto(tableName)

@since(1.4)
def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options):
Expand Down
19 changes: 19 additions & 0 deletions python/pyspark/sql/tests/test_readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,25 @@ def count_bucketed_cols(names, table="pyspark_bucket"):
.mode("overwrite").saveAsTable("pyspark_bucket"))
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))

def test_insert_into(self):
df = self.spark.createDataFrame([("a", 1), ("b", 2)], ["C1", "C2"])
df.write.saveAsTable("test_table")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use with self.table("test_table"): too?

self.assertEqual(2, self.spark.sql("select * from test_table").count())

df.write.insertInto("test_table")
self.assertEqual(4, self.spark.sql("select * from test_table").count())

df.write.mode("overwrite").insertInto("test_table")
self.assertEqual(2, self.spark.sql("select * from test_table").count())

df.write.insertInto("test_table", True)
self.assertEqual(2, self.spark.sql("select * from test_table").count())

df.write.insertInto("test_table", False)
self.assertEqual(4, self.spark.sql("select * from test_table").count())

self.spark.sql("drop table test_table")


if __name__ == "__main__":
import unittest
Expand Down