1818from  keras  import  models 
1919from  keras  import  ops 
2020
21+ from  keras_nlp .src .layers .modeling .position_embedding  import  PositionEmbedding 
2122from  keras_nlp .src .models .stable_diffusion_v3 .mmdit_block  import  MMDiTBlock 
2223from  keras_nlp .src .utils .keras_utils  import  standardize_data_format 
2324
@@ -58,45 +59,45 @@ def get_config(self):
5859        return  config 
5960
6061
61- class  PositionEmbedding ( layers . Layer ):
62+ class  AdjustablePositionEmbedding ( PositionEmbedding ):
6263    def  __init__ (
6364        self ,
64-         sequence_length ,
65+         height ,
66+         width ,
6567        initializer = "glorot_uniform" ,
6668        ** kwargs ,
6769    ):
68-         super ().__init__ (** kwargs )
69-         if  sequence_length  is  None :
70-             raise  ValueError (
71-                 "`sequence_length` must be an Integer, received `None`." 
72-             )
73-         self .sequence_length  =  int (sequence_length )
74-         self .initializer  =  keras .initializers .get (initializer )
75- 
76-     def  build (self , inputs_shape ):
77-         feature_size  =  inputs_shape [- 1 ]
78-         self .position_embeddings  =  self .add_weight (
79-             name = "embeddings" ,
80-             shape = [self .sequence_length , feature_size ],
81-             initializer = self .initializer ,
82-             trainable = True ,
70+         height  =  int (height )
71+         width  =  int (width )
72+         sequence_length  =  height  *  width 
73+         super ().__init__ (sequence_length , initializer , ** kwargs )
74+         self .height  =  height 
75+         self .width  =  width 
76+ 
77+     def  call (self , inputs , height = None , width = None ):
78+         height  =  height  or  self .height 
79+         width  =  width  or  self .width 
80+         shape  =  ops .shape (inputs )
81+         feature_length  =  shape [- 1 ]
82+         top  =  ops .floor_divide (self .height  -  height , 2 )
83+         left  =  ops .floor_divide (self .width  -  width , 2 )
84+         position_embedding  =  ops .convert_to_tensor (self .position_embeddings )
85+         position_embedding  =  ops .reshape (
86+             position_embedding , (self .height , self .width , feature_length )
8387        )
84- 
85-     def  call (self , inputs ):
86-         return  ops .convert_to_tensor (self .position_embeddings )
87- 
88-     def  get_config (self ):
89-         config  =  super ().get_config ()
90-         config .update (
91-             {
92-                 "sequence_length" : self .sequence_length ,
93-                 "initializer" : keras .initializers .serialize (self .initializer ),
94-             }
88+         position_embedding  =  ops .slice (
89+             position_embedding ,
90+             (top , left , 0 ),
91+             (height , width , feature_length ),
9592        )
96-         return  config 
93+         position_embedding  =  ops .reshape (
94+             position_embedding , (height  *  width , feature_length )
95+         )
96+         position_embedding  =  ops .expand_dims (position_embedding , axis = 0 )
97+         return  position_embedding 
9798
9899    def  compute_output_shape (self , input_shape ):
99-         return  list ( self . position_embeddings . shape ) 
100+         return  input_shape 
100101
101102
102103class  TimestepEmbedding (layers .Layer ):
@@ -112,18 +113,13 @@ def __init__(
112113        self .mlp  =  models .Sequential (
113114            [
114115                layers .Dense (
115-                     embedding_dim ,
116-                     activation = "silu" ,
117-                     dtype = self .dtype_policy ,
118-                     name = "dense0" ,
116+                     embedding_dim , activation = "silu" , dtype = self .dtype_policy 
119117                ),
120118                layers .Dense (
121-                     embedding_dim ,
122-                     activation = None ,
123-                     dtype = self .dtype_policy ,
124-                     name = "dense1" ,
119+                     embedding_dim , activation = None , dtype = self .dtype_policy 
125120                ),
126-             ]
121+             ],
122+             name = "mlp" ,
127123        )
128124
129125    def  build (self , inputs_shape ):
@@ -181,9 +177,7 @@ def __init__(self, hidden_dim, output_dim, **kwargs):
181177            [
182178                layers .Activation ("silu" , dtype = self .dtype_policy ),
183179                layers .Dense (
184-                     num_modulation  *  hidden_dim ,
185-                     dtype = self .dtype_policy ,
186-                     name = "dense" ,
180+                     num_modulation  *  hidden_dim , dtype = self .dtype_policy 
187181                ),
188182            ],
189183            name = "adaptive_norm_modulation" ,
@@ -234,6 +228,41 @@ def get_config(self):
234228        return  config 
235229
236230
231+ class  Unpatch (layers .Layer ):
232+     def  __init__ (self , patch_size , output_dim , ** kwargs ):
233+         super ().__init__ (** kwargs )
234+         self .patch_size  =  int (patch_size )
235+         self .output_dim  =  int (output_dim )
236+ 
237+     def  call (self , inputs , height , width ):
238+         patch_size  =  self .patch_size 
239+         output_dim  =  self .output_dim 
240+         x  =  ops .reshape (
241+             inputs ,
242+             (- 1 , height , width , patch_size , patch_size , output_dim ),
243+         )
244+         # (b, h, w, p1, p2, o) -> (b, h, p1, w, p2, o) 
245+         x  =  ops .transpose (x , (0 , 1 , 3 , 2 , 4 , 5 ))
246+         return  ops .reshape (
247+             x ,
248+             (- 1 , height  *  patch_size , width  *  patch_size , output_dim ),
249+         )
250+ 
251+     def  get_config (self ):
252+         config  =  super ().get_config ()
253+         config .update (
254+             {
255+                 "patch_size" : self .patch_size ,
256+                 "output_dim" : self .output_dim ,
257+             }
258+         )
259+         return  config 
260+ 
261+     def  compute_output_shape (self , inputs_shape ):
262+         inputs_shape  =  list (inputs_shape )
263+         return  [inputs_shape [0 ], None , None , self .output_dim ]
264+ 
265+ 
237266class  MMDiT (keras .Model ):
238267    def  __init__ (
239268        self ,
@@ -251,13 +280,19 @@ def __init__(
251280        dtype = None ,
252281        ** kwargs ,
253282    ):
283+         if  None  in  latent_shape :
284+             raise  ValueError (
285+                 "`latent_shape` must be fully specified. " 
286+                 f"Received: latent_shape={ latent_shape }  
287+             )
288+         image_height  =  latent_shape [0 ] //  patch_size 
289+         image_width  =  latent_shape [1 ] //  patch_size 
290+         output_dim_in_final  =  patch_size ** 2  *  output_dim 
254291        data_format  =  standardize_data_format (data_format )
255292        if  data_format  !=  "channels_last" :
256293            raise  NotImplementedError (
257294                "Currently only 'channels_last' is supported." 
258295            )
259-         position_sequence_length  =  position_size  *  position_size 
260-         output_dim_in_final  =  patch_size ** 2  *  output_dim 
261296
262297        # === Layers === 
263298        self .patch_embedding  =  PatchEmbedding (
@@ -267,8 +302,11 @@ def __init__(
267302            dtype = dtype ,
268303            name = "patch_embedding" ,
269304        )
270-         self .position_embedding  =  PositionEmbedding (
271-             position_sequence_length , dtype = dtype , name = "position_embedding" 
305+         self .position_embedding_add  =  layers .Add (
306+             dtype = dtype , name = "position_embedding_add" 
307+         )
308+         self .position_embedding  =  AdjustablePositionEmbedding (
309+             position_size , position_size , dtype = dtype , name = "position_embedding" 
272310        )
273311        self .context_embedding  =  layers .Dense (
274312            hidden_dim ,
@@ -277,19 +315,13 @@ def __init__(
277315        )
278316        self .vector_embedding  =  models .Sequential (
279317            [
280-                 layers .Dense (
281-                     hidden_dim ,
282-                     activation = "silu" ,
283-                     dtype = dtype ,
284-                     name = "vector_embedding_dense_0" ,
285-                 ),
286-                 layers .Dense (
287-                     hidden_dim ,
288-                     activation = None ,
289-                     dtype = dtype ,
290-                     name = "vector_embedding_dense_1" ,
291-                 ),
292-             ]
318+                 layers .Dense (hidden_dim , activation = "silu" , dtype = dtype ),
319+                 layers .Dense (hidden_dim , activation = None , dtype = dtype ),
320+             ],
321+             name = "vector_embedding" ,
322+         )
323+         self .vector_embedding_add  =  layers .Add (
324+             dtype = dtype , name = "vector_embedding_add" 
293325        )
294326        self .timestep_embedding  =  TimestepEmbedding (
295327            hidden_dim , dtype = dtype , name = "timestep_embedding" 
@@ -301,12 +333,15 @@ def __init__(
301333                mlp_ratio ,
302334                use_context_projection = not  (i  ==  depth  -  1 ),
303335                dtype = dtype ,
304-                 name = f"joint_block { i }  ,
336+                 name = f"joint_block_ { i }  ,
305337            )
306338            for  i  in  range (depth )
307339        ]
308-         self .final_layer  =  OutputLayer (
309-             hidden_dim , output_dim_in_final , dtype = dtype , name = "final_layer" 
340+         self .output_layer  =  OutputLayer (
341+             hidden_dim , output_dim_in_final , dtype = dtype , name = "output_layer" 
342+         )
343+         self .unpatch  =  Unpatch (
344+             patch_size , output_dim , dtype = dtype , name = "unpatch" 
310345        )
311346
312347        # === Functional Model === 
@@ -316,18 +351,17 @@ def __init__(
316351            shape = pooled_projection_shape , name = "pooled_projection" 
317352        )
318353        timestep_inputs  =  layers .Input (shape = (1 ,), name = "timestep" )
319-         image_size  =  latent_shape [:2 ]
320354
321355        # Embeddings. 
322356        x  =  self .patch_embedding (latent_inputs )
323-         cropped_position_embedding  =  self ._get_cropped_position_embedding (
324-             x , patch_size ,  image_size ,  position_size 
357+         position_embedding  =  self .position_embedding (
358+             x , height = image_height ,  width = image_width 
325359        )
326-         x  =  layers . Add ( dtype = dtype )( [x , cropped_position_embedding ])
360+         x  =  self . position_embedding_add ( [x , position_embedding ])
327361        context  =  self .context_embedding (context_inputs )
328362        pooled_projection  =  self .vector_embedding (pooled_projection_inputs )
329363        timestep_embedding  =  self .timestep_embedding (timestep_inputs )
330-         timestep_embedding  =  layers . Add ( dtype = dtype ) (
364+         timestep_embedding  =  self . vector_embedding_add (
331365            [timestep_embedding , pooled_projection ]
332366        )
333367
@@ -338,9 +372,9 @@ def __init__(
338372            else :
339373                x  =  block (x , context , timestep_embedding )
340374
341-         # Final  layer. 
342-         x  =  self .final_layer (x , timestep_embedding )
343-         output_image  =  self ._unpatchify (x , patch_size ,  image_size ,  output_dim )
375+         # Output  layer. 
376+         x  =  self .output_layer (x , timestep_embedding )
377+         output_image  =  self .unpatch (x , height = image_height ,  width = image_width )
344378
345379        super ().__init__ (
346380            inputs = {
@@ -374,42 +408,6 @@ def __init__(
374408                    dtype  =  dtype .name 
375409                self .dtype_policy  =  keras .DTypePolicy (dtype )
376410
377-     def  _get_cropped_position_embedding (
378-         self , inputs , patch_size , image_size , position_size 
379-     ):
380-         h , w  =  image_size 
381-         h  =  h  //  patch_size 
382-         w  =  w  //  patch_size 
383-         top  =  (position_size  -  h ) //  2 
384-         left  =  (position_size  -  w ) //  2 
385-         hidden_dim  =  ops .shape (inputs )[- 1 ]
386-         position_embedding  =  self .position_embedding (inputs )
387-         position_embedding  =  ops .reshape (
388-             position_embedding ,
389-             (1 , position_size , position_size , hidden_dim ),
390-         )
391-         cropped_position_embedding  =  position_embedding [
392-             :, top  : top  +  h , left  : left  +  w , :
393-         ]
394-         cropped_position_embedding  =  ops .reshape (
395-             cropped_position_embedding , (1 , h  *  w , hidden_dim )
396-         )
397-         return  cropped_position_embedding 
398- 
399-     def  _unpatchify (self , x , patch_size , image_size , output_dim ):
400-         h , w  =  image_size 
401-         h  =  h  //  patch_size 
402-         w  =  w  //  patch_size 
403-         batch_size  =  ops .shape (x )[0 ]
404-         x  =  ops .reshape (
405-             x , (batch_size , h , w , patch_size , patch_size , output_dim )
406-         )
407-         # (b, h, w, p1, p2, o) -> (b, h, p1, w, p2, o) 
408-         x  =  ops .transpose (x , (0 , 1 , 3 , 2 , 4 , 5 ))
409-         return  ops .reshape (
410-             x , (batch_size , h  *  patch_size , w  *  patch_size , output_dim )
411-         )
412- 
413411    def  get_config (self ):
414412        config  =  super ().get_config ()
415413        config .update (
0 commit comments