diff --git a/.drone.yml b/.drone.yml index 323a6a18..166ea448 100644 --- a/.drone.yml +++ b/.drone.yml @@ -59,6 +59,8 @@ steps: from_secret: ATHENS_AZURE_ACCOUNT_NAME ATHENS_AZURE_ACCOUNT_KEY: from_secret: ATHENS_AZURE_ACCOUNT_KEY + PROPAGATE_AUTH_TEST_TOKEN: + from_secret: PROPAGATE_AUTH_TEST_TOKEN when: branch: - main diff --git a/cmd/proxy/actions/app.go b/cmd/proxy/actions/app.go index a0c0de8a..af423a97 100644 --- a/cmd/proxy/actions/app.go +++ b/cmd/proxy/actions/app.go @@ -51,14 +51,17 @@ func App(conf *config.Config) (http.Handler, error) { lggr := log.New(conf.CloudRuntime, logLvl) r := mux.NewRouter() - r.Use(mw.WithRequestID) - r.Use(mw.LogEntryMiddleware(lggr)) - r.Use(mw.RequestLogger) - r.Use(secure.New(secure.Options{ - SSLRedirect: conf.ForceSSL, - SSLProxyHeaders: map[string]string{"X-Forwarded-Proto": "https"}, - }).Handler) - r.Use(mw.ContentType) + r.Use( + mw.WithRequestID, + mw.LogEntryMiddleware(lggr), + mw.RequestLogger, + secure.New(secure.Options{ + SSLRedirect: conf.ForceSSL, + SSLProxyHeaders: map[string]string{"X-Forwarded-Proto": "https"}, + }).Handler, + mw.ContentType, + mw.WithAuth, + ) var subRouter *mux.Router if prefix := conf.PathPrefix; prefix != "" { diff --git a/cmd/proxy/actions/app_proxy.go b/cmd/proxy/actions/app_proxy.go index 99e8fecb..b3dd3450 100644 --- a/cmd/proxy/actions/app_proxy.go +++ b/cmd/proxy/actions/app_proxy.go @@ -96,12 +96,12 @@ func addProxyRoutes( if err := c.GoBinaryEnvVars.Validate(); err != nil { return err } - mf, err := module.NewGoGetFetcher(c.GoBinary, c.GoGetDir, c.GoBinaryEnvVars, fs) + mf, err := module.NewGoGetFetcher(c.GoBinary, c.GoGetDir, c.GoBinaryEnvVars, fs, c.PropagateAuthHost) if err != nil { return err } - lister := module.NewVCSLister(c.GoBinary, c.GoBinaryEnvVars, fs) + lister := module.NewVCSLister(c.GoBinary, c.GoBinaryEnvVars, fs, c.PropagateAuthHost) checker := storage.WithChecker(s) withSingleFlight, err := getSingleFlight(c, checker) if err != nil { diff --git a/cmd/proxy/actions/auth.go b/cmd/proxy/actions/auth.go index 767dea1a..3e75a12f 100644 --- a/cmd/proxy/actions/auth.go +++ b/cmd/proxy/actions/auth.go @@ -44,7 +44,7 @@ func netrcFromToken(tok string) { if err != nil { log.Fatalf("netrcFromToken: could not get homedir: %v", err) } - rcp := filepath.Join(hdir, getNetrcFileName()) + rcp := filepath.Join(hdir, getNETRCFilename()) if err := ioutil.WriteFile(rcp, []byte(fileContent), 0600); err != nil { log.Fatalf("netrcFromToken: could not write to file: %v", err) } @@ -52,12 +52,12 @@ func netrcFromToken(tok string) { func transformAuthFileName(authFileName string) string { if root := strings.TrimLeft(authFileName, "._"); root == "netrc" { - return getNetrcFileName() + return getNETRCFilename() } return authFileName } -func getNetrcFileName() string { +func getNETRCFilename() string { if runtime.GOOS == "windows" { return "_netrc" } diff --git a/config.dev.toml b/config.dev.toml index 426f0b88..ba34b3d5 100755 --- a/config.dev.toml +++ b/config.dev.toml @@ -166,6 +166,30 @@ BasicAuthUser = "" # Env override: BASIC_AUTH_PASS BasicAuthPass = "" +# PropagateAuthHost, when set to a hostname such as "github.com", will pass the Basic Authentication +# Headers to the "go mod download" operations. This will allow a user +# to pass their credentials for a private repository and have Athens be +# able to download and store it. Note that, once a private repository is stored, +# Athens will naively serve it to anyone who requests it. +# +# Therefore, it is **important** that you +# make sure you have a ValidatorHook or put Athens behind an auth proxy that always +# ensures access to modules are securely authorized. +# +# Note that "go mod download" uses "git clone" which will look for these credentials +# in the $HOME directory of the process. Therefore, turning this feature on means that each +# "go mod download" will have its own $HOME direcotry with only the .netrc file. If +# your "go mod download" relies on your global $HOME directory (such as .gitconfig), then +# you must turn this feature off. If you'd like to specify files to be copied from the global +# $HOME directory to the temporary one, please open an issue at https://github.com/gomods/athens +# to gauge demand for such a feature before implementing. +# +# You must also specify the import path host using PropagateAuthHost so that the .netrc file knows +# when to forward the credentials and when not to. +# +# Env override: ATHENS_PROPAGATE_AUTH_HOST +PropagateAuthHost = "" + # Set to true to force an SSL redirect # Env override: PROXY_FORCE_SSL ForceSSL = false diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go new file mode 100644 index 00000000..e69b32cf --- /dev/null +++ b/pkg/auth/auth.go @@ -0,0 +1,64 @@ +package auth + +import ( + "context" + "fmt" + "io/ioutil" + "path/filepath" + "runtime" + + "github.com/gomods/athens/pkg/errors" +) + +type authkey struct{} + +// BasicAuth is the embedded credentials in a context +type BasicAuth struct { + User, Password string +} + +// SetAuthInContext sets the auth value in context +func SetAuthInContext(ctx context.Context, auth BasicAuth) context.Context { + return context.WithValue(ctx, authkey{}, auth) +} + +// FromContext retrieves the auth value +func FromContext(ctx context.Context) (BasicAuth, bool) { + auth, ok := ctx.Value(authkey{}).(BasicAuth) + return auth, ok +} + +// WriteNETRC writes the netrc file to the specified directory +func WriteNETRC(path, host, user, password string) error { + const op errors.Op = "auth.WriteNETRC" + fileContent := fmt.Sprintf("machine %s login %s password %s\n", host, user, password) + if err := ioutil.WriteFile(path, []byte(fileContent), 0600); err != nil { + return errors.E(op, fmt.Errorf("netrcFromToken: could not write to file: %v", err)) + } + return nil +} + +// WriteTemporaryNETRC writes a netrc file to a temporary directory, returning +// the directory it was written to. +func WriteTemporaryNETRC(host, user, password string) (string, error) { + const op errors.Op = "auth.WriteTemporaryNETRC" + dir, err := ioutil.TempDir("", "netrcp") + if err != nil { + return "", errors.E(op, err) + } + rcp := filepath.Join(dir, GetNETRCFilename()) + err = WriteNETRC(rcp, host, user, password) + if err != nil { + return "", errors.E(op, err) + } + return dir, nil +} + +// GetNETRCFilename returns the name of the netrc file +// according to the contextual platform +func GetNETRCFilename() string { + if runtime.GOOS == "windows" { + return "_netrc" + } + return ".netrc" +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 656ff597..7f05cb58 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -21,44 +21,45 @@ const defaultConfigFile = "athens.toml" // Config provides configuration values for all components type Config struct { TimeoutConf - GoEnv string `validate:"required" envconfig:"GO_ENV"` - GoBinary string `validate:"required" envconfig:"GO_BINARY_PATH"` - GoProxy string `envconfig:"GOPROXY"` - GoBinaryEnvVars EnvList `envconfig:"ATHENS_GO_BINARY_ENV_VARS"` - GoGetWorkers int `validate:"required" envconfig:"ATHENS_GOGET_WORKERS"` - GoGetDir string `envconfig:"ATHENS_GOGOET_DIR"` - ProtocolWorkers int `validate:"required" envconfig:"ATHENS_PROTOCOL_WORKERS"` - LogLevel string `validate:"required" envconfig:"ATHENS_LOG_LEVEL"` - CloudRuntime string `validate:"required" envconfig:"ATHENS_CLOUD_RUNTIME"` - EnablePprof bool `envconfig:"ATHENS_ENABLE_PPROF"` - PprofPort string `envconfig:"ATHENS_PPROF_PORT"` - FilterFile string `envconfig:"ATHENS_FILTER_FILE"` - TraceExporterURL string `envconfig:"ATHENS_TRACE_EXPORTER_URL"` - TraceExporter string `envconfig:"ATHENS_TRACE_EXPORTER"` - StatsExporter string `envconfig:"ATHENS_STATS_EXPORTER"` - StorageType string `validate:"required" envconfig:"ATHENS_STORAGE_TYPE"` - GlobalEndpoint string `envconfig:"ATHENS_GLOBAL_ENDPOINT"` // This feature is not yet implemented - Port string `envconfig:"ATHENS_PORT"` - BasicAuthUser string `envconfig:"BASIC_AUTH_USER"` - BasicAuthPass string `envconfig:"BASIC_AUTH_PASS"` - ForceSSL bool `envconfig:"PROXY_FORCE_SSL"` - ValidatorHook string `envconfig:"ATHENS_PROXY_VALIDATOR"` - PathPrefix string `envconfig:"ATHENS_PATH_PREFIX"` - NETRCPath string `envconfig:"ATHENS_NETRC_PATH"` - GithubToken string `envconfig:"ATHENS_GITHUB_TOKEN"` - HGRCPath string `envconfig:"ATHENS_HGRC_PATH"` - TLSCertFile string `envconfig:"ATHENS_TLSCERT_FILE"` - TLSKeyFile string `envconfig:"ATHENS_TLSKEY_FILE"` - SumDBs []string `envconfig:"ATHENS_SUM_DBS"` - NoSumPatterns []string `envconfig:"ATHENS_GONOSUM_PATTERNS"` - DownloadMode mode.Mode `envconfig:"ATHENS_DOWNLOAD_MODE"` - DownloadURL string `envconfig:"ATHENS_DOWNLOAD_URL"` - SingleFlightType string `envconfig:"ATHENS_SINGLE_FLIGHT_TYPE"` - RobotsFile string `envconfig:"ATHENS_ROBOTS_FILE"` - IndexType string `envconfig:"ATHENS_INDEX_TYPE"` - SingleFlight *SingleFlight - Storage *Storage - Index *Index + GoEnv string `validate:"required" envconfig:"GO_ENV"` + GoBinary string `validate:"required" envconfig:"GO_BINARY_PATH"` + GoProxy string `envconfig:"GOPROXY"` + GoBinaryEnvVars EnvList `envconfig:"ATHENS_GO_BINARY_ENV_VARS"` + GoGetWorkers int `validate:"required" envconfig:"ATHENS_GOGET_WORKERS"` + GoGetDir string `envconfig:"ATHENS_GOGOET_DIR"` + ProtocolWorkers int `validate:"required" envconfig:"ATHENS_PROTOCOL_WORKERS"` + LogLevel string `validate:"required" envconfig:"ATHENS_LOG_LEVEL"` + CloudRuntime string `validate:"required" envconfig:"ATHENS_CLOUD_RUNTIME"` + EnablePprof bool `envconfig:"ATHENS_ENABLE_PPROF"` + PprofPort string `envconfig:"ATHENS_PPROF_PORT"` + FilterFile string `envconfig:"ATHENS_FILTER_FILE"` + TraceExporterURL string `envconfig:"ATHENS_TRACE_EXPORTER_URL"` + TraceExporter string `envconfig:"ATHENS_TRACE_EXPORTER"` + StatsExporter string `envconfig:"ATHENS_STATS_EXPORTER"` + StorageType string `validate:"required" envconfig:"ATHENS_STORAGE_TYPE"` + GlobalEndpoint string `envconfig:"ATHENS_GLOBAL_ENDPOINT"` // This feature is not yet implemented + Port string `envconfig:"ATHENS_PORT"` + BasicAuthUser string `envconfig:"BASIC_AUTH_USER"` + BasicAuthPass string `envconfig:"BASIC_AUTH_PASS"` + PropagateAuthHost string `envconfig:"ATHENS_PROPAGATE_AUTH_HOST"` + ForceSSL bool `envconfig:"PROXY_FORCE_SSL"` + ValidatorHook string `envconfig:"ATHENS_PROXY_VALIDATOR"` + PathPrefix string `envconfig:"ATHENS_PATH_PREFIX"` + NETRCPath string `envconfig:"ATHENS_NETRC_PATH"` + GithubToken string `envconfig:"ATHENS_GITHUB_TOKEN"` + HGRCPath string `envconfig:"ATHENS_HGRC_PATH"` + TLSCertFile string `envconfig:"ATHENS_TLSCERT_FILE"` + TLSKeyFile string `envconfig:"ATHENS_TLSKEY_FILE"` + SumDBs []string `envconfig:"ATHENS_SUM_DBS"` + NoSumPatterns []string `envconfig:"ATHENS_GONOSUM_PATTERNS"` + DownloadMode mode.Mode `envconfig:"ATHENS_DOWNLOAD_MODE"` + DownloadURL string `envconfig:"ATHENS_DOWNLOAD_URL"` + SingleFlightType string `envconfig:"ATHENS_SINGLE_FLIGHT_TYPE"` + RobotsFile string `envconfig:"ATHENS_ROBOTS_FILE"` + IndexType string `envconfig:"ATHENS_INDEX_TYPE"` + SingleFlight *SingleFlight + Storage *Storage + Index *Index } // EnvList is a list of key-value environment @@ -142,29 +143,30 @@ func Load(configFile string) (*Config, error) { func defaultConfig() *Config { return &Config{ - GoBinary: "go", - GoBinaryEnvVars: EnvList{"GOPROXY=direct"}, - GoEnv: "development", - GoProxy: "direct", - GoGetWorkers: 10, - ProtocolWorkers: 30, - LogLevel: "debug", - CloudRuntime: "none", - EnablePprof: false, - PprofPort: ":3001", - StatsExporter: "prometheus", - TimeoutConf: TimeoutConf{Timeout: 300}, - StorageType: "memory", - Port: ":3000", - SingleFlightType: "memory", - GlobalEndpoint: "http://localhost:3001", - TraceExporterURL: "http://localhost:14268", - SumDBs: []string{"https://sum.golang.org"}, - NoSumPatterns: []string{}, - DownloadMode: "sync", - DownloadURL: "", - RobotsFile: "robots.txt", - IndexType: "none", + GoBinary: "go", + GoBinaryEnvVars: EnvList{"GOPROXY=direct"}, + GoEnv: "development", + GoProxy: "direct", + GoGetWorkers: 10, + ProtocolWorkers: 30, + LogLevel: "debug", + CloudRuntime: "none", + EnablePprof: false, + PprofPort: ":3001", + StatsExporter: "prometheus", + TimeoutConf: TimeoutConf{Timeout: 300}, + StorageType: "memory", + Port: ":3000", + PropagateAuthHost: "", + SingleFlightType: "memory", + GlobalEndpoint: "http://localhost:3001", + TraceExporterURL: "http://localhost:14268", + SumDBs: []string{"https://sum.golang.org"}, + NoSumPatterns: []string{}, + DownloadMode: "sync", + DownloadURL: "", + RobotsFile: "robots.txt", + IndexType: "none", SingleFlight: &SingleFlight{ Etcd: &Etcd{"localhost:2379,localhost:22379,localhost:32379"}, Redis: &Redis{"127.0.0.1:6379", ""}, diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index c8d6b2aa..be9580e2 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -274,26 +274,27 @@ func TestParseExampleConfig(t *testing.T) { TimeoutConf: TimeoutConf{ Timeout: 300, }, - StorageType: "memory", - GlobalEndpoint: "http://localhost:3001", - Port: ":3000", - EnablePprof: false, - PprofPort: ":3001", - BasicAuthUser: "", - BasicAuthPass: "", - Storage: expStorage, - TraceExporterURL: "http://localhost:14268", - TraceExporter: "", - StatsExporter: "prometheus", - SingleFlightType: "memory", - GoBinaryEnvVars: []string{"GOPROXY=direct"}, - SingleFlight: &SingleFlight{}, - SumDBs: []string{"https://sum.golang.org"}, - NoSumPatterns: []string{}, - DownloadMode: "sync", - RobotsFile: "robots.txt", - IndexType: "none", - Index: &Index{}, + StorageType: "memory", + GlobalEndpoint: "http://localhost:3001", + Port: ":3000", + PropagateAuthHost: "", + EnablePprof: false, + PprofPort: ":3001", + BasicAuthUser: "", + BasicAuthPass: "", + Storage: expStorage, + TraceExporterURL: "http://localhost:14268", + TraceExporter: "", + StatsExporter: "prometheus", + SingleFlightType: "memory", + GoBinaryEnvVars: []string{"GOPROXY=direct"}, + SingleFlight: &SingleFlight{}, + SumDBs: []string{"https://sum.golang.org"}, + NoSumPatterns: []string{}, + DownloadMode: "sync", + RobotsFile: "robots.txt", + IndexType: "none", + Index: &Index{}, } absPath, err := filepath.Abs(testConfigFile(t)) diff --git a/pkg/download/protocol_test.go b/pkg/download/protocol_test.go index 5703f6b8..eea12f78 100644 --- a/pkg/download/protocol_test.go +++ b/pkg/download/protocol_test.go @@ -37,7 +37,7 @@ func getDP(t *testing.T) Protocol { } goBin := conf.GoBinary fs := afero.NewOsFs() - mf, err := module.NewGoGetFetcher(goBin, conf.GoGetDir, conf.GoBinaryEnvVars, fs) + mf, err := module.NewGoGetFetcher(goBin, conf.GoGetDir, conf.GoBinaryEnvVars, fs, "") if err != nil { t.Fatal(err) } @@ -46,7 +46,7 @@ func getDP(t *testing.T) Protocol { t.Fatal(err) } st := stash.New(mf, s, nop.New()) - return New(&Opts{s, st, module.NewVCSLister(goBin, conf.GoBinaryEnvVars, fs), nil}) + return New(&Opts{s, st, module.NewVCSLister(goBin, conf.GoBinaryEnvVars, fs, ""), nil}) } type listTest struct { diff --git a/pkg/middleware/auth.go b/pkg/middleware/auth.go new file mode 100644 index 00000000..af1eca26 --- /dev/null +++ b/pkg/middleware/auth.go @@ -0,0 +1,22 @@ +package middleware + +import ( + "net/http" + + "github.com/gomods/athens/pkg/auth" +) + +type authkey struct{} + +// WithAuth inserts the Authorization header +// into the request context +func WithAuth(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user, password, ok := r.BasicAuth() + if ok { + ctx := auth.SetAuthInContext(r.Context(), auth.BasicAuth{User: user, Password: password}) + r = r.WithContext(ctx) + } + h.ServeHTTP(w, r) + }) +} diff --git a/pkg/middleware/auth_test.go b/pkg/middleware/auth_test.go new file mode 100644 index 00000000..c3147849 --- /dev/null +++ b/pkg/middleware/auth_test.go @@ -0,0 +1,63 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/gomods/athens/pkg/auth" +) + +func TestAuthMiddleware(t *testing.T) { + var tests = []struct { + name string + reqfunc func(r *http.Request) + wantok bool + wantauth auth.BasicAuth + }{ + { + name: "no auth", + reqfunc: func(r *http.Request) {}, + }, + { + name: "with basic auth", + reqfunc: func(r *http.Request) { + r.SetBasicAuth("user", "pass") + }, + wantok: true, + wantauth: auth.BasicAuth{User: "user", Password: "pass"}, + }, + { + name: "only user", + reqfunc: func(r *http.Request) { + r.SetBasicAuth("justuser", "") + }, + wantok: true, + wantauth: auth.BasicAuth{User: "justuser"}, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var ( + givenok bool + givenauth auth.BasicAuth + ) + h := WithAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + givenauth.User, givenauth.Password, givenok = r.BasicAuth() + })) + + r := httptest.NewRequest("GET", "/", nil) + tc.reqfunc(r) + w := httptest.NewRecorder() + + h.ServeHTTP(w, r) + + if givenok != tc.wantok { + t.Fatalf("expected basic auth existence to be %t but got %t", tc.wantok, givenok) + } + if givenauth != tc.wantauth { + t.Fatalf("expected basic auth to be %+v but got %+v", tc.wantauth, givenauth) + } + }) + } +} diff --git a/pkg/module/all_test.go b/pkg/module/all_test.go index 975f81ba..1254fc0a 100644 --- a/pkg/module/all_test.go +++ b/pkg/module/all_test.go @@ -13,6 +13,9 @@ const ( // github.com/NYTimes/gizmo is a example of a path that needs to be encoded so we can cover that case as well repoURI = "github.com/NYTimes/gizmo" version = "v0.1.4" + + privateRepoURI = "github.com/athens-artifacts/private" + privateRepoVersion = "v0.0.1" ) type ModuleSuite struct { diff --git a/pkg/module/go_get_fetcher.go b/pkg/module/go_get_fetcher.go index 9055ae1a..4f2d26c5 100644 --- a/pkg/module/go_get_fetcher.go +++ b/pkg/module/go_get_fetcher.go @@ -10,6 +10,7 @@ import ( "path/filepath" "strings" + "github.com/gomods/athens/pkg/auth" "github.com/gomods/athens/pkg/errors" "github.com/gomods/athens/pkg/observ" "github.com/gomods/athens/pkg/storage" @@ -17,10 +18,11 @@ import ( ) type goGetFetcher struct { - fs afero.Fs - goBinaryName string - envVars []string - gogetDir string + fs afero.Fs + goBinaryName string + envVars []string + gogetDir string + propagateAuthHost string } type goModule struct { @@ -36,16 +38,17 @@ type goModule struct { } // NewGoGetFetcher creates fetcher which uses go get tool to fetch modules -func NewGoGetFetcher(goBinaryName, gogetDir string, envVars []string, fs afero.Fs) (Fetcher, error) { +func NewGoGetFetcher(goBinaryName, gogetDir string, envVars []string, fs afero.Fs, propagateAuthHost string) (Fetcher, error) { const op errors.Op = "module.NewGoGetFetcher" if err := validGoBinary(goBinaryName); err != nil { return nil, errors.E(op, err) } return &goGetFetcher{ - fs: fs, - goBinaryName: goBinaryName, - envVars: envVars, - gogetDir: gogetDir, + fs: fs, + goBinaryName: goBinaryName, + envVars: envVars, + gogetDir: gogetDir, + propagateAuthHost: propagateAuthHost, }, nil } @@ -68,7 +71,7 @@ func (g *goGetFetcher) Fetch(ctx context.Context, mod, ver string) (*storage.Ver return nil, errors.E(op, err) } - m, err := downloadModule(g.goBinaryName, g.envVars, g.fs, goPathRoot, modPath, mod, ver) + m, err := g.downloadModule(ctx, goPathRoot, modPath, mod, ver) if err != nil { clearFiles(g.fs, goPathRoot) return nil, errors.E(op, err) @@ -103,19 +106,34 @@ func (g *goGetFetcher) Fetch(ctx context.Context, mod, ver string) (*storage.Ver // given a filesystem, gopath, repository root, module and version, runs 'go mod download -json' // on module@version from the repoRoot with GOPATH=gopath, and returns a non-nil error if anything went wrong. -func downloadModule(goBinaryName string, envVars []string, fs afero.Fs, gopath, repoRoot, module, version string) (goModule, error) { +func (g *goGetFetcher) downloadModule(ctx context.Context, gopath, repoRoot, module, version string) (goModule, error) { const op errors.Op = "module.downloadModule" + var ( + netrcDir string + err error + ) + creds, ok := auth.FromContext(ctx) + if ok && g.shouldPropAuth() { + if ok { + netrcDir, err = auth.WriteTemporaryNETRC(g.propagateAuthHost, creds.User, creds.Password) + if err != nil { + return goModule{}, errors.E(op, err) + } + defer os.RemoveAll(netrcDir) + } + } uri := strings.TrimSuffix(module, "/") - fullURI := fmt.Sprintf("%s@%s", uri, version) - cmd := exec.Command(goBinaryName, "mod", "download", "-json", fullURI) - cmd.Env = prepareEnv(gopath, envVars) + fullURI := fmt.Sprintf("%s@%s", uri, version) + cmd := exec.CommandContext(ctx, g.goBinaryName, "mod", "download", "-json", fullURI) + cmd.Env = prepareEnv(gopath, netrcDir, g.envVars) cmd.Dir = repoRoot stdout := &bytes.Buffer{} stderr := &bytes.Buffer{} cmd.Stdout = stdout cmd.Stderr = stderr - err := cmd.Run() + + err = cmd.Run() if err != nil { err = fmt.Errorf("%v: %s", err, stderr) var m goModule @@ -140,6 +158,10 @@ func downloadModule(goBinaryName string, envVars []string, fs afero.Fs, gopath, return m, nil } +func (g *goGetFetcher) shouldPropAuth() bool { + return len(g.propagateAuthHost) > 0 +} + func isLimitHit(o string) bool { return strings.Contains(o, "403 response from api.github.com") } diff --git a/pkg/module/go_get_fetcher_test.go b/pkg/module/go_get_fetcher_test.go index 37273c17..ac5fce83 100644 --- a/pkg/module/go_get_fetcher_test.go +++ b/pkg/module/go_get_fetcher_test.go @@ -7,24 +7,28 @@ import ( "net/http/httptest" "os" "runtime" + "testing" + "github.com/gobuffalo/envy" + "github.com/gomods/athens/pkg/auth" "github.com/gomods/athens/pkg/errors" "github.com/spf13/afero" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var ctx = context.Background() func (s *ModuleSuite) TestNewGoGetFetcher() { r := s.Require() - fetcher, err := NewGoGetFetcher(s.goBinaryName, "", s.env, s.fs) + fetcher, err := NewGoGetFetcher(s.goBinaryName, "", s.env, s.fs, "") r.NoError(err) _, ok := fetcher.(*goGetFetcher) r.True(ok) } func (s *ModuleSuite) TestGoGetFetcherError() { - fetcher, err := NewGoGetFetcher("invalidpath", "", s.env, afero.NewOsFs()) + fetcher, err := NewGoGetFetcher("invalidpath", "", s.env, afero.NewOsFs(), "") assert.Nil(s.T(), fetcher) if runtime.GOOS == "windows" { @@ -38,7 +42,7 @@ func (s *ModuleSuite) TestGoGetFetcherFetch() { r := s.Require() // we need to use an OS filesystem because fetch executes vgo on the command line, which // always writes to the filesystem - fetcher, err := NewGoGetFetcher(s.goBinaryName, "", s.env, afero.NewOsFs()) + fetcher, err := NewGoGetFetcher(s.goBinaryName, "", s.env, afero.NewOsFs(), "") r.NoError(err) ver, err := fetcher.Fetch(ctx, repoURI, version) r.NoError(err) @@ -56,9 +60,137 @@ func (s *ModuleSuite) TestGoGetFetcherFetch() { r.NoError(ver.Zip.Close()) } +func TestGoGetFetcherFetchPrivate(t *testing.T) { + token := os.Getenv("PROPAGATE_AUTH_TEST_TOKEN") + if token == "" { + t.SkipNow() + } + var tests = []struct { + name string + desc string + host string + auth auth.BasicAuth + hasErr bool + preTest func(t *testing.T, fetcher Fetcher) + }{ + { + name: "private no token", + desc: "cannot fetch a private repository without a basic auth token", + auth: auth.BasicAuth{User: "", Password: ""}, + hasErr: true, + host: "github.com", + }, + { + name: "prive fetch", + desc: "can successfully download private repository with a valid auth header", + host: "github.com", + auth: auth.BasicAuth{ + User: "athensuser", + Password: token, + }, + }, + { + name: "disable propagation", + desc: "cannot fetch a private repository even if basic auth is provided when there is no host", + auth: auth.BasicAuth{ + User: "athensuser", + Password: token, + }, + host: "", + hasErr: true, + }, + { + name: "mismatched auth host", + desc: "cannot fetch a private repository unless the module matches the provided host", + auth: auth.BasicAuth{ + User: "athensuser", + Password: token, + }, + host: "bitbucket.org", + hasErr: true, + }, + { + name: "consecutive private fetch", + desc: "this test ensures that the .netrc is removed after a private fetch so credentials are not leakaed to proceeding requests", + host: "github.com", + auth: auth.BasicAuth{}, + preTest: func(t *testing.T, fetcher Fetcher) { + a := auth.BasicAuth{ + User: "athensuser", + Password: token, + } + ctx := auth.SetAuthInContext(ctx, a) + ver, err := fetcher.Fetch(ctx, privateRepoURI, privateRepoVersion) + require.NoError(t, err) + require.NoError(t, ver.Zip.Close()) + }, + hasErr: true, + }, + } + goBinaryPath := envy.Get("GO_BINARY_PATH", "go") + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + envs := []string{"GOPROXY=direct", "GONOSUMDB=github.com/athens-artifacts/private"} + fetcher, err := NewGoGetFetcher( + goBinaryPath, + "", + envs, + afero.NewOsFs(), + tc.host, + ) + require.NoError(t, err) + if tc.preTest != nil { + tc.preTest(t, fetcher) + } + ctx := auth.SetAuthInContext(ctx, tc.auth) + ver, err := fetcher.Fetch(ctx, privateRepoURI, privateRepoVersion) + if ver != nil && ver.Zip != nil { + t.Cleanup(func() { + require.NoError(t, ver.Zip.Close()) + }) + } + if checkErr(t, tc.hasErr, err) { + return + } + require.True(t, len(ver.Info) > 0) + require.True(t, len(ver.Mod) > 0) + + zipBytes, err := ioutil.ReadAll(ver.Zip) + require.NoError(t, err) + require.True(t, len(zipBytes) > 0) + + lister := NewVCSLister(goBinaryPath, envs, afero.NewOsFs(), tc.host) + _, vers, err := lister.List(ctx, privateRepoURI) + if checkErr(t, tc.hasErr, err) { + return + } + if len(vers) != 1 { + t.Fatalf("expected number of version to be 1 but got %v", len(vers)) + } + if vers[0] != "v0.0.1" { + t.Fatalf("expected the version to be %q but got %q", "v0.0.1", vers[0]) + } + }) + } +} + +// checkErr fails on whether we were expecting an error and did not get one, +// or whether we were not expecting an error and did get one. +// It returns a boolean to say whether the caller should return early +func checkErr(t *testing.T, wantErr bool, err error) bool { + if wantErr { + if err == nil { + t.Fatal("expected an error but got nil") + } + return true + } + require.NoError(t, err) + return false +} + func (s *ModuleSuite) TestNotFoundFetches() { r := s.Require() - fetcher, err := NewGoGetFetcher(s.goBinaryName, "", s.env, afero.NewOsFs()) + fetcher, err := NewGoGetFetcher(s.goBinaryName, "", s.env, afero.NewOsFs(), "") r.NoError(err) // when someone buys laks47dfjoijskdvjxuyyd.com, and implements // a git server on top of it, this test will fail :) @@ -86,13 +218,13 @@ func (s *ModuleSuite) TestGoGetFetcherSumDB() { proxyAddr, close := s.getProxy(mp) defer close() - fetcher, err := NewGoGetFetcher(s.goBinaryName, "", []string{"GOPROXY=" + proxyAddr}, afero.NewOsFs()) + fetcher, err := NewGoGetFetcher(s.goBinaryName, "", []string{"GOPROXY=" + proxyAddr}, afero.NewOsFs(), "") r.NoError(err) _, err = fetcher.Fetch(ctx, "mockmod.xyz", "v1.2.3") if err == nil { s.T().Fatal("expected a gosum error but got nil") } - fetcher, err = NewGoGetFetcher(s.goBinaryName, "", []string{"GONOSUMDB=mockmod.xyz", "GOPROXY=" + proxyAddr}, afero.NewOsFs()) + fetcher, err = NewGoGetFetcher(s.goBinaryName, "", []string{"GONOSUMDB=mockmod.xyz", "GOPROXY=" + proxyAddr}, afero.NewOsFs(), "") r.NoError(err) _, err = fetcher.Fetch(ctx, "mockmod.xyz", "v1.2.3") r.NoError(err, "expected the go sum to not be consulted but got an error") @@ -106,7 +238,7 @@ func (s *ModuleSuite) TestGoGetDir() { t.Cleanup(func() { os.RemoveAll(dir) }) - fetcher, err := NewGoGetFetcher(s.goBinaryName, dir, s.env, afero.NewOsFs()) + fetcher, err := NewGoGetFetcher(s.goBinaryName, dir, s.env, afero.NewOsFs(), "") r.NoError(err) ver, err := fetcher.Fetch(ctx, repoURI, version) diff --git a/pkg/module/go_vcs_lister.go b/pkg/module/go_vcs_lister.go index dde9e257..3ea3c06d 100644 --- a/pkg/module/go_vcs_lister.go +++ b/pkg/module/go_vcs_lister.go @@ -5,11 +5,14 @@ import ( "context" "encoding/json" "fmt" + "os" "os/exec" "time" + "github.com/gomods/athens/pkg/auth" "github.com/gomods/athens/pkg/config" "github.com/gomods/athens/pkg/errors" + "github.com/gomods/athens/pkg/log" "github.com/gomods/athens/pkg/observ" "github.com/gomods/athens/pkg/storage" "github.com/spf13/afero" @@ -23,16 +26,43 @@ type listResp struct { } type vcsLister struct { - goBinPath string - env []string - fs afero.Fs + goBinPath string + env []string + fs afero.Fs + propagateAuthHost string } -func (l *vcsLister) List(ctx context.Context, mod string) (*storage.RevInfo, []string, error) { +// NewVCSLister creates an UpstreamLister which uses VCS to fetch a list of available versions +func NewVCSLister(goBinPath string, env []string, fs afero.Fs, propagateAuthHost string) UpstreamLister { + return &vcsLister{ + goBinPath: goBinPath, + env: env, + fs: fs, + propagateAuthHost: propagateAuthHost, + } +} + +func (l *vcsLister) shouldPropAuth() bool { + return len(l.propagateAuthHost) > 0 +} + +func (l *vcsLister) List(ctx context.Context, module string) (*storage.RevInfo, []string, error) { const op errors.Op = "vcsLister.List" ctx, span := observ.StartSpan(ctx, op.String()) defer span.End() - + var ( + netrcDir string + err error + ) + creds, ok := auth.FromContext(ctx) + if ok && l.shouldPropAuth() { + log.EntryFromContext(ctx).Debugf("propagating authentication") + netrcDir, err = auth.WriteTemporaryNETRC(l.propagateAuthHost, creds.User, creds.Password) + if err != nil { + return nil, nil, errors.E(op, err) + } + defer os.RemoveAll(netrcDir) + } tmpDir, err := afero.TempDir(l.fs, "", "go-list") if err != nil { return nil, nil, errors.E(op, err) @@ -42,7 +72,7 @@ func (l *vcsLister) List(ctx context.Context, mod string) (*storage.RevInfo, []s cmd := exec.Command( l.goBinPath, "list", "-m", "-versions", "-json", - config.FmtModVer(mod, "latest"), + config.FmtModVer(module, "latest"), ) cmd.Dir = tmpDir stdout := &bytes.Buffer{} @@ -55,7 +85,7 @@ func (l *vcsLister) List(ctx context.Context, mod string) (*storage.RevInfo, []s return nil, nil, errors.E(op, err) } defer clearFiles(l.fs, gopath) - cmd.Env = prepareEnv(gopath, l.env) + cmd.Env = prepareEnv(gopath, netrcDir, l.env) err = cmd.Run() if err != nil { @@ -81,8 +111,3 @@ func (l *vcsLister) List(ctx context.Context, mod string) (*storage.RevInfo, []s } return &rev, lr.Versions, nil } - -// NewVCSLister creates an UpstreamLister which uses VCS to fetch a list of available versions -func NewVCSLister(goBinPath string, env []string, fs afero.Fs) UpstreamLister { - return &vcsLister{goBinPath: goBinPath, env: env, fs: fs} -} diff --git a/pkg/module/prepare_env.go b/pkg/module/prepare_env.go index 49549759..2d6e57d4 100644 --- a/pkg/module/prepare_env.go +++ b/pkg/module/prepare_env.go @@ -10,7 +10,7 @@ import ( // prepareEnv will return all the appropriate // environment variables for a Go Command to run // successfully (such as GOPATH, GOCACHE, PATH etc) -func prepareEnv(gopath string, envVars []string) []string { +func prepareEnv(gopath, homedir string, envVars []string) []string { gopathEnv := fmt.Sprintf("GOPATH=%s", gopath) cacheEnv := fmt.Sprintf("GOCACHE=%s", filepath.Join(gopath, "cache")) disableCgo := "CGO_ENABLED=0" @@ -21,9 +21,9 @@ func prepareEnv(gopath string, envVars []string) []string { disableCgo, enableGoModules, } + cmdEnv = append(cmdEnv, withHomeDir(homedir)...) keys := []string{ "PATH", - "HOME", "GIT_SSH", "GIT_SSH_COMMAND", "HTTP_PROXY", @@ -36,7 +36,6 @@ func prepareEnv(gopath string, envVars []string) []string { } if runtime.GOOS == "windows" { windowsSpecificKeys := []string{ - "USERPROFILE", "SystemRoot", "ALLUSERSPROFILE", "HOMEDRIVE", @@ -62,3 +61,18 @@ func prepareEnv(gopath string, envVars []string) []string { } return cmdEnv } + +func withHomeDir(dir string) []string { + key := "HOME" + if runtime.GOOS == "windows" { + key = "USERPROFILE" + } + if dir != "" { + return []string{key + "=" + dir} + } + val, ok := os.LookupEnv(key) + if ok { + return []string{key + "=" + val} + } + return []string{} +} diff --git a/pkg/paths/path_test.go b/pkg/paths/path_test.go index 1514c98f..754179e3 100644 --- a/pkg/paths/path_test.go +++ b/pkg/paths/path_test.go @@ -80,6 +80,14 @@ func TestMatchesPattern(t *testing.T) { }, want: false, }, + { + name: "matches everything", + args: args{ + pattern: "*", + name: "github.com/gomods/athen", + }, + want: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/pkg/stash/stasher.go b/pkg/stash/stasher.go index 1909ace4..97f040d1 100644 --- a/pkg/stash/stasher.go +++ b/pkg/stash/stasher.go @@ -4,6 +4,7 @@ import ( "context" "time" + "github.com/gomods/athens/pkg/auth" "github.com/gomods/athens/pkg/errors" "github.com/gomods/athens/pkg/index" "github.com/gomods/athens/pkg/log" @@ -52,9 +53,12 @@ func (s *stasher) Stash(ctx context.Context, mod, ver string) (string, error) { // create a new context that ditches whatever deadline the caller passed // but keep the tracing info so that we can properly trace the whole thing. + tok, ok := auth.FromContext(ctx) ctx, cancel := context.WithTimeout(trace.NewContext(context.Background(), span), time.Minute*10) defer cancel() - + if ok { + ctx = auth.SetAuthInContext(ctx, tok) + } v, err := s.fetchModule(ctx, mod, ver) if err != nil { return "", errors.E(op, err) diff --git a/pkg/stash/stasher_test.go b/pkg/stash/stasher_test.go index 0f53b788..fc3931bc 100644 --- a/pkg/stash/stasher_test.go +++ b/pkg/stash/stasher_test.go @@ -7,6 +7,7 @@ import ( "strings" "testing" + "github.com/gomods/athens/pkg/auth" "github.com/gomods/athens/pkg/index/nop" "github.com/gomods/athens/pkg/storage" ) @@ -77,6 +78,24 @@ func TestStash(t *testing.T) { } } +func TestStashWithAuthContext(t *testing.T) { + var mf mockFetcher + var ms mockStorage + s := New(&mf, &ms, nop.New()) + want := auth.BasicAuth{ + User: "gomods", + Password: "athens", + } + ctx := auth.SetAuthInContext(context.Background(), want) + _, err := s.Stash(ctx, "mod", "ver") + if err != nil { + t.Fatal(err) + } + if mf.auth != want { + t.Fatalf("expected %+v but got %+v", want, mf.auth) + } +} + type mockStorage struct { storage.Backend existsCalled bool @@ -97,10 +116,13 @@ func (ms *mockStorage) Exists(ctx context.Context, mod, ver string) (bool, error } type mockFetcher struct { - ver string + ver string + auth auth.BasicAuth } func (mf *mockFetcher) Fetch(ctx context.Context, mod, ver string) (*storage.Version, error) { + a, _ := auth.FromContext(ctx) + mf.auth = a return &storage.Version{ Info: []byte("info"), Mod: []byte("gomod"),