Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion example/cnn.f90
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ program cnn
allocate(x(3,32,32))
call random_number(x)

print *, 'Output:', net % output(x)
print *, 'Output:', net % predict(x)

end program cnn
2 changes: 1 addition & 1 deletion example/cnn_from_keras.f90
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ real function accuracy(net, x, y)
integer :: i, good
good = 0
do i = 1, size(x, dim=4)
if (all(maxloc(net % output(x(:,:,:,i))) == maxloc(y(:,i)))) then
if (all(maxloc(net % predict(x(:,:,:,i))) == maxloc(y(:,i)))) then
good = good + 1
end if
end do
Expand Down
2 changes: 1 addition & 1 deletion example/dense_from_keras.f90
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ real function accuracy(net, x, y)
integer :: i, good
good = 0
do i = 1, size(x, dim=2)
if (all(maxloc(net % output(x(:,i))) == maxloc(y(:,i)))) then
if (all(maxloc(net % predict(x(:,i))) == maxloc(y(:,i)))) then
good = good + 1
end if
end do
Expand Down
2 changes: 1 addition & 1 deletion example/mnist.f90
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ real function accuracy(net, x, y)
integer :: i, good
good = 0
do i = 1, size(x, dim=2)
if (all(maxloc(net % output(x(:,i))) == maxloc(y(:,i)))) then
if (all(maxloc(net % predict(x(:,i))) == maxloc(y(:,i)))) then
good = good + 1
end if
end do
Expand Down
2 changes: 1 addition & 1 deletion example/simple.f90
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ program simple
call net % update(1.)

if (mod(n, 50) == 0) &
print '(i4,2(3x,f8.6))', n, net % output(x)
print '(i4,2(3x,f8.6))', n, net % predict(x)

end do

Expand Down
2 changes: 1 addition & 1 deletion example/sine.f90
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ program sine
call net % update(1.)

if (mod(n, 10000) == 0) then
ypred = [(net % output([xtest(i)]), i = 1, test_size)]
ypred = [(net % predict([xtest(i)]), i = 1, test_size)]
print '(i0,1x,f9.6)', n, sum((ypred - ytest)**2) / size(ypred)
end if

Expand Down
26 changes: 13 additions & 13 deletions src/nf/nf_network.f90
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ module nf_network

procedure, private :: forward_1d
procedure, private :: forward_3d
procedure, private :: output_1d
procedure, private :: output_3d
procedure, private :: output_batch_1d
procedure, private :: output_batch_3d
procedure, private :: predict_1d
procedure, private :: predict_3d
procedure, private :: predict_batch_1d
procedure, private :: predict_batch_3d

generic :: forward => forward_1d, forward_3d
generic :: output => output_1d, output_3d, output_batch_1d, output_batch_3d
generic :: predict => predict_1d, predict_3d, predict_batch_1d, predict_batch_3d

end type network

Expand Down Expand Up @@ -89,45 +89,45 @@ end subroutine forward_3d

interface output

module function output_1d(self, input) result(res)
module function predict_1d(self, input) result(res)
!! Return the output of the network given the input 1-d array.
class(network), intent(in out) :: self
!! Network instance
real, intent(in) :: input(:)
!! Input data
real, allocatable :: res(:)
!! Output of the network
end function output_1d
end function predict_1d

module function output_3d(self, input) result(res)
module function predict_3d(self, input) result(res)
!! Return the output of the network given the input 3-d array.
class(network), intent(in out) :: self
!! Network instance
real, intent(in) :: input(:,:,:)
!! Input data
real, allocatable :: res(:)
!! Output of the network
end function output_3d
end function predict_3d

module function output_batch_1d(self, input) result(res)
module function predict_batch_1d(self, input) result(res)
!! Return the output of the network given an input batch of 3-d data.
class(network), intent(in out) :: self
!! Network instance
real, intent(in) :: input(:,:)
!! Input data; the last dimension is the batch
real, allocatable :: res(:,:)
!! Output of the network; the last dimension is the batch
end function output_batch_1d
end function predict_batch_1d

module function output_batch_3d(self, input) result(res)
module function predict_batch_3d(self, input) result(res)
!! Return the output of the network given an input batch of 3-d data.
class(network), intent(in out) :: self
!! Network instance
real, intent(in) :: input(:,:,:,:)
!! Input data; the last dimension is the batch
real, allocatable :: res(:,:)
!! Output of the network; the last dimension is the batch
end function output_batch_3d
end function predict_batch_3d

end interface output

Expand Down
40 changes: 20 additions & 20 deletions src/nf/nf_network_submodule.f90
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ pure module subroutine forward_3d(self, input)
end subroutine forward_3d


module function output_1d(self, input) result(res)
module function predict_1d(self, input) result(res)
class(network), intent(in out) :: self
real, intent(in) :: input(:)
real, allocatable :: res(:)
Expand All @@ -263,10 +263,10 @@ module function output_1d(self, input) result(res)
error stop 'network % output not implemented for this output layer'
end select

end function output_1d
end function predict_1d


module function output_3d(self, input) result(res)
module function predict_3d(self, input) result(res)
class(network), intent(in out) :: self
real, intent(in) :: input(:,:,:)
real, allocatable :: res(:)
Expand All @@ -288,10 +288,10 @@ module function output_3d(self, input) result(res)
error stop 'network % output not implemented for this output layer'
end select

end function output_3d
end function predict_3d


module function output_batch_1d(self, input) result(res)
module function predict_batch_1d(self, input) result(res)
class(network), intent(in out) :: self
real, intent(in) :: input(:,:)
real, allocatable :: res(:,:)
Expand All @@ -318,10 +318,10 @@ module function output_batch_1d(self, input) result(res)

end do batch

end function output_batch_1d
end function predict_batch_1d


module function output_batch_3d(self, input) result(res)
module function predict_batch_3d(self, input) result(res)
class(network), intent(in out) :: self
real, intent(in) :: input(:,:,:,:)
real, allocatable :: res(:,:)
Expand All @@ -335,23 +335,23 @@ module function output_batch_3d(self, input) result(res)

batch: do concurrent(i = 1:batch_size)

call self % forward(input(:,:,:,i))
call self % forward(input(:,:,:,i))

select type(output_layer => self % layers(num_layers) % p)
type is(conv2d_layer)
!FIXME flatten the result for now; find a better solution
res(:,i) = pack(output_layer % output, .true.)
type is(dense_layer)
res(:,i) = output_layer % output
type is(flatten_layer)
res(:,i) = output_layer % output
class default
error stop 'network % output not implemented for this output layer'
end select
select type(output_layer => self % layers(num_layers) % p)
type is(conv2d_layer)
!FIXME flatten the result for now; find a better solution
res(:,i) = pack(output_layer % output, .true.)
type is(dense_layer)
res(:,i) = output_layer % output
type is(flatten_layer)
res(:,i) = output_layer % output
class default
error stop 'network % output not implemented for this output layer'
end select

end do batch

end function output_batch_3d
end function predict_batch_3d


module subroutine print_info(self)
Expand Down
2 changes: 1 addition & 1 deletion test/test_cnn_from_keras.f90
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ real function accuracy(net, x, y)
integer :: i, good
good = 0
do i = 1, size(x, dim=4)
if (all(maxloc(net % output(x(:,:,:,i))) == maxloc(y(:,i)))) then
if (all(maxloc(net % predict(x(:,:,:,i))) == maxloc(y(:,i)))) then
good = good + 1
end if
end do
Expand Down
4 changes: 2 additions & 2 deletions test/test_dense_network.f90
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ program test_dense_network
ok = .false.
end if

if (.not. all(net % output([0.]) == 0.5)) then
if (.not. all(net % predict([0.]) == 0.5)) then
write(stderr, '(a)') &
'dense network should output exactly 0.5 for input 0.. failed'
ok = .false.
Expand All @@ -35,7 +35,7 @@ program test_dense_network
call net % forward(x)
call net % backward(y)
call net % update(1.)
if (all(abs(net % output(x) - y) < tolerance)) exit
if (all(abs(net % predict(x) - y) < tolerance)) exit
end do

if (.not. n <= num_iterations) then
Expand Down
2 changes: 1 addition & 1 deletion test/test_dense_network_from_keras.f90
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ real function accuracy(net, x, y)
integer :: i, good
good = 0
do i = 1, size(x, dim=2)
if (all(maxloc(net % output(x(:,i))) == maxloc(y(:,i)))) then
if (all(maxloc(net % predict(x(:,i))) == maxloc(y(:,i)))) then
good = good + 1
end if
end do
Expand Down