diff --git a/Sources/TensorFlow/Core/Runtime.swift b/Sources/TensorFlow/Core/Runtime.swift index d9acb42e6..8370a509b 100644 --- a/Sources/TensorFlow/Core/Runtime.swift +++ b/Sources/TensorFlow/Core/Runtime.swift @@ -569,6 +569,12 @@ public final class _ExecutionContext { // Initialize the TF runtime exactly once. Only affects local execution // (when _RuntimeConfig.tensorFlowServer is set to ""). if !_RuntimeConfig.tensorFlowRuntimeInitialized { + // Install a signal handler to ensure we exit when interrupted. + signal(SIGINT) { _ in + print("Caught interrupt signal, exiting...") + exit(1) + } + var args = ["dummyProgramName"] if _RuntimeConfig.printsDebugLog { args.append("--alsologtostderr") @@ -588,9 +594,9 @@ public final class _ExecutionContext { // Calculate the addresses of all the strings within our single buffer, and then call // TF_InitMain. - flattenedStringBytes.withUnsafeMutableBufferPointer { flattenedStringBytesBuffer in + flattenedStringBytes.withUnsafeMutableBufferPointer { buffer in var stringAddrs: [UnsafeMutablePointer?] = [] - var currentStringAddr = flattenedStringBytesBuffer.baseAddress + var currentStringAddr = buffer.baseAddress .map(UnsafeMutablePointer.init) for length in lengths { stringAddrs.append(currentStringAddr) @@ -598,14 +604,9 @@ public final class _ExecutionContext { } stringAddrs.withUnsafeMutableBufferPointer { stringAddrsBuffer in - var cArgs = [stringAddrsBuffer.baseAddress.map(UnsafeMutablePointer.init)] - var cArgsCount = [Int32(args.count)] - - cArgs.withUnsafeMutableBufferPointer { cArgsBuffer in - cArgsCount.withUnsafeMutableBufferPointer { cArgsCountBuffer in - TF_InitMain(nil, cArgsCountBuffer.baseAddress, cArgsBuffer.baseAddress) - } - } + var cArgsCount = Int32(args.count) + var cArgs = stringAddrsBuffer.baseAddress.map(UnsafeMutablePointer.init) + TF_InitMain(nil, &cArgsCount, &cArgs) } } _RuntimeConfig.tensorFlowRuntimeInitialized = true