From 011cbc40c47ab5a3997d0c53f14f989da336bce1 Mon Sep 17 00:00:00 2001 From: Sarunya Pumma Date: Sat, 30 Nov 2024 01:29:08 -0800 Subject: [PATCH] Make check_feature_gate_key PT2 compatible Summary: Add a new API for `check_feature_gate_key` that is PT2 compatible. PT2 complains when an op does not take/return a tensor. Thus, `check_feature_gate_key_pt2` (the new API) takes a dummy tensor as an input and returns a boolean tensor as an output. Differential Revision: D66611784 --- fbgemm_gpu/src/config/feature_gates.cpp | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/fbgemm_gpu/src/config/feature_gates.cpp b/fbgemm_gpu/src/config/feature_gates.cpp index e0e07a12d0..cb3aad5925 100644 --- a/fbgemm_gpu/src/config/feature_gates.cpp +++ b/fbgemm_gpu/src/config/feature_gates.cpp @@ -65,6 +65,21 @@ DLL_PUBLIC bool check_feature_gate_key(const std::string& key) { } } +DLL_PUBLIC at::Tensor check_feature_gate_key_pt2( + at::Tensor& tensor, + const std::string& key) { + auto output = at::empty({1}, tensor.options().dtype(at::kBool)); + output.data_ptr()[0] = check_feature_gate_key(key); + return output; +} + +DLL_PUBLIC at::Tensor check_feature_gate_key_pt2_meta( + at::Tensor& tensor, + const std::string& key) { + auto output = at::empty({1}, tensor.options().dtype(at::kBool)); + return output; +} + DLL_PUBLIC bool is_feature_enabled(const FeatureGateName& feature) { return check_feature_gate_key(to_string(feature)); } @@ -81,4 +96,14 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( "check_feature_gate_key(str key) -> bool", fbgemm_gpu::config::check_feature_gate_key); + m.def("check_feature_gate_key_pt2(Tensor tensor, str key) -> Tensor"); + DISPATCH_TO_CPU( + "check_feature_gate_key_pt2", + fbgemm_gpu::config::check_feature_gate_key_pt2); + DISPATCH_TO_CUDA( + "check_feature_gate_key_pt2", + fbgemm_gpu::config::check_feature_gate_key_pt2); + DISPATCH_TO_META( + "check_feature_gate_key_pt2", + fbgemm_gpu::config::check_feature_gate_key_pt2_meta); }