Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement backward passes for llama with small training llama from scratch example #1360

Merged
merged 110 commits into from
May 13, 2023

Commits on May 1, 2023

  1. implement 8 of 14 missing backward pass operations used by llama

    - GGML_OP_ADD_AT
    - GGML_OP_CPY
    - GGML_OP_MUL_MAT (src0.grad)
    - GGML_OP_PERMUTE
    - GGML_OP_RESHAPE
    - GGML_OP_SCALE
    - GGML_OP_TRANSPOSE
    - GGML_OP_VIEW
    
    implement additional ggml operation GGML_OP_ADD_AT, which is necessary for backward pass of GGML_OP_VIEW.
    
    this operation adds src1 to src0 with data offset, i.e. to view(src0, ..., offset).
    the values are return in a tensor size of src0. values outside of [data+offset:data+offset+nbytes(src1)] are just the original values from src0.
    
    still missing backward passes for llama:
    
    - GGML_OP_DIAG_MASK_INF
    - GGML_OP_GET_ROWS
    - GGML_OP_RMS_NORM
    - GGML_OP_ROPE
    - GGML_OP_SILU
    - GGML_OP_SOFT_MAX
    xaedes committed May 1, 2023
    Configuration menu
    Copy the full SHA
    73ac18d View commit details
    Browse the repository at this point in the history
  2. implement 5 of 6 missing backward pass operations used by llama

    - GGML_OP_DIAG_MASK_INF
    - GGML_OP_GET_ROWS
    - GGML_OP_RMS_NORM
    - GGML_OP_SILU
    - GGML_OP_SOFT_MAX
    
    add necessary ggml operations GGML_OP_ADD1, GGML_OP_SILU_BACK, GGML_OP_RMS_NORM_BACK, GGML_OP_DIAG_MASK_ZERO, and GGML_OP_ROPE_BACK
    
    GGML_OP_ADD1 is necessary to add a scalar value in the backward pass of GGML_OP_SOFT_MAX
    GGML_OP_ADD1 could also be replaced by using GGML_OP_ADD and GGML_OP_REPEAT, but the performance would be worse. additionally GGML_OP_REPEAT will return unexpected value when the the input to GGML_OP_SOFT_MAX contains only a single scalar. in this case GGML_OP_REPEAT will not return the value that should be repeated (src1) but the value which shape the result should take (src0). So in this case it can not replace GGML_OP_ADD1.
    
    GGML_OP_SILU_BACK, GGML_OP_RMS_NORM_BACK and GGML_OP_ROPE_BACK are necessary for backward pass of GGML_OP_SILU, GGML_OP_RMS_NORM and GGML_OP_ROPE. The backward pass for these functions cannot be easily composed of existing operations. Since the backward pass builds a computation graph we need operations forward pass implementations of the the required backward passes. Sounds a bit confusing at first, I know...
    
    GGML_OP_DIAG_MASK_ZERO is necessary for backward pass of GGML_OP_DIAG_MASK_INF.
    
    Some operations where previously inplace-only. for backward pass there needs to be non-inplace variants.
    staying consistent with other operations that have non-inplace and inplace variants, the operations are changed to non-inplace and
    functions with "_inplace" are added which are inplace.
    in llama we need to call the inplace variants so that it is implemented as before.
    for llama backward pass we need to use the non-inplace variants.
    
    still not completely implemented backward passes for llama:
    
    - GGML_OP_ROPE: needs forward pass for GGML_OP_ROPE_BACK
    - GGML_OP_GET_ROWS: only necessary for tokenizer
    xaedes committed May 1, 2023
    Configuration menu
    Copy the full SHA
    b164343 View commit details
    Browse the repository at this point in the history
  3. norm & rms_norm can not be threaded:

    after investigation rms norm for quite some time I come to the conclusion that neither norm, nor rms_norm can be threaded, because we need mean over all items, not just of the slices each thread sees.
    xaedes committed May 1, 2023
    Configuration menu
    Copy the full SHA
    b908007 View commit details
    Browse the repository at this point in the history
  4. remove already resolved TODO

    xaedes committed May 1, 2023
    Configuration menu
    Copy the full SHA
    36d8a05 View commit details
    Browse the repository at this point in the history
  5. Configuration menu
    Copy the full SHA
    488decf View commit details
    Browse the repository at this point in the history
  6. Configuration menu
    Copy the full SHA
    4e1f81d View commit details
    Browse the repository at this point in the history
  7. add test-grad0.c

    xaedes committed May 1, 2023
    Configuration menu
    Copy the full SHA
    0da2675 View commit details
    Browse the repository at this point in the history
  8. Configuration menu
    Copy the full SHA
    20e3c1d View commit details
    Browse the repository at this point in the history
  9. Configuration menu
    Copy the full SHA
    9345f4c View commit details
    Browse the repository at this point in the history
  10. Configuration menu
    Copy the full SHA
    9d6fc28 View commit details
    Browse the repository at this point in the history
  11. bug fixes for silu_back

    xaedes committed May 1, 2023
    Configuration menu
    Copy the full SHA
    6fb08b4 View commit details
    Browse the repository at this point in the history
  12. Configuration menu
    Copy the full SHA
    671e592 View commit details
    Browse the repository at this point in the history
  13. bug fix for scale backward pass

    use sum instead of mean for gradient of scalar scale parameter
    xaedes committed May 1, 2023
    Configuration menu
    Copy the full SHA
    a367eb9 View commit details
    Browse the repository at this point in the history
  14. Configuration menu
    Copy the full SHA
    0197bcb View commit details
    Browse the repository at this point in the history
  15. improve performance of sum backward pass

    use add1(x,y) instead of add(x,repeat(y,x))
    xaedes committed May 1, 2023
    Configuration menu
    Copy the full SHA
    bfe5072 View commit details
    Browse the repository at this point in the history
  16. improve performance of sqr backward pass

    use scale(x,y) instead of mul(x,repeat(y,x))
    xaedes committed May 1, 2023
    Configuration menu
    Copy the full SHA
    b583136 View commit details
    Browse the repository at this point in the history
  17. Configuration menu
    Copy the full SHA
    7571147 View commit details
    Browse the repository at this point in the history
  18. bug fix for cpy backward pass

    xaedes committed May 1, 2023
    Configuration menu
    Copy the full SHA
    0ea8201 View commit details
    Browse the repository at this point in the history
  19. Configuration menu
    Copy the full SHA
    b2bd822 View commit details
    Browse the repository at this point in the history
  20. Configuration menu
    Copy the full SHA
    c483a7d View commit details
    Browse the repository at this point in the history
  21. Configuration menu
    Copy the full SHA
    ecf949b View commit details
    Browse the repository at this point in the history
  22. add test-opt.c

    this uses ggml_opt to train a,b for minimal e=sum(sqr(c - a*b)) for random initial a,b,c
    xaedes committed May 1, 2023
    Configuration menu
    Copy the full SHA
    54ab300 View commit details
    Browse the repository at this point in the history
  23. correctly implement softmax backward pass using new operation ggml_diag

    ggml_diag constructs diagonal matrices with entries.
    ggml_diag(shape[a,1,c,d]) -> shape[a,a,c,d]
    xaedes committed May 1, 2023
    Configuration menu
    Copy the full SHA
    1a80e9a View commit details
    Browse the repository at this point in the history
  24. Configuration menu
    Copy the full SHA
    fea42be View commit details
    Browse the repository at this point in the history
  25. align shape annotations

    xaedes committed May 1, 2023
    Configuration menu
    Copy the full SHA
    9310650 View commit details
    Browse the repository at this point in the history
  26. Configuration menu
    Copy the full SHA
    38675e5 View commit details
    Browse the repository at this point in the history
  27. de-duplicate ggml_forward_dup code taking care of contiguous tensors …

    …of same type.
    
    with this we can duplicate tensor of any typ as long as they are contiguous.
    xaedes committed May 1, 2023
    Configuration menu
    Copy the full SHA
    c1a8893 View commit details
    Browse the repository at this point in the history
  28. fix ggml_compute_forward_dup_same_cont for when nelements < nthreads

    when more threads are used than elements exist ie1 was less than ie0, resulting in invalid negative byte count argument in memcpy
    xaedes committed May 1, 2023
    Configuration menu
    Copy the full SHA
    83fa6b3 View commit details
    Browse the repository at this point in the history
  29. bug fix for add_at forward

    required for view backward pass
    
    src0 values must be copied to dst, because during addition we don't touch all dst elements in contrast to the normal add function.
    xaedes committed May 1, 2023
    Configuration menu
    Copy the full SHA
    cecd6c7 View commit details
    Browse the repository at this point in the history
  30. Configuration menu
    Copy the full SHA
    124fdca View commit details
    Browse the repository at this point in the history
  31. minor code format improvement

    xaedes committed May 1, 2023
    Configuration menu
    Copy the full SHA
    410a47a View commit details
    Browse the repository at this point in the history
  32. fix ggml_forward_add functions to work correctly with transposed tensors

    uses the same logic as in ggml_compute_forward_add_q_f32, but make it consistent across all ggml_compute_forward_add_... functions.
    this also slightly changes the mem access pattern of the different threads to works as in ggml_compute_forward_add_q_f32.
    xaedes committed May 1, 2023
    Configuration menu
    Copy the full SHA
    b9416d7 View commit details
    Browse the repository at this point in the history
  33. fix ggml_forward_add1 functions to work correctly with transposed ten…

    …sors
    
    uses the same logic as in ggml_compute_forward_add1_q_f32, but make it consistent across all ggml_compute_forward_add1_... functions.
    this also slightly changes the mem access pattern of the different threads to works as in ggml_compute_forward_add1_q_f32.
    xaedes committed May 1, 2023
    Configuration menu
    Copy the full SHA
    339b2ad View commit details
    Browse the repository at this point in the history
  34. Configuration menu
    Copy the full SHA
    86b44a0 View commit details
    Browse the repository at this point in the history
  35. Configuration menu
    Copy the full SHA
    a7a8370 View commit details
    Browse the repository at this point in the history
  36. some minor test-grad0 fixes

    xaedes committed May 1, 2023
    Configuration menu
    Copy the full SHA
    b0555fc View commit details
    Browse the repository at this point in the history
  37. fix sub, mul and div functions to work correctly with transposed tensors

    uses the same logic as in add
    xaedes committed May 1, 2023
    Configuration menu
    Copy the full SHA
    02d3fd0 View commit details
    Browse the repository at this point in the history
  38. Configuration menu
    Copy the full SHA
    3d21f26 View commit details
    Browse the repository at this point in the history
  39. successfully test transpose backward and permute for all permutations

    also test sub, mul and div up to max n_dims
    xaedes committed May 1, 2023
    Configuration menu
    Copy the full SHA
    c601df9 View commit details
    Browse the repository at this point in the history
  40. test-grad0.c add TODO for view_2d and view_3d

    add_at (required for view backward pass) is a bit tricky for n_dims > 1.
    xaedes committed May 1, 2023
    Configuration menu
    Copy the full SHA
    1997152 View commit details
    Browse the repository at this point in the history
  41. fix comments

    xaedes committed May 1, 2023
    Configuration menu
    Copy the full SHA
    d42531f View commit details
    Browse the repository at this point in the history
  42. Configuration menu
    Copy the full SHA
    19f5159 View commit details
    Browse the repository at this point in the history
  43. test-grad0 : fix test for div

    nargs and ndims was swapped, corrupting the stack
    xaedes committed May 1, 2023
    Configuration menu
    Copy the full SHA
    b9920e5 View commit details
    Browse the repository at this point in the history
  44. Configuration menu
    Copy the full SHA
    3dbd649 View commit details
    Browse the repository at this point in the history
  45. Configuration menu
    Copy the full SHA
    7281f60 View commit details
    Browse the repository at this point in the history
  46. fix get rows backward pass

    xaedes committed May 1, 2023
    Configuration menu
    Copy the full SHA
    96e773b View commit details
    Browse the repository at this point in the history
  47. Configuration menu
    Copy the full SHA
    f0302fa View commit details
    Browse the repository at this point in the history
  48. fix view backward pass

    add nb parameters to add_at like in view.
    together with offset they define how to view dst and src0 during the add_at operation.
    xaedes committed May 1, 2023
    Configuration menu
    Copy the full SHA
    8443638 View commit details
    Browse the repository at this point in the history
  49. Configuration menu
    Copy the full SHA
    b18b72d View commit details
    Browse the repository at this point in the history
  50. fix backward pass for rms_norm

    I would have used formulas from other frameworks, but they differed so I could not decide which is correct.
    Instead it was derived here in comment using manual forward-backward automatic differention of rms_norm and simplification.
    xaedes committed May 1, 2023
    Configuration menu
    Copy the full SHA
    84a4b39 View commit details
    Browse the repository at this point in the history
  51. successfully test backward pass of rms_norm

    some tests may fail when gradients are large.
    could not find a satisfying configuration to check for abs error and relative error that passes all tests while still actually testing the results with tight enough error bounds.
    when looking at the values the "failed" tests look actually ok. for example:
    
    rms_norm: ndims=2, i=0, k=2, x0=0.000153, xm=0.000053, xp=0.000253, f0=0.278594, f1=0.086213, g0=961.905457, g1=966.064941, eps=0.000100, error_abs=4.159485, error_rel=0.004324
    
    it is due to the test logic in check_gradients that they fail.
    xaedes committed May 1, 2023
    Configuration menu
    Copy the full SHA
    2ecc690 View commit details
    Browse the repository at this point in the history
  52. add todos for llama backward pass

    - implementation for ADD1 backward pass should probably use sum instead of mean (but this backward pass is not required)
    - repeat is not yet tested and looks like it only works for single element src0 inputs.
    xaedes committed May 1, 2023
    Configuration menu
    Copy the full SHA
    2277053 View commit details
    Browse the repository at this point in the history
  53. add operation ggml_sum_rows

    ggml_sum_rows(shape[a,b,c,d]) -> shape[1,b,c,d]
    xaedes committed May 1, 2023
    Configuration menu
    Copy the full SHA
    c4539ed View commit details
    Browse the repository at this point in the history
  54. add missing GGML_OP_SUM_ROWS

    xaedes committed May 1, 2023
    Configuration menu
    Copy the full SHA
    ba62c79 View commit details
    Browse the repository at this point in the history
  55. fix backward pass for repeat

    requires ggml_sum_rows
    xaedes committed May 1, 2023
    Configuration menu
    Copy the full SHA
    8b5b2f0 View commit details
    Browse the repository at this point in the history
  56. Configuration menu
    Copy the full SHA
    72bcfb5 View commit details
    Browse the repository at this point in the history
  57. Configuration menu
    Copy the full SHA
    1c4dc1e View commit details
    Browse the repository at this point in the history
  58. add baby-llama example training a very small llama model from scratch…

    … to output a sinusoidal wave.
    
    had to increase maximum number of optimization parameters to train from scratch.
    xaedes committed May 1, 2023
    Configuration menu
    Copy the full SHA
    8fde656 View commit details
    Browse the repository at this point in the history
  59. Configuration menu
    Copy the full SHA
    29a0f8b View commit details
    Browse the repository at this point in the history
  60. switching from training with adam to lbfgs produces much better resul…

    …ts in the baby-llama example
    xaedes committed May 1, 2023
    Configuration menu
    Copy the full SHA
    5f23052 View commit details
    Browse the repository at this point in the history
  61. Configuration menu
    Copy the full SHA
    bc1c13b View commit details
    Browse the repository at this point in the history

Commits on May 6, 2023

  1. fix bug when using ggml_opt to optimize params in one context and use…

    … a renewable context for eval and opt
    
    when not keeping gradients of model parameters they are overwritten by tensors created by opt, which may be invalid after opt context is renewed.
    so we need to keep the original gradients and make dups for opt
    xaedes committed May 6, 2023
    Configuration menu
    Copy the full SHA
    83ee1cd View commit details
    Browse the repository at this point in the history
  2. train on multiple examples, generate & print tokens with trained mode…

    …l afterwards
    
    ctx0 for evaluation and optimization is renewed for each sample
    xaedes committed May 6, 2023
    Configuration menu
    Copy the full SHA
    f1d51d1 View commit details
    Browse the repository at this point in the history
  3. Configuration menu
    Copy the full SHA
    b4c273f View commit details
    Browse the repository at this point in the history
  4. Configuration menu
    Copy the full SHA
    8cf04fe View commit details
    Browse the repository at this point in the history
  5. Configuration menu
    Copy the full SHA
    65d9f73 View commit details
    Browse the repository at this point in the history
  6. Configuration menu
    Copy the full SHA
    5724628 View commit details
    Browse the repository at this point in the history
  7. Configuration menu
    Copy the full SHA
    7a15a83 View commit details
    Browse the repository at this point in the history
  8. Configuration menu
    Copy the full SHA
    e6186d9 View commit details
    Browse the repository at this point in the history
  9. Configuration menu
    Copy the full SHA
    80223d9 View commit details
    Browse the repository at this point in the history
  10. fix training get_example_targets

    predict the next token, not the current token!
    xaedes committed May 6, 2023
    Configuration menu
    Copy the full SHA
    73fd66e View commit details
    Browse the repository at this point in the history
  11. Configuration menu
    Copy the full SHA
    7a5dec2 View commit details
    Browse the repository at this point in the history
  12. optimize loss over multiple samples

    this increases computation graph, need parallel batched forward for more efficiency.
    xaedes committed May 6, 2023
    Configuration menu
    Copy the full SHA
    226521a View commit details
    Browse the repository at this point in the history
  13. Configuration menu
    Copy the full SHA
    48bcc4d View commit details
    Browse the repository at this point in the history
  14. add ggml_set(ctx, a, b) to set b in view of a and return modified a

    necessary to set values into kv_self cache and properly propagate the gradients
    xaedes committed May 6, 2023
    Configuration menu
    Copy the full SHA
    47561de View commit details
    Browse the repository at this point in the history
  15. fix kv_self gradients for training

    use ggml_set instead of ggml_cpy to set kv_self cache with properly propagating gradients
    xaedes committed May 6, 2023
    Configuration menu
    Copy the full SHA
    956511b View commit details
    Browse the repository at this point in the history
  16. replace inplace operations for training with copying operations to al…

    …low gradient propagation
    xaedes committed May 6, 2023
    Configuration menu
    Copy the full SHA
    561fbe0 View commit details
    Browse the repository at this point in the history
  17. Configuration menu
    Copy the full SHA
    e91b83b View commit details
    Browse the repository at this point in the history

Commits on May 7, 2023

  1. add trainable lora-only model with all big matrices C split into A,B …

    …with A*B=C
    
    this is not a lora-finetune, but the whole model changed to have only low-rank "lora" matrices.
    
    training this instead of the normal model resulted in much worse results though...
    xaedes committed May 7, 2023
    Configuration menu
    Copy the full SHA
    93201ab View commit details
    Browse the repository at this point in the history
  2. vastly improve training results

    instead of logit targets 0 and 1 use -1 and +1.
    xaedes committed May 7, 2023
    Configuration menu
    Copy the full SHA
    49d6daa View commit details
    Browse the repository at this point in the history
  3. shorten code using a variable

    xaedes committed May 7, 2023
    Configuration menu
    Copy the full SHA
    e0de09d View commit details
    Browse the repository at this point in the history
  4. Configuration menu
    Copy the full SHA
    4764842 View commit details
    Browse the repository at this point in the history
  5. Merge branch 'master' into train-example

    # Conflicts:
    #	ggml.c
    #	llama.cpp
    xaedes committed May 7, 2023
    Configuration menu
    Copy the full SHA
    ee565f3 View commit details
    Browse the repository at this point in the history
  6. Configuration menu
    Copy the full SHA
    e643fa1 View commit details
    Browse the repository at this point in the history
  7. Configuration menu
    Copy the full SHA
    d20ba6f View commit details
    Browse the repository at this point in the history
  8. Configuration menu
    Copy the full SHA
    5d9fed7 View commit details
    Browse the repository at this point in the history
  9. Configuration menu
    Copy the full SHA
    47ad186 View commit details
    Browse the repository at this point in the history
  10. Configuration menu
    Copy the full SHA
    9dd8e40 View commit details
    Browse the repository at this point in the history
  11. fix call to ggml_set_name

    xaedes committed May 7, 2023
    Configuration menu
    Copy the full SHA
    660836f View commit details
    Browse the repository at this point in the history
  12. Configuration menu
    Copy the full SHA
    7c8768f View commit details
    Browse the repository at this point in the history
  13. remove trailing whitespace

    xaedes committed May 7, 2023
    Configuration menu
    Copy the full SHA
    2936dd6 View commit details
    Browse the repository at this point in the history
  14. reduce number of test-grad0 iterations

    avoid exceeding timeout of automated tests
    xaedes committed May 7, 2023
    Configuration menu
    Copy the full SHA
    4997bc5 View commit details
    Browse the repository at this point in the history
  15. Configuration menu
    Copy the full SHA
    f530106 View commit details
    Browse the repository at this point in the history

Commits on May 8, 2023

  1. Configuration menu
    Copy the full SHA
    1ecbece View commit details
    Browse the repository at this point in the history
  2. c++ in baby-llama example

    use c++ includes instead of c includes
    use std::min, std::max instead of MIN, MAX macros
    xaedes committed May 8, 2023
    Configuration menu
    Copy the full SHA
    dea9c93 View commit details
    Browse the repository at this point in the history
  3. c++ in baby-llama example

    use c++ includes instead of c includes
    use std::min, std::max instead of MIN, MAX macros
    xaedes committed May 8, 2023
    Configuration menu
    Copy the full SHA
    0d72207 View commit details
    Browse the repository at this point in the history
  4. Configuration menu
    Copy the full SHA
    78af3e9 View commit details
    Browse the repository at this point in the history
  5. Configuration menu
    Copy the full SHA
    6cc42de View commit details
    Browse the repository at this point in the history
  6. swap arguments to vDSP_vdiv call

    documentation for vDSP_vdiv states: "Note that B comes before A!"
    xaedes committed May 8, 2023
    Configuration menu
    Copy the full SHA
    cafbb78 View commit details
    Browse the repository at this point in the history
  7. swap arguments to vDSP_vdiv call

    documentation for vDSP_vdiv states: "Note that B comes before A!"
    xaedes authored and ggerganov committed May 8, 2023
    Configuration menu
    Copy the full SHA
    9c3fe4e View commit details
    Browse the repository at this point in the history
  8. Configuration menu
    Copy the full SHA
    6ca682b View commit details
    Browse the repository at this point in the history

Commits on May 11, 2023

  1. Configuration menu
    Copy the full SHA
    3e3ed95 View commit details
    Browse the repository at this point in the history
  2. Configuration menu
    Copy the full SHA
    581e5eb View commit details
    Browse the repository at this point in the history
  3. remove trailing whitespace

    xaedes committed May 11, 2023
    Configuration menu
    Copy the full SHA
    b9ef08c View commit details
    Browse the repository at this point in the history

Commits on May 13, 2023

  1. Configuration menu
    Copy the full SHA
    f977243 View commit details
    Browse the repository at this point in the history
  2. Configuration menu
    Copy the full SHA
    33034cf View commit details
    Browse the repository at this point in the history
  3. Configuration menu
    Copy the full SHA
    092913e View commit details
    Browse the repository at this point in the history
  4. ggml : remove Q4_2 remnants

    ggerganov committed May 13, 2023
    Configuration menu
    Copy the full SHA
    95a487a View commit details
    Browse the repository at this point in the history
  5. Configuration menu
    Copy the full SHA
    ef3d42a View commit details
    Browse the repository at this point in the history
  6. Configuration menu
    Copy the full SHA
    dae6ba2 View commit details
    Browse the repository at this point in the history