Skip to content

Commit

Permalink
Use queue
Browse files Browse the repository at this point in the history
  • Loading branch information
wzshiming committed Jan 9, 2025
1 parent 812c657 commit a6cea1d
Show file tree
Hide file tree
Showing 12 changed files with 530 additions and 217 deletions.
186 changes: 125 additions & 61 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
"time"

"github.com/daocloud/crproxy/cache"
"github.com/daocloud/crproxy/internal/unique"
"github.com/daocloud/crproxy/internal/queue"
"github.com/daocloud/crproxy/internal/utils"
"github.com/daocloud/crproxy/token"
"github.com/docker/distribution/registry/api/errcode"
Expand All @@ -33,10 +33,11 @@ type BlobInfo struct {
}

type Agent struct {
uniq unique.Unique[string]
httpClient *http.Client
logger *slog.Logger
cache *cache.Cache
concurrency int
queue *queue.Queue[BlobInfo]
httpClient *http.Client
logger *slog.Logger
cache *cache.Cache

blobCacheDuration time.Duration
blobCache *blobCache
Expand Down Expand Up @@ -84,24 +85,45 @@ func WithBlobsLENoAgent(blobsLENoAgent int) Option {

func WithBlobCacheDuration(blobCacheDuration time.Duration) Option {
return func(c *Agent) error {
if blobCacheDuration < 10*time.Second {
blobCacheDuration = 10 * time.Second
}
c.blobCacheDuration = blobCacheDuration
return nil
}
}

func WithConcurrency(concurrency int) Option {
return func(c *Agent) error {
if concurrency < 1 {
concurrency = 1
}
c.concurrency = concurrency
return nil
}
}

func NewAgent(opts ...Option) (*Agent, error) {
c := &Agent{
logger: slog.Default(),
httpClient: http.DefaultClient,
blobCacheDuration: time.Hour,
queue: queue.NewQueue[BlobInfo](),
concurrency: 10,
}

for _, opt := range opts {
opt(c)
}

ctx := context.Background()

c.blobCache = newBlobCache(c.blobCacheDuration)
c.blobCache.Start(context.Background(), c.logger)
c.blobCache.Start(ctx, c.logger)

for i := 0; i <= c.concurrency; i++ {
go c.worker(ctx)
}

return c, nil
}
Expand All @@ -126,6 +148,20 @@ func parsePath(path string) (string, string, string, bool) {
return source, image, digest, true
}

func (c *Agent) worker(ctx context.Context) {
for {
info, finish, ok := c.queue.GetOrWaitWithDone(ctx.Done())
if !ok {
return
}
sc, err := c.cacheBlob(&info)
if err != nil {
c.blobCache.PutError(info.Blobs, err, sc)
}
finish()
}
}

func (c *Agent) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
oriPath := r.URL.Path
if !strings.HasPrefix(oriPath, prefix) {
Expand Down Expand Up @@ -185,43 +221,64 @@ func (c *Agent) Serve(rw http.ResponseWriter, r *http.Request, info *BlobInfo, t
start = time.Now()
}

err := c.uniq.Do(ctx, info.Blobs,
func(ctx context.Context) (passCtx context.Context, done bool) {
value, ok := c.blobCache.Get(info.Blobs)
if ok {
if value.Error != nil {
utils.ServeError(rw, r, value.Error, 0)
return ctx, true
}
c.serveCachedBlob(rw, r, info.Blobs, info, t, value.Size, start)
return ctx, true
}
stat, err := c.cache.StatBlob(ctx, info.Blobs)
if err != nil {
return ctx, false
}
c.serveCachedBlob(rw, r, info.Blobs, info, t, stat.Size(), start)
return ctx, true
},
func(ctx context.Context) (err error) {
if ctx.Err() != nil {
return nil
}
size, sc, err := c.cacheBlob(info)
if err != nil {
utils.ServeError(rw, r, err, sc)
return nil
}
c.serveCachedBlob(rw, r, info.Blobs, info, t, size, start)
return nil
},
)
if err != nil {
c.logger.Warn("error response", "remoteAddr", r.RemoteAddr, "error", err)
c.blobCache.PutError(info.Blobs, err)
utils.ServeError(rw, r, err, 0)
value, ok := c.blobCache.Get(info.Blobs)
if ok {
if value.Error != nil {
utils.ServeError(rw, r, value.Error, 0)
return
}

if c.serveCachedBlobHead(rw, r, value.Size) {
return
}

c.rateLimit(rw, r, info.Blobs, info, t, value.Size, start)
c.serveCachedBlob(rw, r, info.Blobs, info, t, value.Size)
return
}

stat, err := c.cache.StatBlob(ctx, info.Blobs)
if err == nil {
if c.serveCachedBlobHead(rw, r, value.Size) {
return
}

c.rateLimit(rw, r, info.Blobs, info, t, value.Size, start)
c.serveCachedBlob(rw, r, info.Blobs, info, t, stat.Size())
return
}

c.rateLimit(rw, r, info.Blobs, info, t, value.Size, start)
select {
case <-ctx.Done():
return
case <-c.queue.AddWeight(*info, t.Weight):
}

value, ok = c.blobCache.Get(info.Blobs)
if ok {
if value.Error != nil {
utils.ServeError(rw, r, value.Error, 0)
return
}
if c.serveCachedBlobHead(rw, r, value.Size) {
return
}
c.serveCachedBlob(rw, r, info.Blobs, info, t, value.Size)
return
}

stat, err = c.cache.StatBlob(ctx, info.Blobs)
if err == nil {
if c.serveCachedBlobHead(rw, r, value.Size) {
return
}
c.serveCachedBlob(rw, r, info.Blobs, info, t, stat.Size())
return
}

c.logger.Error("here should never be executed", "info", info)
utils.ServeError(rw, r, errcode.ErrorCodeUnknown, 0)
}

func sleepDuration(ctx context.Context, size, limit float64, start time.Time) error {
Expand All @@ -241,7 +298,7 @@ func sleepDuration(ctx context.Context, size, limit float64, start time.Time) er
return nil
}

func (c *Agent) cacheBlob(info *BlobInfo) (int64, int, error) {
func (c *Agent) cacheBlob(info *BlobInfo) (int, error) {
ctx := context.Background()
u := &url.URL{
Scheme: "https",
Expand All @@ -251,7 +308,7 @@ func (c *Agent) cacheBlob(info *BlobInfo) (int64, int, error) {
forwardReq, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
if err != nil {
c.logger.Warn("failed to new request", "url", u.String(), "error", err)
return 0, 0, err
return 0, err
}

forwardReq.Header.Set("Accept", "*/*")
Expand All @@ -260,71 +317,78 @@ func (c *Agent) cacheBlob(info *BlobInfo) (int64, int, error) {
if err != nil {
var tErr *transport.Error
if errors.As(err, &tErr) {
return 0, http.StatusForbidden, errcode.ErrorCodeDenied
return http.StatusForbidden, errcode.ErrorCodeDenied
}
c.logger.Warn("failed to request", "url", u.String(), "error", err)
return 0, 0, errcode.ErrorCodeUnknown
return 0, errcode.ErrorCodeUnknown
}
defer func() {
resp.Body.Close()
}()

switch resp.StatusCode {
case http.StatusUnauthorized, http.StatusForbidden:
return 0, 0, errcode.ErrorCodeDenied
return 0, errcode.ErrorCodeDenied
}

switch resp.StatusCode {
case http.StatusUnauthorized, http.StatusForbidden:
c.logger.Error("upstream denied", "statusCode", resp.StatusCode, "url", u.String())
return 0, 0, errcode.ErrorCodeDenied
return 0, errcode.ErrorCodeDenied
}
if resp.StatusCode < http.StatusOK ||
(resp.StatusCode >= http.StatusMultipleChoices && resp.StatusCode < http.StatusBadRequest) {
c.logger.Error("upstream unkown code", "statusCode", resp.StatusCode, "url", u.String())
return 0, 0, errcode.ErrorCodeUnknown
return 0, errcode.ErrorCodeUnknown
}

if resp.StatusCode >= http.StatusBadRequest {
body, err := io.ReadAll(io.LimitReader(resp.Body, 1024*1024))
if err != nil {
c.logger.Error("failed to get body", "statusCode", resp.StatusCode, "url", u.String(), "error", err)
return 0, 0, errcode.ErrorCodeUnknown
return 0, errcode.ErrorCodeUnknown
}
if !json.Valid(body) {
c.logger.Error("invalid body", "statusCode", resp.StatusCode, "url", u.String(), "body", string(body))
return 0, 0, errcode.ErrorCodeDenied
return 0, errcode.ErrorCodeDenied
}
var retErrs errcode.Errors
err = retErrs.UnmarshalJSON(body)
if err != nil {
c.logger.Error("failed to unmarshal body", "statusCode", resp.StatusCode, "url", u.String(), "body", string(body))
return 0, 0, errcode.ErrorCodeUnknown
return 0, errcode.ErrorCodeUnknown
}
return 0, resp.StatusCode, retErrs
return resp.StatusCode, retErrs
}

size, err := c.cache.PutBlob(ctx, info.Blobs, resp.Body)
if err != nil {
return 0, 0, err
return 0, err
}
return size, 0, nil
}
c.blobCache.Put(info.Blobs, size)

func (c *Agent) serveCachedBlob(rw http.ResponseWriter, r *http.Request, blob string, info *BlobInfo, t *token.Token, size int64, start time.Time) {
if size != 0 && r.Method == http.MethodHead {
rw.Header().Set("Content-Length", strconv.FormatInt(size, 10))
rw.Header().Set("Content-Type", "application/octet-stream")
return
}
return 0, nil
}

func (c *Agent) rateLimit(rw http.ResponseWriter, r *http.Request, blob string, info *BlobInfo, t *token.Token, size int64, start time.Time) {
if !t.NoRateLimit {
err := sleepDuration(r.Context(), float64(size), float64(t.RateLimitPerSecond), start)
if err != nil {
return
}
}
}

func (c *Agent) serveCachedBlobHead(rw http.ResponseWriter, r *http.Request, size int64) bool {
if size != 0 && r.Method == http.MethodHead {
rw.Header().Set("Content-Length", strconv.FormatInt(size, 10))
rw.Header().Set("Content-Type", "application/octet-stream")
return true
}
return false
}

func (c *Agent) serveCachedBlob(rw http.ResponseWriter, r *http.Request, blob string, info *BlobInfo, t *token.Token, size int64) {
if c.blobsLENoAgent < 0 || int64(c.blobsLENoAgent) > size {
data, err := c.cache.GetBlob(r.Context(), info.Blobs)
if err != nil {
Expand Down
10 changes: 6 additions & 4 deletions agent/blob_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ func (m *blobCache) Remove(key string) {
m.digest.Remove(key)
}

func (m *blobCache) PutError(key string, err error) {
func (m *blobCache) PutError(key string, err error, sc int) {
m.digest.SetWithTTL(key, blobValue{
Error: err,
Error: err,
StatusCode: sc,
}, m.duration)
}

Expand All @@ -53,6 +54,7 @@ func (m *blobCache) Put(key string, size int64) {
}

type blobValue struct {
Size int64
Error error
Size int64
Error error
StatusCode int
}
27 changes: 20 additions & 7 deletions cache/cache_manifest.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,20 @@ func (c *Cache) RelinkManifest(ctx context.Context, host, image, tag string, blo
return nil
}

func (c *Cache) PutManifestContent(ctx context.Context, host, image, tagOrBlob string, content []byte) (int64, string, error) {
func (c *Cache) PutManifestContent(ctx context.Context, host, image, tagOrBlob string, content []byte) (int64, string, string, error) {
mt := struct {
MediaType string `json:"mediaType"`
}{}
err := json.Unmarshal(content, &mt)
if err != nil {
return 0, "", "", fmt.Errorf("invalid content: %w: %s", err, string(content))
}

mediaType := mt.MediaType
if mediaType == "" {
mediaType = "application/vnd.docker.distribution.manifest.v1+json"
}

h := sha256.New()
h.Write(content)
hash := hex.EncodeToString(h.Sum(nil)[:])
Expand All @@ -43,27 +56,27 @@ func (c *Cache) PutManifestContent(ctx context.Context, host, image, tagOrBlob s
if isHash {
tagOrBlob = tagOrBlob[7:]
if tagOrBlob != hash {
return 0, "", fmt.Errorf("expected hash %s is not same to %s", tagOrBlob, hash)
return 0, "", "", fmt.Errorf("expected hash %s is not same to %s", tagOrBlob, hash)
}
} else {
manifestLinkPath := manifestTagCachePath(host, image, tagOrBlob)
err := c.PutContent(ctx, manifestLinkPath, []byte("sha256:"+hash))
if err != nil {
return 0, "", fmt.Errorf("put manifest link path %s error: %w", manifestLinkPath, err)
return 0, "", "", fmt.Errorf("put manifest link path %s error: %w", manifestLinkPath, err)
}
}

manifestLinkPath := manifestRevisionsCachePath(host, image, hash)
err := c.PutContent(ctx, manifestLinkPath, []byte("sha256:"+hash))
err = c.PutContent(ctx, manifestLinkPath, []byte("sha256:"+hash))
if err != nil {
return 0, "", fmt.Errorf("put manifest revisions path %s error: %w", manifestLinkPath, err)
return 0, "", "", fmt.Errorf("put manifest revisions path %s error: %w", manifestLinkPath, err)
}

n, err := c.PutBlobContent(ctx, hash, content)
if err != nil {
return 0, "", fmt.Errorf("put manifest blob path %s error: %w", hash, err)
return 0, "", "", fmt.Errorf("put manifest blob path %s error: %w", hash, err)
}
return n, hash, nil
return n, hash, mediaType, nil
}

func (c *Cache) GetManifestContent(ctx context.Context, host, image, tagOrBlob string) ([]byte, string, string, error) {
Expand Down
Loading

0 comments on commit a6cea1d

Please sign in to comment.