Skip to content

Commit

Permalink
Fix segfault in XlaCallModule shape inference
Browse files Browse the repository at this point in the history
Entries to `xla_call_module_loaders_` were added eagerly with nullptr as values (to save a lookup on a miss), but this caused some entries to keep nullptr loaders if we fail to initialize the loader. This CL changes the logic so that we insert only non-nullptr loaders to the map.

Confirmed that the added MLIR test crashes without the fix.

PiperOrigin-RevId: 547656759
  • Loading branch information
junwhanahn authored and tensorflower-gardener committed Jul 13, 2023
1 parent a310a86 commit 1c0b86b
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1297,6 +1297,12 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
func.return %0 : tensor<*xf32>
}

func.func @xla_call_module_parsing_error(%arg0: tensor<f32>) -> tensor<*xf32> {
%0 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<*xf32>
%1 = "tf.XlaCallModule"(%arg0, %0) {Sout = [#tf_type.shape<*>], device = "", dim_args_spec = [], module = "invalid-stablehlo-module", platforms = [], version = 4 : i64} : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
func.return %1 : tensor<*xf32>
}

// CHECK-LABEL: func @xla_host_compute_mlir_empty_module
func.func @xla_host_compute_mlir_empty_module(%arg0: tensor<2xf32>) -> tensor<*xf32> {
// CHECK: "tf._XlaHostComputeMlir"
Expand Down
78 changes: 39 additions & 39 deletions tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1191,49 +1191,49 @@ bool ShapeInference::InferShapeForCaseRegion(CaseRegionOp op) {

bool ShapeInference::InferShapeForXlaCallModule(XlaCallModuleOp op) {
tensorflow::XlaCallModuleLoader* loader;
{
const auto [it, inserted] = xla_call_module_loaders_.insert({op, nullptr});

if (auto it = xla_call_module_loaders_.find(op);
it != xla_call_module_loaders_.end()) {
loader = it->second.get();
} else {
// Lazily parse XlaCallModule's embedded HLO module and cache the loader to
// avoid repeatedly parsing the module.
if (inserted) {
std::vector<std::string> dim_args_spec;
for (auto attr : op.getDimArgsSpec().getAsRange<StringAttr>()) {
dim_args_spec.push_back(attr.getValue().str());
}
std::vector<std::string> disabled_checks;
for (auto attr : op.getDisabledChecks().getAsRange<StringAttr>()) {
disabled_checks.push_back(attr.getValue().str());
}
std::vector<std::string> platforms;
for (auto attr : op.getPlatforms().getAsRange<StringAttr>()) {
platforms.push_back(attr.getValue().str());
}
// Always use the first platform. The assumption is that shape inference
// results should be the same regardless of which platform is chosen.
// Very old versions of the op have an empty platforms attribute.
std::string loading_platform =
(platforms.empty() ? "CPU" : platforms.front());

// It is a terrible idea to have local MLIR contexts so we need to
// register extensions here, again.
mlir::DialectRegistry registry;
registry.insert<mlir::func::FuncDialect>();
mlir::func::registerAllExtensions(registry);
xla_call_module_context_.appendDialectRegistry(registry);

auto l = tensorflow::XlaCallModuleLoader::Create(
&xla_call_module_context_, op.getVersion(), op.getModule().str(),
std::move(dim_args_spec), std::move(disabled_checks),
std::move(platforms), std::move(loading_platform));
if (!l.ok()) {
LLVM_DEBUG(llvm::dbgs() << "Parsing error in XlaCallModule: "
<< l.status().ToString() << "\n");
return false;
}
it->second = *std::move(l);

std::vector<std::string> dim_args_spec;
for (auto attr : op.getDimArgsSpec().getAsRange<StringAttr>()) {
dim_args_spec.push_back(attr.getValue().str());
}
std::vector<std::string> disabled_checks;
for (auto attr : op.getDisabledChecks().getAsRange<StringAttr>()) {
disabled_checks.push_back(attr.getValue().str());
}
std::vector<std::string> platforms;
for (auto attr : op.getPlatforms().getAsRange<StringAttr>()) {
platforms.push_back(attr.getValue().str());
}
// Always use the first platform. The assumption is that shape inference
// results should be the same regardless of which platform is chosen.
// Very old versions of the op have an empty platforms attribute.
std::string loading_platform =
(platforms.empty() ? "CPU" : platforms.front());

// It is a terrible idea to have local MLIR contexts so we need to
// register extensions here, again.
mlir::DialectRegistry registry;
registry.insert<mlir::func::FuncDialect>();
mlir::func::registerAllExtensions(registry);
xla_call_module_context_.appendDialectRegistry(registry);

auto l = tensorflow::XlaCallModuleLoader::Create(
&xla_call_module_context_, op.getVersion(), op.getModule().str(),
std::move(dim_args_spec), std::move(disabled_checks),
std::move(platforms), std::move(loading_platform));
if (!l.ok()) {
LLVM_DEBUG(llvm::dbgs() << "Parsing error in XlaCallModule: "
<< l.status().ToString() << "\n");
return false;
}

it = xla_call_module_loaders_.insert({op, *std::move(l)}).first;
loader = it->second.get();
}

Expand Down

0 comments on commit 1c0b86b

Please sign in to comment.