feat(sync): replace WS+DO transport with unified HTTP sync
Replace the WebSocket + Cloudflare Durable Object architecture with a single POST /sync endpoint. The CLI now operates autonomously with local state (tasks.json) and syncs bidirectionally via adaptive-interval HTTP polling (3s watching, 60s idle). - Remove transport_ws, transport_hybrid, transport_http (~2,600 lines) - Add SyncClient with adaptive interval loop - Add LocalState for CLI-side task persistence - Add TaskStateFromUpdate() helper (DRY) - Extract finalize() to deduplicate processTask/processTaskRetry - Consolidate shortID() into agent.ShortID (was in 3 packages) - Wire GetActiveCount so `unarr status` shows active tasks - Remove poll_interval, heartbeat_interval, ws_url from config - Simplify ProgressReporter (sync replaces direct HTTP reporting)
This commit is contained in:
parent
2398707cc1
commit
5d4a67c7a2
26 changed files with 1320 additions and 3400 deletions
11
CHANGELOG.md
11
CHANGELOG.md
|
|
@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file.
|
||||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
||||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||||
|
|
||||||
|
## [0.5.6] - 2026-04-07
|
||||||
|
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- **ws**: add ping/pong keepalive and read deadline to detect zombie connections
|
||||||
## [0.5.5] - 2026-04-07
|
## [0.5.5] - 2026-04-07
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -17,6 +23,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||||
### Fixed
|
### Fixed
|
||||||
|
|
||||||
- **daemon**: cancel watch reporter on stream switch and re-notify ready
|
- **daemon**: cancel watch reporter on stream switch and re-notify ready
|
||||||
|
|
||||||
|
### Other
|
||||||
|
|
||||||
|
- **release**: 0.5.5
|
||||||
## [0.5.4] - 2026-04-07
|
## [0.5.4] - 2026-04-07
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -153,6 +163,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||||
|
|
||||||
- remove UPX compression (antivirus false positives, startup penalty)
|
- remove UPX compression (antivirus false positives, startup penalty)
|
||||||
- add -s -w -trimpath to Makefile, add build-small target with UPX
|
- add -s -w -trimpath to Makefile, add build-small target with UPX
|
||||||
|
[0.5.6]: https://github.com/torrentclaw/unarr/compare/v0.5.5...v0.5.6
|
||||||
[0.5.5]: https://github.com/torrentclaw/unarr/compare/v0.5.4...v0.5.5
|
[0.5.5]: https://github.com/torrentclaw/unarr/compare/v0.5.4...v0.5.5
|
||||||
[0.5.4]: https://github.com/torrentclaw/unarr/compare/v0.5.3...v0.5.4
|
[0.5.4]: https://github.com/torrentclaw/unarr/compare/v0.5.3...v0.5.4
|
||||||
[0.5.3]: https://github.com/torrentclaw/unarr/compare/v0.5.2...v0.5.3
|
[0.5.3]: https://github.com/torrentclaw/unarr/compare/v0.5.2...v0.5.3
|
||||||
|
|
|
||||||
2
go.mod
2
go.mod
|
|
@ -11,7 +11,6 @@ require (
|
||||||
github.com/fatih/color v1.19.0
|
github.com/fatih/color v1.19.0
|
||||||
github.com/getsentry/sentry-go v0.44.1
|
github.com/getsentry/sentry-go v0.44.1
|
||||||
github.com/google/uuid v1.6.0
|
github.com/google/uuid v1.6.0
|
||||||
github.com/gorilla/websocket v1.5.3
|
|
||||||
github.com/huin/goupnp v1.3.0
|
github.com/huin/goupnp v1.3.0
|
||||||
github.com/olekukonko/tablewriter v1.1.4
|
github.com/olekukonko/tablewriter v1.1.4
|
||||||
github.com/spf13/cobra v1.10.2
|
github.com/spf13/cobra v1.10.2
|
||||||
|
|
@ -69,6 +68,7 @@ require (
|
||||||
github.com/google/btree v1.1.3 // indirect
|
github.com/google/btree v1.1.3 // indirect
|
||||||
github.com/google/go-cmp v0.7.0 // indirect
|
github.com/google/go-cmp v0.7.0 // indirect
|
||||||
github.com/google/pprof v0.0.0-20260302011040-a15ffb7f9dcc // indirect
|
github.com/google/pprof v0.0.0-20260302011040-a15ffb7f9dcc // indirect
|
||||||
|
github.com/gorilla/websocket v1.5.3 // indirect
|
||||||
github.com/huandu/xstrings v1.5.0 // indirect
|
github.com/huandu/xstrings v1.5.0 // indirect
|
||||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||||
|
|
|
||||||
|
|
@ -40,27 +40,6 @@ func (c *Client) Register(ctx context.Context, req RegisterRequest) (*RegisterRe
|
||||||
return &resp, nil
|
return &resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Heartbeat sends a periodic keep-alive signal and returns server directives.
|
|
||||||
func (c *Client) Heartbeat(ctx context.Context, req HeartbeatRequest) (*HeartbeatResponse, error) {
|
|
||||||
var resp HeartbeatResponse
|
|
||||||
if err := c.doPost(ctx, "/api/internal/agent/heartbeat", req, &resp); err != nil {
|
|
||||||
return nil, fmt.Errorf("heartbeat: %w", err)
|
|
||||||
}
|
|
||||||
return &resp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClaimTasks polls for pending download tasks and claims them atomically.
|
|
||||||
// Also returns any stream requests for completed downloads.
|
|
||||||
func (c *Client) ClaimTasks(ctx context.Context, agentID string) (*TasksResponse, error) {
|
|
||||||
url := fmt.Sprintf("/api/internal/agent/tasks?agentId=%s", agentID)
|
|
||||||
var resp TasksResponse
|
|
||||||
if err := c.doGet(ctx, url, &resp); err != nil {
|
|
||||||
return nil, fmt.Errorf("claim tasks: %w", err)
|
|
||||||
}
|
|
||||||
return &resp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReportStatus reports download progress or completion for a task.
|
|
||||||
// Deregister notifies the server that the agent is shutting down.
|
// Deregister notifies the server that the agent is shutting down.
|
||||||
func (c *Client) Deregister(ctx context.Context, agentID string) error {
|
func (c *Client) Deregister(ctx context.Context, agentID string) error {
|
||||||
req := struct {
|
req := struct {
|
||||||
|
|
@ -91,6 +70,16 @@ func (c *Client) BatchReportStatus(ctx context.Context, updates []StatusUpdate)
|
||||||
return &resp, nil
|
return &resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Sync sends the CLI's full state and receives all pending server actions.
|
||||||
|
// This is the single endpoint for bidirectional state synchronization.
|
||||||
|
func (c *Client) Sync(ctx context.Context, req SyncRequest) (*SyncResponse, error) {
|
||||||
|
var resp SyncResponse
|
||||||
|
if err := c.doPost(ctx, "/api/internal/agent/sync", req, &resp); err != nil {
|
||||||
|
return nil, fmt.Errorf("sync: %w", err)
|
||||||
|
}
|
||||||
|
return &resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
// Usenet endpoints
|
// Usenet endpoints
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
|
|
|
||||||
|
|
@ -72,70 +72,6 @@ func TestRegister(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHeartbeat(t *testing.T) {
|
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if r.URL.Path != "/api/internal/agent/heartbeat" {
|
|
||||||
t.Errorf("path = %s, want /api/internal/agent/heartbeat", r.URL.Path)
|
|
||||||
}
|
|
||||||
var req HeartbeatRequest
|
|
||||||
json.NewDecoder(r.Body).Decode(&req)
|
|
||||||
if req.AgentID != "agent-123" {
|
|
||||||
t.Errorf("agentId = %q, want agent-123", req.AgentID)
|
|
||||||
}
|
|
||||||
json.NewEncoder(w).Encode(HeartbeatResponse{Success: true})
|
|
||||||
}))
|
|
||||||
defer srv.Close()
|
|
||||||
|
|
||||||
c := NewClient(srv.URL, "test-key", "unarr-test")
|
|
||||||
resp, err := c.Heartbeat(context.Background(), HeartbeatRequest{AgentID: "agent-123"})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Heartbeat failed: %v", err)
|
|
||||||
}
|
|
||||||
if !resp.Success {
|
|
||||||
t.Error("expected success=true")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClaimTasks(t *testing.T) {
|
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if r.Method != http.MethodGet {
|
|
||||||
t.Errorf("method = %s, want GET", r.Method)
|
|
||||||
}
|
|
||||||
if r.URL.Query().Get("agentId") != "agent-123" {
|
|
||||||
t.Errorf("agentId param = %q, want agent-123", r.URL.Query().Get("agentId"))
|
|
||||||
}
|
|
||||||
json.NewEncoder(w).Encode(TasksResponse{
|
|
||||||
Tasks: []Task{
|
|
||||||
{
|
|
||||||
ID: "task-uuid-1",
|
|
||||||
InfoHash: "abc123def456abc123def456abc123def456abc1",
|
|
||||||
Title: "The Matrix (1999)",
|
|
||||||
PreferredMethod: "auto",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}))
|
|
||||||
defer srv.Close()
|
|
||||||
|
|
||||||
c := NewClient(srv.URL, "test-key", "unarr-test")
|
|
||||||
resp, err := c.ClaimTasks(context.Background(), "agent-123")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ClaimTasks failed: %v", err)
|
|
||||||
}
|
|
||||||
if len(resp.Tasks) != 1 {
|
|
||||||
t.Fatalf("len(tasks) = %d, want 1", len(resp.Tasks))
|
|
||||||
}
|
|
||||||
if resp.Tasks[0].ID != "task-uuid-1" {
|
|
||||||
t.Errorf("task.ID = %q, want task-uuid-1", resp.Tasks[0].ID)
|
|
||||||
}
|
|
||||||
if resp.Tasks[0].InfoHash != "abc123def456abc123def456abc123def456abc1" {
|
|
||||||
t.Errorf("task.InfoHash = %q", resp.Tasks[0].InfoHash)
|
|
||||||
}
|
|
||||||
if resp.Tasks[0].PreferredMethod != "auto" {
|
|
||||||
t.Errorf("task.PreferredMethod = %q, want auto", resp.Tasks[0].PreferredMethod)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReportStatus(t *testing.T) {
|
func TestReportStatus(t *testing.T) {
|
||||||
var received StatusUpdate
|
var received StatusUpdate
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
@ -173,22 +109,6 @@ func TestReportStatus(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClaimTasksEmpty(t *testing.T) {
|
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
json.NewEncoder(w).Encode(TasksResponse{Tasks: []Task{}})
|
|
||||||
}))
|
|
||||||
defer srv.Close()
|
|
||||||
|
|
||||||
c := NewClient(srv.URL, "test-key", "unarr-test")
|
|
||||||
resp, err := c.ClaimTasks(context.Background(), "agent-123")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ClaimTasks failed: %v", err)
|
|
||||||
}
|
|
||||||
if len(resp.Tasks) != 0 {
|
|
||||||
t.Errorf("expected empty tasks, got %d", len(resp.Tasks))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAPIError(t *testing.T) {
|
func TestAPIError(t *testing.T) {
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
|
@ -279,50 +199,12 @@ func TestUserAgent(t *testing.T) {
|
||||||
if r.Header.Get("User-Agent") != "unarr/0.2.0" {
|
if r.Header.Get("User-Agent") != "unarr/0.2.0" {
|
||||||
t.Errorf("User-Agent = %q, want unarr/0.2.0", r.Header.Get("User-Agent"))
|
t.Errorf("User-Agent = %q, want unarr/0.2.0", r.Header.Get("User-Agent"))
|
||||||
}
|
}
|
||||||
json.NewEncoder(w).Encode(HeartbeatResponse{Success: true})
|
json.NewEncoder(w).Encode(RegisterResponse{Success: true})
|
||||||
}))
|
}))
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
|
|
||||||
c := NewClient(srv.URL, "test-key", "unarr/0.2.0")
|
c := NewClient(srv.URL, "test-key", "unarr/0.2.0")
|
||||||
c.Heartbeat(context.Background(), HeartbeatRequest{AgentID: "x"})
|
c.Register(context.Background(), RegisterRequest{AgentID: "x"})
|
||||||
}
|
|
||||||
|
|
||||||
func TestHeartbeatWithUpgradeSignal(t *testing.T) {
|
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
json.NewEncoder(w).Encode(HeartbeatResponse{
|
|
||||||
Success: true,
|
|
||||||
Upgrade: &UpgradeSignal{Version: "2.0.0"},
|
|
||||||
})
|
|
||||||
}))
|
|
||||||
defer srv.Close()
|
|
||||||
|
|
||||||
c := NewClient(srv.URL, "test-key", "unarr-test")
|
|
||||||
resp, err := c.Heartbeat(context.Background(), HeartbeatRequest{AgentID: "agent-1"})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Heartbeat failed: %v", err)
|
|
||||||
}
|
|
||||||
if resp.Upgrade == nil {
|
|
||||||
t.Fatal("expected upgrade signal, got nil")
|
|
||||||
}
|
|
||||||
if resp.Upgrade.Version != "2.0.0" {
|
|
||||||
t.Errorf("upgrade version = %q, want 2.0.0", resp.Upgrade.Version)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHeartbeatWithoutUpgradeSignal(t *testing.T) {
|
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
json.NewEncoder(w).Encode(HeartbeatResponse{Success: true})
|
|
||||||
}))
|
|
||||||
defer srv.Close()
|
|
||||||
|
|
||||||
c := NewClient(srv.URL, "test-key", "unarr-test")
|
|
||||||
resp, err := c.Heartbeat(context.Background(), HeartbeatRequest{AgentID: "agent-1"})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Heartbeat failed: %v", err)
|
|
||||||
}
|
|
||||||
if resp.Upgrade != nil {
|
|
||||||
t.Errorf("expected no upgrade signal, got %+v", resp.Upgrade)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDeregister(t *testing.T) {
|
func TestDeregister(t *testing.T) {
|
||||||
|
|
|
||||||
|
|
@ -14,75 +14,62 @@ import (
|
||||||
|
|
||||||
// DaemonConfig holds daemon runtime settings.
|
// DaemonConfig holds daemon runtime settings.
|
||||||
type DaemonConfig struct {
|
type DaemonConfig struct {
|
||||||
AgentID string
|
AgentID string
|
||||||
AgentName string
|
AgentName string
|
||||||
Version string
|
Version string
|
||||||
DownloadDir string
|
DownloadDir string
|
||||||
PollInterval time.Duration
|
StreamPort int // port for the HTTP stream server
|
||||||
HeartbeatInterval time.Duration
|
LanIP string // LAN IP (reported in sync for stream URL resolution)
|
||||||
StreamPort int // port for the HTTP stream server (reported in heartbeat)
|
TailscaleIP string // Tailscale IP (reported in sync for stream URL resolution)
|
||||||
LanIP string // LAN IP (reported in heartbeat for stream URL resolution)
|
|
||||||
TailscaleIP string // Tailscale IP (reported in heartbeat for stream URL resolution)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Daemon manages the main loop: register, heartbeat, poll tasks.
|
// Daemon manages agent registration and the sync loop.
|
||||||
type Daemon struct {
|
type Daemon struct {
|
||||||
cfg DaemonConfig
|
cfg DaemonConfig
|
||||||
transport Transport
|
client *Client
|
||||||
|
sync *SyncClient
|
||||||
|
state *LocalState
|
||||||
|
|
||||||
// Callbacks
|
// Callbacks — set by cmd/daemon.go before calling Run.
|
||||||
OnTasksClaimed func(tasks []Task)
|
OnTasksClaimed func(tasks []Task)
|
||||||
OnStreamRequested func(req StreamRequest)
|
OnStreamRequested func(req StreamRequest)
|
||||||
OnControlAction func(action, taskID string)
|
OnControlAction func(action, taskID string, deleteFiles bool)
|
||||||
|
GetActiveCount func() int // returns number of active downloads (wired from manager)
|
||||||
|
|
||||||
// State
|
// State
|
||||||
User UserInfo
|
User UserInfo
|
||||||
Features FeatureFlags
|
Features FeatureFlags
|
||||||
Info AgentInfo
|
Info AgentInfo
|
||||||
State DaemonState
|
State DaemonState
|
||||||
heartbeatFailures int
|
|
||||||
lastNotifiedVersion string
|
lastNotifiedVersion string
|
||||||
|
|
||||||
// Callbacks for state tracking (set by cmd/daemon.go)
|
|
||||||
GetActiveCount func() int
|
|
||||||
GetCleanableBytes func() int64
|
|
||||||
|
|
||||||
// Watching tracks whether a user is viewing download progress in the web UI.
|
// Watching tracks whether a user is viewing download progress in the web UI.
|
||||||
// When false, the progress reporter skips detailed updates (only sends final states).
|
|
||||||
// Accessed from heartbeat goroutine, flush goroutine, and WatchingFunc closure — must be atomic.
|
|
||||||
Watching atomic.Bool
|
Watching atomic.Bool
|
||||||
|
|
||||||
// Exposed tickers for hot-reload
|
// ScanNow triggers an immediate library scan.
|
||||||
PollTicker *time.Ticker
|
|
||||||
HeartbeatTicker *time.Ticker
|
|
||||||
|
|
||||||
// pollNow triggers an immediate poll (e.g. on resume)
|
|
||||||
pollNow chan struct{}
|
|
||||||
|
|
||||||
// ScanNow triggers an immediate library scan (from heartbeat or WebSocket control event)
|
|
||||||
ScanNow chan struct{}
|
ScanNow chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDaemon creates a daemon with the given transport.
|
// NewDaemon creates a daemon with an HTTP client for sync-based communication.
|
||||||
// Use NewHTTPTransport for HTTP-only, or NewHybridTransport for WS+HTTP.
|
func NewDaemon(cfg DaemonConfig, client *Client) *Daemon {
|
||||||
func NewDaemon(cfg DaemonConfig, transport Transport) *Daemon {
|
state := NewLocalState()
|
||||||
if cfg.PollInterval == 0 {
|
|
||||||
cfg.PollInterval = 30 * time.Second
|
|
||||||
}
|
|
||||||
if cfg.HeartbeatInterval == 0 {
|
|
||||||
cfg.HeartbeatInterval = 30 * time.Second
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Daemon{
|
return &Daemon{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
transport: transport,
|
client: client,
|
||||||
pollNow: make(chan struct{}, 1),
|
state: state,
|
||||||
ScanNow: make(chan struct{}, 1),
|
sync: NewSyncClient(client, cfg, state),
|
||||||
|
ScanNow: make(chan struct{}, 1),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Transport returns the configured transport.
|
// SyncClient returns the sync client for external wiring.
|
||||||
func (d *Daemon) Transport() Transport { return d.transport }
|
func (d *Daemon) SyncClient() *SyncClient { return d.sync }
|
||||||
|
|
||||||
|
// UpdateStreamPort updates the stream port reported in sync requests.
|
||||||
|
func (d *Daemon) UpdateStreamPort(port int) {
|
||||||
|
d.cfg.StreamPort = port
|
||||||
|
d.sync.cfg.StreamPort = port
|
||||||
|
}
|
||||||
|
|
||||||
// Register registers the agent and fetches user info + features.
|
// Register registers the agent and fetches user info + features.
|
||||||
// Retries with exponential backoff on transient errors (429, 5xx, network).
|
// Retries with exponential backoff on transient errors (429, 5xx, network).
|
||||||
|
|
@ -109,11 +96,10 @@ func (d *Daemon) Register(ctx context.Context) error {
|
||||||
var resp *RegisterResponse
|
var resp *RegisterResponse
|
||||||
var err error
|
var err error
|
||||||
for attempt := range maxRetries {
|
for attempt := range maxRetries {
|
||||||
resp, err = d.transport.Register(ctx, req)
|
resp, err = d.client.Register(ctx, req)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
// Only retry on transient errors (429, 5xx, network failures)
|
|
||||||
if !isTransientError(err) {
|
if !isTransientError(err) {
|
||||||
return fmt.Errorf("register: %w", err)
|
return fmt.Errorf("register: %w", err)
|
||||||
}
|
}
|
||||||
|
|
@ -154,14 +140,9 @@ func (d *Daemon) Register(ctx context.Context) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run connects the transport, registers the agent, and starts the main loop.
|
// Run registers the agent and starts the sync loop.
|
||||||
// Blocks until ctx is cancelled. Callers must NOT call transport.Connect before Run.
|
// Blocks until ctx is cancelled.
|
||||||
func (d *Daemon) Run(ctx context.Context) error {
|
func (d *Daemon) Run(ctx context.Context) error {
|
||||||
// Connect transport (establishes WebSocket if available, falls back to HTTP)
|
|
||||||
if err := d.transport.Connect(ctx); err != nil {
|
|
||||||
return fmt.Errorf("connect transport: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register
|
// Register
|
||||||
if err := d.Register(ctx); err != nil {
|
if err := d.Register(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -169,163 +150,61 @@ func (d *Daemon) Run(ctx context.Context) error {
|
||||||
|
|
||||||
log.Printf("Agent registered: %s (%s) [%s]", d.User.Name, d.User.Email, d.User.Plan)
|
log.Printf("Agent registered: %s (%s) [%s]", d.User.Name, d.User.Email, d.User.Plan)
|
||||||
log.Printf("Features: torrent=%v debrid=%v usenet=%v", d.Features.Torrent, d.Features.Debrid, d.Features.Usenet)
|
log.Printf("Features: torrent=%v debrid=%v usenet=%v", d.Features.Torrent, d.Features.Debrid, d.Features.Usenet)
|
||||||
log.Printf("Polling every %s, heartbeat every %s", d.cfg.PollInterval, d.cfg.HeartbeatInterval)
|
|
||||||
|
|
||||||
d.HeartbeatTicker = time.NewTicker(d.cfg.HeartbeatInterval)
|
// Wire sync callbacks
|
||||||
defer d.HeartbeatTicker.Stop()
|
d.sync.OnNewTasks = func(tasks []Task) {
|
||||||
|
if d.OnTasksClaimed != nil {
|
||||||
d.PollTicker = time.NewTicker(d.cfg.PollInterval)
|
d.OnTasksClaimed(tasks)
|
||||||
defer d.PollTicker.Stop()
|
|
||||||
|
|
||||||
heartbeatTicker := d.HeartbeatTicker
|
|
||||||
pollTicker := d.PollTicker
|
|
||||||
|
|
||||||
// Initial poll immediately
|
|
||||||
d.poll(ctx)
|
|
||||||
|
|
||||||
eventsCh := d.transport.Events()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
log.Println("Daemon shutting down...")
|
|
||||||
d.deregister()
|
|
||||||
return nil
|
|
||||||
|
|
||||||
case event := <-eventsCh:
|
|
||||||
d.handleEvent(event)
|
|
||||||
|
|
||||||
case <-heartbeatTicker.C:
|
|
||||||
d.heartbeat(ctx)
|
|
||||||
|
|
||||||
case <-pollTicker.C:
|
|
||||||
// Only poll in HTTP mode — WS mode receives tasks via Events
|
|
||||||
if d.transport.Mode() == "http" {
|
|
||||||
d.poll(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
case <-d.pollNow:
|
|
||||||
d.poll(ctx)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
d.sync.OnControl = func(action, taskID string, deleteFiles bool) {
|
||||||
|
if d.OnControlAction != nil {
|
||||||
func (d *Daemon) heartbeat(ctx context.Context) {
|
d.OnControlAction(action, taskID, deleteFiles)
|
||||||
req := HeartbeatRequest{
|
|
||||||
AgentID: d.cfg.AgentID,
|
|
||||||
Name: d.cfg.AgentName,
|
|
||||||
Version: d.cfg.Version,
|
|
||||||
OS: runtime.GOOS,
|
|
||||||
DownloadDir: d.cfg.DownloadDir,
|
|
||||||
StreamPort: d.cfg.StreamPort,
|
|
||||||
LanIP: d.cfg.LanIP,
|
|
||||||
TailscaleIP: d.cfg.TailscaleIP,
|
|
||||||
}
|
|
||||||
if free, total, err := DiskInfo(d.cfg.DownloadDir); err == nil {
|
|
||||||
req.DiskFreeBytes = free
|
|
||||||
req.DiskTotalBytes = total
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := d.transport.SendHeartbeat(ctx, req)
|
|
||||||
if err != nil {
|
|
||||||
d.heartbeatFailures++
|
|
||||||
if d.heartbeatFailures >= 5 && d.heartbeatFailures%5 == 0 {
|
|
||||||
log.Printf("CRITICAL: %d consecutive heartbeat failures — server may be unreachable", d.heartbeatFailures)
|
|
||||||
} else {
|
|
||||||
log.Printf("Heartbeat failed: %v", err)
|
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
if d.heartbeatFailures > 0 {
|
d.sync.OnStreamRequest = func(req StreamRequest) {
|
||||||
log.Printf("Heartbeat recovered after %d failures", d.heartbeatFailures)
|
if d.OnStreamRequested != nil {
|
||||||
d.heartbeatFailures = 0
|
d.OnStreamRequested(req)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
d.sync.OnUpgrade = func(version string) {
|
||||||
// Update watching flag and state file
|
if version != d.lastNotifiedVersion {
|
||||||
d.Watching.Store(resp.Watching)
|
d.lastNotifiedVersion = version
|
||||||
d.State.LastHeartbeat = time.Now()
|
log.Printf("New version available: %s (run `unarr self-update` to upgrade)", version)
|
||||||
if d.GetActiveCount != nil {
|
}
|
||||||
d.State.ActiveTasks = d.GetActiveCount()
|
|
||||||
}
|
}
|
||||||
WriteState(&d.State)
|
d.sync.OnScan = func() {
|
||||||
|
|
||||||
// Trigger library scan if requested
|
|
||||||
if resp.Scan {
|
|
||||||
log.Printf("Library scan requested by server")
|
log.Printf("Library scan requested by server")
|
||||||
select {
|
select {
|
||||||
case d.ScanNow <- struct{}{}:
|
case d.ScanNow <- struct{}{}:
|
||||||
default: // scan already pending
|
default:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
d.sync.OnWatchingChange = func(watching bool) {
|
||||||
// Log once per version when server suggests an upgrade
|
d.Watching.Store(watching)
|
||||||
if resp.Upgrade != nil && resp.Upgrade.Version != "" && resp.Upgrade.Version != d.lastNotifiedVersion {
|
|
||||||
d.lastNotifiedVersion = resp.Upgrade.Version
|
|
||||||
log.Printf("New version available: %s (run `unarr self-update` to upgrade)", resp.Upgrade.Version)
|
|
||||||
}
|
}
|
||||||
}
|
d.sync.OnSyncSuccess = func() {
|
||||||
|
d.State.LastHeartbeat = time.Now()
|
||||||
// handleEvent processes a server-initiated event from the WebSocket transport.
|
if d.GetActiveCount != nil {
|
||||||
func (d *Daemon) handleEvent(event ServerEvent) {
|
d.State.ActiveTasks = d.GetActiveCount()
|
||||||
switch event.Type {
|
|
||||||
case "tasks":
|
|
||||||
if event.Tasks != nil && len(event.Tasks.Tasks) > 0 {
|
|
||||||
log.Printf("Received %d task(s) via WebSocket", len(event.Tasks.Tasks))
|
|
||||||
if d.OnTasksClaimed != nil {
|
|
||||||
d.OnTasksClaimed(event.Tasks.Tasks)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if event.Tasks != nil && d.OnStreamRequested != nil {
|
WriteState(&d.State)
|
||||||
for _, sr := range event.Tasks.StreamRequests {
|
|
||||||
d.OnStreamRequested(sr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
case "upgrade":
|
|
||||||
if event.Upgrade != nil && event.Upgrade.Version != "" && event.Upgrade.Version != d.lastNotifiedVersion {
|
|
||||||
d.lastNotifiedVersion = event.Upgrade.Version
|
|
||||||
log.Printf("New version available: %s (run `unarr self-update` to upgrade)", event.Upgrade.Version)
|
|
||||||
}
|
|
||||||
|
|
||||||
case "control":
|
|
||||||
if event.Control != nil {
|
|
||||||
log.Printf("Control action via WebSocket: %s task %s", event.Control.Action, event.Control.TaskID)
|
|
||||||
if event.Control.Action == "scan" {
|
|
||||||
select {
|
|
||||||
case d.ScanNow <- struct{}{}:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if d.OnControlAction != nil {
|
|
||||||
d.OnControlAction(event.Control.Action, event.Control.TaskID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
case "disconnected":
|
|
||||||
log.Println("WebSocket disconnected, switching to HTTP polling")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Start sync loop (blocks)
|
||||||
|
return d.sync.Run(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateStreamPort updates the stream port reported in heartbeats.
|
// TriggerSync requests an immediate sync cycle.
|
||||||
// Called after the persistent stream server binds (actual port may differ from configured).
|
func (d *Daemon) TriggerSync() {
|
||||||
func (d *Daemon) UpdateStreamPort(port int) {
|
d.sync.TriggerSync()
|
||||||
d.cfg.StreamPort = port
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TriggerPoll requests an immediate task poll cycle.
|
// Deregister notifies the server of graceful shutdown.
|
||||||
// Used when a resume event is received to pick up re-pending tasks faster.
|
func (d *Daemon) Deregister() {
|
||||||
func (d *Daemon) TriggerPoll() {
|
|
||||||
select {
|
|
||||||
case d.pollNow <- struct{}{}:
|
|
||||||
default: // already pending
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Daemon) deregister() {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
err := d.transport.Deregister(ctx, d.cfg.AgentID)
|
if err := d.client.Deregister(ctx, d.cfg.AgentID); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Printf("Deregister failed: %v", err)
|
log.Printf("Deregister failed: %v", err)
|
||||||
} else {
|
} else {
|
||||||
log.Println("Agent deregistered")
|
log.Println("Agent deregistered")
|
||||||
|
|
@ -338,12 +217,10 @@ func isTransientError(err error) bool {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
// Structured check: HTTPError carries the status code directly
|
|
||||||
var httpErr *HTTPError
|
var httpErr *HTTPError
|
||||||
if errors.As(err, &httpErr) {
|
if errors.As(err, &httpErr) {
|
||||||
return httpErr.StatusCode == 429 || httpErr.StatusCode >= 500
|
return httpErr.StatusCode == 429 || httpErr.StatusCode >= 500
|
||||||
}
|
}
|
||||||
// Fallback: network-level errors (no HTTP response received)
|
|
||||||
lower := strings.ToLower(err.Error())
|
lower := strings.ToLower(err.Error())
|
||||||
for _, keyword := range []string{"connection refused", "no such host", "timeout", "request failed"} {
|
for _, keyword := range []string{"connection refused", "no such host", "timeout", "request failed"} {
|
||||||
if strings.Contains(lower, keyword) {
|
if strings.Contains(lower, keyword) {
|
||||||
|
|
@ -352,27 +229,3 @@ func isTransientError(err error) bool {
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Daemon) poll(ctx context.Context) {
|
|
||||||
resp, err := d.transport.ClaimTasks(ctx, d.cfg.AgentID)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Poll failed: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
d.Info.LastPollAt = time.Now()
|
|
||||||
|
|
||||||
if len(resp.Tasks) > 0 {
|
|
||||||
log.Printf("Claimed %d task(s)", len(resp.Tasks))
|
|
||||||
if d.OnTasksClaimed != nil {
|
|
||||||
d.OnTasksClaimed(resp.Tasks)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle stream requests for completed downloads
|
|
||||||
if d.OnStreamRequested != nil {
|
|
||||||
for _, sr := range resp.StreamRequests {
|
|
||||||
d.OnStreamRequested(sr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
||||||
195
internal/agent/sync.go
Normal file
195
internal/agent/sync.go
Normal file
|
|
@ -0,0 +1,195 @@
|
||||||
|
package agent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log"
|
||||||
|
"runtime"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// SyncIntervalWatching is the sync interval when someone is viewing the web UI.
|
||||||
|
SyncIntervalWatching = 3 * time.Second
|
||||||
|
// SyncIntervalIdle is the sync interval when nobody is watching.
|
||||||
|
SyncIntervalIdle = 60 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// SyncClient handles bidirectional state synchronization between the CLI and server.
|
||||||
|
// It sends the CLI's full execution state and receives all pending server actions
|
||||||
|
// in a single HTTP round-trip, at an adaptive interval.
|
||||||
|
type SyncClient struct {
|
||||||
|
client *Client
|
||||||
|
cfg DaemonConfig
|
||||||
|
state *LocalState
|
||||||
|
|
||||||
|
// Callbacks — set by the daemon before calling Run.
|
||||||
|
OnNewTasks func(tasks []Task)
|
||||||
|
OnControl func(action, taskID string, deleteFiles bool)
|
||||||
|
OnStreamRequest func(req StreamRequest)
|
||||||
|
OnUpgrade func(version string)
|
||||||
|
OnScan func()
|
||||||
|
OnWatchingChange func(watching bool)
|
||||||
|
OnSyncSuccess func() // called after each successful sync (e.g. to update state file)
|
||||||
|
GetFreeSlots func() int
|
||||||
|
GetTaskStates func() []TaskState // returns current state of all active + recently finished tasks
|
||||||
|
|
||||||
|
// SyncNow triggers an immediate sync (e.g., on task completion).
|
||||||
|
SyncNow chan struct{}
|
||||||
|
|
||||||
|
watching atomic.Bool
|
||||||
|
interval atomic.Int64 // stored as nanoseconds
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSyncClient creates a sync client.
|
||||||
|
func NewSyncClient(client *Client, cfg DaemonConfig, state *LocalState) *SyncClient {
|
||||||
|
sc := &SyncClient{
|
||||||
|
client: client,
|
||||||
|
cfg: cfg,
|
||||||
|
state: state,
|
||||||
|
SyncNow: make(chan struct{}, 1),
|
||||||
|
}
|
||||||
|
sc.interval.Store(int64(SyncIntervalIdle))
|
||||||
|
return sc
|
||||||
|
}
|
||||||
|
|
||||||
|
// Watching returns whether someone is viewing the web UI.
|
||||||
|
func (sc *SyncClient) Watching() bool {
|
||||||
|
return sc.watching.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TriggerSync requests an immediate sync cycle.
|
||||||
|
func (sc *SyncClient) TriggerSync() {
|
||||||
|
select {
|
||||||
|
case sc.SyncNow <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run starts the adaptive sync loop. Blocks until ctx is cancelled.
|
||||||
|
func (sc *SyncClient) Run(ctx context.Context) error {
|
||||||
|
// Initial sync immediately
|
||||||
|
sc.doSync(ctx)
|
||||||
|
|
||||||
|
ticker := time.NewTicker(sc.currentInterval())
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
// Final sync to report latest state
|
||||||
|
finalCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
sc.doSync(finalCtx)
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case <-ticker.C:
|
||||||
|
sc.doSync(ctx)
|
||||||
|
ticker.Reset(sc.currentInterval())
|
||||||
|
|
||||||
|
case <-sc.SyncNow:
|
||||||
|
sc.doSync(ctx)
|
||||||
|
ticker.Reset(sc.currentInterval())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *SyncClient) currentInterval() time.Duration {
|
||||||
|
return time.Duration(sc.interval.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *SyncClient) doSync(ctx context.Context) {
|
||||||
|
req := sc.buildRequest()
|
||||||
|
resp, err := sc.client.Sync(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() == nil {
|
||||||
|
log.Printf("sync failed: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sc.processResponse(resp)
|
||||||
|
sc.adjustInterval(resp.Watching)
|
||||||
|
if sc.OnSyncSuccess != nil {
|
||||||
|
sc.OnSyncSuccess()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *SyncClient) buildRequest() SyncRequest {
|
||||||
|
req := SyncRequest{
|
||||||
|
AgentID: sc.cfg.AgentID,
|
||||||
|
Name: sc.cfg.AgentName,
|
||||||
|
Version: sc.cfg.Version,
|
||||||
|
OS: runtime.GOOS,
|
||||||
|
Arch: runtime.GOARCH,
|
||||||
|
DownloadDir: sc.cfg.DownloadDir,
|
||||||
|
StreamPort: sc.cfg.StreamPort,
|
||||||
|
LanIP: sc.cfg.LanIP,
|
||||||
|
TailscaleIP: sc.cfg.TailscaleIP,
|
||||||
|
}
|
||||||
|
if sc.GetTaskStates != nil {
|
||||||
|
req.Tasks = sc.GetTaskStates()
|
||||||
|
} else {
|
||||||
|
req.Tasks = sc.state.Snapshot()
|
||||||
|
}
|
||||||
|
if free, total, err := DiskInfo(sc.cfg.DownloadDir); err == nil {
|
||||||
|
req.DiskFreeBytes = free
|
||||||
|
req.DiskTotalBytes = total
|
||||||
|
}
|
||||||
|
if sc.GetFreeSlots != nil {
|
||||||
|
req.FreeSlots = sc.GetFreeSlots()
|
||||||
|
}
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *SyncClient) processResponse(resp *SyncResponse) {
|
||||||
|
// New tasks
|
||||||
|
if len(resp.NewTasks) > 0 && sc.OnNewTasks != nil {
|
||||||
|
log.Printf("sync: received %d new task(s)", len(resp.NewTasks))
|
||||||
|
sc.OnNewTasks(resp.NewTasks)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Control signals
|
||||||
|
for _, ctrl := range resp.Controls {
|
||||||
|
log.Printf("sync: control %s on task %s", ctrl.Action, ShortID(ctrl.TaskID))
|
||||||
|
if sc.OnControl != nil {
|
||||||
|
sc.OnControl(ctrl.Action, ctrl.TaskID, ctrl.DeleteFiles)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stream requests
|
||||||
|
for _, sr := range resp.StreamRequests {
|
||||||
|
if sc.OnStreamRequest != nil {
|
||||||
|
sc.OnStreamRequest(sr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upgrade
|
||||||
|
if resp.Upgrade != nil && resp.Upgrade.Version != "" && sc.OnUpgrade != nil {
|
||||||
|
sc.OnUpgrade(resp.Upgrade.Version)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan
|
||||||
|
if resp.Scan && sc.OnScan != nil {
|
||||||
|
sc.OnScan()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *SyncClient) adjustInterval(watching bool) {
|
||||||
|
prev := sc.watching.Load()
|
||||||
|
sc.watching.Store(watching)
|
||||||
|
|
||||||
|
var newInterval time.Duration
|
||||||
|
if watching {
|
||||||
|
newInterval = SyncIntervalWatching
|
||||||
|
} else {
|
||||||
|
newInterval = SyncIntervalIdle
|
||||||
|
}
|
||||||
|
|
||||||
|
if sc.interval.Swap(int64(newInterval)) != int64(newInterval) {
|
||||||
|
log.Printf("sync: interval=%s (watching=%v)", newInterval, watching)
|
||||||
|
}
|
||||||
|
|
||||||
|
if prev != watching && sc.OnWatchingChange != nil {
|
||||||
|
sc.OnWatchingChange(watching)
|
||||||
|
}
|
||||||
|
}
|
||||||
362
internal/agent/sync_test.go
Normal file
362
internal/agent/sync_test.go
Normal file
|
|
@ -0,0 +1,362 @@
|
||||||
|
package agent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTestSyncClient(url string) (*SyncClient, *Client) {
|
||||||
|
client := NewClient(url, "test-key", "test-agent/1.0")
|
||||||
|
cfg := DaemonConfig{
|
||||||
|
AgentID: "test-agent",
|
||||||
|
AgentName: "Test",
|
||||||
|
Version: "1.0.0",
|
||||||
|
DownloadDir: "/tmp/downloads",
|
||||||
|
}
|
||||||
|
state := NewLocalState()
|
||||||
|
sc := NewSyncClient(client, cfg, state)
|
||||||
|
return sc, client
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncClient_NewDefaults(t *testing.T) {
|
||||||
|
sc, _ := newTestSyncClient("http://localhost")
|
||||||
|
|
||||||
|
if sc.Watching() {
|
||||||
|
t.Error("should not be watching initially")
|
||||||
|
}
|
||||||
|
if sc.currentInterval() != SyncIntervalIdle {
|
||||||
|
t.Errorf("expected idle interval %v, got %v", SyncIntervalIdle, sc.currentInterval())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncClient_AdjustInterval_Watching(t *testing.T) {
|
||||||
|
sc, _ := newTestSyncClient("http://localhost")
|
||||||
|
|
||||||
|
sc.adjustInterval(true)
|
||||||
|
|
||||||
|
if sc.currentInterval() != SyncIntervalWatching {
|
||||||
|
t.Errorf("expected watching interval %v, got %v", SyncIntervalWatching, sc.currentInterval())
|
||||||
|
}
|
||||||
|
if !sc.Watching() {
|
||||||
|
t.Error("expected watching=true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncClient_AdjustInterval_NotWatching(t *testing.T) {
|
||||||
|
sc, _ := newTestSyncClient("http://localhost")
|
||||||
|
|
||||||
|
// First set watching, then unset
|
||||||
|
sc.adjustInterval(true)
|
||||||
|
sc.adjustInterval(false)
|
||||||
|
|
||||||
|
if sc.currentInterval() != SyncIntervalIdle {
|
||||||
|
t.Errorf("expected idle interval %v, got %v", SyncIntervalIdle, sc.currentInterval())
|
||||||
|
}
|
||||||
|
if sc.Watching() {
|
||||||
|
t.Error("expected watching=false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncClient_AdjustInterval_CallsOnWatchingChange(t *testing.T) {
|
||||||
|
sc, _ := newTestSyncClient("http://localhost")
|
||||||
|
|
||||||
|
var changes []bool
|
||||||
|
sc.OnWatchingChange = func(w bool) { changes = append(changes, w) }
|
||||||
|
|
||||||
|
sc.adjustInterval(true)
|
||||||
|
sc.adjustInterval(true) // no change
|
||||||
|
sc.adjustInterval(false) // change
|
||||||
|
|
||||||
|
if len(changes) != 2 {
|
||||||
|
t.Fatalf("expected 2 changes, got %d: %v", len(changes), changes)
|
||||||
|
}
|
||||||
|
if !changes[0] {
|
||||||
|
t.Error("first change should be true")
|
||||||
|
}
|
||||||
|
if changes[1] {
|
||||||
|
t.Error("second change should be false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncClient_TriggerSync_NonBlocking(t *testing.T) {
|
||||||
|
sc, _ := newTestSyncClient("http://localhost")
|
||||||
|
|
||||||
|
// Fill the channel
|
||||||
|
sc.TriggerSync()
|
||||||
|
// Should not block
|
||||||
|
sc.TriggerSync()
|
||||||
|
sc.TriggerSync()
|
||||||
|
|
||||||
|
// Drain
|
||||||
|
select {
|
||||||
|
case <-sc.SyncNow:
|
||||||
|
default:
|
||||||
|
t.Error("expected a sync trigger in channel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncClient_ProcessResponse_NewTasks(t *testing.T) {
|
||||||
|
sc, _ := newTestSyncClient("http://localhost")
|
||||||
|
|
||||||
|
var received []Task
|
||||||
|
sc.OnNewTasks = func(tasks []Task) { received = tasks }
|
||||||
|
|
||||||
|
sc.processResponse(&SyncResponse{
|
||||||
|
NewTasks: []Task{
|
||||||
|
{ID: "t1", Title: "Movie 1", InfoHash: "abc"},
|
||||||
|
{ID: "t2", Title: "Movie 2", InfoHash: "def"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
if len(received) != 2 {
|
||||||
|
t.Fatalf("expected 2 tasks, got %d", len(received))
|
||||||
|
}
|
||||||
|
if received[0].Title != "Movie 1" {
|
||||||
|
t.Errorf("expected Movie 1, got %s", received[0].Title)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncClient_ProcessResponse_NoTasks(t *testing.T) {
|
||||||
|
sc, _ := newTestSyncClient("http://localhost")
|
||||||
|
|
||||||
|
var called bool
|
||||||
|
sc.OnNewTasks = func(tasks []Task) { called = true }
|
||||||
|
|
||||||
|
sc.processResponse(&SyncResponse{NewTasks: nil})
|
||||||
|
|
||||||
|
if called {
|
||||||
|
t.Error("OnNewTasks should not be called with empty tasks")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncClient_ProcessResponse_Controls(t *testing.T) {
|
||||||
|
sc, _ := newTestSyncClient("http://localhost")
|
||||||
|
|
||||||
|
var actions []string
|
||||||
|
var taskIDs []string
|
||||||
|
sc.OnControl = func(action, taskID string, deleteFiles bool) {
|
||||||
|
actions = append(actions, action)
|
||||||
|
taskIDs = append(taskIDs, taskID)
|
||||||
|
}
|
||||||
|
|
||||||
|
sc.processResponse(&SyncResponse{
|
||||||
|
Controls: []ControlAction{
|
||||||
|
{Action: "cancel", TaskID: "task-1234-5678"},
|
||||||
|
{Action: "pause", TaskID: "task-abcd-efgh"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
if len(actions) != 2 {
|
||||||
|
t.Fatalf("expected 2 controls, got %d", len(actions))
|
||||||
|
}
|
||||||
|
if actions[0] != "cancel" {
|
||||||
|
t.Errorf("expected cancel, got %s", actions[0])
|
||||||
|
}
|
||||||
|
if actions[1] != "pause" {
|
||||||
|
t.Errorf("expected pause, got %s", actions[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncClient_ProcessResponse_Upgrade(t *testing.T) {
|
||||||
|
sc, _ := newTestSyncClient("http://localhost")
|
||||||
|
|
||||||
|
var version string
|
||||||
|
sc.OnUpgrade = func(v string) { version = v }
|
||||||
|
|
||||||
|
sc.processResponse(&SyncResponse{
|
||||||
|
Upgrade: &UpgradeSignal{Version: "2.0.0"},
|
||||||
|
})
|
||||||
|
|
||||||
|
if version != "2.0.0" {
|
||||||
|
t.Errorf("expected 2.0.0, got %s", version)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncClient_ProcessResponse_UpgradeEmpty(t *testing.T) {
|
||||||
|
sc, _ := newTestSyncClient("http://localhost")
|
||||||
|
|
||||||
|
var called bool
|
||||||
|
sc.OnUpgrade = func(v string) { called = true }
|
||||||
|
|
||||||
|
sc.processResponse(&SyncResponse{
|
||||||
|
Upgrade: &UpgradeSignal{Version: ""},
|
||||||
|
})
|
||||||
|
|
||||||
|
if called {
|
||||||
|
t.Error("OnUpgrade should not be called with empty version")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncClient_ProcessResponse_Scan(t *testing.T) {
|
||||||
|
sc, _ := newTestSyncClient("http://localhost")
|
||||||
|
|
||||||
|
var called bool
|
||||||
|
sc.OnScan = func() { called = true }
|
||||||
|
|
||||||
|
sc.processResponse(&SyncResponse{Scan: true})
|
||||||
|
|
||||||
|
if !called {
|
||||||
|
t.Error("OnScan should have been called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncClient_ProcessResponse_StreamRequests(t *testing.T) {
|
||||||
|
sc, _ := newTestSyncClient("http://localhost")
|
||||||
|
|
||||||
|
var received []StreamRequest
|
||||||
|
sc.OnStreamRequest = func(sr StreamRequest) { received = append(received, sr) }
|
||||||
|
|
||||||
|
sc.processResponse(&SyncResponse{
|
||||||
|
StreamRequests: []StreamRequest{
|
||||||
|
{TaskID: "t1", FilePath: "/tmp/movie.mkv"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
if len(received) != 1 {
|
||||||
|
t.Fatalf("expected 1 stream request, got %d", len(received))
|
||||||
|
}
|
||||||
|
if received[0].FilePath != "/tmp/movie.mkv" {
|
||||||
|
t.Errorf("expected /tmp/movie.mkv, got %s", received[0].FilePath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncClient_BuildRequest_WithGetTaskStates(t *testing.T) {
|
||||||
|
sc, _ := newTestSyncClient("http://localhost")
|
||||||
|
|
||||||
|
sc.GetTaskStates = func() []TaskState {
|
||||||
|
return []TaskState{
|
||||||
|
{TaskID: "t1", Status: "downloading", Progress: 50},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sc.GetFreeSlots = func() int { return 2 }
|
||||||
|
|
||||||
|
req := sc.buildRequest()
|
||||||
|
|
||||||
|
if req.AgentID != "test-agent" {
|
||||||
|
t.Errorf("expected test-agent, got %s", req.AgentID)
|
||||||
|
}
|
||||||
|
if len(req.Tasks) != 1 {
|
||||||
|
t.Fatalf("expected 1 task, got %d", len(req.Tasks))
|
||||||
|
}
|
||||||
|
if req.Tasks[0].Progress != 50 {
|
||||||
|
t.Errorf("expected progress 50, got %d", req.Tasks[0].Progress)
|
||||||
|
}
|
||||||
|
if req.FreeSlots != 2 {
|
||||||
|
t.Errorf("expected 2 free slots, got %d", req.FreeSlots)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncClient_BuildRequest_FallbackToState(t *testing.T) {
|
||||||
|
client := NewClient("http://localhost", "key", "ua")
|
||||||
|
state := NewLocalState()
|
||||||
|
state.Update(TaskState{TaskID: "t1", Status: "completed", Progress: 100})
|
||||||
|
|
||||||
|
sc := NewSyncClient(client, DaemonConfig{AgentID: "a1", Version: "1.0"}, state)
|
||||||
|
// GetTaskStates is nil — should fall back to state.Snapshot()
|
||||||
|
|
||||||
|
req := sc.buildRequest()
|
||||||
|
if len(req.Tasks) != 1 {
|
||||||
|
t.Fatalf("expected 1 task from state fallback, got %d", len(req.Tasks))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncClient_DoSync_Success(t *testing.T) {
|
||||||
|
var syncCount atomic.Int32
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
syncCount.Add(1)
|
||||||
|
json.NewEncoder(w).Encode(SyncResponse{
|
||||||
|
Watching: true,
|
||||||
|
NewTasks: []Task{{ID: "t1", Title: "Test Movie", InfoHash: "abc"}},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
sc, _ := newTestSyncClient(srv.URL)
|
||||||
|
|
||||||
|
var tasksReceived []Task
|
||||||
|
sc.OnNewTasks = func(tasks []Task) { tasksReceived = tasks }
|
||||||
|
|
||||||
|
sc.doSync(context.Background())
|
||||||
|
|
||||||
|
if syncCount.Load() != 1 {
|
||||||
|
t.Errorf("expected 1 sync call, got %d", syncCount.Load())
|
||||||
|
}
|
||||||
|
if len(tasksReceived) != 1 {
|
||||||
|
t.Fatalf("expected 1 task, got %d", len(tasksReceived))
|
||||||
|
}
|
||||||
|
if !sc.Watching() {
|
||||||
|
t.Error("expected watching=true after sync")
|
||||||
|
}
|
||||||
|
if sc.currentInterval() != SyncIntervalWatching {
|
||||||
|
t.Errorf("expected watching interval after sync")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncClient_DoSync_Error(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
sc, _ := newTestSyncClient(srv.URL)
|
||||||
|
|
||||||
|
// Should not panic on error
|
||||||
|
sc.doSync(context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncClient_Run_CancelStopsLoop(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
json.NewEncoder(w).Encode(SyncResponse{})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
sc, _ := newTestSyncClient(srv.URL)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err := sc.Run(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("expected nil error, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncClient_Run_ImmediateSyncOnTrigger(t *testing.T) {
|
||||||
|
var syncCount atomic.Int32
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
syncCount.Add(1)
|
||||||
|
json.NewEncoder(w).Encode(SyncResponse{})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
sc, _ := newTestSyncClient(srv.URL)
|
||||||
|
// Set interval to something long so only triggers cause syncs
|
||||||
|
sc.interval.Store(int64(10 * time.Second))
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
// Wait for initial sync, then trigger 2 more
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
sc.TriggerSync()
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
sc.TriggerSync()
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
sc.Run(ctx)
|
||||||
|
|
||||||
|
// Initial sync (1) + 2 triggers + final sync = 4
|
||||||
|
count := syncCount.Load()
|
||||||
|
if count < 3 {
|
||||||
|
t.Errorf("expected at least 3 syncs (initial + 2 triggers), got %d", count)
|
||||||
|
}
|
||||||
|
}
|
||||||
136
internal/agent/taskstate.go
Normal file
136
internal/agent/taskstate.go
Normal file
|
|
@ -0,0 +1,136 @@
|
||||||
|
package agent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/torrentclaw/unarr/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TaskState represents the execution state of a single download task.
|
||||||
|
// Written by the Task Engine, read by the Sync goroutine.
|
||||||
|
type TaskState struct {
|
||||||
|
TaskID string `json:"taskId"`
|
||||||
|
Status string `json:"status"` // resolving, downloading, verifying, organizing, completed, failed
|
||||||
|
Progress int `json:"progress"`
|
||||||
|
DownloadedBytes int64 `json:"downloadedBytes,omitempty"`
|
||||||
|
TotalBytes int64 `json:"totalBytes,omitempty"`
|
||||||
|
SpeedBps int64 `json:"speedBps,omitempty"`
|
||||||
|
ETA int `json:"eta,omitempty"`
|
||||||
|
ResolvedMethod string `json:"resolvedMethod,omitempty"`
|
||||||
|
FileName string `json:"fileName,omitempty"`
|
||||||
|
FilePath string `json:"filePath,omitempty"`
|
||||||
|
StreamURL string `json:"streamUrl,omitempty"`
|
||||||
|
ErrorMessage string `json:"errorMessage,omitempty"`
|
||||||
|
UpdatedAt int64 `json:"updatedAt"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// LocalState holds the CLI's local execution state (tasks.json).
|
||||||
|
// This is the CLI's source of truth for what it's doing right now.
|
||||||
|
type LocalState struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
tasks map[string]*TaskState
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewLocalState creates an empty local state.
|
||||||
|
func NewLocalState() *LocalState {
|
||||||
|
return &LocalState{
|
||||||
|
tasks: make(map[string]*TaskState),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update adds or updates a task in local state.
|
||||||
|
func (s *LocalState) Update(ts TaskState) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
ts.UpdatedAt = time.Now().Unix()
|
||||||
|
copied := ts
|
||||||
|
s.tasks[ts.TaskID] = &copied
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove removes a task from local state.
|
||||||
|
func (s *LocalState) Remove(taskID string) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
delete(s.tasks, taskID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Snapshot returns a copy of all current task states.
|
||||||
|
func (s *LocalState) Snapshot() []TaskState {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
result := make([]TaskState, 0, len(s.tasks))
|
||||||
|
for _, ts := range s.tasks {
|
||||||
|
result = append(result, *ts)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// TaskStateFromUpdate converts a StatusUpdate into a TaskState.
|
||||||
|
func TaskStateFromUpdate(u StatusUpdate) TaskState {
|
||||||
|
return TaskState{
|
||||||
|
TaskID: u.TaskID,
|
||||||
|
Status: u.Status,
|
||||||
|
Progress: u.Progress,
|
||||||
|
DownloadedBytes: u.DownloadedBytes,
|
||||||
|
TotalBytes: u.TotalBytes,
|
||||||
|
SpeedBps: u.SpeedBps,
|
||||||
|
ETA: u.ETA,
|
||||||
|
ResolvedMethod: u.ResolvedMethod,
|
||||||
|
FileName: u.FileName,
|
||||||
|
FilePath: u.FilePath,
|
||||||
|
StreamURL: u.StreamURL,
|
||||||
|
ErrorMessage: u.ErrorMessage,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShortID returns the first 8 characters of an ID, or the full ID if shorter.
|
||||||
|
func ShortID(id string) string {
|
||||||
|
if len(id) > 8 {
|
||||||
|
return id[:8]
|
||||||
|
}
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
// taskStateFilePathFn is overridable for testing.
|
||||||
|
var taskStateFilePathFn = func() string {
|
||||||
|
return filepath.Join(config.DataDir(), "tasks.json")
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteToDisk persists local state to disk atomically (best-effort).
|
||||||
|
func (s *LocalState) WriteToDisk() {
|
||||||
|
tasks := s.Snapshot()
|
||||||
|
data, err := json.MarshalIndent(tasks, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
path := taskStateFilePathFn()
|
||||||
|
dir := filepath.Dir(path)
|
||||||
|
os.MkdirAll(dir, 0o755)
|
||||||
|
tmp := path + ".tmp"
|
||||||
|
if err := os.WriteFile(tmp, data, 0o644); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
os.Rename(tmp, path)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadFromDisk loads local state from disk. Returns empty state on error.
|
||||||
|
func (s *LocalState) ReadFromDisk() {
|
||||||
|
data, err := os.ReadFile(taskStateFilePathFn())
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var tasks []TaskState
|
||||||
|
if json.Unmarshal(data, &tasks) != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.tasks = make(map[string]*TaskState, len(tasks))
|
||||||
|
for i := range tasks {
|
||||||
|
s.tasks[tasks[i].TaskID] = &tasks[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
217
internal/agent/taskstate_test.go
Normal file
217
internal/agent/taskstate_test.go
Normal file
|
|
@ -0,0 +1,217 @@
|
||||||
|
package agent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLocalState_UpdateAndSnapshot(t *testing.T) {
|
||||||
|
s := NewLocalState()
|
||||||
|
|
||||||
|
s.Update(TaskState{TaskID: "t1", Status: "downloading", Progress: 50})
|
||||||
|
s.Update(TaskState{TaskID: "t2", Status: "completed", Progress: 100})
|
||||||
|
|
||||||
|
snap := s.Snapshot()
|
||||||
|
if len(snap) != 2 {
|
||||||
|
t.Fatalf("expected 2 tasks, got %d", len(snap))
|
||||||
|
}
|
||||||
|
|
||||||
|
byID := make(map[string]TaskState, len(snap))
|
||||||
|
for _, ts := range snap {
|
||||||
|
byID[ts.TaskID] = ts
|
||||||
|
}
|
||||||
|
|
||||||
|
if byID["t1"].Progress != 50 {
|
||||||
|
t.Errorf("expected progress 50, got %d", byID["t1"].Progress)
|
||||||
|
}
|
||||||
|
if byID["t2"].Status != "completed" {
|
||||||
|
t.Errorf("expected completed, got %s", byID["t2"].Status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLocalState_UpdateOverwrites(t *testing.T) {
|
||||||
|
s := NewLocalState()
|
||||||
|
|
||||||
|
s.Update(TaskState{TaskID: "t1", Status: "downloading", Progress: 30})
|
||||||
|
s.Update(TaskState{TaskID: "t1", Status: "downloading", Progress: 70})
|
||||||
|
|
||||||
|
snap := s.Snapshot()
|
||||||
|
if len(snap) != 1 {
|
||||||
|
t.Fatalf("expected 1 task, got %d", len(snap))
|
||||||
|
}
|
||||||
|
if snap[0].Progress != 70 {
|
||||||
|
t.Errorf("expected progress 70, got %d", snap[0].Progress)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLocalState_Remove(t *testing.T) {
|
||||||
|
s := NewLocalState()
|
||||||
|
|
||||||
|
s.Update(TaskState{TaskID: "t1", Status: "downloading"})
|
||||||
|
s.Update(TaskState{TaskID: "t2", Status: "downloading"})
|
||||||
|
s.Remove("t1")
|
||||||
|
|
||||||
|
snap := s.Snapshot()
|
||||||
|
if len(snap) != 1 {
|
||||||
|
t.Fatalf("expected 1 task, got %d", len(snap))
|
||||||
|
}
|
||||||
|
if snap[0].TaskID != "t2" {
|
||||||
|
t.Errorf("expected t2, got %s", snap[0].TaskID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLocalState_RemoveNonExistent(t *testing.T) {
|
||||||
|
s := NewLocalState()
|
||||||
|
s.Remove("nonexistent") // should not panic
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLocalState_SnapshotIsACopy(t *testing.T) {
|
||||||
|
s := NewLocalState()
|
||||||
|
s.Update(TaskState{TaskID: "t1", Status: "downloading", Progress: 50})
|
||||||
|
|
||||||
|
snap := s.Snapshot()
|
||||||
|
snap[0].Progress = 999
|
||||||
|
|
||||||
|
snap2 := s.Snapshot()
|
||||||
|
if snap2[0].Progress != 50 {
|
||||||
|
t.Errorf("snapshot mutation leaked: got progress %d", snap2[0].Progress)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLocalState_UpdateSetsTimestamp(t *testing.T) {
|
||||||
|
s := NewLocalState()
|
||||||
|
s.Update(TaskState{TaskID: "t1", Status: "downloading"})
|
||||||
|
|
||||||
|
snap := s.Snapshot()
|
||||||
|
if snap[0].UpdatedAt == 0 {
|
||||||
|
t.Error("expected non-zero UpdatedAt")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLocalState_ConcurrentAccess(t *testing.T) {
|
||||||
|
s := NewLocalState()
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
for i := range 100 {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(n int) {
|
||||||
|
defer wg.Done()
|
||||||
|
taskID := "t" + string(rune('0'+n%10))
|
||||||
|
s.Update(TaskState{TaskID: taskID, Status: "downloading", Progress: n})
|
||||||
|
s.Snapshot()
|
||||||
|
if n%3 == 0 {
|
||||||
|
s.Remove(taskID)
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
// No race condition = test passes
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLocalState_WriteToDisk_ReadFromDisk(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "tasks.json")
|
||||||
|
|
||||||
|
// Override the file path for testing
|
||||||
|
orig := taskStateFilePathFn
|
||||||
|
taskStateFilePathFn = func() string { return path }
|
||||||
|
defer func() { taskStateFilePathFn = orig }()
|
||||||
|
|
||||||
|
s := NewLocalState()
|
||||||
|
s.Update(TaskState{TaskID: "t1", Status: "downloading", Progress: 45})
|
||||||
|
s.Update(TaskState{TaskID: "t2", Status: "completed", Progress: 100, FilePath: "/tmp/movie.mkv"})
|
||||||
|
s.WriteToDisk()
|
||||||
|
|
||||||
|
// Verify file exists
|
||||||
|
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||||
|
t.Fatal("tasks.json was not created")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read into a new LocalState
|
||||||
|
s2 := NewLocalState()
|
||||||
|
s2.ReadFromDisk()
|
||||||
|
|
||||||
|
snap := s2.Snapshot()
|
||||||
|
if len(snap) != 2 {
|
||||||
|
t.Fatalf("expected 2 tasks after read, got %d", len(snap))
|
||||||
|
}
|
||||||
|
|
||||||
|
byID := make(map[string]TaskState, len(snap))
|
||||||
|
for _, ts := range snap {
|
||||||
|
byID[ts.TaskID] = ts
|
||||||
|
}
|
||||||
|
|
||||||
|
if byID["t1"].Progress != 45 {
|
||||||
|
t.Errorf("expected progress 45, got %d", byID["t1"].Progress)
|
||||||
|
}
|
||||||
|
if byID["t2"].FilePath != "/tmp/movie.mkv" {
|
||||||
|
t.Errorf("expected /tmp/movie.mkv, got %s", byID["t2"].FilePath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLocalState_ReadFromDisk_CorruptedFile(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "tasks.json")
|
||||||
|
|
||||||
|
orig := taskStateFilePathFn
|
||||||
|
taskStateFilePathFn = func() string { return path }
|
||||||
|
defer func() { taskStateFilePathFn = orig }()
|
||||||
|
|
||||||
|
// Write corrupted JSON
|
||||||
|
os.WriteFile(path, []byte("{invalid json"), 0o644)
|
||||||
|
|
||||||
|
s := NewLocalState()
|
||||||
|
s.ReadFromDisk() // should not panic
|
||||||
|
|
||||||
|
snap := s.Snapshot()
|
||||||
|
if len(snap) != 0 {
|
||||||
|
t.Errorf("expected 0 tasks from corrupted file, got %d", len(snap))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLocalState_ReadFromDisk_FileNotFound(t *testing.T) {
|
||||||
|
orig := taskStateFilePathFn
|
||||||
|
taskStateFilePathFn = func() string { return "/nonexistent/path/tasks.json" }
|
||||||
|
defer func() { taskStateFilePathFn = orig }()
|
||||||
|
|
||||||
|
s := NewLocalState()
|
||||||
|
s.ReadFromDisk() // should not panic
|
||||||
|
|
||||||
|
snap := s.Snapshot()
|
||||||
|
if len(snap) != 0 {
|
||||||
|
t.Errorf("expected 0 tasks, got %d", len(snap))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLocalState_AtomicWrite(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "tasks.json")
|
||||||
|
|
||||||
|
orig := taskStateFilePathFn
|
||||||
|
taskStateFilePathFn = func() string { return path }
|
||||||
|
defer func() { taskStateFilePathFn = orig }()
|
||||||
|
|
||||||
|
s := NewLocalState()
|
||||||
|
s.Update(TaskState{TaskID: "t1", Status: "downloading"})
|
||||||
|
s.WriteToDisk()
|
||||||
|
|
||||||
|
// Verify no .tmp file remains
|
||||||
|
tmpPath := path + ".tmp"
|
||||||
|
if _, err := os.Stat(tmpPath); !os.IsNotExist(err) {
|
||||||
|
t.Error("temp file should not exist after write")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLocalState_EmptySnapshot(t *testing.T) {
|
||||||
|
s := NewLocalState()
|
||||||
|
snap := s.Snapshot()
|
||||||
|
if snap == nil {
|
||||||
|
t.Error("snapshot should be non-nil empty slice")
|
||||||
|
}
|
||||||
|
if len(snap) != 0 {
|
||||||
|
t.Errorf("expected 0 tasks, got %d", len(snap))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,51 +0,0 @@
|
||||||
package agent
|
|
||||||
|
|
||||||
import "context"
|
|
||||||
|
|
||||||
// Transport abstracts the communication protocol between the agent and server.
|
|
||||||
// Both WebSocket (via CF Durable Object) and HTTP (direct to origin) implement this.
|
|
||||||
type Transport interface {
|
|
||||||
// Connect establishes the transport connection.
|
|
||||||
// Called internally by Daemon.Run — callers must NOT call Connect separately.
|
|
||||||
Connect(ctx context.Context) error
|
|
||||||
|
|
||||||
// Close tears down the connection gracefully.
|
|
||||||
Close() error
|
|
||||||
|
|
||||||
// Mode returns the current transport mode ("ws" or "http").
|
|
||||||
Mode() string
|
|
||||||
|
|
||||||
// Register sends agent registration and returns user info + features.
|
|
||||||
Register(ctx context.Context, req RegisterRequest) (*RegisterResponse, error)
|
|
||||||
|
|
||||||
// SendHeartbeat sends a periodic keep-alive.
|
|
||||||
SendHeartbeat(ctx context.Context, req HeartbeatRequest) (*HeartbeatResponse, error)
|
|
||||||
|
|
||||||
// SendProgress reports download progress for a task.
|
|
||||||
SendProgress(ctx context.Context, update StatusUpdate) (*StatusResponse, error)
|
|
||||||
|
|
||||||
// ClaimTasks polls for new tasks (HTTP mode only; WS receives via Events).
|
|
||||||
ClaimTasks(ctx context.Context, agentID string) (*TasksResponse, error)
|
|
||||||
|
|
||||||
// Deregister notifies the server of graceful shutdown.
|
|
||||||
Deregister(ctx context.Context, agentID string) error
|
|
||||||
|
|
||||||
// Events returns a channel that emits server-initiated events.
|
|
||||||
// In HTTP mode this channel is never written to (polling handles it).
|
|
||||||
// In WS mode, tasks/upgrade/control arrive here.
|
|
||||||
Events() <-chan ServerEvent
|
|
||||||
}
|
|
||||||
|
|
||||||
// ServerEvent represents a server-initiated message received via WebSocket.
|
|
||||||
type ServerEvent struct {
|
|
||||||
Type string // "tasks", "upgrade", "control", "disconnected"
|
|
||||||
Tasks *TasksResponse // populated when Type == "tasks"
|
|
||||||
Upgrade *UpgradeSignal // populated when Type == "upgrade"
|
|
||||||
Control *ControlAction // populated when Type == "control"
|
|
||||||
}
|
|
||||||
|
|
||||||
// ControlAction represents a server push for task control.
|
|
||||||
type ControlAction struct {
|
|
||||||
Action string `json:"action"` // "pause", "resume", "cancel", "stream"
|
|
||||||
TaskID string `json:"taskId"`
|
|
||||||
}
|
|
||||||
|
|
@ -1,285 +0,0 @@
|
||||||
package agent
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TestE2EFullLifecycle tests the full lifecycle:
|
|
||||||
// connect → auth → receive tasks → send progress → receive control → disconnect → reconnect
|
|
||||||
func TestE2EFullLifecycle(t *testing.T) {
|
|
||||||
var mu sync.Mutex
|
|
||||||
var receivedMessages []map[string]interface{}
|
|
||||||
|
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
conn, err := upgrader.Upgrade(w, r, nil)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
for {
|
|
||||||
_, msg, err := conn.ReadMessage()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var parsed map[string]interface{}
|
|
||||||
json.Unmarshal(msg, &parsed)
|
|
||||||
|
|
||||||
mu.Lock()
|
|
||||||
receivedMessages = append(receivedMessages, parsed)
|
|
||||||
mu.Unlock()
|
|
||||||
|
|
||||||
msgType, _ := parsed["type"].(string)
|
|
||||||
switch msgType {
|
|
||||||
case "auth":
|
|
||||||
conn.WriteJSON(wsRegisteredMessage{
|
|
||||||
Type: "registered",
|
|
||||||
User: UserInfo{Name: "E2E User", Plan: "pro", IsPro: true},
|
|
||||||
Features: FeatureFlags{Torrent: true, Debrid: true},
|
|
||||||
})
|
|
||||||
|
|
||||||
case "heartbeat":
|
|
||||||
// No response in WS mode
|
|
||||||
|
|
||||||
case "progress":
|
|
||||||
// Simulate server-side cancel after progress
|
|
||||||
if progress, ok := parsed["progress"].(float64); ok && progress >= 50 {
|
|
||||||
conn.WriteJSON(map[string]string{
|
|
||||||
"type": "control",
|
|
||||||
"action": "cancel",
|
|
||||||
"taskId": parsed["taskId"].(string),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
case "upgrade-result":
|
|
||||||
// Acknowledged
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
defer srv.Close()
|
|
||||||
|
|
||||||
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
|
|
||||||
tr := NewWSTransport(wsURL, "e2e-key", "e2e-agent", "test/1.0")
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
// 1. Connect
|
|
||||||
if err := tr.Connect(ctx); err != nil {
|
|
||||||
t.Fatalf("Connect: %v", err)
|
|
||||||
}
|
|
||||||
defer tr.Close()
|
|
||||||
|
|
||||||
// 2. Auth
|
|
||||||
resp, err := tr.Register(ctx, RegisterRequest{
|
|
||||||
AgentID: "e2e-agent",
|
|
||||||
Name: "E2E Test Agent",
|
|
||||||
Version: "1.0.0",
|
|
||||||
OS: "linux",
|
|
||||||
Arch: "amd64",
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Register: %v", err)
|
|
||||||
}
|
|
||||||
if resp.User.Name != "E2E User" {
|
|
||||||
t.Errorf("expected E2E User, got %s", resp.User.Name)
|
|
||||||
}
|
|
||||||
if !resp.Features.Debrid {
|
|
||||||
t.Error("expected debrid feature")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3. Send heartbeat
|
|
||||||
_, err = tr.SendHeartbeat(ctx, HeartbeatRequest{
|
|
||||||
AgentID: "e2e-agent",
|
|
||||||
DiskFreeBytes: 1000000000,
|
|
||||||
DiskTotalBytes: 5000000000,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("SendHeartbeat: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 4. Send progress (50% → should trigger cancel control)
|
|
||||||
_, err = tr.SendProgress(ctx, StatusUpdate{
|
|
||||||
TaskID: "task-e2e-1",
|
|
||||||
Status: "downloading",
|
|
||||||
Progress: 50,
|
|
||||||
DownloadedBytes: 500,
|
|
||||||
TotalBytes: 1000,
|
|
||||||
SpeedBps: 100,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("SendProgress: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 5. Wait for control event (cancel)
|
|
||||||
select {
|
|
||||||
case event := <-tr.Events():
|
|
||||||
if event.Type != "control" {
|
|
||||||
t.Errorf("expected control event, got %s", event.Type)
|
|
||||||
}
|
|
||||||
if event.Control.Action != "cancel" {
|
|
||||||
t.Errorf("expected cancel, got %s", event.Control.Action)
|
|
||||||
}
|
|
||||||
if event.Control.TaskID != "task-e2e-1" {
|
|
||||||
t.Errorf("expected task-e2e-1, got %s", event.Control.TaskID)
|
|
||||||
}
|
|
||||||
case <-time.After(3 * time.Second):
|
|
||||||
t.Fatal("timeout waiting for cancel control")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify server received all messages
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
mu.Lock()
|
|
||||||
defer mu.Unlock()
|
|
||||||
|
|
||||||
if len(receivedMessages) < 3 {
|
|
||||||
t.Fatalf("expected at least 3 messages, got %d", len(receivedMessages))
|
|
||||||
}
|
|
||||||
|
|
||||||
types := make([]string, len(receivedMessages))
|
|
||||||
for i, m := range receivedMessages {
|
|
||||||
types[i], _ = m["type"].(string)
|
|
||||||
}
|
|
||||||
|
|
||||||
expected := []string{"auth", "heartbeat", "progress"}
|
|
||||||
for _, exp := range expected {
|
|
||||||
found := false
|
|
||||||
for _, got := range types {
|
|
||||||
if got == exp {
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !found {
|
|
||||||
t.Errorf("missing message type %q in %v", exp, types)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestE2EHybridFailover tests the full failover scenario:
|
|
||||||
// WS connect → download → WS disconnect → switch to HTTP → continue working
|
|
||||||
func TestE2EHybridFailover(t *testing.T) {
|
|
||||||
connectionCount := 0
|
|
||||||
var mu sync.Mutex
|
|
||||||
|
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
conn, err := upgrader.Upgrade(w, r, nil)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
mu.Lock()
|
|
||||||
connectionCount++
|
|
||||||
connNum := connectionCount
|
|
||||||
mu.Unlock()
|
|
||||||
|
|
||||||
// Read auth
|
|
||||||
conn.ReadMessage()
|
|
||||||
conn.WriteJSON(wsRegisteredMessage{
|
|
||||||
Type: "registered",
|
|
||||||
User: UserInfo{Name: "Failover User"},
|
|
||||||
})
|
|
||||||
|
|
||||||
if connNum == 1 {
|
|
||||||
// First connection: push tasks then disconnect after 200ms
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
|
||||||
conn.WriteJSON(wsTasksMessage{
|
|
||||||
Type: "tasks",
|
|
||||||
Tasks: []Task{{ID: "t1", InfoHash: "abc", Title: "Failover Movie"}},
|
|
||||||
})
|
|
||||||
time.Sleep(150 * time.Millisecond)
|
|
||||||
conn.Close()
|
|
||||||
} else {
|
|
||||||
// Second connection (after reconnect): push upgrade
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
|
||||||
conn.WriteJSON(wsUpgradeMessage{Type: "upgrade", Version: "3.0.0"})
|
|
||||||
time.Sleep(500 * time.Millisecond)
|
|
||||||
conn.Close()
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
defer srv.Close()
|
|
||||||
|
|
||||||
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
|
|
||||||
wsT := NewWSTransport(wsURL, "key", "a1", "ua")
|
|
||||||
|
|
||||||
// HTTP mock for fallback
|
|
||||||
httpSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// Simple heartbeat response
|
|
||||||
json.NewEncoder(w).Encode(HeartbeatResponse{Success: true})
|
|
||||||
}))
|
|
||||||
defer httpSrv.Close()
|
|
||||||
|
|
||||||
httpT := NewHTTPTransport(httpSrv.URL, "key", "ua")
|
|
||||||
h := NewHybridTransport(wsT, httpT)
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
err := h.Connect(ctx)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Connect: %v", err)
|
|
||||||
}
|
|
||||||
defer h.Close()
|
|
||||||
|
|
||||||
// Should start in WS mode
|
|
||||||
if h.Mode() != "ws" {
|
|
||||||
t.Fatalf("expected ws mode, got %s", h.Mode())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register via WS
|
|
||||||
_, err = h.Register(ctx, RegisterRequest{AgentID: "a1"})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Register: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Receive tasks via WS
|
|
||||||
var tasksReceived bool
|
|
||||||
var disconnected bool
|
|
||||||
|
|
||||||
for i := 0; i < 3; i++ {
|
|
||||||
select {
|
|
||||||
case event := <-h.Events():
|
|
||||||
switch event.Type {
|
|
||||||
case "tasks":
|
|
||||||
tasksReceived = true
|
|
||||||
if len(event.Tasks.Tasks) != 1 || event.Tasks.Tasks[0].Title != "Failover Movie" {
|
|
||||||
t.Errorf("unexpected tasks: %+v", event.Tasks)
|
|
||||||
}
|
|
||||||
case "disconnected":
|
|
||||||
disconnected = true
|
|
||||||
}
|
|
||||||
case <-time.After(2 * time.Second):
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if disconnected {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !tasksReceived {
|
|
||||||
t.Error("did not receive tasks before disconnect")
|
|
||||||
}
|
|
||||||
if !disconnected {
|
|
||||||
t.Error("did not receive disconnect event")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should now be in HTTP mode
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
if h.Mode() != "http" {
|
|
||||||
t.Errorf("expected http mode after disconnect, got %s", h.Mode())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Heartbeat should work via HTTP fallback
|
|
||||||
hbResp, err := h.SendHeartbeat(ctx, HeartbeatRequest{AgentID: "a1"})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("SendHeartbeat via HTTP fallback: %v", err)
|
|
||||||
}
|
|
||||||
if !hbResp.Success {
|
|
||||||
t.Error("expected heartbeat success")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,50 +0,0 @@
|
||||||
package agent
|
|
||||||
|
|
||||||
import "context"
|
|
||||||
|
|
||||||
// HTTPTransport wraps the existing Client to implement Transport.
|
|
||||||
// This is a thin adapter — no behavioral changes from the current HTTP protocol.
|
|
||||||
type HTTPTransport struct {
|
|
||||||
client *Client
|
|
||||||
events chan ServerEvent
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewHTTPTransport creates a new HTTP-based transport.
|
|
||||||
func NewHTTPTransport(baseURL, apiKey, userAgent string) *HTTPTransport {
|
|
||||||
return &HTTPTransport{
|
|
||||||
client: NewClient(baseURL, apiKey, userAgent),
|
|
||||||
events: make(chan ServerEvent, 10),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *HTTPTransport) Connect(_ context.Context) error { return nil }
|
|
||||||
func (t *HTTPTransport) Close() error { return nil }
|
|
||||||
func (t *HTTPTransport) Mode() string { return "http" }
|
|
||||||
func (t *HTTPTransport) Events() <-chan ServerEvent { return t.events }
|
|
||||||
|
|
||||||
func (t *HTTPTransport) Register(ctx context.Context, req RegisterRequest) (*RegisterResponse, error) {
|
|
||||||
return t.client.Register(ctx, req)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *HTTPTransport) SendHeartbeat(ctx context.Context, req HeartbeatRequest) (*HeartbeatResponse, error) {
|
|
||||||
return t.client.Heartbeat(ctx, req)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *HTTPTransport) SendProgress(ctx context.Context, update StatusUpdate) (*StatusResponse, error) {
|
|
||||||
return t.client.ReportStatus(ctx, update)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *HTTPTransport) BatchReportStatus(ctx context.Context, updates []StatusUpdate) (*BatchStatusResponse, error) {
|
|
||||||
return t.client.BatchReportStatus(ctx, updates)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *HTTPTransport) ClaimTasks(ctx context.Context, agentID string) (*TasksResponse, error) {
|
|
||||||
return t.client.ClaimTasks(ctx, agentID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *HTTPTransport) Deregister(ctx context.Context, agentID string) error {
|
|
||||||
return t.client.Deregister(ctx, agentID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Client returns the underlying HTTP client for direct use if needed.
|
|
||||||
func (t *HTTPTransport) Client() *Client { return t.client }
|
|
||||||
|
|
@ -1,214 +0,0 @@
|
||||||
package agent
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"log"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// HybridTransport tries WebSocket first, falls back to HTTP if WS fails.
|
|
||||||
// Automatically reconnects WS in the background.
|
|
||||||
type HybridTransport struct {
|
|
||||||
ws *WSTransport
|
|
||||||
http *HTTPTransport
|
|
||||||
|
|
||||||
mode atomic.Value // "ws" or "http"
|
|
||||||
events chan ServerEvent
|
|
||||||
|
|
||||||
reconnectMu sync.Mutex
|
|
||||||
reconnectRunning bool
|
|
||||||
reconnectStop chan struct{}
|
|
||||||
closed atomic.Bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewHybridTransport creates a transport that prefers WS with HTTP fallback.
|
|
||||||
func NewHybridTransport(ws *WSTransport, http *HTTPTransport) *HybridTransport {
|
|
||||||
h := &HybridTransport{
|
|
||||||
ws: ws,
|
|
||||||
http: http,
|
|
||||||
events: make(chan ServerEvent, 50),
|
|
||||||
reconnectStop: make(chan struct{}),
|
|
||||||
}
|
|
||||||
h.mode.Store("http") // start in HTTP, upgrade to WS on Connect
|
|
||||||
return h
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *HybridTransport) Mode() string { return h.mode.Load().(string) }
|
|
||||||
func (h *HybridTransport) Events() <-chan ServerEvent { return h.events }
|
|
||||||
|
|
||||||
// Connect tries WS first. If it fails, falls back to HTTP and starts reconnection loop.
|
|
||||||
func (h *HybridTransport) Connect(ctx context.Context) error {
|
|
||||||
// Try WebSocket first
|
|
||||||
if err := h.ws.Connect(ctx); err != nil {
|
|
||||||
log.Printf("[transport] WebSocket connect failed (%v), using HTTP fallback", err)
|
|
||||||
h.mode.Store("http")
|
|
||||||
h.startReconnectLoop()
|
|
||||||
return h.http.Connect(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
h.mode.Store("ws")
|
|
||||||
log.Println("[transport] Connected via WebSocket")
|
|
||||||
|
|
||||||
// Forward WS events to unified channel + watch for disconnection
|
|
||||||
go h.forwardWSEvents()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close shuts down both transports and stops reconnection.
|
|
||||||
func (h *HybridTransport) Close() error {
|
|
||||||
h.closed.Store(true)
|
|
||||||
select {
|
|
||||||
case <-h.reconnectStop:
|
|
||||||
default:
|
|
||||||
close(h.reconnectStop)
|
|
||||||
}
|
|
||||||
_ = h.ws.Close()
|
|
||||||
return h.http.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register delegates to the active transport.
|
|
||||||
func (h *HybridTransport) Register(ctx context.Context, req RegisterRequest) (*RegisterResponse, error) {
|
|
||||||
if h.mode.Load() == "ws" {
|
|
||||||
return h.ws.Register(ctx, req)
|
|
||||||
}
|
|
||||||
return h.http.Register(ctx, req)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SendHeartbeat delegates to the active transport.
|
|
||||||
func (h *HybridTransport) SendHeartbeat(ctx context.Context, req HeartbeatRequest) (*HeartbeatResponse, error) {
|
|
||||||
if h.mode.Load() == "ws" {
|
|
||||||
resp, err := h.ws.SendHeartbeat(ctx, req)
|
|
||||||
if err != nil {
|
|
||||||
// WS write failed — switch to HTTP
|
|
||||||
h.switchToHTTP()
|
|
||||||
return h.http.SendHeartbeat(ctx, req)
|
|
||||||
}
|
|
||||||
return resp, nil
|
|
||||||
}
|
|
||||||
return h.http.SendHeartbeat(ctx, req)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SendProgress delegates to the active transport.
|
|
||||||
func (h *HybridTransport) SendProgress(ctx context.Context, update StatusUpdate) (*StatusResponse, error) {
|
|
||||||
if h.mode.Load() == "ws" {
|
|
||||||
resp, err := h.ws.SendProgress(ctx, update)
|
|
||||||
if err != nil {
|
|
||||||
h.switchToHTTP()
|
|
||||||
return h.http.SendProgress(ctx, update)
|
|
||||||
}
|
|
||||||
return resp, nil
|
|
||||||
}
|
|
||||||
return h.http.SendProgress(ctx, update)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClaimTasks delegates to the active transport.
|
|
||||||
func (h *HybridTransport) ClaimTasks(ctx context.Context, agentID string) (*TasksResponse, error) {
|
|
||||||
if h.mode.Load() == "ws" {
|
|
||||||
return h.ws.ClaimTasks(ctx, agentID) // no-op in WS mode
|
|
||||||
}
|
|
||||||
return h.http.ClaimTasks(ctx, agentID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Deregister delegates to the active transport.
|
|
||||||
func (h *HybridTransport) Deregister(ctx context.Context, agentID string) error {
|
|
||||||
if h.mode.Load() == "ws" {
|
|
||||||
return h.ws.Deregister(ctx, agentID)
|
|
||||||
}
|
|
||||||
return h.http.Deregister(ctx, agentID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ── Internal ─────────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
func (h *HybridTransport) switchToHTTP() {
|
|
||||||
if h.mode.Load() == "http" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Println("[transport] Switching to HTTP fallback")
|
|
||||||
h.mode.Store("http")
|
|
||||||
_ = h.ws.Close()
|
|
||||||
h.startReconnectLoop()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *HybridTransport) forwardWSEvents() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-h.reconnectStop:
|
|
||||||
return
|
|
||||||
case event, ok := <-h.ws.Events():
|
|
||||||
if !ok {
|
|
||||||
return // channel closed
|
|
||||||
}
|
|
||||||
if event.Type == "disconnected" {
|
|
||||||
h.switchToHTTP()
|
|
||||||
select {
|
|
||||||
case h.events <- event:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case h.events <- event:
|
|
||||||
default:
|
|
||||||
log.Printf("[transport] events channel full, dropping %s event", event.Type)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *HybridTransport) startReconnectLoop() {
|
|
||||||
h.reconnectMu.Lock()
|
|
||||||
defer h.reconnectMu.Unlock()
|
|
||||||
if h.reconnectRunning {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
h.reconnectRunning = true
|
|
||||||
go h.reconnectLoop()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *HybridTransport) reconnectLoop() {
|
|
||||||
backoff := 5 * time.Second
|
|
||||||
maxBackoff := 60 * time.Second
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-h.reconnectStop:
|
|
||||||
return
|
|
||||||
case <-time.After(backoff):
|
|
||||||
}
|
|
||||||
|
|
||||||
if h.closed.Load() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Already on WS? (someone else reconnected)
|
|
||||||
if h.mode.Load() == "ws" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
||||||
err := h.ws.Connect(ctx)
|
|
||||||
cancel()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("[transport] WS reconnect failed: %v (retry in %v)", err, backoff)
|
|
||||||
backoff = min(backoff*2, maxBackoff)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// WS reconnected — switch back
|
|
||||||
log.Println("[transport] WebSocket reconnected")
|
|
||||||
h.mode.Store("ws")
|
|
||||||
|
|
||||||
// Reset reconnect flag so loop can start again if WS drops
|
|
||||||
h.reconnectMu.Lock()
|
|
||||||
h.reconnectRunning = false
|
|
||||||
h.reconnectMu.Unlock()
|
|
||||||
|
|
||||||
// Forward events from new WS connection
|
|
||||||
go h.forwardWSEvents()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,395 +0,0 @@
|
||||||
package agent
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
)
|
|
||||||
|
|
||||||
// WSTransport communicates with the server via WebSocket through a Cloudflare Durable Object.
|
|
||||||
type WSTransport struct {
|
|
||||||
wsURL string // wss://unarr.torrentclaw.com/ws/{agentId}
|
|
||||||
apiKey string
|
|
||||||
agentID string
|
|
||||||
userAgent string
|
|
||||||
|
|
||||||
conn *websocket.Conn
|
|
||||||
mu sync.Mutex
|
|
||||||
events chan ServerEvent
|
|
||||||
closed atomic.Bool
|
|
||||||
|
|
||||||
// Cached auth response from the DO
|
|
||||||
authResp *RegisterResponse
|
|
||||||
authMu sync.Mutex
|
|
||||||
authDone chan struct{}
|
|
||||||
authDoneOnce sync.Once
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewWSTransport creates a WebSocket-based transport.
|
|
||||||
func NewWSTransport(wsURL, apiKey, agentID, userAgent string) *WSTransport {
|
|
||||||
return &WSTransport{
|
|
||||||
wsURL: wsURL,
|
|
||||||
apiKey: apiKey,
|
|
||||||
agentID: agentID,
|
|
||||||
userAgent: userAgent,
|
|
||||||
events: make(chan ServerEvent, 50),
|
|
||||||
authDone: make(chan struct{}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *WSTransport) Mode() string { return "ws" }
|
|
||||||
func (t *WSTransport) Events() <-chan ServerEvent { return t.events }
|
|
||||||
|
|
||||||
// Connect dials the WebSocket server and starts the read loop.
|
|
||||||
func (t *WSTransport) Connect(ctx context.Context) error {
|
|
||||||
dialer := websocket.Dialer{
|
|
||||||
HandshakeTimeout: 10 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
header := http.Header{}
|
|
||||||
header.Set("User-Agent", t.userAgent)
|
|
||||||
|
|
||||||
// Append API key as query param for auth on WS upgrade
|
|
||||||
wsURLWithKey := t.wsURL
|
|
||||||
if t.apiKey != "" {
|
|
||||||
sep := "?"
|
|
||||||
if strings.Contains(wsURLWithKey, "?") {
|
|
||||||
sep = "&"
|
|
||||||
}
|
|
||||||
wsURLWithKey += sep + "key=" + t.apiKey
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, wsResp, err := dialer.DialContext(ctx, wsURLWithKey, header)
|
|
||||||
if wsResp != nil && wsResp.Body != nil {
|
|
||||||
defer wsResp.Body.Close()
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("ws dial: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.mu.Lock()
|
|
||||||
t.conn = conn
|
|
||||||
t.closed.Store(false)
|
|
||||||
t.authDone = make(chan struct{})
|
|
||||||
t.authDoneOnce = sync.Once{}
|
|
||||||
t.mu.Unlock()
|
|
||||||
|
|
||||||
go t.readLoop(conn)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close sends a close frame and shuts down the connection.
|
|
||||||
func (t *WSTransport) Close() error {
|
|
||||||
t.closed.Store(true)
|
|
||||||
t.mu.Lock()
|
|
||||||
defer t.mu.Unlock()
|
|
||||||
if t.conn != nil {
|
|
||||||
_ = t.conn.WriteMessage(
|
|
||||||
websocket.CloseMessage,
|
|
||||||
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""),
|
|
||||||
)
|
|
||||||
err := t.conn.Close()
|
|
||||||
t.conn = nil
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register sends auth message and waits for the registered response.
|
|
||||||
func (t *WSTransport) Register(ctx context.Context, req RegisterRequest) (*RegisterResponse, error) {
|
|
||||||
msg := wsAuthMessage{
|
|
||||||
Type: "auth",
|
|
||||||
APIKey: t.apiKey,
|
|
||||||
AgentID: req.AgentID,
|
|
||||||
Name: req.Name,
|
|
||||||
OS: req.OS,
|
|
||||||
Arch: req.Arch,
|
|
||||||
Version: req.Version,
|
|
||||||
DownloadDir: req.DownloadDir,
|
|
||||||
DiskFreeBytes: req.DiskFreeBytes,
|
|
||||||
DiskTotalBytes: req.DiskTotalBytes,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := t.send(msg); err != nil {
|
|
||||||
return nil, fmt.Errorf("ws auth send: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for the auth response or context cancellation
|
|
||||||
select {
|
|
||||||
case <-t.authDone:
|
|
||||||
t.authMu.Lock()
|
|
||||||
resp := t.authResp
|
|
||||||
t.authMu.Unlock()
|
|
||||||
if resp == nil {
|
|
||||||
return nil, fmt.Errorf("ws auth: no response received")
|
|
||||||
}
|
|
||||||
return resp, nil
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil, ctx.Err()
|
|
||||||
case <-time.After(15 * time.Second):
|
|
||||||
return nil, fmt.Errorf("ws auth: timeout waiting for registered response")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SendHeartbeat sends a heartbeat message. No blocking response in WS mode.
|
|
||||||
func (t *WSTransport) SendHeartbeat(_ context.Context, req HeartbeatRequest) (*HeartbeatResponse, error) {
|
|
||||||
msg := struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Disk *struct {
|
|
||||||
Free int64 `json:"free"`
|
|
||||||
Total int64 `json:"total"`
|
|
||||||
} `json:"disk,omitempty"`
|
|
||||||
}{Type: "heartbeat"}
|
|
||||||
|
|
||||||
if req.DiskFreeBytes > 0 || req.DiskTotalBytes > 0 {
|
|
||||||
msg.Disk = &struct {
|
|
||||||
Free int64 `json:"free"`
|
|
||||||
Total int64 `json:"total"`
|
|
||||||
}{Free: req.DiskFreeBytes, Total: req.DiskTotalBytes}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := t.send(msg); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
// WS mode: heartbeat is fire-and-forget. Upgrade signals arrive via Events().
|
|
||||||
return &HeartbeatResponse{Success: true}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SendProgress sends a progress update. Control signals arrive async via Events().
|
|
||||||
func (t *WSTransport) SendProgress(_ context.Context, update StatusUpdate) (*StatusResponse, error) {
|
|
||||||
msg := struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
TaskID string `json:"taskId"`
|
|
||||||
Status string `json:"status,omitempty"`
|
|
||||||
Progress int `json:"progress,omitempty"`
|
|
||||||
DownloadedBytes int64 `json:"downloadedBytes,omitempty"`
|
|
||||||
TotalBytes int64 `json:"totalBytes,omitempty"`
|
|
||||||
SpeedBps int64 `json:"speedBps,omitempty"`
|
|
||||||
ETA int `json:"eta,omitempty"`
|
|
||||||
ResolvedMethod string `json:"resolvedMethod,omitempty"`
|
|
||||||
FileName string `json:"fileName,omitempty"`
|
|
||||||
FilePath string `json:"filePath,omitempty"`
|
|
||||||
StreamURL string `json:"streamUrl,omitempty"`
|
|
||||||
StreamReady bool `json:"streamReady,omitempty"`
|
|
||||||
ErrorMessage string `json:"errorMessage,omitempty"`
|
|
||||||
}{
|
|
||||||
Type: "progress",
|
|
||||||
TaskID: update.TaskID,
|
|
||||||
Status: update.Status,
|
|
||||||
Progress: update.Progress,
|
|
||||||
DownloadedBytes: update.DownloadedBytes,
|
|
||||||
TotalBytes: update.TotalBytes,
|
|
||||||
SpeedBps: update.SpeedBps,
|
|
||||||
ETA: update.ETA,
|
|
||||||
ResolvedMethod: update.ResolvedMethod,
|
|
||||||
FileName: update.FileName,
|
|
||||||
FilePath: update.FilePath,
|
|
||||||
StreamURL: update.StreamURL,
|
|
||||||
StreamReady: update.StreamReady,
|
|
||||||
ErrorMessage: update.ErrorMessage,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := t.send(msg); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
// In WS mode, control signals come via Events(), not in the progress response.
|
|
||||||
return &StatusResponse{Success: true}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClaimTasks is a no-op in WS mode — tasks arrive via Events().
|
|
||||||
func (t *WSTransport) ClaimTasks(_ context.Context, _ string) (*TasksResponse, error) {
|
|
||||||
return &TasksResponse{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Deregister is handled by WebSocket close (DO detects disconnection).
|
|
||||||
func (t *WSTransport) Deregister(_ context.Context, _ string) error {
|
|
||||||
return t.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ── Internal ─────────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
func (t *WSTransport) send(msg any) error {
|
|
||||||
t.mu.Lock()
|
|
||||||
defer t.mu.Unlock()
|
|
||||||
if t.conn == nil {
|
|
||||||
return fmt.Errorf("ws: not connected")
|
|
||||||
}
|
|
||||||
data, err := json.Marshal(msg)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_ = t.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
|
||||||
return t.conn.WriteMessage(websocket.TextMessage, data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *WSTransport) readLoop(conn *websocket.Conn) {
|
|
||||||
// Cloudflare idle timeout is 100s. We send pings every 30s and expect
|
|
||||||
// either a pong or a server message within 45s. If neither arrives,
|
|
||||||
// the read deadline fires and we detect the zombie connection.
|
|
||||||
const (
|
|
||||||
pongWait = 45 * time.Second
|
|
||||||
pingPeriod = 30 * time.Second
|
|
||||||
)
|
|
||||||
|
|
||||||
_ = conn.SetReadDeadline(time.Now().Add(pongWait))
|
|
||||||
conn.SetPongHandler(func(string) error {
|
|
||||||
_ = conn.SetReadDeadline(time.Now().Add(pongWait))
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
// Ping ticker goroutine — stops when readLoop returns.
|
|
||||||
pingDone := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
ticker := time.NewTicker(pingPeriod)
|
|
||||||
defer ticker.Stop()
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ticker.C:
|
|
||||||
t.mu.Lock()
|
|
||||||
if t.conn != nil {
|
|
||||||
_ = t.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
|
||||||
err := t.conn.WriteMessage(websocket.PingMessage, nil)
|
|
||||||
_ = t.conn.SetWriteDeadline(time.Time{})
|
|
||||||
if err != nil {
|
|
||||||
t.mu.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
t.mu.Unlock()
|
|
||||||
case <-pingDone:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
defer close(pingDone)
|
|
||||||
|
|
||||||
for {
|
|
||||||
_, msg, err := conn.ReadMessage()
|
|
||||||
if err != nil {
|
|
||||||
if !t.closed.Load() {
|
|
||||||
log.Printf("[ws] read error: %v", err)
|
|
||||||
// Signal disconnection to the daemon
|
|
||||||
select {
|
|
||||||
case t.events <- ServerEvent{Type: "disconnected"}:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Any message (text or pong) proves the connection is alive.
|
|
||||||
_ = conn.SetReadDeadline(time.Now().Add(pongWait))
|
|
||||||
|
|
||||||
var envelope struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(msg, &envelope); err != nil {
|
|
||||||
log.Printf("[ws] invalid message: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
switch envelope.Type {
|
|
||||||
case "registered":
|
|
||||||
var resp wsRegisteredMessage
|
|
||||||
if json.Unmarshal(msg, &resp) == nil {
|
|
||||||
t.authMu.Lock()
|
|
||||||
t.authResp = &RegisterResponse{
|
|
||||||
Success: true,
|
|
||||||
User: resp.User,
|
|
||||||
Features: resp.Features,
|
|
||||||
}
|
|
||||||
t.authMu.Unlock()
|
|
||||||
// Signal that auth is complete (sync.Once prevents double-close panic)
|
|
||||||
t.authDoneOnce.Do(func() { close(t.authDone) })
|
|
||||||
}
|
|
||||||
|
|
||||||
case "tasks":
|
|
||||||
var resp wsTasksMessage
|
|
||||||
if json.Unmarshal(msg, &resp) == nil {
|
|
||||||
select {
|
|
||||||
case t.events <- ServerEvent{
|
|
||||||
Type: "tasks",
|
|
||||||
Tasks: &TasksResponse{
|
|
||||||
Tasks: resp.Tasks,
|
|
||||||
StreamRequests: resp.StreamRequests,
|
|
||||||
},
|
|
||||||
}:
|
|
||||||
default:
|
|
||||||
log.Printf("[ws] events channel full, dropping tasks message")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
case "upgrade":
|
|
||||||
var resp wsUpgradeMessage
|
|
||||||
if json.Unmarshal(msg, &resp) == nil {
|
|
||||||
select {
|
|
||||||
case t.events <- ServerEvent{
|
|
||||||
Type: "upgrade",
|
|
||||||
Upgrade: &UpgradeSignal{Version: resp.Version},
|
|
||||||
}:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
case "control":
|
|
||||||
var resp ControlAction
|
|
||||||
if json.Unmarshal(msg, &resp) == nil {
|
|
||||||
select {
|
|
||||||
case t.events <- ServerEvent{
|
|
||||||
Type: "control",
|
|
||||||
Control: &resp,
|
|
||||||
}:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
case "error":
|
|
||||||
var resp struct {
|
|
||||||
Message string `json:"message"`
|
|
||||||
}
|
|
||||||
if json.Unmarshal(msg, &resp) == nil {
|
|
||||||
log.Printf("[ws] server error: %s", resp.Message)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ── WS message types ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
type wsAuthMessage struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
APIKey string `json:"apiKey"`
|
|
||||||
AgentID string `json:"agentId"`
|
|
||||||
Name string `json:"name,omitempty"`
|
|
||||||
OS string `json:"os,omitempty"`
|
|
||||||
Arch string `json:"arch,omitempty"`
|
|
||||||
Version string `json:"version,omitempty"`
|
|
||||||
DownloadDir string `json:"downloadDir,omitempty"`
|
|
||||||
DiskFreeBytes int64 `json:"diskFreeBytes,omitempty"`
|
|
||||||
DiskTotalBytes int64 `json:"diskTotalBytes,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type wsRegisteredMessage struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
User UserInfo `json:"user"`
|
|
||||||
Features FeatureFlags `json:"features"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type wsTasksMessage struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Tasks []Task `json:"tasks"`
|
|
||||||
StreamRequests []StreamRequest `json:"streamRequests,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type wsUpgradeMessage struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Version string `json:"version"`
|
|
||||||
}
|
|
||||||
|
|
@ -50,20 +50,6 @@ type UsenetServerInfo struct {
|
||||||
SSL bool `json:"ssl"`
|
SSL bool `json:"ssl"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// HeartbeatRequest is sent every 30s to keep the agent alive.
|
|
||||||
type HeartbeatRequest struct {
|
|
||||||
AgentID string `json:"agentId"`
|
|
||||||
Name string `json:"name,omitempty"`
|
|
||||||
OS string `json:"os,omitempty"`
|
|
||||||
Version string `json:"version,omitempty"`
|
|
||||||
DownloadDir string `json:"downloadDir,omitempty"`
|
|
||||||
DiskFreeBytes int64 `json:"diskFreeBytes,omitempty"`
|
|
||||||
DiskTotalBytes int64 `json:"diskTotalBytes,omitempty"`
|
|
||||||
StreamPort int `json:"streamPort,omitempty"`
|
|
||||||
LanIP string `json:"lanIp,omitempty"`
|
|
||||||
TailscaleIP string `json:"tailscaleIp,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Task represents a download task claimed from the server.
|
// Task represents a download task claimed from the server.
|
||||||
type Task struct {
|
type Task struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
|
|
@ -88,12 +74,6 @@ type Task struct {
|
||||||
CollectionName string `json:"collectionName,omitempty"` // Collection name (e.g., "Harry Potter Collection")
|
CollectionName string `json:"collectionName,omitempty"` // Collection name (e.g., "Harry Potter Collection")
|
||||||
}
|
}
|
||||||
|
|
||||||
// TasksResponse wraps the array of tasks returned by the server.
|
|
||||||
type TasksResponse struct {
|
|
||||||
Tasks []Task `json:"tasks"`
|
|
||||||
StreamRequests []StreamRequest `json:"streamRequests,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// StreamRequest is a request to stream a completed download from disk.
|
// StreamRequest is a request to stream a completed download from disk.
|
||||||
type StreamRequest struct {
|
type StreamRequest struct {
|
||||||
TaskID string `json:"taskId"`
|
TaskID string `json:"taskId"`
|
||||||
|
|
@ -139,14 +119,6 @@ type BatchStatusResponse struct {
|
||||||
Watching bool `json:"watching,omitempty"`
|
Watching bool `json:"watching,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// HeartbeatResponse is returned by the server on heartbeat.
|
|
||||||
type HeartbeatResponse struct {
|
|
||||||
Success bool `json:"success"`
|
|
||||||
Upgrade *UpgradeSignal `json:"upgrade,omitempty"`
|
|
||||||
Watching bool `json:"watching,omitempty"` // true when a user is viewing download progress in the web UI
|
|
||||||
Scan bool `json:"scan,omitempty"` // true when user triggered a library scan from the web UI
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpgradeSignal tells the agent to upgrade to a specific version.
|
// UpgradeSignal tells the agent to upgrade to a specific version.
|
||||||
type UpgradeSignal struct {
|
type UpgradeSignal struct {
|
||||||
Version string `json:"version"`
|
Version string `json:"version"`
|
||||||
|
|
@ -176,7 +148,6 @@ type AgentInfo struct {
|
||||||
User UserInfo
|
User UserInfo
|
||||||
Features FeatureFlags
|
Features FeatureFlags
|
||||||
StartedAt time.Time
|
StartedAt time.Time
|
||||||
LastPollAt time.Time
|
|
||||||
ActiveTasks int
|
ActiveTasks int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -334,6 +305,45 @@ type LibrarySyncResponse struct {
|
||||||
Removed int `json:"removed"`
|
Removed int `json:"removed"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Sync types (unified CLI ↔ Server communication)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// SyncRequest is sent by the CLI periodically to synchronize state with the server.
|
||||||
|
// Contains the CLI's full execution state — the server responds with pending actions.
|
||||||
|
type SyncRequest struct {
|
||||||
|
AgentID string `json:"agentId"`
|
||||||
|
Version string `json:"version,omitempty"`
|
||||||
|
OS string `json:"os,omitempty"`
|
||||||
|
Arch string `json:"arch,omitempty"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
DownloadDir string `json:"downloadDir,omitempty"`
|
||||||
|
DiskFreeBytes int64 `json:"diskFreeBytes,omitempty"`
|
||||||
|
DiskTotalBytes int64 `json:"diskTotalBytes,omitempty"`
|
||||||
|
StreamPort int `json:"streamPort,omitempty"`
|
||||||
|
LanIP string `json:"lanIp,omitempty"`
|
||||||
|
TailscaleIP string `json:"tailscaleIp,omitempty"`
|
||||||
|
FreeSlots int `json:"freeSlots"`
|
||||||
|
Tasks []TaskState `json:"tasks"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ControlAction represents a server-side control signal for a task.
|
||||||
|
type ControlAction struct {
|
||||||
|
Action string `json:"action"` // "pause", "resume", "cancel", "stream"
|
||||||
|
TaskID string `json:"taskId"`
|
||||||
|
DeleteFiles bool `json:"deleteFiles,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SyncResponse is returned by the server with all pending actions for the CLI.
|
||||||
|
type SyncResponse struct {
|
||||||
|
NewTasks []Task `json:"newTasks,omitempty"`
|
||||||
|
Controls []ControlAction `json:"controls,omitempty"`
|
||||||
|
StreamRequests []StreamRequest `json:"streamRequests,omitempty"`
|
||||||
|
Watching bool `json:"watching"`
|
||||||
|
Upgrade *UpgradeSignal `json:"upgrade,omitempty"`
|
||||||
|
Scan bool `json:"scan,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
// Watch progress types (used by stream tracking)
|
// Watch progress types (used by stream tracking)
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
|
|
|
||||||
|
|
@ -311,21 +311,10 @@ func configConnection(cfg *config.Config) error {
|
||||||
).Run()
|
).Run()
|
||||||
}
|
}
|
||||||
|
|
||||||
func configAdvanced(cfg *config.Config) error {
|
func configAdvanced(_ *config.Config) error {
|
||||||
return huh.NewForm(
|
// Sync intervals are adaptive (3s watching, 60s idle) — no user-facing config needed.
|
||||||
huh.NewGroup(
|
fmt.Println("No advanced settings to configure. Sync intervals are automatic.")
|
||||||
huh.NewInput().
|
return nil
|
||||||
Title("Poll interval").
|
|
||||||
Description("How often to check for new tasks (e.g. 30s, 1m)").
|
|
||||||
Value(&cfg.Daemon.PollInterval).
|
|
||||||
Validate(validateDuration),
|
|
||||||
huh.NewInput().
|
|
||||||
Title("Heartbeat interval").
|
|
||||||
Description("How often to send heartbeat to server (e.g. 30s, 1m)").
|
|
||||||
Value(&cfg.Daemon.HeartbeatInterval).
|
|
||||||
Validate(validateDuration),
|
|
||||||
),
|
|
||||||
).Run()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── Validators ──────────────────────────────────────────────────────
|
// ── Validators ──────────────────────────────────────────────────────
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,6 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
|
@ -27,13 +26,13 @@ func newStartCmd() *cobra.Command {
|
||||||
Short: "Start the download daemon (foreground)",
|
Short: "Start the download daemon (foreground)",
|
||||||
Long: `Start the unarr daemon in the foreground.
|
Long: `Start the unarr daemon in the foreground.
|
||||||
|
|
||||||
Registers with the server, receives download tasks via WebSocket (with
|
Registers with the server, receives download tasks via periodic sync,
|
||||||
HTTP fallback), and executes them using the configured download method.
|
and executes them using the configured download method.
|
||||||
Supports torrent, debrid, and usenet downloads concurrently.
|
Supports torrent, debrid, and usenet downloads concurrently.
|
||||||
|
|
||||||
The daemon sends periodic heartbeats and reports download progress back
|
The daemon syncs state with the server every 3s when someone is viewing
|
||||||
to the web dashboard. Press Ctrl+C to stop gracefully — active downloads
|
the web dashboard, or every 60s when idle. Press Ctrl+C to stop
|
||||||
get up to 30 seconds to finish.
|
gracefully — active downloads get up to 30 seconds to finish.
|
||||||
|
|
||||||
Requires: API key, agent ID, and download directory (run 'unarr init' first).
|
Requires: API key, agent ID, and download directory (run 'unarr init' first).
|
||||||
|
|
||||||
|
|
@ -127,85 +126,59 @@ func runDaemonStart() error {
|
||||||
bold.Println(" unarr Daemon")
|
bold.Println(" unarr Daemon")
|
||||||
fmt.Println()
|
fmt.Println()
|
||||||
|
|
||||||
// Parse intervals
|
|
||||||
pollInterval, _ := time.ParseDuration(cfg.Daemon.PollInterval)
|
|
||||||
if pollInterval == 0 {
|
|
||||||
pollInterval = 30 * time.Second
|
|
||||||
}
|
|
||||||
heartbeatInterval, _ := time.ParseDuration(cfg.Daemon.HeartbeatInterval)
|
|
||||||
if heartbeatInterval == 0 {
|
|
||||||
heartbeatInterval = 30 * time.Second
|
|
||||||
}
|
|
||||||
statusInterval, _ := time.ParseDuration(cfg.Daemon.StatusInterval)
|
|
||||||
if statusInterval == 0 {
|
|
||||||
statusInterval = 3 * time.Second
|
|
||||||
}
|
|
||||||
|
|
||||||
userAgent := "unarr/" + Version
|
userAgent := "unarr/" + Version
|
||||||
|
|
||||||
// Create daemon config
|
// Create daemon config
|
||||||
daemonCfg := agent.DaemonConfig{
|
daemonCfg := agent.DaemonConfig{
|
||||||
AgentID: cfg.Agent.ID,
|
AgentID: cfg.Agent.ID,
|
||||||
AgentName: cfg.Agent.Name,
|
AgentName: cfg.Agent.Name,
|
||||||
Version: Version,
|
Version: Version,
|
||||||
DownloadDir: cfg.Download.Dir,
|
DownloadDir: cfg.Download.Dir,
|
||||||
PollInterval: pollInterval,
|
StreamPort: cfg.Download.StreamPort,
|
||||||
HeartbeatInterval: heartbeatInterval,
|
LanIP: engine.LanIP(),
|
||||||
StreamPort: cfg.Download.StreamPort,
|
TailscaleIP: engine.TailscaleIP(),
|
||||||
LanIP: engine.LanIP(),
|
|
||||||
TailscaleIP: engine.TailscaleIP(),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create transport: Hybrid (WS + HTTP fallback) or HTTP-only
|
// Create HTTP client — single communication channel
|
||||||
httpT := agent.NewHTTPTransport(cfg.Auth.APIURL, cfg.Auth.APIKey, userAgent)
|
|
||||||
|
|
||||||
wsURL := cfg.Auth.WSURL
|
|
||||||
if wsURL == "" {
|
|
||||||
wsURL = deriveWSURL(cfg.Auth.APIURL, cfg.Agent.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
var transport agent.Transport
|
|
||||||
if wsURL != "" {
|
|
||||||
wsT := agent.NewWSTransport(wsURL, cfg.Auth.APIKey, cfg.Agent.ID, userAgent)
|
|
||||||
transport = agent.NewHybridTransport(wsT, httpT)
|
|
||||||
log.Printf("Transport: WebSocket (fallback: HTTP) → %s", wsURL)
|
|
||||||
} else {
|
|
||||||
transport = httpT
|
|
||||||
log.Println("Transport: HTTP only")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create daemon — always uses Transport interface
|
|
||||||
d := agent.NewDaemon(daemonCfg, transport)
|
|
||||||
|
|
||||||
// Create agent client for watch progress reporting
|
|
||||||
agentClient := agent.NewClient(cfg.Auth.APIURL, cfg.Auth.APIKey, userAgent)
|
agentClient := agent.NewClient(cfg.Auth.APIURL, cfg.Auth.APIKey, userAgent)
|
||||||
|
log.Printf("Transport: HTTP sync → %s", cfg.Auth.APIURL)
|
||||||
|
|
||||||
|
// Create daemon
|
||||||
|
d := agent.NewDaemon(daemonCfg, agentClient)
|
||||||
|
|
||||||
|
// Start SIGUSR1 reload watcher (unix only, no-op on Windows)
|
||||||
|
startReloadWatcher(&ReloadableConfig{Daemon: d})
|
||||||
|
|
||||||
// Daemon-scoped context — cancelled on shutdown
|
// Daemon-scoped context — cancelled on shutdown
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// Create progress reporter using transport
|
|
||||||
reporter := engine.NewProgressReporterWithTransport(transport, statusInterval)
|
|
||||||
reporter.SetWatchingFunc(func() bool { return d.Watching.Load() })
|
|
||||||
reporter.SetWatchingChangedHandler(func(watching bool) { d.Watching.Store(watching) })
|
|
||||||
|
|
||||||
// Parse speed limits
|
// Parse speed limits
|
||||||
maxDl, _ := config.ParseSpeed(cfg.Download.MaxDownloadSpeed)
|
maxDl, _ := config.ParseSpeed(cfg.Download.MaxDownloadSpeed)
|
||||||
maxUl, _ := config.ParseSpeed(cfg.Download.MaxUploadSpeed)
|
maxUl, _ := config.ParseSpeed(cfg.Download.MaxUploadSpeed)
|
||||||
|
|
||||||
// Parse torrent timeouts from config (default: 0 = unlimited, like qBittorrent)
|
// Parse torrent timeouts
|
||||||
metaTimeout, _ := time.ParseDuration(cfg.Download.MetadataTimeout)
|
metaTimeout, _ := time.ParseDuration(cfg.Download.MetadataTimeout)
|
||||||
stallTimeout, _ := time.ParseDuration(cfg.Download.StallTimeout)
|
stallTimeout, _ := time.ParseDuration(cfg.Download.StallTimeout)
|
||||||
|
|
||||||
|
// Create progress reporter — only used for stream tasks (handleStreamTask)
|
||||||
|
// The sync goroutine handles all regular progress reporting.
|
||||||
|
statusInterval, _ := time.ParseDuration(cfg.Daemon.StatusInterval)
|
||||||
|
if statusInterval == 0 {
|
||||||
|
statusInterval = 3 * time.Second
|
||||||
|
}
|
||||||
|
reporter := engine.NewProgressReporter(agentClient, statusInterval)
|
||||||
|
reporter.SetWatchingFunc(func() bool { return d.Watching.Load() })
|
||||||
|
|
||||||
// Create torrent downloader
|
// Create torrent downloader
|
||||||
torrentDl, err := engine.NewTorrentDownloader(engine.TorrentConfig{
|
torrentDl, err := engine.NewTorrentDownloader(engine.TorrentConfig{
|
||||||
DataDir: cfg.Download.Dir,
|
DataDir: cfg.Download.Dir,
|
||||||
MetadataTimeout: metaTimeout, // 0 = unlimited (default)
|
MetadataTimeout: metaTimeout,
|
||||||
StallTimeout: stallTimeout, // 0 = unlimited (default)
|
StallTimeout: stallTimeout,
|
||||||
MaxTimeout: 0, // unlimited
|
MaxTimeout: 0,
|
||||||
MaxDownloadRate: maxDl,
|
MaxDownloadRate: maxDl,
|
||||||
MaxUploadRate: maxUl,
|
MaxUploadRate: maxUl,
|
||||||
ListenPort: cfg.Download.ListenPort, // 0 = default 42069
|
ListenPort: cfg.Download.ListenPort,
|
||||||
SeedEnabled: false,
|
SeedEnabled: false,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -223,7 +196,7 @@ func runDaemonStart() error {
|
||||||
log.Printf("Speed limits: download=%s upload=%s", dlStr, ulStr)
|
log.Printf("Speed limits: download=%s upload=%s", dlStr, ulStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create debrid downloader (HTTPS-based, no provider interaction needed)
|
// Create debrid downloader
|
||||||
debridDl := engine.NewDebridDownloader()
|
debridDl := engine.NewDebridDownloader()
|
||||||
|
|
||||||
// Create download manager
|
// Create download manager
|
||||||
|
|
@ -237,170 +210,53 @@ func runDaemonStart() error {
|
||||||
TVShowsDir: cfg.Organize.TVShowsDir,
|
TVShowsDir: cfg.Organize.TVShowsDir,
|
||||||
OutputDir: cfg.Download.Dir,
|
OutputDir: cfg.Download.Dir,
|
||||||
},
|
},
|
||||||
}, reporter, torrentDl, debridDl, engine.NewUsenetDownloader(httpT.Client()))
|
}, reporter, torrentDl, debridDl, engine.NewUsenetDownloader(agentClient))
|
||||||
|
|
||||||
// Create persistent stream server — lives for the entire daemon lifecycle.
|
// Create persistent stream server
|
||||||
// One port, one server, swap files with SetFile(). No more port churn.
|
|
||||||
streamSrv := engine.NewStreamServer(cfg.Download.StreamPort)
|
streamSrv := engine.NewStreamServer(cfg.Download.StreamPort)
|
||||||
if err := streamSrv.Listen(ctx); err != nil {
|
if err := streamSrv.Listen(ctx); err != nil {
|
||||||
return fmt.Errorf("start stream server: %w", err)
|
return fmt.Errorf("start stream server: %w", err)
|
||||||
}
|
}
|
||||||
// Update heartbeat with actual port (may differ if configured port was busy)
|
|
||||||
d.UpdateStreamPort(streamSrv.Port())
|
d.UpdateStreamPort(streamSrv.Port())
|
||||||
|
|
||||||
// Wire state tracking
|
// Wire sync client callbacks
|
||||||
|
sc := d.SyncClient()
|
||||||
|
sc.GetFreeSlots = manager.FreeSlots
|
||||||
|
sc.GetTaskStates = manager.TaskStates
|
||||||
d.GetActiveCount = manager.ActiveCount
|
d.GetActiveCount = manager.ActiveCount
|
||||||
d.GetCleanableBytes = CleanableBytes
|
|
||||||
|
|
||||||
// Wire: server-side signals -> manager actions + stream tasks
|
// Trigger immediate sync when a download slot frees up
|
||||||
reporter.SetCancelHandler(func(taskID string) {
|
manager.OnTaskDone = func() { d.TriggerSync() }
|
||||||
manager.CancelTask(taskID)
|
|
||||||
cancelStreamTask(taskID)
|
|
||||||
})
|
|
||||||
reporter.SetPauseHandler(func(taskID string) {
|
|
||||||
manager.PauseTask(taskID)
|
|
||||||
cancelStreamTask(taskID)
|
|
||||||
})
|
|
||||||
reporter.SetDeleteFilesHandler(func(taskID string) {
|
|
||||||
manager.CancelAndDeleteFiles(taskID)
|
|
||||||
cancelStreamTask(taskID)
|
|
||||||
})
|
|
||||||
|
|
||||||
// Wire: stream requested on active download → set file on persistent server
|
|
||||||
reporter.SetStreamRequestedHandler(func(taskID string) {
|
|
||||||
task := manager.GetTask(taskID)
|
|
||||||
if task == nil {
|
|
||||||
log.Printf("[%s] stream requested but task not found in manager", taskID[:8])
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if task.GetStreamURL() != "" {
|
|
||||||
return // already streaming
|
|
||||||
}
|
|
||||||
provider, err := torrentDl.GetStreamProvider(taskID)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("[%s] stream failed: %v", taskID[:8], err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
cancelStreamContexts()
|
|
||||||
streamSrv.SetFile(provider, taskID)
|
|
||||||
task.SetStreamURL(streamSrv.URLsJSON())
|
|
||||||
log.Printf("[%s] streaming active download: %s", taskID[:8], provider.FileName())
|
|
||||||
|
|
||||||
// Start watch progress reporter with cancellable context
|
|
||||||
watchCtx, watchCancel := context.WithCancel(ctx) //nolint:gosec // cancel stored in streamRegistry, called by cancelStreamContexts()
|
|
||||||
streamRegistry.mu.Lock()
|
|
||||||
streamRegistry.cancels["watch:"+taskID] = watchCancel
|
|
||||||
streamRegistry.mu.Unlock()
|
|
||||||
go engine.NewWatchReporter(agentClient, streamSrv, taskID).Run(watchCtx)
|
|
||||||
})
|
|
||||||
|
|
||||||
// Wire: daemon claimed tasks -> manager
|
|
||||||
|
|
||||||
|
// Wire: sync receives new tasks → submit to manager or handle stream
|
||||||
d.OnTasksClaimed = func(tasks []agent.Task) {
|
d.OnTasksClaimed = func(tasks []agent.Task) {
|
||||||
for _, t := range tasks {
|
for _, t := range tasks {
|
||||||
if t.Mode == "stream" {
|
if t.Mode == "stream" {
|
||||||
// Skip if already streaming this task
|
|
||||||
if isStreamingTask(t.ID) {
|
if isStreamingTask(t.ID) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// Only 1 stream at a time: cancel existing stream goroutines + clear file
|
|
||||||
cancelStreamContexts()
|
cancelStreamContexts()
|
||||||
streamSrv.ClearFile()
|
streamSrv.ClearFile()
|
||||||
// Reserve slot before spawning goroutine to prevent TOCTOU race.
|
streamCtx, streamCancel := context.WithCancel(ctx) //nolint:gosec // G118: cancel stored in registry
|
||||||
streamCtx, streamCancel := context.WithCancel(ctx) //nolint:gosec // G118: cancel ownership transferred to streamRegistry
|
|
||||||
streamRegistry.mu.Lock()
|
streamRegistry.mu.Lock()
|
||||||
streamRegistry.cancels[t.ID] = streamCancel
|
streamRegistry.cancels[t.ID] = streamCancel
|
||||||
streamRegistry.mu.Unlock()
|
streamRegistry.mu.Unlock()
|
||||||
go handleStreamTask(streamCtx, t, reporter, cfg, agentClient, streamSrv)
|
go handleStreamTask(streamCtx, t, reporter, cfg, agentClient, streamSrv)
|
||||||
} else if t.ForceStart || manager.HasCapacity() {
|
|
||||||
manager.Submit(ctx, t)
|
|
||||||
} else {
|
} else {
|
||||||
log.Printf("[%s] skipped: no capacity (max %d)", t.ID[:8], cfg.Download.MaxConcurrent)
|
manager.Submit(ctx, t)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wire: stream requests for completed downloads → set file on persistent server
|
// Wire: sync receives control signals → act on manager
|
||||||
d.OnStreamRequested = func(sr agent.StreamRequest) {
|
d.OnControlAction = func(action, taskID string, deleteFiles bool) {
|
||||||
// Already serving this task — just notify server it's ready
|
|
||||||
if streamSrv.CurrentTaskID() == sr.TaskID {
|
|
||||||
go func() {
|
|
||||||
if _, err := transport.SendProgress(ctx, agent.StatusUpdate{
|
|
||||||
TaskID: sr.TaskID,
|
|
||||||
StreamReady: true,
|
|
||||||
}); err != nil {
|
|
||||||
log.Printf("[%s] stream ready re-notify failed: %v", sr.TaskID[:8], err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
filePath := sr.FilePath
|
|
||||||
info, err := os.Stat(filePath)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("[%s] stream request: file not found: %s", sr.TaskID[:8], filePath)
|
|
||||||
go func() {
|
|
||||||
if _, err := transport.SendProgress(ctx, agent.StatusUpdate{
|
|
||||||
TaskID: sr.TaskID,
|
|
||||||
Status: "failed",
|
|
||||||
ErrorMessage: fmt.Sprintf("file not found: %s", filePath),
|
|
||||||
}); err != nil {
|
|
||||||
log.Printf("[%s] stream error report failed: %v", sr.TaskID[:8], err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// If filePath is a directory, find the largest video file inside
|
|
||||||
if info.IsDir() {
|
|
||||||
found := engine.FindVideoFile(filePath)
|
|
||||||
if found == "" {
|
|
||||||
log.Printf("[%s] stream request: no video file in directory: %s", sr.TaskID[:8], filePath)
|
|
||||||
go func() {
|
|
||||||
if _, err := transport.SendProgress(ctx, agent.StatusUpdate{
|
|
||||||
TaskID: sr.TaskID,
|
|
||||||
Status: "failed",
|
|
||||||
ErrorMessage: fmt.Sprintf("no video file in directory: %s", filePath),
|
|
||||||
}); err != nil {
|
|
||||||
log.Printf("[%s] stream error report failed: %v", sr.TaskID[:8], err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
filePath = found
|
|
||||||
log.Printf("[%s] resolved directory to video file: %s", sr.TaskID[:8], filepath.Base(filePath))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cancel any active stream goroutines and swap file on the persistent server
|
|
||||||
cancelStreamContexts()
|
|
||||||
streamSrv.SetFile(engine.NewDiskFileProvider(filePath), sr.TaskID)
|
|
||||||
|
|
||||||
log.Printf("[%s] streaming from disk: %s → %s", sr.TaskID[:8], filepath.Base(filePath), streamSrv.URL())
|
|
||||||
|
|
||||||
// Start watch progress reporter with a cancellable context
|
|
||||||
// so it stops when the user switches to a different stream.
|
|
||||||
watchCtx, watchCancel := context.WithCancel(ctx) //nolint:gosec // cancel stored in streamRegistry, called by cancelStreamContexts()
|
|
||||||
streamRegistry.mu.Lock()
|
|
||||||
streamRegistry.cancels["watch:"+sr.TaskID] = watchCancel
|
|
||||||
streamRegistry.mu.Unlock()
|
|
||||||
go engine.NewWatchReporter(agentClient, streamSrv, sr.TaskID).Run(watchCtx)
|
|
||||||
|
|
||||||
// Notify server that stream is ready (clears streamRequested flag)
|
|
||||||
go func() {
|
|
||||||
if _, err := transport.SendProgress(ctx, agent.StatusUpdate{
|
|
||||||
TaskID: sr.TaskID,
|
|
||||||
StreamReady: true,
|
|
||||||
}); err != nil {
|
|
||||||
log.Printf("[%s] stream ready report failed: %v", sr.TaskID[:8], err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wire: WS control actions (pause/cancel/stream pushed from server)
|
|
||||||
d.OnControlAction = func(action, taskID string) {
|
|
||||||
switch action {
|
switch action {
|
||||||
case "cancel":
|
case "cancel":
|
||||||
manager.CancelTask(taskID)
|
if deleteFiles {
|
||||||
|
manager.CancelAndDeleteFiles(taskID)
|
||||||
|
} else {
|
||||||
|
manager.CancelTask(taskID)
|
||||||
|
}
|
||||||
cancelStreamTask(taskID)
|
cancelStreamTask(taskID)
|
||||||
if streamSrv.CurrentTaskID() == taskID {
|
if streamSrv.CurrentTaskID() == taskID {
|
||||||
streamSrv.ClearFile()
|
streamSrv.ClearFile()
|
||||||
|
|
@ -412,10 +268,9 @@ func runDaemonStart() error {
|
||||||
streamSrv.ClearFile()
|
streamSrv.ClearFile()
|
||||||
}
|
}
|
||||||
case "resume":
|
case "resume":
|
||||||
log.Printf("[%s] resume requested via WebSocket, triggering poll", taskID[:8])
|
log.Printf("[%s] resume requested, triggering sync", agent.ShortID(taskID))
|
||||||
d.TriggerPoll()
|
d.TriggerSync()
|
||||||
case "stream":
|
case "stream":
|
||||||
// Skip if already streaming this task
|
|
||||||
if streamSrv.CurrentTaskID() == taskID {
|
if streamSrv.CurrentTaskID() == taskID {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -425,13 +280,19 @@ func runDaemonStart() error {
|
||||||
}
|
}
|
||||||
provider, err := torrentDl.GetStreamProvider(taskID)
|
provider, err := torrentDl.GetStreamProvider(taskID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[%s] stream failed: %v", taskID[:8], err)
|
log.Printf("[%s] stream failed: %v", agent.ShortID(taskID), err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
cancelStreamContexts()
|
cancelStreamContexts()
|
||||||
streamSrv.SetFile(provider, taskID)
|
streamSrv.SetFile(provider, taskID)
|
||||||
task.SetStreamURL(streamSrv.URLsJSON())
|
task.SetStreamURL(streamSrv.URLsJSON())
|
||||||
log.Printf("[%s] streaming via WS: %s", taskID[:8], provider.FileName())
|
log.Printf("[%s] streaming: %s", agent.ShortID(taskID), provider.FileName())
|
||||||
|
|
||||||
|
watchCtx, watchCancel := context.WithCancel(ctx) //nolint:gosec // G118
|
||||||
|
streamRegistry.mu.Lock()
|
||||||
|
streamRegistry.cancels["watch:"+taskID] = watchCancel
|
||||||
|
streamRegistry.mu.Unlock()
|
||||||
|
go engine.NewWatchReporter(agentClient, streamSrv, taskID).Run(watchCtx)
|
||||||
case "stop-stream":
|
case "stop-stream":
|
||||||
cancelStreamTask(taskID)
|
cancelStreamTask(taskID)
|
||||||
if streamSrv.CurrentTaskID() == taskID {
|
if streamSrv.CurrentTaskID() == taskID {
|
||||||
|
|
@ -440,19 +301,77 @@ func runDaemonStart() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Config hot-reload (SIGUSR1 on Unix, no-op on Windows)
|
// Wire: sync receives stream requests for completed downloads
|
||||||
// Tickers are initialized inside d.Run(), so we pass the daemon
|
d.OnStreamRequested = func(sr agent.StreamRequest) {
|
||||||
// and the reload goroutine reads them when the signal arrives.
|
if streamSrv.CurrentTaskID() == sr.TaskID {
|
||||||
startReloadWatcher(&ReloadableConfig{Daemon: d})
|
// Already serving — notify server it's ready
|
||||||
|
go func() {
|
||||||
|
if _, err := agentClient.ReportStatus(ctx, agent.StatusUpdate{
|
||||||
|
TaskID: sr.TaskID,
|
||||||
|
StreamReady: true,
|
||||||
|
}); err != nil {
|
||||||
|
log.Printf("[%s] stream ready re-notify failed: %v", agent.ShortID(sr.TaskID), err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Signal handling
|
filePath := sr.FilePath
|
||||||
sigCh := make(chan os.Signal, 1)
|
info, err := os.Stat(filePath)
|
||||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
if err != nil {
|
||||||
|
log.Printf("[%s] stream request: file not found: %s", agent.ShortID(sr.TaskID), filePath)
|
||||||
|
go func() {
|
||||||
|
if _, err := agentClient.ReportStatus(ctx, agent.StatusUpdate{
|
||||||
|
TaskID: sr.TaskID,
|
||||||
|
Status: "failed",
|
||||||
|
ErrorMessage: fmt.Sprintf("file not found: %s", filePath),
|
||||||
|
}); err != nil {
|
||||||
|
log.Printf("[%s] stream error report failed: %v", agent.ShortID(sr.TaskID), err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Start progress reporter in background
|
if info.IsDir() {
|
||||||
go reporter.Run(ctx)
|
found := engine.FindVideoFile(filePath)
|
||||||
|
if found == "" {
|
||||||
|
log.Printf("[%s] stream request: no video file in directory: %s", agent.ShortID(sr.TaskID), filePath)
|
||||||
|
go func() {
|
||||||
|
if _, err := agentClient.ReportStatus(ctx, agent.StatusUpdate{
|
||||||
|
TaskID: sr.TaskID,
|
||||||
|
Status: "failed",
|
||||||
|
ErrorMessage: fmt.Sprintf("no video file in directory: %s", filePath),
|
||||||
|
}); err != nil {
|
||||||
|
log.Printf("[%s] stream error report failed: %v", agent.ShortID(sr.TaskID), err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
filePath = found
|
||||||
|
log.Printf("[%s] resolved directory to video file: %s", agent.ShortID(sr.TaskID), filepath.Base(filePath))
|
||||||
|
}
|
||||||
|
|
||||||
// Periodic DHT node persistence (every 5 min) — protects against crash data loss
|
cancelStreamContexts()
|
||||||
|
streamSrv.SetFile(engine.NewDiskFileProvider(filePath), sr.TaskID)
|
||||||
|
log.Printf("[%s] streaming from disk: %s → %s", agent.ShortID(sr.TaskID), filepath.Base(filePath), streamSrv.URL())
|
||||||
|
|
||||||
|
watchCtx, watchCancel := context.WithCancel(ctx) //nolint:gosec // G118
|
||||||
|
streamRegistry.mu.Lock()
|
||||||
|
streamRegistry.cancels["watch:"+sr.TaskID] = watchCancel
|
||||||
|
streamRegistry.mu.Unlock()
|
||||||
|
go engine.NewWatchReporter(agentClient, streamSrv, sr.TaskID).Run(watchCtx)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if _, err := agentClient.ReportStatus(ctx, agent.StatusUpdate{
|
||||||
|
TaskID: sr.TaskID,
|
||||||
|
StreamReady: true,
|
||||||
|
}); err != nil {
|
||||||
|
log.Printf("[%s] stream ready report failed: %v", agent.ShortID(sr.TaskID), err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Periodic DHT node persistence (every 5 min)
|
||||||
go func() {
|
go func() {
|
||||||
ticker := time.NewTicker(5 * time.Minute)
|
ticker := time.NewTicker(5 * time.Minute)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
@ -466,8 +385,7 @@ func runDaemonStart() error {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Start auto-scan goroutine (daily library scan + sync)
|
// Start auto-scan goroutine
|
||||||
// Default scan_path to download dir so auto-scan works out of the box.
|
|
||||||
scanPath := cfg.Library.ScanPath
|
scanPath := cfg.Library.ScanPath
|
||||||
if scanPath == "" {
|
if scanPath == "" {
|
||||||
scanPath = cfg.Download.Dir
|
scanPath = cfg.Download.Dir
|
||||||
|
|
@ -484,7 +402,10 @@ func runDaemonStart() error {
|
||||||
go runAutoScan(ctx, scanCfg, scanInterval, agentClient, d.ScanNow)
|
go runAutoScan(ctx, scanCfg, scanInterval, agentClient, d.ScanNow)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start daemon (blocks)
|
// Start reporter only for stream task handling
|
||||||
|
go reporter.Run(ctx)
|
||||||
|
|
||||||
|
// Start daemon (blocks — runs sync loop)
|
||||||
errCh := make(chan error, 1)
|
errCh := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
errCh <- d.Run(ctx)
|
errCh <- d.Run(ctx)
|
||||||
|
|
@ -493,6 +414,10 @@ func runDaemonStart() error {
|
||||||
// Start idle guard for the persistent stream server
|
// Start idle guard for the persistent stream server
|
||||||
go startIdleGuard(ctx, streamSrv)
|
go startIdleGuard(ctx, streamSrv)
|
||||||
|
|
||||||
|
// Signal handling
|
||||||
|
sigCh := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|
||||||
// Wait for signal or error
|
// Wait for signal or error
|
||||||
select {
|
select {
|
||||||
case sig := <-sigCh:
|
case sig := <-sigCh:
|
||||||
|
|
@ -506,6 +431,7 @@ func runDaemonStart() error {
|
||||||
defer shutdownCancel()
|
defer shutdownCancel()
|
||||||
manager.Shutdown(shutdownCtx)
|
manager.Shutdown(shutdownCtx)
|
||||||
|
|
||||||
|
d.Deregister()
|
||||||
fmt.Println(" Daemon stopped.")
|
fmt.Println(" Daemon stopped.")
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
|
|
@ -517,41 +443,6 @@ func runDaemonStart() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// deriveWSURL derives a WebSocket URL from the API URL.
|
|
||||||
// https://torrentclaw.com → wss://unarr.torrentclaw.com/ws/{agentId}
|
|
||||||
// Returns "" for localhost/dev environments where WS gateway isn't available.
|
|
||||||
func deriveWSURL(apiURL, agentID string) string {
|
|
||||||
if apiURL == "" || agentID == "" {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
// Parse domain from API URL
|
|
||||||
domain := apiURL
|
|
||||||
for _, prefix := range []string{"https://", "http://"} {
|
|
||||||
if len(domain) > len(prefix) && domain[:len(prefix)] == prefix {
|
|
||||||
domain = domain[len(prefix):]
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Strip trailing slash/path
|
|
||||||
for i := 0; i < len(domain); i++ {
|
|
||||||
if domain[i] == '/' {
|
|
||||||
domain = domain[:i]
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Strip port if present
|
|
||||||
if idx := strings.LastIndex(domain, ":"); idx > 0 {
|
|
||||||
domain = domain[:idx]
|
|
||||||
}
|
|
||||||
|
|
||||||
// Skip WS for localhost/dev — gateway only available in production
|
|
||||||
if domain == "localhost" || domain == "127.0.0.1" || domain == "0.0.0.0" {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
return "wss://unarr." + domain + "/ws/" + agentID
|
|
||||||
}
|
|
||||||
|
|
||||||
func formatSpeedLog(bps int64) string {
|
func formatSpeedLog(bps int64) string {
|
||||||
switch {
|
switch {
|
||||||
case bps >= 1024*1024*1024:
|
case bps >= 1024*1024*1024:
|
||||||
|
|
@ -569,11 +460,9 @@ func formatSpeedLog(bps int64) string {
|
||||||
func runAutoScan(ctx context.Context, cfg config.Config, interval time.Duration, ac *agent.Client, scanNow <-chan struct{}) {
|
func runAutoScan(ctx context.Context, cfg config.Config, interval time.Duration, ac *agent.Client, scanNow <-chan struct{}) {
|
||||||
log.Printf("[auto-scan] enabled: every %s, path: %s", interval, cfg.Library.ScanPath)
|
log.Printf("[auto-scan] enabled: every %s, path: %s", interval, cfg.Library.ScanPath)
|
||||||
|
|
||||||
// Run first scan after a short delay (let daemon stabilize)
|
|
||||||
select {
|
select {
|
||||||
case <-time.After(30 * time.Second):
|
case <-time.After(30 * time.Second):
|
||||||
case <-scanNow:
|
case <-scanNow:
|
||||||
// Immediate scan requested before initial delay
|
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -608,7 +497,6 @@ func runAutoScan(ctx context.Context, cfg config.Config, interval time.Duration,
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sync to server
|
|
||||||
items := library.BuildSyncItems(cache)
|
items := library.BuildSyncItems(cache)
|
||||||
if len(items) == 0 {
|
if len(items) == 0 {
|
||||||
log.Printf("[auto-scan] no items to sync")
|
log.Printf("[auto-scan] no items to sync")
|
||||||
|
|
|
||||||
|
|
@ -2,32 +2,6 @@ package cmd
|
||||||
|
|
||||||
import "testing"
|
import "testing"
|
||||||
|
|
||||||
func TestDeriveWSURL(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
apiURL string
|
|
||||||
agentID string
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{"https://torrentclaw.com", "agent-123", "wss://unarr.torrentclaw.com/ws/agent-123"},
|
|
||||||
{"http://localhost:3000", "a1", ""}, // localhost skipped
|
|
||||||
{"http://127.0.0.1:3000", "a1", ""}, // 127.0.0.1 skipped
|
|
||||||
{"https://torrentclaw.com/", "a1", "wss://unarr.torrentclaw.com/ws/a1"},
|
|
||||||
{"https://api.example.io", "x", "wss://unarr.api.example.io/ws/x"},
|
|
||||||
{"", "agent-123", ""},
|
|
||||||
{"https://torrentclaw.com", "", ""},
|
|
||||||
{"", "", ""},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.apiURL+"_"+tt.agentID, func(t *testing.T) {
|
|
||||||
got := deriveWSURL(tt.apiURL, tt.agentID)
|
|
||||||
if got != tt.want {
|
|
||||||
t.Errorf("deriveWSURL(%q, %q) = %q, want %q", tt.apiURL, tt.agentID, got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFormatSpeedLog(t *testing.T) {
|
func TestFormatSpeedLog(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
bps int64
|
bps int64
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,6 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/torrentclaw/unarr/internal/agent"
|
"github.com/torrentclaw/unarr/internal/agent"
|
||||||
"github.com/torrentclaw/unarr/internal/config"
|
"github.com/torrentclaw/unarr/internal/config"
|
||||||
|
|
@ -19,7 +18,8 @@ type ReloadableConfig struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// startReloadWatcher listens for SIGUSR1 and reloads config.
|
// startReloadWatcher listens for SIGUSR1 and reloads config.
|
||||||
// Only intervals are hot-reloadable (speeds require torrent client restart).
|
// With the sync-based architecture, intervals are fixed (3s watching, 60s idle).
|
||||||
|
// Hot-reload now mainly serves as a signal to re-read config for future settings.
|
||||||
func startReloadWatcher(rc *ReloadableConfig) {
|
func startReloadWatcher(rc *ReloadableConfig) {
|
||||||
sigCh := make(chan os.Signal, 1)
|
sigCh := make(chan os.Signal, 1)
|
||||||
signal.Notify(sigCh, syscall.SIGUSR1)
|
signal.Notify(sigCh, syscall.SIGUSR1)
|
||||||
|
|
@ -28,24 +28,11 @@ func startReloadWatcher(rc *ReloadableConfig) {
|
||||||
for range sigCh {
|
for range sigCh {
|
||||||
log.Println("Received SIGUSR1, reloading config...")
|
log.Println("Received SIGUSR1, reloading config...")
|
||||||
|
|
||||||
cfg, err := config.Load("")
|
_, err := config.Load("")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Config reload failed: %v", err)
|
log.Printf("Config reload failed: %v", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
cfg.ApplyEnvOverrides()
|
|
||||||
|
|
||||||
// Update poll interval
|
|
||||||
if d, _ := time.ParseDuration(cfg.Daemon.PollInterval); d > 0 && rc.Daemon.PollTicker != nil {
|
|
||||||
rc.Daemon.PollTicker.Reset(d)
|
|
||||||
log.Printf(" Poll interval: %s", d)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update heartbeat interval
|
|
||||||
if d, _ := time.ParseDuration(cfg.Daemon.HeartbeatInterval); d > 0 && rc.Daemon.HeartbeatTicker != nil {
|
|
||||||
rc.Daemon.HeartbeatTicker.Reset(d)
|
|
||||||
log.Printf(" Heartbeat interval: %s", d)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Println("Config reloaded successfully")
|
log.Println("Config reloaded successfully")
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
package cmd
|
package cmd
|
||||||
|
|
||||||
// Version is the CLI version. Overridden by goreleaser ldflags at release time.
|
// Version is the CLI version. Overridden by goreleaser ldflags at release time.
|
||||||
var Version = "0.5.5"
|
var Version = "0.5.6"
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,6 @@ type Config struct {
|
||||||
type AuthConfig struct {
|
type AuthConfig struct {
|
||||||
APIKey string `toml:"api_key"`
|
APIKey string `toml:"api_key"`
|
||||||
APIURL string `toml:"api_url"`
|
APIURL string `toml:"api_url"`
|
||||||
WSURL string `toml:"ws_url"` // optional, derived from api_url if empty
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type AgentConfig struct {
|
type AgentConfig struct {
|
||||||
|
|
@ -54,9 +53,7 @@ type OrganizeConfig struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type DaemonConfig struct {
|
type DaemonConfig struct {
|
||||||
PollInterval string `toml:"poll_interval"`
|
StatusInterval string `toml:"status_interval"`
|
||||||
HeartbeatInterval string `toml:"heartbeat_interval"`
|
|
||||||
StatusInterval string `toml:"status_interval"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type NotificationsConfig struct {
|
type NotificationsConfig struct {
|
||||||
|
|
@ -92,10 +89,7 @@ func Default() Config {
|
||||||
Organize: OrganizeConfig{
|
Organize: OrganizeConfig{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
},
|
},
|
||||||
Daemon: DaemonConfig{
|
Daemon: DaemonConfig{},
|
||||||
PollInterval: "30s",
|
|
||||||
HeartbeatInterval: "30s",
|
|
||||||
},
|
|
||||||
Notifications: NotificationsConfig{
|
Notifications: NotificationsConfig{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -21,8 +21,8 @@ func TestDefault(t *testing.T) {
|
||||||
if cfg.General.Country != "US" {
|
if cfg.General.Country != "US" {
|
||||||
t.Errorf("default Country = %q, want US", cfg.General.Country)
|
t.Errorf("default Country = %q, want US", cfg.General.Country)
|
||||||
}
|
}
|
||||||
if cfg.Daemon.HeartbeatInterval != "30s" {
|
if cfg.Daemon.StatusInterval != "" {
|
||||||
t.Errorf("default HeartbeatInterval = %q, want 30s", cfg.Daemon.HeartbeatInterval)
|
t.Errorf("default StatusInterval = %q, want empty", cfg.Daemon.StatusInterval)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,8 @@ import (
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/torrentclaw/unarr/internal/agent"
|
||||||
)
|
)
|
||||||
|
|
||||||
// httpClient is used for debrid HTTPS downloads with a reasonable header timeout.
|
// httpClient is used for debrid HTTPS downloads with a reasonable header timeout.
|
||||||
|
|
@ -19,13 +21,6 @@ var httpClient = &http.Client{
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func shortID(id string) string {
|
|
||||||
if len(id) > 8 {
|
|
||||||
return id[:8]
|
|
||||||
}
|
|
||||||
return id
|
|
||||||
}
|
|
||||||
|
|
||||||
// DebridDownloader downloads files via HTTPS direct URLs resolved by the server.
|
// DebridDownloader downloads files via HTTPS direct URLs resolved by the server.
|
||||||
// The server handles all debrid provider interaction; this downloader only needs
|
// The server handles all debrid provider interaction; this downloader only needs
|
||||||
// a plain HTTPS URL to fetch.
|
// a plain HTTPS URL to fetch.
|
||||||
|
|
@ -129,7 +124,7 @@ func (d *DebridDownloader) Download(ctx context.Context, task *Task, outputDir s
|
||||||
var serverSize int64
|
var serverSize int64
|
||||||
if _, err := fmt.Sscanf(cr, "bytes */%d", &serverSize); err == nil && serverSize > 0 && existingSize != serverSize {
|
if _, err := fmt.Sscanf(cr, "bytes */%d", &serverSize); err == nil && serverSize > 0 && existingSize != serverSize {
|
||||||
// Local file size doesn't match server — re-download from scratch
|
// Local file size doesn't match server — re-download from scratch
|
||||||
log.Printf("[%s] local size %s != server size %s, re-downloading", shortID(task.ID), formatBytes(existingSize), formatBytes(serverSize))
|
log.Printf("[%s] local size %s != server size %s, re-downloading", agent.ShortID(task.ID), formatBytes(existingSize), formatBytes(serverSize))
|
||||||
resp.Body.Close()
|
resp.Body.Close()
|
||||||
req2, err := http.NewRequestWithContext(dlCtx, http.MethodGet, task.DirectURL, nil)
|
req2, err := http.NewRequestWithContext(dlCtx, http.MethodGet, task.DirectURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -149,7 +144,7 @@ func (d *DebridDownloader) Download(ctx context.Context, task *Task, outputDir s
|
||||||
break // continue to download loop
|
break // continue to download loop
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
log.Printf("[%s] file already complete: %s (%s)", shortID(task.ID), fileName, formatBytes(existingSize))
|
log.Printf("[%s] file already complete: %s (%s)", agent.ShortID(task.ID), fileName, formatBytes(existingSize))
|
||||||
return &Result{
|
return &Result{
|
||||||
FilePath: destPath,
|
FilePath: destPath,
|
||||||
FileName: fileName,
|
FileName: fileName,
|
||||||
|
|
@ -166,10 +161,10 @@ func (d *DebridDownloader) Download(ctx context.Context, task *Task, outputDir s
|
||||||
var flags int
|
var flags int
|
||||||
if startOffset > 0 {
|
if startOffset > 0 {
|
||||||
flags = os.O_WRONLY | os.O_APPEND
|
flags = os.O_WRONLY | os.O_APPEND
|
||||||
log.Printf("[%s] resuming debrid download at %s: %s", shortID(task.ID), formatBytes(startOffset), fileName)
|
log.Printf("[%s] resuming debrid download at %s: %s", agent.ShortID(task.ID), formatBytes(startOffset), fileName)
|
||||||
} else {
|
} else {
|
||||||
flags = os.O_WRONLY | os.O_CREATE | os.O_TRUNC
|
flags = os.O_WRONLY | os.O_CREATE | os.O_TRUNC
|
||||||
log.Printf("[%s] starting debrid download: %s", shortID(task.ID), fileName)
|
log.Printf("[%s] starting debrid download: %s", agent.ShortID(task.ID), fileName)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := os.MkdirAll(filepath.Dir(destPath), 0o755); err != nil {
|
if err := os.MkdirAll(filepath.Dir(destPath), 0o755); err != nil {
|
||||||
|
|
@ -223,7 +218,7 @@ func (d *DebridDownloader) Download(ctx context.Context, task *Task, outputDir s
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("[%s] %d%% — %s/%s @ %s/s (debrid)",
|
log.Printf("[%s] %d%% — %s/%s @ %s/s (debrid)",
|
||||||
shortID(task.ID), pct,
|
agent.ShortID(task.ID), pct,
|
||||||
formatBytes(downloaded), formatBytes(totalBytes), formatBytes(speed))
|
formatBytes(downloaded), formatBytes(totalBytes), formatBytes(speed))
|
||||||
|
|
||||||
p := Progress{
|
p := Progress{
|
||||||
|
|
@ -252,7 +247,7 @@ func (d *DebridDownloader) Download(ctx context.Context, task *Task, outputDir s
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("[%s] debrid download complete: %s (%s)", shortID(task.ID), fileName, formatBytes(downloaded))
|
log.Printf("[%s] debrid download complete: %s (%s)", agent.ShortID(task.ID), fileName, formatBytes(downloaded))
|
||||||
|
|
||||||
return &Result{
|
return &Result{
|
||||||
FilePath: destPath,
|
FilePath: destPath,
|
||||||
|
|
@ -271,7 +266,7 @@ func (d *DebridDownloader) Pause(taskID string) error {
|
||||||
|
|
||||||
if ok {
|
if ok {
|
||||||
cancel()
|
cancel()
|
||||||
log.Printf("[%s] debrid download paused (file kept for resume)", shortID(taskID))
|
log.Printf("[%s] debrid download paused (file kept for resume)", agent.ShortID(taskID))
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
@ -285,7 +280,7 @@ func (d *DebridDownloader) Cancel(taskID string) error {
|
||||||
|
|
||||||
if ok {
|
if ok {
|
||||||
cancel()
|
cancel()
|
||||||
log.Printf("[%s] debrid download cancelled", shortID(taskID))
|
log.Printf("[%s] debrid download cancelled", agent.ShortID(taskID))
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,15 @@ type Manager struct {
|
||||||
|
|
||||||
sem chan struct{}
|
sem chan struct{}
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
|
|
||||||
|
// OnTaskDone is called after a task completes or fails (slot freed).
|
||||||
|
// Used by the daemon to trigger an immediate sync.
|
||||||
|
OnTaskDone func()
|
||||||
|
|
||||||
|
// recentlyFinished holds tasks that completed/failed since the last sync read.
|
||||||
|
// The sync goroutine reads and clears this to include final states in the next sync.
|
||||||
|
recentMu sync.Mutex
|
||||||
|
recentFinished []agent.TaskState
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewManager creates a download manager.
|
// NewManager creates a download manager.
|
||||||
|
|
@ -67,7 +76,7 @@ func (m *Manager) Submit(ctx context.Context, at agent.Task) {
|
||||||
|
|
||||||
// Force start: bypass semaphore (like Transmission's "Force Start")
|
// Force start: bypass semaphore (like Transmission's "Force Start")
|
||||||
if at.ForceStart {
|
if at.ForceStart {
|
||||||
log.Printf("[%s] force start: bypassing queue", task.ID[:8])
|
log.Printf("[%s] force start: bypassing queue", agent.ShortID(task.ID))
|
||||||
m.wg.Add(1)
|
m.wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer m.wg.Done()
|
defer m.wg.Done()
|
||||||
|
|
@ -88,7 +97,12 @@ func (m *Manager) Submit(ctx context.Context, at agent.Task) {
|
||||||
m.wg.Add(1)
|
m.wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer m.wg.Done()
|
defer m.wg.Done()
|
||||||
defer func() { <-m.sem }()
|
defer func() {
|
||||||
|
<-m.sem
|
||||||
|
if m.OnTaskDone != nil {
|
||||||
|
m.OnTaskDone()
|
||||||
|
}
|
||||||
|
}()
|
||||||
defer taskCancel()
|
defer taskCancel()
|
||||||
m.processTask(taskCtx, task)
|
m.processTask(taskCtx, task)
|
||||||
}()
|
}()
|
||||||
|
|
@ -99,6 +113,11 @@ func (m *Manager) HasCapacity() bool {
|
||||||
return len(m.sem) < cap(m.sem)
|
return len(m.sem) < cap(m.sem)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FreeSlots returns the number of available download slots.
|
||||||
|
func (m *Manager) FreeSlots() int {
|
||||||
|
return cap(m.sem) - len(m.sem)
|
||||||
|
}
|
||||||
|
|
||||||
// ActiveCount returns the number of in-progress downloads.
|
// ActiveCount returns the number of in-progress downloads.
|
||||||
func (m *Manager) ActiveCount() int {
|
func (m *Manager) ActiveCount() int {
|
||||||
m.activeMu.RLock()
|
m.activeMu.RLock()
|
||||||
|
|
@ -113,6 +132,17 @@ func (m *Manager) GetTask(taskID string) *Task {
|
||||||
return m.active[taskID]
|
return m.active[taskID]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ActiveTaskIDs returns the IDs of all in-progress tasks.
|
||||||
|
func (m *Manager) ActiveTaskIDs() []string {
|
||||||
|
m.activeMu.RLock()
|
||||||
|
defer m.activeMu.RUnlock()
|
||||||
|
ids := make([]string, 0, len(m.active))
|
||||||
|
for id := range m.active {
|
||||||
|
ids = append(ids, id)
|
||||||
|
}
|
||||||
|
return ids
|
||||||
|
}
|
||||||
|
|
||||||
// ActiveTasks returns a snapshot of all active tasks.
|
// ActiveTasks returns a snapshot of all active tasks.
|
||||||
func (m *Manager) ActiveTasks() []*Task {
|
func (m *Manager) ActiveTasks() []*Task {
|
||||||
m.activeMu.RLock()
|
m.activeMu.RLock()
|
||||||
|
|
@ -124,6 +154,37 @@ func (m *Manager) ActiveTasks() []*Task {
|
||||||
return tasks
|
return tasks
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TaskStates returns the current state of all active tasks plus any recently
|
||||||
|
// finished tasks that haven't been synced yet. Called by the sync goroutine.
|
||||||
|
func (m *Manager) TaskStates() []agent.TaskState {
|
||||||
|
// Collect active tasks
|
||||||
|
m.activeMu.RLock()
|
||||||
|
states := make([]agent.TaskState, 0, len(m.active))
|
||||||
|
for _, t := range m.active {
|
||||||
|
states = append(states, agent.TaskStateFromUpdate(t.ToStatusUpdate()))
|
||||||
|
}
|
||||||
|
m.activeMu.RUnlock()
|
||||||
|
|
||||||
|
// Drain recently finished tasks (consumed once per sync)
|
||||||
|
m.recentMu.Lock()
|
||||||
|
states = append(states, m.recentFinished...)
|
||||||
|
m.recentFinished = nil
|
||||||
|
m.recentMu.Unlock()
|
||||||
|
|
||||||
|
return states
|
||||||
|
}
|
||||||
|
|
||||||
|
// recordFinished stores a completed/failed task for the next sync cycle.
|
||||||
|
func (m *Manager) recordFinished(update agent.StatusUpdate) {
|
||||||
|
m.recentMu.Lock()
|
||||||
|
defer m.recentMu.Unlock()
|
||||||
|
m.recentFinished = append(m.recentFinished, agent.TaskStateFromUpdate(update))
|
||||||
|
// Keep bounded
|
||||||
|
if len(m.recentFinished) > 20 {
|
||||||
|
m.recentFinished = m.recentFinished[len(m.recentFinished)-20:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// CancelTask cancels an active download by task ID (keeps partial files).
|
// CancelTask cancels an active download by task ID (keeps partial files).
|
||||||
func (m *Manager) CancelTask(taskID string) {
|
func (m *Manager) CancelTask(taskID string) {
|
||||||
m.activeMu.RLock()
|
m.activeMu.RLock()
|
||||||
|
|
@ -150,7 +211,7 @@ func (m *Manager) CancelTask(taskID string) {
|
||||||
task.mu.Unlock()
|
task.mu.Unlock()
|
||||||
task.Transition(StatusCancelled)
|
task.Transition(StatusCancelled)
|
||||||
|
|
||||||
log.Printf("[%s] cancelled: %s", taskID[:8], task.Title)
|
log.Printf("[%s] cancelled: %s", agent.ShortID(taskID), task.Title)
|
||||||
}
|
}
|
||||||
|
|
||||||
// PauseTask pauses an active download (keeps partial files for resume).
|
// PauseTask pauses an active download (keeps partial files for resume).
|
||||||
|
|
@ -173,7 +234,7 @@ func (m *Manager) PauseTask(taskID string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
task.Transition(StatusCancelled) // will be re-created as pending by server
|
task.Transition(StatusCancelled) // will be re-created as pending by server
|
||||||
log.Printf("[%s] paused: %s", taskID[:8], task.Title)
|
log.Printf("[%s] paused: %s", agent.ShortID(taskID), task.Title)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CancelAndDeleteFiles cancels a download and removes its files from disk.
|
// CancelAndDeleteFiles cancels a download and removes its files from disk.
|
||||||
|
|
@ -200,7 +261,7 @@ func (m *Manager) CancelAndDeleteFiles(taskID string) {
|
||||||
task.mu.Unlock()
|
task.mu.Unlock()
|
||||||
task.Transition(StatusCancelled)
|
task.Transition(StatusCancelled)
|
||||||
|
|
||||||
log.Printf("[%s] cancelled + files deleted: %s", taskID[:8], task.Title)
|
log.Printf("[%s] cancelled + files deleted: %s", agent.ShortID(taskID), task.Title)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait blocks until all active downloads finish.
|
// Wait blocks until all active downloads finish.
|
||||||
|
|
@ -261,7 +322,7 @@ func (m *Manager) processTask(ctx context.Context, task *Task) {
|
||||||
}
|
}
|
||||||
|
|
||||||
task.ResolvedMethod = method
|
task.ResolvedMethod = method
|
||||||
log.Printf("[%s] resolved method: %s", task.ID[:8], method)
|
log.Printf("[%s] resolved method: %s", agent.ShortID(task.ID), method)
|
||||||
|
|
||||||
// 2. Download
|
// 2. Download
|
||||||
if err := task.Transition(StatusDownloading); err != nil {
|
if err := task.Transition(StatusDownloading); err != nil {
|
||||||
|
|
@ -285,7 +346,7 @@ func (m *Manager) processTask(ctx context.Context, task *Task) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Try fallback
|
// Try fallback
|
||||||
if tryFallback(task, m.downloaders) {
|
if tryFallback(task, m.downloaders) {
|
||||||
log.Printf("[%s] %s failed, trying fallback: %v", task.ID[:8], method, err)
|
log.Printf("[%s] %s failed, trying fallback: %v", agent.ShortID(task.ID), method, err)
|
||||||
if err := task.Transition(StatusResolving); err == nil {
|
if err := task.Transition(StatusResolving); err == nil {
|
||||||
m.processTaskRetry(ctx, task)
|
m.processTaskRetry(ctx, task)
|
||||||
return
|
return
|
||||||
|
|
@ -295,61 +356,7 @@ func (m *Manager) processTask(ctx context.Context, task *Task) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. Verify
|
m.finalize(ctx, task, result)
|
||||||
if err := task.Transition(StatusVerifying); err != nil {
|
|
||||||
m.fail(ctx, task, "transition error: "+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := verify(result); err != nil {
|
|
||||||
m.fail(ctx, task, "verification failed: "+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 4. Organize
|
|
||||||
if err := task.Transition(StatusOrganizing); err != nil {
|
|
||||||
m.fail(ctx, task, "transition error: "+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
finalPath, err := organize(result, task, m.cfg.Organize)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("[%s] organize warning: %v (keeping in download dir)", task.ID[:8], err)
|
|
||||||
finalPath = result.FilePath
|
|
||||||
}
|
|
||||||
|
|
||||||
task.mu.Lock()
|
|
||||||
task.FilePath = finalPath
|
|
||||||
task.mu.Unlock()
|
|
||||||
|
|
||||||
// 4b. Handle upgrade replacement (mode = "upgrade")
|
|
||||||
if task.ReplacePath != "" {
|
|
||||||
backupDir := "" // uses default ~/.local/share/unarr/replaced/
|
|
||||||
if err := replaceFile(task.ReplacePath, finalPath, backupDir); err != nil {
|
|
||||||
log.Printf("[%s] replace warning: %v (keeping new file at %s)", task.ID[:8], err, finalPath)
|
|
||||||
} else {
|
|
||||||
task.mu.Lock()
|
|
||||||
task.FilePath = task.ReplacePath
|
|
||||||
task.mu.Unlock()
|
|
||||||
log.Printf("[%s] upgraded: replaced %s", task.ID[:8], task.ReplacePath)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 5. Complete
|
|
||||||
if method == MethodTorrent && m.cfg.Organize.Enabled {
|
|
||||||
// Could add seeding here in the future
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := task.Transition(StatusCompleted); err != nil {
|
|
||||||
m.fail(ctx, task, "transition error: "+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("[%s] completed: %s -> %s", task.ID[:8], task.Title, finalPath)
|
|
||||||
if m.cfg.Notifications {
|
|
||||||
desktopNotify("Download complete", task.Title)
|
|
||||||
}
|
|
||||||
m.reporter.ReportFinal(ctx, task)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// processTaskRetry handles fallback after a method failure.
|
// processTaskRetry handles fallback after a method failure.
|
||||||
|
|
@ -361,7 +368,7 @@ func (m *Manager) processTaskRetry(ctx context.Context, task *Task) {
|
||||||
}
|
}
|
||||||
|
|
||||||
task.ResolvedMethod = method
|
task.ResolvedMethod = method
|
||||||
log.Printf("[%s] fallback to: %s", task.ID[:8], method)
|
log.Printf("[%s] fallback to: %s", agent.ShortID(task.ID), method)
|
||||||
|
|
||||||
if err := task.Transition(StatusDownloading); err != nil {
|
if err := task.Transition(StatusDownloading); err != nil {
|
||||||
m.fail(ctx, task, "transition error: "+err.Error())
|
m.fail(ctx, task, "transition error: "+err.Error())
|
||||||
|
|
@ -383,15 +390,31 @@ func (m *Manager) processTaskRetry(ctx context.Context, task *Task) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify + Organize + Complete (same as processTask)
|
m.finalize(ctx, task, result)
|
||||||
task.Transition(StatusVerifying)
|
}
|
||||||
|
|
||||||
|
// finalize runs verify → organize → upgrade replacement → complete for a downloaded task.
|
||||||
|
func (m *Manager) finalize(ctx context.Context, task *Task, result *Result) {
|
||||||
|
// Verify
|
||||||
|
if err := task.Transition(StatusVerifying); err != nil {
|
||||||
|
m.fail(ctx, task, "transition error: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
if err := verify(result); err != nil {
|
if err := verify(result); err != nil {
|
||||||
m.fail(ctx, task, "verification failed: "+err.Error())
|
m.fail(ctx, task, "verification failed: "+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
task.Transition(StatusOrganizing)
|
// Organize
|
||||||
finalPath, _ := organize(result, task, m.cfg.Organize)
|
if err := task.Transition(StatusOrganizing); err != nil {
|
||||||
|
m.fail(ctx, task, "transition error: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
finalPath, err := organize(result, task, m.cfg.Organize)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[%s] organize warning: %v (keeping in download dir)", agent.ShortID(task.ID), err)
|
||||||
|
finalPath = result.FilePath
|
||||||
|
}
|
||||||
if finalPath == "" {
|
if finalPath == "" {
|
||||||
finalPath = result.FilePath
|
finalPath = result.FilePath
|
||||||
}
|
}
|
||||||
|
|
@ -399,8 +422,29 @@ func (m *Manager) processTaskRetry(ctx context.Context, task *Task) {
|
||||||
task.FilePath = finalPath
|
task.FilePath = finalPath
|
||||||
task.mu.Unlock()
|
task.mu.Unlock()
|
||||||
|
|
||||||
task.Transition(StatusCompleted)
|
// Handle upgrade replacement (mode = "upgrade")
|
||||||
log.Printf("[%s] completed (fallback): %s -> %s", task.ID[:8], task.Title, finalPath)
|
if task.ReplacePath != "" {
|
||||||
|
backupDir := "" // uses default ~/.local/share/unarr/replaced/
|
||||||
|
if err := replaceFile(task.ReplacePath, finalPath, backupDir); err != nil {
|
||||||
|
log.Printf("[%s] replace warning: %v (keeping new file at %s)", agent.ShortID(task.ID), err, finalPath)
|
||||||
|
} else {
|
||||||
|
task.mu.Lock()
|
||||||
|
task.FilePath = task.ReplacePath
|
||||||
|
task.mu.Unlock()
|
||||||
|
log.Printf("[%s] upgraded: replaced %s", agent.ShortID(task.ID), task.ReplacePath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Complete
|
||||||
|
if err := task.Transition(StatusCompleted); err != nil {
|
||||||
|
m.fail(ctx, task, "transition error: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Printf("[%s] completed: %s -> %s", agent.ShortID(task.ID), task.Title, finalPath)
|
||||||
|
if m.cfg.Notifications {
|
||||||
|
desktopNotify("Download complete", task.Title)
|
||||||
|
}
|
||||||
|
m.recordFinished(task.ToStatusUpdate())
|
||||||
m.reporter.ReportFinal(ctx, task)
|
m.reporter.ReportFinal(ctx, task)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -409,9 +453,10 @@ func (m *Manager) fail(ctx context.Context, task *Task, msg string) {
|
||||||
task.ErrorMessage = msg
|
task.ErrorMessage = msg
|
||||||
task.mu.Unlock()
|
task.mu.Unlock()
|
||||||
task.Transition(StatusFailed)
|
task.Transition(StatusFailed)
|
||||||
log.Printf("[%s] FAILED: %s — %s", task.ID[:8], task.Title, msg)
|
log.Printf("[%s] FAILED: %s — %s", agent.ShortID(task.ID), task.Title, msg)
|
||||||
if m.cfg.Notifications {
|
if m.cfg.Notifications {
|
||||||
desktopNotify("Download failed", task.Title+": "+msg)
|
desktopNotify("Download failed", task.Title+": "+msg)
|
||||||
}
|
}
|
||||||
|
m.recordFinished(task.ToStatusUpdate())
|
||||||
m.reporter.ReportFinal(ctx, task)
|
m.reporter.ReportFinal(ctx, task)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -13,13 +13,11 @@ import (
|
||||||
type ActionFunc func(taskID string)
|
type ActionFunc func(taskID string)
|
||||||
|
|
||||||
// StatusReporter is the interface used by ProgressReporter to send progress updates.
|
// StatusReporter is the interface used by ProgressReporter to send progress updates.
|
||||||
// Both *agent.Client and agent.Transport implement this via their ReportStatus/SendProgress methods.
|
|
||||||
type StatusReporter interface {
|
type StatusReporter interface {
|
||||||
ReportStatus(ctx context.Context, update agent.StatusUpdate) (*agent.StatusResponse, error)
|
ReportStatus(ctx context.Context, update agent.StatusUpdate) (*agent.StatusResponse, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// BatchStatusReporter extends StatusReporter with batch support.
|
// BatchStatusReporter extends StatusReporter with batch support.
|
||||||
// Transports that implement this send all updates in a single request.
|
|
||||||
type BatchStatusReporter interface {
|
type BatchStatusReporter interface {
|
||||||
StatusReporter
|
StatusReporter
|
||||||
BatchReportStatus(ctx context.Context, updates []agent.StatusUpdate) (*agent.BatchStatusResponse, error)
|
BatchReportStatus(ctx context.Context, updates []agent.StatusUpdate) (*agent.BatchStatusResponse, error)
|
||||||
|
|
@ -48,7 +46,6 @@ type ProgressReporter struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewProgressReporter creates a reporter that flushes every interval.
|
// NewProgressReporter creates a reporter that flushes every interval.
|
||||||
// Accepts *agent.Client directly (backwards compatible).
|
|
||||||
func NewProgressReporter(ac *agent.Client, interval time.Duration) *ProgressReporter {
|
func NewProgressReporter(ac *agent.Client, interval time.Duration) *ProgressReporter {
|
||||||
return &ProgressReporter{
|
return &ProgressReporter{
|
||||||
reporter: ac,
|
reporter: ac,
|
||||||
|
|
@ -58,25 +55,6 @@ func NewProgressReporter(ac *agent.Client, interval time.Duration) *ProgressRepo
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewProgressReporterWithTransport creates a reporter using a Transport.
|
|
||||||
func NewProgressReporterWithTransport(t agent.Transport, interval time.Duration) *ProgressReporter {
|
|
||||||
return &ProgressReporter{
|
|
||||||
reporter: &transportStatusAdapter{t: t},
|
|
||||||
interval: interval,
|
|
||||||
latest: make(map[string]*Task),
|
|
||||||
lastReported: make(map[string]TaskStatus),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// transportStatusAdapter adapts agent.Transport to StatusReporter.
|
|
||||||
type transportStatusAdapter struct {
|
|
||||||
t agent.Transport
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *transportStatusAdapter) ReportStatus(ctx context.Context, update agent.StatusUpdate) (*agent.StatusResponse, error) {
|
|
||||||
return a.t.SendProgress(ctx, update)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetCancelHandler sets the callback invoked when the server says a task is cancelled.
|
// SetCancelHandler sets the callback invoked when the server says a task is cancelled.
|
||||||
func (r *ProgressReporter) SetCancelHandler(fn ActionFunc) { r.onCancel = fn }
|
func (r *ProgressReporter) SetCancelHandler(fn ActionFunc) { r.onCancel = fn }
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue