Skip to content

Commit

Permalink
[XLA:GPU] Move singleton to HloOpProfiles.
Browse files Browse the repository at this point in the history
Profiler data already is a singleton inside GpuHloCostAnalysis. This was we'll be able to use it in different places.

PiperOrigin-RevId: 603326939
  • Loading branch information
olegshyshkov authored and tensorflower-gardener committed Feb 1, 2024
1 parent 131f388 commit a68de63
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 11 deletions.
10 changes: 5 additions & 5 deletions third_party/xla/xla/service/gpu/model/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,7 @@ xla_cc_test(
cc_library(
name = "gpu_hlo_cost_analysis",
srcs = ["gpu_hlo_cost_analysis.cc"],
hdrs = [
"gpu_hlo_cost_analysis.h",
"hlo_op_profiles_data.h",
],
hdrs = ["gpu_hlo_cost_analysis.h"],
compatible_with = get_compatible_with_portable(),
visibility = ["//visibility:public"],
deps = [
Expand Down Expand Up @@ -452,7 +449,10 @@ tf_proto_library(
cc_library(
name = "hlo_op_profiles",
srcs = ["hlo_op_profiles.cc"],
hdrs = ["hlo_op_profiles.h"],
hdrs = [
"hlo_op_profiles.h",
"hlo_op_profiles_data.h",
],
compatible_with = get_compatible_with_portable(),
visibility = ["//visibility:public"],
deps = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ limitations under the License.
#include "xla/service/gpu/cublas_cudnn.h"
#include "xla/service/gpu/model/hlo_op_profile.pb.h"
#include "xla/service/gpu/model/hlo_op_profiles.h"
#include "xla/service/gpu/model/hlo_op_profiles_data.h"
#include "xla/service/hlo_cost_analysis.h"
#include "xla/service/hlo_module_config.h"
#include "xla/shape.h"
Expand Down Expand Up @@ -327,11 +326,7 @@ int64_t GpuHloCostAnalysis::GetConvolutionFlops(

int64_t FlopsPerElement(const se::DeviceDescription* device_info,
const PrimitiveType type, const HloOpcode opcode) {
static const auto* hlo_op_profiles =
HloOpProfiles::Load(kDeviceHloOpProfiles,
/*default_profile_name=*/"sm_86")
.release();
auto device_profile = hlo_op_profiles->GetProfile(device_info);
auto device_profile = HloOpProfiles::Singleton().GetProfile(device_info);
// Elementwise instructions typically take at least a few clock cycles.
constexpr int64_t kDefaultFlopsPerElement = 3;
return FindOrDefault(device_profile, std::make_pair(opcode, type),
Expand Down
9 changes: 9 additions & 0 deletions third_party/xla/xla/service/gpu/model/hlo_op_profiles.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,22 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/service/gpu/model/hlo_op_profiles_data.h"
#include "xla/stream_executor/device_description.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/protobuf.h"

namespace xla {
namespace gpu {

/*static*/ const HloOpProfiles& HloOpProfiles::Singleton() {
static const auto* hlo_op_profiles =
HloOpProfiles::Load(kDeviceHloOpProfiles,
/*default_profile_name=*/"sm_86")
.release();
return *hlo_op_profiles;
}

/*static*/ std::string HloOpProfiles::GetProfileName(
const se::DeviceDescription* device_info) {
if (device_info != nullptr) {
Expand Down
3 changes: 3 additions & 0 deletions third_party/xla/xla/service/gpu/model/hlo_op_profiles.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ class HloOpProfiles {
absl::flat_hash_map<std::string, // compute capability.
HloOpProfile>;

// Returns singleton with profiler data.
static const HloOpProfiles& Singleton();

// Returns profile name for the gived device.
// For CUDA, the format is "sm_XX".
static std::string GetProfileName(const se::DeviceDescription* device_info);
Expand Down

0 comments on commit a68de63

Please sign in to comment.