Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
283 commits
Select commit Hold shift + click to select a range
1c044da
Fix pad token
danielhanchen Oct 28, 2024
5286f19
Update llama.py
danielhanchen Oct 28, 2024
02437a8
Typo
danielhanchen Oct 28, 2024
9d07be0
ignored labels
danielhanchen Oct 28, 2024
a8b37a3
Revert "ignored labels"
danielhanchen Oct 28, 2024
2dfdba3
More patching
danielhanchen Oct 28, 2024
5541ab4
Update _utils.py
danielhanchen Oct 28, 2024
c6e9af2
Update _utils.py
danielhanchen Oct 28, 2024
cac56d1
Update cross_entropy_loss.py
danielhanchen Oct 28, 2024
5ee1189
Update cross_entropy_loss.py
danielhanchen Oct 28, 2024
85a5f60
Update cross_entropy_loss.py
danielhanchen Oct 28, 2024
20e38ed
Feat/all tmp (#1219)
danielhanchen Oct 30, 2024
7e1692a
Bug fixes
danielhanchen Oct 30, 2024
6bef8f1
Update pyproject.toml
danielhanchen Oct 30, 2024
9ccbc0e
Update _utils.py
danielhanchen Oct 30, 2024
95ecc57
Update __init__.py
danielhanchen Oct 30, 2024
5f5fef8
Update __init__.py
danielhanchen Oct 30, 2024
784dd13
Update _utils.py
danielhanchen Oct 30, 2024
5b75e21
Update _utils.py
danielhanchen Oct 30, 2024
74ab93c
Update _utils.py
danielhanchen Oct 30, 2024
526505c
Update _utils.py
danielhanchen Oct 30, 2024
251ba77
Update cross_entropy_loss.py
danielhanchen Oct 30, 2024
530c495
Update cross_entropy_loss.py
danielhanchen Oct 30, 2024
07394c3
Update cross_entropy_loss.py
danielhanchen Oct 30, 2024
6d7004b
Update cross_entropy_loss.py
danielhanchen Oct 30, 2024
d86b20a
Update cross_entropy_loss.py
danielhanchen Oct 30, 2024
9920950
Update cross_entropy_loss.py
danielhanchen Oct 30, 2024
9f926ce
Update cross_entropy_loss.py
danielhanchen Oct 30, 2024
30cdf65
Update cross_entropy_loss.py
danielhanchen Oct 30, 2024
54b901b
Update cross_entropy_loss.py
danielhanchen Oct 31, 2024
6db9d28
Update cross_entropy_loss.py
danielhanchen Oct 31, 2024
8aefcd0
Update cross_entropy_loss.py
danielhanchen Oct 31, 2024
7bf626b
Update cross_entropy_loss.py
danielhanchen Oct 31, 2024
d455751
Update cross_entropy_loss.py
danielhanchen Oct 31, 2024
055eeb8
Update cross_entropy_loss.py
danielhanchen Oct 31, 2024
8090b7c
Tied weights
danielhanchen Oct 31, 2024
7559efb
Revert "Tied weights"
danielhanchen Oct 31, 2024
ad63a32
Tied weights
danielhanchen Oct 31, 2024
35aa992
Utils
danielhanchen Nov 3, 2024
0172ee3
CE Loss patching
danielhanchen Nov 3, 2024
c228682
Update __init__.py
danielhanchen Nov 3, 2024
9aa221a
Update __init__.py
danielhanchen Nov 3, 2024
751413e
Patching
danielhanchen Nov 3, 2024
82db087
Update cross_entropy_loss.py
danielhanchen Nov 3, 2024
cf68202
CE Loss
danielhanchen Nov 3, 2024
63a1828
Update _utils.py
danielhanchen Nov 3, 2024
3f0e56f
Update _utils.py
danielhanchen Nov 3, 2024
1190ed4
CE Loss
danielhanchen Nov 3, 2024
607ac34
Update _utils.py
danielhanchen Nov 3, 2024
32eac0b
Update _utils.py
danielhanchen Nov 3, 2024
5b6d401
Layernorm
danielhanchen Nov 4, 2024
3d19a71
Update _utils.py
danielhanchen Nov 4, 2024
76da511
Update _utils.py
danielhanchen Nov 4, 2024
013ebaa
Post patch
danielhanchen Nov 4, 2024
608916a
Update _utils.py
danielhanchen Nov 4, 2024
19836e3
Update llama.py
danielhanchen Nov 4, 2024
0164087
Update _utils.py
danielhanchen Nov 4, 2024
205f7ad
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
2f1f393
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
05b8f66
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
8d205c0
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
a1e9e13
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
94655f8
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
085f998
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
c796fd9
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
e943d77
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
16a7df6
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
f65b064
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
1ff49b8
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
080e558
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
f6d50c7
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
fad4202
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
736b16a
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
eb76416
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
367e43f
typing
danielhanchen Nov 4, 2024
993df20
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
8f566b3
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
22bb46b
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
b5c9f81
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
c7b2220
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
2d0ab26
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
428f662
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
5023ce9
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
5ca3d4a
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
3b32d81
int64
danielhanchen Nov 4, 2024
9bae6e2
Update _utils.py
danielhanchen Nov 4, 2024
5123623
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
4b1d9e2
constexpr
danielhanchen Nov 4, 2024
7d5111a
constexpr
danielhanchen Nov 4, 2024
dff5a52
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
969d1bd
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
4b5847f
Update _utils.py
danielhanchen Nov 4, 2024
766bf1e
Update _utils.py
danielhanchen Nov 4, 2024
646f1b7
Update _utils.py
danielhanchen Nov 5, 2024
97f37ac
CE
danielhanchen Nov 5, 2024
cc563fa
Update cross_entropy_loss.py
danielhanchen Nov 5, 2024
f643148
Update _utils.py
danielhanchen Nov 5, 2024
f28d7f6
Update llama.py
danielhanchen Nov 5, 2024
d8103e1
Update _utils.py
danielhanchen Nov 5, 2024
b9e1a49
Update rms_layernorm.py
danielhanchen Nov 5, 2024
56af302
Update rms_layernorm.py
danielhanchen Nov 5, 2024
a3c84a3
Update rms_layernorm.py
danielhanchen Nov 5, 2024
f7d5c56
Update rms_layernorm.py
danielhanchen Nov 5, 2024
8496ff6
Update rms_layernorm.py
danielhanchen Nov 5, 2024
2909eaf
Update rms_layernorm.py
danielhanchen Nov 5, 2024
afc8af6
Update utils.py
danielhanchen Nov 5, 2024
2d8d1e1
Update rms_layernorm.py
danielhanchen Nov 5, 2024
ecc1ad2
Update rms_layernorm.py
danielhanchen Nov 5, 2024
ae7cb78
Update rms_layernorm.py
danielhanchen Nov 5, 2024
22da266
Update rms_layernorm.py
danielhanchen Nov 5, 2024
beb6854
Update rms_layernorm.py
danielhanchen Nov 5, 2024
14c3d2f
Update rms_layernorm.py
danielhanchen Nov 5, 2024
ef4b079
Update rms_layernorm.py
danielhanchen Nov 5, 2024
ef684f8
Update rms_layernorm.py
danielhanchen Nov 5, 2024
3e4c42f
Update rms_layernorm.py
danielhanchen Nov 5, 2024
8f825eb
Update rms_layernorm.py
danielhanchen Nov 5, 2024
bd4ac7b
Update rms_layernorm.py
danielhanchen Nov 5, 2024
6f38731
Update rms_layernorm.py
danielhanchen Nov 5, 2024
2df35d4
typing
danielhanchen Nov 5, 2024
74d89d1
Update rope_embedding.py
danielhanchen Nov 5, 2024
98927ee
types
danielhanchen Nov 5, 2024
f3e2bd6
Disable compiling
danielhanchen Nov 5, 2024
c30bd2a
Update _utils.py
danielhanchen Nov 5, 2024
813cbdd
Update _utils.py
danielhanchen Nov 5, 2024
34ce5d1
Forward hook
danielhanchen Nov 5, 2024
f84cf4b
Update _utils.py
danielhanchen Nov 5, 2024
745814c
Update llama.py
danielhanchen Nov 5, 2024
ab9f8e1
Update _utils.py
danielhanchen Nov 5, 2024
daa7909
Update llama.py
danielhanchen Nov 5, 2024
536a1a6
Update llama.py
danielhanchen Nov 5, 2024
648ca59
Update _utils.py
danielhanchen Nov 5, 2024
486d0d6
Update pyproject.toml
danielhanchen Nov 5, 2024
eb4da9d
Update _utils.py
danielhanchen Nov 5, 2024
da397f4
Update llama.py
danielhanchen Nov 5, 2024
70b65cf
CE Loss
danielhanchen Nov 5, 2024
aeec57e
Update cross_entropy_loss.py
danielhanchen Nov 5, 2024
fb393fc
Update _utils.py
danielhanchen Nov 5, 2024
cab1e72
Update cross_entropy_loss.py
danielhanchen Nov 6, 2024
51fea97
Update cross_entropy_loss.py
danielhanchen Nov 6, 2024
58e541b
Update cross_entropy_loss.py
danielhanchen Nov 6, 2024
0ed0532
Merge branch 'main' into nightly
danielhanchen Nov 6, 2024
ef2c56f
Update llama.py
danielhanchen Nov 6, 2024
24ab0d2
Merge branch 'main' into nightly
danielhanchen Nov 6, 2024
13d7412
Update _utils.py
danielhanchen Nov 6, 2024
5a7eaf8
Update _utils.py
danielhanchen Nov 6, 2024
d2186ed
Update _utils.py
danielhanchen Nov 6, 2024
6434447
Update _utils.py
danielhanchen Nov 6, 2024
67611e6
Update _utils.py
danielhanchen Nov 6, 2024
36c5836
Merge branch 'main' into nightly
danielhanchen Nov 6, 2024
f24aef5
Fix: cast logits to float32 in cross_entropy_forward to prevent error…
Erland366 Nov 6, 2024
3d906e6
Throw error when inferencing longer than max_popsition_embeddings (#1…
Datta0 Nov 6, 2024
de1049b
CLI now handles user input strings for dtype correctly (#1235)
Rabbidon Nov 6, 2024
be72975
Update flex_attention.py
danielhanchen Nov 6, 2024
05170cd
Update _utils.py
danielhanchen Nov 6, 2024
7e0877d
Update _utils.py
danielhanchen Nov 6, 2024
6b5c599
Update flex_attention.py
danielhanchen Nov 6, 2024
1ba9f2e
Update flex_attention.py
danielhanchen Nov 6, 2024
da61c4d
Update loader.py
danielhanchen Nov 6, 2024
3316ee2
Update loader.py
danielhanchen Nov 6, 2024
501ca84
Update flex_attention.py
danielhanchen Nov 6, 2024
ce621b7
Update flex_attention.py
danielhanchen Nov 6, 2024
4b01ff1
Update flex_attention.py
danielhanchen Nov 6, 2024
ef5052a
Update flex_attention.py
danielhanchen Nov 7, 2024
52bca32
Update _utils.py
danielhanchen Nov 7, 2024
68b8d62
Merge branch 'main' into nightly
danielhanchen Nov 7, 2024
15da065
Merge branch 'main' into nightly
danielhanchen Nov 7, 2024
8b3e9c2
Update cross_entropy_loss.py
danielhanchen Nov 7, 2024
3a1e7ef
Update _utils.py
danielhanchen Nov 7, 2024
f1ec165
Update tokenizer_utils.py
danielhanchen Nov 10, 2024
a4e9705
Update tokenizer_utils.py
danielhanchen Nov 10, 2024
92c6a27
Update tokenizer_utils.py
danielhanchen Nov 10, 2024
673f541
Update tokenizer_utils.py
danielhanchen Nov 10, 2024
8fe9109
Update tokenizer_utils.py
danielhanchen Nov 11, 2024
ad41479
triton_cast
danielhanchen Nov 11, 2024
fcf2009
Update utils.py
danielhanchen Nov 11, 2024
af9ba07
Qwen 2.5 Coder
danielhanchen Nov 12, 2024
e99acdd
Merge branch 'main' into nightly
danielhanchen Nov 13, 2024
3fec577
Fix/export mistral (#1281)
Erland366 Nov 13, 2024
03c6243
DOC Update - Update README.md with os.environ in example (#1269)
udaygirish Nov 13, 2024
10565ef
fix/get_chat_template (#1246)
Erland366 Nov 13, 2024
dc0232c
fix/sft-trainer (#1276)
Erland366 Nov 14, 2024
84d6d36
Update __init__.py
danielhanchen Nov 14, 2024
a31027c
Update trainer.py
danielhanchen Nov 14, 2024
035bcce
Update trainer.py
danielhanchen Nov 14, 2024
597169c
Update trainer.py
danielhanchen Nov 14, 2024
11b350f
Update tokenizer_utils.py
danielhanchen Nov 14, 2024
e4d1754
Merge branch 'main' into nightly
danielhanchen Nov 14, 2024
3b11ae7
Update llama.py
danielhanchen Nov 14, 2024
5eb971f
Fix #853
danielhanchen Nov 14, 2024
a146521
fix/sfttrainer-compatibility (#1293)
Erland366 Nov 15, 2024
74382de
Update rms_layernorm.py
danielhanchen Nov 16, 2024
a6b8dda
Update rms_layernorm.py
danielhanchen Nov 16, 2024
82e4466
Gemma
danielhanchen Nov 16, 2024
50b0aba
Update rms_layernorm.py
danielhanchen Nov 16, 2024
9773fee
Update gemma2.py
danielhanchen Nov 16, 2024
1a3d2d5
Cut Cross Entropy
danielhanchen Nov 17, 2024
4f51d87
Update llama.py
danielhanchen Nov 17, 2024
b18edb9
Cut Cross Entropy
danielhanchen Nov 17, 2024
0a5c519
Update llama.py
danielhanchen Nov 17, 2024
59caca9
Update llama.py
danielhanchen Nov 17, 2024
49df51f
Update llama.py
danielhanchen Nov 18, 2024
cc314c8
Update __init__.py
danielhanchen Nov 18, 2024
42a76f1
Update __init__.py
danielhanchen Nov 18, 2024
4ed6ae8
Update _utils.py
danielhanchen Nov 18, 2024
2fade27
Update _utils.py
danielhanchen Nov 18, 2024
07ee0da
Update _utils.py
danielhanchen Nov 18, 2024
8eae7f9
Update _utils.py
danielhanchen Nov 18, 2024
6ab1d3a
Update _utils.py
danielhanchen Nov 18, 2024
d5c1c17
Update _utils.py
danielhanchen Nov 18, 2024
4abf3de
Update _utils.py
danielhanchen Nov 18, 2024
b144ff4
Update _utils.py
danielhanchen Nov 18, 2024
b9b7a5b
Update mapper.py
danielhanchen Nov 18, 2024
9f93c49
Update _utils.py
danielhanchen Nov 19, 2024
d00dc52
Update _utils.py
danielhanchen Nov 19, 2024
caf4cd4
Update _utils.py
danielhanchen Nov 19, 2024
4cd14bb
Update _utils.py
danielhanchen Nov 19, 2024
a0e709b
Update _utils.py
danielhanchen Nov 19, 2024
f92c16d
Update _utils.py
danielhanchen Nov 19, 2024
c7c984f
Update _utils.py
danielhanchen Nov 19, 2024
81538c3
Update _utils.py
danielhanchen Nov 19, 2024
029f5d5
Update _utils.py
danielhanchen Nov 20, 2024
bd1a175
patch_fast_lora
danielhanchen Nov 20, 2024
cabf21f
vision
danielhanchen Nov 20, 2024
7d5c9ed
Update fast_lora.py
danielhanchen Nov 21, 2024
4ddd1bb
Update _utils.py
danielhanchen Nov 21, 2024
1c94f04
Update _utils.py
danielhanchen Nov 21, 2024
d6ccbfb
Vision
danielhanchen Nov 21, 2024
8a44b6c
Update trainer.py
danielhanchen Nov 21, 2024
f077680
Merge branch 'main' into nightly
danielhanchen Nov 21, 2024
d5b8408
Update save.py
danielhanchen Nov 21, 2024
a5d4084
FastBaseVisionModel
danielhanchen Nov 21, 2024
7f5a9a7
Update loader_utils.py
danielhanchen Nov 21, 2024
d160618
Update vision.py
danielhanchen Nov 21, 2024
2420736
Update loader.py
danielhanchen Nov 21, 2024
0747078
Update vision.py
danielhanchen Nov 21, 2024
1f32b23
Update loader.py
danielhanchen Nov 21, 2024
a45e564
Update vision.py
danielhanchen Nov 21, 2024
767a31f
Update _utils.py
danielhanchen Nov 21, 2024
1ad1b46
tokenizer_name
danielhanchen Nov 21, 2024
26f2337
Update loader.py
danielhanchen Nov 21, 2024
5ab4b60
Update vision.py
danielhanchen Nov 21, 2024
fc7d747
Update save.py
danielhanchen Nov 21, 2024
e0b14fa
Update save.py
danielhanchen Nov 21, 2024
677cf9f
Update vision.py
danielhanchen Nov 21, 2024
8ab5dcb
Update vision.py
danielhanchen Nov 21, 2024
adaf6ee
Update vision.py
danielhanchen Nov 21, 2024
1a548f3
Update vision.py
danielhanchen Nov 21, 2024
5886ecb
Update vision.py
danielhanchen Nov 21, 2024
a98fc9c
Update vision.py
danielhanchen Nov 21, 2024
535e899
Update _utils.py
danielhanchen Nov 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion unsloth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
pass

# Reduce VRAM usage by reducing fragmentation
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,roundup_power2_divisions:[64:128,256:64,>:32]"

# Hugging Face Hub faster downloads
if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
Expand Down
1 change: 1 addition & 0 deletions unsloth/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
apply_lora_mlp_geglu_approx,
apply_lora_qkv,
apply_lora_o,
fast_lora_forward,
)
from .utils import fast_dequantize, fast_gemv, QUANT_STATE, fast_linear_forward, matmul_lora

Expand Down
78 changes: 78 additions & 0 deletions unsloth/kernels/fast_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,3 +410,81 @@ def apply_lora_o(self, X):
O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS)
return O
pass


IDENTITY_DROPOUT = torch.nn.Identity
@torch._disable_dynamo
def fast_lora_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError(
"Unsloth: Currently not supported yet - reshaping done incorrectly"
)
self._check_forward_args(x, *args, **kwargs)
adapter_names = kwargs.pop("adapter_names", None)

if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif adapter_names is not None:
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
# Fastpath
if len(self.active_adapters) == 1:
active_adapter = self.active_adapters[0]
if active_adapter not in self.lora_A.keys(): return self.base_layer(x, *args, **kwargs)

dropout = self.lora_dropout[active_adapter]
if isinstance(dropout, IDENTITY_DROPOUT) and not self.use_dora[active_adapter]:
lora_A = self.lora_A[active_adapter].weight
lora_B = self.lora_B[active_adapter].weight
scaling = self.scaling[active_adapter]
W = self.base_layer.weight
return LoRA_W.apply(x, W, QUANT_STATE(W), lora_A, lora_B, scaling)
pass
pass

result = self.base_layer(x, *args, **kwargs)
# As per Tim Dettmers, for 4bit, we need to defensively clone here.
# The reason is that in some cases, an error can occur that backprop
# does not work on a manipulated view. This issue may be solved with
# newer PyTorch versions but this would need extensive testing to be
# sure.
result = result.clone()

for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]

requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
x = x.to(lora_A.weight.dtype)

if not self.use_dora[active_adapter]:
result = result + lora_B(lora_A(dropout(x))) * scaling
else:
if isinstance(dropout, torch.nn.Identity) or not self.training:
base_result = result
else:
x = dropout(x)
base_result = None

result = result + self.lora_magnitude_vector[active_adapter](
x,
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
base_layer=self.get_base_layer(),
base_result=base_result,
)
if requires_conversion:
result = result.to(expected_dtype)

return result
pass
10 changes: 8 additions & 2 deletions unsloth/kernels/rms_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def _rms_layernorm_forward(
@triton.jit
def _rms_layernorm_backward(
dY, dY_row_stride,
dX, dX_row_stride,
X, X_row_stride,
W, W_row_stride,
r, r_row_stride,
Expand All @@ -78,6 +79,9 @@ def _rms_layernorm_backward(
X += row_idx * X_row_stride
r += row_idx * r_row_stride

if GEMMA: dX += row_idx * dY_row_stride
else: dX = dY

dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32)
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
Expand All @@ -91,7 +95,7 @@ def _rms_layernorm_backward(

rowsum_dY_normed = tl.sum(dY_W * normed, axis = 0)
output = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed)
tl.store(dY + col_offsets, output, mask = mask)
tl.store(dX + col_offsets, output, mask = mask)
pass


Expand Down Expand Up @@ -172,9 +176,11 @@ def backward(ctx, dY : torch.Tensor):
n_cols : int
n_rows, n_cols = dY.shape
# dW = X
dX = torch.empty_like(dY, device = "cuda:0") if ctx.GEMMA else dY

_rms_layernorm_backward[(n_rows,)](
dY, dY.stride(0),
dX, dX.stride(0),
X, X .stride(0),
W, W .stride(0),
r, r .stride(0),
Expand All @@ -184,7 +190,7 @@ def backward(ctx, dY : torch.Tensor):
BLOCK_SIZE = ctx.BLOCK_SIZE,
num_warps = ctx.num_warps,
)
dX = dY.view(*shape)
dX = dX.view(*shape)
return dX, None, None, None
pass
pass
Expand Down
2 changes: 1 addition & 1 deletion unsloth/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .loader import FastLanguageModel
from .loader import FastLanguageModel, FastVisionModel
from .llama import FastLlamaModel
from .mistral import FastMistralModel
from .qwen2 import FastQwen2Model
Expand Down
Loading