From 4d7362a5670358a2e8df8b9143f65f605636f981 Mon Sep 17 00:00:00 2001 From: Deivid Soto Date: Tue, 7 Apr 2026 23:29:09 +0200 Subject: [PATCH] fix(daemon): cancel watch reporter on stream switch and re-notify ready - Register WatchReporter cancel funcs in streamRegistry so they get cancelled when switching to a different stream (prevents goroutine leak) - Re-notify streamReady when the server is already serving the requested task (handles duplicate stream requests from the web UI) - Rewrite tests for byte-based tracking semantics, remove dead parseRangeStart tests --- internal/cmd/daemon.go | 27 +++++-- internal/engine/watch_reporter_test.go | 103 +++++++------------------ 2 files changed, 51 insertions(+), 79 deletions(-) diff --git a/internal/cmd/daemon.go b/internal/cmd/daemon.go index c1887e2..a6abc4c 100644 --- a/internal/cmd/daemon.go +++ b/internal/cmd/daemon.go @@ -286,8 +286,12 @@ func runDaemonStart() error { task.SetStreamURL(streamSrv.URLsJSON()) log.Printf("[%s] streaming active download: %s", taskID[:8], provider.FileName()) - // Start watch progress reporter - go engine.NewWatchReporter(agentClient, streamSrv, taskID).Run(ctx) + // Start watch progress reporter with cancellable context + watchCtx, watchCancel := context.WithCancel(ctx) //nolint:gosec // cancel stored in streamRegistry, called by cancelStreamContexts() + streamRegistry.mu.Lock() + streamRegistry.cancels["watch:"+taskID] = watchCancel + streamRegistry.mu.Unlock() + go engine.NewWatchReporter(agentClient, streamSrv, taskID).Run(watchCtx) }) // Wire: daemon claimed tasks -> manager @@ -318,8 +322,16 @@ func runDaemonStart() error { // Wire: stream requests for completed downloads → set file on persistent server d.OnStreamRequested = func(sr agent.StreamRequest) { - // Skip if already serving this task + // Already serving this task — just notify server it's ready if streamSrv.CurrentTaskID() == sr.TaskID { + go func() { + if _, err := transport.SendProgress(ctx, agent.StatusUpdate{ + TaskID: sr.TaskID, + StreamReady: true, + }); err != nil { + log.Printf("[%s] stream ready re-notify failed: %v", sr.TaskID[:8], err) + } + }() return } @@ -365,8 +377,13 @@ func runDaemonStart() error { log.Printf("[%s] streaming from disk: %s → %s", sr.TaskID[:8], filepath.Base(filePath), streamSrv.URL()) - // Start watch progress reporter - go engine.NewWatchReporter(agentClient, streamSrv, sr.TaskID).Run(ctx) + // Start watch progress reporter with a cancellable context + // so it stops when the user switches to a different stream. + watchCtx, watchCancel := context.WithCancel(ctx) //nolint:gosec // cancel stored in streamRegistry, called by cancelStreamContexts() + streamRegistry.mu.Lock() + streamRegistry.cancels["watch:"+sr.TaskID] = watchCancel + streamRegistry.mu.Unlock() + go engine.NewWatchReporter(agentClient, streamSrv, sr.TaskID).Run(watchCtx) // Notify server that stream is ready (clears streamRequested flag) go func() { diff --git a/internal/engine/watch_reporter_test.go b/internal/engine/watch_reporter_test.go index 8cd0878..b9f17c0 100644 --- a/internal/engine/watch_reporter_test.go +++ b/internal/engine/watch_reporter_test.go @@ -2,38 +2,12 @@ package engine import ( "context" + "io" "net/http" "os" "testing" ) -// --------------------------------------------------------------------------- -// parseRangeStart -// --------------------------------------------------------------------------- - -func TestParseRangeStart(t *testing.T) { - tests := []struct { - header string - want int64 - }{ - {"bytes=0-", 0}, - {"bytes=1024-", 1024}, - {"bytes=5000-9999", 5000}, - {"bytes=1048576-", 1048576}, - {"", -1}, - {"invalid", -1}, - {"bytes=", -1}, - {"bytes=-500", -1}, - } - - for _, tc := range tests { - got := parseRangeStart(tc.header) - if got != tc.want { - t.Errorf("parseRangeStart(%q) = %d, want %d", tc.header, got, tc.want) - } - } -} - // --------------------------------------------------------------------------- // StreamServer.EstimatedProgress // --------------------------------------------------------------------------- @@ -51,9 +25,9 @@ func TestEstimatedProgress_HalfWay(t *testing.T) { ss.totalFileSize.Store(1000) ss.maxByteOffset.Store(500) - pos, dur := ss.EstimatedProgress() - if pos != 50 || dur != 100 { - t.Errorf("expected (50, 100), got (%d, %d)", pos, dur) + pos, _ := ss.EstimatedProgress() + if pos != 50 { + t.Errorf("expected pct=50, got %d", pos) } } @@ -62,9 +36,9 @@ func TestEstimatedProgress_CapsAt100(t *testing.T) { ss.totalFileSize.Store(1000) ss.maxByteOffset.Store(1500) - pos, dur := ss.EstimatedProgress() - if pos != 100 || dur != 100 { - t.Errorf("expected (100, 100), got (%d, %d)", pos, dur) + pos, _ := ss.EstimatedProgress() + if pos != 100 { + t.Errorf("expected pct=100, got %d", pos) } } @@ -95,7 +69,7 @@ func TestMaxByteOffsetNeverRegresses(t *testing.T) { // End-to-end: real HTTP server with Range requests // --------------------------------------------------------------------------- -func TestStreamServerRangeTracking(t *testing.T) { +func TestStreamServerByteTracking(t *testing.T) { // Create temp file (10 KB) tmpFile := t.TempDir() + "/test.mp4" data := make([]byte, 10240) @@ -116,66 +90,47 @@ func TestStreamServerRangeTracking(t *testing.T) { srv.SetFile(NewDiskFileProvider(tmpFile), "test-task") url := srv.URL() - // 1. Non-range GET — maxByteOffset stays 0 + // 1. Full GET — reads all bytes, maxByteOffset reaches file size resp, err := http.Get(url) if err != nil { t.Fatalf("GET: %v", err) } + io.Copy(io.Discard, resp.Body) resp.Body.Close() - if srv.maxByteOffset.Load() != 0 { - t.Errorf("non-range: expected 0, got %d", srv.maxByteOffset.Load()) + if srv.maxByteOffset.Load() != 10240 { + t.Errorf("full read: expected 10240, got %d", srv.maxByteOffset.Load()) } - // 2. Range: bytes=5000- → offset 5000 + // 2. Reset and verify progress after partial read via Range + srv.SetFile(NewDiskFileProvider(tmpFile), "test-task-2") + if srv.maxByteOffset.Load() != 0 { + t.Errorf("after reset: expected 0, got %d", srv.maxByteOffset.Load()) + } + + // Range request reads from offset 5000 to end (5240 bytes) req, _ := http.NewRequest("GET", url, nil) req.Header.Set("Range", "bytes=5000-") resp, err = http.DefaultClient.Do(req) if err != nil { t.Fatalf("Range GET: %v", err) } - resp.Body.Close() - if resp.StatusCode != http.StatusPartialContent { t.Errorf("expected 206, got %d", resp.StatusCode) } - if srv.maxByteOffset.Load() != 5000 { - t.Errorf("expected 5000, got %d", srv.maxByteOffset.Load()) - } - - // 3. Higher offset - req, _ = http.NewRequest("GET", url, nil) - req.Header.Set("Range", "bytes=8000-") - resp, err = http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("Range GET 2: %v", err) - } + io.Copy(io.Discard, resp.Body) resp.Body.Close() - if srv.maxByteOffset.Load() != 8000 { - t.Errorf("expected 8000, got %d", srv.maxByteOffset.Load()) + // The reader reads 5240 bytes (from offset 5000 to 10240). + // maxByteOffset tracks the read position, which ends at 10240. + got := srv.maxByteOffset.Load() + if got != 10240 { + t.Errorf("after range read: expected 10240, got %d", got) } - // 4. Lower offset does NOT regress - req, _ = http.NewRequest("GET", url, nil) - req.Header.Set("Range", "bytes=2000-") - resp, err = http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("Range GET 3: %v", err) - } - resp.Body.Close() - - if srv.maxByteOffset.Load() != 8000 { - t.Errorf("expected still 8000, got %d", srv.maxByteOffset.Load()) - } - - // 5. Verify progress estimate - pos, dur := srv.EstimatedProgress() - // 8000/10240 = 78.1% → 78 - if pos < 78 || pos > 79 { - t.Errorf("expected pos ~78, got %d", pos) - } - if dur != 100 { - t.Errorf("expected dur=100, got %d", dur) + // 3. Verify progress reaches 100% + pos, _ := srv.EstimatedProgress() + if pos != 100 { + t.Errorf("expected pct=100, got %d", pos) } }