Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Define Base.close(::Session) #342

Merged
merged 3 commits into from
Oct 24, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -554,8 +554,7 @@ mutable struct Session
this = new(ptr, graph)
check_status(status)
finalizer(this, self->begin
status = Status()
@tfcall(:TF_DeleteSession, Void, (Ptr{Void}, Ptr{Void}), self.ptr, status.ptr)
close(self)
end)
return this
end
Expand All @@ -571,6 +570,21 @@ mutable struct Session
end
end

"""
close(sess::Session)

Closes the TensorFlow session, freeing the associated computational resources.
"""
function Base.close(sess::Session)
if sess.ptr != C_NULL
status = Status()
@tfcall(:TF_DeleteSession, Void, (Ptr{Void}, Ptr{Void}), sess.ptr, status.ptr)
check_status(status)
sess.ptr = C_NULL
end
return nothing
end


mutable struct Buffer
ptr::Ptr{Void}
Expand Down
14 changes: 13 additions & 1 deletion src/run.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,18 @@ function build_input(tensor_map::Dict)
input_tensors, input_values
end

struct ClosedSessionError <: Exception
end

function Base.show(io::IO, err::ClosedSessionError)
print(io, "An operation was attempted on a closed TensorFlow session.")
end

function run(sess::Session, inputs, input_values, outputs, targets)
#Low level run, without size checking, and type conversion etc.

if sess.ptr == C_NULL
throw(ClosedSessionError())
end
status = Status()
output_values = fill(C_NULL, length(outputs))
input_tensors = [RawTensor(x) for x in input_values]
Expand Down Expand Up @@ -184,6 +193,9 @@ end


"""
run(sess::Session, output, input_dict::Dict)


Compute the result of one of more operations in the computation graph.
"""
function run(sess::Session, output, input_dict)
Expand Down
9 changes: 9 additions & 0 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@ end
end
end

@testset "Session closing" begin
session = tf.Session(Graph())
x = constant(1)
@test run(session, x) == 1
close(session)
close(session) # Test that we can safely call `close` twice on the same session
@test_throws tf.ClosedSessionError run(session, x)
end

@testset "get_operations" begin
let
graph = Graph()
Expand Down