Skip to content

Commit

Permalink
Support storage
Browse files Browse the repository at this point in the history
  • Loading branch information
wzshiming committed Jun 6, 2024
1 parent a1c1ef7 commit a984ed9
Show file tree
Hide file tree
Showing 6 changed files with 1,170 additions and 346 deletions.
39 changes: 39 additions & 0 deletions cmd/crproxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,22 @@ import (
"log"
"net"
"net/http"
"net/url"
"os"
"slices"
"strings"
"time"

"github.com/distribution/distribution/v3/registry/storage/driver/factory"
"github.com/gorilla/handlers"
"github.com/spf13/pflag"
"github.com/wzshiming/geario"

_ "github.com/distribution/distribution/v3/registry/storage/driver/azure"
_ "github.com/distribution/distribution/v3/registry/storage/driver/gcs"
_ "github.com/distribution/distribution/v3/registry/storage/driver/s3-aws"
_ "github.com/wzshiming/crproxy/storage/driver/oss"

"github.com/wzshiming/crproxy"
)

Expand All @@ -27,6 +34,10 @@ var (
blockImageList []string
retry int
retryInterval time.Duration
storageDriver string
storageParameters map[string]string
linkExpires time.Duration
redirectLinks string
)

func init() {
Expand All @@ -38,6 +49,10 @@ func init() {
pflag.StringSliceVar(&blockImageList, "block-image-list", nil, "block image list")
pflag.IntVar(&retry, "retry", 0, "retry times")
pflag.DurationVar(&retryInterval, "retry-interval", 0, "retry interval")
pflag.StringVar(&storageDriver, "storage-driver", "", "storage driver")
pflag.StringToStringVar(&storageParameters, "storage-parameters", nil, "storage parameters")
pflag.DurationVar(&linkExpires, "link-expires", 0, "link expires")
pflag.StringVar(&redirectLinks, "redirect-links", "", "redirect links")
pflag.Parse()
}

Expand Down Expand Up @@ -105,6 +120,30 @@ func main() {
crproxy.WithDisableKeepAlives(disableKeepAlives),
}

if storageDriver != "" {
parameters := map[string]interface{}{}
for k, v := range storageParameters {
parameters[k] = v
}
sd, err := factory.Create(storageDriver, parameters)
if err != nil {
logger.Println("create storage driver failed:", err)
os.Exit(1)
}
opts = append(opts, crproxy.WithStorageDriver(sd))
if linkExpires > 0 {
opts = append(opts, crproxy.WithLinkExpires(linkExpires))
}
if redirectLinks != "" {
u, err := url.Parse(redirectLinks)
if err != nil {
logger.Println("parse redirect links failed:", err)
os.Exit(1)
}
opts = append(opts, crproxy.WithRedirectLinks(u))
}
}

if len(blockImageList) != 0 {
opts = append(opts, crproxy.WithBlockFunc(func(info *crproxy.PathInfo) bool {
image := info.Host + "/" + info.Image
Expand Down
197 changes: 191 additions & 6 deletions crproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@ package crproxy

import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"io"
"net/http"
"net/textproto"
"net/url"
"path"
"strings"
"sync"
"time"
Expand All @@ -15,6 +19,7 @@ import (
"github.com/distribution/distribution/v3/registry/client/auth"
"github.com/distribution/distribution/v3/registry/client/auth/challenge"
"github.com/distribution/distribution/v3/registry/client/transport"
storagedriver "github.com/distribution/distribution/v3/registry/storage/driver"
"github.com/wzshiming/geario"
"github.com/wzshiming/httpseek"
"github.com/wzshiming/lru"
Expand All @@ -41,18 +46,40 @@ type CRProxy struct {
domainAlias map[string]string
userAndPass map[string]Userpass
basicCredentials *basicCredentials
mut sync.Mutex
mutClientset sync.Mutex
bytesPool sync.Pool
logger Logger
totalBlobsSpeedLimit *geario.Gear
blobsSpeedLimit *geario.B
blockFunc func(*PathInfo) bool
retry int
retryInterval time.Duration
storageDriver storagedriver.StorageDriver
linkExpires time.Duration
mutCache sync.Map
redirectLinks *url.URL
}

type Option func(c *CRProxy)

func WithLinkExpires(d time.Duration) Option {
return func(c *CRProxy) {
c.linkExpires = d
}
}

func WithRedirectLinks(l *url.URL) Option {
return func(c *CRProxy) {
c.redirectLinks = l
}
}

func WithStorageDriver(storageDriver storagedriver.StorageDriver) Option {
return func(c *CRProxy) {
c.storageDriver = storageDriver
}
}

func WithBlobsSpeedLimit(limit geario.B) Option {
return func(c *CRProxy) {
c.blobsSpeedLimit = &limit
Expand Down Expand Up @@ -163,8 +190,8 @@ func (c *CRProxy) getScheme(host string) string {
}

func (c *CRProxy) getClientset(host string, image string) *http.Client {
c.mut.Lock()
defer c.mut.Unlock()
c.mutClientset.Lock()
defer c.mutClientset.Unlock()
if c.clientset[host] != nil {
client, ok := c.clientset[host].Get(image)
if ok {
Expand Down Expand Up @@ -249,8 +276,8 @@ func (c *CRProxy) disableKeepAlives(rt http.RoundTripper) http.RoundTripper {
}

func (c *CRProxy) ping(host string) error {
c.mut.Lock()
defer c.mut.Unlock()
c.mutClientset.Lock()
defer c.mutClientset.Unlock()

if c.logger != nil {
c.logger.Println("ping", host)
Expand Down Expand Up @@ -336,7 +363,7 @@ func (c *CRProxy) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
return
}
if !strings.HasPrefix(oriPath, prefix) {
http.NotFound(rw, r)
c.notFoundResponse(rw, r)
return
}
if oriPath == catalog {
Expand Down Expand Up @@ -374,6 +401,14 @@ func (c *CRProxy) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
r.URL.Scheme = c.getScheme(info.Host)
r.URL.Path = path

if c.storageDriver != nil && info.Blobs != "" {
c.cacheBlobResponse(rw, r, info)
return
}
c.directResponse(rw, r, info)
}

func (c *CRProxy) directResponse(rw http.ResponseWriter, r *http.Request, info *PathInfo) {
cli := c.getClientset(info.Host, info.Image)
resp, err := c.doWithAuth(cli, r, info.Host)
if err != nil {
Expand Down Expand Up @@ -418,6 +453,156 @@ func (c *CRProxy) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
}
}

func (c *CRProxy) cacheBlobResponse(rw http.ResponseWriter, r *http.Request, info *PathInfo) {
ctx := r.Context()

blob := strings.TrimPrefix(info.Blobs, "sha256:")
blobPath := path.Join("/docker/registry/v2/blobs/sha256", blob[:2], blob, "data")

closeValue, loaded := c.mutCache.LoadOrStore(blobPath, make(chan struct{}))
closeCh := closeValue.(chan struct{})
for loaded {
select {
case <-ctx.Done():
err := ctx.Err().Error()
if c.logger != nil {
c.logger.Println(err)
}
http.Error(rw, err, http.StatusInternalServerError)
return
case <-closeCh:
}
closeValue, loaded = c.mutCache.LoadOrStore(blobPath, make(chan struct{}))
closeCh = closeValue.(chan struct{})
}

doneCache := func() {
c.mutCache.Delete(blobPath)
close(closeCh)
}

_, err := c.storageDriver.Stat(ctx, blobPath)
if err == nil {
err = c.redirect(rw, r, blobPath)
if err == nil {
doneCache()
return
}
c.errorResponse(rw, r, ctx.Err())
return
} else {
if c.logger != nil {
c.logger.Println("Cache miss", blobPath)
}
}

errCh := make(chan error, 1)

go func() {
defer doneCache()
err = c.cacheBlobContent(r, blobPath, info)
errCh <- err
}()

select {
case <-ctx.Done():
c.errorResponse(rw, r, ctx.Err())
return
case err := <-errCh:
if err != nil {
c.errorResponse(rw, r, err)
return
}
err = c.redirect(rw, r, blobPath)
if err != nil {
if c.logger != nil {
c.logger.Println("failed to redirect", blobPath, err)
}
}
return
}
}

func (c *CRProxy) cacheBlobContent(r *http.Request, blobPath string, info *PathInfo) error {
cli := c.getClientset(info.Host, info.Image)
resp, err := c.doWithAuth(cli, r, info.Host)
if err != nil {
return err
}
defer func() {
resp.Body.Close()
}()

buf := c.bytesPool.Get().([]byte)
defer c.bytesPool.Put(buf)

fw, err := c.storageDriver.Writer(context.Background(), blobPath, false)
if err != nil {
return err
}

h := sha256.New()
n, err := io.CopyBuffer(fw, io.TeeReader(resp.Body, h), buf)
if err != nil {
fw.Cancel()
return err
}

if n != resp.ContentLength {
fw.Cancel()
return fmt.Errorf("expected %d bytes, got %d", resp.ContentLength, n)
}

hash := hex.EncodeToString(h.Sum(nil)[:])
if info.Blobs[7:] != hash {
fw.Cancel()
return fmt.Errorf("expected %s hash, got %s", info.Blobs[7:], hash)
}

return fw.Commit()
}

func (c *CRProxy) errorResponse(rw http.ResponseWriter, r *http.Request, err error) {
if err != nil {
e := err.Error()
if c.logger != nil {
c.logger.Println(e)
}
}
errcode.ServeJSON(rw, errcode.ErrorCodeUnknown)
}

func (c *CRProxy) notFoundResponse(rw http.ResponseWriter, r *http.Request) {
http.NotFound(rw, r)
}

func (c *CRProxy) redirect(rw http.ResponseWriter, r *http.Request, blobPath string) error {
options := map[string]interface{}{
"method": http.MethodGet,
}
linkExpires := c.linkExpires
if linkExpires > 0 {
options["expiry"] = time.Now().Add(linkExpires)
}
u, err := c.storageDriver.URLFor(r.Context(), blobPath, options)
if err != nil {
return err
}
if c.logger != nil {
c.logger.Println("Cache hit", blobPath, u)
}
if c.redirectLinks != nil {
uri, err := url.Parse(u)
if err == nil {
uri.Scheme = c.redirectLinks.Scheme
uri.Host = c.redirectLinks.Host
u = uri.String()
}
}
http.Redirect(rw, r, u, http.StatusTemporaryRedirect)
return nil
}

func (c *CRProxy) getDomainAlias(host string) string {
if c.domainAlias == nil {
return host
Expand Down
Loading

0 comments on commit a984ed9

Please sign in to comment.