-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LoRA] introduce LoraBaseMixin to promote reusability. #8774
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Nice initiative 👍🏽 . A lot to unpack here, so perhaps it's best to start bit by bit. I just went over the pipeline related components here. Regarding the There are quite a few methods in there that are making assumptions about the inheriting class using the method, which isn't really how a base class should behave. So loading methods related to specific model components are better left out e.g. I would assume that these are the methods that need to be defined for managing LoRAs across all pipelines? class LoraBaseMixin:
@classmethod
def _optionally_disable_offloading(cls, _pipeline):
raise NotImplementedError()
@classmethod
def _fetch_state_dict(
cls,
pretrained_model_name_or_path_or_dict,
weight_name,
use_safetensors,
local_files_only,
cache_dir,
force_download,
resume_download,
proxies,
token,
revision,
subfolder,
user_agent,
allow_pickle,
):
raise NotImplementedError()
@classmethod
def _best_guess_weight_name(
cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
):
return NotImplementedError()
@classmethod
def save_lora_weights(cls, **kwargs):
raise NotImplementedError("`save_lora_weights()` not implemented.")
@classmethod
def lora_state_dict(cls, **kwargs):
raise NotImplementedError("`lora_state_dict()` is not implemented.")
def load_lora_weights(self, **kwargs):
raise NotImplementedError("`load_lora_weights()` is not implemented.")
def unload_lora_weights(self, **kwargs):
raise NotImplementedError("`unload_lora_weights()` is not implemented.")
def fuse_lora(self, **kwargs):
raise NotImplementedError("`fuse_lora()` is not implemented.")
def unfuse_lora(self, **kwargs):
raise NotImplementedError("`unfuse_lora()` is not implemented.")
def disable_lora(self):
raise NotImplementedError("`disable_lora()` is not implemented.")
def enable_lora(self):
raise NotImplementedError("`unfuse_lora()` is not implemented.")
def get_active_adapters(self):
raise NotImplementedError("`delete_adapters()` is not implemented.")
def delete_adapters(self, adapter_names):
raise NotImplementedError("`delete_adapters()` is not implemented.")
def set_lora_device(self, adapter_names):
raise NotImplementedError("`delete_adapters()` is not implemented.")
@staticmethod
def pack_weights(layers, prefix):
raise NotImplementedError()
@staticmethod
def write_lora_layers(
state_dict: Dict[str, torch.Tensor],
save_directory: str,
is_main_process: bool,
weight_name: str,
save_function: Callable,
safe_serialization: bool,
):
raise NotImplementedError()
@property
def lora_scale(self) -> float:
raise NotImplementedError() Quite a few of these methods probably cannot be defined in the base class, such as I think it might be better to define these methods in a pipeline specific class that inherits from the class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
_lora_loadable_modules = ["unet", "text_encoder"]
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
**kwargs,
):
_load_lora_into_unet(**kwargs)
_load_lora_into_text_encoder(**kwargs)
def fuse_lora(self, components=["unet", "text_encoder"], **kwargs):
for fuse_component in components:
if fuse_component not in self._lora_loadable_modules:
raise ValueError()
model = getattr(self, fuse_component)
# check if diffusers model
if issubclass(model, ModelMixin):
model.fuse_lora()
# handle transformers models.
if issubclass(model, PretrainedModel):
fuse_text_encoder() I saw this comment about using the term "fuse_denoiser" in the fusing methods. I'm not so sure about that. I think if we want to fuse the LoRA in a specific component, it's better to pass in the actual name of the component used in pipeline, rather than track another attribute such as I also think the constants and class attributes such as TEXT_ENCODER_NAME and |
@DN6 as discussed over Slack, I have unified the One thing to note is that I had to still keep We could have two additional classes under
LMK. |
@DN6 I think this is ready for another review now. |
weights = [w if w is not None else 1.0 for w in weights] | ||
|
||
# e.g. [{...}, 7] -> [{expanded dict...}, 7] | ||
scale_expansion_fn = _SET_ADAPTER_SCALE_FN_MAPPING[self.__class__.__name__] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's just add a check in case this is applied to a model that doesn't exist in the mapping. Edge case, because we would probably always verify, but better to be safe.
if `scale_expansion_fn` is not None:
....
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But scale_expansion_fn
CANNOT be None no? We are directly indexing the dictionary here and not using get()
. So, wrong indexing will anyway lead to an error. But LMK if I am missing something.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually on second thought, the check might be overkill. If we add to a model not in the mapping, we should error out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I was thinking is that SD3Transformer2DModel
doesn't even need to be in the mapping. We use get
to check if a scale_expansion_fn
exists for a model class, and return None
if it doesn't. Either approach works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we add to a model not in the mapping, we should error out.
Yeah this already works. So, I would prefer that.
@DN6 anything else you would like me to address? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM 👍🏽
Thanks for the massive help and guidance, Dhruv! |
* introduce to promote reusability. * up * add more tests * up * remove comments. * fix fuse_nan test * clarify the scope of fuse_lora and unfuse_lora * remove space * rewrite fuse_lora a bit. * feedback * copy over load_lora_into_text_encoder. * address dhruv's feedback. * fix-copies * fix issubclass. * num_fused_loras * fix * fix * remove mapping * up * fix * style * fix-copies * change to SD3TransformerLoRALoadersMixin * Apply suggestions from code review Co-authored-by: Dhruv Nair <[email protected]> * up * handle wuerstchen * up * move lora to lora_pipeline.py * up * fix-copies * fix documentation. * comment set_adapters(). * fix-copies * fix set_adapters() at the model level. * fix? * fix --------- Co-authored-by: Dhruv Nair <[email protected]>
What does this PR do?
It is basically a mirror of #8670. I had accidentally merged it but I have reverted it in #8773. Apologies for this.
Check #8774 (comment) as well.
I have made comments in line to address the questions brought up by @yiyixuxu.