diff --git a/docs/manager.md b/docs/manager.md index f5dfe66b..bf9250b1 100644 --- a/docs/manager.md +++ b/docs/manager.md @@ -209,7 +209,7 @@ obj = ( ) ``` -`bulk_insert` uses a single query to insert all specified rows at once. +`bulk_insert` uses a single query to insert all specified rows at once. It returns a `list` of `dict()` with each `dict()` being a merge of the `dict()` passed in along with any index returned from Postgres. #### Limitations In order to stick to the "everything in one query" principle, various, more advanced usages of `bulk_insert` are impossible. It is not possible to have different rows specify different amounts of columns. The following example does **not work**: diff --git a/psqlextra/compiler.py b/psqlextra/compiler.py index e4f87ffa..aeeb8dd2 100644 --- a/psqlextra/compiler.py +++ b/psqlextra/compiler.py @@ -48,7 +48,7 @@ def execute_sql(self, return_id=False): rows = [] for sql, params in self.as_sql(return_id): cursor.execute(sql, params) - rows.append(cursor.fetchone()) + rows.extend(cursor.fetchall()) # create a mapping between column names and column value return [ diff --git a/psqlextra/manager/manager.py b/psqlextra/manager/manager.py index aa6f8b19..c88462a1 100644 --- a/psqlextra/manager/manager.py +++ b/psqlextra/manager/manager.py @@ -139,7 +139,7 @@ def on_conflict(self, fields: List[Union[str, Tuple[str]]], action, index_predic return self - def bulk_insert(self, rows): + def bulk_insert(self, rows, return_model=False): """Creates multiple new records in the database. This allows specifying custom conflict behavior using .on_conflict(). @@ -150,16 +150,25 @@ def bulk_insert(self, rows): An array of dictionaries, where each dictionary describes the fields to insert. + return_model (default: False): + If model instances should be returned rather than + just dicts. + Returns: + A list of either the dicts of the rows inserted, including the pk or + the models of the rows inserted with defaults for any fields not specified """ if self.conflict_target or self.conflict_action: compiler = self._build_insert_compiler(rows) - compiler.execute_sql(return_id=True) - return + objs = compiler.execute_sql(return_id=True) + if return_model: + return [self.model(**dict(r, **k)) for r, k in zip(rows, objs)] + else: + return [dict(r, **k) for r, k in zip(rows, objs)] # no special action required, use the standard Django bulk_create(..) - super().bulk_create([self.model(**fields) for fields in rows]) + return super().bulk_create([self.model(**fields) for fields in rows]) def insert(self, **fields): """Creates a new record in the database. diff --git a/tests/test_on_conflict.py b/tests/test_on_conflict.py index 3111dd9a..314a42d4 100644 --- a/tests/test_on_conflict.py +++ b/tests/test_on_conflict.py @@ -413,3 +413,37 @@ def test_on_conflict_bulk(): for index, obj in enumerate(list(model.objects.all())): assert obj.title == rows[index]['title'] + +def test_bulk_return(): + """Tests if primary keys are properly returned from 'bulk_insert' + """ + + model = get_fake_model({ + 'id': models.BigAutoField(primary_key=True), + 'name': models.CharField(max_length=255, unique=True) + }) + + rows = [ + dict(name='John Smith'), + dict(name='Jane Doe') + ] + + objs = ( + model.objects + .on_conflict(['name'], ConflictAction.UPDATE) + .bulk_insert(rows) + ) + + for index, obj in enumerate(objs, 1): + assert obj['id'] == index + + """Add objects again, update should return the same ids + as we're just updating.""" + objs = ( + model.objects + .on_conflict(['name'], ConflictAction.UPDATE) + .bulk_insert(rows) + ) + + for index, obj in enumerate(objs, 1): + assert obj['id'] == index