From bbba3f5e8fce0a1db133cc2b0baa5ad38c448273 Mon Sep 17 00:00:00 2001 From: Joe Mancuso Date: Wed, 30 Sep 2020 23:00:05 -0400 Subject: [PATCH] fixed models not loading their primary keys when updating --- src/masoniteorm/query/QueryBuilder.py | 10 ++++-- tests/mysql/builder/test_query_builder.py | 2 +- .../builder/test_postgres_query_builder.py | 2 +- .../builder/test_sqlite_builder_insert.py | 2 +- tests/sqlite/models/test_sqlite_model.py | 35 +++++++++++++++++++ 5 files changed, 46 insertions(+), 5 deletions(-) create mode 100644 tests/sqlite/models/test_sqlite_model.py diff --git a/src/masoniteorm/query/QueryBuilder.py b/src/masoniteorm/query/QueryBuilder.py index 9c0a484f..c5a46bc7 100644 --- a/src/masoniteorm/query/QueryBuilder.py +++ b/src/masoniteorm/query/QueryBuilder.py @@ -47,6 +47,7 @@ def __init__( """ self.grammar = grammar self._table = table + self.dry = dry self.connection = connection self._connection = None self._connection_details = connection_details @@ -786,9 +787,14 @@ def update(self, updates: dict, dry=False): Returns: self """ + if self._model and self._model.is_loaded(): + self.where( + self._model.get_primary_key(), self._model.get_primary_key_value() + ) + self._updates += (UpdateQueryExpression(updates),) self.set_action("update") - if dry: + if dry or self.dry: return self return self.new_connection().query(self.to_sql(), self._bindings) @@ -960,7 +966,7 @@ def _register_relationships_to_model( hydrated_model.add_relation({relation_key: related_result or {}}) return self - + def find(self, record_id): """Finds a row by the primary key ID. Requires a model diff --git a/tests/mysql/builder/test_query_builder.py b/tests/mysql/builder/test_query_builder.py index 373998fd..3bf1814a 100644 --- a/tests/mysql/builder/test_query_builder.py +++ b/tests/mysql/builder/test_query_builder.py @@ -27,7 +27,7 @@ def get_builder(self, table="users"): grammar=self.grammar, connection=connection, table=table, - model=User, + model=User(), connection_details=DATABASES, ) diff --git a/tests/postgres/builder/test_postgres_query_builder.py b/tests/postgres/builder/test_postgres_query_builder.py index e1a55092..6336daa4 100644 --- a/tests/postgres/builder/test_postgres_query_builder.py +++ b/tests/postgres/builder/test_postgres_query_builder.py @@ -23,7 +23,7 @@ def get_default_query_grammar(cls): class BaseTestQueryBuilder: def get_builder(self, table="users"): connection = MockConnectionFactory().make("default") - return QueryBuilder(self.grammar, connection, table=table, model=Model) + return QueryBuilder(self.grammar, connection, table=table, model=Model()) def test_sum(self): builder = self.get_builder() diff --git a/tests/sqlite/builder/test_sqlite_builder_insert.py b/tests/sqlite/builder/test_sqlite_builder_insert.py index 1d45fc68..348f1af9 100644 --- a/tests/sqlite/builder/test_sqlite_builder_insert.py +++ b/tests/sqlite/builder/test_sqlite_builder_insert.py @@ -27,7 +27,7 @@ def get_builder(self, table="users"): connection=connection, table=table, # model=User, - connection_details={}, + connection_details=DATABASES, ).on("sqlite") def test_insert(self): diff --git a/tests/sqlite/models/test_sqlite_model.py b/tests/sqlite/models/test_sqlite_model.py new file mode 100644 index 00000000..4913e6f2 --- /dev/null +++ b/tests/sqlite/models/test_sqlite_model.py @@ -0,0 +1,35 @@ +import inspect +import unittest + +from src.masoniteorm.query import QueryBuilder +from src.masoniteorm.query.grammars import SQLiteGrammar +from src.masoniteorm.connections import ConnectionFactory +from src.masoniteorm.relationships import belongs_to +from src.masoniteorm.models import Model +from tests.utils import MockConnectionFactory +from config.database import DATABASES + + +class User(Model): + __connection__ = "sqlite" + __timestamps__ = False + __dry__ = True + + +class BaseTestQueryRelationships(unittest.TestCase): + + maxDiff = None + + def test_update_specific_record(self): + user = User.first() + sql = user.update({"name": "joe"}).to_sql() + + self.assertEqual( + sql, + """UPDATE "users" SET "name" = 'joe' WHERE "id" = '{}'""".format(user.id), + ) + + def test_update_all_records(self): + sql = User.update({"name": "joe"}).to_sql() + + self.assertEqual(sql, """UPDATE "users" SET "name" = 'joe'""")