Skip to content

Conversation

@dbaranchuk
Copy link
Contributor

TODO #1: double check that both memory_efficient_backward and has_fp16_weights options work properly.
TODO #2: make a clearer PR description.

This PR provides two features:

  1. Memory efficient backward option:
  • Stores int8 weights in row-major format;
  • In forward pass, we have to cast row-major matrix to Turing/Ampere at each iteration. This leads to noticeable computational overhead for inference. Thus, we suggest using this option only for training.
  • In backward pass, we transform row-major weight matrix to the fp16 weights and perform fp16 matmul with fp16 grad_outputs. Note that we do not store fp16 weights and just efficiently compute them on the fly. Therefore, overall computational and memory overheads are negligible.
  1. Cast inputs and outputs of the 8bit layer to fp16 and back to the initial input dtype, respectively. This allows us to use models of arbitrary dtypes without their conversion to fp16.

Copy link
Collaborator

@TimDettmers TimDettmers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me overall. Please include also a test that tests for minimal difference between fp16 gradients and int8->fp16 gradients. You can add these tests either to the existing tests or create a new one.

Comment on lines 230 to 232
# Cast A to fp16
A_dtype = A.dtype
A = A.to(torch.float16)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is exactly what we talked about, this should take care of bfloat16! One thing we need to think about is some kind of warning or something when this is first executed to make people aware that a cast happens and the operation quantization is performed in fp16.

Additionally, it would be great to have bf16 tests that verify everything works correctly with those inputs. I think you just need to change one line (and check if everything else is correctly executed).

Copy link
Contributor

@justheuristic justheuristic Sep 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • added a warnings.warn the first time A is cast to a different dtype
  • added FP16 tests; fine-tuned typecasts to fit into almost all thresholds
    • increased atol 0.0175 -> 0.02 in one specific case for BF16 (this does not affect FP16)
    • e.g. grad_bias is now computed on natural grad_output before it is cast (and before it loses precision)
  • memory_efficient_backward will set Linear8bitLt.weight.requires_grad = False by default
  • added a test for backward pass with memory_efficient_backward

output = output.to(A.dtype).add_(bias)

# 4. Mixed-precision decomposition matmul
if coo_tensorA is not None and subA is not None:
Copy link
Contributor

@justheuristic justheuristic Sep 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curiously, when i tried to replace the next line from
output += torch.matmul(subA, state.subB)

to
output.addmm_(subA, state.subB)

the precision would drop and the tests would fail.
I have no idea why - the dtypes of output, subA and subB are always equal (tested).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cannot remember if I stumbled upon the same thing. I remember trying to make this matrix multiplication more efficient but failed. What is the increase that you see in errors?

It does not make much sense to me since in cuBLAS you perform (A @ B) + D = C and the results of A @ B is in fp32 so the entire operation should be more precise. The same goes for fused multiply-add in general, which is more precise than multiplication followed by addition. It might be some weird tensor core issue, but it makes no sense to me.

If the error is only smaller some of the time and it has more variance, it would still be okay to have this. I believe it would be a good chunk faster.

@justheuristic
Copy link
Contributor

Ran all tests 10 times to check for stability

image
image

Copy link
Collaborator

@TimDettmers TimDettmers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks all good to me. Let's discuss this briefly tomorrow. I am curious if we can get the .addmm_ to work. Otherwise, just a couple of questions on the test performance. Overall great work! Thank you so much, Yozh!

SCB = (state.SCB.unsqueeze(1) / 127.0).half()
CB *= SCB
grad_A = torch.mm(grad_output, CB).view(ctx.grad_shape)
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).div(127.0))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how PyTorch implements div, but multiplication is about 30x faster than division. Since we apply it over a matrix this might make a tiny but significant difference. So .mul(1/127.0) might be better here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

applied, thanks

(o1 * grad_proj).sum().backward()
grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half()
scale = grad_ref.abs().mean()
assert torch.allclose(b1.grad, grad_ref, rtol=0, atol=0.05 * scale)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember I had some tests that were using relative difference normalized by the standard deviation, which is similar to this. What is the range of errors that you see? It might also be good to test for a maximum of k elements that exceed a threshold. This helps to differentiate worse-case vs general performance.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it, i've added a separate assert with outliers:

        torch.testing.assert_allclose(b1.grad, grad_ref, rtol=0, atol=0.05 * scale)
        idx = torch.isclose(b1.grad, grad_ref, atol=0.01 * scale, rtol=0.1)
        assert (idx == 0).sum().item() <= b1.numel() * 0.005

output = output.to(A.dtype).add_(bias)

# 4. Mixed-precision decomposition matmul
if coo_tensorA is not None and subA is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cannot remember if I stumbled upon the same thing. I remember trying to make this matrix multiplication more efficient but failed. What is the increase that you see in errors?

It does not make much sense to me since in cuBLAS you perform (A @ B) + D = C and the results of A @ B is in fp32 so the entire operation should be more precise. The same goes for fused multiply-add in general, which is more precise than multiplication followed by addition. It might be some weird tensor core issue, but it makes no sense to me.

If the error is only smaller some of the time and it has more variance, it would still be okay to have this. I believe it would be a good chunk faster.

@TimDettmers TimDettmers marked this pull request as ready for review September 20, 2022 04:08
@TimDettmers
Copy link
Collaborator

All good! Thank you both!

@TimDettmers TimDettmers merged commit 439f2b0 into bitsandbytes-foundation:main Sep 20, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants