11from django .core import checks
2+ from django .core .exceptions import FieldDoesNotExist
23from django .db import models
34from django .db .models .fields .related import lazy_related_operation
45from django .db .models .lookups import Transform
@@ -112,7 +113,8 @@ def get_transform(self, name):
112113 transform = super ().get_transform (name )
113114 if transform :
114115 return transform
115- return KeyTransformFactory (name )
116+ field = self .embedded_model ._meta .get_field (name )
117+ return KeyTransformFactory (name , field )
116118
117119 def validate (self , value , model_instance ):
118120 super ().validate (value , model_instance )
@@ -134,9 +136,25 @@ def formfield(self, **kwargs):
134136
135137
136138class KeyTransform (Transform ):
137- def __init__ (self , key_name , * args , ** kwargs ):
139+ def __init__ (self , key_name , ref_field , * args , ** kwargs ):
138140 super ().__init__ (* args , ** kwargs )
139141 self .key_name = str (key_name )
142+ self .ref_field = ref_field
143+
144+ def get_transform (self , name ):
145+ result = None
146+ if isinstance (self .ref_field , EmbeddedModelField ):
147+ opts = self .ref_field .embedded_model ._meta
148+ new_field = opts .get_field (name )
149+ result = KeyTransformFactory (name , new_field )
150+ else :
151+ if self .ref_field .get_transform (name ) is None :
152+ raise FieldDoesNotExist (
153+ f"{ self .ref_field .model ._meta .object_name } .{ self .ref_field .name } "
154+ f" has no field named '{ name } '"
155+ )
156+ result = KeyTransformFactory (name , self .ref_field )
157+ return result
140158
141159 def preprocess_lhs (self , compiler , connection ):
142160 key_transforms = [self .key_name ]
@@ -154,8 +172,9 @@ def as_mql(self, compiler, connection):
154172
155173
156174class KeyTransformFactory :
157- def __init__ (self , key_name ):
175+ def __init__ (self , key_name , ref_field ):
158176 self .key_name = key_name
177+ self .ref_field = ref_field
159178
160179 def __call__ (self , * args , ** kwargs ):
161- return KeyTransform (self .key_name , * args , ** kwargs )
180+ return KeyTransform (self .key_name , self . ref_field , * args , ** kwargs )
0 commit comments