train_batch Subroutine

private subroutine train_batch(self, x, y, eta)

Trains a network using input data x and output data y, and learning rate eta. The learning rate is normalized with the size of the data batch. mini-batch size number of layers

Arguments

Type IntentOptional AttributesName
class(network_type), intent(inout) :: self
real(kind=rk), intent(in) :: x(:,:)
real(kind=rk), intent(in) :: y(:,:)
real(kind=rk), intent(in) :: eta

Calls

proc~~train_batch~~CallsGraph proc~train_batch train_batch proc~dw_co_sum dw_co_sum proc~train_batch->proc~dw_co_sum proc~db_co_sum db_co_sum proc~train_batch->proc~db_co_sum proc~tile_indices tile_indices proc~train_batch->proc~tile_indices proc~db_init db_init proc~train_batch->proc~db_init proc~dw_init dw_init proc~train_batch->proc~dw_init interface~array1d array1d proc~db_init->interface~array1d interface~array2d array2d proc~dw_init->interface~array2d proc~array2d_constructor array2d_constructor interface~array2d->proc~array2d_constructor proc~array1d_constructor array1d_constructor interface~array1d->proc~array1d_constructor

Contents

Source Code


Source Code

  subroutine train_batch(self, x, y, eta)
    !! Trains a network using input data x and output data y,
    !! and learning rate eta. The learning rate is normalized
    !! with the size of the data batch.
    class(network_type), intent(in out) :: self
    real(rk), intent(in) :: x(:,:), y(:,:), eta
    type(array1d), allocatable :: db(:), db_batch(:)
    type(array2d), allocatable :: dw(:), dw_batch(:)
    integer(ik) :: i, im, n, nm
    integer(ik) :: is, ie, indices(2)

    im = size(x, dim=2) !! mini-batch size
    nm = size(self % dims) !! number of layers

    ! get start and end index for mini-batch
    indices = tile_indices(im)
    is = indices(1)
    ie = indices(2)

    call db_init(db_batch, self % dims)
    call dw_init(dw_batch, self % dims)

    do concurrent(i = is:ie)
      call self % fwdprop(x(:,i))
      call self % backprop(y(:,i), dw, db)
      do concurrent(n = 1:nm)
        dw_batch(n) % array =  dw_batch(n) % array + dw(n) % array
        db_batch(n) % array =  db_batch(n) % array + db(n) % array
      end do
    end do

    if (num_images() > 1) then
      call dw_co_sum(dw_batch)
      call db_co_sum(db_batch)
    end if

    call self % update(dw_batch, db_batch, eta / im)

  end subroutine train_batch