diff --git a/stdlib/public/CTensorFlow/ctensorflow_init.cpp b/stdlib/public/CTensorFlow/ctensorflow_init.cpp index 56932a8b138bc..d40e66260a75b 100644 --- a/stdlib/public/CTensorFlow/ctensorflow_init.cpp +++ b/stdlib/public/CTensorFlow/ctensorflow_init.cpp @@ -17,33 +17,6 @@ void handle_sigint(int signal) { exit(1); } -void InitTensorFlowRuntime(unsigned char enable_debug_logging, - int verbose_level) { - // Install a signal handler to ensure we exit when interrupted. - signal(SIGINT, handle_sigint); - - // Synthesize argc and argv - char arg0[] = "dummyProgramName"; - std::vector my_argv; - my_argv.push_back(&arg0[0]); - // This allows us to dump TF logging to the output of a swift binary. - // We can only dump to stderr, since there is no flag alsologtostdout. - char arg1[] = "--alsologtostderr"; - if (enable_debug_logging > 0) { - my_argv.push_back(&arg1[0]); - } - char arg2[] = "--v=?"; - if (verbose_level > 0) { - assert(verbose_level <= 4); - arg2[4] = verbose_level + '0'; - my_argv.push_back(&arg2[0]); - } - int my_argc = my_argv.size(); - char** tmpArgv = my_argv.data(); - // Initialize GPU devices. - TF_InitMain(/*usage=*/nullptr, &my_argc, &tmpArgv); -} - static bool setValue(TF_DataType tfDtype, int64_t val, void *ptr) { switch (tfDtype) { case TF_INT8: diff --git a/stdlib/public/CTensorFlow/ctensorflow_init.h b/stdlib/public/CTensorFlow/ctensorflow_init.h index b0ab1550de6fe..16a36877717f0 100644 --- a/stdlib/public/CTensorFlow/ctensorflow_init.h +++ b/stdlib/public/CTensorFlow/ctensorflow_init.h @@ -7,15 +7,6 @@ extern "C" { #endif -// Call this API exactly once before any TensorFlow backend/runtime calls. -// -// It sets up device context for any GPU based computation. When -// `enable_debug_logging` is true, it also dumps TF logging for debugging -// purposes. In that case, when `verbose_level` is positive (must be <= 4), it -// also dumps verbose logs at that level. -extern void InitTensorFlowRuntime(unsigned char enable_debug_logging, - int verbose_level); - //===----------------------------------------------------------------------===// // - MARK: Runtime functions to be called via IRGen. //===----------------------------------------------------------------------===// diff --git a/stdlib/public/TensorFlow/CompilerRuntime.swift b/stdlib/public/TensorFlow/CompilerRuntime.swift index 33ebeb66be50d..221b25c6eb8c4 100644 --- a/stdlib/public/TensorFlow/CompilerRuntime.swift +++ b/stdlib/public/TensorFlow/CompilerRuntime.swift @@ -580,8 +580,44 @@ public final class _ExecutionContext { // Initialize the TF runtime exactly once. Only affects local execution // (when _RuntimeConfig.tensorFlowServer is set to ""). if !_RuntimeConfig.tensorFlowRuntimeInitialized { - InitTensorFlowRuntime(_RuntimeConfig.printsDebugLog ? 1 : 0, - _RuntimeConfig.tensorflowVerboseLogLevel) + var args = ["dummyProgramName"] + if _RuntimeConfig.printsDebugLog { + args.append("--alsologtostderr") + } + if _RuntimeConfig.tensorflowVerboseLogLevel > 0 { + args.append("--v=\(_RuntimeConfig.tensorflowVerboseLogLevel)") + } + // Collect all the strings' utf8 bytes into a single array so that we can + // address all the strings with a single `flattenedStringBytes.withUnsafeBufferPointer`. + var flattenedStringBytes: [Int8] = [] + var lengths: [Int] = [] + for arg in args { + let bytes = arg.utf8CString + flattenedStringBytes.append(contentsOf: bytes) + lengths.append(bytes.count) + } + + // Calculate the addresses of all the strings within our single buffer, and then call + // TF_InitMain. + flattenedStringBytes.withUnsafeMutableBufferPointer { flattenedStringBytesBuffer in + var stringAddrs: [UnsafeMutablePointer?] = [] + var currentStringAddr = flattenedStringBytesBuffer.baseAddress.map(UnsafeMutablePointer.init) + for length in lengths { + stringAddrs.append(currentStringAddr) + currentStringAddr = currentStringAddr?.advanced(by: length) + } + + stringAddrs.withUnsafeMutableBufferPointer { stringAddrsBuffer in + var cArgs = [stringAddrsBuffer.baseAddress.map(UnsafeMutablePointer.init)] + var cArgsCount = [Int32(args.count)] + + cArgs.withUnsafeMutableBufferPointer { cArgsBuffer in + cArgsCount.withUnsafeMutableBufferPointer { cArgsCountBuffer in + TF_InitMain(nil, cArgsCountBuffer.baseAddress, cArgsBuffer.baseAddress) + } + } + } + } _RuntimeConfig.tensorFlowRuntimeInitialized = true }