diff --git a/db.go b/db.go index f16483e..d6725fd 100644 --- a/db.go +++ b/db.go @@ -160,6 +160,25 @@ CREATE TABLE IF NOT EXISTS clicks ( CREATE INDEX IF NOT EXISTS idx_clicks_short_code ON clicks(short_code); CREATE INDEX IF NOT EXISTS idx_clicks_clicked_at ON clicks(clicked_at DESC); +-- OAuth sessions (persisted for login persistence across deploys) +CREATE TABLE IF NOT EXISTS oauth_sessions ( + id TEXT PRIMARY KEY, + did TEXT NOT NULL, + handle TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL, + expires_at TIMESTAMPTZ NOT NULL, + access_token TEXT, + refresh_token TEXT, + token_expiry TIMESTAMPTZ, + dpop_private_jwk TEXT, + dpop_authserver_nonce TEXT, + dpop_pds_nonce TEXT, + pds_url TEXT, + authserver_iss TEXT +); + +CREATE INDEX IF NOT EXISTS idx_oauth_sessions_expires_at ON oauth_sessions(expires_at); + -- Trigger to normalize feed URLs on insert/update (strips https://, http://, www.) CREATE OR REPLACE FUNCTION normalize_feed_url() RETURNS TRIGGER AS $$ diff --git a/oauth.go b/oauth.go index 39610e7..aeff41b 100644 --- a/oauth.go +++ b/oauth.go @@ -37,7 +37,7 @@ type OAuthConfig struct { } // NewOAuthManager creates a new OAuth manager -func NewOAuthManager(cfg OAuthConfig) (*OAuthManager, error) { +func NewOAuthManager(cfg OAuthConfig, db *DB) (*OAuthManager, error) { // Parse cookie secret (must be 32 bytes for AES-256) cookieSecret, err := parseHexSecret(cfg.CookieSecret) if err != nil { @@ -81,7 +81,7 @@ func NewOAuthManager(cfg OAuthConfig) (*OAuthManager, error) { redirectURI: cfg.RedirectURI, privateJWK: privateJWK, publicJWK: publicJWK, - sessions: NewSessionStore(), + sessions: NewSessionStore(db), cookieSecret: cookieSecret, allowedScope: "atproto", }, nil diff --git a/oauth_session.go b/oauth_session.go index 28a51c9..9181d9f 100644 --- a/oauth_session.go +++ b/oauth_session.go @@ -53,19 +53,19 @@ type PendingAuth struct { CreatedAt time.Time `json:"created_at"` } -// SessionStore manages sessions in memory +// SessionStore manages sessions in the database type SessionStore struct { - sessions map[string]*OAuthSession - pending map[string]*PendingAuth // keyed by state + db *DB + pending map[string]*PendingAuth // keyed by state (short-lived, kept in memory) mu sync.RWMutex cleanupOnce sync.Once } // NewSessionStore creates a new session store -func NewSessionStore() *SessionStore { +func NewSessionStore(db *DB) *SessionStore { s := &SessionStore{ - sessions: make(map[string]*OAuthSession), - pending: make(map[string]*PendingAuth), + db: db, + pending: make(map[string]*PendingAuth), } s.startCleanup() return s @@ -86,19 +86,13 @@ func (s *SessionStore) startCleanup() { // cleanup removes expired sessions and pending auths func (s *SessionStore) cleanup() { + // Clean up expired sessions from database + s.db.Exec("DELETE FROM oauth_sessions WHERE expires_at < NOW()") + + // Clean up old pending auths (10 minute timeout) from memory s.mu.Lock() defer s.mu.Unlock() - now := time.Now() - - // Clean up expired sessions - for id, sess := range s.sessions { - if now.After(sess.ExpiresAt) { - delete(s.sessions, id) - } - } - - // Clean up old pending auths (10 minute timeout) for state, pending := range s.pending { if now.Sub(pending.CreatedAt) > 10*time.Minute { delete(s.pending, state) @@ -122,40 +116,88 @@ func (s *SessionStore) CreateSession(did, handle string) (*OAuthSession, error) ExpiresAt: now.Add(sessionTTL), } - s.mu.Lock() - s.sessions[id] = session - s.mu.Unlock() + _, err = s.db.Exec(` + INSERT INTO oauth_sessions (id, did, handle, created_at, expires_at) + VALUES ($1, $2, $3, $4, $5) + `, session.ID, session.DID, session.Handle, session.CreatedAt, session.ExpiresAt) + if err != nil { + return nil, err + } return session, nil } // GetSession retrieves a session by ID func (s *SessionStore) GetSession(id string) *OAuthSession { - s.mu.RLock() - defer s.mu.RUnlock() + row := s.db.QueryRow(` + SELECT id, did, handle, created_at, expires_at, + access_token, refresh_token, token_expiry, + dpop_private_jwk, dpop_authserver_nonce, dpop_pds_nonce, + pds_url, authserver_iss + FROM oauth_sessions + WHERE id = $1 AND expires_at > NOW() + `, id) - session, ok := s.sessions[id] - if !ok || time.Now().After(session.ExpiresAt) { + var session OAuthSession + var accessToken, refreshToken, dpopJwk, dpopAuthNonce, dpopPdsNonce, pdsURL, authIss *string + var tokenExpiry *time.Time + + err := row.Scan( + &session.ID, &session.DID, &session.Handle, &session.CreatedAt, &session.ExpiresAt, + &accessToken, &refreshToken, &tokenExpiry, + &dpopJwk, &dpopAuthNonce, &dpopPdsNonce, + &pdsURL, &authIss, + ) + if err != nil { return nil } - return session + + session.AccessToken = StringValue(accessToken) + session.RefreshToken = StringValue(refreshToken) + if tokenExpiry != nil { + session.TokenExpiry = *tokenExpiry + } + session.DpopPrivateJWK = StringValue(dpopJwk) + session.DpopAuthserverNonce = StringValue(dpopAuthNonce) + session.DpopPdsNonce = StringValue(dpopPdsNonce) + session.PdsURL = StringValue(pdsURL) + session.AuthserverIss = StringValue(authIss) + + return &session } // UpdateSession updates a session func (s *SessionStore) UpdateSession(session *OAuthSession) { - s.mu.Lock() - defer s.mu.Unlock() - s.sessions[session.ID] = session + s.db.Exec(` + UPDATE oauth_sessions SET + access_token = $2, + refresh_token = $3, + token_expiry = $4, + dpop_private_jwk = $5, + dpop_authserver_nonce = $6, + dpop_pds_nonce = $7, + pds_url = $8, + authserver_iss = $9 + WHERE id = $1 + `, + session.ID, + NullableString(session.AccessToken), + NullableString(session.RefreshToken), + NullableTime(session.TokenExpiry), + NullableString(session.DpopPrivateJWK), + NullableString(session.DpopAuthserverNonce), + NullableString(session.DpopPdsNonce), + NullableString(session.PdsURL), + NullableString(session.AuthserverIss), + ) } // DeleteSession removes a session func (s *SessionStore) DeleteSession(id string) { - s.mu.Lock() - defer s.mu.Unlock() - delete(s.sessions, id) + s.db.Exec("DELETE FROM oauth_sessions WHERE id = $1", id) } -// SavePending saves pending OAuth state +// SavePending saves pending OAuth state (kept in memory - short lived) func (s *SessionStore) SavePending(state string, pending *PendingAuth) { s.mu.Lock() defer s.mu.Unlock() diff --git a/routes.go b/routes.go index 9ce1f4c..67a556f 100644 --- a/routes.go +++ b/routes.go @@ -25,7 +25,7 @@ func (c *Crawler) StartDashboard(addr string) error { if err != nil { fmt.Printf("OAuth not configured: %v (dashboard will be unprotected)\n", err) } else { - oauth, err = NewOAuthManager(*oauthCfg) + oauth, err = NewOAuthManager(*oauthCfg, c.db) if err != nil { fmt.Printf("Failed to initialize OAuth: %v (dashboard will be unprotected)\n", err) oauth = nil diff --git a/templates.go b/templates.go index 1663261..29e4a54 100644 --- a/templates.go +++ b/templates.go @@ -534,7 +534,7 @@ const dashboardHTML = `
-