Files
watcher/oauth_session.go
primal 53919fa31e Add OAuth files - authentication system
Migrated from app/:
- oauth.go - OAuthManager, config loading, handle/DID resolution
- oauth_session.go - SessionStore, encrypted cookies, token storage
- oauth_middleware.go - RequireAuth middleware, token refresh
- oauth_handlers.go - Login, callback, logout, JWKS endpoints

Changed *DB to *shared.DB, using shared.StringValue/NullableString helpers.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-02 12:48:23 -05:00

355 lines
9.2 KiB
Go

package main
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"sync"
"time"
"github.com/1440news/shared"
)
const (
sessionCookieName = "1440_session"
sessionTTL = 24 * time.Hour
)
// OAuthSession stores the OAuth session state for a user
type OAuthSession struct {
ID string `json:"id"`
DID string `json:"did"`
Handle string `json:"handle"`
CreatedAt time.Time `json:"created_at"`
ExpiresAt time.Time `json:"expires_at"`
// OAuth tokens (stored server-side only)
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
TokenExpiry time.Time `json:"token_expiry"`
// DPoP state
DpopPrivateJWK string `json:"dpop_private_jwk"`
DpopAuthserverNonce string `json:"dpop_authserver_nonce"`
DpopPdsNonce string `json:"dpop_pds_nonce"`
// Auth server info
PdsURL string `json:"pds_url"`
AuthserverIss string `json:"authserver_iss"`
}
// PendingAuth stores state during the OAuth flow (before callback)
type PendingAuth struct {
State string `json:"state"`
PkceVerifier string `json:"pkce_verifier"`
DpopPrivateJWK string `json:"dpop_private_jwk"`
DpopNonce string `json:"dpop_nonce"`
DID string `json:"did"`
PdsURL string `json:"pds_url"`
AuthserverIss string `json:"authserver_iss"`
CreatedAt time.Time `json:"created_at"`
}
// SessionStore manages sessions in the database
type SessionStore struct {
db *shared.DB
pending map[string]*PendingAuth // keyed by state (short-lived, kept in memory)
mu sync.RWMutex
cleanupOnce sync.Once
}
// NewSessionStore creates a new session store
func NewSessionStore(db *shared.DB) *SessionStore {
s := &SessionStore{
db: db,
pending: make(map[string]*PendingAuth),
}
s.startCleanup()
return s
}
// startCleanup starts a background goroutine to clean up expired sessions
func (s *SessionStore) startCleanup() {
s.cleanupOnce.Do(func() {
go func() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
s.cleanup()
}
}()
})
}
// cleanup removes expired sessions and pending auths
func (s *SessionStore) cleanup() {
// Clean up expired sessions from database
s.db.Exec("DELETE FROM oauth_sessions WHERE expires_at < NOW()")
// Clean up old pending auths (10 minute timeout) from memory
s.mu.Lock()
defer s.mu.Unlock()
now := time.Now()
for state, pending := range s.pending {
if now.Sub(pending.CreatedAt) > 10*time.Minute {
delete(s.pending, state)
}
}
}
// CreateSession creates a new session and returns it
func (s *SessionStore) CreateSession(did, handle string) (*OAuthSession, error) {
id, err := generateRandomID()
if err != nil {
return nil, err
}
now := time.Now()
session := &OAuthSession{
ID: id,
DID: did,
Handle: handle,
CreatedAt: now,
ExpiresAt: now.Add(sessionTTL),
}
_, err = s.db.Exec(`
INSERT INTO oauth_sessions (id, did, handle, created_at, expires_at)
VALUES ($1, $2, $3, $4, $5)
`, session.ID, session.DID, session.Handle, session.CreatedAt, session.ExpiresAt)
if err != nil {
return nil, err
}
return session, nil
}
// GetSession retrieves a session by ID
func (s *SessionStore) GetSession(id string) *OAuthSession {
row := s.db.QueryRow(`
SELECT id, did, handle, created_at, expires_at,
access_token, refresh_token, token_expiry,
dpop_private_jwk, dpop_authserver_nonce, dpop_pds_nonce,
pds_url, authserver_iss
FROM oauth_sessions
WHERE id = $1 AND expires_at > NOW()
`, id)
var session OAuthSession
var accessToken, refreshToken, dpopJwk, dpopAuthNonce, dpopPdsNonce, pdsURL, authIss *string
var tokenExpiry *time.Time
err := row.Scan(
&session.ID, &session.DID, &session.Handle, &session.CreatedAt, &session.ExpiresAt,
&accessToken, &refreshToken, &tokenExpiry,
&dpopJwk, &dpopAuthNonce, &dpopPdsNonce,
&pdsURL, &authIss,
)
if err != nil {
return nil
}
session.AccessToken = shared.StringValue(accessToken)
session.RefreshToken = shared.StringValue(refreshToken)
if tokenExpiry != nil {
session.TokenExpiry = *tokenExpiry
}
session.DpopPrivateJWK = shared.StringValue(dpopJwk)
session.DpopAuthserverNonce = shared.StringValue(dpopAuthNonce)
session.DpopPdsNonce = shared.StringValue(dpopPdsNonce)
session.PdsURL = shared.StringValue(pdsURL)
session.AuthserverIss = shared.StringValue(authIss)
return &session
}
// UpdateSession updates a session
func (s *SessionStore) UpdateSession(session *OAuthSession) {
s.db.Exec(`
UPDATE oauth_sessions SET
access_token = $2,
refresh_token = $3,
token_expiry = $4,
dpop_private_jwk = $5,
dpop_authserver_nonce = $6,
dpop_pds_nonce = $7,
pds_url = $8,
authserver_iss = $9
WHERE id = $1
`,
session.ID,
shared.NullableString(session.AccessToken),
shared.NullableString(session.RefreshToken),
shared.NullableTime(session.TokenExpiry),
shared.NullableString(session.DpopPrivateJWK),
shared.NullableString(session.DpopAuthserverNonce),
shared.NullableString(session.DpopPdsNonce),
shared.NullableString(session.PdsURL),
shared.NullableString(session.AuthserverIss),
)
}
// DeleteSession removes a session
func (s *SessionStore) DeleteSession(id string) {
s.db.Exec("DELETE FROM oauth_sessions WHERE id = $1", id)
}
// SavePending saves pending OAuth state (kept in memory - short lived)
func (s *SessionStore) SavePending(state string, pending *PendingAuth) {
s.mu.Lock()
defer s.mu.Unlock()
pending.CreatedAt = time.Now()
s.pending[state] = pending
}
// GetPending retrieves and removes pending OAuth state
func (s *SessionStore) GetPending(state string) *PendingAuth {
s.mu.Lock()
defer s.mu.Unlock()
pending, ok := s.pending[state]
if ok {
delete(s.pending, state)
}
return pending
}
// generateRandomID generates a random session ID
func generateRandomID() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(b), nil
}
// encryptSessionID encrypts a session ID using AES-256-GCM
func encryptSessionID(sessionID string, key []byte) (string, error) {
block, err := aes.NewCipher(key)
if err != nil {
return "", err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", err
}
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", err
}
ciphertext := gcm.Seal(nonce, nonce, []byte(sessionID), nil)
return base64.URLEncoding.EncodeToString(ciphertext), nil
}
// decryptSessionID decrypts a session ID using AES-256-GCM
func decryptSessionID(encrypted string, key []byte) (string, error) {
ciphertext, err := base64.URLEncoding.DecodeString(encrypted)
if err != nil {
return "", err
}
block, err := aes.NewCipher(key)
if err != nil {
return "", err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", err
}
if len(ciphertext) < gcm.NonceSize() {
return "", fmt.Errorf("ciphertext too short")
}
nonce, ciphertext := ciphertext[:gcm.NonceSize()], ciphertext[gcm.NonceSize():]
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return "", err
}
return string(plaintext), nil
}
// SetSessionCookie sets an encrypted session cookie
func (m *OAuthManager) SetSessionCookie(w http.ResponseWriter, r *http.Request, sessionID string) error {
encrypted, err := encryptSessionID(sessionID, m.cookieSecret)
if err != nil {
return err
}
// Only set Secure flag for HTTPS connections
secure := r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https"
http.SetCookie(w, &http.Cookie{
Name: sessionCookieName,
Value: encrypted,
Path: "/",
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
MaxAge: int(sessionTTL.Seconds()),
})
return nil
}
// GetSessionFromCookie retrieves the session from the request cookie
func (m *OAuthManager) GetSessionFromCookie(r *http.Request) *OAuthSession {
cookie, err := r.Cookie(sessionCookieName)
if err != nil {
fmt.Printf("GetSessionFromCookie: no cookie found: %v\n", err)
return nil
}
fmt.Printf("GetSessionFromCookie: found cookie, length=%d\n", len(cookie.Value))
sessionID, err := decryptSessionID(cookie.Value, m.cookieSecret)
if err != nil {
fmt.Printf("GetSessionFromCookie: decrypt failed: %v\n", err)
return nil
}
fmt.Printf("GetSessionFromCookie: decrypted session ID: %s\n", sessionID)
session := m.sessions.GetSession(sessionID)
if session == nil {
fmt.Printf("GetSessionFromCookie: session not found in store\n")
} else {
fmt.Printf("GetSessionFromCookie: found session for %s\n", session.Handle)
}
return session
}
// ClearSessionCookie removes the session cookie
func (m *OAuthManager) ClearSessionCookie(w http.ResponseWriter) {
http.SetCookie(w, &http.Cookie{
Name: sessionCookieName,
Value: "",
Path: "/",
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteLaxMode,
MaxAge: -1,
})
}
// SessionInfo is the public session info returned to the client
type SessionInfo struct {
DID string `json:"did"`
Handle string `json:"handle"`
ExpiresAt time.Time `json:"expires_at"`
}
// MarshalJSON converts SessionInfo to JSON
func (s *SessionInfo) MarshalJSON() ([]byte, error) {
type Alias SessionInfo
return json.Marshal((*Alias)(s))
}