feat(agent): add WebSocket transport with HTTP fallback

Add Transport interface abstraction supporting WebSocket (via CF
Durable Objects) and HTTP (direct to origin) with automatic failover.

- Transport interface: Register, SendHeartbeat, SendProgress, Events()
- HTTPTransport: thin adapter over existing Client
- WSTransport: gorilla/websocket with auth handshake, readLoop, reconnect
- HybridTransport: tries WS first, falls back to HTTP, reconnects in bg
- Daemon refactored to always use Transport (no dual-path forks)
- ProgressReporter accepts StatusReporter interface
- deriveWSURL skips localhost/dev (returns "" → HTTP-only)
- API key passed in WS query param for connection auth
- Fixed: reconnectOnce race (mutex+bool), authDone double-close (sync.Once)
- Fixed: forwardWSEvents goroutine leak (select with stop signal)
- 20 transport tests + 2 E2E tests (full lifecycle, hybrid failover)
This commit is contained in:
Deivid Soto 2026-03-28 18:55:29 +01:00
parent 5e80911501
commit 5f337eebd7
10 changed files with 1646 additions and 64 deletions

View file

@ -7,6 +7,7 @@ import (
"os"
"os/signal"
"path/filepath"
"strings"
"syscall"
"time"
@ -104,8 +105,6 @@ func runDaemonStart() error {
heartbeatInterval = 30 * time.Second
}
// Create agent client (direct HTTP — always available as fallback)
ac := agent.NewClient(cfg.Auth.APIURL, cfg.Auth.APIKey, "unarr/"+Version)
userAgent := "unarr/" + Version
// Create daemon config
@ -119,6 +118,8 @@ func runDaemonStart() error {
}
// Create transport: Hybrid (WS + HTTP fallback) or HTTP-only
httpT := agent.NewHTTPTransport(cfg.Auth.APIURL, cfg.Auth.APIKey, userAgent)
wsURL := cfg.Auth.WSURL
if wsURL == "" {
wsURL = deriveWSURL(cfg.Auth.APIURL, cfg.Agent.ID)
@ -126,28 +127,19 @@ func runDaemonStart() error {
var transport agent.Transport
if wsURL != "" {
httpT := agent.NewHTTPTransport(cfg.Auth.APIURL, cfg.Auth.APIKey, userAgent)
wsT := agent.NewWSTransport(wsURL, cfg.Auth.APIKey, cfg.Agent.ID, userAgent)
transport = agent.NewHybridTransport(wsT, httpT)
log.Printf("Transport: WebSocket (fallback: HTTP) → %s", wsURL)
} else {
transport = httpT
log.Println("Transport: HTTP only")
}
// Create daemon
var d *agent.Daemon
if transport != nil {
d = agent.NewDaemonWithTransport(daemonCfg, transport)
} else {
d = agent.NewDaemon(daemonCfg, ac)
}
// Create daemon — always uses Transport interface
d := agent.NewDaemon(daemonCfg, transport)
// Wire state tracking (connected after manager creation below)
// Create progress reporter
var reporter *engine.ProgressReporter
if transport != nil {
reporter = engine.NewProgressReporterWithTransport(transport, 3*time.Second)
} else {
reporter = engine.NewProgressReporter(ac, 3*time.Second)
}
// Create progress reporter using transport
reporter := engine.NewProgressReporterWithTransport(transport, 3*time.Second)
// Parse speed limits
maxDl, _ := config.ParseSpeed(cfg.Download.MaxDownloadSpeed)
@ -190,7 +182,7 @@ func runDaemonStart() error {
MoviesDir: cfg.Organize.MoviesDir,
TVShowsDir: cfg.Organize.TVShowsDir,
},
}, reporter, torrentDl, debridDl)
}, reporter, torrentDl, debridDl, engine.NewUsenetDownloader(httpT.Client()))
// Wire state tracking
d.GetActiveCount = manager.ActiveCount
@ -275,9 +267,9 @@ func runDaemonStart() error {
log.Printf("[%s] streaming from disk: %s → %s", sr.TaskID[:8], filepath.Base(sr.FilePath), streamURL)
// Report stream URL back to the server
// Report stream URL back to the server via transport
go func() {
if _, err := ac.ReportStatus(ctx, agent.StatusUpdate{
if _, err := transport.SendProgress(ctx, agent.StatusUpdate{
TaskID: sr.TaskID,
StreamURL: streamURL,
}); err != nil {
@ -298,13 +290,18 @@ func runDaemonStart() error {
case "resume":
log.Printf("[%s] resume requested via WebSocket", taskID[:8])
case "stream":
// Use registry mutex to prevent TOCTOU race with HTTP-polled stream requests
streamRegistry.mu.Lock()
if _, exists := streamRegistry.servers[taskID]; exists {
streamRegistry.mu.Unlock()
return
}
task := manager.GetTask(taskID)
if task == nil {
return
}
if task.GetStreamURL() != "" {
if task == nil || task.GetStreamURL() != "" {
streamRegistry.mu.Unlock()
return
}
streamRegistry.mu.Unlock()
srv, err := torrentDl.StartStream(taskID)
if err != nil {
log.Printf("[%s] stream failed: %v", taskID[:8], err)
@ -342,11 +339,7 @@ func runDaemonStart() error {
Version: result.NewVersion,
Error: errMsg,
}
if transport != nil {
_ = transport.ReportUpgradeResult(reportCtx, upgradeResult)
} else {
_ = ac.ReportUpgradeResult(reportCtx, upgradeResult)
}
_ = transport.ReportUpgradeResult(reportCtx, upgradeResult)
if !result.Success {
log.Printf("Upgrade failed: %v", result.Error)
@ -360,7 +353,7 @@ func runDaemonStart() error {
// Deregister first so the server knows we're restarting
deregCtx, deregCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer deregCancel()
_ = ac.Deregister(deregCtx, cfg.Agent.ID)
_ = transport.Deregister(deregCtx, cfg.Agent.ID)
// Flush progress reporter
cancel()
@ -418,6 +411,7 @@ func runDaemonStart() error {
// deriveWSURL derives a WebSocket URL from the API URL.
// https://torrentclaw.com → wss://unarr.torrentclaw.com/ws/{agentId}
// Returns "" for localhost/dev environments where WS gateway isn't available.
func deriveWSURL(apiURL, agentID string) string {
if apiURL == "" || agentID == "" {
return ""
@ -437,6 +431,15 @@ func deriveWSURL(apiURL, agentID string) string {
break
}
}
// Strip port if present
if idx := strings.LastIndex(domain, ":"); idx > 0 {
domain = domain[:idx]
}
// Skip WS for localhost/dev — gateway only available in production
if domain == "localhost" || domain == "127.0.0.1" || domain == "0.0.0.0" {
return ""
}
return "wss://unarr." + domain + "/ws/" + agentID
}