1818from keras import models
1919from keras import ops
2020
21+ from keras_nlp .src .layers .modeling .position_embedding import PositionEmbedding
2122from keras_nlp .src .models .stable_diffusion_v3 .mmdit_block import MMDiTBlock
2223from keras_nlp .src .utils .keras_utils import standardize_data_format
2324
@@ -58,45 +59,45 @@ def get_config(self):
5859 return config
5960
6061
61- class PositionEmbedding ( layers . Layer ):
62+ class AdjustablePositionEmbedding ( PositionEmbedding ):
6263 def __init__ (
6364 self ,
64- sequence_length ,
65+ height ,
66+ width ,
6567 initializer = "glorot_uniform" ,
6668 ** kwargs ,
6769 ):
68- super ().__init__ (** kwargs )
69- if sequence_length is None :
70- raise ValueError (
71- "`sequence_length` must be an Integer, received `None`."
72- )
73- self .sequence_length = int (sequence_length )
74- self .initializer = keras .initializers .get (initializer )
75-
76- def build (self , inputs_shape ):
77- feature_size = inputs_shape [- 1 ]
78- self .position_embeddings = self .add_weight (
79- name = "embeddings" ,
80- shape = [self .sequence_length , feature_size ],
81- initializer = self .initializer ,
82- trainable = True ,
70+ height = int (height )
71+ width = int (width )
72+ sequence_length = height * width
73+ super ().__init__ (sequence_length , initializer , ** kwargs )
74+ self .height = height
75+ self .width = width
76+
77+ def call (self , inputs , height = None , width = None ):
78+ height = height or self .height
79+ width = width or self .width
80+ shape = ops .shape (inputs )
81+ feature_length = shape [- 1 ]
82+ top = ops .floor_divide (self .height - height , 2 )
83+ left = ops .floor_divide (self .width - width , 2 )
84+ position_embedding = ops .convert_to_tensor (self .position_embeddings )
85+ position_embedding = ops .reshape (
86+ position_embedding , (self .height , self .width , feature_length )
8387 )
84-
85- def call (self , inputs ):
86- return ops .convert_to_tensor (self .position_embeddings )
87-
88- def get_config (self ):
89- config = super ().get_config ()
90- config .update (
91- {
92- "sequence_length" : self .sequence_length ,
93- "initializer" : keras .initializers .serialize (self .initializer ),
94- }
88+ position_embedding = ops .slice (
89+ position_embedding ,
90+ (top , left , 0 ),
91+ (height , width , feature_length ),
9592 )
96- return config
93+ position_embedding = ops .reshape (
94+ position_embedding , (height * width , feature_length )
95+ )
96+ position_embedding = ops .expand_dims (position_embedding , axis = 0 )
97+ return position_embedding
9798
9899 def compute_output_shape (self , input_shape ):
99- return list ( self . position_embeddings . shape )
100+ return input_shape
100101
101102
102103class TimestepEmbedding (layers .Layer ):
@@ -112,18 +113,13 @@ def __init__(
112113 self .mlp = models .Sequential (
113114 [
114115 layers .Dense (
115- embedding_dim ,
116- activation = "silu" ,
117- dtype = self .dtype_policy ,
118- name = "dense0" ,
116+ embedding_dim , activation = "silu" , dtype = self .dtype_policy
119117 ),
120118 layers .Dense (
121- embedding_dim ,
122- activation = None ,
123- dtype = self .dtype_policy ,
124- name = "dense1" ,
119+ embedding_dim , activation = None , dtype = self .dtype_policy
125120 ),
126- ]
121+ ],
122+ name = "mlp" ,
127123 )
128124
129125 def build (self , inputs_shape ):
@@ -181,9 +177,7 @@ def __init__(self, hidden_dim, output_dim, **kwargs):
181177 [
182178 layers .Activation ("silu" , dtype = self .dtype_policy ),
183179 layers .Dense (
184- num_modulation * hidden_dim ,
185- dtype = self .dtype_policy ,
186- name = "dense" ,
180+ num_modulation * hidden_dim , dtype = self .dtype_policy
187181 ),
188182 ],
189183 name = "adaptive_norm_modulation" ,
@@ -234,6 +228,41 @@ def get_config(self):
234228 return config
235229
236230
231+ class Unpatch (layers .Layer ):
232+ def __init__ (self , patch_size , output_dim , ** kwargs ):
233+ super ().__init__ (** kwargs )
234+ self .patch_size = int (patch_size )
235+ self .output_dim = int (output_dim )
236+
237+ def call (self , inputs , height , width ):
238+ patch_size = self .patch_size
239+ output_dim = self .output_dim
240+ x = ops .reshape (
241+ inputs ,
242+ (- 1 , height , width , patch_size , patch_size , output_dim ),
243+ )
244+ # (b, h, w, p1, p2, o) -> (b, h, p1, w, p2, o)
245+ x = ops .transpose (x , (0 , 1 , 3 , 2 , 4 , 5 ))
246+ return ops .reshape (
247+ x ,
248+ (- 1 , height * patch_size , width * patch_size , output_dim ),
249+ )
250+
251+ def get_config (self ):
252+ config = super ().get_config ()
253+ config .update (
254+ {
255+ "patch_size" : self .patch_size ,
256+ "output_dim" : self .output_dim ,
257+ }
258+ )
259+ return config
260+
261+ def compute_output_shape (self , inputs_shape ):
262+ inputs_shape = list (inputs_shape )
263+ return [inputs_shape [0 ], None , None , self .output_dim ]
264+
265+
237266class MMDiT (keras .Model ):
238267 def __init__ (
239268 self ,
@@ -251,13 +280,19 @@ def __init__(
251280 dtype = None ,
252281 ** kwargs ,
253282 ):
283+ if None in latent_shape :
284+ raise ValueError (
285+ "`latent_shape` must be fully specified. "
286+ f"Received: latent_shape={ latent_shape } "
287+ )
288+ image_height = latent_shape [0 ] // patch_size
289+ image_width = latent_shape [1 ] // patch_size
290+ output_dim_in_final = patch_size ** 2 * output_dim
254291 data_format = standardize_data_format (data_format )
255292 if data_format != "channels_last" :
256293 raise NotImplementedError (
257294 "Currently only 'channels_last' is supported."
258295 )
259- position_sequence_length = position_size * position_size
260- output_dim_in_final = patch_size ** 2 * output_dim
261296
262297 # === Layers ===
263298 self .patch_embedding = PatchEmbedding (
@@ -267,8 +302,11 @@ def __init__(
267302 dtype = dtype ,
268303 name = "patch_embedding" ,
269304 )
270- self .position_embedding = PositionEmbedding (
271- position_sequence_length , dtype = dtype , name = "position_embedding"
305+ self .position_embedding_add = layers .Add (
306+ dtype = dtype , name = "position_embedding_add"
307+ )
308+ self .position_embedding = AdjustablePositionEmbedding (
309+ position_size , position_size , dtype = dtype , name = "position_embedding"
272310 )
273311 self .context_embedding = layers .Dense (
274312 hidden_dim ,
@@ -277,19 +315,13 @@ def __init__(
277315 )
278316 self .vector_embedding = models .Sequential (
279317 [
280- layers .Dense (
281- hidden_dim ,
282- activation = "silu" ,
283- dtype = dtype ,
284- name = "vector_embedding_dense_0" ,
285- ),
286- layers .Dense (
287- hidden_dim ,
288- activation = None ,
289- dtype = dtype ,
290- name = "vector_embedding_dense_1" ,
291- ),
292- ]
318+ layers .Dense (hidden_dim , activation = "silu" , dtype = dtype ),
319+ layers .Dense (hidden_dim , activation = None , dtype = dtype ),
320+ ],
321+ name = "vector_embedding" ,
322+ )
323+ self .vector_embedding_add = layers .Add (
324+ dtype = dtype , name = "vector_embedding_add"
293325 )
294326 self .timestep_embedding = TimestepEmbedding (
295327 hidden_dim , dtype = dtype , name = "timestep_embedding"
@@ -301,12 +333,15 @@ def __init__(
301333 mlp_ratio ,
302334 use_context_projection = not (i == depth - 1 ),
303335 dtype = dtype ,
304- name = f"joint_block { i } " ,
336+ name = f"joint_block_ { i } " ,
305337 )
306338 for i in range (depth )
307339 ]
308- self .final_layer = OutputLayer (
309- hidden_dim , output_dim_in_final , dtype = dtype , name = "final_layer"
340+ self .output_layer = OutputLayer (
341+ hidden_dim , output_dim_in_final , dtype = dtype , name = "output_layer"
342+ )
343+ self .unpatch = Unpatch (
344+ patch_size , output_dim , dtype = dtype , name = "unpatch"
310345 )
311346
312347 # === Functional Model ===
@@ -316,18 +351,17 @@ def __init__(
316351 shape = pooled_projection_shape , name = "pooled_projection"
317352 )
318353 timestep_inputs = layers .Input (shape = (1 ,), name = "timestep" )
319- image_size = latent_shape [:2 ]
320354
321355 # Embeddings.
322356 x = self .patch_embedding (latent_inputs )
323- cropped_position_embedding = self ._get_cropped_position_embedding (
324- x , patch_size , image_size , position_size
357+ position_embedding = self .position_embedding (
358+ x , height = image_height , width = image_width
325359 )
326- x = layers . Add ( dtype = dtype )( [x , cropped_position_embedding ])
360+ x = self . position_embedding_add ( [x , position_embedding ])
327361 context = self .context_embedding (context_inputs )
328362 pooled_projection = self .vector_embedding (pooled_projection_inputs )
329363 timestep_embedding = self .timestep_embedding (timestep_inputs )
330- timestep_embedding = layers . Add ( dtype = dtype ) (
364+ timestep_embedding = self . vector_embedding_add (
331365 [timestep_embedding , pooled_projection ]
332366 )
333367
@@ -338,9 +372,9 @@ def __init__(
338372 else :
339373 x = block (x , context , timestep_embedding )
340374
341- # Final layer.
342- x = self .final_layer (x , timestep_embedding )
343- output_image = self ._unpatchify (x , patch_size , image_size , output_dim )
375+ # Output layer.
376+ x = self .output_layer (x , timestep_embedding )
377+ outputs = self .unpatch (x , height = image_height , width = image_width )
344378
345379 super ().__init__ (
346380 inputs = {
@@ -349,7 +383,7 @@ def __init__(
349383 "pooled_projection" : pooled_projection_inputs ,
350384 "timestep" : timestep_inputs ,
351385 },
352- outputs = output_image ,
386+ outputs = outputs ,
353387 ** kwargs ,
354388 )
355389
@@ -374,42 +408,6 @@ def __init__(
374408 dtype = dtype .name
375409 self .dtype_policy = keras .DTypePolicy (dtype )
376410
377- def _get_cropped_position_embedding (
378- self , inputs , patch_size , image_size , position_size
379- ):
380- h , w = image_size
381- h = h // patch_size
382- w = w // patch_size
383- top = (position_size - h ) // 2
384- left = (position_size - w ) // 2
385- hidden_dim = ops .shape (inputs )[- 1 ]
386- position_embedding = self .position_embedding (inputs )
387- position_embedding = ops .reshape (
388- position_embedding ,
389- (1 , position_size , position_size , hidden_dim ),
390- )
391- cropped_position_embedding = position_embedding [
392- :, top : top + h , left : left + w , :
393- ]
394- cropped_position_embedding = ops .reshape (
395- cropped_position_embedding , (1 , h * w , hidden_dim )
396- )
397- return cropped_position_embedding
398-
399- def _unpatchify (self , x , patch_size , image_size , output_dim ):
400- h , w = image_size
401- h = h // patch_size
402- w = w // patch_size
403- batch_size = ops .shape (x )[0 ]
404- x = ops .reshape (
405- x , (batch_size , h , w , patch_size , patch_size , output_dim )
406- )
407- # (b, h, w, p1, p2, o) -> (b, h, p1, w, p2, o)
408- x = ops .transpose (x , (0 , 1 , 3 , 2 , 4 , 5 ))
409- return ops .reshape (
410- x , (batch_size , h * patch_size , w * patch_size , output_dim )
411- )
412-
413411 def get_config (self ):
414412 config = super ().get_config ()
415413 config .update (
0 commit comments