402 lines
10 KiB
Go
402 lines
10 KiB
Go
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"compress/gzip"
|
|
"database/sql"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"strings"
|
|
"sync/atomic"
|
|
"time"
|
|
)
|
|
|
|
// Domain represents a host to be crawled for feeds
|
|
type Domain struct {
|
|
Host string `json:"host"`
|
|
Status string `json:"status"`
|
|
DiscoveredAt time.Time `json:"discovered_at"`
|
|
LastCrawledAt time.Time `json:"last_crawled_at,omitempty"`
|
|
FeedsFound int `json:"feeds_found,omitempty"`
|
|
LastError string `json:"last_error,omitempty"`
|
|
TLD string `json:"tld,omitempty"`
|
|
}
|
|
|
|
// saveDomain stores a domain in SQLite
|
|
func (c *Crawler) saveDomain(domain *Domain) error {
|
|
_, err := c.db.Exec(`
|
|
INSERT INTO domains (host, status, discoveredAt, lastCrawledAt, feedsFound, lastError, tld)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT(host) DO UPDATE SET
|
|
status = excluded.status,
|
|
lastCrawledAt = excluded.lastCrawledAt,
|
|
feedsFound = excluded.feedsFound,
|
|
lastError = excluded.lastError,
|
|
tld = excluded.tld
|
|
`, domain.Host, domain.Status, domain.DiscoveredAt, nullTime(domain.LastCrawledAt),
|
|
domain.FeedsFound, nullString(domain.LastError), domain.TLD)
|
|
return err
|
|
}
|
|
|
|
// saveDomainTx stores a domain using a transaction
|
|
func (c *Crawler) saveDomainTx(tx *sql.Tx, domain *Domain) error {
|
|
_, err := tx.Exec(`
|
|
INSERT INTO domains (host, status, discoveredAt, lastCrawledAt, feedsFound, lastError, tld)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT(host) DO NOTHING
|
|
`, domain.Host, domain.Status, domain.DiscoveredAt, nullTime(domain.LastCrawledAt),
|
|
domain.FeedsFound, nullString(domain.LastError), domain.TLD)
|
|
return err
|
|
}
|
|
|
|
// domainExists checks if a domain already exists in the database
|
|
func (c *Crawler) domainExists(host string) bool {
|
|
var exists bool
|
|
err := c.db.QueryRow("SELECT EXISTS(SELECT 1 FROM domains WHERE host = ?)", normalizeHost(host)).Scan(&exists)
|
|
return err == nil && exists
|
|
}
|
|
|
|
// getDomain retrieves a domain from SQLite
|
|
func (c *Crawler) getDomain(host string) (*Domain, error) {
|
|
domain := &Domain{}
|
|
var lastCrawledAt sql.NullTime
|
|
var lastError sql.NullString
|
|
|
|
err := c.db.QueryRow(`
|
|
SELECT host, status, discoveredAt, lastCrawledAt, feedsFound, lastError, tld
|
|
FROM domains WHERE host = ?
|
|
`, normalizeHost(host)).Scan(
|
|
&domain.Host, &domain.Status, &domain.DiscoveredAt, &lastCrawledAt,
|
|
&domain.FeedsFound, &lastError, &domain.TLD,
|
|
)
|
|
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if lastCrawledAt.Valid {
|
|
domain.LastCrawledAt = lastCrawledAt.Time
|
|
}
|
|
if lastError.Valid {
|
|
domain.LastError = lastError.String
|
|
}
|
|
|
|
return domain, nil
|
|
}
|
|
|
|
// GetUncheckedDomains returns all domains with status "unchecked"
|
|
func (c *Crawler) GetUncheckedDomains() ([]*Domain, error) {
|
|
rows, err := c.db.Query(`
|
|
SELECT host, status, discoveredAt, lastCrawledAt, feedsFound, lastError, tld
|
|
FROM domains WHERE status = 'unchecked'
|
|
`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
return c.scanDomains(rows)
|
|
}
|
|
|
|
// GetUncheckedDomainsRandom returns up to limit unchecked domains in random order
|
|
func (c *Crawler) GetUncheckedDomainsRandom(limit int) ([]*Domain, error) {
|
|
rows, err := c.db.Query(`
|
|
SELECT host, status, discoveredAt, lastCrawledAt, feedsFound, lastError, tld
|
|
FROM domains WHERE status = 'unchecked'
|
|
ORDER BY RANDOM()
|
|
LIMIT ?
|
|
`, limit)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
return c.scanDomains(rows)
|
|
}
|
|
|
|
// scanDomains is a helper to scan multiple domain rows
|
|
func (c *Crawler) scanDomains(rows *sql.Rows) ([]*Domain, error) {
|
|
var domains []*Domain
|
|
for rows.Next() {
|
|
domain := &Domain{}
|
|
var lastCrawledAt sql.NullTime
|
|
var lastError sql.NullString
|
|
|
|
if err := rows.Scan(
|
|
&domain.Host, &domain.Status, &domain.DiscoveredAt, &lastCrawledAt,
|
|
&domain.FeedsFound, &lastError, &domain.TLD,
|
|
); err != nil {
|
|
continue
|
|
}
|
|
|
|
if lastCrawledAt.Valid {
|
|
domain.LastCrawledAt = lastCrawledAt.Time
|
|
}
|
|
if lastError.Valid {
|
|
domain.LastError = lastError.String
|
|
}
|
|
|
|
domains = append(domains, domain)
|
|
}
|
|
|
|
return domains, rows.Err()
|
|
}
|
|
|
|
// markDomainCrawled updates a domain's status after crawling
|
|
func (c *Crawler) markDomainCrawled(host string, feedsFound int, lastError string) error {
|
|
status := "checked"
|
|
if lastError != "" {
|
|
status = "error"
|
|
}
|
|
|
|
var err error
|
|
if lastError != "" {
|
|
_, err = c.db.Exec(`
|
|
UPDATE domains SET status = ?, lastCrawledAt = ?, feedsFound = ?, lastError = ?
|
|
WHERE host = ?
|
|
`, status, time.Now(), feedsFound, lastError, normalizeHost(host))
|
|
} else {
|
|
_, err = c.db.Exec(`
|
|
UPDATE domains SET status = ?, lastCrawledAt = ?, feedsFound = ?, lastError = NULL
|
|
WHERE host = ?
|
|
`, status, time.Now(), feedsFound, normalizeHost(host))
|
|
}
|
|
return err
|
|
}
|
|
|
|
// GetDomainCount returns the total number of domains in the database
|
|
func (c *Crawler) GetDomainCount() (total int, unchecked int, err error) {
|
|
err = c.db.QueryRow("SELECT COUNT(*) FROM domains").Scan(&total)
|
|
if err != nil {
|
|
return 0, 0, err
|
|
}
|
|
err = c.db.QueryRow("SELECT COUNT(*) FROM domains WHERE status = 'unchecked'").Scan(&unchecked)
|
|
return total, unchecked, err
|
|
}
|
|
|
|
// ImportDomainsFromFile reads a vertices file and stores new domains as "unchecked"
|
|
func (c *Crawler) ImportDomainsFromFile(filename string, limit int) (imported int, skipped int, err error) {
|
|
file, err := os.Open(filename)
|
|
if err != nil {
|
|
return 0, 0, fmt.Errorf("failed to open file: %v", err)
|
|
}
|
|
defer file.Close()
|
|
|
|
return c.parseAndStoreDomains(file, limit)
|
|
}
|
|
|
|
// ImportDomainsInBackground starts domain import in a background goroutine
|
|
func (c *Crawler) ImportDomainsInBackground(filename string) {
|
|
go func() {
|
|
file, err := os.Open(filename)
|
|
if err != nil {
|
|
fmt.Printf("Failed to open vertices file: %v\n", err)
|
|
return
|
|
}
|
|
defer file.Close()
|
|
|
|
var bodyReader io.Reader
|
|
|
|
bufReader := bufio.NewReader(file)
|
|
peekBytes, err := bufReader.Peek(2)
|
|
if err != nil && err != io.EOF {
|
|
fmt.Printf("Failed to peek at file: %v\n", err)
|
|
return
|
|
}
|
|
|
|
if len(peekBytes) >= 2 && peekBytes[0] == 0x1f && peekBytes[1] == 0x8b {
|
|
gzReader, err := gzip.NewReader(bufReader)
|
|
if err != nil {
|
|
fmt.Printf("Failed to create gzip reader: %v\n", err)
|
|
return
|
|
}
|
|
defer gzReader.Close()
|
|
bodyReader = gzReader
|
|
} else {
|
|
bodyReader = bufReader
|
|
}
|
|
|
|
scanner := bufio.NewScanner(bodyReader)
|
|
buf := make([]byte, 0, 64*1024)
|
|
scanner.Buffer(buf, 1024*1024)
|
|
|
|
const batchSize = 10000
|
|
now := time.Now()
|
|
nowStr := now.Format("2006-01-02 15:04:05")
|
|
totalImported := 0
|
|
batchCount := 0
|
|
|
|
type domainEntry struct {
|
|
host string
|
|
tld string
|
|
}
|
|
|
|
for {
|
|
// Read and canonicalize batch
|
|
var domains []domainEntry
|
|
for len(domains) < batchSize && scanner.Scan() {
|
|
line := scanner.Text()
|
|
parts := strings.Split(line, "\t")
|
|
if len(parts) >= 2 {
|
|
reverseHostName := strings.TrimSpace(parts[1])
|
|
if reverseHostName != "" {
|
|
host := normalizeHost(reverseHost(reverseHostName))
|
|
domains = append(domains, domainEntry{host: host, tld: getTLD(host)})
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(domains) == 0 {
|
|
break
|
|
}
|
|
|
|
// Build bulk INSERT statement
|
|
var sb strings.Builder
|
|
sb.WriteString("INSERT INTO domains (host, status, discoveredAt, tld) VALUES ")
|
|
args := make([]interface{}, 0, len(domains)*4)
|
|
for i, d := range domains {
|
|
if i > 0 {
|
|
sb.WriteString(",")
|
|
}
|
|
sb.WriteString("(?, 'unchecked', ?, ?)")
|
|
args = append(args, d.host, nowStr, d.tld)
|
|
}
|
|
sb.WriteString(" ON CONFLICT(host) DO NOTHING")
|
|
|
|
// Execute bulk insert
|
|
result, err := c.db.Exec(sb.String(), args...)
|
|
imported := 0
|
|
if err != nil {
|
|
fmt.Printf("Bulk insert error: %v\n", err)
|
|
} else {
|
|
rowsAffected, _ := result.RowsAffected()
|
|
imported = int(rowsAffected)
|
|
}
|
|
|
|
batchCount++
|
|
totalImported += imported
|
|
atomic.AddInt32(&c.domainsImported, int32(imported))
|
|
|
|
// Wait 1 second before the next batch
|
|
time.Sleep(1 * time.Second)
|
|
}
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
fmt.Printf("Error reading vertices file: %v\n", err)
|
|
}
|
|
|
|
fmt.Printf("Background import complete: %d domains imported\n", totalImported)
|
|
}()
|
|
}
|
|
|
|
func (c *Crawler) parseAndStoreDomains(reader io.Reader, limit int) (imported int, skipped int, err error) {
|
|
var bodyReader io.Reader
|
|
|
|
bufReader := bufio.NewReader(reader)
|
|
peekBytes, err := bufReader.Peek(2)
|
|
if err != nil && err != io.EOF {
|
|
return 0, 0, fmt.Errorf("failed to peek at file: %v", err)
|
|
}
|
|
|
|
if len(peekBytes) >= 2 && peekBytes[0] == 0x1f && peekBytes[1] == 0x8b {
|
|
gzReader, err := gzip.NewReader(bufReader)
|
|
if err != nil {
|
|
return 0, 0, fmt.Errorf("failed to create gzip reader: %v", err)
|
|
}
|
|
defer gzReader.Close()
|
|
bodyReader = gzReader
|
|
} else {
|
|
bodyReader = bufReader
|
|
}
|
|
|
|
scanner := bufio.NewScanner(bodyReader)
|
|
buf := make([]byte, 0, 64*1024)
|
|
scanner.Buffer(buf, 1024*1024)
|
|
|
|
now := time.Now()
|
|
nowStr := now.Format("2006-01-02 15:04:05")
|
|
count := 0
|
|
const batchSize = 1000
|
|
|
|
type domainEntry struct {
|
|
host string
|
|
tld string
|
|
}
|
|
|
|
for {
|
|
// Read and canonicalize batch
|
|
var domains []domainEntry
|
|
for len(domains) < batchSize && scanner.Scan() {
|
|
if limit > 0 && count >= limit {
|
|
break
|
|
}
|
|
line := scanner.Text()
|
|
parts := strings.Split(line, "\t")
|
|
if len(parts) >= 2 {
|
|
reverseHostName := strings.TrimSpace(parts[1])
|
|
if reverseHostName != "" {
|
|
host := normalizeHost(reverseHost(reverseHostName))
|
|
domains = append(domains, domainEntry{host: host, tld: getTLD(host)})
|
|
count++
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(domains) == 0 {
|
|
break
|
|
}
|
|
|
|
// Build bulk INSERT statement
|
|
var sb strings.Builder
|
|
sb.WriteString("INSERT INTO domains (host, status, discoveredAt, tld) VALUES ")
|
|
args := make([]interface{}, 0, len(domains)*4)
|
|
for i, d := range domains {
|
|
if i > 0 {
|
|
sb.WriteString(",")
|
|
}
|
|
sb.WriteString("(?, 'unchecked', ?, ?)")
|
|
args = append(args, d.host, nowStr, d.tld)
|
|
}
|
|
sb.WriteString(" ON CONFLICT(host) DO NOTHING")
|
|
|
|
// Execute bulk insert
|
|
result, execErr := c.db.Exec(sb.String(), args...)
|
|
if execErr != nil {
|
|
skipped += len(domains)
|
|
continue
|
|
}
|
|
rowsAffected, _ := result.RowsAffected()
|
|
imported += int(rowsAffected)
|
|
skipped += len(domains) - int(rowsAffected)
|
|
|
|
if limit > 0 && count >= limit {
|
|
break
|
|
}
|
|
}
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
return imported, skipped, fmt.Errorf("error reading file: %v", err)
|
|
}
|
|
|
|
return imported, skipped, nil
|
|
}
|
|
|
|
// Helper functions for SQL null handling
|
|
func nullTime(t time.Time) sql.NullTime {
|
|
if t.IsZero() {
|
|
return sql.NullTime{}
|
|
}
|
|
return sql.NullTime{Time: t, Valid: true}
|
|
}
|
|
|
|
func nullString(s string) sql.NullString {
|
|
if s == "" {
|
|
return sql.NullString{}
|
|
}
|
|
return sql.NullString{String: s, Valid: true}
|
|
}
|