Skip to content

Commit

Permalink
Make check_feature_gate_key PT2 compatible
Browse files Browse the repository at this point in the history
Summary:
Add a new API for `check_feature_gate_key` that is PT2 compatible.
PT2 complains when an op does not take/return a tensor.  Thus,
`check_feature_gate_key_pt2` (the new API) takes a dummy tensor as an
input and returns a boolean tensor as an output.

Differential Revision: D66611784
  • Loading branch information
sryap authored and facebook-github-bot committed Nov 30, 2024
1 parent 6eb379a commit 011cbc4
Showing 1 changed file with 25 additions and 0 deletions.
25 changes: 25 additions & 0 deletions fbgemm_gpu/src/config/feature_gates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,21 @@ DLL_PUBLIC bool check_feature_gate_key(const std::string& key) {
}
}

DLL_PUBLIC at::Tensor check_feature_gate_key_pt2(
at::Tensor& tensor,
const std::string& key) {
auto output = at::empty({1}, tensor.options().dtype(at::kBool));
output.data_ptr<bool>()[0] = check_feature_gate_key(key);
return output;
}

DLL_PUBLIC at::Tensor check_feature_gate_key_pt2_meta(
at::Tensor& tensor,
const std::string& key) {
auto output = at::empty({1}, tensor.options().dtype(at::kBool));
return output;
}

DLL_PUBLIC bool is_feature_enabled(const FeatureGateName& feature) {
return check_feature_gate_key(to_string(feature));
}
Expand All @@ -81,4 +96,14 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"check_feature_gate_key(str key) -> bool",
fbgemm_gpu::config::check_feature_gate_key);
m.def("check_feature_gate_key_pt2(Tensor tensor, str key) -> Tensor");
DISPATCH_TO_CPU(
"check_feature_gate_key_pt2",
fbgemm_gpu::config::check_feature_gate_key_pt2);
DISPATCH_TO_CUDA(
"check_feature_gate_key_pt2",
fbgemm_gpu::config::check_feature_gate_key_pt2);
DISPATCH_TO_META(
"check_feature_gate_key_pt2",
fbgemm_gpu::config::check_feature_gate_key_pt2_meta);
}

0 comments on commit 011cbc4

Please sign in to comment.