From 6056fdbddcc493e4f41545bc05ccc1c031fa1a3f Mon Sep 17 00:00:00 2001 From: Noah Hsu Date: Mon, 13 Jun 2022 20:24:13 +0800 Subject: [PATCH] feat: use singleflight to prevent cache breakdown --- internal/operations/fs.go | 20 +- pkg/singleflight/signleflight_test.go | 320 ++++++++++++++++++++++++++ pkg/singleflight/singleflight.go | 212 +++++++++++++++++ 3 files changed, 545 insertions(+), 7 deletions(-) create mode 100644 pkg/singleflight/signleflight_test.go create mode 100644 pkg/singleflight/singleflight.go diff --git a/internal/operations/fs.go b/internal/operations/fs.go index df494a72dc4..938854cc603 100644 --- a/internal/operations/fs.go +++ b/internal/operations/fs.go @@ -5,6 +5,7 @@ import ( "github.com/Xhofe/go-cache" "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/singleflight" "github.com/alist-org/alist/v3/pkg/utils" "github.com/pkg/errors" stdpath "path" @@ -52,6 +53,7 @@ func Get(ctx context.Context, account driver.Driver, path string) (driver.FileIn } var linkCache = cache.NewMemCache[*driver.Link]() +var linkG singleflight.Group[*driver.Link] // Link get link, if is an url. should have an expiry time func Link(ctx context.Context, account driver.Driver, path string, args driver.LinkArgs) (*driver.Link, error) { @@ -59,14 +61,18 @@ func Link(ctx context.Context, account driver.Driver, path string, args driver.L if link, ok := linkCache.Get(key); ok { return link, nil } - link, err := account.Link(ctx, path, args) - if err != nil { - return nil, errors.WithMessage(err, "failed get link") - } - if link.Expiration != nil { - linkCache.Set(key, link, cache.WithEx(*link.Expiration)) + fn := func() (*driver.Link, error) { + link, err := account.Link(ctx, path, args) + if err != nil { + return nil, errors.WithMessage(err, "failed get link") + } + if link.Expiration != nil { + linkCache.Set(key, link, cache.WithEx[*driver.Link](*link.Expiration)) + } + return link, nil } - return link, nil + link, err, _ := linkG.Do(key, fn) + return link, err } func MakeDir(ctx context.Context, account driver.Driver, path string) error { diff --git a/pkg/singleflight/signleflight_test.go b/pkg/singleflight/signleflight_test.go new file mode 100644 index 00000000000..34250299776 --- /dev/null +++ b/pkg/singleflight/signleflight_test.go @@ -0,0 +1,320 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package singleflight + +import ( + "bytes" + "errors" + "fmt" + "os" + "os/exec" + "runtime" + "runtime/debug" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestDo(t *testing.T) { + var g Group[string] + v, err, _ := g.Do("key", func() (string, error) { + return "bar", nil + }) + if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want { + t.Errorf("Do = %v; want %v", got, want) + } + if err != nil { + t.Errorf("Do error = %v", err) + } +} + +func TestDoErr(t *testing.T) { + var g Group[any] + someErr := errors.New("Some error") + v, err, _ := g.Do("key", func() (any, error) { + return nil, someErr + }) + if err != someErr { + t.Errorf("Do error = %v; want someErr %v", err, someErr) + } + if v != nil { + t.Errorf("unexpected non-nil value %#v", v) + } +} + +func TestDoDupSuppress(t *testing.T) { + var g Group[string] + var wg1, wg2 sync.WaitGroup + c := make(chan string, 1) + var calls int32 + fn := func() (string, error) { + if atomic.AddInt32(&calls, 1) == 1 { + // First invocation. + wg1.Done() + } + v := <-c + c <- v // pump; make available for any future calls + + time.Sleep(10 * time.Millisecond) // let more goroutines enter Do + + return v, nil + } + + const n = 10 + wg1.Add(1) + for i := 0; i < n; i++ { + wg1.Add(1) + wg2.Add(1) + go func() { + defer wg2.Done() + wg1.Done() + v, err, _ := g.Do("key", fn) + if err != nil { + t.Errorf("Do error: %v", err) + return + } + if v != "bar" { + t.Errorf("Do = %T %v; want %q", v, v, "bar") + } + }() + } + wg1.Wait() + // At least one goroutine is in fn now and all of them have at + // least reached the line before the Do. + c <- "bar" + wg2.Wait() + if got := atomic.LoadInt32(&calls); got <= 0 || got >= n { + t.Errorf("number of calls = %d; want over 0 and less than %d", got, n) + } +} + +// Test that singleflight behaves correctly after Forget called. +// See https://github.com/golang/go/issues/31420 +func TestForget(t *testing.T) { + var g Group[any] + + var ( + firstStarted = make(chan struct{}) + unblockFirst = make(chan struct{}) + firstFinished = make(chan struct{}) + ) + + go func() { + g.Do("key", func() (i any, e error) { + close(firstStarted) + <-unblockFirst + close(firstFinished) + return + }) + }() + <-firstStarted + g.Forget("key") + + unblockSecond := make(chan struct{}) + secondResult := g.DoChan("key", func() (i any, e error) { + <-unblockSecond + return 2, nil + }) + + close(unblockFirst) + <-firstFinished + + thirdResult := g.DoChan("key", func() (i any, e error) { + return 3, nil + }) + + close(unblockSecond) + <-secondResult + r := <-thirdResult + if r.Val != 2 { + t.Errorf("We should receive result produced by second call, expected: 2, got %d", r.Val) + } +} + +func TestDoChan(t *testing.T) { + var g Group[string] + ch := g.DoChan("key", func() (string, error) { + return "bar", nil + }) + + res := <-ch + v := res.Val + err := res.Err + if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want { + t.Errorf("Do = %v; want %v", got, want) + } + if err != nil { + t.Errorf("Do error = %v", err) + } +} + +// Test singleflight behaves correctly after Do panic. +// See https://github.com/golang/go/issues/41133 +func TestPanicDo(t *testing.T) { + var g Group[any] + fn := func() (any, error) { + panic("invalid memory address or nil pointer dereference") + } + + const n = 5 + waited := int32(n) + panicCount := int32(0) + done := make(chan struct{}) + for i := 0; i < n; i++ { + go func() { + defer func() { + if err := recover(); err != nil { + t.Logf("Got panic: %v\n%s", err, debug.Stack()) + atomic.AddInt32(&panicCount, 1) + } + + if atomic.AddInt32(&waited, -1) == 0 { + close(done) + } + }() + + g.Do("key", fn) + }() + } + + select { + case <-done: + if panicCount != n { + t.Errorf("Expect %d panic, but got %d", n, panicCount) + } + case <-time.After(time.Second): + t.Fatalf("Do hangs") + } +} + +func TestGoexitDo(t *testing.T) { + var g Group[any] + fn := func() (any, error) { + runtime.Goexit() + return nil, nil + } + + const n = 5 + waited := int32(n) + done := make(chan struct{}) + for i := 0; i < n; i++ { + go func() { + var err error + defer func() { + if err != nil { + t.Errorf("Error should be nil, but got: %v", err) + } + if atomic.AddInt32(&waited, -1) == 0 { + close(done) + } + }() + _, err, _ = g.Do("key", fn) + }() + } + + select { + case <-done: + case <-time.After(time.Second): + t.Fatalf("Do hangs") + } +} + +func TestPanicDoChan(t *testing.T) { + if runtime.GOOS == "js" { + t.Skipf("js does not support exec") + } + + if os.Getenv("TEST_PANIC_DOCHAN") != "" { + defer func() { + recover() + }() + + g := new(Group[any]) + ch := g.DoChan("", func() (any, error) { + panic("Panicking in DoChan") + }) + <-ch + t.Fatalf("DoChan unexpectedly returned") + } + + t.Parallel() + + cmd := exec.Command(os.Args[0], "-test.run="+t.Name(), "-test.v") + cmd.Env = append(os.Environ(), "TEST_PANIC_DOCHAN=1") + out := new(bytes.Buffer) + cmd.Stdout = out + cmd.Stderr = out + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + + err := cmd.Wait() + t.Logf("%s:\n%s", strings.Join(cmd.Args, " "), out) + if err == nil { + t.Errorf("Test subprocess passed; want a crash due to panic in DoChan") + } + if bytes.Contains(out.Bytes(), []byte("DoChan unexpectedly")) { + t.Errorf("Test subprocess failed with an unexpected failure mode.") + } + if !bytes.Contains(out.Bytes(), []byte("Panicking in DoChan")) { + t.Errorf("Test subprocess failed, but the crash isn't caused by panicking in DoChan") + } +} + +func TestPanicDoSharedByDoChan(t *testing.T) { + if runtime.GOOS == "js" { + t.Skipf("js does not support exec") + } + + if os.Getenv("TEST_PANIC_DOCHAN") != "" { + blocked := make(chan struct{}) + unblock := make(chan struct{}) + + g := new(Group[any]) + go func() { + defer func() { + recover() + }() + g.Do("", func() (any, error) { + close(blocked) + <-unblock + panic("Panicking in Do") + }) + }() + + <-blocked + ch := g.DoChan("", func() (any, error) { + panic("DoChan unexpectedly executed callback") + }) + close(unblock) + <-ch + t.Fatalf("DoChan unexpectedly returned") + } + + t.Parallel() + + cmd := exec.Command(os.Args[0], "-test.run="+t.Name(), "-test.v") + cmd.Env = append(os.Environ(), "TEST_PANIC_DOCHAN=1") + out := new(bytes.Buffer) + cmd.Stdout = out + cmd.Stderr = out + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + + err := cmd.Wait() + t.Logf("%s:\n%s", strings.Join(cmd.Args, " "), out) + if err == nil { + t.Errorf("Test subprocess passed; want a crash due to panic in Do shared by DoChan") + } + if bytes.Contains(out.Bytes(), []byte("DoChan unexpectedly")) { + t.Errorf("Test subprocess failed with an unexpected failure mode.") + } + if !bytes.Contains(out.Bytes(), []byte("Panicking in Do")) { + t.Errorf("Test subprocess failed, but the crash isn't caused by panicking in Do") + } +} diff --git a/pkg/singleflight/singleflight.go b/pkg/singleflight/singleflight.go new file mode 100644 index 00000000000..dcd84a3b9d4 --- /dev/null +++ b/pkg/singleflight/singleflight.go @@ -0,0 +1,212 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package singleflight provides a duplicate function call suppression +// mechanism. +package singleflight + +import ( + "bytes" + "errors" + "fmt" + "runtime" + "runtime/debug" + "sync" +) + +// errGoexit indicates the runtime.Goexit was called in +// the user given function. +var errGoexit = errors.New("runtime.Goexit was called") + +// A panicError is an arbitrary value recovered from a panic +// with the stack trace during the execution of given function. +type panicError struct { + value any + stack []byte +} + +// Error implements error interface. +func (p *panicError) Error() string { + return fmt.Sprintf("%v\n\n%s", p.value, p.stack) +} + +func newPanicError(v any) error { + stack := debug.Stack() + + // The first line of the stack trace is of the form "goroutine N [status]:" + // but by the time the panic reaches Do the goroutine may no longer exist + // and its status will have changed. Trim out the misleading line. + if line := bytes.IndexByte(stack[:], '\n'); line >= 0 { + stack = stack[line+1:] + } + return &panicError{value: v, stack: stack} +} + +// call is an in-flight or completed singleflight.Do call +type call[T any] struct { + wg sync.WaitGroup + + // These fields are written once before the WaitGroup is done + // and are only read after the WaitGroup is done. + val T + err error + + // forgotten indicates whether Forget was called with this call's key + // while the call was still in flight. + forgotten bool + + // These fields are read and written with the singleflight + // mutex held before the WaitGroup is done, and are read but + // not written after the WaitGroup is done. + dups int + chans []chan<- Result[T] +} + +// Group represents a class of work and forms a namespace in +// which units of work can be executed with duplicate suppression. +type Group[T any] struct { + mu sync.Mutex // protects m + m map[string]*call[T] // lazily initialized +} + +// Result holds the results of Do, so they can be passed +// on a channel. +type Result[T any] struct { + Val T + Err error + Shared bool +} + +// Do executes and returns the results of the given function, making +// sure that only one execution is in-flight for a given key at a +// time. If a duplicate comes in, the duplicate caller waits for the +// original to complete and receives the same results. +// The return value shared indicates whether v was given to multiple callers. +func (g *Group[T]) Do(key string, fn func() (T, error)) (v T, err error, shared bool) { + g.mu.Lock() + if g.m == nil { + g.m = make(map[string]*call[T]) + } + if c, ok := g.m[key]; ok { + c.dups++ + g.mu.Unlock() + c.wg.Wait() + + if e, ok := c.err.(*panicError); ok { + panic(e) + } else if c.err == errGoexit { + runtime.Goexit() + } + return c.val, c.err, true + } + c := new(call[T]) + c.wg.Add(1) + g.m[key] = c + g.mu.Unlock() + + g.doCall(c, key, fn) + return c.val, c.err, c.dups > 0 +} + +// DoChan is like Do but returns a channel that will receive the +// results when they are ready. +// +// The returned channel will not be closed. +func (g *Group[T]) DoChan(key string, fn func() (T, error)) <-chan Result[T] { + ch := make(chan Result[T], 1) + g.mu.Lock() + if g.m == nil { + g.m = make(map[string]*call[T]) + } + if c, ok := g.m[key]; ok { + c.dups++ + c.chans = append(c.chans, ch) + g.mu.Unlock() + return ch + } + c := &call[T]{chans: []chan<- Result[T]{ch}} + c.wg.Add(1) + g.m[key] = c + g.mu.Unlock() + + go g.doCall(c, key, fn) + + return ch +} + +// doCall handles the single call for a key. +func (g *Group[T]) doCall(c *call[T], key string, fn func() (T, error)) { + normalReturn := false + recovered := false + + // use double-defer to distinguish panic from runtime.Goexit, + // more details see https://golang.org/cl/134395 + defer func() { + // the given function invoked runtime.Goexit + if !normalReturn && !recovered { + c.err = errGoexit + } + + c.wg.Done() + g.mu.Lock() + defer g.mu.Unlock() + if !c.forgotten { + delete(g.m, key) + } + + if e, ok := c.err.(*panicError); ok { + // In order to prevent the waiting channels from being blocked forever, + // needs to ensure that this panic cannot be recovered. + if len(c.chans) > 0 { + go panic(e) + select {} // Keep this goroutine around so that it will appear in the crash dump. + } else { + panic(e) + } + } else if c.err == errGoexit { + // Already in the process of goexit, no need to call again + } else { + // Normal return + for _, ch := range c.chans { + ch <- Result[T]{c.val, c.err, c.dups > 0} + } + } + }() + + func() { + defer func() { + if !normalReturn { + // Ideally, we would wait to take a stack trace until we've determined + // whether this is a panic or a runtime.Goexit. + // + // Unfortunately, the only way we can distinguish the two is to see + // whether the recover stopped the goroutine from terminating, and by + // the time we know that, the part of the stack trace relevant to the + // panic has been discarded. + if r := recover(); r != nil { + c.err = newPanicError(r) + } + } + }() + + c.val, c.err = fn() + normalReturn = true + }() + + if !normalReturn { + recovered = true + } +} + +// Forget tells the singleflight to forget about a key. Future calls +// to Do for this key will call the function rather than waiting for +// an earlier call to complete. +func (g *Group[T]) Forget(key string) { + g.mu.Lock() + if c, ok := g.m[key]; ok { + c.forgotten = true + } + delete(g.m, key) + g.mu.Unlock() +}