feat(sync): replace WS+DO transport with unified HTTP sync

Replace the WebSocket + Cloudflare Durable Object architecture with a
single POST /sync endpoint. The CLI now operates autonomously with local
state (tasks.json) and syncs bidirectionally via adaptive-interval HTTP
polling (3s watching, 60s idle).

- Remove transport_ws, transport_hybrid, transport_http (~2,600 lines)
- Add SyncClient with adaptive interval loop
- Add LocalState for CLI-side task persistence
- Add TaskStateFromUpdate() helper (DRY)
- Extract finalize() to deduplicate processTask/processTaskRetry
- Consolidate shortID() into agent.ShortID (was in 3 packages)
- Wire GetActiveCount so `unarr status` shows active tasks
- Remove poll_interval, heartbeat_interval, ws_url from config
- Simplify ProgressReporter (sync replaces direct HTTP reporting)
This commit is contained in:
Deivid Soto 2026-04-08 18:50:59 +02:00
parent 2398707cc1
commit 5d4a67c7a2
26 changed files with 1320 additions and 3400 deletions

View file

@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.5.6] - 2026-04-07
### Fixed
- **ws**: add ping/pong keepalive and read deadline to detect zombie connections
## [0.5.5] - 2026-04-07
@ -17,6 +23,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- **daemon**: cancel watch reporter on stream switch and re-notify ready
### Other
- **release**: 0.5.5
## [0.5.4] - 2026-04-07
@ -153,6 +163,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- remove UPX compression (antivirus false positives, startup penalty)
- add -s -w -trimpath to Makefile, add build-small target with UPX
[0.5.6]: https://github.com/torrentclaw/unarr/compare/v0.5.5...v0.5.6
[0.5.5]: https://github.com/torrentclaw/unarr/compare/v0.5.4...v0.5.5
[0.5.4]: https://github.com/torrentclaw/unarr/compare/v0.5.3...v0.5.4
[0.5.3]: https://github.com/torrentclaw/unarr/compare/v0.5.2...v0.5.3

2
go.mod
View file

@ -11,7 +11,6 @@ require (
github.com/fatih/color v1.19.0
github.com/getsentry/sentry-go v0.44.1
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.3
github.com/huin/goupnp v1.3.0
github.com/olekukonko/tablewriter v1.1.4
github.com/spf13/cobra v1.10.2
@ -69,6 +68,7 @@ require (
github.com/google/btree v1.1.3 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/google/pprof v0.0.0-20260302011040-a15ffb7f9dcc // indirect
github.com/gorilla/websocket v1.5.3 // indirect
github.com/huandu/xstrings v1.5.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect

View file

@ -40,27 +40,6 @@ func (c *Client) Register(ctx context.Context, req RegisterRequest) (*RegisterRe
return &resp, nil
}
// Heartbeat sends a periodic keep-alive signal and returns server directives.
func (c *Client) Heartbeat(ctx context.Context, req HeartbeatRequest) (*HeartbeatResponse, error) {
var resp HeartbeatResponse
if err := c.doPost(ctx, "/api/internal/agent/heartbeat", req, &resp); err != nil {
return nil, fmt.Errorf("heartbeat: %w", err)
}
return &resp, nil
}
// ClaimTasks polls for pending download tasks and claims them atomically.
// Also returns any stream requests for completed downloads.
func (c *Client) ClaimTasks(ctx context.Context, agentID string) (*TasksResponse, error) {
url := fmt.Sprintf("/api/internal/agent/tasks?agentId=%s", agentID)
var resp TasksResponse
if err := c.doGet(ctx, url, &resp); err != nil {
return nil, fmt.Errorf("claim tasks: %w", err)
}
return &resp, nil
}
// ReportStatus reports download progress or completion for a task.
// Deregister notifies the server that the agent is shutting down.
func (c *Client) Deregister(ctx context.Context, agentID string) error {
req := struct {
@ -91,6 +70,16 @@ func (c *Client) BatchReportStatus(ctx context.Context, updates []StatusUpdate)
return &resp, nil
}
// Sync sends the CLI's full state and receives all pending server actions.
// This is the single endpoint for bidirectional state synchronization.
func (c *Client) Sync(ctx context.Context, req SyncRequest) (*SyncResponse, error) {
var resp SyncResponse
if err := c.doPost(ctx, "/api/internal/agent/sync", req, &resp); err != nil {
return nil, fmt.Errorf("sync: %w", err)
}
return &resp, nil
}
// ---------------------------------------------------------------------------
// Usenet endpoints
// ---------------------------------------------------------------------------

View file

@ -72,70 +72,6 @@ func TestRegister(t *testing.T) {
}
}
func TestHeartbeat(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/internal/agent/heartbeat" {
t.Errorf("path = %s, want /api/internal/agent/heartbeat", r.URL.Path)
}
var req HeartbeatRequest
json.NewDecoder(r.Body).Decode(&req)
if req.AgentID != "agent-123" {
t.Errorf("agentId = %q, want agent-123", req.AgentID)
}
json.NewEncoder(w).Encode(HeartbeatResponse{Success: true})
}))
defer srv.Close()
c := NewClient(srv.URL, "test-key", "unarr-test")
resp, err := c.Heartbeat(context.Background(), HeartbeatRequest{AgentID: "agent-123"})
if err != nil {
t.Fatalf("Heartbeat failed: %v", err)
}
if !resp.Success {
t.Error("expected success=true")
}
}
func TestClaimTasks(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
t.Errorf("method = %s, want GET", r.Method)
}
if r.URL.Query().Get("agentId") != "agent-123" {
t.Errorf("agentId param = %q, want agent-123", r.URL.Query().Get("agentId"))
}
json.NewEncoder(w).Encode(TasksResponse{
Tasks: []Task{
{
ID: "task-uuid-1",
InfoHash: "abc123def456abc123def456abc123def456abc1",
Title: "The Matrix (1999)",
PreferredMethod: "auto",
},
},
})
}))
defer srv.Close()
c := NewClient(srv.URL, "test-key", "unarr-test")
resp, err := c.ClaimTasks(context.Background(), "agent-123")
if err != nil {
t.Fatalf("ClaimTasks failed: %v", err)
}
if len(resp.Tasks) != 1 {
t.Fatalf("len(tasks) = %d, want 1", len(resp.Tasks))
}
if resp.Tasks[0].ID != "task-uuid-1" {
t.Errorf("task.ID = %q, want task-uuid-1", resp.Tasks[0].ID)
}
if resp.Tasks[0].InfoHash != "abc123def456abc123def456abc123def456abc1" {
t.Errorf("task.InfoHash = %q", resp.Tasks[0].InfoHash)
}
if resp.Tasks[0].PreferredMethod != "auto" {
t.Errorf("task.PreferredMethod = %q, want auto", resp.Tasks[0].PreferredMethod)
}
}
func TestReportStatus(t *testing.T) {
var received StatusUpdate
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -173,22 +109,6 @@ func TestReportStatus(t *testing.T) {
}
}
func TestClaimTasksEmpty(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(TasksResponse{Tasks: []Task{}})
}))
defer srv.Close()
c := NewClient(srv.URL, "test-key", "unarr-test")
resp, err := c.ClaimTasks(context.Background(), "agent-123")
if err != nil {
t.Fatalf("ClaimTasks failed: %v", err)
}
if len(resp.Tasks) != 0 {
t.Errorf("expected empty tasks, got %d", len(resp.Tasks))
}
}
func TestAPIError(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
@ -279,50 +199,12 @@ func TestUserAgent(t *testing.T) {
if r.Header.Get("User-Agent") != "unarr/0.2.0" {
t.Errorf("User-Agent = %q, want unarr/0.2.0", r.Header.Get("User-Agent"))
}
json.NewEncoder(w).Encode(HeartbeatResponse{Success: true})
json.NewEncoder(w).Encode(RegisterResponse{Success: true})
}))
defer srv.Close()
c := NewClient(srv.URL, "test-key", "unarr/0.2.0")
c.Heartbeat(context.Background(), HeartbeatRequest{AgentID: "x"})
}
func TestHeartbeatWithUpgradeSignal(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(HeartbeatResponse{
Success: true,
Upgrade: &UpgradeSignal{Version: "2.0.0"},
})
}))
defer srv.Close()
c := NewClient(srv.URL, "test-key", "unarr-test")
resp, err := c.Heartbeat(context.Background(), HeartbeatRequest{AgentID: "agent-1"})
if err != nil {
t.Fatalf("Heartbeat failed: %v", err)
}
if resp.Upgrade == nil {
t.Fatal("expected upgrade signal, got nil")
}
if resp.Upgrade.Version != "2.0.0" {
t.Errorf("upgrade version = %q, want 2.0.0", resp.Upgrade.Version)
}
}
func TestHeartbeatWithoutUpgradeSignal(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(HeartbeatResponse{Success: true})
}))
defer srv.Close()
c := NewClient(srv.URL, "test-key", "unarr-test")
resp, err := c.Heartbeat(context.Background(), HeartbeatRequest{AgentID: "agent-1"})
if err != nil {
t.Fatalf("Heartbeat failed: %v", err)
}
if resp.Upgrade != nil {
t.Errorf("expected no upgrade signal, got %+v", resp.Upgrade)
}
c.Register(context.Background(), RegisterRequest{AgentID: "x"})
}
func TestDeregister(t *testing.T) {

View file

@ -14,75 +14,62 @@ import (
// DaemonConfig holds daemon runtime settings.
type DaemonConfig struct {
AgentID string
AgentName string
Version string
DownloadDir string
PollInterval time.Duration
HeartbeatInterval time.Duration
StreamPort int // port for the HTTP stream server (reported in heartbeat)
LanIP string // LAN IP (reported in heartbeat for stream URL resolution)
TailscaleIP string // Tailscale IP (reported in heartbeat for stream URL resolution)
AgentID string
AgentName string
Version string
DownloadDir string
StreamPort int // port for the HTTP stream server
LanIP string // LAN IP (reported in sync for stream URL resolution)
TailscaleIP string // Tailscale IP (reported in sync for stream URL resolution)
}
// Daemon manages the main loop: register, heartbeat, poll tasks.
// Daemon manages agent registration and the sync loop.
type Daemon struct {
cfg DaemonConfig
transport Transport
cfg DaemonConfig
client *Client
sync *SyncClient
state *LocalState
// Callbacks
// Callbacks — set by cmd/daemon.go before calling Run.
OnTasksClaimed func(tasks []Task)
OnStreamRequested func(req StreamRequest)
OnControlAction func(action, taskID string)
OnControlAction func(action, taskID string, deleteFiles bool)
GetActiveCount func() int // returns number of active downloads (wired from manager)
// State
User UserInfo
Features FeatureFlags
Info AgentInfo
State DaemonState
heartbeatFailures int
lastNotifiedVersion string
// Callbacks for state tracking (set by cmd/daemon.go)
GetActiveCount func() int
GetCleanableBytes func() int64
// Watching tracks whether a user is viewing download progress in the web UI.
// When false, the progress reporter skips detailed updates (only sends final states).
// Accessed from heartbeat goroutine, flush goroutine, and WatchingFunc closure — must be atomic.
Watching atomic.Bool
// Exposed tickers for hot-reload
PollTicker *time.Ticker
HeartbeatTicker *time.Ticker
// pollNow triggers an immediate poll (e.g. on resume)
pollNow chan struct{}
// ScanNow triggers an immediate library scan (from heartbeat or WebSocket control event)
// ScanNow triggers an immediate library scan.
ScanNow chan struct{}
}
// NewDaemon creates a daemon with the given transport.
// Use NewHTTPTransport for HTTP-only, or NewHybridTransport for WS+HTTP.
func NewDaemon(cfg DaemonConfig, transport Transport) *Daemon {
if cfg.PollInterval == 0 {
cfg.PollInterval = 30 * time.Second
}
if cfg.HeartbeatInterval == 0 {
cfg.HeartbeatInterval = 30 * time.Second
}
// NewDaemon creates a daemon with an HTTP client for sync-based communication.
func NewDaemon(cfg DaemonConfig, client *Client) *Daemon {
state := NewLocalState()
return &Daemon{
cfg: cfg,
transport: transport,
pollNow: make(chan struct{}, 1),
ScanNow: make(chan struct{}, 1),
cfg: cfg,
client: client,
state: state,
sync: NewSyncClient(client, cfg, state),
ScanNow: make(chan struct{}, 1),
}
}
// Transport returns the configured transport.
func (d *Daemon) Transport() Transport { return d.transport }
// SyncClient returns the sync client for external wiring.
func (d *Daemon) SyncClient() *SyncClient { return d.sync }
// UpdateStreamPort updates the stream port reported in sync requests.
func (d *Daemon) UpdateStreamPort(port int) {
d.cfg.StreamPort = port
d.sync.cfg.StreamPort = port
}
// Register registers the agent and fetches user info + features.
// Retries with exponential backoff on transient errors (429, 5xx, network).
@ -109,11 +96,10 @@ func (d *Daemon) Register(ctx context.Context) error {
var resp *RegisterResponse
var err error
for attempt := range maxRetries {
resp, err = d.transport.Register(ctx, req)
resp, err = d.client.Register(ctx, req)
if err == nil {
break
}
// Only retry on transient errors (429, 5xx, network failures)
if !isTransientError(err) {
return fmt.Errorf("register: %w", err)
}
@ -154,14 +140,9 @@ func (d *Daemon) Register(ctx context.Context) error {
return nil
}
// Run connects the transport, registers the agent, and starts the main loop.
// Blocks until ctx is cancelled. Callers must NOT call transport.Connect before Run.
// Run registers the agent and starts the sync loop.
// Blocks until ctx is cancelled.
func (d *Daemon) Run(ctx context.Context) error {
// Connect transport (establishes WebSocket if available, falls back to HTTP)
if err := d.transport.Connect(ctx); err != nil {
return fmt.Errorf("connect transport: %w", err)
}
// Register
if err := d.Register(ctx); err != nil {
return err
@ -169,163 +150,61 @@ func (d *Daemon) Run(ctx context.Context) error {
log.Printf("Agent registered: %s (%s) [%s]", d.User.Name, d.User.Email, d.User.Plan)
log.Printf("Features: torrent=%v debrid=%v usenet=%v", d.Features.Torrent, d.Features.Debrid, d.Features.Usenet)
log.Printf("Polling every %s, heartbeat every %s", d.cfg.PollInterval, d.cfg.HeartbeatInterval)
d.HeartbeatTicker = time.NewTicker(d.cfg.HeartbeatInterval)
defer d.HeartbeatTicker.Stop()
d.PollTicker = time.NewTicker(d.cfg.PollInterval)
defer d.PollTicker.Stop()
heartbeatTicker := d.HeartbeatTicker
pollTicker := d.PollTicker
// Initial poll immediately
d.poll(ctx)
eventsCh := d.transport.Events()
for {
select {
case <-ctx.Done():
log.Println("Daemon shutting down...")
d.deregister()
return nil
case event := <-eventsCh:
d.handleEvent(event)
case <-heartbeatTicker.C:
d.heartbeat(ctx)
case <-pollTicker.C:
// Only poll in HTTP mode — WS mode receives tasks via Events
if d.transport.Mode() == "http" {
d.poll(ctx)
}
case <-d.pollNow:
d.poll(ctx)
// Wire sync callbacks
d.sync.OnNewTasks = func(tasks []Task) {
if d.OnTasksClaimed != nil {
d.OnTasksClaimed(tasks)
}
}
}
func (d *Daemon) heartbeat(ctx context.Context) {
req := HeartbeatRequest{
AgentID: d.cfg.AgentID,
Name: d.cfg.AgentName,
Version: d.cfg.Version,
OS: runtime.GOOS,
DownloadDir: d.cfg.DownloadDir,
StreamPort: d.cfg.StreamPort,
LanIP: d.cfg.LanIP,
TailscaleIP: d.cfg.TailscaleIP,
}
if free, total, err := DiskInfo(d.cfg.DownloadDir); err == nil {
req.DiskFreeBytes = free
req.DiskTotalBytes = total
}
resp, err := d.transport.SendHeartbeat(ctx, req)
if err != nil {
d.heartbeatFailures++
if d.heartbeatFailures >= 5 && d.heartbeatFailures%5 == 0 {
log.Printf("CRITICAL: %d consecutive heartbeat failures — server may be unreachable", d.heartbeatFailures)
} else {
log.Printf("Heartbeat failed: %v", err)
d.sync.OnControl = func(action, taskID string, deleteFiles bool) {
if d.OnControlAction != nil {
d.OnControlAction(action, taskID, deleteFiles)
}
return
}
if d.heartbeatFailures > 0 {
log.Printf("Heartbeat recovered after %d failures", d.heartbeatFailures)
d.heartbeatFailures = 0
d.sync.OnStreamRequest = func(req StreamRequest) {
if d.OnStreamRequested != nil {
d.OnStreamRequested(req)
}
}
// Update watching flag and state file
d.Watching.Store(resp.Watching)
d.State.LastHeartbeat = time.Now()
if d.GetActiveCount != nil {
d.State.ActiveTasks = d.GetActiveCount()
d.sync.OnUpgrade = func(version string) {
if version != d.lastNotifiedVersion {
d.lastNotifiedVersion = version
log.Printf("New version available: %s (run `unarr self-update` to upgrade)", version)
}
}
WriteState(&d.State)
// Trigger library scan if requested
if resp.Scan {
d.sync.OnScan = func() {
log.Printf("Library scan requested by server")
select {
case d.ScanNow <- struct{}{}:
default: // scan already pending
default:
}
}
// Log once per version when server suggests an upgrade
if resp.Upgrade != nil && resp.Upgrade.Version != "" && resp.Upgrade.Version != d.lastNotifiedVersion {
d.lastNotifiedVersion = resp.Upgrade.Version
log.Printf("New version available: %s (run `unarr self-update` to upgrade)", resp.Upgrade.Version)
d.sync.OnWatchingChange = func(watching bool) {
d.Watching.Store(watching)
}
}
// handleEvent processes a server-initiated event from the WebSocket transport.
func (d *Daemon) handleEvent(event ServerEvent) {
switch event.Type {
case "tasks":
if event.Tasks != nil && len(event.Tasks.Tasks) > 0 {
log.Printf("Received %d task(s) via WebSocket", len(event.Tasks.Tasks))
if d.OnTasksClaimed != nil {
d.OnTasksClaimed(event.Tasks.Tasks)
}
d.sync.OnSyncSuccess = func() {
d.State.LastHeartbeat = time.Now()
if d.GetActiveCount != nil {
d.State.ActiveTasks = d.GetActiveCount()
}
if event.Tasks != nil && d.OnStreamRequested != nil {
for _, sr := range event.Tasks.StreamRequests {
d.OnStreamRequested(sr)
}
}
case "upgrade":
if event.Upgrade != nil && event.Upgrade.Version != "" && event.Upgrade.Version != d.lastNotifiedVersion {
d.lastNotifiedVersion = event.Upgrade.Version
log.Printf("New version available: %s (run `unarr self-update` to upgrade)", event.Upgrade.Version)
}
case "control":
if event.Control != nil {
log.Printf("Control action via WebSocket: %s task %s", event.Control.Action, event.Control.TaskID)
if event.Control.Action == "scan" {
select {
case d.ScanNow <- struct{}{}:
default:
}
}
if d.OnControlAction != nil {
d.OnControlAction(event.Control.Action, event.Control.TaskID)
}
}
case "disconnected":
log.Println("WebSocket disconnected, switching to HTTP polling")
WriteState(&d.State)
}
// Start sync loop (blocks)
return d.sync.Run(ctx)
}
// UpdateStreamPort updates the stream port reported in heartbeats.
// Called after the persistent stream server binds (actual port may differ from configured).
func (d *Daemon) UpdateStreamPort(port int) {
d.cfg.StreamPort = port
// TriggerSync requests an immediate sync cycle.
func (d *Daemon) TriggerSync() {
d.sync.TriggerSync()
}
// TriggerPoll requests an immediate task poll cycle.
// Used when a resume event is received to pick up re-pending tasks faster.
func (d *Daemon) TriggerPoll() {
select {
case d.pollNow <- struct{}{}:
default: // already pending
}
}
func (d *Daemon) deregister() {
// Deregister notifies the server of graceful shutdown.
func (d *Daemon) Deregister() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := d.transport.Deregister(ctx, d.cfg.AgentID)
if err != nil {
if err := d.client.Deregister(ctx, d.cfg.AgentID); err != nil {
log.Printf("Deregister failed: %v", err)
} else {
log.Println("Agent deregistered")
@ -338,12 +217,10 @@ func isTransientError(err error) bool {
if err == nil {
return false
}
// Structured check: HTTPError carries the status code directly
var httpErr *HTTPError
if errors.As(err, &httpErr) {
return httpErr.StatusCode == 429 || httpErr.StatusCode >= 500
}
// Fallback: network-level errors (no HTTP response received)
lower := strings.ToLower(err.Error())
for _, keyword := range []string{"connection refused", "no such host", "timeout", "request failed"} {
if strings.Contains(lower, keyword) {
@ -352,27 +229,3 @@ func isTransientError(err error) bool {
}
return false
}
func (d *Daemon) poll(ctx context.Context) {
resp, err := d.transport.ClaimTasks(ctx, d.cfg.AgentID)
if err != nil {
log.Printf("Poll failed: %v", err)
return
}
d.Info.LastPollAt = time.Now()
if len(resp.Tasks) > 0 {
log.Printf("Claimed %d task(s)", len(resp.Tasks))
if d.OnTasksClaimed != nil {
d.OnTasksClaimed(resp.Tasks)
}
}
// Handle stream requests for completed downloads
if d.OnStreamRequested != nil {
for _, sr := range resp.StreamRequests {
d.OnStreamRequested(sr)
}
}
}

195
internal/agent/sync.go Normal file
View file

@ -0,0 +1,195 @@
package agent
import (
"context"
"log"
"runtime"
"sync/atomic"
"time"
)
const (
// SyncIntervalWatching is the sync interval when someone is viewing the web UI.
SyncIntervalWatching = 3 * time.Second
// SyncIntervalIdle is the sync interval when nobody is watching.
SyncIntervalIdle = 60 * time.Second
)
// SyncClient handles bidirectional state synchronization between the CLI and server.
// It sends the CLI's full execution state and receives all pending server actions
// in a single HTTP round-trip, at an adaptive interval.
type SyncClient struct {
client *Client
cfg DaemonConfig
state *LocalState
// Callbacks — set by the daemon before calling Run.
OnNewTasks func(tasks []Task)
OnControl func(action, taskID string, deleteFiles bool)
OnStreamRequest func(req StreamRequest)
OnUpgrade func(version string)
OnScan func()
OnWatchingChange func(watching bool)
OnSyncSuccess func() // called after each successful sync (e.g. to update state file)
GetFreeSlots func() int
GetTaskStates func() []TaskState // returns current state of all active + recently finished tasks
// SyncNow triggers an immediate sync (e.g., on task completion).
SyncNow chan struct{}
watching atomic.Bool
interval atomic.Int64 // stored as nanoseconds
}
// NewSyncClient creates a sync client.
func NewSyncClient(client *Client, cfg DaemonConfig, state *LocalState) *SyncClient {
sc := &SyncClient{
client: client,
cfg: cfg,
state: state,
SyncNow: make(chan struct{}, 1),
}
sc.interval.Store(int64(SyncIntervalIdle))
return sc
}
// Watching returns whether someone is viewing the web UI.
func (sc *SyncClient) Watching() bool {
return sc.watching.Load()
}
// TriggerSync requests an immediate sync cycle.
func (sc *SyncClient) TriggerSync() {
select {
case sc.SyncNow <- struct{}{}:
default:
}
}
// Run starts the adaptive sync loop. Blocks until ctx is cancelled.
func (sc *SyncClient) Run(ctx context.Context) error {
// Initial sync immediately
sc.doSync(ctx)
ticker := time.NewTicker(sc.currentInterval())
defer ticker.Stop()
for {
select {
case <-ctx.Done():
// Final sync to report latest state
finalCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
sc.doSync(finalCtx)
return nil
case <-ticker.C:
sc.doSync(ctx)
ticker.Reset(sc.currentInterval())
case <-sc.SyncNow:
sc.doSync(ctx)
ticker.Reset(sc.currentInterval())
}
}
}
func (sc *SyncClient) currentInterval() time.Duration {
return time.Duration(sc.interval.Load())
}
func (sc *SyncClient) doSync(ctx context.Context) {
req := sc.buildRequest()
resp, err := sc.client.Sync(ctx, req)
if err != nil {
if ctx.Err() == nil {
log.Printf("sync failed: %v", err)
}
return
}
sc.processResponse(resp)
sc.adjustInterval(resp.Watching)
if sc.OnSyncSuccess != nil {
sc.OnSyncSuccess()
}
}
func (sc *SyncClient) buildRequest() SyncRequest {
req := SyncRequest{
AgentID: sc.cfg.AgentID,
Name: sc.cfg.AgentName,
Version: sc.cfg.Version,
OS: runtime.GOOS,
Arch: runtime.GOARCH,
DownloadDir: sc.cfg.DownloadDir,
StreamPort: sc.cfg.StreamPort,
LanIP: sc.cfg.LanIP,
TailscaleIP: sc.cfg.TailscaleIP,
}
if sc.GetTaskStates != nil {
req.Tasks = sc.GetTaskStates()
} else {
req.Tasks = sc.state.Snapshot()
}
if free, total, err := DiskInfo(sc.cfg.DownloadDir); err == nil {
req.DiskFreeBytes = free
req.DiskTotalBytes = total
}
if sc.GetFreeSlots != nil {
req.FreeSlots = sc.GetFreeSlots()
}
return req
}
func (sc *SyncClient) processResponse(resp *SyncResponse) {
// New tasks
if len(resp.NewTasks) > 0 && sc.OnNewTasks != nil {
log.Printf("sync: received %d new task(s)", len(resp.NewTasks))
sc.OnNewTasks(resp.NewTasks)
}
// Control signals
for _, ctrl := range resp.Controls {
log.Printf("sync: control %s on task %s", ctrl.Action, ShortID(ctrl.TaskID))
if sc.OnControl != nil {
sc.OnControl(ctrl.Action, ctrl.TaskID, ctrl.DeleteFiles)
}
}
// Stream requests
for _, sr := range resp.StreamRequests {
if sc.OnStreamRequest != nil {
sc.OnStreamRequest(sr)
}
}
// Upgrade
if resp.Upgrade != nil && resp.Upgrade.Version != "" && sc.OnUpgrade != nil {
sc.OnUpgrade(resp.Upgrade.Version)
}
// Scan
if resp.Scan && sc.OnScan != nil {
sc.OnScan()
}
}
func (sc *SyncClient) adjustInterval(watching bool) {
prev := sc.watching.Load()
sc.watching.Store(watching)
var newInterval time.Duration
if watching {
newInterval = SyncIntervalWatching
} else {
newInterval = SyncIntervalIdle
}
if sc.interval.Swap(int64(newInterval)) != int64(newInterval) {
log.Printf("sync: interval=%s (watching=%v)", newInterval, watching)
}
if prev != watching && sc.OnWatchingChange != nil {
sc.OnWatchingChange(watching)
}
}

362
internal/agent/sync_test.go Normal file
View file

@ -0,0 +1,362 @@
package agent
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
)
func newTestSyncClient(url string) (*SyncClient, *Client) {
client := NewClient(url, "test-key", "test-agent/1.0")
cfg := DaemonConfig{
AgentID: "test-agent",
AgentName: "Test",
Version: "1.0.0",
DownloadDir: "/tmp/downloads",
}
state := NewLocalState()
sc := NewSyncClient(client, cfg, state)
return sc, client
}
func TestSyncClient_NewDefaults(t *testing.T) {
sc, _ := newTestSyncClient("http://localhost")
if sc.Watching() {
t.Error("should not be watching initially")
}
if sc.currentInterval() != SyncIntervalIdle {
t.Errorf("expected idle interval %v, got %v", SyncIntervalIdle, sc.currentInterval())
}
}
func TestSyncClient_AdjustInterval_Watching(t *testing.T) {
sc, _ := newTestSyncClient("http://localhost")
sc.adjustInterval(true)
if sc.currentInterval() != SyncIntervalWatching {
t.Errorf("expected watching interval %v, got %v", SyncIntervalWatching, sc.currentInterval())
}
if !sc.Watching() {
t.Error("expected watching=true")
}
}
func TestSyncClient_AdjustInterval_NotWatching(t *testing.T) {
sc, _ := newTestSyncClient("http://localhost")
// First set watching, then unset
sc.adjustInterval(true)
sc.adjustInterval(false)
if sc.currentInterval() != SyncIntervalIdle {
t.Errorf("expected idle interval %v, got %v", SyncIntervalIdle, sc.currentInterval())
}
if sc.Watching() {
t.Error("expected watching=false")
}
}
func TestSyncClient_AdjustInterval_CallsOnWatchingChange(t *testing.T) {
sc, _ := newTestSyncClient("http://localhost")
var changes []bool
sc.OnWatchingChange = func(w bool) { changes = append(changes, w) }
sc.adjustInterval(true)
sc.adjustInterval(true) // no change
sc.adjustInterval(false) // change
if len(changes) != 2 {
t.Fatalf("expected 2 changes, got %d: %v", len(changes), changes)
}
if !changes[0] {
t.Error("first change should be true")
}
if changes[1] {
t.Error("second change should be false")
}
}
func TestSyncClient_TriggerSync_NonBlocking(t *testing.T) {
sc, _ := newTestSyncClient("http://localhost")
// Fill the channel
sc.TriggerSync()
// Should not block
sc.TriggerSync()
sc.TriggerSync()
// Drain
select {
case <-sc.SyncNow:
default:
t.Error("expected a sync trigger in channel")
}
}
func TestSyncClient_ProcessResponse_NewTasks(t *testing.T) {
sc, _ := newTestSyncClient("http://localhost")
var received []Task
sc.OnNewTasks = func(tasks []Task) { received = tasks }
sc.processResponse(&SyncResponse{
NewTasks: []Task{
{ID: "t1", Title: "Movie 1", InfoHash: "abc"},
{ID: "t2", Title: "Movie 2", InfoHash: "def"},
},
})
if len(received) != 2 {
t.Fatalf("expected 2 tasks, got %d", len(received))
}
if received[0].Title != "Movie 1" {
t.Errorf("expected Movie 1, got %s", received[0].Title)
}
}
func TestSyncClient_ProcessResponse_NoTasks(t *testing.T) {
sc, _ := newTestSyncClient("http://localhost")
var called bool
sc.OnNewTasks = func(tasks []Task) { called = true }
sc.processResponse(&SyncResponse{NewTasks: nil})
if called {
t.Error("OnNewTasks should not be called with empty tasks")
}
}
func TestSyncClient_ProcessResponse_Controls(t *testing.T) {
sc, _ := newTestSyncClient("http://localhost")
var actions []string
var taskIDs []string
sc.OnControl = func(action, taskID string, deleteFiles bool) {
actions = append(actions, action)
taskIDs = append(taskIDs, taskID)
}
sc.processResponse(&SyncResponse{
Controls: []ControlAction{
{Action: "cancel", TaskID: "task-1234-5678"},
{Action: "pause", TaskID: "task-abcd-efgh"},
},
})
if len(actions) != 2 {
t.Fatalf("expected 2 controls, got %d", len(actions))
}
if actions[0] != "cancel" {
t.Errorf("expected cancel, got %s", actions[0])
}
if actions[1] != "pause" {
t.Errorf("expected pause, got %s", actions[1])
}
}
func TestSyncClient_ProcessResponse_Upgrade(t *testing.T) {
sc, _ := newTestSyncClient("http://localhost")
var version string
sc.OnUpgrade = func(v string) { version = v }
sc.processResponse(&SyncResponse{
Upgrade: &UpgradeSignal{Version: "2.0.0"},
})
if version != "2.0.0" {
t.Errorf("expected 2.0.0, got %s", version)
}
}
func TestSyncClient_ProcessResponse_UpgradeEmpty(t *testing.T) {
sc, _ := newTestSyncClient("http://localhost")
var called bool
sc.OnUpgrade = func(v string) { called = true }
sc.processResponse(&SyncResponse{
Upgrade: &UpgradeSignal{Version: ""},
})
if called {
t.Error("OnUpgrade should not be called with empty version")
}
}
func TestSyncClient_ProcessResponse_Scan(t *testing.T) {
sc, _ := newTestSyncClient("http://localhost")
var called bool
sc.OnScan = func() { called = true }
sc.processResponse(&SyncResponse{Scan: true})
if !called {
t.Error("OnScan should have been called")
}
}
func TestSyncClient_ProcessResponse_StreamRequests(t *testing.T) {
sc, _ := newTestSyncClient("http://localhost")
var received []StreamRequest
sc.OnStreamRequest = func(sr StreamRequest) { received = append(received, sr) }
sc.processResponse(&SyncResponse{
StreamRequests: []StreamRequest{
{TaskID: "t1", FilePath: "/tmp/movie.mkv"},
},
})
if len(received) != 1 {
t.Fatalf("expected 1 stream request, got %d", len(received))
}
if received[0].FilePath != "/tmp/movie.mkv" {
t.Errorf("expected /tmp/movie.mkv, got %s", received[0].FilePath)
}
}
func TestSyncClient_BuildRequest_WithGetTaskStates(t *testing.T) {
sc, _ := newTestSyncClient("http://localhost")
sc.GetTaskStates = func() []TaskState {
return []TaskState{
{TaskID: "t1", Status: "downloading", Progress: 50},
}
}
sc.GetFreeSlots = func() int { return 2 }
req := sc.buildRequest()
if req.AgentID != "test-agent" {
t.Errorf("expected test-agent, got %s", req.AgentID)
}
if len(req.Tasks) != 1 {
t.Fatalf("expected 1 task, got %d", len(req.Tasks))
}
if req.Tasks[0].Progress != 50 {
t.Errorf("expected progress 50, got %d", req.Tasks[0].Progress)
}
if req.FreeSlots != 2 {
t.Errorf("expected 2 free slots, got %d", req.FreeSlots)
}
}
func TestSyncClient_BuildRequest_FallbackToState(t *testing.T) {
client := NewClient("http://localhost", "key", "ua")
state := NewLocalState()
state.Update(TaskState{TaskID: "t1", Status: "completed", Progress: 100})
sc := NewSyncClient(client, DaemonConfig{AgentID: "a1", Version: "1.0"}, state)
// GetTaskStates is nil — should fall back to state.Snapshot()
req := sc.buildRequest()
if len(req.Tasks) != 1 {
t.Fatalf("expected 1 task from state fallback, got %d", len(req.Tasks))
}
}
func TestSyncClient_DoSync_Success(t *testing.T) {
var syncCount atomic.Int32
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
syncCount.Add(1)
json.NewEncoder(w).Encode(SyncResponse{
Watching: true,
NewTasks: []Task{{ID: "t1", Title: "Test Movie", InfoHash: "abc"}},
})
}))
defer srv.Close()
sc, _ := newTestSyncClient(srv.URL)
var tasksReceived []Task
sc.OnNewTasks = func(tasks []Task) { tasksReceived = tasks }
sc.doSync(context.Background())
if syncCount.Load() != 1 {
t.Errorf("expected 1 sync call, got %d", syncCount.Load())
}
if len(tasksReceived) != 1 {
t.Fatalf("expected 1 task, got %d", len(tasksReceived))
}
if !sc.Watching() {
t.Error("expected watching=true after sync")
}
if sc.currentInterval() != SyncIntervalWatching {
t.Errorf("expected watching interval after sync")
}
}
func TestSyncClient_DoSync_Error(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer srv.Close()
sc, _ := newTestSyncClient(srv.URL)
// Should not panic on error
sc.doSync(context.Background())
}
func TestSyncClient_Run_CancelStopsLoop(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(SyncResponse{})
}))
defer srv.Close()
sc, _ := newTestSyncClient(srv.URL)
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
err := sc.Run(ctx)
if err != nil {
t.Errorf("expected nil error, got %v", err)
}
}
func TestSyncClient_Run_ImmediateSyncOnTrigger(t *testing.T) {
var syncCount atomic.Int32
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
syncCount.Add(1)
json.NewEncoder(w).Encode(SyncResponse{})
}))
defer srv.Close()
sc, _ := newTestSyncClient(srv.URL)
// Set interval to something long so only triggers cause syncs
sc.interval.Store(int64(10 * time.Second))
ctx, cancel := context.WithCancel(context.Background())
go func() {
// Wait for initial sync, then trigger 2 more
time.Sleep(50 * time.Millisecond)
sc.TriggerSync()
time.Sleep(50 * time.Millisecond)
sc.TriggerSync()
time.Sleep(50 * time.Millisecond)
cancel()
}()
sc.Run(ctx)
// Initial sync (1) + 2 triggers + final sync = 4
count := syncCount.Load()
if count < 3 {
t.Errorf("expected at least 3 syncs (initial + 2 triggers), got %d", count)
}
}

136
internal/agent/taskstate.go Normal file
View file

@ -0,0 +1,136 @@
package agent
import (
"encoding/json"
"os"
"path/filepath"
"sync"
"time"
"github.com/torrentclaw/unarr/internal/config"
)
// TaskState represents the execution state of a single download task.
// Written by the Task Engine, read by the Sync goroutine.
type TaskState struct {
TaskID string `json:"taskId"`
Status string `json:"status"` // resolving, downloading, verifying, organizing, completed, failed
Progress int `json:"progress"`
DownloadedBytes int64 `json:"downloadedBytes,omitempty"`
TotalBytes int64 `json:"totalBytes,omitempty"`
SpeedBps int64 `json:"speedBps,omitempty"`
ETA int `json:"eta,omitempty"`
ResolvedMethod string `json:"resolvedMethod,omitempty"`
FileName string `json:"fileName,omitempty"`
FilePath string `json:"filePath,omitempty"`
StreamURL string `json:"streamUrl,omitempty"`
ErrorMessage string `json:"errorMessage,omitempty"`
UpdatedAt int64 `json:"updatedAt"`
}
// LocalState holds the CLI's local execution state (tasks.json).
// This is the CLI's source of truth for what it's doing right now.
type LocalState struct {
mu sync.RWMutex
tasks map[string]*TaskState
}
// NewLocalState creates an empty local state.
func NewLocalState() *LocalState {
return &LocalState{
tasks: make(map[string]*TaskState),
}
}
// Update adds or updates a task in local state.
func (s *LocalState) Update(ts TaskState) {
s.mu.Lock()
defer s.mu.Unlock()
ts.UpdatedAt = time.Now().Unix()
copied := ts
s.tasks[ts.TaskID] = &copied
}
// Remove removes a task from local state.
func (s *LocalState) Remove(taskID string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.tasks, taskID)
}
// Snapshot returns a copy of all current task states.
func (s *LocalState) Snapshot() []TaskState {
s.mu.RLock()
defer s.mu.RUnlock()
result := make([]TaskState, 0, len(s.tasks))
for _, ts := range s.tasks {
result = append(result, *ts)
}
return result
}
// TaskStateFromUpdate converts a StatusUpdate into a TaskState.
func TaskStateFromUpdate(u StatusUpdate) TaskState {
return TaskState{
TaskID: u.TaskID,
Status: u.Status,
Progress: u.Progress,
DownloadedBytes: u.DownloadedBytes,
TotalBytes: u.TotalBytes,
SpeedBps: u.SpeedBps,
ETA: u.ETA,
ResolvedMethod: u.ResolvedMethod,
FileName: u.FileName,
FilePath: u.FilePath,
StreamURL: u.StreamURL,
ErrorMessage: u.ErrorMessage,
}
}
// ShortID returns the first 8 characters of an ID, or the full ID if shorter.
func ShortID(id string) string {
if len(id) > 8 {
return id[:8]
}
return id
}
// taskStateFilePathFn is overridable for testing.
var taskStateFilePathFn = func() string {
return filepath.Join(config.DataDir(), "tasks.json")
}
// WriteToDisk persists local state to disk atomically (best-effort).
func (s *LocalState) WriteToDisk() {
tasks := s.Snapshot()
data, err := json.MarshalIndent(tasks, "", " ")
if err != nil {
return
}
path := taskStateFilePathFn()
dir := filepath.Dir(path)
os.MkdirAll(dir, 0o755)
tmp := path + ".tmp"
if err := os.WriteFile(tmp, data, 0o644); err != nil {
return
}
os.Rename(tmp, path)
}
// ReadFromDisk loads local state from disk. Returns empty state on error.
func (s *LocalState) ReadFromDisk() {
data, err := os.ReadFile(taskStateFilePathFn())
if err != nil {
return
}
var tasks []TaskState
if json.Unmarshal(data, &tasks) != nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
s.tasks = make(map[string]*TaskState, len(tasks))
for i := range tasks {
s.tasks[tasks[i].TaskID] = &tasks[i]
}
}

View file

@ -0,0 +1,217 @@
package agent
import (
"os"
"path/filepath"
"sync"
"testing"
)
func TestLocalState_UpdateAndSnapshot(t *testing.T) {
s := NewLocalState()
s.Update(TaskState{TaskID: "t1", Status: "downloading", Progress: 50})
s.Update(TaskState{TaskID: "t2", Status: "completed", Progress: 100})
snap := s.Snapshot()
if len(snap) != 2 {
t.Fatalf("expected 2 tasks, got %d", len(snap))
}
byID := make(map[string]TaskState, len(snap))
for _, ts := range snap {
byID[ts.TaskID] = ts
}
if byID["t1"].Progress != 50 {
t.Errorf("expected progress 50, got %d", byID["t1"].Progress)
}
if byID["t2"].Status != "completed" {
t.Errorf("expected completed, got %s", byID["t2"].Status)
}
}
func TestLocalState_UpdateOverwrites(t *testing.T) {
s := NewLocalState()
s.Update(TaskState{TaskID: "t1", Status: "downloading", Progress: 30})
s.Update(TaskState{TaskID: "t1", Status: "downloading", Progress: 70})
snap := s.Snapshot()
if len(snap) != 1 {
t.Fatalf("expected 1 task, got %d", len(snap))
}
if snap[0].Progress != 70 {
t.Errorf("expected progress 70, got %d", snap[0].Progress)
}
}
func TestLocalState_Remove(t *testing.T) {
s := NewLocalState()
s.Update(TaskState{TaskID: "t1", Status: "downloading"})
s.Update(TaskState{TaskID: "t2", Status: "downloading"})
s.Remove("t1")
snap := s.Snapshot()
if len(snap) != 1 {
t.Fatalf("expected 1 task, got %d", len(snap))
}
if snap[0].TaskID != "t2" {
t.Errorf("expected t2, got %s", snap[0].TaskID)
}
}
func TestLocalState_RemoveNonExistent(t *testing.T) {
s := NewLocalState()
s.Remove("nonexistent") // should not panic
}
func TestLocalState_SnapshotIsACopy(t *testing.T) {
s := NewLocalState()
s.Update(TaskState{TaskID: "t1", Status: "downloading", Progress: 50})
snap := s.Snapshot()
snap[0].Progress = 999
snap2 := s.Snapshot()
if snap2[0].Progress != 50 {
t.Errorf("snapshot mutation leaked: got progress %d", snap2[0].Progress)
}
}
func TestLocalState_UpdateSetsTimestamp(t *testing.T) {
s := NewLocalState()
s.Update(TaskState{TaskID: "t1", Status: "downloading"})
snap := s.Snapshot()
if snap[0].UpdatedAt == 0 {
t.Error("expected non-zero UpdatedAt")
}
}
func TestLocalState_ConcurrentAccess(t *testing.T) {
s := NewLocalState()
var wg sync.WaitGroup
for i := range 100 {
wg.Add(1)
go func(n int) {
defer wg.Done()
taskID := "t" + string(rune('0'+n%10))
s.Update(TaskState{TaskID: taskID, Status: "downloading", Progress: n})
s.Snapshot()
if n%3 == 0 {
s.Remove(taskID)
}
}(i)
}
wg.Wait()
// No race condition = test passes
}
func TestLocalState_WriteToDisk_ReadFromDisk(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "tasks.json")
// Override the file path for testing
orig := taskStateFilePathFn
taskStateFilePathFn = func() string { return path }
defer func() { taskStateFilePathFn = orig }()
s := NewLocalState()
s.Update(TaskState{TaskID: "t1", Status: "downloading", Progress: 45})
s.Update(TaskState{TaskID: "t2", Status: "completed", Progress: 100, FilePath: "/tmp/movie.mkv"})
s.WriteToDisk()
// Verify file exists
if _, err := os.Stat(path); os.IsNotExist(err) {
t.Fatal("tasks.json was not created")
}
// Read into a new LocalState
s2 := NewLocalState()
s2.ReadFromDisk()
snap := s2.Snapshot()
if len(snap) != 2 {
t.Fatalf("expected 2 tasks after read, got %d", len(snap))
}
byID := make(map[string]TaskState, len(snap))
for _, ts := range snap {
byID[ts.TaskID] = ts
}
if byID["t1"].Progress != 45 {
t.Errorf("expected progress 45, got %d", byID["t1"].Progress)
}
if byID["t2"].FilePath != "/tmp/movie.mkv" {
t.Errorf("expected /tmp/movie.mkv, got %s", byID["t2"].FilePath)
}
}
func TestLocalState_ReadFromDisk_CorruptedFile(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "tasks.json")
orig := taskStateFilePathFn
taskStateFilePathFn = func() string { return path }
defer func() { taskStateFilePathFn = orig }()
// Write corrupted JSON
os.WriteFile(path, []byte("{invalid json"), 0o644)
s := NewLocalState()
s.ReadFromDisk() // should not panic
snap := s.Snapshot()
if len(snap) != 0 {
t.Errorf("expected 0 tasks from corrupted file, got %d", len(snap))
}
}
func TestLocalState_ReadFromDisk_FileNotFound(t *testing.T) {
orig := taskStateFilePathFn
taskStateFilePathFn = func() string { return "/nonexistent/path/tasks.json" }
defer func() { taskStateFilePathFn = orig }()
s := NewLocalState()
s.ReadFromDisk() // should not panic
snap := s.Snapshot()
if len(snap) != 0 {
t.Errorf("expected 0 tasks, got %d", len(snap))
}
}
func TestLocalState_AtomicWrite(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "tasks.json")
orig := taskStateFilePathFn
taskStateFilePathFn = func() string { return path }
defer func() { taskStateFilePathFn = orig }()
s := NewLocalState()
s.Update(TaskState{TaskID: "t1", Status: "downloading"})
s.WriteToDisk()
// Verify no .tmp file remains
tmpPath := path + ".tmp"
if _, err := os.Stat(tmpPath); !os.IsNotExist(err) {
t.Error("temp file should not exist after write")
}
}
func TestLocalState_EmptySnapshot(t *testing.T) {
s := NewLocalState()
snap := s.Snapshot()
if snap == nil {
t.Error("snapshot should be non-nil empty slice")
}
if len(snap) != 0 {
t.Errorf("expected 0 tasks, got %d", len(snap))
}
}

View file

@ -1,51 +0,0 @@
package agent
import "context"
// Transport abstracts the communication protocol between the agent and server.
// Both WebSocket (via CF Durable Object) and HTTP (direct to origin) implement this.
type Transport interface {
// Connect establishes the transport connection.
// Called internally by Daemon.Run — callers must NOT call Connect separately.
Connect(ctx context.Context) error
// Close tears down the connection gracefully.
Close() error
// Mode returns the current transport mode ("ws" or "http").
Mode() string
// Register sends agent registration and returns user info + features.
Register(ctx context.Context, req RegisterRequest) (*RegisterResponse, error)
// SendHeartbeat sends a periodic keep-alive.
SendHeartbeat(ctx context.Context, req HeartbeatRequest) (*HeartbeatResponse, error)
// SendProgress reports download progress for a task.
SendProgress(ctx context.Context, update StatusUpdate) (*StatusResponse, error)
// ClaimTasks polls for new tasks (HTTP mode only; WS receives via Events).
ClaimTasks(ctx context.Context, agentID string) (*TasksResponse, error)
// Deregister notifies the server of graceful shutdown.
Deregister(ctx context.Context, agentID string) error
// Events returns a channel that emits server-initiated events.
// In HTTP mode this channel is never written to (polling handles it).
// In WS mode, tasks/upgrade/control arrive here.
Events() <-chan ServerEvent
}
// ServerEvent represents a server-initiated message received via WebSocket.
type ServerEvent struct {
Type string // "tasks", "upgrade", "control", "disconnected"
Tasks *TasksResponse // populated when Type == "tasks"
Upgrade *UpgradeSignal // populated when Type == "upgrade"
Control *ControlAction // populated when Type == "control"
}
// ControlAction represents a server push for task control.
type ControlAction struct {
Action string `json:"action"` // "pause", "resume", "cancel", "stream"
TaskID string `json:"taskId"`
}

View file

@ -1,285 +0,0 @@
package agent
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
)
// TestE2EFullLifecycle tests the full lifecycle:
// connect → auth → receive tasks → send progress → receive control → disconnect → reconnect
func TestE2EFullLifecycle(t *testing.T) {
var mu sync.Mutex
var receivedMessages []map[string]interface{}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
for {
_, msg, err := conn.ReadMessage()
if err != nil {
return
}
var parsed map[string]interface{}
json.Unmarshal(msg, &parsed)
mu.Lock()
receivedMessages = append(receivedMessages, parsed)
mu.Unlock()
msgType, _ := parsed["type"].(string)
switch msgType {
case "auth":
conn.WriteJSON(wsRegisteredMessage{
Type: "registered",
User: UserInfo{Name: "E2E User", Plan: "pro", IsPro: true},
Features: FeatureFlags{Torrent: true, Debrid: true},
})
case "heartbeat":
// No response in WS mode
case "progress":
// Simulate server-side cancel after progress
if progress, ok := parsed["progress"].(float64); ok && progress >= 50 {
conn.WriteJSON(map[string]string{
"type": "control",
"action": "cancel",
"taskId": parsed["taskId"].(string),
})
}
case "upgrade-result":
// Acknowledged
}
}
}))
defer srv.Close()
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
tr := NewWSTransport(wsURL, "e2e-key", "e2e-agent", "test/1.0")
ctx := context.Background()
// 1. Connect
if err := tr.Connect(ctx); err != nil {
t.Fatalf("Connect: %v", err)
}
defer tr.Close()
// 2. Auth
resp, err := tr.Register(ctx, RegisterRequest{
AgentID: "e2e-agent",
Name: "E2E Test Agent",
Version: "1.0.0",
OS: "linux",
Arch: "amd64",
})
if err != nil {
t.Fatalf("Register: %v", err)
}
if resp.User.Name != "E2E User" {
t.Errorf("expected E2E User, got %s", resp.User.Name)
}
if !resp.Features.Debrid {
t.Error("expected debrid feature")
}
// 3. Send heartbeat
_, err = tr.SendHeartbeat(ctx, HeartbeatRequest{
AgentID: "e2e-agent",
DiskFreeBytes: 1000000000,
DiskTotalBytes: 5000000000,
})
if err != nil {
t.Fatalf("SendHeartbeat: %v", err)
}
// 4. Send progress (50% → should trigger cancel control)
_, err = tr.SendProgress(ctx, StatusUpdate{
TaskID: "task-e2e-1",
Status: "downloading",
Progress: 50,
DownloadedBytes: 500,
TotalBytes: 1000,
SpeedBps: 100,
})
if err != nil {
t.Fatalf("SendProgress: %v", err)
}
// 5. Wait for control event (cancel)
select {
case event := <-tr.Events():
if event.Type != "control" {
t.Errorf("expected control event, got %s", event.Type)
}
if event.Control.Action != "cancel" {
t.Errorf("expected cancel, got %s", event.Control.Action)
}
if event.Control.TaskID != "task-e2e-1" {
t.Errorf("expected task-e2e-1, got %s", event.Control.TaskID)
}
case <-time.After(3 * time.Second):
t.Fatal("timeout waiting for cancel control")
}
// Verify server received all messages
time.Sleep(100 * time.Millisecond)
mu.Lock()
defer mu.Unlock()
if len(receivedMessages) < 3 {
t.Fatalf("expected at least 3 messages, got %d", len(receivedMessages))
}
types := make([]string, len(receivedMessages))
for i, m := range receivedMessages {
types[i], _ = m["type"].(string)
}
expected := []string{"auth", "heartbeat", "progress"}
for _, exp := range expected {
found := false
for _, got := range types {
if got == exp {
found = true
break
}
}
if !found {
t.Errorf("missing message type %q in %v", exp, types)
}
}
}
// TestE2EHybridFailover tests the full failover scenario:
// WS connect → download → WS disconnect → switch to HTTP → continue working
func TestE2EHybridFailover(t *testing.T) {
connectionCount := 0
var mu sync.Mutex
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
mu.Lock()
connectionCount++
connNum := connectionCount
mu.Unlock()
// Read auth
conn.ReadMessage()
conn.WriteJSON(wsRegisteredMessage{
Type: "registered",
User: UserInfo{Name: "Failover User"},
})
if connNum == 1 {
// First connection: push tasks then disconnect after 200ms
time.Sleep(50 * time.Millisecond)
conn.WriteJSON(wsTasksMessage{
Type: "tasks",
Tasks: []Task{{ID: "t1", InfoHash: "abc", Title: "Failover Movie"}},
})
time.Sleep(150 * time.Millisecond)
conn.Close()
} else {
// Second connection (after reconnect): push upgrade
time.Sleep(50 * time.Millisecond)
conn.WriteJSON(wsUpgradeMessage{Type: "upgrade", Version: "3.0.0"})
time.Sleep(500 * time.Millisecond)
conn.Close()
}
}))
defer srv.Close()
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
wsT := NewWSTransport(wsURL, "key", "a1", "ua")
// HTTP mock for fallback
httpSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Simple heartbeat response
json.NewEncoder(w).Encode(HeartbeatResponse{Success: true})
}))
defer httpSrv.Close()
httpT := NewHTTPTransport(httpSrv.URL, "key", "ua")
h := NewHybridTransport(wsT, httpT)
ctx := context.Background()
err := h.Connect(ctx)
if err != nil {
t.Fatalf("Connect: %v", err)
}
defer h.Close()
// Should start in WS mode
if h.Mode() != "ws" {
t.Fatalf("expected ws mode, got %s", h.Mode())
}
// Register via WS
_, err = h.Register(ctx, RegisterRequest{AgentID: "a1"})
if err != nil {
t.Fatalf("Register: %v", err)
}
// Receive tasks via WS
var tasksReceived bool
var disconnected bool
for i := 0; i < 3; i++ {
select {
case event := <-h.Events():
switch event.Type {
case "tasks":
tasksReceived = true
if len(event.Tasks.Tasks) != 1 || event.Tasks.Tasks[0].Title != "Failover Movie" {
t.Errorf("unexpected tasks: %+v", event.Tasks)
}
case "disconnected":
disconnected = true
}
case <-time.After(2 * time.Second):
break
}
if disconnected {
break
}
}
if !tasksReceived {
t.Error("did not receive tasks before disconnect")
}
if !disconnected {
t.Error("did not receive disconnect event")
}
// Should now be in HTTP mode
time.Sleep(100 * time.Millisecond)
if h.Mode() != "http" {
t.Errorf("expected http mode after disconnect, got %s", h.Mode())
}
// Heartbeat should work via HTTP fallback
hbResp, err := h.SendHeartbeat(ctx, HeartbeatRequest{AgentID: "a1"})
if err != nil {
t.Fatalf("SendHeartbeat via HTTP fallback: %v", err)
}
if !hbResp.Success {
t.Error("expected heartbeat success")
}
}

View file

@ -1,50 +0,0 @@
package agent
import "context"
// HTTPTransport wraps the existing Client to implement Transport.
// This is a thin adapter — no behavioral changes from the current HTTP protocol.
type HTTPTransport struct {
client *Client
events chan ServerEvent
}
// NewHTTPTransport creates a new HTTP-based transport.
func NewHTTPTransport(baseURL, apiKey, userAgent string) *HTTPTransport {
return &HTTPTransport{
client: NewClient(baseURL, apiKey, userAgent),
events: make(chan ServerEvent, 10),
}
}
func (t *HTTPTransport) Connect(_ context.Context) error { return nil }
func (t *HTTPTransport) Close() error { return nil }
func (t *HTTPTransport) Mode() string { return "http" }
func (t *HTTPTransport) Events() <-chan ServerEvent { return t.events }
func (t *HTTPTransport) Register(ctx context.Context, req RegisterRequest) (*RegisterResponse, error) {
return t.client.Register(ctx, req)
}
func (t *HTTPTransport) SendHeartbeat(ctx context.Context, req HeartbeatRequest) (*HeartbeatResponse, error) {
return t.client.Heartbeat(ctx, req)
}
func (t *HTTPTransport) SendProgress(ctx context.Context, update StatusUpdate) (*StatusResponse, error) {
return t.client.ReportStatus(ctx, update)
}
func (t *HTTPTransport) BatchReportStatus(ctx context.Context, updates []StatusUpdate) (*BatchStatusResponse, error) {
return t.client.BatchReportStatus(ctx, updates)
}
func (t *HTTPTransport) ClaimTasks(ctx context.Context, agentID string) (*TasksResponse, error) {
return t.client.ClaimTasks(ctx, agentID)
}
func (t *HTTPTransport) Deregister(ctx context.Context, agentID string) error {
return t.client.Deregister(ctx, agentID)
}
// Client returns the underlying HTTP client for direct use if needed.
func (t *HTTPTransport) Client() *Client { return t.client }

View file

@ -1,214 +0,0 @@
package agent
import (
"context"
"log"
"sync"
"sync/atomic"
"time"
)
// HybridTransport tries WebSocket first, falls back to HTTP if WS fails.
// Automatically reconnects WS in the background.
type HybridTransport struct {
ws *WSTransport
http *HTTPTransport
mode atomic.Value // "ws" or "http"
events chan ServerEvent
reconnectMu sync.Mutex
reconnectRunning bool
reconnectStop chan struct{}
closed atomic.Bool
}
// NewHybridTransport creates a transport that prefers WS with HTTP fallback.
func NewHybridTransport(ws *WSTransport, http *HTTPTransport) *HybridTransport {
h := &HybridTransport{
ws: ws,
http: http,
events: make(chan ServerEvent, 50),
reconnectStop: make(chan struct{}),
}
h.mode.Store("http") // start in HTTP, upgrade to WS on Connect
return h
}
func (h *HybridTransport) Mode() string { return h.mode.Load().(string) }
func (h *HybridTransport) Events() <-chan ServerEvent { return h.events }
// Connect tries WS first. If it fails, falls back to HTTP and starts reconnection loop.
func (h *HybridTransport) Connect(ctx context.Context) error {
// Try WebSocket first
if err := h.ws.Connect(ctx); err != nil {
log.Printf("[transport] WebSocket connect failed (%v), using HTTP fallback", err)
h.mode.Store("http")
h.startReconnectLoop()
return h.http.Connect(ctx)
}
h.mode.Store("ws")
log.Println("[transport] Connected via WebSocket")
// Forward WS events to unified channel + watch for disconnection
go h.forwardWSEvents()
return nil
}
// Close shuts down both transports and stops reconnection.
func (h *HybridTransport) Close() error {
h.closed.Store(true)
select {
case <-h.reconnectStop:
default:
close(h.reconnectStop)
}
_ = h.ws.Close()
return h.http.Close()
}
// Register delegates to the active transport.
func (h *HybridTransport) Register(ctx context.Context, req RegisterRequest) (*RegisterResponse, error) {
if h.mode.Load() == "ws" {
return h.ws.Register(ctx, req)
}
return h.http.Register(ctx, req)
}
// SendHeartbeat delegates to the active transport.
func (h *HybridTransport) SendHeartbeat(ctx context.Context, req HeartbeatRequest) (*HeartbeatResponse, error) {
if h.mode.Load() == "ws" {
resp, err := h.ws.SendHeartbeat(ctx, req)
if err != nil {
// WS write failed — switch to HTTP
h.switchToHTTP()
return h.http.SendHeartbeat(ctx, req)
}
return resp, nil
}
return h.http.SendHeartbeat(ctx, req)
}
// SendProgress delegates to the active transport.
func (h *HybridTransport) SendProgress(ctx context.Context, update StatusUpdate) (*StatusResponse, error) {
if h.mode.Load() == "ws" {
resp, err := h.ws.SendProgress(ctx, update)
if err != nil {
h.switchToHTTP()
return h.http.SendProgress(ctx, update)
}
return resp, nil
}
return h.http.SendProgress(ctx, update)
}
// ClaimTasks delegates to the active transport.
func (h *HybridTransport) ClaimTasks(ctx context.Context, agentID string) (*TasksResponse, error) {
if h.mode.Load() == "ws" {
return h.ws.ClaimTasks(ctx, agentID) // no-op in WS mode
}
return h.http.ClaimTasks(ctx, agentID)
}
// Deregister delegates to the active transport.
func (h *HybridTransport) Deregister(ctx context.Context, agentID string) error {
if h.mode.Load() == "ws" {
return h.ws.Deregister(ctx, agentID)
}
return h.http.Deregister(ctx, agentID)
}
// ── Internal ─────────────────────────────────────────────────────────────────
func (h *HybridTransport) switchToHTTP() {
if h.mode.Load() == "http" {
return
}
log.Println("[transport] Switching to HTTP fallback")
h.mode.Store("http")
_ = h.ws.Close()
h.startReconnectLoop()
}
func (h *HybridTransport) forwardWSEvents() {
for {
select {
case <-h.reconnectStop:
return
case event, ok := <-h.ws.Events():
if !ok {
return // channel closed
}
if event.Type == "disconnected" {
h.switchToHTTP()
select {
case h.events <- event:
default:
}
return
}
select {
case h.events <- event:
default:
log.Printf("[transport] events channel full, dropping %s event", event.Type)
}
}
}
}
func (h *HybridTransport) startReconnectLoop() {
h.reconnectMu.Lock()
defer h.reconnectMu.Unlock()
if h.reconnectRunning {
return
}
h.reconnectRunning = true
go h.reconnectLoop()
}
func (h *HybridTransport) reconnectLoop() {
backoff := 5 * time.Second
maxBackoff := 60 * time.Second
for {
select {
case <-h.reconnectStop:
return
case <-time.After(backoff):
}
if h.closed.Load() {
return
}
// Already on WS? (someone else reconnected)
if h.mode.Load() == "ws" {
return
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
err := h.ws.Connect(ctx)
cancel()
if err != nil {
log.Printf("[transport] WS reconnect failed: %v (retry in %v)", err, backoff)
backoff = min(backoff*2, maxBackoff)
continue
}
// WS reconnected — switch back
log.Println("[transport] WebSocket reconnected")
h.mode.Store("ws")
// Reset reconnect flag so loop can start again if WS drops
h.reconnectMu.Lock()
h.reconnectRunning = false
h.reconnectMu.Unlock()
// Forward events from new WS connection
go h.forwardWSEvents()
return
}
}

File diff suppressed because it is too large Load diff

View file

@ -1,395 +0,0 @@
package agent
import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/gorilla/websocket"
)
// WSTransport communicates with the server via WebSocket through a Cloudflare Durable Object.
type WSTransport struct {
wsURL string // wss://unarr.torrentclaw.com/ws/{agentId}
apiKey string
agentID string
userAgent string
conn *websocket.Conn
mu sync.Mutex
events chan ServerEvent
closed atomic.Bool
// Cached auth response from the DO
authResp *RegisterResponse
authMu sync.Mutex
authDone chan struct{}
authDoneOnce sync.Once
}
// NewWSTransport creates a WebSocket-based transport.
func NewWSTransport(wsURL, apiKey, agentID, userAgent string) *WSTransport {
return &WSTransport{
wsURL: wsURL,
apiKey: apiKey,
agentID: agentID,
userAgent: userAgent,
events: make(chan ServerEvent, 50),
authDone: make(chan struct{}),
}
}
func (t *WSTransport) Mode() string { return "ws" }
func (t *WSTransport) Events() <-chan ServerEvent { return t.events }
// Connect dials the WebSocket server and starts the read loop.
func (t *WSTransport) Connect(ctx context.Context) error {
dialer := websocket.Dialer{
HandshakeTimeout: 10 * time.Second,
}
header := http.Header{}
header.Set("User-Agent", t.userAgent)
// Append API key as query param for auth on WS upgrade
wsURLWithKey := t.wsURL
if t.apiKey != "" {
sep := "?"
if strings.Contains(wsURLWithKey, "?") {
sep = "&"
}
wsURLWithKey += sep + "key=" + t.apiKey
}
conn, wsResp, err := dialer.DialContext(ctx, wsURLWithKey, header)
if wsResp != nil && wsResp.Body != nil {
defer wsResp.Body.Close()
}
if err != nil {
return fmt.Errorf("ws dial: %w", err)
}
t.mu.Lock()
t.conn = conn
t.closed.Store(false)
t.authDone = make(chan struct{})
t.authDoneOnce = sync.Once{}
t.mu.Unlock()
go t.readLoop(conn)
return nil
}
// Close sends a close frame and shuts down the connection.
func (t *WSTransport) Close() error {
t.closed.Store(true)
t.mu.Lock()
defer t.mu.Unlock()
if t.conn != nil {
_ = t.conn.WriteMessage(
websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""),
)
err := t.conn.Close()
t.conn = nil
return err
}
return nil
}
// Register sends auth message and waits for the registered response.
func (t *WSTransport) Register(ctx context.Context, req RegisterRequest) (*RegisterResponse, error) {
msg := wsAuthMessage{
Type: "auth",
APIKey: t.apiKey,
AgentID: req.AgentID,
Name: req.Name,
OS: req.OS,
Arch: req.Arch,
Version: req.Version,
DownloadDir: req.DownloadDir,
DiskFreeBytes: req.DiskFreeBytes,
DiskTotalBytes: req.DiskTotalBytes,
}
if err := t.send(msg); err != nil {
return nil, fmt.Errorf("ws auth send: %w", err)
}
// Wait for the auth response or context cancellation
select {
case <-t.authDone:
t.authMu.Lock()
resp := t.authResp
t.authMu.Unlock()
if resp == nil {
return nil, fmt.Errorf("ws auth: no response received")
}
return resp, nil
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(15 * time.Second):
return nil, fmt.Errorf("ws auth: timeout waiting for registered response")
}
}
// SendHeartbeat sends a heartbeat message. No blocking response in WS mode.
func (t *WSTransport) SendHeartbeat(_ context.Context, req HeartbeatRequest) (*HeartbeatResponse, error) {
msg := struct {
Type string `json:"type"`
Disk *struct {
Free int64 `json:"free"`
Total int64 `json:"total"`
} `json:"disk,omitempty"`
}{Type: "heartbeat"}
if req.DiskFreeBytes > 0 || req.DiskTotalBytes > 0 {
msg.Disk = &struct {
Free int64 `json:"free"`
Total int64 `json:"total"`
}{Free: req.DiskFreeBytes, Total: req.DiskTotalBytes}
}
if err := t.send(msg); err != nil {
return nil, err
}
// WS mode: heartbeat is fire-and-forget. Upgrade signals arrive via Events().
return &HeartbeatResponse{Success: true}, nil
}
// SendProgress sends a progress update. Control signals arrive async via Events().
func (t *WSTransport) SendProgress(_ context.Context, update StatusUpdate) (*StatusResponse, error) {
msg := struct {
Type string `json:"type"`
TaskID string `json:"taskId"`
Status string `json:"status,omitempty"`
Progress int `json:"progress,omitempty"`
DownloadedBytes int64 `json:"downloadedBytes,omitempty"`
TotalBytes int64 `json:"totalBytes,omitempty"`
SpeedBps int64 `json:"speedBps,omitempty"`
ETA int `json:"eta,omitempty"`
ResolvedMethod string `json:"resolvedMethod,omitempty"`
FileName string `json:"fileName,omitempty"`
FilePath string `json:"filePath,omitempty"`
StreamURL string `json:"streamUrl,omitempty"`
StreamReady bool `json:"streamReady,omitempty"`
ErrorMessage string `json:"errorMessage,omitempty"`
}{
Type: "progress",
TaskID: update.TaskID,
Status: update.Status,
Progress: update.Progress,
DownloadedBytes: update.DownloadedBytes,
TotalBytes: update.TotalBytes,
SpeedBps: update.SpeedBps,
ETA: update.ETA,
ResolvedMethod: update.ResolvedMethod,
FileName: update.FileName,
FilePath: update.FilePath,
StreamURL: update.StreamURL,
StreamReady: update.StreamReady,
ErrorMessage: update.ErrorMessage,
}
if err := t.send(msg); err != nil {
return nil, err
}
// In WS mode, control signals come via Events(), not in the progress response.
return &StatusResponse{Success: true}, nil
}
// ClaimTasks is a no-op in WS mode — tasks arrive via Events().
func (t *WSTransport) ClaimTasks(_ context.Context, _ string) (*TasksResponse, error) {
return &TasksResponse{}, nil
}
// Deregister is handled by WebSocket close (DO detects disconnection).
func (t *WSTransport) Deregister(_ context.Context, _ string) error {
return t.Close()
}
// ── Internal ─────────────────────────────────────────────────────────────────
func (t *WSTransport) send(msg any) error {
t.mu.Lock()
defer t.mu.Unlock()
if t.conn == nil {
return fmt.Errorf("ws: not connected")
}
data, err := json.Marshal(msg)
if err != nil {
return err
}
_ = t.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
return t.conn.WriteMessage(websocket.TextMessage, data)
}
func (t *WSTransport) readLoop(conn *websocket.Conn) {
// Cloudflare idle timeout is 100s. We send pings every 30s and expect
// either a pong or a server message within 45s. If neither arrives,
// the read deadline fires and we detect the zombie connection.
const (
pongWait = 45 * time.Second
pingPeriod = 30 * time.Second
)
_ = conn.SetReadDeadline(time.Now().Add(pongWait))
conn.SetPongHandler(func(string) error {
_ = conn.SetReadDeadline(time.Now().Add(pongWait))
return nil
})
// Ping ticker goroutine — stops when readLoop returns.
pingDone := make(chan struct{})
go func() {
ticker := time.NewTicker(pingPeriod)
defer ticker.Stop()
for {
select {
case <-ticker.C:
t.mu.Lock()
if t.conn != nil {
_ = t.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
err := t.conn.WriteMessage(websocket.PingMessage, nil)
_ = t.conn.SetWriteDeadline(time.Time{})
if err != nil {
t.mu.Unlock()
return
}
}
t.mu.Unlock()
case <-pingDone:
return
}
}
}()
defer close(pingDone)
for {
_, msg, err := conn.ReadMessage()
if err != nil {
if !t.closed.Load() {
log.Printf("[ws] read error: %v", err)
// Signal disconnection to the daemon
select {
case t.events <- ServerEvent{Type: "disconnected"}:
default:
}
}
return
}
// Any message (text or pong) proves the connection is alive.
_ = conn.SetReadDeadline(time.Now().Add(pongWait))
var envelope struct {
Type string `json:"type"`
}
if err := json.Unmarshal(msg, &envelope); err != nil {
log.Printf("[ws] invalid message: %v", err)
continue
}
switch envelope.Type {
case "registered":
var resp wsRegisteredMessage
if json.Unmarshal(msg, &resp) == nil {
t.authMu.Lock()
t.authResp = &RegisterResponse{
Success: true,
User: resp.User,
Features: resp.Features,
}
t.authMu.Unlock()
// Signal that auth is complete (sync.Once prevents double-close panic)
t.authDoneOnce.Do(func() { close(t.authDone) })
}
case "tasks":
var resp wsTasksMessage
if json.Unmarshal(msg, &resp) == nil {
select {
case t.events <- ServerEvent{
Type: "tasks",
Tasks: &TasksResponse{
Tasks: resp.Tasks,
StreamRequests: resp.StreamRequests,
},
}:
default:
log.Printf("[ws] events channel full, dropping tasks message")
}
}
case "upgrade":
var resp wsUpgradeMessage
if json.Unmarshal(msg, &resp) == nil {
select {
case t.events <- ServerEvent{
Type: "upgrade",
Upgrade: &UpgradeSignal{Version: resp.Version},
}:
default:
}
}
case "control":
var resp ControlAction
if json.Unmarshal(msg, &resp) == nil {
select {
case t.events <- ServerEvent{
Type: "control",
Control: &resp,
}:
default:
}
}
case "error":
var resp struct {
Message string `json:"message"`
}
if json.Unmarshal(msg, &resp) == nil {
log.Printf("[ws] server error: %s", resp.Message)
}
}
}
}
// ── WS message types ─────────────────────────────────────────────────────────
type wsAuthMessage struct {
Type string `json:"type"`
APIKey string `json:"apiKey"`
AgentID string `json:"agentId"`
Name string `json:"name,omitempty"`
OS string `json:"os,omitempty"`
Arch string `json:"arch,omitempty"`
Version string `json:"version,omitempty"`
DownloadDir string `json:"downloadDir,omitempty"`
DiskFreeBytes int64 `json:"diskFreeBytes,omitempty"`
DiskTotalBytes int64 `json:"diskTotalBytes,omitempty"`
}
type wsRegisteredMessage struct {
Type string `json:"type"`
User UserInfo `json:"user"`
Features FeatureFlags `json:"features"`
}
type wsTasksMessage struct {
Type string `json:"type"`
Tasks []Task `json:"tasks"`
StreamRequests []StreamRequest `json:"streamRequests,omitempty"`
}
type wsUpgradeMessage struct {
Type string `json:"type"`
Version string `json:"version"`
}

View file

@ -50,20 +50,6 @@ type UsenetServerInfo struct {
SSL bool `json:"ssl"`
}
// HeartbeatRequest is sent every 30s to keep the agent alive.
type HeartbeatRequest struct {
AgentID string `json:"agentId"`
Name string `json:"name,omitempty"`
OS string `json:"os,omitempty"`
Version string `json:"version,omitempty"`
DownloadDir string `json:"downloadDir,omitempty"`
DiskFreeBytes int64 `json:"diskFreeBytes,omitempty"`
DiskTotalBytes int64 `json:"diskTotalBytes,omitempty"`
StreamPort int `json:"streamPort,omitempty"`
LanIP string `json:"lanIp,omitempty"`
TailscaleIP string `json:"tailscaleIp,omitempty"`
}
// Task represents a download task claimed from the server.
type Task struct {
ID string `json:"id"`
@ -88,12 +74,6 @@ type Task struct {
CollectionName string `json:"collectionName,omitempty"` // Collection name (e.g., "Harry Potter Collection")
}
// TasksResponse wraps the array of tasks returned by the server.
type TasksResponse struct {
Tasks []Task `json:"tasks"`
StreamRequests []StreamRequest `json:"streamRequests,omitempty"`
}
// StreamRequest is a request to stream a completed download from disk.
type StreamRequest struct {
TaskID string `json:"taskId"`
@ -139,14 +119,6 @@ type BatchStatusResponse struct {
Watching bool `json:"watching,omitempty"`
}
// HeartbeatResponse is returned by the server on heartbeat.
type HeartbeatResponse struct {
Success bool `json:"success"`
Upgrade *UpgradeSignal `json:"upgrade,omitempty"`
Watching bool `json:"watching,omitempty"` // true when a user is viewing download progress in the web UI
Scan bool `json:"scan,omitempty"` // true when user triggered a library scan from the web UI
}
// UpgradeSignal tells the agent to upgrade to a specific version.
type UpgradeSignal struct {
Version string `json:"version"`
@ -176,7 +148,6 @@ type AgentInfo struct {
User UserInfo
Features FeatureFlags
StartedAt time.Time
LastPollAt time.Time
ActiveTasks int
}
@ -334,6 +305,45 @@ type LibrarySyncResponse struct {
Removed int `json:"removed"`
}
// ---------------------------------------------------------------------------
// Sync types (unified CLI ↔ Server communication)
// ---------------------------------------------------------------------------
// SyncRequest is sent by the CLI periodically to synchronize state with the server.
// Contains the CLI's full execution state — the server responds with pending actions.
type SyncRequest struct {
AgentID string `json:"agentId"`
Version string `json:"version,omitempty"`
OS string `json:"os,omitempty"`
Arch string `json:"arch,omitempty"`
Name string `json:"name,omitempty"`
DownloadDir string `json:"downloadDir,omitempty"`
DiskFreeBytes int64 `json:"diskFreeBytes,omitempty"`
DiskTotalBytes int64 `json:"diskTotalBytes,omitempty"`
StreamPort int `json:"streamPort,omitempty"`
LanIP string `json:"lanIp,omitempty"`
TailscaleIP string `json:"tailscaleIp,omitempty"`
FreeSlots int `json:"freeSlots"`
Tasks []TaskState `json:"tasks"`
}
// ControlAction represents a server-side control signal for a task.
type ControlAction struct {
Action string `json:"action"` // "pause", "resume", "cancel", "stream"
TaskID string `json:"taskId"`
DeleteFiles bool `json:"deleteFiles,omitempty"`
}
// SyncResponse is returned by the server with all pending actions for the CLI.
type SyncResponse struct {
NewTasks []Task `json:"newTasks,omitempty"`
Controls []ControlAction `json:"controls,omitempty"`
StreamRequests []StreamRequest `json:"streamRequests,omitempty"`
Watching bool `json:"watching"`
Upgrade *UpgradeSignal `json:"upgrade,omitempty"`
Scan bool `json:"scan,omitempty"`
}
// ---------------------------------------------------------------------------
// Watch progress types (used by stream tracking)
// ---------------------------------------------------------------------------

View file

@ -311,21 +311,10 @@ func configConnection(cfg *config.Config) error {
).Run()
}
func configAdvanced(cfg *config.Config) error {
return huh.NewForm(
huh.NewGroup(
huh.NewInput().
Title("Poll interval").
Description("How often to check for new tasks (e.g. 30s, 1m)").
Value(&cfg.Daemon.PollInterval).
Validate(validateDuration),
huh.NewInput().
Title("Heartbeat interval").
Description("How often to send heartbeat to server (e.g. 30s, 1m)").
Value(&cfg.Daemon.HeartbeatInterval).
Validate(validateDuration),
),
).Run()
func configAdvanced(_ *config.Config) error {
// Sync intervals are adaptive (3s watching, 60s idle) — no user-facing config needed.
fmt.Println("No advanced settings to configure. Sync intervals are automatic.")
return nil
}
// ── Validators ──────────────────────────────────────────────────────

View file

@ -7,7 +7,6 @@ import (
"os"
"os/signal"
"path/filepath"
"strings"
"syscall"
"time"
@ -27,13 +26,13 @@ func newStartCmd() *cobra.Command {
Short: "Start the download daemon (foreground)",
Long: `Start the unarr daemon in the foreground.
Registers with the server, receives download tasks via WebSocket (with
HTTP fallback), and executes them using the configured download method.
Registers with the server, receives download tasks via periodic sync,
and executes them using the configured download method.
Supports torrent, debrid, and usenet downloads concurrently.
The daemon sends periodic heartbeats and reports download progress back
to the web dashboard. Press Ctrl+C to stop gracefully active downloads
get up to 30 seconds to finish.
The daemon syncs state with the server every 3s when someone is viewing
the web dashboard, or every 60s when idle. Press Ctrl+C to stop
gracefully active downloads get up to 30 seconds to finish.
Requires: API key, agent ID, and download directory (run 'unarr init' first).
@ -127,85 +126,59 @@ func runDaemonStart() error {
bold.Println(" unarr Daemon")
fmt.Println()
// Parse intervals
pollInterval, _ := time.ParseDuration(cfg.Daemon.PollInterval)
if pollInterval == 0 {
pollInterval = 30 * time.Second
}
heartbeatInterval, _ := time.ParseDuration(cfg.Daemon.HeartbeatInterval)
if heartbeatInterval == 0 {
heartbeatInterval = 30 * time.Second
}
statusInterval, _ := time.ParseDuration(cfg.Daemon.StatusInterval)
if statusInterval == 0 {
statusInterval = 3 * time.Second
}
userAgent := "unarr/" + Version
// Create daemon config
daemonCfg := agent.DaemonConfig{
AgentID: cfg.Agent.ID,
AgentName: cfg.Agent.Name,
Version: Version,
DownloadDir: cfg.Download.Dir,
PollInterval: pollInterval,
HeartbeatInterval: heartbeatInterval,
StreamPort: cfg.Download.StreamPort,
LanIP: engine.LanIP(),
TailscaleIP: engine.TailscaleIP(),
AgentID: cfg.Agent.ID,
AgentName: cfg.Agent.Name,
Version: Version,
DownloadDir: cfg.Download.Dir,
StreamPort: cfg.Download.StreamPort,
LanIP: engine.LanIP(),
TailscaleIP: engine.TailscaleIP(),
}
// Create transport: Hybrid (WS + HTTP fallback) or HTTP-only
httpT := agent.NewHTTPTransport(cfg.Auth.APIURL, cfg.Auth.APIKey, userAgent)
wsURL := cfg.Auth.WSURL
if wsURL == "" {
wsURL = deriveWSURL(cfg.Auth.APIURL, cfg.Agent.ID)
}
var transport agent.Transport
if wsURL != "" {
wsT := agent.NewWSTransport(wsURL, cfg.Auth.APIKey, cfg.Agent.ID, userAgent)
transport = agent.NewHybridTransport(wsT, httpT)
log.Printf("Transport: WebSocket (fallback: HTTP) → %s", wsURL)
} else {
transport = httpT
log.Println("Transport: HTTP only")
}
// Create daemon — always uses Transport interface
d := agent.NewDaemon(daemonCfg, transport)
// Create agent client for watch progress reporting
// Create HTTP client — single communication channel
agentClient := agent.NewClient(cfg.Auth.APIURL, cfg.Auth.APIKey, userAgent)
log.Printf("Transport: HTTP sync → %s", cfg.Auth.APIURL)
// Create daemon
d := agent.NewDaemon(daemonCfg, agentClient)
// Start SIGUSR1 reload watcher (unix only, no-op on Windows)
startReloadWatcher(&ReloadableConfig{Daemon: d})
// Daemon-scoped context — cancelled on shutdown
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Create progress reporter using transport
reporter := engine.NewProgressReporterWithTransport(transport, statusInterval)
reporter.SetWatchingFunc(func() bool { return d.Watching.Load() })
reporter.SetWatchingChangedHandler(func(watching bool) { d.Watching.Store(watching) })
// Parse speed limits
maxDl, _ := config.ParseSpeed(cfg.Download.MaxDownloadSpeed)
maxUl, _ := config.ParseSpeed(cfg.Download.MaxUploadSpeed)
// Parse torrent timeouts from config (default: 0 = unlimited, like qBittorrent)
// Parse torrent timeouts
metaTimeout, _ := time.ParseDuration(cfg.Download.MetadataTimeout)
stallTimeout, _ := time.ParseDuration(cfg.Download.StallTimeout)
// Create progress reporter — only used for stream tasks (handleStreamTask)
// The sync goroutine handles all regular progress reporting.
statusInterval, _ := time.ParseDuration(cfg.Daemon.StatusInterval)
if statusInterval == 0 {
statusInterval = 3 * time.Second
}
reporter := engine.NewProgressReporter(agentClient, statusInterval)
reporter.SetWatchingFunc(func() bool { return d.Watching.Load() })
// Create torrent downloader
torrentDl, err := engine.NewTorrentDownloader(engine.TorrentConfig{
DataDir: cfg.Download.Dir,
MetadataTimeout: metaTimeout, // 0 = unlimited (default)
StallTimeout: stallTimeout, // 0 = unlimited (default)
MaxTimeout: 0, // unlimited
MetadataTimeout: metaTimeout,
StallTimeout: stallTimeout,
MaxTimeout: 0,
MaxDownloadRate: maxDl,
MaxUploadRate: maxUl,
ListenPort: cfg.Download.ListenPort, // 0 = default 42069
ListenPort: cfg.Download.ListenPort,
SeedEnabled: false,
})
if err != nil {
@ -223,7 +196,7 @@ func runDaemonStart() error {
log.Printf("Speed limits: download=%s upload=%s", dlStr, ulStr)
}
// Create debrid downloader (HTTPS-based, no provider interaction needed)
// Create debrid downloader
debridDl := engine.NewDebridDownloader()
// Create download manager
@ -237,170 +210,53 @@ func runDaemonStart() error {
TVShowsDir: cfg.Organize.TVShowsDir,
OutputDir: cfg.Download.Dir,
},
}, reporter, torrentDl, debridDl, engine.NewUsenetDownloader(httpT.Client()))
}, reporter, torrentDl, debridDl, engine.NewUsenetDownloader(agentClient))
// Create persistent stream server — lives for the entire daemon lifecycle.
// One port, one server, swap files with SetFile(). No more port churn.
// Create persistent stream server
streamSrv := engine.NewStreamServer(cfg.Download.StreamPort)
if err := streamSrv.Listen(ctx); err != nil {
return fmt.Errorf("start stream server: %w", err)
}
// Update heartbeat with actual port (may differ if configured port was busy)
d.UpdateStreamPort(streamSrv.Port())
// Wire state tracking
// Wire sync client callbacks
sc := d.SyncClient()
sc.GetFreeSlots = manager.FreeSlots
sc.GetTaskStates = manager.TaskStates
d.GetActiveCount = manager.ActiveCount
d.GetCleanableBytes = CleanableBytes
// Wire: server-side signals -> manager actions + stream tasks
reporter.SetCancelHandler(func(taskID string) {
manager.CancelTask(taskID)
cancelStreamTask(taskID)
})
reporter.SetPauseHandler(func(taskID string) {
manager.PauseTask(taskID)
cancelStreamTask(taskID)
})
reporter.SetDeleteFilesHandler(func(taskID string) {
manager.CancelAndDeleteFiles(taskID)
cancelStreamTask(taskID)
})
// Wire: stream requested on active download → set file on persistent server
reporter.SetStreamRequestedHandler(func(taskID string) {
task := manager.GetTask(taskID)
if task == nil {
log.Printf("[%s] stream requested but task not found in manager", taskID[:8])
return
}
if task.GetStreamURL() != "" {
return // already streaming
}
provider, err := torrentDl.GetStreamProvider(taskID)
if err != nil {
log.Printf("[%s] stream failed: %v", taskID[:8], err)
return
}
cancelStreamContexts()
streamSrv.SetFile(provider, taskID)
task.SetStreamURL(streamSrv.URLsJSON())
log.Printf("[%s] streaming active download: %s", taskID[:8], provider.FileName())
// Start watch progress reporter with cancellable context
watchCtx, watchCancel := context.WithCancel(ctx) //nolint:gosec // cancel stored in streamRegistry, called by cancelStreamContexts()
streamRegistry.mu.Lock()
streamRegistry.cancels["watch:"+taskID] = watchCancel
streamRegistry.mu.Unlock()
go engine.NewWatchReporter(agentClient, streamSrv, taskID).Run(watchCtx)
})
// Wire: daemon claimed tasks -> manager
// Trigger immediate sync when a download slot frees up
manager.OnTaskDone = func() { d.TriggerSync() }
// Wire: sync receives new tasks → submit to manager or handle stream
d.OnTasksClaimed = func(tasks []agent.Task) {
for _, t := range tasks {
if t.Mode == "stream" {
// Skip if already streaming this task
if isStreamingTask(t.ID) {
continue
}
// Only 1 stream at a time: cancel existing stream goroutines + clear file
cancelStreamContexts()
streamSrv.ClearFile()
// Reserve slot before spawning goroutine to prevent TOCTOU race.
streamCtx, streamCancel := context.WithCancel(ctx) //nolint:gosec // G118: cancel ownership transferred to streamRegistry
streamCtx, streamCancel := context.WithCancel(ctx) //nolint:gosec // G118: cancel stored in registry
streamRegistry.mu.Lock()
streamRegistry.cancels[t.ID] = streamCancel
streamRegistry.mu.Unlock()
go handleStreamTask(streamCtx, t, reporter, cfg, agentClient, streamSrv)
} else if t.ForceStart || manager.HasCapacity() {
manager.Submit(ctx, t)
} else {
log.Printf("[%s] skipped: no capacity (max %d)", t.ID[:8], cfg.Download.MaxConcurrent)
manager.Submit(ctx, t)
}
}
}
// Wire: stream requests for completed downloads → set file on persistent server
d.OnStreamRequested = func(sr agent.StreamRequest) {
// Already serving this task — just notify server it's ready
if streamSrv.CurrentTaskID() == sr.TaskID {
go func() {
if _, err := transport.SendProgress(ctx, agent.StatusUpdate{
TaskID: sr.TaskID,
StreamReady: true,
}); err != nil {
log.Printf("[%s] stream ready re-notify failed: %v", sr.TaskID[:8], err)
}
}()
return
}
filePath := sr.FilePath
info, err := os.Stat(filePath)
if err != nil {
log.Printf("[%s] stream request: file not found: %s", sr.TaskID[:8], filePath)
go func() {
if _, err := transport.SendProgress(ctx, agent.StatusUpdate{
TaskID: sr.TaskID,
Status: "failed",
ErrorMessage: fmt.Sprintf("file not found: %s", filePath),
}); err != nil {
log.Printf("[%s] stream error report failed: %v", sr.TaskID[:8], err)
}
}()
return
}
// If filePath is a directory, find the largest video file inside
if info.IsDir() {
found := engine.FindVideoFile(filePath)
if found == "" {
log.Printf("[%s] stream request: no video file in directory: %s", sr.TaskID[:8], filePath)
go func() {
if _, err := transport.SendProgress(ctx, agent.StatusUpdate{
TaskID: sr.TaskID,
Status: "failed",
ErrorMessage: fmt.Sprintf("no video file in directory: %s", filePath),
}); err != nil {
log.Printf("[%s] stream error report failed: %v", sr.TaskID[:8], err)
}
}()
return
}
filePath = found
log.Printf("[%s] resolved directory to video file: %s", sr.TaskID[:8], filepath.Base(filePath))
}
// Cancel any active stream goroutines and swap file on the persistent server
cancelStreamContexts()
streamSrv.SetFile(engine.NewDiskFileProvider(filePath), sr.TaskID)
log.Printf("[%s] streaming from disk: %s → %s", sr.TaskID[:8], filepath.Base(filePath), streamSrv.URL())
// Start watch progress reporter with a cancellable context
// so it stops when the user switches to a different stream.
watchCtx, watchCancel := context.WithCancel(ctx) //nolint:gosec // cancel stored in streamRegistry, called by cancelStreamContexts()
streamRegistry.mu.Lock()
streamRegistry.cancels["watch:"+sr.TaskID] = watchCancel
streamRegistry.mu.Unlock()
go engine.NewWatchReporter(agentClient, streamSrv, sr.TaskID).Run(watchCtx)
// Notify server that stream is ready (clears streamRequested flag)
go func() {
if _, err := transport.SendProgress(ctx, agent.StatusUpdate{
TaskID: sr.TaskID,
StreamReady: true,
}); err != nil {
log.Printf("[%s] stream ready report failed: %v", sr.TaskID[:8], err)
}
}()
}
// Wire: WS control actions (pause/cancel/stream pushed from server)
d.OnControlAction = func(action, taskID string) {
// Wire: sync receives control signals → act on manager
d.OnControlAction = func(action, taskID string, deleteFiles bool) {
switch action {
case "cancel":
manager.CancelTask(taskID)
if deleteFiles {
manager.CancelAndDeleteFiles(taskID)
} else {
manager.CancelTask(taskID)
}
cancelStreamTask(taskID)
if streamSrv.CurrentTaskID() == taskID {
streamSrv.ClearFile()
@ -412,10 +268,9 @@ func runDaemonStart() error {
streamSrv.ClearFile()
}
case "resume":
log.Printf("[%s] resume requested via WebSocket, triggering poll", taskID[:8])
d.TriggerPoll()
log.Printf("[%s] resume requested, triggering sync", agent.ShortID(taskID))
d.TriggerSync()
case "stream":
// Skip if already streaming this task
if streamSrv.CurrentTaskID() == taskID {
return
}
@ -425,13 +280,19 @@ func runDaemonStart() error {
}
provider, err := torrentDl.GetStreamProvider(taskID)
if err != nil {
log.Printf("[%s] stream failed: %v", taskID[:8], err)
log.Printf("[%s] stream failed: %v", agent.ShortID(taskID), err)
return
}
cancelStreamContexts()
streamSrv.SetFile(provider, taskID)
task.SetStreamURL(streamSrv.URLsJSON())
log.Printf("[%s] streaming via WS: %s", taskID[:8], provider.FileName())
log.Printf("[%s] streaming: %s", agent.ShortID(taskID), provider.FileName())
watchCtx, watchCancel := context.WithCancel(ctx) //nolint:gosec // G118
streamRegistry.mu.Lock()
streamRegistry.cancels["watch:"+taskID] = watchCancel
streamRegistry.mu.Unlock()
go engine.NewWatchReporter(agentClient, streamSrv, taskID).Run(watchCtx)
case "stop-stream":
cancelStreamTask(taskID)
if streamSrv.CurrentTaskID() == taskID {
@ -440,19 +301,77 @@ func runDaemonStart() error {
}
}
// Config hot-reload (SIGUSR1 on Unix, no-op on Windows)
// Tickers are initialized inside d.Run(), so we pass the daemon
// and the reload goroutine reads them when the signal arrives.
startReloadWatcher(&ReloadableConfig{Daemon: d})
// Wire: sync receives stream requests for completed downloads
d.OnStreamRequested = func(sr agent.StreamRequest) {
if streamSrv.CurrentTaskID() == sr.TaskID {
// Already serving — notify server it's ready
go func() {
if _, err := agentClient.ReportStatus(ctx, agent.StatusUpdate{
TaskID: sr.TaskID,
StreamReady: true,
}); err != nil {
log.Printf("[%s] stream ready re-notify failed: %v", agent.ShortID(sr.TaskID), err)
}
}()
return
}
// Signal handling
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
filePath := sr.FilePath
info, err := os.Stat(filePath)
if err != nil {
log.Printf("[%s] stream request: file not found: %s", agent.ShortID(sr.TaskID), filePath)
go func() {
if _, err := agentClient.ReportStatus(ctx, agent.StatusUpdate{
TaskID: sr.TaskID,
Status: "failed",
ErrorMessage: fmt.Sprintf("file not found: %s", filePath),
}); err != nil {
log.Printf("[%s] stream error report failed: %v", agent.ShortID(sr.TaskID), err)
}
}()
return
}
// Start progress reporter in background
go reporter.Run(ctx)
if info.IsDir() {
found := engine.FindVideoFile(filePath)
if found == "" {
log.Printf("[%s] stream request: no video file in directory: %s", agent.ShortID(sr.TaskID), filePath)
go func() {
if _, err := agentClient.ReportStatus(ctx, agent.StatusUpdate{
TaskID: sr.TaskID,
Status: "failed",
ErrorMessage: fmt.Sprintf("no video file in directory: %s", filePath),
}); err != nil {
log.Printf("[%s] stream error report failed: %v", agent.ShortID(sr.TaskID), err)
}
}()
return
}
filePath = found
log.Printf("[%s] resolved directory to video file: %s", agent.ShortID(sr.TaskID), filepath.Base(filePath))
}
// Periodic DHT node persistence (every 5 min) — protects against crash data loss
cancelStreamContexts()
streamSrv.SetFile(engine.NewDiskFileProvider(filePath), sr.TaskID)
log.Printf("[%s] streaming from disk: %s → %s", agent.ShortID(sr.TaskID), filepath.Base(filePath), streamSrv.URL())
watchCtx, watchCancel := context.WithCancel(ctx) //nolint:gosec // G118
streamRegistry.mu.Lock()
streamRegistry.cancels["watch:"+sr.TaskID] = watchCancel
streamRegistry.mu.Unlock()
go engine.NewWatchReporter(agentClient, streamSrv, sr.TaskID).Run(watchCtx)
go func() {
if _, err := agentClient.ReportStatus(ctx, agent.StatusUpdate{
TaskID: sr.TaskID,
StreamReady: true,
}); err != nil {
log.Printf("[%s] stream ready report failed: %v", agent.ShortID(sr.TaskID), err)
}
}()
}
// Periodic DHT node persistence (every 5 min)
go func() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
@ -466,8 +385,7 @@ func runDaemonStart() error {
}
}()
// Start auto-scan goroutine (daily library scan + sync)
// Default scan_path to download dir so auto-scan works out of the box.
// Start auto-scan goroutine
scanPath := cfg.Library.ScanPath
if scanPath == "" {
scanPath = cfg.Download.Dir
@ -484,7 +402,10 @@ func runDaemonStart() error {
go runAutoScan(ctx, scanCfg, scanInterval, agentClient, d.ScanNow)
}
// Start daemon (blocks)
// Start reporter only for stream task handling
go reporter.Run(ctx)
// Start daemon (blocks — runs sync loop)
errCh := make(chan error, 1)
go func() {
errCh <- d.Run(ctx)
@ -493,6 +414,10 @@ func runDaemonStart() error {
// Start idle guard for the persistent stream server
go startIdleGuard(ctx, streamSrv)
// Signal handling
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
// Wait for signal or error
select {
case sig := <-sigCh:
@ -506,6 +431,7 @@ func runDaemonStart() error {
defer shutdownCancel()
manager.Shutdown(shutdownCtx)
d.Deregister()
fmt.Println(" Daemon stopped.")
return nil
@ -517,41 +443,6 @@ func runDaemonStart() error {
}
}
// deriveWSURL derives a WebSocket URL from the API URL.
// https://torrentclaw.com → wss://unarr.torrentclaw.com/ws/{agentId}
// Returns "" for localhost/dev environments where WS gateway isn't available.
func deriveWSURL(apiURL, agentID string) string {
if apiURL == "" || agentID == "" {
return ""
}
// Parse domain from API URL
domain := apiURL
for _, prefix := range []string{"https://", "http://"} {
if len(domain) > len(prefix) && domain[:len(prefix)] == prefix {
domain = domain[len(prefix):]
break
}
}
// Strip trailing slash/path
for i := 0; i < len(domain); i++ {
if domain[i] == '/' {
domain = domain[:i]
break
}
}
// Strip port if present
if idx := strings.LastIndex(domain, ":"); idx > 0 {
domain = domain[:idx]
}
// Skip WS for localhost/dev — gateway only available in production
if domain == "localhost" || domain == "127.0.0.1" || domain == "0.0.0.0" {
return ""
}
return "wss://unarr." + domain + "/ws/" + agentID
}
func formatSpeedLog(bps int64) string {
switch {
case bps >= 1024*1024*1024:
@ -569,11 +460,9 @@ func formatSpeedLog(bps int64) string {
func runAutoScan(ctx context.Context, cfg config.Config, interval time.Duration, ac *agent.Client, scanNow <-chan struct{}) {
log.Printf("[auto-scan] enabled: every %s, path: %s", interval, cfg.Library.ScanPath)
// Run first scan after a short delay (let daemon stabilize)
select {
case <-time.After(30 * time.Second):
case <-scanNow:
// Immediate scan requested before initial delay
case <-ctx.Done():
return
}
@ -608,7 +497,6 @@ func runAutoScan(ctx context.Context, cfg config.Config, interval time.Duration,
return
}
// Sync to server
items := library.BuildSyncItems(cache)
if len(items) == 0 {
log.Printf("[auto-scan] no items to sync")

View file

@ -2,32 +2,6 @@ package cmd
import "testing"
func TestDeriveWSURL(t *testing.T) {
tests := []struct {
apiURL string
agentID string
want string
}{
{"https://torrentclaw.com", "agent-123", "wss://unarr.torrentclaw.com/ws/agent-123"},
{"http://localhost:3000", "a1", ""}, // localhost skipped
{"http://127.0.0.1:3000", "a1", ""}, // 127.0.0.1 skipped
{"https://torrentclaw.com/", "a1", "wss://unarr.torrentclaw.com/ws/a1"},
{"https://api.example.io", "x", "wss://unarr.api.example.io/ws/x"},
{"", "agent-123", ""},
{"https://torrentclaw.com", "", ""},
{"", "", ""},
}
for _, tt := range tests {
t.Run(tt.apiURL+"_"+tt.agentID, func(t *testing.T) {
got := deriveWSURL(tt.apiURL, tt.agentID)
if got != tt.want {
t.Errorf("deriveWSURL(%q, %q) = %q, want %q", tt.apiURL, tt.agentID, got, tt.want)
}
})
}
}
func TestFormatSpeedLog(t *testing.T) {
tests := []struct {
bps int64

View file

@ -7,7 +7,6 @@ import (
"os"
"os/signal"
"syscall"
"time"
"github.com/torrentclaw/unarr/internal/agent"
"github.com/torrentclaw/unarr/internal/config"
@ -19,7 +18,8 @@ type ReloadableConfig struct {
}
// startReloadWatcher listens for SIGUSR1 and reloads config.
// Only intervals are hot-reloadable (speeds require torrent client restart).
// With the sync-based architecture, intervals are fixed (3s watching, 60s idle).
// Hot-reload now mainly serves as a signal to re-read config for future settings.
func startReloadWatcher(rc *ReloadableConfig) {
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGUSR1)
@ -28,24 +28,11 @@ func startReloadWatcher(rc *ReloadableConfig) {
for range sigCh {
log.Println("Received SIGUSR1, reloading config...")
cfg, err := config.Load("")
_, err := config.Load("")
if err != nil {
log.Printf("Config reload failed: %v", err)
continue
}
cfg.ApplyEnvOverrides()
// Update poll interval
if d, _ := time.ParseDuration(cfg.Daemon.PollInterval); d > 0 && rc.Daemon.PollTicker != nil {
rc.Daemon.PollTicker.Reset(d)
log.Printf(" Poll interval: %s", d)
}
// Update heartbeat interval
if d, _ := time.ParseDuration(cfg.Daemon.HeartbeatInterval); d > 0 && rc.Daemon.HeartbeatTicker != nil {
rc.Daemon.HeartbeatTicker.Reset(d)
log.Printf(" Heartbeat interval: %s", d)
}
log.Println("Config reloaded successfully")
}

View file

@ -1,4 +1,4 @@
package cmd
// Version is the CLI version. Overridden by goreleaser ldflags at release time.
var Version = "0.5.5"
var Version = "0.5.6"

View file

@ -26,7 +26,6 @@ type Config struct {
type AuthConfig struct {
APIKey string `toml:"api_key"`
APIURL string `toml:"api_url"`
WSURL string `toml:"ws_url"` // optional, derived from api_url if empty
}
type AgentConfig struct {
@ -54,9 +53,7 @@ type OrganizeConfig struct {
}
type DaemonConfig struct {
PollInterval string `toml:"poll_interval"`
HeartbeatInterval string `toml:"heartbeat_interval"`
StatusInterval string `toml:"status_interval"`
StatusInterval string `toml:"status_interval"`
}
type NotificationsConfig struct {
@ -92,10 +89,7 @@ func Default() Config {
Organize: OrganizeConfig{
Enabled: true,
},
Daemon: DaemonConfig{
PollInterval: "30s",
HeartbeatInterval: "30s",
},
Daemon: DaemonConfig{},
Notifications: NotificationsConfig{
Enabled: true,
},

View file

@ -21,8 +21,8 @@ func TestDefault(t *testing.T) {
if cfg.General.Country != "US" {
t.Errorf("default Country = %q, want US", cfg.General.Country)
}
if cfg.Daemon.HeartbeatInterval != "30s" {
t.Errorf("default HeartbeatInterval = %q, want 30s", cfg.Daemon.HeartbeatInterval)
if cfg.Daemon.StatusInterval != "" {
t.Errorf("default StatusInterval = %q, want empty", cfg.Daemon.StatusInterval)
}
}

View file

@ -10,6 +10,8 @@ import (
"path/filepath"
"sync"
"time"
"github.com/torrentclaw/unarr/internal/agent"
)
// httpClient is used for debrid HTTPS downloads with a reasonable header timeout.
@ -19,13 +21,6 @@ var httpClient = &http.Client{
},
}
func shortID(id string) string {
if len(id) > 8 {
return id[:8]
}
return id
}
// DebridDownloader downloads files via HTTPS direct URLs resolved by the server.
// The server handles all debrid provider interaction; this downloader only needs
// a plain HTTPS URL to fetch.
@ -129,7 +124,7 @@ func (d *DebridDownloader) Download(ctx context.Context, task *Task, outputDir s
var serverSize int64
if _, err := fmt.Sscanf(cr, "bytes */%d", &serverSize); err == nil && serverSize > 0 && existingSize != serverSize {
// Local file size doesn't match server — re-download from scratch
log.Printf("[%s] local size %s != server size %s, re-downloading", shortID(task.ID), formatBytes(existingSize), formatBytes(serverSize))
log.Printf("[%s] local size %s != server size %s, re-downloading", agent.ShortID(task.ID), formatBytes(existingSize), formatBytes(serverSize))
resp.Body.Close()
req2, err := http.NewRequestWithContext(dlCtx, http.MethodGet, task.DirectURL, nil)
if err != nil {
@ -149,7 +144,7 @@ func (d *DebridDownloader) Download(ctx context.Context, task *Task, outputDir s
break // continue to download loop
}
}
log.Printf("[%s] file already complete: %s (%s)", shortID(task.ID), fileName, formatBytes(existingSize))
log.Printf("[%s] file already complete: %s (%s)", agent.ShortID(task.ID), fileName, formatBytes(existingSize))
return &Result{
FilePath: destPath,
FileName: fileName,
@ -166,10 +161,10 @@ func (d *DebridDownloader) Download(ctx context.Context, task *Task, outputDir s
var flags int
if startOffset > 0 {
flags = os.O_WRONLY | os.O_APPEND
log.Printf("[%s] resuming debrid download at %s: %s", shortID(task.ID), formatBytes(startOffset), fileName)
log.Printf("[%s] resuming debrid download at %s: %s", agent.ShortID(task.ID), formatBytes(startOffset), fileName)
} else {
flags = os.O_WRONLY | os.O_CREATE | os.O_TRUNC
log.Printf("[%s] starting debrid download: %s", shortID(task.ID), fileName)
log.Printf("[%s] starting debrid download: %s", agent.ShortID(task.ID), fileName)
}
if err := os.MkdirAll(filepath.Dir(destPath), 0o755); err != nil {
@ -223,7 +218,7 @@ func (d *DebridDownloader) Download(ctx context.Context, task *Task, outputDir s
}
log.Printf("[%s] %d%% — %s/%s @ %s/s (debrid)",
shortID(task.ID), pct,
agent.ShortID(task.ID), pct,
formatBytes(downloaded), formatBytes(totalBytes), formatBytes(speed))
p := Progress{
@ -252,7 +247,7 @@ func (d *DebridDownloader) Download(ctx context.Context, task *Task, outputDir s
}
}
log.Printf("[%s] debrid download complete: %s (%s)", shortID(task.ID), fileName, formatBytes(downloaded))
log.Printf("[%s] debrid download complete: %s (%s)", agent.ShortID(task.ID), fileName, formatBytes(downloaded))
return &Result{
FilePath: destPath,
@ -271,7 +266,7 @@ func (d *DebridDownloader) Pause(taskID string) error {
if ok {
cancel()
log.Printf("[%s] debrid download paused (file kept for resume)", shortID(taskID))
log.Printf("[%s] debrid download paused (file kept for resume)", agent.ShortID(taskID))
}
return nil
}
@ -285,7 +280,7 @@ func (d *DebridDownloader) Cancel(taskID string) error {
if ok {
cancel()
log.Printf("[%s] debrid download cancelled", shortID(taskID))
log.Printf("[%s] debrid download cancelled", agent.ShortID(taskID))
}
return nil
}

View file

@ -28,6 +28,15 @@ type Manager struct {
sem chan struct{}
wg sync.WaitGroup
// OnTaskDone is called after a task completes or fails (slot freed).
// Used by the daemon to trigger an immediate sync.
OnTaskDone func()
// recentlyFinished holds tasks that completed/failed since the last sync read.
// The sync goroutine reads and clears this to include final states in the next sync.
recentMu sync.Mutex
recentFinished []agent.TaskState
}
// NewManager creates a download manager.
@ -67,7 +76,7 @@ func (m *Manager) Submit(ctx context.Context, at agent.Task) {
// Force start: bypass semaphore (like Transmission's "Force Start")
if at.ForceStart {
log.Printf("[%s] force start: bypassing queue", task.ID[:8])
log.Printf("[%s] force start: bypassing queue", agent.ShortID(task.ID))
m.wg.Add(1)
go func() {
defer m.wg.Done()
@ -88,7 +97,12 @@ func (m *Manager) Submit(ctx context.Context, at agent.Task) {
m.wg.Add(1)
go func() {
defer m.wg.Done()
defer func() { <-m.sem }()
defer func() {
<-m.sem
if m.OnTaskDone != nil {
m.OnTaskDone()
}
}()
defer taskCancel()
m.processTask(taskCtx, task)
}()
@ -99,6 +113,11 @@ func (m *Manager) HasCapacity() bool {
return len(m.sem) < cap(m.sem)
}
// FreeSlots returns the number of available download slots.
func (m *Manager) FreeSlots() int {
return cap(m.sem) - len(m.sem)
}
// ActiveCount returns the number of in-progress downloads.
func (m *Manager) ActiveCount() int {
m.activeMu.RLock()
@ -113,6 +132,17 @@ func (m *Manager) GetTask(taskID string) *Task {
return m.active[taskID]
}
// ActiveTaskIDs returns the IDs of all in-progress tasks.
func (m *Manager) ActiveTaskIDs() []string {
m.activeMu.RLock()
defer m.activeMu.RUnlock()
ids := make([]string, 0, len(m.active))
for id := range m.active {
ids = append(ids, id)
}
return ids
}
// ActiveTasks returns a snapshot of all active tasks.
func (m *Manager) ActiveTasks() []*Task {
m.activeMu.RLock()
@ -124,6 +154,37 @@ func (m *Manager) ActiveTasks() []*Task {
return tasks
}
// TaskStates returns the current state of all active tasks plus any recently
// finished tasks that haven't been synced yet. Called by the sync goroutine.
func (m *Manager) TaskStates() []agent.TaskState {
// Collect active tasks
m.activeMu.RLock()
states := make([]agent.TaskState, 0, len(m.active))
for _, t := range m.active {
states = append(states, agent.TaskStateFromUpdate(t.ToStatusUpdate()))
}
m.activeMu.RUnlock()
// Drain recently finished tasks (consumed once per sync)
m.recentMu.Lock()
states = append(states, m.recentFinished...)
m.recentFinished = nil
m.recentMu.Unlock()
return states
}
// recordFinished stores a completed/failed task for the next sync cycle.
func (m *Manager) recordFinished(update agent.StatusUpdate) {
m.recentMu.Lock()
defer m.recentMu.Unlock()
m.recentFinished = append(m.recentFinished, agent.TaskStateFromUpdate(update))
// Keep bounded
if len(m.recentFinished) > 20 {
m.recentFinished = m.recentFinished[len(m.recentFinished)-20:]
}
}
// CancelTask cancels an active download by task ID (keeps partial files).
func (m *Manager) CancelTask(taskID string) {
m.activeMu.RLock()
@ -150,7 +211,7 @@ func (m *Manager) CancelTask(taskID string) {
task.mu.Unlock()
task.Transition(StatusCancelled)
log.Printf("[%s] cancelled: %s", taskID[:8], task.Title)
log.Printf("[%s] cancelled: %s", agent.ShortID(taskID), task.Title)
}
// PauseTask pauses an active download (keeps partial files for resume).
@ -173,7 +234,7 @@ func (m *Manager) PauseTask(taskID string) {
}
task.Transition(StatusCancelled) // will be re-created as pending by server
log.Printf("[%s] paused: %s", taskID[:8], task.Title)
log.Printf("[%s] paused: %s", agent.ShortID(taskID), task.Title)
}
// CancelAndDeleteFiles cancels a download and removes its files from disk.
@ -200,7 +261,7 @@ func (m *Manager) CancelAndDeleteFiles(taskID string) {
task.mu.Unlock()
task.Transition(StatusCancelled)
log.Printf("[%s] cancelled + files deleted: %s", taskID[:8], task.Title)
log.Printf("[%s] cancelled + files deleted: %s", agent.ShortID(taskID), task.Title)
}
// Wait blocks until all active downloads finish.
@ -261,7 +322,7 @@ func (m *Manager) processTask(ctx context.Context, task *Task) {
}
task.ResolvedMethod = method
log.Printf("[%s] resolved method: %s", task.ID[:8], method)
log.Printf("[%s] resolved method: %s", agent.ShortID(task.ID), method)
// 2. Download
if err := task.Transition(StatusDownloading); err != nil {
@ -285,7 +346,7 @@ func (m *Manager) processTask(ctx context.Context, task *Task) {
if err != nil {
// Try fallback
if tryFallback(task, m.downloaders) {
log.Printf("[%s] %s failed, trying fallback: %v", task.ID[:8], method, err)
log.Printf("[%s] %s failed, trying fallback: %v", agent.ShortID(task.ID), method, err)
if err := task.Transition(StatusResolving); err == nil {
m.processTaskRetry(ctx, task)
return
@ -295,61 +356,7 @@ func (m *Manager) processTask(ctx context.Context, task *Task) {
return
}
// 3. Verify
if err := task.Transition(StatusVerifying); err != nil {
m.fail(ctx, task, "transition error: "+err.Error())
return
}
if err := verify(result); err != nil {
m.fail(ctx, task, "verification failed: "+err.Error())
return
}
// 4. Organize
if err := task.Transition(StatusOrganizing); err != nil {
m.fail(ctx, task, "transition error: "+err.Error())
return
}
finalPath, err := organize(result, task, m.cfg.Organize)
if err != nil {
log.Printf("[%s] organize warning: %v (keeping in download dir)", task.ID[:8], err)
finalPath = result.FilePath
}
task.mu.Lock()
task.FilePath = finalPath
task.mu.Unlock()
// 4b. Handle upgrade replacement (mode = "upgrade")
if task.ReplacePath != "" {
backupDir := "" // uses default ~/.local/share/unarr/replaced/
if err := replaceFile(task.ReplacePath, finalPath, backupDir); err != nil {
log.Printf("[%s] replace warning: %v (keeping new file at %s)", task.ID[:8], err, finalPath)
} else {
task.mu.Lock()
task.FilePath = task.ReplacePath
task.mu.Unlock()
log.Printf("[%s] upgraded: replaced %s", task.ID[:8], task.ReplacePath)
}
}
// 5. Complete
if method == MethodTorrent && m.cfg.Organize.Enabled {
// Could add seeding here in the future
}
if err := task.Transition(StatusCompleted); err != nil {
m.fail(ctx, task, "transition error: "+err.Error())
return
}
log.Printf("[%s] completed: %s -> %s", task.ID[:8], task.Title, finalPath)
if m.cfg.Notifications {
desktopNotify("Download complete", task.Title)
}
m.reporter.ReportFinal(ctx, task)
m.finalize(ctx, task, result)
}
// processTaskRetry handles fallback after a method failure.
@ -361,7 +368,7 @@ func (m *Manager) processTaskRetry(ctx context.Context, task *Task) {
}
task.ResolvedMethod = method
log.Printf("[%s] fallback to: %s", task.ID[:8], method)
log.Printf("[%s] fallback to: %s", agent.ShortID(task.ID), method)
if err := task.Transition(StatusDownloading); err != nil {
m.fail(ctx, task, "transition error: "+err.Error())
@ -383,15 +390,31 @@ func (m *Manager) processTaskRetry(ctx context.Context, task *Task) {
return
}
// Verify + Organize + Complete (same as processTask)
task.Transition(StatusVerifying)
m.finalize(ctx, task, result)
}
// finalize runs verify → organize → upgrade replacement → complete for a downloaded task.
func (m *Manager) finalize(ctx context.Context, task *Task, result *Result) {
// Verify
if err := task.Transition(StatusVerifying); err != nil {
m.fail(ctx, task, "transition error: "+err.Error())
return
}
if err := verify(result); err != nil {
m.fail(ctx, task, "verification failed: "+err.Error())
return
}
task.Transition(StatusOrganizing)
finalPath, _ := organize(result, task, m.cfg.Organize)
// Organize
if err := task.Transition(StatusOrganizing); err != nil {
m.fail(ctx, task, "transition error: "+err.Error())
return
}
finalPath, err := organize(result, task, m.cfg.Organize)
if err != nil {
log.Printf("[%s] organize warning: %v (keeping in download dir)", agent.ShortID(task.ID), err)
finalPath = result.FilePath
}
if finalPath == "" {
finalPath = result.FilePath
}
@ -399,8 +422,29 @@ func (m *Manager) processTaskRetry(ctx context.Context, task *Task) {
task.FilePath = finalPath
task.mu.Unlock()
task.Transition(StatusCompleted)
log.Printf("[%s] completed (fallback): %s -> %s", task.ID[:8], task.Title, finalPath)
// Handle upgrade replacement (mode = "upgrade")
if task.ReplacePath != "" {
backupDir := "" // uses default ~/.local/share/unarr/replaced/
if err := replaceFile(task.ReplacePath, finalPath, backupDir); err != nil {
log.Printf("[%s] replace warning: %v (keeping new file at %s)", agent.ShortID(task.ID), err, finalPath)
} else {
task.mu.Lock()
task.FilePath = task.ReplacePath
task.mu.Unlock()
log.Printf("[%s] upgraded: replaced %s", agent.ShortID(task.ID), task.ReplacePath)
}
}
// Complete
if err := task.Transition(StatusCompleted); err != nil {
m.fail(ctx, task, "transition error: "+err.Error())
return
}
log.Printf("[%s] completed: %s -> %s", agent.ShortID(task.ID), task.Title, finalPath)
if m.cfg.Notifications {
desktopNotify("Download complete", task.Title)
}
m.recordFinished(task.ToStatusUpdate())
m.reporter.ReportFinal(ctx, task)
}
@ -409,9 +453,10 @@ func (m *Manager) fail(ctx context.Context, task *Task, msg string) {
task.ErrorMessage = msg
task.mu.Unlock()
task.Transition(StatusFailed)
log.Printf("[%s] FAILED: %s — %s", task.ID[:8], task.Title, msg)
log.Printf("[%s] FAILED: %s — %s", agent.ShortID(task.ID), task.Title, msg)
if m.cfg.Notifications {
desktopNotify("Download failed", task.Title+": "+msg)
}
m.recordFinished(task.ToStatusUpdate())
m.reporter.ReportFinal(ctx, task)
}

View file

@ -13,13 +13,11 @@ import (
type ActionFunc func(taskID string)
// StatusReporter is the interface used by ProgressReporter to send progress updates.
// Both *agent.Client and agent.Transport implement this via their ReportStatus/SendProgress methods.
type StatusReporter interface {
ReportStatus(ctx context.Context, update agent.StatusUpdate) (*agent.StatusResponse, error)
}
// BatchStatusReporter extends StatusReporter with batch support.
// Transports that implement this send all updates in a single request.
type BatchStatusReporter interface {
StatusReporter
BatchReportStatus(ctx context.Context, updates []agent.StatusUpdate) (*agent.BatchStatusResponse, error)
@ -48,7 +46,6 @@ type ProgressReporter struct {
}
// NewProgressReporter creates a reporter that flushes every interval.
// Accepts *agent.Client directly (backwards compatible).
func NewProgressReporter(ac *agent.Client, interval time.Duration) *ProgressReporter {
return &ProgressReporter{
reporter: ac,
@ -58,25 +55,6 @@ func NewProgressReporter(ac *agent.Client, interval time.Duration) *ProgressRepo
}
}
// NewProgressReporterWithTransport creates a reporter using a Transport.
func NewProgressReporterWithTransport(t agent.Transport, interval time.Duration) *ProgressReporter {
return &ProgressReporter{
reporter: &transportStatusAdapter{t: t},
interval: interval,
latest: make(map[string]*Task),
lastReported: make(map[string]TaskStatus),
}
}
// transportStatusAdapter adapts agent.Transport to StatusReporter.
type transportStatusAdapter struct {
t agent.Transport
}
func (a *transportStatusAdapter) ReportStatus(ctx context.Context, update agent.StatusUpdate) (*agent.StatusResponse, error) {
return a.t.SendProgress(ctx, update)
}
// SetCancelHandler sets the callback invoked when the server says a task is cancelled.
func (r *ProgressReporter) SetCancelHandler(fn ActionFunc) { r.onCancel = fn }