diff --git a/internal/agent/client.go b/internal/agent/client.go index fe4e04a..ef0be81 100644 --- a/internal/agent/client.go +++ b/internal/agent/client.go @@ -16,6 +16,9 @@ type Client struct { baseURL string apiKey string httpClient *http.Client + // wakeClient has no built-in timeout — used exclusively for the long-poll + // wake endpoint where the context controls cancellation. + wakeClient *http.Client userAgent string } @@ -27,7 +30,10 @@ func NewClient(baseURL, apiKey, userAgent string) *Client { httpClient: &http.Client{ Timeout: 30 * time.Second, }, - userAgent: userAgent, + // wakeClient has no built-in timeout — the context controls it. + // The server holds the connection for up to 28s before responding. + wakeClient: &http.Client{}, + userAgent: userAgent, } } @@ -176,6 +182,36 @@ func (c *Client) ReportWatchProgress(ctx context.Context, update WatchProgressUp return nil } +// WaitForWake blocks until the server sends a wake signal, the long-poll +// timeout elapses, or ctx is cancelled. Returns true when a wake signal +// was received (caller should sync immediately), false on timeout/cancel. +func (c *Client) WaitForWake(ctx context.Context) (bool, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.baseURL+"/api/internal/agent/wake", nil) + if err != nil { + return false, fmt.Errorf("create wake request: %w", err) + } + c.setHeaders(req) + + resp, err := c.wakeClient.Do(req) + if err != nil { + return false, fmt.Errorf("wake request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode >= 400 { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<10)) + return false, &HTTPError{StatusCode: resp.StatusCode, Message: string(body)} + } + + var result struct { + Wake bool `json:"wake"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return false, fmt.Errorf("decode wake response: %w", err) + } + return result.Wake, nil +} + // doPost sends a JSON POST request and decodes the response. func (c *Client) doPost(ctx context.Context, path string, body any, dst any) error { jsonBody, err := json.Marshal(body) diff --git a/internal/agent/sync.go b/internal/agent/sync.go index 70129d4..484472e 100644 --- a/internal/agent/sync.go +++ b/internal/agent/sync.go @@ -12,7 +12,8 @@ const ( // SyncIntervalWatching is the sync interval when someone is viewing the web UI. SyncIntervalWatching = 3 * time.Second // SyncIntervalIdle is the sync interval when nobody is watching. - SyncIntervalIdle = 60 * time.Second + // Keep this short enough to pick up stream requests quickly without hammering the server. + SyncIntervalIdle = 10 * time.Second ) // SyncClient handles bidirectional state synchronization between the CLI and server. @@ -68,6 +69,9 @@ func (sc *SyncClient) TriggerSync() { // Run starts the adaptive sync loop. Blocks until ctx is cancelled. func (sc *SyncClient) Run(ctx context.Context) error { + // Start wake listener in background — triggers immediate syncs on demand. + go sc.runWakeListener(ctx) + // Initial sync immediately sc.doSync(ctx) @@ -174,6 +178,38 @@ func (sc *SyncClient) processResponse(resp *SyncResponse) { } } +// runWakeListener holds a long-poll connection to /api/internal/agent/wake. +// When the server resolves it with wake=true (e.g., a stream was requested), +// it triggers an immediate sync so the CLI acts in <100ms instead of waiting +// for the next scheduled interval. Reconnects immediately after each response +// so coverage is continuous. Runs until ctx is cancelled. +func (sc *SyncClient) runWakeListener(ctx context.Context) { + const retryDelay = 2 * time.Second + for { + if ctx.Err() != nil { + return + } + woke, err := sc.client.WaitForWake(ctx) + if ctx.Err() != nil { + return + } + if err != nil { + log.Printf("wake listener: %v (retrying in %s)", err, retryDelay) + select { + case <-ctx.Done(): + return + case <-time.After(retryDelay): + } + continue + } + if woke { + log.Printf("wake signal received — syncing immediately") + sc.TriggerSync() + } + // On timeout (woke=false) or after a wake, reconnect immediately. + } +} + func (sc *SyncClient) adjustInterval(watching bool) { prev := sc.watching.Load() sc.watching.Store(watching) @@ -189,6 +225,12 @@ func (sc *SyncClient) adjustInterval(watching bool) { log.Printf("sync: interval=%s (watching=%v)", newInterval, watching) } + // Trigger an immediate sync when entering watching mode so stream requests + // are picked up right away without waiting for the next scheduled interval. + if watching && !prev { + sc.TriggerSync() + } + if prev != watching && sc.OnWatchingChange != nil { sc.OnWatchingChange(watching) } diff --git a/internal/cmd/daemon.go b/internal/cmd/daemon.go index d050903..a446a3e 100644 --- a/internal/cmd/daemon.go +++ b/internal/cmd/daemon.go @@ -7,6 +7,7 @@ import ( "os" "os/signal" "path/filepath" + "strings" "syscall" "time" @@ -316,7 +317,12 @@ func runDaemonStart() error { return } - filePath := sr.FilePath + filePath := filepath.Clean(sr.FilePath) + if !isAllowedStreamPath(filePath, cfg.Download.Dir, cfg.Library.ScanPath, + cfg.Organize.MoviesDir, cfg.Organize.TVShowsDir) { + log.Printf("[%s] stream request rejected: path outside allowed dirs: %s", agent.ShortID(sr.TaskID), filePath) + return + } info, err := os.Stat(filePath) if err != nil { log.Printf("[%s] stream request: file not found: %s", agent.ShortID(sr.TaskID), filePath) @@ -443,6 +449,25 @@ func runDaemonStart() error { } } +// isAllowedStreamPath checks that filePath is within one of the directories +// the daemon is configured to manage. This defends against a compromised API +// server sending a path traversal payload (e.g. /etc/passwd) in StreamRequest. +// isAllowedStreamPath reports whether filePath is contained within one of the +// allowedDirs. filePath must already be cleaned (filepath.Clean) by the caller. +// This defends against a compromised API server sending a path traversal payload. +func isAllowedStreamPath(filePath string, allowedDirs ...string) bool { + for _, dir := range allowedDirs { + if dir == "" { + continue + } + rel, err := filepath.Rel(filepath.Clean(dir), filePath) + if err == nil && !strings.HasPrefix(rel, "..") { + return true + } + } + return false +} + func formatSpeedLog(bps int64) string { switch { case bps >= 1024*1024*1024: diff --git a/internal/engine/task.go b/internal/engine/task.go index 27c7462..ceba6c9 100644 --- a/internal/engine/task.go +++ b/internal/engine/task.go @@ -207,10 +207,20 @@ func (t *Task) ToStatusUpdate() agent.StatusUpdate { // StatusPending, StatusClaimed, StatusCancelled — not reported } + // Compute percent inline — do NOT call t.Percent() here since we already hold RLock. + // Calling Percent() (which also RLocks) while holding RLock deadlocks when a writer is waiting. + percent := 0 + if t.TotalBytes > 0 { + percent = int(float64(t.DownloadedBytes) / float64(t.TotalBytes) * 100) + if percent > 100 { + percent = 100 + } + } + return agent.StatusUpdate{ TaskID: t.ID, Status: apiStatus, - Progress: t.Percent(), + Progress: percent, DownloadedBytes: t.DownloadedBytes, TotalBytes: t.TotalBytes, SpeedBps: t.SpeedBps, diff --git a/internal/engine/upnp.go b/internal/engine/upnp.go index 9361157..50587c9 100644 --- a/internal/engine/upnp.go +++ b/internal/engine/upnp.go @@ -338,16 +338,28 @@ func localIPFor(host string) string { } // Remove deletes the port mapping from the router. +// It runs in a goroutine with a 5-second deadline so it never blocks shutdown. func (m *UPnPMapping) Remove() { if m == nil { return } - switch m.protocol { - case "natpmp": - m.removeNATPMP() - case "upnp": - m.removeUPnP() + done := make(chan struct{}) + go func() { + defer close(done) + switch m.protocol { + case "natpmp": + m.removeNATPMP() + case "upnp": + m.removeUPnP() + } + }() + select { + case <-done: + case <-time.After(10 * time.Second): + // removeNATPMP worst case: 3s dial + 5s natpmpMapPort deadline = 8s. + // 10s gives enough margin without blocking shutdown indefinitely. + log.Printf("stream: UPnP/NAT-PMP cleanup timed out after 10s — port %d may remain mapped", m.ExternalPort) } } diff --git a/internal/engine/usenet.go b/internal/engine/usenet.go index fda121b..c39be86 100644 --- a/internal/engine/usenet.go +++ b/internal/engine/usenet.go @@ -300,8 +300,16 @@ func (u *UsenetDownloader) Pause(taskID string) error { // Cancel aborts an in-progress download and removes partial files + resume state. func (u *UsenetDownloader) Cancel(taskID string) error { + // Read all fields under the lock — Download() writes tracker and taskDir under + // the same lock, so we must hold it while reading to avoid a data race. u.mu.Lock() dl := u.active[taskID] + var tracker *download.ProgressTracker + var taskDir string + if dl != nil { + tracker = dl.tracker + taskDir = dl.taskDir + } u.mu.Unlock() if dl == nil { @@ -312,13 +320,13 @@ func (u *UsenetDownloader) Cancel(taskID string) error { dl.cancel() // Remove resume state (best-effort) - if dl.tracker != nil { - dl.tracker.Remove() + if tracker != nil { + tracker.Remove() } // Remove partial download directory in background (can be slow for large dirs) - if dl.taskDir != "" { - go os.RemoveAll(dl.taskDir) + if taskDir != "" { + go os.RemoveAll(taskDir) } return nil