package main import ( "crypto/aes" "crypto/cipher" "crypto/rand" "encoding/base64" "encoding/json" "fmt" "io" "net/http" "sync" "time" ) const ( sessionCookieName = "1440_session" sessionTTL = 24 * time.Hour ) // OAuthSession stores the OAuth session state for a user type OAuthSession struct { ID string `json:"id"` DID string `json:"did"` Handle string `json:"handle"` CreatedAt time.Time `json:"created_at"` ExpiresAt time.Time `json:"expires_at"` // OAuth tokens (stored server-side only) AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` TokenExpiry time.Time `json:"token_expiry"` // DPoP state DpopPrivateJWK string `json:"dpop_private_jwk"` DpopAuthserverNonce string `json:"dpop_authserver_nonce"` DpopPdsNonce string `json:"dpop_pds_nonce"` // Auth server info PdsURL string `json:"pds_url"` AuthserverIss string `json:"authserver_iss"` } // PendingAuth stores state during the OAuth flow (before callback) type PendingAuth struct { State string `json:"state"` PkceVerifier string `json:"pkce_verifier"` DpopPrivateJWK string `json:"dpop_private_jwk"` DpopNonce string `json:"dpop_nonce"` DID string `json:"did"` PdsURL string `json:"pds_url"` AuthserverIss string `json:"authserver_iss"` CreatedAt time.Time `json:"created_at"` } // SessionStore manages sessions in the database type SessionStore struct { db *DB pending map[string]*PendingAuth // keyed by state (short-lived, kept in memory) mu sync.RWMutex cleanupOnce sync.Once } // NewSessionStore creates a new session store func NewSessionStore(db *DB) *SessionStore { s := &SessionStore{ db: db, pending: make(map[string]*PendingAuth), } s.startCleanup() return s } // startCleanup starts a background goroutine to clean up expired sessions func (s *SessionStore) startCleanup() { s.cleanupOnce.Do(func() { go func() { ticker := time.NewTicker(5 * time.Minute) defer ticker.Stop() for range ticker.C { s.cleanup() } }() }) } // cleanup removes expired sessions and pending auths func (s *SessionStore) cleanup() { // Clean up expired sessions from database s.db.Exec("DELETE FROM oauth_sessions WHERE expires_at < NOW()") // Clean up old pending auths (10 minute timeout) from memory s.mu.Lock() defer s.mu.Unlock() now := time.Now() for state, pending := range s.pending { if now.Sub(pending.CreatedAt) > 10*time.Minute { delete(s.pending, state) } } } // CreateSession creates a new session and returns it func (s *SessionStore) CreateSession(did, handle string) (*OAuthSession, error) { id, err := generateRandomID() if err != nil { return nil, err } now := time.Now() session := &OAuthSession{ ID: id, DID: did, Handle: handle, CreatedAt: now, ExpiresAt: now.Add(sessionTTL), } _, err = s.db.Exec(` INSERT INTO oauth_sessions (id, did, handle, created_at, expires_at) VALUES ($1, $2, $3, $4, $5) `, session.ID, session.DID, session.Handle, session.CreatedAt, session.ExpiresAt) if err != nil { return nil, err } return session, nil } // GetSession retrieves a session by ID func (s *SessionStore) GetSession(id string) *OAuthSession { row := s.db.QueryRow(` SELECT id, did, handle, created_at, expires_at, access_token, refresh_token, token_expiry, dpop_private_jwk, dpop_authserver_nonce, dpop_pds_nonce, pds_url, authserver_iss FROM oauth_sessions WHERE id = $1 AND expires_at > NOW() `, id) var session OAuthSession var accessToken, refreshToken, dpopJwk, dpopAuthNonce, dpopPdsNonce, pdsURL, authIss *string var tokenExpiry *time.Time err := row.Scan( &session.ID, &session.DID, &session.Handle, &session.CreatedAt, &session.ExpiresAt, &accessToken, &refreshToken, &tokenExpiry, &dpopJwk, &dpopAuthNonce, &dpopPdsNonce, &pdsURL, &authIss, ) if err != nil { return nil } session.AccessToken = StringValue(accessToken) session.RefreshToken = StringValue(refreshToken) if tokenExpiry != nil { session.TokenExpiry = *tokenExpiry } session.DpopPrivateJWK = StringValue(dpopJwk) session.DpopAuthserverNonce = StringValue(dpopAuthNonce) session.DpopPdsNonce = StringValue(dpopPdsNonce) session.PdsURL = StringValue(pdsURL) session.AuthserverIss = StringValue(authIss) return &session } // UpdateSession updates a session func (s *SessionStore) UpdateSession(session *OAuthSession) { s.db.Exec(` UPDATE oauth_sessions SET access_token = $2, refresh_token = $3, token_expiry = $4, dpop_private_jwk = $5, dpop_authserver_nonce = $6, dpop_pds_nonce = $7, pds_url = $8, authserver_iss = $9 WHERE id = $1 `, session.ID, NullableString(session.AccessToken), NullableString(session.RefreshToken), NullableTime(session.TokenExpiry), NullableString(session.DpopPrivateJWK), NullableString(session.DpopAuthserverNonce), NullableString(session.DpopPdsNonce), NullableString(session.PdsURL), NullableString(session.AuthserverIss), ) } // DeleteSession removes a session func (s *SessionStore) DeleteSession(id string) { s.db.Exec("DELETE FROM oauth_sessions WHERE id = $1", id) } // SavePending saves pending OAuth state (kept in memory - short lived) func (s *SessionStore) SavePending(state string, pending *PendingAuth) { s.mu.Lock() defer s.mu.Unlock() pending.CreatedAt = time.Now() s.pending[state] = pending } // GetPending retrieves and removes pending OAuth state func (s *SessionStore) GetPending(state string) *PendingAuth { s.mu.Lock() defer s.mu.Unlock() pending, ok := s.pending[state] if ok { delete(s.pending, state) } return pending } // generateRandomID generates a random session ID func generateRandomID() (string, error) { b := make([]byte, 32) if _, err := rand.Read(b); err != nil { return "", err } return base64.URLEncoding.EncodeToString(b), nil } // encryptSessionID encrypts a session ID using AES-256-GCM func encryptSessionID(sessionID string, key []byte) (string, error) { block, err := aes.NewCipher(key) if err != nil { return "", err } gcm, err := cipher.NewGCM(block) if err != nil { return "", err } nonce := make([]byte, gcm.NonceSize()) if _, err := io.ReadFull(rand.Reader, nonce); err != nil { return "", err } ciphertext := gcm.Seal(nonce, nonce, []byte(sessionID), nil) return base64.URLEncoding.EncodeToString(ciphertext), nil } // decryptSessionID decrypts a session ID using AES-256-GCM func decryptSessionID(encrypted string, key []byte) (string, error) { ciphertext, err := base64.URLEncoding.DecodeString(encrypted) if err != nil { return "", err } block, err := aes.NewCipher(key) if err != nil { return "", err } gcm, err := cipher.NewGCM(block) if err != nil { return "", err } if len(ciphertext) < gcm.NonceSize() { return "", fmt.Errorf("ciphertext too short") } nonce, ciphertext := ciphertext[:gcm.NonceSize()], ciphertext[gcm.NonceSize():] plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) if err != nil { return "", err } return string(plaintext), nil } // SetSessionCookie sets an encrypted session cookie func (m *OAuthManager) SetSessionCookie(w http.ResponseWriter, r *http.Request, sessionID string) error { encrypted, err := encryptSessionID(sessionID, m.cookieSecret) if err != nil { return err } // Only set Secure flag for HTTPS connections secure := r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" http.SetCookie(w, &http.Cookie{ Name: sessionCookieName, Value: encrypted, Path: "/", HttpOnly: true, Secure: secure, SameSite: http.SameSiteLaxMode, MaxAge: int(sessionTTL.Seconds()), }) return nil } // GetSessionFromCookie retrieves the session from the request cookie func (m *OAuthManager) GetSessionFromCookie(r *http.Request) *OAuthSession { cookie, err := r.Cookie(sessionCookieName) if err != nil { return nil } sessionID, err := decryptSessionID(cookie.Value, m.cookieSecret) if err != nil { return nil } return m.sessions.GetSession(sessionID) } // ClearSessionCookie removes the session cookie func (m *OAuthManager) ClearSessionCookie(w http.ResponseWriter) { http.SetCookie(w, &http.Cookie{ Name: sessionCookieName, Value: "", Path: "/", HttpOnly: true, Secure: true, SameSite: http.SameSiteLaxMode, MaxAge: -1, }) } // SessionInfo is the public session info returned to the client type SessionInfo struct { DID string `json:"did"` Handle string `json:"handle"` ExpiresAt time.Time `json:"expires_at"` } // MarshalJSON converts SessionInfo to JSON func (s *SessionInfo) MarshalJSON() ([]byte, error) { type Alias SessionInfo return json.Marshal((*Alias)(s)) }