@@ -438,6 +438,7 @@ template <typename TKey, typename TValue>
438
438
class KvResourceGatherOp : public OpKernel {
439
439
public:
440
440
explicit KvResourceGatherOp (OpKernelConstruction* c) : OpKernel(c) {
441
+ OP_REQUIRES_OK (c, c->GetAttr (" is_inference" , &is_inference_));
441
442
OP_REQUIRES_OK (c,
442
443
c->GetAttr (" is_use_default_value_tensor" ,
443
444
&is_use_default_value_tensor_));
@@ -461,6 +462,17 @@ class KvResourceGatherOp : public OpKernel {
461
462
return 1 ;
462
463
};
463
464
}
465
+ if (!is_inference_) {
466
+ lookup_fn_ = [](EmbeddingVar<TKey, TValue>* ev, TKey key,
467
+ TValue* val, TValue* default_v, int count) {
468
+ ev->LookupOrCreate (key, val, default_v, count);
469
+ };
470
+ } else {
471
+ lookup_fn_ = [](EmbeddingVar<TKey, TValue>* ev, TKey key,
472
+ TValue* val, TValue* default_v, int count) {
473
+ ev->Lookup (key, val, default_v);
474
+ };
475
+ }
464
476
}
465
477
466
478
void Compute (OpKernelContext* c) override {
@@ -511,7 +523,7 @@ class KvResourceGatherOp : public OpKernel {
511
523
default_v, indices_flat (i), i, ev->GetDefaultValueDim (),
512
524
ev->ValueLen ());
513
525
int32 count = get_count_fn_ (counts, i);
514
- ev-> LookupOrCreate ( indices_flat (i),
526
+ lookup_fn_ (ev, indices_flat (i),
515
527
out_base + i * slice_elems, default_v_ptr, count);
516
528
}
517
529
};
@@ -530,9 +542,12 @@ class KvResourceGatherOp : public OpKernel {
530
542
531
543
private:
532
544
bool is_use_default_value_tensor_;
545
+ bool is_inference_;
533
546
std::function<
534
547
TValue*(TValue*, TKey, int64, int64, int64)> get_default_v_fn_;
535
548
std::function<int32(int32*, int64)> get_count_fn_;
549
+ std::function<void (EmbeddingVar<TKey, TValue>* ev,
550
+ TKey key, TValue* val, TValue* default_v, int count)> lookup_fn_;
536
551
};
537
552
538
553
#define REGISTER_GATHER_FULL (dev, ktype, vtype ) \
0 commit comments