From 5d4a67c7a2e6bdccdbac1718a9f2f33f0b159ab0 Mon Sep 17 00:00:00 2001 From: Deivid Soto Date: Wed, 8 Apr 2026 18:50:59 +0200 Subject: [PATCH] feat(sync): replace WS+DO transport with unified HTTP sync Replace the WebSocket + Cloudflare Durable Object architecture with a single POST /sync endpoint. The CLI now operates autonomously with local state (tasks.json) and syncs bidirectionally via adaptive-interval HTTP polling (3s watching, 60s idle). - Remove transport_ws, transport_hybrid, transport_http (~2,600 lines) - Add SyncClient with adaptive interval loop - Add LocalState for CLI-side task persistence - Add TaskStateFromUpdate() helper (DRY) - Extract finalize() to deduplicate processTask/processTaskRetry - Consolidate shortID() into agent.ShortID (was in 3 packages) - Wire GetActiveCount so `unarr status` shows active tasks - Remove poll_interval, heartbeat_interval, ws_url from config - Simplify ProgressReporter (sync replaces direct HTTP reporting) --- CHANGELOG.md | 11 + go.mod | 2 +- internal/agent/client.go | 31 +- internal/agent/client_test.go | 122 +- internal/agent/daemon.go | 285 ++--- internal/agent/sync.go | 195 ++++ internal/agent/sync_test.go | 362 ++++++ internal/agent/taskstate.go | 136 +++ internal/agent/taskstate_test.go | 217 ++++ internal/agent/transport.go | 51 - internal/agent/transport_e2e_test.go | 285 ----- internal/agent/transport_http.go | 50 - internal/agent/transport_hybrid.go | 214 ---- internal/agent/transport_test.go | 1590 -------------------------- internal/agent/transport_ws.go | 395 ------- internal/agent/types.go | 68 +- internal/cmd/config_menu.go | 19 +- internal/cmd/daemon.go | 394 +++---- internal/cmd/daemon_test.go | 26 - internal/cmd/reload_unix.go | 19 +- internal/cmd/version.go | 2 +- internal/config/config.go | 10 +- internal/config/config_test.go | 4 +- internal/engine/debrid.go | 25 +- internal/engine/manager.go | 185 +-- internal/engine/progress.go | 22 - 26 files changed, 1320 insertions(+), 3400 deletions(-) create mode 100644 internal/agent/sync.go create mode 100644 internal/agent/sync_test.go create mode 100644 internal/agent/taskstate.go create mode 100644 internal/agent/taskstate_test.go delete mode 100644 internal/agent/transport.go delete mode 100644 internal/agent/transport_e2e_test.go delete mode 100644 internal/agent/transport_http.go delete mode 100644 internal/agent/transport_hybrid.go delete mode 100644 internal/agent/transport_test.go delete mode 100644 internal/agent/transport_ws.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 89e484d..18d0125 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.5.6] - 2026-04-07 + + +### Fixed + +- **ws**: add ping/pong keepalive and read deadline to detect zombie connections ## [0.5.5] - 2026-04-07 @@ -17,6 +23,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - **daemon**: cancel watch reporter on stream switch and re-notify ready + +### Other + +- **release**: 0.5.5 ## [0.5.4] - 2026-04-07 @@ -153,6 +163,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - remove UPX compression (antivirus false positives, startup penalty) - add -s -w -trimpath to Makefile, add build-small target with UPX +[0.5.6]: https://github.com/torrentclaw/unarr/compare/v0.5.5...v0.5.6 [0.5.5]: https://github.com/torrentclaw/unarr/compare/v0.5.4...v0.5.5 [0.5.4]: https://github.com/torrentclaw/unarr/compare/v0.5.3...v0.5.4 [0.5.3]: https://github.com/torrentclaw/unarr/compare/v0.5.2...v0.5.3 diff --git a/go.mod b/go.mod index 5457304..6439955 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,6 @@ require ( github.com/fatih/color v1.19.0 github.com/getsentry/sentry-go v0.44.1 github.com/google/uuid v1.6.0 - github.com/gorilla/websocket v1.5.3 github.com/huin/goupnp v1.3.0 github.com/olekukonko/tablewriter v1.1.4 github.com/spf13/cobra v1.10.2 @@ -69,6 +68,7 @@ require ( github.com/google/btree v1.1.3 // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/pprof v0.0.0-20260302011040-a15ffb7f9dcc // indirect + github.com/gorilla/websocket v1.5.3 // indirect github.com/huandu/xstrings v1.5.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect diff --git a/internal/agent/client.go b/internal/agent/client.go index b437e9e..fe4e04a 100644 --- a/internal/agent/client.go +++ b/internal/agent/client.go @@ -40,27 +40,6 @@ func (c *Client) Register(ctx context.Context, req RegisterRequest) (*RegisterRe return &resp, nil } -// Heartbeat sends a periodic keep-alive signal and returns server directives. -func (c *Client) Heartbeat(ctx context.Context, req HeartbeatRequest) (*HeartbeatResponse, error) { - var resp HeartbeatResponse - if err := c.doPost(ctx, "/api/internal/agent/heartbeat", req, &resp); err != nil { - return nil, fmt.Errorf("heartbeat: %w", err) - } - return &resp, nil -} - -// ClaimTasks polls for pending download tasks and claims them atomically. -// Also returns any stream requests for completed downloads. -func (c *Client) ClaimTasks(ctx context.Context, agentID string) (*TasksResponse, error) { - url := fmt.Sprintf("/api/internal/agent/tasks?agentId=%s", agentID) - var resp TasksResponse - if err := c.doGet(ctx, url, &resp); err != nil { - return nil, fmt.Errorf("claim tasks: %w", err) - } - return &resp, nil -} - -// ReportStatus reports download progress or completion for a task. // Deregister notifies the server that the agent is shutting down. func (c *Client) Deregister(ctx context.Context, agentID string) error { req := struct { @@ -91,6 +70,16 @@ func (c *Client) BatchReportStatus(ctx context.Context, updates []StatusUpdate) return &resp, nil } +// Sync sends the CLI's full state and receives all pending server actions. +// This is the single endpoint for bidirectional state synchronization. +func (c *Client) Sync(ctx context.Context, req SyncRequest) (*SyncResponse, error) { + var resp SyncResponse + if err := c.doPost(ctx, "/api/internal/agent/sync", req, &resp); err != nil { + return nil, fmt.Errorf("sync: %w", err) + } + return &resp, nil +} + // --------------------------------------------------------------------------- // Usenet endpoints // --------------------------------------------------------------------------- diff --git a/internal/agent/client_test.go b/internal/agent/client_test.go index c7ff470..c78b9ba 100644 --- a/internal/agent/client_test.go +++ b/internal/agent/client_test.go @@ -72,70 +72,6 @@ func TestRegister(t *testing.T) { } } -func TestHeartbeat(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != "/api/internal/agent/heartbeat" { - t.Errorf("path = %s, want /api/internal/agent/heartbeat", r.URL.Path) - } - var req HeartbeatRequest - json.NewDecoder(r.Body).Decode(&req) - if req.AgentID != "agent-123" { - t.Errorf("agentId = %q, want agent-123", req.AgentID) - } - json.NewEncoder(w).Encode(HeartbeatResponse{Success: true}) - })) - defer srv.Close() - - c := NewClient(srv.URL, "test-key", "unarr-test") - resp, err := c.Heartbeat(context.Background(), HeartbeatRequest{AgentID: "agent-123"}) - if err != nil { - t.Fatalf("Heartbeat failed: %v", err) - } - if !resp.Success { - t.Error("expected success=true") - } -} - -func TestClaimTasks(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - t.Errorf("method = %s, want GET", r.Method) - } - if r.URL.Query().Get("agentId") != "agent-123" { - t.Errorf("agentId param = %q, want agent-123", r.URL.Query().Get("agentId")) - } - json.NewEncoder(w).Encode(TasksResponse{ - Tasks: []Task{ - { - ID: "task-uuid-1", - InfoHash: "abc123def456abc123def456abc123def456abc1", - Title: "The Matrix (1999)", - PreferredMethod: "auto", - }, - }, - }) - })) - defer srv.Close() - - c := NewClient(srv.URL, "test-key", "unarr-test") - resp, err := c.ClaimTasks(context.Background(), "agent-123") - if err != nil { - t.Fatalf("ClaimTasks failed: %v", err) - } - if len(resp.Tasks) != 1 { - t.Fatalf("len(tasks) = %d, want 1", len(resp.Tasks)) - } - if resp.Tasks[0].ID != "task-uuid-1" { - t.Errorf("task.ID = %q, want task-uuid-1", resp.Tasks[0].ID) - } - if resp.Tasks[0].InfoHash != "abc123def456abc123def456abc123def456abc1" { - t.Errorf("task.InfoHash = %q", resp.Tasks[0].InfoHash) - } - if resp.Tasks[0].PreferredMethod != "auto" { - t.Errorf("task.PreferredMethod = %q, want auto", resp.Tasks[0].PreferredMethod) - } -} - func TestReportStatus(t *testing.T) { var received StatusUpdate srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -173,22 +109,6 @@ func TestReportStatus(t *testing.T) { } } -func TestClaimTasksEmpty(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - json.NewEncoder(w).Encode(TasksResponse{Tasks: []Task{}}) - })) - defer srv.Close() - - c := NewClient(srv.URL, "test-key", "unarr-test") - resp, err := c.ClaimTasks(context.Background(), "agent-123") - if err != nil { - t.Fatalf("ClaimTasks failed: %v", err) - } - if len(resp.Tasks) != 0 { - t.Errorf("expected empty tasks, got %d", len(resp.Tasks)) - } -} - func TestAPIError(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUnauthorized) @@ -279,50 +199,12 @@ func TestUserAgent(t *testing.T) { if r.Header.Get("User-Agent") != "unarr/0.2.0" { t.Errorf("User-Agent = %q, want unarr/0.2.0", r.Header.Get("User-Agent")) } - json.NewEncoder(w).Encode(HeartbeatResponse{Success: true}) + json.NewEncoder(w).Encode(RegisterResponse{Success: true}) })) defer srv.Close() c := NewClient(srv.URL, "test-key", "unarr/0.2.0") - c.Heartbeat(context.Background(), HeartbeatRequest{AgentID: "x"}) -} - -func TestHeartbeatWithUpgradeSignal(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - json.NewEncoder(w).Encode(HeartbeatResponse{ - Success: true, - Upgrade: &UpgradeSignal{Version: "2.0.0"}, - }) - })) - defer srv.Close() - - c := NewClient(srv.URL, "test-key", "unarr-test") - resp, err := c.Heartbeat(context.Background(), HeartbeatRequest{AgentID: "agent-1"}) - if err != nil { - t.Fatalf("Heartbeat failed: %v", err) - } - if resp.Upgrade == nil { - t.Fatal("expected upgrade signal, got nil") - } - if resp.Upgrade.Version != "2.0.0" { - t.Errorf("upgrade version = %q, want 2.0.0", resp.Upgrade.Version) - } -} - -func TestHeartbeatWithoutUpgradeSignal(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - json.NewEncoder(w).Encode(HeartbeatResponse{Success: true}) - })) - defer srv.Close() - - c := NewClient(srv.URL, "test-key", "unarr-test") - resp, err := c.Heartbeat(context.Background(), HeartbeatRequest{AgentID: "agent-1"}) - if err != nil { - t.Fatalf("Heartbeat failed: %v", err) - } - if resp.Upgrade != nil { - t.Errorf("expected no upgrade signal, got %+v", resp.Upgrade) - } + c.Register(context.Background(), RegisterRequest{AgentID: "x"}) } func TestDeregister(t *testing.T) { diff --git a/internal/agent/daemon.go b/internal/agent/daemon.go index af967c4..225dde9 100644 --- a/internal/agent/daemon.go +++ b/internal/agent/daemon.go @@ -14,75 +14,62 @@ import ( // DaemonConfig holds daemon runtime settings. type DaemonConfig struct { - AgentID string - AgentName string - Version string - DownloadDir string - PollInterval time.Duration - HeartbeatInterval time.Duration - StreamPort int // port for the HTTP stream server (reported in heartbeat) - LanIP string // LAN IP (reported in heartbeat for stream URL resolution) - TailscaleIP string // Tailscale IP (reported in heartbeat for stream URL resolution) + AgentID string + AgentName string + Version string + DownloadDir string + StreamPort int // port for the HTTP stream server + LanIP string // LAN IP (reported in sync for stream URL resolution) + TailscaleIP string // Tailscale IP (reported in sync for stream URL resolution) } -// Daemon manages the main loop: register, heartbeat, poll tasks. +// Daemon manages agent registration and the sync loop. type Daemon struct { - cfg DaemonConfig - transport Transport + cfg DaemonConfig + client *Client + sync *SyncClient + state *LocalState - // Callbacks + // Callbacks — set by cmd/daemon.go before calling Run. OnTasksClaimed func(tasks []Task) OnStreamRequested func(req StreamRequest) - OnControlAction func(action, taskID string) + OnControlAction func(action, taskID string, deleteFiles bool) + GetActiveCount func() int // returns number of active downloads (wired from manager) // State User UserInfo Features FeatureFlags Info AgentInfo State DaemonState - heartbeatFailures int lastNotifiedVersion string - // Callbacks for state tracking (set by cmd/daemon.go) - GetActiveCount func() int - GetCleanableBytes func() int64 - // Watching tracks whether a user is viewing download progress in the web UI. - // When false, the progress reporter skips detailed updates (only sends final states). - // Accessed from heartbeat goroutine, flush goroutine, and WatchingFunc closure — must be atomic. Watching atomic.Bool - // Exposed tickers for hot-reload - PollTicker *time.Ticker - HeartbeatTicker *time.Ticker - - // pollNow triggers an immediate poll (e.g. on resume) - pollNow chan struct{} - - // ScanNow triggers an immediate library scan (from heartbeat or WebSocket control event) + // ScanNow triggers an immediate library scan. ScanNow chan struct{} } -// NewDaemon creates a daemon with the given transport. -// Use NewHTTPTransport for HTTP-only, or NewHybridTransport for WS+HTTP. -func NewDaemon(cfg DaemonConfig, transport Transport) *Daemon { - if cfg.PollInterval == 0 { - cfg.PollInterval = 30 * time.Second - } - if cfg.HeartbeatInterval == 0 { - cfg.HeartbeatInterval = 30 * time.Second - } - +// NewDaemon creates a daemon with an HTTP client for sync-based communication. +func NewDaemon(cfg DaemonConfig, client *Client) *Daemon { + state := NewLocalState() return &Daemon{ - cfg: cfg, - transport: transport, - pollNow: make(chan struct{}, 1), - ScanNow: make(chan struct{}, 1), + cfg: cfg, + client: client, + state: state, + sync: NewSyncClient(client, cfg, state), + ScanNow: make(chan struct{}, 1), } } -// Transport returns the configured transport. -func (d *Daemon) Transport() Transport { return d.transport } +// SyncClient returns the sync client for external wiring. +func (d *Daemon) SyncClient() *SyncClient { return d.sync } + +// UpdateStreamPort updates the stream port reported in sync requests. +func (d *Daemon) UpdateStreamPort(port int) { + d.cfg.StreamPort = port + d.sync.cfg.StreamPort = port +} // Register registers the agent and fetches user info + features. // Retries with exponential backoff on transient errors (429, 5xx, network). @@ -109,11 +96,10 @@ func (d *Daemon) Register(ctx context.Context) error { var resp *RegisterResponse var err error for attempt := range maxRetries { - resp, err = d.transport.Register(ctx, req) + resp, err = d.client.Register(ctx, req) if err == nil { break } - // Only retry on transient errors (429, 5xx, network failures) if !isTransientError(err) { return fmt.Errorf("register: %w", err) } @@ -154,14 +140,9 @@ func (d *Daemon) Register(ctx context.Context) error { return nil } -// Run connects the transport, registers the agent, and starts the main loop. -// Blocks until ctx is cancelled. Callers must NOT call transport.Connect before Run. +// Run registers the agent and starts the sync loop. +// Blocks until ctx is cancelled. func (d *Daemon) Run(ctx context.Context) error { - // Connect transport (establishes WebSocket if available, falls back to HTTP) - if err := d.transport.Connect(ctx); err != nil { - return fmt.Errorf("connect transport: %w", err) - } - // Register if err := d.Register(ctx); err != nil { return err @@ -169,163 +150,61 @@ func (d *Daemon) Run(ctx context.Context) error { log.Printf("Agent registered: %s (%s) [%s]", d.User.Name, d.User.Email, d.User.Plan) log.Printf("Features: torrent=%v debrid=%v usenet=%v", d.Features.Torrent, d.Features.Debrid, d.Features.Usenet) - log.Printf("Polling every %s, heartbeat every %s", d.cfg.PollInterval, d.cfg.HeartbeatInterval) - d.HeartbeatTicker = time.NewTicker(d.cfg.HeartbeatInterval) - defer d.HeartbeatTicker.Stop() - - d.PollTicker = time.NewTicker(d.cfg.PollInterval) - defer d.PollTicker.Stop() - - heartbeatTicker := d.HeartbeatTicker - pollTicker := d.PollTicker - - // Initial poll immediately - d.poll(ctx) - - eventsCh := d.transport.Events() - - for { - select { - case <-ctx.Done(): - log.Println("Daemon shutting down...") - d.deregister() - return nil - - case event := <-eventsCh: - d.handleEvent(event) - - case <-heartbeatTicker.C: - d.heartbeat(ctx) - - case <-pollTicker.C: - // Only poll in HTTP mode — WS mode receives tasks via Events - if d.transport.Mode() == "http" { - d.poll(ctx) - } - - case <-d.pollNow: - d.poll(ctx) + // Wire sync callbacks + d.sync.OnNewTasks = func(tasks []Task) { + if d.OnTasksClaimed != nil { + d.OnTasksClaimed(tasks) } } -} - -func (d *Daemon) heartbeat(ctx context.Context) { - req := HeartbeatRequest{ - AgentID: d.cfg.AgentID, - Name: d.cfg.AgentName, - Version: d.cfg.Version, - OS: runtime.GOOS, - DownloadDir: d.cfg.DownloadDir, - StreamPort: d.cfg.StreamPort, - LanIP: d.cfg.LanIP, - TailscaleIP: d.cfg.TailscaleIP, - } - if free, total, err := DiskInfo(d.cfg.DownloadDir); err == nil { - req.DiskFreeBytes = free - req.DiskTotalBytes = total - } - - resp, err := d.transport.SendHeartbeat(ctx, req) - if err != nil { - d.heartbeatFailures++ - if d.heartbeatFailures >= 5 && d.heartbeatFailures%5 == 0 { - log.Printf("CRITICAL: %d consecutive heartbeat failures — server may be unreachable", d.heartbeatFailures) - } else { - log.Printf("Heartbeat failed: %v", err) + d.sync.OnControl = func(action, taskID string, deleteFiles bool) { + if d.OnControlAction != nil { + d.OnControlAction(action, taskID, deleteFiles) } - return } - if d.heartbeatFailures > 0 { - log.Printf("Heartbeat recovered after %d failures", d.heartbeatFailures) - d.heartbeatFailures = 0 + d.sync.OnStreamRequest = func(req StreamRequest) { + if d.OnStreamRequested != nil { + d.OnStreamRequested(req) + } } - - // Update watching flag and state file - d.Watching.Store(resp.Watching) - d.State.LastHeartbeat = time.Now() - if d.GetActiveCount != nil { - d.State.ActiveTasks = d.GetActiveCount() + d.sync.OnUpgrade = func(version string) { + if version != d.lastNotifiedVersion { + d.lastNotifiedVersion = version + log.Printf("New version available: %s (run `unarr self-update` to upgrade)", version) + } } - WriteState(&d.State) - - // Trigger library scan if requested - if resp.Scan { + d.sync.OnScan = func() { log.Printf("Library scan requested by server") select { case d.ScanNow <- struct{}{}: - default: // scan already pending + default: } } - - // Log once per version when server suggests an upgrade - if resp.Upgrade != nil && resp.Upgrade.Version != "" && resp.Upgrade.Version != d.lastNotifiedVersion { - d.lastNotifiedVersion = resp.Upgrade.Version - log.Printf("New version available: %s (run `unarr self-update` to upgrade)", resp.Upgrade.Version) + d.sync.OnWatchingChange = func(watching bool) { + d.Watching.Store(watching) } -} - -// handleEvent processes a server-initiated event from the WebSocket transport. -func (d *Daemon) handleEvent(event ServerEvent) { - switch event.Type { - case "tasks": - if event.Tasks != nil && len(event.Tasks.Tasks) > 0 { - log.Printf("Received %d task(s) via WebSocket", len(event.Tasks.Tasks)) - if d.OnTasksClaimed != nil { - d.OnTasksClaimed(event.Tasks.Tasks) - } + d.sync.OnSyncSuccess = func() { + d.State.LastHeartbeat = time.Now() + if d.GetActiveCount != nil { + d.State.ActiveTasks = d.GetActiveCount() } - if event.Tasks != nil && d.OnStreamRequested != nil { - for _, sr := range event.Tasks.StreamRequests { - d.OnStreamRequested(sr) - } - } - - case "upgrade": - if event.Upgrade != nil && event.Upgrade.Version != "" && event.Upgrade.Version != d.lastNotifiedVersion { - d.lastNotifiedVersion = event.Upgrade.Version - log.Printf("New version available: %s (run `unarr self-update` to upgrade)", event.Upgrade.Version) - } - - case "control": - if event.Control != nil { - log.Printf("Control action via WebSocket: %s task %s", event.Control.Action, event.Control.TaskID) - if event.Control.Action == "scan" { - select { - case d.ScanNow <- struct{}{}: - default: - } - } - if d.OnControlAction != nil { - d.OnControlAction(event.Control.Action, event.Control.TaskID) - } - } - - case "disconnected": - log.Println("WebSocket disconnected, switching to HTTP polling") + WriteState(&d.State) } + + // Start sync loop (blocks) + return d.sync.Run(ctx) } -// UpdateStreamPort updates the stream port reported in heartbeats. -// Called after the persistent stream server binds (actual port may differ from configured). -func (d *Daemon) UpdateStreamPort(port int) { - d.cfg.StreamPort = port +// TriggerSync requests an immediate sync cycle. +func (d *Daemon) TriggerSync() { + d.sync.TriggerSync() } -// TriggerPoll requests an immediate task poll cycle. -// Used when a resume event is received to pick up re-pending tasks faster. -func (d *Daemon) TriggerPoll() { - select { - case d.pollNow <- struct{}{}: - default: // already pending - } -} - -func (d *Daemon) deregister() { +// Deregister notifies the server of graceful shutdown. +func (d *Daemon) Deregister() { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - err := d.transport.Deregister(ctx, d.cfg.AgentID) - if err != nil { + if err := d.client.Deregister(ctx, d.cfg.AgentID); err != nil { log.Printf("Deregister failed: %v", err) } else { log.Println("Agent deregistered") @@ -338,12 +217,10 @@ func isTransientError(err error) bool { if err == nil { return false } - // Structured check: HTTPError carries the status code directly var httpErr *HTTPError if errors.As(err, &httpErr) { return httpErr.StatusCode == 429 || httpErr.StatusCode >= 500 } - // Fallback: network-level errors (no HTTP response received) lower := strings.ToLower(err.Error()) for _, keyword := range []string{"connection refused", "no such host", "timeout", "request failed"} { if strings.Contains(lower, keyword) { @@ -352,27 +229,3 @@ func isTransientError(err error) bool { } return false } - -func (d *Daemon) poll(ctx context.Context) { - resp, err := d.transport.ClaimTasks(ctx, d.cfg.AgentID) - if err != nil { - log.Printf("Poll failed: %v", err) - return - } - - d.Info.LastPollAt = time.Now() - - if len(resp.Tasks) > 0 { - log.Printf("Claimed %d task(s)", len(resp.Tasks)) - if d.OnTasksClaimed != nil { - d.OnTasksClaimed(resp.Tasks) - } - } - - // Handle stream requests for completed downloads - if d.OnStreamRequested != nil { - for _, sr := range resp.StreamRequests { - d.OnStreamRequested(sr) - } - } -} diff --git a/internal/agent/sync.go b/internal/agent/sync.go new file mode 100644 index 0000000..70129d4 --- /dev/null +++ b/internal/agent/sync.go @@ -0,0 +1,195 @@ +package agent + +import ( + "context" + "log" + "runtime" + "sync/atomic" + "time" +) + +const ( + // SyncIntervalWatching is the sync interval when someone is viewing the web UI. + SyncIntervalWatching = 3 * time.Second + // SyncIntervalIdle is the sync interval when nobody is watching. + SyncIntervalIdle = 60 * time.Second +) + +// SyncClient handles bidirectional state synchronization between the CLI and server. +// It sends the CLI's full execution state and receives all pending server actions +// in a single HTTP round-trip, at an adaptive interval. +type SyncClient struct { + client *Client + cfg DaemonConfig + state *LocalState + + // Callbacks — set by the daemon before calling Run. + OnNewTasks func(tasks []Task) + OnControl func(action, taskID string, deleteFiles bool) + OnStreamRequest func(req StreamRequest) + OnUpgrade func(version string) + OnScan func() + OnWatchingChange func(watching bool) + OnSyncSuccess func() // called after each successful sync (e.g. to update state file) + GetFreeSlots func() int + GetTaskStates func() []TaskState // returns current state of all active + recently finished tasks + + // SyncNow triggers an immediate sync (e.g., on task completion). + SyncNow chan struct{} + + watching atomic.Bool + interval atomic.Int64 // stored as nanoseconds +} + +// NewSyncClient creates a sync client. +func NewSyncClient(client *Client, cfg DaemonConfig, state *LocalState) *SyncClient { + sc := &SyncClient{ + client: client, + cfg: cfg, + state: state, + SyncNow: make(chan struct{}, 1), + } + sc.interval.Store(int64(SyncIntervalIdle)) + return sc +} + +// Watching returns whether someone is viewing the web UI. +func (sc *SyncClient) Watching() bool { + return sc.watching.Load() +} + +// TriggerSync requests an immediate sync cycle. +func (sc *SyncClient) TriggerSync() { + select { + case sc.SyncNow <- struct{}{}: + default: + } +} + +// Run starts the adaptive sync loop. Blocks until ctx is cancelled. +func (sc *SyncClient) Run(ctx context.Context) error { + // Initial sync immediately + sc.doSync(ctx) + + ticker := time.NewTicker(sc.currentInterval()) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + // Final sync to report latest state + finalCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + sc.doSync(finalCtx) + return nil + + case <-ticker.C: + sc.doSync(ctx) + ticker.Reset(sc.currentInterval()) + + case <-sc.SyncNow: + sc.doSync(ctx) + ticker.Reset(sc.currentInterval()) + } + } +} + +func (sc *SyncClient) currentInterval() time.Duration { + return time.Duration(sc.interval.Load()) +} + +func (sc *SyncClient) doSync(ctx context.Context) { + req := sc.buildRequest() + resp, err := sc.client.Sync(ctx, req) + if err != nil { + if ctx.Err() == nil { + log.Printf("sync failed: %v", err) + } + return + } + sc.processResponse(resp) + sc.adjustInterval(resp.Watching) + if sc.OnSyncSuccess != nil { + sc.OnSyncSuccess() + } +} + +func (sc *SyncClient) buildRequest() SyncRequest { + req := SyncRequest{ + AgentID: sc.cfg.AgentID, + Name: sc.cfg.AgentName, + Version: sc.cfg.Version, + OS: runtime.GOOS, + Arch: runtime.GOARCH, + DownloadDir: sc.cfg.DownloadDir, + StreamPort: sc.cfg.StreamPort, + LanIP: sc.cfg.LanIP, + TailscaleIP: sc.cfg.TailscaleIP, + } + if sc.GetTaskStates != nil { + req.Tasks = sc.GetTaskStates() + } else { + req.Tasks = sc.state.Snapshot() + } + if free, total, err := DiskInfo(sc.cfg.DownloadDir); err == nil { + req.DiskFreeBytes = free + req.DiskTotalBytes = total + } + if sc.GetFreeSlots != nil { + req.FreeSlots = sc.GetFreeSlots() + } + return req +} + +func (sc *SyncClient) processResponse(resp *SyncResponse) { + // New tasks + if len(resp.NewTasks) > 0 && sc.OnNewTasks != nil { + log.Printf("sync: received %d new task(s)", len(resp.NewTasks)) + sc.OnNewTasks(resp.NewTasks) + } + + // Control signals + for _, ctrl := range resp.Controls { + log.Printf("sync: control %s on task %s", ctrl.Action, ShortID(ctrl.TaskID)) + if sc.OnControl != nil { + sc.OnControl(ctrl.Action, ctrl.TaskID, ctrl.DeleteFiles) + } + } + + // Stream requests + for _, sr := range resp.StreamRequests { + if sc.OnStreamRequest != nil { + sc.OnStreamRequest(sr) + } + } + + // Upgrade + if resp.Upgrade != nil && resp.Upgrade.Version != "" && sc.OnUpgrade != nil { + sc.OnUpgrade(resp.Upgrade.Version) + } + + // Scan + if resp.Scan && sc.OnScan != nil { + sc.OnScan() + } +} + +func (sc *SyncClient) adjustInterval(watching bool) { + prev := sc.watching.Load() + sc.watching.Store(watching) + + var newInterval time.Duration + if watching { + newInterval = SyncIntervalWatching + } else { + newInterval = SyncIntervalIdle + } + + if sc.interval.Swap(int64(newInterval)) != int64(newInterval) { + log.Printf("sync: interval=%s (watching=%v)", newInterval, watching) + } + + if prev != watching && sc.OnWatchingChange != nil { + sc.OnWatchingChange(watching) + } +} diff --git a/internal/agent/sync_test.go b/internal/agent/sync_test.go new file mode 100644 index 0000000..ad3d9de --- /dev/null +++ b/internal/agent/sync_test.go @@ -0,0 +1,362 @@ +package agent + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" +) + +func newTestSyncClient(url string) (*SyncClient, *Client) { + client := NewClient(url, "test-key", "test-agent/1.0") + cfg := DaemonConfig{ + AgentID: "test-agent", + AgentName: "Test", + Version: "1.0.0", + DownloadDir: "/tmp/downloads", + } + state := NewLocalState() + sc := NewSyncClient(client, cfg, state) + return sc, client +} + +func TestSyncClient_NewDefaults(t *testing.T) { + sc, _ := newTestSyncClient("http://localhost") + + if sc.Watching() { + t.Error("should not be watching initially") + } + if sc.currentInterval() != SyncIntervalIdle { + t.Errorf("expected idle interval %v, got %v", SyncIntervalIdle, sc.currentInterval()) + } +} + +func TestSyncClient_AdjustInterval_Watching(t *testing.T) { + sc, _ := newTestSyncClient("http://localhost") + + sc.adjustInterval(true) + + if sc.currentInterval() != SyncIntervalWatching { + t.Errorf("expected watching interval %v, got %v", SyncIntervalWatching, sc.currentInterval()) + } + if !sc.Watching() { + t.Error("expected watching=true") + } +} + +func TestSyncClient_AdjustInterval_NotWatching(t *testing.T) { + sc, _ := newTestSyncClient("http://localhost") + + // First set watching, then unset + sc.adjustInterval(true) + sc.adjustInterval(false) + + if sc.currentInterval() != SyncIntervalIdle { + t.Errorf("expected idle interval %v, got %v", SyncIntervalIdle, sc.currentInterval()) + } + if sc.Watching() { + t.Error("expected watching=false") + } +} + +func TestSyncClient_AdjustInterval_CallsOnWatchingChange(t *testing.T) { + sc, _ := newTestSyncClient("http://localhost") + + var changes []bool + sc.OnWatchingChange = func(w bool) { changes = append(changes, w) } + + sc.adjustInterval(true) + sc.adjustInterval(true) // no change + sc.adjustInterval(false) // change + + if len(changes) != 2 { + t.Fatalf("expected 2 changes, got %d: %v", len(changes), changes) + } + if !changes[0] { + t.Error("first change should be true") + } + if changes[1] { + t.Error("second change should be false") + } +} + +func TestSyncClient_TriggerSync_NonBlocking(t *testing.T) { + sc, _ := newTestSyncClient("http://localhost") + + // Fill the channel + sc.TriggerSync() + // Should not block + sc.TriggerSync() + sc.TriggerSync() + + // Drain + select { + case <-sc.SyncNow: + default: + t.Error("expected a sync trigger in channel") + } +} + +func TestSyncClient_ProcessResponse_NewTasks(t *testing.T) { + sc, _ := newTestSyncClient("http://localhost") + + var received []Task + sc.OnNewTasks = func(tasks []Task) { received = tasks } + + sc.processResponse(&SyncResponse{ + NewTasks: []Task{ + {ID: "t1", Title: "Movie 1", InfoHash: "abc"}, + {ID: "t2", Title: "Movie 2", InfoHash: "def"}, + }, + }) + + if len(received) != 2 { + t.Fatalf("expected 2 tasks, got %d", len(received)) + } + if received[0].Title != "Movie 1" { + t.Errorf("expected Movie 1, got %s", received[0].Title) + } +} + +func TestSyncClient_ProcessResponse_NoTasks(t *testing.T) { + sc, _ := newTestSyncClient("http://localhost") + + var called bool + sc.OnNewTasks = func(tasks []Task) { called = true } + + sc.processResponse(&SyncResponse{NewTasks: nil}) + + if called { + t.Error("OnNewTasks should not be called with empty tasks") + } +} + +func TestSyncClient_ProcessResponse_Controls(t *testing.T) { + sc, _ := newTestSyncClient("http://localhost") + + var actions []string + var taskIDs []string + sc.OnControl = func(action, taskID string, deleteFiles bool) { + actions = append(actions, action) + taskIDs = append(taskIDs, taskID) + } + + sc.processResponse(&SyncResponse{ + Controls: []ControlAction{ + {Action: "cancel", TaskID: "task-1234-5678"}, + {Action: "pause", TaskID: "task-abcd-efgh"}, + }, + }) + + if len(actions) != 2 { + t.Fatalf("expected 2 controls, got %d", len(actions)) + } + if actions[0] != "cancel" { + t.Errorf("expected cancel, got %s", actions[0]) + } + if actions[1] != "pause" { + t.Errorf("expected pause, got %s", actions[1]) + } +} + +func TestSyncClient_ProcessResponse_Upgrade(t *testing.T) { + sc, _ := newTestSyncClient("http://localhost") + + var version string + sc.OnUpgrade = func(v string) { version = v } + + sc.processResponse(&SyncResponse{ + Upgrade: &UpgradeSignal{Version: "2.0.0"}, + }) + + if version != "2.0.0" { + t.Errorf("expected 2.0.0, got %s", version) + } +} + +func TestSyncClient_ProcessResponse_UpgradeEmpty(t *testing.T) { + sc, _ := newTestSyncClient("http://localhost") + + var called bool + sc.OnUpgrade = func(v string) { called = true } + + sc.processResponse(&SyncResponse{ + Upgrade: &UpgradeSignal{Version: ""}, + }) + + if called { + t.Error("OnUpgrade should not be called with empty version") + } +} + +func TestSyncClient_ProcessResponse_Scan(t *testing.T) { + sc, _ := newTestSyncClient("http://localhost") + + var called bool + sc.OnScan = func() { called = true } + + sc.processResponse(&SyncResponse{Scan: true}) + + if !called { + t.Error("OnScan should have been called") + } +} + +func TestSyncClient_ProcessResponse_StreamRequests(t *testing.T) { + sc, _ := newTestSyncClient("http://localhost") + + var received []StreamRequest + sc.OnStreamRequest = func(sr StreamRequest) { received = append(received, sr) } + + sc.processResponse(&SyncResponse{ + StreamRequests: []StreamRequest{ + {TaskID: "t1", FilePath: "/tmp/movie.mkv"}, + }, + }) + + if len(received) != 1 { + t.Fatalf("expected 1 stream request, got %d", len(received)) + } + if received[0].FilePath != "/tmp/movie.mkv" { + t.Errorf("expected /tmp/movie.mkv, got %s", received[0].FilePath) + } +} + +func TestSyncClient_BuildRequest_WithGetTaskStates(t *testing.T) { + sc, _ := newTestSyncClient("http://localhost") + + sc.GetTaskStates = func() []TaskState { + return []TaskState{ + {TaskID: "t1", Status: "downloading", Progress: 50}, + } + } + sc.GetFreeSlots = func() int { return 2 } + + req := sc.buildRequest() + + if req.AgentID != "test-agent" { + t.Errorf("expected test-agent, got %s", req.AgentID) + } + if len(req.Tasks) != 1 { + t.Fatalf("expected 1 task, got %d", len(req.Tasks)) + } + if req.Tasks[0].Progress != 50 { + t.Errorf("expected progress 50, got %d", req.Tasks[0].Progress) + } + if req.FreeSlots != 2 { + t.Errorf("expected 2 free slots, got %d", req.FreeSlots) + } +} + +func TestSyncClient_BuildRequest_FallbackToState(t *testing.T) { + client := NewClient("http://localhost", "key", "ua") + state := NewLocalState() + state.Update(TaskState{TaskID: "t1", Status: "completed", Progress: 100}) + + sc := NewSyncClient(client, DaemonConfig{AgentID: "a1", Version: "1.0"}, state) + // GetTaskStates is nil — should fall back to state.Snapshot() + + req := sc.buildRequest() + if len(req.Tasks) != 1 { + t.Fatalf("expected 1 task from state fallback, got %d", len(req.Tasks)) + } +} + +func TestSyncClient_DoSync_Success(t *testing.T) { + var syncCount atomic.Int32 + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + syncCount.Add(1) + json.NewEncoder(w).Encode(SyncResponse{ + Watching: true, + NewTasks: []Task{{ID: "t1", Title: "Test Movie", InfoHash: "abc"}}, + }) + })) + defer srv.Close() + + sc, _ := newTestSyncClient(srv.URL) + + var tasksReceived []Task + sc.OnNewTasks = func(tasks []Task) { tasksReceived = tasks } + + sc.doSync(context.Background()) + + if syncCount.Load() != 1 { + t.Errorf("expected 1 sync call, got %d", syncCount.Load()) + } + if len(tasksReceived) != 1 { + t.Fatalf("expected 1 task, got %d", len(tasksReceived)) + } + if !sc.Watching() { + t.Error("expected watching=true after sync") + } + if sc.currentInterval() != SyncIntervalWatching { + t.Errorf("expected watching interval after sync") + } +} + +func TestSyncClient_DoSync_Error(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer srv.Close() + + sc, _ := newTestSyncClient(srv.URL) + + // Should not panic on error + sc.doSync(context.Background()) +} + +func TestSyncClient_Run_CancelStopsLoop(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(SyncResponse{}) + })) + defer srv.Close() + + sc, _ := newTestSyncClient(srv.URL) + + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + + err := sc.Run(ctx) + if err != nil { + t.Errorf("expected nil error, got %v", err) + } +} + +func TestSyncClient_Run_ImmediateSyncOnTrigger(t *testing.T) { + var syncCount atomic.Int32 + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + syncCount.Add(1) + json.NewEncoder(w).Encode(SyncResponse{}) + })) + defer srv.Close() + + sc, _ := newTestSyncClient(srv.URL) + // Set interval to something long so only triggers cause syncs + sc.interval.Store(int64(10 * time.Second)) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + // Wait for initial sync, then trigger 2 more + time.Sleep(50 * time.Millisecond) + sc.TriggerSync() + time.Sleep(50 * time.Millisecond) + sc.TriggerSync() + time.Sleep(50 * time.Millisecond) + cancel() + }() + + sc.Run(ctx) + + // Initial sync (1) + 2 triggers + final sync = 4 + count := syncCount.Load() + if count < 3 { + t.Errorf("expected at least 3 syncs (initial + 2 triggers), got %d", count) + } +} diff --git a/internal/agent/taskstate.go b/internal/agent/taskstate.go new file mode 100644 index 0000000..51eba8b --- /dev/null +++ b/internal/agent/taskstate.go @@ -0,0 +1,136 @@ +package agent + +import ( + "encoding/json" + "os" + "path/filepath" + "sync" + "time" + + "github.com/torrentclaw/unarr/internal/config" +) + +// TaskState represents the execution state of a single download task. +// Written by the Task Engine, read by the Sync goroutine. +type TaskState struct { + TaskID string `json:"taskId"` + Status string `json:"status"` // resolving, downloading, verifying, organizing, completed, failed + Progress int `json:"progress"` + DownloadedBytes int64 `json:"downloadedBytes,omitempty"` + TotalBytes int64 `json:"totalBytes,omitempty"` + SpeedBps int64 `json:"speedBps,omitempty"` + ETA int `json:"eta,omitempty"` + ResolvedMethod string `json:"resolvedMethod,omitempty"` + FileName string `json:"fileName,omitempty"` + FilePath string `json:"filePath,omitempty"` + StreamURL string `json:"streamUrl,omitempty"` + ErrorMessage string `json:"errorMessage,omitempty"` + UpdatedAt int64 `json:"updatedAt"` +} + +// LocalState holds the CLI's local execution state (tasks.json). +// This is the CLI's source of truth for what it's doing right now. +type LocalState struct { + mu sync.RWMutex + tasks map[string]*TaskState +} + +// NewLocalState creates an empty local state. +func NewLocalState() *LocalState { + return &LocalState{ + tasks: make(map[string]*TaskState), + } +} + +// Update adds or updates a task in local state. +func (s *LocalState) Update(ts TaskState) { + s.mu.Lock() + defer s.mu.Unlock() + ts.UpdatedAt = time.Now().Unix() + copied := ts + s.tasks[ts.TaskID] = &copied +} + +// Remove removes a task from local state. +func (s *LocalState) Remove(taskID string) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.tasks, taskID) +} + +// Snapshot returns a copy of all current task states. +func (s *LocalState) Snapshot() []TaskState { + s.mu.RLock() + defer s.mu.RUnlock() + result := make([]TaskState, 0, len(s.tasks)) + for _, ts := range s.tasks { + result = append(result, *ts) + } + return result +} + +// TaskStateFromUpdate converts a StatusUpdate into a TaskState. +func TaskStateFromUpdate(u StatusUpdate) TaskState { + return TaskState{ + TaskID: u.TaskID, + Status: u.Status, + Progress: u.Progress, + DownloadedBytes: u.DownloadedBytes, + TotalBytes: u.TotalBytes, + SpeedBps: u.SpeedBps, + ETA: u.ETA, + ResolvedMethod: u.ResolvedMethod, + FileName: u.FileName, + FilePath: u.FilePath, + StreamURL: u.StreamURL, + ErrorMessage: u.ErrorMessage, + } +} + +// ShortID returns the first 8 characters of an ID, or the full ID if shorter. +func ShortID(id string) string { + if len(id) > 8 { + return id[:8] + } + return id +} + +// taskStateFilePathFn is overridable for testing. +var taskStateFilePathFn = func() string { + return filepath.Join(config.DataDir(), "tasks.json") +} + +// WriteToDisk persists local state to disk atomically (best-effort). +func (s *LocalState) WriteToDisk() { + tasks := s.Snapshot() + data, err := json.MarshalIndent(tasks, "", " ") + if err != nil { + return + } + path := taskStateFilePathFn() + dir := filepath.Dir(path) + os.MkdirAll(dir, 0o755) + tmp := path + ".tmp" + if err := os.WriteFile(tmp, data, 0o644); err != nil { + return + } + os.Rename(tmp, path) +} + +// ReadFromDisk loads local state from disk. Returns empty state on error. +func (s *LocalState) ReadFromDisk() { + data, err := os.ReadFile(taskStateFilePathFn()) + if err != nil { + return + } + var tasks []TaskState + if json.Unmarshal(data, &tasks) != nil { + return + } + s.mu.Lock() + defer s.mu.Unlock() + s.tasks = make(map[string]*TaskState, len(tasks)) + for i := range tasks { + s.tasks[tasks[i].TaskID] = &tasks[i] + } +} diff --git a/internal/agent/taskstate_test.go b/internal/agent/taskstate_test.go new file mode 100644 index 0000000..18814f7 --- /dev/null +++ b/internal/agent/taskstate_test.go @@ -0,0 +1,217 @@ +package agent + +import ( + "os" + "path/filepath" + "sync" + "testing" +) + +func TestLocalState_UpdateAndSnapshot(t *testing.T) { + s := NewLocalState() + + s.Update(TaskState{TaskID: "t1", Status: "downloading", Progress: 50}) + s.Update(TaskState{TaskID: "t2", Status: "completed", Progress: 100}) + + snap := s.Snapshot() + if len(snap) != 2 { + t.Fatalf("expected 2 tasks, got %d", len(snap)) + } + + byID := make(map[string]TaskState, len(snap)) + for _, ts := range snap { + byID[ts.TaskID] = ts + } + + if byID["t1"].Progress != 50 { + t.Errorf("expected progress 50, got %d", byID["t1"].Progress) + } + if byID["t2"].Status != "completed" { + t.Errorf("expected completed, got %s", byID["t2"].Status) + } +} + +func TestLocalState_UpdateOverwrites(t *testing.T) { + s := NewLocalState() + + s.Update(TaskState{TaskID: "t1", Status: "downloading", Progress: 30}) + s.Update(TaskState{TaskID: "t1", Status: "downloading", Progress: 70}) + + snap := s.Snapshot() + if len(snap) != 1 { + t.Fatalf("expected 1 task, got %d", len(snap)) + } + if snap[0].Progress != 70 { + t.Errorf("expected progress 70, got %d", snap[0].Progress) + } +} + +func TestLocalState_Remove(t *testing.T) { + s := NewLocalState() + + s.Update(TaskState{TaskID: "t1", Status: "downloading"}) + s.Update(TaskState{TaskID: "t2", Status: "downloading"}) + s.Remove("t1") + + snap := s.Snapshot() + if len(snap) != 1 { + t.Fatalf("expected 1 task, got %d", len(snap)) + } + if snap[0].TaskID != "t2" { + t.Errorf("expected t2, got %s", snap[0].TaskID) + } +} + +func TestLocalState_RemoveNonExistent(t *testing.T) { + s := NewLocalState() + s.Remove("nonexistent") // should not panic +} + +func TestLocalState_SnapshotIsACopy(t *testing.T) { + s := NewLocalState() + s.Update(TaskState{TaskID: "t1", Status: "downloading", Progress: 50}) + + snap := s.Snapshot() + snap[0].Progress = 999 + + snap2 := s.Snapshot() + if snap2[0].Progress != 50 { + t.Errorf("snapshot mutation leaked: got progress %d", snap2[0].Progress) + } +} + +func TestLocalState_UpdateSetsTimestamp(t *testing.T) { + s := NewLocalState() + s.Update(TaskState{TaskID: "t1", Status: "downloading"}) + + snap := s.Snapshot() + if snap[0].UpdatedAt == 0 { + t.Error("expected non-zero UpdatedAt") + } +} + +func TestLocalState_ConcurrentAccess(t *testing.T) { + s := NewLocalState() + var wg sync.WaitGroup + + for i := range 100 { + wg.Add(1) + go func(n int) { + defer wg.Done() + taskID := "t" + string(rune('0'+n%10)) + s.Update(TaskState{TaskID: taskID, Status: "downloading", Progress: n}) + s.Snapshot() + if n%3 == 0 { + s.Remove(taskID) + } + }(i) + } + + wg.Wait() + // No race condition = test passes +} + +func TestLocalState_WriteToDisk_ReadFromDisk(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "tasks.json") + + // Override the file path for testing + orig := taskStateFilePathFn + taskStateFilePathFn = func() string { return path } + defer func() { taskStateFilePathFn = orig }() + + s := NewLocalState() + s.Update(TaskState{TaskID: "t1", Status: "downloading", Progress: 45}) + s.Update(TaskState{TaskID: "t2", Status: "completed", Progress: 100, FilePath: "/tmp/movie.mkv"}) + s.WriteToDisk() + + // Verify file exists + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Fatal("tasks.json was not created") + } + + // Read into a new LocalState + s2 := NewLocalState() + s2.ReadFromDisk() + + snap := s2.Snapshot() + if len(snap) != 2 { + t.Fatalf("expected 2 tasks after read, got %d", len(snap)) + } + + byID := make(map[string]TaskState, len(snap)) + for _, ts := range snap { + byID[ts.TaskID] = ts + } + + if byID["t1"].Progress != 45 { + t.Errorf("expected progress 45, got %d", byID["t1"].Progress) + } + if byID["t2"].FilePath != "/tmp/movie.mkv" { + t.Errorf("expected /tmp/movie.mkv, got %s", byID["t2"].FilePath) + } +} + +func TestLocalState_ReadFromDisk_CorruptedFile(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "tasks.json") + + orig := taskStateFilePathFn + taskStateFilePathFn = func() string { return path } + defer func() { taskStateFilePathFn = orig }() + + // Write corrupted JSON + os.WriteFile(path, []byte("{invalid json"), 0o644) + + s := NewLocalState() + s.ReadFromDisk() // should not panic + + snap := s.Snapshot() + if len(snap) != 0 { + t.Errorf("expected 0 tasks from corrupted file, got %d", len(snap)) + } +} + +func TestLocalState_ReadFromDisk_FileNotFound(t *testing.T) { + orig := taskStateFilePathFn + taskStateFilePathFn = func() string { return "/nonexistent/path/tasks.json" } + defer func() { taskStateFilePathFn = orig }() + + s := NewLocalState() + s.ReadFromDisk() // should not panic + + snap := s.Snapshot() + if len(snap) != 0 { + t.Errorf("expected 0 tasks, got %d", len(snap)) + } +} + +func TestLocalState_AtomicWrite(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "tasks.json") + + orig := taskStateFilePathFn + taskStateFilePathFn = func() string { return path } + defer func() { taskStateFilePathFn = orig }() + + s := NewLocalState() + s.Update(TaskState{TaskID: "t1", Status: "downloading"}) + s.WriteToDisk() + + // Verify no .tmp file remains + tmpPath := path + ".tmp" + if _, err := os.Stat(tmpPath); !os.IsNotExist(err) { + t.Error("temp file should not exist after write") + } +} + +func TestLocalState_EmptySnapshot(t *testing.T) { + s := NewLocalState() + snap := s.Snapshot() + if snap == nil { + t.Error("snapshot should be non-nil empty slice") + } + if len(snap) != 0 { + t.Errorf("expected 0 tasks, got %d", len(snap)) + } +} diff --git a/internal/agent/transport.go b/internal/agent/transport.go deleted file mode 100644 index 5e223fb..0000000 --- a/internal/agent/transport.go +++ /dev/null @@ -1,51 +0,0 @@ -package agent - -import "context" - -// Transport abstracts the communication protocol between the agent and server. -// Both WebSocket (via CF Durable Object) and HTTP (direct to origin) implement this. -type Transport interface { - // Connect establishes the transport connection. - // Called internally by Daemon.Run — callers must NOT call Connect separately. - Connect(ctx context.Context) error - - // Close tears down the connection gracefully. - Close() error - - // Mode returns the current transport mode ("ws" or "http"). - Mode() string - - // Register sends agent registration and returns user info + features. - Register(ctx context.Context, req RegisterRequest) (*RegisterResponse, error) - - // SendHeartbeat sends a periodic keep-alive. - SendHeartbeat(ctx context.Context, req HeartbeatRequest) (*HeartbeatResponse, error) - - // SendProgress reports download progress for a task. - SendProgress(ctx context.Context, update StatusUpdate) (*StatusResponse, error) - - // ClaimTasks polls for new tasks (HTTP mode only; WS receives via Events). - ClaimTasks(ctx context.Context, agentID string) (*TasksResponse, error) - - // Deregister notifies the server of graceful shutdown. - Deregister(ctx context.Context, agentID string) error - - // Events returns a channel that emits server-initiated events. - // In HTTP mode this channel is never written to (polling handles it). - // In WS mode, tasks/upgrade/control arrive here. - Events() <-chan ServerEvent -} - -// ServerEvent represents a server-initiated message received via WebSocket. -type ServerEvent struct { - Type string // "tasks", "upgrade", "control", "disconnected" - Tasks *TasksResponse // populated when Type == "tasks" - Upgrade *UpgradeSignal // populated when Type == "upgrade" - Control *ControlAction // populated when Type == "control" -} - -// ControlAction represents a server push for task control. -type ControlAction struct { - Action string `json:"action"` // "pause", "resume", "cancel", "stream" - TaskID string `json:"taskId"` -} diff --git a/internal/agent/transport_e2e_test.go b/internal/agent/transport_e2e_test.go deleted file mode 100644 index 01de3cb..0000000 --- a/internal/agent/transport_e2e_test.go +++ /dev/null @@ -1,285 +0,0 @@ -package agent - -import ( - "context" - "encoding/json" - "net/http" - "net/http/httptest" - "strings" - "sync" - "testing" - "time" -) - -// TestE2EFullLifecycle tests the full lifecycle: -// connect → auth → receive tasks → send progress → receive control → disconnect → reconnect -func TestE2EFullLifecycle(t *testing.T) { - var mu sync.Mutex - var receivedMessages []map[string]interface{} - - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return - } - defer conn.Close() - - for { - _, msg, err := conn.ReadMessage() - if err != nil { - return - } - - var parsed map[string]interface{} - json.Unmarshal(msg, &parsed) - - mu.Lock() - receivedMessages = append(receivedMessages, parsed) - mu.Unlock() - - msgType, _ := parsed["type"].(string) - switch msgType { - case "auth": - conn.WriteJSON(wsRegisteredMessage{ - Type: "registered", - User: UserInfo{Name: "E2E User", Plan: "pro", IsPro: true}, - Features: FeatureFlags{Torrent: true, Debrid: true}, - }) - - case "heartbeat": - // No response in WS mode - - case "progress": - // Simulate server-side cancel after progress - if progress, ok := parsed["progress"].(float64); ok && progress >= 50 { - conn.WriteJSON(map[string]string{ - "type": "control", - "action": "cancel", - "taskId": parsed["taskId"].(string), - }) - } - - case "upgrade-result": - // Acknowledged - } - } - })) - defer srv.Close() - - wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") - tr := NewWSTransport(wsURL, "e2e-key", "e2e-agent", "test/1.0") - - ctx := context.Background() - - // 1. Connect - if err := tr.Connect(ctx); err != nil { - t.Fatalf("Connect: %v", err) - } - defer tr.Close() - - // 2. Auth - resp, err := tr.Register(ctx, RegisterRequest{ - AgentID: "e2e-agent", - Name: "E2E Test Agent", - Version: "1.0.0", - OS: "linux", - Arch: "amd64", - }) - if err != nil { - t.Fatalf("Register: %v", err) - } - if resp.User.Name != "E2E User" { - t.Errorf("expected E2E User, got %s", resp.User.Name) - } - if !resp.Features.Debrid { - t.Error("expected debrid feature") - } - - // 3. Send heartbeat - _, err = tr.SendHeartbeat(ctx, HeartbeatRequest{ - AgentID: "e2e-agent", - DiskFreeBytes: 1000000000, - DiskTotalBytes: 5000000000, - }) - if err != nil { - t.Fatalf("SendHeartbeat: %v", err) - } - - // 4. Send progress (50% → should trigger cancel control) - _, err = tr.SendProgress(ctx, StatusUpdate{ - TaskID: "task-e2e-1", - Status: "downloading", - Progress: 50, - DownloadedBytes: 500, - TotalBytes: 1000, - SpeedBps: 100, - }) - if err != nil { - t.Fatalf("SendProgress: %v", err) - } - - // 5. Wait for control event (cancel) - select { - case event := <-tr.Events(): - if event.Type != "control" { - t.Errorf("expected control event, got %s", event.Type) - } - if event.Control.Action != "cancel" { - t.Errorf("expected cancel, got %s", event.Control.Action) - } - if event.Control.TaskID != "task-e2e-1" { - t.Errorf("expected task-e2e-1, got %s", event.Control.TaskID) - } - case <-time.After(3 * time.Second): - t.Fatal("timeout waiting for cancel control") - } - - // Verify server received all messages - time.Sleep(100 * time.Millisecond) - mu.Lock() - defer mu.Unlock() - - if len(receivedMessages) < 3 { - t.Fatalf("expected at least 3 messages, got %d", len(receivedMessages)) - } - - types := make([]string, len(receivedMessages)) - for i, m := range receivedMessages { - types[i], _ = m["type"].(string) - } - - expected := []string{"auth", "heartbeat", "progress"} - for _, exp := range expected { - found := false - for _, got := range types { - if got == exp { - found = true - break - } - } - if !found { - t.Errorf("missing message type %q in %v", exp, types) - } - } -} - -// TestE2EHybridFailover tests the full failover scenario: -// WS connect → download → WS disconnect → switch to HTTP → continue working -func TestE2EHybridFailover(t *testing.T) { - connectionCount := 0 - var mu sync.Mutex - - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return - } - - mu.Lock() - connectionCount++ - connNum := connectionCount - mu.Unlock() - - // Read auth - conn.ReadMessage() - conn.WriteJSON(wsRegisteredMessage{ - Type: "registered", - User: UserInfo{Name: "Failover User"}, - }) - - if connNum == 1 { - // First connection: push tasks then disconnect after 200ms - time.Sleep(50 * time.Millisecond) - conn.WriteJSON(wsTasksMessage{ - Type: "tasks", - Tasks: []Task{{ID: "t1", InfoHash: "abc", Title: "Failover Movie"}}, - }) - time.Sleep(150 * time.Millisecond) - conn.Close() - } else { - // Second connection (after reconnect): push upgrade - time.Sleep(50 * time.Millisecond) - conn.WriteJSON(wsUpgradeMessage{Type: "upgrade", Version: "3.0.0"}) - time.Sleep(500 * time.Millisecond) - conn.Close() - } - })) - defer srv.Close() - - wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") - wsT := NewWSTransport(wsURL, "key", "a1", "ua") - - // HTTP mock for fallback - httpSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Simple heartbeat response - json.NewEncoder(w).Encode(HeartbeatResponse{Success: true}) - })) - defer httpSrv.Close() - - httpT := NewHTTPTransport(httpSrv.URL, "key", "ua") - h := NewHybridTransport(wsT, httpT) - - ctx := context.Background() - err := h.Connect(ctx) - if err != nil { - t.Fatalf("Connect: %v", err) - } - defer h.Close() - - // Should start in WS mode - if h.Mode() != "ws" { - t.Fatalf("expected ws mode, got %s", h.Mode()) - } - - // Register via WS - _, err = h.Register(ctx, RegisterRequest{AgentID: "a1"}) - if err != nil { - t.Fatalf("Register: %v", err) - } - - // Receive tasks via WS - var tasksReceived bool - var disconnected bool - - for i := 0; i < 3; i++ { - select { - case event := <-h.Events(): - switch event.Type { - case "tasks": - tasksReceived = true - if len(event.Tasks.Tasks) != 1 || event.Tasks.Tasks[0].Title != "Failover Movie" { - t.Errorf("unexpected tasks: %+v", event.Tasks) - } - case "disconnected": - disconnected = true - } - case <-time.After(2 * time.Second): - break - } - if disconnected { - break - } - } - - if !tasksReceived { - t.Error("did not receive tasks before disconnect") - } - if !disconnected { - t.Error("did not receive disconnect event") - } - - // Should now be in HTTP mode - time.Sleep(100 * time.Millisecond) - if h.Mode() != "http" { - t.Errorf("expected http mode after disconnect, got %s", h.Mode()) - } - - // Heartbeat should work via HTTP fallback - hbResp, err := h.SendHeartbeat(ctx, HeartbeatRequest{AgentID: "a1"}) - if err != nil { - t.Fatalf("SendHeartbeat via HTTP fallback: %v", err) - } - if !hbResp.Success { - t.Error("expected heartbeat success") - } -} diff --git a/internal/agent/transport_http.go b/internal/agent/transport_http.go deleted file mode 100644 index 6bce13b..0000000 --- a/internal/agent/transport_http.go +++ /dev/null @@ -1,50 +0,0 @@ -package agent - -import "context" - -// HTTPTransport wraps the existing Client to implement Transport. -// This is a thin adapter — no behavioral changes from the current HTTP protocol. -type HTTPTransport struct { - client *Client - events chan ServerEvent -} - -// NewHTTPTransport creates a new HTTP-based transport. -func NewHTTPTransport(baseURL, apiKey, userAgent string) *HTTPTransport { - return &HTTPTransport{ - client: NewClient(baseURL, apiKey, userAgent), - events: make(chan ServerEvent, 10), - } -} - -func (t *HTTPTransport) Connect(_ context.Context) error { return nil } -func (t *HTTPTransport) Close() error { return nil } -func (t *HTTPTransport) Mode() string { return "http" } -func (t *HTTPTransport) Events() <-chan ServerEvent { return t.events } - -func (t *HTTPTransport) Register(ctx context.Context, req RegisterRequest) (*RegisterResponse, error) { - return t.client.Register(ctx, req) -} - -func (t *HTTPTransport) SendHeartbeat(ctx context.Context, req HeartbeatRequest) (*HeartbeatResponse, error) { - return t.client.Heartbeat(ctx, req) -} - -func (t *HTTPTransport) SendProgress(ctx context.Context, update StatusUpdate) (*StatusResponse, error) { - return t.client.ReportStatus(ctx, update) -} - -func (t *HTTPTransport) BatchReportStatus(ctx context.Context, updates []StatusUpdate) (*BatchStatusResponse, error) { - return t.client.BatchReportStatus(ctx, updates) -} - -func (t *HTTPTransport) ClaimTasks(ctx context.Context, agentID string) (*TasksResponse, error) { - return t.client.ClaimTasks(ctx, agentID) -} - -func (t *HTTPTransport) Deregister(ctx context.Context, agentID string) error { - return t.client.Deregister(ctx, agentID) -} - -// Client returns the underlying HTTP client for direct use if needed. -func (t *HTTPTransport) Client() *Client { return t.client } diff --git a/internal/agent/transport_hybrid.go b/internal/agent/transport_hybrid.go deleted file mode 100644 index 3a4b51e..0000000 --- a/internal/agent/transport_hybrid.go +++ /dev/null @@ -1,214 +0,0 @@ -package agent - -import ( - "context" - "log" - "sync" - "sync/atomic" - "time" -) - -// HybridTransport tries WebSocket first, falls back to HTTP if WS fails. -// Automatically reconnects WS in the background. -type HybridTransport struct { - ws *WSTransport - http *HTTPTransport - - mode atomic.Value // "ws" or "http" - events chan ServerEvent - - reconnectMu sync.Mutex - reconnectRunning bool - reconnectStop chan struct{} - closed atomic.Bool -} - -// NewHybridTransport creates a transport that prefers WS with HTTP fallback. -func NewHybridTransport(ws *WSTransport, http *HTTPTransport) *HybridTransport { - h := &HybridTransport{ - ws: ws, - http: http, - events: make(chan ServerEvent, 50), - reconnectStop: make(chan struct{}), - } - h.mode.Store("http") // start in HTTP, upgrade to WS on Connect - return h -} - -func (h *HybridTransport) Mode() string { return h.mode.Load().(string) } -func (h *HybridTransport) Events() <-chan ServerEvent { return h.events } - -// Connect tries WS first. If it fails, falls back to HTTP and starts reconnection loop. -func (h *HybridTransport) Connect(ctx context.Context) error { - // Try WebSocket first - if err := h.ws.Connect(ctx); err != nil { - log.Printf("[transport] WebSocket connect failed (%v), using HTTP fallback", err) - h.mode.Store("http") - h.startReconnectLoop() - return h.http.Connect(ctx) - } - - h.mode.Store("ws") - log.Println("[transport] Connected via WebSocket") - - // Forward WS events to unified channel + watch for disconnection - go h.forwardWSEvents() - - return nil -} - -// Close shuts down both transports and stops reconnection. -func (h *HybridTransport) Close() error { - h.closed.Store(true) - select { - case <-h.reconnectStop: - default: - close(h.reconnectStop) - } - _ = h.ws.Close() - return h.http.Close() -} - -// Register delegates to the active transport. -func (h *HybridTransport) Register(ctx context.Context, req RegisterRequest) (*RegisterResponse, error) { - if h.mode.Load() == "ws" { - return h.ws.Register(ctx, req) - } - return h.http.Register(ctx, req) -} - -// SendHeartbeat delegates to the active transport. -func (h *HybridTransport) SendHeartbeat(ctx context.Context, req HeartbeatRequest) (*HeartbeatResponse, error) { - if h.mode.Load() == "ws" { - resp, err := h.ws.SendHeartbeat(ctx, req) - if err != nil { - // WS write failed — switch to HTTP - h.switchToHTTP() - return h.http.SendHeartbeat(ctx, req) - } - return resp, nil - } - return h.http.SendHeartbeat(ctx, req) -} - -// SendProgress delegates to the active transport. -func (h *HybridTransport) SendProgress(ctx context.Context, update StatusUpdate) (*StatusResponse, error) { - if h.mode.Load() == "ws" { - resp, err := h.ws.SendProgress(ctx, update) - if err != nil { - h.switchToHTTP() - return h.http.SendProgress(ctx, update) - } - return resp, nil - } - return h.http.SendProgress(ctx, update) -} - -// ClaimTasks delegates to the active transport. -func (h *HybridTransport) ClaimTasks(ctx context.Context, agentID string) (*TasksResponse, error) { - if h.mode.Load() == "ws" { - return h.ws.ClaimTasks(ctx, agentID) // no-op in WS mode - } - return h.http.ClaimTasks(ctx, agentID) -} - -// Deregister delegates to the active transport. -func (h *HybridTransport) Deregister(ctx context.Context, agentID string) error { - if h.mode.Load() == "ws" { - return h.ws.Deregister(ctx, agentID) - } - return h.http.Deregister(ctx, agentID) -} - -// ── Internal ───────────────────────────────────────────────────────────────── - -func (h *HybridTransport) switchToHTTP() { - if h.mode.Load() == "http" { - return - } - log.Println("[transport] Switching to HTTP fallback") - h.mode.Store("http") - _ = h.ws.Close() - h.startReconnectLoop() -} - -func (h *HybridTransport) forwardWSEvents() { - for { - select { - case <-h.reconnectStop: - return - case event, ok := <-h.ws.Events(): - if !ok { - return // channel closed - } - if event.Type == "disconnected" { - h.switchToHTTP() - select { - case h.events <- event: - default: - } - return - } - select { - case h.events <- event: - default: - log.Printf("[transport] events channel full, dropping %s event", event.Type) - } - } - } -} - -func (h *HybridTransport) startReconnectLoop() { - h.reconnectMu.Lock() - defer h.reconnectMu.Unlock() - if h.reconnectRunning { - return - } - h.reconnectRunning = true - go h.reconnectLoop() -} - -func (h *HybridTransport) reconnectLoop() { - backoff := 5 * time.Second - maxBackoff := 60 * time.Second - - for { - select { - case <-h.reconnectStop: - return - case <-time.After(backoff): - } - - if h.closed.Load() { - return - } - - // Already on WS? (someone else reconnected) - if h.mode.Load() == "ws" { - return - } - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - err := h.ws.Connect(ctx) - cancel() - - if err != nil { - log.Printf("[transport] WS reconnect failed: %v (retry in %v)", err, backoff) - backoff = min(backoff*2, maxBackoff) - continue - } - - // WS reconnected — switch back - log.Println("[transport] WebSocket reconnected") - h.mode.Store("ws") - - // Reset reconnect flag so loop can start again if WS drops - h.reconnectMu.Lock() - h.reconnectRunning = false - h.reconnectMu.Unlock() - - // Forward events from new WS connection - go h.forwardWSEvents() - return - } -} diff --git a/internal/agent/transport_test.go b/internal/agent/transport_test.go deleted file mode 100644 index be2f6c6..0000000 --- a/internal/agent/transport_test.go +++ /dev/null @@ -1,1590 +0,0 @@ -package agent - -import ( - "context" - "encoding/json" - "net/http" - "net/http/httptest" - "strings" - "sync" - "testing" - "time" - - "github.com/gorilla/websocket" -) - -// ── HTTP Transport Tests ───────────────────────────────────────────────────── - -func TestHTTPTransportMode(t *testing.T) { - tr := NewHTTPTransport("http://localhost", "key", "ua") - if tr.Mode() != "http" { - t.Errorf("expected http, got %s", tr.Mode()) - } -} - -func TestHTTPTransportEventsNeverEmit(t *testing.T) { - tr := NewHTTPTransport("http://localhost", "key", "ua") - select { - case <-tr.Events(): - t.Error("events channel should never emit in HTTP mode") - case <-time.After(50 * time.Millisecond): - // expected - } -} - -func TestHTTPTransportDelegates(t *testing.T) { - // Mock server for register - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - json.NewEncoder(w).Encode(RegisterResponse{ - Success: true, - User: UserInfo{Name: "Test", Plan: "pro"}, - }) - })) - defer srv.Close() - - tr := NewHTTPTransport(srv.URL, "test-key", "test-agent") - resp, err := tr.Register(context.Background(), RegisterRequest{AgentID: "a1"}) - if err != nil { - t.Fatalf("Register failed: %v", err) - } - if !resp.Success { - t.Error("expected success") - } - if resp.User.Name != "Test" { - t.Errorf("expected Test, got %s", resp.User.Name) - } -} - -// ── WebSocket Transport Tests ──────────────────────────────────────────────── - -var upgrader = websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { return true }, -} - -func TestWSTransportConnectAndAuth(t *testing.T) { - var received wsAuthMessage - var mu sync.Mutex - - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - t.Fatalf("upgrade: %v", err) - } - defer conn.Close() - - // Read auth message - _, msg, err := conn.ReadMessage() - if err != nil { - return - } - mu.Lock() - json.Unmarshal(msg, &received) - mu.Unlock() - - // Send registered response - conn.WriteJSON(wsRegisteredMessage{ - Type: "registered", - User: UserInfo{Name: "WS User", Plan: "pro", IsPro: true}, - Features: FeatureFlags{Torrent: true}, - }) - - // Keep connection open - time.Sleep(500 * time.Millisecond) - })) - defer srv.Close() - - wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") - tr := NewWSTransport(wsURL, "my-api-key", "agent-123", "test/1.0") - - ctx := context.Background() - if err := tr.Connect(ctx); err != nil { - t.Fatalf("Connect failed: %v", err) - } - defer tr.Close() - - resp, err := tr.Register(ctx, RegisterRequest{ - AgentID: "agent-123", - Name: "test-agent", - Version: "1.0.0", - }) - if err != nil { - t.Fatalf("Register failed: %v", err) - } - if !resp.Success { - t.Error("expected success") - } - if resp.User.Name != "WS User" { - t.Errorf("expected WS User, got %s", resp.User.Name) - } - - mu.Lock() - if received.APIKey != "my-api-key" { - t.Errorf("expected my-api-key, got %s", received.APIKey) - } - if received.AgentID != "agent-123" { - t.Errorf("expected agent-123, got %s", received.AgentID) - } - mu.Unlock() -} - -func TestWSTransportReceiveTasks(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return - } - defer conn.Close() - - // Read auth - conn.ReadMessage() - conn.WriteJSON(wsRegisteredMessage{ - Type: "registered", - User: UserInfo{Name: "Test"}, - }) - - // Push tasks - time.Sleep(50 * time.Millisecond) - conn.WriteJSON(wsTasksMessage{ - Type: "tasks", - Tasks: []Task{ - {ID: "t1", InfoHash: "abc123", Title: "Test Movie"}, - {ID: "t2", InfoHash: "def456", Title: "Test Show"}, - }, - }) - - time.Sleep(500 * time.Millisecond) - })) - defer srv.Close() - - wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") - tr := NewWSTransport(wsURL, "key", "agent1", "ua") - - ctx := context.Background() - tr.Connect(ctx) - defer tr.Close() - - tr.Register(ctx, RegisterRequest{AgentID: "agent1"}) - - // Wait for tasks event - select { - case event := <-tr.Events(): - if event.Type != "tasks" { - t.Errorf("expected tasks, got %s", event.Type) - } - if len(event.Tasks.Tasks) != 2 { - t.Errorf("expected 2 tasks, got %d", len(event.Tasks.Tasks)) - } - if event.Tasks.Tasks[0].Title != "Test Movie" { - t.Errorf("expected Test Movie, got %s", event.Tasks.Tasks[0].Title) - } - case <-time.After(2 * time.Second): - t.Fatal("timeout waiting for tasks event") - } -} - -func TestWSTransportReceiveControl(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return - } - defer conn.Close() - - conn.ReadMessage() - conn.WriteJSON(wsRegisteredMessage{Type: "registered", User: UserInfo{}}) - - time.Sleep(50 * time.Millisecond) - conn.WriteJSON(map[string]string{ - "type": "control", - "action": "cancel", - "taskId": "task-99", - }) - - time.Sleep(500 * time.Millisecond) - })) - defer srv.Close() - - wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") - tr := NewWSTransport(wsURL, "key", "a1", "ua") - - ctx := context.Background() - tr.Connect(ctx) - defer tr.Close() - tr.Register(ctx, RegisterRequest{AgentID: "a1"}) - - select { - case event := <-tr.Events(): - if event.Type != "control" { - t.Errorf("expected control, got %s", event.Type) - } - if event.Control.Action != "cancel" { - t.Errorf("expected cancel, got %s", event.Control.Action) - } - if event.Control.TaskID != "task-99" { - t.Errorf("expected task-99, got %s", event.Control.TaskID) - } - case <-time.After(2 * time.Second): - t.Fatal("timeout waiting for control event") - } -} - -func TestWSTransportReceiveUpgrade(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return - } - defer conn.Close() - - conn.ReadMessage() - conn.WriteJSON(wsRegisteredMessage{Type: "registered", User: UserInfo{}}) - - time.Sleep(50 * time.Millisecond) - conn.WriteJSON(wsUpgradeMessage{Type: "upgrade", Version: "2.0.0"}) - - time.Sleep(500 * time.Millisecond) - })) - defer srv.Close() - - wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") - tr := NewWSTransport(wsURL, "key", "a1", "ua") - - ctx := context.Background() - tr.Connect(ctx) - defer tr.Close() - tr.Register(ctx, RegisterRequest{AgentID: "a1"}) - - select { - case event := <-tr.Events(): - if event.Type != "upgrade" { - t.Errorf("expected upgrade, got %s", event.Type) - } - if event.Upgrade.Version != "2.0.0" { - t.Errorf("expected 2.0.0, got %s", event.Upgrade.Version) - } - case <-time.After(2 * time.Second): - t.Fatal("timeout waiting for upgrade event") - } -} - -func TestWSTransportDisconnect(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return - } - - conn.ReadMessage() - conn.WriteJSON(wsRegisteredMessage{Type: "registered", User: UserInfo{}}) - - // Close after a short delay to simulate disconnection - time.Sleep(100 * time.Millisecond) - conn.Close() - })) - defer srv.Close() - - wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") - tr := NewWSTransport(wsURL, "key", "a1", "ua") - - ctx := context.Background() - tr.Connect(ctx) - defer tr.Close() - tr.Register(ctx, RegisterRequest{AgentID: "a1"}) - - select { - case event := <-tr.Events(): - if event.Type != "disconnected" { - t.Errorf("expected disconnected, got %s", event.Type) - } - case <-time.After(2 * time.Second): - t.Fatal("timeout waiting for disconnected event") - } -} - -func TestWSTransportSendProgress(t *testing.T) { - var receivedMsg map[string]interface{} - var mu sync.Mutex - - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return - } - defer conn.Close() - - // Read auth - conn.ReadMessage() - conn.WriteJSON(wsRegisteredMessage{Type: "registered", User: UserInfo{}}) - - // Read progress - _, msg, err := conn.ReadMessage() - if err != nil { - return - } - mu.Lock() - json.Unmarshal(msg, &receivedMsg) - mu.Unlock() - - time.Sleep(500 * time.Millisecond) - })) - defer srv.Close() - - wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") - tr := NewWSTransport(wsURL, "key", "a1", "ua") - - ctx := context.Background() - tr.Connect(ctx) - defer tr.Close() - tr.Register(ctx, RegisterRequest{AgentID: "a1"}) - - time.Sleep(50 * time.Millisecond) - resp, err := tr.SendProgress(ctx, StatusUpdate{ - TaskID: "t1", - Status: "downloading", - Progress: 42, - }) - if err != nil { - t.Fatalf("SendProgress failed: %v", err) - } - if !resp.Success { - t.Error("expected success response") - } - - time.Sleep(100 * time.Millisecond) - mu.Lock() - if receivedMsg["type"] != "progress" { - t.Errorf("expected progress, got %v", receivedMsg["type"]) - } - if receivedMsg["taskId"] != "t1" { - t.Errorf("expected t1, got %v", receivedMsg["taskId"]) - } - mu.Unlock() -} - -// ── Hybrid Transport Tests ─────────────────────────────────────────────────── - -func TestHybridTransportWSSuccess(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return - } - defer conn.Close() - time.Sleep(500 * time.Millisecond) - })) - defer srv.Close() - - wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") - wsT := NewWSTransport(wsURL, "key", "a1", "ua") - httpT := NewHTTPTransport("http://localhost", "key", "ua") - - h := NewHybridTransport(wsT, httpT) - err := h.Connect(context.Background()) - if err != nil { - t.Fatalf("Connect failed: %v", err) - } - defer h.Close() - - if h.Mode() != "ws" { - t.Errorf("expected ws mode, got %s", h.Mode()) - } -} - -func TestHybridTransportWSFailFallbackHTTP(t *testing.T) { - // WS URL points to nowhere - wsT := NewWSTransport("ws://127.0.0.1:1", "key", "a1", "ua") - httpT := NewHTTPTransport("http://localhost", "key", "ua") - - h := NewHybridTransport(wsT, httpT) - err := h.Connect(context.Background()) - if err != nil { - t.Fatalf("Connect should succeed with HTTP fallback: %v", err) - } - defer h.Close() - - if h.Mode() != "http" { - t.Errorf("expected http mode after WS failure, got %s", h.Mode()) - } -} - -func TestHybridTransportWSDisconnectSwitchesToHTTP(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return - } - // Close immediately to trigger disconnect - time.Sleep(100 * time.Millisecond) - conn.Close() - })) - defer srv.Close() - - wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") - wsT := NewWSTransport(wsURL, "key", "a1", "ua") - httpT := NewHTTPTransport("http://localhost", "key", "ua") - - h := NewHybridTransport(wsT, httpT) - h.Connect(context.Background()) - defer h.Close() - - // Wait for disconnect event - select { - case event := <-h.Events(): - if event.Type != "disconnected" { - t.Errorf("expected disconnected, got %s", event.Type) - } - case <-time.After(2 * time.Second): - t.Fatal("timeout waiting for disconnected event") - } - - // Mode should be HTTP now - time.Sleep(100 * time.Millisecond) - if h.Mode() != "http" { - t.Errorf("expected http after disconnect, got %s", h.Mode()) - } -} - -// ── Additional HTTP Transport Tests ───────────────────────────────────────── - -func TestNewHTTPTransportConstructor(t *testing.T) { - tr := NewHTTPTransport("http://example.com", "my-key", "my-agent/1.0") - - if tr.client == nil { - t.Fatal("expected client to be non-nil") - } - if tr.events == nil { - t.Fatal("expected events channel to be non-nil") - } - // events channel should have capacity 10 - if cap(tr.events) != 10 { - t.Errorf("expected events capacity 10, got %d", cap(tr.events)) - } -} - -func TestHTTPTransportConnectAndCloseAreNoOps(t *testing.T) { - tr := NewHTTPTransport("http://localhost", "key", "ua") - - if err := tr.Connect(context.Background()); err != nil { - t.Errorf("Connect should be a no-op, got error: %v", err) - } - if err := tr.Close(); err != nil { - t.Errorf("Close should be a no-op, got error: %v", err) - } -} - -func TestHTTPTransportClientAccessor(t *testing.T) { - tr := NewHTTPTransport("http://localhost", "key", "ua") - c := tr.Client() - if c == nil { - t.Fatal("Client() should return the underlying client") - } - if c != tr.client { - t.Error("Client() should return the same instance stored internally") - } -} - -func TestHTTPTransportSendHeartbeat(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - t.Errorf("expected POST, got %s", r.Method) - } - if !strings.Contains(r.URL.Path, "heartbeat") { - t.Errorf("expected heartbeat path, got %s", r.URL.Path) - } - json.NewEncoder(w).Encode(HeartbeatResponse{ - Success: true, - Watching: true, - Upgrade: &UpgradeSignal{Version: "9.9.9"}, - }) - })) - defer srv.Close() - - tr := NewHTTPTransport(srv.URL, "key", "ua") - resp, err := tr.SendHeartbeat(context.Background(), HeartbeatRequest{ - AgentID: "a1", - Name: "test", - Version: "1.0", - }) - if err != nil { - t.Fatalf("SendHeartbeat failed: %v", err) - } - if !resp.Success { - t.Error("expected success") - } - if !resp.Watching { - t.Error("expected watching=true") - } - if resp.Upgrade == nil || resp.Upgrade.Version != "9.9.9" { - t.Error("expected upgrade version 9.9.9") - } -} - -func TestHTTPTransportSendProgress(t *testing.T) { - var received StatusUpdate - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - json.NewDecoder(r.Body).Decode(&received) - json.NewEncoder(w).Encode(StatusResponse{ - Success: true, - Cancelled: true, - }) - })) - defer srv.Close() - - tr := NewHTTPTransport(srv.URL, "key", "ua") - resp, err := tr.SendProgress(context.Background(), StatusUpdate{ - TaskID: "task-1", - Status: "downloading", - Progress: 55, - SpeedBps: 1024000, - }) - if err != nil { - t.Fatalf("SendProgress failed: %v", err) - } - if !resp.Success { - t.Error("expected success") - } - if !resp.Cancelled { - t.Error("expected cancelled flag") - } - if received.TaskID != "task-1" { - t.Errorf("expected task-1, got %s", received.TaskID) - } - if received.Progress != 55 { - t.Errorf("expected progress 55, got %d", received.Progress) - } -} - -func TestHTTPTransportClaimTasks(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - t.Errorf("expected GET, got %s", r.Method) - } - agentID := r.URL.Query().Get("agentId") - if agentID != "agent-42" { - t.Errorf("expected agentId=agent-42, got %s", agentID) - } - json.NewEncoder(w).Encode(TasksResponse{ - Tasks: []Task{ - {ID: "t1", Title: "Movie 1", InfoHash: "abc"}, - {ID: "t2", Title: "Movie 2", InfoHash: "def"}, - }, - }) - })) - defer srv.Close() - - tr := NewHTTPTransport(srv.URL, "key", "ua") - resp, err := tr.ClaimTasks(context.Background(), "agent-42") - if err != nil { - t.Fatalf("ClaimTasks failed: %v", err) - } - if len(resp.Tasks) != 2 { - t.Fatalf("expected 2 tasks, got %d", len(resp.Tasks)) - } - if resp.Tasks[0].Title != "Movie 1" { - t.Errorf("expected Movie 1, got %s", resp.Tasks[0].Title) - } -} - -func TestHTTPTransportDeregister(t *testing.T) { - var called bool - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - called = true - json.NewEncoder(w).Encode(StatusResponse{Success: true}) - })) - defer srv.Close() - - tr := NewHTTPTransport(srv.URL, "key", "ua") - err := tr.Deregister(context.Background(), "agent-1") - if err != nil { - t.Fatalf("Deregister failed: %v", err) - } - if !called { - t.Error("expected server to be called") - } -} - -func TestHTTPTransportBatchReportStatus(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - json.NewEncoder(w).Encode(BatchStatusResponse{ - Results: []StatusResponse{ - {Success: true}, - {Success: true, Cancelled: true}, - }, - Watching: true, - }) - })) - defer srv.Close() - - tr := NewHTTPTransport(srv.URL, "key", "ua") - resp, err := tr.BatchReportStatus(context.Background(), []StatusUpdate{ - {TaskID: "t1", Status: "downloading", Progress: 10}, - {TaskID: "t2", Status: "completed", Progress: 100}, - }) - if err != nil { - t.Fatalf("BatchReportStatus failed: %v", err) - } - if len(resp.Results) != 2 { - t.Fatalf("expected 2 results, got %d", len(resp.Results)) - } - if !resp.Watching { - t.Error("expected watching=true") - } - if !resp.Results[1].Cancelled { - t.Error("expected second result to be cancelled") - } -} - -func TestHTTPTransportAuthHeader(t *testing.T) { - var gotAuth string - var gotUA string - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gotAuth = r.Header.Get("Authorization") - gotUA = r.Header.Get("User-Agent") - json.NewEncoder(w).Encode(RegisterResponse{Success: true}) - })) - defer srv.Close() - - tr := NewHTTPTransport(srv.URL, "secret-key-123", "unarr/2.0") - tr.Register(context.Background(), RegisterRequest{AgentID: "a1"}) - - if gotAuth != "Bearer secret-key-123" { - t.Errorf("expected Bearer secret-key-123, got %s", gotAuth) - } - if gotUA != "unarr/2.0" { - t.Errorf("expected unarr/2.0, got %s", gotUA) - } -} - -// ── Additional WebSocket Transport Tests ──────────────────────────────────── - -func TestNewWSTransportConstructor(t *testing.T) { - tr := NewWSTransport("ws://example.com/ws", "api-key", "agent-1", "ua/1.0") - - if tr.Mode() != "ws" { - t.Errorf("expected ws mode, got %s", tr.Mode()) - } - if tr.wsURL != "ws://example.com/ws" { - t.Errorf("expected ws URL, got %s", tr.wsURL) - } - if tr.apiKey != "api-key" { - t.Errorf("expected api-key, got %s", tr.apiKey) - } - if tr.agentID != "agent-1" { - t.Errorf("expected agent-1, got %s", tr.agentID) - } - if tr.userAgent != "ua/1.0" { - t.Errorf("expected ua/1.0, got %s", tr.userAgent) - } - if cap(tr.events) != 50 { - t.Errorf("expected events capacity 50, got %d", cap(tr.events)) - } - if tr.authDone == nil { - t.Fatal("expected authDone channel to be non-nil") - } -} - -func TestWSTransportClaimTasksIsNoOp(t *testing.T) { - tr := NewWSTransport("ws://localhost", "key", "a1", "ua") - resp, err := tr.ClaimTasks(context.Background(), "a1") - if err != nil { - t.Fatalf("ClaimTasks should succeed (no-op): %v", err) - } - if resp == nil { - t.Fatal("expected non-nil response") - } - if len(resp.Tasks) != 0 { - t.Errorf("expected 0 tasks, got %d", len(resp.Tasks)) - } -} - -func TestWSTransportCloseWhenNotConnected(t *testing.T) { - tr := NewWSTransport("ws://localhost", "key", "a1", "ua") - // Close without ever connecting should not panic or error - if err := tr.Close(); err != nil { - t.Errorf("Close on unconnected transport should return nil, got %v", err) - } -} - -func TestWSTransportSendWhenNotConnected(t *testing.T) { - tr := NewWSTransport("ws://localhost", "key", "a1", "ua") - // Attempting to send a heartbeat without connecting should fail - _, err := tr.SendHeartbeat(context.Background(), HeartbeatRequest{AgentID: "a1"}) - if err == nil { - t.Error("expected error when sending without connection") - } -} - -func TestWSTransportConnectBadURL(t *testing.T) { - tr := NewWSTransport("ws://127.0.0.1:1", "key", "a1", "ua") - err := tr.Connect(context.Background()) - if err == nil { - t.Error("expected error connecting to invalid address") - } -} - -func TestWSTransportSendHeartbeatWithDisk(t *testing.T) { - var receivedMsg map[string]interface{} - var mu sync.Mutex - - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return - } - defer conn.Close() - - // Read auth - conn.ReadMessage() - conn.WriteJSON(wsRegisteredMessage{Type: "registered", User: UserInfo{}}) - - // Read heartbeat - _, msg, err := conn.ReadMessage() - if err != nil { - return - } - mu.Lock() - json.Unmarshal(msg, &receivedMsg) - mu.Unlock() - - time.Sleep(500 * time.Millisecond) - })) - defer srv.Close() - - wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") - tr := NewWSTransport(wsURL, "key", "a1", "ua") - - ctx := context.Background() - tr.Connect(ctx) - defer tr.Close() - tr.Register(ctx, RegisterRequest{AgentID: "a1"}) - - time.Sleep(50 * time.Millisecond) - resp, err := tr.SendHeartbeat(ctx, HeartbeatRequest{ - AgentID: "a1", - DiskFreeBytes: 500000000, - DiskTotalBytes: 1000000000, - }) - if err != nil { - t.Fatalf("SendHeartbeat failed: %v", err) - } - if !resp.Success { - t.Error("expected success") - } - - time.Sleep(100 * time.Millisecond) - mu.Lock() - defer mu.Unlock() - if receivedMsg["type"] != "heartbeat" { - t.Errorf("expected heartbeat, got %v", receivedMsg["type"]) - } - disk, ok := receivedMsg["disk"].(map[string]interface{}) - if !ok { - t.Fatal("expected disk field in heartbeat message") - } - if disk["free"].(float64) != 500000000 { - t.Errorf("expected free=500000000, got %v", disk["free"]) - } - if disk["total"].(float64) != 1000000000 { - t.Errorf("expected total=1000000000, got %v", disk["total"]) - } -} - -func TestWSTransportSendHeartbeatWithoutDisk(t *testing.T) { - var receivedMsg map[string]interface{} - var mu sync.Mutex - - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return - } - defer conn.Close() - - conn.ReadMessage() - conn.WriteJSON(wsRegisteredMessage{Type: "registered", User: UserInfo{}}) - - _, msg, err := conn.ReadMessage() - if err != nil { - return - } - mu.Lock() - json.Unmarshal(msg, &receivedMsg) - mu.Unlock() - - time.Sleep(500 * time.Millisecond) - })) - defer srv.Close() - - wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") - tr := NewWSTransport(wsURL, "key", "a1", "ua") - - ctx := context.Background() - tr.Connect(ctx) - defer tr.Close() - tr.Register(ctx, RegisterRequest{AgentID: "a1"}) - - time.Sleep(50 * time.Millisecond) - resp, err := tr.SendHeartbeat(ctx, HeartbeatRequest{AgentID: "a1"}) - if err != nil { - t.Fatalf("SendHeartbeat failed: %v", err) - } - if !resp.Success { - t.Error("expected success") - } - - time.Sleep(100 * time.Millisecond) - mu.Lock() - defer mu.Unlock() - if receivedMsg["type"] != "heartbeat" { - t.Errorf("expected heartbeat, got %v", receivedMsg["type"]) - } - // disk field should be absent when no disk info provided - if _, exists := receivedMsg["disk"]; exists { - t.Error("expected no disk field when disk info is zero") - } -} - -func TestWSTransportDeregisterClosesConnection(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return - } - defer conn.Close() - conn.ReadMessage() - conn.WriteJSON(wsRegisteredMessage{Type: "registered", User: UserInfo{}}) - time.Sleep(500 * time.Millisecond) - })) - defer srv.Close() - - wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") - tr := NewWSTransport(wsURL, "key", "a1", "ua") - - ctx := context.Background() - tr.Connect(ctx) - tr.Register(ctx, RegisterRequest{AgentID: "a1"}) - - err := tr.Deregister(ctx, "a1") - if err != nil { - t.Fatalf("Deregister failed: %v", err) - } - - // After deregister, send should fail (connection closed) - _, err = tr.SendHeartbeat(ctx, HeartbeatRequest{AgentID: "a1"}) - if err == nil { - t.Error("expected error sending after deregister") - } -} - -func TestWSTransportReceiveStreamRequests(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return - } - defer conn.Close() - - conn.ReadMessage() - conn.WriteJSON(wsRegisteredMessage{Type: "registered", User: UserInfo{}}) - - time.Sleep(50 * time.Millisecond) - conn.WriteJSON(wsTasksMessage{ - Type: "tasks", - Tasks: []Task{}, - StreamRequests: []StreamRequest{ - {TaskID: "t1", FilePath: "/data/movie.mkv"}, - }, - }) - - time.Sleep(500 * time.Millisecond) - })) - defer srv.Close() - - wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") - tr := NewWSTransport(wsURL, "key", "a1", "ua") - - ctx := context.Background() - tr.Connect(ctx) - defer tr.Close() - tr.Register(ctx, RegisterRequest{AgentID: "a1"}) - - select { - case event := <-tr.Events(): - if event.Type != "tasks" { - t.Errorf("expected tasks, got %s", event.Type) - } - if len(event.Tasks.StreamRequests) != 1 { - t.Fatalf("expected 1 stream request, got %d", len(event.Tasks.StreamRequests)) - } - if event.Tasks.StreamRequests[0].FilePath != "/data/movie.mkv" { - t.Errorf("expected /data/movie.mkv, got %s", event.Tasks.StreamRequests[0].FilePath) - } - case <-time.After(2 * time.Second): - t.Fatal("timeout waiting for tasks event with stream requests") - } -} - -func TestWSTransportReceiveErrorMessage(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return - } - defer conn.Close() - - conn.ReadMessage() - conn.WriteJSON(wsRegisteredMessage{Type: "registered", User: UserInfo{}}) - - time.Sleep(50 * time.Millisecond) - // Send an error message (should be logged, not emitted as event) - conn.WriteJSON(map[string]string{ - "type": "error", - "message": "rate limited", - }) - - time.Sleep(200 * time.Millisecond) - })) - defer srv.Close() - - wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") - tr := NewWSTransport(wsURL, "key", "a1", "ua") - - ctx := context.Background() - tr.Connect(ctx) - defer tr.Close() - tr.Register(ctx, RegisterRequest{AgentID: "a1"}) - - // Error messages are logged but not emitted — events channel should be quiet - select { - case event := <-tr.Events(): - // If we get disconnected, that's acceptable (server closes after delay) - if event.Type != "disconnected" { - t.Errorf("unexpected event type: %s", event.Type) - } - case <-time.After(300 * time.Millisecond): - // Expected: no event emitted for error messages - } -} - -func TestWSTransportRegisterTimeout(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return - } - defer conn.Close() - conn.ReadMessage() - // Never send registered response — should timeout - time.Sleep(20 * time.Second) - })) - defer srv.Close() - - wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") - tr := NewWSTransport(wsURL, "key", "a1", "ua") - - ctx := context.Background() - tr.Connect(ctx) - defer tr.Close() - - // Use a context with short timeout to avoid waiting 15s - ctxShort, cancel := context.WithTimeout(ctx, 200*time.Millisecond) - defer cancel() - - _, err := tr.Register(ctxShort, RegisterRequest{AgentID: "a1"}) - if err == nil { - t.Error("expected timeout error from Register") - } -} - -// ── Additional Hybrid Transport Tests ─────────────────────────────────────── - -func TestNewHybridTransportConstructor(t *testing.T) { - wsT := NewWSTransport("ws://localhost", "key", "a1", "ua") - httpT := NewHTTPTransport("http://localhost", "key", "ua") - - h := NewHybridTransport(wsT, httpT) - - if h.Mode() != "http" { - t.Errorf("expected initial mode http, got %s", h.Mode()) - } - if cap(h.events) != 50 { - t.Errorf("expected events capacity 50, got %d", cap(h.events)) - } - if h.ws != wsT { - t.Error("expected ws transport to match") - } - if h.http != httpT { - t.Error("expected http transport to match") - } - if h.reconnectStop == nil { - t.Error("expected reconnectStop channel to be non-nil") - } -} - -func TestHybridTransportCloseIsIdempotent(t *testing.T) { - wsT := NewWSTransport("ws://127.0.0.1:1", "key", "a1", "ua") - httpT := NewHTTPTransport("http://localhost", "key", "ua") - - h := NewHybridTransport(wsT, httpT) - // Close twice should not panic - if err := h.Close(); err != nil { - t.Errorf("first Close failed: %v", err) - } - if err := h.Close(); err != nil { - t.Errorf("second Close failed: %v", err) - } -} - -func TestHybridTransportHTTPModeRegister(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - json.NewEncoder(w).Encode(RegisterResponse{ - Success: true, - User: UserInfo{Name: "HTTPUser", Plan: "free"}, - }) - })) - defer srv.Close() - - wsT := NewWSTransport("ws://127.0.0.1:1", "key", "a1", "ua") - httpT := NewHTTPTransport(srv.URL, "key", "ua") - - h := NewHybridTransport(wsT, httpT) - // Force HTTP mode (default) - h.mode.Store("http") - - resp, err := h.Register(context.Background(), RegisterRequest{AgentID: "a1"}) - if err != nil { - t.Fatalf("Register failed: %v", err) - } - if resp.User.Name != "HTTPUser" { - t.Errorf("expected HTTPUser, got %s", resp.User.Name) - } -} - -func TestHybridTransportHTTPModeClaimTasks(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - json.NewEncoder(w).Encode(TasksResponse{ - Tasks: []Task{{ID: "t1", Title: "Test"}}, - }) - })) - defer srv.Close() - - wsT := NewWSTransport("ws://127.0.0.1:1", "key", "a1", "ua") - httpT := NewHTTPTransport(srv.URL, "key", "ua") - - h := NewHybridTransport(wsT, httpT) - h.mode.Store("http") - - resp, err := h.ClaimTasks(context.Background(), "a1") - if err != nil { - t.Fatalf("ClaimTasks failed: %v", err) - } - if len(resp.Tasks) != 1 { - t.Errorf("expected 1 task, got %d", len(resp.Tasks)) - } -} - -func TestHybridTransportHTTPModeDeregister(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - json.NewEncoder(w).Encode(StatusResponse{Success: true}) - })) - defer srv.Close() - - wsT := NewWSTransport("ws://127.0.0.1:1", "key", "a1", "ua") - httpT := NewHTTPTransport(srv.URL, "key", "ua") - - h := NewHybridTransport(wsT, httpT) - h.mode.Store("http") - - err := h.Deregister(context.Background(), "a1") - if err != nil { - t.Fatalf("Deregister failed: %v", err) - } -} - -func TestHybridTransportHTTPModeSendHeartbeat(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - json.NewEncoder(w).Encode(HeartbeatResponse{Success: true, Watching: true}) - })) - defer srv.Close() - - wsT := NewWSTransport("ws://127.0.0.1:1", "key", "a1", "ua") - httpT := NewHTTPTransport(srv.URL, "key", "ua") - - h := NewHybridTransport(wsT, httpT) - h.mode.Store("http") - - resp, err := h.SendHeartbeat(context.Background(), HeartbeatRequest{AgentID: "a1"}) - if err != nil { - t.Fatalf("SendHeartbeat failed: %v", err) - } - if !resp.Success { - t.Error("expected success") - } - if !resp.Watching { - t.Error("expected watching=true") - } -} - -func TestHybridTransportHTTPModeSendProgress(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - json.NewEncoder(w).Encode(StatusResponse{Success: true}) - })) - defer srv.Close() - - wsT := NewWSTransport("ws://127.0.0.1:1", "key", "a1", "ua") - httpT := NewHTTPTransport(srv.URL, "key", "ua") - - h := NewHybridTransport(wsT, httpT) - h.mode.Store("http") - - resp, err := h.SendProgress(context.Background(), StatusUpdate{ - TaskID: "t1", - Status: "completed", - Progress: 100, - }) - if err != nil { - t.Fatalf("SendProgress failed: %v", err) - } - if !resp.Success { - t.Error("expected success") - } -} - -func TestHybridTransportWSModeClaimTasksIsNoOp(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return - } - defer conn.Close() - time.Sleep(500 * time.Millisecond) - })) - defer srv.Close() - - wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") - wsT := NewWSTransport(wsURL, "key", "a1", "ua") - httpT := NewHTTPTransport("http://localhost", "key", "ua") - - h := NewHybridTransport(wsT, httpT) - h.Connect(context.Background()) - defer h.Close() - - // In WS mode, ClaimTasks delegates to WS which is a no-op - resp, err := h.ClaimTasks(context.Background(), "a1") - if err != nil { - t.Fatalf("ClaimTasks failed: %v", err) - } - if len(resp.Tasks) != 0 { - t.Errorf("expected 0 tasks in WS mode, got %d", len(resp.Tasks)) - } -} - -func TestHybridTransportEventsChannel(t *testing.T) { - wsT := NewWSTransport("ws://127.0.0.1:1", "key", "a1", "ua") - httpT := NewHTTPTransport("http://localhost", "key", "ua") - - h := NewHybridTransport(wsT, httpT) - ch := h.Events() - if ch == nil { - t.Fatal("Events() should return non-nil channel") - } - // Verify it is the correct channel - if cap(ch) != 50 { - t.Errorf("expected events capacity 50, got %d", cap(ch)) - } -} - -func TestHybridTransportSwitchToHTTPIdempotent(t *testing.T) { - wsT := NewWSTransport("ws://127.0.0.1:1", "key", "a1", "ua") - httpT := NewHTTPTransport("http://localhost", "key", "ua") - - h := NewHybridTransport(wsT, httpT) - // Already in HTTP mode, switchToHTTP should be a no-op - h.mode.Store("http") - h.switchToHTTP() // should not panic or start reconnect - - if h.Mode() != "http" { - t.Errorf("expected http, got %s", h.Mode()) - } -} - -// ── Daemon Constructor & Utility Tests ────────────────────────────────────── - -func TestNewDaemonDefaults(t *testing.T) { - tr := NewHTTPTransport("http://localhost", "key", "ua") - d := NewDaemon(DaemonConfig{ - AgentID: "a1", - AgentName: "test", - Version: "1.0", - DownloadDir: "/tmp", - }, tr) - - if d.cfg.PollInterval != 30*time.Second { - t.Errorf("expected default PollInterval 30s, got %v", d.cfg.PollInterval) - } - if d.cfg.HeartbeatInterval != 30*time.Second { - t.Errorf("expected default HeartbeatInterval 30s, got %v", d.cfg.HeartbeatInterval) - } - if d.Transport() != tr { - t.Error("Transport() should return the configured transport") - } - if d.pollNow == nil { - t.Error("pollNow channel should be initialized") - } -} - -func TestNewDaemonCustomIntervals(t *testing.T) { - tr := NewHTTPTransport("http://localhost", "key", "ua") - d := NewDaemon(DaemonConfig{ - AgentID: "a1", - AgentName: "test", - Version: "1.0", - DownloadDir: "/tmp", - PollInterval: 10 * time.Second, - HeartbeatInterval: 15 * time.Second, - }, tr) - - if d.cfg.PollInterval != 10*time.Second { - t.Errorf("expected PollInterval 10s, got %v", d.cfg.PollInterval) - } - if d.cfg.HeartbeatInterval != 15*time.Second { - t.Errorf("expected HeartbeatInterval 15s, got %v", d.cfg.HeartbeatInterval) - } -} - -func TestDaemonTriggerPoll(t *testing.T) { - tr := NewHTTPTransport("http://localhost", "key", "ua") - d := NewDaemon(DaemonConfig{ - AgentID: "a1", - AgentName: "test", - Version: "1.0", - DownloadDir: "/tmp", - }, tr) - - // First trigger should succeed - d.TriggerPoll() - - // Channel should have one signal - select { - case <-d.pollNow: - // good - default: - t.Error("expected signal on pollNow channel") - } - - // Second trigger when channel is empty should also succeed - d.TriggerPoll() - select { - case <-d.pollNow: - // good - default: - t.Error("expected signal on pollNow channel after second trigger") - } -} - -func TestDaemonTriggerPollNonBlocking(t *testing.T) { - tr := NewHTTPTransport("http://localhost", "key", "ua") - d := NewDaemon(DaemonConfig{ - AgentID: "a1", - AgentName: "test", - Version: "1.0", - DownloadDir: "/tmp", - }, tr) - - // Fill the channel (capacity 1) - d.TriggerPoll() - // Second call should not block even though channel is full - done := make(chan struct{}) - go func() { - d.TriggerPoll() - close(done) - }() - - select { - case <-done: - // good, did not block - case <-time.After(1 * time.Second): - t.Fatal("TriggerPoll blocked on full channel") - } -} - -func TestDaemonHandleEventTasks(t *testing.T) { - tr := NewHTTPTransport("http://localhost", "key", "ua") - d := NewDaemon(DaemonConfig{ - AgentID: "a1", - AgentName: "test", - Version: "1.0", - DownloadDir: "/tmp", - }, tr) - - var claimedTasks []Task - d.OnTasksClaimed = func(tasks []Task) { - claimedTasks = tasks - } - - d.handleEvent(ServerEvent{ - Type: "tasks", - Tasks: &TasksResponse{ - Tasks: []Task{ - {ID: "t1", Title: "Movie 1"}, - {ID: "t2", Title: "Movie 2"}, - }, - }, - }) - - if len(claimedTasks) != 2 { - t.Fatalf("expected 2 claimed tasks, got %d", len(claimedTasks)) - } - if claimedTasks[0].Title != "Movie 1" { - t.Errorf("expected Movie 1, got %s", claimedTasks[0].Title) - } -} - -func TestDaemonHandleEventTasksWithStreamRequests(t *testing.T) { - tr := NewHTTPTransport("http://localhost", "key", "ua") - d := NewDaemon(DaemonConfig{ - AgentID: "a1", - AgentName: "test", - Version: "1.0", - DownloadDir: "/tmp", - }, tr) - - var streamReqs []StreamRequest - d.OnStreamRequested = func(req StreamRequest) { - streamReqs = append(streamReqs, req) - } - - d.handleEvent(ServerEvent{ - Type: "tasks", - Tasks: &TasksResponse{ - Tasks: []Task{}, - StreamRequests: []StreamRequest{ - {TaskID: "t1", FilePath: "/data/movie.mkv"}, - {TaskID: "t2", FilePath: "/data/show.mkv"}, - }, - }, - }) - - if len(streamReqs) != 2 { - t.Fatalf("expected 2 stream requests, got %d", len(streamReqs)) - } - if streamReqs[0].FilePath != "/data/movie.mkv" { - t.Errorf("expected /data/movie.mkv, got %s", streamReqs[0].FilePath) - } -} - -func TestDaemonHandleEventUpgrade(t *testing.T) { - tr := NewHTTPTransport("http://localhost", "key", "ua") - d := NewDaemon(DaemonConfig{ - AgentID: "a1", - AgentName: "test", - Version: "1.0", - DownloadDir: "/tmp", - }, tr) - - d.handleEvent(ServerEvent{ - Type: "upgrade", - Upgrade: &UpgradeSignal{Version: "2.0.0"}, - }) - - if d.lastNotifiedVersion != "2.0.0" { - t.Errorf("expected lastNotifiedVersion 2.0.0, got %s", d.lastNotifiedVersion) - } - - // Same version again should not update (already notified) - d.lastNotifiedVersion = "2.0.0" - d.handleEvent(ServerEvent{ - Type: "upgrade", - Upgrade: &UpgradeSignal{Version: "2.0.0"}, - }) - // Still 2.0.0, no change - if d.lastNotifiedVersion != "2.0.0" { - t.Errorf("expected lastNotifiedVersion unchanged at 2.0.0, got %s", d.lastNotifiedVersion) - } -} - -func TestDaemonHandleEventControl(t *testing.T) { - tr := NewHTTPTransport("http://localhost", "key", "ua") - d := NewDaemon(DaemonConfig{ - AgentID: "a1", - AgentName: "test", - Version: "1.0", - DownloadDir: "/tmp", - }, tr) - - var gotAction, gotTaskID string - d.OnControlAction = func(action, taskID string) { - gotAction = action - gotTaskID = taskID - } - - d.handleEvent(ServerEvent{ - Type: "control", - Control: &ControlAction{Action: "cancel", TaskID: "task-99"}, - }) - - if gotAction != "cancel" { - t.Errorf("expected cancel, got %s", gotAction) - } - if gotTaskID != "task-99" { - t.Errorf("expected task-99, got %s", gotTaskID) - } -} - -func TestDaemonHandleEventControlWithNilCallback(t *testing.T) { - tr := NewHTTPTransport("http://localhost", "key", "ua") - d := NewDaemon(DaemonConfig{ - AgentID: "a1", - AgentName: "test", - Version: "1.0", - DownloadDir: "/tmp", - }, tr) - - // OnControlAction is nil — should not panic - d.handleEvent(ServerEvent{ - Type: "control", - Control: &ControlAction{Action: "pause", TaskID: "t1"}, - }) -} - -func TestDaemonHandleEventDisconnected(t *testing.T) { - tr := NewHTTPTransport("http://localhost", "key", "ua") - d := NewDaemon(DaemonConfig{ - AgentID: "a1", - AgentName: "test", - Version: "1.0", - DownloadDir: "/tmp", - }, tr) - - // disconnected event should not panic (just logs) - d.handleEvent(ServerEvent{Type: "disconnected"}) -} - -func TestDaemonHandleEventTasksNilCallback(t *testing.T) { - tr := NewHTTPTransport("http://localhost", "key", "ua") - d := NewDaemon(DaemonConfig{ - AgentID: "a1", - AgentName: "test", - Version: "1.0", - DownloadDir: "/tmp", - }, tr) - - // OnTasksClaimed is nil — should not panic - d.handleEvent(ServerEvent{ - Type: "tasks", - Tasks: &TasksResponse{ - Tasks: []Task{{ID: "t1", Title: "Test"}}, - }, - }) -} - -func TestDaemonHandleEventEmptyTasks(t *testing.T) { - tr := NewHTTPTransport("http://localhost", "key", "ua") - d := NewDaemon(DaemonConfig{ - AgentID: "a1", - AgentName: "test", - Version: "1.0", - DownloadDir: "/tmp", - }, tr) - - var called bool - d.OnTasksClaimed = func(tasks []Task) { - called = true - } - - // Empty tasks should not trigger callback - d.handleEvent(ServerEvent{ - Type: "tasks", - Tasks: &TasksResponse{Tasks: []Task{}}, - }) - - if called { - t.Error("OnTasksClaimed should not be called for empty task list") - } -} - -func TestDaemonHandleEventNilTasks(t *testing.T) { - tr := NewHTTPTransport("http://localhost", "key", "ua") - d := NewDaemon(DaemonConfig{ - AgentID: "a1", - AgentName: "test", - Version: "1.0", - DownloadDir: "/tmp", - }, tr) - - // Nil Tasks field should not panic - d.handleEvent(ServerEvent{ - Type: "tasks", - Tasks: nil, - }) -} - -func TestDaemonHandleEventUpgradeNilSignal(t *testing.T) { - tr := NewHTTPTransport("http://localhost", "key", "ua") - d := NewDaemon(DaemonConfig{ - AgentID: "a1", - AgentName: "test", - Version: "1.0", - DownloadDir: "/tmp", - }, tr) - - // Nil Upgrade should not panic - d.handleEvent(ServerEvent{ - Type: "upgrade", - Upgrade: nil, - }) - if d.lastNotifiedVersion != "" { - t.Errorf("expected empty lastNotifiedVersion, got %s", d.lastNotifiedVersion) - } -} - -func TestDaemonHandleEventUpgradeEmptyVersion(t *testing.T) { - tr := NewHTTPTransport("http://localhost", "key", "ua") - d := NewDaemon(DaemonConfig{ - AgentID: "a1", - AgentName: "test", - Version: "1.0", - DownloadDir: "/tmp", - }, tr) - - // Empty version should not update lastNotifiedVersion - d.handleEvent(ServerEvent{ - Type: "upgrade", - Upgrade: &UpgradeSignal{Version: ""}, - }) - if d.lastNotifiedVersion != "" { - t.Errorf("expected empty lastNotifiedVersion, got %s", d.lastNotifiedVersion) - } -} - -func TestDaemonWatchingFlag(t *testing.T) { - tr := NewHTTPTransport("http://localhost", "key", "ua") - d := NewDaemon(DaemonConfig{ - AgentID: "a1", - AgentName: "test", - Version: "1.0", - DownloadDir: "/tmp", - }, tr) - - if d.Watching.Load() { - t.Error("expected Watching to be false initially") - } - d.Watching.Store(true) - if !d.Watching.Load() { - t.Error("expected Watching to be true after Store(true)") - } -} - -// ── Transport Interface Compliance ────────────────────────────────────────── - -func TestHTTPTransportImplementsTransport(t *testing.T) { - var _ Transport = (*HTTPTransport)(nil) -} - -func TestWSTransportImplementsTransport(t *testing.T) { - var _ Transport = (*WSTransport)(nil) -} - -func TestHybridTransportImplementsTransport(t *testing.T) { - var _ Transport = (*HybridTransport)(nil) -} diff --git a/internal/agent/transport_ws.go b/internal/agent/transport_ws.go deleted file mode 100644 index 4860ca5..0000000 --- a/internal/agent/transport_ws.go +++ /dev/null @@ -1,395 +0,0 @@ -package agent - -import ( - "context" - "encoding/json" - "fmt" - "log" - "net/http" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/gorilla/websocket" -) - -// WSTransport communicates with the server via WebSocket through a Cloudflare Durable Object. -type WSTransport struct { - wsURL string // wss://unarr.torrentclaw.com/ws/{agentId} - apiKey string - agentID string - userAgent string - - conn *websocket.Conn - mu sync.Mutex - events chan ServerEvent - closed atomic.Bool - - // Cached auth response from the DO - authResp *RegisterResponse - authMu sync.Mutex - authDone chan struct{} - authDoneOnce sync.Once -} - -// NewWSTransport creates a WebSocket-based transport. -func NewWSTransport(wsURL, apiKey, agentID, userAgent string) *WSTransport { - return &WSTransport{ - wsURL: wsURL, - apiKey: apiKey, - agentID: agentID, - userAgent: userAgent, - events: make(chan ServerEvent, 50), - authDone: make(chan struct{}), - } -} - -func (t *WSTransport) Mode() string { return "ws" } -func (t *WSTransport) Events() <-chan ServerEvent { return t.events } - -// Connect dials the WebSocket server and starts the read loop. -func (t *WSTransport) Connect(ctx context.Context) error { - dialer := websocket.Dialer{ - HandshakeTimeout: 10 * time.Second, - } - - header := http.Header{} - header.Set("User-Agent", t.userAgent) - - // Append API key as query param for auth on WS upgrade - wsURLWithKey := t.wsURL - if t.apiKey != "" { - sep := "?" - if strings.Contains(wsURLWithKey, "?") { - sep = "&" - } - wsURLWithKey += sep + "key=" + t.apiKey - } - - conn, wsResp, err := dialer.DialContext(ctx, wsURLWithKey, header) - if wsResp != nil && wsResp.Body != nil { - defer wsResp.Body.Close() - } - if err != nil { - return fmt.Errorf("ws dial: %w", err) - } - - t.mu.Lock() - t.conn = conn - t.closed.Store(false) - t.authDone = make(chan struct{}) - t.authDoneOnce = sync.Once{} - t.mu.Unlock() - - go t.readLoop(conn) - return nil -} - -// Close sends a close frame and shuts down the connection. -func (t *WSTransport) Close() error { - t.closed.Store(true) - t.mu.Lock() - defer t.mu.Unlock() - if t.conn != nil { - _ = t.conn.WriteMessage( - websocket.CloseMessage, - websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), - ) - err := t.conn.Close() - t.conn = nil - return err - } - return nil -} - -// Register sends auth message and waits for the registered response. -func (t *WSTransport) Register(ctx context.Context, req RegisterRequest) (*RegisterResponse, error) { - msg := wsAuthMessage{ - Type: "auth", - APIKey: t.apiKey, - AgentID: req.AgentID, - Name: req.Name, - OS: req.OS, - Arch: req.Arch, - Version: req.Version, - DownloadDir: req.DownloadDir, - DiskFreeBytes: req.DiskFreeBytes, - DiskTotalBytes: req.DiskTotalBytes, - } - - if err := t.send(msg); err != nil { - return nil, fmt.Errorf("ws auth send: %w", err) - } - - // Wait for the auth response or context cancellation - select { - case <-t.authDone: - t.authMu.Lock() - resp := t.authResp - t.authMu.Unlock() - if resp == nil { - return nil, fmt.Errorf("ws auth: no response received") - } - return resp, nil - case <-ctx.Done(): - return nil, ctx.Err() - case <-time.After(15 * time.Second): - return nil, fmt.Errorf("ws auth: timeout waiting for registered response") - } -} - -// SendHeartbeat sends a heartbeat message. No blocking response in WS mode. -func (t *WSTransport) SendHeartbeat(_ context.Context, req HeartbeatRequest) (*HeartbeatResponse, error) { - msg := struct { - Type string `json:"type"` - Disk *struct { - Free int64 `json:"free"` - Total int64 `json:"total"` - } `json:"disk,omitempty"` - }{Type: "heartbeat"} - - if req.DiskFreeBytes > 0 || req.DiskTotalBytes > 0 { - msg.Disk = &struct { - Free int64 `json:"free"` - Total int64 `json:"total"` - }{Free: req.DiskFreeBytes, Total: req.DiskTotalBytes} - } - - if err := t.send(msg); err != nil { - return nil, err - } - // WS mode: heartbeat is fire-and-forget. Upgrade signals arrive via Events(). - return &HeartbeatResponse{Success: true}, nil -} - -// SendProgress sends a progress update. Control signals arrive async via Events(). -func (t *WSTransport) SendProgress(_ context.Context, update StatusUpdate) (*StatusResponse, error) { - msg := struct { - Type string `json:"type"` - TaskID string `json:"taskId"` - Status string `json:"status,omitempty"` - Progress int `json:"progress,omitempty"` - DownloadedBytes int64 `json:"downloadedBytes,omitempty"` - TotalBytes int64 `json:"totalBytes,omitempty"` - SpeedBps int64 `json:"speedBps,omitempty"` - ETA int `json:"eta,omitempty"` - ResolvedMethod string `json:"resolvedMethod,omitempty"` - FileName string `json:"fileName,omitempty"` - FilePath string `json:"filePath,omitempty"` - StreamURL string `json:"streamUrl,omitempty"` - StreamReady bool `json:"streamReady,omitempty"` - ErrorMessage string `json:"errorMessage,omitempty"` - }{ - Type: "progress", - TaskID: update.TaskID, - Status: update.Status, - Progress: update.Progress, - DownloadedBytes: update.DownloadedBytes, - TotalBytes: update.TotalBytes, - SpeedBps: update.SpeedBps, - ETA: update.ETA, - ResolvedMethod: update.ResolvedMethod, - FileName: update.FileName, - FilePath: update.FilePath, - StreamURL: update.StreamURL, - StreamReady: update.StreamReady, - ErrorMessage: update.ErrorMessage, - } - - if err := t.send(msg); err != nil { - return nil, err - } - // In WS mode, control signals come via Events(), not in the progress response. - return &StatusResponse{Success: true}, nil -} - -// ClaimTasks is a no-op in WS mode — tasks arrive via Events(). -func (t *WSTransport) ClaimTasks(_ context.Context, _ string) (*TasksResponse, error) { - return &TasksResponse{}, nil -} - -// Deregister is handled by WebSocket close (DO detects disconnection). -func (t *WSTransport) Deregister(_ context.Context, _ string) error { - return t.Close() -} - -// ── Internal ───────────────────────────────────────────────────────────────── - -func (t *WSTransport) send(msg any) error { - t.mu.Lock() - defer t.mu.Unlock() - if t.conn == nil { - return fmt.Errorf("ws: not connected") - } - data, err := json.Marshal(msg) - if err != nil { - return err - } - _ = t.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) - return t.conn.WriteMessage(websocket.TextMessage, data) -} - -func (t *WSTransport) readLoop(conn *websocket.Conn) { - // Cloudflare idle timeout is 100s. We send pings every 30s and expect - // either a pong or a server message within 45s. If neither arrives, - // the read deadline fires and we detect the zombie connection. - const ( - pongWait = 45 * time.Second - pingPeriod = 30 * time.Second - ) - - _ = conn.SetReadDeadline(time.Now().Add(pongWait)) - conn.SetPongHandler(func(string) error { - _ = conn.SetReadDeadline(time.Now().Add(pongWait)) - return nil - }) - - // Ping ticker goroutine — stops when readLoop returns. - pingDone := make(chan struct{}) - go func() { - ticker := time.NewTicker(pingPeriod) - defer ticker.Stop() - for { - select { - case <-ticker.C: - t.mu.Lock() - if t.conn != nil { - _ = t.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) - err := t.conn.WriteMessage(websocket.PingMessage, nil) - _ = t.conn.SetWriteDeadline(time.Time{}) - if err != nil { - t.mu.Unlock() - return - } - } - t.mu.Unlock() - case <-pingDone: - return - } - } - }() - defer close(pingDone) - - for { - _, msg, err := conn.ReadMessage() - if err != nil { - if !t.closed.Load() { - log.Printf("[ws] read error: %v", err) - // Signal disconnection to the daemon - select { - case t.events <- ServerEvent{Type: "disconnected"}: - default: - } - } - return - } - - // Any message (text or pong) proves the connection is alive. - _ = conn.SetReadDeadline(time.Now().Add(pongWait)) - - var envelope struct { - Type string `json:"type"` - } - if err := json.Unmarshal(msg, &envelope); err != nil { - log.Printf("[ws] invalid message: %v", err) - continue - } - - switch envelope.Type { - case "registered": - var resp wsRegisteredMessage - if json.Unmarshal(msg, &resp) == nil { - t.authMu.Lock() - t.authResp = &RegisterResponse{ - Success: true, - User: resp.User, - Features: resp.Features, - } - t.authMu.Unlock() - // Signal that auth is complete (sync.Once prevents double-close panic) - t.authDoneOnce.Do(func() { close(t.authDone) }) - } - - case "tasks": - var resp wsTasksMessage - if json.Unmarshal(msg, &resp) == nil { - select { - case t.events <- ServerEvent{ - Type: "tasks", - Tasks: &TasksResponse{ - Tasks: resp.Tasks, - StreamRequests: resp.StreamRequests, - }, - }: - default: - log.Printf("[ws] events channel full, dropping tasks message") - } - } - - case "upgrade": - var resp wsUpgradeMessage - if json.Unmarshal(msg, &resp) == nil { - select { - case t.events <- ServerEvent{ - Type: "upgrade", - Upgrade: &UpgradeSignal{Version: resp.Version}, - }: - default: - } - } - - case "control": - var resp ControlAction - if json.Unmarshal(msg, &resp) == nil { - select { - case t.events <- ServerEvent{ - Type: "control", - Control: &resp, - }: - default: - } - } - - case "error": - var resp struct { - Message string `json:"message"` - } - if json.Unmarshal(msg, &resp) == nil { - log.Printf("[ws] server error: %s", resp.Message) - } - } - } -} - -// ── WS message types ───────────────────────────────────────────────────────── - -type wsAuthMessage struct { - Type string `json:"type"` - APIKey string `json:"apiKey"` - AgentID string `json:"agentId"` - Name string `json:"name,omitempty"` - OS string `json:"os,omitempty"` - Arch string `json:"arch,omitempty"` - Version string `json:"version,omitempty"` - DownloadDir string `json:"downloadDir,omitempty"` - DiskFreeBytes int64 `json:"diskFreeBytes,omitempty"` - DiskTotalBytes int64 `json:"diskTotalBytes,omitempty"` -} - -type wsRegisteredMessage struct { - Type string `json:"type"` - User UserInfo `json:"user"` - Features FeatureFlags `json:"features"` -} - -type wsTasksMessage struct { - Type string `json:"type"` - Tasks []Task `json:"tasks"` - StreamRequests []StreamRequest `json:"streamRequests,omitempty"` -} - -type wsUpgradeMessage struct { - Type string `json:"type"` - Version string `json:"version"` -} diff --git a/internal/agent/types.go b/internal/agent/types.go index f1ab153..e7d07d6 100644 --- a/internal/agent/types.go +++ b/internal/agent/types.go @@ -50,20 +50,6 @@ type UsenetServerInfo struct { SSL bool `json:"ssl"` } -// HeartbeatRequest is sent every 30s to keep the agent alive. -type HeartbeatRequest struct { - AgentID string `json:"agentId"` - Name string `json:"name,omitempty"` - OS string `json:"os,omitempty"` - Version string `json:"version,omitempty"` - DownloadDir string `json:"downloadDir,omitempty"` - DiskFreeBytes int64 `json:"diskFreeBytes,omitempty"` - DiskTotalBytes int64 `json:"diskTotalBytes,omitempty"` - StreamPort int `json:"streamPort,omitempty"` - LanIP string `json:"lanIp,omitempty"` - TailscaleIP string `json:"tailscaleIp,omitempty"` -} - // Task represents a download task claimed from the server. type Task struct { ID string `json:"id"` @@ -88,12 +74,6 @@ type Task struct { CollectionName string `json:"collectionName,omitempty"` // Collection name (e.g., "Harry Potter Collection") } -// TasksResponse wraps the array of tasks returned by the server. -type TasksResponse struct { - Tasks []Task `json:"tasks"` - StreamRequests []StreamRequest `json:"streamRequests,omitempty"` -} - // StreamRequest is a request to stream a completed download from disk. type StreamRequest struct { TaskID string `json:"taskId"` @@ -139,14 +119,6 @@ type BatchStatusResponse struct { Watching bool `json:"watching,omitempty"` } -// HeartbeatResponse is returned by the server on heartbeat. -type HeartbeatResponse struct { - Success bool `json:"success"` - Upgrade *UpgradeSignal `json:"upgrade,omitempty"` - Watching bool `json:"watching,omitempty"` // true when a user is viewing download progress in the web UI - Scan bool `json:"scan,omitempty"` // true when user triggered a library scan from the web UI -} - // UpgradeSignal tells the agent to upgrade to a specific version. type UpgradeSignal struct { Version string `json:"version"` @@ -176,7 +148,6 @@ type AgentInfo struct { User UserInfo Features FeatureFlags StartedAt time.Time - LastPollAt time.Time ActiveTasks int } @@ -334,6 +305,45 @@ type LibrarySyncResponse struct { Removed int `json:"removed"` } +// --------------------------------------------------------------------------- +// Sync types (unified CLI ↔ Server communication) +// --------------------------------------------------------------------------- + +// SyncRequest is sent by the CLI periodically to synchronize state with the server. +// Contains the CLI's full execution state — the server responds with pending actions. +type SyncRequest struct { + AgentID string `json:"agentId"` + Version string `json:"version,omitempty"` + OS string `json:"os,omitempty"` + Arch string `json:"arch,omitempty"` + Name string `json:"name,omitempty"` + DownloadDir string `json:"downloadDir,omitempty"` + DiskFreeBytes int64 `json:"diskFreeBytes,omitempty"` + DiskTotalBytes int64 `json:"diskTotalBytes,omitempty"` + StreamPort int `json:"streamPort,omitempty"` + LanIP string `json:"lanIp,omitempty"` + TailscaleIP string `json:"tailscaleIp,omitempty"` + FreeSlots int `json:"freeSlots"` + Tasks []TaskState `json:"tasks"` +} + +// ControlAction represents a server-side control signal for a task. +type ControlAction struct { + Action string `json:"action"` // "pause", "resume", "cancel", "stream" + TaskID string `json:"taskId"` + DeleteFiles bool `json:"deleteFiles,omitempty"` +} + +// SyncResponse is returned by the server with all pending actions for the CLI. +type SyncResponse struct { + NewTasks []Task `json:"newTasks,omitempty"` + Controls []ControlAction `json:"controls,omitempty"` + StreamRequests []StreamRequest `json:"streamRequests,omitempty"` + Watching bool `json:"watching"` + Upgrade *UpgradeSignal `json:"upgrade,omitempty"` + Scan bool `json:"scan,omitempty"` +} + // --------------------------------------------------------------------------- // Watch progress types (used by stream tracking) // --------------------------------------------------------------------------- diff --git a/internal/cmd/config_menu.go b/internal/cmd/config_menu.go index 07297f7..9b1ddbf 100644 --- a/internal/cmd/config_menu.go +++ b/internal/cmd/config_menu.go @@ -311,21 +311,10 @@ func configConnection(cfg *config.Config) error { ).Run() } -func configAdvanced(cfg *config.Config) error { - return huh.NewForm( - huh.NewGroup( - huh.NewInput(). - Title("Poll interval"). - Description("How often to check for new tasks (e.g. 30s, 1m)"). - Value(&cfg.Daemon.PollInterval). - Validate(validateDuration), - huh.NewInput(). - Title("Heartbeat interval"). - Description("How often to send heartbeat to server (e.g. 30s, 1m)"). - Value(&cfg.Daemon.HeartbeatInterval). - Validate(validateDuration), - ), - ).Run() +func configAdvanced(_ *config.Config) error { + // Sync intervals are adaptive (3s watching, 60s idle) — no user-facing config needed. + fmt.Println("No advanced settings to configure. Sync intervals are automatic.") + return nil } // ── Validators ────────────────────────────────────────────────────── diff --git a/internal/cmd/daemon.go b/internal/cmd/daemon.go index a6abc4c..d050903 100644 --- a/internal/cmd/daemon.go +++ b/internal/cmd/daemon.go @@ -7,7 +7,6 @@ import ( "os" "os/signal" "path/filepath" - "strings" "syscall" "time" @@ -27,13 +26,13 @@ func newStartCmd() *cobra.Command { Short: "Start the download daemon (foreground)", Long: `Start the unarr daemon in the foreground. -Registers with the server, receives download tasks via WebSocket (with -HTTP fallback), and executes them using the configured download method. +Registers with the server, receives download tasks via periodic sync, +and executes them using the configured download method. Supports torrent, debrid, and usenet downloads concurrently. -The daemon sends periodic heartbeats and reports download progress back -to the web dashboard. Press Ctrl+C to stop gracefully — active downloads -get up to 30 seconds to finish. +The daemon syncs state with the server every 3s when someone is viewing +the web dashboard, or every 60s when idle. Press Ctrl+C to stop +gracefully — active downloads get up to 30 seconds to finish. Requires: API key, agent ID, and download directory (run 'unarr init' first). @@ -127,85 +126,59 @@ func runDaemonStart() error { bold.Println(" unarr Daemon") fmt.Println() - // Parse intervals - pollInterval, _ := time.ParseDuration(cfg.Daemon.PollInterval) - if pollInterval == 0 { - pollInterval = 30 * time.Second - } - heartbeatInterval, _ := time.ParseDuration(cfg.Daemon.HeartbeatInterval) - if heartbeatInterval == 0 { - heartbeatInterval = 30 * time.Second - } - statusInterval, _ := time.ParseDuration(cfg.Daemon.StatusInterval) - if statusInterval == 0 { - statusInterval = 3 * time.Second - } - userAgent := "unarr/" + Version // Create daemon config daemonCfg := agent.DaemonConfig{ - AgentID: cfg.Agent.ID, - AgentName: cfg.Agent.Name, - Version: Version, - DownloadDir: cfg.Download.Dir, - PollInterval: pollInterval, - HeartbeatInterval: heartbeatInterval, - StreamPort: cfg.Download.StreamPort, - LanIP: engine.LanIP(), - TailscaleIP: engine.TailscaleIP(), + AgentID: cfg.Agent.ID, + AgentName: cfg.Agent.Name, + Version: Version, + DownloadDir: cfg.Download.Dir, + StreamPort: cfg.Download.StreamPort, + LanIP: engine.LanIP(), + TailscaleIP: engine.TailscaleIP(), } - // Create transport: Hybrid (WS + HTTP fallback) or HTTP-only - httpT := agent.NewHTTPTransport(cfg.Auth.APIURL, cfg.Auth.APIKey, userAgent) - - wsURL := cfg.Auth.WSURL - if wsURL == "" { - wsURL = deriveWSURL(cfg.Auth.APIURL, cfg.Agent.ID) - } - - var transport agent.Transport - if wsURL != "" { - wsT := agent.NewWSTransport(wsURL, cfg.Auth.APIKey, cfg.Agent.ID, userAgent) - transport = agent.NewHybridTransport(wsT, httpT) - log.Printf("Transport: WebSocket (fallback: HTTP) → %s", wsURL) - } else { - transport = httpT - log.Println("Transport: HTTP only") - } - - // Create daemon — always uses Transport interface - d := agent.NewDaemon(daemonCfg, transport) - - // Create agent client for watch progress reporting + // Create HTTP client — single communication channel agentClient := agent.NewClient(cfg.Auth.APIURL, cfg.Auth.APIKey, userAgent) + log.Printf("Transport: HTTP sync → %s", cfg.Auth.APIURL) + + // Create daemon + d := agent.NewDaemon(daemonCfg, agentClient) + + // Start SIGUSR1 reload watcher (unix only, no-op on Windows) + startReloadWatcher(&ReloadableConfig{Daemon: d}) // Daemon-scoped context — cancelled on shutdown ctx, cancel := context.WithCancel(context.Background()) defer cancel() - // Create progress reporter using transport - reporter := engine.NewProgressReporterWithTransport(transport, statusInterval) - reporter.SetWatchingFunc(func() bool { return d.Watching.Load() }) - reporter.SetWatchingChangedHandler(func(watching bool) { d.Watching.Store(watching) }) - // Parse speed limits maxDl, _ := config.ParseSpeed(cfg.Download.MaxDownloadSpeed) maxUl, _ := config.ParseSpeed(cfg.Download.MaxUploadSpeed) - // Parse torrent timeouts from config (default: 0 = unlimited, like qBittorrent) + // Parse torrent timeouts metaTimeout, _ := time.ParseDuration(cfg.Download.MetadataTimeout) stallTimeout, _ := time.ParseDuration(cfg.Download.StallTimeout) + // Create progress reporter — only used for stream tasks (handleStreamTask) + // The sync goroutine handles all regular progress reporting. + statusInterval, _ := time.ParseDuration(cfg.Daemon.StatusInterval) + if statusInterval == 0 { + statusInterval = 3 * time.Second + } + reporter := engine.NewProgressReporter(agentClient, statusInterval) + reporter.SetWatchingFunc(func() bool { return d.Watching.Load() }) + // Create torrent downloader torrentDl, err := engine.NewTorrentDownloader(engine.TorrentConfig{ DataDir: cfg.Download.Dir, - MetadataTimeout: metaTimeout, // 0 = unlimited (default) - StallTimeout: stallTimeout, // 0 = unlimited (default) - MaxTimeout: 0, // unlimited + MetadataTimeout: metaTimeout, + StallTimeout: stallTimeout, + MaxTimeout: 0, MaxDownloadRate: maxDl, MaxUploadRate: maxUl, - ListenPort: cfg.Download.ListenPort, // 0 = default 42069 + ListenPort: cfg.Download.ListenPort, SeedEnabled: false, }) if err != nil { @@ -223,7 +196,7 @@ func runDaemonStart() error { log.Printf("Speed limits: download=%s upload=%s", dlStr, ulStr) } - // Create debrid downloader (HTTPS-based, no provider interaction needed) + // Create debrid downloader debridDl := engine.NewDebridDownloader() // Create download manager @@ -237,170 +210,53 @@ func runDaemonStart() error { TVShowsDir: cfg.Organize.TVShowsDir, OutputDir: cfg.Download.Dir, }, - }, reporter, torrentDl, debridDl, engine.NewUsenetDownloader(httpT.Client())) + }, reporter, torrentDl, debridDl, engine.NewUsenetDownloader(agentClient)) - // Create persistent stream server — lives for the entire daemon lifecycle. - // One port, one server, swap files with SetFile(). No more port churn. + // Create persistent stream server streamSrv := engine.NewStreamServer(cfg.Download.StreamPort) if err := streamSrv.Listen(ctx); err != nil { return fmt.Errorf("start stream server: %w", err) } - // Update heartbeat with actual port (may differ if configured port was busy) d.UpdateStreamPort(streamSrv.Port()) - // Wire state tracking + // Wire sync client callbacks + sc := d.SyncClient() + sc.GetFreeSlots = manager.FreeSlots + sc.GetTaskStates = manager.TaskStates d.GetActiveCount = manager.ActiveCount - d.GetCleanableBytes = CleanableBytes - // Wire: server-side signals -> manager actions + stream tasks - reporter.SetCancelHandler(func(taskID string) { - manager.CancelTask(taskID) - cancelStreamTask(taskID) - }) - reporter.SetPauseHandler(func(taskID string) { - manager.PauseTask(taskID) - cancelStreamTask(taskID) - }) - reporter.SetDeleteFilesHandler(func(taskID string) { - manager.CancelAndDeleteFiles(taskID) - cancelStreamTask(taskID) - }) - - // Wire: stream requested on active download → set file on persistent server - reporter.SetStreamRequestedHandler(func(taskID string) { - task := manager.GetTask(taskID) - if task == nil { - log.Printf("[%s] stream requested but task not found in manager", taskID[:8]) - return - } - if task.GetStreamURL() != "" { - return // already streaming - } - provider, err := torrentDl.GetStreamProvider(taskID) - if err != nil { - log.Printf("[%s] stream failed: %v", taskID[:8], err) - return - } - cancelStreamContexts() - streamSrv.SetFile(provider, taskID) - task.SetStreamURL(streamSrv.URLsJSON()) - log.Printf("[%s] streaming active download: %s", taskID[:8], provider.FileName()) - - // Start watch progress reporter with cancellable context - watchCtx, watchCancel := context.WithCancel(ctx) //nolint:gosec // cancel stored in streamRegistry, called by cancelStreamContexts() - streamRegistry.mu.Lock() - streamRegistry.cancels["watch:"+taskID] = watchCancel - streamRegistry.mu.Unlock() - go engine.NewWatchReporter(agentClient, streamSrv, taskID).Run(watchCtx) - }) - - // Wire: daemon claimed tasks -> manager + // Trigger immediate sync when a download slot frees up + manager.OnTaskDone = func() { d.TriggerSync() } + // Wire: sync receives new tasks → submit to manager or handle stream d.OnTasksClaimed = func(tasks []agent.Task) { for _, t := range tasks { if t.Mode == "stream" { - // Skip if already streaming this task if isStreamingTask(t.ID) { continue } - // Only 1 stream at a time: cancel existing stream goroutines + clear file cancelStreamContexts() streamSrv.ClearFile() - // Reserve slot before spawning goroutine to prevent TOCTOU race. - streamCtx, streamCancel := context.WithCancel(ctx) //nolint:gosec // G118: cancel ownership transferred to streamRegistry + streamCtx, streamCancel := context.WithCancel(ctx) //nolint:gosec // G118: cancel stored in registry streamRegistry.mu.Lock() streamRegistry.cancels[t.ID] = streamCancel streamRegistry.mu.Unlock() go handleStreamTask(streamCtx, t, reporter, cfg, agentClient, streamSrv) - } else if t.ForceStart || manager.HasCapacity() { - manager.Submit(ctx, t) } else { - log.Printf("[%s] skipped: no capacity (max %d)", t.ID[:8], cfg.Download.MaxConcurrent) + manager.Submit(ctx, t) } } } - // Wire: stream requests for completed downloads → set file on persistent server - d.OnStreamRequested = func(sr agent.StreamRequest) { - // Already serving this task — just notify server it's ready - if streamSrv.CurrentTaskID() == sr.TaskID { - go func() { - if _, err := transport.SendProgress(ctx, agent.StatusUpdate{ - TaskID: sr.TaskID, - StreamReady: true, - }); err != nil { - log.Printf("[%s] stream ready re-notify failed: %v", sr.TaskID[:8], err) - } - }() - return - } - - filePath := sr.FilePath - info, err := os.Stat(filePath) - if err != nil { - log.Printf("[%s] stream request: file not found: %s", sr.TaskID[:8], filePath) - go func() { - if _, err := transport.SendProgress(ctx, agent.StatusUpdate{ - TaskID: sr.TaskID, - Status: "failed", - ErrorMessage: fmt.Sprintf("file not found: %s", filePath), - }); err != nil { - log.Printf("[%s] stream error report failed: %v", sr.TaskID[:8], err) - } - }() - return - } - - // If filePath is a directory, find the largest video file inside - if info.IsDir() { - found := engine.FindVideoFile(filePath) - if found == "" { - log.Printf("[%s] stream request: no video file in directory: %s", sr.TaskID[:8], filePath) - go func() { - if _, err := transport.SendProgress(ctx, agent.StatusUpdate{ - TaskID: sr.TaskID, - Status: "failed", - ErrorMessage: fmt.Sprintf("no video file in directory: %s", filePath), - }); err != nil { - log.Printf("[%s] stream error report failed: %v", sr.TaskID[:8], err) - } - }() - return - } - filePath = found - log.Printf("[%s] resolved directory to video file: %s", sr.TaskID[:8], filepath.Base(filePath)) - } - - // Cancel any active stream goroutines and swap file on the persistent server - cancelStreamContexts() - streamSrv.SetFile(engine.NewDiskFileProvider(filePath), sr.TaskID) - - log.Printf("[%s] streaming from disk: %s → %s", sr.TaskID[:8], filepath.Base(filePath), streamSrv.URL()) - - // Start watch progress reporter with a cancellable context - // so it stops when the user switches to a different stream. - watchCtx, watchCancel := context.WithCancel(ctx) //nolint:gosec // cancel stored in streamRegistry, called by cancelStreamContexts() - streamRegistry.mu.Lock() - streamRegistry.cancels["watch:"+sr.TaskID] = watchCancel - streamRegistry.mu.Unlock() - go engine.NewWatchReporter(agentClient, streamSrv, sr.TaskID).Run(watchCtx) - - // Notify server that stream is ready (clears streamRequested flag) - go func() { - if _, err := transport.SendProgress(ctx, agent.StatusUpdate{ - TaskID: sr.TaskID, - StreamReady: true, - }); err != nil { - log.Printf("[%s] stream ready report failed: %v", sr.TaskID[:8], err) - } - }() - } - - // Wire: WS control actions (pause/cancel/stream pushed from server) - d.OnControlAction = func(action, taskID string) { + // Wire: sync receives control signals → act on manager + d.OnControlAction = func(action, taskID string, deleteFiles bool) { switch action { case "cancel": - manager.CancelTask(taskID) + if deleteFiles { + manager.CancelAndDeleteFiles(taskID) + } else { + manager.CancelTask(taskID) + } cancelStreamTask(taskID) if streamSrv.CurrentTaskID() == taskID { streamSrv.ClearFile() @@ -412,10 +268,9 @@ func runDaemonStart() error { streamSrv.ClearFile() } case "resume": - log.Printf("[%s] resume requested via WebSocket, triggering poll", taskID[:8]) - d.TriggerPoll() + log.Printf("[%s] resume requested, triggering sync", agent.ShortID(taskID)) + d.TriggerSync() case "stream": - // Skip if already streaming this task if streamSrv.CurrentTaskID() == taskID { return } @@ -425,13 +280,19 @@ func runDaemonStart() error { } provider, err := torrentDl.GetStreamProvider(taskID) if err != nil { - log.Printf("[%s] stream failed: %v", taskID[:8], err) + log.Printf("[%s] stream failed: %v", agent.ShortID(taskID), err) return } cancelStreamContexts() streamSrv.SetFile(provider, taskID) task.SetStreamURL(streamSrv.URLsJSON()) - log.Printf("[%s] streaming via WS: %s", taskID[:8], provider.FileName()) + log.Printf("[%s] streaming: %s", agent.ShortID(taskID), provider.FileName()) + + watchCtx, watchCancel := context.WithCancel(ctx) //nolint:gosec // G118 + streamRegistry.mu.Lock() + streamRegistry.cancels["watch:"+taskID] = watchCancel + streamRegistry.mu.Unlock() + go engine.NewWatchReporter(agentClient, streamSrv, taskID).Run(watchCtx) case "stop-stream": cancelStreamTask(taskID) if streamSrv.CurrentTaskID() == taskID { @@ -440,19 +301,77 @@ func runDaemonStart() error { } } - // Config hot-reload (SIGUSR1 on Unix, no-op on Windows) - // Tickers are initialized inside d.Run(), so we pass the daemon - // and the reload goroutine reads them when the signal arrives. - startReloadWatcher(&ReloadableConfig{Daemon: d}) + // Wire: sync receives stream requests for completed downloads + d.OnStreamRequested = func(sr agent.StreamRequest) { + if streamSrv.CurrentTaskID() == sr.TaskID { + // Already serving — notify server it's ready + go func() { + if _, err := agentClient.ReportStatus(ctx, agent.StatusUpdate{ + TaskID: sr.TaskID, + StreamReady: true, + }); err != nil { + log.Printf("[%s] stream ready re-notify failed: %v", agent.ShortID(sr.TaskID), err) + } + }() + return + } - // Signal handling - sigCh := make(chan os.Signal, 1) - signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + filePath := sr.FilePath + info, err := os.Stat(filePath) + if err != nil { + log.Printf("[%s] stream request: file not found: %s", agent.ShortID(sr.TaskID), filePath) + go func() { + if _, err := agentClient.ReportStatus(ctx, agent.StatusUpdate{ + TaskID: sr.TaskID, + Status: "failed", + ErrorMessage: fmt.Sprintf("file not found: %s", filePath), + }); err != nil { + log.Printf("[%s] stream error report failed: %v", agent.ShortID(sr.TaskID), err) + } + }() + return + } - // Start progress reporter in background - go reporter.Run(ctx) + if info.IsDir() { + found := engine.FindVideoFile(filePath) + if found == "" { + log.Printf("[%s] stream request: no video file in directory: %s", agent.ShortID(sr.TaskID), filePath) + go func() { + if _, err := agentClient.ReportStatus(ctx, agent.StatusUpdate{ + TaskID: sr.TaskID, + Status: "failed", + ErrorMessage: fmt.Sprintf("no video file in directory: %s", filePath), + }); err != nil { + log.Printf("[%s] stream error report failed: %v", agent.ShortID(sr.TaskID), err) + } + }() + return + } + filePath = found + log.Printf("[%s] resolved directory to video file: %s", agent.ShortID(sr.TaskID), filepath.Base(filePath)) + } - // Periodic DHT node persistence (every 5 min) — protects against crash data loss + cancelStreamContexts() + streamSrv.SetFile(engine.NewDiskFileProvider(filePath), sr.TaskID) + log.Printf("[%s] streaming from disk: %s → %s", agent.ShortID(sr.TaskID), filepath.Base(filePath), streamSrv.URL()) + + watchCtx, watchCancel := context.WithCancel(ctx) //nolint:gosec // G118 + streamRegistry.mu.Lock() + streamRegistry.cancels["watch:"+sr.TaskID] = watchCancel + streamRegistry.mu.Unlock() + go engine.NewWatchReporter(agentClient, streamSrv, sr.TaskID).Run(watchCtx) + + go func() { + if _, err := agentClient.ReportStatus(ctx, agent.StatusUpdate{ + TaskID: sr.TaskID, + StreamReady: true, + }); err != nil { + log.Printf("[%s] stream ready report failed: %v", agent.ShortID(sr.TaskID), err) + } + }() + } + + // Periodic DHT node persistence (every 5 min) go func() { ticker := time.NewTicker(5 * time.Minute) defer ticker.Stop() @@ -466,8 +385,7 @@ func runDaemonStart() error { } }() - // Start auto-scan goroutine (daily library scan + sync) - // Default scan_path to download dir so auto-scan works out of the box. + // Start auto-scan goroutine scanPath := cfg.Library.ScanPath if scanPath == "" { scanPath = cfg.Download.Dir @@ -484,7 +402,10 @@ func runDaemonStart() error { go runAutoScan(ctx, scanCfg, scanInterval, agentClient, d.ScanNow) } - // Start daemon (blocks) + // Start reporter only for stream task handling + go reporter.Run(ctx) + + // Start daemon (blocks — runs sync loop) errCh := make(chan error, 1) go func() { errCh <- d.Run(ctx) @@ -493,6 +414,10 @@ func runDaemonStart() error { // Start idle guard for the persistent stream server go startIdleGuard(ctx, streamSrv) + // Signal handling + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + // Wait for signal or error select { case sig := <-sigCh: @@ -506,6 +431,7 @@ func runDaemonStart() error { defer shutdownCancel() manager.Shutdown(shutdownCtx) + d.Deregister() fmt.Println(" Daemon stopped.") return nil @@ -517,41 +443,6 @@ func runDaemonStart() error { } } -// deriveWSURL derives a WebSocket URL from the API URL. -// https://torrentclaw.com → wss://unarr.torrentclaw.com/ws/{agentId} -// Returns "" for localhost/dev environments where WS gateway isn't available. -func deriveWSURL(apiURL, agentID string) string { - if apiURL == "" || agentID == "" { - return "" - } - // Parse domain from API URL - domain := apiURL - for _, prefix := range []string{"https://", "http://"} { - if len(domain) > len(prefix) && domain[:len(prefix)] == prefix { - domain = domain[len(prefix):] - break - } - } - // Strip trailing slash/path - for i := 0; i < len(domain); i++ { - if domain[i] == '/' { - domain = domain[:i] - break - } - } - // Strip port if present - if idx := strings.LastIndex(domain, ":"); idx > 0 { - domain = domain[:idx] - } - - // Skip WS for localhost/dev — gateway only available in production - if domain == "localhost" || domain == "127.0.0.1" || domain == "0.0.0.0" { - return "" - } - - return "wss://unarr." + domain + "/ws/" + agentID -} - func formatSpeedLog(bps int64) string { switch { case bps >= 1024*1024*1024: @@ -569,11 +460,9 @@ func formatSpeedLog(bps int64) string { func runAutoScan(ctx context.Context, cfg config.Config, interval time.Duration, ac *agent.Client, scanNow <-chan struct{}) { log.Printf("[auto-scan] enabled: every %s, path: %s", interval, cfg.Library.ScanPath) - // Run first scan after a short delay (let daemon stabilize) select { case <-time.After(30 * time.Second): case <-scanNow: - // Immediate scan requested before initial delay case <-ctx.Done(): return } @@ -608,7 +497,6 @@ func runAutoScan(ctx context.Context, cfg config.Config, interval time.Duration, return } - // Sync to server items := library.BuildSyncItems(cache) if len(items) == 0 { log.Printf("[auto-scan] no items to sync") diff --git a/internal/cmd/daemon_test.go b/internal/cmd/daemon_test.go index fe1cdd4..09b5f49 100644 --- a/internal/cmd/daemon_test.go +++ b/internal/cmd/daemon_test.go @@ -2,32 +2,6 @@ package cmd import "testing" -func TestDeriveWSURL(t *testing.T) { - tests := []struct { - apiURL string - agentID string - want string - }{ - {"https://torrentclaw.com", "agent-123", "wss://unarr.torrentclaw.com/ws/agent-123"}, - {"http://localhost:3000", "a1", ""}, // localhost skipped - {"http://127.0.0.1:3000", "a1", ""}, // 127.0.0.1 skipped - {"https://torrentclaw.com/", "a1", "wss://unarr.torrentclaw.com/ws/a1"}, - {"https://api.example.io", "x", "wss://unarr.api.example.io/ws/x"}, - {"", "agent-123", ""}, - {"https://torrentclaw.com", "", ""}, - {"", "", ""}, - } - - for _, tt := range tests { - t.Run(tt.apiURL+"_"+tt.agentID, func(t *testing.T) { - got := deriveWSURL(tt.apiURL, tt.agentID) - if got != tt.want { - t.Errorf("deriveWSURL(%q, %q) = %q, want %q", tt.apiURL, tt.agentID, got, tt.want) - } - }) - } -} - func TestFormatSpeedLog(t *testing.T) { tests := []struct { bps int64 diff --git a/internal/cmd/reload_unix.go b/internal/cmd/reload_unix.go index 5577a76..8aa9177 100644 --- a/internal/cmd/reload_unix.go +++ b/internal/cmd/reload_unix.go @@ -7,7 +7,6 @@ import ( "os" "os/signal" "syscall" - "time" "github.com/torrentclaw/unarr/internal/agent" "github.com/torrentclaw/unarr/internal/config" @@ -19,7 +18,8 @@ type ReloadableConfig struct { } // startReloadWatcher listens for SIGUSR1 and reloads config. -// Only intervals are hot-reloadable (speeds require torrent client restart). +// With the sync-based architecture, intervals are fixed (3s watching, 60s idle). +// Hot-reload now mainly serves as a signal to re-read config for future settings. func startReloadWatcher(rc *ReloadableConfig) { sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGUSR1) @@ -28,24 +28,11 @@ func startReloadWatcher(rc *ReloadableConfig) { for range sigCh { log.Println("Received SIGUSR1, reloading config...") - cfg, err := config.Load("") + _, err := config.Load("") if err != nil { log.Printf("Config reload failed: %v", err) continue } - cfg.ApplyEnvOverrides() - - // Update poll interval - if d, _ := time.ParseDuration(cfg.Daemon.PollInterval); d > 0 && rc.Daemon.PollTicker != nil { - rc.Daemon.PollTicker.Reset(d) - log.Printf(" Poll interval: %s", d) - } - - // Update heartbeat interval - if d, _ := time.ParseDuration(cfg.Daemon.HeartbeatInterval); d > 0 && rc.Daemon.HeartbeatTicker != nil { - rc.Daemon.HeartbeatTicker.Reset(d) - log.Printf(" Heartbeat interval: %s", d) - } log.Println("Config reloaded successfully") } diff --git a/internal/cmd/version.go b/internal/cmd/version.go index e1b2837..86c4267 100644 --- a/internal/cmd/version.go +++ b/internal/cmd/version.go @@ -1,4 +1,4 @@ package cmd // Version is the CLI version. Overridden by goreleaser ldflags at release time. -var Version = "0.5.5" +var Version = "0.5.6" diff --git a/internal/config/config.go b/internal/config/config.go index 693f30d..cba221c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -26,7 +26,6 @@ type Config struct { type AuthConfig struct { APIKey string `toml:"api_key"` APIURL string `toml:"api_url"` - WSURL string `toml:"ws_url"` // optional, derived from api_url if empty } type AgentConfig struct { @@ -54,9 +53,7 @@ type OrganizeConfig struct { } type DaemonConfig struct { - PollInterval string `toml:"poll_interval"` - HeartbeatInterval string `toml:"heartbeat_interval"` - StatusInterval string `toml:"status_interval"` + StatusInterval string `toml:"status_interval"` } type NotificationsConfig struct { @@ -92,10 +89,7 @@ func Default() Config { Organize: OrganizeConfig{ Enabled: true, }, - Daemon: DaemonConfig{ - PollInterval: "30s", - HeartbeatInterval: "30s", - }, + Daemon: DaemonConfig{}, Notifications: NotificationsConfig{ Enabled: true, }, diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 3190399..6685fbc 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -21,8 +21,8 @@ func TestDefault(t *testing.T) { if cfg.General.Country != "US" { t.Errorf("default Country = %q, want US", cfg.General.Country) } - if cfg.Daemon.HeartbeatInterval != "30s" { - t.Errorf("default HeartbeatInterval = %q, want 30s", cfg.Daemon.HeartbeatInterval) + if cfg.Daemon.StatusInterval != "" { + t.Errorf("default StatusInterval = %q, want empty", cfg.Daemon.StatusInterval) } } diff --git a/internal/engine/debrid.go b/internal/engine/debrid.go index 7aea0bf..fce60dd 100644 --- a/internal/engine/debrid.go +++ b/internal/engine/debrid.go @@ -10,6 +10,8 @@ import ( "path/filepath" "sync" "time" + + "github.com/torrentclaw/unarr/internal/agent" ) // httpClient is used for debrid HTTPS downloads with a reasonable header timeout. @@ -19,13 +21,6 @@ var httpClient = &http.Client{ }, } -func shortID(id string) string { - if len(id) > 8 { - return id[:8] - } - return id -} - // DebridDownloader downloads files via HTTPS direct URLs resolved by the server. // The server handles all debrid provider interaction; this downloader only needs // a plain HTTPS URL to fetch. @@ -129,7 +124,7 @@ func (d *DebridDownloader) Download(ctx context.Context, task *Task, outputDir s var serverSize int64 if _, err := fmt.Sscanf(cr, "bytes */%d", &serverSize); err == nil && serverSize > 0 && existingSize != serverSize { // Local file size doesn't match server — re-download from scratch - log.Printf("[%s] local size %s != server size %s, re-downloading", shortID(task.ID), formatBytes(existingSize), formatBytes(serverSize)) + log.Printf("[%s] local size %s != server size %s, re-downloading", agent.ShortID(task.ID), formatBytes(existingSize), formatBytes(serverSize)) resp.Body.Close() req2, err := http.NewRequestWithContext(dlCtx, http.MethodGet, task.DirectURL, nil) if err != nil { @@ -149,7 +144,7 @@ func (d *DebridDownloader) Download(ctx context.Context, task *Task, outputDir s break // continue to download loop } } - log.Printf("[%s] file already complete: %s (%s)", shortID(task.ID), fileName, formatBytes(existingSize)) + log.Printf("[%s] file already complete: %s (%s)", agent.ShortID(task.ID), fileName, formatBytes(existingSize)) return &Result{ FilePath: destPath, FileName: fileName, @@ -166,10 +161,10 @@ func (d *DebridDownloader) Download(ctx context.Context, task *Task, outputDir s var flags int if startOffset > 0 { flags = os.O_WRONLY | os.O_APPEND - log.Printf("[%s] resuming debrid download at %s: %s", shortID(task.ID), formatBytes(startOffset), fileName) + log.Printf("[%s] resuming debrid download at %s: %s", agent.ShortID(task.ID), formatBytes(startOffset), fileName) } else { flags = os.O_WRONLY | os.O_CREATE | os.O_TRUNC - log.Printf("[%s] starting debrid download: %s", shortID(task.ID), fileName) + log.Printf("[%s] starting debrid download: %s", agent.ShortID(task.ID), fileName) } if err := os.MkdirAll(filepath.Dir(destPath), 0o755); err != nil { @@ -223,7 +218,7 @@ func (d *DebridDownloader) Download(ctx context.Context, task *Task, outputDir s } log.Printf("[%s] %d%% — %s/%s @ %s/s (debrid)", - shortID(task.ID), pct, + agent.ShortID(task.ID), pct, formatBytes(downloaded), formatBytes(totalBytes), formatBytes(speed)) p := Progress{ @@ -252,7 +247,7 @@ func (d *DebridDownloader) Download(ctx context.Context, task *Task, outputDir s } } - log.Printf("[%s] debrid download complete: %s (%s)", shortID(task.ID), fileName, formatBytes(downloaded)) + log.Printf("[%s] debrid download complete: %s (%s)", agent.ShortID(task.ID), fileName, formatBytes(downloaded)) return &Result{ FilePath: destPath, @@ -271,7 +266,7 @@ func (d *DebridDownloader) Pause(taskID string) error { if ok { cancel() - log.Printf("[%s] debrid download paused (file kept for resume)", shortID(taskID)) + log.Printf("[%s] debrid download paused (file kept for resume)", agent.ShortID(taskID)) } return nil } @@ -285,7 +280,7 @@ func (d *DebridDownloader) Cancel(taskID string) error { if ok { cancel() - log.Printf("[%s] debrid download cancelled", shortID(taskID)) + log.Printf("[%s] debrid download cancelled", agent.ShortID(taskID)) } return nil } diff --git a/internal/engine/manager.go b/internal/engine/manager.go index 12cfc06..2a07b6f 100644 --- a/internal/engine/manager.go +++ b/internal/engine/manager.go @@ -28,6 +28,15 @@ type Manager struct { sem chan struct{} wg sync.WaitGroup + + // OnTaskDone is called after a task completes or fails (slot freed). + // Used by the daemon to trigger an immediate sync. + OnTaskDone func() + + // recentlyFinished holds tasks that completed/failed since the last sync read. + // The sync goroutine reads and clears this to include final states in the next sync. + recentMu sync.Mutex + recentFinished []agent.TaskState } // NewManager creates a download manager. @@ -67,7 +76,7 @@ func (m *Manager) Submit(ctx context.Context, at agent.Task) { // Force start: bypass semaphore (like Transmission's "Force Start") if at.ForceStart { - log.Printf("[%s] force start: bypassing queue", task.ID[:8]) + log.Printf("[%s] force start: bypassing queue", agent.ShortID(task.ID)) m.wg.Add(1) go func() { defer m.wg.Done() @@ -88,7 +97,12 @@ func (m *Manager) Submit(ctx context.Context, at agent.Task) { m.wg.Add(1) go func() { defer m.wg.Done() - defer func() { <-m.sem }() + defer func() { + <-m.sem + if m.OnTaskDone != nil { + m.OnTaskDone() + } + }() defer taskCancel() m.processTask(taskCtx, task) }() @@ -99,6 +113,11 @@ func (m *Manager) HasCapacity() bool { return len(m.sem) < cap(m.sem) } +// FreeSlots returns the number of available download slots. +func (m *Manager) FreeSlots() int { + return cap(m.sem) - len(m.sem) +} + // ActiveCount returns the number of in-progress downloads. func (m *Manager) ActiveCount() int { m.activeMu.RLock() @@ -113,6 +132,17 @@ func (m *Manager) GetTask(taskID string) *Task { return m.active[taskID] } +// ActiveTaskIDs returns the IDs of all in-progress tasks. +func (m *Manager) ActiveTaskIDs() []string { + m.activeMu.RLock() + defer m.activeMu.RUnlock() + ids := make([]string, 0, len(m.active)) + for id := range m.active { + ids = append(ids, id) + } + return ids +} + // ActiveTasks returns a snapshot of all active tasks. func (m *Manager) ActiveTasks() []*Task { m.activeMu.RLock() @@ -124,6 +154,37 @@ func (m *Manager) ActiveTasks() []*Task { return tasks } +// TaskStates returns the current state of all active tasks plus any recently +// finished tasks that haven't been synced yet. Called by the sync goroutine. +func (m *Manager) TaskStates() []agent.TaskState { + // Collect active tasks + m.activeMu.RLock() + states := make([]agent.TaskState, 0, len(m.active)) + for _, t := range m.active { + states = append(states, agent.TaskStateFromUpdate(t.ToStatusUpdate())) + } + m.activeMu.RUnlock() + + // Drain recently finished tasks (consumed once per sync) + m.recentMu.Lock() + states = append(states, m.recentFinished...) + m.recentFinished = nil + m.recentMu.Unlock() + + return states +} + +// recordFinished stores a completed/failed task for the next sync cycle. +func (m *Manager) recordFinished(update agent.StatusUpdate) { + m.recentMu.Lock() + defer m.recentMu.Unlock() + m.recentFinished = append(m.recentFinished, agent.TaskStateFromUpdate(update)) + // Keep bounded + if len(m.recentFinished) > 20 { + m.recentFinished = m.recentFinished[len(m.recentFinished)-20:] + } +} + // CancelTask cancels an active download by task ID (keeps partial files). func (m *Manager) CancelTask(taskID string) { m.activeMu.RLock() @@ -150,7 +211,7 @@ func (m *Manager) CancelTask(taskID string) { task.mu.Unlock() task.Transition(StatusCancelled) - log.Printf("[%s] cancelled: %s", taskID[:8], task.Title) + log.Printf("[%s] cancelled: %s", agent.ShortID(taskID), task.Title) } // PauseTask pauses an active download (keeps partial files for resume). @@ -173,7 +234,7 @@ func (m *Manager) PauseTask(taskID string) { } task.Transition(StatusCancelled) // will be re-created as pending by server - log.Printf("[%s] paused: %s", taskID[:8], task.Title) + log.Printf("[%s] paused: %s", agent.ShortID(taskID), task.Title) } // CancelAndDeleteFiles cancels a download and removes its files from disk. @@ -200,7 +261,7 @@ func (m *Manager) CancelAndDeleteFiles(taskID string) { task.mu.Unlock() task.Transition(StatusCancelled) - log.Printf("[%s] cancelled + files deleted: %s", taskID[:8], task.Title) + log.Printf("[%s] cancelled + files deleted: %s", agent.ShortID(taskID), task.Title) } // Wait blocks until all active downloads finish. @@ -261,7 +322,7 @@ func (m *Manager) processTask(ctx context.Context, task *Task) { } task.ResolvedMethod = method - log.Printf("[%s] resolved method: %s", task.ID[:8], method) + log.Printf("[%s] resolved method: %s", agent.ShortID(task.ID), method) // 2. Download if err := task.Transition(StatusDownloading); err != nil { @@ -285,7 +346,7 @@ func (m *Manager) processTask(ctx context.Context, task *Task) { if err != nil { // Try fallback if tryFallback(task, m.downloaders) { - log.Printf("[%s] %s failed, trying fallback: %v", task.ID[:8], method, err) + log.Printf("[%s] %s failed, trying fallback: %v", agent.ShortID(task.ID), method, err) if err := task.Transition(StatusResolving); err == nil { m.processTaskRetry(ctx, task) return @@ -295,61 +356,7 @@ func (m *Manager) processTask(ctx context.Context, task *Task) { return } - // 3. Verify - if err := task.Transition(StatusVerifying); err != nil { - m.fail(ctx, task, "transition error: "+err.Error()) - return - } - - if err := verify(result); err != nil { - m.fail(ctx, task, "verification failed: "+err.Error()) - return - } - - // 4. Organize - if err := task.Transition(StatusOrganizing); err != nil { - m.fail(ctx, task, "transition error: "+err.Error()) - return - } - - finalPath, err := organize(result, task, m.cfg.Organize) - if err != nil { - log.Printf("[%s] organize warning: %v (keeping in download dir)", task.ID[:8], err) - finalPath = result.FilePath - } - - task.mu.Lock() - task.FilePath = finalPath - task.mu.Unlock() - - // 4b. Handle upgrade replacement (mode = "upgrade") - if task.ReplacePath != "" { - backupDir := "" // uses default ~/.local/share/unarr/replaced/ - if err := replaceFile(task.ReplacePath, finalPath, backupDir); err != nil { - log.Printf("[%s] replace warning: %v (keeping new file at %s)", task.ID[:8], err, finalPath) - } else { - task.mu.Lock() - task.FilePath = task.ReplacePath - task.mu.Unlock() - log.Printf("[%s] upgraded: replaced %s", task.ID[:8], task.ReplacePath) - } - } - - // 5. Complete - if method == MethodTorrent && m.cfg.Organize.Enabled { - // Could add seeding here in the future - } - - if err := task.Transition(StatusCompleted); err != nil { - m.fail(ctx, task, "transition error: "+err.Error()) - return - } - - log.Printf("[%s] completed: %s -> %s", task.ID[:8], task.Title, finalPath) - if m.cfg.Notifications { - desktopNotify("Download complete", task.Title) - } - m.reporter.ReportFinal(ctx, task) + m.finalize(ctx, task, result) } // processTaskRetry handles fallback after a method failure. @@ -361,7 +368,7 @@ func (m *Manager) processTaskRetry(ctx context.Context, task *Task) { } task.ResolvedMethod = method - log.Printf("[%s] fallback to: %s", task.ID[:8], method) + log.Printf("[%s] fallback to: %s", agent.ShortID(task.ID), method) if err := task.Transition(StatusDownloading); err != nil { m.fail(ctx, task, "transition error: "+err.Error()) @@ -383,15 +390,31 @@ func (m *Manager) processTaskRetry(ctx context.Context, task *Task) { return } - // Verify + Organize + Complete (same as processTask) - task.Transition(StatusVerifying) + m.finalize(ctx, task, result) +} + +// finalize runs verify → organize → upgrade replacement → complete for a downloaded task. +func (m *Manager) finalize(ctx context.Context, task *Task, result *Result) { + // Verify + if err := task.Transition(StatusVerifying); err != nil { + m.fail(ctx, task, "transition error: "+err.Error()) + return + } if err := verify(result); err != nil { m.fail(ctx, task, "verification failed: "+err.Error()) return } - task.Transition(StatusOrganizing) - finalPath, _ := organize(result, task, m.cfg.Organize) + // Organize + if err := task.Transition(StatusOrganizing); err != nil { + m.fail(ctx, task, "transition error: "+err.Error()) + return + } + finalPath, err := organize(result, task, m.cfg.Organize) + if err != nil { + log.Printf("[%s] organize warning: %v (keeping in download dir)", agent.ShortID(task.ID), err) + finalPath = result.FilePath + } if finalPath == "" { finalPath = result.FilePath } @@ -399,8 +422,29 @@ func (m *Manager) processTaskRetry(ctx context.Context, task *Task) { task.FilePath = finalPath task.mu.Unlock() - task.Transition(StatusCompleted) - log.Printf("[%s] completed (fallback): %s -> %s", task.ID[:8], task.Title, finalPath) + // Handle upgrade replacement (mode = "upgrade") + if task.ReplacePath != "" { + backupDir := "" // uses default ~/.local/share/unarr/replaced/ + if err := replaceFile(task.ReplacePath, finalPath, backupDir); err != nil { + log.Printf("[%s] replace warning: %v (keeping new file at %s)", agent.ShortID(task.ID), err, finalPath) + } else { + task.mu.Lock() + task.FilePath = task.ReplacePath + task.mu.Unlock() + log.Printf("[%s] upgraded: replaced %s", agent.ShortID(task.ID), task.ReplacePath) + } + } + + // Complete + if err := task.Transition(StatusCompleted); err != nil { + m.fail(ctx, task, "transition error: "+err.Error()) + return + } + log.Printf("[%s] completed: %s -> %s", agent.ShortID(task.ID), task.Title, finalPath) + if m.cfg.Notifications { + desktopNotify("Download complete", task.Title) + } + m.recordFinished(task.ToStatusUpdate()) m.reporter.ReportFinal(ctx, task) } @@ -409,9 +453,10 @@ func (m *Manager) fail(ctx context.Context, task *Task, msg string) { task.ErrorMessage = msg task.mu.Unlock() task.Transition(StatusFailed) - log.Printf("[%s] FAILED: %s — %s", task.ID[:8], task.Title, msg) + log.Printf("[%s] FAILED: %s — %s", agent.ShortID(task.ID), task.Title, msg) if m.cfg.Notifications { desktopNotify("Download failed", task.Title+": "+msg) } + m.recordFinished(task.ToStatusUpdate()) m.reporter.ReportFinal(ctx, task) } diff --git a/internal/engine/progress.go b/internal/engine/progress.go index 6f958c9..eba8814 100644 --- a/internal/engine/progress.go +++ b/internal/engine/progress.go @@ -13,13 +13,11 @@ import ( type ActionFunc func(taskID string) // StatusReporter is the interface used by ProgressReporter to send progress updates. -// Both *agent.Client and agent.Transport implement this via their ReportStatus/SendProgress methods. type StatusReporter interface { ReportStatus(ctx context.Context, update agent.StatusUpdate) (*agent.StatusResponse, error) } // BatchStatusReporter extends StatusReporter with batch support. -// Transports that implement this send all updates in a single request. type BatchStatusReporter interface { StatusReporter BatchReportStatus(ctx context.Context, updates []agent.StatusUpdate) (*agent.BatchStatusResponse, error) @@ -48,7 +46,6 @@ type ProgressReporter struct { } // NewProgressReporter creates a reporter that flushes every interval. -// Accepts *agent.Client directly (backwards compatible). func NewProgressReporter(ac *agent.Client, interval time.Duration) *ProgressReporter { return &ProgressReporter{ reporter: ac, @@ -58,25 +55,6 @@ func NewProgressReporter(ac *agent.Client, interval time.Duration) *ProgressRepo } } -// NewProgressReporterWithTransport creates a reporter using a Transport. -func NewProgressReporterWithTransport(t agent.Transport, interval time.Duration) *ProgressReporter { - return &ProgressReporter{ - reporter: &transportStatusAdapter{t: t}, - interval: interval, - latest: make(map[string]*Task), - lastReported: make(map[string]TaskStatus), - } -} - -// transportStatusAdapter adapts agent.Transport to StatusReporter. -type transportStatusAdapter struct { - t agent.Transport -} - -func (a *transportStatusAdapter) ReportStatus(ctx context.Context, update agent.StatusUpdate) (*agent.StatusResponse, error) { - return a.t.SendProgress(ctx, update) -} - // SetCancelHandler sets the callback invoked when the server says a task is cancelled. func (r *ProgressReporter) SetCancelHandler(fn ActionFunc) { r.onCancel = fn }