feat(agent): add WebSocket transport with HTTP fallback
Add Transport interface abstraction supporting WebSocket (via CF Durable Objects) and HTTP (direct to origin) with automatic failover. - Transport interface: Register, SendHeartbeat, SendProgress, Events() - HTTPTransport: thin adapter over existing Client - WSTransport: gorilla/websocket with auth handshake, readLoop, reconnect - HybridTransport: tries WS first, falls back to HTTP, reconnects in bg - Daemon refactored to always use Transport (no dual-path forks) - ProgressReporter accepts StatusReporter interface - deriveWSURL skips localhost/dev (returns "" → HTTP-only) - API key passed in WS query param for connection auth - Fixed: reconnectOnce race (mutex+bool), authDone double-close (sync.Once) - Fixed: forwardWSEvents goroutine leak (select with stop signal) - 20 transport tests + 2 E2E tests (full lifecycle, hybrid failover)
This commit is contained in:
parent
5e80911501
commit
5f337eebd7
10 changed files with 1646 additions and 64 deletions
|
|
@ -7,6 +7,7 @@ import (
|
|||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
|
|
@ -104,8 +105,6 @@ func runDaemonStart() error {
|
|||
heartbeatInterval = 30 * time.Second
|
||||
}
|
||||
|
||||
// Create agent client (direct HTTP — always available as fallback)
|
||||
ac := agent.NewClient(cfg.Auth.APIURL, cfg.Auth.APIKey, "unarr/"+Version)
|
||||
userAgent := "unarr/" + Version
|
||||
|
||||
// Create daemon config
|
||||
|
|
@ -119,6 +118,8 @@ func runDaemonStart() error {
|
|||
}
|
||||
|
||||
// Create transport: Hybrid (WS + HTTP fallback) or HTTP-only
|
||||
httpT := agent.NewHTTPTransport(cfg.Auth.APIURL, cfg.Auth.APIKey, userAgent)
|
||||
|
||||
wsURL := cfg.Auth.WSURL
|
||||
if wsURL == "" {
|
||||
wsURL = deriveWSURL(cfg.Auth.APIURL, cfg.Agent.ID)
|
||||
|
|
@ -126,28 +127,19 @@ func runDaemonStart() error {
|
|||
|
||||
var transport agent.Transport
|
||||
if wsURL != "" {
|
||||
httpT := agent.NewHTTPTransport(cfg.Auth.APIURL, cfg.Auth.APIKey, userAgent)
|
||||
wsT := agent.NewWSTransport(wsURL, cfg.Auth.APIKey, cfg.Agent.ID, userAgent)
|
||||
transport = agent.NewHybridTransport(wsT, httpT)
|
||||
log.Printf("Transport: WebSocket (fallback: HTTP) → %s", wsURL)
|
||||
} else {
|
||||
transport = httpT
|
||||
log.Println("Transport: HTTP only")
|
||||
}
|
||||
|
||||
// Create daemon
|
||||
var d *agent.Daemon
|
||||
if transport != nil {
|
||||
d = agent.NewDaemonWithTransport(daemonCfg, transport)
|
||||
} else {
|
||||
d = agent.NewDaemon(daemonCfg, ac)
|
||||
}
|
||||
// Create daemon — always uses Transport interface
|
||||
d := agent.NewDaemon(daemonCfg, transport)
|
||||
|
||||
// Wire state tracking (connected after manager creation below)
|
||||
// Create progress reporter
|
||||
var reporter *engine.ProgressReporter
|
||||
if transport != nil {
|
||||
reporter = engine.NewProgressReporterWithTransport(transport, 3*time.Second)
|
||||
} else {
|
||||
reporter = engine.NewProgressReporter(ac, 3*time.Second)
|
||||
}
|
||||
// Create progress reporter using transport
|
||||
reporter := engine.NewProgressReporterWithTransport(transport, 3*time.Second)
|
||||
|
||||
// Parse speed limits
|
||||
maxDl, _ := config.ParseSpeed(cfg.Download.MaxDownloadSpeed)
|
||||
|
|
@ -190,7 +182,7 @@ func runDaemonStart() error {
|
|||
MoviesDir: cfg.Organize.MoviesDir,
|
||||
TVShowsDir: cfg.Organize.TVShowsDir,
|
||||
},
|
||||
}, reporter, torrentDl, debridDl)
|
||||
}, reporter, torrentDl, debridDl, engine.NewUsenetDownloader(httpT.Client()))
|
||||
|
||||
// Wire state tracking
|
||||
d.GetActiveCount = manager.ActiveCount
|
||||
|
|
@ -275,9 +267,9 @@ func runDaemonStart() error {
|
|||
|
||||
log.Printf("[%s] streaming from disk: %s → %s", sr.TaskID[:8], filepath.Base(sr.FilePath), streamURL)
|
||||
|
||||
// Report stream URL back to the server
|
||||
// Report stream URL back to the server via transport
|
||||
go func() {
|
||||
if _, err := ac.ReportStatus(ctx, agent.StatusUpdate{
|
||||
if _, err := transport.SendProgress(ctx, agent.StatusUpdate{
|
||||
TaskID: sr.TaskID,
|
||||
StreamURL: streamURL,
|
||||
}); err != nil {
|
||||
|
|
@ -298,13 +290,18 @@ func runDaemonStart() error {
|
|||
case "resume":
|
||||
log.Printf("[%s] resume requested via WebSocket", taskID[:8])
|
||||
case "stream":
|
||||
// Use registry mutex to prevent TOCTOU race with HTTP-polled stream requests
|
||||
streamRegistry.mu.Lock()
|
||||
if _, exists := streamRegistry.servers[taskID]; exists {
|
||||
streamRegistry.mu.Unlock()
|
||||
return
|
||||
}
|
||||
task := manager.GetTask(taskID)
|
||||
if task == nil {
|
||||
return
|
||||
}
|
||||
if task.GetStreamURL() != "" {
|
||||
if task == nil || task.GetStreamURL() != "" {
|
||||
streamRegistry.mu.Unlock()
|
||||
return
|
||||
}
|
||||
streamRegistry.mu.Unlock()
|
||||
srv, err := torrentDl.StartStream(taskID)
|
||||
if err != nil {
|
||||
log.Printf("[%s] stream failed: %v", taskID[:8], err)
|
||||
|
|
@ -342,11 +339,7 @@ func runDaemonStart() error {
|
|||
Version: result.NewVersion,
|
||||
Error: errMsg,
|
||||
}
|
||||
if transport != nil {
|
||||
_ = transport.ReportUpgradeResult(reportCtx, upgradeResult)
|
||||
} else {
|
||||
_ = ac.ReportUpgradeResult(reportCtx, upgradeResult)
|
||||
}
|
||||
_ = transport.ReportUpgradeResult(reportCtx, upgradeResult)
|
||||
|
||||
if !result.Success {
|
||||
log.Printf("Upgrade failed: %v", result.Error)
|
||||
|
|
@ -360,7 +353,7 @@ func runDaemonStart() error {
|
|||
// Deregister first so the server knows we're restarting
|
||||
deregCtx, deregCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer deregCancel()
|
||||
_ = ac.Deregister(deregCtx, cfg.Agent.ID)
|
||||
_ = transport.Deregister(deregCtx, cfg.Agent.ID)
|
||||
|
||||
// Flush progress reporter
|
||||
cancel()
|
||||
|
|
@ -418,6 +411,7 @@ func runDaemonStart() error {
|
|||
|
||||
// deriveWSURL derives a WebSocket URL from the API URL.
|
||||
// https://torrentclaw.com → wss://unarr.torrentclaw.com/ws/{agentId}
|
||||
// Returns "" for localhost/dev environments where WS gateway isn't available.
|
||||
func deriveWSURL(apiURL, agentID string) string {
|
||||
if apiURL == "" || agentID == "" {
|
||||
return ""
|
||||
|
|
@ -437,6 +431,15 @@ func deriveWSURL(apiURL, agentID string) string {
|
|||
break
|
||||
}
|
||||
}
|
||||
// Strip port if present
|
||||
if idx := strings.LastIndex(domain, ":"); idx > 0 {
|
||||
domain = domain[:idx]
|
||||
}
|
||||
|
||||
// Skip WS for localhost/dev — gateway only available in production
|
||||
if domain == "localhost" || domain == "127.0.0.1" || domain == "0.0.0.0" {
|
||||
return ""
|
||||
}
|
||||
|
||||
return "wss://unarr." + domain + "/ws/" + agentID
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue