load_mnist Subroutine

public subroutine load_mnist(tr_images, tr_labels, te_images, te_labels, va_images, va_labels)

Loads the MNIST dataset into arrays.

Arguments

Type IntentOptional AttributesName
real(kind=rk), intent(inout), allocatable:: tr_images(:,:)
real(kind=rk), intent(inout), allocatable:: tr_labels(:)
real(kind=rk), intent(inout), allocatable:: te_images(:,:)
real(kind=rk), intent(inout), allocatable:: te_labels(:)
real(kind=rk), intent(inout), optional allocatable:: va_images(:,:)
real(kind=rk), intent(inout), optional allocatable:: va_labels(:)

Calls

proc~~load_mnist~~CallsGraph proc~load_mnist load_mnist interface~read_binary_file read_binary_file proc~load_mnist->interface~read_binary_file proc~read_binary_file_2d read_binary_file_2d interface~read_binary_file->proc~read_binary_file_2d proc~read_binary_file_1d read_binary_file_1d interface~read_binary_file->proc~read_binary_file_1d

Contents

Source Code


Source Code

  subroutine load_mnist(tr_images, tr_labels, te_images,&
                        te_labels, va_images, va_labels)
    !! Loads the MNIST dataset into arrays.
    real(rk), allocatable, intent(in out) :: tr_images(:,:), tr_labels(:)
    real(rk), allocatable, intent(in out) :: te_images(:,:), te_labels(:)
    real(rk), allocatable, intent(in out), optional :: va_images(:,:), va_labels(:)
    integer(ik), parameter :: dtype = 4, image_size = 784
    integer(ik), parameter :: tr_nimages = 50000
    integer(ik), parameter :: te_nimages = 10000
    integer(ik), parameter :: va_nimages = 10000

    call read_binary_file('data/mnist/mnist_training_images.dat',&
                          dtype, image_size, tr_nimages, tr_images)
    call read_binary_file('data/mnist/mnist_training_labels.dat',&
                          dtype, tr_nimages, tr_labels)

    call read_binary_file('data/mnist/mnist_testing_images.dat',&
                          dtype, image_size, te_nimages, te_images)
    call read_binary_file('data/mnist/mnist_testing_labels.dat',&
                          dtype, te_nimages, te_labels)

    if (present(va_images) .and. present(va_labels)) then
      call read_binary_file('data/mnist/mnist_validation_images.dat',&
                            dtype, image_size, va_nimages, va_images)
      call read_binary_file('data/mnist/mnist_validation_labels.dat',&
                            dtype, va_nimages, va_labels)
    end if

  end subroutine load_mnist