-
Notifications
You must be signed in to change notification settings - Fork 9.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement backward passes for llama with small training llama from scratch example #1360
Commits on May 1, 2023
-
implement 8 of 14 missing backward pass operations used by llama
- GGML_OP_ADD_AT - GGML_OP_CPY - GGML_OP_MUL_MAT (src0.grad) - GGML_OP_PERMUTE - GGML_OP_RESHAPE - GGML_OP_SCALE - GGML_OP_TRANSPOSE - GGML_OP_VIEW implement additional ggml operation GGML_OP_ADD_AT, which is necessary for backward pass of GGML_OP_VIEW. this operation adds src1 to src0 with data offset, i.e. to view(src0, ..., offset). the values are return in a tensor size of src0. values outside of [data+offset:data+offset+nbytes(src1)] are just the original values from src0. still missing backward passes for llama: - GGML_OP_DIAG_MASK_INF - GGML_OP_GET_ROWS - GGML_OP_RMS_NORM - GGML_OP_ROPE - GGML_OP_SILU - GGML_OP_SOFT_MAX
Configuration menu - View commit details
-
Copy full SHA for 73ac18d - Browse repository at this point
Copy the full SHA 73ac18dView commit details -
implement 5 of 6 missing backward pass operations used by llama
- GGML_OP_DIAG_MASK_INF - GGML_OP_GET_ROWS - GGML_OP_RMS_NORM - GGML_OP_SILU - GGML_OP_SOFT_MAX add necessary ggml operations GGML_OP_ADD1, GGML_OP_SILU_BACK, GGML_OP_RMS_NORM_BACK, GGML_OP_DIAG_MASK_ZERO, and GGML_OP_ROPE_BACK GGML_OP_ADD1 is necessary to add a scalar value in the backward pass of GGML_OP_SOFT_MAX GGML_OP_ADD1 could also be replaced by using GGML_OP_ADD and GGML_OP_REPEAT, but the performance would be worse. additionally GGML_OP_REPEAT will return unexpected value when the the input to GGML_OP_SOFT_MAX contains only a single scalar. in this case GGML_OP_REPEAT will not return the value that should be repeated (src1) but the value which shape the result should take (src0). So in this case it can not replace GGML_OP_ADD1. GGML_OP_SILU_BACK, GGML_OP_RMS_NORM_BACK and GGML_OP_ROPE_BACK are necessary for backward pass of GGML_OP_SILU, GGML_OP_RMS_NORM and GGML_OP_ROPE. The backward pass for these functions cannot be easily composed of existing operations. Since the backward pass builds a computation graph we need operations forward pass implementations of the the required backward passes. Sounds a bit confusing at first, I know... GGML_OP_DIAG_MASK_ZERO is necessary for backward pass of GGML_OP_DIAG_MASK_INF. Some operations where previously inplace-only. for backward pass there needs to be non-inplace variants. staying consistent with other operations that have non-inplace and inplace variants, the operations are changed to non-inplace and functions with "_inplace" are added which are inplace. in llama we need to call the inplace variants so that it is implemented as before. for llama backward pass we need to use the non-inplace variants. still not completely implemented backward passes for llama: - GGML_OP_ROPE: needs forward pass for GGML_OP_ROPE_BACK - GGML_OP_GET_ROWS: only necessary for tokenizer
Configuration menu - View commit details
-
Copy full SHA for b164343 - Browse repository at this point
Copy the full SHA b164343View commit details -
norm & rms_norm can not be threaded:
after investigation rms norm for quite some time I come to the conclusion that neither norm, nor rms_norm can be threaded, because we need mean over all items, not just of the slices each thread sees.
Configuration menu - View commit details
-
Copy full SHA for b908007 - Browse repository at this point
Copy the full SHA b908007View commit details -
Configuration menu - View commit details
-
Copy full SHA for 36d8a05 - Browse repository at this point
Copy the full SHA 36d8a05View commit details -
Configuration menu - View commit details
-
Copy full SHA for 488decf - Browse repository at this point
Copy the full SHA 488decfView commit details -
Configuration menu - View commit details
-
Copy full SHA for 4e1f81d - Browse repository at this point
Copy the full SHA 4e1f81dView commit details -
Configuration menu - View commit details
-
Copy full SHA for 0da2675 - Browse repository at this point
Copy the full SHA 0da2675View commit details -
Configuration menu - View commit details
-
Copy full SHA for 20e3c1d - Browse repository at this point
Copy the full SHA 20e3c1dView commit details -
Configuration menu - View commit details
-
Copy full SHA for 9345f4c - Browse repository at this point
Copy the full SHA 9345f4cView commit details -
Configuration menu - View commit details
-
Copy full SHA for 9d6fc28 - Browse repository at this point
Copy the full SHA 9d6fc28View commit details -
Configuration menu - View commit details
-
Copy full SHA for 6fb08b4 - Browse repository at this point
Copy the full SHA 6fb08b4View commit details -
Configuration menu - View commit details
-
Copy full SHA for 671e592 - Browse repository at this point
Copy the full SHA 671e592View commit details -
bug fix for scale backward pass
use sum instead of mean for gradient of scalar scale parameter
Configuration menu - View commit details
-
Copy full SHA for a367eb9 - Browse repository at this point
Copy the full SHA a367eb9View commit details -
Configuration menu - View commit details
-
Copy full SHA for 0197bcb - Browse repository at this point
Copy the full SHA 0197bcbView commit details -
improve performance of sum backward pass
use add1(x,y) instead of add(x,repeat(y,x))
Configuration menu - View commit details
-
Copy full SHA for bfe5072 - Browse repository at this point
Copy the full SHA bfe5072View commit details -
improve performance of sqr backward pass
use scale(x,y) instead of mul(x,repeat(y,x))
Configuration menu - View commit details
-
Copy full SHA for b583136 - Browse repository at this point
Copy the full SHA b583136View commit details -
Configuration menu - View commit details
-
Copy full SHA for 7571147 - Browse repository at this point
Copy the full SHA 7571147View commit details -
Configuration menu - View commit details
-
Copy full SHA for 0ea8201 - Browse repository at this point
Copy the full SHA 0ea8201View commit details -
Configuration menu - View commit details
-
Copy full SHA for b2bd822 - Browse repository at this point
Copy the full SHA b2bd822View commit details -
Configuration menu - View commit details
-
Copy full SHA for c483a7d - Browse repository at this point
Copy the full SHA c483a7dView commit details -
Configuration menu - View commit details
-
Copy full SHA for ecf949b - Browse repository at this point
Copy the full SHA ecf949bView commit details -
this uses ggml_opt to train a,b for minimal e=sum(sqr(c - a*b)) for random initial a,b,c
Configuration menu - View commit details
-
Copy full SHA for 54ab300 - Browse repository at this point
Copy the full SHA 54ab300View commit details -
correctly implement softmax backward pass using new operation ggml_diag
ggml_diag constructs diagonal matrices with entries. ggml_diag(shape[a,1,c,d]) -> shape[a,a,c,d]
Configuration menu - View commit details
-
Copy full SHA for 1a80e9a - Browse repository at this point
Copy the full SHA 1a80e9aView commit details -
Configuration menu - View commit details
-
Copy full SHA for fea42be - Browse repository at this point
Copy the full SHA fea42beView commit details -
Configuration menu - View commit details
-
Copy full SHA for 9310650 - Browse repository at this point
Copy the full SHA 9310650View commit details -
Configuration menu - View commit details
-
Copy full SHA for 38675e5 - Browse repository at this point
Copy the full SHA 38675e5View commit details -
de-duplicate ggml_forward_dup code taking care of contiguous tensors …
…of same type. with this we can duplicate tensor of any typ as long as they are contiguous.
Configuration menu - View commit details
-
Copy full SHA for c1a8893 - Browse repository at this point
Copy the full SHA c1a8893View commit details -
fix ggml_compute_forward_dup_same_cont for when nelements < nthreads
when more threads are used than elements exist ie1 was less than ie0, resulting in invalid negative byte count argument in memcpy
Configuration menu - View commit details
-
Copy full SHA for 83fa6b3 - Browse repository at this point
Copy the full SHA 83fa6b3View commit details -
required for view backward pass src0 values must be copied to dst, because during addition we don't touch all dst elements in contrast to the normal add function.
Configuration menu - View commit details
-
Copy full SHA for cecd6c7 - Browse repository at this point
Copy the full SHA cecd6c7View commit details -
Configuration menu - View commit details
-
Copy full SHA for 124fdca - Browse repository at this point
Copy the full SHA 124fdcaView commit details -
Configuration menu - View commit details
-
Copy full SHA for 410a47a - Browse repository at this point
Copy the full SHA 410a47aView commit details -
fix ggml_forward_add functions to work correctly with transposed tensors
uses the same logic as in ggml_compute_forward_add_q_f32, but make it consistent across all ggml_compute_forward_add_... functions. this also slightly changes the mem access pattern of the different threads to works as in ggml_compute_forward_add_q_f32.
Configuration menu - View commit details
-
Copy full SHA for b9416d7 - Browse repository at this point
Copy the full SHA b9416d7View commit details -
fix ggml_forward_add1 functions to work correctly with transposed ten…
…sors uses the same logic as in ggml_compute_forward_add1_q_f32, but make it consistent across all ggml_compute_forward_add1_... functions. this also slightly changes the mem access pattern of the different threads to works as in ggml_compute_forward_add1_q_f32.
Configuration menu - View commit details
-
Copy full SHA for 339b2ad - Browse repository at this point
Copy the full SHA 339b2adView commit details -
Configuration menu - View commit details
-
Copy full SHA for 86b44a0 - Browse repository at this point
Copy the full SHA 86b44a0View commit details -
Configuration menu - View commit details
-
Copy full SHA for a7a8370 - Browse repository at this point
Copy the full SHA a7a8370View commit details -
Configuration menu - View commit details
-
Copy full SHA for b0555fc - Browse repository at this point
Copy the full SHA b0555fcView commit details -
fix sub, mul and div functions to work correctly with transposed tensors
uses the same logic as in add
Configuration menu - View commit details
-
Copy full SHA for 02d3fd0 - Browse repository at this point
Copy the full SHA 02d3fd0View commit details -
Configuration menu - View commit details
-
Copy full SHA for 3d21f26 - Browse repository at this point
Copy the full SHA 3d21f26View commit details -
successfully test transpose backward and permute for all permutations
also test sub, mul and div up to max n_dims
Configuration menu - View commit details
-
Copy full SHA for c601df9 - Browse repository at this point
Copy the full SHA c601df9View commit details -
test-grad0.c add TODO for view_2d and view_3d
add_at (required for view backward pass) is a bit tricky for n_dims > 1.
Configuration menu - View commit details
-
Copy full SHA for 1997152 - Browse repository at this point
Copy the full SHA 1997152View commit details -
Configuration menu - View commit details
-
Copy full SHA for d42531f - Browse repository at this point
Copy the full SHA d42531fView commit details -
Configuration menu - View commit details
-
Copy full SHA for 19f5159 - Browse repository at this point
Copy the full SHA 19f5159View commit details -
nargs and ndims was swapped, corrupting the stack
Configuration menu - View commit details
-
Copy full SHA for b9920e5 - Browse repository at this point
Copy the full SHA b9920e5View commit details -
Configuration menu - View commit details
-
Copy full SHA for 3dbd649 - Browse repository at this point
Copy the full SHA 3dbd649View commit details -
Configuration menu - View commit details
-
Copy full SHA for 7281f60 - Browse repository at this point
Copy the full SHA 7281f60View commit details -
Configuration menu - View commit details
-
Copy full SHA for 96e773b - Browse repository at this point
Copy the full SHA 96e773bView commit details -
Configuration menu - View commit details
-
Copy full SHA for f0302fa - Browse repository at this point
Copy the full SHA f0302faView commit details -
add nb parameters to add_at like in view. together with offset they define how to view dst and src0 during the add_at operation.
Configuration menu - View commit details
-
Copy full SHA for 8443638 - Browse repository at this point
Copy the full SHA 8443638View commit details -
Configuration menu - View commit details
-
Copy full SHA for b18b72d - Browse repository at this point
Copy the full SHA b18b72dView commit details -
fix backward pass for rms_norm
I would have used formulas from other frameworks, but they differed so I could not decide which is correct. Instead it was derived here in comment using manual forward-backward automatic differention of rms_norm and simplification.
Configuration menu - View commit details
-
Copy full SHA for 84a4b39 - Browse repository at this point
Copy the full SHA 84a4b39View commit details -
successfully test backward pass of rms_norm
some tests may fail when gradients are large. could not find a satisfying configuration to check for abs error and relative error that passes all tests while still actually testing the results with tight enough error bounds. when looking at the values the "failed" tests look actually ok. for example: rms_norm: ndims=2, i=0, k=2, x0=0.000153, xm=0.000053, xp=0.000253, f0=0.278594, f1=0.086213, g0=961.905457, g1=966.064941, eps=0.000100, error_abs=4.159485, error_rel=0.004324 it is due to the test logic in check_gradients that they fail.
Configuration menu - View commit details
-
Copy full SHA for 2ecc690 - Browse repository at this point
Copy the full SHA 2ecc690View commit details -
add todos for llama backward pass
- implementation for ADD1 backward pass should probably use sum instead of mean (but this backward pass is not required) - repeat is not yet tested and looks like it only works for single element src0 inputs.
Configuration menu - View commit details
-
Copy full SHA for 2277053 - Browse repository at this point
Copy the full SHA 2277053View commit details -
ggml_sum_rows(shape[a,b,c,d]) -> shape[1,b,c,d]
Configuration menu - View commit details
-
Copy full SHA for c4539ed - Browse repository at this point
Copy the full SHA c4539edView commit details -
Configuration menu - View commit details
-
Copy full SHA for ba62c79 - Browse repository at this point
Copy the full SHA ba62c79View commit details -
Configuration menu - View commit details
-
Copy full SHA for 8b5b2f0 - Browse repository at this point
Copy the full SHA 8b5b2f0View commit details -
Configuration menu - View commit details
-
Copy full SHA for 72bcfb5 - Browse repository at this point
Copy the full SHA 72bcfb5View commit details -
Configuration menu - View commit details
-
Copy full SHA for 1c4dc1e - Browse repository at this point
Copy the full SHA 1c4dc1eView commit details -
add baby-llama example training a very small llama model from scratch…
… to output a sinusoidal wave. had to increase maximum number of optimization parameters to train from scratch.
Configuration menu - View commit details
-
Copy full SHA for 8fde656 - Browse repository at this point
Copy the full SHA 8fde656View commit details -
Configuration menu - View commit details
-
Copy full SHA for 29a0f8b - Browse repository at this point
Copy the full SHA 29a0f8bView commit details -
switching from training with adam to lbfgs produces much better resul…
…ts in the baby-llama example
Configuration menu - View commit details
-
Copy full SHA for 5f23052 - Browse repository at this point
Copy the full SHA 5f23052View commit details -
Configuration menu - View commit details
-
Copy full SHA for bc1c13b - Browse repository at this point
Copy the full SHA bc1c13bView commit details
Commits on May 6, 2023
-
fix bug when using ggml_opt to optimize params in one context and use…
… a renewable context for eval and opt when not keeping gradients of model parameters they are overwritten by tensors created by opt, which may be invalid after opt context is renewed. so we need to keep the original gradients and make dups for opt
Configuration menu - View commit details
-
Copy full SHA for 83ee1cd - Browse repository at this point
Copy the full SHA 83ee1cdView commit details -
train on multiple examples, generate & print tokens with trained mode…
…l afterwards ctx0 for evaluation and optimization is renewed for each sample
Configuration menu - View commit details
-
Copy full SHA for f1d51d1 - Browse repository at this point
Copy the full SHA f1d51d1View commit details -
Configuration menu - View commit details
-
Copy full SHA for b4c273f - Browse repository at this point
Copy the full SHA b4c273fView commit details -
Configuration menu - View commit details
-
Copy full SHA for 8cf04fe - Browse repository at this point
Copy the full SHA 8cf04feView commit details -
Configuration menu - View commit details
-
Copy full SHA for 65d9f73 - Browse repository at this point
Copy the full SHA 65d9f73View commit details -
Configuration menu - View commit details
-
Copy full SHA for 5724628 - Browse repository at this point
Copy the full SHA 5724628View commit details -
Configuration menu - View commit details
-
Copy full SHA for 7a15a83 - Browse repository at this point
Copy the full SHA 7a15a83View commit details -
Configuration menu - View commit details
-
Copy full SHA for e6186d9 - Browse repository at this point
Copy the full SHA e6186d9View commit details -
Configuration menu - View commit details
-
Copy full SHA for 80223d9 - Browse repository at this point
Copy the full SHA 80223d9View commit details -
fix training get_example_targets
predict the next token, not the current token!
Configuration menu - View commit details
-
Copy full SHA for 73fd66e - Browse repository at this point
Copy the full SHA 73fd66eView commit details -
Configuration menu - View commit details
-
Copy full SHA for 7a5dec2 - Browse repository at this point
Copy the full SHA 7a5dec2View commit details -
optimize loss over multiple samples
this increases computation graph, need parallel batched forward for more efficiency.
Configuration menu - View commit details
-
Copy full SHA for 226521a - Browse repository at this point
Copy the full SHA 226521aView commit details -
Configuration menu - View commit details
-
Copy full SHA for 48bcc4d - Browse repository at this point
Copy the full SHA 48bcc4dView commit details -
add ggml_set(ctx, a, b) to set b in view of a and return modified a
necessary to set values into kv_self cache and properly propagate the gradients
Configuration menu - View commit details
-
Copy full SHA for 47561de - Browse repository at this point
Copy the full SHA 47561deView commit details -
fix kv_self gradients for training
use ggml_set instead of ggml_cpy to set kv_self cache with properly propagating gradients
Configuration menu - View commit details
-
Copy full SHA for 956511b - Browse repository at this point
Copy the full SHA 956511bView commit details -
replace inplace operations for training with copying operations to al…
…low gradient propagation
Configuration menu - View commit details
-
Copy full SHA for 561fbe0 - Browse repository at this point
Copy the full SHA 561fbe0View commit details -
Configuration menu - View commit details
-
Copy full SHA for e91b83b - Browse repository at this point
Copy the full SHA e91b83bView commit details
Commits on May 7, 2023
-
add trainable lora-only model with all big matrices C split into A,B …
…with A*B=C this is not a lora-finetune, but the whole model changed to have only low-rank "lora" matrices. training this instead of the normal model resulted in much worse results though...
Configuration menu - View commit details
-
Copy full SHA for 93201ab - Browse repository at this point
Copy the full SHA 93201abView commit details -
vastly improve training results
instead of logit targets 0 and 1 use -1 and +1.
Configuration menu - View commit details
-
Copy full SHA for 49d6daa - Browse repository at this point
Copy the full SHA 49d6daaView commit details -
Configuration menu - View commit details
-
Copy full SHA for e0de09d - Browse repository at this point
Copy the full SHA e0de09dView commit details -
Configuration menu - View commit details
-
Copy full SHA for 4764842 - Browse repository at this point
Copy the full SHA 4764842View commit details -
Merge branch 'master' into train-example
# Conflicts: # ggml.c # llama.cpp
Configuration menu - View commit details
-
Copy full SHA for ee565f3 - Browse repository at this point
Copy the full SHA ee565f3View commit details -
Configuration menu - View commit details
-
Copy full SHA for e643fa1 - Browse repository at this point
Copy the full SHA e643fa1View commit details -
Configuration menu - View commit details
-
Copy full SHA for d20ba6f - Browse repository at this point
Copy the full SHA d20ba6fView commit details -
Configuration menu - View commit details
-
Copy full SHA for 5d9fed7 - Browse repository at this point
Copy the full SHA 5d9fed7View commit details -
Configuration menu - View commit details
-
Copy full SHA for 47ad186 - Browse repository at this point
Copy the full SHA 47ad186View commit details -
Configuration menu - View commit details
-
Copy full SHA for 9dd8e40 - Browse repository at this point
Copy the full SHA 9dd8e40View commit details -
Configuration menu - View commit details
-
Copy full SHA for 660836f - Browse repository at this point
Copy the full SHA 660836fView commit details -
Configuration menu - View commit details
-
Copy full SHA for 7c8768f - Browse repository at this point
Copy the full SHA 7c8768fView commit details -
Configuration menu - View commit details
-
Copy full SHA for 2936dd6 - Browse repository at this point
Copy the full SHA 2936dd6View commit details -
reduce number of test-grad0 iterations
avoid exceeding timeout of automated tests
Configuration menu - View commit details
-
Copy full SHA for 4997bc5 - Browse repository at this point
Copy the full SHA 4997bc5View commit details -
Configuration menu - View commit details
-
Copy full SHA for f530106 - Browse repository at this point
Copy the full SHA f530106View commit details
Commits on May 8, 2023
-
Configuration menu - View commit details
-
Copy full SHA for 1ecbece - Browse repository at this point
Copy the full SHA 1ecbeceView commit details -
use c++ includes instead of c includes use std::min, std::max instead of MIN, MAX macros
Configuration menu - View commit details
-
Copy full SHA for dea9c93 - Browse repository at this point
Copy the full SHA dea9c93View commit details -
use c++ includes instead of c includes use std::min, std::max instead of MIN, MAX macros
Configuration menu - View commit details
-
Copy full SHA for 0d72207 - Browse repository at this point
Copy the full SHA 0d72207View commit details -
Configuration menu - View commit details
-
Copy full SHA for 78af3e9 - Browse repository at this point
Copy the full SHA 78af3e9View commit details -
Configuration menu - View commit details
-
Copy full SHA for 6cc42de - Browse repository at this point
Copy the full SHA 6cc42deView commit details -
swap arguments to vDSP_vdiv call
documentation for vDSP_vdiv states: "Note that B comes before A!"
Configuration menu - View commit details
-
Copy full SHA for cafbb78 - Browse repository at this point
Copy the full SHA cafbb78View commit details -
swap arguments to vDSP_vdiv call
documentation for vDSP_vdiv states: "Note that B comes before A!"
Configuration menu - View commit details
-
Copy full SHA for 9c3fe4e - Browse repository at this point
Copy the full SHA 9c3fe4eView commit details -
Configuration menu - View commit details
-
Copy full SHA for 6ca682b - Browse repository at this point
Copy the full SHA 6ca682bView commit details
Commits on May 11, 2023
-
Configuration menu - View commit details
-
Copy full SHA for 3e3ed95 - Browse repository at this point
Copy the full SHA 3e3ed95View commit details -
Configuration menu - View commit details
-
Copy full SHA for 581e5eb - Browse repository at this point
Copy the full SHA 581e5ebView commit details -
Configuration menu - View commit details
-
Copy full SHA for b9ef08c - Browse repository at this point
Copy the full SHA b9ef08cView commit details
Commits on May 13, 2023
-
Configuration menu - View commit details
-
Copy full SHA for f977243 - Browse repository at this point
Copy the full SHA f977243View commit details -
Configuration menu - View commit details
-
Copy full SHA for 33034cf - Browse repository at this point
Copy the full SHA 33034cfView commit details -
Configuration menu - View commit details
-
Copy full SHA for 092913e - Browse repository at this point
Copy the full SHA 092913eView commit details -
Configuration menu - View commit details
-
Copy full SHA for 95a487a - Browse repository at this point
Copy the full SHA 95a487aView commit details -
Configuration menu - View commit details
-
Copy full SHA for ef3d42a - Browse repository at this point
Copy the full SHA ef3d42aView commit details -
Configuration menu - View commit details
-
Copy full SHA for dae6ba2 - Browse repository at this point
Copy the full SHA dae6ba2View commit details