diff --git a/cmd/crproxy/sync/sync.go b/cmd/crproxy/sync/sync.go index c5b9d3a..507ade9 100644 --- a/cmd/crproxy/sync/sync.go +++ b/cmd/crproxy/sync/sync.go @@ -92,7 +92,7 @@ func runE(ctx context.Context, flags *flagpole) error { } opts = append(opts, - csync.WithCache(cache), + csync.WithCaches(cache), csync.WithDomainAlias(map[string]string{ "docker.io": "registry-1.docker.io", "ollama.ai": "registry.ollama.ai", diff --git a/sync/sync.go b/sync/sync.go index 487905f..8d3c9f7 100644 --- a/sync/sync.go +++ b/sync/sync.go @@ -3,10 +3,12 @@ package crproxy import ( "context" "fmt" + "io" "log/slog" "net/http" "regexp" "strings" + "sync" "github.com/daocloud/crproxy/cache" "github.com/distribution/reference" @@ -37,7 +39,7 @@ func newNameWithoutDomain(named reference.Named, name string) reference.Named { type SyncManager struct { transport http.RoundTripper - cache *cache.Cache + caches []*cache.Cache logger *slog.Logger domainAlias map[string]string deep bool @@ -76,9 +78,9 @@ func WithLogger(logger *slog.Logger) Option { } } -func WithCache(cache *cache.Cache) Option { +func WithCaches(caches ...*cache.Cache) Option { return func(c *SyncManager) { - c.cache = cache + c.caches = caches } } @@ -103,7 +105,7 @@ func NewSyncManager(opts ...Option) (*SyncManager, error) { opt(c) } - if c.cache == nil { + if len(c.caches) == 0 { return nil, fmt.Errorf("cache is required") } @@ -151,7 +153,7 @@ func (c *SyncManager) Image(ctx context.Context, image string) error { bs := repo.Blobs(ctx) uniq := map[digest.Digest]struct{}{} - blobCallback := func(dgst digest.Digest, size int64, pf *manifestlist.PlatformSpec, name string) error { + blobCallback := func(caches []*cache.Cache, dgst digest.Digest, size int64, pf *manifestlist.PlatformSpec, name string) error { _, ok := uniq[dgst] if ok { c.logger.Debug("skip blob by unique", "image", image, "digest", dgst) @@ -160,19 +162,26 @@ func (c *SyncManager) Image(ctx context.Context, image string) error { uniq[dgst] = struct{}{} blob := dgst.String() - stat, err := c.cache.StatBlob(ctx, blob) - if err == nil { - if size > 0 { - gotSize := stat.Size() - if size == gotSize { - c.logger.Debug("skip blob", "image", image, "digest", dgst) - return nil + var subCaches []*cache.Cache + for _, cache := range caches { + stat, err := cache.StatBlob(ctx, blob) + if err == nil { + if size > 0 { + gotSize := stat.Size() + if size == gotSize { + continue + } + c.logger.Error("size is not meeting expectations", "digest", dgst, "size", size, "gotSize", gotSize) + } else { + continue } - c.logger.Error("size is not meeting expectations", "digest", dgst, "size", size, "gotSize", gotSize) - } else { - c.logger.Debug("skip blob", "image", image, "digest", dgst) - return nil } + subCaches = append(subCaches, cache) + } + + if len(subCaches) == 0 { + c.logger.Debug("skip blob by cache", "image", image, "digest", dgst) + return nil } f, err := bs.Open(ctx, dgst) @@ -181,24 +190,56 @@ func (c *SyncManager) Image(ctx context.Context, image string) error { } defer f.Close() - n, err := c.cache.PutBlob(ctx, blob, f) + if len(subCaches) == 1 { + n, err := subCaches[0].PutBlob(ctx, blob, f) + if err != nil { + return fmt.Errorf("put blob failed: %w", err) + } + c.logger.Info("sync blob", "image", image, "digest", dgst, "size", n, "platform", pf, "name", name) + return nil + } + + var writers []io.Writer + var wg sync.WaitGroup + wg.Add(len(subCaches)) + + for _, ca := range subCaches { + pr, pw := io.Pipe() + writers = append(writers, pw) + + go func(cache *cache.Cache, pr io.Reader) { + defer wg.Done() + + _, err := cache.PutBlob(ctx, blob, pr) + if err != nil { + c.logger.Error("put blob failed", "image", image, "digest", dgst, "platform", pf, "name", name, "error", err) + return + } + }(ca, pr) + } + + n, err := io.Copy(io.MultiWriter(writers...), f) if err != nil { - return fmt.Errorf("put blob failed: %w", err) + return fmt.Errorf("copy blob failed: %w", err) } + wg.Wait() + c.logger.Info("sync blob", "image", image, "digest", dgst, "size", n, "platform", pf, "name", name) return nil } - manifestCallback := func(tagOrHash string, m distribution.Manifest) error { + manifestCallback := func(caches []*cache.Cache, tagOrHash string, m distribution.Manifest) error { _, playload, err := m.Payload() if err != nil { return fmt.Errorf("get manifest payload failed: %w", err) } - _, _, err = c.cache.PutManifestContent(ctx, host, path, tagOrHash, playload) - if err != nil { - return fmt.Errorf("put manifest content failed: %w", err) + for _, cache := range caches { + _, _, err = cache.PutManifestContent(ctx, host, path, tagOrHash, playload) + if err != nil { + return fmt.Errorf("put manifest content failed: %w", err) + } } return nil } @@ -245,21 +286,32 @@ func (c *SyncManager) Image(ctx context.Context, image string) error { } func (c *SyncManager) syncLayerFromManifestList(ctx context.Context, image string, ms distribution.ManifestService, ts distribution.TagService, ref reference.Reference, - digestCallback func(dgst digest.Digest, size int64, pf *manifestlist.PlatformSpec, name string) error, - manifestCallback func(tagOrHash string, m distribution.Manifest) error, name string) error { + digestCallback func(caches []*cache.Cache, dgst digest.Digest, size int64, pf *manifestlist.PlatformSpec, name string) error, + manifestCallback func(caches []*cache.Cache, tagOrHash string, m distribution.Manifest) error, name string) error { var ( m distribution.Manifest err error ) + var caches []*cache.Cache + + if c.deep { + caches = c.caches + } + var hash digest.Digest switch r := ref.(type) { case reference.Digested: hash = r.Digest() if !c.deep { - stat, err := c.cache.StatBlob(ctx, hash.String()) - if err == nil && stat.Size() > 0 { + for _, cache := range c.caches { + stat, err := cache.StatBlob(ctx, hash.String()) + if err != nil || stat.Size() == 0 { + caches = append(caches, cache) + } + } + if len(caches) == 0 { c.logger.Debug("skip manifest by cache", "image", image, "digest", hash) return nil } @@ -268,7 +320,7 @@ func (c *SyncManager) syncLayerFromManifestList(ctx context.Context, image strin if err != nil { return fmt.Errorf("get manifest digest failed: %w", err) } - err = manifestCallback(hash.String(), m) + err = manifestCallback(caches, hash.String(), m) if err != nil { return fmt.Errorf("manifest callback failed: %w", err) } @@ -280,8 +332,13 @@ func (c *SyncManager) syncLayerFromManifestList(ctx context.Context, image strin } hash = desc.Digest if !c.deep { - stat, err := c.cache.StatBlob(ctx, hash.String()) - if err == nil && stat.Size() == desc.Size { + for _, cache := range c.caches { + stat, err := cache.StatBlob(ctx, hash.String()) + if err != nil || stat.Size() == 0 { + caches = append(caches, cache) + } + } + if len(caches) == 0 { c.logger.Debug("skip manifest by cache", "image", image, "digest", hash, "tag", tag) return nil } @@ -290,7 +347,7 @@ func (c *SyncManager) syncLayerFromManifestList(ctx context.Context, image strin if err != nil { return fmt.Errorf("get manifest digest failed: %w", err) } - err = manifestCallback(tag, m) + err = manifestCallback(caches, tag, m) if err != nil { return fmt.Errorf("manifest callback failed: %w", err) } @@ -305,23 +362,35 @@ func (c *SyncManager) syncLayerFromManifestList(ctx context.Context, image strin c.logger.Debug("skip manifest by filter platform", "image", image, "digest", mfest.Digest, "platform", mfest.Platform) continue } + + var subCaches []*cache.Cache if !c.deep { - stat, err := c.cache.StatBlob(ctx, mfest.Digest.String()) - if err == nil && stat.Size() == mfest.Size { + for _, cache := range caches { + stat, err := cache.StatBlob(ctx, mfest.Digest.String()) + if err == nil && stat.Size() == mfest.Size { + continue + } + subCaches = append(subCaches, cache) + } + + if len(subCaches) == 0 { + c.logger.Debug("skip manifest by cache", "image", image, "digest", mfest.Digest, "platform", mfest.Platform) continue } + } else { + subCaches = caches } m0, err := ms.Get(ctx, mfest.Digest) if err != nil { return fmt.Errorf("get manifest failed: %w", err) } - err = manifestCallback(mfest.Digest.String(), m0) + err = manifestCallback(subCaches, mfest.Digest.String(), m0) if err != nil { return fmt.Errorf("manifest callback failed: %w", err) } err = c.syncLayerFromManifest(m0, func(dgst digest.Digest, size int64) error { - return digestCallback(dgst, size, &mfest.Platform, name) + return digestCallback(subCaches, dgst, size, &mfest.Platform, name) }) if err != nil { return fmt.Errorf("get layer from manifest failed: %w", err) @@ -330,7 +399,7 @@ func (c *SyncManager) syncLayerFromManifestList(ctx context.Context, image strin return nil default: return c.syncLayerFromManifest(m, func(dgst digest.Digest, size int64) error { - return digestCallback(dgst, size, nil, name) + return digestCallback(caches, dgst, size, nil, name) }) } }