diff --git a/third_party/xla/xla/backends/interpreter/executor.h b/third_party/xla/xla/backends/interpreter/executor.h index 8ca0cd9c357ef0..c653fc7317b595 100644 --- a/third_party/xla/xla/backends/interpreter/executor.h +++ b/third_party/xla/xla/backends/interpreter/executor.h @@ -86,8 +86,8 @@ class XlaInterpreterExecutor : public StreamExecutorCommon { absl::Status Init() override { return absl::OkStatus(); } int device_ordinal() const override { return device_ordinal_; }; - absl::Status GetKernel(const MultiKernelLoaderSpec &spec, - Kernel *kernel) override { + absl::StatusOr> LoadKernel( + const MultiKernelLoaderSpec &spec) override { return absl::UnimplementedError("Not Implemented"); } absl::Status Launch(Stream *stream, const ThreadDim &thread_dims, diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc index 63df37d3c037d9..7f478df047be84 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc @@ -216,21 +216,19 @@ absl::Status GpuExecutor::LoadModuleFromHsaco(const char* hsaco, "Feature not supported on CUDA platform (LoadModuleFromHsaco)"); } -absl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, - Kernel* kernel) { - GpuKernel* cuda_kernel = AsGpuKernel(kernel); +absl::StatusOr> GpuExecutor::LoadKernel( + const MultiKernelLoaderSpec& spec) { + auto cuda_kernel = std::make_unique(this); CUmodule module; const std::string* kernel_name; - VLOG(3) << "GetKernel on kernel " << kernel << " : " << kernel->name(); - if (spec.has_cuda_cubin_in_memory()) { absl::MutexLock lock{&in_memory_modules_mu_}; kernel_name = &spec.cuda_cubin_in_memory().kernel_name(); const char* cubin = reinterpret_cast( spec.cuda_cubin_in_memory().cubin_bytes().data()); TF_RETURN_IF_ERROR(LoadModuleFromCuBin(cubin, &module)); - kernel_to_gpu_binary_[kernel] = cubin; + kernel_to_gpu_binary_[cuda_kernel.get()] = cubin; } else if (spec.has_cuda_ptx_in_memory()) { kernel_name = &spec.cuda_ptx_in_memory().kernel_name(); @@ -249,7 +247,7 @@ absl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, absl::MutexLock lock{&in_memory_modules_mu_}; TF_RETURN_IF_ERROR(LoadModuleFromPtx(ptx, &module)); - kernel_to_gpu_binary_[kernel] = ptx; + kernel_to_gpu_binary_[cuda_kernel.get()] = ptx; } else if (spec.has_in_process_symbol()) { kernel_name = &spec.in_process_symbol().kernel_name(); @@ -265,7 +263,7 @@ absl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, } else { return absl::InternalError("No method of loading CUDA kernel provided"); } - + VLOG(3) << "LoadKernel on kernel : " << *kernel_name; // If we resolved kernel from a symbol pointer, there is no need to load it // from a module, as CUDA runtime did that automatically for us. if (!spec.has_in_process_symbol()) { @@ -284,11 +282,11 @@ absl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, cuda_kernel->set_arity(spec.arity()); KernelMetadata kernel_metadata; - TF_RETURN_IF_ERROR(GetKernelMetadata(cuda_kernel, &kernel_metadata)); - kernel->set_metadata(kernel_metadata); - kernel->set_name(*kernel_name); - kernel->set_args_packing(spec.kernel_args_packing()); - return absl::OkStatus(); + TF_RETURN_IF_ERROR(GetKernelMetadata(cuda_kernel.get(), &kernel_metadata)); + cuda_kernel->set_metadata(kernel_metadata); + cuda_kernel->set_name(*kernel_name); + cuda_kernel->set_args_packing(spec.kernel_args_packing()); + return std::move(cuda_kernel); } absl::StatusOr> @@ -793,10 +791,6 @@ absl::StatusOr> GpuExecutor::CreateStream( } } -absl::StatusOr> GpuExecutor::CreateKernel() { - return std::make_unique(this); -} - absl::StatusOr> GpuExecutor::CreateCommandBuffer( CommandBuffer::Mode mode) { VLOG(2) << "Create CUDA command buffer (CUDA graph)"; diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h index c19fa1cceeba0c..13b9b944d1beb2 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h @@ -122,8 +122,8 @@ class GpuExecutor : public StreamExecutorCommon { int device_ordinal() const override { return device_ordinal_; }; - absl::Status GetKernel(const MultiKernelLoaderSpec& spec, - Kernel* kernel) override; + absl::StatusOr> LoadKernel( + const MultiKernelLoaderSpec& spec) override; // (supported on CUDA only) void UnloadKernel(const Kernel* kernel) override; @@ -240,8 +240,6 @@ class GpuExecutor : public StreamExecutorCommon { std::optional> priority = std::nullopt) override; - absl::StatusOr> CreateKernel() override; - absl::StatusOr> CreateCommandBuffer( CommandBuffer::Mode mode) override; diff --git a/third_party/xla/xla/stream_executor/host/host_executor.cc b/third_party/xla/xla/stream_executor/host/host_executor.cc index ac1d22583d0fde..8d8eeb7e421de1 100644 --- a/third_party/xla/xla/stream_executor/host/host_executor.cc +++ b/third_party/xla/xla/stream_executor/host/host_executor.cc @@ -74,24 +74,18 @@ absl::Status HostExecutor::Init() { return absl::OkStatus(); } -absl::StatusOr> HostExecutor::CreateKernel() { - return std::make_unique(thread_pool_); -} - -absl::Status HostExecutor::GetKernel(const MultiKernelLoaderSpec& spec, - Kernel* kernel) { - HostKernel* host_kernel = AsHostKernel(kernel); +absl::StatusOr> HostExecutor::LoadKernel( + const MultiKernelLoaderSpec& spec) { + auto host_kernel = std::make_unique(thread_pool_); host_kernel->SetArity(spec.arity()); - VLOG(3) << "GetKernel on kernel " << kernel << " : " << kernel->name(); - for (auto& loader : KernelFunctionLoaderRegistry()) { auto loaded = loader(spec); if (!loaded.has_value()) continue; TF_ASSIGN_OR_RETURN(auto kernel_function, *std::move(loaded)); host_kernel->SetKernelFunction(std::move(kernel_function)); - return absl::OkStatus(); + return std::move(host_kernel); } return absl::InternalError("No method of loading host kernel provided"); diff --git a/third_party/xla/xla/stream_executor/host/host_executor.h b/third_party/xla/xla/stream_executor/host/host_executor.h index 18ec5a739faca5..4e2a2230ffbd4c 100644 --- a/third_party/xla/xla/stream_executor/host/host_executor.h +++ b/third_party/xla/xla/stream_executor/host/host_executor.h @@ -70,10 +70,8 @@ class HostExecutor : public StreamExecutorCommon { absl::Status Init() override; - absl::Status GetKernel(const MultiKernelLoaderSpec& spec, - Kernel* kernel) override; - - absl::StatusOr> CreateKernel() override; + absl::StatusOr> LoadKernel( + const MultiKernelLoaderSpec& spec) override; absl::Status Launch(Stream* stream, const ThreadDim& thread_dims, const BlockDim& block_dims, const Kernel& kernel, diff --git a/third_party/xla/xla/stream_executor/kernel_factory.h b/third_party/xla/xla/stream_executor/kernel_factory.h index 24e594ed89d10e..17e07cd0f97950 100644 --- a/third_party/xla/xla/stream_executor/kernel_factory.h +++ b/third_party/xla/xla/stream_executor/kernel_factory.h @@ -22,8 +22,6 @@ limitations under the License. #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/stream_executor.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" namespace stream_executor { @@ -33,9 +31,7 @@ class KernelFactory { // Creates kernel on a given executor from a given kernel specification. static inline absl::StatusOr> Create( StreamExecutor *executor, const MultiKernelLoaderSpec &spec) { - TF_ASSIGN_OR_RETURN(auto kernel, executor->CreateKernel()); - TF_RETURN_IF_ERROR(executor->GetKernel(spec, kernel.get())); - return kernel; + return executor->LoadKernel(spec); } }; diff --git a/third_party/xla/xla/stream_executor/mock_stream_executor.h b/third_party/xla/xla/stream_executor/mock_stream_executor.h index 3787be1133b5d4..f58a553f9ebdd8 100644 --- a/third_party/xla/xla/stream_executor/mock_stream_executor.h +++ b/third_party/xla/xla/stream_executor/mock_stream_executor.h @@ -59,8 +59,8 @@ class MockStreamExecutor : public StreamExecutor { MockStreamExecutor() = default; MOCK_METHOD(absl::Status, Init, (), (override)); MOCK_METHOD(int, device_ordinal, (), (const, override)); - MOCK_METHOD(absl::Status, GetKernel, - (const MultiKernelLoaderSpec& spec, Kernel* kernel), (override)); + MOCK_METHOD(absl::StatusOr>, LoadKernel, + (const MultiKernelLoaderSpec& spec), (override)); MOCK_METHOD(bool, UnloadModule, (ModuleHandle module_handle), (override)); MOCK_METHOD(absl::Status, LoadModule, (const MultiModuleLoaderSpec& spec, ModuleHandle* module_handle), @@ -124,8 +124,6 @@ class MockStreamExecutor : public StreamExecutor { MOCK_METHOD(blas::BlasSupport*, AsBlas, (), (override)); MOCK_METHOD(fft::FftSupport*, AsFft, (), (override)); MOCK_METHOD(dnn::DnnSupport*, AsDnn, (), (override)); - MOCK_METHOD(absl::StatusOr>, CreateKernel, (), - (override)); MOCK_METHOD(absl::StatusOr>, CreateCommandBuffer, (CommandBuffer::Mode mode), (override)); MOCK_METHOD(std::optional, GetAllocatorStats, (), (override)); diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc index 49fc9c646868ae..19a367a37ec27a 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc @@ -255,9 +255,9 @@ absl::StatusOr GpuExecutor::DelayKernelIsSupported(GpuStream* stream) { return false; } -absl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, - Kernel* kernel) { - GpuKernel* rocm_kernel = AsGpuKernel(kernel); +absl::StatusOr> GpuExecutor::LoadKernel( + const MultiKernelLoaderSpec& spec) { + auto rocm_kernel = std::make_unique(this); hipModule_t module = nullptr; const std::string* kernel_name; @@ -272,7 +272,7 @@ absl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, if (module == nullptr) { TF_RETURN_IF_ERROR(GpuDriver::LoadHsaco(context_, hsaco, &module)); } - kernel_to_gpu_binary_[kernel] = hsaco; + kernel_to_gpu_binary_[rocm_kernel.get()] = hsaco; } else if (spec.has_in_process_symbol()) { kernel_name = &spec.in_process_symbol().kernel_name(); void* symbol = spec.in_process_symbol().symbol(); @@ -310,12 +310,12 @@ absl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, // unable to get kernel metadata for in-process kernel if (!spec.has_in_process_symbol()) { KernelMetadata kernel_metadata; - TF_RETURN_IF_ERROR(GetKernelMetadata(rocm_kernel, &kernel_metadata)); - kernel->set_metadata(kernel_metadata); + TF_RETURN_IF_ERROR(GetKernelMetadata(rocm_kernel.get(), &kernel_metadata)); + rocm_kernel->set_metadata(kernel_metadata); } - kernel->set_name(*kernel_name); - kernel->set_args_packing(spec.kernel_args_packing()); - return absl::OkStatus(); + rocm_kernel->set_name(*kernel_name); + rocm_kernel->set_args_packing(spec.kernel_args_packing()); + return std::move(rocm_kernel); } absl::Status GpuExecutor::GetKernelMetadata(GpuKernel* rocm_kernel, @@ -669,10 +669,6 @@ absl::StatusOr> GpuExecutor::CreateStream( } } -absl::StatusOr> GpuExecutor::CreateKernel() { - return std::make_unique(this); -} - absl::StatusOr> GpuExecutor::CreateCommandBuffer( CommandBuffer::Mode mode) { VLOG(2) << "Create ROCm command buffer (ROCm graph)"; diff --git a/third_party/xla/xla/stream_executor/stream_executor.h b/third_party/xla/xla/stream_executor/stream_executor.h index 60fc20de835fb7..53c7ab9d33a08a 100644 --- a/third_party/xla/xla/stream_executor/stream_executor.h +++ b/third_party/xla/xla/stream_executor/stream_executor.h @@ -107,15 +107,13 @@ class StreamExecutor { return AllocateArray(1); } - // Retrieves (loads) a kernel, if one exists. + // Loads a kernel from a MultiKernelLoaderSpec. // // Parameters: // spec: The MultiKernelLoaderSpec is usually generated as a compile-time // constant into an appropriate namespace. - // kernel: Outparam that the kernel is loaded into. A given Kernel - // instantiation should not be loaded into more than once. - virtual absl::Status GetKernel(const MultiKernelLoaderSpec& spec, - Kernel* kernel) { + virtual absl::StatusOr> LoadKernel( + const MultiKernelLoaderSpec& spec) { return absl::UnimplementedError("Not Implemented"); } @@ -314,12 +312,6 @@ class StreamExecutor { // underlying platform. virtual dnn::DnnSupport* AsDnn() { return nullptr; } - // Creates a new Kernel object. - // TODO(klucke) Combine with GetKernel. - virtual absl::StatusOr> CreateKernel() { - return absl::UnimplementedError("Kernels are not implemented"); - } - // Creates a new CommandBuffer object. virtual absl::StatusOr> CreateCommandBuffer( CommandBuffer::Mode mode) {