288 lines
7.6 KiB
Go
288 lines
7.6 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
neturl "net/url"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
oauth "github.com/haileyok/atproto-oauth-golang"
|
|
"github.com/haileyok/atproto-oauth-golang/helpers"
|
|
"github.com/lestrrat-go/jwx/v2/jwk"
|
|
)
|
|
|
|
// OAuthManager handles OAuth 2.0 authentication for the dashboard
|
|
type OAuthManager struct {
|
|
client *oauth.Client
|
|
clientID string
|
|
redirectURI string
|
|
privateJWK jwk.Key
|
|
publicJWK jwk.Key
|
|
sessions *SessionStore
|
|
cookieSecret []byte
|
|
allowedScope string
|
|
}
|
|
|
|
// OAuthConfig holds configuration for the OAuth manager
|
|
type OAuthConfig struct {
|
|
ClientID string // URL to client metadata (e.g., https://app.1440.news/.well-known/oauth-client-metadata)
|
|
RedirectURI string // OAuth callback URL (e.g., https://app.1440.news/auth/callback)
|
|
CookieSecret string // 32-byte hex string for AES-256-GCM encryption
|
|
PrivateJWK string // ES256 private key as JSON
|
|
}
|
|
|
|
// NewOAuthManager creates a new OAuth manager
|
|
func NewOAuthManager(cfg OAuthConfig, db *DB) (*OAuthManager, error) {
|
|
// Parse cookie secret (must be 32 bytes for AES-256)
|
|
cookieSecret, err := parseHexSecret(cfg.CookieSecret)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid cookie secret: %v", err)
|
|
}
|
|
if len(cookieSecret) != 32 {
|
|
return nil, fmt.Errorf("cookie secret must be 32 bytes, got %d", len(cookieSecret))
|
|
}
|
|
|
|
// Parse private JWK
|
|
privateJWK, err := helpers.ParseJWKFromBytes([]byte(cfg.PrivateJWK))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid private JWK: %v", err)
|
|
}
|
|
|
|
// Extract public key
|
|
publicJWK, err := privateJWK.PublicKey()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to extract public key: %v", err)
|
|
}
|
|
|
|
// Create HTTP client with longer timeout
|
|
httpClient := &http.Client{
|
|
Timeout: 30 * time.Second,
|
|
}
|
|
|
|
// Create OAuth client
|
|
client, err := oauth.NewClient(oauth.ClientArgs{
|
|
Http: httpClient,
|
|
ClientJwk: privateJWK,
|
|
ClientId: cfg.ClientID,
|
|
RedirectUri: cfg.RedirectURI,
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create OAuth client: %v", err)
|
|
}
|
|
|
|
return &OAuthManager{
|
|
client: client,
|
|
clientID: cfg.ClientID,
|
|
redirectURI: cfg.RedirectURI,
|
|
privateJWK: privateJWK,
|
|
publicJWK: publicJWK,
|
|
sessions: NewSessionStore(db),
|
|
cookieSecret: cookieSecret,
|
|
allowedScope: "atproto",
|
|
}, nil
|
|
}
|
|
|
|
// LoadOAuthConfig loads OAuth configuration from environment or oauth.env file
|
|
func LoadOAuthConfig(baseURL string) (*OAuthConfig, error) {
|
|
cfg := &OAuthConfig{
|
|
ClientID: baseURL + "/.well-known/oauth-client-metadata",
|
|
RedirectURI: baseURL + "/auth/callback",
|
|
}
|
|
|
|
// Try environment variables first
|
|
cfg.CookieSecret = os.Getenv("OAUTH_COOKIE_SECRET")
|
|
cfg.PrivateJWK = os.Getenv("OAUTH_PRIVATE_JWK")
|
|
|
|
// Fall back to oauth.env file
|
|
if cfg.CookieSecret == "" || cfg.PrivateJWK == "" {
|
|
if data, err := os.ReadFile("oauth.env"); err == nil {
|
|
for _, line := range strings.Split(string(data), "\n") {
|
|
line = strings.TrimSpace(line)
|
|
if strings.HasPrefix(line, "#") || line == "" {
|
|
continue
|
|
}
|
|
parts := strings.SplitN(line, "=", 2)
|
|
if len(parts) == 2 {
|
|
key := strings.TrimSpace(parts[0])
|
|
value := strings.TrimSpace(parts[1])
|
|
switch key {
|
|
case "OAUTH_COOKIE_SECRET":
|
|
cfg.CookieSecret = value
|
|
case "OAUTH_PRIVATE_JWK":
|
|
cfg.PrivateJWK = value
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Validate required fields
|
|
if cfg.CookieSecret == "" {
|
|
return nil, fmt.Errorf("OAUTH_COOKIE_SECRET not configured")
|
|
}
|
|
if cfg.PrivateJWK == "" {
|
|
return nil, fmt.Errorf("OAUTH_PRIVATE_JWK not configured")
|
|
}
|
|
|
|
return cfg, nil
|
|
}
|
|
|
|
// parseHexSecret converts a hex string to bytes
|
|
func parseHexSecret(hex string) ([]byte, error) {
|
|
if len(hex)%2 != 0 {
|
|
return nil, fmt.Errorf("hex string must have even length")
|
|
}
|
|
b := make([]byte, len(hex)/2)
|
|
for i := 0; i < len(hex); i += 2 {
|
|
var val byte
|
|
for j := 0; j < 2; j++ {
|
|
c := hex[i+j]
|
|
switch {
|
|
case c >= '0' && c <= '9':
|
|
val = val*16 + (c - '0')
|
|
case c >= 'a' && c <= 'f':
|
|
val = val*16 + (c - 'a' + 10)
|
|
case c >= 'A' && c <= 'F':
|
|
val = val*16 + (c - 'A' + 10)
|
|
default:
|
|
return nil, fmt.Errorf("invalid hex character: %c", c)
|
|
}
|
|
}
|
|
b[i/2] = val
|
|
}
|
|
return b, nil
|
|
}
|
|
|
|
// resolveHandle resolves a Bluesky handle to a DID
|
|
func resolveHandle(ctx context.Context, handle string) (string, error) {
|
|
// Normalize handle (remove @ prefix and whitespace)
|
|
handle = strings.TrimSpace(handle)
|
|
handle = strings.TrimPrefix(handle, "@")
|
|
handle = strings.ToLower(handle)
|
|
|
|
// Try DNS-based resolution first
|
|
url := fmt.Sprintf("https://bsky.social/xrpc/com.atproto.identity.resolveHandle?handle=%s", neturl.QueryEscape(handle))
|
|
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
return "", fmt.Errorf("resolve handle failed: %s", string(body))
|
|
}
|
|
|
|
var result struct {
|
|
DID string `json:"did"`
|
|
}
|
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return result.DID, nil
|
|
}
|
|
|
|
// resolveDIDToHandle resolves a DID to the current handle
|
|
func resolveDIDToHandle(ctx context.Context, did string) (string, error) {
|
|
// Fetch DID document
|
|
var docURL string
|
|
if strings.HasPrefix(did, "did:plc:") {
|
|
docURL = fmt.Sprintf("https://plc.directory/%s", did)
|
|
} else if strings.HasPrefix(did, "did:web:") {
|
|
domain := strings.TrimPrefix(did, "did:web:")
|
|
docURL = fmt.Sprintf("https://%s/.well-known/did.json", domain)
|
|
} else {
|
|
return "", fmt.Errorf("unsupported DID method: %s", did)
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "GET", docURL, nil)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return "", fmt.Errorf("failed to fetch DID document: %d", resp.StatusCode)
|
|
}
|
|
|
|
var doc struct {
|
|
AlsoKnownAs []string `json:"alsoKnownAs"`
|
|
}
|
|
if err := json.NewDecoder(resp.Body).Decode(&doc); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// Find the at:// handle
|
|
for _, aka := range doc.AlsoKnownAs {
|
|
if strings.HasPrefix(aka, "at://") {
|
|
return strings.TrimPrefix(aka, "at://"), nil
|
|
}
|
|
}
|
|
|
|
return "", fmt.Errorf("no handle found for DID %s", did)
|
|
}
|
|
|
|
// resolveDIDToService gets the PDS service URL from a DID
|
|
func resolveDIDToService(ctx context.Context, did string) (string, error) {
|
|
var docURL string
|
|
if strings.HasPrefix(did, "did:plc:") {
|
|
docURL = fmt.Sprintf("https://plc.directory/%s", did)
|
|
} else if strings.HasPrefix(did, "did:web:") {
|
|
domain := strings.TrimPrefix(did, "did:web:")
|
|
docURL = fmt.Sprintf("https://%s/.well-known/did.json", domain)
|
|
} else {
|
|
return "", fmt.Errorf("unsupported DID method: %s", did)
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "GET", docURL, nil)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return "", fmt.Errorf("failed to fetch DID document: %d", resp.StatusCode)
|
|
}
|
|
|
|
var doc struct {
|
|
Service []struct {
|
|
ID string `json:"id"`
|
|
Type string `json:"type"`
|
|
ServiceEndpoint string `json:"serviceEndpoint"`
|
|
} `json:"service"`
|
|
}
|
|
if err := json.NewDecoder(resp.Body).Decode(&doc); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// Find the atproto_pds service
|
|
for _, svc := range doc.Service {
|
|
if svc.Type == "AtprotoPersonalDataServer" || svc.ID == "#atproto_pds" {
|
|
return svc.ServiceEndpoint, nil
|
|
}
|
|
}
|
|
|
|
return "", fmt.Errorf("no PDS service found for DID %s", did)
|
|
}
|