feat(sync): replace WS+DO transport with unified HTTP sync
Replace the WebSocket + Cloudflare Durable Object architecture with a single POST /sync endpoint. The CLI now operates autonomously with local state (tasks.json) and syncs bidirectionally via adaptive-interval HTTP polling (3s watching, 60s idle). - Remove transport_ws, transport_hybrid, transport_http (~2,600 lines) - Add SyncClient with adaptive interval loop - Add LocalState for CLI-side task persistence - Add TaskStateFromUpdate() helper (DRY) - Extract finalize() to deduplicate processTask/processTaskRetry - Consolidate shortID() into agent.ShortID (was in 3 packages) - Wire GetActiveCount so `unarr status` shows active tasks - Remove poll_interval, heartbeat_interval, ws_url from config - Simplify ProgressReporter (sync replaces direct HTTP reporting)
This commit is contained in:
parent
2398707cc1
commit
5d4a67c7a2
26 changed files with 1320 additions and 3400 deletions
11
CHANGELOG.md
11
CHANGELOG.md
|
|
@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file.
|
|||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [0.5.6] - 2026-04-07
|
||||
|
||||
|
||||
### Fixed
|
||||
|
||||
- **ws**: add ping/pong keepalive and read deadline to detect zombie connections
|
||||
## [0.5.5] - 2026-04-07
|
||||
|
||||
|
||||
|
|
@ -17,6 +23,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
### Fixed
|
||||
|
||||
- **daemon**: cancel watch reporter on stream switch and re-notify ready
|
||||
|
||||
### Other
|
||||
|
||||
- **release**: 0.5.5
|
||||
## [0.5.4] - 2026-04-07
|
||||
|
||||
|
||||
|
|
@ -153,6 +163,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
|
||||
- remove UPX compression (antivirus false positives, startup penalty)
|
||||
- add -s -w -trimpath to Makefile, add build-small target with UPX
|
||||
[0.5.6]: https://github.com/torrentclaw/unarr/compare/v0.5.5...v0.5.6
|
||||
[0.5.5]: https://github.com/torrentclaw/unarr/compare/v0.5.4...v0.5.5
|
||||
[0.5.4]: https://github.com/torrentclaw/unarr/compare/v0.5.3...v0.5.4
|
||||
[0.5.3]: https://github.com/torrentclaw/unarr/compare/v0.5.2...v0.5.3
|
||||
|
|
|
|||
2
go.mod
2
go.mod
|
|
@ -11,7 +11,6 @@ require (
|
|||
github.com/fatih/color v1.19.0
|
||||
github.com/getsentry/sentry-go v0.44.1
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/huin/goupnp v1.3.0
|
||||
github.com/olekukonko/tablewriter v1.1.4
|
||||
github.com/spf13/cobra v1.10.2
|
||||
|
|
@ -69,6 +68,7 @@ require (
|
|||
github.com/google/btree v1.1.3 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/google/pprof v0.0.0-20260302011040-a15ffb7f9dcc // indirect
|
||||
github.com/gorilla/websocket v1.5.3 // indirect
|
||||
github.com/huandu/xstrings v1.5.0 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||
|
|
|
|||
|
|
@ -40,27 +40,6 @@ func (c *Client) Register(ctx context.Context, req RegisterRequest) (*RegisterRe
|
|||
return &resp, nil
|
||||
}
|
||||
|
||||
// Heartbeat sends a periodic keep-alive signal and returns server directives.
|
||||
func (c *Client) Heartbeat(ctx context.Context, req HeartbeatRequest) (*HeartbeatResponse, error) {
|
||||
var resp HeartbeatResponse
|
||||
if err := c.doPost(ctx, "/api/internal/agent/heartbeat", req, &resp); err != nil {
|
||||
return nil, fmt.Errorf("heartbeat: %w", err)
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ClaimTasks polls for pending download tasks and claims them atomically.
|
||||
// Also returns any stream requests for completed downloads.
|
||||
func (c *Client) ClaimTasks(ctx context.Context, agentID string) (*TasksResponse, error) {
|
||||
url := fmt.Sprintf("/api/internal/agent/tasks?agentId=%s", agentID)
|
||||
var resp TasksResponse
|
||||
if err := c.doGet(ctx, url, &resp); err != nil {
|
||||
return nil, fmt.Errorf("claim tasks: %w", err)
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ReportStatus reports download progress or completion for a task.
|
||||
// Deregister notifies the server that the agent is shutting down.
|
||||
func (c *Client) Deregister(ctx context.Context, agentID string) error {
|
||||
req := struct {
|
||||
|
|
@ -91,6 +70,16 @@ func (c *Client) BatchReportStatus(ctx context.Context, updates []StatusUpdate)
|
|||
return &resp, nil
|
||||
}
|
||||
|
||||
// Sync sends the CLI's full state and receives all pending server actions.
|
||||
// This is the single endpoint for bidirectional state synchronization.
|
||||
func (c *Client) Sync(ctx context.Context, req SyncRequest) (*SyncResponse, error) {
|
||||
var resp SyncResponse
|
||||
if err := c.doPost(ctx, "/api/internal/agent/sync", req, &resp); err != nil {
|
||||
return nil, fmt.Errorf("sync: %w", err)
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Usenet endpoints
|
||||
// ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -72,70 +72,6 @@ func TestRegister(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestHeartbeat(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/internal/agent/heartbeat" {
|
||||
t.Errorf("path = %s, want /api/internal/agent/heartbeat", r.URL.Path)
|
||||
}
|
||||
var req HeartbeatRequest
|
||||
json.NewDecoder(r.Body).Decode(&req)
|
||||
if req.AgentID != "agent-123" {
|
||||
t.Errorf("agentId = %q, want agent-123", req.AgentID)
|
||||
}
|
||||
json.NewEncoder(w).Encode(HeartbeatResponse{Success: true})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key", "unarr-test")
|
||||
resp, err := c.Heartbeat(context.Background(), HeartbeatRequest{AgentID: "agent-123"})
|
||||
if err != nil {
|
||||
t.Fatalf("Heartbeat failed: %v", err)
|
||||
}
|
||||
if !resp.Success {
|
||||
t.Error("expected success=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaimTasks(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
t.Errorf("method = %s, want GET", r.Method)
|
||||
}
|
||||
if r.URL.Query().Get("agentId") != "agent-123" {
|
||||
t.Errorf("agentId param = %q, want agent-123", r.URL.Query().Get("agentId"))
|
||||
}
|
||||
json.NewEncoder(w).Encode(TasksResponse{
|
||||
Tasks: []Task{
|
||||
{
|
||||
ID: "task-uuid-1",
|
||||
InfoHash: "abc123def456abc123def456abc123def456abc1",
|
||||
Title: "The Matrix (1999)",
|
||||
PreferredMethod: "auto",
|
||||
},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key", "unarr-test")
|
||||
resp, err := c.ClaimTasks(context.Background(), "agent-123")
|
||||
if err != nil {
|
||||
t.Fatalf("ClaimTasks failed: %v", err)
|
||||
}
|
||||
if len(resp.Tasks) != 1 {
|
||||
t.Fatalf("len(tasks) = %d, want 1", len(resp.Tasks))
|
||||
}
|
||||
if resp.Tasks[0].ID != "task-uuid-1" {
|
||||
t.Errorf("task.ID = %q, want task-uuid-1", resp.Tasks[0].ID)
|
||||
}
|
||||
if resp.Tasks[0].InfoHash != "abc123def456abc123def456abc123def456abc1" {
|
||||
t.Errorf("task.InfoHash = %q", resp.Tasks[0].InfoHash)
|
||||
}
|
||||
if resp.Tasks[0].PreferredMethod != "auto" {
|
||||
t.Errorf("task.PreferredMethod = %q, want auto", resp.Tasks[0].PreferredMethod)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReportStatus(t *testing.T) {
|
||||
var received StatusUpdate
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
|
@ -173,22 +109,6 @@ func TestReportStatus(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestClaimTasksEmpty(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(TasksResponse{Tasks: []Task{}})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key", "unarr-test")
|
||||
resp, err := c.ClaimTasks(context.Background(), "agent-123")
|
||||
if err != nil {
|
||||
t.Fatalf("ClaimTasks failed: %v", err)
|
||||
}
|
||||
if len(resp.Tasks) != 0 {
|
||||
t.Errorf("expected empty tasks, got %d", len(resp.Tasks))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIError(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
|
|
@ -279,50 +199,12 @@ func TestUserAgent(t *testing.T) {
|
|||
if r.Header.Get("User-Agent") != "unarr/0.2.0" {
|
||||
t.Errorf("User-Agent = %q, want unarr/0.2.0", r.Header.Get("User-Agent"))
|
||||
}
|
||||
json.NewEncoder(w).Encode(HeartbeatResponse{Success: true})
|
||||
json.NewEncoder(w).Encode(RegisterResponse{Success: true})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key", "unarr/0.2.0")
|
||||
c.Heartbeat(context.Background(), HeartbeatRequest{AgentID: "x"})
|
||||
}
|
||||
|
||||
func TestHeartbeatWithUpgradeSignal(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(HeartbeatResponse{
|
||||
Success: true,
|
||||
Upgrade: &UpgradeSignal{Version: "2.0.0"},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key", "unarr-test")
|
||||
resp, err := c.Heartbeat(context.Background(), HeartbeatRequest{AgentID: "agent-1"})
|
||||
if err != nil {
|
||||
t.Fatalf("Heartbeat failed: %v", err)
|
||||
}
|
||||
if resp.Upgrade == nil {
|
||||
t.Fatal("expected upgrade signal, got nil")
|
||||
}
|
||||
if resp.Upgrade.Version != "2.0.0" {
|
||||
t.Errorf("upgrade version = %q, want 2.0.0", resp.Upgrade.Version)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeartbeatWithoutUpgradeSignal(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(HeartbeatResponse{Success: true})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key", "unarr-test")
|
||||
resp, err := c.Heartbeat(context.Background(), HeartbeatRequest{AgentID: "agent-1"})
|
||||
if err != nil {
|
||||
t.Fatalf("Heartbeat failed: %v", err)
|
||||
}
|
||||
if resp.Upgrade != nil {
|
||||
t.Errorf("expected no upgrade signal, got %+v", resp.Upgrade)
|
||||
}
|
||||
c.Register(context.Background(), RegisterRequest{AgentID: "x"})
|
||||
}
|
||||
|
||||
func TestDeregister(t *testing.T) {
|
||||
|
|
|
|||
|
|
@ -14,75 +14,62 @@ import (
|
|||
|
||||
// DaemonConfig holds daemon runtime settings.
|
||||
type DaemonConfig struct {
|
||||
AgentID string
|
||||
AgentName string
|
||||
Version string
|
||||
DownloadDir string
|
||||
PollInterval time.Duration
|
||||
HeartbeatInterval time.Duration
|
||||
StreamPort int // port for the HTTP stream server (reported in heartbeat)
|
||||
LanIP string // LAN IP (reported in heartbeat for stream URL resolution)
|
||||
TailscaleIP string // Tailscale IP (reported in heartbeat for stream URL resolution)
|
||||
AgentID string
|
||||
AgentName string
|
||||
Version string
|
||||
DownloadDir string
|
||||
StreamPort int // port for the HTTP stream server
|
||||
LanIP string // LAN IP (reported in sync for stream URL resolution)
|
||||
TailscaleIP string // Tailscale IP (reported in sync for stream URL resolution)
|
||||
}
|
||||
|
||||
// Daemon manages the main loop: register, heartbeat, poll tasks.
|
||||
// Daemon manages agent registration and the sync loop.
|
||||
type Daemon struct {
|
||||
cfg DaemonConfig
|
||||
transport Transport
|
||||
cfg DaemonConfig
|
||||
client *Client
|
||||
sync *SyncClient
|
||||
state *LocalState
|
||||
|
||||
// Callbacks
|
||||
// Callbacks — set by cmd/daemon.go before calling Run.
|
||||
OnTasksClaimed func(tasks []Task)
|
||||
OnStreamRequested func(req StreamRequest)
|
||||
OnControlAction func(action, taskID string)
|
||||
OnControlAction func(action, taskID string, deleteFiles bool)
|
||||
GetActiveCount func() int // returns number of active downloads (wired from manager)
|
||||
|
||||
// State
|
||||
User UserInfo
|
||||
Features FeatureFlags
|
||||
Info AgentInfo
|
||||
State DaemonState
|
||||
heartbeatFailures int
|
||||
lastNotifiedVersion string
|
||||
|
||||
// Callbacks for state tracking (set by cmd/daemon.go)
|
||||
GetActiveCount func() int
|
||||
GetCleanableBytes func() int64
|
||||
|
||||
// Watching tracks whether a user is viewing download progress in the web UI.
|
||||
// When false, the progress reporter skips detailed updates (only sends final states).
|
||||
// Accessed from heartbeat goroutine, flush goroutine, and WatchingFunc closure — must be atomic.
|
||||
Watching atomic.Bool
|
||||
|
||||
// Exposed tickers for hot-reload
|
||||
PollTicker *time.Ticker
|
||||
HeartbeatTicker *time.Ticker
|
||||
|
||||
// pollNow triggers an immediate poll (e.g. on resume)
|
||||
pollNow chan struct{}
|
||||
|
||||
// ScanNow triggers an immediate library scan (from heartbeat or WebSocket control event)
|
||||
// ScanNow triggers an immediate library scan.
|
||||
ScanNow chan struct{}
|
||||
}
|
||||
|
||||
// NewDaemon creates a daemon with the given transport.
|
||||
// Use NewHTTPTransport for HTTP-only, or NewHybridTransport for WS+HTTP.
|
||||
func NewDaemon(cfg DaemonConfig, transport Transport) *Daemon {
|
||||
if cfg.PollInterval == 0 {
|
||||
cfg.PollInterval = 30 * time.Second
|
||||
}
|
||||
if cfg.HeartbeatInterval == 0 {
|
||||
cfg.HeartbeatInterval = 30 * time.Second
|
||||
}
|
||||
|
||||
// NewDaemon creates a daemon with an HTTP client for sync-based communication.
|
||||
func NewDaemon(cfg DaemonConfig, client *Client) *Daemon {
|
||||
state := NewLocalState()
|
||||
return &Daemon{
|
||||
cfg: cfg,
|
||||
transport: transport,
|
||||
pollNow: make(chan struct{}, 1),
|
||||
ScanNow: make(chan struct{}, 1),
|
||||
cfg: cfg,
|
||||
client: client,
|
||||
state: state,
|
||||
sync: NewSyncClient(client, cfg, state),
|
||||
ScanNow: make(chan struct{}, 1),
|
||||
}
|
||||
}
|
||||
|
||||
// Transport returns the configured transport.
|
||||
func (d *Daemon) Transport() Transport { return d.transport }
|
||||
// SyncClient returns the sync client for external wiring.
|
||||
func (d *Daemon) SyncClient() *SyncClient { return d.sync }
|
||||
|
||||
// UpdateStreamPort updates the stream port reported in sync requests.
|
||||
func (d *Daemon) UpdateStreamPort(port int) {
|
||||
d.cfg.StreamPort = port
|
||||
d.sync.cfg.StreamPort = port
|
||||
}
|
||||
|
||||
// Register registers the agent and fetches user info + features.
|
||||
// Retries with exponential backoff on transient errors (429, 5xx, network).
|
||||
|
|
@ -109,11 +96,10 @@ func (d *Daemon) Register(ctx context.Context) error {
|
|||
var resp *RegisterResponse
|
||||
var err error
|
||||
for attempt := range maxRetries {
|
||||
resp, err = d.transport.Register(ctx, req)
|
||||
resp, err = d.client.Register(ctx, req)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
// Only retry on transient errors (429, 5xx, network failures)
|
||||
if !isTransientError(err) {
|
||||
return fmt.Errorf("register: %w", err)
|
||||
}
|
||||
|
|
@ -154,14 +140,9 @@ func (d *Daemon) Register(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// Run connects the transport, registers the agent, and starts the main loop.
|
||||
// Blocks until ctx is cancelled. Callers must NOT call transport.Connect before Run.
|
||||
// Run registers the agent and starts the sync loop.
|
||||
// Blocks until ctx is cancelled.
|
||||
func (d *Daemon) Run(ctx context.Context) error {
|
||||
// Connect transport (establishes WebSocket if available, falls back to HTTP)
|
||||
if err := d.transport.Connect(ctx); err != nil {
|
||||
return fmt.Errorf("connect transport: %w", err)
|
||||
}
|
||||
|
||||
// Register
|
||||
if err := d.Register(ctx); err != nil {
|
||||
return err
|
||||
|
|
@ -169,163 +150,61 @@ func (d *Daemon) Run(ctx context.Context) error {
|
|||
|
||||
log.Printf("Agent registered: %s (%s) [%s]", d.User.Name, d.User.Email, d.User.Plan)
|
||||
log.Printf("Features: torrent=%v debrid=%v usenet=%v", d.Features.Torrent, d.Features.Debrid, d.Features.Usenet)
|
||||
log.Printf("Polling every %s, heartbeat every %s", d.cfg.PollInterval, d.cfg.HeartbeatInterval)
|
||||
|
||||
d.HeartbeatTicker = time.NewTicker(d.cfg.HeartbeatInterval)
|
||||
defer d.HeartbeatTicker.Stop()
|
||||
|
||||
d.PollTicker = time.NewTicker(d.cfg.PollInterval)
|
||||
defer d.PollTicker.Stop()
|
||||
|
||||
heartbeatTicker := d.HeartbeatTicker
|
||||
pollTicker := d.PollTicker
|
||||
|
||||
// Initial poll immediately
|
||||
d.poll(ctx)
|
||||
|
||||
eventsCh := d.transport.Events()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Println("Daemon shutting down...")
|
||||
d.deregister()
|
||||
return nil
|
||||
|
||||
case event := <-eventsCh:
|
||||
d.handleEvent(event)
|
||||
|
||||
case <-heartbeatTicker.C:
|
||||
d.heartbeat(ctx)
|
||||
|
||||
case <-pollTicker.C:
|
||||
// Only poll in HTTP mode — WS mode receives tasks via Events
|
||||
if d.transport.Mode() == "http" {
|
||||
d.poll(ctx)
|
||||
}
|
||||
|
||||
case <-d.pollNow:
|
||||
d.poll(ctx)
|
||||
// Wire sync callbacks
|
||||
d.sync.OnNewTasks = func(tasks []Task) {
|
||||
if d.OnTasksClaimed != nil {
|
||||
d.OnTasksClaimed(tasks)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Daemon) heartbeat(ctx context.Context) {
|
||||
req := HeartbeatRequest{
|
||||
AgentID: d.cfg.AgentID,
|
||||
Name: d.cfg.AgentName,
|
||||
Version: d.cfg.Version,
|
||||
OS: runtime.GOOS,
|
||||
DownloadDir: d.cfg.DownloadDir,
|
||||
StreamPort: d.cfg.StreamPort,
|
||||
LanIP: d.cfg.LanIP,
|
||||
TailscaleIP: d.cfg.TailscaleIP,
|
||||
}
|
||||
if free, total, err := DiskInfo(d.cfg.DownloadDir); err == nil {
|
||||
req.DiskFreeBytes = free
|
||||
req.DiskTotalBytes = total
|
||||
}
|
||||
|
||||
resp, err := d.transport.SendHeartbeat(ctx, req)
|
||||
if err != nil {
|
||||
d.heartbeatFailures++
|
||||
if d.heartbeatFailures >= 5 && d.heartbeatFailures%5 == 0 {
|
||||
log.Printf("CRITICAL: %d consecutive heartbeat failures — server may be unreachable", d.heartbeatFailures)
|
||||
} else {
|
||||
log.Printf("Heartbeat failed: %v", err)
|
||||
d.sync.OnControl = func(action, taskID string, deleteFiles bool) {
|
||||
if d.OnControlAction != nil {
|
||||
d.OnControlAction(action, taskID, deleteFiles)
|
||||
}
|
||||
return
|
||||
}
|
||||
if d.heartbeatFailures > 0 {
|
||||
log.Printf("Heartbeat recovered after %d failures", d.heartbeatFailures)
|
||||
d.heartbeatFailures = 0
|
||||
d.sync.OnStreamRequest = func(req StreamRequest) {
|
||||
if d.OnStreamRequested != nil {
|
||||
d.OnStreamRequested(req)
|
||||
}
|
||||
}
|
||||
|
||||
// Update watching flag and state file
|
||||
d.Watching.Store(resp.Watching)
|
||||
d.State.LastHeartbeat = time.Now()
|
||||
if d.GetActiveCount != nil {
|
||||
d.State.ActiveTasks = d.GetActiveCount()
|
||||
d.sync.OnUpgrade = func(version string) {
|
||||
if version != d.lastNotifiedVersion {
|
||||
d.lastNotifiedVersion = version
|
||||
log.Printf("New version available: %s (run `unarr self-update` to upgrade)", version)
|
||||
}
|
||||
}
|
||||
WriteState(&d.State)
|
||||
|
||||
// Trigger library scan if requested
|
||||
if resp.Scan {
|
||||
d.sync.OnScan = func() {
|
||||
log.Printf("Library scan requested by server")
|
||||
select {
|
||||
case d.ScanNow <- struct{}{}:
|
||||
default: // scan already pending
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// Log once per version when server suggests an upgrade
|
||||
if resp.Upgrade != nil && resp.Upgrade.Version != "" && resp.Upgrade.Version != d.lastNotifiedVersion {
|
||||
d.lastNotifiedVersion = resp.Upgrade.Version
|
||||
log.Printf("New version available: %s (run `unarr self-update` to upgrade)", resp.Upgrade.Version)
|
||||
d.sync.OnWatchingChange = func(watching bool) {
|
||||
d.Watching.Store(watching)
|
||||
}
|
||||
}
|
||||
|
||||
// handleEvent processes a server-initiated event from the WebSocket transport.
|
||||
func (d *Daemon) handleEvent(event ServerEvent) {
|
||||
switch event.Type {
|
||||
case "tasks":
|
||||
if event.Tasks != nil && len(event.Tasks.Tasks) > 0 {
|
||||
log.Printf("Received %d task(s) via WebSocket", len(event.Tasks.Tasks))
|
||||
if d.OnTasksClaimed != nil {
|
||||
d.OnTasksClaimed(event.Tasks.Tasks)
|
||||
}
|
||||
d.sync.OnSyncSuccess = func() {
|
||||
d.State.LastHeartbeat = time.Now()
|
||||
if d.GetActiveCount != nil {
|
||||
d.State.ActiveTasks = d.GetActiveCount()
|
||||
}
|
||||
if event.Tasks != nil && d.OnStreamRequested != nil {
|
||||
for _, sr := range event.Tasks.StreamRequests {
|
||||
d.OnStreamRequested(sr)
|
||||
}
|
||||
}
|
||||
|
||||
case "upgrade":
|
||||
if event.Upgrade != nil && event.Upgrade.Version != "" && event.Upgrade.Version != d.lastNotifiedVersion {
|
||||
d.lastNotifiedVersion = event.Upgrade.Version
|
||||
log.Printf("New version available: %s (run `unarr self-update` to upgrade)", event.Upgrade.Version)
|
||||
}
|
||||
|
||||
case "control":
|
||||
if event.Control != nil {
|
||||
log.Printf("Control action via WebSocket: %s task %s", event.Control.Action, event.Control.TaskID)
|
||||
if event.Control.Action == "scan" {
|
||||
select {
|
||||
case d.ScanNow <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
if d.OnControlAction != nil {
|
||||
d.OnControlAction(event.Control.Action, event.Control.TaskID)
|
||||
}
|
||||
}
|
||||
|
||||
case "disconnected":
|
||||
log.Println("WebSocket disconnected, switching to HTTP polling")
|
||||
WriteState(&d.State)
|
||||
}
|
||||
|
||||
// Start sync loop (blocks)
|
||||
return d.sync.Run(ctx)
|
||||
}
|
||||
|
||||
// UpdateStreamPort updates the stream port reported in heartbeats.
|
||||
// Called after the persistent stream server binds (actual port may differ from configured).
|
||||
func (d *Daemon) UpdateStreamPort(port int) {
|
||||
d.cfg.StreamPort = port
|
||||
// TriggerSync requests an immediate sync cycle.
|
||||
func (d *Daemon) TriggerSync() {
|
||||
d.sync.TriggerSync()
|
||||
}
|
||||
|
||||
// TriggerPoll requests an immediate task poll cycle.
|
||||
// Used when a resume event is received to pick up re-pending tasks faster.
|
||||
func (d *Daemon) TriggerPoll() {
|
||||
select {
|
||||
case d.pollNow <- struct{}{}:
|
||||
default: // already pending
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Daemon) deregister() {
|
||||
// Deregister notifies the server of graceful shutdown.
|
||||
func (d *Daemon) Deregister() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
err := d.transport.Deregister(ctx, d.cfg.AgentID)
|
||||
if err != nil {
|
||||
if err := d.client.Deregister(ctx, d.cfg.AgentID); err != nil {
|
||||
log.Printf("Deregister failed: %v", err)
|
||||
} else {
|
||||
log.Println("Agent deregistered")
|
||||
|
|
@ -338,12 +217,10 @@ func isTransientError(err error) bool {
|
|||
if err == nil {
|
||||
return false
|
||||
}
|
||||
// Structured check: HTTPError carries the status code directly
|
||||
var httpErr *HTTPError
|
||||
if errors.As(err, &httpErr) {
|
||||
return httpErr.StatusCode == 429 || httpErr.StatusCode >= 500
|
||||
}
|
||||
// Fallback: network-level errors (no HTTP response received)
|
||||
lower := strings.ToLower(err.Error())
|
||||
for _, keyword := range []string{"connection refused", "no such host", "timeout", "request failed"} {
|
||||
if strings.Contains(lower, keyword) {
|
||||
|
|
@ -352,27 +229,3 @@ func isTransientError(err error) bool {
|
|||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (d *Daemon) poll(ctx context.Context) {
|
||||
resp, err := d.transport.ClaimTasks(ctx, d.cfg.AgentID)
|
||||
if err != nil {
|
||||
log.Printf("Poll failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
d.Info.LastPollAt = time.Now()
|
||||
|
||||
if len(resp.Tasks) > 0 {
|
||||
log.Printf("Claimed %d task(s)", len(resp.Tasks))
|
||||
if d.OnTasksClaimed != nil {
|
||||
d.OnTasksClaimed(resp.Tasks)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle stream requests for completed downloads
|
||||
if d.OnStreamRequested != nil {
|
||||
for _, sr := range resp.StreamRequests {
|
||||
d.OnStreamRequested(sr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
195
internal/agent/sync.go
Normal file
195
internal/agent/sync.go
Normal file
|
|
@ -0,0 +1,195 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"runtime"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// SyncIntervalWatching is the sync interval when someone is viewing the web UI.
|
||||
SyncIntervalWatching = 3 * time.Second
|
||||
// SyncIntervalIdle is the sync interval when nobody is watching.
|
||||
SyncIntervalIdle = 60 * time.Second
|
||||
)
|
||||
|
||||
// SyncClient handles bidirectional state synchronization between the CLI and server.
|
||||
// It sends the CLI's full execution state and receives all pending server actions
|
||||
// in a single HTTP round-trip, at an adaptive interval.
|
||||
type SyncClient struct {
|
||||
client *Client
|
||||
cfg DaemonConfig
|
||||
state *LocalState
|
||||
|
||||
// Callbacks — set by the daemon before calling Run.
|
||||
OnNewTasks func(tasks []Task)
|
||||
OnControl func(action, taskID string, deleteFiles bool)
|
||||
OnStreamRequest func(req StreamRequest)
|
||||
OnUpgrade func(version string)
|
||||
OnScan func()
|
||||
OnWatchingChange func(watching bool)
|
||||
OnSyncSuccess func() // called after each successful sync (e.g. to update state file)
|
||||
GetFreeSlots func() int
|
||||
GetTaskStates func() []TaskState // returns current state of all active + recently finished tasks
|
||||
|
||||
// SyncNow triggers an immediate sync (e.g., on task completion).
|
||||
SyncNow chan struct{}
|
||||
|
||||
watching atomic.Bool
|
||||
interval atomic.Int64 // stored as nanoseconds
|
||||
}
|
||||
|
||||
// NewSyncClient creates a sync client.
|
||||
func NewSyncClient(client *Client, cfg DaemonConfig, state *LocalState) *SyncClient {
|
||||
sc := &SyncClient{
|
||||
client: client,
|
||||
cfg: cfg,
|
||||
state: state,
|
||||
SyncNow: make(chan struct{}, 1),
|
||||
}
|
||||
sc.interval.Store(int64(SyncIntervalIdle))
|
||||
return sc
|
||||
}
|
||||
|
||||
// Watching returns whether someone is viewing the web UI.
|
||||
func (sc *SyncClient) Watching() bool {
|
||||
return sc.watching.Load()
|
||||
}
|
||||
|
||||
// TriggerSync requests an immediate sync cycle.
|
||||
func (sc *SyncClient) TriggerSync() {
|
||||
select {
|
||||
case sc.SyncNow <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// Run starts the adaptive sync loop. Blocks until ctx is cancelled.
|
||||
func (sc *SyncClient) Run(ctx context.Context) error {
|
||||
// Initial sync immediately
|
||||
sc.doSync(ctx)
|
||||
|
||||
ticker := time.NewTicker(sc.currentInterval())
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Final sync to report latest state
|
||||
finalCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
sc.doSync(finalCtx)
|
||||
return nil
|
||||
|
||||
case <-ticker.C:
|
||||
sc.doSync(ctx)
|
||||
ticker.Reset(sc.currentInterval())
|
||||
|
||||
case <-sc.SyncNow:
|
||||
sc.doSync(ctx)
|
||||
ticker.Reset(sc.currentInterval())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (sc *SyncClient) currentInterval() time.Duration {
|
||||
return time.Duration(sc.interval.Load())
|
||||
}
|
||||
|
||||
func (sc *SyncClient) doSync(ctx context.Context) {
|
||||
req := sc.buildRequest()
|
||||
resp, err := sc.client.Sync(ctx, req)
|
||||
if err != nil {
|
||||
if ctx.Err() == nil {
|
||||
log.Printf("sync failed: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
sc.processResponse(resp)
|
||||
sc.adjustInterval(resp.Watching)
|
||||
if sc.OnSyncSuccess != nil {
|
||||
sc.OnSyncSuccess()
|
||||
}
|
||||
}
|
||||
|
||||
func (sc *SyncClient) buildRequest() SyncRequest {
|
||||
req := SyncRequest{
|
||||
AgentID: sc.cfg.AgentID,
|
||||
Name: sc.cfg.AgentName,
|
||||
Version: sc.cfg.Version,
|
||||
OS: runtime.GOOS,
|
||||
Arch: runtime.GOARCH,
|
||||
DownloadDir: sc.cfg.DownloadDir,
|
||||
StreamPort: sc.cfg.StreamPort,
|
||||
LanIP: sc.cfg.LanIP,
|
||||
TailscaleIP: sc.cfg.TailscaleIP,
|
||||
}
|
||||
if sc.GetTaskStates != nil {
|
||||
req.Tasks = sc.GetTaskStates()
|
||||
} else {
|
||||
req.Tasks = sc.state.Snapshot()
|
||||
}
|
||||
if free, total, err := DiskInfo(sc.cfg.DownloadDir); err == nil {
|
||||
req.DiskFreeBytes = free
|
||||
req.DiskTotalBytes = total
|
||||
}
|
||||
if sc.GetFreeSlots != nil {
|
||||
req.FreeSlots = sc.GetFreeSlots()
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
func (sc *SyncClient) processResponse(resp *SyncResponse) {
|
||||
// New tasks
|
||||
if len(resp.NewTasks) > 0 && sc.OnNewTasks != nil {
|
||||
log.Printf("sync: received %d new task(s)", len(resp.NewTasks))
|
||||
sc.OnNewTasks(resp.NewTasks)
|
||||
}
|
||||
|
||||
// Control signals
|
||||
for _, ctrl := range resp.Controls {
|
||||
log.Printf("sync: control %s on task %s", ctrl.Action, ShortID(ctrl.TaskID))
|
||||
if sc.OnControl != nil {
|
||||
sc.OnControl(ctrl.Action, ctrl.TaskID, ctrl.DeleteFiles)
|
||||
}
|
||||
}
|
||||
|
||||
// Stream requests
|
||||
for _, sr := range resp.StreamRequests {
|
||||
if sc.OnStreamRequest != nil {
|
||||
sc.OnStreamRequest(sr)
|
||||
}
|
||||
}
|
||||
|
||||
// Upgrade
|
||||
if resp.Upgrade != nil && resp.Upgrade.Version != "" && sc.OnUpgrade != nil {
|
||||
sc.OnUpgrade(resp.Upgrade.Version)
|
||||
}
|
||||
|
||||
// Scan
|
||||
if resp.Scan && sc.OnScan != nil {
|
||||
sc.OnScan()
|
||||
}
|
||||
}
|
||||
|
||||
func (sc *SyncClient) adjustInterval(watching bool) {
|
||||
prev := sc.watching.Load()
|
||||
sc.watching.Store(watching)
|
||||
|
||||
var newInterval time.Duration
|
||||
if watching {
|
||||
newInterval = SyncIntervalWatching
|
||||
} else {
|
||||
newInterval = SyncIntervalIdle
|
||||
}
|
||||
|
||||
if sc.interval.Swap(int64(newInterval)) != int64(newInterval) {
|
||||
log.Printf("sync: interval=%s (watching=%v)", newInterval, watching)
|
||||
}
|
||||
|
||||
if prev != watching && sc.OnWatchingChange != nil {
|
||||
sc.OnWatchingChange(watching)
|
||||
}
|
||||
}
|
||||
362
internal/agent/sync_test.go
Normal file
362
internal/agent/sync_test.go
Normal file
|
|
@ -0,0 +1,362 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func newTestSyncClient(url string) (*SyncClient, *Client) {
|
||||
client := NewClient(url, "test-key", "test-agent/1.0")
|
||||
cfg := DaemonConfig{
|
||||
AgentID: "test-agent",
|
||||
AgentName: "Test",
|
||||
Version: "1.0.0",
|
||||
DownloadDir: "/tmp/downloads",
|
||||
}
|
||||
state := NewLocalState()
|
||||
sc := NewSyncClient(client, cfg, state)
|
||||
return sc, client
|
||||
}
|
||||
|
||||
func TestSyncClient_NewDefaults(t *testing.T) {
|
||||
sc, _ := newTestSyncClient("http://localhost")
|
||||
|
||||
if sc.Watching() {
|
||||
t.Error("should not be watching initially")
|
||||
}
|
||||
if sc.currentInterval() != SyncIntervalIdle {
|
||||
t.Errorf("expected idle interval %v, got %v", SyncIntervalIdle, sc.currentInterval())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncClient_AdjustInterval_Watching(t *testing.T) {
|
||||
sc, _ := newTestSyncClient("http://localhost")
|
||||
|
||||
sc.adjustInterval(true)
|
||||
|
||||
if sc.currentInterval() != SyncIntervalWatching {
|
||||
t.Errorf("expected watching interval %v, got %v", SyncIntervalWatching, sc.currentInterval())
|
||||
}
|
||||
if !sc.Watching() {
|
||||
t.Error("expected watching=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncClient_AdjustInterval_NotWatching(t *testing.T) {
|
||||
sc, _ := newTestSyncClient("http://localhost")
|
||||
|
||||
// First set watching, then unset
|
||||
sc.adjustInterval(true)
|
||||
sc.adjustInterval(false)
|
||||
|
||||
if sc.currentInterval() != SyncIntervalIdle {
|
||||
t.Errorf("expected idle interval %v, got %v", SyncIntervalIdle, sc.currentInterval())
|
||||
}
|
||||
if sc.Watching() {
|
||||
t.Error("expected watching=false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncClient_AdjustInterval_CallsOnWatchingChange(t *testing.T) {
|
||||
sc, _ := newTestSyncClient("http://localhost")
|
||||
|
||||
var changes []bool
|
||||
sc.OnWatchingChange = func(w bool) { changes = append(changes, w) }
|
||||
|
||||
sc.adjustInterval(true)
|
||||
sc.adjustInterval(true) // no change
|
||||
sc.adjustInterval(false) // change
|
||||
|
||||
if len(changes) != 2 {
|
||||
t.Fatalf("expected 2 changes, got %d: %v", len(changes), changes)
|
||||
}
|
||||
if !changes[0] {
|
||||
t.Error("first change should be true")
|
||||
}
|
||||
if changes[1] {
|
||||
t.Error("second change should be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncClient_TriggerSync_NonBlocking(t *testing.T) {
|
||||
sc, _ := newTestSyncClient("http://localhost")
|
||||
|
||||
// Fill the channel
|
||||
sc.TriggerSync()
|
||||
// Should not block
|
||||
sc.TriggerSync()
|
||||
sc.TriggerSync()
|
||||
|
||||
// Drain
|
||||
select {
|
||||
case <-sc.SyncNow:
|
||||
default:
|
||||
t.Error("expected a sync trigger in channel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncClient_ProcessResponse_NewTasks(t *testing.T) {
|
||||
sc, _ := newTestSyncClient("http://localhost")
|
||||
|
||||
var received []Task
|
||||
sc.OnNewTasks = func(tasks []Task) { received = tasks }
|
||||
|
||||
sc.processResponse(&SyncResponse{
|
||||
NewTasks: []Task{
|
||||
{ID: "t1", Title: "Movie 1", InfoHash: "abc"},
|
||||
{ID: "t2", Title: "Movie 2", InfoHash: "def"},
|
||||
},
|
||||
})
|
||||
|
||||
if len(received) != 2 {
|
||||
t.Fatalf("expected 2 tasks, got %d", len(received))
|
||||
}
|
||||
if received[0].Title != "Movie 1" {
|
||||
t.Errorf("expected Movie 1, got %s", received[0].Title)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncClient_ProcessResponse_NoTasks(t *testing.T) {
|
||||
sc, _ := newTestSyncClient("http://localhost")
|
||||
|
||||
var called bool
|
||||
sc.OnNewTasks = func(tasks []Task) { called = true }
|
||||
|
||||
sc.processResponse(&SyncResponse{NewTasks: nil})
|
||||
|
||||
if called {
|
||||
t.Error("OnNewTasks should not be called with empty tasks")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncClient_ProcessResponse_Controls(t *testing.T) {
|
||||
sc, _ := newTestSyncClient("http://localhost")
|
||||
|
||||
var actions []string
|
||||
var taskIDs []string
|
||||
sc.OnControl = func(action, taskID string, deleteFiles bool) {
|
||||
actions = append(actions, action)
|
||||
taskIDs = append(taskIDs, taskID)
|
||||
}
|
||||
|
||||
sc.processResponse(&SyncResponse{
|
||||
Controls: []ControlAction{
|
||||
{Action: "cancel", TaskID: "task-1234-5678"},
|
||||
{Action: "pause", TaskID: "task-abcd-efgh"},
|
||||
},
|
||||
})
|
||||
|
||||
if len(actions) != 2 {
|
||||
t.Fatalf("expected 2 controls, got %d", len(actions))
|
||||
}
|
||||
if actions[0] != "cancel" {
|
||||
t.Errorf("expected cancel, got %s", actions[0])
|
||||
}
|
||||
if actions[1] != "pause" {
|
||||
t.Errorf("expected pause, got %s", actions[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncClient_ProcessResponse_Upgrade(t *testing.T) {
|
||||
sc, _ := newTestSyncClient("http://localhost")
|
||||
|
||||
var version string
|
||||
sc.OnUpgrade = func(v string) { version = v }
|
||||
|
||||
sc.processResponse(&SyncResponse{
|
||||
Upgrade: &UpgradeSignal{Version: "2.0.0"},
|
||||
})
|
||||
|
||||
if version != "2.0.0" {
|
||||
t.Errorf("expected 2.0.0, got %s", version)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncClient_ProcessResponse_UpgradeEmpty(t *testing.T) {
|
||||
sc, _ := newTestSyncClient("http://localhost")
|
||||
|
||||
var called bool
|
||||
sc.OnUpgrade = func(v string) { called = true }
|
||||
|
||||
sc.processResponse(&SyncResponse{
|
||||
Upgrade: &UpgradeSignal{Version: ""},
|
||||
})
|
||||
|
||||
if called {
|
||||
t.Error("OnUpgrade should not be called with empty version")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncClient_ProcessResponse_Scan(t *testing.T) {
|
||||
sc, _ := newTestSyncClient("http://localhost")
|
||||
|
||||
var called bool
|
||||
sc.OnScan = func() { called = true }
|
||||
|
||||
sc.processResponse(&SyncResponse{Scan: true})
|
||||
|
||||
if !called {
|
||||
t.Error("OnScan should have been called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncClient_ProcessResponse_StreamRequests(t *testing.T) {
|
||||
sc, _ := newTestSyncClient("http://localhost")
|
||||
|
||||
var received []StreamRequest
|
||||
sc.OnStreamRequest = func(sr StreamRequest) { received = append(received, sr) }
|
||||
|
||||
sc.processResponse(&SyncResponse{
|
||||
StreamRequests: []StreamRequest{
|
||||
{TaskID: "t1", FilePath: "/tmp/movie.mkv"},
|
||||
},
|
||||
})
|
||||
|
||||
if len(received) != 1 {
|
||||
t.Fatalf("expected 1 stream request, got %d", len(received))
|
||||
}
|
||||
if received[0].FilePath != "/tmp/movie.mkv" {
|
||||
t.Errorf("expected /tmp/movie.mkv, got %s", received[0].FilePath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncClient_BuildRequest_WithGetTaskStates(t *testing.T) {
|
||||
sc, _ := newTestSyncClient("http://localhost")
|
||||
|
||||
sc.GetTaskStates = func() []TaskState {
|
||||
return []TaskState{
|
||||
{TaskID: "t1", Status: "downloading", Progress: 50},
|
||||
}
|
||||
}
|
||||
sc.GetFreeSlots = func() int { return 2 }
|
||||
|
||||
req := sc.buildRequest()
|
||||
|
||||
if req.AgentID != "test-agent" {
|
||||
t.Errorf("expected test-agent, got %s", req.AgentID)
|
||||
}
|
||||
if len(req.Tasks) != 1 {
|
||||
t.Fatalf("expected 1 task, got %d", len(req.Tasks))
|
||||
}
|
||||
if req.Tasks[0].Progress != 50 {
|
||||
t.Errorf("expected progress 50, got %d", req.Tasks[0].Progress)
|
||||
}
|
||||
if req.FreeSlots != 2 {
|
||||
t.Errorf("expected 2 free slots, got %d", req.FreeSlots)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncClient_BuildRequest_FallbackToState(t *testing.T) {
|
||||
client := NewClient("http://localhost", "key", "ua")
|
||||
state := NewLocalState()
|
||||
state.Update(TaskState{TaskID: "t1", Status: "completed", Progress: 100})
|
||||
|
||||
sc := NewSyncClient(client, DaemonConfig{AgentID: "a1", Version: "1.0"}, state)
|
||||
// GetTaskStates is nil — should fall back to state.Snapshot()
|
||||
|
||||
req := sc.buildRequest()
|
||||
if len(req.Tasks) != 1 {
|
||||
t.Fatalf("expected 1 task from state fallback, got %d", len(req.Tasks))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncClient_DoSync_Success(t *testing.T) {
|
||||
var syncCount atomic.Int32
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
syncCount.Add(1)
|
||||
json.NewEncoder(w).Encode(SyncResponse{
|
||||
Watching: true,
|
||||
NewTasks: []Task{{ID: "t1", Title: "Test Movie", InfoHash: "abc"}},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
sc, _ := newTestSyncClient(srv.URL)
|
||||
|
||||
var tasksReceived []Task
|
||||
sc.OnNewTasks = func(tasks []Task) { tasksReceived = tasks }
|
||||
|
||||
sc.doSync(context.Background())
|
||||
|
||||
if syncCount.Load() != 1 {
|
||||
t.Errorf("expected 1 sync call, got %d", syncCount.Load())
|
||||
}
|
||||
if len(tasksReceived) != 1 {
|
||||
t.Fatalf("expected 1 task, got %d", len(tasksReceived))
|
||||
}
|
||||
if !sc.Watching() {
|
||||
t.Error("expected watching=true after sync")
|
||||
}
|
||||
if sc.currentInterval() != SyncIntervalWatching {
|
||||
t.Errorf("expected watching interval after sync")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncClient_DoSync_Error(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
sc, _ := newTestSyncClient(srv.URL)
|
||||
|
||||
// Should not panic on error
|
||||
sc.doSync(context.Background())
|
||||
}
|
||||
|
||||
func TestSyncClient_Run_CancelStopsLoop(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(SyncResponse{})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
sc, _ := newTestSyncClient(srv.URL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
err := sc.Run(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("expected nil error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncClient_Run_ImmediateSyncOnTrigger(t *testing.T) {
|
||||
var syncCount atomic.Int32
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
syncCount.Add(1)
|
||||
json.NewEncoder(w).Encode(SyncResponse{})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
sc, _ := newTestSyncClient(srv.URL)
|
||||
// Set interval to something long so only triggers cause syncs
|
||||
sc.interval.Store(int64(10 * time.Second))
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
go func() {
|
||||
// Wait for initial sync, then trigger 2 more
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
sc.TriggerSync()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
sc.TriggerSync()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
sc.Run(ctx)
|
||||
|
||||
// Initial sync (1) + 2 triggers + final sync = 4
|
||||
count := syncCount.Load()
|
||||
if count < 3 {
|
||||
t.Errorf("expected at least 3 syncs (initial + 2 triggers), got %d", count)
|
||||
}
|
||||
}
|
||||
136
internal/agent/taskstate.go
Normal file
136
internal/agent/taskstate.go
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/torrentclaw/unarr/internal/config"
|
||||
)
|
||||
|
||||
// TaskState represents the execution state of a single download task.
|
||||
// Written by the Task Engine, read by the Sync goroutine.
|
||||
type TaskState struct {
|
||||
TaskID string `json:"taskId"`
|
||||
Status string `json:"status"` // resolving, downloading, verifying, organizing, completed, failed
|
||||
Progress int `json:"progress"`
|
||||
DownloadedBytes int64 `json:"downloadedBytes,omitempty"`
|
||||
TotalBytes int64 `json:"totalBytes,omitempty"`
|
||||
SpeedBps int64 `json:"speedBps,omitempty"`
|
||||
ETA int `json:"eta,omitempty"`
|
||||
ResolvedMethod string `json:"resolvedMethod,omitempty"`
|
||||
FileName string `json:"fileName,omitempty"`
|
||||
FilePath string `json:"filePath,omitempty"`
|
||||
StreamURL string `json:"streamUrl,omitempty"`
|
||||
ErrorMessage string `json:"errorMessage,omitempty"`
|
||||
UpdatedAt int64 `json:"updatedAt"`
|
||||
}
|
||||
|
||||
// LocalState holds the CLI's local execution state (tasks.json).
|
||||
// This is the CLI's source of truth for what it's doing right now.
|
||||
type LocalState struct {
|
||||
mu sync.RWMutex
|
||||
tasks map[string]*TaskState
|
||||
}
|
||||
|
||||
// NewLocalState creates an empty local state.
|
||||
func NewLocalState() *LocalState {
|
||||
return &LocalState{
|
||||
tasks: make(map[string]*TaskState),
|
||||
}
|
||||
}
|
||||
|
||||
// Update adds or updates a task in local state.
|
||||
func (s *LocalState) Update(ts TaskState) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
ts.UpdatedAt = time.Now().Unix()
|
||||
copied := ts
|
||||
s.tasks[ts.TaskID] = &copied
|
||||
}
|
||||
|
||||
// Remove removes a task from local state.
|
||||
func (s *LocalState) Remove(taskID string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.tasks, taskID)
|
||||
}
|
||||
|
||||
// Snapshot returns a copy of all current task states.
|
||||
func (s *LocalState) Snapshot() []TaskState {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
result := make([]TaskState, 0, len(s.tasks))
|
||||
for _, ts := range s.tasks {
|
||||
result = append(result, *ts)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// TaskStateFromUpdate converts a StatusUpdate into a TaskState.
|
||||
func TaskStateFromUpdate(u StatusUpdate) TaskState {
|
||||
return TaskState{
|
||||
TaskID: u.TaskID,
|
||||
Status: u.Status,
|
||||
Progress: u.Progress,
|
||||
DownloadedBytes: u.DownloadedBytes,
|
||||
TotalBytes: u.TotalBytes,
|
||||
SpeedBps: u.SpeedBps,
|
||||
ETA: u.ETA,
|
||||
ResolvedMethod: u.ResolvedMethod,
|
||||
FileName: u.FileName,
|
||||
FilePath: u.FilePath,
|
||||
StreamURL: u.StreamURL,
|
||||
ErrorMessage: u.ErrorMessage,
|
||||
}
|
||||
}
|
||||
|
||||
// ShortID returns the first 8 characters of an ID, or the full ID if shorter.
|
||||
func ShortID(id string) string {
|
||||
if len(id) > 8 {
|
||||
return id[:8]
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
// taskStateFilePathFn is overridable for testing.
|
||||
var taskStateFilePathFn = func() string {
|
||||
return filepath.Join(config.DataDir(), "tasks.json")
|
||||
}
|
||||
|
||||
// WriteToDisk persists local state to disk atomically (best-effort).
|
||||
func (s *LocalState) WriteToDisk() {
|
||||
tasks := s.Snapshot()
|
||||
data, err := json.MarshalIndent(tasks, "", " ")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
path := taskStateFilePathFn()
|
||||
dir := filepath.Dir(path)
|
||||
os.MkdirAll(dir, 0o755)
|
||||
tmp := path + ".tmp"
|
||||
if err := os.WriteFile(tmp, data, 0o644); err != nil {
|
||||
return
|
||||
}
|
||||
os.Rename(tmp, path)
|
||||
}
|
||||
|
||||
// ReadFromDisk loads local state from disk. Returns empty state on error.
|
||||
func (s *LocalState) ReadFromDisk() {
|
||||
data, err := os.ReadFile(taskStateFilePathFn())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var tasks []TaskState
|
||||
if json.Unmarshal(data, &tasks) != nil {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.tasks = make(map[string]*TaskState, len(tasks))
|
||||
for i := range tasks {
|
||||
s.tasks[tasks[i].TaskID] = &tasks[i]
|
||||
}
|
||||
}
|
||||
217
internal/agent/taskstate_test.go
Normal file
217
internal/agent/taskstate_test.go
Normal file
|
|
@ -0,0 +1,217 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLocalState_UpdateAndSnapshot(t *testing.T) {
|
||||
s := NewLocalState()
|
||||
|
||||
s.Update(TaskState{TaskID: "t1", Status: "downloading", Progress: 50})
|
||||
s.Update(TaskState{TaskID: "t2", Status: "completed", Progress: 100})
|
||||
|
||||
snap := s.Snapshot()
|
||||
if len(snap) != 2 {
|
||||
t.Fatalf("expected 2 tasks, got %d", len(snap))
|
||||
}
|
||||
|
||||
byID := make(map[string]TaskState, len(snap))
|
||||
for _, ts := range snap {
|
||||
byID[ts.TaskID] = ts
|
||||
}
|
||||
|
||||
if byID["t1"].Progress != 50 {
|
||||
t.Errorf("expected progress 50, got %d", byID["t1"].Progress)
|
||||
}
|
||||
if byID["t2"].Status != "completed" {
|
||||
t.Errorf("expected completed, got %s", byID["t2"].Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalState_UpdateOverwrites(t *testing.T) {
|
||||
s := NewLocalState()
|
||||
|
||||
s.Update(TaskState{TaskID: "t1", Status: "downloading", Progress: 30})
|
||||
s.Update(TaskState{TaskID: "t1", Status: "downloading", Progress: 70})
|
||||
|
||||
snap := s.Snapshot()
|
||||
if len(snap) != 1 {
|
||||
t.Fatalf("expected 1 task, got %d", len(snap))
|
||||
}
|
||||
if snap[0].Progress != 70 {
|
||||
t.Errorf("expected progress 70, got %d", snap[0].Progress)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalState_Remove(t *testing.T) {
|
||||
s := NewLocalState()
|
||||
|
||||
s.Update(TaskState{TaskID: "t1", Status: "downloading"})
|
||||
s.Update(TaskState{TaskID: "t2", Status: "downloading"})
|
||||
s.Remove("t1")
|
||||
|
||||
snap := s.Snapshot()
|
||||
if len(snap) != 1 {
|
||||
t.Fatalf("expected 1 task, got %d", len(snap))
|
||||
}
|
||||
if snap[0].TaskID != "t2" {
|
||||
t.Errorf("expected t2, got %s", snap[0].TaskID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalState_RemoveNonExistent(t *testing.T) {
|
||||
s := NewLocalState()
|
||||
s.Remove("nonexistent") // should not panic
|
||||
}
|
||||
|
||||
func TestLocalState_SnapshotIsACopy(t *testing.T) {
|
||||
s := NewLocalState()
|
||||
s.Update(TaskState{TaskID: "t1", Status: "downloading", Progress: 50})
|
||||
|
||||
snap := s.Snapshot()
|
||||
snap[0].Progress = 999
|
||||
|
||||
snap2 := s.Snapshot()
|
||||
if snap2[0].Progress != 50 {
|
||||
t.Errorf("snapshot mutation leaked: got progress %d", snap2[0].Progress)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalState_UpdateSetsTimestamp(t *testing.T) {
|
||||
s := NewLocalState()
|
||||
s.Update(TaskState{TaskID: "t1", Status: "downloading"})
|
||||
|
||||
snap := s.Snapshot()
|
||||
if snap[0].UpdatedAt == 0 {
|
||||
t.Error("expected non-zero UpdatedAt")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalState_ConcurrentAccess(t *testing.T) {
|
||||
s := NewLocalState()
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := range 100 {
|
||||
wg.Add(1)
|
||||
go func(n int) {
|
||||
defer wg.Done()
|
||||
taskID := "t" + string(rune('0'+n%10))
|
||||
s.Update(TaskState{TaskID: taskID, Status: "downloading", Progress: n})
|
||||
s.Snapshot()
|
||||
if n%3 == 0 {
|
||||
s.Remove(taskID)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
// No race condition = test passes
|
||||
}
|
||||
|
||||
func TestLocalState_WriteToDisk_ReadFromDisk(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "tasks.json")
|
||||
|
||||
// Override the file path for testing
|
||||
orig := taskStateFilePathFn
|
||||
taskStateFilePathFn = func() string { return path }
|
||||
defer func() { taskStateFilePathFn = orig }()
|
||||
|
||||
s := NewLocalState()
|
||||
s.Update(TaskState{TaskID: "t1", Status: "downloading", Progress: 45})
|
||||
s.Update(TaskState{TaskID: "t2", Status: "completed", Progress: 100, FilePath: "/tmp/movie.mkv"})
|
||||
s.WriteToDisk()
|
||||
|
||||
// Verify file exists
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
t.Fatal("tasks.json was not created")
|
||||
}
|
||||
|
||||
// Read into a new LocalState
|
||||
s2 := NewLocalState()
|
||||
s2.ReadFromDisk()
|
||||
|
||||
snap := s2.Snapshot()
|
||||
if len(snap) != 2 {
|
||||
t.Fatalf("expected 2 tasks after read, got %d", len(snap))
|
||||
}
|
||||
|
||||
byID := make(map[string]TaskState, len(snap))
|
||||
for _, ts := range snap {
|
||||
byID[ts.TaskID] = ts
|
||||
}
|
||||
|
||||
if byID["t1"].Progress != 45 {
|
||||
t.Errorf("expected progress 45, got %d", byID["t1"].Progress)
|
||||
}
|
||||
if byID["t2"].FilePath != "/tmp/movie.mkv" {
|
||||
t.Errorf("expected /tmp/movie.mkv, got %s", byID["t2"].FilePath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalState_ReadFromDisk_CorruptedFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "tasks.json")
|
||||
|
||||
orig := taskStateFilePathFn
|
||||
taskStateFilePathFn = func() string { return path }
|
||||
defer func() { taskStateFilePathFn = orig }()
|
||||
|
||||
// Write corrupted JSON
|
||||
os.WriteFile(path, []byte("{invalid json"), 0o644)
|
||||
|
||||
s := NewLocalState()
|
||||
s.ReadFromDisk() // should not panic
|
||||
|
||||
snap := s.Snapshot()
|
||||
if len(snap) != 0 {
|
||||
t.Errorf("expected 0 tasks from corrupted file, got %d", len(snap))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalState_ReadFromDisk_FileNotFound(t *testing.T) {
|
||||
orig := taskStateFilePathFn
|
||||
taskStateFilePathFn = func() string { return "/nonexistent/path/tasks.json" }
|
||||
defer func() { taskStateFilePathFn = orig }()
|
||||
|
||||
s := NewLocalState()
|
||||
s.ReadFromDisk() // should not panic
|
||||
|
||||
snap := s.Snapshot()
|
||||
if len(snap) != 0 {
|
||||
t.Errorf("expected 0 tasks, got %d", len(snap))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalState_AtomicWrite(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "tasks.json")
|
||||
|
||||
orig := taskStateFilePathFn
|
||||
taskStateFilePathFn = func() string { return path }
|
||||
defer func() { taskStateFilePathFn = orig }()
|
||||
|
||||
s := NewLocalState()
|
||||
s.Update(TaskState{TaskID: "t1", Status: "downloading"})
|
||||
s.WriteToDisk()
|
||||
|
||||
// Verify no .tmp file remains
|
||||
tmpPath := path + ".tmp"
|
||||
if _, err := os.Stat(tmpPath); !os.IsNotExist(err) {
|
||||
t.Error("temp file should not exist after write")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalState_EmptySnapshot(t *testing.T) {
|
||||
s := NewLocalState()
|
||||
snap := s.Snapshot()
|
||||
if snap == nil {
|
||||
t.Error("snapshot should be non-nil empty slice")
|
||||
}
|
||||
if len(snap) != 0 {
|
||||
t.Errorf("expected 0 tasks, got %d", len(snap))
|
||||
}
|
||||
}
|
||||
|
|
@ -1,51 +0,0 @@
|
|||
package agent
|
||||
|
||||
import "context"
|
||||
|
||||
// Transport abstracts the communication protocol between the agent and server.
|
||||
// Both WebSocket (via CF Durable Object) and HTTP (direct to origin) implement this.
|
||||
type Transport interface {
|
||||
// Connect establishes the transport connection.
|
||||
// Called internally by Daemon.Run — callers must NOT call Connect separately.
|
||||
Connect(ctx context.Context) error
|
||||
|
||||
// Close tears down the connection gracefully.
|
||||
Close() error
|
||||
|
||||
// Mode returns the current transport mode ("ws" or "http").
|
||||
Mode() string
|
||||
|
||||
// Register sends agent registration and returns user info + features.
|
||||
Register(ctx context.Context, req RegisterRequest) (*RegisterResponse, error)
|
||||
|
||||
// SendHeartbeat sends a periodic keep-alive.
|
||||
SendHeartbeat(ctx context.Context, req HeartbeatRequest) (*HeartbeatResponse, error)
|
||||
|
||||
// SendProgress reports download progress for a task.
|
||||
SendProgress(ctx context.Context, update StatusUpdate) (*StatusResponse, error)
|
||||
|
||||
// ClaimTasks polls for new tasks (HTTP mode only; WS receives via Events).
|
||||
ClaimTasks(ctx context.Context, agentID string) (*TasksResponse, error)
|
||||
|
||||
// Deregister notifies the server of graceful shutdown.
|
||||
Deregister(ctx context.Context, agentID string) error
|
||||
|
||||
// Events returns a channel that emits server-initiated events.
|
||||
// In HTTP mode this channel is never written to (polling handles it).
|
||||
// In WS mode, tasks/upgrade/control arrive here.
|
||||
Events() <-chan ServerEvent
|
||||
}
|
||||
|
||||
// ServerEvent represents a server-initiated message received via WebSocket.
|
||||
type ServerEvent struct {
|
||||
Type string // "tasks", "upgrade", "control", "disconnected"
|
||||
Tasks *TasksResponse // populated when Type == "tasks"
|
||||
Upgrade *UpgradeSignal // populated when Type == "upgrade"
|
||||
Control *ControlAction // populated when Type == "control"
|
||||
}
|
||||
|
||||
// ControlAction represents a server push for task control.
|
||||
type ControlAction struct {
|
||||
Action string `json:"action"` // "pause", "resume", "cancel", "stream"
|
||||
TaskID string `json:"taskId"`
|
||||
}
|
||||
|
|
@ -1,285 +0,0 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestE2EFullLifecycle tests the full lifecycle:
|
||||
// connect → auth → receive tasks → send progress → receive control → disconnect → reconnect
|
||||
func TestE2EFullLifecycle(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
var receivedMessages []map[string]interface{}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
for {
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var parsed map[string]interface{}
|
||||
json.Unmarshal(msg, &parsed)
|
||||
|
||||
mu.Lock()
|
||||
receivedMessages = append(receivedMessages, parsed)
|
||||
mu.Unlock()
|
||||
|
||||
msgType, _ := parsed["type"].(string)
|
||||
switch msgType {
|
||||
case "auth":
|
||||
conn.WriteJSON(wsRegisteredMessage{
|
||||
Type: "registered",
|
||||
User: UserInfo{Name: "E2E User", Plan: "pro", IsPro: true},
|
||||
Features: FeatureFlags{Torrent: true, Debrid: true},
|
||||
})
|
||||
|
||||
case "heartbeat":
|
||||
// No response in WS mode
|
||||
|
||||
case "progress":
|
||||
// Simulate server-side cancel after progress
|
||||
if progress, ok := parsed["progress"].(float64); ok && progress >= 50 {
|
||||
conn.WriteJSON(map[string]string{
|
||||
"type": "control",
|
||||
"action": "cancel",
|
||||
"taskId": parsed["taskId"].(string),
|
||||
})
|
||||
}
|
||||
|
||||
case "upgrade-result":
|
||||
// Acknowledged
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
|
||||
tr := NewWSTransport(wsURL, "e2e-key", "e2e-agent", "test/1.0")
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// 1. Connect
|
||||
if err := tr.Connect(ctx); err != nil {
|
||||
t.Fatalf("Connect: %v", err)
|
||||
}
|
||||
defer tr.Close()
|
||||
|
||||
// 2. Auth
|
||||
resp, err := tr.Register(ctx, RegisterRequest{
|
||||
AgentID: "e2e-agent",
|
||||
Name: "E2E Test Agent",
|
||||
Version: "1.0.0",
|
||||
OS: "linux",
|
||||
Arch: "amd64",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Register: %v", err)
|
||||
}
|
||||
if resp.User.Name != "E2E User" {
|
||||
t.Errorf("expected E2E User, got %s", resp.User.Name)
|
||||
}
|
||||
if !resp.Features.Debrid {
|
||||
t.Error("expected debrid feature")
|
||||
}
|
||||
|
||||
// 3. Send heartbeat
|
||||
_, err = tr.SendHeartbeat(ctx, HeartbeatRequest{
|
||||
AgentID: "e2e-agent",
|
||||
DiskFreeBytes: 1000000000,
|
||||
DiskTotalBytes: 5000000000,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SendHeartbeat: %v", err)
|
||||
}
|
||||
|
||||
// 4. Send progress (50% → should trigger cancel control)
|
||||
_, err = tr.SendProgress(ctx, StatusUpdate{
|
||||
TaskID: "task-e2e-1",
|
||||
Status: "downloading",
|
||||
Progress: 50,
|
||||
DownloadedBytes: 500,
|
||||
TotalBytes: 1000,
|
||||
SpeedBps: 100,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SendProgress: %v", err)
|
||||
}
|
||||
|
||||
// 5. Wait for control event (cancel)
|
||||
select {
|
||||
case event := <-tr.Events():
|
||||
if event.Type != "control" {
|
||||
t.Errorf("expected control event, got %s", event.Type)
|
||||
}
|
||||
if event.Control.Action != "cancel" {
|
||||
t.Errorf("expected cancel, got %s", event.Control.Action)
|
||||
}
|
||||
if event.Control.TaskID != "task-e2e-1" {
|
||||
t.Errorf("expected task-e2e-1, got %s", event.Control.TaskID)
|
||||
}
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("timeout waiting for cancel control")
|
||||
}
|
||||
|
||||
// Verify server received all messages
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if len(receivedMessages) < 3 {
|
||||
t.Fatalf("expected at least 3 messages, got %d", len(receivedMessages))
|
||||
}
|
||||
|
||||
types := make([]string, len(receivedMessages))
|
||||
for i, m := range receivedMessages {
|
||||
types[i], _ = m["type"].(string)
|
||||
}
|
||||
|
||||
expected := []string{"auth", "heartbeat", "progress"}
|
||||
for _, exp := range expected {
|
||||
found := false
|
||||
for _, got := range types {
|
||||
if got == exp {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("missing message type %q in %v", exp, types)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestE2EHybridFailover tests the full failover scenario:
|
||||
// WS connect → download → WS disconnect → switch to HTTP → continue working
|
||||
func TestE2EHybridFailover(t *testing.T) {
|
||||
connectionCount := 0
|
||||
var mu sync.Mutex
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
connectionCount++
|
||||
connNum := connectionCount
|
||||
mu.Unlock()
|
||||
|
||||
// Read auth
|
||||
conn.ReadMessage()
|
||||
conn.WriteJSON(wsRegisteredMessage{
|
||||
Type: "registered",
|
||||
User: UserInfo{Name: "Failover User"},
|
||||
})
|
||||
|
||||
if connNum == 1 {
|
||||
// First connection: push tasks then disconnect after 200ms
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
conn.WriteJSON(wsTasksMessage{
|
||||
Type: "tasks",
|
||||
Tasks: []Task{{ID: "t1", InfoHash: "abc", Title: "Failover Movie"}},
|
||||
})
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
conn.Close()
|
||||
} else {
|
||||
// Second connection (after reconnect): push upgrade
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
conn.WriteJSON(wsUpgradeMessage{Type: "upgrade", Version: "3.0.0"})
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
conn.Close()
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
|
||||
wsT := NewWSTransport(wsURL, "key", "a1", "ua")
|
||||
|
||||
// HTTP mock for fallback
|
||||
httpSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Simple heartbeat response
|
||||
json.NewEncoder(w).Encode(HeartbeatResponse{Success: true})
|
||||
}))
|
||||
defer httpSrv.Close()
|
||||
|
||||
httpT := NewHTTPTransport(httpSrv.URL, "key", "ua")
|
||||
h := NewHybridTransport(wsT, httpT)
|
||||
|
||||
ctx := context.Background()
|
||||
err := h.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Connect: %v", err)
|
||||
}
|
||||
defer h.Close()
|
||||
|
||||
// Should start in WS mode
|
||||
if h.Mode() != "ws" {
|
||||
t.Fatalf("expected ws mode, got %s", h.Mode())
|
||||
}
|
||||
|
||||
// Register via WS
|
||||
_, err = h.Register(ctx, RegisterRequest{AgentID: "a1"})
|
||||
if err != nil {
|
||||
t.Fatalf("Register: %v", err)
|
||||
}
|
||||
|
||||
// Receive tasks via WS
|
||||
var tasksReceived bool
|
||||
var disconnected bool
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
select {
|
||||
case event := <-h.Events():
|
||||
switch event.Type {
|
||||
case "tasks":
|
||||
tasksReceived = true
|
||||
if len(event.Tasks.Tasks) != 1 || event.Tasks.Tasks[0].Title != "Failover Movie" {
|
||||
t.Errorf("unexpected tasks: %+v", event.Tasks)
|
||||
}
|
||||
case "disconnected":
|
||||
disconnected = true
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
break
|
||||
}
|
||||
if disconnected {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !tasksReceived {
|
||||
t.Error("did not receive tasks before disconnect")
|
||||
}
|
||||
if !disconnected {
|
||||
t.Error("did not receive disconnect event")
|
||||
}
|
||||
|
||||
// Should now be in HTTP mode
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
if h.Mode() != "http" {
|
||||
t.Errorf("expected http mode after disconnect, got %s", h.Mode())
|
||||
}
|
||||
|
||||
// Heartbeat should work via HTTP fallback
|
||||
hbResp, err := h.SendHeartbeat(ctx, HeartbeatRequest{AgentID: "a1"})
|
||||
if err != nil {
|
||||
t.Fatalf("SendHeartbeat via HTTP fallback: %v", err)
|
||||
}
|
||||
if !hbResp.Success {
|
||||
t.Error("expected heartbeat success")
|
||||
}
|
||||
}
|
||||
|
|
@ -1,50 +0,0 @@
|
|||
package agent
|
||||
|
||||
import "context"
|
||||
|
||||
// HTTPTransport wraps the existing Client to implement Transport.
|
||||
// This is a thin adapter — no behavioral changes from the current HTTP protocol.
|
||||
type HTTPTransport struct {
|
||||
client *Client
|
||||
events chan ServerEvent
|
||||
}
|
||||
|
||||
// NewHTTPTransport creates a new HTTP-based transport.
|
||||
func NewHTTPTransport(baseURL, apiKey, userAgent string) *HTTPTransport {
|
||||
return &HTTPTransport{
|
||||
client: NewClient(baseURL, apiKey, userAgent),
|
||||
events: make(chan ServerEvent, 10),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *HTTPTransport) Connect(_ context.Context) error { return nil }
|
||||
func (t *HTTPTransport) Close() error { return nil }
|
||||
func (t *HTTPTransport) Mode() string { return "http" }
|
||||
func (t *HTTPTransport) Events() <-chan ServerEvent { return t.events }
|
||||
|
||||
func (t *HTTPTransport) Register(ctx context.Context, req RegisterRequest) (*RegisterResponse, error) {
|
||||
return t.client.Register(ctx, req)
|
||||
}
|
||||
|
||||
func (t *HTTPTransport) SendHeartbeat(ctx context.Context, req HeartbeatRequest) (*HeartbeatResponse, error) {
|
||||
return t.client.Heartbeat(ctx, req)
|
||||
}
|
||||
|
||||
func (t *HTTPTransport) SendProgress(ctx context.Context, update StatusUpdate) (*StatusResponse, error) {
|
||||
return t.client.ReportStatus(ctx, update)
|
||||
}
|
||||
|
||||
func (t *HTTPTransport) BatchReportStatus(ctx context.Context, updates []StatusUpdate) (*BatchStatusResponse, error) {
|
||||
return t.client.BatchReportStatus(ctx, updates)
|
||||
}
|
||||
|
||||
func (t *HTTPTransport) ClaimTasks(ctx context.Context, agentID string) (*TasksResponse, error) {
|
||||
return t.client.ClaimTasks(ctx, agentID)
|
||||
}
|
||||
|
||||
func (t *HTTPTransport) Deregister(ctx context.Context, agentID string) error {
|
||||
return t.client.Deregister(ctx, agentID)
|
||||
}
|
||||
|
||||
// Client returns the underlying HTTP client for direct use if needed.
|
||||
func (t *HTTPTransport) Client() *Client { return t.client }
|
||||
|
|
@ -1,214 +0,0 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HybridTransport tries WebSocket first, falls back to HTTP if WS fails.
|
||||
// Automatically reconnects WS in the background.
|
||||
type HybridTransport struct {
|
||||
ws *WSTransport
|
||||
http *HTTPTransport
|
||||
|
||||
mode atomic.Value // "ws" or "http"
|
||||
events chan ServerEvent
|
||||
|
||||
reconnectMu sync.Mutex
|
||||
reconnectRunning bool
|
||||
reconnectStop chan struct{}
|
||||
closed atomic.Bool
|
||||
}
|
||||
|
||||
// NewHybridTransport creates a transport that prefers WS with HTTP fallback.
|
||||
func NewHybridTransport(ws *WSTransport, http *HTTPTransport) *HybridTransport {
|
||||
h := &HybridTransport{
|
||||
ws: ws,
|
||||
http: http,
|
||||
events: make(chan ServerEvent, 50),
|
||||
reconnectStop: make(chan struct{}),
|
||||
}
|
||||
h.mode.Store("http") // start in HTTP, upgrade to WS on Connect
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *HybridTransport) Mode() string { return h.mode.Load().(string) }
|
||||
func (h *HybridTransport) Events() <-chan ServerEvent { return h.events }
|
||||
|
||||
// Connect tries WS first. If it fails, falls back to HTTP and starts reconnection loop.
|
||||
func (h *HybridTransport) Connect(ctx context.Context) error {
|
||||
// Try WebSocket first
|
||||
if err := h.ws.Connect(ctx); err != nil {
|
||||
log.Printf("[transport] WebSocket connect failed (%v), using HTTP fallback", err)
|
||||
h.mode.Store("http")
|
||||
h.startReconnectLoop()
|
||||
return h.http.Connect(ctx)
|
||||
}
|
||||
|
||||
h.mode.Store("ws")
|
||||
log.Println("[transport] Connected via WebSocket")
|
||||
|
||||
// Forward WS events to unified channel + watch for disconnection
|
||||
go h.forwardWSEvents()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close shuts down both transports and stops reconnection.
|
||||
func (h *HybridTransport) Close() error {
|
||||
h.closed.Store(true)
|
||||
select {
|
||||
case <-h.reconnectStop:
|
||||
default:
|
||||
close(h.reconnectStop)
|
||||
}
|
||||
_ = h.ws.Close()
|
||||
return h.http.Close()
|
||||
}
|
||||
|
||||
// Register delegates to the active transport.
|
||||
func (h *HybridTransport) Register(ctx context.Context, req RegisterRequest) (*RegisterResponse, error) {
|
||||
if h.mode.Load() == "ws" {
|
||||
return h.ws.Register(ctx, req)
|
||||
}
|
||||
return h.http.Register(ctx, req)
|
||||
}
|
||||
|
||||
// SendHeartbeat delegates to the active transport.
|
||||
func (h *HybridTransport) SendHeartbeat(ctx context.Context, req HeartbeatRequest) (*HeartbeatResponse, error) {
|
||||
if h.mode.Load() == "ws" {
|
||||
resp, err := h.ws.SendHeartbeat(ctx, req)
|
||||
if err != nil {
|
||||
// WS write failed — switch to HTTP
|
||||
h.switchToHTTP()
|
||||
return h.http.SendHeartbeat(ctx, req)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
return h.http.SendHeartbeat(ctx, req)
|
||||
}
|
||||
|
||||
// SendProgress delegates to the active transport.
|
||||
func (h *HybridTransport) SendProgress(ctx context.Context, update StatusUpdate) (*StatusResponse, error) {
|
||||
if h.mode.Load() == "ws" {
|
||||
resp, err := h.ws.SendProgress(ctx, update)
|
||||
if err != nil {
|
||||
h.switchToHTTP()
|
||||
return h.http.SendProgress(ctx, update)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
return h.http.SendProgress(ctx, update)
|
||||
}
|
||||
|
||||
// ClaimTasks delegates to the active transport.
|
||||
func (h *HybridTransport) ClaimTasks(ctx context.Context, agentID string) (*TasksResponse, error) {
|
||||
if h.mode.Load() == "ws" {
|
||||
return h.ws.ClaimTasks(ctx, agentID) // no-op in WS mode
|
||||
}
|
||||
return h.http.ClaimTasks(ctx, agentID)
|
||||
}
|
||||
|
||||
// Deregister delegates to the active transport.
|
||||
func (h *HybridTransport) Deregister(ctx context.Context, agentID string) error {
|
||||
if h.mode.Load() == "ws" {
|
||||
return h.ws.Deregister(ctx, agentID)
|
||||
}
|
||||
return h.http.Deregister(ctx, agentID)
|
||||
}
|
||||
|
||||
// ── Internal ─────────────────────────────────────────────────────────────────
|
||||
|
||||
func (h *HybridTransport) switchToHTTP() {
|
||||
if h.mode.Load() == "http" {
|
||||
return
|
||||
}
|
||||
log.Println("[transport] Switching to HTTP fallback")
|
||||
h.mode.Store("http")
|
||||
_ = h.ws.Close()
|
||||
h.startReconnectLoop()
|
||||
}
|
||||
|
||||
func (h *HybridTransport) forwardWSEvents() {
|
||||
for {
|
||||
select {
|
||||
case <-h.reconnectStop:
|
||||
return
|
||||
case event, ok := <-h.ws.Events():
|
||||
if !ok {
|
||||
return // channel closed
|
||||
}
|
||||
if event.Type == "disconnected" {
|
||||
h.switchToHTTP()
|
||||
select {
|
||||
case h.events <- event:
|
||||
default:
|
||||
}
|
||||
return
|
||||
}
|
||||
select {
|
||||
case h.events <- event:
|
||||
default:
|
||||
log.Printf("[transport] events channel full, dropping %s event", event.Type)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *HybridTransport) startReconnectLoop() {
|
||||
h.reconnectMu.Lock()
|
||||
defer h.reconnectMu.Unlock()
|
||||
if h.reconnectRunning {
|
||||
return
|
||||
}
|
||||
h.reconnectRunning = true
|
||||
go h.reconnectLoop()
|
||||
}
|
||||
|
||||
func (h *HybridTransport) reconnectLoop() {
|
||||
backoff := 5 * time.Second
|
||||
maxBackoff := 60 * time.Second
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-h.reconnectStop:
|
||||
return
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
|
||||
if h.closed.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
// Already on WS? (someone else reconnected)
|
||||
if h.mode.Load() == "ws" {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
err := h.ws.Connect(ctx)
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
log.Printf("[transport] WS reconnect failed: %v (retry in %v)", err, backoff)
|
||||
backoff = min(backoff*2, maxBackoff)
|
||||
continue
|
||||
}
|
||||
|
||||
// WS reconnected — switch back
|
||||
log.Println("[transport] WebSocket reconnected")
|
||||
h.mode.Store("ws")
|
||||
|
||||
// Reset reconnect flag so loop can start again if WS drops
|
||||
h.reconnectMu.Lock()
|
||||
h.reconnectRunning = false
|
||||
h.reconnectMu.Unlock()
|
||||
|
||||
// Forward events from new WS connection
|
||||
go h.forwardWSEvents()
|
||||
return
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,395 +0,0 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// WSTransport communicates with the server via WebSocket through a Cloudflare Durable Object.
|
||||
type WSTransport struct {
|
||||
wsURL string // wss://unarr.torrentclaw.com/ws/{agentId}
|
||||
apiKey string
|
||||
agentID string
|
||||
userAgent string
|
||||
|
||||
conn *websocket.Conn
|
||||
mu sync.Mutex
|
||||
events chan ServerEvent
|
||||
closed atomic.Bool
|
||||
|
||||
// Cached auth response from the DO
|
||||
authResp *RegisterResponse
|
||||
authMu sync.Mutex
|
||||
authDone chan struct{}
|
||||
authDoneOnce sync.Once
|
||||
}
|
||||
|
||||
// NewWSTransport creates a WebSocket-based transport.
|
||||
func NewWSTransport(wsURL, apiKey, agentID, userAgent string) *WSTransport {
|
||||
return &WSTransport{
|
||||
wsURL: wsURL,
|
||||
apiKey: apiKey,
|
||||
agentID: agentID,
|
||||
userAgent: userAgent,
|
||||
events: make(chan ServerEvent, 50),
|
||||
authDone: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *WSTransport) Mode() string { return "ws" }
|
||||
func (t *WSTransport) Events() <-chan ServerEvent { return t.events }
|
||||
|
||||
// Connect dials the WebSocket server and starts the read loop.
|
||||
func (t *WSTransport) Connect(ctx context.Context) error {
|
||||
dialer := websocket.Dialer{
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
header := http.Header{}
|
||||
header.Set("User-Agent", t.userAgent)
|
||||
|
||||
// Append API key as query param for auth on WS upgrade
|
||||
wsURLWithKey := t.wsURL
|
||||
if t.apiKey != "" {
|
||||
sep := "?"
|
||||
if strings.Contains(wsURLWithKey, "?") {
|
||||
sep = "&"
|
||||
}
|
||||
wsURLWithKey += sep + "key=" + t.apiKey
|
||||
}
|
||||
|
||||
conn, wsResp, err := dialer.DialContext(ctx, wsURLWithKey, header)
|
||||
if wsResp != nil && wsResp.Body != nil {
|
||||
defer wsResp.Body.Close()
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("ws dial: %w", err)
|
||||
}
|
||||
|
||||
t.mu.Lock()
|
||||
t.conn = conn
|
||||
t.closed.Store(false)
|
||||
t.authDone = make(chan struct{})
|
||||
t.authDoneOnce = sync.Once{}
|
||||
t.mu.Unlock()
|
||||
|
||||
go t.readLoop(conn)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close sends a close frame and shuts down the connection.
|
||||
func (t *WSTransport) Close() error {
|
||||
t.closed.Store(true)
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
if t.conn != nil {
|
||||
_ = t.conn.WriteMessage(
|
||||
websocket.CloseMessage,
|
||||
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""),
|
||||
)
|
||||
err := t.conn.Close()
|
||||
t.conn = nil
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Register sends auth message and waits for the registered response.
|
||||
func (t *WSTransport) Register(ctx context.Context, req RegisterRequest) (*RegisterResponse, error) {
|
||||
msg := wsAuthMessage{
|
||||
Type: "auth",
|
||||
APIKey: t.apiKey,
|
||||
AgentID: req.AgentID,
|
||||
Name: req.Name,
|
||||
OS: req.OS,
|
||||
Arch: req.Arch,
|
||||
Version: req.Version,
|
||||
DownloadDir: req.DownloadDir,
|
||||
DiskFreeBytes: req.DiskFreeBytes,
|
||||
DiskTotalBytes: req.DiskTotalBytes,
|
||||
}
|
||||
|
||||
if err := t.send(msg); err != nil {
|
||||
return nil, fmt.Errorf("ws auth send: %w", err)
|
||||
}
|
||||
|
||||
// Wait for the auth response or context cancellation
|
||||
select {
|
||||
case <-t.authDone:
|
||||
t.authMu.Lock()
|
||||
resp := t.authResp
|
||||
t.authMu.Unlock()
|
||||
if resp == nil {
|
||||
return nil, fmt.Errorf("ws auth: no response received")
|
||||
}
|
||||
return resp, nil
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(15 * time.Second):
|
||||
return nil, fmt.Errorf("ws auth: timeout waiting for registered response")
|
||||
}
|
||||
}
|
||||
|
||||
// SendHeartbeat sends a heartbeat message. No blocking response in WS mode.
|
||||
func (t *WSTransport) SendHeartbeat(_ context.Context, req HeartbeatRequest) (*HeartbeatResponse, error) {
|
||||
msg := struct {
|
||||
Type string `json:"type"`
|
||||
Disk *struct {
|
||||
Free int64 `json:"free"`
|
||||
Total int64 `json:"total"`
|
||||
} `json:"disk,omitempty"`
|
||||
}{Type: "heartbeat"}
|
||||
|
||||
if req.DiskFreeBytes > 0 || req.DiskTotalBytes > 0 {
|
||||
msg.Disk = &struct {
|
||||
Free int64 `json:"free"`
|
||||
Total int64 `json:"total"`
|
||||
}{Free: req.DiskFreeBytes, Total: req.DiskTotalBytes}
|
||||
}
|
||||
|
||||
if err := t.send(msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// WS mode: heartbeat is fire-and-forget. Upgrade signals arrive via Events().
|
||||
return &HeartbeatResponse{Success: true}, nil
|
||||
}
|
||||
|
||||
// SendProgress sends a progress update. Control signals arrive async via Events().
|
||||
func (t *WSTransport) SendProgress(_ context.Context, update StatusUpdate) (*StatusResponse, error) {
|
||||
msg := struct {
|
||||
Type string `json:"type"`
|
||||
TaskID string `json:"taskId"`
|
||||
Status string `json:"status,omitempty"`
|
||||
Progress int `json:"progress,omitempty"`
|
||||
DownloadedBytes int64 `json:"downloadedBytes,omitempty"`
|
||||
TotalBytes int64 `json:"totalBytes,omitempty"`
|
||||
SpeedBps int64 `json:"speedBps,omitempty"`
|
||||
ETA int `json:"eta,omitempty"`
|
||||
ResolvedMethod string `json:"resolvedMethod,omitempty"`
|
||||
FileName string `json:"fileName,omitempty"`
|
||||
FilePath string `json:"filePath,omitempty"`
|
||||
StreamURL string `json:"streamUrl,omitempty"`
|
||||
StreamReady bool `json:"streamReady,omitempty"`
|
||||
ErrorMessage string `json:"errorMessage,omitempty"`
|
||||
}{
|
||||
Type: "progress",
|
||||
TaskID: update.TaskID,
|
||||
Status: update.Status,
|
||||
Progress: update.Progress,
|
||||
DownloadedBytes: update.DownloadedBytes,
|
||||
TotalBytes: update.TotalBytes,
|
||||
SpeedBps: update.SpeedBps,
|
||||
ETA: update.ETA,
|
||||
ResolvedMethod: update.ResolvedMethod,
|
||||
FileName: update.FileName,
|
||||
FilePath: update.FilePath,
|
||||
StreamURL: update.StreamURL,
|
||||
StreamReady: update.StreamReady,
|
||||
ErrorMessage: update.ErrorMessage,
|
||||
}
|
||||
|
||||
if err := t.send(msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// In WS mode, control signals come via Events(), not in the progress response.
|
||||
return &StatusResponse{Success: true}, nil
|
||||
}
|
||||
|
||||
// ClaimTasks is a no-op in WS mode — tasks arrive via Events().
|
||||
func (t *WSTransport) ClaimTasks(_ context.Context, _ string) (*TasksResponse, error) {
|
||||
return &TasksResponse{}, nil
|
||||
}
|
||||
|
||||
// Deregister is handled by WebSocket close (DO detects disconnection).
|
||||
func (t *WSTransport) Deregister(_ context.Context, _ string) error {
|
||||
return t.Close()
|
||||
}
|
||||
|
||||
// ── Internal ─────────────────────────────────────────────────────────────────
|
||||
|
||||
func (t *WSTransport) send(msg any) error {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
if t.conn == nil {
|
||||
return fmt.Errorf("ws: not connected")
|
||||
}
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_ = t.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
return t.conn.WriteMessage(websocket.TextMessage, data)
|
||||
}
|
||||
|
||||
func (t *WSTransport) readLoop(conn *websocket.Conn) {
|
||||
// Cloudflare idle timeout is 100s. We send pings every 30s and expect
|
||||
// either a pong or a server message within 45s. If neither arrives,
|
||||
// the read deadline fires and we detect the zombie connection.
|
||||
const (
|
||||
pongWait = 45 * time.Second
|
||||
pingPeriod = 30 * time.Second
|
||||
)
|
||||
|
||||
_ = conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
conn.SetPongHandler(func(string) error {
|
||||
_ = conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
return nil
|
||||
})
|
||||
|
||||
// Ping ticker goroutine — stops when readLoop returns.
|
||||
pingDone := make(chan struct{})
|
||||
go func() {
|
||||
ticker := time.NewTicker(pingPeriod)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
t.mu.Lock()
|
||||
if t.conn != nil {
|
||||
_ = t.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
err := t.conn.WriteMessage(websocket.PingMessage, nil)
|
||||
_ = t.conn.SetWriteDeadline(time.Time{})
|
||||
if err != nil {
|
||||
t.mu.Unlock()
|
||||
return
|
||||
}
|
||||
}
|
||||
t.mu.Unlock()
|
||||
case <-pingDone:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
defer close(pingDone)
|
||||
|
||||
for {
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
if !t.closed.Load() {
|
||||
log.Printf("[ws] read error: %v", err)
|
||||
// Signal disconnection to the daemon
|
||||
select {
|
||||
case t.events <- ServerEvent{Type: "disconnected"}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Any message (text or pong) proves the connection is alive.
|
||||
_ = conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
|
||||
var envelope struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
if err := json.Unmarshal(msg, &envelope); err != nil {
|
||||
log.Printf("[ws] invalid message: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
switch envelope.Type {
|
||||
case "registered":
|
||||
var resp wsRegisteredMessage
|
||||
if json.Unmarshal(msg, &resp) == nil {
|
||||
t.authMu.Lock()
|
||||
t.authResp = &RegisterResponse{
|
||||
Success: true,
|
||||
User: resp.User,
|
||||
Features: resp.Features,
|
||||
}
|
||||
t.authMu.Unlock()
|
||||
// Signal that auth is complete (sync.Once prevents double-close panic)
|
||||
t.authDoneOnce.Do(func() { close(t.authDone) })
|
||||
}
|
||||
|
||||
case "tasks":
|
||||
var resp wsTasksMessage
|
||||
if json.Unmarshal(msg, &resp) == nil {
|
||||
select {
|
||||
case t.events <- ServerEvent{
|
||||
Type: "tasks",
|
||||
Tasks: &TasksResponse{
|
||||
Tasks: resp.Tasks,
|
||||
StreamRequests: resp.StreamRequests,
|
||||
},
|
||||
}:
|
||||
default:
|
||||
log.Printf("[ws] events channel full, dropping tasks message")
|
||||
}
|
||||
}
|
||||
|
||||
case "upgrade":
|
||||
var resp wsUpgradeMessage
|
||||
if json.Unmarshal(msg, &resp) == nil {
|
||||
select {
|
||||
case t.events <- ServerEvent{
|
||||
Type: "upgrade",
|
||||
Upgrade: &UpgradeSignal{Version: resp.Version},
|
||||
}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
case "control":
|
||||
var resp ControlAction
|
||||
if json.Unmarshal(msg, &resp) == nil {
|
||||
select {
|
||||
case t.events <- ServerEvent{
|
||||
Type: "control",
|
||||
Control: &resp,
|
||||
}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
case "error":
|
||||
var resp struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
if json.Unmarshal(msg, &resp) == nil {
|
||||
log.Printf("[ws] server error: %s", resp.Message)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── WS message types ─────────────────────────────────────────────────────────
|
||||
|
||||
type wsAuthMessage struct {
|
||||
Type string `json:"type"`
|
||||
APIKey string `json:"apiKey"`
|
||||
AgentID string `json:"agentId"`
|
||||
Name string `json:"name,omitempty"`
|
||||
OS string `json:"os,omitempty"`
|
||||
Arch string `json:"arch,omitempty"`
|
||||
Version string `json:"version,omitempty"`
|
||||
DownloadDir string `json:"downloadDir,omitempty"`
|
||||
DiskFreeBytes int64 `json:"diskFreeBytes,omitempty"`
|
||||
DiskTotalBytes int64 `json:"diskTotalBytes,omitempty"`
|
||||
}
|
||||
|
||||
type wsRegisteredMessage struct {
|
||||
Type string `json:"type"`
|
||||
User UserInfo `json:"user"`
|
||||
Features FeatureFlags `json:"features"`
|
||||
}
|
||||
|
||||
type wsTasksMessage struct {
|
||||
Type string `json:"type"`
|
||||
Tasks []Task `json:"tasks"`
|
||||
StreamRequests []StreamRequest `json:"streamRequests,omitempty"`
|
||||
}
|
||||
|
||||
type wsUpgradeMessage struct {
|
||||
Type string `json:"type"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
|
@ -50,20 +50,6 @@ type UsenetServerInfo struct {
|
|||
SSL bool `json:"ssl"`
|
||||
}
|
||||
|
||||
// HeartbeatRequest is sent every 30s to keep the agent alive.
|
||||
type HeartbeatRequest struct {
|
||||
AgentID string `json:"agentId"`
|
||||
Name string `json:"name,omitempty"`
|
||||
OS string `json:"os,omitempty"`
|
||||
Version string `json:"version,omitempty"`
|
||||
DownloadDir string `json:"downloadDir,omitempty"`
|
||||
DiskFreeBytes int64 `json:"diskFreeBytes,omitempty"`
|
||||
DiskTotalBytes int64 `json:"diskTotalBytes,omitempty"`
|
||||
StreamPort int `json:"streamPort,omitempty"`
|
||||
LanIP string `json:"lanIp,omitempty"`
|
||||
TailscaleIP string `json:"tailscaleIp,omitempty"`
|
||||
}
|
||||
|
||||
// Task represents a download task claimed from the server.
|
||||
type Task struct {
|
||||
ID string `json:"id"`
|
||||
|
|
@ -88,12 +74,6 @@ type Task struct {
|
|||
CollectionName string `json:"collectionName,omitempty"` // Collection name (e.g., "Harry Potter Collection")
|
||||
}
|
||||
|
||||
// TasksResponse wraps the array of tasks returned by the server.
|
||||
type TasksResponse struct {
|
||||
Tasks []Task `json:"tasks"`
|
||||
StreamRequests []StreamRequest `json:"streamRequests,omitempty"`
|
||||
}
|
||||
|
||||
// StreamRequest is a request to stream a completed download from disk.
|
||||
type StreamRequest struct {
|
||||
TaskID string `json:"taskId"`
|
||||
|
|
@ -139,14 +119,6 @@ type BatchStatusResponse struct {
|
|||
Watching bool `json:"watching,omitempty"`
|
||||
}
|
||||
|
||||
// HeartbeatResponse is returned by the server on heartbeat.
|
||||
type HeartbeatResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Upgrade *UpgradeSignal `json:"upgrade,omitempty"`
|
||||
Watching bool `json:"watching,omitempty"` // true when a user is viewing download progress in the web UI
|
||||
Scan bool `json:"scan,omitempty"` // true when user triggered a library scan from the web UI
|
||||
}
|
||||
|
||||
// UpgradeSignal tells the agent to upgrade to a specific version.
|
||||
type UpgradeSignal struct {
|
||||
Version string `json:"version"`
|
||||
|
|
@ -176,7 +148,6 @@ type AgentInfo struct {
|
|||
User UserInfo
|
||||
Features FeatureFlags
|
||||
StartedAt time.Time
|
||||
LastPollAt time.Time
|
||||
ActiveTasks int
|
||||
}
|
||||
|
||||
|
|
@ -334,6 +305,45 @@ type LibrarySyncResponse struct {
|
|||
Removed int `json:"removed"`
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Sync types (unified CLI ↔ Server communication)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// SyncRequest is sent by the CLI periodically to synchronize state with the server.
|
||||
// Contains the CLI's full execution state — the server responds with pending actions.
|
||||
type SyncRequest struct {
|
||||
AgentID string `json:"agentId"`
|
||||
Version string `json:"version,omitempty"`
|
||||
OS string `json:"os,omitempty"`
|
||||
Arch string `json:"arch,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
DownloadDir string `json:"downloadDir,omitempty"`
|
||||
DiskFreeBytes int64 `json:"diskFreeBytes,omitempty"`
|
||||
DiskTotalBytes int64 `json:"diskTotalBytes,omitempty"`
|
||||
StreamPort int `json:"streamPort,omitempty"`
|
||||
LanIP string `json:"lanIp,omitempty"`
|
||||
TailscaleIP string `json:"tailscaleIp,omitempty"`
|
||||
FreeSlots int `json:"freeSlots"`
|
||||
Tasks []TaskState `json:"tasks"`
|
||||
}
|
||||
|
||||
// ControlAction represents a server-side control signal for a task.
|
||||
type ControlAction struct {
|
||||
Action string `json:"action"` // "pause", "resume", "cancel", "stream"
|
||||
TaskID string `json:"taskId"`
|
||||
DeleteFiles bool `json:"deleteFiles,omitempty"`
|
||||
}
|
||||
|
||||
// SyncResponse is returned by the server with all pending actions for the CLI.
|
||||
type SyncResponse struct {
|
||||
NewTasks []Task `json:"newTasks,omitempty"`
|
||||
Controls []ControlAction `json:"controls,omitempty"`
|
||||
StreamRequests []StreamRequest `json:"streamRequests,omitempty"`
|
||||
Watching bool `json:"watching"`
|
||||
Upgrade *UpgradeSignal `json:"upgrade,omitempty"`
|
||||
Scan bool `json:"scan,omitempty"`
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Watch progress types (used by stream tracking)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -311,21 +311,10 @@ func configConnection(cfg *config.Config) error {
|
|||
).Run()
|
||||
}
|
||||
|
||||
func configAdvanced(cfg *config.Config) error {
|
||||
return huh.NewForm(
|
||||
huh.NewGroup(
|
||||
huh.NewInput().
|
||||
Title("Poll interval").
|
||||
Description("How often to check for new tasks (e.g. 30s, 1m)").
|
||||
Value(&cfg.Daemon.PollInterval).
|
||||
Validate(validateDuration),
|
||||
huh.NewInput().
|
||||
Title("Heartbeat interval").
|
||||
Description("How often to send heartbeat to server (e.g. 30s, 1m)").
|
||||
Value(&cfg.Daemon.HeartbeatInterval).
|
||||
Validate(validateDuration),
|
||||
),
|
||||
).Run()
|
||||
func configAdvanced(_ *config.Config) error {
|
||||
// Sync intervals are adaptive (3s watching, 60s idle) — no user-facing config needed.
|
||||
fmt.Println("No advanced settings to configure. Sync intervals are automatic.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// ── Validators ──────────────────────────────────────────────────────
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ import (
|
|||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
|
|
@ -27,13 +26,13 @@ func newStartCmd() *cobra.Command {
|
|||
Short: "Start the download daemon (foreground)",
|
||||
Long: `Start the unarr daemon in the foreground.
|
||||
|
||||
Registers with the server, receives download tasks via WebSocket (with
|
||||
HTTP fallback), and executes them using the configured download method.
|
||||
Registers with the server, receives download tasks via periodic sync,
|
||||
and executes them using the configured download method.
|
||||
Supports torrent, debrid, and usenet downloads concurrently.
|
||||
|
||||
The daemon sends periodic heartbeats and reports download progress back
|
||||
to the web dashboard. Press Ctrl+C to stop gracefully — active downloads
|
||||
get up to 30 seconds to finish.
|
||||
The daemon syncs state with the server every 3s when someone is viewing
|
||||
the web dashboard, or every 60s when idle. Press Ctrl+C to stop
|
||||
gracefully — active downloads get up to 30 seconds to finish.
|
||||
|
||||
Requires: API key, agent ID, and download directory (run 'unarr init' first).
|
||||
|
||||
|
|
@ -127,85 +126,59 @@ func runDaemonStart() error {
|
|||
bold.Println(" unarr Daemon")
|
||||
fmt.Println()
|
||||
|
||||
// Parse intervals
|
||||
pollInterval, _ := time.ParseDuration(cfg.Daemon.PollInterval)
|
||||
if pollInterval == 0 {
|
||||
pollInterval = 30 * time.Second
|
||||
}
|
||||
heartbeatInterval, _ := time.ParseDuration(cfg.Daemon.HeartbeatInterval)
|
||||
if heartbeatInterval == 0 {
|
||||
heartbeatInterval = 30 * time.Second
|
||||
}
|
||||
statusInterval, _ := time.ParseDuration(cfg.Daemon.StatusInterval)
|
||||
if statusInterval == 0 {
|
||||
statusInterval = 3 * time.Second
|
||||
}
|
||||
|
||||
userAgent := "unarr/" + Version
|
||||
|
||||
// Create daemon config
|
||||
daemonCfg := agent.DaemonConfig{
|
||||
AgentID: cfg.Agent.ID,
|
||||
AgentName: cfg.Agent.Name,
|
||||
Version: Version,
|
||||
DownloadDir: cfg.Download.Dir,
|
||||
PollInterval: pollInterval,
|
||||
HeartbeatInterval: heartbeatInterval,
|
||||
StreamPort: cfg.Download.StreamPort,
|
||||
LanIP: engine.LanIP(),
|
||||
TailscaleIP: engine.TailscaleIP(),
|
||||
AgentID: cfg.Agent.ID,
|
||||
AgentName: cfg.Agent.Name,
|
||||
Version: Version,
|
||||
DownloadDir: cfg.Download.Dir,
|
||||
StreamPort: cfg.Download.StreamPort,
|
||||
LanIP: engine.LanIP(),
|
||||
TailscaleIP: engine.TailscaleIP(),
|
||||
}
|
||||
|
||||
// Create transport: Hybrid (WS + HTTP fallback) or HTTP-only
|
||||
httpT := agent.NewHTTPTransport(cfg.Auth.APIURL, cfg.Auth.APIKey, userAgent)
|
||||
|
||||
wsURL := cfg.Auth.WSURL
|
||||
if wsURL == "" {
|
||||
wsURL = deriveWSURL(cfg.Auth.APIURL, cfg.Agent.ID)
|
||||
}
|
||||
|
||||
var transport agent.Transport
|
||||
if wsURL != "" {
|
||||
wsT := agent.NewWSTransport(wsURL, cfg.Auth.APIKey, cfg.Agent.ID, userAgent)
|
||||
transport = agent.NewHybridTransport(wsT, httpT)
|
||||
log.Printf("Transport: WebSocket (fallback: HTTP) → %s", wsURL)
|
||||
} else {
|
||||
transport = httpT
|
||||
log.Println("Transport: HTTP only")
|
||||
}
|
||||
|
||||
// Create daemon — always uses Transport interface
|
||||
d := agent.NewDaemon(daemonCfg, transport)
|
||||
|
||||
// Create agent client for watch progress reporting
|
||||
// Create HTTP client — single communication channel
|
||||
agentClient := agent.NewClient(cfg.Auth.APIURL, cfg.Auth.APIKey, userAgent)
|
||||
log.Printf("Transport: HTTP sync → %s", cfg.Auth.APIURL)
|
||||
|
||||
// Create daemon
|
||||
d := agent.NewDaemon(daemonCfg, agentClient)
|
||||
|
||||
// Start SIGUSR1 reload watcher (unix only, no-op on Windows)
|
||||
startReloadWatcher(&ReloadableConfig{Daemon: d})
|
||||
|
||||
// Daemon-scoped context — cancelled on shutdown
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Create progress reporter using transport
|
||||
reporter := engine.NewProgressReporterWithTransport(transport, statusInterval)
|
||||
reporter.SetWatchingFunc(func() bool { return d.Watching.Load() })
|
||||
reporter.SetWatchingChangedHandler(func(watching bool) { d.Watching.Store(watching) })
|
||||
|
||||
// Parse speed limits
|
||||
maxDl, _ := config.ParseSpeed(cfg.Download.MaxDownloadSpeed)
|
||||
maxUl, _ := config.ParseSpeed(cfg.Download.MaxUploadSpeed)
|
||||
|
||||
// Parse torrent timeouts from config (default: 0 = unlimited, like qBittorrent)
|
||||
// Parse torrent timeouts
|
||||
metaTimeout, _ := time.ParseDuration(cfg.Download.MetadataTimeout)
|
||||
stallTimeout, _ := time.ParseDuration(cfg.Download.StallTimeout)
|
||||
|
||||
// Create progress reporter — only used for stream tasks (handleStreamTask)
|
||||
// The sync goroutine handles all regular progress reporting.
|
||||
statusInterval, _ := time.ParseDuration(cfg.Daemon.StatusInterval)
|
||||
if statusInterval == 0 {
|
||||
statusInterval = 3 * time.Second
|
||||
}
|
||||
reporter := engine.NewProgressReporter(agentClient, statusInterval)
|
||||
reporter.SetWatchingFunc(func() bool { return d.Watching.Load() })
|
||||
|
||||
// Create torrent downloader
|
||||
torrentDl, err := engine.NewTorrentDownloader(engine.TorrentConfig{
|
||||
DataDir: cfg.Download.Dir,
|
||||
MetadataTimeout: metaTimeout, // 0 = unlimited (default)
|
||||
StallTimeout: stallTimeout, // 0 = unlimited (default)
|
||||
MaxTimeout: 0, // unlimited
|
||||
MetadataTimeout: metaTimeout,
|
||||
StallTimeout: stallTimeout,
|
||||
MaxTimeout: 0,
|
||||
MaxDownloadRate: maxDl,
|
||||
MaxUploadRate: maxUl,
|
||||
ListenPort: cfg.Download.ListenPort, // 0 = default 42069
|
||||
ListenPort: cfg.Download.ListenPort,
|
||||
SeedEnabled: false,
|
||||
})
|
||||
if err != nil {
|
||||
|
|
@ -223,7 +196,7 @@ func runDaemonStart() error {
|
|||
log.Printf("Speed limits: download=%s upload=%s", dlStr, ulStr)
|
||||
}
|
||||
|
||||
// Create debrid downloader (HTTPS-based, no provider interaction needed)
|
||||
// Create debrid downloader
|
||||
debridDl := engine.NewDebridDownloader()
|
||||
|
||||
// Create download manager
|
||||
|
|
@ -237,170 +210,53 @@ func runDaemonStart() error {
|
|||
TVShowsDir: cfg.Organize.TVShowsDir,
|
||||
OutputDir: cfg.Download.Dir,
|
||||
},
|
||||
}, reporter, torrentDl, debridDl, engine.NewUsenetDownloader(httpT.Client()))
|
||||
}, reporter, torrentDl, debridDl, engine.NewUsenetDownloader(agentClient))
|
||||
|
||||
// Create persistent stream server — lives for the entire daemon lifecycle.
|
||||
// One port, one server, swap files with SetFile(). No more port churn.
|
||||
// Create persistent stream server
|
||||
streamSrv := engine.NewStreamServer(cfg.Download.StreamPort)
|
||||
if err := streamSrv.Listen(ctx); err != nil {
|
||||
return fmt.Errorf("start stream server: %w", err)
|
||||
}
|
||||
// Update heartbeat with actual port (may differ if configured port was busy)
|
||||
d.UpdateStreamPort(streamSrv.Port())
|
||||
|
||||
// Wire state tracking
|
||||
// Wire sync client callbacks
|
||||
sc := d.SyncClient()
|
||||
sc.GetFreeSlots = manager.FreeSlots
|
||||
sc.GetTaskStates = manager.TaskStates
|
||||
d.GetActiveCount = manager.ActiveCount
|
||||
d.GetCleanableBytes = CleanableBytes
|
||||
|
||||
// Wire: server-side signals -> manager actions + stream tasks
|
||||
reporter.SetCancelHandler(func(taskID string) {
|
||||
manager.CancelTask(taskID)
|
||||
cancelStreamTask(taskID)
|
||||
})
|
||||
reporter.SetPauseHandler(func(taskID string) {
|
||||
manager.PauseTask(taskID)
|
||||
cancelStreamTask(taskID)
|
||||
})
|
||||
reporter.SetDeleteFilesHandler(func(taskID string) {
|
||||
manager.CancelAndDeleteFiles(taskID)
|
||||
cancelStreamTask(taskID)
|
||||
})
|
||||
|
||||
// Wire: stream requested on active download → set file on persistent server
|
||||
reporter.SetStreamRequestedHandler(func(taskID string) {
|
||||
task := manager.GetTask(taskID)
|
||||
if task == nil {
|
||||
log.Printf("[%s] stream requested but task not found in manager", taskID[:8])
|
||||
return
|
||||
}
|
||||
if task.GetStreamURL() != "" {
|
||||
return // already streaming
|
||||
}
|
||||
provider, err := torrentDl.GetStreamProvider(taskID)
|
||||
if err != nil {
|
||||
log.Printf("[%s] stream failed: %v", taskID[:8], err)
|
||||
return
|
||||
}
|
||||
cancelStreamContexts()
|
||||
streamSrv.SetFile(provider, taskID)
|
||||
task.SetStreamURL(streamSrv.URLsJSON())
|
||||
log.Printf("[%s] streaming active download: %s", taskID[:8], provider.FileName())
|
||||
|
||||
// Start watch progress reporter with cancellable context
|
||||
watchCtx, watchCancel := context.WithCancel(ctx) //nolint:gosec // cancel stored in streamRegistry, called by cancelStreamContexts()
|
||||
streamRegistry.mu.Lock()
|
||||
streamRegistry.cancels["watch:"+taskID] = watchCancel
|
||||
streamRegistry.mu.Unlock()
|
||||
go engine.NewWatchReporter(agentClient, streamSrv, taskID).Run(watchCtx)
|
||||
})
|
||||
|
||||
// Wire: daemon claimed tasks -> manager
|
||||
// Trigger immediate sync when a download slot frees up
|
||||
manager.OnTaskDone = func() { d.TriggerSync() }
|
||||
|
||||
// Wire: sync receives new tasks → submit to manager or handle stream
|
||||
d.OnTasksClaimed = func(tasks []agent.Task) {
|
||||
for _, t := range tasks {
|
||||
if t.Mode == "stream" {
|
||||
// Skip if already streaming this task
|
||||
if isStreamingTask(t.ID) {
|
||||
continue
|
||||
}
|
||||
// Only 1 stream at a time: cancel existing stream goroutines + clear file
|
||||
cancelStreamContexts()
|
||||
streamSrv.ClearFile()
|
||||
// Reserve slot before spawning goroutine to prevent TOCTOU race.
|
||||
streamCtx, streamCancel := context.WithCancel(ctx) //nolint:gosec // G118: cancel ownership transferred to streamRegistry
|
||||
streamCtx, streamCancel := context.WithCancel(ctx) //nolint:gosec // G118: cancel stored in registry
|
||||
streamRegistry.mu.Lock()
|
||||
streamRegistry.cancels[t.ID] = streamCancel
|
||||
streamRegistry.mu.Unlock()
|
||||
go handleStreamTask(streamCtx, t, reporter, cfg, agentClient, streamSrv)
|
||||
} else if t.ForceStart || manager.HasCapacity() {
|
||||
manager.Submit(ctx, t)
|
||||
} else {
|
||||
log.Printf("[%s] skipped: no capacity (max %d)", t.ID[:8], cfg.Download.MaxConcurrent)
|
||||
manager.Submit(ctx, t)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Wire: stream requests for completed downloads → set file on persistent server
|
||||
d.OnStreamRequested = func(sr agent.StreamRequest) {
|
||||
// Already serving this task — just notify server it's ready
|
||||
if streamSrv.CurrentTaskID() == sr.TaskID {
|
||||
go func() {
|
||||
if _, err := transport.SendProgress(ctx, agent.StatusUpdate{
|
||||
TaskID: sr.TaskID,
|
||||
StreamReady: true,
|
||||
}); err != nil {
|
||||
log.Printf("[%s] stream ready re-notify failed: %v", sr.TaskID[:8], err)
|
||||
}
|
||||
}()
|
||||
return
|
||||
}
|
||||
|
||||
filePath := sr.FilePath
|
||||
info, err := os.Stat(filePath)
|
||||
if err != nil {
|
||||
log.Printf("[%s] stream request: file not found: %s", sr.TaskID[:8], filePath)
|
||||
go func() {
|
||||
if _, err := transport.SendProgress(ctx, agent.StatusUpdate{
|
||||
TaskID: sr.TaskID,
|
||||
Status: "failed",
|
||||
ErrorMessage: fmt.Sprintf("file not found: %s", filePath),
|
||||
}); err != nil {
|
||||
log.Printf("[%s] stream error report failed: %v", sr.TaskID[:8], err)
|
||||
}
|
||||
}()
|
||||
return
|
||||
}
|
||||
|
||||
// If filePath is a directory, find the largest video file inside
|
||||
if info.IsDir() {
|
||||
found := engine.FindVideoFile(filePath)
|
||||
if found == "" {
|
||||
log.Printf("[%s] stream request: no video file in directory: %s", sr.TaskID[:8], filePath)
|
||||
go func() {
|
||||
if _, err := transport.SendProgress(ctx, agent.StatusUpdate{
|
||||
TaskID: sr.TaskID,
|
||||
Status: "failed",
|
||||
ErrorMessage: fmt.Sprintf("no video file in directory: %s", filePath),
|
||||
}); err != nil {
|
||||
log.Printf("[%s] stream error report failed: %v", sr.TaskID[:8], err)
|
||||
}
|
||||
}()
|
||||
return
|
||||
}
|
||||
filePath = found
|
||||
log.Printf("[%s] resolved directory to video file: %s", sr.TaskID[:8], filepath.Base(filePath))
|
||||
}
|
||||
|
||||
// Cancel any active stream goroutines and swap file on the persistent server
|
||||
cancelStreamContexts()
|
||||
streamSrv.SetFile(engine.NewDiskFileProvider(filePath), sr.TaskID)
|
||||
|
||||
log.Printf("[%s] streaming from disk: %s → %s", sr.TaskID[:8], filepath.Base(filePath), streamSrv.URL())
|
||||
|
||||
// Start watch progress reporter with a cancellable context
|
||||
// so it stops when the user switches to a different stream.
|
||||
watchCtx, watchCancel := context.WithCancel(ctx) //nolint:gosec // cancel stored in streamRegistry, called by cancelStreamContexts()
|
||||
streamRegistry.mu.Lock()
|
||||
streamRegistry.cancels["watch:"+sr.TaskID] = watchCancel
|
||||
streamRegistry.mu.Unlock()
|
||||
go engine.NewWatchReporter(agentClient, streamSrv, sr.TaskID).Run(watchCtx)
|
||||
|
||||
// Notify server that stream is ready (clears streamRequested flag)
|
||||
go func() {
|
||||
if _, err := transport.SendProgress(ctx, agent.StatusUpdate{
|
||||
TaskID: sr.TaskID,
|
||||
StreamReady: true,
|
||||
}); err != nil {
|
||||
log.Printf("[%s] stream ready report failed: %v", sr.TaskID[:8], err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Wire: WS control actions (pause/cancel/stream pushed from server)
|
||||
d.OnControlAction = func(action, taskID string) {
|
||||
// Wire: sync receives control signals → act on manager
|
||||
d.OnControlAction = func(action, taskID string, deleteFiles bool) {
|
||||
switch action {
|
||||
case "cancel":
|
||||
manager.CancelTask(taskID)
|
||||
if deleteFiles {
|
||||
manager.CancelAndDeleteFiles(taskID)
|
||||
} else {
|
||||
manager.CancelTask(taskID)
|
||||
}
|
||||
cancelStreamTask(taskID)
|
||||
if streamSrv.CurrentTaskID() == taskID {
|
||||
streamSrv.ClearFile()
|
||||
|
|
@ -412,10 +268,9 @@ func runDaemonStart() error {
|
|||
streamSrv.ClearFile()
|
||||
}
|
||||
case "resume":
|
||||
log.Printf("[%s] resume requested via WebSocket, triggering poll", taskID[:8])
|
||||
d.TriggerPoll()
|
||||
log.Printf("[%s] resume requested, triggering sync", agent.ShortID(taskID))
|
||||
d.TriggerSync()
|
||||
case "stream":
|
||||
// Skip if already streaming this task
|
||||
if streamSrv.CurrentTaskID() == taskID {
|
||||
return
|
||||
}
|
||||
|
|
@ -425,13 +280,19 @@ func runDaemonStart() error {
|
|||
}
|
||||
provider, err := torrentDl.GetStreamProvider(taskID)
|
||||
if err != nil {
|
||||
log.Printf("[%s] stream failed: %v", taskID[:8], err)
|
||||
log.Printf("[%s] stream failed: %v", agent.ShortID(taskID), err)
|
||||
return
|
||||
}
|
||||
cancelStreamContexts()
|
||||
streamSrv.SetFile(provider, taskID)
|
||||
task.SetStreamURL(streamSrv.URLsJSON())
|
||||
log.Printf("[%s] streaming via WS: %s", taskID[:8], provider.FileName())
|
||||
log.Printf("[%s] streaming: %s", agent.ShortID(taskID), provider.FileName())
|
||||
|
||||
watchCtx, watchCancel := context.WithCancel(ctx) //nolint:gosec // G118
|
||||
streamRegistry.mu.Lock()
|
||||
streamRegistry.cancels["watch:"+taskID] = watchCancel
|
||||
streamRegistry.mu.Unlock()
|
||||
go engine.NewWatchReporter(agentClient, streamSrv, taskID).Run(watchCtx)
|
||||
case "stop-stream":
|
||||
cancelStreamTask(taskID)
|
||||
if streamSrv.CurrentTaskID() == taskID {
|
||||
|
|
@ -440,19 +301,77 @@ func runDaemonStart() error {
|
|||
}
|
||||
}
|
||||
|
||||
// Config hot-reload (SIGUSR1 on Unix, no-op on Windows)
|
||||
// Tickers are initialized inside d.Run(), so we pass the daemon
|
||||
// and the reload goroutine reads them when the signal arrives.
|
||||
startReloadWatcher(&ReloadableConfig{Daemon: d})
|
||||
// Wire: sync receives stream requests for completed downloads
|
||||
d.OnStreamRequested = func(sr agent.StreamRequest) {
|
||||
if streamSrv.CurrentTaskID() == sr.TaskID {
|
||||
// Already serving — notify server it's ready
|
||||
go func() {
|
||||
if _, err := agentClient.ReportStatus(ctx, agent.StatusUpdate{
|
||||
TaskID: sr.TaskID,
|
||||
StreamReady: true,
|
||||
}); err != nil {
|
||||
log.Printf("[%s] stream ready re-notify failed: %v", agent.ShortID(sr.TaskID), err)
|
||||
}
|
||||
}()
|
||||
return
|
||||
}
|
||||
|
||||
// Signal handling
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
filePath := sr.FilePath
|
||||
info, err := os.Stat(filePath)
|
||||
if err != nil {
|
||||
log.Printf("[%s] stream request: file not found: %s", agent.ShortID(sr.TaskID), filePath)
|
||||
go func() {
|
||||
if _, err := agentClient.ReportStatus(ctx, agent.StatusUpdate{
|
||||
TaskID: sr.TaskID,
|
||||
Status: "failed",
|
||||
ErrorMessage: fmt.Sprintf("file not found: %s", filePath),
|
||||
}); err != nil {
|
||||
log.Printf("[%s] stream error report failed: %v", agent.ShortID(sr.TaskID), err)
|
||||
}
|
||||
}()
|
||||
return
|
||||
}
|
||||
|
||||
// Start progress reporter in background
|
||||
go reporter.Run(ctx)
|
||||
if info.IsDir() {
|
||||
found := engine.FindVideoFile(filePath)
|
||||
if found == "" {
|
||||
log.Printf("[%s] stream request: no video file in directory: %s", agent.ShortID(sr.TaskID), filePath)
|
||||
go func() {
|
||||
if _, err := agentClient.ReportStatus(ctx, agent.StatusUpdate{
|
||||
TaskID: sr.TaskID,
|
||||
Status: "failed",
|
||||
ErrorMessage: fmt.Sprintf("no video file in directory: %s", filePath),
|
||||
}); err != nil {
|
||||
log.Printf("[%s] stream error report failed: %v", agent.ShortID(sr.TaskID), err)
|
||||
}
|
||||
}()
|
||||
return
|
||||
}
|
||||
filePath = found
|
||||
log.Printf("[%s] resolved directory to video file: %s", agent.ShortID(sr.TaskID), filepath.Base(filePath))
|
||||
}
|
||||
|
||||
// Periodic DHT node persistence (every 5 min) — protects against crash data loss
|
||||
cancelStreamContexts()
|
||||
streamSrv.SetFile(engine.NewDiskFileProvider(filePath), sr.TaskID)
|
||||
log.Printf("[%s] streaming from disk: %s → %s", agent.ShortID(sr.TaskID), filepath.Base(filePath), streamSrv.URL())
|
||||
|
||||
watchCtx, watchCancel := context.WithCancel(ctx) //nolint:gosec // G118
|
||||
streamRegistry.mu.Lock()
|
||||
streamRegistry.cancels["watch:"+sr.TaskID] = watchCancel
|
||||
streamRegistry.mu.Unlock()
|
||||
go engine.NewWatchReporter(agentClient, streamSrv, sr.TaskID).Run(watchCtx)
|
||||
|
||||
go func() {
|
||||
if _, err := agentClient.ReportStatus(ctx, agent.StatusUpdate{
|
||||
TaskID: sr.TaskID,
|
||||
StreamReady: true,
|
||||
}); err != nil {
|
||||
log.Printf("[%s] stream ready report failed: %v", agent.ShortID(sr.TaskID), err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Periodic DHT node persistence (every 5 min)
|
||||
go func() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
|
@ -466,8 +385,7 @@ func runDaemonStart() error {
|
|||
}
|
||||
}()
|
||||
|
||||
// Start auto-scan goroutine (daily library scan + sync)
|
||||
// Default scan_path to download dir so auto-scan works out of the box.
|
||||
// Start auto-scan goroutine
|
||||
scanPath := cfg.Library.ScanPath
|
||||
if scanPath == "" {
|
||||
scanPath = cfg.Download.Dir
|
||||
|
|
@ -484,7 +402,10 @@ func runDaemonStart() error {
|
|||
go runAutoScan(ctx, scanCfg, scanInterval, agentClient, d.ScanNow)
|
||||
}
|
||||
|
||||
// Start daemon (blocks)
|
||||
// Start reporter only for stream task handling
|
||||
go reporter.Run(ctx)
|
||||
|
||||
// Start daemon (blocks — runs sync loop)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- d.Run(ctx)
|
||||
|
|
@ -493,6 +414,10 @@ func runDaemonStart() error {
|
|||
// Start idle guard for the persistent stream server
|
||||
go startIdleGuard(ctx, streamSrv)
|
||||
|
||||
// Signal handling
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// Wait for signal or error
|
||||
select {
|
||||
case sig := <-sigCh:
|
||||
|
|
@ -506,6 +431,7 @@ func runDaemonStart() error {
|
|||
defer shutdownCancel()
|
||||
manager.Shutdown(shutdownCtx)
|
||||
|
||||
d.Deregister()
|
||||
fmt.Println(" Daemon stopped.")
|
||||
return nil
|
||||
|
||||
|
|
@ -517,41 +443,6 @@ func runDaemonStart() error {
|
|||
}
|
||||
}
|
||||
|
||||
// deriveWSURL derives a WebSocket URL from the API URL.
|
||||
// https://torrentclaw.com → wss://unarr.torrentclaw.com/ws/{agentId}
|
||||
// Returns "" for localhost/dev environments where WS gateway isn't available.
|
||||
func deriveWSURL(apiURL, agentID string) string {
|
||||
if apiURL == "" || agentID == "" {
|
||||
return ""
|
||||
}
|
||||
// Parse domain from API URL
|
||||
domain := apiURL
|
||||
for _, prefix := range []string{"https://", "http://"} {
|
||||
if len(domain) > len(prefix) && domain[:len(prefix)] == prefix {
|
||||
domain = domain[len(prefix):]
|
||||
break
|
||||
}
|
||||
}
|
||||
// Strip trailing slash/path
|
||||
for i := 0; i < len(domain); i++ {
|
||||
if domain[i] == '/' {
|
||||
domain = domain[:i]
|
||||
break
|
||||
}
|
||||
}
|
||||
// Strip port if present
|
||||
if idx := strings.LastIndex(domain, ":"); idx > 0 {
|
||||
domain = domain[:idx]
|
||||
}
|
||||
|
||||
// Skip WS for localhost/dev — gateway only available in production
|
||||
if domain == "localhost" || domain == "127.0.0.1" || domain == "0.0.0.0" {
|
||||
return ""
|
||||
}
|
||||
|
||||
return "wss://unarr." + domain + "/ws/" + agentID
|
||||
}
|
||||
|
||||
func formatSpeedLog(bps int64) string {
|
||||
switch {
|
||||
case bps >= 1024*1024*1024:
|
||||
|
|
@ -569,11 +460,9 @@ func formatSpeedLog(bps int64) string {
|
|||
func runAutoScan(ctx context.Context, cfg config.Config, interval time.Duration, ac *agent.Client, scanNow <-chan struct{}) {
|
||||
log.Printf("[auto-scan] enabled: every %s, path: %s", interval, cfg.Library.ScanPath)
|
||||
|
||||
// Run first scan after a short delay (let daemon stabilize)
|
||||
select {
|
||||
case <-time.After(30 * time.Second):
|
||||
case <-scanNow:
|
||||
// Immediate scan requested before initial delay
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
|
|
@ -608,7 +497,6 @@ func runAutoScan(ctx context.Context, cfg config.Config, interval time.Duration,
|
|||
return
|
||||
}
|
||||
|
||||
// Sync to server
|
||||
items := library.BuildSyncItems(cache)
|
||||
if len(items) == 0 {
|
||||
log.Printf("[auto-scan] no items to sync")
|
||||
|
|
|
|||
|
|
@ -2,32 +2,6 @@ package cmd
|
|||
|
||||
import "testing"
|
||||
|
||||
func TestDeriveWSURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
apiURL string
|
||||
agentID string
|
||||
want string
|
||||
}{
|
||||
{"https://torrentclaw.com", "agent-123", "wss://unarr.torrentclaw.com/ws/agent-123"},
|
||||
{"http://localhost:3000", "a1", ""}, // localhost skipped
|
||||
{"http://127.0.0.1:3000", "a1", ""}, // 127.0.0.1 skipped
|
||||
{"https://torrentclaw.com/", "a1", "wss://unarr.torrentclaw.com/ws/a1"},
|
||||
{"https://api.example.io", "x", "wss://unarr.api.example.io/ws/x"},
|
||||
{"", "agent-123", ""},
|
||||
{"https://torrentclaw.com", "", ""},
|
||||
{"", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.apiURL+"_"+tt.agentID, func(t *testing.T) {
|
||||
got := deriveWSURL(tt.apiURL, tt.agentID)
|
||||
if got != tt.want {
|
||||
t.Errorf("deriveWSURL(%q, %q) = %q, want %q", tt.apiURL, tt.agentID, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatSpeedLog(t *testing.T) {
|
||||
tests := []struct {
|
||||
bps int64
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ import (
|
|||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/torrentclaw/unarr/internal/agent"
|
||||
"github.com/torrentclaw/unarr/internal/config"
|
||||
|
|
@ -19,7 +18,8 @@ type ReloadableConfig struct {
|
|||
}
|
||||
|
||||
// startReloadWatcher listens for SIGUSR1 and reloads config.
|
||||
// Only intervals are hot-reloadable (speeds require torrent client restart).
|
||||
// With the sync-based architecture, intervals are fixed (3s watching, 60s idle).
|
||||
// Hot-reload now mainly serves as a signal to re-read config for future settings.
|
||||
func startReloadWatcher(rc *ReloadableConfig) {
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGUSR1)
|
||||
|
|
@ -28,24 +28,11 @@ func startReloadWatcher(rc *ReloadableConfig) {
|
|||
for range sigCh {
|
||||
log.Println("Received SIGUSR1, reloading config...")
|
||||
|
||||
cfg, err := config.Load("")
|
||||
_, err := config.Load("")
|
||||
if err != nil {
|
||||
log.Printf("Config reload failed: %v", err)
|
||||
continue
|
||||
}
|
||||
cfg.ApplyEnvOverrides()
|
||||
|
||||
// Update poll interval
|
||||
if d, _ := time.ParseDuration(cfg.Daemon.PollInterval); d > 0 && rc.Daemon.PollTicker != nil {
|
||||
rc.Daemon.PollTicker.Reset(d)
|
||||
log.Printf(" Poll interval: %s", d)
|
||||
}
|
||||
|
||||
// Update heartbeat interval
|
||||
if d, _ := time.ParseDuration(cfg.Daemon.HeartbeatInterval); d > 0 && rc.Daemon.HeartbeatTicker != nil {
|
||||
rc.Daemon.HeartbeatTicker.Reset(d)
|
||||
log.Printf(" Heartbeat interval: %s", d)
|
||||
}
|
||||
|
||||
log.Println("Config reloaded successfully")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
package cmd
|
||||
|
||||
// Version is the CLI version. Overridden by goreleaser ldflags at release time.
|
||||
var Version = "0.5.5"
|
||||
var Version = "0.5.6"
|
||||
|
|
|
|||
|
|
@ -26,7 +26,6 @@ type Config struct {
|
|||
type AuthConfig struct {
|
||||
APIKey string `toml:"api_key"`
|
||||
APIURL string `toml:"api_url"`
|
||||
WSURL string `toml:"ws_url"` // optional, derived from api_url if empty
|
||||
}
|
||||
|
||||
type AgentConfig struct {
|
||||
|
|
@ -54,9 +53,7 @@ type OrganizeConfig struct {
|
|||
}
|
||||
|
||||
type DaemonConfig struct {
|
||||
PollInterval string `toml:"poll_interval"`
|
||||
HeartbeatInterval string `toml:"heartbeat_interval"`
|
||||
StatusInterval string `toml:"status_interval"`
|
||||
StatusInterval string `toml:"status_interval"`
|
||||
}
|
||||
|
||||
type NotificationsConfig struct {
|
||||
|
|
@ -92,10 +89,7 @@ func Default() Config {
|
|||
Organize: OrganizeConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
Daemon: DaemonConfig{
|
||||
PollInterval: "30s",
|
||||
HeartbeatInterval: "30s",
|
||||
},
|
||||
Daemon: DaemonConfig{},
|
||||
Notifications: NotificationsConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
|
|
|
|||
|
|
@ -21,8 +21,8 @@ func TestDefault(t *testing.T) {
|
|||
if cfg.General.Country != "US" {
|
||||
t.Errorf("default Country = %q, want US", cfg.General.Country)
|
||||
}
|
||||
if cfg.Daemon.HeartbeatInterval != "30s" {
|
||||
t.Errorf("default HeartbeatInterval = %q, want 30s", cfg.Daemon.HeartbeatInterval)
|
||||
if cfg.Daemon.StatusInterval != "" {
|
||||
t.Errorf("default StatusInterval = %q, want empty", cfg.Daemon.StatusInterval)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ import (
|
|||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/torrentclaw/unarr/internal/agent"
|
||||
)
|
||||
|
||||
// httpClient is used for debrid HTTPS downloads with a reasonable header timeout.
|
||||
|
|
@ -19,13 +21,6 @@ var httpClient = &http.Client{
|
|||
},
|
||||
}
|
||||
|
||||
func shortID(id string) string {
|
||||
if len(id) > 8 {
|
||||
return id[:8]
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
// DebridDownloader downloads files via HTTPS direct URLs resolved by the server.
|
||||
// The server handles all debrid provider interaction; this downloader only needs
|
||||
// a plain HTTPS URL to fetch.
|
||||
|
|
@ -129,7 +124,7 @@ func (d *DebridDownloader) Download(ctx context.Context, task *Task, outputDir s
|
|||
var serverSize int64
|
||||
if _, err := fmt.Sscanf(cr, "bytes */%d", &serverSize); err == nil && serverSize > 0 && existingSize != serverSize {
|
||||
// Local file size doesn't match server — re-download from scratch
|
||||
log.Printf("[%s] local size %s != server size %s, re-downloading", shortID(task.ID), formatBytes(existingSize), formatBytes(serverSize))
|
||||
log.Printf("[%s] local size %s != server size %s, re-downloading", agent.ShortID(task.ID), formatBytes(existingSize), formatBytes(serverSize))
|
||||
resp.Body.Close()
|
||||
req2, err := http.NewRequestWithContext(dlCtx, http.MethodGet, task.DirectURL, nil)
|
||||
if err != nil {
|
||||
|
|
@ -149,7 +144,7 @@ func (d *DebridDownloader) Download(ctx context.Context, task *Task, outputDir s
|
|||
break // continue to download loop
|
||||
}
|
||||
}
|
||||
log.Printf("[%s] file already complete: %s (%s)", shortID(task.ID), fileName, formatBytes(existingSize))
|
||||
log.Printf("[%s] file already complete: %s (%s)", agent.ShortID(task.ID), fileName, formatBytes(existingSize))
|
||||
return &Result{
|
||||
FilePath: destPath,
|
||||
FileName: fileName,
|
||||
|
|
@ -166,10 +161,10 @@ func (d *DebridDownloader) Download(ctx context.Context, task *Task, outputDir s
|
|||
var flags int
|
||||
if startOffset > 0 {
|
||||
flags = os.O_WRONLY | os.O_APPEND
|
||||
log.Printf("[%s] resuming debrid download at %s: %s", shortID(task.ID), formatBytes(startOffset), fileName)
|
||||
log.Printf("[%s] resuming debrid download at %s: %s", agent.ShortID(task.ID), formatBytes(startOffset), fileName)
|
||||
} else {
|
||||
flags = os.O_WRONLY | os.O_CREATE | os.O_TRUNC
|
||||
log.Printf("[%s] starting debrid download: %s", shortID(task.ID), fileName)
|
||||
log.Printf("[%s] starting debrid download: %s", agent.ShortID(task.ID), fileName)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(destPath), 0o755); err != nil {
|
||||
|
|
@ -223,7 +218,7 @@ func (d *DebridDownloader) Download(ctx context.Context, task *Task, outputDir s
|
|||
}
|
||||
|
||||
log.Printf("[%s] %d%% — %s/%s @ %s/s (debrid)",
|
||||
shortID(task.ID), pct,
|
||||
agent.ShortID(task.ID), pct,
|
||||
formatBytes(downloaded), formatBytes(totalBytes), formatBytes(speed))
|
||||
|
||||
p := Progress{
|
||||
|
|
@ -252,7 +247,7 @@ func (d *DebridDownloader) Download(ctx context.Context, task *Task, outputDir s
|
|||
}
|
||||
}
|
||||
|
||||
log.Printf("[%s] debrid download complete: %s (%s)", shortID(task.ID), fileName, formatBytes(downloaded))
|
||||
log.Printf("[%s] debrid download complete: %s (%s)", agent.ShortID(task.ID), fileName, formatBytes(downloaded))
|
||||
|
||||
return &Result{
|
||||
FilePath: destPath,
|
||||
|
|
@ -271,7 +266,7 @@ func (d *DebridDownloader) Pause(taskID string) error {
|
|||
|
||||
if ok {
|
||||
cancel()
|
||||
log.Printf("[%s] debrid download paused (file kept for resume)", shortID(taskID))
|
||||
log.Printf("[%s] debrid download paused (file kept for resume)", agent.ShortID(taskID))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -285,7 +280,7 @@ func (d *DebridDownloader) Cancel(taskID string) error {
|
|||
|
||||
if ok {
|
||||
cancel()
|
||||
log.Printf("[%s] debrid download cancelled", shortID(taskID))
|
||||
log.Printf("[%s] debrid download cancelled", agent.ShortID(taskID))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,6 +28,15 @@ type Manager struct {
|
|||
|
||||
sem chan struct{}
|
||||
wg sync.WaitGroup
|
||||
|
||||
// OnTaskDone is called after a task completes or fails (slot freed).
|
||||
// Used by the daemon to trigger an immediate sync.
|
||||
OnTaskDone func()
|
||||
|
||||
// recentlyFinished holds tasks that completed/failed since the last sync read.
|
||||
// The sync goroutine reads and clears this to include final states in the next sync.
|
||||
recentMu sync.Mutex
|
||||
recentFinished []agent.TaskState
|
||||
}
|
||||
|
||||
// NewManager creates a download manager.
|
||||
|
|
@ -67,7 +76,7 @@ func (m *Manager) Submit(ctx context.Context, at agent.Task) {
|
|||
|
||||
// Force start: bypass semaphore (like Transmission's "Force Start")
|
||||
if at.ForceStart {
|
||||
log.Printf("[%s] force start: bypassing queue", task.ID[:8])
|
||||
log.Printf("[%s] force start: bypassing queue", agent.ShortID(task.ID))
|
||||
m.wg.Add(1)
|
||||
go func() {
|
||||
defer m.wg.Done()
|
||||
|
|
@ -88,7 +97,12 @@ func (m *Manager) Submit(ctx context.Context, at agent.Task) {
|
|||
m.wg.Add(1)
|
||||
go func() {
|
||||
defer m.wg.Done()
|
||||
defer func() { <-m.sem }()
|
||||
defer func() {
|
||||
<-m.sem
|
||||
if m.OnTaskDone != nil {
|
||||
m.OnTaskDone()
|
||||
}
|
||||
}()
|
||||
defer taskCancel()
|
||||
m.processTask(taskCtx, task)
|
||||
}()
|
||||
|
|
@ -99,6 +113,11 @@ func (m *Manager) HasCapacity() bool {
|
|||
return len(m.sem) < cap(m.sem)
|
||||
}
|
||||
|
||||
// FreeSlots returns the number of available download slots.
|
||||
func (m *Manager) FreeSlots() int {
|
||||
return cap(m.sem) - len(m.sem)
|
||||
}
|
||||
|
||||
// ActiveCount returns the number of in-progress downloads.
|
||||
func (m *Manager) ActiveCount() int {
|
||||
m.activeMu.RLock()
|
||||
|
|
@ -113,6 +132,17 @@ func (m *Manager) GetTask(taskID string) *Task {
|
|||
return m.active[taskID]
|
||||
}
|
||||
|
||||
// ActiveTaskIDs returns the IDs of all in-progress tasks.
|
||||
func (m *Manager) ActiveTaskIDs() []string {
|
||||
m.activeMu.RLock()
|
||||
defer m.activeMu.RUnlock()
|
||||
ids := make([]string, 0, len(m.active))
|
||||
for id := range m.active {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// ActiveTasks returns a snapshot of all active tasks.
|
||||
func (m *Manager) ActiveTasks() []*Task {
|
||||
m.activeMu.RLock()
|
||||
|
|
@ -124,6 +154,37 @@ func (m *Manager) ActiveTasks() []*Task {
|
|||
return tasks
|
||||
}
|
||||
|
||||
// TaskStates returns the current state of all active tasks plus any recently
|
||||
// finished tasks that haven't been synced yet. Called by the sync goroutine.
|
||||
func (m *Manager) TaskStates() []agent.TaskState {
|
||||
// Collect active tasks
|
||||
m.activeMu.RLock()
|
||||
states := make([]agent.TaskState, 0, len(m.active))
|
||||
for _, t := range m.active {
|
||||
states = append(states, agent.TaskStateFromUpdate(t.ToStatusUpdate()))
|
||||
}
|
||||
m.activeMu.RUnlock()
|
||||
|
||||
// Drain recently finished tasks (consumed once per sync)
|
||||
m.recentMu.Lock()
|
||||
states = append(states, m.recentFinished...)
|
||||
m.recentFinished = nil
|
||||
m.recentMu.Unlock()
|
||||
|
||||
return states
|
||||
}
|
||||
|
||||
// recordFinished stores a completed/failed task for the next sync cycle.
|
||||
func (m *Manager) recordFinished(update agent.StatusUpdate) {
|
||||
m.recentMu.Lock()
|
||||
defer m.recentMu.Unlock()
|
||||
m.recentFinished = append(m.recentFinished, agent.TaskStateFromUpdate(update))
|
||||
// Keep bounded
|
||||
if len(m.recentFinished) > 20 {
|
||||
m.recentFinished = m.recentFinished[len(m.recentFinished)-20:]
|
||||
}
|
||||
}
|
||||
|
||||
// CancelTask cancels an active download by task ID (keeps partial files).
|
||||
func (m *Manager) CancelTask(taskID string) {
|
||||
m.activeMu.RLock()
|
||||
|
|
@ -150,7 +211,7 @@ func (m *Manager) CancelTask(taskID string) {
|
|||
task.mu.Unlock()
|
||||
task.Transition(StatusCancelled)
|
||||
|
||||
log.Printf("[%s] cancelled: %s", taskID[:8], task.Title)
|
||||
log.Printf("[%s] cancelled: %s", agent.ShortID(taskID), task.Title)
|
||||
}
|
||||
|
||||
// PauseTask pauses an active download (keeps partial files for resume).
|
||||
|
|
@ -173,7 +234,7 @@ func (m *Manager) PauseTask(taskID string) {
|
|||
}
|
||||
|
||||
task.Transition(StatusCancelled) // will be re-created as pending by server
|
||||
log.Printf("[%s] paused: %s", taskID[:8], task.Title)
|
||||
log.Printf("[%s] paused: %s", agent.ShortID(taskID), task.Title)
|
||||
}
|
||||
|
||||
// CancelAndDeleteFiles cancels a download and removes its files from disk.
|
||||
|
|
@ -200,7 +261,7 @@ func (m *Manager) CancelAndDeleteFiles(taskID string) {
|
|||
task.mu.Unlock()
|
||||
task.Transition(StatusCancelled)
|
||||
|
||||
log.Printf("[%s] cancelled + files deleted: %s", taskID[:8], task.Title)
|
||||
log.Printf("[%s] cancelled + files deleted: %s", agent.ShortID(taskID), task.Title)
|
||||
}
|
||||
|
||||
// Wait blocks until all active downloads finish.
|
||||
|
|
@ -261,7 +322,7 @@ func (m *Manager) processTask(ctx context.Context, task *Task) {
|
|||
}
|
||||
|
||||
task.ResolvedMethod = method
|
||||
log.Printf("[%s] resolved method: %s", task.ID[:8], method)
|
||||
log.Printf("[%s] resolved method: %s", agent.ShortID(task.ID), method)
|
||||
|
||||
// 2. Download
|
||||
if err := task.Transition(StatusDownloading); err != nil {
|
||||
|
|
@ -285,7 +346,7 @@ func (m *Manager) processTask(ctx context.Context, task *Task) {
|
|||
if err != nil {
|
||||
// Try fallback
|
||||
if tryFallback(task, m.downloaders) {
|
||||
log.Printf("[%s] %s failed, trying fallback: %v", task.ID[:8], method, err)
|
||||
log.Printf("[%s] %s failed, trying fallback: %v", agent.ShortID(task.ID), method, err)
|
||||
if err := task.Transition(StatusResolving); err == nil {
|
||||
m.processTaskRetry(ctx, task)
|
||||
return
|
||||
|
|
@ -295,61 +356,7 @@ func (m *Manager) processTask(ctx context.Context, task *Task) {
|
|||
return
|
||||
}
|
||||
|
||||
// 3. Verify
|
||||
if err := task.Transition(StatusVerifying); err != nil {
|
||||
m.fail(ctx, task, "transition error: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := verify(result); err != nil {
|
||||
m.fail(ctx, task, "verification failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 4. Organize
|
||||
if err := task.Transition(StatusOrganizing); err != nil {
|
||||
m.fail(ctx, task, "transition error: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
finalPath, err := organize(result, task, m.cfg.Organize)
|
||||
if err != nil {
|
||||
log.Printf("[%s] organize warning: %v (keeping in download dir)", task.ID[:8], err)
|
||||
finalPath = result.FilePath
|
||||
}
|
||||
|
||||
task.mu.Lock()
|
||||
task.FilePath = finalPath
|
||||
task.mu.Unlock()
|
||||
|
||||
// 4b. Handle upgrade replacement (mode = "upgrade")
|
||||
if task.ReplacePath != "" {
|
||||
backupDir := "" // uses default ~/.local/share/unarr/replaced/
|
||||
if err := replaceFile(task.ReplacePath, finalPath, backupDir); err != nil {
|
||||
log.Printf("[%s] replace warning: %v (keeping new file at %s)", task.ID[:8], err, finalPath)
|
||||
} else {
|
||||
task.mu.Lock()
|
||||
task.FilePath = task.ReplacePath
|
||||
task.mu.Unlock()
|
||||
log.Printf("[%s] upgraded: replaced %s", task.ID[:8], task.ReplacePath)
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Complete
|
||||
if method == MethodTorrent && m.cfg.Organize.Enabled {
|
||||
// Could add seeding here in the future
|
||||
}
|
||||
|
||||
if err := task.Transition(StatusCompleted); err != nil {
|
||||
m.fail(ctx, task, "transition error: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("[%s] completed: %s -> %s", task.ID[:8], task.Title, finalPath)
|
||||
if m.cfg.Notifications {
|
||||
desktopNotify("Download complete", task.Title)
|
||||
}
|
||||
m.reporter.ReportFinal(ctx, task)
|
||||
m.finalize(ctx, task, result)
|
||||
}
|
||||
|
||||
// processTaskRetry handles fallback after a method failure.
|
||||
|
|
@ -361,7 +368,7 @@ func (m *Manager) processTaskRetry(ctx context.Context, task *Task) {
|
|||
}
|
||||
|
||||
task.ResolvedMethod = method
|
||||
log.Printf("[%s] fallback to: %s", task.ID[:8], method)
|
||||
log.Printf("[%s] fallback to: %s", agent.ShortID(task.ID), method)
|
||||
|
||||
if err := task.Transition(StatusDownloading); err != nil {
|
||||
m.fail(ctx, task, "transition error: "+err.Error())
|
||||
|
|
@ -383,15 +390,31 @@ func (m *Manager) processTaskRetry(ctx context.Context, task *Task) {
|
|||
return
|
||||
}
|
||||
|
||||
// Verify + Organize + Complete (same as processTask)
|
||||
task.Transition(StatusVerifying)
|
||||
m.finalize(ctx, task, result)
|
||||
}
|
||||
|
||||
// finalize runs verify → organize → upgrade replacement → complete for a downloaded task.
|
||||
func (m *Manager) finalize(ctx context.Context, task *Task, result *Result) {
|
||||
// Verify
|
||||
if err := task.Transition(StatusVerifying); err != nil {
|
||||
m.fail(ctx, task, "transition error: "+err.Error())
|
||||
return
|
||||
}
|
||||
if err := verify(result); err != nil {
|
||||
m.fail(ctx, task, "verification failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
task.Transition(StatusOrganizing)
|
||||
finalPath, _ := organize(result, task, m.cfg.Organize)
|
||||
// Organize
|
||||
if err := task.Transition(StatusOrganizing); err != nil {
|
||||
m.fail(ctx, task, "transition error: "+err.Error())
|
||||
return
|
||||
}
|
||||
finalPath, err := organize(result, task, m.cfg.Organize)
|
||||
if err != nil {
|
||||
log.Printf("[%s] organize warning: %v (keeping in download dir)", agent.ShortID(task.ID), err)
|
||||
finalPath = result.FilePath
|
||||
}
|
||||
if finalPath == "" {
|
||||
finalPath = result.FilePath
|
||||
}
|
||||
|
|
@ -399,8 +422,29 @@ func (m *Manager) processTaskRetry(ctx context.Context, task *Task) {
|
|||
task.FilePath = finalPath
|
||||
task.mu.Unlock()
|
||||
|
||||
task.Transition(StatusCompleted)
|
||||
log.Printf("[%s] completed (fallback): %s -> %s", task.ID[:8], task.Title, finalPath)
|
||||
// Handle upgrade replacement (mode = "upgrade")
|
||||
if task.ReplacePath != "" {
|
||||
backupDir := "" // uses default ~/.local/share/unarr/replaced/
|
||||
if err := replaceFile(task.ReplacePath, finalPath, backupDir); err != nil {
|
||||
log.Printf("[%s] replace warning: %v (keeping new file at %s)", agent.ShortID(task.ID), err, finalPath)
|
||||
} else {
|
||||
task.mu.Lock()
|
||||
task.FilePath = task.ReplacePath
|
||||
task.mu.Unlock()
|
||||
log.Printf("[%s] upgraded: replaced %s", agent.ShortID(task.ID), task.ReplacePath)
|
||||
}
|
||||
}
|
||||
|
||||
// Complete
|
||||
if err := task.Transition(StatusCompleted); err != nil {
|
||||
m.fail(ctx, task, "transition error: "+err.Error())
|
||||
return
|
||||
}
|
||||
log.Printf("[%s] completed: %s -> %s", agent.ShortID(task.ID), task.Title, finalPath)
|
||||
if m.cfg.Notifications {
|
||||
desktopNotify("Download complete", task.Title)
|
||||
}
|
||||
m.recordFinished(task.ToStatusUpdate())
|
||||
m.reporter.ReportFinal(ctx, task)
|
||||
}
|
||||
|
||||
|
|
@ -409,9 +453,10 @@ func (m *Manager) fail(ctx context.Context, task *Task, msg string) {
|
|||
task.ErrorMessage = msg
|
||||
task.mu.Unlock()
|
||||
task.Transition(StatusFailed)
|
||||
log.Printf("[%s] FAILED: %s — %s", task.ID[:8], task.Title, msg)
|
||||
log.Printf("[%s] FAILED: %s — %s", agent.ShortID(task.ID), task.Title, msg)
|
||||
if m.cfg.Notifications {
|
||||
desktopNotify("Download failed", task.Title+": "+msg)
|
||||
}
|
||||
m.recordFinished(task.ToStatusUpdate())
|
||||
m.reporter.ReportFinal(ctx, task)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,13 +13,11 @@ import (
|
|||
type ActionFunc func(taskID string)
|
||||
|
||||
// StatusReporter is the interface used by ProgressReporter to send progress updates.
|
||||
// Both *agent.Client and agent.Transport implement this via their ReportStatus/SendProgress methods.
|
||||
type StatusReporter interface {
|
||||
ReportStatus(ctx context.Context, update agent.StatusUpdate) (*agent.StatusResponse, error)
|
||||
}
|
||||
|
||||
// BatchStatusReporter extends StatusReporter with batch support.
|
||||
// Transports that implement this send all updates in a single request.
|
||||
type BatchStatusReporter interface {
|
||||
StatusReporter
|
||||
BatchReportStatus(ctx context.Context, updates []agent.StatusUpdate) (*agent.BatchStatusResponse, error)
|
||||
|
|
@ -48,7 +46,6 @@ type ProgressReporter struct {
|
|||
}
|
||||
|
||||
// NewProgressReporter creates a reporter that flushes every interval.
|
||||
// Accepts *agent.Client directly (backwards compatible).
|
||||
func NewProgressReporter(ac *agent.Client, interval time.Duration) *ProgressReporter {
|
||||
return &ProgressReporter{
|
||||
reporter: ac,
|
||||
|
|
@ -58,25 +55,6 @@ func NewProgressReporter(ac *agent.Client, interval time.Duration) *ProgressRepo
|
|||
}
|
||||
}
|
||||
|
||||
// NewProgressReporterWithTransport creates a reporter using a Transport.
|
||||
func NewProgressReporterWithTransport(t agent.Transport, interval time.Duration) *ProgressReporter {
|
||||
return &ProgressReporter{
|
||||
reporter: &transportStatusAdapter{t: t},
|
||||
interval: interval,
|
||||
latest: make(map[string]*Task),
|
||||
lastReported: make(map[string]TaskStatus),
|
||||
}
|
||||
}
|
||||
|
||||
// transportStatusAdapter adapts agent.Transport to StatusReporter.
|
||||
type transportStatusAdapter struct {
|
||||
t agent.Transport
|
||||
}
|
||||
|
||||
func (a *transportStatusAdapter) ReportStatus(ctx context.Context, update agent.StatusUpdate) (*agent.StatusResponse, error) {
|
||||
return a.t.SendProgress(ctx, update)
|
||||
}
|
||||
|
||||
// SetCancelHandler sets the callback invoked when the server says a task is cancelled.
|
||||
func (r *ProgressReporter) SetCancelHandler(fn ActionFunc) { r.onCancel = fn }
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue