diff --git a/pkg/storage/s3/checker.go b/pkg/storage/s3/checker.go index 63c6a91a..e98c7678 100644 --- a/pkg/storage/s3/checker.go +++ b/pkg/storage/s3/checker.go @@ -2,13 +2,14 @@ package s3 import ( "context" - "fmt" + errs "errors" + "sync" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/service/s3" "github.com/gomods/athens/pkg/config" "github.com/gomods/athens/pkg/errors" - "github.com/gomods/athens/pkg/log" "github.com/gomods/athens/pkg/observ" ) @@ -19,27 +20,39 @@ func (s *Storage) Exists(ctx context.Context, module, version string) (bool, err ctx, span := observ.StartSpan(ctx, op.String()) defer span.End() - lsParams := &s3.ListObjectsInput{ - Bucket: aws.String(s.bucket), - Prefix: aws.String(fmt.Sprintf("%s/@v", module)), + files := []string{"info", "mod", "zip"} + errChan := make(chan error, len(files)) + defer close(errChan) + cancelingCtx, cancel := context.WithCancel(ctx) + var wg sync.WaitGroup + for _, file := range files { + wg.Add(1) + go func(file string) { + defer wg.Done() + _, err := s.s3API.HeadObjectWithContext( + cancelingCtx, + &s3.HeadObjectInput{ + Bucket: aws.String(s.bucket), + Key: aws.String(config.PackageVersionedName(module, version, file)), + }) + errChan <- err + }(file) } - found := make(map[string]struct{}, 3) - err := s.s3API.ListObjectsPagesWithContext(ctx, lsParams, func(loo *s3.ListObjectsOutput, lastPage bool) bool { - for _, o := range loo.Contents { - if _, exists := found[*o.Key]; exists { - log.EntryFromContext(ctx).Warnf("duplicate key in prefix %q: %q", *lsParams.Prefix, *o.Key) - continue - } - if *o.Key == config.PackageVersionedName(module, version, "info") || - *o.Key == config.PackageVersionedName(module, version, "mod") || - *o.Key == config.PackageVersionedName(module, version, "zip") { - found[*o.Key] = struct{}{} - } + exists := true + var err error + for range files { + err = <-errChan + if err == nil { + continue } - return len(found) < 3 - }) - if err != nil { - return false, errors.E(op, err, errors.M(module), errors.V(version)) + var aerr awserr.Error + if errs.As(err, &aerr) && aerr.Code() == "NotFound" { + err = nil + exists = false + } + break } - return len(found) == 3, nil + cancel() + wg.Wait() + return exists, err }