Migrated from app/: - oauth.go - OAuthManager, config loading, handle/DID resolution - oauth_session.go - SessionStore, encrypted cookies, token storage - oauth_middleware.go - RequireAuth middleware, token refresh - oauth_handlers.go - Login, callback, logout, JWKS endpoints Changed *DB to *shared.DB, using shared.StringValue/NullableString helpers. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
355 lines
9.2 KiB
Go
355 lines
9.2 KiB
Go
package main
|
|
|
|
import (
|
|
"crypto/aes"
|
|
"crypto/cipher"
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/1440news/shared"
|
|
)
|
|
|
|
const (
|
|
sessionCookieName = "1440_session"
|
|
sessionTTL = 24 * time.Hour
|
|
)
|
|
|
|
// OAuthSession stores the OAuth session state for a user
|
|
type OAuthSession struct {
|
|
ID string `json:"id"`
|
|
DID string `json:"did"`
|
|
Handle string `json:"handle"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
ExpiresAt time.Time `json:"expires_at"`
|
|
|
|
// OAuth tokens (stored server-side only)
|
|
AccessToken string `json:"access_token"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
TokenExpiry time.Time `json:"token_expiry"`
|
|
|
|
// DPoP state
|
|
DpopPrivateJWK string `json:"dpop_private_jwk"`
|
|
DpopAuthserverNonce string `json:"dpop_authserver_nonce"`
|
|
DpopPdsNonce string `json:"dpop_pds_nonce"`
|
|
|
|
// Auth server info
|
|
PdsURL string `json:"pds_url"`
|
|
AuthserverIss string `json:"authserver_iss"`
|
|
}
|
|
|
|
// PendingAuth stores state during the OAuth flow (before callback)
|
|
type PendingAuth struct {
|
|
State string `json:"state"`
|
|
PkceVerifier string `json:"pkce_verifier"`
|
|
DpopPrivateJWK string `json:"dpop_private_jwk"`
|
|
DpopNonce string `json:"dpop_nonce"`
|
|
DID string `json:"did"`
|
|
PdsURL string `json:"pds_url"`
|
|
AuthserverIss string `json:"authserver_iss"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
}
|
|
|
|
// SessionStore manages sessions in the database
|
|
type SessionStore struct {
|
|
db *shared.DB
|
|
pending map[string]*PendingAuth // keyed by state (short-lived, kept in memory)
|
|
mu sync.RWMutex
|
|
cleanupOnce sync.Once
|
|
}
|
|
|
|
// NewSessionStore creates a new session store
|
|
func NewSessionStore(db *shared.DB) *SessionStore {
|
|
s := &SessionStore{
|
|
db: db,
|
|
pending: make(map[string]*PendingAuth),
|
|
}
|
|
s.startCleanup()
|
|
return s
|
|
}
|
|
|
|
// startCleanup starts a background goroutine to clean up expired sessions
|
|
func (s *SessionStore) startCleanup() {
|
|
s.cleanupOnce.Do(func() {
|
|
go func() {
|
|
ticker := time.NewTicker(5 * time.Minute)
|
|
defer ticker.Stop()
|
|
for range ticker.C {
|
|
s.cleanup()
|
|
}
|
|
}()
|
|
})
|
|
}
|
|
|
|
// cleanup removes expired sessions and pending auths
|
|
func (s *SessionStore) cleanup() {
|
|
// Clean up expired sessions from database
|
|
s.db.Exec("DELETE FROM oauth_sessions WHERE expires_at < NOW()")
|
|
|
|
// Clean up old pending auths (10 minute timeout) from memory
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
now := time.Now()
|
|
for state, pending := range s.pending {
|
|
if now.Sub(pending.CreatedAt) > 10*time.Minute {
|
|
delete(s.pending, state)
|
|
}
|
|
}
|
|
}
|
|
|
|
// CreateSession creates a new session and returns it
|
|
func (s *SessionStore) CreateSession(did, handle string) (*OAuthSession, error) {
|
|
id, err := generateRandomID()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
now := time.Now()
|
|
session := &OAuthSession{
|
|
ID: id,
|
|
DID: did,
|
|
Handle: handle,
|
|
CreatedAt: now,
|
|
ExpiresAt: now.Add(sessionTTL),
|
|
}
|
|
|
|
_, err = s.db.Exec(`
|
|
INSERT INTO oauth_sessions (id, did, handle, created_at, expires_at)
|
|
VALUES ($1, $2, $3, $4, $5)
|
|
`, session.ID, session.DID, session.Handle, session.CreatedAt, session.ExpiresAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return session, nil
|
|
}
|
|
|
|
// GetSession retrieves a session by ID
|
|
func (s *SessionStore) GetSession(id string) *OAuthSession {
|
|
row := s.db.QueryRow(`
|
|
SELECT id, did, handle, created_at, expires_at,
|
|
access_token, refresh_token, token_expiry,
|
|
dpop_private_jwk, dpop_authserver_nonce, dpop_pds_nonce,
|
|
pds_url, authserver_iss
|
|
FROM oauth_sessions
|
|
WHERE id = $1 AND expires_at > NOW()
|
|
`, id)
|
|
|
|
var session OAuthSession
|
|
var accessToken, refreshToken, dpopJwk, dpopAuthNonce, dpopPdsNonce, pdsURL, authIss *string
|
|
var tokenExpiry *time.Time
|
|
|
|
err := row.Scan(
|
|
&session.ID, &session.DID, &session.Handle, &session.CreatedAt, &session.ExpiresAt,
|
|
&accessToken, &refreshToken, &tokenExpiry,
|
|
&dpopJwk, &dpopAuthNonce, &dpopPdsNonce,
|
|
&pdsURL, &authIss,
|
|
)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
|
|
session.AccessToken = shared.StringValue(accessToken)
|
|
session.RefreshToken = shared.StringValue(refreshToken)
|
|
if tokenExpiry != nil {
|
|
session.TokenExpiry = *tokenExpiry
|
|
}
|
|
session.DpopPrivateJWK = shared.StringValue(dpopJwk)
|
|
session.DpopAuthserverNonce = shared.StringValue(dpopAuthNonce)
|
|
session.DpopPdsNonce = shared.StringValue(dpopPdsNonce)
|
|
session.PdsURL = shared.StringValue(pdsURL)
|
|
session.AuthserverIss = shared.StringValue(authIss)
|
|
|
|
return &session
|
|
}
|
|
|
|
// UpdateSession updates a session
|
|
func (s *SessionStore) UpdateSession(session *OAuthSession) {
|
|
s.db.Exec(`
|
|
UPDATE oauth_sessions SET
|
|
access_token = $2,
|
|
refresh_token = $3,
|
|
token_expiry = $4,
|
|
dpop_private_jwk = $5,
|
|
dpop_authserver_nonce = $6,
|
|
dpop_pds_nonce = $7,
|
|
pds_url = $8,
|
|
authserver_iss = $9
|
|
WHERE id = $1
|
|
`,
|
|
session.ID,
|
|
shared.NullableString(session.AccessToken),
|
|
shared.NullableString(session.RefreshToken),
|
|
shared.NullableTime(session.TokenExpiry),
|
|
shared.NullableString(session.DpopPrivateJWK),
|
|
shared.NullableString(session.DpopAuthserverNonce),
|
|
shared.NullableString(session.DpopPdsNonce),
|
|
shared.NullableString(session.PdsURL),
|
|
shared.NullableString(session.AuthserverIss),
|
|
)
|
|
}
|
|
|
|
// DeleteSession removes a session
|
|
func (s *SessionStore) DeleteSession(id string) {
|
|
s.db.Exec("DELETE FROM oauth_sessions WHERE id = $1", id)
|
|
}
|
|
|
|
// SavePending saves pending OAuth state (kept in memory - short lived)
|
|
func (s *SessionStore) SavePending(state string, pending *PendingAuth) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
pending.CreatedAt = time.Now()
|
|
s.pending[state] = pending
|
|
}
|
|
|
|
// GetPending retrieves and removes pending OAuth state
|
|
func (s *SessionStore) GetPending(state string) *PendingAuth {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
pending, ok := s.pending[state]
|
|
if ok {
|
|
delete(s.pending, state)
|
|
}
|
|
return pending
|
|
}
|
|
|
|
// generateRandomID generates a random session ID
|
|
func generateRandomID() (string, error) {
|
|
b := make([]byte, 32)
|
|
if _, err := rand.Read(b); err != nil {
|
|
return "", err
|
|
}
|
|
return base64.URLEncoding.EncodeToString(b), nil
|
|
}
|
|
|
|
// encryptSessionID encrypts a session ID using AES-256-GCM
|
|
func encryptSessionID(sessionID string, key []byte) (string, error) {
|
|
block, err := aes.NewCipher(key)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
gcm, err := cipher.NewGCM(block)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
nonce := make([]byte, gcm.NonceSize())
|
|
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
ciphertext := gcm.Seal(nonce, nonce, []byte(sessionID), nil)
|
|
return base64.URLEncoding.EncodeToString(ciphertext), nil
|
|
}
|
|
|
|
// decryptSessionID decrypts a session ID using AES-256-GCM
|
|
func decryptSessionID(encrypted string, key []byte) (string, error) {
|
|
ciphertext, err := base64.URLEncoding.DecodeString(encrypted)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
block, err := aes.NewCipher(key)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
gcm, err := cipher.NewGCM(block)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
if len(ciphertext) < gcm.NonceSize() {
|
|
return "", fmt.Errorf("ciphertext too short")
|
|
}
|
|
|
|
nonce, ciphertext := ciphertext[:gcm.NonceSize()], ciphertext[gcm.NonceSize():]
|
|
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return string(plaintext), nil
|
|
}
|
|
|
|
// SetSessionCookie sets an encrypted session cookie
|
|
func (m *OAuthManager) SetSessionCookie(w http.ResponseWriter, r *http.Request, sessionID string) error {
|
|
encrypted, err := encryptSessionID(sessionID, m.cookieSecret)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Only set Secure flag for HTTPS connections
|
|
secure := r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https"
|
|
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: sessionCookieName,
|
|
Value: encrypted,
|
|
Path: "/",
|
|
HttpOnly: true,
|
|
Secure: secure,
|
|
SameSite: http.SameSiteLaxMode,
|
|
MaxAge: int(sessionTTL.Seconds()),
|
|
})
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetSessionFromCookie retrieves the session from the request cookie
|
|
func (m *OAuthManager) GetSessionFromCookie(r *http.Request) *OAuthSession {
|
|
cookie, err := r.Cookie(sessionCookieName)
|
|
if err != nil {
|
|
fmt.Printf("GetSessionFromCookie: no cookie found: %v\n", err)
|
|
return nil
|
|
}
|
|
fmt.Printf("GetSessionFromCookie: found cookie, length=%d\n", len(cookie.Value))
|
|
|
|
sessionID, err := decryptSessionID(cookie.Value, m.cookieSecret)
|
|
if err != nil {
|
|
fmt.Printf("GetSessionFromCookie: decrypt failed: %v\n", err)
|
|
return nil
|
|
}
|
|
fmt.Printf("GetSessionFromCookie: decrypted session ID: %s\n", sessionID)
|
|
|
|
session := m.sessions.GetSession(sessionID)
|
|
if session == nil {
|
|
fmt.Printf("GetSessionFromCookie: session not found in store\n")
|
|
} else {
|
|
fmt.Printf("GetSessionFromCookie: found session for %s\n", session.Handle)
|
|
}
|
|
return session
|
|
}
|
|
|
|
// ClearSessionCookie removes the session cookie
|
|
func (m *OAuthManager) ClearSessionCookie(w http.ResponseWriter) {
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: sessionCookieName,
|
|
Value: "",
|
|
Path: "/",
|
|
HttpOnly: true,
|
|
Secure: true,
|
|
SameSite: http.SameSiteLaxMode,
|
|
MaxAge: -1,
|
|
})
|
|
}
|
|
|
|
// SessionInfo is the public session info returned to the client
|
|
type SessionInfo struct {
|
|
DID string `json:"did"`
|
|
Handle string `json:"handle"`
|
|
ExpiresAt time.Time `json:"expires_at"`
|
|
}
|
|
|
|
// MarshalJSON converts SessionInfo to JSON
|
|
func (s *SessionInfo) MarshalJSON() ([]byte, error) {
|
|
type Alias SessionInfo
|
|
return json.Marshal((*Alias)(s))
|
|
}
|