diff --git a/cosmos_predict2/_src/predict2/models/text2world_model_rectified_flow.py b/cosmos_predict2/_src/predict2/models/text2world_model_rectified_flow.py index fc05b1d..f66ddc7 100644 --- a/cosmos_predict2/_src/predict2/models/text2world_model_rectified_flow.py +++ b/cosmos_predict2/_src/predict2/models/text2world_model_rectified_flow.py @@ -184,14 +184,6 @@ def build_net(self): with misc.timer("Creating PyTorch model"): with torch.device(init_device): net = lazy_instantiate(config.net) - if config.use_lora: - self.add_lora( - net, - lora_rank=config.lora_rank, - lora_alpha=config.lora_alpha, - lora_target_modules=config.lora_target_modules, - init_lora_weights=config.init_lora_weights, - ) self._param_count = count_params(net, verbose=False) @@ -204,6 +196,15 @@ def build_net(self): # IMPORTANT: (qsh) model init should not depends on current tensor shape, or it can handle Dtensor shape. net.init_weights() + if config.use_lora: + self.add_lora( + net, + lora_rank=config.lora_rank, + lora_alpha=config.lora_alpha, + lora_target_modules=config.lora_target_modules, + init_lora_weights=config.init_lora_weights, + ) + if self.fsdp_device_mesh: broadcast_dtensor_model_states(net, self.fsdp_device_mesh) for name, param in net.named_parameters():