v40: Persist OAuth sessions to database

This commit is contained in:
primal
2026-01-30 16:09:46 -05:00
parent 31b7b61bb0
commit e0602b0123
5 changed files with 96 additions and 35 deletions
+19
View File
@@ -160,6 +160,25 @@ CREATE TABLE IF NOT EXISTS clicks (
CREATE INDEX IF NOT EXISTS idx_clicks_short_code ON clicks(short_code);
CREATE INDEX IF NOT EXISTS idx_clicks_clicked_at ON clicks(clicked_at DESC);
-- OAuth sessions (persisted for login persistence across deploys)
CREATE TABLE IF NOT EXISTS oauth_sessions (
id TEXT PRIMARY KEY,
did TEXT NOT NULL,
handle TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL,
expires_at TIMESTAMPTZ NOT NULL,
access_token TEXT,
refresh_token TEXT,
token_expiry TIMESTAMPTZ,
dpop_private_jwk TEXT,
dpop_authserver_nonce TEXT,
dpop_pds_nonce TEXT,
pds_url TEXT,
authserver_iss TEXT
);
CREATE INDEX IF NOT EXISTS idx_oauth_sessions_expires_at ON oauth_sessions(expires_at);
-- Trigger to normalize feed URLs on insert/update (strips https://, http://, www.)
CREATE OR REPLACE FUNCTION normalize_feed_url()
RETURNS TRIGGER AS $$
+2 -2
View File
@@ -37,7 +37,7 @@ type OAuthConfig struct {
}
// NewOAuthManager creates a new OAuth manager
func NewOAuthManager(cfg OAuthConfig) (*OAuthManager, error) {
func NewOAuthManager(cfg OAuthConfig, db *DB) (*OAuthManager, error) {
// Parse cookie secret (must be 32 bytes for AES-256)
cookieSecret, err := parseHexSecret(cfg.CookieSecret)
if err != nil {
@@ -81,7 +81,7 @@ func NewOAuthManager(cfg OAuthConfig) (*OAuthManager, error) {
redirectURI: cfg.RedirectURI,
privateJWK: privateJWK,
publicJWK: publicJWK,
sessions: NewSessionStore(),
sessions: NewSessionStore(db),
cookieSecret: cookieSecret,
allowedScope: "atproto",
}, nil
+73 -31
View File
@@ -53,19 +53,19 @@ type PendingAuth struct {
CreatedAt time.Time `json:"created_at"`
}
// SessionStore manages sessions in memory
// SessionStore manages sessions in the database
type SessionStore struct {
sessions map[string]*OAuthSession
pending map[string]*PendingAuth // keyed by state
db *DB
pending map[string]*PendingAuth // keyed by state (short-lived, kept in memory)
mu sync.RWMutex
cleanupOnce sync.Once
}
// NewSessionStore creates a new session store
func NewSessionStore() *SessionStore {
func NewSessionStore(db *DB) *SessionStore {
s := &SessionStore{
sessions: make(map[string]*OAuthSession),
pending: make(map[string]*PendingAuth),
db: db,
pending: make(map[string]*PendingAuth),
}
s.startCleanup()
return s
@@ -86,19 +86,13 @@ func (s *SessionStore) startCleanup() {
// cleanup removes expired sessions and pending auths
func (s *SessionStore) cleanup() {
// Clean up expired sessions from database
s.db.Exec("DELETE FROM oauth_sessions WHERE expires_at < NOW()")
// Clean up old pending auths (10 minute timeout) from memory
s.mu.Lock()
defer s.mu.Unlock()
now := time.Now()
// Clean up expired sessions
for id, sess := range s.sessions {
if now.After(sess.ExpiresAt) {
delete(s.sessions, id)
}
}
// Clean up old pending auths (10 minute timeout)
for state, pending := range s.pending {
if now.Sub(pending.CreatedAt) > 10*time.Minute {
delete(s.pending, state)
@@ -122,40 +116,88 @@ func (s *SessionStore) CreateSession(did, handle string) (*OAuthSession, error)
ExpiresAt: now.Add(sessionTTL),
}
s.mu.Lock()
s.sessions[id] = session
s.mu.Unlock()
_, err = s.db.Exec(`
INSERT INTO oauth_sessions (id, did, handle, created_at, expires_at)
VALUES ($1, $2, $3, $4, $5)
`, session.ID, session.DID, session.Handle, session.CreatedAt, session.ExpiresAt)
if err != nil {
return nil, err
}
return session, nil
}
// GetSession retrieves a session by ID
func (s *SessionStore) GetSession(id string) *OAuthSession {
s.mu.RLock()
defer s.mu.RUnlock()
row := s.db.QueryRow(`
SELECT id, did, handle, created_at, expires_at,
access_token, refresh_token, token_expiry,
dpop_private_jwk, dpop_authserver_nonce, dpop_pds_nonce,
pds_url, authserver_iss
FROM oauth_sessions
WHERE id = $1 AND expires_at > NOW()
`, id)
session, ok := s.sessions[id]
if !ok || time.Now().After(session.ExpiresAt) {
var session OAuthSession
var accessToken, refreshToken, dpopJwk, dpopAuthNonce, dpopPdsNonce, pdsURL, authIss *string
var tokenExpiry *time.Time
err := row.Scan(
&session.ID, &session.DID, &session.Handle, &session.CreatedAt, &session.ExpiresAt,
&accessToken, &refreshToken, &tokenExpiry,
&dpopJwk, &dpopAuthNonce, &dpopPdsNonce,
&pdsURL, &authIss,
)
if err != nil {
return nil
}
return session
session.AccessToken = StringValue(accessToken)
session.RefreshToken = StringValue(refreshToken)
if tokenExpiry != nil {
session.TokenExpiry = *tokenExpiry
}
session.DpopPrivateJWK = StringValue(dpopJwk)
session.DpopAuthserverNonce = StringValue(dpopAuthNonce)
session.DpopPdsNonce = StringValue(dpopPdsNonce)
session.PdsURL = StringValue(pdsURL)
session.AuthserverIss = StringValue(authIss)
return &session
}
// UpdateSession updates a session
func (s *SessionStore) UpdateSession(session *OAuthSession) {
s.mu.Lock()
defer s.mu.Unlock()
s.sessions[session.ID] = session
s.db.Exec(`
UPDATE oauth_sessions SET
access_token = $2,
refresh_token = $3,
token_expiry = $4,
dpop_private_jwk = $5,
dpop_authserver_nonce = $6,
dpop_pds_nonce = $7,
pds_url = $8,
authserver_iss = $9
WHERE id = $1
`,
session.ID,
NullableString(session.AccessToken),
NullableString(session.RefreshToken),
NullableTime(session.TokenExpiry),
NullableString(session.DpopPrivateJWK),
NullableString(session.DpopAuthserverNonce),
NullableString(session.DpopPdsNonce),
NullableString(session.PdsURL),
NullableString(session.AuthserverIss),
)
}
// DeleteSession removes a session
func (s *SessionStore) DeleteSession(id string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.sessions, id)
s.db.Exec("DELETE FROM oauth_sessions WHERE id = $1", id)
}
// SavePending saves pending OAuth state
// SavePending saves pending OAuth state (kept in memory - short lived)
func (s *SessionStore) SavePending(state string, pending *PendingAuth) {
s.mu.Lock()
defer s.mu.Unlock()
+1 -1
View File
@@ -25,7 +25,7 @@ func (c *Crawler) StartDashboard(addr string) error {
if err != nil {
fmt.Printf("OAuth not configured: %v (dashboard will be unprotected)\n", err)
} else {
oauth, err = NewOAuthManager(*oauthCfg)
oauth, err = NewOAuthManager(*oauthCfg, c.db)
if err != nil {
fmt.Printf("Failed to initialize OAuth: %v (dashboard will be unprotected)\n", err)
oauth = nil
+1 -1
View File
@@ -534,7 +534,7 @@ const dashboardHTML = `<!DOCTYPE html>
<div id="output"></div>
</div>
<div style="color: #333; font-size: 11px; margin-top: 10px;">v39</div>
<div style="color: #333; font-size: 11px; margin-top: 10px;">v40</div>
<div class="updated" id="updatedAt">Last updated: {{.UpdatedAt.Format "2006-01-02 15:04:05"}}</div>
</body>