diff --git a/src/nf/nf_network.f90 b/src/nf/nf_network.f90 index 38bd3e7b..c8fb764c 100644 --- a/src/nf/nf_network.f90 +++ b/src/nf/nf_network.f90 @@ -3,7 +3,7 @@ module nf_network !! This module provides the network type to create new models. use nf_layer, only: layer - use nf_optimizers, only: sgd + use nf_optimizers, only: optimizer_base_type implicit none @@ -193,7 +193,7 @@ module subroutine train(self, input_data, output_data, batch_size, & !! Set to `size(input_data, dim=2)` for a batch gradient descent. integer, intent(in) :: epochs !! Number of epochs to run - type(sgd), intent(in) :: optimizer + class(optimizer_base_type), intent(in) :: optimizer !! Optimizer instance; currently this is an `sgd` optimizer type !! and it will be made to be a more general optimizer type. end subroutine train diff --git a/src/nf/nf_network_submodule.f90 b/src/nf/nf_network_submodule.f90 index df4e04cc..c85646ec 100644 --- a/src/nf/nf_network_submodule.f90 +++ b/src/nf/nf_network_submodule.f90 @@ -12,7 +12,7 @@ use nf_layer, only: layer use nf_layer_constructors, only: conv2d, dense, flatten, input, maxpool2d, reshape use nf_loss, only: quadratic_derivative - use nf_optimizers, only: sgd + use nf_optimizers, only: optimizer_base_type, sgd use nf_parallel, only: tile_indices implicit none @@ -426,7 +426,7 @@ module subroutine train(self, input_data, output_data, batch_size, & real, intent(in) :: output_data(:,:) integer, intent(in) :: batch_size integer, intent(in) :: epochs - type(sgd), intent(in) :: optimizer + class(optimizer_base_type), intent(in) :: optimizer real :: pos integer :: dataset_size @@ -439,26 +439,31 @@ module subroutine train(self, input_data, output_data, batch_size, & epoch_loop: do n = 1, epochs batch_loop: do i = 1, dataset_size / batch_size - ! Pull a random mini-batch from the dataset - call random_number(pos) - batch_start = int(pos * (dataset_size - batch_size + 1)) + 1 - batch_end = batch_start + batch_size - 1 - - ! FIXME shuffle in a way that doesn't require co_broadcast - call co_broadcast(batch_start, 1) - call co_broadcast(batch_end, 1) - - ! Distribute the batch in nearly equal pieces to all images - indices = tile_indices(batch_size) - istart = indices(1) + batch_start - 1 - iend = indices(2) + batch_start - 1 - - do concurrent(j = istart:iend) - call self % forward(input_data(:,j)) - call self % backward(output_data(:,j)) - end do - - call self % update(optimizer % learning_rate / batch_size) + ! Pull a random mini-batch from the dataset + call random_number(pos) + batch_start = int(pos * (dataset_size - batch_size + 1)) + 1 + batch_end = batch_start + batch_size - 1 + + ! FIXME shuffle in a way that doesn't require co_broadcast + call co_broadcast(batch_start, 1) + call co_broadcast(batch_end, 1) + + ! Distribute the batch in nearly equal pieces to all images + indices = tile_indices(batch_size) + istart = indices(1) + batch_start - 1 + iend = indices(2) + batch_start - 1 + + do concurrent(j = istart:iend) + call self % forward(input_data(:,j)) + call self % backward(output_data(:,j)) + end do + + select type (optimizer) + type is (sgd) + call self % update(optimizer % learning_rate / batch_size) + class default + error stop 'Unsupported optimizer' + end select end do batch_loop end do epoch_loop diff --git a/src/nf/nf_optimizers.f90 b/src/nf/nf_optimizers.f90 index 2ba89904..7d00a3cf 100644 --- a/src/nf/nf_optimizers.f90 +++ b/src/nf/nf_optimizers.f90 @@ -5,9 +5,13 @@ module nf_optimizers implicit none private - public :: sgd + public :: optimizer_base_type, sgd - type :: sgd + type, abstract :: optimizer_base_type + character(:), allocatable :: name + end type optimizer_base_type + + type, extends(optimizer_base_type) :: sgd !! Stochastic Gradient Descent optimizer real :: learning_rate real :: momentum = 0 !TODO