@@ -171,7 +171,8 @@ Tensor split_embedding_codegen_lookup_dense_function(
171
171
Tensor>& /* vbe_B_offsets_rank_per_feature = std::nullopt */ ,
172
172
c10::SymInt /* max_B = -1 */ ,
173
173
c10::SymInt /* max_B_feature_rank = -1 */ ,
174
- c10::SymInt /* vbe_output_size = -1 */ ) {
174
+ c10::SymInt /* vbe_output_size = -1 */ ,
175
+ bool /* mixed_D = false */ ) {
175
176
return SplitLookupFunction_Dense_Op::apply (
176
177
host_weights,
177
178
weights_offsets,
@@ -190,15 +191,15 @@ Tensor split_embedding_codegen_lookup_dense_function(
190
191
// Deprecated for fb namespace! Please use fbgemm namespace instead!
191
192
TORCH_LIBRARY_FRAGMENT (fb, m) {
192
193
m.def (
193
- " dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1) -> Tensor" );
194
+ " dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1, bool mixed_D=False ) -> Tensor" );
194
195
DISPATCH_TO_CPU (
195
196
" dense_embedding_codegen_lookup_function" ,
196
197
split_embedding_codegen_lookup_dense_function);
197
198
}
198
199
199
200
TORCH_LIBRARY_FRAGMENT (fbgemm, m) {
200
201
m.def (
201
- " dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1) -> Tensor" );
202
+ " dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1, bool mixed_D=False ) -> Tensor" );
202
203
DISPATCH_TO_CPU (
203
204
" dense_embedding_codegen_lookup_function" ,
204
205
split_embedding_codegen_lookup_dense_function);
0 commit comments