From 5f337eebd762391d07fd61d3bf515f4b9e755a87 Mon Sep 17 00:00:00 2001 From: Deivid Soto Date: Sat, 28 Mar 2026 18:55:29 +0100 Subject: [PATCH] feat(agent): add WebSocket transport with HTTP fallback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add Transport interface abstraction supporting WebSocket (via CF Durable Objects) and HTTP (direct to origin) with automatic failover. - Transport interface: Register, SendHeartbeat, SendProgress, Events() - HTTPTransport: thin adapter over existing Client - WSTransport: gorilla/websocket with auth handshake, readLoop, reconnect - HybridTransport: tries WS first, falls back to HTTP, reconnects in bg - Daemon refactored to always use Transport (no dual-path forks) - ProgressReporter accepts StatusReporter interface - deriveWSURL skips localhost/dev (returns "" → HTTP-only) - API key passed in WS query param for connection auth - Fixed: reconnectOnce race (mutex+bool), authDone double-close (sync.Once) - Fixed: forwardWSEvents goroutine leak (select with stop signal) - 20 transport tests + 2 E2E tests (full lifecycle, hybrid failover) --- internal/agent/daemon.go | 176 +++++++++-- internal/agent/transport.go | 53 ++++ internal/agent/transport_e2e_test.go | 295 ++++++++++++++++++ internal/agent/transport_http.go | 50 +++ internal/agent/transport_hybrid.go | 226 ++++++++++++++ internal/agent/transport_test.go | 445 +++++++++++++++++++++++++++ internal/agent/transport_ws.go | 360 ++++++++++++++++++++++ internal/cmd/daemon.go | 65 ++-- internal/config/config.go | 1 + internal/engine/progress.go | 39 ++- 10 files changed, 1646 insertions(+), 64 deletions(-) create mode 100644 internal/agent/transport.go create mode 100644 internal/agent/transport_e2e_test.go create mode 100644 internal/agent/transport_http.go create mode 100644 internal/agent/transport_hybrid.go create mode 100644 internal/agent/transport_test.go create mode 100644 internal/agent/transport_ws.go diff --git a/internal/agent/daemon.go b/internal/agent/daemon.go index 82924da..06e0c3e 100644 --- a/internal/agent/daemon.go +++ b/internal/agent/daemon.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log" + "os" "runtime" "time" ) @@ -20,20 +21,34 @@ type DaemonConfig struct { // Daemon manages the main loop: register, heartbeat, poll tasks. type Daemon struct { - cfg DaemonConfig - client *Client + cfg DaemonConfig + transport Transport // Callbacks - OnTasksClaimed func(tasks []Task) + OnTasksClaimed func(tasks []Task) + OnStreamRequested func(req StreamRequest) + OnUpgradeRequested func(version string) + OnControlAction func(action, taskID string) // State - User UserInfo - Features FeatureFlags - Info AgentInfo + User UserInfo + Features FeatureFlags + Info AgentInfo + State DaemonState + upgradeInProgress bool + heartbeatFailures int + + // Callbacks for state tracking (set by cmd/daemon.go) + GetActiveCount func() int + + // Exposed tickers for hot-reload + PollTicker *time.Ticker + HeartbeatTicker *time.Ticker } -// NewDaemon creates a daemon with the given config and agent client. -func NewDaemon(cfg DaemonConfig, client *Client) *Daemon { +// NewDaemon creates a daemon with the given transport. +// Use NewHTTPTransport for HTTP-only, or NewHybridTransport for WS+HTTP. +func NewDaemon(cfg DaemonConfig, transport Transport) *Daemon { if cfg.PollInterval == 0 { cfg.PollInterval = 30 * time.Second } @@ -42,11 +57,14 @@ func NewDaemon(cfg DaemonConfig, client *Client) *Daemon { } return &Daemon{ - cfg: cfg, - client: client, + cfg: cfg, + transport: transport, } } +// Transport returns the configured transport. +func (d *Daemon) Transport() Transport { return d.transport } + // Register registers the agent and fetches user info + features. func (d *Daemon) Register(ctx context.Context) error { req := RegisterRequest{ @@ -62,20 +80,30 @@ func (d *Daemon) Register(ctx context.Context) error { req.DiskTotalBytes = total } - resp, err := d.client.Register(ctx, req) + resp, err := d.transport.Register(ctx, req) if err != nil { return fmt.Errorf("register: %w", err) } d.User = resp.User d.Features = resp.Features + now := time.Now() d.Info = AgentInfo{ ID: d.cfg.AgentID, Name: d.cfg.AgentName, User: resp.User, Features: resp.Features, - StartedAt: time.Now(), + StartedAt: now, } + d.State = DaemonState{ + AgentID: d.cfg.AgentID, + Status: "running", + Version: d.cfg.Version, + PID: os.Getpid(), + StartedAt: now, + MethodStats: make(map[string]int), + } + WriteState(&d.State) return nil } @@ -91,26 +119,38 @@ func (d *Daemon) Run(ctx context.Context) error { log.Printf("Features: torrent=%v debrid=%v usenet=%v", d.Features.Torrent, d.Features.Debrid, d.Features.Usenet) log.Printf("Polling every %s, heartbeat every %s", d.cfg.PollInterval, d.cfg.HeartbeatInterval) - heartbeatTicker := time.NewTicker(d.cfg.HeartbeatInterval) - defer heartbeatTicker.Stop() + d.HeartbeatTicker = time.NewTicker(d.cfg.HeartbeatInterval) + defer d.HeartbeatTicker.Stop() - pollTicker := time.NewTicker(d.cfg.PollInterval) - defer pollTicker.Stop() + d.PollTicker = time.NewTicker(d.cfg.PollInterval) + defer d.PollTicker.Stop() + + heartbeatTicker := d.HeartbeatTicker + pollTicker := d.PollTicker // Initial poll immediately d.poll(ctx) + eventsCh := d.transport.Events() + for { select { case <-ctx.Done(): log.Println("Daemon shutting down...") + d.deregister() return nil + case event := <-eventsCh: + d.handleEvent(event) + case <-heartbeatTicker.C: d.heartbeat(ctx) case <-pollTicker.C: - d.poll(ctx) + // Only poll in HTTP mode — WS mode receives tasks via Events + if d.transport.Mode() == "http" { + d.poll(ctx) + } } } } @@ -128,13 +168,93 @@ func (d *Daemon) heartbeat(ctx context.Context) { req.DiskTotalBytes = total } - if err := d.client.Heartbeat(ctx, req); err != nil { - log.Printf("Heartbeat failed: %v", err) + resp, err := d.transport.SendHeartbeat(ctx, req) + if err != nil { + d.heartbeatFailures++ + if d.heartbeatFailures >= 5 && d.heartbeatFailures%5 == 0 { + log.Printf("CRITICAL: %d consecutive heartbeat failures — server may be unreachable", d.heartbeatFailures) + } else { + log.Printf("Heartbeat failed: %v", err) + } + return + } + if d.heartbeatFailures > 0 { + log.Printf("Heartbeat recovered after %d failures", d.heartbeatFailures) + d.heartbeatFailures = 0 + } + + // Update state file + d.State.LastHeartbeat = time.Now() + if d.GetActiveCount != nil { + d.State.ActiveTasks = d.GetActiveCount() + } + WriteState(&d.State) + + // Check for upgrade signal from server + if resp.Upgrade != nil && resp.Upgrade.Version != "" && !d.upgradeInProgress { + d.upgradeInProgress = true + log.Printf("Upgrade requested by server: %s → %s", d.cfg.Version, resp.Upgrade.Version) + if d.OnUpgradeRequested != nil { + go d.OnUpgradeRequested(resp.Upgrade.Version) + } } } +// handleEvent processes a server-initiated event from the WebSocket transport. +func (d *Daemon) handleEvent(event ServerEvent) { + switch event.Type { + case "tasks": + if event.Tasks != nil && len(event.Tasks.Tasks) > 0 { + log.Printf("Received %d task(s) via WebSocket", len(event.Tasks.Tasks)) + if d.OnTasksClaimed != nil { + d.OnTasksClaimed(event.Tasks.Tasks) + } + } + if event.Tasks != nil && d.OnStreamRequested != nil { + for _, sr := range event.Tasks.StreamRequests { + d.OnStreamRequested(sr) + } + } + + case "upgrade": + if event.Upgrade != nil && event.Upgrade.Version != "" && !d.upgradeInProgress { + d.upgradeInProgress = true + log.Printf("Upgrade requested via WebSocket: %s → %s", d.cfg.Version, event.Upgrade.Version) + if d.OnUpgradeRequested != nil { + go d.OnUpgradeRequested(event.Upgrade.Version) + } + } + + case "control": + if event.Control != nil && d.OnControlAction != nil { + log.Printf("Control action via WebSocket: %s task %s", event.Control.Action, event.Control.TaskID) + d.OnControlAction(event.Control.Action, event.Control.TaskID) + } + + case "disconnected": + log.Println("WebSocket disconnected, switching to HTTP polling") + } +} + +// ClearUpgradeInProgress resets the upgrade flag so a retry can be attempted. +func (d *Daemon) ClearUpgradeInProgress() { + d.upgradeInProgress = false +} + +func (d *Daemon) deregister() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + err := d.transport.Deregister(ctx, d.cfg.AgentID) + if err != nil { + log.Printf("Deregister failed: %v", err) + } else { + log.Println("Agent deregistered") + } + RemoveState() +} + func (d *Daemon) poll(ctx context.Context) { - tasks, err := d.client.ClaimTasks(ctx, d.cfg.AgentID) + resp, err := d.transport.ClaimTasks(ctx, d.cfg.AgentID) if err != nil { log.Printf("Poll failed: %v", err) return @@ -142,13 +262,17 @@ func (d *Daemon) poll(ctx context.Context) { d.Info.LastPollAt = time.Now() - if len(tasks) == 0 { - return + if len(resp.Tasks) > 0 { + log.Printf("Claimed %d task(s)", len(resp.Tasks)) + if d.OnTasksClaimed != nil { + d.OnTasksClaimed(resp.Tasks) + } } - log.Printf("Claimed %d task(s)", len(tasks)) - - if d.OnTasksClaimed != nil { - d.OnTasksClaimed(tasks) + // Handle stream requests for completed downloads + if d.OnStreamRequested != nil { + for _, sr := range resp.StreamRequests { + d.OnStreamRequested(sr) + } } } diff --git a/internal/agent/transport.go b/internal/agent/transport.go new file mode 100644 index 0000000..9aeee53 --- /dev/null +++ b/internal/agent/transport.go @@ -0,0 +1,53 @@ +package agent + +import "context" + +// Transport abstracts the communication protocol between the agent and server. +// Both WebSocket (via CF Durable Object) and HTTP (direct to origin) implement this. +type Transport interface { + // Connect establishes the transport connection. + Connect(ctx context.Context) error + + // Close tears down the connection gracefully. + Close() error + + // Mode returns the current transport mode ("ws" or "http"). + Mode() string + + // Register sends agent registration and returns user info + features. + Register(ctx context.Context, req RegisterRequest) (*RegisterResponse, error) + + // SendHeartbeat sends a periodic keep-alive. + SendHeartbeat(ctx context.Context, req HeartbeatRequest) (*HeartbeatResponse, error) + + // SendProgress reports download progress for a task. + SendProgress(ctx context.Context, update StatusUpdate) (*StatusResponse, error) + + // ClaimTasks polls for new tasks (HTTP mode only; WS receives via Events). + ClaimTasks(ctx context.Context, agentID string) (*TasksResponse, error) + + // Deregister notifies the server of graceful shutdown. + Deregister(ctx context.Context, agentID string) error + + // ReportUpgradeResult reports upgrade outcome. + ReportUpgradeResult(ctx context.Context, result UpgradeResult) error + + // Events returns a channel that emits server-initiated events. + // In HTTP mode this channel is never written to (polling handles it). + // In WS mode, tasks/upgrade/control arrive here. + Events() <-chan ServerEvent +} + +// ServerEvent represents a server-initiated message received via WebSocket. +type ServerEvent struct { + Type string // "tasks", "upgrade", "control", "disconnected" + Tasks *TasksResponse // populated when Type == "tasks" + Upgrade *UpgradeSignal // populated when Type == "upgrade" + Control *ControlAction // populated when Type == "control" +} + +// ControlAction represents a server push for task control. +type ControlAction struct { + Action string `json:"action"` // "pause", "resume", "cancel", "stream" + TaskID string `json:"taskId"` +} diff --git a/internal/agent/transport_e2e_test.go b/internal/agent/transport_e2e_test.go new file mode 100644 index 0000000..0dd3668 --- /dev/null +++ b/internal/agent/transport_e2e_test.go @@ -0,0 +1,295 @@ +package agent + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" +) + +// TestE2EFullLifecycle tests the full lifecycle: +// connect → auth → receive tasks → send progress → receive control → disconnect → reconnect +func TestE2EFullLifecycle(t *testing.T) { + var mu sync.Mutex + var receivedMessages []map[string]interface{} + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + + for { + _, msg, err := conn.ReadMessage() + if err != nil { + return + } + + var parsed map[string]interface{} + json.Unmarshal(msg, &parsed) + + mu.Lock() + receivedMessages = append(receivedMessages, parsed) + mu.Unlock() + + msgType, _ := parsed["type"].(string) + switch msgType { + case "auth": + conn.WriteJSON(wsRegisteredMessage{ + Type: "registered", + User: UserInfo{Name: "E2E User", Plan: "pro", IsPro: true}, + Features: FeatureFlags{Torrent: true, Debrid: true}, + }) + + case "heartbeat": + // No response in WS mode + + case "progress": + // Simulate server-side cancel after progress + if progress, ok := parsed["progress"].(float64); ok && progress >= 50 { + conn.WriteJSON(map[string]string{ + "type": "control", + "action": "cancel", + "taskId": parsed["taskId"].(string), + }) + } + + case "upgrade-result": + // Acknowledged + } + } + })) + defer srv.Close() + + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + tr := NewWSTransport(wsURL, "e2e-key", "e2e-agent", "test/1.0") + + ctx := context.Background() + + // 1. Connect + if err := tr.Connect(ctx); err != nil { + t.Fatalf("Connect: %v", err) + } + defer tr.Close() + + // 2. Auth + resp, err := tr.Register(ctx, RegisterRequest{ + AgentID: "e2e-agent", + Name: "E2E Test Agent", + Version: "1.0.0", + OS: "linux", + Arch: "amd64", + }) + if err != nil { + t.Fatalf("Register: %v", err) + } + if resp.User.Name != "E2E User" { + t.Errorf("expected E2E User, got %s", resp.User.Name) + } + if !resp.Features.Debrid { + t.Error("expected debrid feature") + } + + // 3. Send heartbeat + _, err = tr.SendHeartbeat(ctx, HeartbeatRequest{ + AgentID: "e2e-agent", + DiskFreeBytes: 1000000000, + DiskTotalBytes: 5000000000, + }) + if err != nil { + t.Fatalf("SendHeartbeat: %v", err) + } + + // 4. Send progress (50% → should trigger cancel control) + _, err = tr.SendProgress(ctx, StatusUpdate{ + TaskID: "task-e2e-1", + Status: "downloading", + Progress: 50, + DownloadedBytes: 500, + TotalBytes: 1000, + SpeedBps: 100, + }) + if err != nil { + t.Fatalf("SendProgress: %v", err) + } + + // 5. Wait for control event (cancel) + select { + case event := <-tr.Events(): + if event.Type != "control" { + t.Errorf("expected control event, got %s", event.Type) + } + if event.Control.Action != "cancel" { + t.Errorf("expected cancel, got %s", event.Control.Action) + } + if event.Control.TaskID != "task-e2e-1" { + t.Errorf("expected task-e2e-1, got %s", event.Control.TaskID) + } + case <-time.After(3 * time.Second): + t.Fatal("timeout waiting for cancel control") + } + + // 6. Send upgrade result + err = tr.ReportUpgradeResult(ctx, UpgradeResult{ + AgentID: "e2e-agent", + Success: true, + Version: "2.0.0", + }) + if err != nil { + t.Fatalf("ReportUpgradeResult: %v", err) + } + + // Verify server received all messages + time.Sleep(100 * time.Millisecond) + mu.Lock() + defer mu.Unlock() + + if len(receivedMessages) < 4 { + t.Fatalf("expected at least 4 messages, got %d", len(receivedMessages)) + } + + types := make([]string, len(receivedMessages)) + for i, m := range receivedMessages { + types[i], _ = m["type"].(string) + } + + expected := []string{"auth", "heartbeat", "progress", "upgrade-result"} + for _, exp := range expected { + found := false + for _, got := range types { + if got == exp { + found = true + break + } + } + if !found { + t.Errorf("missing message type %q in %v", exp, types) + } + } +} + +// TestE2EHybridFailover tests the full failover scenario: +// WS connect → download → WS disconnect → switch to HTTP → continue working +func TestE2EHybridFailover(t *testing.T) { + connectionCount := 0 + var mu sync.Mutex + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + + mu.Lock() + connectionCount++ + connNum := connectionCount + mu.Unlock() + + // Read auth + conn.ReadMessage() + conn.WriteJSON(wsRegisteredMessage{ + Type: "registered", + User: UserInfo{Name: "Failover User"}, + }) + + if connNum == 1 { + // First connection: push tasks then disconnect after 200ms + time.Sleep(50 * time.Millisecond) + conn.WriteJSON(wsTasksMessage{ + Type: "tasks", + Tasks: []Task{{ID: "t1", InfoHash: "abc", Title: "Failover Movie"}}, + }) + time.Sleep(150 * time.Millisecond) + conn.Close() + } else { + // Second connection (after reconnect): push upgrade + time.Sleep(50 * time.Millisecond) + conn.WriteJSON(wsUpgradeMessage{Type: "upgrade", Version: "3.0.0"}) + time.Sleep(500 * time.Millisecond) + conn.Close() + } + })) + defer srv.Close() + + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + wsT := NewWSTransport(wsURL, "key", "a1", "ua") + + // HTTP mock for fallback + httpSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Simple heartbeat response + json.NewEncoder(w).Encode(HeartbeatResponse{Success: true}) + })) + defer httpSrv.Close() + + httpT := NewHTTPTransport(httpSrv.URL, "key", "ua") + h := NewHybridTransport(wsT, httpT) + + ctx := context.Background() + err := h.Connect(ctx) + if err != nil { + t.Fatalf("Connect: %v", err) + } + defer h.Close() + + // Should start in WS mode + if h.Mode() != "ws" { + t.Fatalf("expected ws mode, got %s", h.Mode()) + } + + // Register via WS + _, err = h.Register(ctx, RegisterRequest{AgentID: "a1"}) + if err != nil { + t.Fatalf("Register: %v", err) + } + + // Receive tasks via WS + var tasksReceived bool + var disconnected bool + + for i := 0; i < 3; i++ { + select { + case event := <-h.Events(): + switch event.Type { + case "tasks": + tasksReceived = true + if len(event.Tasks.Tasks) != 1 || event.Tasks.Tasks[0].Title != "Failover Movie" { + t.Errorf("unexpected tasks: %+v", event.Tasks) + } + case "disconnected": + disconnected = true + } + case <-time.After(2 * time.Second): + break + } + if disconnected { + break + } + } + + if !tasksReceived { + t.Error("did not receive tasks before disconnect") + } + if !disconnected { + t.Error("did not receive disconnect event") + } + + // Should now be in HTTP mode + time.Sleep(100 * time.Millisecond) + if h.Mode() != "http" { + t.Errorf("expected http mode after disconnect, got %s", h.Mode()) + } + + // Heartbeat should work via HTTP fallback + hbResp, err := h.SendHeartbeat(ctx, HeartbeatRequest{AgentID: "a1"}) + if err != nil { + t.Fatalf("SendHeartbeat via HTTP fallback: %v", err) + } + if !hbResp.Success { + t.Error("expected heartbeat success") + } +} diff --git a/internal/agent/transport_http.go b/internal/agent/transport_http.go new file mode 100644 index 0000000..d5f52a4 --- /dev/null +++ b/internal/agent/transport_http.go @@ -0,0 +1,50 @@ +package agent + +import "context" + +// HTTPTransport wraps the existing Client to implement Transport. +// This is a thin adapter — no behavioral changes from the current HTTP protocol. +type HTTPTransport struct { + client *Client + events chan ServerEvent +} + +// NewHTTPTransport creates a new HTTP-based transport. +func NewHTTPTransport(baseURL, apiKey, userAgent string) *HTTPTransport { + return &HTTPTransport{ + client: NewClient(baseURL, apiKey, userAgent), + events: make(chan ServerEvent, 10), + } +} + +func (t *HTTPTransport) Connect(_ context.Context) error { return nil } +func (t *HTTPTransport) Close() error { return nil } +func (t *HTTPTransport) Mode() string { return "http" } +func (t *HTTPTransport) Events() <-chan ServerEvent { return t.events } + +func (t *HTTPTransport) Register(ctx context.Context, req RegisterRequest) (*RegisterResponse, error) { + return t.client.Register(ctx, req) +} + +func (t *HTTPTransport) SendHeartbeat(ctx context.Context, req HeartbeatRequest) (*HeartbeatResponse, error) { + return t.client.Heartbeat(ctx, req) +} + +func (t *HTTPTransport) SendProgress(ctx context.Context, update StatusUpdate) (*StatusResponse, error) { + return t.client.ReportStatus(ctx, update) +} + +func (t *HTTPTransport) ClaimTasks(ctx context.Context, agentID string) (*TasksResponse, error) { + return t.client.ClaimTasks(ctx, agentID) +} + +func (t *HTTPTransport) Deregister(ctx context.Context, agentID string) error { + return t.client.Deregister(ctx, agentID) +} + +func (t *HTTPTransport) ReportUpgradeResult(ctx context.Context, result UpgradeResult) error { + return t.client.ReportUpgradeResult(ctx, result) +} + +// Client returns the underlying HTTP client for direct use if needed. +func (t *HTTPTransport) Client() *Client { return t.client } diff --git a/internal/agent/transport_hybrid.go b/internal/agent/transport_hybrid.go new file mode 100644 index 0000000..c2bd831 --- /dev/null +++ b/internal/agent/transport_hybrid.go @@ -0,0 +1,226 @@ +package agent + +import ( + "context" + "log" + "sync" + "sync/atomic" + "time" +) + +// HybridTransport tries WebSocket first, falls back to HTTP if WS fails. +// Automatically reconnects WS in the background. +type HybridTransport struct { + ws *WSTransport + http *HTTPTransport + + mode atomic.Value // "ws" or "http" + events chan ServerEvent + + reconnectMu sync.Mutex + reconnectRunning bool + reconnectStop chan struct{} + closed atomic.Bool +} + +// NewHybridTransport creates a transport that prefers WS with HTTP fallback. +func NewHybridTransport(ws *WSTransport, http *HTTPTransport) *HybridTransport { + h := &HybridTransport{ + ws: ws, + http: http, + events: make(chan ServerEvent, 50), + reconnectStop: make(chan struct{}), + } + h.mode.Store("http") // start in HTTP, upgrade to WS on Connect + return h +} + +func (h *HybridTransport) Mode() string { return h.mode.Load().(string) } +func (h *HybridTransport) Events() <-chan ServerEvent { return h.events } + +// Connect tries WS first. If it fails, falls back to HTTP and starts reconnection loop. +func (h *HybridTransport) Connect(ctx context.Context) error { + // Try WebSocket first + if err := h.ws.Connect(ctx); err != nil { + log.Printf("[transport] WebSocket connect failed (%v), using HTTP fallback", err) + h.mode.Store("http") + h.startReconnectLoop() + return h.http.Connect(ctx) + } + + h.mode.Store("ws") + log.Println("[transport] Connected via WebSocket") + + // Forward WS events to unified channel + watch for disconnection + go h.forwardWSEvents() + + return nil +} + +// Close shuts down both transports and stops reconnection. +func (h *HybridTransport) Close() error { + h.closed.Store(true) + select { + case <-h.reconnectStop: + default: + close(h.reconnectStop) + } + _ = h.ws.Close() + return h.http.Close() +} + +// Register delegates to the active transport. +func (h *HybridTransport) Register(ctx context.Context, req RegisterRequest) (*RegisterResponse, error) { + if h.mode.Load() == "ws" { + return h.ws.Register(ctx, req) + } + return h.http.Register(ctx, req) +} + +// SendHeartbeat delegates to the active transport. +func (h *HybridTransport) SendHeartbeat(ctx context.Context, req HeartbeatRequest) (*HeartbeatResponse, error) { + if h.mode.Load() == "ws" { + resp, err := h.ws.SendHeartbeat(ctx, req) + if err != nil { + // WS write failed — switch to HTTP + h.switchToHTTP() + return h.http.SendHeartbeat(ctx, req) + } + return resp, nil + } + return h.http.SendHeartbeat(ctx, req) +} + +// SendProgress delegates to the active transport. +func (h *HybridTransport) SendProgress(ctx context.Context, update StatusUpdate) (*StatusResponse, error) { + if h.mode.Load() == "ws" { + resp, err := h.ws.SendProgress(ctx, update) + if err != nil { + h.switchToHTTP() + return h.http.SendProgress(ctx, update) + } + return resp, nil + } + return h.http.SendProgress(ctx, update) +} + +// ClaimTasks delegates to the active transport. +func (h *HybridTransport) ClaimTasks(ctx context.Context, agentID string) (*TasksResponse, error) { + if h.mode.Load() == "ws" { + return h.ws.ClaimTasks(ctx, agentID) // no-op in WS mode + } + return h.http.ClaimTasks(ctx, agentID) +} + +// Deregister delegates to the active transport. +func (h *HybridTransport) Deregister(ctx context.Context, agentID string) error { + if h.mode.Load() == "ws" { + return h.ws.Deregister(ctx, agentID) + } + return h.http.Deregister(ctx, agentID) +} + +// ReportUpgradeResult delegates to the active transport. +func (h *HybridTransport) ReportUpgradeResult(ctx context.Context, result UpgradeResult) error { + if h.mode.Load() == "ws" { + if err := h.ws.ReportUpgradeResult(ctx, result); err != nil { + h.switchToHTTP() + return h.http.ReportUpgradeResult(ctx, result) + } + return nil + } + return h.http.ReportUpgradeResult(ctx, result) +} + +// ── Internal ───────────────────────────────────────────────────────────────── + +func (h *HybridTransport) switchToHTTP() { + if h.mode.Load() == "http" { + return + } + log.Println("[transport] Switching to HTTP fallback") + h.mode.Store("http") + _ = h.ws.Close() + h.startReconnectLoop() +} + +func (h *HybridTransport) forwardWSEvents() { + for { + select { + case <-h.reconnectStop: + return + case event, ok := <-h.ws.Events(): + if !ok { + return // channel closed + } + if event.Type == "disconnected" { + h.switchToHTTP() + select { + case h.events <- event: + default: + } + return + } + select { + case h.events <- event: + default: + log.Printf("[transport] events channel full, dropping %s event", event.Type) + } + } + } +} + +func (h *HybridTransport) startReconnectLoop() { + h.reconnectMu.Lock() + defer h.reconnectMu.Unlock() + if h.reconnectRunning { + return + } + h.reconnectRunning = true + go h.reconnectLoop() +} + +func (h *HybridTransport) reconnectLoop() { + backoff := 5 * time.Second + maxBackoff := 60 * time.Second + + for { + select { + case <-h.reconnectStop: + return + case <-time.After(backoff): + } + + if h.closed.Load() { + return + } + + // Already on WS? (someone else reconnected) + if h.mode.Load() == "ws" { + return + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + err := h.ws.Connect(ctx) + cancel() + + if err != nil { + log.Printf("[transport] WS reconnect failed: %v (retry in %v)", err, backoff) + backoff = min(backoff*2, maxBackoff) + continue + } + + // WS reconnected — switch back + log.Println("[transport] WebSocket reconnected") + h.mode.Store("ws") + + // Reset reconnect flag so loop can start again if WS drops + h.reconnectMu.Lock() + h.reconnectRunning = false + h.reconnectMu.Unlock() + + // Forward events from new WS connection + go h.forwardWSEvents() + return + } +} diff --git a/internal/agent/transport_test.go b/internal/agent/transport_test.go new file mode 100644 index 0000000..e2270a8 --- /dev/null +++ b/internal/agent/transport_test.go @@ -0,0 +1,445 @@ +package agent + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "github.com/gorilla/websocket" +) + +// ── HTTP Transport Tests ───────────────────────────────────────────────────── + +func TestHTTPTransportMode(t *testing.T) { + tr := NewHTTPTransport("http://localhost", "key", "ua") + if tr.Mode() != "http" { + t.Errorf("expected http, got %s", tr.Mode()) + } +} + +func TestHTTPTransportEventsNeverEmit(t *testing.T) { + tr := NewHTTPTransport("http://localhost", "key", "ua") + select { + case <-tr.Events(): + t.Error("events channel should never emit in HTTP mode") + case <-time.After(50 * time.Millisecond): + // expected + } +} + +func TestHTTPTransportDelegates(t *testing.T) { + // Mock server for register + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(RegisterResponse{ + Success: true, + User: UserInfo{Name: "Test", Plan: "pro"}, + }) + })) + defer srv.Close() + + tr := NewHTTPTransport(srv.URL, "test-key", "test-agent") + resp, err := tr.Register(context.Background(), RegisterRequest{AgentID: "a1"}) + if err != nil { + t.Fatalf("Register failed: %v", err) + } + if !resp.Success { + t.Error("expected success") + } + if resp.User.Name != "Test" { + t.Errorf("expected Test, got %s", resp.User.Name) + } +} + +// ── WebSocket Transport Tests ──────────────────────────────────────────────── + +var upgrader = websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, +} + +func TestWSTransportConnectAndAuth(t *testing.T) { + var received wsAuthMessage + var mu sync.Mutex + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Fatalf("upgrade: %v", err) + } + defer conn.Close() + + // Read auth message + _, msg, err := conn.ReadMessage() + if err != nil { + return + } + mu.Lock() + json.Unmarshal(msg, &received) + mu.Unlock() + + // Send registered response + conn.WriteJSON(wsRegisteredMessage{ + Type: "registered", + User: UserInfo{Name: "WS User", Plan: "pro", IsPro: true}, + Features: FeatureFlags{Torrent: true}, + }) + + // Keep connection open + time.Sleep(500 * time.Millisecond) + })) + defer srv.Close() + + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + tr := NewWSTransport(wsURL, "my-api-key", "agent-123", "test/1.0") + + ctx := context.Background() + if err := tr.Connect(ctx); err != nil { + t.Fatalf("Connect failed: %v", err) + } + defer tr.Close() + + resp, err := tr.Register(ctx, RegisterRequest{ + AgentID: "agent-123", + Name: "test-agent", + Version: "1.0.0", + }) + if err != nil { + t.Fatalf("Register failed: %v", err) + } + if !resp.Success { + t.Error("expected success") + } + if resp.User.Name != "WS User" { + t.Errorf("expected WS User, got %s", resp.User.Name) + } + + mu.Lock() + if received.APIKey != "my-api-key" { + t.Errorf("expected my-api-key, got %s", received.APIKey) + } + if received.AgentID != "agent-123" { + t.Errorf("expected agent-123, got %s", received.AgentID) + } + mu.Unlock() +} + +func TestWSTransportReceiveTasks(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + + // Read auth + conn.ReadMessage() + conn.WriteJSON(wsRegisteredMessage{ + Type: "registered", + User: UserInfo{Name: "Test"}, + }) + + // Push tasks + time.Sleep(50 * time.Millisecond) + conn.WriteJSON(wsTasksMessage{ + Type: "tasks", + Tasks: []Task{ + {ID: "t1", InfoHash: "abc123", Title: "Test Movie"}, + {ID: "t2", InfoHash: "def456", Title: "Test Show"}, + }, + }) + + time.Sleep(500 * time.Millisecond) + })) + defer srv.Close() + + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + tr := NewWSTransport(wsURL, "key", "agent1", "ua") + + ctx := context.Background() + tr.Connect(ctx) + defer tr.Close() + + tr.Register(ctx, RegisterRequest{AgentID: "agent1"}) + + // Wait for tasks event + select { + case event := <-tr.Events(): + if event.Type != "tasks" { + t.Errorf("expected tasks, got %s", event.Type) + } + if len(event.Tasks.Tasks) != 2 { + t.Errorf("expected 2 tasks, got %d", len(event.Tasks.Tasks)) + } + if event.Tasks.Tasks[0].Title != "Test Movie" { + t.Errorf("expected Test Movie, got %s", event.Tasks.Tasks[0].Title) + } + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for tasks event") + } +} + +func TestWSTransportReceiveControl(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + + conn.ReadMessage() + conn.WriteJSON(wsRegisteredMessage{Type: "registered", User: UserInfo{}}) + + time.Sleep(50 * time.Millisecond) + conn.WriteJSON(map[string]string{ + "type": "control", + "action": "cancel", + "taskId": "task-99", + }) + + time.Sleep(500 * time.Millisecond) + })) + defer srv.Close() + + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + tr := NewWSTransport(wsURL, "key", "a1", "ua") + + ctx := context.Background() + tr.Connect(ctx) + defer tr.Close() + tr.Register(ctx, RegisterRequest{AgentID: "a1"}) + + select { + case event := <-tr.Events(): + if event.Type != "control" { + t.Errorf("expected control, got %s", event.Type) + } + if event.Control.Action != "cancel" { + t.Errorf("expected cancel, got %s", event.Control.Action) + } + if event.Control.TaskID != "task-99" { + t.Errorf("expected task-99, got %s", event.Control.TaskID) + } + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for control event") + } +} + +func TestWSTransportReceiveUpgrade(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + + conn.ReadMessage() + conn.WriteJSON(wsRegisteredMessage{Type: "registered", User: UserInfo{}}) + + time.Sleep(50 * time.Millisecond) + conn.WriteJSON(wsUpgradeMessage{Type: "upgrade", Version: "2.0.0"}) + + time.Sleep(500 * time.Millisecond) + })) + defer srv.Close() + + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + tr := NewWSTransport(wsURL, "key", "a1", "ua") + + ctx := context.Background() + tr.Connect(ctx) + defer tr.Close() + tr.Register(ctx, RegisterRequest{AgentID: "a1"}) + + select { + case event := <-tr.Events(): + if event.Type != "upgrade" { + t.Errorf("expected upgrade, got %s", event.Type) + } + if event.Upgrade.Version != "2.0.0" { + t.Errorf("expected 2.0.0, got %s", event.Upgrade.Version) + } + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for upgrade event") + } +} + +func TestWSTransportDisconnect(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + + conn.ReadMessage() + conn.WriteJSON(wsRegisteredMessage{Type: "registered", User: UserInfo{}}) + + // Close after a short delay to simulate disconnection + time.Sleep(100 * time.Millisecond) + conn.Close() + })) + defer srv.Close() + + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + tr := NewWSTransport(wsURL, "key", "a1", "ua") + + ctx := context.Background() + tr.Connect(ctx) + defer tr.Close() + tr.Register(ctx, RegisterRequest{AgentID: "a1"}) + + select { + case event := <-tr.Events(): + if event.Type != "disconnected" { + t.Errorf("expected disconnected, got %s", event.Type) + } + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for disconnected event") + } +} + +func TestWSTransportSendProgress(t *testing.T) { + var receivedMsg map[string]interface{} + var mu sync.Mutex + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + + // Read auth + conn.ReadMessage() + conn.WriteJSON(wsRegisteredMessage{Type: "registered", User: UserInfo{}}) + + // Read progress + _, msg, err := conn.ReadMessage() + if err != nil { + return + } + mu.Lock() + json.Unmarshal(msg, &receivedMsg) + mu.Unlock() + + time.Sleep(500 * time.Millisecond) + })) + defer srv.Close() + + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + tr := NewWSTransport(wsURL, "key", "a1", "ua") + + ctx := context.Background() + tr.Connect(ctx) + defer tr.Close() + tr.Register(ctx, RegisterRequest{AgentID: "a1"}) + + time.Sleep(50 * time.Millisecond) + resp, err := tr.SendProgress(ctx, StatusUpdate{ + TaskID: "t1", + Status: "downloading", + Progress: 42, + }) + if err != nil { + t.Fatalf("SendProgress failed: %v", err) + } + if !resp.Success { + t.Error("expected success response") + } + + time.Sleep(100 * time.Millisecond) + mu.Lock() + if receivedMsg["type"] != "progress" { + t.Errorf("expected progress, got %v", receivedMsg["type"]) + } + if receivedMsg["taskId"] != "t1" { + t.Errorf("expected t1, got %v", receivedMsg["taskId"]) + } + mu.Unlock() +} + +// ── Hybrid Transport Tests ─────────────────────────────────────────────────── + +func TestHybridTransportWSSuccess(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + time.Sleep(500 * time.Millisecond) + })) + defer srv.Close() + + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + wsT := NewWSTransport(wsURL, "key", "a1", "ua") + httpT := NewHTTPTransport("http://localhost", "key", "ua") + + h := NewHybridTransport(wsT, httpT) + err := h.Connect(context.Background()) + if err != nil { + t.Fatalf("Connect failed: %v", err) + } + defer h.Close() + + if h.Mode() != "ws" { + t.Errorf("expected ws mode, got %s", h.Mode()) + } +} + +func TestHybridTransportWSFailFallbackHTTP(t *testing.T) { + // WS URL points to nowhere + wsT := NewWSTransport("ws://127.0.0.1:1", "key", "a1", "ua") + httpT := NewHTTPTransport("http://localhost", "key", "ua") + + h := NewHybridTransport(wsT, httpT) + err := h.Connect(context.Background()) + if err != nil { + t.Fatalf("Connect should succeed with HTTP fallback: %v", err) + } + defer h.Close() + + if h.Mode() != "http" { + t.Errorf("expected http mode after WS failure, got %s", h.Mode()) + } +} + +func TestHybridTransportWSDisconnectSwitchesToHTTP(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + // Close immediately to trigger disconnect + time.Sleep(100 * time.Millisecond) + conn.Close() + })) + defer srv.Close() + + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + wsT := NewWSTransport(wsURL, "key", "a1", "ua") + httpT := NewHTTPTransport("http://localhost", "key", "ua") + + h := NewHybridTransport(wsT, httpT) + h.Connect(context.Background()) + defer h.Close() + + // Wait for disconnect event + select { + case event := <-h.Events(): + if event.Type != "disconnected" { + t.Errorf("expected disconnected, got %s", event.Type) + } + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for disconnected event") + } + + // Mode should be HTTP now + time.Sleep(100 * time.Millisecond) + if h.Mode() != "http" { + t.Errorf("expected http after disconnect, got %s", h.Mode()) + } +} diff --git a/internal/agent/transport_ws.go b/internal/agent/transport_ws.go new file mode 100644 index 0000000..8625b7c --- /dev/null +++ b/internal/agent/transport_ws.go @@ -0,0 +1,360 @@ +package agent + +import ( + "context" + "encoding/json" + "fmt" + "log" + "net/http" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/gorilla/websocket" +) + +// WSTransport communicates with the server via WebSocket through a Cloudflare Durable Object. +type WSTransport struct { + wsURL string // wss://unarr.torrentclaw.com/ws/{agentId} + apiKey string + agentID string + userAgent string + + conn *websocket.Conn + mu sync.Mutex + events chan ServerEvent + closed atomic.Bool + + // Cached auth response from the DO + authResp *RegisterResponse + authMu sync.Mutex + authDone chan struct{} + authDoneOnce sync.Once +} + +// NewWSTransport creates a WebSocket-based transport. +func NewWSTransport(wsURL, apiKey, agentID, userAgent string) *WSTransport { + return &WSTransport{ + wsURL: wsURL, + apiKey: apiKey, + agentID: agentID, + userAgent: userAgent, + events: make(chan ServerEvent, 50), + authDone: make(chan struct{}), + } +} + +func (t *WSTransport) Mode() string { return "ws" } +func (t *WSTransport) Events() <-chan ServerEvent { return t.events } + +// Connect dials the WebSocket server and starts the read loop. +func (t *WSTransport) Connect(ctx context.Context) error { + dialer := websocket.Dialer{ + HandshakeTimeout: 10 * time.Second, + } + + header := http.Header{} + header.Set("User-Agent", t.userAgent) + + // Append API key as query param for auth on WS upgrade + wsURLWithKey := t.wsURL + if t.apiKey != "" { + sep := "?" + if strings.Contains(wsURLWithKey, "?") { + sep = "&" + } + wsURLWithKey += sep + "key=" + t.apiKey + } + + conn, _, err := dialer.DialContext(ctx, wsURLWithKey, header) + if err != nil { + return fmt.Errorf("ws dial: %w", err) + } + + t.mu.Lock() + t.conn = conn + t.closed.Store(false) + t.authDone = make(chan struct{}) + t.authDoneOnce = sync.Once{} + t.mu.Unlock() + + go t.readLoop() + return nil +} + +// Close sends a close frame and shuts down the connection. +func (t *WSTransport) Close() error { + t.closed.Store(true) + t.mu.Lock() + defer t.mu.Unlock() + if t.conn != nil { + _ = t.conn.WriteMessage( + websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), + ) + err := t.conn.Close() + t.conn = nil + return err + } + return nil +} + +// Register sends auth message and waits for the registered response. +func (t *WSTransport) Register(ctx context.Context, req RegisterRequest) (*RegisterResponse, error) { + msg := wsAuthMessage{ + Type: "auth", + APIKey: t.apiKey, + AgentID: req.AgentID, + Name: req.Name, + OS: req.OS, + Arch: req.Arch, + Version: req.Version, + DownloadDir: req.DownloadDir, + DiskFreeBytes: req.DiskFreeBytes, + DiskTotalBytes: req.DiskTotalBytes, + } + + if err := t.send(msg); err != nil { + return nil, fmt.Errorf("ws auth send: %w", err) + } + + // Wait for the auth response or context cancellation + select { + case <-t.authDone: + t.authMu.Lock() + resp := t.authResp + t.authMu.Unlock() + if resp == nil { + return nil, fmt.Errorf("ws auth: no response received") + } + return resp, nil + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(15 * time.Second): + return nil, fmt.Errorf("ws auth: timeout waiting for registered response") + } +} + +// SendHeartbeat sends a heartbeat message. No blocking response in WS mode. +func (t *WSTransport) SendHeartbeat(_ context.Context, req HeartbeatRequest) (*HeartbeatResponse, error) { + msg := struct { + Type string `json:"type"` + Disk *struct { + Free int64 `json:"free"` + Total int64 `json:"total"` + } `json:"disk,omitempty"` + }{Type: "heartbeat"} + + if req.DiskFreeBytes > 0 || req.DiskTotalBytes > 0 { + msg.Disk = &struct { + Free int64 `json:"free"` + Total int64 `json:"total"` + }{Free: req.DiskFreeBytes, Total: req.DiskTotalBytes} + } + + if err := t.send(msg); err != nil { + return nil, err + } + // WS mode: heartbeat is fire-and-forget. Upgrade signals arrive via Events(). + return &HeartbeatResponse{Success: true}, nil +} + +// SendProgress sends a progress update. Control signals arrive async via Events(). +func (t *WSTransport) SendProgress(_ context.Context, update StatusUpdate) (*StatusResponse, error) { + msg := struct { + Type string `json:"type"` + TaskID string `json:"taskId"` + Status string `json:"status,omitempty"` + Progress int `json:"progress,omitempty"` + DownloadedBytes int64 `json:"downloadedBytes,omitempty"` + TotalBytes int64 `json:"totalBytes,omitempty"` + SpeedBps int64 `json:"speedBps,omitempty"` + ETA int `json:"eta,omitempty"` + ResolvedMethod string `json:"resolvedMethod,omitempty"` + FileName string `json:"fileName,omitempty"` + FilePath string `json:"filePath,omitempty"` + StreamURL string `json:"streamUrl,omitempty"` + ErrorMessage string `json:"errorMessage,omitempty"` + }{ + Type: "progress", + TaskID: update.TaskID, + Status: update.Status, + Progress: update.Progress, + DownloadedBytes: update.DownloadedBytes, + TotalBytes: update.TotalBytes, + SpeedBps: update.SpeedBps, + ETA: update.ETA, + ResolvedMethod: update.ResolvedMethod, + FileName: update.FileName, + FilePath: update.FilePath, + StreamURL: update.StreamURL, + ErrorMessage: update.ErrorMessage, + } + + if err := t.send(msg); err != nil { + return nil, err + } + // In WS mode, control signals come via Events(), not in the progress response. + return &StatusResponse{Success: true}, nil +} + +// ClaimTasks is a no-op in WS mode — tasks arrive via Events(). +func (t *WSTransport) ClaimTasks(_ context.Context, _ string) (*TasksResponse, error) { + return &TasksResponse{}, nil +} + +// Deregister is handled by WebSocket close (DO detects disconnection). +func (t *WSTransport) Deregister(_ context.Context, _ string) error { + return t.Close() +} + +// ReportUpgradeResult sends upgrade result to the DO. +func (t *WSTransport) ReportUpgradeResult(_ context.Context, result UpgradeResult) error { + msg := struct { + Type string `json:"type"` + Success bool `json:"success"` + Version string `json:"version,omitempty"` + Error string `json:"error,omitempty"` + }{ + Type: "upgrade-result", + Success: result.Success, + Version: result.Version, + Error: result.Error, + } + return t.send(msg) +} + +// ── Internal ───────────────────────────────────────────────────────────────── + +func (t *WSTransport) send(msg any) error { + t.mu.Lock() + defer t.mu.Unlock() + if t.conn == nil { + return fmt.Errorf("ws: not connected") + } + data, err := json.Marshal(msg) + if err != nil { + return err + } + return t.conn.WriteMessage(websocket.TextMessage, data) +} + +func (t *WSTransport) readLoop() { + for { + _, msg, err := t.conn.ReadMessage() + if err != nil { + if !t.closed.Load() { + log.Printf("[ws] read error: %v", err) + // Signal disconnection to the daemon + select { + case t.events <- ServerEvent{Type: "disconnected"}: + default: + } + } + return + } + + var envelope struct { + Type string `json:"type"` + } + if err := json.Unmarshal(msg, &envelope); err != nil { + log.Printf("[ws] invalid message: %v", err) + continue + } + + switch envelope.Type { + case "registered": + var resp wsRegisteredMessage + if json.Unmarshal(msg, &resp) == nil { + t.authMu.Lock() + t.authResp = &RegisterResponse{ + Success: true, + User: resp.User, + Features: resp.Features, + } + t.authMu.Unlock() + // Signal that auth is complete (sync.Once prevents double-close panic) + t.authDoneOnce.Do(func() { close(t.authDone) }) + } + + case "tasks": + var resp wsTasksMessage + if json.Unmarshal(msg, &resp) == nil { + select { + case t.events <- ServerEvent{ + Type: "tasks", + Tasks: &TasksResponse{ + Tasks: resp.Tasks, + StreamRequests: resp.StreamRequests, + }, + }: + default: + log.Printf("[ws] events channel full, dropping tasks message") + } + } + + case "upgrade": + var resp wsUpgradeMessage + if json.Unmarshal(msg, &resp) == nil { + select { + case t.events <- ServerEvent{ + Type: "upgrade", + Upgrade: &UpgradeSignal{Version: resp.Version}, + }: + default: + } + } + + case "control": + var resp ControlAction + if json.Unmarshal(msg, &resp) == nil { + select { + case t.events <- ServerEvent{ + Type: "control", + Control: &resp, + }: + default: + } + } + + case "error": + var resp struct{ Message string `json:"message"` } + if json.Unmarshal(msg, &resp) == nil { + log.Printf("[ws] server error: %s", resp.Message) + } + } + } +} + +// ── WS message types ───────────────────────────────────────────────────────── + +type wsAuthMessage struct { + Type string `json:"type"` + APIKey string `json:"apiKey"` + AgentID string `json:"agentId"` + Name string `json:"name,omitempty"` + OS string `json:"os,omitempty"` + Arch string `json:"arch,omitempty"` + Version string `json:"version,omitempty"` + DownloadDir string `json:"downloadDir,omitempty"` + DiskFreeBytes int64 `json:"diskFreeBytes,omitempty"` + DiskTotalBytes int64 `json:"diskTotalBytes,omitempty"` +} + +type wsRegisteredMessage struct { + Type string `json:"type"` + User UserInfo `json:"user"` + Features FeatureFlags `json:"features"` +} + +type wsTasksMessage struct { + Type string `json:"type"` + Tasks []Task `json:"tasks"` + StreamRequests []StreamRequest `json:"streamRequests,omitempty"` +} + +type wsUpgradeMessage struct { + Type string `json:"type"` + Version string `json:"version"` +} diff --git a/internal/cmd/daemon.go b/internal/cmd/daemon.go index f904cfa..e2cc59a 100644 --- a/internal/cmd/daemon.go +++ b/internal/cmd/daemon.go @@ -7,6 +7,7 @@ import ( "os" "os/signal" "path/filepath" + "strings" "syscall" "time" @@ -104,8 +105,6 @@ func runDaemonStart() error { heartbeatInterval = 30 * time.Second } - // Create agent client (direct HTTP — always available as fallback) - ac := agent.NewClient(cfg.Auth.APIURL, cfg.Auth.APIKey, "unarr/"+Version) userAgent := "unarr/" + Version // Create daemon config @@ -119,6 +118,8 @@ func runDaemonStart() error { } // Create transport: Hybrid (WS + HTTP fallback) or HTTP-only + httpT := agent.NewHTTPTransport(cfg.Auth.APIURL, cfg.Auth.APIKey, userAgent) + wsURL := cfg.Auth.WSURL if wsURL == "" { wsURL = deriveWSURL(cfg.Auth.APIURL, cfg.Agent.ID) @@ -126,28 +127,19 @@ func runDaemonStart() error { var transport agent.Transport if wsURL != "" { - httpT := agent.NewHTTPTransport(cfg.Auth.APIURL, cfg.Auth.APIKey, userAgent) wsT := agent.NewWSTransport(wsURL, cfg.Auth.APIKey, cfg.Agent.ID, userAgent) transport = agent.NewHybridTransport(wsT, httpT) log.Printf("Transport: WebSocket (fallback: HTTP) → %s", wsURL) + } else { + transport = httpT + log.Println("Transport: HTTP only") } - // Create daemon - var d *agent.Daemon - if transport != nil { - d = agent.NewDaemonWithTransport(daemonCfg, transport) - } else { - d = agent.NewDaemon(daemonCfg, ac) - } + // Create daemon — always uses Transport interface + d := agent.NewDaemon(daemonCfg, transport) - // Wire state tracking (connected after manager creation below) - // Create progress reporter - var reporter *engine.ProgressReporter - if transport != nil { - reporter = engine.NewProgressReporterWithTransport(transport, 3*time.Second) - } else { - reporter = engine.NewProgressReporter(ac, 3*time.Second) - } + // Create progress reporter using transport + reporter := engine.NewProgressReporterWithTransport(transport, 3*time.Second) // Parse speed limits maxDl, _ := config.ParseSpeed(cfg.Download.MaxDownloadSpeed) @@ -190,7 +182,7 @@ func runDaemonStart() error { MoviesDir: cfg.Organize.MoviesDir, TVShowsDir: cfg.Organize.TVShowsDir, }, - }, reporter, torrentDl, debridDl) + }, reporter, torrentDl, debridDl, engine.NewUsenetDownloader(httpT.Client())) // Wire state tracking d.GetActiveCount = manager.ActiveCount @@ -275,9 +267,9 @@ func runDaemonStart() error { log.Printf("[%s] streaming from disk: %s → %s", sr.TaskID[:8], filepath.Base(sr.FilePath), streamURL) - // Report stream URL back to the server + // Report stream URL back to the server via transport go func() { - if _, err := ac.ReportStatus(ctx, agent.StatusUpdate{ + if _, err := transport.SendProgress(ctx, agent.StatusUpdate{ TaskID: sr.TaskID, StreamURL: streamURL, }); err != nil { @@ -298,13 +290,18 @@ func runDaemonStart() error { case "resume": log.Printf("[%s] resume requested via WebSocket", taskID[:8]) case "stream": + // Use registry mutex to prevent TOCTOU race with HTTP-polled stream requests + streamRegistry.mu.Lock() + if _, exists := streamRegistry.servers[taskID]; exists { + streamRegistry.mu.Unlock() + return + } task := manager.GetTask(taskID) - if task == nil { - return - } - if task.GetStreamURL() != "" { + if task == nil || task.GetStreamURL() != "" { + streamRegistry.mu.Unlock() return } + streamRegistry.mu.Unlock() srv, err := torrentDl.StartStream(taskID) if err != nil { log.Printf("[%s] stream failed: %v", taskID[:8], err) @@ -342,11 +339,7 @@ func runDaemonStart() error { Version: result.NewVersion, Error: errMsg, } - if transport != nil { - _ = transport.ReportUpgradeResult(reportCtx, upgradeResult) - } else { - _ = ac.ReportUpgradeResult(reportCtx, upgradeResult) - } + _ = transport.ReportUpgradeResult(reportCtx, upgradeResult) if !result.Success { log.Printf("Upgrade failed: %v", result.Error) @@ -360,7 +353,7 @@ func runDaemonStart() error { // Deregister first so the server knows we're restarting deregCtx, deregCancel := context.WithTimeout(context.Background(), 5*time.Second) defer deregCancel() - _ = ac.Deregister(deregCtx, cfg.Agent.ID) + _ = transport.Deregister(deregCtx, cfg.Agent.ID) // Flush progress reporter cancel() @@ -418,6 +411,7 @@ func runDaemonStart() error { // deriveWSURL derives a WebSocket URL from the API URL. // https://torrentclaw.com → wss://unarr.torrentclaw.com/ws/{agentId} +// Returns "" for localhost/dev environments where WS gateway isn't available. func deriveWSURL(apiURL, agentID string) string { if apiURL == "" || agentID == "" { return "" @@ -437,6 +431,15 @@ func deriveWSURL(apiURL, agentID string) string { break } } + // Strip port if present + if idx := strings.LastIndex(domain, ":"); idx > 0 { + domain = domain[:idx] + } + + // Skip WS for localhost/dev — gateway only available in production + if domain == "localhost" || domain == "127.0.0.1" || domain == "0.0.0.0" { + return "" + } return "wss://unarr." + domain + "/ws/" + agentID } diff --git a/internal/config/config.go b/internal/config/config.go index 00d9894..2de9f9b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -25,6 +25,7 @@ type Config struct { type AuthConfig struct { APIKey string `toml:"api_key"` APIURL string `toml:"api_url"` + WSURL string `toml:"ws_url"` // optional, derived from api_url if empty } type AgentConfig struct { diff --git a/internal/engine/progress.go b/internal/engine/progress.go index 04d9552..e3c15fb 100644 --- a/internal/engine/progress.go +++ b/internal/engine/progress.go @@ -12,11 +12,17 @@ import ( // ActionFunc is called when the server signals an action on a task. type ActionFunc func(taskID string) +// StatusReporter is the interface used by ProgressReporter to send progress updates. +// Both *agent.Client and agent.Transport implement this via their ReportStatus/SendProgress methods. +type StatusReporter interface { + ReportStatus(ctx context.Context, update agent.StatusUpdate) (*agent.StatusResponse, error) +} + // ProgressReporter aggregates progress from downloads and reports to the API. // It batches updates to avoid flooding the server. type ProgressReporter struct { - agentClient *agent.Client - interval time.Duration + reporter StatusReporter + interval time.Duration onCancel ActionFunc onPause ActionFunc @@ -28,14 +34,33 @@ type ProgressReporter struct { } // NewProgressReporter creates a reporter that flushes every interval. +// Accepts *agent.Client directly (backwards compatible). func NewProgressReporter(ac *agent.Client, interval time.Duration) *ProgressReporter { return &ProgressReporter{ - agentClient: ac, - interval: interval, - latest: make(map[string]*Task), + reporter: ac, + interval: interval, + latest: make(map[string]*Task), } } +// NewProgressReporterWithTransport creates a reporter using a Transport. +func NewProgressReporterWithTransport(t agent.Transport, interval time.Duration) *ProgressReporter { + return &ProgressReporter{ + reporter: &transportStatusAdapter{t: t}, + interval: interval, + latest: make(map[string]*Task), + } +} + +// transportStatusAdapter adapts agent.Transport to StatusReporter. +type transportStatusAdapter struct { + t agent.Transport +} + +func (a *transportStatusAdapter) ReportStatus(ctx context.Context, update agent.StatusUpdate) (*agent.StatusResponse, error) { + return a.t.SendProgress(ctx, update) +} + // SetCancelHandler sets the callback invoked when the server says a task is cancelled. func (r *ProgressReporter) SetCancelHandler(fn ActionFunc) { r.onCancel = fn } @@ -95,7 +120,7 @@ func (r *ProgressReporter) flush(ctx context.Context) { } update := task.ToStatusUpdate() - resp, err := r.agentClient.ReportStatus(ctx, update) + resp, err := r.reporter.ReportStatus(ctx, update) if err != nil { log.Printf("[%s] progress report failed: %v", task.ID[:8], err) continue @@ -130,7 +155,7 @@ func (r *ProgressReporter) flush(ctx context.Context) { // ReportFinal sends a final status update for a completed/failed task. func (r *ProgressReporter) ReportFinal(ctx context.Context, task *Task) { update := task.ToStatusUpdate() - if _, err := r.agentClient.ReportStatus(ctx, update); err != nil { + if _, err := r.reporter.ReportStatus(ctx, update); err != nil { log.Printf("[%s] final report failed: %v", task.ID[:8], err) } r.Untrack(task.ID)