Skip to content

Commit

Permalink
Combine StreamExecutor::GetKernel and ::CreateKernel into a single ne…
Browse files Browse the repository at this point in the history
…w method ::LoadKernel.

PiperOrigin-RevId: 657249905
  • Loading branch information
klucke authored and tensorflower-gardener committed Jul 29, 2024
1 parent 9b8e6bb commit 2cf91be
Show file tree
Hide file tree
Showing 9 changed files with 36 additions and 70 deletions.
4 changes: 2 additions & 2 deletions third_party/xla/xla/backends/interpreter/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ class XlaInterpreterExecutor : public StreamExecutorCommon {
absl::Status Init() override { return absl::OkStatus(); }

int device_ordinal() const override { return device_ordinal_; };
absl::Status GetKernel(const MultiKernelLoaderSpec &spec,
Kernel *kernel) override {
absl::StatusOr<std::unique_ptr<Kernel>> LoadKernel(
const MultiKernelLoaderSpec &spec) override {
return absl::UnimplementedError("Not Implemented");
}
absl::Status Launch(Stream *stream, const ThreadDim &thread_dims,
Expand Down
28 changes: 11 additions & 17 deletions third_party/xla/xla/stream_executor/cuda/cuda_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,21 +216,19 @@ absl::Status GpuExecutor::LoadModuleFromHsaco(const char* hsaco,
"Feature not supported on CUDA platform (LoadModuleFromHsaco)");
}

absl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec,
Kernel* kernel) {
GpuKernel* cuda_kernel = AsGpuKernel(kernel);
absl::StatusOr<std::unique_ptr<Kernel>> GpuExecutor::LoadKernel(
const MultiKernelLoaderSpec& spec) {
auto cuda_kernel = std::make_unique<GpuKernel>(this);
CUmodule module;
const std::string* kernel_name;

VLOG(3) << "GetKernel on kernel " << kernel << " : " << kernel->name();

if (spec.has_cuda_cubin_in_memory()) {
absl::MutexLock lock{&in_memory_modules_mu_};
kernel_name = &spec.cuda_cubin_in_memory().kernel_name();
const char* cubin = reinterpret_cast<const char*>(
spec.cuda_cubin_in_memory().cubin_bytes().data());
TF_RETURN_IF_ERROR(LoadModuleFromCuBin(cubin, &module));
kernel_to_gpu_binary_[kernel] = cubin;
kernel_to_gpu_binary_[cuda_kernel.get()] = cubin;

} else if (spec.has_cuda_ptx_in_memory()) {
kernel_name = &spec.cuda_ptx_in_memory().kernel_name();
Expand All @@ -249,7 +247,7 @@ absl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec,

absl::MutexLock lock{&in_memory_modules_mu_};
TF_RETURN_IF_ERROR(LoadModuleFromPtx(ptx, &module));
kernel_to_gpu_binary_[kernel] = ptx;
kernel_to_gpu_binary_[cuda_kernel.get()] = ptx;

} else if (spec.has_in_process_symbol()) {
kernel_name = &spec.in_process_symbol().kernel_name();
Expand All @@ -265,7 +263,7 @@ absl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec,
} else {
return absl::InternalError("No method of loading CUDA kernel provided");
}

VLOG(3) << "LoadKernel on kernel : " << *kernel_name;
// If we resolved kernel from a symbol pointer, there is no need to load it
// from a module, as CUDA runtime did that automatically for us.
if (!spec.has_in_process_symbol()) {
Expand All @@ -284,11 +282,11 @@ absl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec,
cuda_kernel->set_arity(spec.arity());

KernelMetadata kernel_metadata;
TF_RETURN_IF_ERROR(GetKernelMetadata(cuda_kernel, &kernel_metadata));
kernel->set_metadata(kernel_metadata);
kernel->set_name(*kernel_name);
kernel->set_args_packing(spec.kernel_args_packing());
return absl::OkStatus();
TF_RETURN_IF_ERROR(GetKernelMetadata(cuda_kernel.get(), &kernel_metadata));
cuda_kernel->set_metadata(kernel_metadata);
cuda_kernel->set_name(*kernel_name);
cuda_kernel->set_args_packing(spec.kernel_args_packing());
return std::move(cuda_kernel);
}

absl::StatusOr<std::unique_ptr<EventBasedTimer>>
Expand Down Expand Up @@ -793,10 +791,6 @@ absl::StatusOr<std::unique_ptr<Stream>> GpuExecutor::CreateStream(
}
}

absl::StatusOr<std::unique_ptr<Kernel>> GpuExecutor::CreateKernel() {
return std::make_unique<GpuKernel>(this);
}

absl::StatusOr<std::unique_ptr<CommandBuffer>> GpuExecutor::CreateCommandBuffer(
CommandBuffer::Mode mode) {
VLOG(2) << "Create CUDA command buffer (CUDA graph)";
Expand Down
6 changes: 2 additions & 4 deletions third_party/xla/xla/stream_executor/gpu/gpu_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ class GpuExecutor : public StreamExecutorCommon {

int device_ordinal() const override { return device_ordinal_; };

absl::Status GetKernel(const MultiKernelLoaderSpec& spec,
Kernel* kernel) override;
absl::StatusOr<std::unique_ptr<Kernel>> LoadKernel(
const MultiKernelLoaderSpec& spec) override;

// (supported on CUDA only)
void UnloadKernel(const Kernel* kernel) override;
Expand Down Expand Up @@ -240,8 +240,6 @@ class GpuExecutor : public StreamExecutorCommon {
std::optional<std::variant<StreamPriority, int>> priority =
std::nullopt) override;

absl::StatusOr<std::unique_ptr<Kernel>> CreateKernel() override;

absl::StatusOr<std::unique_ptr<CommandBuffer>> CreateCommandBuffer(
CommandBuffer::Mode mode) override;

Expand Down
14 changes: 4 additions & 10 deletions third_party/xla/xla/stream_executor/host/host_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,24 +74,18 @@ absl::Status HostExecutor::Init() {
return absl::OkStatus();
}

absl::StatusOr<std::unique_ptr<Kernel>> HostExecutor::CreateKernel() {
return std::make_unique<HostKernel>(thread_pool_);
}

absl::Status HostExecutor::GetKernel(const MultiKernelLoaderSpec& spec,
Kernel* kernel) {
HostKernel* host_kernel = AsHostKernel(kernel);
absl::StatusOr<std::unique_ptr<Kernel>> HostExecutor::LoadKernel(
const MultiKernelLoaderSpec& spec) {
auto host_kernel = std::make_unique<HostKernel>(thread_pool_);
host_kernel->SetArity(spec.arity());

VLOG(3) << "GetKernel on kernel " << kernel << " : " << kernel->name();

for (auto& loader : KernelFunctionLoaderRegistry()) {
auto loaded = loader(spec);
if (!loaded.has_value()) continue;

TF_ASSIGN_OR_RETURN(auto kernel_function, *std::move(loaded));
host_kernel->SetKernelFunction(std::move(kernel_function));
return absl::OkStatus();
return std::move(host_kernel);
}

return absl::InternalError("No method of loading host kernel provided");
Expand Down
6 changes: 2 additions & 4 deletions third_party/xla/xla/stream_executor/host/host_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,8 @@ class HostExecutor : public StreamExecutorCommon {

absl::Status Init() override;

absl::Status GetKernel(const MultiKernelLoaderSpec& spec,
Kernel* kernel) override;

absl::StatusOr<std::unique_ptr<Kernel>> CreateKernel() override;
absl::StatusOr<std::unique_ptr<Kernel>> LoadKernel(
const MultiKernelLoaderSpec& spec) override;

absl::Status Launch(Stream* stream, const ThreadDim& thread_dims,
const BlockDim& block_dims, const Kernel& kernel,
Expand Down
6 changes: 1 addition & 5 deletions third_party/xla/xla/stream_executor/kernel_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ limitations under the License.
#include "xla/stream_executor/kernel.h"
#include "xla/stream_executor/kernel_spec.h"
#include "xla/stream_executor/stream_executor.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"

namespace stream_executor {

Expand All @@ -33,9 +31,7 @@ class KernelFactory {
// Creates kernel on a given executor from a given kernel specification.
static inline absl::StatusOr<std::unique_ptr<Kernel>> Create(
StreamExecutor *executor, const MultiKernelLoaderSpec &spec) {
TF_ASSIGN_OR_RETURN(auto kernel, executor->CreateKernel());
TF_RETURN_IF_ERROR(executor->GetKernel(spec, kernel.get()));
return kernel;
return executor->LoadKernel(spec);
}
};

Expand Down
6 changes: 2 additions & 4 deletions third_party/xla/xla/stream_executor/mock_stream_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ class MockStreamExecutor : public StreamExecutor {
MockStreamExecutor() = default;
MOCK_METHOD(absl::Status, Init, (), (override));
MOCK_METHOD(int, device_ordinal, (), (const, override));
MOCK_METHOD(absl::Status, GetKernel,
(const MultiKernelLoaderSpec& spec, Kernel* kernel), (override));
MOCK_METHOD(absl::StatusOr<std::unique_ptr<Kernel>>, LoadKernel,
(const MultiKernelLoaderSpec& spec), (override));
MOCK_METHOD(bool, UnloadModule, (ModuleHandle module_handle), (override));
MOCK_METHOD(absl::Status, LoadModule,
(const MultiModuleLoaderSpec& spec, ModuleHandle* module_handle),
Expand Down Expand Up @@ -124,8 +124,6 @@ class MockStreamExecutor : public StreamExecutor {
MOCK_METHOD(blas::BlasSupport*, AsBlas, (), (override));
MOCK_METHOD(fft::FftSupport*, AsFft, (), (override));
MOCK_METHOD(dnn::DnnSupport*, AsDnn, (), (override));
MOCK_METHOD(absl::StatusOr<std::unique_ptr<Kernel>>, CreateKernel, (),
(override));
MOCK_METHOD(absl::StatusOr<std::unique_ptr<CommandBuffer>>,
CreateCommandBuffer, (CommandBuffer::Mode mode), (override));
MOCK_METHOD(std::optional<AllocatorStats>, GetAllocatorStats, (), (override));
Expand Down
22 changes: 9 additions & 13 deletions third_party/xla/xla/stream_executor/rocm/rocm_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,9 @@ absl::StatusOr<bool> GpuExecutor::DelayKernelIsSupported(GpuStream* stream) {
return false;
}

absl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec,
Kernel* kernel) {
GpuKernel* rocm_kernel = AsGpuKernel(kernel);
absl::StatusOr<std::unique_ptr<Kernel>> GpuExecutor::LoadKernel(
const MultiKernelLoaderSpec& spec) {
auto rocm_kernel = std::make_unique<GpuKernel>(this);
hipModule_t module = nullptr;
const std::string* kernel_name;

Expand All @@ -272,7 +272,7 @@ absl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec,
if (module == nullptr) {
TF_RETURN_IF_ERROR(GpuDriver::LoadHsaco(context_, hsaco, &module));
}
kernel_to_gpu_binary_[kernel] = hsaco;
kernel_to_gpu_binary_[rocm_kernel.get()] = hsaco;
} else if (spec.has_in_process_symbol()) {
kernel_name = &spec.in_process_symbol().kernel_name();
void* symbol = spec.in_process_symbol().symbol();
Expand Down Expand Up @@ -310,12 +310,12 @@ absl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec,
// unable to get kernel metadata for in-process kernel
if (!spec.has_in_process_symbol()) {
KernelMetadata kernel_metadata;
TF_RETURN_IF_ERROR(GetKernelMetadata(rocm_kernel, &kernel_metadata));
kernel->set_metadata(kernel_metadata);
TF_RETURN_IF_ERROR(GetKernelMetadata(rocm_kernel.get(), &kernel_metadata));
rocm_kernel->set_metadata(kernel_metadata);
}
kernel->set_name(*kernel_name);
kernel->set_args_packing(spec.kernel_args_packing());
return absl::OkStatus();
rocm_kernel->set_name(*kernel_name);
rocm_kernel->set_args_packing(spec.kernel_args_packing());
return std::move(rocm_kernel);
}

absl::Status GpuExecutor::GetKernelMetadata(GpuKernel* rocm_kernel,
Expand Down Expand Up @@ -669,10 +669,6 @@ absl::StatusOr<std::unique_ptr<Stream>> GpuExecutor::CreateStream(
}
}

absl::StatusOr<std::unique_ptr<Kernel>> GpuExecutor::CreateKernel() {
return std::make_unique<GpuKernel>(this);
}

absl::StatusOr<std::unique_ptr<CommandBuffer>> GpuExecutor::CreateCommandBuffer(
CommandBuffer::Mode mode) {
VLOG(2) << "Create ROCm command buffer (ROCm graph)";
Expand Down
14 changes: 3 additions & 11 deletions third_party/xla/xla/stream_executor/stream_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,13 @@ class StreamExecutor {
return AllocateArray<T>(1);
}

// Retrieves (loads) a kernel, if one exists.
// Loads a kernel from a MultiKernelLoaderSpec.
//
// Parameters:
// spec: The MultiKernelLoaderSpec is usually generated as a compile-time
// constant into an appropriate namespace.
// kernel: Outparam that the kernel is loaded into. A given Kernel
// instantiation should not be loaded into more than once.
virtual absl::Status GetKernel(const MultiKernelLoaderSpec& spec,
Kernel* kernel) {
virtual absl::StatusOr<std::unique_ptr<Kernel>> LoadKernel(
const MultiKernelLoaderSpec& spec) {
return absl::UnimplementedError("Not Implemented");
}

Expand Down Expand Up @@ -314,12 +312,6 @@ class StreamExecutor {
// underlying platform.
virtual dnn::DnnSupport* AsDnn() { return nullptr; }

// Creates a new Kernel object.
// TODO(klucke) Combine with GetKernel.
virtual absl::StatusOr<std::unique_ptr<Kernel>> CreateKernel() {
return absl::UnimplementedError("Kernels are not implemented");
}

// Creates a new CommandBuffer object.
virtual absl::StatusOr<std::unique_ptr<CommandBuffer>> CreateCommandBuffer(
CommandBuffer::Mode mode) {
Expand Down

0 comments on commit 2cf91be

Please sign in to comment.