diff --git a/cmd/olympus/actions/app.go b/cmd/olympus/actions/app.go index ba6899d2..e93e6ace 100644 --- a/cmd/olympus/actions/app.go +++ b/cmd/olympus/actions/app.go @@ -15,12 +15,14 @@ import ( "github.com/gomods/athens/pkg/cdn/metadata/azurecdn" "github.com/gomods/athens/pkg/config/env" "github.com/gomods/athens/pkg/download" - "github.com/gomods/athens/pkg/download/goget" "github.com/gomods/athens/pkg/eventlog" "github.com/gomods/athens/pkg/log" + "github.com/gomods/athens/pkg/module" + "github.com/gomods/athens/pkg/stash" "github.com/gomods/athens/pkg/storage" "github.com/gomodule/redigo/redis" "github.com/rs/cors" + "github.com/spf13/afero" "github.com/unrolled/secure" ) @@ -85,6 +87,7 @@ func App(config *AppConfig) (*buffalo.App, error) { WorkerOff: true, // TODO(marwan): turned off until worker is being used. Logger: blggr, }) + // Automatically redirect to SSL app.Use(ssl.ForceSSL(secure.Options{ SSLRedirect: ENV == "production", @@ -123,13 +126,23 @@ func App(config *AppConfig) (*buffalo.App, error) { app.GET("/healthz", healthHandler) // Download Protocol - gg, err := goget.New() + goBin := env.GoBinPath() + fs := afero.NewOsFs() + mf, err := module.NewGoGetFetcher(goBin, fs) if err != nil { return nil, err } - dp := download.New(gg, config.Storage, env.GoGetWorkers()) - opts := &download.HandlerOpts{Protocol: dp, Logger: lggr, Engine: renderEng} - download.RegisterHandlers(app, opts) + st := stash.New(mf, config.Storage) + dpOpts := &download.Opts{ + Storage: config.Storage, + Stasher: st, + GoBinPath: goBin, + Fs: fs, + } + dp := download.New(dpOpts) + + handlerOpts := &download.HandlerOpts{Protocol: dp, Logger: lggr, Engine: renderEng} + download.RegisterHandlers(app, handlerOpts) app.ServeFiles("/", assetsBox) // serve files from the public directory diff --git a/cmd/proxy/actions/app_proxy.go b/cmd/proxy/actions/app_proxy.go index 3b9322df..244bd0c5 100644 --- a/cmd/proxy/actions/app_proxy.go +++ b/cmd/proxy/actions/app_proxy.go @@ -4,9 +4,12 @@ import ( "github.com/gobuffalo/buffalo" "github.com/gomods/athens/pkg/config/env" "github.com/gomods/athens/pkg/download" - "github.com/gomods/athens/pkg/download/goget" + "github.com/gomods/athens/pkg/download/addons" "github.com/gomods/athens/pkg/log" + "github.com/gomods/athens/pkg/module" + "github.com/gomods/athens/pkg/stash" "github.com/gomods/athens/pkg/storage" + "github.com/spf13/afero" ) func addProxyRoutes( @@ -18,13 +21,45 @@ func addProxyRoutes( app.GET("/healthz", healthHandler) // Download Protocol - gg, err := goget.New() + // the download.Protocol and the stash.Stasher interfaces are composable + // in a middleware fashion. Therefore you can separate concerns + // by the functionality: a download.Protocol that just takes care + // of "go getting" things, and another Protocol that just takes care + // of "pooling" requests etc. + + // In our case, we'd like to compose both interfaces in a particular + // order to ensure logical ordering of execution. + + // Here's the order of an incoming request to the download.Protocol: + + // 1. The downloadpool gets hit first, and manages concurrent requests + // 2. The downloadpool passes the request to its parent Protocol: stasher + // 3. The stasher Protocol checks storage first, and if storage is empty + // it makes a Stash request to the stash.Stasher interface. + + // Once the stasher picks up an order, here's how the requests go in order: + // 1. The singleflight picks up the first request and latches duplicate ones. + // 2. The singleflight passes the stash to its parent: stashpool. + // 3. The stashpool manages limiting concurrent requests and passes them to stash. + // 4. The plain stash.New just takes a request from upstream and saves it into storage. + goBin := env.GoBinPath() + fs := afero.NewOsFs() + mf, err := module.NewGoGetFetcher(goBin, fs) if err != nil { return err } - p := download.New(gg, s, env.GoGetWorkers()) - opts := &download.HandlerOpts{Protocol: p, Logger: l, Engine: proxy} - download.RegisterHandlers(app, opts) + st := stash.New(mf, s, stash.WithPool(env.GoGetWorkers()), stash.WithSingleflight) + + dpOpts := &download.Opts{ + Storage: s, + Stasher: st, + GoBinPath: goBin, + Fs: fs, + } + dp := download.New(dpOpts, addons.WithPool(env.ProtocolWorkers())) + + handlerOpts := &download.HandlerOpts{Protocol: dp, Logger: l, Engine: proxy} + download.RegisterHandlers(app, handlerOpts) return nil } diff --git a/pkg/config/env/go.go b/pkg/config/env/go.go index e2cd1b14..66f4a4ea 100644 --- a/pkg/config/env/go.go +++ b/pkg/config/env/go.go @@ -36,3 +36,24 @@ func GoGetWorkers() int { return num } + +// ProtocolWorkers returns how many concurrent +// requests can you handle at a time for all +// download protocol paths. This is different from +// GoGetWorkers in that you can potentially serve +// 30 requests to the Download Protocol but only 5 +// at a time can stash a module from Upstream to Storage. +func ProtocolWorkers() int { + defaultNum := 30 + str := os.Getenv("ATHENS_PROTOCOL_WORKERS") + if str == "" { + return defaultNum + } + + num, err := strconv.Atoi(str) + if err != nil { + return defaultNum + } + + return num +} diff --git a/pkg/download/addons/with_pool.go b/pkg/download/addons/with_pool.go new file mode 100644 index 00000000..605dad80 --- /dev/null +++ b/pkg/download/addons/with_pool.go @@ -0,0 +1,129 @@ +package addons + +import ( + "context" + "io" + + "github.com/gomods/athens/pkg/download" + "github.com/gomods/athens/pkg/errors" + "github.com/gomods/athens/pkg/storage" +) + +type withpool struct { + dp download.Protocol + + // jobCh is a channel that takes an anonymous + // function that it executes based on the pool's + // business. The design levarages closures + // so that the worker does not need to worry about + // what the type of job it is taking (Info, Zip etc), + // it just regulates functions and executes them + // in a worker-pool fashion. + jobCh chan func() +} + +// WithPool takes a download Protocol and a number of workers +// and creates a N worker pool that share all the download.Protocol +// methods. +func WithPool(workers int) download.Wrapper { + return func(dp download.Protocol) download.Protocol { + jobCh := make(chan func()) + p := &withpool{dp: dp, jobCh: jobCh} + + p.start(workers) + return p + } +} + +func (p *withpool) start(numWorkers int) { + for i := 0; i < numWorkers; i++ { + go p.listen() + } +} + +func (p *withpool) listen() { + for f := range p.jobCh { + f() + } +} + +func (p *withpool) List(ctx context.Context, mod string) ([]string, error) { + const op errors.Op = "pool.List" + var vers []string + var err error + done := make(chan struct{}, 1) + p.jobCh <- func() { + vers, err = p.dp.List(ctx, mod) + close(done) + } + <-done + if err != nil { + return nil, errors.E(op, err) + } + + return vers, nil +} + +func (p *withpool) Info(ctx context.Context, mod, ver string) ([]byte, error) { + const op errors.Op = "pool.Info" + var info []byte + var err error + done := make(chan struct{}, 1) + p.jobCh <- func() { + info, err = p.dp.Info(ctx, mod, ver) + close(done) + } + <-done + if err != nil { + return nil, errors.E(op, err) + } + return info, nil +} + +func (p *withpool) Latest(ctx context.Context, mod string) (*storage.RevInfo, error) { + const op errors.Op = "pool.Latest" + var info *storage.RevInfo + var err error + done := make(chan struct{}, 1) + p.jobCh <- func() { + info, err = p.dp.Latest(ctx, mod) + close(done) + } + <-done + if err != nil { + return nil, errors.E(op, err) + } + return info, nil +} + +func (p *withpool) GoMod(ctx context.Context, mod, ver string) ([]byte, error) { + const op errors.Op = "pool.GoMod" + var goMod []byte + var err error + done := make(chan struct{}, 1) + p.jobCh <- func() { + goMod, err = p.dp.GoMod(ctx, mod, ver) + close(done) + } + <-done + if err != nil { + return nil, errors.E(op, err) + } + return goMod, nil +} + +func (p *withpool) Zip(ctx context.Context, mod, ver string) (io.ReadCloser, error) { + const op errors.Op = "pool.Zip" + var zip io.ReadCloser + var err error + done := make(chan struct{}, 1) + p.jobCh <- func() { + zip, err = p.dp.Zip(ctx, mod, ver) + close(done) + } + <-done + if err != nil { + return nil, errors.E(op, err) + } + return zip, nil +} diff --git a/pkg/download/addons/with_pool_test.go b/pkg/download/addons/with_pool_test.go new file mode 100644 index 00000000..307e4588 --- /dev/null +++ b/pkg/download/addons/with_pool_test.go @@ -0,0 +1,153 @@ +package addons + +import ( + "bytes" + "context" + "fmt" + "io" + "reflect" + "sync" + "testing" + "time" + + "github.com/gomods/athens/pkg/download" + "github.com/gomods/athens/pkg/storage" +) + +// TestPoolLogic ensures that no +// more than given workers are working +// at one time. +func TestPoolLogic(t *testing.T) { + m := &mockPool{} + dp := WithPool(5)(m) + ctx := context.Background() + m.ch = make(chan struct{}) + for i := 0; i < 10; i++ { + go dp.List(ctx, "") + } + <-m.ch + if m.num != 5 { + t.Fatalf("expected 4 workers but got %v", m.num) + } +} + +type mockPool struct { + download.Protocol + num int + mu sync.Mutex + ch chan struct{} +} + +func (m *mockPool) List(ctx context.Context, mod string) ([]string, error) { + m.mu.Lock() + m.num++ + if m.num == 5 { + m.ch <- struct{}{} + } + m.mu.Unlock() + + time.Sleep(time.Minute) + return nil, nil +} + +// TestPoolWrapper ensures all upstream methods +// are successfully called. +func TestPoolWrapper(t *testing.T) { + m := &mockDP{} + dp := WithPool(1)(m) + ctx := context.Background() + mod := "pkg" + ver := "v0.1.0" + m.inputMod = mod + m.inputVer = ver + m.list = []string{"v0.0.0", "v0.1.0"} + givenList, err := dp.List(ctx, mod) + if err != m.err { + t.Fatalf("expected dp.List err to be %v but got %v", m.err, err) + } + if !reflect.DeepEqual(m.list, givenList) { + t.Fatalf("dp.List: expected %v and %v to be equal", m.list, givenList) + } + m.info = []byte("info response") + givenInfo, err := dp.Info(ctx, mod, ver) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(m.info, givenInfo) { + t.Fatalf("dp.Info: expected %s and %s to be equal", m.info, givenInfo) + } + m.err = fmt.Errorf("mod err") + _, err = dp.GoMod(ctx, mod, ver) + if m.err.Error() != err.Error() { + t.Fatalf("dp.GoMod: expected err to be `%v` but got `%v`", m.err, err) + } + _, err = dp.Zip(ctx, mod, ver) + if m.err.Error() != err.Error() { + t.Fatalf("dp.Zip: expected err to be `%v` but got `%v`", m.err, err) + } +} + +type mockDP struct { + err error + list []string + info []byte + latest *storage.RevInfo + gomod []byte + zip io.ReadCloser + inputMod string + inputVer string +} + +// List implements GET /{module}/@v/list +func (m *mockDP) List(ctx context.Context, mod string) ([]string, error) { + if m.inputMod != mod { + return nil, fmt.Errorf("expected mod input %v but got %v", m.inputMod, mod) + } + return m.list, m.err +} + +// Info implements GET /{module}/@v/{version}.info +func (m *mockDP) Info(ctx context.Context, mod, ver string) ([]byte, error) { + if m.inputMod != mod { + return nil, fmt.Errorf("expected mod input %v but got %v", m.inputMod, mod) + } + if m.inputVer != ver { + return nil, fmt.Errorf("expected ver input %v but got %v", m.inputVer, ver) + } + return m.info, m.err +} + +// Latest implements GET /{module}/@latest +func (m *mockDP) Latest(ctx context.Context, mod string) (*storage.RevInfo, error) { + if m.inputMod != mod { + return nil, fmt.Errorf("expected mod input %v but got %v", m.inputMod, mod) + } + return m.latest, m.err +} + +// GoMod implements GET /{module}/@v/{version}.mod +func (m *mockDP) GoMod(ctx context.Context, mod, ver string) ([]byte, error) { + if m.inputMod != mod { + return nil, fmt.Errorf("expected mod input %v but got %v", m.inputMod, mod) + } + if m.inputVer != ver { + return nil, fmt.Errorf("expected ver input %v but got %v", m.inputVer, ver) + } + return m.gomod, m.err +} + +// Zip implements GET /{module}/@v/{version}.zip +func (m *mockDP) Zip(ctx context.Context, mod, ver string) (io.ReadCloser, error) { + if m.inputMod != mod { + return nil, fmt.Errorf("expected mod input %v but got %v", m.inputMod, mod) + } + if m.inputVer != ver { + return nil, fmt.Errorf("expected ver input %v but got %v", m.inputVer, ver) + } + return m.zip, m.err +} + +// Version is a helper method to get Info, GoMod, and Zip together. +func (m *mockDP) Version(ctx context.Context, mod, ver string) (*storage.Version, error) { + panic("skipped") +} diff --git a/pkg/download/download_test.go b/pkg/download/download_test.go deleted file mode 100644 index eaff77d2..00000000 --- a/pkg/download/download_test.go +++ /dev/null @@ -1,74 +0,0 @@ -package download - -import ( - "bytes" - "context" - "io/ioutil" - "testing" - - "github.com/gomods/athens/pkg/storage/mem" - - "github.com/gomods/athens/pkg/storage" - "golang.org/x/sync/errgroup" -) - -type testMod struct { - mod, ver string -} - -var mods = []testMod{ - {"github.com/athens-artifacts/no-tags", "v0.0.2"}, - {"github.com/athens-artifacts/happy-path", "v0.0.0-20180803035119-e4e0177efdb5"}, - {"github.com/athens-artifacts/samplelib", "v1.0.0"}, -} - -func TestDownloadProtocol(t *testing.T) { - s, err := mem.NewStorage() - if err != nil { - t.Fatal(err) - } - dp := New(&mockProtocol{}, s, 2) - ctx := context.Background() - - var eg errgroup.Group - for i := 0; i < len(mods); i++ { - m := mods[i] - eg.Go(func() error { - _, err := dp.GoMod(ctx, m.mod, m.ver) - return err - }) - } - - err = eg.Wait() - if err != nil { - t.Fatal(err) - } - - for _, m := range mods { - bts, err := dp.GoMod(ctx, m.mod, m.ver) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(bts, []byte(m.mod+"@"+m.ver)) { - t.Fatalf("unexpected gomod content: %s", bts) - } - } -} - -type mockProtocol struct { - Protocol -} - -// Info implements GET /{module}/@v/{version}.info -func (m *mockProtocol) Info(ctx context.Context, mod, ver string) ([]byte, error) { - return []byte(mod + "@" + ver), nil -} - -func (m *mockProtocol) Version(ctx context.Context, mod, ver string) (*storage.Version, error) { - bts := []byte(mod + "@" + ver) - return &storage.Version{ - Mod: bts, - Info: bts, - Zip: ioutil.NopCloser(bytes.NewReader(bts)), - }, nil -} diff --git a/pkg/download/goget/goget.go b/pkg/download/goget/goget.go deleted file mode 100644 index 677815df..00000000 --- a/pkg/download/goget/goget.go +++ /dev/null @@ -1,160 +0,0 @@ -package goget - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "os/exec" - "time" - - "github.com/gomods/athens/pkg/config" - "github.com/gomods/athens/pkg/config/env" - "github.com/gomods/athens/pkg/download" - "github.com/gomods/athens/pkg/errors" - "github.com/gomods/athens/pkg/module" - "github.com/gomods/athens/pkg/storage" - "github.com/spf13/afero" -) - -// New returns a download protocol by using -// go get. You must have a modules supported -// go binary for this to work. -func New() (download.Protocol, error) { - const op errors.Op = "goget.New" - goBin := env.GoBinPath() - fs := afero.NewOsFs() - mf, err := module.NewGoGetFetcher(goBin, fs) - if err != nil { - return nil, errors.E(op, err) - } - return &goget{ - goBinPath: goBin, - fetcher: mf, - fs: fs, - }, nil -} - -type goget struct { - goBinPath string - fetcher module.Fetcher - fs afero.Fs -} - -func (gg *goget) List(ctx context.Context, mod string) ([]string, error) { - const op errors.Op = "goget.List" - lr, err := gg.list(op, mod) - if err != nil { - return nil, err - } - - return lr.Versions, nil -} - -type listResp struct { - Path string - Version string - Versions []string `json:",omitempty"` - Time time.Time -} - -func (gg *goget) Info(ctx context.Context, mod string, ver string) ([]byte, error) { - const op errors.Op = "goget.Info" - v, err := gg.Version(ctx, mod, ver) - if err != nil { - return nil, errors.E(op, err) - } - v.Zip.Close() - - return v.Info, nil -} - -func (gg *goget) Latest(ctx context.Context, mod string) (*storage.RevInfo, error) { - const op errors.Op = "goget.Latest" - lr, err := gg.list(op, mod) - if err != nil { - return nil, errors.E(op, err) - } - - return &storage.RevInfo{ - Time: lr.Time, - Version: lr.Version, - }, nil -} - -func (gg *goget) list(op errors.Op, mod string) (*listResp, error) { - hackyPath, err := afero.TempDir(gg.fs, "", "hackymod") - if err != nil { - return nil, errors.E(op, err) - } - defer gg.fs.RemoveAll(hackyPath) - err = module.Dummy(gg.fs, hackyPath) - - cmd := exec.Command( - gg.goBinPath, - "list", "-m", "-versions", "-json", - config.FmtModVer(mod, "latest"), - ) - cmd.Dir = hackyPath - stdout := &bytes.Buffer{} - stderr := &bytes.Buffer{} - cmd.Stdout = stdout - cmd.Stderr = stderr - - gopath, err := afero.TempDir(gg.fs, "", "athens") - if err != nil { - return nil, errors.E(op, err) - } - defer module.ClearFiles(gg.fs, gopath) - cmd.Env = module.PrepareEnv(gopath) - - err = cmd.Run() - if err != nil { - err = fmt.Errorf("%v: %s", err, stderr) - return nil, errors.E(op, err) - } - - var lr listResp - err = json.NewDecoder(stdout).Decode(&lr) - if err != nil { - return nil, errors.E(op, err) - } - - return &lr, nil -} - -func (gg *goget) GoMod(ctx context.Context, mod string, ver string) ([]byte, error) { - const op errors.Op = "goget.Info" - v, err := gg.Version(ctx, mod, ver) - if err != nil { - return nil, errors.E(op, err) - } - v.Zip.Close() - - return v.Mod, nil -} - -func (gg *goget) Zip(ctx context.Context, mod, ver string) (io.ReadCloser, error) { - const op errors.Op = "goget.Info" - v, err := gg.Version(ctx, mod, ver) - if err != nil { - return nil, errors.E(op, err) - } - - return v.Zip, nil -} - -func (gg *goget) Version(ctx context.Context, mod, ver string) (*storage.Version, error) { - const op errors.Op = "goget.Version" - ref, err := gg.fetcher.Fetch(mod, ver) - if err != nil { - return nil, errors.E(op, err) - } - v, err := ref.Read() - if err != nil { - return nil, errors.E(op, err) - } - - return v, nil -} diff --git a/pkg/download/download.go b/pkg/download/protocol.go similarity index 52% rename from pkg/download/download.go rename to pkg/download/protocol.go index 64c645aa..e535df4e 100644 --- a/pkg/download/download.go +++ b/pkg/download/protocol.go @@ -1,12 +1,20 @@ package download import ( + "bytes" "context" + "encoding/json" + "fmt" "io" + "os/exec" "time" + "github.com/gomods/athens/pkg/config" "github.com/gomods/athens/pkg/errors" + "github.com/gomods/athens/pkg/module" + "github.com/gomods/athens/pkg/stash" "github.com/gomods/athens/pkg/storage" + "github.com/spf13/afero" ) // Protocol is the download protocol which mirrors @@ -26,63 +34,119 @@ type Protocol interface { // Zip implements GET /{module}/@v/{version}.zip Zip(ctx context.Context, mod, ver string) (io.ReadCloser, error) - - // Version is a helper method to get Info, GoMod, and Zip together. - Version(ctx context.Context, mod, ver string) (*storage.Version, error) } -type protocol struct { - s storage.Backend - dp Protocol - ch chan *job +// Wrapper helps extend the main stasher's functionality with addons. +type Wrapper func(Protocol) Protocol + +// Opts specifies download protocol options to avoid long func signature. +type Opts struct { + Storage storage.Backend + Stasher stash.Stasher + GoBinPath string + Fs afero.Fs } -type job struct { - mod, ver string - done chan error -} +// New returns a full implementation of the download.Protocol +// that the proxy needs. New also takes a variadic list of wrappers +// to extend the protocol's functionality (see addons package). +// The wrappers are applied in order, meaning the last wrapper +// passed is the Protocol that gets hit first. +func New(opts *Opts, wrappers ...Wrapper) Protocol { + var p Protocol = &protocol{opts.Storage, opts.Stasher, opts.GoBinPath, opts.Fs} + for _, w := range wrappers { + p = w(p) + } -// New takes an upstream Protocol and storage -// it always prefers storage, otherwise it goes to upstream -// and fills the storage with the results. -func New(dp Protocol, s storage.Backend, workers int) Protocol { - ch := make(chan *job) - p := &protocol{dp: dp, s: s, ch: ch} - p.start(workers) return p } -func (p *protocol) start(numWorkers int) { - for i := 0; i < numWorkers; i++ { - go p.listen() - } -} - -func (p *protocol) listen() { - for j := range p.ch { - j.done <- p.fillCache(j.mod, j.ver) - } -} - -func (p *protocol) request(mod, ver string) error { - j := &job{ - mod: mod, - ver: ver, - done: make(chan error), - } - p.ch <- j - return <-j.done +type protocol struct { + s storage.Backend + stasher stash.Stasher + goBinPath string + fs afero.Fs } func (p *protocol) List(ctx context.Context, mod string) ([]string, error) { - return p.dp.List(ctx, mod) + const op errors.Op = "protocol.List" + lr, err := p.list(op, mod) + if err != nil { + return nil, err + } + + return lr.Versions, nil +} + +func (p *protocol) Latest(ctx context.Context, mod string) (*storage.RevInfo, error) { + const op errors.Op = "protocol.Latest" + lr, err := p.list(op, mod) + if err != nil { + return nil, errors.E(op, err) + } + + return &storage.RevInfo{ + Time: lr.Time, + Version: lr.Version, + }, nil +} + +type listResp struct { + Path string + Version string + Versions []string `json:",omitempty"` + Time time.Time +} + +func (p *protocol) list(op errors.Op, mod string) (*listResp, error) { + hackyPath, err := afero.TempDir(p.fs, "", "hackymod") + if err != nil { + return nil, errors.E(op, err) + } + defer p.fs.RemoveAll(hackyPath) + err = module.Dummy(p.fs, hackyPath) + if err != nil { + return nil, errors.E(op, err) + } + + cmd := exec.Command( + p.goBinPath, + "list", "-m", "-versions", "-json", + config.FmtModVer(mod, "latest"), + ) + cmd.Dir = hackyPath + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + cmd.Stdout = stdout + cmd.Stderr = stderr + + gopath, err := afero.TempDir(p.fs, "", "athens") + if err != nil { + return nil, errors.E(op, err) + } + defer module.ClearFiles(p.fs, gopath) + cmd.Env = module.PrepareEnv(gopath) + + err = cmd.Run() + if err != nil { + err = fmt.Errorf("%v: %s", err, stderr) + return nil, errors.E(op, err) + } + + var lr listResp + err = json.NewDecoder(stdout).Decode(&lr) + if err != nil { + return nil, errors.E(op, err) + } + + return &lr, nil } func (p *protocol) Info(ctx context.Context, mod, ver string) ([]byte, error) { const op errors.Op = "protocol.Info" info, err := p.s.Info(ctx, mod, ver) if errors.IsNotFoundErr(err) { - err = p.request(mod, ver) + err = p.stasher.Stash(mod, ver) if err != nil { return nil, errors.E(op, err) } @@ -95,38 +159,11 @@ func (p *protocol) Info(ctx context.Context, mod, ver string) ([]byte, error) { return info, nil } -func (p *protocol) fillCache(mod, ver string) error { - const op errors.Op = "protocol.fillCache" - ctx, cancel := context.WithTimeout(context.Background(), time.Minute*10) - defer cancel() - v, err := p.dp.Version(ctx, mod, ver) - if err != nil { - return errors.E(op, err) - } - defer v.Zip.Close() - err = p.s.Save(ctx, mod, ver, v.Mod, v.Zip, v.Info) - if err != nil { - return errors.E(op, err) - } - - return nil -} - -func (p *protocol) Latest(ctx context.Context, mod string) (*storage.RevInfo, error) { - const op errors.Op = "protocol.Latest" - info, err := p.dp.Latest(ctx, mod) - if err != nil { - return nil, errors.E(op, err) - } - - return info, nil -} - func (p *protocol) GoMod(ctx context.Context, mod, ver string) ([]byte, error) { const op errors.Op = "protocol.GoMod" goMod, err := p.s.GoMod(ctx, mod, ver) if errors.IsNotFoundErr(err) { - err = p.request(mod, ver) + err = p.stasher.Stash(mod, ver) if err != nil { return nil, errors.E(op, err) } @@ -143,7 +180,7 @@ func (p *protocol) Zip(ctx context.Context, mod, ver string) (io.ReadCloser, err const op errors.Op = "protocol.Zip" zip, err := p.s.Zip(ctx, mod, ver) if errors.IsNotFoundErr(err) { - err = p.request(mod, ver) + err = p.stasher.Stash(mod, ver) if err != nil { return nil, errors.E(op, err) } @@ -155,27 +192,3 @@ func (p *protocol) Zip(ctx context.Context, mod, ver string) (io.ReadCloser, err return zip, nil } - -func (p *protocol) Version(ctx context.Context, mod, ver string) (*storage.Version, error) { - const op errors.Op = "protocol.Version" - info, err := p.Info(ctx, mod, ver) - if err != nil { - return nil, errors.E(op, err) - } - - goMod, err := p.GoMod(ctx, mod, ver) - if err != nil { - return nil, errors.E(op, err) - } - - zip, err := p.s.Zip(ctx, mod, ver) - if err != nil { - return nil, errors.E(op, err) - } - - return &storage.Version{ - Info: info, - Mod: goMod, - Zip: zip, - }, nil -} diff --git a/pkg/download/goget/goget_test.go b/pkg/download/protocol_test.go similarity index 71% rename from pkg/download/goget/goget_test.go rename to pkg/download/protocol_test.go index 589de42e..f3d2335e 100644 --- a/pkg/download/goget/goget_test.go +++ b/pkg/download/protocol_test.go @@ -1,4 +1,4 @@ -package goget +package download import ( "bytes" @@ -11,10 +11,32 @@ import ( "testing" "time" + "github.com/gomods/athens/pkg/config/env" + "github.com/gomods/athens/pkg/module" + "github.com/gomods/athens/pkg/stash" "github.com/gomods/athens/pkg/storage" + "github.com/gomods/athens/pkg/storage/mem" + "github.com/spf13/afero" "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" ) +func getDP(t *testing.T) Protocol { + t.Helper() + goBin := env.GoBinPath() + fs := afero.NewOsFs() + mf, err := module.NewGoGetFetcher(goBin, fs) + if err != nil { + t.Fatal(err) + } + s, err := mem.NewStorage() + if err != nil { + t.Fatal(err) + } + st := stash.New(mf, s) + return New(&Opts{s, st, goBin, fs}) +} + type listTest struct { name string path string @@ -34,8 +56,7 @@ var listTests = []listTest{ } func TestList(t *testing.T) { - dp, err := New() - require.NoError(t, err, "failed to create protocol") + dp := getDP(t) ctx := context.Background() for _, tc := range listTests { @@ -48,8 +69,7 @@ func TestList(t *testing.T) { } func TestConcurrentLists(t *testing.T) { - dp, err := New() - require.NoError(t, err, "failed to create protocol") + dp := getDP(t) ctx := context.Background() pkg := "github.com/athens-artifacts/samplelib" @@ -106,8 +126,7 @@ var latestTests = []latestTest{ } func TestLatest(t *testing.T) { - dp, err := New() - require.NoError(t, err) + dp := getDP(t) ctx := context.Background() for _, tc := range latestTests { @@ -153,8 +172,7 @@ var infoTests = []infoTest{ } func TestInfo(t *testing.T) { - dp, err := New() - require.NoError(t, err) + dp := getDP(t) ctx := context.Background() for _, tc := range infoTests { @@ -200,8 +218,7 @@ var modTests = []modTest{ } func TestGoMod(t *testing.T) { - dp, err := New() - require.NoError(t, err) + dp := getDP(t) ctx := context.Background() for _, tc := range modTests { @@ -228,3 +245,59 @@ func getGoldenFile(t *testing.T, name string) []byte { return bts } + +type testMod struct { + mod, ver string +} + +var mods = []testMod{ + {"github.com/athens-artifacts/no-tags", "v0.0.2"}, + {"github.com/athens-artifacts/happy-path", "v0.0.0-20180803035119-e4e0177efdb5"}, + {"github.com/athens-artifacts/samplelib", "v1.0.0"}, +} + +func TestDownloadProtocol(t *testing.T) { + s, err := mem.NewStorage() + if err != nil { + t.Fatal(err) + } + mp := &mockFetcher{} + st := stash.New(mp, s) + dp := New(&Opts{s, st, "", afero.NewMemMapFs()}) + ctx := context.Background() + + var eg errgroup.Group + for i := 0; i < len(mods); i++ { + m := mods[i] + eg.Go(func() error { + _, err := dp.GoMod(ctx, m.mod, m.ver) + return err + }) + } + + err = eg.Wait() + if err != nil { + t.Fatal(err) + } + + for _, m := range mods { + bts, err := dp.GoMod(ctx, m.mod, m.ver) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(bts, []byte(m.mod+"@"+m.ver)) { + t.Fatalf("unexpected gomod content: %s", bts) + } + } +} + +type mockFetcher struct{} + +func (m *mockFetcher) Fetch(ctx context.Context, mod, ver string) (*storage.Version, error) { + bts := []byte(mod + "@" + ver) + return &storage.Version{ + Mod: bts, + Info: bts, + Zip: ioutil.NopCloser(bytes.NewReader(bts)), + }, nil +} diff --git a/pkg/download/goget/test_data/no_mod_file.golden b/pkg/download/test_data/no_mod_file.golden similarity index 100% rename from pkg/download/goget/test_data/no_mod_file.golden rename to pkg/download/test_data/no_mod_file.golden diff --git a/pkg/download/goget/test_data/upstream_mod_file.golden b/pkg/download/test_data/upstream_mod_file.golden similarity index 100% rename from pkg/download/goget/test_data/upstream_mod_file.golden rename to pkg/download/test_data/upstream_mod_file.golden diff --git a/pkg/module/fetcher.go b/pkg/module/fetcher.go index 1135b99c..1f76bdfe 100644 --- a/pkg/module/fetcher.go +++ b/pkg/module/fetcher.go @@ -1,10 +1,14 @@ package module +import ( + "context" + + "github.com/gomods/athens/pkg/storage" +) + // Fetcher fetches module from an upstream source type Fetcher interface { - // Fetch fetches the module and puts it somewhere addressable by ModuleRef. - // returns a non-nil error on failure. - // - // The caller should call moduleRef.Clear() after they're done with the module - Fetch(mod, ver string) (Ref, error) + // Fetch downloads the sources from an upstream and returns the corresponding + // .info, .mod, and .zip files. + Fetch(ctx context.Context, mod, ver string) (*storage.Version, error) } diff --git a/pkg/module/go_get_fetcher.go b/pkg/module/go_get_fetcher.go index c0a03c4f..e58fd29a 100644 --- a/pkg/module/go_get_fetcher.go +++ b/pkg/module/go_get_fetcher.go @@ -2,6 +2,7 @@ package module import ( "bytes" + "context" "fmt" "os" "os/exec" @@ -11,6 +12,7 @@ import ( "github.com/gomods/athens/pkg/errors" "github.com/gomods/athens/pkg/paths" + "github.com/gomods/athens/pkg/storage" "github.com/spf13/afero" ) @@ -31,9 +33,9 @@ func NewGoGetFetcher(goBinaryName string, fs afero.Fs) (Fetcher, error) { }, nil } -// Fetch downloads the sources and returns path where it can be found. Make sure to call Clear -// on the returned Ref when you are done with it -func (g *goGetFetcher) Fetch(mod, ver string) (Ref, error) { +// Fetch downloads the sources from the go binary and returns the corresponding +// .info, .mod, and .zip files. +func (g *goGetFetcher) Fetch(ctx context.Context, mod, ver string) (*storage.Version, error) { const op errors.Op = "goGetFetcher.Fetch" // setup the GOPATH goPathRoot, err := afero.TempDir(g.fs, "", "athens") @@ -59,7 +61,8 @@ func (g *goGetFetcher) Fetch(mod, ver string) (Ref, error) { return nil, errors.E(op, err) } - return newDiskRef(g.fs, goPathRoot, mod, ver), nil + dr := newDiskRef(g.fs, goPathRoot, mod, ver) + return dr.Read() } // Dummy Hacky thing makes vgo not to complain diff --git a/pkg/module/go_get_fetcher_test.go b/pkg/module/go_get_fetcher_test.go index 7989e201..7dcfbdcc 100644 --- a/pkg/module/go_get_fetcher_test.go +++ b/pkg/module/go_get_fetcher_test.go @@ -1,16 +1,18 @@ package module import ( + "context" "fmt" "io/ioutil" "log" "github.com/gomods/athens/pkg/config/env" - "github.com/stretchr/testify/assert" - "github.com/spf13/afero" + "github.com/stretchr/testify/assert" ) +var ctx = context.Background() + func (s *ModuleSuite) TestNewGoGetFetcher() { r := s.Require() fetcher, err := NewGoGetFetcher(s.goBinaryName, s.fs) @@ -32,9 +34,7 @@ func (s *ModuleSuite) TestGoGetFetcherFetch() { // always writes to the filesystem fetcher, err := NewGoGetFetcher(s.goBinaryName, afero.NewOsFs()) r.NoError(err) - ref, err := fetcher.Fetch(repoURI, version) - r.NoError(err, "fetch shouldn't error") - ver, err := ref.Read() + ver, err := fetcher.Fetch(ctx, repoURI, version) r.NoError(err) defer ver.Zip.Close() @@ -48,9 +48,6 @@ func (s *ModuleSuite) TestGoGetFetcherFetch() { // close the version's zip file (which also cleans up the underlying diskref's GOPATH) and expect it to fail again r.NoError(ver.Zip.Close()) - ver, err = ref.Read() - r.NotNil(err) - r.Nil(ver) } func ExampleFetcher() { @@ -61,12 +58,11 @@ func ExampleFetcher() { if err != nil { log.Fatal(err) } - ref, err := fetcher.Fetch(repoURI, version) + versionData, err := fetcher.Fetch(ctx, repoURI, version) // handle errors if any if err != nil { return } - versionData, err := ref.Read() // Close the handle to versionData.Zip once done // This will also handle cleanup so it's important to call Close defer versionData.Zip.Close() diff --git a/pkg/module/noop_ref.go b/pkg/module/noop_ref.go deleted file mode 100644 index 1bd2cbce..00000000 --- a/pkg/module/noop_ref.go +++ /dev/null @@ -1,13 +0,0 @@ -package module - -import ( - "fmt" - - "github.com/gomods/athens/pkg/storage" -) - -type noopRef struct{} - -func (n noopRef) Read() (*storage.Version, error) { - return nil, fmt.Errorf("noop ref doesn't have a storage.Version") -} diff --git a/pkg/module/ref.go b/pkg/module/ref.go deleted file mode 100644 index 5630df9f..00000000 --- a/pkg/module/ref.go +++ /dev/null @@ -1,13 +0,0 @@ -package module - -import ( - "github.com/gomods/athens/pkg/storage" -) - -// Ref points to a module somewhere -type Ref interface { - // Read reads the module into memory and returns it. Notice that the Zip field on the returned - // storage.Version is an io.ReadCloser, so make sure to call Close on it after you're done - // with it. - Read() (*storage.Version, error) -} diff --git a/pkg/stash/stasher.go b/pkg/stash/stasher.go new file mode 100644 index 00000000..ab655add --- /dev/null +++ b/pkg/stash/stasher.go @@ -0,0 +1,62 @@ +package stash + +import ( + "context" + "time" + + "github.com/gomods/athens/pkg/errors" + "github.com/gomods/athens/pkg/module" + "github.com/gomods/athens/pkg/storage" +) + +// Stasher has the job of taking a module +// from an upstream entity and stashing it to a Storage Backend. +type Stasher interface { + Stash(string, string) error +} + +// Wrapper helps extend the main stasher's functionality with addons. +type Wrapper func(Stasher) Stasher + +// New returns a plain stasher that takes +// a module from a download.Protocol and +// stashes it into a backend.Storage. +func New(f module.Fetcher, s storage.Backend, wrappers ...Wrapper) Stasher { + var st Stasher = &stasher{f, s} + for _, w := range wrappers { + st = w(st) + } + + return st +} + +type stasher struct { + f module.Fetcher + s storage.Backend +} + +func (s *stasher) Stash(mod, ver string) error { + const op errors.Op = "stasher.Stash" + ctx, cancel := context.WithTimeout(context.Background(), time.Minute*10) + defer cancel() + v, err := s.fetchModule(ctx, mod, ver) + if err != nil { + return errors.E(op, err) + } + defer v.Zip.Close() + err = s.s.Save(ctx, mod, ver, v.Mod, v.Zip, v.Info) + if err != nil { + return errors.E(op, err) + } + return nil +} + +func (s *stasher) fetchModule(ctx context.Context, mod, ver string) (*storage.Version, error) { + const op errors.Op = "stasher.fetchModule" + v, err := s.f.Fetch(ctx, mod, ver) + if err != nil { + return nil, errors.E(op, err) + } + + return v, nil +} diff --git a/pkg/stash/with_pool.go b/pkg/stash/with_pool.go new file mode 100644 index 00000000..ba4e2ba0 --- /dev/null +++ b/pkg/stash/with_pool.go @@ -0,0 +1,54 @@ +package stash + +import ( + "github.com/gomods/athens/pkg/errors" +) + +type withpool struct { + s Stasher + + // see download/addons/with_pool + // for design docs on about this channel. + jobCh chan func() +} + +// WithPool returns a stasher that runs a stash operation +// {numWorkers} at a time. +func WithPool(numWorkers int) Wrapper { + return func(s Stasher) Stasher { + st := &withpool{ + s: s, + jobCh: make(chan func()), + } + st.start(numWorkers) + return st + } +} + +func (s *withpool) start(numWorkers int) { + for i := 0; i < numWorkers; i++ { + go s.listen() + } +} + +func (s *withpool) listen() { + for f := range s.jobCh { + f() + } +} + +func (s *withpool) Stash(mod, ver string) error { + const op errors.Op = "stash.Pool" + var err error + done := make(chan struct{}, 1) + s.jobCh <- func() { + err = s.s.Stash(mod, ver) + close(done) + } + <-done + if err != nil { + return errors.E(op, err) + } + + return nil +} diff --git a/pkg/stash/with_pool_test.go b/pkg/stash/with_pool_test.go new file mode 100644 index 00000000..676aab80 --- /dev/null +++ b/pkg/stash/with_pool_test.go @@ -0,0 +1,31 @@ +package stash + +import ( + "fmt" + "testing" +) + +func TestPoolWrapper(t *testing.T) { + m := &mockStasher{inputMod: "mod", inputVer: "ver", err: fmt.Errorf("wrapped err")} + s := WithPool(2)(m) + err := s.Stash(m.inputMod, m.inputVer) + if err.Error() != m.err.Error() { + t.Fatalf("expected err to be `%v` but got `%v`", m.err, err) + } +} + +type mockStasher struct { + inputMod string + inputVer string + err error +} + +func (m *mockStasher) Stash(mod, ver string) error { + if m.inputMod != mod { + return fmt.Errorf("expected input mod %v but got %v", m.inputMod, mod) + } + if m.inputVer != ver { + return fmt.Errorf("expected input ver %v but got %v", m.inputVer, ver) + } + return m.err +} diff --git a/pkg/stash/with_singleflight.go b/pkg/stash/with_singleflight.go new file mode 100644 index 00000000..347a5420 --- /dev/null +++ b/pkg/stash/with_singleflight.go @@ -0,0 +1,54 @@ +package stash + +import ( + "sync" + + "github.com/gomods/athens/pkg/config" +) + +// WithSingleflight returns a singleflight stasher. +// This two clients make two subsequent +// requests to stash a module, then +// it will only do it once and give the first +// response to both the first and the second client. +func WithSingleflight(s Stasher) Stasher { + sf := &withsf{} + sf.s = s + sf.subs = map[string][]chan error{} + + return sf +} + +type withsf struct { + s Stasher + + mu sync.Mutex + subs map[string][]chan error +} + +func (s *withsf) process(mod, ver string) { + mv := config.FmtModVer(mod, ver) + err := s.s.Stash(mod, ver) + s.mu.Lock() + defer s.mu.Unlock() + for _, ch := range s.subs[mv] { + ch <- err + } + delete(s.subs, mv) +} + +func (s *withsf) Stash(mod, ver string) error { + mv := config.FmtModVer(mod, ver) + s.mu.Lock() + subCh := make(chan error, 1) + _, inFlight := s.subs[mv] + if !inFlight { + s.subs[mv] = []chan error{subCh} + go s.process(mod, ver) + } else { + s.subs[mv] = append(s.subs[mv], subCh) + } + s.mu.Unlock() + + return <-subCh +} diff --git a/pkg/stash/with_singleflight_test.go b/pkg/stash/with_singleflight_test.go new file mode 100644 index 00000000..a7d54690 --- /dev/null +++ b/pkg/stash/with_singleflight_test.go @@ -0,0 +1,61 @@ +package stash + +import ( + "fmt" + "sync" + "testing" + "time" + + "golang.org/x/sync/errgroup" +) + +// TestSingleFlight will ensure that 5 concurrent requests will all get the first request's +// response. We can ensure that because only the first response does not return an error +// and therefore all 5 responses should have no error. +func TestSingleFlight(t *testing.T) { + ms := &mockSFStasher{} + s := WithSingleflight(ms) + + var eg errgroup.Group + for i := 0; i < 5; i++ { + eg.Go(func() error { + return s.Stash("mod", "ver") + }) + } + + err := eg.Wait() + if err != nil { + t.Fatal(err) + } + + for i := 0; i < 5; i++ { + eg.Go(func() error { + return s.Stash("mod", "ver") + }) + } + err = eg.Wait() + if err == nil { + t.Fatal("expected second error to return") + } +} + +// mockSFStasher mocks a Stash request that +// will always return a different result after the +// first one. This way we can prove that a second +// request did not get a second result, but the first +// one, provided the request came in at the right time. +type mockSFStasher struct { + mu sync.Mutex + num int +} + +func (ms *mockSFStasher) Stash(mod, ver string) error { + time.Sleep(time.Millisecond * 100) // allow for second requests to come in. + ms.mu.Lock() + defer ms.mu.Unlock() + if ms.num == 0 { + ms.num++ + return nil + } + return fmt.Errorf("second time error") +}