diff --git a/server/etcdserver/txn/txn.go b/server/etcdserver/txn/txn.go index 9e8d397cb88e..22d232a864f8 100644 --- a/server/etcdserver/txn/txn.go +++ b/server/etcdserver/txn/txn.go @@ -37,55 +37,40 @@ func Put(ctx context.Context, lg *zap.Logger, lessor lease.Lessor, kv mvcc.KV, p traceutil.Field{Key: "key", Value: string(p.Key)}, traceutil.Field{Key: "req_size", Value: p.Size()}, ) - leaseID := lease.LeaseID(p.Lease) - if leaseID != lease.NoLease { - if l := lessor.Lookup(leaseID); l == nil { - return nil, nil, lease.ErrLeaseNotFound - } + err = checkLease(lessor, p) + if err != nil { + return nil, trace, err } txnWrite := kv.Write(trace) defer txnWrite.End() - resp, err = put(ctx, txnWrite, p) - return resp, trace, err + prevKV, err := checkAndGetPrevKV(trace, txnWrite, p) + if err != nil { + return nil, trace, err + } + return put(ctx, txnWrite, p, prevKV), trace, nil } -func put(ctx context.Context, txnWrite mvcc.TxnWrite, p *pb.PutRequest) (resp *pb.PutResponse, err error) { +func put(ctx context.Context, txnWrite mvcc.TxnWrite, p *pb.PutRequest, prevKV *mvcc.RangeResult) *pb.PutResponse { trace := traceutil.Get(ctx) - resp = &pb.PutResponse{} + resp := &pb.PutResponse{} resp.Header = &pb.ResponseHeader{} val, leaseID := p.Value, lease.LeaseID(p.Lease) - var rr *mvcc.RangeResult - if p.IgnoreValue || p.IgnoreLease || p.PrevKv { - trace.StepWithFunction(func() { - rr, err = txnWrite.Range(context.TODO(), p.Key, nil, mvcc.RangeOptions{}) - }, "get previous kv pair") - - if err != nil { - return nil, err - } - } - if p.IgnoreValue || p.IgnoreLease { - if rr == nil || len(rr.KVs) == 0 { - // ignore_{lease,value} flag expects previous key-value pair - return nil, errors.ErrKeyNotFound - } - } if p.IgnoreValue { - val = rr.KVs[0].Value + val = prevKV.KVs[0].Value } if p.IgnoreLease { - leaseID = lease.LeaseID(rr.KVs[0].Lease) + leaseID = lease.LeaseID(prevKV.KVs[0].Lease) } if p.PrevKv { - if rr != nil && len(rr.KVs) != 0 { - resp.PrevKv = &rr.KVs[0] + if prevKV != nil && len(prevKV.KVs) != 0 { + resp.PrevKv = &prevKV.KVs[0] } } resp.Header.Revision = txnWrite.Put(p.Key, val, leaseID) trace.AddField(traceutil.Field{Key: "response_revision", Value: resp.Header.Revision}) - return resp, nil + return resp } func DeleteRange(ctx context.Context, lg *zap.Logger, kv mvcc.KV, dr *pb.DeleteRangeRequest) (resp *pb.DeleteRangeResponse, trace *traceutil.Trace, err error) { @@ -255,7 +240,7 @@ func Txn(ctx context.Context, lg *zap.Logger, rt *pb.TxnRequest, txnModeWriteWit if isWrite { trace.AddField(traceutil.Field{Key: "read_only", Value: false}) } - _, err = checkTxn(txnRead, rt, lessor, txnPath) + _, err = checkTxn(trace, txnRead, rt, lessor, txnPath) if err != nil { txnRead.End() return nil, nil, err @@ -362,10 +347,11 @@ func executeTxn(ctx context.Context, lg *zap.Logger, txnWrite mvcc.TxnWrite, rt traceutil.Field{Key: "req_type", Value: "put"}, traceutil.Field{Key: "key", Value: string(tv.RequestPut.Key)}, traceutil.Field{Key: "req_size", Value: tv.RequestPut.Size()}) - resp, err := put(ctx, txnWrite, tv.RequestPut) + prevKV, err := getPrevKV(trace, txnWrite, tv.RequestPut) if err != nil { - return 0, fmt.Errorf("applyTxn: failed Put: %w", err) + return 0, fmt.Errorf("applyTxn: failed to get prevKV on put: %w", err) } + resp := put(ctx, txnWrite, tv.RequestPut, prevKV) respi.(*pb.ResponseOp_ResponsePut).ResponsePut = resp trace.StopSubTrace() case *pb.RequestOp_RequestDeleteRange: @@ -390,25 +376,52 @@ func executeTxn(ctx context.Context, lg *zap.Logger, txnWrite mvcc.TxnWrite, rt return txns, nil } -func checkPut(rv mvcc.ReadView, lessor lease.Lessor, req *pb.PutRequest) error { - if req.IgnoreValue || req.IgnoreLease { - // expects previous key-value, error if not exist - rr, err := rv.Range(context.TODO(), req.Key, nil, mvcc.RangeOptions{}) - if err != nil { - return err - } - if rr == nil || len(rr.KVs) == 0 { - return errors.ErrKeyNotFound - } +func checkPut(trace *traceutil.Trace, txnWrite mvcc.ReadView, lessor lease.Lessor, p *pb.PutRequest) error { + err := checkLease(lessor, p) + if err != nil { + return err } - if lease.LeaseID(req.Lease) != lease.NoLease { - if l := lessor.Lookup(lease.LeaseID(req.Lease)); l == nil { + _, err = checkAndGetPrevKV(trace, txnWrite, p) + return err +} + +func checkLease(lessor lease.Lessor, p *pb.PutRequest) error { + leaseID := lease.LeaseID(p.Lease) + if leaseID != lease.NoLease { + if l := lessor.Lookup(leaseID); l == nil { return lease.ErrLeaseNotFound } } return nil } +func checkAndGetPrevKV(trace *traceutil.Trace, txnWrite mvcc.ReadView, p *pb.PutRequest) (prevKV *mvcc.RangeResult, err error) { + prevKV, err = getPrevKV(trace, txnWrite, p) + if err != nil { + return nil, err + } + if p.IgnoreValue || p.IgnoreLease { + if prevKV == nil || len(prevKV.KVs) == 0 { + // ignore_{lease,value} flag expects previous key-value pair + return nil, errors.ErrKeyNotFound + } + } + return prevKV, nil +} + +func getPrevKV(trace *traceutil.Trace, txnWrite mvcc.ReadView, p *pb.PutRequest) (prevKV *mvcc.RangeResult, err error) { + if p.IgnoreValue || p.IgnoreLease || p.PrevKv { + trace.StepWithFunction(func() { + prevKV, err = txnWrite.Range(context.TODO(), p.Key, nil, mvcc.RangeOptions{}) + }, "get previous kv pair") + + if err != nil { + return nil, err + } + } + return prevKV, nil +} + func checkRange(rv mvcc.ReadView, req *pb.RangeRequest) error { switch { case req.Revision == 0: @@ -421,7 +434,7 @@ func checkRange(rv mvcc.ReadView, req *pb.RangeRequest) error { return nil } -func checkTxn(rv mvcc.ReadView, rt *pb.TxnRequest, lessor lease.Lessor, txnPath []bool) (int, error) { +func checkTxn(trace *traceutil.Trace, rv mvcc.ReadView, rt *pb.TxnRequest, lessor lease.Lessor, txnPath []bool) (int, error) { txnCount := 0 reqs := rt.Success if !txnPath[0] { @@ -434,10 +447,10 @@ func checkTxn(rv mvcc.ReadView, rt *pb.TxnRequest, lessor lease.Lessor, txnPath case *pb.RequestOp_RequestRange: err = checkRange(rv, tv.RequestRange) case *pb.RequestOp_RequestPut: - err = checkPut(rv, lessor, tv.RequestPut) + err = checkPut(trace, rv, lessor, tv.RequestPut) case *pb.RequestOp_RequestDeleteRange: case *pb.RequestOp_RequestTxn: - txns, err = checkTxn(rv, tv.RequestTxn, lessor, txnPath[1:]) + txns, err = checkTxn(trace, rv, tv.RequestTxn, lessor, txnPath[1:]) txnCount += txns + 1 txnPath = txnPath[txns+1:] default: