fix(daemon): cancel watch reporter on stream switch and re-notify ready
- Register WatchReporter cancel funcs in streamRegistry so they get cancelled when switching to a different stream (prevents goroutine leak) - Re-notify streamReady when the server is already serving the requested task (handles duplicate stream requests from the web UI) - Rewrite tests for byte-based tracking semantics, remove dead parseRangeStart tests
This commit is contained in:
parent
c612ebb2e4
commit
4d7362a567
2 changed files with 51 additions and 79 deletions
|
|
@ -286,8 +286,12 @@ func runDaemonStart() error {
|
||||||
task.SetStreamURL(streamSrv.URLsJSON())
|
task.SetStreamURL(streamSrv.URLsJSON())
|
||||||
log.Printf("[%s] streaming active download: %s", taskID[:8], provider.FileName())
|
log.Printf("[%s] streaming active download: %s", taskID[:8], provider.FileName())
|
||||||
|
|
||||||
// Start watch progress reporter
|
// Start watch progress reporter with cancellable context
|
||||||
go engine.NewWatchReporter(agentClient, streamSrv, taskID).Run(ctx)
|
watchCtx, watchCancel := context.WithCancel(ctx) //nolint:gosec // cancel stored in streamRegistry, called by cancelStreamContexts()
|
||||||
|
streamRegistry.mu.Lock()
|
||||||
|
streamRegistry.cancels["watch:"+taskID] = watchCancel
|
||||||
|
streamRegistry.mu.Unlock()
|
||||||
|
go engine.NewWatchReporter(agentClient, streamSrv, taskID).Run(watchCtx)
|
||||||
})
|
})
|
||||||
|
|
||||||
// Wire: daemon claimed tasks -> manager
|
// Wire: daemon claimed tasks -> manager
|
||||||
|
|
@ -318,8 +322,16 @@ func runDaemonStart() error {
|
||||||
|
|
||||||
// Wire: stream requests for completed downloads → set file on persistent server
|
// Wire: stream requests for completed downloads → set file on persistent server
|
||||||
d.OnStreamRequested = func(sr agent.StreamRequest) {
|
d.OnStreamRequested = func(sr agent.StreamRequest) {
|
||||||
// Skip if already serving this task
|
// Already serving this task — just notify server it's ready
|
||||||
if streamSrv.CurrentTaskID() == sr.TaskID {
|
if streamSrv.CurrentTaskID() == sr.TaskID {
|
||||||
|
go func() {
|
||||||
|
if _, err := transport.SendProgress(ctx, agent.StatusUpdate{
|
||||||
|
TaskID: sr.TaskID,
|
||||||
|
StreamReady: true,
|
||||||
|
}); err != nil {
|
||||||
|
log.Printf("[%s] stream ready re-notify failed: %v", sr.TaskID[:8], err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -365,8 +377,13 @@ func runDaemonStart() error {
|
||||||
|
|
||||||
log.Printf("[%s] streaming from disk: %s → %s", sr.TaskID[:8], filepath.Base(filePath), streamSrv.URL())
|
log.Printf("[%s] streaming from disk: %s → %s", sr.TaskID[:8], filepath.Base(filePath), streamSrv.URL())
|
||||||
|
|
||||||
// Start watch progress reporter
|
// Start watch progress reporter with a cancellable context
|
||||||
go engine.NewWatchReporter(agentClient, streamSrv, sr.TaskID).Run(ctx)
|
// so it stops when the user switches to a different stream.
|
||||||
|
watchCtx, watchCancel := context.WithCancel(ctx) //nolint:gosec // cancel stored in streamRegistry, called by cancelStreamContexts()
|
||||||
|
streamRegistry.mu.Lock()
|
||||||
|
streamRegistry.cancels["watch:"+sr.TaskID] = watchCancel
|
||||||
|
streamRegistry.mu.Unlock()
|
||||||
|
go engine.NewWatchReporter(agentClient, streamSrv, sr.TaskID).Run(watchCtx)
|
||||||
|
|
||||||
// Notify server that stream is ready (clears streamRequested flag)
|
// Notify server that stream is ready (clears streamRequested flag)
|
||||||
go func() {
|
go func() {
|
||||||
|
|
|
||||||
|
|
@ -2,38 +2,12 @@ package engine
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
|
||||||
// parseRangeStart
|
|
||||||
// ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
func TestParseRangeStart(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
header string
|
|
||||||
want int64
|
|
||||||
}{
|
|
||||||
{"bytes=0-", 0},
|
|
||||||
{"bytes=1024-", 1024},
|
|
||||||
{"bytes=5000-9999", 5000},
|
|
||||||
{"bytes=1048576-", 1048576},
|
|
||||||
{"", -1},
|
|
||||||
{"invalid", -1},
|
|
||||||
{"bytes=", -1},
|
|
||||||
{"bytes=-500", -1},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
got := parseRangeStart(tc.header)
|
|
||||||
if got != tc.want {
|
|
||||||
t.Errorf("parseRangeStart(%q) = %d, want %d", tc.header, got, tc.want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
// StreamServer.EstimatedProgress
|
// StreamServer.EstimatedProgress
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
|
|
@ -51,9 +25,9 @@ func TestEstimatedProgress_HalfWay(t *testing.T) {
|
||||||
ss.totalFileSize.Store(1000)
|
ss.totalFileSize.Store(1000)
|
||||||
ss.maxByteOffset.Store(500)
|
ss.maxByteOffset.Store(500)
|
||||||
|
|
||||||
pos, dur := ss.EstimatedProgress()
|
pos, _ := ss.EstimatedProgress()
|
||||||
if pos != 50 || dur != 100 {
|
if pos != 50 {
|
||||||
t.Errorf("expected (50, 100), got (%d, %d)", pos, dur)
|
t.Errorf("expected pct=50, got %d", pos)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -62,9 +36,9 @@ func TestEstimatedProgress_CapsAt100(t *testing.T) {
|
||||||
ss.totalFileSize.Store(1000)
|
ss.totalFileSize.Store(1000)
|
||||||
ss.maxByteOffset.Store(1500)
|
ss.maxByteOffset.Store(1500)
|
||||||
|
|
||||||
pos, dur := ss.EstimatedProgress()
|
pos, _ := ss.EstimatedProgress()
|
||||||
if pos != 100 || dur != 100 {
|
if pos != 100 {
|
||||||
t.Errorf("expected (100, 100), got (%d, %d)", pos, dur)
|
t.Errorf("expected pct=100, got %d", pos)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -95,7 +69,7 @@ func TestMaxByteOffsetNeverRegresses(t *testing.T) {
|
||||||
// End-to-end: real HTTP server with Range requests
|
// End-to-end: real HTTP server with Range requests
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
func TestStreamServerRangeTracking(t *testing.T) {
|
func TestStreamServerByteTracking(t *testing.T) {
|
||||||
// Create temp file (10 KB)
|
// Create temp file (10 KB)
|
||||||
tmpFile := t.TempDir() + "/test.mp4"
|
tmpFile := t.TempDir() + "/test.mp4"
|
||||||
data := make([]byte, 10240)
|
data := make([]byte, 10240)
|
||||||
|
|
@ -116,66 +90,47 @@ func TestStreamServerRangeTracking(t *testing.T) {
|
||||||
srv.SetFile(NewDiskFileProvider(tmpFile), "test-task")
|
srv.SetFile(NewDiskFileProvider(tmpFile), "test-task")
|
||||||
url := srv.URL()
|
url := srv.URL()
|
||||||
|
|
||||||
// 1. Non-range GET — maxByteOffset stays 0
|
// 1. Full GET — reads all bytes, maxByteOffset reaches file size
|
||||||
resp, err := http.Get(url)
|
resp, err := http.Get(url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("GET: %v", err)
|
t.Fatalf("GET: %v", err)
|
||||||
}
|
}
|
||||||
|
io.Copy(io.Discard, resp.Body)
|
||||||
resp.Body.Close()
|
resp.Body.Close()
|
||||||
|
|
||||||
if srv.maxByteOffset.Load() != 0 {
|
if srv.maxByteOffset.Load() != 10240 {
|
||||||
t.Errorf("non-range: expected 0, got %d", srv.maxByteOffset.Load())
|
t.Errorf("full read: expected 10240, got %d", srv.maxByteOffset.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. Range: bytes=5000- → offset 5000
|
// 2. Reset and verify progress after partial read via Range
|
||||||
|
srv.SetFile(NewDiskFileProvider(tmpFile), "test-task-2")
|
||||||
|
if srv.maxByteOffset.Load() != 0 {
|
||||||
|
t.Errorf("after reset: expected 0, got %d", srv.maxByteOffset.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Range request reads from offset 5000 to end (5240 bytes)
|
||||||
req, _ := http.NewRequest("GET", url, nil)
|
req, _ := http.NewRequest("GET", url, nil)
|
||||||
req.Header.Set("Range", "bytes=5000-")
|
req.Header.Set("Range", "bytes=5000-")
|
||||||
resp, err = http.DefaultClient.Do(req)
|
resp, err = http.DefaultClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Range GET: %v", err)
|
t.Fatalf("Range GET: %v", err)
|
||||||
}
|
}
|
||||||
resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusPartialContent {
|
if resp.StatusCode != http.StatusPartialContent {
|
||||||
t.Errorf("expected 206, got %d", resp.StatusCode)
|
t.Errorf("expected 206, got %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
if srv.maxByteOffset.Load() != 5000 {
|
io.Copy(io.Discard, resp.Body)
|
||||||
t.Errorf("expected 5000, got %d", srv.maxByteOffset.Load())
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3. Higher offset
|
|
||||||
req, _ = http.NewRequest("GET", url, nil)
|
|
||||||
req.Header.Set("Range", "bytes=8000-")
|
|
||||||
resp, err = http.DefaultClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Range GET 2: %v", err)
|
|
||||||
}
|
|
||||||
resp.Body.Close()
|
resp.Body.Close()
|
||||||
|
|
||||||
if srv.maxByteOffset.Load() != 8000 {
|
// The reader reads 5240 bytes (from offset 5000 to 10240).
|
||||||
t.Errorf("expected 8000, got %d", srv.maxByteOffset.Load())
|
// maxByteOffset tracks the read position, which ends at 10240.
|
||||||
|
got := srv.maxByteOffset.Load()
|
||||||
|
if got != 10240 {
|
||||||
|
t.Errorf("after range read: expected 10240, got %d", got)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. Lower offset does NOT regress
|
// 3. Verify progress reaches 100%
|
||||||
req, _ = http.NewRequest("GET", url, nil)
|
pos, _ := srv.EstimatedProgress()
|
||||||
req.Header.Set("Range", "bytes=2000-")
|
if pos != 100 {
|
||||||
resp, err = http.DefaultClient.Do(req)
|
t.Errorf("expected pct=100, got %d", pos)
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Range GET 3: %v", err)
|
|
||||||
}
|
|
||||||
resp.Body.Close()
|
|
||||||
|
|
||||||
if srv.maxByteOffset.Load() != 8000 {
|
|
||||||
t.Errorf("expected still 8000, got %d", srv.maxByteOffset.Load())
|
|
||||||
}
|
|
||||||
|
|
||||||
// 5. Verify progress estimate
|
|
||||||
pos, dur := srv.EstimatedProgress()
|
|
||||||
// 8000/10240 = 78.1% → 78
|
|
||||||
if pos < 78 || pos > 79 {
|
|
||||||
t.Errorf("expected pos ~78, got %d", pos)
|
|
||||||
}
|
|
||||||
if dur != 100 {
|
|
||||||
t.Errorf("expected dur=100, got %d", dur)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue