From 03ae364bd8881ef90821744951bb08720722d3b9 Mon Sep 17 00:00:00 2001 From: Shiming Zhang Date: Tue, 10 Dec 2024 18:46:03 +0800 Subject: [PATCH] Add transport --- clientset/clientset.go | 304 --------------------------------------- clientset/credentials.go | 98 ------------- cmd/crproxy/main.go | 46 ++---- cmd/crproxy/sync/sync.go | 17 ++- crproxy.go | 53 +++++-- crproxy_blob.go | 6 +- crproxy_manifest.go | 3 +- go.mod | 7 +- go.sum | 10 ++ sync/sync.go | 23 +-- transport/transport.go | 248 ++++++++++++++++++++++++++++++++ 11 files changed, 332 insertions(+), 483 deletions(-) delete mode 100644 clientset/clientset.go delete mode 100644 clientset/credentials.go create mode 100644 transport/transport.go diff --git a/clientset/clientset.go b/clientset/clientset.go deleted file mode 100644 index fa260c5..0000000 --- a/clientset/clientset.go +++ /dev/null @@ -1,304 +0,0 @@ -package clientset - -import ( - "context" - "errors" - "log/slog" - "net/http" - "net/url" - "sync" - "time" - - "github.com/daocloud/crproxy/internal/maps" - "github.com/docker/distribution/registry/client/auth" - "github.com/docker/distribution/registry/client/auth/challenge" - "github.com/docker/distribution/registry/client/transport" - storagedriver "github.com/docker/distribution/registry/storage/driver" - "github.com/wzshiming/httpseek" - "github.com/wzshiming/lru" -) - -const ( - prefix = "/v2/" -) - -type Clientset struct { - baseClient *http.Client - challengeManager challenge.Manager - clientset maps.SyncMap[string, *lru.LRU[string, *http.Client]] - clientSize int - insecureDomain map[string]struct{} - domainDisableKeepAlives map[string]struct{} - domainAlias map[string]string - userAndPass map[string]Userpass - basicCredentials *basicCredentials - mutClientset sync.Mutex - logger *slog.Logger - retry int - retryInterval time.Duration - storageDriver storagedriver.StorageDriver - mutCache sync.Map - allowHeadMethod bool -} - -type Option func(c *Clientset) - -func WithStorageDriver(storageDriver storagedriver.StorageDriver) Option { - return func(c *Clientset) { - c.storageDriver = storageDriver - } -} - -func WithBaseClient(baseClient *http.Client) Option { - return func(c *Clientset) { - c.baseClient = baseClient - } -} - -func WithLogger(logger *slog.Logger) Option { - return func(c *Clientset) { - c.logger = logger - } -} - -func WithUserAndPass(userAndPass map[string]Userpass) Option { - return func(c *Clientset) { - c.userAndPass = userAndPass - } -} - -func WithDomainAlias(domainAlias map[string]string) Option { - return func(c *Clientset) { - c.domainAlias = domainAlias - } -} - -func WithMaxClientSizeForEachRegistry(clientSize int) Option { - return func(c *Clientset) { - c.clientSize = clientSize - } -} - -func WithDisableKeepAlives(disableKeepAlives []string) Option { - return func(c *Clientset) { - c.domainDisableKeepAlives = map[string]struct{}{} - for _, v := range disableKeepAlives { - c.domainDisableKeepAlives[v] = struct{}{} - } - } -} - -func WithRetry(retry int, retryInterval time.Duration) Option { - return func(c *Clientset) { - c.retry = retry - c.retryInterval = retryInterval - } -} - -func WithAllowHeadMethod(allowHeadMethod bool) Option { - return func(c *Clientset) { - c.allowHeadMethod = allowHeadMethod - } -} - -func NewClientset(opts ...Option) (*Clientset, error) { - c := &Clientset{ - logger: slog.Default(), - challengeManager: challenge.NewSimpleManager(), - clientSize: 10240, - baseClient: http.DefaultClient, - } - - for _, opt := range opts { - opt(c) - } - if len(c.userAndPass) != 0 { - bc, err := newBasicCredentials(c.userAndPass, c.getDomainAlias, c.GetScheme) - if err != nil { - return nil, err - } - c.basicCredentials = bc - } - - return c, nil -} - -func (c *Clientset) HostURL(host string) string { - return c.GetScheme(host) + "://" + host -} - -func (c *Clientset) pingURL(host string) string { - return c.HostURL(host) + prefix -} - -func (c *Clientset) GetScheme(host string) string { - if c.insecureDomain != nil { - _, ok := c.insecureDomain[host] - if ok { - return "http" - } - } - return "https" -} - -func (c *Clientset) GetClientset(host string, image string) *http.Client { - sets, hasSets := c.clientset.Load(host) - if hasSets { - client, ok := sets.Get(image) - if ok { - return client - } - } - - c.mutClientset.Lock() - defer c.mutClientset.Unlock() - if sets == nil { - sets = lru.NewLRU(c.clientSize, func(image string, client *http.Client) { - c.logger.Info("evicted client", "host", host, "image", image) - client.CloseIdleConnections() - }) - c.clientset.Store(host, sets) - } - - c.logger.Info("cache client", "host", host, "image", image) - var credentialStore auth.CredentialStore - if c.basicCredentials != nil { - credentialStore = c.basicCredentials - } - authHandler := auth.NewTokenHandler(nil, credentialStore, image, "pull") - - tr := c.baseClient.Transport - - if c.domainDisableKeepAlives != nil { - if _, ok := c.domainDisableKeepAlives[host]; ok { - tr = c.disableKeepAlives(tr) - } - } - - if c.retryInterval > 0 { - if tr == nil { - tr = http.DefaultTransport - } - tr = httpseek.NewMustReaderTransport(tr, func(request *http.Request, retry int, err error) error { - if errors.Is(err, context.Canceled) || - errors.Is(err, context.DeadlineExceeded) { - return err - } - if c.retry > 0 && retry >= c.retry { - return err - } - c.logger.Info("Retry", "url", request.URL, "retry", retry, "error", err) - time.Sleep(c.retryInterval) - return nil - }) - } - - tr = transport.NewTransport(tr, auth.NewAuthorizer(c.challengeManager, authHandler)) - - client := &http.Client{ - Transport: tr, - CheckRedirect: c.baseClient.CheckRedirect, - Timeout: c.baseClient.Timeout, - Jar: c.baseClient.Jar, - } - - sets.Put(image, client) - return client -} - -func (c *Clientset) disableKeepAlives(rt http.RoundTripper) http.RoundTripper { - if rt == nil { - tr := http.DefaultTransport.(*http.Transport).Clone() - tr.DisableKeepAlives = true - return tr - } - if tr, ok := rt.(*http.Transport); ok { - if !tr.DisableKeepAlives { - tr = tr.Clone() - tr.DisableKeepAlives = true - } - return tr - } - c.logger.Warn("failed to disable keep alives") - return rt -} - -func (c *Clientset) Ping(host string) error { - c.logger.Info("ping", "host", host) - - ep := c.pingURL(host) - e, err := url.Parse(ep) - if err != nil { - return err - } - challenges, err := c.challengeManager.GetChallenges(*e) - if err == nil && len(challenges) != 0 { - return nil - } - - resp, err := c.baseClient.Get(ep) - if err != nil { - return err - } - defer resp.Body.Close() - err = c.challengeManager.AddResponse(resp) - if err != nil { - return err - } - return nil -} - -func (c *Clientset) Do(cli *http.Client, r *http.Request) (resp *http.Response, err error) { - forHead := !c.allowHeadMethod && r.Method == http.MethodHead - if forHead { - r.Method = http.MethodGet - } - resp, err = cli.Do(r) - if err != nil { - return nil, err - } - - if forHead { - r.Method = http.MethodHead - if resp.Body != nil { - resp.Body.Close() - } - resp.Body = http.NoBody - } - return resp, err -} - -func (c *Clientset) DoWithAuth(cli *http.Client, r *http.Request, host string) (*http.Response, error) { - resp, err := c.Do(cli, r) - if err != nil { - return nil, err - } - - if resp.StatusCode == http.StatusUnauthorized { - err = c.Ping(host) - if err != nil { - c.logger.Warn("failed to ping", "host", host, "error", err) - return resp, nil - } - - resp0, err0 := c.Do(cli, r) - if err0 != nil { - c.logger.Warn("failed to redo", "host", host, "error", err) - return resp, nil - } - resp.Body.Close() - resp = resp0 - } - return resp, nil -} - -func (c *Clientset) getDomainAlias(host string) string { - if c.domainAlias == nil { - return host - } - h, ok := c.domainAlias[host] - if !ok { - return host - } - return h -} diff --git a/clientset/credentials.go b/clientset/credentials.go deleted file mode 100644 index 1e36989..0000000 --- a/clientset/credentials.go +++ /dev/null @@ -1,98 +0,0 @@ -package clientset - -import ( - "fmt" - "net/http" - "net/url" - "strings" - - "github.com/docker/distribution/registry/client/auth/challenge" -) - -type Userpass struct { - Username string - Password string -} - -func ToUserAndPass(userpass []string) (map[string]Userpass, error) { - bc := map[string]Userpass{} - for _, up := range userpass { - s := strings.SplitN(up, "@", 3) - if len(s) != 2 { - return nil, fmt.Errorf("invalid userpass %q", up) - } - - u := strings.SplitN(s[0], ":", 3) - if len(s) != 2 { - return nil, fmt.Errorf("invalid userpass %q", up) - } - host := s[1] - user := u[0] - pwd := u[1] - bc[host] = Userpass{ - Username: user, - Password: pwd, - } - } - return bc, nil -} - -type basicCredentials struct { - credentials map[string]Userpass -} - -func newBasicCredentials(cred map[string]Userpass, domainAlias func(string) string, hostScheme func(string) string) (*basicCredentials, error) { - bc := &basicCredentials{ - credentials: map[string]Userpass{}, - } - for domain, c := range cred { - urls, err := getAuthURLs(hostScheme(domain)+"://"+domain, domainAlias) - if err != nil { - return nil, err - } - for _, u := range urls { - bc.credentials[u] = c - } - } - return bc, nil -} - -func (c *basicCredentials) Basic(u *url.URL) (string, string) { - up := c.credentials[u.String()] - - return up.Username, up.Password -} - -func (c *basicCredentials) RefreshToken(u *url.URL, service string) string { - return "" -} - -func (c *basicCredentials) SetRefreshToken(u *url.URL, service, token string) { -} - -func getAuthURLs(remoteURL string, domainAlias func(string) string) ([]string, error) { - authURLs := []string{} - - u, err := url.Parse(remoteURL) - if err != nil { - return nil, err - } - if domainAlias != nil { - u.Host = domainAlias(u.Host) - } - remoteURL = u.String() - - resp, err := http.Get(remoteURL + "/v2/") - if err != nil { - return nil, err - } - defer resp.Body.Close() - - for _, c := range challenge.ResponseChallenges(resp) { - if strings.EqualFold(c.Scheme, "bearer") { - authURLs = append(authURLs, c.Parameters["realm"]) - } - } - - return authURLs, nil -} diff --git a/cmd/crproxy/main.go b/cmd/crproxy/main.go index 9dee4bd..910ed29 100644 --- a/cmd/crproxy/main.go +++ b/cmd/crproxy/main.go @@ -20,8 +20,8 @@ import ( "time" "github.com/daocloud/crproxy/cache" - "github.com/daocloud/crproxy/clientset" csync "github.com/daocloud/crproxy/cmd/crproxy/sync" + "github.com/daocloud/crproxy/transport" "github.com/docker/distribution/registry/storage/driver/factory" "github.com/gorilla/handlers" "github.com/spf13/cobra" @@ -167,31 +167,10 @@ func run(ctx context.Context) { logger := slog.New(slog.NewJSONHandler(os.Stderr, nil)) mux := http.NewServeMux() - cli := &http.Client{ - CheckRedirect: func(req *http.Request, via []*http.Request) error { - if len(via) > 10 { - return http.ErrUseLastResponse - } - s := make([]string, 0, len(via)+1) - for _, v := range via { - s = append(s, v.URL.String()) - } - - lastRedirect := req.URL.String() - s = append(s, lastRedirect) - logger.Info("redirect", "redirects", s) - if v := crproxy.GetCtxValue(req.Context()); v != nil { - v.LastRedirect = lastRedirect - } - return nil - }, - } - clientOpts := []clientset.Option{ - clientset.WithLogger(logger), - clientset.WithBaseClient(cli), - clientset.WithMaxClientSizeForEachRegistry(16), - clientset.WithDisableKeepAlives(disableKeepAlives), + transportOpts := []transport.Option{ + transport.WithLogger(logger), + transport.WithDisableKeepAlives(disableKeepAlives), } opts := []crproxy.Option{ @@ -426,12 +405,12 @@ func run(ctx context.Context) { } if len(userpass) != 0 { - bc, err := clientset.ToUserAndPass(userpass) + bc, err := transport.ToUserAndPass(userpass) if err != nil { logger.Error("failed to toUserAndPass", "error", err) os.Exit(1) } - clientOpts = append(clientOpts, clientset.WithUserAndPass(bc)) + transportOpts = append(transportOpts, transport.WithUserAndPass(bc)) } if ipsSpeedLimit != "" { @@ -466,7 +445,7 @@ func run(ctx context.Context) { } if retry > 0 { - clientOpts = append(clientOpts, clientset.WithRetry(retry, retryInterval)) + transportOpts = append(transportOpts, transport.WithRetry(retry, retryInterval)) } if limitDelay { opts = append(opts, crproxy.WithLimitDelay(true)) @@ -565,20 +544,17 @@ func run(ctx context.Context) { })) } - if allowHeadMethod { - clientOpts = append(clientOpts, clientset.WithAllowHeadMethod(allowHeadMethod)) - } - if manifestCacheDuration != 0 { opts = append(opts, crproxy.WithManifestCacheDuration(manifestCacheDuration)) } - clientset, err := clientset.NewClientset(clientOpts...) + transport, err := transport.NewTransport(transportOpts...) if err != nil { - logger.Error("failed to NewClientset", "error", err) + logger.Error("failed to NewTransport", "error", err) os.Exit(1) } - opts = append(opts, crproxy.WithClient(clientset)) + + opts = append(opts, crproxy.WithTransport(transport)) crp, err := crproxy.NewCRProxy(opts...) if err != nil { diff --git a/cmd/crproxy/sync/sync.go b/cmd/crproxy/sync/sync.go index 5858ab8..75d35f1 100644 --- a/cmd/crproxy/sync/sync.go +++ b/cmd/crproxy/sync/sync.go @@ -9,8 +9,8 @@ import ( "strings" "github.com/daocloud/crproxy/cache" - "github.com/daocloud/crproxy/clientset" csync "github.com/daocloud/crproxy/sync" + "github.com/daocloud/crproxy/transport" "github.com/docker/distribution/manifest/manifestlist" "github.com/docker/distribution/registry/storage/driver/factory" "github.com/spf13/cobra" @@ -78,22 +78,21 @@ func runE(ctx context.Context, flags *flagpole) error { return fmt.Errorf("create cache failed: %w", err) } - clientOpts := []clientset.Option{ - clientset.WithLogger(logger), - clientset.WithMaxClientSizeForEachRegistry(16), + transportOpts := []transport.Option{ + transport.WithLogger(logger), } if len(flags.Userpass) != 0 { - bc, err := clientset.ToUserAndPass(flags.Userpass) + bc, err := transport.ToUserAndPass(flags.Userpass) if err != nil { return fmt.Errorf("failed to toUserAndPass: %w", err) } - clientOpts = append(clientOpts, clientset.WithUserAndPass(bc)) + transportOpts = append(transportOpts, transport.WithUserAndPass(bc)) } - client, err := clientset.NewClientset(clientOpts...) + tp, err := transport.NewTransport(transportOpts...) if err != nil { - return fmt.Errorf("create clientset failed: %w", err) + return fmt.Errorf("create transport failed: %w", err) } opts = append(opts, @@ -103,7 +102,7 @@ func runE(ctx context.Context, flags *flagpole) error { "ollama.ai": "registry.ollama.ai", }), csync.WithDeep(flags.Deep), - csync.WithClient(client), + csync.WithTransport(tp), csync.WithLogger(logger), csync.WithFilterPlatform(filterPlatform(flags.Platform)), ) diff --git a/crproxy.go b/crproxy.go index cb29942..c63c526 100644 --- a/crproxy.go +++ b/crproxy.go @@ -7,12 +7,12 @@ import ( "net" "net/http" "net/textproto" + "net/url" "strings" "sync" "time" "github.com/daocloud/crproxy/cache" - "github.com/daocloud/crproxy/clientset" "github.com/daocloud/crproxy/internal/maps" "github.com/daocloud/crproxy/token" "github.com/docker/distribution/registry/api/errcode" @@ -36,7 +36,7 @@ type BlockInfo struct { } type CRProxy struct { - client *clientset.Clientset + httpClient *http.Client modify func(info *ImageInfo) *ImageInfo domainAlias map[string]string bytesPool sync.Pool @@ -65,9 +65,29 @@ type CRProxy struct { type Option func(c *CRProxy) -func WithClient(client *clientset.Clientset) Option { +func WithTransport(transport http.RoundTripper) Option { return func(c *CRProxy) { - c.client = client + cli := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + if len(via) > 10 { + return http.ErrUseLastResponse + } + s := make([]string, 0, len(via)+1) + for _, v := range via { + s = append(s, v.URL.String()) + } + + lastRedirect := req.URL.String() + s = append(s, lastRedirect) + + if v := GetCtxValue(req.Context()); v != nil { + v.LastRedirect = lastRedirect + } + return nil + }, + Transport: transport, + } + c.httpClient = cli } } @@ -358,14 +378,20 @@ func (c *CRProxy) ServeHTTP(rw http.ResponseWriter, r *http.Request) { errcode.ServeJSON(rw, errcode.ErrorCodeUnknown) return } - r.RequestURI = "" - r.Host = info.Host - r.URL.Host = info.Host - r.URL.Scheme = c.client.GetScheme(info.Host) - r.URL.Path = path - r.URL.RawQuery = "" - r.URL.ForceQuery = false - r.Body = http.NoBody + + u := url.URL{ + Scheme: "https", + Host: info.Host, + Path: path, + } + + r, err = http.NewRequestWithContext(r.Context(), r.Method, u.String(), nil) + if err != nil { + c.logger.Warn("failed to new request", "error", err) + errcode.ServeJSON(rw, errcode.ErrorCodeUnknown) + return + } + if info.Blobs != "" && c.isRedirectToOriginBlob(r, imageInfo) { c.redirectBlobResponse(rw, r, info) return @@ -392,8 +418,7 @@ func (c *CRProxy) ServeHTTP(rw http.ResponseWriter, r *http.Request) { } func (c *CRProxy) directResponse(rw http.ResponseWriter, r *http.Request, info *PathInfo, t *token.Token) { - cli := c.client.GetClientset(info.Host, info.Image) - resp, err := c.client.DoWithAuth(cli, r, info.Host) + resp, err := c.httpClient.Do(r) if err != nil { c.logger.Warn("failed to request", "host", info.Host, "image", info.Image, "error", err) errcode.ServeJSON(rw, errcode.ErrorCodeUnknown) diff --git a/crproxy_blob.go b/crproxy_blob.go index 4a65f07..9d5070b 100644 --- a/crproxy_blob.go +++ b/crproxy_blob.go @@ -108,8 +108,7 @@ func (c *CRProxy) cacheBlobResponse(rw http.ResponseWriter, r *http.Request, inf } func (c *CRProxy) cacheBlobContent(ctx context.Context, r *http.Request, info *PathInfo) (int64, error) { - cli := c.client.GetClientset(info.Host, info.Image) - resp, err := c.client.DoWithAuth(cli, r.WithContext(ctx), info.Host) + resp, err := c.httpClient.Do(r) if err != nil { return 0, err } @@ -132,8 +131,7 @@ func (c *CRProxy) cacheBlobContent(ctx context.Context, r *http.Request, info *P func (c *CRProxy) redirectBlobResponse(rw http.ResponseWriter, r *http.Request, info *PathInfo) { r = r.WithContext(withCtxValue(r.Context())) - cli := c.client.GetClientset(info.Host, info.Image) - resp, err := c.client.DoWithAuth(cli, r, info.Host) + resp, err := c.httpClient.Do(r) if err != nil { c.logger.Error("failed to request", "host", info.Host, "image", info.Image, "error", err) errcode.ServeJSON(rw, errcode.ErrorCodeUnknown) diff --git a/crproxy_manifest.go b/crproxy_manifest.go index 862231b..2a803a7 100644 --- a/crproxy_manifest.go +++ b/crproxy_manifest.go @@ -18,8 +18,7 @@ func (c *CRProxy) cacheManifestResponse(rw http.ResponseWriter, r *http.Request, return } - cli := c.client.GetClientset(info.Host, info.Image) - resp, err := c.client.DoWithAuth(cli, r, info.Host) + resp, err := c.httpClient.Do(r) if err != nil { if c.fallbackServeCachedManifest(rw, r, info) { return diff --git a/go.mod b/go.mod index f4187ab..09b90da 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,8 @@ go 1.22 require ( github.com/denverdino/aliyungo v0.0.0 github.com/distribution/reference v0.6.0 - github.com/docker/distribution v0.0.0 + github.com/docker/distribution v2.8.2+incompatible + github.com/google/go-containerregistry v0.20.2 github.com/gorilla/handlers v1.5.2 github.com/huaweicloud/huaweicloud-sdk-go-obs v3.24.6+incompatible github.com/opencontainers/go-digest v1.0.0 @@ -38,6 +39,8 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dnaeon/go-vcr v1.2.0 // indirect + github.com/docker/cli v27.1.1+incompatible // indirect + github.com/docker/docker-credential-helpers v0.7.0 // indirect github.com/docker/go-metrics v0.0.1 // indirect github.com/docker/libtrust v0.0.0-20150114040149-fa567046d9b1 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect @@ -53,9 +56,11 @@ require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/kr/text v0.2.0 // indirect + github.com/mitchellh/go-homedir v1.1.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/opencontainers/image-spec v1.1.0-rc3 // indirect github.com/opentracing/opentracing-go v1.2.0 // indirect + github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/prometheus/client_golang v1.19.1 // indirect github.com/prometheus/client_model v0.6.1 // indirect diff --git a/go.sum b/go.sum index 0f062ea..78f7aaf 100644 --- a/go.sum +++ b/go.sum @@ -53,6 +53,10 @@ github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5Qvfr github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI= github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ= +github.com/docker/cli v27.1.1+incompatible h1:goaZxOqs4QKxznZjjBWKONQci/MywhtRv2oNn0GkeZE= +github.com/docker/cli v27.1.1+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8= +github.com/docker/docker-credential-helpers v0.7.0 h1:xtCHsjxogADNZcdv1pKUHXryefjlVRqWqIhk/uXJp0A= +github.com/docker/docker-credential-helpers v0.7.0/go.mod h1:rETQfLdHNT3foU5kuNkFR1R1V12OJRRO5lzt2D1b5X0= github.com/docker/go-metrics v0.0.1 h1:AgB/0SvBxihN0X8OR4SjsblXkbMvalQ8cjmtKQ2rQV8= github.com/docker/go-metrics v0.0.1/go.mod h1:cG1hvH2utMXtqgqqYE9plW6lDxS3/5ayHzueweSI3Vw= github.com/docker/libtrust v0.0.0-20150114040149-fa567046d9b1 h1:ZClxb8laGDf5arXfYcAtECDFgAgHklGI8CxgjHnXKJ4= @@ -107,6 +111,8 @@ github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-containerregistry v0.20.2 h1:B1wPJ1SN/S7pB+ZAimcciVD+r+yV/l/DSArMxlbwseo= +github.com/google/go-containerregistry v0.20.2/go.mod h1:z38EKdKh4h7IP2gSfUUqEvalZBqs6AoLeWfUy34nQC8= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/s2a-go v0.1.4 h1:1kZ/sQM3srePvKs3tXAvQzo66XfcReoqFpIpIccE7Oc= github.com/google/s2a-go v0.1.4/go.mod h1:Ej+mSEMGRnqRzjc7VtF+jdBwYG5fuJfiZ8ELkjEwM0A= @@ -141,6 +147,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/magiconair/properties v1.8.6/go.mod h1:y3VJvCyxH9uVvJTWEGAELF3aiYNyPKd5NZ3oSwXrF60= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= +github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= +github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= @@ -358,5 +366,7 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gotest.tools/v3 v3.0.3 h1:4AuOwCGf4lLR9u3YOe2awrHygurzhO/HeQ6laiA6Sx0= +gotest.tools/v3 v3.0.3/go.mod h1:Z7Lb0S5l+klDB31fvDQX8ss/FlKDxtlFlw3Oa8Ymbl8= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/sync/sync.go b/sync/sync.go index 7404bb2..487905f 100644 --- a/sync/sync.go +++ b/sync/sync.go @@ -4,11 +4,11 @@ import ( "context" "fmt" "log/slog" + "net/http" "regexp" "strings" "github.com/daocloud/crproxy/cache" - "github.com/daocloud/crproxy/clientset" "github.com/distribution/reference" "github.com/docker/distribution" "github.com/docker/distribution/manifest/manifestlist" @@ -36,7 +36,7 @@ func newNameWithoutDomain(named reference.Named, name string) reference.Named { } type SyncManager struct { - client *clientset.Clientset + transport http.RoundTripper cache *cache.Cache logger *slog.Logger domainAlias map[string]string @@ -82,9 +82,9 @@ func WithCache(cache *cache.Cache) Option { } } -func WithClient(client *clientset.Clientset) Option { +func WithTransport(transport http.RoundTripper) Option { return func(c *SyncManager) { - c.client = client + c.transport = transport } } @@ -96,14 +96,12 @@ func WithFilterPlatform(filterPlatform func(pf manifestlist.PlatformSpec) bool) func NewSyncManager(opts ...Option) (*SyncManager, error) { c := &SyncManager{ - logger: slog.Default(), + logger: slog.Default(), + transport: http.DefaultTransport, } for _, opt := range opts { opt(c) } - if c.client == nil { - return nil, fmt.Errorf("client is required") - } if c.cache == nil { return nil, fmt.Errorf("cache is required") @@ -140,14 +138,7 @@ func (c *SyncManager) Image(ctx context.Context, image string) error { name := newNameWithoutDomain(named, path) - err = c.client.Ping(host) - if err != nil { - return fmt.Errorf("ping registry failed: %w", err) - } - - cli := c.client.GetClientset(host, path) - - repo, err := client.NewRepository(name, c.client.HostURL(host), cli.Transport) + repo, err := client.NewRepository(name, "https://"+host, c.transport) if err != nil { return fmt.Errorf("create repository failed: %w", err) } diff --git a/transport/transport.go b/transport/transport.go new file mode 100644 index 0000000..953b7f0 --- /dev/null +++ b/transport/transport.go @@ -0,0 +1,248 @@ +package transport + +import ( + "context" + "errors" + "fmt" + "log/slog" + "net/http" + "strings" + "sync" + "time" + + "github.com/daocloud/crproxy/internal/maps" + "github.com/google/go-containerregistry/pkg/authn" + "github.com/google/go-containerregistry/pkg/name" + "github.com/google/go-containerregistry/pkg/v1/remote/transport" + "github.com/wzshiming/httpseek" +) + +type Userpass = authn.AuthConfig + +type Transport struct { + baseTransport http.RoundTripper + insecureDomain map[string]struct{} + domainDisableKeepAlives map[string]struct{} + userAndPass map[string]Userpass + clientset maps.SyncMap[string, maps.SyncMap[string, http.RoundTripper]] + mutClientset sync.Mutex + logger *slog.Logger + retry int + retryInterval time.Duration + allowHeadMethod bool +} + +type Option func(c *Transport) + +func WithBaseTransport(baseTransport http.RoundTripper) Option { + return func(c *Transport) { + c.baseTransport = baseTransport + } +} + +func WithLogger(logger *slog.Logger) Option { + return func(c *Transport) { + c.logger = logger + } +} + +func WithUserAndPass(userAndPass map[string]Userpass) Option { + return func(c *Transport) { + c.userAndPass = userAndPass + } +} + +func WithDisableKeepAlives(disableKeepAlives []string) Option { + return func(c *Transport) { + c.domainDisableKeepAlives = map[string]struct{}{} + for _, v := range disableKeepAlives { + c.domainDisableKeepAlives[v] = struct{}{} + } + } +} + +func WithRetry(retry int, retryInterval time.Duration) Option { + return func(c *Transport) { + c.retry = retry + c.retryInterval = retryInterval + } +} + +func WithAllowHeadMethod(allowHeadMethod bool) Option { + return func(c *Transport) { + c.allowHeadMethod = allowHeadMethod + } +} + +func NewTransport(opts ...Option) (*Transport, error) { + c := &Transport{ + logger: slog.Default(), + baseTransport: http.DefaultTransport, + } + + for _, opt := range opts { + opt(c) + } + + return c, nil +} + +func (c *Transport) getRegistry(host string) (name.Registry, error) { + if c.insecureDomain != nil { + _, ok := c.insecureDomain[host] + if ok { + return name.NewRegistry(host, name.Insecure) + } + } + return name.NewRegistry(host) +} + +func (c *Transport) getUserpass(host string) Userpass { + userpass, ok := c.userAndPass[host] + if !ok { + return Userpass{} + } + return userpass +} + +func parsePath(path string) (string, bool) { + path = strings.TrimPrefix(path, "/v2/") + parts := strings.Split(path, "/") + if len(parts) < 3 { + return "", false + } + + image := strings.Join(parts[0:len(parts)-2], "/") + + return image, true +} + +func (c *Transport) RoundTrip(req *http.Request) (*http.Response, error) { + image, ok := parsePath(req.URL.Path) + if !ok { + return nil, fmt.Errorf("invalid path: %s", req.URL.Path) + } + + host := req.Host + if host == "" { + host = req.URL.Host + } + rt, err := c.getRoundTripper(host, image) + if err != nil { + return nil, err + } + + resp, err := rt.RoundTrip(req.WithContext(req.Context())) + if err != nil { + return nil, err + } + if resp.StatusCode == http.StatusUnauthorized { + resp.Body.Close() + c.logger.Warn("unauthorized retry", "url", req.URL) + + sets, hasSets := c.clientset.Load(host) + if hasSets { + sets.Delete(image) + } + + resp, err = rt.RoundTrip(req.WithContext(req.Context())) + if err != nil { + return nil, err + } + } + + return resp, nil +} + +func (c *Transport) getRoundTripper(host string, image string) (http.RoundTripper, error) { + sets, hasSets := c.clientset.Load(host) + if hasSets { + client, ok := sets.Load(image) + if ok { + return client, nil + } + } + + c.mutClientset.Lock() + defer c.mutClientset.Unlock() + + registry, err := c.getRegistry(host) + if err != nil { + return nil, err + } + + tr, err := transport.NewWithContext( + context.Background(), + registry, + authn.FromConfig(c.getUserpass(host)), + c.baseTransport, + []string{"repository:" + image + ":pull"}, + ) + if err != nil { + return nil, err + } + + if c.domainDisableKeepAlives != nil { + if _, ok := c.domainDisableKeepAlives[host]; ok { + tr = c.disableKeepAlives(tr) + } + } + + if c.retryInterval > 0 { + tr = httpseek.NewMustReaderTransport(tr, func(request *http.Request, retry int, err error) error { + if errors.Is(err, context.Canceled) || + errors.Is(err, context.DeadlineExceeded) { + return err + } + if c.retry > 0 && retry >= c.retry { + return err + } + c.logger.Info("Retry", "url", request.URL, "retry", retry, "error", err) + time.Sleep(c.retryInterval) + return nil + }) + } + + sets.Store(image, tr) + return tr, nil +} + +func (c *Transport) disableKeepAlives(rt http.RoundTripper) http.RoundTripper { + if rt == nil { + tr := http.DefaultTransport.(*http.Transport).Clone() + tr.DisableKeepAlives = true + return tr + } + if tr, ok := rt.(*http.Transport); ok { + if !tr.DisableKeepAlives { + tr = tr.Clone() + tr.DisableKeepAlives = true + } + return tr + } + c.logger.Warn("Failed to disable keep alives") + return rt +} + +func ToUserAndPass(userpass []string) (map[string]Userpass, error) { + bc := map[string]Userpass{} + for _, up := range userpass { + s := strings.SplitN(up, "@", 3) + if len(s) != 2 { + return nil, fmt.Errorf("invalid userpass %q", up) + } + + u := strings.SplitN(s[0], ":", 3) + if len(s) != 2 { + return nil, fmt.Errorf("invalid userpass %q", up) + } + host := s[1] + user := u[0] + pwd := u[1] + bc[host] = Userpass{ + Username: user, + Password: pwd, + } + } + return bc, nil +}