Skip to content

Commit dc95f9b

Browse files
Merge pull request #239 from MasoniteFramework/feature/233
Feature/233
2 parents 8277489 + 6c0a702 commit dc95f9b

File tree

11 files changed

+181
-40
lines changed

11 files changed

+181
-40
lines changed

cc.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
from src.masoniteorm.connections import MySQLConnection, PostgresConnection
33
from src.masoniteorm.query.grammars import MySQLGrammar, PostgresGrammar
44
from src.masoniteorm.models import Model
5+
from src.masoniteorm.relationships import has_many
6+
import inspect
57

68

7-
builder = QueryBuilder(connection=PostgresConnection, grammar=PostgresGrammar).table("users").on("postgres")
9+
# builder = QueryBuilder(connection=PostgresConnection, grammar=PostgresGrammar).table("users").on("postgres")
810

911

1012

@@ -14,8 +16,16 @@ class User(Model):
1416
__connection__ = "sqlite"
1517
__table__ = "users"
1618

17-
user = User.create({"name": "phill", "email": "phill"})
18-
print(User.get().count())
19+
@has_many("id", "user_id")
20+
def articles(self):
21+
return Article
22+
class Article(Model):
23+
__connection__ = "sqlite"
24+
25+
26+
# user = User.create({"name": "phill", "email": "phill"})
27+
# print(inspect.isclass(User))
28+
print(User.find(1).with_("articles").first().serialize())
1929

20-
print(user.serialize())
30+
# print(user.serialize())
2131
# print(User.first())

queries.log

Whitespace-only changes.

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# Versions should comply with PEP440. For a discussion on single-sourcing
77
# the version across setup.py and the project code, see
88
# https://packaging.python.org/en/latest/single_source_version.html
9-
version='0.8.0b',
9+
version='0.8.0b1',
1010
package_dir={'': 'src'},
1111

1212
description='The Official Masonite ORM',

src/masoniteorm/models/Model.py

+4
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,10 @@ def related(self, relation):
611611
related = getattr(self.__class__, relation)
612612
return related.where(related.foreign_key, self.get_primary_key_value())
613613

614+
def get_related(self, relation):
615+
related = getattr(self.__class__, relation)
616+
return related
617+
614618
def attach(self, relation, related_record):
615619
related = getattr(self.__class__, relation)
616620
setattr(
+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
class EagerRelations:
2+
def __init__(self, relation=None):
3+
self.eagers = []
4+
self.nested_eagers = {}
5+
self.is_nested = False
6+
self.relation = relation
7+
8+
def register(self, *relations):
9+
for relation in relations:
10+
if isinstance(relation, str) and "." not in relation:
11+
self.eagers += [relation]
12+
elif isinstance(relation, str) and "." in relation:
13+
self.is_nested = True
14+
relation_key = relation.split(".")[0]
15+
if relation_key not in self.nested_eagers:
16+
self.nested_eagers = {relation_key: relation.split(".")[1:]}
17+
else:
18+
self.nested_eagers[relation_key] += relation.split(".")[1:]
19+
elif isinstance(relation, tuple):
20+
for eagers in relations:
21+
for eager in eagers:
22+
self.register(eager)
23+
elif isinstance(relation, list):
24+
for eagers in relations:
25+
for eager in eagers:
26+
self.register(eager)
27+
28+
return self
29+
30+
def get_eagers(self):
31+
eagers = []
32+
if self.eagers:
33+
eagers.append(self.eagers)
34+
35+
if self.nested_eagers:
36+
eagers.append(self.nested_eagers)
37+
38+
return eagers

src/masoniteorm/query/QueryBuilder.py

+34-27
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424

2525
from ..pagination import LengthAwarePaginator, SimplePaginator
2626

27+
from .EagerRelation import EagerRelations
28+
2729

2830
class QueryBuilder(ObservesEvents):
2931
"""A builder class to manage the building and creation of query expressions."""
@@ -56,7 +58,7 @@ def __init__(
5658
self._connection_details = connection_details
5759
self._connection_driver = connection_driver
5860
self._scopes = scopes
59-
self._eager_loads = ()
61+
self._eager_relation = EagerRelations()
6062
if model:
6163
self._global_scopes = model._global_scopes
6264
if model.__with__:
@@ -1095,27 +1097,39 @@ def prepare_result(self, result, collection=False):
10951097
if self._model:
10961098
# eager load here
10971099
hydrated_model = self._model.hydrate(result)
1098-
if self._eager_loads and hydrated_model:
1099-
for eager in self._eager_loads:
1100-
if "." in eager:
1101-
last_owner = self._model
1102-
last_result = hydrated_model
1103-
for eager in eager.split("."):
1104-
related = getattr(last_owner, eager)
1105-
result_set = related.get_related(last_result)
1100+
if self._eager_relation.eagers and hydrated_model:
1101+
for eager_load in self._eager_relation.get_eagers():
1102+
if isinstance(eager_load, dict):
1103+
# Nested
1104+
for relation, eagers in eager_load.items():
1105+
if inspect.isclass(self._model):
1106+
related = getattr(self._model, relation)
1107+
else:
1108+
related = self._model.get_related(relation)
1109+
1110+
result_set = related.get_related(
1111+
hydrated_model, eagers=eagers
1112+
)
11061113

11071114
self._register_relationships_to_model(
1108-
related, result_set, last_result, relation_key=eager
1115+
related,
1116+
result_set,
1117+
hydrated_model,
1118+
relation_key=relation,
11091119
)
1110-
1111-
last_result = result_set
1112-
last_owner = related.get_builder()._model
11131120
else:
1114-
related = getattr(self._model, eager)
1115-
related_result = related.get_related(hydrated_model)
1116-
self._register_relationships_to_model(
1117-
related, related_result, hydrated_model, relation_key=eager
1118-
)
1121+
# Not Nested
1122+
for eager in eager_load:
1123+
if inspect.isclass(self._model):
1124+
related = getattr(self._model, eager)
1125+
else:
1126+
related = self._model.get_related(eager)
1127+
1128+
result_set = related.get_related(hydrated_model)
1129+
1130+
self._register_relationships_to_model(
1131+
related, result_set, hydrated_model, relation_key=eager
1132+
)
11191133

11201134
if collection:
11211135
return hydrated_model if result else Collection([])
@@ -1149,7 +1163,6 @@ def _register_relationships_to_model(
11491163
else:
11501164
model.add_relation({relation_key: related_result or None})
11511165
else:
1152-
11531166
hydrated_model.add_relation({relation_key: related_result or None})
11541167
return self
11551168

@@ -1192,14 +1205,8 @@ def without_eager(self):
11921205
self._should_eager = False
11931206
return self
11941207

1195-
def with_(self, eagers=(), *others):
1196-
if not isinstance(eagers, (tuple, list)):
1197-
eagers = (eagers,)
1198-
1199-
if others:
1200-
eagers += others
1201-
1202-
self._eager_loads += eagers
1208+
def with_(self, *eagers):
1209+
self._eager_relation.register(eagers)
12031210
return self
12041211

12051212
def paginate(self, per_page, page=1):

src/masoniteorm/relationships/BelongsTo.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def apply_query(self, foreign, owner):
1919
self.foreign_key, owner.__attributes__[self.local_key]
2020
).first()
2121

22-
def get_related(self, relation):
22+
def get_related(self, relation, eagers=()):
2323
"""Gets the relation needed between the relation and the related builder. If the relation is a collection
2424
then will need to pluck out all the keys from the collection and fetch from the related builder. If
2525
relation is just a Model then we can just call the model based on the value of the related
@@ -31,7 +31,7 @@ def get_related(self, relation):
3131
Returns:
3232
Model|Collection
3333
"""
34-
builder = self.get_builder()
34+
builder = self.get_builder().with_(eagers)
3535
if isinstance(relation, Collection):
3636
return builder.where_in(
3737
f"{builder.get_table_name()}.{self.foreign_key}",

src/masoniteorm/relationships/HasMany.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ def apply_query(self, foreign, owner):
2121

2222
return result
2323

24-
def get_related(self, relation):
25-
builder = self.get_builder()
24+
def get_related(self, relation, eagers=[]):
25+
builder = self.get_builder().with_(eagers)
2626
if isinstance(relation, Collection):
2727
return builder.where_in(
2828
f"{builder.get_table_name()}.{self.foreign_key}",

tests/eagers/test_eager.py

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import os
2+
import unittest
3+
4+
from src.masoniteorm.query.EagerRelation import EagerRelations
5+
6+
7+
class TestEagerRelation(unittest.TestCase):
8+
def test_can_register_string_eager_load(self):
9+
10+
self.assertEqual(
11+
EagerRelations().register("profile").get_eagers(), [["profile"]]
12+
)
13+
self.assertEqual(EagerRelations().register("profile").is_nested, False)
14+
self.assertEqual(
15+
EagerRelations().register("profile.user").get_eagers(),
16+
[{"profile": ["user"]}],
17+
)
18+
self.assertEqual(
19+
EagerRelations().register("profile.user", "profile.logo").get_eagers(),
20+
[{"profile": ["user", "logo"]}],
21+
)
22+
self.assertEqual(
23+
EagerRelations()
24+
.register("profile.user", "profile.logo", "profile.bio")
25+
.get_eagers(),
26+
[{"profile": ["user", "logo", "bio"]}],
27+
)
28+
self.assertEqual(
29+
EagerRelations().register("user", "logo", "bio").get_eagers(),
30+
[["user", "logo", "bio"]],
31+
)
32+
33+
def test_can_register_tuple_eager_load(self):
34+
35+
self.assertEqual(
36+
EagerRelations().register(("profile",)).get_eagers(), [["profile"]]
37+
)
38+
self.assertEqual(
39+
EagerRelations().register(("profile", "user")).get_eagers(),
40+
[["profile", "user"]],
41+
)
42+
self.assertEqual(
43+
EagerRelations().register(("profile.name", "profile.user")).get_eagers(),
44+
[{"profile": ["name", "user"]}],
45+
)
46+
47+
def test_can_register_list_eager_load(self):
48+
49+
self.assertEqual(
50+
EagerRelations().register(["profile"]).get_eagers(), [["profile"]]
51+
)
52+
self.assertEqual(
53+
EagerRelations().register(["profile", "user"]).get_eagers(),
54+
[["profile", "user"]],
55+
)
56+
self.assertEqual(
57+
EagerRelations().register(["profile.name", "profile.user"]).get_eagers(),
58+
[{"profile": ["name", "user"]}],
59+
)
60+
self.assertEqual(
61+
EagerRelations().register(["profile.name"]).get_eagers(),
62+
[{"profile": ["name"]}],
63+
)
64+
self.assertEqual(
65+
EagerRelations().register(["profile.name", "logo"]).get_eagers(),
66+
[["logo"], {"profile": ["name"]}],
67+
)
68+
self.assertEqual(
69+
EagerRelations()
70+
.register(["profile.name", "logo", "profile.user"])
71+
.get_eagers(),
72+
[["logo"], {"profile": ["name", "user"]}],
73+
)

tests/sqlite/builder/test_sqlite_query_builder_eager_loading.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from src.masoniteorm.models import Model
77
from src.masoniteorm.query import QueryBuilder
88
from src.masoniteorm.query.grammars import SQLiteGrammar
9-
from src.masoniteorm.relationships import belongs_to
9+
from src.masoniteorm.relationships import belongs_to, has_many
1010
from tests.utils import MockConnectionFactory
1111

1212

@@ -21,6 +21,10 @@ class Article(Model):
2121
def logo(self):
2222
return Logo
2323

24+
@belongs_to("user_id", "id")
25+
def user(self):
26+
return User
27+
2428

2529
class Profile(Model):
2630
__connection__ = "sqlite"
@@ -29,9 +33,9 @@ class Profile(Model):
2933
class User(Model):
3034
__connection__ = "sqlite"
3135

32-
__with__ = ("profile",)
36+
__with__ = ["articles.logo"]
3337

34-
@belongs_to("id", "user_id")
38+
@has_many("id", "user_id")
3539
def articles(self):
3640
return Article
3741

@@ -88,3 +92,9 @@ def test_with_where_no_relation(self):
8892
builder = self.get_builder()
8993
result = builder.with_("profile").where("id", 5).first()
9094
result.serialize()
95+
96+
def test_with_multiple_per_same_relation(self):
97+
builder = self.get_builder()
98+
result = User.with_("articles", "articles.logo").where("id", 1).first()
99+
self.assertTrue(result.serialize()["articles"])
100+
self.assertTrue(result.serialize()["articles"][0]["logo"])

tests/sqlite/relationships/test_sqlite_relationships.py

-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ def test_loading(self):
7777
def test_loading_with_nested_with(self):
7878
users = User.with_("articles", "articles.logo").get()
7979
for user in users:
80-
print(user.articles)
8180
for article in user.articles:
8281
if article.logo:
8382
print("aa", article.logo.url)

0 commit comments

Comments
 (0)