package main import ( "context" "encoding/json" "fmt" "net/http" "strings" "time" "github.com/haileyok/atproto-oauth-golang/helpers" ) // RequireAuth is middleware that protects routes requiring authentication func (m *OAuthManager) RequireAuth(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { session := m.GetSessionFromCookie(r) if session == nil { fmt.Printf("RequireAuth: no session found for %s\n", r.URL.Path) // Check if this is an API call (wants JSON response) if isAPIRequest(r) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusUnauthorized) json.NewEncoder(w).Encode(map[string]string{ "error": "unauthorized", }) return } // Redirect to login for browser requests http.Redirect(w, r, "/auth/login", http.StatusFound) return } // Check if token needs refresh (refresh when within 5 minutes of expiry) if time.Until(session.TokenExpiry) < 5*time.Minute { if err := m.refreshToken(r.Context(), session); err != nil { // Token refresh failed - clear session and redirect to login m.sessions.DeleteSession(session.ID) m.ClearSessionCookie(w) if isAPIRequest(r) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusUnauthorized) json.NewEncoder(w).Encode(map[string]string{ "error": "session expired", }) return } http.Redirect(w, r, "/auth/login", http.StatusFound) return } } // Add session to request context ctx := context.WithValue(r.Context(), sessionContextKey, session) next(w, r.WithContext(ctx)) } } // sessionContextKey is the context key for the OAuth session type contextKey string const sessionContextKey contextKey = "oauth_session" // GetSession retrieves the session from request context func GetSession(r *http.Request) *OAuthSession { session, _ := r.Context().Value(sessionContextKey).(*OAuthSession) return session } // isAPIRequest checks if the request expects JSON response func isAPIRequest(r *http.Request) bool { // Check Accept header accept := r.Header.Get("Accept") if strings.Contains(accept, "application/json") { return true } // Check URL path if strings.HasPrefix(r.URL.Path, "/api/") { return true } // Check X-Requested-With header (for AJAX) if r.Header.Get("X-Requested-With") == "XMLHttpRequest" { return true } return false } // refreshToken refreshes the OAuth access token func (m *OAuthManager) refreshToken(ctx context.Context, session *OAuthSession) error { if session.RefreshToken == "" { return nil // No refresh token available } // Parse the DPoP private key dpopKey, err := helpers.ParseJWKFromBytes([]byte(session.DpopPrivateJWK)) if err != nil { return err } // Refresh the token tokenResp, err := m.client.RefreshTokenRequest( ctx, session.RefreshToken, session.AuthserverIss, session.DpopAuthserverNonce, dpopKey, ) if err != nil { return err } // Update session with new tokens session.AccessToken = tokenResp.AccessToken session.RefreshToken = tokenResp.RefreshToken session.TokenExpiry = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) session.DpopAuthserverNonce = tokenResp.DpopAuthserverNonce // Save updated session m.sessions.UpdateSession(session) return nil }