Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ def lora_adapters_switch_ddp_from_fsdp(modules, fsdp_plugin):
reduces the accumulated gradients across devices
"""

fsdp_plugin.ignored_modules = modules
# NOTE: assuming lora has no bias
fsdp_plugin.ignored_modules = []
for mod in modules:
fsdp_plugin.ignored_modules.append(mod.lora_A)
fsdp_plugin.ignored_modules.append(mod.lora_B)

def _all_reduce_hook(grad):
if grad is not None:
Expand All @@ -64,7 +68,6 @@ def _all_reduce_hook(grad):
return grad

for mod in modules:
# NOTE: assuming lora has no bias
A = mod.lora_A.default
B = mod.lora_B.default

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,27 +41,35 @@ def calculate_settings(n):
cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16
cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemm_4bit_inference_naive_bf16

# modified by [email protected]
def QUANT_STATE(W, base_layer):

def QUANT_STATE(W):
return getattr(W, "quant_state", None)
pass
# if the weights has quant_state just take it from there
if hasattr(W, 'quant_state'):
return W.quant_state

# otherwise fall back to checking if it is on the base layer
# This is needed when FSDP shards the parameters, and destroys the original
# weight matrix, so we can get the quant state back
return getattr(base_layer, 'quant_state', None)
pass

# modified by [email protected]
def get_lora_parameters(proj):
# For DPO or disabled adapters
base_layer = (proj.base_layer if hasattr(proj, "base_layer") else proj)
W = base_layer.weight

if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
return W, QUANT_STATE(W), None, None, None
return W, QUANT_STATE(W, base_layer), None, None, None
pass

active_adapter = proj.active_adapters[0] if \
hasattr(proj, "active_adapters") else proj.active_adapter
A = proj.lora_A [active_adapter].weight
B = proj.lora_B [active_adapter].weight
s = proj.scaling[active_adapter]
return W, QUANT_STATE(W), A, B, s
return W, QUANT_STATE(W, base_layer), A, B, s
pass


Expand Down
Loading