-
Notifications
You must be signed in to change notification settings - Fork 367
/
raw_sql.py
70 lines (53 loc) · 2.4 KB
/
raw_sql.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from django.db import connections
from django.db.models.sql.constants import INNER, LOUTER
from django.db.models.sql.datastructures import Join
from judge.utils.cachedict import CacheDict
class RawSQLJoin(Join):
def __init__(self, subquery, subquery_params, parent_alias, table_alias, join_type, join_field, nullable,
filtered_relation=None):
self.subquery_params = subquery_params
super().__init__(subquery, parent_alias, table_alias, join_type, join_field, nullable, filtered_relation)
def as_sql(self, compiler, connection):
compiler.quote_cache[self.table_name] = '(%s)' % self.table_name
sql, params = super().as_sql(compiler, connection)
return sql, self.subquery_params + params
class FakeJoinField:
def __init__(self, joining_columns, related_model):
self.joining_columns = joining_columns
self.related_model = related_model
def get_joining_columns(self):
return self.joining_columns
def get_extra_restriction(self, where_class, alias, remote_alias):
pass
def join_sql_subquery(
queryset, subquery, params, join_fields, alias, related_model, join_type=INNER, parent_model=None):
if parent_model is not None:
parent_alias = parent_model._meta.db_table
else:
parent_alias = queryset.query.get_initial_alias()
if isinstance(queryset.query.external_aliases, dict): # Django 3.x
queryset.query.external_aliases[alias] = True
else:
queryset.query.external_aliases.add(alias)
join = RawSQLJoin(subquery, params, parent_alias, alias, join_type, FakeJoinField(join_fields, related_model),
join_type == LOUTER)
queryset.query.join(join)
join.table_alias = alias
def make_straight_join_query(QueryType):
class Query(QueryType):
def join(self, join, *args, **kwargs):
alias = super().join(join, *args, **kwargs)
join = self.alias_map[alias]
if join.join_type == INNER:
join.join_type = 'STRAIGHT_JOIN'
return alias
return Query
straight_join_cache = CacheDict(make_straight_join_query)
def use_straight_join(queryset):
if connections[queryset.db].vendor != 'mysql':
return
try:
cloner = queryset.query.chain
except AttributeError:
cloner = queryset.query.clone
queryset.query = cloner(straight_join_cache[type(queryset.query)])