diff --git a/frontend/gateway/forwarder/forward.go b/frontend/gateway/forwarder/forward.go index b866de818c93..c4b32e961b0e 100644 --- a/frontend/gateway/forwarder/forward.go +++ b/frontend/gateway/forwarder/forward.go @@ -33,7 +33,6 @@ func llbBridgeToGatewayClient(ctx context.Context, llbBridge frontend.FrontendLL sid: sid, sm: sm, workers: w, - final: map[*ref]struct{}{}, workerRefByID: make(map[string]*worker.WorkerRef), } bc.buildOpts = bc.loadBuildOpts() @@ -45,7 +44,6 @@ type bridgeClient struct { mu sync.Mutex opts map[string]string inputs map[string]*opspb.Definition - final map[*ref]struct{} sid string sm *session.Manager refs []*ref @@ -194,8 +192,7 @@ func (c *bridgeClient) toFrontendResult(r *client.Result) (*frontend.Result, err if !ok { return nil, errors.Errorf("invalid reference type for forward %T", r) } - c.final[rr] = struct{}{} - res.Refs[k] = rr.ResultProxy + res.Refs[k] = rr.acquireResultProxy() } } if r := r.Ref; r != nil { @@ -203,8 +200,7 @@ func (c *bridgeClient) toFrontendResult(r *client.Result) (*frontend.Result, err if !ok { return nil, errors.Errorf("invalid reference type for forward %T", r) } - c.final[rr] = struct{}{} - res.Ref = rr.ResultProxy + res.Ref = rr.acquireResultProxy() } res.Metadata = r.Metadata @@ -218,8 +214,11 @@ func (c *bridgeClient) discard(err error) { } for _, r := range c.refs { if r != nil { - if _, ok := c.final[r]; !ok || err != nil { - r.Release(context.TODO()) + r.resultProxy.Release(context.TODO()) + if err != nil { + for _, clone := range r.resultProxyClones { + clone.Release(context.TODO()) + } } } } @@ -248,7 +247,7 @@ func (c *bridgeClient) NewContainer(ctx context.Context, req client.NewContainer return errors.Errorf("unexpected Ref type: %T", m.Ref) } - res, err := refProxy.Result(ctx) + res, err := refProxy.resultProxy.Result(ctx) if err != nil { return err } @@ -304,17 +303,26 @@ func (c *bridgeClient) NewContainer(ctx context.Context, req client.NewContainer } type ref struct { - solver.ResultProxy + resultProxy solver.ResultProxy + resultProxyClones []solver.ResultProxy + session session.Group c *bridgeClient } func (c *bridgeClient) newRef(r solver.ResultProxy, s session.Group) (*ref, error) { - return &ref{ResultProxy: r, session: s, c: c}, nil + return &ref{resultProxy: r, session: s, c: c}, nil +} + +func (r *ref) acquireResultProxy() solver.ResultProxy { + s1, s2 := solver.SplitResultProxy(r.resultProxy) + r.resultProxy = s1 + r.resultProxyClones = append(r.resultProxyClones, s2) + return s2 } func (r *ref) ToState() (st llb.State, err error) { - defop, err := llb.NewDefinitionOp(r.Definition()) + defop, err := llb.NewDefinitionOp(r.resultProxy.Definition()) if err != nil { return st, err } @@ -359,7 +367,7 @@ func (r *ref) StatFile(ctx context.Context, req client.StatRequest) (*fstypes.St } func (r *ref) getMountable(ctx context.Context) (snapshot.Mountable, error) { - rr, err := r.ResultProxy.Result(ctx) + rr, err := r.resultProxy.Result(ctx) if err != nil { return nil, r.c.wrapSolveError(err) } diff --git a/frontend/gateway/gateway.go b/frontend/gateway/gateway.go index 5bf1ccd86456..b983d37a60be 100644 --- a/frontend/gateway/gateway.go +++ b/frontend/gateway/gateway.go @@ -349,22 +349,16 @@ func (lbf *llbBridgeForwarder) Discard() { workerRef.ImmutableRef.Release(context.TODO()) delete(lbf.workerRefByID, id) } - for id, r := range lbf.refs { - if lbf.err == nil && lbf.result != nil { - keep := false - lbf.result.EachRef(func(r2 solver.ResultProxy) error { - if r == r2 { - keep = true - } - return nil - }) - if keep { - continue - } - } + if lbf.err != nil && lbf.result != nil { + lbf.result.EachRef(func(r solver.ResultProxy) error { + r.Release(context.TODO()) + return nil + }) + } + for _, r := range lbf.refs { r.Release(context.TODO()) - delete(lbf.refs, id) } + lbf.refs = map[string]solver.ResultProxy{} } func (lbf *llbBridgeForwarder) Done() <-chan struct{} { @@ -864,7 +858,7 @@ func (lbf *llbBridgeForwarder) Return(ctx context.Context, in *pb.ReturnRequest) switch res := in.Result.Result.(type) { case *pb.Result_RefDeprecated: - ref, err := lbf.convertRef(res.RefDeprecated) + ref, err := lbf.cloneRef(res.RefDeprecated) if err != nil { return nil, err } @@ -872,7 +866,7 @@ func (lbf *llbBridgeForwarder) Return(ctx context.Context, in *pb.ReturnRequest) case *pb.Result_RefsDeprecated: m := map[string]solver.ResultProxy{} for k, id := range res.RefsDeprecated.Refs { - ref, err := lbf.convertRef(id) + ref, err := lbf.cloneRef(id) if err != nil { return nil, err } @@ -880,7 +874,7 @@ func (lbf *llbBridgeForwarder) Return(ctx context.Context, in *pb.ReturnRequest) } r.Refs = m case *pb.Result_Ref: - ref, err := lbf.convertRef(res.Ref.Id) + ref, err := lbf.cloneRef(res.Ref.Id) if err != nil { return nil, err } @@ -888,7 +882,7 @@ func (lbf *llbBridgeForwarder) Return(ctx context.Context, in *pb.ReturnRequest) case *pb.Result_Refs: m := map[string]solver.ResultProxy{} for k, ref := range res.Refs.Refs { - ref, err := lbf.convertRef(ref.Id) + ref, err := lbf.cloneRef(ref.Id) if err != nil { return nil, err } @@ -1383,10 +1377,27 @@ func (lbf *llbBridgeForwarder) convertRef(id string) (solver.ResultProxy, error) if !ok { return nil, errors.Errorf("return reference %s not found", id) } - return r, nil } +func (lbf *llbBridgeForwarder) cloneRef(id string) (solver.ResultProxy, error) { + if id == "" { + return nil, nil + } + + lbf.mu.Lock() + defer lbf.mu.Unlock() + + r, ok := lbf.refs[id] + if !ok { + return nil, errors.Errorf("return reference %s not found", id) + } + + s1, s2 := solver.SplitResultProxy(r) + lbf.refs[id] = s1 + return s2, nil +} + func serve(ctx context.Context, grpcServer *grpc.Server, conn net.Conn) { go func() { <-ctx.Done() diff --git a/solver/llbsolver/bridge.go b/solver/llbsolver/bridge.go index 0f1892feb550..0bf8ac905354 100644 --- a/solver/llbsolver/bridge.go +++ b/solver/llbsolver/bridge.go @@ -188,8 +188,8 @@ func (b *llbBridge) Solve(ctx context.Context, req frontend.SolveRequest, sid st } type resultProxy struct { - cb func(context.Context) (solver.CachedResult, solver.BuildSources, error) - def *pb.Definition + b *llbBridge + req frontend.SolveRequest g flightcontrol.Group mu sync.Mutex released bool @@ -200,27 +200,11 @@ type resultProxy struct { } func newResultProxy(b *llbBridge, req frontend.SolveRequest) *resultProxy { - rp := &resultProxy{ - def: req.Definition, - } - rp.cb = func(ctx context.Context) (solver.CachedResult, solver.BuildSources, error) { - res, bsrc, err := b.loadResult(ctx, req.Definition, req.CacheImports) - var ee *llberrdefs.ExecError - if errors.As(err, &ee) { - ee.EachRef(func(res solver.Result) error { - rp.errResults = append(rp.errResults, res) - return nil - }) - // acquire ownership so ExecError finalizer doesn't attempt to release as well - ee.OwnerBorrowed = true - } - return res, bsrc, err - } - return rp + return &resultProxy{req: req, b: b} } func (rp *resultProxy) Definition() *pb.Definition { - return rp.def + return rp.req.Definition } func (rp *resultProxy) BuildSources() solver.BuildSources { @@ -255,12 +239,12 @@ func (rp *resultProxy) wrapError(err error) error { } var ve *errdefs.VertexError if errors.As(err, &ve) { - if rp.def.Source != nil { - locs, ok := rp.def.Source.Locations[string(ve.Digest)] + if rp.req.Definition.Source != nil { + locs, ok := rp.req.Definition.Source.Locations[string(ve.Digest)] if ok { for _, loc := range locs.Locations { err = errdefs.WithSource(err, errdefs.Source{ - Info: rp.def.Source.Infos[loc.SourceIndex], + Info: rp.req.Definition.Source.Infos[loc.SourceIndex], Ranges: loc.Ranges, }) } @@ -270,6 +254,20 @@ func (rp *resultProxy) wrapError(err error) error { return err } +func (rp *resultProxy) loadResult(ctx context.Context) (solver.CachedResult, solver.BuildSources, error) { + res, bsrc, err := rp.b.loadResult(ctx, rp.req.Definition, rp.req.CacheImports) + var ee *llberrdefs.ExecError + if errors.As(err, &ee) { + ee.EachRef(func(res solver.Result) error { + rp.errResults = append(rp.errResults, res) + return nil + }) + // acquire ownership so ExecError finalizer doesn't attempt to release as well + ee.OwnerBorrowed = true + } + return res, bsrc, err +} + func (rp *resultProxy) Result(ctx context.Context) (res solver.CachedResult, err error) { defer func() { err = rp.wrapError(err) @@ -285,7 +283,7 @@ func (rp *resultProxy) Result(ctx context.Context) (res solver.CachedResult, err return rp.v, rp.err } rp.mu.Unlock() - v, bsrc, err := rp.cb(ctx) + v, bsrc, err := rp.loadResult(ctx) if err != nil { select { case <-ctx.Done(): diff --git a/solver/result.go b/solver/result.go index 81766a30f4bc..2ba1ef9bc1b6 100644 --- a/solver/result.go +++ b/solver/result.go @@ -108,3 +108,26 @@ type SharedCachedResult struct { *SharedResult CachedResult } + +type splitResultProxy struct { + released int64 + sem *int64 + ResultProxy +} + +func (r *splitResultProxy) Release(ctx context.Context) error { + if atomic.AddInt64(&r.released, 1) > 1 { + err := errors.New("releasing already released reference") + bklog.G(ctx).Error(err) + return err + } + if atomic.AddInt64(r.sem, 1) == 2 { + return r.ResultProxy.Release(ctx) + } + return nil +} + +func SplitResultProxy(res ResultProxy) (ResultProxy, ResultProxy) { + sem := int64(0) + return &splitResultProxy{ResultProxy: res, sem: &sem}, &splitResultProxy{ResultProxy: res, sem: &sem} +}