Add OAuth files - authentication system
Migrated from app/: - oauth.go - OAuthManager, config loading, handle/DID resolution - oauth_session.go - SessionStore, encrypted cookies, token storage - oauth_middleware.go - RequireAuth middleware, token refresh - oauth_handlers.go - Login, callback, logout, JWKS endpoints Changed *DB to *shared.DB, using shared.StringValue/NullableString helpers. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,354 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/1440news/shared"
|
||||
)
|
||||
|
||||
const (
|
||||
sessionCookieName = "1440_session"
|
||||
sessionTTL = 24 * time.Hour
|
||||
)
|
||||
|
||||
// OAuthSession stores the OAuth session state for a user
|
||||
type OAuthSession struct {
|
||||
ID string `json:"id"`
|
||||
DID string `json:"did"`
|
||||
Handle string `json:"handle"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
|
||||
// OAuth tokens (stored server-side only)
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenExpiry time.Time `json:"token_expiry"`
|
||||
|
||||
// DPoP state
|
||||
DpopPrivateJWK string `json:"dpop_private_jwk"`
|
||||
DpopAuthserverNonce string `json:"dpop_authserver_nonce"`
|
||||
DpopPdsNonce string `json:"dpop_pds_nonce"`
|
||||
|
||||
// Auth server info
|
||||
PdsURL string `json:"pds_url"`
|
||||
AuthserverIss string `json:"authserver_iss"`
|
||||
}
|
||||
|
||||
// PendingAuth stores state during the OAuth flow (before callback)
|
||||
type PendingAuth struct {
|
||||
State string `json:"state"`
|
||||
PkceVerifier string `json:"pkce_verifier"`
|
||||
DpopPrivateJWK string `json:"dpop_private_jwk"`
|
||||
DpopNonce string `json:"dpop_nonce"`
|
||||
DID string `json:"did"`
|
||||
PdsURL string `json:"pds_url"`
|
||||
AuthserverIss string `json:"authserver_iss"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// SessionStore manages sessions in the database
|
||||
type SessionStore struct {
|
||||
db *shared.DB
|
||||
pending map[string]*PendingAuth // keyed by state (short-lived, kept in memory)
|
||||
mu sync.RWMutex
|
||||
cleanupOnce sync.Once
|
||||
}
|
||||
|
||||
// NewSessionStore creates a new session store
|
||||
func NewSessionStore(db *shared.DB) *SessionStore {
|
||||
s := &SessionStore{
|
||||
db: db,
|
||||
pending: make(map[string]*PendingAuth),
|
||||
}
|
||||
s.startCleanup()
|
||||
return s
|
||||
}
|
||||
|
||||
// startCleanup starts a background goroutine to clean up expired sessions
|
||||
func (s *SessionStore) startCleanup() {
|
||||
s.cleanupOnce.Do(func() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
s.cleanup()
|
||||
}
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
// cleanup removes expired sessions and pending auths
|
||||
func (s *SessionStore) cleanup() {
|
||||
// Clean up expired sessions from database
|
||||
s.db.Exec("DELETE FROM oauth_sessions WHERE expires_at < NOW()")
|
||||
|
||||
// Clean up old pending auths (10 minute timeout) from memory
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
now := time.Now()
|
||||
for state, pending := range s.pending {
|
||||
if now.Sub(pending.CreatedAt) > 10*time.Minute {
|
||||
delete(s.pending, state)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CreateSession creates a new session and returns it
|
||||
func (s *SessionStore) CreateSession(did, handle string) (*OAuthSession, error) {
|
||||
id, err := generateRandomID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
session := &OAuthSession{
|
||||
ID: id,
|
||||
DID: did,
|
||||
Handle: handle,
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(sessionTTL),
|
||||
}
|
||||
|
||||
_, err = s.db.Exec(`
|
||||
INSERT INTO oauth_sessions (id, did, handle, created_at, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
`, session.ID, session.DID, session.Handle, session.CreatedAt, session.ExpiresAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// GetSession retrieves a session by ID
|
||||
func (s *SessionStore) GetSession(id string) *OAuthSession {
|
||||
row := s.db.QueryRow(`
|
||||
SELECT id, did, handle, created_at, expires_at,
|
||||
access_token, refresh_token, token_expiry,
|
||||
dpop_private_jwk, dpop_authserver_nonce, dpop_pds_nonce,
|
||||
pds_url, authserver_iss
|
||||
FROM oauth_sessions
|
||||
WHERE id = $1 AND expires_at > NOW()
|
||||
`, id)
|
||||
|
||||
var session OAuthSession
|
||||
var accessToken, refreshToken, dpopJwk, dpopAuthNonce, dpopPdsNonce, pdsURL, authIss *string
|
||||
var tokenExpiry *time.Time
|
||||
|
||||
err := row.Scan(
|
||||
&session.ID, &session.DID, &session.Handle, &session.CreatedAt, &session.ExpiresAt,
|
||||
&accessToken, &refreshToken, &tokenExpiry,
|
||||
&dpopJwk, &dpopAuthNonce, &dpopPdsNonce,
|
||||
&pdsURL, &authIss,
|
||||
)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
session.AccessToken = shared.StringValue(accessToken)
|
||||
session.RefreshToken = shared.StringValue(refreshToken)
|
||||
if tokenExpiry != nil {
|
||||
session.TokenExpiry = *tokenExpiry
|
||||
}
|
||||
session.DpopPrivateJWK = shared.StringValue(dpopJwk)
|
||||
session.DpopAuthserverNonce = shared.StringValue(dpopAuthNonce)
|
||||
session.DpopPdsNonce = shared.StringValue(dpopPdsNonce)
|
||||
session.PdsURL = shared.StringValue(pdsURL)
|
||||
session.AuthserverIss = shared.StringValue(authIss)
|
||||
|
||||
return &session
|
||||
}
|
||||
|
||||
// UpdateSession updates a session
|
||||
func (s *SessionStore) UpdateSession(session *OAuthSession) {
|
||||
s.db.Exec(`
|
||||
UPDATE oauth_sessions SET
|
||||
access_token = $2,
|
||||
refresh_token = $3,
|
||||
token_expiry = $4,
|
||||
dpop_private_jwk = $5,
|
||||
dpop_authserver_nonce = $6,
|
||||
dpop_pds_nonce = $7,
|
||||
pds_url = $8,
|
||||
authserver_iss = $9
|
||||
WHERE id = $1
|
||||
`,
|
||||
session.ID,
|
||||
shared.NullableString(session.AccessToken),
|
||||
shared.NullableString(session.RefreshToken),
|
||||
shared.NullableTime(session.TokenExpiry),
|
||||
shared.NullableString(session.DpopPrivateJWK),
|
||||
shared.NullableString(session.DpopAuthserverNonce),
|
||||
shared.NullableString(session.DpopPdsNonce),
|
||||
shared.NullableString(session.PdsURL),
|
||||
shared.NullableString(session.AuthserverIss),
|
||||
)
|
||||
}
|
||||
|
||||
// DeleteSession removes a session
|
||||
func (s *SessionStore) DeleteSession(id string) {
|
||||
s.db.Exec("DELETE FROM oauth_sessions WHERE id = $1", id)
|
||||
}
|
||||
|
||||
// SavePending saves pending OAuth state (kept in memory - short lived)
|
||||
func (s *SessionStore) SavePending(state string, pending *PendingAuth) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
pending.CreatedAt = time.Now()
|
||||
s.pending[state] = pending
|
||||
}
|
||||
|
||||
// GetPending retrieves and removes pending OAuth state
|
||||
func (s *SessionStore) GetPending(state string) *PendingAuth {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
pending, ok := s.pending[state]
|
||||
if ok {
|
||||
delete(s.pending, state)
|
||||
}
|
||||
return pending
|
||||
}
|
||||
|
||||
// generateRandomID generates a random session ID
|
||||
func generateRandomID() (string, error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.URLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// encryptSessionID encrypts a session ID using AES-256-GCM
|
||||
func encryptSessionID(sessionID string, key []byte) (string, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
ciphertext := gcm.Seal(nonce, nonce, []byte(sessionID), nil)
|
||||
return base64.URLEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// decryptSessionID decrypts a session ID using AES-256-GCM
|
||||
func decryptSessionID(encrypted string, key []byte) (string, error) {
|
||||
ciphertext, err := base64.URLEncoding.DecodeString(encrypted)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if len(ciphertext) < gcm.NonceSize() {
|
||||
return "", fmt.Errorf("ciphertext too short")
|
||||
}
|
||||
|
||||
nonce, ciphertext := ciphertext[:gcm.NonceSize()], ciphertext[gcm.NonceSize():]
|
||||
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(plaintext), nil
|
||||
}
|
||||
|
||||
// SetSessionCookie sets an encrypted session cookie
|
||||
func (m *OAuthManager) SetSessionCookie(w http.ResponseWriter, r *http.Request, sessionID string) error {
|
||||
encrypted, err := encryptSessionID(sessionID, m.cookieSecret)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Only set Secure flag for HTTPS connections
|
||||
secure := r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https"
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: sessionCookieName,
|
||||
Value: encrypted,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: secure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: int(sessionTTL.Seconds()),
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSessionFromCookie retrieves the session from the request cookie
|
||||
func (m *OAuthManager) GetSessionFromCookie(r *http.Request) *OAuthSession {
|
||||
cookie, err := r.Cookie(sessionCookieName)
|
||||
if err != nil {
|
||||
fmt.Printf("GetSessionFromCookie: no cookie found: %v\n", err)
|
||||
return nil
|
||||
}
|
||||
fmt.Printf("GetSessionFromCookie: found cookie, length=%d\n", len(cookie.Value))
|
||||
|
||||
sessionID, err := decryptSessionID(cookie.Value, m.cookieSecret)
|
||||
if err != nil {
|
||||
fmt.Printf("GetSessionFromCookie: decrypt failed: %v\n", err)
|
||||
return nil
|
||||
}
|
||||
fmt.Printf("GetSessionFromCookie: decrypted session ID: %s\n", sessionID)
|
||||
|
||||
session := m.sessions.GetSession(sessionID)
|
||||
if session == nil {
|
||||
fmt.Printf("GetSessionFromCookie: session not found in store\n")
|
||||
} else {
|
||||
fmt.Printf("GetSessionFromCookie: found session for %s\n", session.Handle)
|
||||
}
|
||||
return session
|
||||
}
|
||||
|
||||
// ClearSessionCookie removes the session cookie
|
||||
func (m *OAuthManager) ClearSessionCookie(w http.ResponseWriter) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: sessionCookieName,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: -1,
|
||||
})
|
||||
}
|
||||
|
||||
// SessionInfo is the public session info returned to the client
|
||||
type SessionInfo struct {
|
||||
DID string `json:"did"`
|
||||
Handle string `json:"handle"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
}
|
||||
|
||||
// MarshalJSON converts SessionInfo to JSON
|
||||
func (s *SessionInfo) MarshalJSON() ([]byte, error) {
|
||||
type Alias SessionInfo
|
||||
return json.Marshal((*Alias)(s))
|
||||
}
|
||||
Reference in New Issue
Block a user