diff --git a/internal/cmd/daemon.go b/internal/cmd/daemon.go index 2cd9125..55b37c5 100644 --- a/internal/cmd/daemon.go +++ b/internal/cmd/daemon.go @@ -284,9 +284,19 @@ func runDaemonStart() error { d.OnTasksClaimed = func(tasks []agent.Task) { for _, t := range tasks { if t.Mode == "stream" { + // Skip if already streaming this task + if isStreamingTask(t.ID) { + continue + } // Only 1 stream at a time: cancel all existing streams cancelAllStreams() - go handleStreamTask(ctx, t, reporter, cfg, agentClient) + // Reserve slot before spawning goroutine to prevent TOCTOU race. + // streamCancel is stored in streamRegistry and called by cancelAllStreams/cancelStreamTask. + streamCtx, streamCancel := context.WithCancel(ctx) //nolint:gosec // G118: cancel ownership transferred to streamRegistry + streamRegistry.mu.Lock() + streamRegistry.cancels[t.ID] = streamCancel + streamRegistry.mu.Unlock() + go handleStreamTask(streamCtx, t, reporter, cfg, agentClient) } else if t.ForceStart || manager.HasCapacity() { manager.Submit(ctx, t) } else { @@ -297,6 +307,11 @@ func runDaemonStart() error { // Wire: stream requests for completed downloads → serve file from disk d.OnStreamRequested = func(sr agent.StreamRequest) { + // Skip if already streaming this task + if isStreamingTask(sr.TaskID) { + return + } + // Only 1 stream at a time: cancel all existing streams cancelAllStreams() @@ -337,7 +352,7 @@ func runDaemonStart() error { } srv := engine.NewStreamServerFromDisk(filePath, cfg.Download.StreamPort) - streamURL, err := srv.Start(context.Background()) + streamURL, err := srv.Start(ctx) if err != nil { log.Printf("[%s] stream failed: %v", sr.TaskID[:8], err) go func() { @@ -388,20 +403,16 @@ func runDaemonStart() error { log.Printf("[%s] resume requested via WebSocket, triggering poll", taskID[:8]) d.TriggerPoll() case "stream": - // Only 1 stream at a time: cancel all existing streams - cancelAllStreams() - // Use registry mutex to prevent TOCTOU race with HTTP-polled stream requests - streamRegistry.mu.Lock() - if _, exists := streamRegistry.servers[taskID]; exists { - streamRegistry.mu.Unlock() + // Skip if already streaming this task + if isStreamingTask(taskID) { return } task := manager.GetTask(taskID) if task == nil || task.GetStreamURL() != "" { - streamRegistry.mu.Unlock() return } - streamRegistry.mu.Unlock() + // Only 1 stream at a time: cancel all existing streams + cancelAllStreams() srv, err := torrentDl.StartStream(taskID) if err != nil { log.Printf("[%s] stream failed: %v", taskID[:8], err) diff --git a/internal/cmd/stream_handler.go b/internal/cmd/stream_handler.go index def74ab..cd66e25 100644 --- a/internal/cmd/stream_handler.go +++ b/internal/cmd/stream_handler.go @@ -14,7 +14,9 @@ import ( "github.com/torrentclaw/unarr/internal/ui" ) -// startIdleGuard monitors a stream server and cancels the task after 30 minutes of inactivity. +const streamIdleTimeout = 30 * time.Minute + +// startIdleGuard monitors a stream server and cancels the task after inactivity. func startIdleGuard(ctx context.Context, srv *engine.StreamServer, taskID string) { ticker := time.NewTicker(60 * time.Second) defer ticker.Stop() @@ -23,8 +25,8 @@ func startIdleGuard(ctx context.Context, srv *engine.StreamServer, taskID string case <-ctx.Done(): return case <-ticker.C: - if srv.IdleSince() > 30*time.Minute { - log.Printf("[%s] stream idle timeout (30m no HTTP requests), shutting down", taskID[:8]) + if srv.IdleSince() > streamIdleTimeout { + log.Printf("[%s] stream idle timeout (%v no HTTP requests), shutting down", taskID[:8], streamIdleTimeout) cancelStreamTask(taskID) return } @@ -45,29 +47,50 @@ var streamRegistry = struct { // cancelAllStreams cancels all active stream tasks and servers (only 1 stream at a time). func cancelAllStreams() { streamRegistry.mu.Lock() - for taskID, cancel := range streamRegistry.cancels { - cancel() - delete(streamRegistry.cancels, taskID) + cancels := make(map[string]context.CancelFunc, len(streamRegistry.cancels)) + for k, v := range streamRegistry.cancels { + cancels[k] = v + delete(streamRegistry.cancels, k) } - for taskID, srv := range streamRegistry.servers { - srv.Shutdown(context.Background()) - delete(streamRegistry.servers, taskID) + servers := make(map[string]*engine.StreamServer, len(streamRegistry.servers)) + for k, v := range streamRegistry.servers { + servers[k] = v + delete(streamRegistry.servers, k) } streamRegistry.mu.Unlock() + + for _, cancel := range cancels { + cancel() + } + for _, srv := range servers { + srv.Shutdown(context.Background()) + } +} + +// isStreamingTask returns true if there is an active stream (goroutine or server) for the given task. +func isStreamingTask(taskID string) bool { + streamRegistry.mu.Lock() + defer streamRegistry.mu.Unlock() + _, inCancels := streamRegistry.cancels[taskID] + _, inServers := streamRegistry.servers[taskID] + return inCancels || inServers } // cancelStreamTask cancels a running stream task and shuts down any stream server. func cancelStreamTask(taskID string) { streamRegistry.mu.Lock() - if cancel, ok := streamRegistry.cancels[taskID]; ok { - cancel() - delete(streamRegistry.cancels, taskID) - } - if srv, ok := streamRegistry.servers[taskID]; ok { - srv.Shutdown(context.Background()) - delete(streamRegistry.servers, taskID) - } + cancel, hasCancel := streamRegistry.cancels[taskID] + delete(streamRegistry.cancels, taskID) + srv, hasSrv := streamRegistry.servers[taskID] + delete(streamRegistry.servers, taskID) streamRegistry.mu.Unlock() + + if hasCancel { + cancel() + } + if hasSrv { + srv.Shutdown(context.Background()) + } } // handleStreamTask manages a streaming task lifecycle outside the Manager. @@ -133,7 +156,15 @@ func handleStreamTask(parentCtx context.Context, at agent.Task, reporter *engine task.Transition(engine.StatusFailed) return } - defer srv.Shutdown(context.Background()) + streamRegistry.mu.Lock() + streamRegistry.servers[at.ID] = srv + streamRegistry.mu.Unlock() + defer func() { + srv.Shutdown(context.Background()) + streamRegistry.mu.Lock() + delete(streamRegistry.servers, at.ID) + streamRegistry.mu.Unlock() + }() // 5. Report stream URL — the reporter will send this to the web task.StreamURL = streamURL diff --git a/internal/engine/stream_server.go b/internal/engine/stream_server.go index 97c7787..ed3f6d8 100644 --- a/internal/engine/stream_server.go +++ b/internal/engine/stream_server.go @@ -96,6 +96,7 @@ func (p *diskFileProvider) FileName() string { return p.name } func (p *diskFileProvider) FileSize() int64 { fi, err := os.Stat(p.path) if err != nil { + log.Printf("stream: failed to stat %q: %v", p.path, err) return 0 } return fi.Size() @@ -244,6 +245,14 @@ func (ss *StreamServer) handler(w http.ResponseWriter, r *http.Request) { defer reader.Close() w.Header().Set("Content-Type", mimeTypeFromExt(ss.provider.FileName())) + // "inline" for play requests (VLC/mpv), "attachment" for download requests. + // Browser download via window.open() relies on "attachment" to trigger save dialog. + disposition := "inline" + if r.URL.Query().Get("download") == "1" { + disposition = "attachment" + } + w.Header().Set("Content-Disposition", fmt.Sprintf("%s; filename=%q", disposition, ss.provider.FileName())) + w.Header().Set("Accept-Ranges", "bytes") http.ServeContent(w, r, ss.provider.FileName(), time.Time{}, reader) }