diff --git a/Makefile b/Makefile
index 4a8245e..6207d50 100644
--- a/Makefile
+++ b/Makefile
@@ -19,10 +19,13 @@ test:
lint:
golangci-lint run ./...
-## Run tests with coverage report
+## Run tests with coverage report (excludes CLI layer — cmd/ is glue code)
+COVER_PKGS = $(shell go list ./... | grep -v '/cmd')
coverage:
- go test -race -coverprofile=coverage.out -covermode=atomic ./...
- go tool cover -func=coverage.out
+ go test -race -coverprofile=coverage.out -covermode=atomic $(COVER_PKGS)
+ @echo "──────────────────────────────────────"
+ @go tool cover -func=coverage.out | tail -1
+ @echo "──────────────────────────────────────"
go tool cover -html=coverage.out -o coverage.html
## Format code
diff --git a/internal/agent/client_test.go b/internal/agent/client_test.go
index 9266b74..c8ce68d 100644
--- a/internal/agent/client_test.go
+++ b/internal/agent/client_test.go
@@ -324,3 +324,265 @@ func TestHeartbeatWithoutUpgradeSignal(t *testing.T) {
t.Errorf("expected no upgrade signal, got %+v", resp.Upgrade)
}
}
+
+func TestDeregister(t *testing.T) {
+ var received struct {
+ AgentID string `json:"agentId"`
+ }
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path != "/api/internal/agent/deregister" {
+ t.Errorf("path = %s", r.URL.Path)
+ }
+ if r.Method != http.MethodPost {
+ t.Errorf("method = %s, want POST", r.Method)
+ }
+ json.NewDecoder(r.Body).Decode(&received)
+ json.NewEncoder(w).Encode(StatusResponse{Success: true})
+ }))
+ defer srv.Close()
+
+ c := NewClient(srv.URL, "test-key", "unarr-test")
+ err := c.Deregister(context.Background(), "agent-42")
+ if err != nil {
+ t.Fatalf("Deregister failed: %v", err)
+ }
+ if received.AgentID != "agent-42" {
+ t.Errorf("agentId = %q, want agent-42", received.AgentID)
+ }
+}
+
+func TestBatchReportStatus(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path != "/api/internal/agent/status" {
+ t.Errorf("path = %s", r.URL.Path)
+ }
+ var req BatchStatusRequest
+ json.NewDecoder(r.Body).Decode(&req)
+ if len(req.Updates) != 2 {
+ t.Errorf("expected 2 updates, got %d", len(req.Updates))
+ }
+ json.NewEncoder(w).Encode(BatchStatusResponse{
+ Results: []StatusResponse{
+ {Success: true},
+ {Success: true, Cancelled: true},
+ },
+ })
+ }))
+ defer srv.Close()
+
+ c := NewClient(srv.URL, "test-key", "unarr-test")
+ resp, err := c.BatchReportStatus(context.Background(), []StatusUpdate{
+ {TaskID: "t1", Status: "downloading"},
+ {TaskID: "t2", Status: "completed"},
+ })
+ if err != nil {
+ t.Fatalf("BatchReportStatus failed: %v", err)
+ }
+ if len(resp.Results) != 2 {
+ t.Fatalf("expected 2 results, got %d", len(resp.Results))
+ }
+ if !resp.Results[1].Cancelled {
+ t.Error("expected result[1].Cancelled=true")
+ }
+}
+
+func TestSearchNzbs(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path != "/api/internal/agent/nzb-search" {
+ t.Errorf("path = %s", r.URL.Path)
+ }
+ json.NewEncoder(w).Encode(NzbSearchResponse{
+ Results: []NzbSearchResult{
+ {NzbID: "nzb-1", Title: "Movie.2023.1080p"},
+ },
+ })
+ }))
+ defer srv.Close()
+
+ c := NewClient(srv.URL, "test-key", "unarr-test")
+ resp, err := c.SearchNzbs(context.Background(), NzbSearchParams{Query: "Movie"})
+ if err != nil {
+ t.Fatalf("SearchNzbs failed: %v", err)
+ }
+ if len(resp.Results) != 1 {
+ t.Fatalf("expected 1 result, got %d", len(resp.Results))
+ }
+ if resp.Results[0].NzbID != "nzb-1" {
+ t.Errorf("nzb ID = %q, want nzb-1", resp.Results[0].NzbID)
+ }
+}
+
+func TestDownloadNzb(t *testing.T) {
+ nzbContent := []byte(`test`)
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path != "/api/internal/agent/nzb-download" {
+ t.Errorf("path = %s", r.URL.Path)
+ }
+ if r.URL.Query().Get("nzbId") != "nzb-42" {
+ t.Errorf("nzbId = %q, want nzb-42", r.URL.Query().Get("nzbId"))
+ }
+ w.Write(nzbContent)
+ }))
+ defer srv.Close()
+
+ c := NewClient(srv.URL, "test-key", "unarr-test")
+ data, err := c.DownloadNzb(context.Background(), "nzb-42")
+ if err != nil {
+ t.Fatalf("DownloadNzb failed: %v", err)
+ }
+ if string(data) != string(nzbContent) {
+ t.Errorf("nzb content mismatch")
+ }
+}
+
+func TestDownloadNzbError(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusNotFound)
+ w.Write([]byte("NZB not found"))
+ }))
+ defer srv.Close()
+
+ c := NewClient(srv.URL, "test-key", "unarr-test")
+ _, err := c.DownloadNzb(context.Background(), "bad-id")
+ if err == nil {
+ t.Fatal("expected error for 404 response")
+ }
+}
+
+func TestGetUsenetCredentials(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path != "/api/internal/agent/usenet-credentials" {
+ t.Errorf("path = %s", r.URL.Path)
+ }
+ json.NewEncoder(w).Encode(UsenetCredentials{
+ Host: "news.example.com",
+ Port: 563,
+ SSL: true,
+ Username: "user1",
+ Password: "pass1",
+ MaxConnections: 10,
+ })
+ }))
+ defer srv.Close()
+
+ c := NewClient(srv.URL, "test-key", "unarr-test")
+ creds, err := c.GetUsenetCredentials(context.Background())
+ if err != nil {
+ t.Fatalf("GetUsenetCredentials failed: %v", err)
+ }
+ if creds.Host != "news.example.com" {
+ t.Errorf("host = %q, want news.example.com", creds.Host)
+ }
+ if creds.Username != "user1" {
+ t.Errorf("username = %q, want user1", creds.Username)
+ }
+ if creds.MaxConnections != 10 {
+ t.Errorf("maxConnections = %d, want 10", creds.MaxConnections)
+ }
+}
+
+func TestGetUsenetUsage(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path != "/api/internal/agent/usenet-usage" {
+ t.Errorf("path = %s", r.URL.Path)
+ }
+ json.NewEncoder(w).Encode(UsenetUsageResponse{
+ UsedBytes: 5368709120,
+ QuotaBytes: 10737418240,
+ })
+ }))
+ defer srv.Close()
+
+ c := NewClient(srv.URL, "test-key", "unarr-test")
+ usage, err := c.GetUsenetUsage(context.Background())
+ if err != nil {
+ t.Fatalf("GetUsenetUsage failed: %v", err)
+ }
+ if usage.UsedBytes != 5368709120 {
+ t.Errorf("usedBytes = %d", usage.UsedBytes)
+ }
+}
+
+func TestConfigureDebrid(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path != "/api/internal/agent/debrid-config" {
+ t.Errorf("path = %s", r.URL.Path)
+ }
+ json.NewEncoder(w).Encode(ConfigureDebridResponse{Success: true})
+ }))
+ defer srv.Close()
+
+ c := NewClient(srv.URL, "test-key", "unarr-test")
+ resp, err := c.ConfigureDebrid(context.Background(), ConfigureDebridRequest{
+ Provider: "real-debrid",
+ Token: "rd-token-123",
+ })
+ if err != nil {
+ t.Fatalf("ConfigureDebrid failed: %v", err)
+ }
+ if !resp.Success {
+ t.Error("expected success=true")
+ }
+}
+
+func TestBatchDownload(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path != "/api/internal/agent/batch-download" {
+ t.Errorf("path = %s", r.URL.Path)
+ }
+ json.NewEncoder(w).Encode(BatchDownloadResponse{
+ Queued: 3,
+ NotFound: 1,
+ })
+ }))
+ defer srv.Close()
+
+ c := NewClient(srv.URL, "test-key", "unarr-test")
+ resp, err := c.BatchDownload(context.Background(), BatchDownloadRequest{})
+ if err != nil {
+ t.Fatalf("BatchDownload failed: %v", err)
+ }
+ if resp.Queued != 3 {
+ t.Errorf("queued = %d, want 3", resp.Queued)
+ }
+}
+
+func TestSyncLibrary(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path != "/api/internal/agent/library-sync" {
+ t.Errorf("path = %s", r.URL.Path)
+ }
+ json.NewEncoder(w).Encode(LibrarySyncResponse{
+ Matched: 10,
+ Synced: 15,
+ Removed: 2,
+ })
+ }))
+ defer srv.Close()
+
+ c := NewClient(srv.URL, "test-key", "unarr-test")
+ resp, err := c.SyncLibrary(context.Background(), LibrarySyncRequest{})
+ if err != nil {
+ t.Fatalf("SyncLibrary failed: %v", err)
+ }
+ if resp.Matched != 10 {
+ t.Errorf("matched = %d, want 10", resp.Matched)
+ }
+ if resp.Synced != 15 {
+ t.Errorf("synced = %d, want 15", resp.Synced)
+ }
+}
+
+func TestHTMLErrorResponse(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusBadGateway)
+ w.Write([]byte("
502 Bad Gateway"))
+ }))
+ defer srv.Close()
+
+ c := NewClient(srv.URL, "test-key", "unarr-test")
+ _, err := c.Register(context.Background(), RegisterRequest{AgentID: "x"})
+ if err == nil {
+ t.Fatal("expected error for HTML error page")
+ }
+}
diff --git a/internal/agent/transport_test.go b/internal/agent/transport_test.go
index a9c7e5d..be2f6c6 100644
--- a/internal/agent/transport_test.go
+++ b/internal/agent/transport_test.go
@@ -443,3 +443,1148 @@ func TestHybridTransportWSDisconnectSwitchesToHTTP(t *testing.T) {
t.Errorf("expected http after disconnect, got %s", h.Mode())
}
}
+
+// ── Additional HTTP Transport Tests ─────────────────────────────────────────
+
+func TestNewHTTPTransportConstructor(t *testing.T) {
+ tr := NewHTTPTransport("http://example.com", "my-key", "my-agent/1.0")
+
+ if tr.client == nil {
+ t.Fatal("expected client to be non-nil")
+ }
+ if tr.events == nil {
+ t.Fatal("expected events channel to be non-nil")
+ }
+ // events channel should have capacity 10
+ if cap(tr.events) != 10 {
+ t.Errorf("expected events capacity 10, got %d", cap(tr.events))
+ }
+}
+
+func TestHTTPTransportConnectAndCloseAreNoOps(t *testing.T) {
+ tr := NewHTTPTransport("http://localhost", "key", "ua")
+
+ if err := tr.Connect(context.Background()); err != nil {
+ t.Errorf("Connect should be a no-op, got error: %v", err)
+ }
+ if err := tr.Close(); err != nil {
+ t.Errorf("Close should be a no-op, got error: %v", err)
+ }
+}
+
+func TestHTTPTransportClientAccessor(t *testing.T) {
+ tr := NewHTTPTransport("http://localhost", "key", "ua")
+ c := tr.Client()
+ if c == nil {
+ t.Fatal("Client() should return the underlying client")
+ }
+ if c != tr.client {
+ t.Error("Client() should return the same instance stored internally")
+ }
+}
+
+func TestHTTPTransportSendHeartbeat(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ t.Errorf("expected POST, got %s", r.Method)
+ }
+ if !strings.Contains(r.URL.Path, "heartbeat") {
+ t.Errorf("expected heartbeat path, got %s", r.URL.Path)
+ }
+ json.NewEncoder(w).Encode(HeartbeatResponse{
+ Success: true,
+ Watching: true,
+ Upgrade: &UpgradeSignal{Version: "9.9.9"},
+ })
+ }))
+ defer srv.Close()
+
+ tr := NewHTTPTransport(srv.URL, "key", "ua")
+ resp, err := tr.SendHeartbeat(context.Background(), HeartbeatRequest{
+ AgentID: "a1",
+ Name: "test",
+ Version: "1.0",
+ })
+ if err != nil {
+ t.Fatalf("SendHeartbeat failed: %v", err)
+ }
+ if !resp.Success {
+ t.Error("expected success")
+ }
+ if !resp.Watching {
+ t.Error("expected watching=true")
+ }
+ if resp.Upgrade == nil || resp.Upgrade.Version != "9.9.9" {
+ t.Error("expected upgrade version 9.9.9")
+ }
+}
+
+func TestHTTPTransportSendProgress(t *testing.T) {
+ var received StatusUpdate
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ json.NewDecoder(r.Body).Decode(&received)
+ json.NewEncoder(w).Encode(StatusResponse{
+ Success: true,
+ Cancelled: true,
+ })
+ }))
+ defer srv.Close()
+
+ tr := NewHTTPTransport(srv.URL, "key", "ua")
+ resp, err := tr.SendProgress(context.Background(), StatusUpdate{
+ TaskID: "task-1",
+ Status: "downloading",
+ Progress: 55,
+ SpeedBps: 1024000,
+ })
+ if err != nil {
+ t.Fatalf("SendProgress failed: %v", err)
+ }
+ if !resp.Success {
+ t.Error("expected success")
+ }
+ if !resp.Cancelled {
+ t.Error("expected cancelled flag")
+ }
+ if received.TaskID != "task-1" {
+ t.Errorf("expected task-1, got %s", received.TaskID)
+ }
+ if received.Progress != 55 {
+ t.Errorf("expected progress 55, got %d", received.Progress)
+ }
+}
+
+func TestHTTPTransportClaimTasks(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodGet {
+ t.Errorf("expected GET, got %s", r.Method)
+ }
+ agentID := r.URL.Query().Get("agentId")
+ if agentID != "agent-42" {
+ t.Errorf("expected agentId=agent-42, got %s", agentID)
+ }
+ json.NewEncoder(w).Encode(TasksResponse{
+ Tasks: []Task{
+ {ID: "t1", Title: "Movie 1", InfoHash: "abc"},
+ {ID: "t2", Title: "Movie 2", InfoHash: "def"},
+ },
+ })
+ }))
+ defer srv.Close()
+
+ tr := NewHTTPTransport(srv.URL, "key", "ua")
+ resp, err := tr.ClaimTasks(context.Background(), "agent-42")
+ if err != nil {
+ t.Fatalf("ClaimTasks failed: %v", err)
+ }
+ if len(resp.Tasks) != 2 {
+ t.Fatalf("expected 2 tasks, got %d", len(resp.Tasks))
+ }
+ if resp.Tasks[0].Title != "Movie 1" {
+ t.Errorf("expected Movie 1, got %s", resp.Tasks[0].Title)
+ }
+}
+
+func TestHTTPTransportDeregister(t *testing.T) {
+ var called bool
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ called = true
+ json.NewEncoder(w).Encode(StatusResponse{Success: true})
+ }))
+ defer srv.Close()
+
+ tr := NewHTTPTransport(srv.URL, "key", "ua")
+ err := tr.Deregister(context.Background(), "agent-1")
+ if err != nil {
+ t.Fatalf("Deregister failed: %v", err)
+ }
+ if !called {
+ t.Error("expected server to be called")
+ }
+}
+
+func TestHTTPTransportBatchReportStatus(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ json.NewEncoder(w).Encode(BatchStatusResponse{
+ Results: []StatusResponse{
+ {Success: true},
+ {Success: true, Cancelled: true},
+ },
+ Watching: true,
+ })
+ }))
+ defer srv.Close()
+
+ tr := NewHTTPTransport(srv.URL, "key", "ua")
+ resp, err := tr.BatchReportStatus(context.Background(), []StatusUpdate{
+ {TaskID: "t1", Status: "downloading", Progress: 10},
+ {TaskID: "t2", Status: "completed", Progress: 100},
+ })
+ if err != nil {
+ t.Fatalf("BatchReportStatus failed: %v", err)
+ }
+ if len(resp.Results) != 2 {
+ t.Fatalf("expected 2 results, got %d", len(resp.Results))
+ }
+ if !resp.Watching {
+ t.Error("expected watching=true")
+ }
+ if !resp.Results[1].Cancelled {
+ t.Error("expected second result to be cancelled")
+ }
+}
+
+func TestHTTPTransportAuthHeader(t *testing.T) {
+ var gotAuth string
+ var gotUA string
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ gotAuth = r.Header.Get("Authorization")
+ gotUA = r.Header.Get("User-Agent")
+ json.NewEncoder(w).Encode(RegisterResponse{Success: true})
+ }))
+ defer srv.Close()
+
+ tr := NewHTTPTransport(srv.URL, "secret-key-123", "unarr/2.0")
+ tr.Register(context.Background(), RegisterRequest{AgentID: "a1"})
+
+ if gotAuth != "Bearer secret-key-123" {
+ t.Errorf("expected Bearer secret-key-123, got %s", gotAuth)
+ }
+ if gotUA != "unarr/2.0" {
+ t.Errorf("expected unarr/2.0, got %s", gotUA)
+ }
+}
+
+// ── Additional WebSocket Transport Tests ────────────────────────────────────
+
+func TestNewWSTransportConstructor(t *testing.T) {
+ tr := NewWSTransport("ws://example.com/ws", "api-key", "agent-1", "ua/1.0")
+
+ if tr.Mode() != "ws" {
+ t.Errorf("expected ws mode, got %s", tr.Mode())
+ }
+ if tr.wsURL != "ws://example.com/ws" {
+ t.Errorf("expected ws URL, got %s", tr.wsURL)
+ }
+ if tr.apiKey != "api-key" {
+ t.Errorf("expected api-key, got %s", tr.apiKey)
+ }
+ if tr.agentID != "agent-1" {
+ t.Errorf("expected agent-1, got %s", tr.agentID)
+ }
+ if tr.userAgent != "ua/1.0" {
+ t.Errorf("expected ua/1.0, got %s", tr.userAgent)
+ }
+ if cap(tr.events) != 50 {
+ t.Errorf("expected events capacity 50, got %d", cap(tr.events))
+ }
+ if tr.authDone == nil {
+ t.Fatal("expected authDone channel to be non-nil")
+ }
+}
+
+func TestWSTransportClaimTasksIsNoOp(t *testing.T) {
+ tr := NewWSTransport("ws://localhost", "key", "a1", "ua")
+ resp, err := tr.ClaimTasks(context.Background(), "a1")
+ if err != nil {
+ t.Fatalf("ClaimTasks should succeed (no-op): %v", err)
+ }
+ if resp == nil {
+ t.Fatal("expected non-nil response")
+ }
+ if len(resp.Tasks) != 0 {
+ t.Errorf("expected 0 tasks, got %d", len(resp.Tasks))
+ }
+}
+
+func TestWSTransportCloseWhenNotConnected(t *testing.T) {
+ tr := NewWSTransport("ws://localhost", "key", "a1", "ua")
+ // Close without ever connecting should not panic or error
+ if err := tr.Close(); err != nil {
+ t.Errorf("Close on unconnected transport should return nil, got %v", err)
+ }
+}
+
+func TestWSTransportSendWhenNotConnected(t *testing.T) {
+ tr := NewWSTransport("ws://localhost", "key", "a1", "ua")
+ // Attempting to send a heartbeat without connecting should fail
+ _, err := tr.SendHeartbeat(context.Background(), HeartbeatRequest{AgentID: "a1"})
+ if err == nil {
+ t.Error("expected error when sending without connection")
+ }
+}
+
+func TestWSTransportConnectBadURL(t *testing.T) {
+ tr := NewWSTransport("ws://127.0.0.1:1", "key", "a1", "ua")
+ err := tr.Connect(context.Background())
+ if err == nil {
+ t.Error("expected error connecting to invalid address")
+ }
+}
+
+func TestWSTransportSendHeartbeatWithDisk(t *testing.T) {
+ var receivedMsg map[string]interface{}
+ var mu sync.Mutex
+
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := upgrader.Upgrade(w, r, nil)
+ if err != nil {
+ return
+ }
+ defer conn.Close()
+
+ // Read auth
+ conn.ReadMessage()
+ conn.WriteJSON(wsRegisteredMessage{Type: "registered", User: UserInfo{}})
+
+ // Read heartbeat
+ _, msg, err := conn.ReadMessage()
+ if err != nil {
+ return
+ }
+ mu.Lock()
+ json.Unmarshal(msg, &receivedMsg)
+ mu.Unlock()
+
+ time.Sleep(500 * time.Millisecond)
+ }))
+ defer srv.Close()
+
+ wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
+ tr := NewWSTransport(wsURL, "key", "a1", "ua")
+
+ ctx := context.Background()
+ tr.Connect(ctx)
+ defer tr.Close()
+ tr.Register(ctx, RegisterRequest{AgentID: "a1"})
+
+ time.Sleep(50 * time.Millisecond)
+ resp, err := tr.SendHeartbeat(ctx, HeartbeatRequest{
+ AgentID: "a1",
+ DiskFreeBytes: 500000000,
+ DiskTotalBytes: 1000000000,
+ })
+ if err != nil {
+ t.Fatalf("SendHeartbeat failed: %v", err)
+ }
+ if !resp.Success {
+ t.Error("expected success")
+ }
+
+ time.Sleep(100 * time.Millisecond)
+ mu.Lock()
+ defer mu.Unlock()
+ if receivedMsg["type"] != "heartbeat" {
+ t.Errorf("expected heartbeat, got %v", receivedMsg["type"])
+ }
+ disk, ok := receivedMsg["disk"].(map[string]interface{})
+ if !ok {
+ t.Fatal("expected disk field in heartbeat message")
+ }
+ if disk["free"].(float64) != 500000000 {
+ t.Errorf("expected free=500000000, got %v", disk["free"])
+ }
+ if disk["total"].(float64) != 1000000000 {
+ t.Errorf("expected total=1000000000, got %v", disk["total"])
+ }
+}
+
+func TestWSTransportSendHeartbeatWithoutDisk(t *testing.T) {
+ var receivedMsg map[string]interface{}
+ var mu sync.Mutex
+
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := upgrader.Upgrade(w, r, nil)
+ if err != nil {
+ return
+ }
+ defer conn.Close()
+
+ conn.ReadMessage()
+ conn.WriteJSON(wsRegisteredMessage{Type: "registered", User: UserInfo{}})
+
+ _, msg, err := conn.ReadMessage()
+ if err != nil {
+ return
+ }
+ mu.Lock()
+ json.Unmarshal(msg, &receivedMsg)
+ mu.Unlock()
+
+ time.Sleep(500 * time.Millisecond)
+ }))
+ defer srv.Close()
+
+ wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
+ tr := NewWSTransport(wsURL, "key", "a1", "ua")
+
+ ctx := context.Background()
+ tr.Connect(ctx)
+ defer tr.Close()
+ tr.Register(ctx, RegisterRequest{AgentID: "a1"})
+
+ time.Sleep(50 * time.Millisecond)
+ resp, err := tr.SendHeartbeat(ctx, HeartbeatRequest{AgentID: "a1"})
+ if err != nil {
+ t.Fatalf("SendHeartbeat failed: %v", err)
+ }
+ if !resp.Success {
+ t.Error("expected success")
+ }
+
+ time.Sleep(100 * time.Millisecond)
+ mu.Lock()
+ defer mu.Unlock()
+ if receivedMsg["type"] != "heartbeat" {
+ t.Errorf("expected heartbeat, got %v", receivedMsg["type"])
+ }
+ // disk field should be absent when no disk info provided
+ if _, exists := receivedMsg["disk"]; exists {
+ t.Error("expected no disk field when disk info is zero")
+ }
+}
+
+func TestWSTransportDeregisterClosesConnection(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := upgrader.Upgrade(w, r, nil)
+ if err != nil {
+ return
+ }
+ defer conn.Close()
+ conn.ReadMessage()
+ conn.WriteJSON(wsRegisteredMessage{Type: "registered", User: UserInfo{}})
+ time.Sleep(500 * time.Millisecond)
+ }))
+ defer srv.Close()
+
+ wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
+ tr := NewWSTransport(wsURL, "key", "a1", "ua")
+
+ ctx := context.Background()
+ tr.Connect(ctx)
+ tr.Register(ctx, RegisterRequest{AgentID: "a1"})
+
+ err := tr.Deregister(ctx, "a1")
+ if err != nil {
+ t.Fatalf("Deregister failed: %v", err)
+ }
+
+ // After deregister, send should fail (connection closed)
+ _, err = tr.SendHeartbeat(ctx, HeartbeatRequest{AgentID: "a1"})
+ if err == nil {
+ t.Error("expected error sending after deregister")
+ }
+}
+
+func TestWSTransportReceiveStreamRequests(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := upgrader.Upgrade(w, r, nil)
+ if err != nil {
+ return
+ }
+ defer conn.Close()
+
+ conn.ReadMessage()
+ conn.WriteJSON(wsRegisteredMessage{Type: "registered", User: UserInfo{}})
+
+ time.Sleep(50 * time.Millisecond)
+ conn.WriteJSON(wsTasksMessage{
+ Type: "tasks",
+ Tasks: []Task{},
+ StreamRequests: []StreamRequest{
+ {TaskID: "t1", FilePath: "/data/movie.mkv"},
+ },
+ })
+
+ time.Sleep(500 * time.Millisecond)
+ }))
+ defer srv.Close()
+
+ wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
+ tr := NewWSTransport(wsURL, "key", "a1", "ua")
+
+ ctx := context.Background()
+ tr.Connect(ctx)
+ defer tr.Close()
+ tr.Register(ctx, RegisterRequest{AgentID: "a1"})
+
+ select {
+ case event := <-tr.Events():
+ if event.Type != "tasks" {
+ t.Errorf("expected tasks, got %s", event.Type)
+ }
+ if len(event.Tasks.StreamRequests) != 1 {
+ t.Fatalf("expected 1 stream request, got %d", len(event.Tasks.StreamRequests))
+ }
+ if event.Tasks.StreamRequests[0].FilePath != "/data/movie.mkv" {
+ t.Errorf("expected /data/movie.mkv, got %s", event.Tasks.StreamRequests[0].FilePath)
+ }
+ case <-time.After(2 * time.Second):
+ t.Fatal("timeout waiting for tasks event with stream requests")
+ }
+}
+
+func TestWSTransportReceiveErrorMessage(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := upgrader.Upgrade(w, r, nil)
+ if err != nil {
+ return
+ }
+ defer conn.Close()
+
+ conn.ReadMessage()
+ conn.WriteJSON(wsRegisteredMessage{Type: "registered", User: UserInfo{}})
+
+ time.Sleep(50 * time.Millisecond)
+ // Send an error message (should be logged, not emitted as event)
+ conn.WriteJSON(map[string]string{
+ "type": "error",
+ "message": "rate limited",
+ })
+
+ time.Sleep(200 * time.Millisecond)
+ }))
+ defer srv.Close()
+
+ wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
+ tr := NewWSTransport(wsURL, "key", "a1", "ua")
+
+ ctx := context.Background()
+ tr.Connect(ctx)
+ defer tr.Close()
+ tr.Register(ctx, RegisterRequest{AgentID: "a1"})
+
+ // Error messages are logged but not emitted — events channel should be quiet
+ select {
+ case event := <-tr.Events():
+ // If we get disconnected, that's acceptable (server closes after delay)
+ if event.Type != "disconnected" {
+ t.Errorf("unexpected event type: %s", event.Type)
+ }
+ case <-time.After(300 * time.Millisecond):
+ // Expected: no event emitted for error messages
+ }
+}
+
+func TestWSTransportRegisterTimeout(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := upgrader.Upgrade(w, r, nil)
+ if err != nil {
+ return
+ }
+ defer conn.Close()
+ conn.ReadMessage()
+ // Never send registered response — should timeout
+ time.Sleep(20 * time.Second)
+ }))
+ defer srv.Close()
+
+ wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
+ tr := NewWSTransport(wsURL, "key", "a1", "ua")
+
+ ctx := context.Background()
+ tr.Connect(ctx)
+ defer tr.Close()
+
+ // Use a context with short timeout to avoid waiting 15s
+ ctxShort, cancel := context.WithTimeout(ctx, 200*time.Millisecond)
+ defer cancel()
+
+ _, err := tr.Register(ctxShort, RegisterRequest{AgentID: "a1"})
+ if err == nil {
+ t.Error("expected timeout error from Register")
+ }
+}
+
+// ── Additional Hybrid Transport Tests ───────────────────────────────────────
+
+func TestNewHybridTransportConstructor(t *testing.T) {
+ wsT := NewWSTransport("ws://localhost", "key", "a1", "ua")
+ httpT := NewHTTPTransport("http://localhost", "key", "ua")
+
+ h := NewHybridTransport(wsT, httpT)
+
+ if h.Mode() != "http" {
+ t.Errorf("expected initial mode http, got %s", h.Mode())
+ }
+ if cap(h.events) != 50 {
+ t.Errorf("expected events capacity 50, got %d", cap(h.events))
+ }
+ if h.ws != wsT {
+ t.Error("expected ws transport to match")
+ }
+ if h.http != httpT {
+ t.Error("expected http transport to match")
+ }
+ if h.reconnectStop == nil {
+ t.Error("expected reconnectStop channel to be non-nil")
+ }
+}
+
+func TestHybridTransportCloseIsIdempotent(t *testing.T) {
+ wsT := NewWSTransport("ws://127.0.0.1:1", "key", "a1", "ua")
+ httpT := NewHTTPTransport("http://localhost", "key", "ua")
+
+ h := NewHybridTransport(wsT, httpT)
+ // Close twice should not panic
+ if err := h.Close(); err != nil {
+ t.Errorf("first Close failed: %v", err)
+ }
+ if err := h.Close(); err != nil {
+ t.Errorf("second Close failed: %v", err)
+ }
+}
+
+func TestHybridTransportHTTPModeRegister(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ json.NewEncoder(w).Encode(RegisterResponse{
+ Success: true,
+ User: UserInfo{Name: "HTTPUser", Plan: "free"},
+ })
+ }))
+ defer srv.Close()
+
+ wsT := NewWSTransport("ws://127.0.0.1:1", "key", "a1", "ua")
+ httpT := NewHTTPTransport(srv.URL, "key", "ua")
+
+ h := NewHybridTransport(wsT, httpT)
+ // Force HTTP mode (default)
+ h.mode.Store("http")
+
+ resp, err := h.Register(context.Background(), RegisterRequest{AgentID: "a1"})
+ if err != nil {
+ t.Fatalf("Register failed: %v", err)
+ }
+ if resp.User.Name != "HTTPUser" {
+ t.Errorf("expected HTTPUser, got %s", resp.User.Name)
+ }
+}
+
+func TestHybridTransportHTTPModeClaimTasks(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ json.NewEncoder(w).Encode(TasksResponse{
+ Tasks: []Task{{ID: "t1", Title: "Test"}},
+ })
+ }))
+ defer srv.Close()
+
+ wsT := NewWSTransport("ws://127.0.0.1:1", "key", "a1", "ua")
+ httpT := NewHTTPTransport(srv.URL, "key", "ua")
+
+ h := NewHybridTransport(wsT, httpT)
+ h.mode.Store("http")
+
+ resp, err := h.ClaimTasks(context.Background(), "a1")
+ if err != nil {
+ t.Fatalf("ClaimTasks failed: %v", err)
+ }
+ if len(resp.Tasks) != 1 {
+ t.Errorf("expected 1 task, got %d", len(resp.Tasks))
+ }
+}
+
+func TestHybridTransportHTTPModeDeregister(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ json.NewEncoder(w).Encode(StatusResponse{Success: true})
+ }))
+ defer srv.Close()
+
+ wsT := NewWSTransport("ws://127.0.0.1:1", "key", "a1", "ua")
+ httpT := NewHTTPTransport(srv.URL, "key", "ua")
+
+ h := NewHybridTransport(wsT, httpT)
+ h.mode.Store("http")
+
+ err := h.Deregister(context.Background(), "a1")
+ if err != nil {
+ t.Fatalf("Deregister failed: %v", err)
+ }
+}
+
+func TestHybridTransportHTTPModeSendHeartbeat(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ json.NewEncoder(w).Encode(HeartbeatResponse{Success: true, Watching: true})
+ }))
+ defer srv.Close()
+
+ wsT := NewWSTransport("ws://127.0.0.1:1", "key", "a1", "ua")
+ httpT := NewHTTPTransport(srv.URL, "key", "ua")
+
+ h := NewHybridTransport(wsT, httpT)
+ h.mode.Store("http")
+
+ resp, err := h.SendHeartbeat(context.Background(), HeartbeatRequest{AgentID: "a1"})
+ if err != nil {
+ t.Fatalf("SendHeartbeat failed: %v", err)
+ }
+ if !resp.Success {
+ t.Error("expected success")
+ }
+ if !resp.Watching {
+ t.Error("expected watching=true")
+ }
+}
+
+func TestHybridTransportHTTPModeSendProgress(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ json.NewEncoder(w).Encode(StatusResponse{Success: true})
+ }))
+ defer srv.Close()
+
+ wsT := NewWSTransport("ws://127.0.0.1:1", "key", "a1", "ua")
+ httpT := NewHTTPTransport(srv.URL, "key", "ua")
+
+ h := NewHybridTransport(wsT, httpT)
+ h.mode.Store("http")
+
+ resp, err := h.SendProgress(context.Background(), StatusUpdate{
+ TaskID: "t1",
+ Status: "completed",
+ Progress: 100,
+ })
+ if err != nil {
+ t.Fatalf("SendProgress failed: %v", err)
+ }
+ if !resp.Success {
+ t.Error("expected success")
+ }
+}
+
+func TestHybridTransportWSModeClaimTasksIsNoOp(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := upgrader.Upgrade(w, r, nil)
+ if err != nil {
+ return
+ }
+ defer conn.Close()
+ time.Sleep(500 * time.Millisecond)
+ }))
+ defer srv.Close()
+
+ wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
+ wsT := NewWSTransport(wsURL, "key", "a1", "ua")
+ httpT := NewHTTPTransport("http://localhost", "key", "ua")
+
+ h := NewHybridTransport(wsT, httpT)
+ h.Connect(context.Background())
+ defer h.Close()
+
+ // In WS mode, ClaimTasks delegates to WS which is a no-op
+ resp, err := h.ClaimTasks(context.Background(), "a1")
+ if err != nil {
+ t.Fatalf("ClaimTasks failed: %v", err)
+ }
+ if len(resp.Tasks) != 0 {
+ t.Errorf("expected 0 tasks in WS mode, got %d", len(resp.Tasks))
+ }
+}
+
+func TestHybridTransportEventsChannel(t *testing.T) {
+ wsT := NewWSTransport("ws://127.0.0.1:1", "key", "a1", "ua")
+ httpT := NewHTTPTransport("http://localhost", "key", "ua")
+
+ h := NewHybridTransport(wsT, httpT)
+ ch := h.Events()
+ if ch == nil {
+ t.Fatal("Events() should return non-nil channel")
+ }
+ // Verify it is the correct channel
+ if cap(ch) != 50 {
+ t.Errorf("expected events capacity 50, got %d", cap(ch))
+ }
+}
+
+func TestHybridTransportSwitchToHTTPIdempotent(t *testing.T) {
+ wsT := NewWSTransport("ws://127.0.0.1:1", "key", "a1", "ua")
+ httpT := NewHTTPTransport("http://localhost", "key", "ua")
+
+ h := NewHybridTransport(wsT, httpT)
+ // Already in HTTP mode, switchToHTTP should be a no-op
+ h.mode.Store("http")
+ h.switchToHTTP() // should not panic or start reconnect
+
+ if h.Mode() != "http" {
+ t.Errorf("expected http, got %s", h.Mode())
+ }
+}
+
+// ── Daemon Constructor & Utility Tests ──────────────────────────────────────
+
+func TestNewDaemonDefaults(t *testing.T) {
+ tr := NewHTTPTransport("http://localhost", "key", "ua")
+ d := NewDaemon(DaemonConfig{
+ AgentID: "a1",
+ AgentName: "test",
+ Version: "1.0",
+ DownloadDir: "/tmp",
+ }, tr)
+
+ if d.cfg.PollInterval != 30*time.Second {
+ t.Errorf("expected default PollInterval 30s, got %v", d.cfg.PollInterval)
+ }
+ if d.cfg.HeartbeatInterval != 30*time.Second {
+ t.Errorf("expected default HeartbeatInterval 30s, got %v", d.cfg.HeartbeatInterval)
+ }
+ if d.Transport() != tr {
+ t.Error("Transport() should return the configured transport")
+ }
+ if d.pollNow == nil {
+ t.Error("pollNow channel should be initialized")
+ }
+}
+
+func TestNewDaemonCustomIntervals(t *testing.T) {
+ tr := NewHTTPTransport("http://localhost", "key", "ua")
+ d := NewDaemon(DaemonConfig{
+ AgentID: "a1",
+ AgentName: "test",
+ Version: "1.0",
+ DownloadDir: "/tmp",
+ PollInterval: 10 * time.Second,
+ HeartbeatInterval: 15 * time.Second,
+ }, tr)
+
+ if d.cfg.PollInterval != 10*time.Second {
+ t.Errorf("expected PollInterval 10s, got %v", d.cfg.PollInterval)
+ }
+ if d.cfg.HeartbeatInterval != 15*time.Second {
+ t.Errorf("expected HeartbeatInterval 15s, got %v", d.cfg.HeartbeatInterval)
+ }
+}
+
+func TestDaemonTriggerPoll(t *testing.T) {
+ tr := NewHTTPTransport("http://localhost", "key", "ua")
+ d := NewDaemon(DaemonConfig{
+ AgentID: "a1",
+ AgentName: "test",
+ Version: "1.0",
+ DownloadDir: "/tmp",
+ }, tr)
+
+ // First trigger should succeed
+ d.TriggerPoll()
+
+ // Channel should have one signal
+ select {
+ case <-d.pollNow:
+ // good
+ default:
+ t.Error("expected signal on pollNow channel")
+ }
+
+ // Second trigger when channel is empty should also succeed
+ d.TriggerPoll()
+ select {
+ case <-d.pollNow:
+ // good
+ default:
+ t.Error("expected signal on pollNow channel after second trigger")
+ }
+}
+
+func TestDaemonTriggerPollNonBlocking(t *testing.T) {
+ tr := NewHTTPTransport("http://localhost", "key", "ua")
+ d := NewDaemon(DaemonConfig{
+ AgentID: "a1",
+ AgentName: "test",
+ Version: "1.0",
+ DownloadDir: "/tmp",
+ }, tr)
+
+ // Fill the channel (capacity 1)
+ d.TriggerPoll()
+ // Second call should not block even though channel is full
+ done := make(chan struct{})
+ go func() {
+ d.TriggerPoll()
+ close(done)
+ }()
+
+ select {
+ case <-done:
+ // good, did not block
+ case <-time.After(1 * time.Second):
+ t.Fatal("TriggerPoll blocked on full channel")
+ }
+}
+
+func TestDaemonHandleEventTasks(t *testing.T) {
+ tr := NewHTTPTransport("http://localhost", "key", "ua")
+ d := NewDaemon(DaemonConfig{
+ AgentID: "a1",
+ AgentName: "test",
+ Version: "1.0",
+ DownloadDir: "/tmp",
+ }, tr)
+
+ var claimedTasks []Task
+ d.OnTasksClaimed = func(tasks []Task) {
+ claimedTasks = tasks
+ }
+
+ d.handleEvent(ServerEvent{
+ Type: "tasks",
+ Tasks: &TasksResponse{
+ Tasks: []Task{
+ {ID: "t1", Title: "Movie 1"},
+ {ID: "t2", Title: "Movie 2"},
+ },
+ },
+ })
+
+ if len(claimedTasks) != 2 {
+ t.Fatalf("expected 2 claimed tasks, got %d", len(claimedTasks))
+ }
+ if claimedTasks[0].Title != "Movie 1" {
+ t.Errorf("expected Movie 1, got %s", claimedTasks[0].Title)
+ }
+}
+
+func TestDaemonHandleEventTasksWithStreamRequests(t *testing.T) {
+ tr := NewHTTPTransport("http://localhost", "key", "ua")
+ d := NewDaemon(DaemonConfig{
+ AgentID: "a1",
+ AgentName: "test",
+ Version: "1.0",
+ DownloadDir: "/tmp",
+ }, tr)
+
+ var streamReqs []StreamRequest
+ d.OnStreamRequested = func(req StreamRequest) {
+ streamReqs = append(streamReqs, req)
+ }
+
+ d.handleEvent(ServerEvent{
+ Type: "tasks",
+ Tasks: &TasksResponse{
+ Tasks: []Task{},
+ StreamRequests: []StreamRequest{
+ {TaskID: "t1", FilePath: "/data/movie.mkv"},
+ {TaskID: "t2", FilePath: "/data/show.mkv"},
+ },
+ },
+ })
+
+ if len(streamReqs) != 2 {
+ t.Fatalf("expected 2 stream requests, got %d", len(streamReqs))
+ }
+ if streamReqs[0].FilePath != "/data/movie.mkv" {
+ t.Errorf("expected /data/movie.mkv, got %s", streamReqs[0].FilePath)
+ }
+}
+
+func TestDaemonHandleEventUpgrade(t *testing.T) {
+ tr := NewHTTPTransport("http://localhost", "key", "ua")
+ d := NewDaemon(DaemonConfig{
+ AgentID: "a1",
+ AgentName: "test",
+ Version: "1.0",
+ DownloadDir: "/tmp",
+ }, tr)
+
+ d.handleEvent(ServerEvent{
+ Type: "upgrade",
+ Upgrade: &UpgradeSignal{Version: "2.0.0"},
+ })
+
+ if d.lastNotifiedVersion != "2.0.0" {
+ t.Errorf("expected lastNotifiedVersion 2.0.0, got %s", d.lastNotifiedVersion)
+ }
+
+ // Same version again should not update (already notified)
+ d.lastNotifiedVersion = "2.0.0"
+ d.handleEvent(ServerEvent{
+ Type: "upgrade",
+ Upgrade: &UpgradeSignal{Version: "2.0.0"},
+ })
+ // Still 2.0.0, no change
+ if d.lastNotifiedVersion != "2.0.0" {
+ t.Errorf("expected lastNotifiedVersion unchanged at 2.0.0, got %s", d.lastNotifiedVersion)
+ }
+}
+
+func TestDaemonHandleEventControl(t *testing.T) {
+ tr := NewHTTPTransport("http://localhost", "key", "ua")
+ d := NewDaemon(DaemonConfig{
+ AgentID: "a1",
+ AgentName: "test",
+ Version: "1.0",
+ DownloadDir: "/tmp",
+ }, tr)
+
+ var gotAction, gotTaskID string
+ d.OnControlAction = func(action, taskID string) {
+ gotAction = action
+ gotTaskID = taskID
+ }
+
+ d.handleEvent(ServerEvent{
+ Type: "control",
+ Control: &ControlAction{Action: "cancel", TaskID: "task-99"},
+ })
+
+ if gotAction != "cancel" {
+ t.Errorf("expected cancel, got %s", gotAction)
+ }
+ if gotTaskID != "task-99" {
+ t.Errorf("expected task-99, got %s", gotTaskID)
+ }
+}
+
+func TestDaemonHandleEventControlWithNilCallback(t *testing.T) {
+ tr := NewHTTPTransport("http://localhost", "key", "ua")
+ d := NewDaemon(DaemonConfig{
+ AgentID: "a1",
+ AgentName: "test",
+ Version: "1.0",
+ DownloadDir: "/tmp",
+ }, tr)
+
+ // OnControlAction is nil — should not panic
+ d.handleEvent(ServerEvent{
+ Type: "control",
+ Control: &ControlAction{Action: "pause", TaskID: "t1"},
+ })
+}
+
+func TestDaemonHandleEventDisconnected(t *testing.T) {
+ tr := NewHTTPTransport("http://localhost", "key", "ua")
+ d := NewDaemon(DaemonConfig{
+ AgentID: "a1",
+ AgentName: "test",
+ Version: "1.0",
+ DownloadDir: "/tmp",
+ }, tr)
+
+ // disconnected event should not panic (just logs)
+ d.handleEvent(ServerEvent{Type: "disconnected"})
+}
+
+func TestDaemonHandleEventTasksNilCallback(t *testing.T) {
+ tr := NewHTTPTransport("http://localhost", "key", "ua")
+ d := NewDaemon(DaemonConfig{
+ AgentID: "a1",
+ AgentName: "test",
+ Version: "1.0",
+ DownloadDir: "/tmp",
+ }, tr)
+
+ // OnTasksClaimed is nil — should not panic
+ d.handleEvent(ServerEvent{
+ Type: "tasks",
+ Tasks: &TasksResponse{
+ Tasks: []Task{{ID: "t1", Title: "Test"}},
+ },
+ })
+}
+
+func TestDaemonHandleEventEmptyTasks(t *testing.T) {
+ tr := NewHTTPTransport("http://localhost", "key", "ua")
+ d := NewDaemon(DaemonConfig{
+ AgentID: "a1",
+ AgentName: "test",
+ Version: "1.0",
+ DownloadDir: "/tmp",
+ }, tr)
+
+ var called bool
+ d.OnTasksClaimed = func(tasks []Task) {
+ called = true
+ }
+
+ // Empty tasks should not trigger callback
+ d.handleEvent(ServerEvent{
+ Type: "tasks",
+ Tasks: &TasksResponse{Tasks: []Task{}},
+ })
+
+ if called {
+ t.Error("OnTasksClaimed should not be called for empty task list")
+ }
+}
+
+func TestDaemonHandleEventNilTasks(t *testing.T) {
+ tr := NewHTTPTransport("http://localhost", "key", "ua")
+ d := NewDaemon(DaemonConfig{
+ AgentID: "a1",
+ AgentName: "test",
+ Version: "1.0",
+ DownloadDir: "/tmp",
+ }, tr)
+
+ // Nil Tasks field should not panic
+ d.handleEvent(ServerEvent{
+ Type: "tasks",
+ Tasks: nil,
+ })
+}
+
+func TestDaemonHandleEventUpgradeNilSignal(t *testing.T) {
+ tr := NewHTTPTransport("http://localhost", "key", "ua")
+ d := NewDaemon(DaemonConfig{
+ AgentID: "a1",
+ AgentName: "test",
+ Version: "1.0",
+ DownloadDir: "/tmp",
+ }, tr)
+
+ // Nil Upgrade should not panic
+ d.handleEvent(ServerEvent{
+ Type: "upgrade",
+ Upgrade: nil,
+ })
+ if d.lastNotifiedVersion != "" {
+ t.Errorf("expected empty lastNotifiedVersion, got %s", d.lastNotifiedVersion)
+ }
+}
+
+func TestDaemonHandleEventUpgradeEmptyVersion(t *testing.T) {
+ tr := NewHTTPTransport("http://localhost", "key", "ua")
+ d := NewDaemon(DaemonConfig{
+ AgentID: "a1",
+ AgentName: "test",
+ Version: "1.0",
+ DownloadDir: "/tmp",
+ }, tr)
+
+ // Empty version should not update lastNotifiedVersion
+ d.handleEvent(ServerEvent{
+ Type: "upgrade",
+ Upgrade: &UpgradeSignal{Version: ""},
+ })
+ if d.lastNotifiedVersion != "" {
+ t.Errorf("expected empty lastNotifiedVersion, got %s", d.lastNotifiedVersion)
+ }
+}
+
+func TestDaemonWatchingFlag(t *testing.T) {
+ tr := NewHTTPTransport("http://localhost", "key", "ua")
+ d := NewDaemon(DaemonConfig{
+ AgentID: "a1",
+ AgentName: "test",
+ Version: "1.0",
+ DownloadDir: "/tmp",
+ }, tr)
+
+ if d.Watching.Load() {
+ t.Error("expected Watching to be false initially")
+ }
+ d.Watching.Store(true)
+ if !d.Watching.Load() {
+ t.Error("expected Watching to be true after Store(true)")
+ }
+}
+
+// ── Transport Interface Compliance ──────────────────────────────────────────
+
+func TestHTTPTransportImplementsTransport(t *testing.T) {
+ var _ Transport = (*HTTPTransport)(nil)
+}
+
+func TestWSTransportImplementsTransport(t *testing.T) {
+ var _ Transport = (*WSTransport)(nil)
+}
+
+func TestHybridTransportImplementsTransport(t *testing.T) {
+ var _ Transport = (*HybridTransport)(nil)
+}
diff --git a/internal/arr/client_test.go b/internal/arr/client_test.go
new file mode 100644
index 0000000..214dd12
--- /dev/null
+++ b/internal/arr/client_test.go
@@ -0,0 +1,396 @@
+package arr
+
+import (
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+)
+
+func newTestServer(t *testing.T, handlers map[string]any) *httptest.Server {
+ t.Helper()
+ return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // Check API key header
+ if r.Header.Get("X-Api-Key") != "test-key" {
+ w.WriteHeader(http.StatusUnauthorized)
+ return
+ }
+
+ handler, ok := handlers[r.URL.Path]
+ if !ok {
+ w.WriteHeader(http.StatusNotFound)
+ return
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(handler)
+ }))
+}
+
+func TestNewClient(t *testing.T) {
+ c := NewClient("http://localhost:8989/", "mykey")
+ if c.baseURL != "http://localhost:8989" {
+ t.Errorf("baseURL = %q, want trailing slash trimmed", c.baseURL)
+ }
+ if c.apiKey != "mykey" {
+ t.Errorf("apiKey = %q, want mykey", c.apiKey)
+ }
+}
+
+func TestSystemStatus(t *testing.T) {
+ srv := newTestServer(t, map[string]any{
+ "/api/v3/system/status": SystemStatus{AppName: "Radarr", Version: "4.0.0"},
+ })
+ defer srv.Close()
+
+ c := NewClient(srv.URL, "test-key")
+ status, err := c.SystemStatus()
+ if err != nil {
+ t.Fatalf("SystemStatus: %v", err)
+ }
+ if status.AppName != "Radarr" {
+ t.Errorf("AppName = %q, want Radarr", status.AppName)
+ }
+ if status.Version != "4.0.0" {
+ t.Errorf("Version = %q, want 4.0.0", status.Version)
+ }
+}
+
+func TestSystemStatusFallbackV1(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Header.Get("X-Api-Key") != "test-key" {
+ w.WriteHeader(http.StatusUnauthorized)
+ return
+ }
+ switch r.URL.Path {
+ case "/api/v3/system/status":
+ w.WriteHeader(http.StatusNotFound)
+ case "/api/v1/system/status":
+ json.NewEncoder(w).Encode(SystemStatus{AppName: "Prowlarr", Version: "1.0.0"})
+ default:
+ w.WriteHeader(http.StatusNotFound)
+ }
+ }))
+ defer srv.Close()
+
+ c := NewClient(srv.URL, "test-key")
+ status, err := c.SystemStatus()
+ if err != nil {
+ t.Fatalf("SystemStatus v1 fallback: %v", err)
+ }
+ if status.AppName != "Prowlarr" {
+ t.Errorf("AppName = %q, want Prowlarr", status.AppName)
+ }
+}
+
+func TestMovies(t *testing.T) {
+ srv := newTestServer(t, map[string]any{
+ "/api/v3/movie": []Movie{
+ {ID: 1, Title: "Inception", Year: 2010, TmdbID: 27205, Monitored: true},
+ {ID: 2, Title: "Tenet", Year: 2020, TmdbID: 577922, HasFile: true},
+ },
+ })
+ defer srv.Close()
+
+ c := NewClient(srv.URL, "test-key")
+ movies, err := c.Movies()
+ if err != nil {
+ t.Fatalf("Movies: %v", err)
+ }
+ if len(movies) != 2 {
+ t.Fatalf("expected 2 movies, got %d", len(movies))
+ }
+ if movies[0].Title != "Inception" {
+ t.Errorf("movies[0].Title = %q, want Inception", movies[0].Title)
+ }
+}
+
+func TestSeries(t *testing.T) {
+ srv := newTestServer(t, map[string]any{
+ "/api/v3/series": []Series{
+ {ID: 1, Title: "Breaking Bad", Year: 2008, TvdbID: 81189},
+ },
+ })
+ defer srv.Close()
+
+ c := NewClient(srv.URL, "test-key")
+ series, err := c.Series()
+ if err != nil {
+ t.Fatalf("Series: %v", err)
+ }
+ if len(series) != 1 {
+ t.Fatalf("expected 1 series, got %d", len(series))
+ }
+ if series[0].Title != "Breaking Bad" {
+ t.Errorf("series[0].Title = %q, want Breaking Bad", series[0].Title)
+ }
+}
+
+func TestQualityProfiles(t *testing.T) {
+ srv := newTestServer(t, map[string]any{
+ "/api/v3/qualityprofile": []QualityProfile{
+ {ID: 1, Name: "HD-1080p"},
+ {ID: 2, Name: "Ultra-HD"},
+ },
+ })
+ defer srv.Close()
+
+ c := NewClient(srv.URL, "test-key")
+ profiles, err := c.QualityProfiles()
+ if err != nil {
+ t.Fatalf("QualityProfiles: %v", err)
+ }
+ if len(profiles) != 2 {
+ t.Fatalf("expected 2 profiles, got %d", len(profiles))
+ }
+}
+
+func TestRootFolders(t *testing.T) {
+ srv := newTestServer(t, map[string]any{
+ "/api/v3/rootfolder": []RootFolder{
+ {ID: 1, Path: "/movies", FreeSpace: 500000000000},
+ },
+ })
+ defer srv.Close()
+
+ c := NewClient(srv.URL, "test-key")
+ folders, err := c.RootFolders()
+ if err != nil {
+ t.Fatalf("RootFolders: %v", err)
+ }
+ if len(folders) != 1 {
+ t.Fatalf("expected 1 folder, got %d", len(folders))
+ }
+ if folders[0].Path != "/movies" {
+ t.Errorf("path = %q, want /movies", folders[0].Path)
+ }
+}
+
+func TestDownloadClients(t *testing.T) {
+ srv := newTestServer(t, map[string]any{
+ "/api/v3/downloadclient": []DownloadClient{
+ {ID: 1, Name: "Transmission", Enable: true, Protocol: "torrent"},
+ },
+ })
+ defer srv.Close()
+
+ c := NewClient(srv.URL, "test-key")
+ clients, err := c.DownloadClients()
+ if err != nil {
+ t.Fatalf("DownloadClients: %v", err)
+ }
+ if len(clients) != 1 || clients[0].Name != "Transmission" {
+ t.Errorf("unexpected clients: %+v", clients)
+ }
+}
+
+func TestDownloadClientDetails(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Header.Get("X-Api-Key") != "test-key" {
+ w.WriteHeader(http.StatusUnauthorized)
+ return
+ }
+ if r.URL.Path == "/api/v3/downloadclient/5" {
+ json.NewEncoder(w).Encode(struct {
+ Fields []Field `json:"fields"`
+ }{
+ Fields: []Field{
+ {Name: "host", Value: "localhost"},
+ {Name: "port", Value: 9091},
+ },
+ })
+ return
+ }
+ w.WriteHeader(http.StatusNotFound)
+ }))
+ defer srv.Close()
+
+ c := NewClient(srv.URL, "test-key")
+ fields, err := c.DownloadClientDetails(5)
+ if err != nil {
+ t.Fatalf("DownloadClientDetails: %v", err)
+ }
+ if len(fields) != 2 {
+ t.Fatalf("expected 2 fields, got %d", len(fields))
+ }
+ if fields[0].Name != "host" {
+ t.Errorf("fields[0].Name = %q, want host", fields[0].Name)
+ }
+}
+
+func TestTags(t *testing.T) {
+ srv := newTestServer(t, map[string]any{
+ "/api/v3/tag": []Tag{
+ {ID: 1, Label: "unarr"},
+ {ID: 2, Label: "imported"},
+ },
+ })
+ defer srv.Close()
+
+ c := NewClient(srv.URL, "test-key")
+ tags, err := c.Tags()
+ if err != nil {
+ t.Fatalf("Tags: %v", err)
+ }
+ if len(tags) != 2 {
+ t.Fatalf("expected 2 tags, got %d", len(tags))
+ }
+}
+
+func TestHistory(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Header.Get("X-Api-Key") != "test-key" {
+ w.WriteHeader(http.StatusUnauthorized)
+ return
+ }
+ if r.URL.Path == "/api/v3/history" {
+ json.NewEncoder(w).Encode(HistoryResponse{
+ Records: []HistoryRecord{
+ {ID: 1, EventType: "grabbed", SourceTitle: "Inception.2010.1080p"},
+ },
+ TotalRecords: 1,
+ })
+ return
+ }
+ w.WriteHeader(http.StatusNotFound)
+ }))
+ defer srv.Close()
+
+ c := NewClient(srv.URL, "test-key")
+ records, err := c.History(10)
+ if err != nil {
+ t.Fatalf("History: %v", err)
+ }
+ if len(records) != 1 {
+ t.Fatalf("expected 1 record, got %d", len(records))
+ }
+ if records[0].SourceTitle != "Inception.2010.1080p" {
+ t.Errorf("sourceTitle = %q", records[0].SourceTitle)
+ }
+}
+
+func TestHistoryDefaultPageSize(t *testing.T) {
+ var requestedPath string
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Header.Get("X-Api-Key") != "test-key" {
+ w.WriteHeader(http.StatusUnauthorized)
+ return
+ }
+ requestedPath = r.URL.String()
+ json.NewEncoder(w).Encode(HistoryResponse{})
+ }))
+ defer srv.Close()
+
+ c := NewClient(srv.URL, "test-key")
+ c.History(0) // should default to 250
+
+ if requestedPath == "" {
+ t.Fatal("no request made")
+ }
+ if !contains(requestedPath, "pageSize=250") {
+ t.Errorf("expected pageSize=250, got path: %s", requestedPath)
+ }
+}
+
+func TestBlocklist(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Header.Get("X-Api-Key") != "test-key" {
+ w.WriteHeader(http.StatusUnauthorized)
+ return
+ }
+ if r.URL.Path == "/api/v3/blocklist" {
+ json.NewEncoder(w).Encode(BlocklistResponse{
+ Records: []BlocklistItem{
+ {ID: 1, SourceTitle: "Bad.Release", Data: BlocklistData{InfoHash: "abc123"}},
+ },
+ })
+ return
+ }
+ w.WriteHeader(http.StatusNotFound)
+ }))
+ defer srv.Close()
+
+ c := NewClient(srv.URL, "test-key")
+ items, err := c.Blocklist(50)
+ if err != nil {
+ t.Fatalf("Blocklist: %v", err)
+ }
+ if len(items) != 1 || items[0].Data.InfoHash != "abc123" {
+ t.Errorf("unexpected blocklist: %+v", items)
+ }
+}
+
+func TestIndexers(t *testing.T) {
+ srv := newTestServer(t, map[string]any{
+ "/api/v1/indexer": []Indexer{
+ {ID: 1, Name: "NZBGeek", Enable: true},
+ {ID: 2, Name: "Torznab", Enable: false},
+ },
+ })
+ defer srv.Close()
+
+ c := NewClient(srv.URL, "test-key")
+ indexers, err := c.Indexers()
+ if err != nil {
+ t.Fatalf("Indexers: %v", err)
+ }
+ if len(indexers) != 2 {
+ t.Fatalf("expected 2 indexers, got %d", len(indexers))
+ }
+}
+
+func TestApplications(t *testing.T) {
+ srv := newTestServer(t, map[string]any{
+ "/api/v1/applications": []Application{
+ {ID: 1, Name: "Radarr"},
+ },
+ })
+ defer srv.Close()
+
+ c := NewClient(srv.URL, "test-key")
+ apps, err := c.Applications()
+ if err != nil {
+ t.Fatalf("Applications: %v", err)
+ }
+ if len(apps) != 1 || apps[0].Name != "Radarr" {
+ t.Errorf("unexpected apps: %+v", apps)
+ }
+}
+
+func TestUnauthorized(t *testing.T) {
+ srv := newTestServer(t, map[string]any{})
+ defer srv.Close()
+
+ c := NewClient(srv.URL, "wrong-key")
+ _, err := c.SystemStatus()
+ if err == nil {
+ t.Error("expected error for unauthorized request")
+ }
+}
+
+func TestHTTPError(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusInternalServerError)
+ w.Write([]byte("internal error"))
+ }))
+ defer srv.Close()
+
+ c := NewClient(srv.URL, "test-key")
+ _, err := c.Movies()
+ if err == nil {
+ t.Error("expected error for 500 response")
+ }
+}
+
+func contains(s, substr string) bool {
+ return len(s) >= len(substr) && searchStr(s, substr)
+}
+
+func searchStr(s, sub string) bool {
+ for i := 0; i <= len(s)-len(sub); i++ {
+ if s[i:i+len(sub)] == sub {
+ return true
+ }
+ }
+ return false
+}
diff --git a/internal/arr/discovery_test.go b/internal/arr/discovery_test.go
index 499877c..6dc78b8 100644
--- a/internal/arr/discovery_test.go
+++ b/internal/arr/discovery_test.go
@@ -1,6 +1,9 @@
package arr
import (
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
"strings"
"testing"
)
@@ -82,3 +85,158 @@ func TestDetectApp(t *testing.T) {
})
}
}
+
+func TestConfigDirs(t *testing.T) {
+ dirs := configDirs()
+ if len(dirs) == 0 {
+ t.Error("configDirs() returned empty")
+ }
+}
+
+func TestParseConfigXMLEmpty(t *testing.T) {
+ port, apiKey, urlBase := parseConfigXML(strings.NewReader(""))
+ if port != "" || apiKey != "" || urlBase != "" {
+ t.Error("empty input should return empty values")
+ }
+}
+
+func TestParseConfigXMLNoPort(t *testing.T) {
+ xml := `key123`
+ port, apiKey, _ := parseConfigXML(strings.NewReader(xml))
+ if port != "" {
+ t.Errorf("port = %q, want empty", port)
+ }
+ if apiKey != "key123" {
+ t.Errorf("apiKey = %q, want key123", apiKey)
+ }
+}
+
+func TestExtractHostPortMultipleMappings(t *testing.T) {
+ tests := []struct {
+ name string
+ ports string
+ container string
+ want string
+ }{
+ {"ipv6 only", ":::8989->8989/tcp", "8989", "8989"},
+ {"different host port", "0.0.0.0:9999->8989/tcp", "8989", "9999"},
+ {"port in string but no mapping", "something 8989 somewhere", "8989", "8989"},
+ {"no match at all", "0.0.0.0:3000->3000/tcp", "9999", ""},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := extractHostPort(tt.ports, tt.container)
+ if got != tt.want {
+ t.Errorf("extractHostPort(%q, %q) = %q, want %q", tt.ports, tt.container, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestDiscoverFromProwlarr(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/api/v1/applications":
+ json.NewEncoder(w).Encode([]Application{
+ {
+ ID: 1,
+ Name: "Radarr",
+ Fields: []Field{
+ {Name: "baseUrl", Value: "http://localhost:7878"},
+ {Name: "apiKey", Value: "radarr-key-123"},
+ },
+ },
+ {
+ ID: 2,
+ Name: "Sonarr",
+ Fields: []Field{
+ {Name: "baseUrl", Value: "http://localhost:8989"},
+ {Name: "apiKey", Value: "sonarr-key-456"},
+ },
+ },
+ {
+ ID: 3,
+ Name: "Unknown App",
+ Fields: []Field{
+ {Name: "baseUrl", Value: "http://localhost:9000"},
+ {Name: "apiKey", Value: "unknown-key"},
+ },
+ },
+ {
+ ID: 4,
+ Name: "Incomplete",
+ Fields: []Field{
+ {Name: "baseUrl", Value: "http://localhost:5000"},
+ // no apiKey → should be skipped
+ },
+ },
+ })
+ case "/api/v3/system/status":
+ json.NewEncoder(w).Encode(SystemStatus{AppName: "Radarr", Version: "4.0.0"})
+ default:
+ w.WriteHeader(http.StatusNotFound)
+ }
+ }))
+ defer srv.Close()
+
+ // DiscoverFromProwlarr will try to verify each instance, which will fail
+ // for localhost URLs (not our test server), but that's OK — we test the parsing
+ instances := DiscoverFromProwlarr(srv.URL, "prowlarr-key")
+
+ // Should find Radarr and Sonarr (Unknown and Incomplete skipped)
+ if len(instances) != 2 {
+ t.Fatalf("expected 2 instances, got %d: %+v", len(instances), instances)
+ }
+
+ found := map[string]bool{}
+ for _, inst := range instances {
+ found[inst.App] = true
+ if inst.Source != "prowlarr" {
+ t.Errorf("source = %q, want prowlarr", inst.Source)
+ }
+ }
+ if !found["radarr"] {
+ t.Error("expected radarr instance")
+ }
+ if !found["sonarr"] {
+ t.Error("expected sonarr instance")
+ }
+}
+
+func TestVerify(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Header.Get("X-Api-Key") != "valid-key" {
+ w.WriteHeader(http.StatusUnauthorized)
+ return
+ }
+ json.NewEncoder(w).Encode(SystemStatus{AppName: "Radarr", Version: "5.0.0"})
+ }))
+ defer srv.Close()
+
+ t.Run("valid", func(t *testing.T) {
+ inst := &Instance{App: "radarr", URL: srv.URL, APIKey: "valid-key"}
+ err := Verify(inst)
+ if err != nil {
+ t.Fatalf("Verify: %v", err)
+ }
+ if inst.Version != "5.0.0" {
+ t.Errorf("version = %q, want 5.0.0", inst.Version)
+ }
+ })
+
+ t.Run("no api key", func(t *testing.T) {
+ inst := &Instance{App: "radarr", URL: srv.URL}
+ err := Verify(inst)
+ if err == nil {
+ t.Error("expected error for no API key")
+ }
+ })
+
+ t.Run("invalid key", func(t *testing.T) {
+ inst := &Instance{App: "radarr", URL: srv.URL, APIKey: "wrong-key"}
+ err := Verify(inst)
+ if err == nil {
+ t.Error("expected error for invalid API key")
+ }
+ })
+}
diff --git a/internal/cmd/config_menu_test.go b/internal/cmd/config_menu_test.go
new file mode 100644
index 0000000..a63389b
--- /dev/null
+++ b/internal/cmd/config_menu_test.go
@@ -0,0 +1,55 @@
+package cmd
+
+import "testing"
+
+func TestValidateSpeed(t *testing.T) {
+ tests := []struct {
+ input string
+ wantErr bool
+ }{
+ {"", false},
+ {"0", false},
+ {" ", false},
+ {"10MB", false},
+ {"500KB", false},
+ {"1GB", false},
+ {"abc", true},
+ {"10XB", true},
+ {"-5MB", true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.input, func(t *testing.T) {
+ err := validateSpeed(tt.input)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("validateSpeed(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr)
+ }
+ })
+ }
+}
+
+func TestValidateDuration(t *testing.T) {
+ tests := []struct {
+ input string
+ wantErr bool
+ }{
+ {"", false},
+ {"30s", false},
+ {"1m", false},
+ {"5m", false},
+ {"1h", false},
+ {"2h30m", false},
+ {"abc", true},
+ {"30", true},
+ {"5x", true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.input, func(t *testing.T) {
+ err := validateDuration(tt.input)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("validateDuration(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr)
+ }
+ })
+ }
+}
diff --git a/internal/cmd/daemon_test.go b/internal/cmd/daemon_test.go
new file mode 100644
index 0000000..11ad2a5
--- /dev/null
+++ b/internal/cmd/daemon_test.go
@@ -0,0 +1,55 @@
+package cmd
+
+import "testing"
+
+func TestDeriveWSURL(t *testing.T) {
+ tests := []struct {
+ apiURL string
+ agentID string
+ want string
+ }{
+ {"https://torrentclaw.com", "agent-123", "wss://unarr.torrentclaw.com/ws/agent-123"},
+ {"http://localhost:3000", "a1", ""}, // localhost skipped
+ {"http://127.0.0.1:3000", "a1", ""}, // 127.0.0.1 skipped
+ {"https://torrentclaw.com/", "a1", "wss://unarr.torrentclaw.com/ws/a1"},
+ {"https://api.example.io", "x", "wss://unarr.api.example.io/ws/x"},
+ {"", "agent-123", ""},
+ {"https://torrentclaw.com", "", ""},
+ {"", "", ""},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.apiURL+"_"+tt.agentID, func(t *testing.T) {
+ got := deriveWSURL(tt.apiURL, tt.agentID)
+ if got != tt.want {
+ t.Errorf("deriveWSURL(%q, %q) = %q, want %q", tt.apiURL, tt.agentID, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestFormatSpeedLog(t *testing.T) {
+ tests := []struct {
+ bps int64
+ want string
+ }{
+ {0, "0 B/s"},
+ {500, "500 B/s"},
+ {1023, "1023 B/s"},
+ {1024, "1 KB/s"},
+ {10240, "10 KB/s"},
+ {1048576, "1.0 MB/s"},
+ {5242880, "5.0 MB/s"},
+ {1073741824, "1.0 GB/s"},
+ {2147483648, "2.0 GB/s"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.want, func(t *testing.T) {
+ got := formatSpeedLog(tt.bps)
+ if got != tt.want {
+ t.Errorf("formatSpeedLog(%d) = %q, want %q", tt.bps, got, tt.want)
+ }
+ })
+ }
+}
diff --git a/internal/cmd/helpers_test.go b/internal/cmd/helpers_test.go
new file mode 100644
index 0000000..a5badc3
--- /dev/null
+++ b/internal/cmd/helpers_test.go
@@ -0,0 +1,43 @@
+package cmd
+
+import (
+ "os"
+ "strings"
+ "testing"
+)
+
+func TestExpandHome(t *testing.T) {
+ home, _ := os.UserHomeDir()
+
+ tests := []struct {
+ input string
+ want string
+ }{
+ {"~/Documents", home + "/Documents"},
+ {"~/", home},
+ {"/absolute/path", "/absolute/path"},
+ {"relative/path", "relative/path"},
+ {"", ""},
+ {"~notexpanded", "~notexpanded"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.input, func(t *testing.T) {
+ got := expandHome(tt.input)
+ if got != tt.want {
+ t.Errorf("expandHome(%q) = %q, want %q", tt.input, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestDefaultDownloadDir(t *testing.T) {
+ dir := defaultDownloadDir()
+ if dir == "" {
+ t.Error("defaultDownloadDir() returned empty string")
+ }
+ home, _ := os.UserHomeDir()
+ if !strings.HasPrefix(dir, home) {
+ t.Errorf("defaultDownloadDir() = %q, expected to start with home dir %q", dir, home)
+ }
+}
diff --git a/internal/cmd/root.go b/internal/cmd/root.go
index bd8f734..bcf3473 100644
--- a/internal/cmd/root.go
+++ b/internal/cmd/root.go
@@ -143,8 +143,9 @@ Source: https://github.com/torrentclaw/unarr`,
completionCmd,
// Library
scanCmd,
+ // Alias: upgrade → self-update
+ newUpgradeCmd(),
// Stubs for future commands
- newStubCmd("upgrade", "Find a better version of a torrent"),
newStubCmd("moreseed", "Find same quality with more seeders"),
newStubCmd("compare", "Compare two torrents side by side"),
newStubCmd("add", "Search and add torrents to your client"),
diff --git a/internal/cmd/status.go b/internal/cmd/status.go
index 0f989f4..e90354c 100644
--- a/internal/cmd/status.go
+++ b/internal/cmd/status.go
@@ -1,20 +1,26 @@
package cmd
import (
+ "context"
"fmt"
+ "runtime"
+ "strings"
+ "time"
"github.com/fatih/color"
"github.com/spf13/cobra"
+ "github.com/torrentclaw/unarr/internal/agent"
+ "github.com/torrentclaw/unarr/internal/upgrade"
)
func newStatusCmd() *cobra.Command {
return &cobra.Command{
Use: "status",
- Short: "Show daemon status and active downloads",
- Long: `Display the current state of the daemon, active downloads, and recent activity.
+ Short: "Show daemon status, configuration, and update availability",
+ Long: `Display the current state of unarr: version, configuration, daemon status,
+disk usage, and whether an update is available.
-Shows the configured agent name, download directory, and preferred method.
-When the daemon is running, also displays active downloads and their progress.`,
+When the daemon is running, also displays uptime, active downloads, and stats.`,
Example: ` unarr status`,
RunE: func(cmd *cobra.Command, args []string) error {
return runStatus()
@@ -25,27 +31,167 @@ When the daemon is running, also displays active downloads and their progress.`,
func runStatus() error {
bold := color.New(color.Bold)
dim := color.New(color.FgHiBlack)
+ green := color.New(color.FgGreen)
+ yellow := color.New(color.FgYellow)
+ cyan := color.New(color.FgCyan)
fmt.Println()
bold.Printf(" unarr %s\n", Version)
+ dim.Printf(" %s/%s\n", runtime.GOOS, runtime.GOARCH)
fmt.Println()
cfg := loadConfig()
+ // ── Configuration ──
if cfg.Auth.APIKey == "" {
- dim.Println(" Not configured. Run 'unarr init' first.")
+ yellow.Println(" ⚠ Not configured. Run 'unarr init' first.")
fmt.Println()
return nil
}
- fmt.Printf(" Agent: %s (%s)\n", cfg.Agent.Name, cfg.Agent.ID[:8]+"...")
- fmt.Printf(" Downloads: %s\n", cfg.Download.Dir)
- fmt.Printf(" Method: %s\n", cfg.Download.PreferredMethod)
+ cyan.Println(" Configuration")
+ agentID := cfg.Agent.ID
+ if len(agentID) > 8 {
+ agentID = agentID[:8] + "..."
+ }
+ fmt.Printf(" Agent: %s (%s)\n", cfg.Agent.Name, agentID)
+ fmt.Printf(" Server: %s\n", cfg.Auth.APIURL)
+ fmt.Printf(" Downloads: %s\n", cfg.Download.Dir)
+ fmt.Printf(" Method: %s\n", cfg.Download.PreferredMethod)
+ if cfg.Download.PreferredQuality != "" {
+ fmt.Printf(" Quality: %s\n", cfg.Download.PreferredQuality)
+ }
+ fmt.Printf(" Concurrent: %d\n", cfg.Download.MaxConcurrent)
+ if cfg.Organize.Enabled {
+ fmt.Printf(" Organize: on")
+ if cfg.Organize.MoviesDir != "" {
+ fmt.Printf(" (movies: %s", cfg.Organize.MoviesDir)
+ if cfg.Organize.TVShowsDir != "" {
+ fmt.Printf(", tv: %s", cfg.Organize.TVShowsDir)
+ }
+ fmt.Print(")")
+ }
+ fmt.Println()
+ }
fmt.Println()
- dim.Println(" Daemon not running. Start with 'unarr start'")
- dim.Println(" (Live status will be shown here when daemon is running)")
+ // ── Disk ──
+ if cfg.Download.Dir != "" {
+ if free, total, err := agent.DiskInfo(cfg.Download.Dir); err == nil && total > 0 {
+ usedPct := float64(total-free) / float64(total) * 100
+ cyan.Println(" Disk")
+ fmt.Printf(" Free: %s / %s (%.0f%% used)\n", formatBytes(free), formatBytes(total), usedPct)
+ if usedPct > 90 {
+ yellow.Println(" ⚠ Low disk space!")
+ }
+ fmt.Println()
+ }
+ }
+
+ // ── Daemon ──
+ cyan.Println(" Daemon")
+ state := agent.ReadState()
+ if state != nil && isDaemonAlive(state) {
+ green.Printf(" Status: running (PID %d)\n", state.PID)
+ fmt.Printf(" Uptime: %s\n", formatDuration(time.Since(state.StartedAt)))
+ fmt.Printf(" Last beat: %s ago\n", formatDuration(time.Since(state.LastHeartbeat)))
+ fmt.Printf(" Active: %d task(s)\n", state.ActiveTasks)
+ fmt.Printf(" Completed: %d\n", state.CompletedCount)
+ if state.FailedCount > 0 {
+ fmt.Printf(" Failed: %d\n", state.FailedCount)
+ }
+ if state.TotalDownloaded > 0 {
+ fmt.Printf(" Downloaded: %s\n", formatBytes(state.TotalDownloaded))
+ }
+ if len(state.MethodStats) > 0 {
+ parts := make([]string, 0, len(state.MethodStats))
+ for method, count := range state.MethodStats {
+ parts = append(parts, fmt.Sprintf("%s:%d", method, count))
+ }
+ fmt.Printf(" Methods: %s\n", strings.Join(parts, ", "))
+ }
+ } else {
+ dim.Println(" Status: stopped")
+ dim.Println(" Start with: unarr start")
+ }
+ fmt.Println()
+
+ // ── Update check (cached: instant if <1h, otherwise async 3s) ──
+ type versionResult struct {
+ version string
+ fromCache bool
+ err error
+ }
+ versionCh := make(chan versionResult, 1)
+ go func() {
+ ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+ v, cached, err := upgrade.CheckLatestCached(ctx)
+ versionCh <- versionResult{v, cached, err}
+ }()
+
+ cyan.Println(" Update")
+ fmt.Print(" Checking... ")
+ vr := <-versionCh
+ if vr.err != nil {
+ dim.Println("could not check (offline?)")
+ } else {
+ currentClean := strings.TrimPrefix(Version, "v")
+ if currentClean == vr.version {
+ green.Printf("✓ up to date (v%s)\n", vr.version)
+ } else {
+ yellow.Printf("v%s available! ", vr.version)
+ fmt.Printf("Run: unarr upgrade\n")
+ }
+ }
fmt.Println()
return nil
}
+
+// isDaemonAlive checks if the daemon process from the state file is still running.
+// Guards against PID reuse by also checking heartbeat recency.
+func isDaemonAlive(state *agent.DaemonState) bool {
+ if state.PID == 0 {
+ return false
+ }
+ // Reject stale state: if last heartbeat is older than 2 minutes, the daemon
+ // likely crashed and the PID may have been reused by another process.
+ if !state.LastHeartbeat.IsZero() && time.Since(state.LastHeartbeat) > 2*time.Minute {
+ return false
+ }
+ return agent.IsProcessAlive(state.PID)
+}
+
+// formatBytes formats bytes into human-readable string.
+func formatBytes(b int64) string {
+ const unit = 1024
+ if b < unit {
+ return fmt.Sprintf("%d B", b)
+ }
+ div, exp := int64(unit), 0
+ for n := b / unit; n >= unit; n /= unit {
+ div *= unit
+ exp++
+ }
+ return fmt.Sprintf("%.1f %cB", float64(b)/float64(div), "KMGTPE"[exp])
+}
+
+// formatDuration formats a duration into a compact human-readable string.
+func formatDuration(d time.Duration) string {
+ if d < 0 {
+ return "0s"
+ }
+ if d < time.Minute {
+ return fmt.Sprintf("%ds", int(d.Seconds()))
+ }
+ if d < time.Hour {
+ return fmt.Sprintf("%dm %ds", int(d.Minutes()), int(d.Seconds())%60)
+ }
+ if d < 24*time.Hour {
+ return fmt.Sprintf("%dh %dm", int(d.Hours()), int(d.Minutes())%60)
+ }
+ days := int(d.Hours()) / 24
+ hours := int(d.Hours()) % 24
+ return fmt.Sprintf("%dd %dh", days, hours)
+}
diff --git a/internal/cmd/stream_handler.go b/internal/cmd/stream_handler.go
index 9bb9657..88f3111 100644
--- a/internal/cmd/stream_handler.go
+++ b/internal/cmd/stream_handler.go
@@ -24,6 +24,20 @@ var streamRegistry = struct {
servers: make(map[string]*engine.StreamServer),
}
+// cancelAllStreams cancels all active stream tasks and servers (only 1 stream at a time).
+func cancelAllStreams() {
+ streamRegistry.mu.Lock()
+ for taskID, cancel := range streamRegistry.cancels {
+ cancel()
+ delete(streamRegistry.cancels, taskID)
+ }
+ for taskID, srv := range streamRegistry.servers {
+ srv.Shutdown(context.Background())
+ delete(streamRegistry.servers, taskID)
+ }
+ streamRegistry.mu.Unlock()
+}
+
// cancelStreamTask cancels a running stream task and shuts down any stream server.
func cancelStreamTask(taskID string) {
streamRegistry.mu.Lock()
@@ -94,7 +108,7 @@ func handleStreamTask(parentCtx context.Context, at agent.Task, reporter *engine
}
// 4. Start HTTP server
- srv := engine.NewStreamServer(eng, 0)
+ srv := engine.NewStreamServer(eng, cfg.Download.StreamPort)
streamURL, err := srv.Start(ctx)
if err != nil {
task.ErrorMessage = "start HTTP server: " + err.Error()
@@ -107,17 +121,27 @@ func handleStreamTask(parentCtx context.Context, at agent.Task, reporter *engine
task.StreamURL = streamURL
log.Printf("[%s] stream ready: %s", at.ID[:8], streamURL)
- // 6. Progress loop
+ // 6. Unified progress + idle timeout loop
eng.StartProgressLoop(ctx)
- ticker := time.NewTicker(3 * time.Second)
- defer ticker.Stop()
+ progressTicker := time.NewTicker(3 * time.Second)
+ defer progressTicker.Stop()
+ idleCheck := time.NewTicker(60 * time.Second)
+ defer idleCheck.Stop()
+ completed := false
for {
select {
case <-ctx.Done():
log.Printf("[%s] stream stopped", at.ID[:8])
return
- case <-ticker.C:
+
+ case <-idleCheck.C:
+ if srv.IdleSince() > 30*time.Minute {
+ log.Printf("[%s] stream idle timeout (30m no HTTP requests), shutting down", at.ID[:8])
+ return
+ }
+
+ case <-progressTicker.C:
p := eng.Progress()
task.UpdateProgress(engine.Progress{
DownloadedBytes: p.DownloadedBytes,
@@ -129,7 +153,7 @@ func handleStreamTask(parentCtx context.Context, at agent.Task, reporter *engine
})
// Terminal progress
- if p.TotalBytes > 0 {
+ if !completed && p.TotalBytes > 0 {
pct := int(float64(p.DownloadedBytes) / float64(p.TotalBytes) * 100)
fmt.Fprintf(os.Stderr, "\r[%s] %d%% — %s/%s @ %s/s peers:%d seeds:%d",
at.ID[:8], pct,
@@ -137,20 +161,11 @@ func handleStreamTask(parentCtx context.Context, at agent.Task, reporter *engine
p.Peers, p.Seeds)
}
- if p.DownloadedBytes >= p.TotalBytes && p.TotalBytes > 0 {
- fmt.Fprint(os.Stderr, "\r\033[2K") // clear progress line
+ if !completed && p.DownloadedBytes >= p.TotalBytes && p.TotalBytes > 0 {
+ fmt.Fprint(os.Stderr, "\r\033[2K")
task.Transition(engine.StatusCompleted)
- log.Printf("[%s] stream download complete, server stays up for 30m or until cancelled", at.ID[:8])
- // Keep HTTP server running so the player can finish reading.
- // Auto-shutdown after 30 minutes of idle to prevent resource leaks.
- idleTimer := time.NewTimer(30 * time.Minute)
- defer idleTimer.Stop()
- select {
- case <-ctx.Done():
- case <-idleTimer.C:
- log.Printf("[%s] stream idle timeout (30m), shutting down", at.ID[:8])
- }
- return
+ log.Printf("[%s] stream download complete, server stays up until idle (30m)", at.ID[:8])
+ completed = true
}
}
}
diff --git a/internal/cmd/upgrade.go b/internal/cmd/upgrade.go
new file mode 100644
index 0000000..c374603
--- /dev/null
+++ b/internal/cmd/upgrade.go
@@ -0,0 +1,30 @@
+package cmd
+
+import (
+ "github.com/spf13/cobra"
+)
+
+// newUpgradeCmd creates the `unarr upgrade` command as an alias for `self-update`.
+func newUpgradeCmd() *cobra.Command {
+ var force bool
+
+ cmd := &cobra.Command{
+ Use: "upgrade",
+ Aliases: []string{"update"},
+ Short: "Update unarr to the latest version",
+ Long: `Download and install the latest version of unarr.
+
+This is an alias for 'unarr self-update'. Checks GitHub for the latest
+release, verifies the checksum, and replaces the current binary.
+A backup is kept at .backup.`,
+ Example: ` unarr upgrade
+ unarr upgrade --force`,
+ RunE: func(cmd *cobra.Command, args []string) error {
+ return runSelfUpdate(force)
+ },
+ }
+
+ cmd.Flags().BoolVarP(&force, "force", "f", false, "reinstall even if already up to date")
+
+ return cmd
+}
diff --git a/internal/engine/manager_test.go b/internal/engine/manager_test.go
index 7c9893f..84bcc18 100644
--- a/internal/engine/manager_test.go
+++ b/internal/engine/manager_test.go
@@ -2,6 +2,7 @@ package engine
import (
"context"
+ "os"
"testing"
"time"
@@ -83,3 +84,223 @@ func TestManagerShutdown(t *testing.T) {
mgr.Shutdown(ctx)
// Should not hang
}
+
+func TestManagerDefaultConcurrency(t *testing.T) {
+ reporter := NewProgressReporter(
+ agent.NewClient("http://localhost", "test", "test"),
+ 1*time.Second,
+ )
+ mgr := NewManager(ManagerConfig{MaxConcurrent: 0}, reporter)
+ if cap(mgr.sem) != 3 {
+ t.Errorf("default MaxConcurrent should be 3, got %d", cap(mgr.sem))
+ }
+}
+
+func TestManagerGetTask(t *testing.T) {
+ reporter := NewProgressReporter(
+ agent.NewClient("http://localhost", "test", "test"),
+ 1*time.Second,
+ )
+ mgr := NewManager(ManagerConfig{MaxConcurrent: 2}, reporter)
+
+ // No task added
+ if task := mgr.GetTask("nonexistent"); task != nil {
+ t.Error("expected nil for nonexistent task")
+ }
+}
+
+func TestManagerActiveTasks(t *testing.T) {
+ reporter := NewProgressReporter(
+ agent.NewClient("http://localhost", "test", "test"),
+ 1*time.Second,
+ )
+ mgr := NewManager(ManagerConfig{MaxConcurrent: 2}, reporter)
+
+ tasks := mgr.ActiveTasks()
+ if len(tasks) != 0 {
+ t.Errorf("expected 0 active tasks, got %d", len(tasks))
+ }
+}
+
+func TestManagerSubmitCompletesWithValidFile(t *testing.T) {
+ dir := t.TempDir()
+ // Create a file that verify() will accept
+ filePath := dir + "/movie.mkv"
+ os.WriteFile(filePath, make([]byte, 1024), 0o644)
+
+ reporter := &mockStatusReporter{}
+ pr := &ProgressReporter{
+ reporter: reporter,
+ interval: 100 * time.Millisecond,
+ latest: make(map[string]*Task),
+ lastReported: make(map[string]TaskStatus),
+ }
+
+ dl := &resultMockDownloader{
+ method: MethodTorrent,
+ result: &Result{
+ FilePath: filePath,
+ FileName: "movie.mkv",
+ Method: MethodTorrent,
+ Size: 1024,
+ },
+ }
+
+ mgr := NewManager(ManagerConfig{
+ MaxConcurrent: 2,
+ OutputDir: dir,
+ }, pr, dl)
+
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ go pr.Run(ctx)
+
+ mgr.Submit(ctx, agent.Task{
+ ID: "task-complete-test1",
+ InfoHash: "abc123def456abc123def456abc123def456abc1",
+ Title: "Test Movie",
+ PreferredMethod: "torrent",
+ })
+
+ mgr.Wait()
+ cancel()
+
+ // Task should have completed successfully
+ // (we can't check directly since it's removed from active map after processing)
+}
+
+func TestManagerCancelTask(t *testing.T) {
+ reporter := NewProgressReporter(
+ agent.NewClient("http://localhost", "test", "test"),
+ 1*time.Second,
+ )
+
+ dl := &slowMockDownloader{method: MethodTorrent}
+ mgr := NewManager(ManagerConfig{
+ MaxConcurrent: 2,
+ OutputDir: t.TempDir(),
+ }, reporter, dl)
+
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ go reporter.Run(ctx)
+
+ mgr.Submit(ctx, agent.Task{
+ ID: "task-cancel-test12",
+ InfoHash: "abc123def456abc123def456abc123def456abc1",
+ Title: "Cancel Me",
+ PreferredMethod: "torrent",
+ })
+
+ // Give it time to start
+ time.Sleep(100 * time.Millisecond)
+
+ mgr.CancelTask("task-cancel-test12")
+ mgr.Wait()
+}
+
+func TestManagerPauseTask(t *testing.T) {
+ reporter := NewProgressReporter(
+ agent.NewClient("http://localhost", "test", "test"),
+ 1*time.Second,
+ )
+
+ dl := &slowMockDownloader{method: MethodTorrent}
+ mgr := NewManager(ManagerConfig{
+ MaxConcurrent: 2,
+ OutputDir: t.TempDir(),
+ }, reporter, dl)
+
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ go reporter.Run(ctx)
+
+ mgr.Submit(ctx, agent.Task{
+ ID: "task-pause-test123",
+ InfoHash: "abc123def456abc123def456abc123def456abc1",
+ Title: "Pause Me",
+ PreferredMethod: "torrent",
+ })
+
+ time.Sleep(100 * time.Millisecond)
+ mgr.PauseTask("task-pause-test123")
+ mgr.Wait()
+}
+
+func TestManagerCancelAndDeleteFiles(t *testing.T) {
+ reporter := NewProgressReporter(
+ agent.NewClient("http://localhost", "test", "test"),
+ 1*time.Second,
+ )
+
+ dl := &slowMockDownloader{method: MethodTorrent}
+ mgr := NewManager(ManagerConfig{
+ MaxConcurrent: 2,
+ OutputDir: t.TempDir(),
+ }, reporter, dl)
+
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ go reporter.Run(ctx)
+
+ mgr.Submit(ctx, agent.Task{
+ ID: "task-delfile-test12",
+ InfoHash: "abc123def456abc123def456abc123def456abc1",
+ Title: "Delete Me",
+ PreferredMethod: "torrent",
+ })
+
+ time.Sleep(100 * time.Millisecond)
+ mgr.CancelAndDeleteFiles("task-delfile-test12")
+ mgr.Wait()
+}
+
+func TestManagerCancelNonexistent(t *testing.T) {
+ reporter := NewProgressReporter(
+ agent.NewClient("http://localhost", "test", "test"),
+ 1*time.Second,
+ )
+ mgr := NewManager(ManagerConfig{MaxConcurrent: 2}, reporter)
+ // Should not panic
+ mgr.CancelTask("nonexistent")
+ mgr.PauseTask("nonexistent")
+ mgr.CancelAndDeleteFiles("nonexistent")
+}
+
+// resultMockDownloader returns a configurable result
+type resultMockDownloader struct {
+ method DownloadMethod
+ result *Result
+}
+
+func (m *resultMockDownloader) Method() DownloadMethod { return m.method }
+func (m *resultMockDownloader) Available(_ context.Context, _ *Task) (bool, error) {
+ return true, nil
+}
+func (m *resultMockDownloader) Download(_ context.Context, _ *Task, _ string, _ chan<- Progress) (*Result, error) {
+ return m.result, nil
+}
+func (m *resultMockDownloader) Pause(_ string) error { return nil }
+func (m *resultMockDownloader) Cancel(_ string) error { return nil }
+func (m *resultMockDownloader) Shutdown(_ context.Context) error { return nil }
+
+// slowMockDownloader blocks until context is cancelled
+type slowMockDownloader struct {
+ method DownloadMethod
+}
+
+func (m *slowMockDownloader) Method() DownloadMethod { return m.method }
+func (m *slowMockDownloader) Available(_ context.Context, _ *Task) (bool, error) {
+ return true, nil
+}
+func (m *slowMockDownloader) Download(ctx context.Context, _ *Task, _ string, _ chan<- Progress) (*Result, error) {
+ <-ctx.Done()
+ return nil, ctx.Err()
+}
+func (m *slowMockDownloader) Pause(_ string) error { return nil }
+func (m *slowMockDownloader) Cancel(_ string) error { return nil }
+func (m *slowMockDownloader) Shutdown(_ context.Context) error { return nil }
diff --git a/internal/engine/method_test.go b/internal/engine/method_test.go
new file mode 100644
index 0000000..e913d32
--- /dev/null
+++ b/internal/engine/method_test.go
@@ -0,0 +1,50 @@
+package engine
+
+import "testing"
+
+func TestDownloadMethodConstants(t *testing.T) {
+ if MethodTorrent != "torrent" {
+ t.Errorf("MethodTorrent = %q, want torrent", MethodTorrent)
+ }
+ if MethodDebrid != "debrid" {
+ t.Errorf("MethodDebrid = %q, want debrid", MethodDebrid)
+ }
+ if MethodUsenet != "usenet" {
+ t.Errorf("MethodUsenet = %q, want usenet", MethodUsenet)
+ }
+}
+
+func TestProgressStruct(t *testing.T) {
+ p := Progress{
+ DownloadedBytes: 1024,
+ TotalBytes: 2048,
+ SpeedBps: 512,
+ ETA: 10,
+ Peers: 5,
+ Seeds: 3,
+ FileName: "movie.mkv",
+ }
+
+ if p.DownloadedBytes != 1024 {
+ t.Errorf("DownloadedBytes = %d, want 1024", p.DownloadedBytes)
+ }
+ if p.FileName != "movie.mkv" {
+ t.Errorf("FileName = %q, want movie.mkv", p.FileName)
+ }
+}
+
+func TestResultStruct(t *testing.T) {
+ r := Result{
+ FilePath: "/downloads/movie.mkv",
+ FileName: "movie.mkv",
+ Method: MethodTorrent,
+ Size: 1073741824,
+ }
+
+ if r.Method != MethodTorrent {
+ t.Errorf("Method = %q, want torrent", r.Method)
+ }
+ if r.Size != 1073741824 {
+ t.Errorf("Size = %d, want 1073741824", r.Size)
+ }
+}
diff --git a/internal/engine/organize_expand_test.go b/internal/engine/organize_expand_test.go
new file mode 100644
index 0000000..0a7d2f2
--- /dev/null
+++ b/internal/engine/organize_expand_test.go
@@ -0,0 +1,181 @@
+package engine
+
+import (
+ "os"
+ "path/filepath"
+ "testing"
+)
+
+func TestReplaceFile(t *testing.T) {
+ tmp := t.TempDir()
+ backupDir := filepath.Join(tmp, "backups")
+
+ // Create "old" file
+ oldPath := filepath.Join(tmp, "movie.mkv")
+ os.WriteFile(oldPath, []byte("old content"), 0o644)
+
+ // Create "new" file
+ newPath := filepath.Join(tmp, "movie-new.mkv")
+ os.WriteFile(newPath, []byte("new better content"), 0o644)
+
+ err := replaceFile(oldPath, newPath, backupDir)
+ if err != nil {
+ t.Fatalf("replaceFile: %v", err)
+ }
+
+ // Old path should now contain new content
+ data, err := os.ReadFile(oldPath)
+ if err != nil {
+ t.Fatalf("read old path: %v", err)
+ }
+ if string(data) != "new better content" {
+ t.Errorf("old path content = %q, want 'new better content'", string(data))
+ }
+
+ // Backup should exist
+ entries, _ := os.ReadDir(backupDir)
+ if len(entries) != 1 {
+ t.Errorf("expected 1 backup file, got %d", len(entries))
+ }
+
+ // New file should be gone
+ if _, err := os.Stat(newPath); !os.IsNotExist(err) {
+ t.Error("new file should have been moved/deleted")
+ }
+}
+
+func TestReplaceFileOldNotFound(t *testing.T) {
+ tmp := t.TempDir()
+ err := replaceFile(filepath.Join(tmp, "nonexistent.mkv"), filepath.Join(tmp, "new.mkv"), "")
+ if err == nil {
+ t.Error("expected error when old file doesn't exist")
+ }
+}
+
+func TestCopyFile(t *testing.T) {
+ tmp := t.TempDir()
+
+ src := filepath.Join(tmp, "source.txt")
+ dst := filepath.Join(tmp, "dest.txt")
+
+ content := []byte("hello world copy test")
+ os.WriteFile(src, content, 0o644)
+
+ err := copyFile(src, dst)
+ if err != nil {
+ t.Fatalf("copyFile: %v", err)
+ }
+
+ data, err := os.ReadFile(dst)
+ if err != nil {
+ t.Fatalf("read dest: %v", err)
+ }
+ if string(data) != string(content) {
+ t.Errorf("dest content = %q, want %q", string(data), string(content))
+ }
+}
+
+func TestCopyFileSrcNotFound(t *testing.T) {
+ tmp := t.TempDir()
+ err := copyFile(filepath.Join(tmp, "nope.txt"), filepath.Join(tmp, "out.txt"))
+ if err == nil {
+ t.Error("expected error when source doesn't exist")
+ }
+}
+
+func TestOrganizeNoDirs(t *testing.T) {
+ r := &Result{FilePath: "/tmp/file.mkv", FileName: "file.mkv"}
+ task := &Task{Title: "Movie"}
+
+ path, err := organize(r, task, OrganizeConfig{Enabled: true})
+ if err != nil {
+ t.Fatal(err)
+ }
+ if path != "/tmp/file.mkv" {
+ t.Errorf("should return original path when no dirs configured, got %q", path)
+ }
+}
+
+func TestOrganizeNilResult(t *testing.T) {
+ task := &Task{Title: "Movie"}
+ path, err := organize(&Result{}, task, OrganizeConfig{Enabled: true})
+ if err != nil {
+ t.Fatal(err)
+ }
+ if path != "" {
+ t.Errorf("expected empty path for empty result, got %q", path)
+ }
+}
+
+func TestOrganizeMovieDirectory(t *testing.T) {
+ tmp := t.TempDir()
+ srcDir := filepath.Join(tmp, "src", "MovieDir")
+ os.MkdirAll(srcDir, 0o755)
+ os.WriteFile(filepath.Join(srcDir, "movie.mkv"), []byte("data"), 0o644)
+
+ moviesDir := filepath.Join(tmp, "Movies")
+
+ r := &Result{FilePath: srcDir, FileName: "MovieDir"}
+ task := &Task{Title: "My Movie 2023"}
+
+ path, err := organize(r, task, OrganizeConfig{
+ Enabled: true,
+ MoviesDir: moviesDir,
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if path == srcDir {
+ t.Error("directory should have moved")
+ }
+ if _, err := os.Stat(path); err != nil {
+ t.Errorf("organized directory should exist at %s", path)
+ }
+}
+
+func TestOrganizeSeasonOnly(t *testing.T) {
+ tmp := t.TempDir()
+ srcFile := filepath.Join(tmp, "Show.S01.Complete.mkv")
+ os.WriteFile(srcFile, []byte("data"), 0o644)
+
+ tvDir := filepath.Join(tmp, "TV")
+
+ r := &Result{FilePath: srcFile, FileName: "Show.S01.Complete.mkv"}
+ task := &Task{Title: "Show S01"}
+
+ path, err := organize(r, task, OrganizeConfig{
+ Enabled: true,
+ TVShowsDir: tvDir,
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ dir := filepath.Dir(path)
+ if filepath.Base(dir) != "Season 01" {
+ t.Errorf("expected Season 01 directory, got %q", filepath.Base(dir))
+ }
+}
+
+func TestCleanTitleEdgeCases(t *testing.T) {
+ tests := []struct {
+ input string
+ want string
+ }{
+ {"", ""},
+ {"Simple Title", "Simple Title"},
+ {"Title (2023) 1080p BluRay", "Title"},
+ {"Title 720p HDTV", "Title"},
+ {"Title x264 HEVC", "Title"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.input, func(t *testing.T) {
+ got := cleanTitle(tt.input)
+ if got != tt.want {
+ t.Errorf("cleanTitle(%q) = %q, want %q", tt.input, got, tt.want)
+ }
+ })
+ }
+}
diff --git a/internal/engine/progress_test.go b/internal/engine/progress_test.go
new file mode 100644
index 0000000..1bb36c6
--- /dev/null
+++ b/internal/engine/progress_test.go
@@ -0,0 +1,419 @@
+package engine
+
+import (
+ "context"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/torrentclaw/unarr/internal/agent"
+)
+
+// mockStatusReporter records calls to ReportStatus.
+type mockStatusReporter struct {
+ mu sync.Mutex
+ calls []agent.StatusUpdate
+ resp *agent.StatusResponse
+ respErr error
+}
+
+func (m *mockStatusReporter) ReportStatus(_ context.Context, update agent.StatusUpdate) (*agent.StatusResponse, error) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.calls = append(m.calls, update)
+ if m.resp != nil {
+ return m.resp, m.respErr
+ }
+ return &agent.StatusResponse{}, m.respErr
+}
+
+// mockBatchReporter records batch calls.
+type mockBatchReporter struct {
+ mockStatusReporter
+ batchCalls [][]agent.StatusUpdate
+ batchResp *agent.BatchStatusResponse
+}
+
+func (m *mockBatchReporter) BatchReportStatus(_ context.Context, updates []agent.StatusUpdate) (*agent.BatchStatusResponse, error) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.batchCalls = append(m.batchCalls, updates)
+ if m.batchResp != nil {
+ return m.batchResp, nil
+ }
+ results := make([]agent.StatusResponse, len(updates))
+ return &agent.BatchStatusResponse{Results: results}, nil
+}
+
+func TestProgressReporter_TrackUntrack(t *testing.T) {
+ reporter := &mockStatusReporter{}
+ pr := &ProgressReporter{
+ reporter: reporter,
+ interval: time.Second,
+ latest: make(map[string]*Task),
+ lastReported: make(map[string]TaskStatus),
+ }
+
+ task := &Task{ID: "task-001", Status: StatusDownloading}
+ pr.Track(task)
+
+ pr.mu.Lock()
+ if _, ok := pr.latest["task-001"]; !ok {
+ t.Error("task should be tracked")
+ }
+ pr.mu.Unlock()
+
+ pr.Untrack("task-001")
+
+ pr.mu.Lock()
+ if _, ok := pr.latest["task-001"]; ok {
+ t.Error("task should be untracked")
+ }
+ pr.mu.Unlock()
+}
+
+func TestProgressReporter_FlushReportsFinalStates(t *testing.T) {
+ reporter := &mockStatusReporter{}
+ pr := &ProgressReporter{
+ reporter: reporter,
+ interval: time.Second,
+ latest: make(map[string]*Task),
+ lastReported: make(map[string]TaskStatus),
+ }
+
+ completed := &Task{ID: "task-completed-1234", Status: StatusCompleted}
+ pr.Track(completed)
+
+ pr.flush(context.Background())
+
+ reporter.mu.Lock()
+ defer reporter.mu.Unlock()
+ if len(reporter.calls) != 1 {
+ t.Fatalf("expected 1 report, got %d", len(reporter.calls))
+ }
+ if reporter.calls[0].TaskID != "task-completed-1234" {
+ t.Errorf("reported wrong task: %s", reporter.calls[0].TaskID)
+ }
+}
+
+func TestProgressReporter_FlushSkipsWhenNotWatching(t *testing.T) {
+ reporter := &mockStatusReporter{}
+ pr := &ProgressReporter{
+ reporter: reporter,
+ interval: time.Second,
+ isWatching: func() bool { return false },
+ latest: make(map[string]*Task),
+ lastReported: make(map[string]TaskStatus),
+ lastCheckAt: time.Now(), // not due for control check
+ }
+
+ // Active downloading task, already reported as downloading
+ task := &Task{ID: "task-active-12345678", Status: StatusDownloading}
+ pr.Track(task)
+ pr.mu.Lock()
+ pr.lastReported["task-active-12345678"] = StatusDownloading
+ pr.mu.Unlock()
+
+ pr.flush(context.Background())
+
+ reporter.mu.Lock()
+ defer reporter.mu.Unlock()
+ if len(reporter.calls) != 0 {
+ t.Errorf("expected 0 reports when not watching (no transition), got %d", len(reporter.calls))
+ }
+}
+
+func TestProgressReporter_FlushReportsTransitions(t *testing.T) {
+ reporter := &mockStatusReporter{}
+ pr := &ProgressReporter{
+ reporter: reporter,
+ interval: time.Second,
+ isWatching: func() bool { return false },
+ latest: make(map[string]*Task),
+ lastReported: make(map[string]TaskStatus),
+ lastCheckAt: time.Now(),
+ }
+
+ // Task transitioning from resolving to downloading
+ task := &Task{ID: "task-trans-12345678", Status: StatusDownloading}
+ pr.Track(task)
+ pr.mu.Lock()
+ pr.lastReported["task-trans-12345678"] = StatusResolving
+ pr.mu.Unlock()
+
+ pr.flush(context.Background())
+
+ reporter.mu.Lock()
+ defer reporter.mu.Unlock()
+ if len(reporter.calls) != 1 {
+ t.Fatalf("expected 1 report for transition, got %d", len(reporter.calls))
+ }
+}
+
+func TestProgressReporter_FlushActiveWhenWatching(t *testing.T) {
+ reporter := &mockStatusReporter{}
+ pr := &ProgressReporter{
+ reporter: reporter,
+ interval: time.Second,
+ isWatching: func() bool { return true },
+ latest: make(map[string]*Task),
+ lastReported: make(map[string]TaskStatus),
+ }
+
+ task := &Task{ID: "task-watch-12345678", Status: StatusDownloading}
+ pr.Track(task)
+ pr.mu.Lock()
+ pr.lastReported["task-watch-12345678"] = StatusDownloading
+ pr.mu.Unlock()
+
+ pr.flush(context.Background())
+
+ reporter.mu.Lock()
+ defer reporter.mu.Unlock()
+ if len(reporter.calls) != 1 {
+ t.Fatalf("expected 1 report when watching active task, got %d", len(reporter.calls))
+ }
+}
+
+func TestProgressReporter_HandleResponseCancel(t *testing.T) {
+ reporter := &mockStatusReporter{
+ resp: &agent.StatusResponse{Cancelled: true},
+ }
+
+ var cancelledID string
+ pr := &ProgressReporter{
+ reporter: reporter,
+ interval: time.Second,
+ latest: make(map[string]*Task),
+ lastReported: make(map[string]TaskStatus),
+ onCancel: func(id string) { cancelledID = id },
+ }
+
+ task := &Task{ID: "task-cancel-1234567", Status: StatusCompleted}
+ pr.Track(task)
+
+ pr.flush(context.Background())
+
+ if cancelledID != "task-cancel-1234567" {
+ t.Errorf("expected cancel handler called with task ID, got %q", cancelledID)
+ }
+}
+
+func TestProgressReporter_HandleResponsePause(t *testing.T) {
+ reporter := &mockStatusReporter{
+ resp: &agent.StatusResponse{Paused: true},
+ }
+
+ var pausedID string
+ pr := &ProgressReporter{
+ reporter: reporter,
+ interval: time.Second,
+ latest: make(map[string]*Task),
+ lastReported: make(map[string]TaskStatus),
+ onPause: func(id string) { pausedID = id },
+ }
+
+ task := &Task{ID: "task-paused-1234567", Status: StatusCompleted}
+ pr.Track(task)
+
+ pr.flush(context.Background())
+
+ if pausedID != "task-paused-1234567" {
+ t.Errorf("expected pause handler called, got %q", pausedID)
+ }
+}
+
+func TestProgressReporter_HandleResponseDeleteFiles(t *testing.T) {
+ reporter := &mockStatusReporter{
+ resp: &agent.StatusResponse{Cancelled: true, DeleteFiles: true},
+ }
+
+ var deletedID string
+ pr := &ProgressReporter{
+ reporter: reporter,
+ interval: time.Second,
+ latest: make(map[string]*Task),
+ lastReported: make(map[string]TaskStatus),
+ onDeleteFiles: func(id string) { deletedID = id },
+ }
+
+ task := &Task{ID: "task-delete-1234567", Status: StatusCompleted}
+ pr.Track(task)
+
+ pr.flush(context.Background())
+
+ if deletedID != "task-delete-1234567" {
+ t.Errorf("expected deleteFiles handler called, got %q", deletedID)
+ }
+}
+
+func TestProgressReporter_HandleResponseStream(t *testing.T) {
+ reporter := &mockStatusReporter{
+ resp: &agent.StatusResponse{StreamRequested: true},
+ }
+
+ var streamID string
+ pr := &ProgressReporter{
+ reporter: reporter,
+ interval: time.Second,
+ latest: make(map[string]*Task),
+ lastReported: make(map[string]TaskStatus),
+ onStreamRequested: func(id string) { streamID = id },
+ }
+
+ // Task with no stream URL yet
+ task := &Task{ID: "task-stream-1234567", Status: StatusCompleted}
+ pr.Track(task)
+
+ pr.flush(context.Background())
+
+ if streamID != "task-stream-1234567" {
+ t.Errorf("expected stream handler called, got %q", streamID)
+ }
+}
+
+func TestProgressReporter_HandleResponseWatchingChanged(t *testing.T) {
+ reporter := &mockStatusReporter{
+ resp: &agent.StatusResponse{Watching: true},
+ }
+
+ var watchingValue bool
+ pr := &ProgressReporter{
+ reporter: reporter,
+ interval: time.Second,
+ latest: make(map[string]*Task),
+ lastReported: make(map[string]TaskStatus),
+ onWatchingChanged: func(w bool) { watchingValue = w },
+ }
+
+ task := &Task{ID: "task-watch2-1234567", Status: StatusCompleted}
+ pr.Track(task)
+
+ pr.flush(context.Background())
+
+ if !watchingValue {
+ t.Error("expected watchingChanged called with true")
+ }
+}
+
+func TestProgressReporter_BatchFlush(t *testing.T) {
+ batcher := &mockBatchReporter{
+ batchResp: &agent.BatchStatusResponse{
+ Results: []agent.StatusResponse{{}, {}},
+ },
+ }
+
+ pr := &ProgressReporter{
+ reporter: batcher,
+ interval: time.Second,
+ isWatching: func() bool { return true },
+ latest: make(map[string]*Task),
+ lastReported: make(map[string]TaskStatus),
+ }
+
+ pr.Track(&Task{ID: "task-batch1-1234567", Status: StatusDownloading})
+ pr.Track(&Task{ID: "task-batch2-1234567", Status: StatusDownloading})
+
+ pr.flush(context.Background())
+
+ batcher.mu.Lock()
+ defer batcher.mu.Unlock()
+
+ if len(batcher.batchCalls) != 1 {
+ t.Fatalf("expected 1 batch call, got %d", len(batcher.batchCalls))
+ }
+ if len(batcher.batchCalls[0]) != 2 {
+ t.Errorf("expected 2 updates in batch, got %d", len(batcher.batchCalls[0]))
+ }
+}
+
+func TestProgressReporter_RunStopsOnCancel(t *testing.T) {
+ reporter := &mockStatusReporter{}
+ pr := &ProgressReporter{
+ reporter: reporter,
+ interval: 50 * time.Millisecond,
+ latest: make(map[string]*Task),
+ lastReported: make(map[string]TaskStatus),
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
+ defer cancel()
+
+ err := pr.Run(ctx)
+ if err != nil {
+ t.Errorf("Run should return nil on context cancel, got: %v", err)
+ }
+}
+
+func TestProgressReporter_ReportFinal(t *testing.T) {
+ reporter := &mockStatusReporter{}
+ pr := &ProgressReporter{
+ reporter: reporter,
+ interval: time.Second,
+ latest: make(map[string]*Task),
+ lastReported: make(map[string]TaskStatus),
+ }
+
+ task := &Task{ID: "task-final-12345678", Status: StatusCompleted}
+ pr.Track(task)
+
+ pr.ReportFinal(context.Background(), task)
+
+ reporter.mu.Lock()
+ defer reporter.mu.Unlock()
+ if len(reporter.calls) != 1 {
+ t.Fatalf("expected 1 final report, got %d", len(reporter.calls))
+ }
+
+ // Should be untracked after final report
+ pr.mu.Lock()
+ if _, ok := pr.latest["task-final-12345678"]; ok {
+ t.Error("task should be untracked after ReportFinal")
+ }
+ pr.mu.Unlock()
+}
+
+func TestProgressReporter_SetHandlers(t *testing.T) {
+ pr := &ProgressReporter{
+ latest: make(map[string]*Task),
+ lastReported: make(map[string]TaskStatus),
+ }
+
+ pr.SetCancelHandler(func(id string) {})
+ pr.SetPauseHandler(func(id string) {})
+ pr.SetDeleteFilesHandler(func(id string) {})
+ pr.SetStreamRequestedHandler(func(id string) {})
+ pr.SetWatchingFunc(func() bool { return true })
+ pr.SetWatchingChangedHandler(func(w bool) {})
+
+ if pr.onCancel == nil || pr.onPause == nil || pr.onDeleteFiles == nil ||
+ pr.onStreamRequested == nil || pr.isWatching == nil || pr.onWatchingChanged == nil {
+ t.Error("expected all handlers to be set")
+ }
+}
+
+func TestProgressReporter_ControlCheckDue(t *testing.T) {
+ reporter := &mockStatusReporter{}
+ pr := &ProgressReporter{
+ reporter: reporter,
+ interval: time.Second,
+ isWatching: func() bool { return false },
+ latest: make(map[string]*Task),
+ lastReported: make(map[string]TaskStatus),
+ lastCheckAt: time.Now().Add(-31 * time.Second), // 31s ago - due for control check
+ }
+
+ task := &Task{ID: "task-ctrl-123456789", Status: StatusDownloading}
+ pr.Track(task)
+ pr.mu.Lock()
+ pr.lastReported["task-ctrl-123456789"] = StatusDownloading
+ pr.mu.Unlock()
+
+ pr.flush(context.Background())
+
+ reporter.mu.Lock()
+ defer reporter.mu.Unlock()
+ if len(reporter.calls) != 1 {
+ t.Errorf("expected 1 report for control check, got %d", len(reporter.calls))
+ }
+}
diff --git a/internal/engine/stream_server.go b/internal/engine/stream_server.go
index 1635d69..e85cb13 100644
--- a/internal/engine/stream_server.go
+++ b/internal/engine/stream_server.go
@@ -11,6 +11,7 @@ import (
"os/exec"
"path/filepath"
"strings"
+ "sync/atomic"
"time"
"github.com/anacrolix/torrent"
@@ -24,11 +25,12 @@ type fileProvider interface {
// StreamServer serves a torrent file over HTTP with Range request support.
type StreamServer struct {
- provider fileProvider
- server *http.Server
- port int
- url string
- upnpMapping *UPnPMapping
+ provider fileProvider
+ server *http.Server
+ port int
+ url string
+ upnpMapping *UPnPMapping
+ lastActivity atomic.Int64 // UnixNano of last HTTP request
}
// NewStreamServer creates a new HTTP server for streaming via StreamEngine.
@@ -93,11 +95,38 @@ func NewStreamServerFromDisk(filePath string, port int) *StreamServer {
}
}
-// Start begins serving the file on all interfaces. Returns the best reachable URL:
-// 1. UPnP public IP (accessible from anywhere on the internet)
-// 2. Tailscale IP (accessible from any device in the tailnet)
-// 3. LAN IP (accessible from local network)
+// FindVideoFile scans a directory (recursively) for the largest video file.
+// Returns empty string if no video file found.
+func FindVideoFile(dir string) string {
+ var best string
+ var bestSize int64
+
+ filepath.WalkDir(dir, func(path string, d os.DirEntry, err error) error {
+ if err != nil || d.IsDir() {
+ return nil
+ }
+ ext := strings.ToLower(filepath.Ext(d.Name()))
+ if !VideoExts[ext] {
+ return nil
+ }
+ info, err := d.Info()
+ if err != nil {
+ return nil
+ }
+ if info.Size() > bestSize {
+ best = path
+ bestSize = info.Size()
+ }
+ return nil
+ })
+ return best
+}
+
+// Start begins serving the file on all interfaces. Returns the best reachable URL.
+// The file is served as-is — the user's media player (VLC, mpv, etc.) handles decoding.
func (ss *StreamServer) Start(ctx context.Context) (string, error) {
+ ss.lastActivity.Store(time.Now().UnixNano())
+
mux := http.NewServeMux()
mux.HandleFunc("/stream", ss.handler)
@@ -107,19 +136,9 @@ func (ss *StreamServer) Start(ctx context.Context) (string, error) {
return "", fmt.Errorf("listen on %s: %w", addr, err)
}
- // Extract actual port (important when port=0)
ss.port = listener.Addr().(*net.TCPAddr).Port
-
- // Try UPnP for public internet access (like Plex Remote Access)
- if mapping, upnpErr := setupUPnP(ss.port); upnpErr == nil {
- ss.upnpMapping = mapping
- ss.url = fmt.Sprintf("http://%s:%d/stream", mapping.ExternalIP, mapping.ExternalPort)
- log.Printf("stream: UPnP mapped %s:%d -> local:%d", mapping.ExternalIP, mapping.ExternalPort, ss.port)
- } else {
- // Fallback: Tailscale IP > LAN IP > 127.0.0.1
- ss.url = fmt.Sprintf("http://%s:%d/stream", reachableIP(), ss.port)
- log.Printf("stream: UPnP unavailable (%v), using %s", upnpErr, ss.url)
- }
+ ss.url = fmt.Sprintf("http://%s:%d/stream", reachableIP(), ss.port)
+ log.Printf("stream: serving on %s", ss.url)
ss.server = &http.Server{
Handler: mux,
@@ -141,6 +160,15 @@ func (ss *StreamServer) URL() string { return ss.url }
// Port returns the bound port.
func (ss *StreamServer) Port() int { return ss.port }
+// IdleSince returns how long since the last HTTP request was received.
+func (ss *StreamServer) IdleSince() time.Duration {
+ last := ss.lastActivity.Load()
+ if last == 0 {
+ return 0
+ }
+ return time.Since(time.Unix(0, last))
+}
+
// Shutdown gracefully stops the HTTP server and removes the UPnP port mapping.
func (ss *StreamServer) Shutdown(ctx context.Context) error {
ss.upnpMapping.Remove()
@@ -151,6 +179,8 @@ func (ss *StreamServer) Shutdown(ctx context.Context) error {
}
func (ss *StreamServer) handler(w http.ResponseWriter, r *http.Request) {
+ ss.lastActivity.Store(time.Now().UnixNano())
+
// CORS headers — only when browser sends Origin (HTTPS site → localhost)
if origin := r.Header.Get("Origin"); origin != "" {
w.Header().Set("Access-Control-Allow-Origin", "*")
diff --git a/internal/library/mediainfo/ffprobe.go b/internal/library/mediainfo/ffprobe.go
index f2c70fb..723ef6f 100644
--- a/internal/library/mediainfo/ffprobe.go
+++ b/internal/library/mediainfo/ffprobe.go
@@ -78,6 +78,12 @@ func ExtractMediaInfo(ctx context.Context, ffprobePath, filePath string) (*Media
return nil, fmt.Errorf("ffprobe JSON parse failed: %w", err)
}
+ return parseFFprobeOutput(data)
+}
+
+// parseFFprobeOutput converts parsed ffprobe JSON into MediaInfo.
+// Separated from ExtractMediaInfo so it can be tested without running ffprobe.
+func parseFFprobeOutput(data ffprobeOutput) (*MediaInfo, error) {
if len(data.Streams) == 0 {
return nil, fmt.Errorf("ffprobe returned no streams")
}
diff --git a/internal/library/mediainfo/ffprobe_test.go b/internal/library/mediainfo/ffprobe_test.go
new file mode 100644
index 0000000..e29eed1
--- /dev/null
+++ b/internal/library/mediainfo/ffprobe_test.go
@@ -0,0 +1,430 @@
+package mediainfo
+
+import (
+ "testing"
+)
+
+func TestParseDuration(t *testing.T) {
+ tests := []struct {
+ input string
+ want float64
+ }{
+ {"", 0},
+ {"0", 0},
+ {"-5", 0},
+ {"invalid", 0},
+ {"7423.500000", 7423.5},
+ {"120.123456", 120.123},
+ {"3600", 3600},
+ {"0.001", 0.001},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.input, func(t *testing.T) {
+ got := parseDuration(tt.input)
+ if got != tt.want {
+ t.Errorf("parseDuration(%q) = %v, want %v", tt.input, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestTagValue(t *testing.T) {
+ tags := map[string]string{
+ "language": "eng",
+ "title": "Main Audio",
+ "HANDLER": "VideoHandler",
+ }
+
+ tests := []struct {
+ key string
+ want string
+ }{
+ {"language", "eng"},
+ {"title", "Main Audio"},
+ {"handler", "VideoHandler"},
+ {"missing", ""},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.key, func(t *testing.T) {
+ got := tagValue(tags, tt.key)
+ if got != tt.want {
+ t.Errorf("tagValue(tags, %q) = %q, want %q", tt.key, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestTagValueNil(t *testing.T) {
+ got := tagValue(nil, "language")
+ if got != "" {
+ t.Errorf("tagValue(nil, language) = %q, want empty", got)
+ }
+}
+
+func TestContainsAny(t *testing.T) {
+ tests := []struct {
+ s string
+ subs []string
+ want bool
+ }{
+ {"yuv420p10le", []string{"10le", "10be", "p010"}, true},
+ {"yuv420p12be", []string{"10le", "10be", "p010"}, false},
+ {"yuv420p12be", []string{"12le", "12be"}, true},
+ {"yuv420p", []string{"10le", "10be"}, false},
+ {"", []string{"any"}, false},
+ {"something", []string{}, false},
+ }
+
+ for _, tt := range tests {
+ got := containsAny(tt.s, tt.subs...)
+ if got != tt.want {
+ t.Errorf("containsAny(%q, %v) = %v, want %v", tt.s, tt.subs, got, tt.want)
+ }
+ }
+}
+
+func TestParseFFprobeOutput_BasicH264(t *testing.T) {
+ data := ffprobeOutput{
+ Format: ffprobeFormat{Duration: "7423.5"},
+ Streams: []ffprobeStream{
+ {
+ CodecType: "video",
+ CodecName: "h264",
+ Profile: "High",
+ Width: 1920,
+ Height: 1080,
+ RFrameRate: "24000/1001",
+ },
+ {
+ CodecType: "audio",
+ CodecName: "aac",
+ Channels: 2,
+ Tags: map[string]string{"language": "eng"},
+ Disposition: map[string]int{"default": 1},
+ },
+ },
+ }
+
+ mi, err := parseFFprobeOutput(data)
+ if err != nil {
+ t.Fatalf("parseFFprobeOutput: %v", err)
+ }
+ if mi.Video == nil {
+ t.Fatal("expected video info")
+ }
+ if mi.Video.Codec != "h264" {
+ t.Errorf("codec = %q, want h264", mi.Video.Codec)
+ }
+ if mi.Video.Width != 1920 || mi.Video.Height != 1080 {
+ t.Errorf("dimensions = %dx%d, want 1920x1080", mi.Video.Width, mi.Video.Height)
+ }
+ if mi.Video.Profile != "High" {
+ t.Errorf("profile = %q, want High", mi.Video.Profile)
+ }
+ if mi.Video.Duration != 7423.5 {
+ t.Errorf("duration = %v, want 7423.5", mi.Video.Duration)
+ }
+ if mi.Video.FrameRate < 23.975 || mi.Video.FrameRate > 23.977 {
+ t.Errorf("frameRate = %v, want ~23.976", mi.Video.FrameRate)
+ }
+ if len(mi.Audio) != 1 {
+ t.Fatalf("audio tracks = %d, want 1", len(mi.Audio))
+ }
+ if mi.Audio[0].Lang != "en" {
+ t.Errorf("audio lang = %q, want en", mi.Audio[0].Lang)
+ }
+ if !mi.Audio[0].Default {
+ t.Error("expected default audio track")
+ }
+}
+
+func TestParseFFprobeOutput_HEVC_HDR10(t *testing.T) {
+ data := ffprobeOutput{
+ Format: ffprobeFormat{Duration: "3600"},
+ Streams: []ffprobeStream{
+ {
+ CodecType: "video",
+ CodecName: "hevc",
+ Width: 3840,
+ Height: 2160,
+ BitsPerRaw: "10",
+ ColorSpace: "bt2020nc",
+ ColorTransfer: "smpte2084",
+ RFrameRate: "24/1",
+ },
+ },
+ }
+
+ mi, err := parseFFprobeOutput(data)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if mi.Video.HDR != "HDR10" {
+ t.Errorf("hdr = %q, want HDR10", mi.Video.HDR)
+ }
+ if mi.Video.BitDepth != 10 {
+ t.Errorf("bitDepth = %d, want 10", mi.Video.BitDepth)
+ }
+}
+
+func TestParseFFprobeOutput_DolbyVisionWithHDR10(t *testing.T) {
+ data := ffprobeOutput{
+ Streams: []ffprobeStream{
+ {
+ CodecType: "video",
+ CodecName: "hevc",
+ Width: 3840,
+ Height: 2160,
+ ColorSpace: "bt2020nc",
+ ColorTransfer: "smpte2084",
+ SideDataList: []sideData{{SideDataType: "DOVI configuration record"}},
+ },
+ },
+ }
+
+ mi, err := parseFFprobeOutput(data)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if mi.Video.HDR != "DV+HDR10" {
+ t.Errorf("hdr = %q, want DV+HDR10", mi.Video.HDR)
+ }
+}
+
+func TestParseFFprobeOutput_DolbyVisionOnly(t *testing.T) {
+ data := ffprobeOutput{
+ Streams: []ffprobeStream{
+ {
+ CodecType: "video",
+ CodecName: "hevc",
+ Width: 3840,
+ Height: 2160,
+ SideDataList: []sideData{{SideDataType: "DOVI configuration record"}},
+ },
+ },
+ }
+
+ mi, err := parseFFprobeOutput(data)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if mi.Video.HDR != "DV" {
+ t.Errorf("hdr = %q, want DV", mi.Video.HDR)
+ }
+}
+
+func TestParseFFprobeOutput_HLG(t *testing.T) {
+ data := ffprobeOutput{
+ Streams: []ffprobeStream{
+ {
+ CodecType: "video",
+ CodecName: "hevc",
+ Width: 3840,
+ Height: 2160,
+ ColorSpace: "bt2020nc",
+ ColorTransfer: "arib-std-b67",
+ },
+ },
+ }
+
+ mi, err := parseFFprobeOutput(data)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if mi.Video.HDR != "HLG" {
+ t.Errorf("hdr = %q, want HLG", mi.Video.HDR)
+ }
+}
+
+func TestParseFFprobeOutput_MultiAudioAndSubtitles(t *testing.T) {
+ data := ffprobeOutput{
+ Format: ffprobeFormat{Duration: "5400"},
+ Streams: []ffprobeStream{
+ {CodecType: "video", CodecName: "h264", Width: 1920, Height: 1080},
+ {
+ CodecType: "audio", CodecName: "ac3", Channels: 6,
+ Tags: map[string]string{"language": "eng", "title": "English 5.1"},
+ Disposition: map[string]int{"default": 1},
+ },
+ {
+ CodecType: "audio", CodecName: "aac", Channels: 2,
+ Tags: map[string]string{"language": "spa"},
+ },
+ {
+ CodecType: "subtitle", CodecName: "subrip",
+ Tags: map[string]string{"language": "eng"},
+ },
+ {
+ CodecType: "subtitle", CodecName: "ass",
+ Tags: map[string]string{"language": "spa"},
+ Disposition: map[string]int{"forced": 1},
+ },
+ },
+ }
+
+ mi, err := parseFFprobeOutput(data)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(mi.Audio) != 2 {
+ t.Fatalf("audio tracks = %d, want 2", len(mi.Audio))
+ }
+ if mi.Audio[0].Title != "English 5.1" {
+ t.Errorf("audio[0].title = %q", mi.Audio[0].Title)
+ }
+ if len(mi.Subtitles) != 2 {
+ t.Fatalf("subtitle tracks = %d, want 2", len(mi.Subtitles))
+ }
+ if !mi.Subtitles[1].Forced {
+ t.Error("expected subtitle[1] to be forced")
+ }
+ if len(mi.Languages) != 2 {
+ t.Errorf("languages = %v, want 2 entries", mi.Languages)
+ }
+}
+
+func TestParseFFprobeOutput_BitDepthFromPixFmt(t *testing.T) {
+ data := ffprobeOutput{
+ Streams: []ffprobeStream{
+ {CodecType: "video", CodecName: "hevc", Width: 1920, Height: 1080, PixFmt: "yuv420p10le"},
+ },
+ }
+
+ mi, err := parseFFprobeOutput(data)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if mi.Video.BitDepth != 10 {
+ t.Errorf("bitDepth = %d, want 10", mi.Video.BitDepth)
+ }
+}
+
+func TestParseFFprobeOutput_12BitFromPixFmt(t *testing.T) {
+ data := ffprobeOutput{
+ Streams: []ffprobeStream{
+ {CodecType: "video", CodecName: "hevc", Width: 1920, Height: 1080, PixFmt: "yuv420p12be"},
+ },
+ }
+
+ mi, err := parseFFprobeOutput(data)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if mi.Video.BitDepth != 12 {
+ t.Errorf("bitDepth = %d, want 12", mi.Video.BitDepth)
+ }
+}
+
+func TestParseFFprobeOutput_DurationFromStreamFallback(t *testing.T) {
+ data := ffprobeOutput{
+ Format: ffprobeFormat{Duration: ""},
+ Streams: []ffprobeStream{
+ {CodecType: "video", CodecName: "h264", Width: 1280, Height: 720, Duration: "1800.5"},
+ },
+ }
+
+ mi, err := parseFFprobeOutput(data)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if mi.Video.Duration != 1800.5 {
+ t.Errorf("duration = %v, want 1800.5", mi.Video.Duration)
+ }
+}
+
+func TestParseFFprobeOutput_NoStreams(t *testing.T) {
+ data := ffprobeOutput{}
+ _, err := parseFFprobeOutput(data)
+ if err == nil {
+ t.Error("expected error for no streams")
+ }
+}
+
+func TestParseFFprobeOutput_OnlyFirstVideoStream(t *testing.T) {
+ data := ffprobeOutput{
+ Streams: []ffprobeStream{
+ {CodecType: "video", CodecName: "h264", Width: 1920, Height: 1080},
+ {CodecType: "video", CodecName: "mjpeg", Width: 320, Height: 240}, // cover art
+ },
+ }
+
+ mi, err := parseFFprobeOutput(data)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if mi.Video.Codec != "h264" {
+ t.Errorf("should use first video stream, got codec %q", mi.Video.Codec)
+ }
+ if mi.Video.Width != 1920 {
+ t.Errorf("width = %d, should be from first video stream", mi.Video.Width)
+ }
+}
+
+func TestParseFFprobeOutput_SMPTE2084_WithoutBT2020(t *testing.T) {
+ data := ffprobeOutput{
+ Streams: []ffprobeStream{
+ {CodecType: "video", CodecName: "hevc", Width: 3840, Height: 2160, ColorTransfer: "smpte2084"},
+ },
+ }
+
+ mi, err := parseFFprobeOutput(data)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if mi.Video.HDR != "HDR10" {
+ t.Errorf("hdr = %q, want HDR10", mi.Video.HDR)
+ }
+}
+
+func TestParseFFprobeOutput_AribWithoutBT2020(t *testing.T) {
+ data := ffprobeOutput{
+ Streams: []ffprobeStream{
+ {CodecType: "video", CodecName: "hevc", Width: 3840, Height: 2160, ColorTransfer: "arib-std-b67", ColorSpace: "other"},
+ },
+ }
+
+ mi, err := parseFFprobeOutput(data)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if mi.Video.HDR != "HLG" {
+ t.Errorf("hdr = %q, want HLG", mi.Video.HDR)
+ }
+}
+
+func TestParseFFprobeOutput_AudioOnly(t *testing.T) {
+ data := ffprobeOutput{
+ Streams: []ffprobeStream{
+ {CodecType: "audio", CodecName: "flac", Channels: 2, Tags: map[string]string{"language": "eng"}},
+ },
+ }
+
+ mi, err := parseFFprobeOutput(data)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if mi.Video != nil {
+ t.Error("expected no video info for audio-only")
+ }
+ if len(mi.Audio) != 1 {
+ t.Errorf("audio tracks = %d, want 1", len(mi.Audio))
+ }
+}
+
+func TestParseFFprobeOutput_FrameRateNoSlash(t *testing.T) {
+ data := ffprobeOutput{
+ Streams: []ffprobeStream{
+ {CodecType: "video", CodecName: "h264", Width: 1920, Height: 1080, RFrameRate: "30"},
+ },
+ }
+
+ mi, err := parseFFprobeOutput(data)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if mi.Video.FrameRate != 0 {
+ t.Errorf("frameRate = %v, want 0 (no slash)", mi.Video.FrameRate)
+ }
+}
diff --git a/internal/library/scanner_test.go b/internal/library/scanner_test.go
new file mode 100644
index 0000000..43c84b2
--- /dev/null
+++ b/internal/library/scanner_test.go
@@ -0,0 +1,93 @@
+package library
+
+import (
+ "os"
+ "path/filepath"
+ "testing"
+)
+
+func TestDiscoverFiles(t *testing.T) {
+ dir := t.TempDir()
+
+ // Create video files (need to be >= 100MB to pass size check)
+ largeContent := make([]byte, 101*1024*1024)
+
+ videoFiles := []string{"movie.mkv", "show.mp4", "clip.avi"}
+ for _, name := range videoFiles {
+ path := filepath.Join(dir, name)
+ if err := os.WriteFile(path, largeContent, 0o644); err != nil {
+ t.Fatalf("write %s: %v", name, err)
+ }
+ }
+
+ // Non-video files (should be excluded)
+ nonVideo := []string{"readme.txt", "cover.jpg", "subs.srt"}
+ for _, name := range nonVideo {
+ if err := os.WriteFile(filepath.Join(dir, name), largeContent, 0o644); err != nil {
+ t.Fatalf("write %s: %v", name, err)
+ }
+ }
+
+ // Small video file (should be excluded, < 100MB)
+ if err := os.WriteFile(filepath.Join(dir, "small.mkv"), []byte("small"), 0o644); err != nil {
+ t.Fatal(err)
+ }
+
+ // Excluded pattern (sample)
+ sampleDir := filepath.Join(dir, "sample")
+ os.MkdirAll(sampleDir, 0o755)
+ if err := os.WriteFile(filepath.Join(sampleDir, "sample.mkv"), largeContent, 0o644); err != nil {
+ t.Fatal(err)
+ }
+
+ files, err := discoverFiles(dir)
+ if err != nil {
+ t.Fatalf("discoverFiles: %v", err)
+ }
+
+ if len(files) != 3 {
+ t.Errorf("expected 3 files, got %d: %v", len(files), files)
+ }
+
+ // Check that all returned files are video extensions
+ for _, f := range files {
+ ext := filepath.Ext(f)
+ if ext != ".mkv" && ext != ".mp4" && ext != ".avi" {
+ t.Errorf("unexpected extension: %s", ext)
+ }
+ }
+}
+
+func TestDiscoverFilesEmptyDir(t *testing.T) {
+ dir := t.TempDir()
+
+ files, err := discoverFiles(dir)
+ if err != nil {
+ t.Fatalf("discoverFiles: %v", err)
+ }
+ if len(files) != 0 {
+ t.Errorf("expected 0 files, got %d", len(files))
+ }
+}
+
+func TestDiscoverFilesExcludePatterns(t *testing.T) {
+ dir := t.TempDir()
+ largeContent := make([]byte, 101*1024*1024)
+
+ excludeDirs := []string{"trailer", "featurette", "extras", "bonus"}
+ for _, name := range excludeDirs {
+ sub := filepath.Join(dir, name)
+ os.MkdirAll(sub, 0o755)
+ if err := os.WriteFile(filepath.Join(sub, "video.mkv"), largeContent, 0o644); err != nil {
+ t.Fatal(err)
+ }
+ }
+
+ files, err := discoverFiles(dir)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(files) != 0 {
+ t.Errorf("expected 0 files (all excluded), got %d: %v", len(files), files)
+ }
+}
diff --git a/internal/library/sync_test.go b/internal/library/sync_test.go
new file mode 100644
index 0000000..fe7a113
--- /dev/null
+++ b/internal/library/sync_test.go
@@ -0,0 +1,108 @@
+package library
+
+import (
+ "testing"
+
+ "github.com/torrentclaw/unarr/internal/library/mediainfo"
+)
+
+func TestBuildSyncItems(t *testing.T) {
+ cache := &LibraryCache{
+ Items: []LibraryItem{
+ {
+ FilePath: "/media/movies/Inception.mkv",
+ FileName: "Inception.2010.1080p.mkv",
+ FileSize: 5000000000,
+ Title: "Inception",
+ Year: "2010",
+ MediaInfo: &mediainfo.MediaInfo{
+ Video: &mediainfo.VideoInfo{
+ Codec: "hevc",
+ Width: 1920,
+ Height: 1080,
+ BitDepth: 10,
+ HDR: "HDR10",
+ },
+ Audio: []mediainfo.AudioTrack{
+ {Lang: "en", Codec: "ac3", Channels: 6, Default: true},
+ {Lang: "es", Codec: "aac", Channels: 2},
+ },
+ Subtitles: []mediainfo.SubtitleTrack{
+ {Lang: "en", Codec: "subrip"},
+ {Lang: "es", Codec: "subrip"},
+ },
+ },
+ },
+ {
+ FilePath: "/media/shows/Breaking.Bad.S01E01.mkv",
+ FileName: "Breaking.Bad.S01E01.mkv",
+ FileSize: 1000000000,
+ Title: "Breaking Bad",
+ Season: 1,
+ Episode: 1,
+ },
+ {
+ // Item with scan error — should be skipped
+ FilePath: "/media/bad.mkv",
+ FileName: "bad.mkv",
+ ScanError: "ffprobe failed",
+ },
+ },
+ }
+
+ items := BuildSyncItems(cache)
+
+ if len(items) != 2 {
+ t.Fatalf("expected 2 items (1 skipped), got %d", len(items))
+ }
+
+ // First item: movie with full media info
+ movie := items[0]
+ if movie.Title != "Inception" {
+ t.Errorf("title = %q, want Inception", movie.Title)
+ }
+ if movie.ContentType != "movie" {
+ t.Errorf("contentType = %q, want movie", movie.ContentType)
+ }
+ if movie.Resolution != "1080p" {
+ t.Errorf("resolution = %q, want 1080p", movie.Resolution)
+ }
+ if movie.VideoCodec != "hevc" {
+ t.Errorf("videoCodec = %q, want hevc", movie.VideoCodec)
+ }
+ if movie.HDR != "HDR10" {
+ t.Errorf("hdr = %q, want HDR10", movie.HDR)
+ }
+ if movie.AudioCodec != "ac3" {
+ t.Errorf("audioCodec = %q, want ac3", movie.AudioCodec)
+ }
+ if movie.AudioChannels != 6 {
+ t.Errorf("audioChannels = %d, want 6", movie.AudioChannels)
+ }
+ if len(movie.AudioLanguages) != 2 {
+ t.Errorf("audioLanguages count = %d, want 2", len(movie.AudioLanguages))
+ }
+ if len(movie.SubtitleLanguages) != 2 {
+ t.Errorf("subtitleLanguages count = %d, want 2", len(movie.SubtitleLanguages))
+ }
+
+ // Second item: show without media info
+ show := items[1]
+ if show.ContentType != "show" {
+ t.Errorf("contentType = %q, want show", show.ContentType)
+ }
+ if show.Season != 1 || show.Episode != 1 {
+ t.Errorf("season/episode = %d/%d, want 1/1", show.Season, show.Episode)
+ }
+ if show.Resolution != "" {
+ t.Errorf("resolution should be empty, got %q", show.Resolution)
+ }
+}
+
+func TestBuildSyncItemsEmpty(t *testing.T) {
+ cache := &LibraryCache{Items: nil}
+ items := BuildSyncItems(cache)
+ if len(items) != 0 {
+ t.Errorf("expected 0 items, got %d", len(items))
+ }
+}
diff --git a/internal/mediaserver/detect_test.go b/internal/mediaserver/detect_test.go
index fc5b00e..19ba53c 100644
--- a/internal/mediaserver/detect_test.go
+++ b/internal/mediaserver/detect_test.go
@@ -2,6 +2,8 @@ package mediaserver
import (
"encoding/json"
+ "os"
+ "path/filepath"
"testing"
)
@@ -69,6 +71,96 @@ func TestJellyfinParsing(t *testing.T) {
}
}
+func TestPlexTokenFromPrefs(t *testing.T) {
+ t.Run("valid prefs", func(t *testing.T) {
+ dir := t.TempDir()
+ prefsPath := filepath.Join(dir, "Preferences.xml")
+ xml := `
+`
+ os.WriteFile(prefsPath, []byte(xml), 0o644)
+
+ token := plexTokenFromPrefs(prefsPath)
+ if token != "my-secret-token" {
+ t.Errorf("token = %q, want my-secret-token", token)
+ }
+ })
+
+ t.Run("no token attr", func(t *testing.T) {
+ dir := t.TempDir()
+ prefsPath := filepath.Join(dir, "Preferences.xml")
+ xml := ``
+ os.WriteFile(prefsPath, []byte(xml), 0o644)
+
+ token := plexTokenFromPrefs(prefsPath)
+ if token != "" {
+ t.Errorf("token = %q, want empty", token)
+ }
+ })
+
+ t.Run("file not found", func(t *testing.T) {
+ token := plexTokenFromPrefs("/nonexistent/Preferences.xml")
+ if token != "" {
+ t.Errorf("token = %q, want empty", token)
+ }
+ })
+
+ t.Run("invalid xml", func(t *testing.T) {
+ dir := t.TempDir()
+ prefsPath := filepath.Join(dir, "Preferences.xml")
+ os.WriteFile(prefsPath, []byte("not xml at all"), 0o644)
+
+ token := plexTokenFromPrefs(prefsPath)
+ if token != "" {
+ t.Errorf("token = %q, want empty", token)
+ }
+ })
+}
+
+func TestParsePlexSectionsMultipleLocations(t *testing.T) {
+ body := `{
+ "MediaContainer": {
+ "Directory": [
+ {
+ "title": "Movies",
+ "Location": [
+ {"path": "/media/movies"},
+ {"path": "/media/movies2"}
+ ]
+ }
+ ]
+ }
+ }`
+
+ paths := parsePlexSections([]byte(body))
+ if len(paths) != 2 {
+ t.Fatalf("expected 2 paths, got %d", len(paths))
+ }
+}
+
+func TestParsePlexSectionsEmptyPath(t *testing.T) {
+ body := `{
+ "MediaContainer": {
+ "Directory": [
+ {
+ "Location": [{"path": ""}, {"path": "/valid"}]
+ }
+ ]
+ }
+ }`
+
+ paths := parsePlexSections([]byte(body))
+ if len(paths) != 1 {
+ t.Fatalf("expected 1 path (empty filtered), got %d: %v", len(paths), paths)
+ }
+}
+
+func TestCommonMediaDirs(t *testing.T) {
+ dirs := commonMediaDirs()
+ if len(dirs) == 0 {
+ t.Error("expected at least some common media dirs")
+ }
+}
+
func TestParentDir(t *testing.T) {
tests := []struct {
name string
diff --git a/internal/sentry/sentry_test.go b/internal/sentry/sentry_test.go
new file mode 100644
index 0000000..671e641
--- /dev/null
+++ b/internal/sentry/sentry_test.go
@@ -0,0 +1,47 @@
+package sentry
+
+import "testing"
+
+func TestEnvironment(t *testing.T) {
+ tests := []struct {
+ version string
+ want string
+ }{
+ {"", "development"},
+ {"dev", "development"},
+ {"0.1.0-dev", "development"},
+ {"1.0.0", "production"},
+ {"0.3.5", "production"},
+ {"2.0.0-beta", "production"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.version, func(t *testing.T) {
+ got := environment(tt.version)
+ if got != tt.want {
+ t.Errorf("environment(%q) = %q, want %q", tt.version, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestInitNoOp(t *testing.T) {
+ // With empty dsn (default in tests), Init should be a no-op
+ Init("1.0.0")
+ // Should not panic
+}
+
+func TestCloseNoOp(t *testing.T) {
+ // Close should be safe to call without Init
+ Close()
+}
+
+func TestCaptureErrorNil(t *testing.T) {
+ // Should not panic with nil error
+ CaptureError(nil, "test")
+}
+
+func TestSetUser(t *testing.T) {
+ // Should not panic without initialization
+ SetUser("agent-123")
+}
diff --git a/internal/ui/format_test.go b/internal/ui/format_test.go
index 7f3d2c5..e5c9eda 100644
--- a/internal/ui/format_test.go
+++ b/internal/ui/format_test.go
@@ -1,7 +1,9 @@
package ui
import (
+ "fmt"
"testing"
+ "time"
)
func TestFormatSize(t *testing.T) {
@@ -127,6 +129,10 @@ func TestQualityIndicator(t *testing.T) {
{"medium", intPtr(60), "🟡"},
{"high", intPtr(80), "🟢"},
{"perfect", intPtr(100), "🟢"},
+ {"boundary_40", intPtr(40), "🟡"},
+ {"boundary_70", intPtr(70), "🟢"},
+ {"boundary_39", intPtr(39), "🔴"},
+ {"zero", intPtr(0), "🔴"},
}
for _, tt := range tests {
@@ -139,6 +145,52 @@ func TestQualityIndicator(t *testing.T) {
}
}
+func TestSeedHealthIndicator(t *testing.T) {
+ tests := []struct {
+ seeds int
+ want string
+ }{
+ {0, "🔴"},
+ {5, "🔴"},
+ {9, "🔴"},
+ {10, "🟡"},
+ {50, "🟡"},
+ {100, "🟡"},
+ {101, "🟢"},
+ {1000, "🟢"},
+ }
+
+ for _, tt := range tests {
+ t.Run(fmt.Sprintf("seeds_%d", tt.seeds), func(t *testing.T) {
+ got := SeedHealthIndicator(tt.seeds)
+ if got != tt.want {
+ t.Errorf("SeedHealthIndicator(%d) = %q, want %q", tt.seeds, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestFormatRating(t *testing.T) {
+ tests := []struct {
+ name string
+ input *string
+ want string
+ }{
+ {"nil", nil, "-"},
+ {"value", strPtr("8.5"), "8.5"},
+ {"empty", strPtr(""), ""},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := FormatRating(tt.input)
+ if got != tt.want {
+ t.Errorf("FormatRating(%v) = %q, want %q", tt.input, got, tt.want)
+ }
+ })
+ }
+}
+
func TestStringOrDash(t *testing.T) {
s := "hello"
if got := StringOrDash(&s); got != "hello" {
@@ -150,16 +202,160 @@ func TestStringOrDash(t *testing.T) {
}
func TestFormatContentType(t *testing.T) {
- if got := FormatContentType("movie"); got != "Movie" {
- t.Errorf("FormatContentType(movie) = %q, want Movie", got)
+ tests := []struct {
+ input string
+ want string
+ }{
+ {"movie", "Movie"},
+ {"Movie", "Movie"},
+ {"MOVIE", "Movie"},
+ {"show", "Show"},
+ {"Show", "Show"},
+ {"other", "other"},
+ {"", ""},
}
- if got := FormatContentType("show"); got != "Show" {
- t.Errorf("FormatContentType(show) = %q, want Show", got)
- }
- if got := FormatContentType("other"); got != "other" {
- t.Errorf("FormatContentType(other) = %q, want other", got)
+
+ for _, tt := range tests {
+ t.Run(tt.input, func(t *testing.T) {
+ got := FormatContentType(tt.input)
+ if got != tt.want {
+ t.Errorf("FormatContentType(%q) = %q, want %q", tt.input, got, tt.want)
+ }
+ })
}
}
-func ptr[T any](v T) *T { return &v }
-func intPtr(v int) *int { return &v }
+func TestFormatLanguages(t *testing.T) {
+ tests := []struct {
+ name string
+ input []string
+ want string
+ }{
+ {"nil", nil, "-"},
+ {"empty", []string{}, "-"},
+ {"single", []string{"en"}, "en"},
+ {"multiple", []string{"en", "es", "fr"}, "en, es, fr"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := FormatLanguages(tt.input)
+ if got != tt.want {
+ t.Errorf("FormatLanguages(%v) = %q, want %q", tt.input, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestFormatSeedRatio(t *testing.T) {
+ tests := []struct {
+ seeders int
+ leechers int
+ want string
+ }{
+ {0, 0, "0:0"},
+ {10, 0, "10:0"},
+ {100, 10, "10:1"},
+ {50, 50, "1:1"},
+ {1, 3, "0:1"},
+ {150, 10, "15:1"},
+ }
+
+ for _, tt := range tests {
+ t.Run(fmt.Sprintf("%d_%d", tt.seeders, tt.leechers), func(t *testing.T) {
+ got := FormatSeedRatio(tt.seeders, tt.leechers)
+ if got != tt.want {
+ t.Errorf("FormatSeedRatio(%d, %d) = %q, want %q", tt.seeders, tt.leechers, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestFormatTimeAgo(t *testing.T) {
+ now := time.Now()
+
+ tests := []struct {
+ name string
+ input string
+ want string
+ }{
+ {"invalid", "not-a-date", "not-a-date"},
+ {"just_now", now.Add(-10 * time.Second).Format(time.RFC3339), "just now"},
+ {"minutes", now.Add(-5 * time.Minute).Format(time.RFC3339), "5m ago"},
+ {"hours", now.Add(-3 * time.Hour).Format(time.RFC3339), "3h ago"},
+ {"days", now.Add(-7 * 24 * time.Hour).Format(time.RFC3339), "7d ago"},
+ {"months", now.Add(-60 * 24 * time.Hour).Format(time.RFC3339), "2mo ago"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := FormatTimeAgo(tt.input)
+ if got != tt.want {
+ t.Errorf("FormatTimeAgo(%q) = %q, want %q", tt.input, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestFormatNumberExtended(t *testing.T) {
+ tests := []struct {
+ input int
+ want string
+ }{
+ {-1000, "-1,000"},
+ {-5, "-5"},
+ {10000, "10,000"},
+ {100000, "100,000"},
+ {1000000, "1,000,000"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.want, func(t *testing.T) {
+ got := FormatNumber(tt.input)
+ if got != tt.want {
+ t.Errorf("FormatNumber(%d) = %q, want %q", tt.input, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestTruncateStringEdgeCases(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ maxLen int
+ want string
+ }{
+ {"maxLen_1", "hello", 1, "h"},
+ {"maxLen_3", "hello", 3, "hel"},
+ {"empty", "", 5, ""},
+ {"unicode", "こんにちは世界", 5, "こん..."},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := TruncateString(tt.input, tt.maxLen)
+ if got != tt.want {
+ t.Errorf("TruncateString(%q, %d) = %q, want %q", tt.input, tt.maxLen, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestPtr(t *testing.T) {
+ v := 42
+ p := Ptr(v)
+ if *p != 42 {
+ t.Errorf("Ptr(42) = %d, want 42", *p)
+ }
+
+ s := "hello"
+ sp := Ptr(s)
+ if *sp != "hello" {
+ t.Errorf("Ptr(hello) = %q, want hello", *sp)
+ }
+}
+
+func ptr[T any](v T) *T { return &v }
+func intPtr(v int) *int { return &v }
+func strPtr(v string) *string { return &v }
diff --git a/internal/ui/table_test.go b/internal/ui/table_test.go
new file mode 100644
index 0000000..3b9abff
--- /dev/null
+++ b/internal/ui/table_test.go
@@ -0,0 +1,122 @@
+package ui
+
+import (
+ "bytes"
+ "strings"
+ "testing"
+
+ tc "github.com/torrentclaw/go-client"
+)
+
+func TestNewCleanTable(t *testing.T) {
+ var buf bytes.Buffer
+ tbl := newCleanTable(&buf)
+ tbl.Header([]string{"A", "B"})
+ tbl.Append([]string{"1", "2"})
+ tbl.Render()
+
+ out := buf.String()
+ if !strings.Contains(out, "A") || !strings.Contains(out, "B") {
+ t.Errorf("expected headers in output, got: %s", out)
+ }
+ if !strings.Contains(out, "1") || !strings.Contains(out, "2") {
+ t.Errorf("expected row data in output, got: %s", out)
+ }
+}
+
+func TestPrintSearchResultEntry(t *testing.T) {
+ var buf bytes.Buffer
+ year := 2010
+ rating := "8.8"
+ quality := "1080p"
+ codec := "x265"
+ size := int64(4294967296)
+ score := 85
+
+ r := tc.SearchResult{
+ Title: "Inception",
+ Year: &year,
+ ContentType: "movie",
+ RatingIMDb: &rating,
+ Genres: []string{"Sci-Fi", "Action"},
+ Torrents: []tc.TorrentInfo{
+ {
+ Quality: &quality,
+ SizeBytes: &size,
+ Seeders: 150,
+ Leechers: 10,
+ Source: "YTS",
+ Codec: &codec,
+ Languages: []string{"en", "es"},
+ QualityScore: &score,
+ },
+ },
+ }
+
+ printSearchResultEntry(&buf, r)
+ out := buf.String()
+
+ if !strings.Contains(out, "Inception") {
+ t.Error("expected title in output")
+ }
+ if !strings.Contains(out, "2010") {
+ t.Error("expected year in output")
+ }
+ if !strings.Contains(out, "8.8") {
+ t.Error("expected rating in output")
+ }
+ if !strings.Contains(out, "1080p") {
+ t.Error("expected quality in output")
+ }
+ if !strings.Contains(out, "YTS") {
+ t.Error("expected source in output")
+ }
+}
+
+func TestPrintSearchResultEntryNoTorrents(t *testing.T) {
+ var buf bytes.Buffer
+ year := 2020
+
+ r := tc.SearchResult{
+ Title: "No Torrents Movie",
+ Year: &year,
+ ContentType: "movie",
+ Torrents: nil,
+ }
+
+ printSearchResultEntry(&buf, r)
+ out := buf.String()
+
+ if !strings.Contains(out, "No Torrents Movie") {
+ t.Error("expected title in output")
+ }
+ if !strings.Contains(out, "No torrents available") {
+ t.Error("expected no-torrents message")
+ }
+}
+
+func TestPrintSearchResultEntryNilFields(t *testing.T) {
+ var buf bytes.Buffer
+
+ r := tc.SearchResult{
+ Title: "Minimal",
+ ContentType: "movie",
+ Torrents: []tc.TorrentInfo{
+ {
+ Seeders: 5,
+ Leechers: 3,
+ Source: "RARBG",
+ },
+ },
+ }
+
+ printSearchResultEntry(&buf, r)
+ out := buf.String()
+
+ if !strings.Contains(out, "Minimal") {
+ t.Error("expected title in output")
+ }
+ if !strings.Contains(out, "-") {
+ t.Error("expected dash for nil fields")
+ }
+}
diff --git a/internal/upgrade/cache.go b/internal/upgrade/cache.go
new file mode 100644
index 0000000..7bdcfb0
--- /dev/null
+++ b/internal/upgrade/cache.go
@@ -0,0 +1,75 @@
+package upgrade
+
+import (
+ "context"
+ "encoding/json"
+ "os"
+ "path/filepath"
+ "time"
+
+ "github.com/torrentclaw/unarr/internal/config"
+)
+
+const cacheTTL = 1 * time.Hour
+
+// versionCache is the on-disk structure for cached version checks.
+type versionCache struct {
+ Version string `json:"version"`
+ CheckedAt time.Time `json:"checkedAt"`
+}
+
+// cacheFilePath returns the path to the version cache file.
+func cacheFilePath() string {
+ return filepath.Join(config.DataDir(), "latest-version.json")
+}
+
+// ReadCachedVersion returns the cached latest version if it's fresh (< cacheTTL).
+// Returns empty string if cache is missing, stale, or corrupt.
+func ReadCachedVersion() string {
+ data, err := os.ReadFile(cacheFilePath())
+ if err != nil {
+ return ""
+ }
+ var c versionCache
+ if json.Unmarshal(data, &c) != nil {
+ return ""
+ }
+ if time.Since(c.CheckedAt) > cacheTTL {
+ return ""
+ }
+ return c.Version
+}
+
+// writeCachedVersion writes the latest version to the cache file.
+func writeCachedVersion(version string) {
+ c := versionCache{
+ Version: version,
+ CheckedAt: time.Now(),
+ }
+ data, err := json.Marshal(c)
+ if err != nil {
+ return
+ }
+ path := cacheFilePath()
+ os.MkdirAll(filepath.Dir(path), 0o755)
+ // Best-effort write — ignore errors
+ tmp := path + ".tmp"
+ if err := os.WriteFile(tmp, data, 0o644); err != nil {
+ return
+ }
+ os.Rename(tmp, path)
+}
+
+// CheckLatestCached returns the latest version, using cache when fresh.
+// If cache is stale, fetches from GitHub and updates the cache.
+func CheckLatestCached(ctx context.Context) (version string, fromCache bool, err error) {
+ if cached := ReadCachedVersion(); cached != "" {
+ return cached, true, nil
+ }
+ v, err := fetchLatestVersion(ctx)
+ if err != nil {
+ return "", false, err
+ }
+ writeCachedVersion(v)
+ return v, false, nil
+}
diff --git a/internal/upgrade/upgrade.go b/internal/upgrade/upgrade.go
index b70dc7e..5d31308 100644
--- a/internal/upgrade/upgrade.go
+++ b/internal/upgrade/upgrade.go
@@ -152,9 +152,13 @@ func (u *Upgrader) fail(format string, args ...any) Result {
}
}
-// CheckLatest fetches the latest version from GitHub API.
+// CheckLatest fetches the latest version from GitHub API and updates the cache.
func CheckLatest(ctx context.Context) (string, error) {
- return fetchLatestVersion(ctx)
+ v, err := fetchLatestVersion(ctx)
+ if err == nil {
+ writeCachedVersion(v)
+ }
+ return v, err
}
// installBinary copies the new binary to the target path, preserving original permissions.
diff --git a/internal/upgrade/upgrade_test.go b/internal/upgrade/upgrade_test.go
index 2753005..b8805db 100644
--- a/internal/upgrade/upgrade_test.go
+++ b/internal/upgrade/upgrade_test.go
@@ -12,6 +12,7 @@ import (
"os"
"path/filepath"
"runtime"
+ "strings"
"testing"
)
@@ -305,3 +306,768 @@ func TestFetchLatestVersionMockServer(t *testing.T) {
t.Errorf("status = %d, want 200", resp.StatusCode)
}
}
+
+// --- New tests below ---
+
+// swapHTTPClient replaces the package-level httpClient and returns a restore function.
+func swapHTTPClient(c *http.Client) func() {
+ orig := httpClient
+ httpClient = c
+ return func() { httpClient = orig }
+}
+
+// rewriteTransport redirects all requests to the given base URL,
+// preserving path and query.
+type rewriteTransport struct {
+ url string
+}
+
+func (rt *rewriteTransport) RoundTrip(req *http.Request) (*http.Response, error) {
+ req.URL.Scheme = "http"
+ u := strings.TrimPrefix(rt.url, "http://")
+ req.URL.Host = u
+ return http.DefaultTransport.RoundTrip(req)
+}
+
+// createTarGz is a test helper that creates a tar.gz file with a single file entry.
+func createTarGz(t *testing.T, archivePath, entryName string, content []byte) {
+ t.Helper()
+ f, err := os.Create(archivePath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ gw := gzip.NewWriter(f)
+ tw := tar.NewWriter(gw)
+
+ tw.WriteHeader(&tar.Header{
+ Name: entryName,
+ Mode: 0o755,
+ Size: int64(len(content)),
+ })
+ tw.Write(content)
+
+ tw.Close()
+ gw.Close()
+ f.Close()
+}
+
+func TestArchiveNameTableDriven(t *testing.T) {
+ // We can only run archiveName for the current GOOS/GOARCH,
+ // so we test several version strings and verify the pattern.
+ tests := []struct {
+ version string
+ }{
+ {"0.1.0"},
+ {"1.0.0-rc1"},
+ {"2.5.10"},
+ {"0.0.0"},
+ }
+ for _, tt := range tests {
+ t.Run(tt.version, func(t *testing.T) {
+ got := archiveName(tt.version)
+ prefix := fmt.Sprintf("unarr_%s_%s_%s.", tt.version, runtime.GOOS, runtime.GOARCH)
+ if !strings.HasPrefix(got, prefix) {
+ t.Errorf("archiveName(%q) = %q, want prefix %q", tt.version, got, prefix)
+ }
+ if runtime.GOOS == "windows" {
+ if !strings.HasSuffix(got, ".zip") {
+ t.Errorf("archiveName on windows should end with .zip, got %q", got)
+ }
+ } else {
+ if !strings.HasSuffix(got, ".tar.gz") {
+ t.Errorf("archiveName on non-windows should end with .tar.gz, got %q", got)
+ }
+ }
+ })
+ }
+}
+
+func TestReleaseURLEdgeCases(t *testing.T) {
+ tests := []struct {
+ name string
+ version string
+ filename string
+ wantURL string
+ }{
+ {
+ name: "pre-release version",
+ version: "2.0.0-beta.1",
+ filename: "unarr_2.0.0-beta.1_darwin_arm64.tar.gz",
+ wantURL: "https://github.com/torrentclaw/unarr/releases/download/v2.0.0-beta.1/unarr_2.0.0-beta.1_darwin_arm64.tar.gz",
+ },
+ {
+ name: "checksums file",
+ version: "3.0.0",
+ filename: "checksums.txt",
+ wantURL: "https://github.com/torrentclaw/unarr/releases/download/v3.0.0/checksums.txt",
+ },
+ {
+ name: "windows zip",
+ version: "1.2.3",
+ filename: "unarr_1.2.3_windows_amd64.zip",
+ wantURL: "https://github.com/torrentclaw/unarr/releases/download/v1.2.3/unarr_1.2.3_windows_amd64.zip",
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := releaseURL(tt.version, tt.filename)
+ if got != tt.wantURL {
+ t.Errorf("releaseURL(%q, %q) = %q, want %q", tt.version, tt.filename, got, tt.wantURL)
+ }
+ })
+ }
+}
+
+func TestExtractBinaryDispatcher(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.Skip("extractBinary dispatcher test only on unix")
+ }
+
+ dir := t.TempDir()
+
+ // Create a valid tar.gz with the unarr binary
+ archivePath := filepath.Join(dir, "test.tar.gz")
+ binaryContent := []byte("#!/bin/sh\necho dispatcher test\n")
+ createTarGz(t, archivePath, "unarr", binaryContent)
+
+ destDir := filepath.Join(dir, "out")
+ os.MkdirAll(destDir, 0o755)
+
+ binPath, err := extractBinary(archivePath, destDir)
+ if err != nil {
+ t.Fatalf("extractBinary() error = %v", err)
+ }
+ if filepath.Base(binPath) != "unarr" {
+ t.Errorf("extractBinary() returned %q, want base name 'unarr'", binPath)
+ }
+ data, _ := os.ReadFile(binPath)
+ if string(data) != string(binaryContent) {
+ t.Error("extractBinary() content mismatch")
+ }
+}
+
+func TestExtractBinaryInvalidArchive(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.Skip("tar.gz test only on unix")
+ }
+
+ dir := t.TempDir()
+ archivePath := filepath.Join(dir, "garbage.tar.gz")
+ os.WriteFile(archivePath, []byte("this is not a tar.gz"), 0o644)
+
+ destDir := filepath.Join(dir, "out")
+ os.MkdirAll(destDir, 0o755)
+
+ _, err := extractBinary(archivePath, destDir)
+ if err == nil {
+ t.Error("extractBinary with garbage data should return error")
+ }
+}
+
+func TestExtractBinaryNonExistentArchive(t *testing.T) {
+ dir := t.TempDir()
+ _, err := extractBinary("/nonexistent-archive-file", filepath.Join(dir, "out"))
+ if err == nil {
+ t.Error("extractBinary with nonexistent file should return error")
+ }
+}
+
+func TestUpgraderFail(t *testing.T) {
+ u := &Upgrader{CurrentVersion: "1.0.0"}
+
+ // Capture log messages
+ var logged []string
+ u.OnProgress = func(msg string) { logged = append(logged, msg) }
+
+ result := u.fail("something went wrong: %d", 42)
+
+ if result.Success {
+ t.Error("fail() should return Success=false")
+ }
+ if result.OldVersion != "1.0.0" {
+ t.Errorf("fail() OldVersion = %q, want 1.0.0", result.OldVersion)
+ }
+ if result.Error == nil {
+ t.Fatal("fail() should set Error")
+ }
+ if !strings.Contains(result.Error.Error(), "something went wrong: 42") {
+ t.Errorf("fail() Error = %q, want to contain 'something went wrong: 42'", result.Error)
+ }
+ if len(logged) == 0 {
+ t.Error("fail() should call OnProgress")
+ }
+ if len(logged) > 0 && !strings.Contains(logged[0], "FAILED") {
+ t.Errorf("fail() logged %q, want to contain 'FAILED'", logged[0])
+ }
+}
+
+func TestUpgraderFailNilOnProgress(t *testing.T) {
+ u := &Upgrader{CurrentVersion: "2.0.0"}
+ // OnProgress is nil — should not panic
+ result := u.fail("error without listener")
+ if result.Success {
+ t.Error("fail() should return Success=false")
+ }
+}
+
+func TestFetchLatestVersionWithHTTPTest(t *testing.T) {
+ tests := []struct {
+ name string
+ body string
+ statusCode int
+ wantVer string
+ wantErr bool
+ }{
+ {
+ name: "valid response",
+ body: `{"tag_name":"v3.1.4"}`,
+ statusCode: 200,
+ wantVer: "3.1.4",
+ },
+ {
+ name: "valid response without v prefix",
+ body: `{"tag_name":"2.0.0"}`,
+ statusCode: 200,
+ wantVer: "2.0.0",
+ },
+ {
+ name: "empty tag_name",
+ body: `{"tag_name":""}`,
+ statusCode: 200,
+ wantErr: true,
+ },
+ {
+ name: "server error",
+ body: `Internal Server Error`,
+ statusCode: 500,
+ wantErr: true,
+ },
+ {
+ name: "invalid json",
+ body: `{invalid`,
+ statusCode: 200,
+ wantErr: true,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(tt.statusCode)
+ fmt.Fprint(w, tt.body)
+ }))
+ defer srv.Close()
+
+ // Use a custom transport that rewrites requests to our test server
+ restore := swapHTTPClient(&http.Client{
+ Transport: &rewriteTransport{url: srv.URL},
+ })
+ defer restore()
+
+ ver, err := CheckLatest(context.Background())
+ if tt.wantErr {
+ if err == nil {
+ t.Errorf("CheckLatest() error = nil, want error")
+ }
+ return
+ }
+ if err != nil {
+ t.Fatalf("CheckLatest() error = %v", err)
+ }
+ if ver != tt.wantVer {
+ t.Errorf("CheckLatest() = %q, want %q", ver, tt.wantVer)
+ }
+ })
+ }
+}
+
+func TestDownloadWithHTTPTest(t *testing.T) {
+ archiveBody := "fake-archive-bytes-for-download-test"
+
+ t.Run("successful download", func(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // Verify user-agent header
+ if ua := r.Header.Get("User-Agent"); ua != "unarr-updater" {
+ t.Errorf("User-Agent = %q, want 'unarr-updater'", ua)
+ }
+ w.WriteHeader(200)
+ fmt.Fprint(w, archiveBody)
+ }))
+ defer srv.Close()
+
+ restore := swapHTTPClient(&http.Client{
+ Transport: &rewriteTransport{url: srv.URL},
+ })
+ defer restore()
+
+ path, err := download(context.Background(), "1.0.0")
+ if err != nil {
+ t.Fatalf("download() error = %v", err)
+ }
+ defer os.Remove(path)
+
+ data, err := os.ReadFile(path)
+ if err != nil {
+ t.Fatalf("read downloaded file: %v", err)
+ }
+ if string(data) != archiveBody {
+ t.Errorf("downloaded content = %q, want %q", data, archiveBody)
+ }
+ })
+
+ t.Run("server returns 404", func(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(404)
+ }))
+ defer srv.Close()
+
+ restore := swapHTTPClient(&http.Client{
+ Transport: &rewriteTransport{url: srv.URL},
+ })
+ defer restore()
+
+ _, err := download(context.Background(), "99.99.99")
+ if err == nil {
+ t.Error("download() with 404 should return error")
+ }
+ if !strings.Contains(err.Error(), "HTTP 404") {
+ t.Errorf("download() error = %q, want to contain 'HTTP 404'", err)
+ }
+ })
+
+ t.Run("cancelled context", func(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(200)
+ fmt.Fprint(w, "data")
+ }))
+ defer srv.Close()
+
+ restore := swapHTTPClient(&http.Client{
+ Transport: &rewriteTransport{url: srv.URL},
+ })
+ defer restore()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel() // cancel immediately
+
+ _, err := download(ctx, "1.0.0")
+ if err == nil {
+ t.Error("download() with cancelled context should return error")
+ }
+ })
+}
+
+func TestVerifyChecksumWithHTTPTest(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.Skip("tar.gz test only on unix")
+ }
+
+ // Create a fake archive file
+ dir := t.TempDir()
+ archiveContent := []byte("archive-content-for-checksum-test")
+ archivePath := filepath.Join(dir, "test-archive.tar.gz")
+ os.WriteFile(archivePath, archiveContent, 0o644)
+
+ h := sha256.Sum256(archiveContent)
+ correctHash := hex.EncodeToString(h[:])
+
+ // The function builds the archive name using archiveName(), which uses runtime.GOOS/GOARCH.
+ expectedArchiveName := archiveName("1.0.0")
+
+ t.Run("matching checksum", func(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprintf(w, "0000000000000000000000000000000000000000000000000000000000000000 other_file.tar.gz\n")
+ fmt.Fprintf(w, "%s %s\n", correctHash, expectedArchiveName)
+ }))
+ defer srv.Close()
+
+ restore := swapHTTPClient(&http.Client{
+ Transport: &rewriteTransport{url: srv.URL},
+ })
+ defer restore()
+
+ err := verifyChecksum(context.Background(), "1.0.0", archivePath)
+ if err != nil {
+ t.Errorf("verifyChecksum() = %v, want nil", err)
+ }
+ })
+
+ t.Run("mismatched checksum", func(t *testing.T) {
+ wrongHash := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprintf(w, "%s %s\n", wrongHash, expectedArchiveName)
+ }))
+ defer srv.Close()
+
+ restore := swapHTTPClient(&http.Client{
+ Transport: &rewriteTransport{url: srv.URL},
+ })
+ defer restore()
+
+ err := verifyChecksum(context.Background(), "1.0.0", archivePath)
+ if err == nil {
+ t.Error("verifyChecksum() with wrong hash should return error")
+ }
+ if !strings.Contains(err.Error(), "SHA256 mismatch") {
+ t.Errorf("verifyChecksum() error = %q, want to contain 'SHA256 mismatch'", err)
+ }
+ })
+
+ t.Run("archive not in checksums", func(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprintf(w, "abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890 some_other_file.tar.gz\n")
+ }))
+ defer srv.Close()
+
+ restore := swapHTTPClient(&http.Client{
+ Transport: &rewriteTransport{url: srv.URL},
+ })
+ defer restore()
+
+ err := verifyChecksum(context.Background(), "1.0.0", archivePath)
+ if err == nil {
+ t.Error("verifyChecksum() with missing entry should return error")
+ }
+ if !strings.Contains(err.Error(), "no checksum found") {
+ t.Errorf("verifyChecksum() error = %q, want to contain 'no checksum found'", err)
+ }
+ })
+
+ t.Run("checksums server error", func(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(500)
+ }))
+ defer srv.Close()
+
+ restore := swapHTTPClient(&http.Client{
+ Transport: &rewriteTransport{url: srv.URL},
+ })
+ defer restore()
+
+ err := verifyChecksum(context.Background(), "1.0.0", archivePath)
+ if err == nil {
+ t.Error("verifyChecksum() with server error should return error")
+ }
+ if !strings.Contains(err.Error(), "HTTP 500") {
+ t.Errorf("verifyChecksum() error = %q, want to contain 'HTTP 500'", err)
+ }
+ })
+
+ t.Run("nonexistent archive file", func(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprintf(w, "%s %s\n", correctHash, expectedArchiveName)
+ }))
+ defer srv.Close()
+
+ restore := swapHTTPClient(&http.Client{
+ Transport: &rewriteTransport{url: srv.URL},
+ })
+ defer restore()
+
+ err := verifyChecksum(context.Background(), "1.0.0", "/nonexistent-archive-path")
+ if err == nil {
+ t.Error("verifyChecksum() with nonexistent archive should return error")
+ }
+ })
+}
+
+func TestVerifyChecksumCaseInsensitive(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.Skip("tar.gz test only on unix")
+ }
+
+ dir := t.TempDir()
+ archiveContent := []byte("case-insensitive-hash-test")
+ archivePath := filepath.Join(dir, "archive.tar.gz")
+ os.WriteFile(archivePath, archiveContent, 0o644)
+
+ h := sha256.Sum256(archiveContent)
+ // Use uppercase hash in checksums.txt — verifyChecksum uses EqualFold
+ upperHash := strings.ToUpper(hex.EncodeToString(h[:]))
+ expectedArchiveName := archiveName("1.0.0")
+
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprintf(w, "%s %s\n", upperHash, expectedArchiveName)
+ }))
+ defer srv.Close()
+
+ restore := swapHTTPClient(&http.Client{
+ Transport: &rewriteTransport{url: srv.URL},
+ })
+ defer restore()
+
+ err := verifyChecksum(context.Background(), "1.0.0", archivePath)
+ if err != nil {
+ t.Errorf("verifyChecksum() with uppercase hash = %v, want nil", err)
+ }
+}
+
+func TestExtractTarGzNestedDirectory(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.Skip("tar.gz test only on unix")
+ }
+
+ dir := t.TempDir()
+ archivePath := filepath.Join(dir, "nested.tar.gz")
+
+ binaryContent := []byte("#!/bin/sh\necho nested\n")
+
+ f, err := os.Create(archivePath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ gw := gzip.NewWriter(f)
+ tw := tar.NewWriter(gw)
+
+ // Write a directory entry first
+ tw.WriteHeader(&tar.Header{
+ Name: "unarr_1.0.0_linux_amd64/",
+ Typeflag: tar.TypeDir,
+ Mode: 0o755,
+ })
+
+ // Write a README in the subdirectory
+ readmeContent := []byte("This is a README")
+ tw.WriteHeader(&tar.Header{
+ Name: "unarr_1.0.0_linux_amd64/README.md",
+ Mode: 0o644,
+ Size: int64(len(readmeContent)),
+ })
+ tw.Write(readmeContent)
+
+ // Write the binary nested inside the directory
+ tw.WriteHeader(&tar.Header{
+ Name: "unarr_1.0.0_linux_amd64/unarr",
+ Mode: 0o755,
+ Size: int64(len(binaryContent)),
+ })
+ tw.Write(binaryContent)
+
+ // Write another unrelated file after the binary
+ licenseContent := []byte("MIT License")
+ tw.WriteHeader(&tar.Header{
+ Name: "unarr_1.0.0_linux_amd64/LICENSE",
+ Mode: 0o644,
+ Size: int64(len(licenseContent)),
+ })
+ tw.Write(licenseContent)
+
+ tw.Close()
+ gw.Close()
+ f.Close()
+
+ destDir := filepath.Join(dir, "out")
+ os.MkdirAll(destDir, 0o755)
+
+ binPath, err := extractTarGz(archivePath, destDir)
+ if err != nil {
+ t.Fatalf("extractTarGz() with nested dir = %v", err)
+ }
+
+ if filepath.Base(binPath) != "unarr" {
+ t.Errorf("extracted name = %q, want 'unarr'", filepath.Base(binPath))
+ }
+
+ data, _ := os.ReadFile(binPath)
+ if string(data) != string(binaryContent) {
+ t.Error("extracted content does not match")
+ }
+
+ // Verify that only the binary was extracted (README and LICENSE should NOT be in destDir)
+ entries, _ := os.ReadDir(destDir)
+ if len(entries) != 1 {
+ names := make([]string, len(entries))
+ for i, e := range entries {
+ names[i] = e.Name()
+ }
+ t.Errorf("destDir should contain only 'unarr', got %v", names)
+ }
+}
+
+func TestExtractTarGzMultipleFiles(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.Skip("tar.gz test only on unix")
+ }
+
+ dir := t.TempDir()
+ archivePath := filepath.Join(dir, "multi.tar.gz")
+
+ binaryContent := []byte("#!/bin/sh\necho multi\n")
+
+ f, err := os.Create(archivePath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ gw := gzip.NewWriter(f)
+ tw := tar.NewWriter(gw)
+
+ // Several non-binary files before the actual binary
+ for _, name := range []string{"README.md", "LICENSE", "config.yaml", "completions.bash"} {
+ content := []byte("content of " + name)
+ tw.WriteHeader(&tar.Header{
+ Name: name,
+ Mode: 0o644,
+ Size: int64(len(content)),
+ })
+ tw.Write(content)
+ }
+
+ // The actual binary
+ tw.WriteHeader(&tar.Header{
+ Name: "unarr",
+ Mode: 0o755,
+ Size: int64(len(binaryContent)),
+ })
+ tw.Write(binaryContent)
+
+ tw.Close()
+ gw.Close()
+ f.Close()
+
+ destDir := filepath.Join(dir, "out")
+ os.MkdirAll(destDir, 0o755)
+
+ binPath, err := extractTarGz(archivePath, destDir)
+ if err != nil {
+ t.Fatalf("extractTarGz() = %v", err)
+ }
+
+ data, _ := os.ReadFile(binPath)
+ if string(data) != string(binaryContent) {
+ t.Error("binary content mismatch among multiple files")
+ }
+}
+
+func TestExtractTarGzSymlinkSkipped(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.Skip("tar.gz test only on unix")
+ }
+
+ dir := t.TempDir()
+ archivePath := filepath.Join(dir, "symlink.tar.gz")
+
+ f, err := os.Create(archivePath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ gw := gzip.NewWriter(f)
+ tw := tar.NewWriter(gw)
+
+ // Write a symlink entry named "unarr" — should be skipped because Typeflag != TypeReg
+ tw.WriteHeader(&tar.Header{
+ Name: "unarr",
+ Typeflag: tar.TypeSymlink,
+ Linkname: "/etc/passwd",
+ Mode: 0o755,
+ })
+
+ tw.Close()
+ gw.Close()
+ f.Close()
+
+ destDir := filepath.Join(dir, "out")
+ os.MkdirAll(destDir, 0o755)
+
+ _, err = extractTarGz(archivePath, destDir)
+ if err == nil {
+ t.Error("extractTarGz() should return error when binary is a symlink (not TypeReg)")
+ }
+ if !strings.Contains(err.Error(), "not found in archive") {
+ t.Errorf("error = %q, want to contain 'not found in archive'", err)
+ }
+}
+
+func TestInstallBinaryNonExistentSource(t *testing.T) {
+ dir := t.TempDir()
+ dst := filepath.Join(dir, "output")
+
+ err := installBinary("/nonexistent-source-binary", dst)
+ if err == nil {
+ t.Error("installBinary with nonexistent source should return error")
+ }
+ if !strings.Contains(err.Error(), "read new binary") {
+ t.Errorf("error = %q, want to contain 'read new binary'", err)
+ }
+
+ // Verify destination was not created
+ if _, statErr := os.Stat(dst); statErr == nil {
+ t.Error("destination file should not exist after failed install")
+ }
+}
+
+func TestInstallBinaryUnwritableDestination(t *testing.T) {
+ dir := t.TempDir()
+ src := filepath.Join(dir, "source")
+ os.WriteFile(src, []byte("binary"), 0o755)
+
+ // Try to write to a path inside a non-existent directory
+ dst := filepath.Join(dir, "nonexistent-subdir", "binary")
+
+ err := installBinary(src, dst)
+ if err == nil {
+ t.Error("installBinary to non-writable destination should return error")
+ }
+ if !strings.Contains(err.Error(), "write binary") {
+ t.Errorf("error = %q, want to contain 'write binary'", err)
+ }
+}
+
+func TestUpgraderLog(t *testing.T) {
+ var messages []string
+ u := &Upgrader{
+ CurrentVersion: "1.0.0",
+ OnProgress: func(msg string) { messages = append(messages, msg) },
+ }
+
+ u.log("hello world")
+ if len(messages) != 1 || messages[0] != "hello world" {
+ t.Errorf("log() messages = %v, want [hello world]", messages)
+ }
+}
+
+func TestUpgraderLogNilOnProgress(t *testing.T) {
+ u := &Upgrader{CurrentVersion: "1.0.0"}
+ // Should not panic
+ u.log("test message with nil OnProgress")
+}
+
+func TestResultFields(t *testing.T) {
+ r := Result{
+ Success: true,
+ OldVersion: "1.0.0",
+ NewVersion: "2.0.0",
+ BackupPath: "/tmp/backup",
+ }
+ if !r.Success || r.OldVersion != "1.0.0" || r.NewVersion != "2.0.0" || r.BackupPath != "/tmp/backup" {
+ t.Errorf("Result fields not set correctly: %+v", r)
+ }
+
+ r2 := Result{Success: false, Error: fmt.Errorf("test error")}
+ if r2.Success || r2.Error == nil {
+ t.Errorf("Result error case not correct: %+v", r2)
+ }
+}
+
+func TestDownloadSetsUserAgent(t *testing.T) {
+ var gotUA string
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ gotUA = r.Header.Get("User-Agent")
+ w.WriteHeader(200)
+ fmt.Fprint(w, "data")
+ }))
+ defer srv.Close()
+
+ restore := swapHTTPClient(&http.Client{
+ Transport: &rewriteTransport{url: srv.URL},
+ })
+ defer restore()
+
+ path, err := download(context.Background(), "1.0.0")
+ if err != nil {
+ t.Fatalf("download() = %v", err)
+ }
+ defer os.Remove(path)
+
+ if gotUA != "unarr-updater" {
+ t.Errorf("User-Agent = %q, want 'unarr-updater'", gotUA)
+ }
+}
diff --git a/internal/usenet/download/progress_expand_test.go b/internal/usenet/download/progress_expand_test.go
new file mode 100644
index 0000000..0ce1b4f
--- /dev/null
+++ b/internal/usenet/download/progress_expand_test.go
@@ -0,0 +1,632 @@
+package download
+
+import (
+ "encoding/binary"
+ "os"
+ "path/filepath"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/torrentclaw/unarr/internal/usenet/nzb"
+)
+
+// --- Fingerprint ---
+
+func TestFingerprint_EmptyNZB(t *testing.T) {
+ n := &nzb.NZB{}
+ fp := Fingerprint(n)
+ // Empty NZB should still produce a deterministic hash (of zero message IDs).
+ fp2 := Fingerprint(n)
+ if fp != fp2 {
+ t.Fatal("fingerprint of empty NZB should be deterministic")
+ }
+}
+
+func TestFingerprint_OrderIndependent(t *testing.T) {
+ // Fingerprint sorts IDs, so different file order should produce the same hash.
+ n1 := &nzb.NZB{
+ Files: []nzb.File{
+ {Segments: []nzb.Segment{{MessageID: "a@x"}, {MessageID: "b@x"}}},
+ {Segments: []nzb.Segment{{MessageID: "c@x"}}},
+ },
+ }
+ n2 := &nzb.NZB{
+ Files: []nzb.File{
+ {Segments: []nzb.Segment{{MessageID: "c@x"}}},
+ {Segments: []nzb.Segment{{MessageID: "b@x"}, {MessageID: "a@x"}}},
+ },
+ }
+ if Fingerprint(n1) != Fingerprint(n2) {
+ t.Fatal("fingerprint should be order-independent (sorted by message ID)")
+ }
+}
+
+func TestFingerprint_SingleSegment(t *testing.T) {
+ n := &nzb.NZB{
+ Files: []nzb.File{
+ {Segments: []nzb.Segment{{MessageID: "only@one"}}},
+ },
+ }
+ fp := Fingerprint(n)
+ if fp == [32]byte{} {
+ t.Fatal("fingerprint should not be zero for a non-empty NZB")
+ }
+}
+
+// --- ProgressTracker MarkDone idempotency ---
+
+func TestProgressTracker_MarkDoneIdempotent(t *testing.T) {
+ dir := t.TempDir()
+ n := makeTestNZB(1, 5)
+ tracker := NewProgressTracker("idem", n, dir)
+
+ tracker.MarkDone(0, 2)
+ if tracker.CompletedSegments(0) != 1 {
+ t.Fatalf("expected 1, got %d", tracker.CompletedSegments(0))
+ }
+
+ // Mark the same segment again — count should not increase.
+ tracker.MarkDone(0, 2)
+ if tracker.CompletedSegments(0) != 1 {
+ t.Fatalf("idempotent mark: expected 1, got %d", tracker.CompletedSegments(0))
+ }
+}
+
+// --- ProgressTracker TotalCompleted ---
+
+func TestProgressTracker_TotalCompleted(t *testing.T) {
+ dir := t.TempDir()
+ n := makeTestNZB(3, 4) // 3 files, 4 segs each
+ tracker := NewProgressTracker("total", n, dir)
+
+ tracker.MarkDone(0, 0)
+ tracker.MarkDone(0, 1)
+ tracker.MarkDone(1, 3)
+ tracker.MarkDone(2, 0)
+ tracker.MarkDone(2, 1)
+ tracker.MarkDone(2, 2)
+
+ if got := tracker.TotalCompleted(); got != 6 {
+ t.Errorf("TotalCompleted: got %d, want 6", got)
+ }
+}
+
+func TestProgressTracker_TotalCompleted_Empty(t *testing.T) {
+ dir := t.TempDir()
+ n := makeTestNZB(2, 3)
+ tracker := NewProgressTracker("empty-total", n, dir)
+
+ if got := tracker.TotalCompleted(); got != 0 {
+ t.Errorf("TotalCompleted on fresh tracker: got %d, want 0", got)
+ }
+}
+
+// --- CompletedBytes edge cases ---
+
+func TestProgressTracker_CompletedBytes_OutOfBounds(t *testing.T) {
+ dir := t.TempDir()
+ n := makeTestNZB(1, 3)
+ tracker := NewProgressTracker("cb-oob", n, dir)
+
+ if got := tracker.CompletedBytes(-1, n.Files[0].Segments); got != 0 {
+ t.Errorf("CompletedBytes with file -1: got %d, want 0", got)
+ }
+ if got := tracker.CompletedBytes(5, n.Files[0].Segments); got != 0 {
+ t.Errorf("CompletedBytes with file 5: got %d, want 0", got)
+ }
+}
+
+func TestProgressTracker_CompletedBytes_AllDone(t *testing.T) {
+ dir := t.TempDir()
+ n := makeTestNZB(1, 3)
+ tracker := NewProgressTracker("cb-all", n, dir)
+
+ for i := 0; i < 3; i++ {
+ tracker.MarkDone(0, i)
+ }
+
+ got := tracker.CompletedBytes(0, n.Files[0].Segments)
+ expected := int64(3 * 750 * 1024)
+ if got != expected {
+ t.Errorf("CompletedBytes all done: got %d, want %d", got, expected)
+ }
+}
+
+// --- CompletedSegments out of bounds ---
+
+func TestProgressTracker_CompletedSegments_OutOfBounds(t *testing.T) {
+ dir := t.TempDir()
+ n := makeTestNZB(1, 3)
+ tracker := NewProgressTracker("cs-oob", n, dir)
+
+ if got := tracker.CompletedSegments(-1); got != 0 {
+ t.Errorf("CompletedSegments(-1) = %d, want 0", got)
+ }
+ if got := tracker.CompletedSegments(99); got != 0 {
+ t.Errorf("CompletedSegments(99) = %d, want 0", got)
+ }
+}
+
+// --- Load with corrupted / truncated data ---
+
+func TestProgressTracker_Load_TruncatedHeader(t *testing.T) {
+ dir := t.TempDir()
+ n := makeTestNZB(1, 3)
+ tracker := NewProgressTracker("trunc", n, dir)
+
+ // Write too-short data
+ os.WriteFile(tracker.progressPath(), []byte("UNR"), 0o644)
+
+ loaded, err := tracker.Load()
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if loaded {
+ t.Error("truncated header should not load")
+ }
+}
+
+func TestProgressTracker_Load_BadMagic(t *testing.T) {
+ dir := t.TempDir()
+ n := makeTestNZB(1, 3)
+ tracker := NewProgressTracker("badmagic", n, dir)
+
+ // Write data with wrong magic bytes
+ data := make([]byte, headerSize+10)
+ copy(data[0:4], []byte("BAAD"))
+ os.WriteFile(tracker.progressPath(), data, 0o644)
+
+ loaded, err := tracker.Load()
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if loaded {
+ t.Error("bad magic should not load")
+ }
+}
+
+func TestProgressTracker_Load_BadVersion(t *testing.T) {
+ dir := t.TempDir()
+ n := makeTestNZB(1, 3)
+ tracker := NewProgressTracker("badver", n, dir)
+
+ data := make([]byte, headerSize+10)
+ copy(data[0:4], progressMagic[:])
+ data[4] = 99 // unsupported version
+ os.WriteFile(tracker.progressPath(), data, 0o644)
+
+ loaded, err := tracker.Load()
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if loaded {
+ t.Error("bad version should not load")
+ }
+}
+
+func TestProgressTracker_Load_WrongFileCount(t *testing.T) {
+ dir := t.TempDir()
+ n := makeTestNZB(2, 3)
+ tracker := NewProgressTracker("wrongfc", n, dir)
+
+ data := make([]byte, headerSize+20)
+ copy(data[0:4], progressMagic[:])
+ data[4] = progressVersion
+ binary.LittleEndian.PutUint16(data[6:8], 99) // wrong file count
+ copy(data[8:40], tracker.fingerprint[:])
+ os.WriteFile(tracker.progressPath(), data, 0o644)
+
+ loaded, err := tracker.Load()
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if loaded {
+ t.Error("wrong file count should not load")
+ }
+}
+
+func TestProgressTracker_Load_TruncatedBitset(t *testing.T) {
+ dir := t.TempDir()
+ n := makeTestNZB(1, 16)
+ tracker := NewProgressTracker("truncbit", n, dir)
+
+ // Build a valid header but truncate the bitset data
+ data := make([]byte, headerSize+4) // header + segCount but no bitset
+ copy(data[0:4], progressMagic[:])
+ data[4] = progressVersion
+ binary.LittleEndian.PutUint16(data[6:8], 1) // 1 file
+ copy(data[8:40], tracker.fingerprint[:])
+ binary.LittleEndian.PutUint32(data[headerSize:headerSize+4], 16) // 16 segs
+ // No bitset data follows — truncated
+ os.WriteFile(tracker.progressPath(), data, 0o644)
+
+ loaded, err := tracker.Load()
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if loaded {
+ t.Error("truncated bitset should not load")
+ }
+}
+
+func TestProgressTracker_Load_SegCountMismatch(t *testing.T) {
+ dir := t.TempDir()
+ n := makeTestNZB(1, 5)
+ tracker := NewProgressTracker("segmis", n, dir)
+
+ // Build valid header with correct file count and fingerprint, but wrong segCount
+ bitsetLen := (999 + 7) / 8
+ data := make([]byte, headerSize+4+bitsetLen)
+ copy(data[0:4], progressMagic[:])
+ data[4] = progressVersion
+ binary.LittleEndian.PutUint16(data[6:8], 1)
+ copy(data[8:40], tracker.fingerprint[:])
+ binary.LittleEndian.PutUint32(data[headerSize:headerSize+4], 999) // wrong seg count
+ os.WriteFile(tracker.progressPath(), data, 0o644)
+
+ loaded, err := tracker.Load()
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if loaded {
+ t.Error("segment count mismatch should not load")
+ }
+}
+
+func TestProgressTracker_Load_NonexistentFile(t *testing.T) {
+ dir := t.TempDir()
+ n := makeTestNZB(1, 3)
+ tracker := NewProgressTracker("nofile", n, dir)
+
+ loaded, err := tracker.Load()
+ if err != nil {
+ t.Fatalf("unexpected error for missing file: %v", err)
+ }
+ if loaded {
+ t.Error("nonexistent file should return false")
+ }
+}
+
+// --- Flush and Load round-trip with multiple files ---
+
+func TestProgressTracker_MultiFileRoundTrip(t *testing.T) {
+ dir := t.TempDir()
+ n := makeTestNZB(5, 10) // 5 files, 10 segments each
+ tracker := NewProgressTracker("multi-rt", n, dir)
+
+ // Mark various segments across files
+ tracker.MarkDone(0, 0)
+ tracker.MarkDone(0, 9)
+ tracker.MarkDone(1, 5)
+ tracker.MarkDone(2, 0)
+ tracker.MarkDone(2, 1)
+ tracker.MarkDone(2, 2)
+ tracker.MarkDone(2, 3)
+ tracker.MarkDone(2, 4)
+ tracker.MarkDone(2, 5)
+ tracker.MarkDone(2, 6)
+ tracker.MarkDone(2, 7)
+ tracker.MarkDone(2, 8)
+ tracker.MarkDone(2, 9) // file 2 fully done
+ tracker.MarkDone(4, 7)
+
+ if err := tracker.Flush(); err != nil {
+ t.Fatalf("flush: %v", err)
+ }
+
+ // Reload
+ tracker2 := NewProgressTracker("multi-rt", n, dir)
+ loaded, err := tracker2.Load()
+ if err != nil {
+ t.Fatalf("load: %v", err)
+ }
+ if !loaded {
+ t.Fatal("should load")
+ }
+
+ if tracker2.CompletedSegments(0) != 2 {
+ t.Errorf("file 0: got %d, want 2", tracker2.CompletedSegments(0))
+ }
+ if tracker2.CompletedSegments(1) != 1 {
+ t.Errorf("file 1: got %d, want 1", tracker2.CompletedSegments(1))
+ }
+ if !tracker2.IsFileDone(2) {
+ t.Error("file 2 should be done")
+ }
+ if tracker2.CompletedSegments(3) != 0 {
+ t.Errorf("file 3: got %d, want 0", tracker2.CompletedSegments(3))
+ }
+ if tracker2.CompletedSegments(4) != 1 {
+ t.Errorf("file 4: got %d, want 1", tracker2.CompletedSegments(4))
+ }
+ if tracker2.TotalCompleted() != 14 {
+ t.Errorf("TotalCompleted: got %d, want 14", tracker2.TotalCompleted())
+ }
+}
+
+// --- Concurrent mark + IsDone reads ---
+
+func TestProgressTracker_ConcurrentMarkAndRead(t *testing.T) {
+ dir := t.TempDir()
+ segCount := 500
+ n := makeTestNZB(2, segCount)
+ tracker := NewProgressTracker("conc-rw", n, dir)
+
+ // Use separate WaitGroups for writers and readers
+ var writerWg sync.WaitGroup
+ stop := make(chan struct{})
+
+ // Writers
+ for file := 0; file < 2; file++ {
+ for seg := 0; seg < segCount; seg++ {
+ writerWg.Add(1)
+ go func(f, s int) {
+ defer writerWg.Done()
+ tracker.MarkDone(f, s)
+ }(file, seg)
+ }
+ }
+
+ // Readers — continuously read while writes happen
+ var readerWg sync.WaitGroup
+ for i := 0; i < 4; i++ {
+ readerWg.Add(1)
+ go func() {
+ defer readerWg.Done()
+ for {
+ select {
+ case <-stop:
+ return
+ default:
+ // These should never panic
+ tracker.IsDone(0, 0)
+ tracker.IsDone(1, segCount-1)
+ tracker.IsFileDone(0)
+ tracker.CompletedSegments(1)
+ tracker.TotalCompleted()
+ }
+ }
+ }()
+ }
+
+ // Wait for all writers to finish, then stop readers
+ writerWg.Wait()
+ close(stop)
+ readerWg.Wait()
+
+ // After all goroutines complete, everything should be done
+ if !tracker.IsFileDone(0) {
+ t.Errorf("file 0 should be done, got %d/%d", tracker.CompletedSegments(0), segCount)
+ }
+ if !tracker.IsFileDone(1) {
+ t.Errorf("file 1 should be done, got %d/%d", tracker.CompletedSegments(1), segCount)
+ }
+}
+
+// --- Concurrent flush safety ---
+
+func TestProgressTracker_ConcurrentFlush(t *testing.T) {
+ dir := t.TempDir()
+ n := makeTestNZB(1, 100)
+ tracker := NewProgressTracker("conc-flush", n, dir)
+
+ // Mark some segments
+ for i := 0; i < 50; i++ {
+ tracker.MarkDone(0, i)
+ }
+
+ // Multiple concurrent flushes should not panic or corrupt
+ var wg sync.WaitGroup
+ for i := 0; i < 10; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ tracker.Flush()
+ }()
+ }
+ wg.Wait()
+
+ // Verify state is loadable
+ tracker2 := NewProgressTracker("conc-flush", n, dir)
+ loaded, err := tracker2.Load()
+ if err != nil {
+ t.Fatalf("load: %v", err)
+ }
+ if !loaded {
+ t.Fatal("should load after concurrent flushes")
+ }
+ if tracker2.CompletedSegments(0) != 50 {
+ t.Errorf("after concurrent flush: got %d, want 50", tracker2.CompletedSegments(0))
+ }
+}
+
+// --- Remove with .tmp file ---
+
+func TestProgressTracker_Remove_WithTmpFile(t *testing.T) {
+ dir := t.TempDir()
+ n := makeTestNZB(1, 3)
+ tracker := NewProgressTracker("rm-tmp", n, dir)
+
+ // Create all three files that Remove should clean up
+ os.WriteFile(tracker.progressPath(), []byte("data"), 0o644)
+ os.WriteFile(tracker.nzbPath(), []byte(""), 0o644)
+ os.WriteFile(tracker.progressPath()+".tmp", []byte("tmp"), 0o644)
+
+ tracker.Remove()
+
+ for _, p := range []string{tracker.progressPath(), tracker.nzbPath(), tracker.progressPath() + ".tmp"} {
+ if _, err := os.Stat(p); !os.IsNotExist(err) {
+ t.Errorf("file should be removed: %s", p)
+ }
+ }
+}
+
+// --- CleanStaleFiles edge cases ---
+
+func TestCleanStaleFiles_EmptyDir(t *testing.T) {
+ dir := t.TempDir()
+ if got := CleanStaleFiles(dir, time.Hour); got != 0 {
+ t.Errorf("empty dir: got %d removed, want 0", got)
+ }
+}
+
+func TestCleanStaleFiles_NonexistentDir(t *testing.T) {
+ if got := CleanStaleFiles("/nonexistent/path/that/does/not/exist", time.Hour); got != 0 {
+ t.Errorf("nonexistent dir: got %d removed, want 0", got)
+ }
+}
+
+func TestCleanStaleFiles_AllFresh(t *testing.T) {
+ dir := t.TempDir()
+ os.WriteFile(filepath.Join(dir, "a.progress"), []byte("a"), 0o644)
+ os.WriteFile(filepath.Join(dir, "b.progress"), []byte("b"), 0o644)
+
+ if got := CleanStaleFiles(dir, 24*time.Hour); got != 0 {
+ t.Errorf("all fresh: got %d removed, want 0", got)
+ }
+}
+
+func TestCleanStaleFiles_SkipsSubdirs(t *testing.T) {
+ dir := t.TempDir()
+ subDir := filepath.Join(dir, "subdir")
+ os.MkdirAll(subDir, 0o755)
+
+ // Backdate the subdir (it should not be removed)
+ os.Chtimes(subDir, fixedPast, fixedPast)
+
+ if got := CleanStaleFiles(dir, 24*time.Hour); got != 0 {
+ t.Errorf("should skip subdirs: got %d removed, want 0", got)
+ }
+ if _, err := os.Stat(subDir); err != nil {
+ t.Error("subdir should still exist")
+ }
+}
+
+func TestCleanStaleFiles_MixedAges(t *testing.T) {
+ dir := t.TempDir()
+
+ stale1 := filepath.Join(dir, "old1.progress")
+ stale2 := filepath.Join(dir, "old2.nzb")
+ fresh := filepath.Join(dir, "new.progress")
+
+ os.WriteFile(stale1, []byte("x"), 0o644)
+ os.WriteFile(stale2, []byte("x"), 0o644)
+ os.WriteFile(fresh, []byte("x"), 0o644)
+
+ os.Chtimes(stale1, fixedPast, fixedPast)
+ os.Chtimes(stale2, fixedPast, fixedPast)
+
+ if got := CleanStaleFiles(dir, 7*24*time.Hour); got != 2 {
+ t.Errorf("mixed ages: got %d removed, want 2", got)
+ }
+ if _, err := os.Stat(fresh); err != nil {
+ t.Error("fresh file should still exist")
+ }
+}
+
+// --- progressPath / nzbPath ---
+
+func TestProgressTracker_Paths(t *testing.T) {
+ dir := "/some/dir"
+ n := makeTestNZB(1, 1)
+ tracker := NewProgressTracker("my-task", n, dir)
+
+ if got := tracker.progressPath(); got != filepath.Join(dir, "my-task.progress") {
+ t.Errorf("progressPath: got %q", got)
+ }
+ if got := tracker.nzbPath(); got != filepath.Join(dir, "my-task.nzb") {
+ t.Errorf("nzbPath: got %q", got)
+ }
+}
+
+// --- formatBytes ---
+
+func TestFormatBytes(t *testing.T) {
+ tests := []struct {
+ input int64
+ want string
+ }{
+ {0, "0 B"},
+ {500, "500 B"},
+ {1023, "1023 B"},
+ {1024, "1.0 KB"},
+ {1536, "1.5 KB"},
+ {1048576, "1.0 MB"},
+ {1073741824, "1.0 GB"},
+ {1099511627776, "1.0 TB"},
+ }
+ for _, tt := range tests {
+ got := formatBytes(tt.input)
+ if got != tt.want {
+ t.Errorf("formatBytes(%d) = %q, want %q", tt.input, got, tt.want)
+ }
+ }
+}
+
+// --- Single file with 1 segment (boundary) ---
+
+func TestProgressTracker_SingleSegment(t *testing.T) {
+ dir := t.TempDir()
+ n := makeTestNZB(1, 1)
+ tracker := NewProgressTracker("single-seg", n, dir)
+
+ if tracker.IsFileDone(0) {
+ t.Error("should not be done initially")
+ }
+
+ tracker.MarkDone(0, 0)
+
+ if !tracker.IsFileDone(0) {
+ t.Error("should be done after marking the only segment")
+ }
+
+ if err := tracker.Flush(); err != nil {
+ t.Fatalf("flush: %v", err)
+ }
+
+ tracker2 := NewProgressTracker("single-seg", n, dir)
+ loaded, _ := tracker2.Load()
+ if !loaded {
+ t.Fatal("should load")
+ }
+ if !tracker2.IsFileDone(0) {
+ t.Error("should be done after reload")
+ }
+}
+
+// --- Flush creates directory if missing ---
+
+func TestProgressTracker_FlushCreatesDir(t *testing.T) {
+ base := t.TempDir()
+ dir := filepath.Join(base, "nested", "resume")
+ n := makeTestNZB(1, 2)
+ tracker := NewProgressTracker("mkdir-test", n, dir)
+
+ tracker.MarkDone(0, 0)
+ if err := tracker.Flush(); err != nil {
+ t.Fatalf("flush should create dir: %v", err)
+ }
+
+ if _, err := os.Stat(tracker.progressPath()); err != nil {
+ t.Fatalf("progress file should exist: %v", err)
+ }
+}
+
+// --- Double flush after no new marks ---
+
+func TestProgressTracker_DoubleFlush(t *testing.T) {
+ dir := t.TempDir()
+ n := makeTestNZB(1, 3)
+ tracker := NewProgressTracker("dbl-flush", n, dir)
+
+ tracker.MarkDone(0, 0)
+ if err := tracker.Flush(); err != nil {
+ t.Fatalf("first flush: %v", err)
+ }
+
+ // Second flush without new marks should be a no-op (dirty=false)
+ if err := tracker.Flush(); err != nil {
+ t.Fatalf("second flush: %v", err)
+ }
+}
diff --git a/internal/usenet/nntp/client_test.go b/internal/usenet/nntp/client_test.go
new file mode 100644
index 0000000..41e61e3
--- /dev/null
+++ b/internal/usenet/nntp/client_test.go
@@ -0,0 +1,131 @@
+package nntp
+
+import (
+ "bufio"
+ "bytes"
+ "testing"
+)
+
+func TestNewClient(t *testing.T) {
+ c := NewClient(Config{Host: "news.example.com", Port: 563, SSL: true})
+ if c.cfg.MaxConnections != 10 {
+ t.Errorf("default MaxConnections = %d, want 10", c.cfg.MaxConnections)
+ }
+ if c.cfg.Host != "news.example.com" {
+ t.Errorf("Host = %q", c.cfg.Host)
+ }
+}
+
+func TestNewClientCustomConnections(t *testing.T) {
+ c := NewClient(Config{Host: "news.example.com", Port: 563, MaxConnections: 20})
+ if c.cfg.MaxConnections != 20 {
+ t.Errorf("MaxConnections = %d, want 20", c.cfg.MaxConnections)
+ }
+}
+
+func TestNewClientZeroConnections(t *testing.T) {
+ c := NewClient(Config{Host: "news.example.com", Port: 563, MaxConnections: 0})
+ if c.cfg.MaxConnections != 10 {
+ t.Errorf("MaxConnections should default to 10, got %d", c.cfg.MaxConnections)
+ }
+}
+
+func TestNewClientNegativeConnections(t *testing.T) {
+ c := NewClient(Config{MaxConnections: -5})
+ if c.cfg.MaxConnections != 10 {
+ t.Errorf("MaxConnections should default to 10 for negative, got %d", c.cfg.MaxConnections)
+ }
+}
+
+func TestActiveConnections(t *testing.T) {
+ c := NewClient(Config{Host: "localhost", Port: 119})
+ if c.ActiveConnections() != 0 {
+ t.Errorf("ActiveConnections = %d, want 0", c.ActiveConnections())
+ }
+}
+
+func TestStatus(t *testing.T) {
+ c := NewClient(Config{Host: "news.example.com", Port: 563})
+ s := c.Status()
+ if s != "0 connections (0 pooled) to news.example.com:563" {
+ t.Errorf("Status = %q", s)
+ }
+}
+
+func TestCloseIdempotent(t *testing.T) {
+ c := NewClient(Config{Host: "localhost", Port: 119})
+ // Close should be idempotent
+ if err := c.Close(); err != nil {
+ t.Errorf("first Close: %v", err)
+ }
+ if err := c.Close(); err != nil {
+ t.Errorf("second Close: %v", err)
+ }
+}
+
+func TestArticleNotFoundError(t *testing.T) {
+ err := &ArticleNotFoundError{MessageID: "abc123@news.example.com"}
+ msg := err.Error()
+ if msg != "nntp: article not found: abc123@news.example.com" {
+ t.Errorf("Error() = %q", msg)
+ }
+}
+
+func TestReadDotBody(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ want string
+ }{
+ {
+ "simple body",
+ "Hello World\r\n.\r\n",
+ "Hello World\n",
+ },
+ {
+ "multiline",
+ "Line 1\r\nLine 2\r\nLine 3\r\n.\r\n",
+ "Line 1\nLine 2\nLine 3\n",
+ },
+ {
+ "dot-stuffed line",
+ "..This starts with a dot\r\n.\r\n",
+ ".This starts with a dot\n",
+ },
+ {
+ "empty body",
+ ".\r\n",
+ "",
+ },
+ {
+ "binary-like data",
+ "=ybegin line=128 size=1024 name=test.bin\r\nsome encoded data\r\n=yend\r\n.\r\n",
+ "=ybegin line=128 size=1024 name=test.bin\nsome encoded data\n=yend\n",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ r := bufio.NewReader(bytes.NewBufferString(tt.input))
+ got, err := readDotBody(r)
+ if err != nil {
+ t.Fatalf("readDotBody: %v", err)
+ }
+ if string(got) != tt.want {
+ t.Errorf("readDotBody = %q, want %q", string(got), tt.want)
+ }
+ })
+ }
+}
+
+func TestReadDotBodyEOF(t *testing.T) {
+ // No dot terminator — should read until EOF
+ r := bufio.NewReader(bytes.NewBufferString("partial data\r\n"))
+ got, err := readDotBody(r)
+ if err != nil {
+ t.Fatalf("readDotBody EOF: %v", err)
+ }
+ if string(got) != "partial data\n" {
+ t.Errorf("readDotBody EOF = %q", string(got))
+ }
+}
diff --git a/internal/usenet/nzb/parser_test.go b/internal/usenet/nzb/parser_test.go
index 8d0d686..6afeef0 100644
--- a/internal/usenet/nzb/parser_test.go
+++ b/internal/usenet/nzb/parser_test.go
@@ -267,3 +267,784 @@ func TestStripAngleBrackets(t *testing.T) {
t.Errorf("MessageID not stripped: got %q", nzb.Files[0].Segments[0].MessageID)
}
}
+
+// --- Malformed / edge-case XML inputs ---
+
+func TestParse_CompletelyEmpty(t *testing.T) {
+ _, err := Parse(strings.NewReader(""))
+ if err == nil {
+ t.Error("expected error for completely empty input")
+ }
+}
+
+func TestParse_OnlyWhitespace(t *testing.T) {
+ _, err := Parse(strings.NewReader(" \n\t "))
+ if err == nil {
+ t.Error("expected error for whitespace-only input")
+ }
+}
+
+func TestParse_ValidXMLButNotNZB(t *testing.T) {
+ _, err := Parse(strings.NewReader(`Hello`))
+ if err == nil {
+ t.Error("expected error for non-NZB XML")
+ }
+}
+
+func TestParse_NZBWithNoSegments(t *testing.T) {
+ xml := `
+
+
+ alt.test
+
+
+`
+ _, err := Parse(strings.NewReader(xml))
+ if err == nil {
+ t.Error("expected error for file with no segments")
+ }
+}
+
+func TestParse_SegmentWithEmptyMessageID(t *testing.T) {
+ xml := `
+
+
+ alt.test
+
+
+
+
+`
+ _, err := Parse(strings.NewReader(xml))
+ if err == nil {
+ t.Error("expected error: segment with empty/whitespace message ID should be skipped, leaving no valid files")
+ }
+}
+
+func TestParse_MixedValidAndEmptySegments(t *testing.T) {
+ xml := `
+
+
+ alt.test
+
+ valid@id
+
+ also-valid@id
+
+
+`
+ nzb, err := Parse(strings.NewReader(xml))
+ if err != nil {
+ t.Fatalf("Parse failed: %v", err)
+ }
+ if len(nzb.Files[0].Segments) != 2 {
+ t.Errorf("expected 2 valid segments, got %d", len(nzb.Files[0].Segments))
+ }
+}
+
+// --- Metadata / Head parsing ---
+
+func TestParse_MetaPassword(t *testing.T) {
+ xml := `
+
+
+ s3cr3t
+ My Movie
+ Movies
+
+
+ alt.test
+
+ seg@id
+
+
+`
+ nzb, err := Parse(strings.NewReader(xml))
+ if err != nil {
+ t.Fatalf("Parse failed: %v", err)
+ }
+ if nzb.Password != "s3cr3t" {
+ t.Errorf("Password: got %q, want %q", nzb.Password, "s3cr3t")
+ }
+ if nzb.Meta["title"] != "My Movie" {
+ t.Errorf("Meta title: got %q", nzb.Meta["title"])
+ }
+ if nzb.Meta["category"] != "Movies" {
+ t.Errorf("Meta category: got %q", nzb.Meta["category"])
+ }
+}
+
+func TestParse_MetaPasswordWithWhitespace(t *testing.T) {
+ xml := `
+
+
+ padded
+
+
+ alt.test
+
+ seg@id
+
+
+`
+ nzb, err := Parse(strings.NewReader(xml))
+ if err != nil {
+ t.Fatalf("Parse failed: %v", err)
+ }
+ if nzb.Password != "padded" {
+ t.Errorf("Password should be trimmed: got %q", nzb.Password)
+ }
+}
+
+func TestParse_NoHead(t *testing.T) {
+ xml := `
+
+
+ alt.test
+
+ seg@id
+
+
+`
+ nzb, err := Parse(strings.NewReader(xml))
+ if err != nil {
+ t.Fatalf("Parse failed: %v", err)
+ }
+ if nzb.Password != "" {
+ t.Errorf("Password should be empty: got %q", nzb.Password)
+ }
+ if len(nzb.Meta) != 0 {
+ t.Errorf("Meta should be empty: got %v", nzb.Meta)
+ }
+}
+
+func TestParse_MetaWithEmptyType(t *testing.T) {
+ xml := `
+
+
+ ignored
+ kept
+
+
+ alt.test
+
+ seg@id
+
+
+`
+ nzb, err := Parse(strings.NewReader(xml))
+ if err != nil {
+ t.Fatalf("Parse failed: %v", err)
+ }
+ if _, ok := nzb.Meta[""]; ok {
+ t.Error("empty-type meta should not be stored")
+ }
+ if nzb.Meta["name"] != "kept" {
+ t.Errorf("Meta name: got %q", nzb.Meta["name"])
+ }
+}
+
+// --- Multiple files ---
+
+func TestParse_MultipleFilesVariousTypes(t *testing.T) {
+ xml := `
+
+
+ alt.binaries.movies
+
+ mkv001@ex
+ mkv002@ex
+
+
+
+ alt.binaries.movies
+
+ nfo001@ex
+
+
+
+ alt.binaries.movies
+
+ par001@ex
+
+
+
+ alt.binaries.movies
+
+ parv001@ex
+
+
+
+ alt.binaries.movies
+
+ sample001@ex
+
+
+`
+ nzb, err := Parse(strings.NewReader(xml))
+ if err != nil {
+ t.Fatalf("Parse failed: %v", err)
+ }
+
+ if len(nzb.Files) != 5 {
+ t.Fatalf("expected 5 files, got %d", len(nzb.Files))
+ }
+
+ // ContentFiles should exclude nfo, par2, par2 vol, and sample
+ content := nzb.ContentFiles()
+ if len(content) != 1 {
+ t.Errorf("ContentFiles: got %d, want 1", len(content))
+ }
+ if len(content) > 0 && content[0].Filename() != "movie.mkv" {
+ t.Errorf("ContentFiles[0]: got %q, want movie.mkv", content[0].Filename())
+ }
+
+ // Par2Files
+ par2 := nzb.Par2Files()
+ if len(par2) != 2 {
+ t.Errorf("Par2Files: got %d, want 2", len(par2))
+ }
+
+ if !nzb.HasPar2() {
+ t.Error("HasPar2 should be true")
+ }
+ if nzb.HasRars() {
+ t.Error("HasRars should be false for this NZB")
+ }
+}
+
+// --- Segment ordering / number parsing ---
+
+func TestParse_SegmentNumberParsing(t *testing.T) {
+ xml := `
+
+
+ alt.test
+
+ c@id
+ a@id
+ b@id
+
+
+`
+ nzb, err := Parse(strings.NewReader(xml))
+ if err != nil {
+ t.Fatalf("Parse failed: %v", err)
+ }
+
+ segs := nzb.Files[0].Segments
+ if len(segs) != 3 {
+ t.Fatalf("expected 3 segments, got %d", len(segs))
+ }
+
+ // Parse preserves order from XML; sorting is done by the downloader
+ // Verify numbers are parsed correctly
+ numbers := make(map[int]bool)
+ for _, s := range segs {
+ numbers[s.Number] = true
+ }
+ for _, want := range []int{1, 2, 3} {
+ if !numbers[want] {
+ t.Errorf("missing segment number %d", want)
+ }
+ }
+}
+
+func TestParse_SegmentBytesZero(t *testing.T) {
+ xml := `
+
+
+ alt.test
+
+ seg@id
+
+
+`
+ nzb, err := Parse(strings.NewReader(xml))
+ if err != nil {
+ t.Fatalf("Parse failed: %v", err)
+ }
+ if nzb.Files[0].Segments[0].Bytes != 0 {
+ t.Errorf("expected 0 bytes, got %d", nzb.Files[0].Segments[0].Bytes)
+ }
+}
+
+func TestParse_SegmentBytesNonNumeric(t *testing.T) {
+ xml := `
+
+
+ alt.test
+
+ seg@id
+
+
+`
+ nzb, err := Parse(strings.NewReader(xml))
+ if err != nil {
+ t.Fatalf("Parse failed: %v", err)
+ }
+ // Non-numeric bytes should parse as 0
+ if nzb.Files[0].Segments[0].Bytes != 0 {
+ t.Errorf("non-numeric bytes should be 0, got %d", nzb.Files[0].Segments[0].Bytes)
+ }
+}
+
+// --- File helper methods ---
+
+func TestFileTotalBytes(t *testing.T) {
+ f := File{
+ Segments: []Segment{
+ {Bytes: 100}, {Bytes: 200}, {Bytes: 300},
+ },
+ }
+ if got := f.TotalBytes(); got != 600 {
+ t.Errorf("TotalBytes: got %d, want 600", got)
+ }
+}
+
+func TestFileTotalBytes_Empty(t *testing.T) {
+ f := File{}
+ if got := f.TotalBytes(); got != 0 {
+ t.Errorf("TotalBytes of empty file: got %d, want 0", got)
+ }
+}
+
+func TestFileExtension_Various(t *testing.T) {
+ tests := []struct {
+ subject string
+ want string
+ }{
+ {`"file.MKV" yEnc`, ".mkv"},
+ {`"file.RAR" yEnc`, ".rar"},
+ {`"file.Par2" yEnc`, ".par2"},
+ {`"noext" yEnc`, ""},
+ {`"file.tar.gz" yEnc`, ".gz"},
+ }
+ for _, tt := range tests {
+ f := File{Subject: tt.subject}
+ if got := f.Extension(); got != tt.want {
+ t.Errorf("Extension(%q) = %q, want %q", tt.subject, got, tt.want)
+ }
+ }
+}
+
+// --- LargestFile edge cases ---
+
+func TestLargestFile_EmptyNZB(t *testing.T) {
+ nzb := &NZB{}
+ if nzb.LargestFile() != nil {
+ t.Error("LargestFile should return nil for empty NZB")
+ }
+}
+
+func TestLargestFile_SingleFile(t *testing.T) {
+ nzb := &NZB{
+ Files: []File{
+ {Subject: `"only.bin"`, Segments: []Segment{{Bytes: 100}}},
+ },
+ }
+ largest := nzb.LargestFile()
+ if largest == nil {
+ t.Fatal("LargestFile should not be nil")
+ }
+ if largest.Filename() != "only.bin" {
+ t.Errorf("got %q", largest.Filename())
+ }
+}
+
+func TestLargestFile_MultipleSameSize(t *testing.T) {
+ nzb := &NZB{
+ Files: []File{
+ {Subject: `"a.bin"`, Segments: []Segment{{Bytes: 100}}},
+ {Subject: `"b.bin"`, Segments: []Segment{{Bytes: 100}}},
+ },
+ }
+ largest := nzb.LargestFile()
+ if largest == nil {
+ t.Fatal("LargestFile should not be nil")
+ }
+ // Should return the first one (stable)
+ if largest.Filename() != "a.bin" {
+ t.Errorf("got %q, expected first file for equal sizes", largest.Filename())
+ }
+}
+
+// --- IsObfuscated ---
+
+func TestIsObfuscated_Normal(t *testing.T) {
+ nzb := &NZB{
+ Files: []File{
+ {Subject: `"Movie.2024.1080p.BluRay.x264-GROUP.mkv"`},
+ },
+ }
+ if nzb.IsObfuscated() {
+ t.Error("normal filename should not be obfuscated")
+ }
+}
+
+func TestIsObfuscated_HexName(t *testing.T) {
+ nzb := &NZB{
+ Files: []File{
+ {Subject: `"a1b2c3d4e5f6a7b8c9d0e1f2.mkv"`},
+ },
+ }
+ if !nzb.IsObfuscated() {
+ t.Error("hex-like filename should be obfuscated")
+ }
+}
+
+func TestIsObfuscated_EmptyFiles(t *testing.T) {
+ nzb := &NZB{}
+ if nzb.IsObfuscated() {
+ t.Error("empty NZB should not be obfuscated")
+ }
+}
+
+func TestIsObfuscated_ShortHex(t *testing.T) {
+ // Short name (<=10 chars) should not trigger obfuscation
+ nzb := &NZB{
+ Files: []File{
+ {Subject: `"abcdef.mkv"`},
+ },
+ }
+ if nzb.IsObfuscated() {
+ t.Error("short hex-like name should not be obfuscated")
+ }
+}
+
+// --- isMetadataFile ---
+
+func TestIsMetadataFile(t *testing.T) {
+ tests := []struct {
+ name string
+ want bool
+ }{
+ {"file.par2", true},
+ {"file.nfo", true},
+ {"file.sfv", true},
+ {"file.nzb", true},
+ {"file.txt", true},
+ {"file.jpg", true},
+ {"file.png", true},
+ {"file.url", true},
+ {"file.mkv", false},
+ {"file.rar", false},
+ {"file.avi", false},
+ {"FILE.PAR2", true},
+ {"FILE.NFO", true},
+ }
+ for _, tt := range tests {
+ if got := isMetadataFile(tt.name); got != tt.want {
+ t.Errorf("isMetadataFile(%q) = %v, want %v", tt.name, got, tt.want)
+ }
+ }
+}
+
+// --- isSampleFile ---
+
+func TestIsSampleFile(t *testing.T) {
+ tests := []struct {
+ name string
+ want bool
+ }{
+ {"movie.sample.mkv", true},
+ {"Sample.mkv", true},
+ {"SAMPLE.avi", true},
+ {"movie-sample-video.mkv", true},
+ {"movie_sample.mkv", true},
+ {"sample.mkv", true},
+ {"resampled.mkv", false}, // "sample" is part of "resampled"
+ {"movie.mkv", false},
+ {"my.samples.zip", false}, // "sample" followed by 's' (alphanumeric)
+ }
+ for _, tt := range tests {
+ if got := isSampleFile(tt.name); got != tt.want {
+ t.Errorf("isSampleFile(%q) = %v, want %v", tt.name, got, tt.want)
+ }
+ }
+}
+
+// --- isHexLike ---
+
+func TestIsHexLike(t *testing.T) {
+ tests := []struct {
+ input string
+ want bool
+ }{
+ {"abcdef0123456789", true},
+ {"ABCDEF", true},
+ {"Movie2024", false},
+ {"aabbccdd", true},
+ {"xyz_not_hex", false},
+ }
+ for _, tt := range tests {
+ if got := isHexLike(tt.input); got != tt.want {
+ t.Errorf("isHexLike(%q) = %v, want %v", tt.input, got, tt.want)
+ }
+ }
+}
+
+// --- sanitizeFilename ---
+
+func TestSanitizeFilename(t *testing.T) {
+ tests := []struct {
+ input string
+ want string
+ }{
+ {"simple name", "simple name"},
+ {"name (1/50)", "name"},
+ {"file yEnc (01/99)", "file"},
+ {`path/with\special:chars*?`, `path_with_special_chars__`},
+ {`"quoted" text`, `_quoted_ text`},
+ {" spaces ", "spaces"},
+ }
+ for _, tt := range tests {
+ if got := sanitizeFilename(tt.input); got != tt.want {
+ t.Errorf("sanitizeFilename(%q) = %q, want %q", tt.input, got, tt.want)
+ }
+ }
+}
+
+// --- Filename fallback ---
+
+func TestFilename_Fallback_NoQuotes(t *testing.T) {
+ f := File{Subject: "No quotes here yEnc (1/50)"}
+ got := f.Filename()
+ if got != "No quotes here" {
+ t.Errorf("Filename fallback: got %q, want %q", got, "No quotes here")
+ }
+}
+
+func TestFilename_EmptySubject(t *testing.T) {
+ f := File{Subject: ""}
+ got := f.Filename()
+ if got != "" {
+ t.Errorf("Filename empty subject: got %q, want empty", got)
+ }
+}
+
+// --- NZB aggregate methods on mixed content ---
+
+func TestNZB_HasRars_NoRars(t *testing.T) {
+ nzb := &NZB{
+ Files: []File{
+ {Subject: `"movie.mkv"`},
+ {Subject: `"movie.par2"`},
+ },
+ }
+ if nzb.HasRars() {
+ t.Error("HasRars should be false")
+ }
+}
+
+func TestNZB_HasPar2_NoPar2(t *testing.T) {
+ nzb := &NZB{
+ Files: []File{
+ {Subject: `"movie.mkv"`},
+ {Subject: `"movie.rar"`},
+ },
+ }
+ if nzb.HasPar2() {
+ t.Error("HasPar2 should be false")
+ }
+}
+
+func TestNZB_TotalSegments_MultiFile(t *testing.T) {
+ nzb := &NZB{
+ Files: []File{
+ {Segments: []Segment{{}, {}, {}}},
+ {Segments: []Segment{{}, {}}},
+ },
+ }
+ if got := nzb.TotalSegments(); got != 5 {
+ t.Errorf("TotalSegments: got %d, want 5", got)
+ }
+}
+
+func TestNZB_TotalBytes_MultiFile(t *testing.T) {
+ nzb := &NZB{
+ Files: []File{
+ {Segments: []Segment{{Bytes: 100}, {Bytes: 200}}},
+ {Segments: []Segment{{Bytes: 300}}},
+ },
+ }
+ if got := nzb.TotalBytes(); got != 600 {
+ t.Errorf("TotalBytes: got %d, want 600", got)
+ }
+}
+
+// --- isRarFile extended ---
+
+func TestIsRarFile_Extended(t *testing.T) {
+ tests := []struct {
+ name string
+ want bool
+ }{
+ {"file.RAR", true}, // case insensitive
+ {"file.Rar", true},
+ {"file.s01", true},
+ {"file.s99", true},
+ {"file.002", true},
+ {"file.999", true},
+ {"file.r0", false}, // too short extension
+ {"file.rXX", false}, // non-numeric
+ {"file", false}, // no extension
+ {"file.mp4", false},
+ }
+ for _, tt := range tests {
+ if got := isRarFile(tt.name); got != tt.want {
+ t.Errorf("isRarFile(%q) = %v, want %v", tt.name, got, tt.want)
+ }
+ }
+}
+
+// --- Parse with date edge cases ---
+
+func TestParse_DateNonNumeric(t *testing.T) {
+ xml := `
+
+
+ alt.test
+
+ seg@id
+
+
+`
+ nzb, err := Parse(strings.NewReader(xml))
+ if err != nil {
+ t.Fatalf("Parse failed: %v", err)
+ }
+ if nzb.Files[0].Date != 0 {
+ t.Errorf("non-numeric date should be 0, got %d", nzb.Files[0].Date)
+ }
+}
+
+func TestParse_DateEmpty(t *testing.T) {
+ xml := `
+
+
+ alt.test
+
+ seg@id
+
+
+`
+ nzb, err := Parse(strings.NewReader(xml))
+ if err != nil {
+ t.Fatalf("Parse failed: %v", err)
+ }
+ if nzb.Files[0].Date != 0 {
+ t.Errorf("empty date should be 0, got %d", nzb.Files[0].Date)
+ }
+}
+
+// --- Parse: file with all segments having empty IDs should be excluded ---
+
+func TestParse_AllEmptySegments(t *testing.T) {
+ xml := `
+
+
+ alt.test
+
+
+
+
+
+
+ alt.test
+
+ valid@id
+
+
+`
+ nzb, err := Parse(strings.NewReader(xml))
+ if err != nil {
+ t.Fatalf("Parse failed: %v", err)
+ }
+ if len(nzb.Files) != 1 {
+ t.Fatalf("expected 1 valid file, got %d", len(nzb.Files))
+ }
+ if nzb.Files[0].Filename() != "good.bin" {
+ t.Errorf("expected good.bin, got %q", nzb.Files[0].Filename())
+ }
+}
+
+// --- Groups ---
+
+func TestParse_NoGroups(t *testing.T) {
+ xml := `
+
+
+
+
+ seg@id
+
+
+`
+ nzb, err := Parse(strings.NewReader(xml))
+ if err != nil {
+ t.Fatalf("Parse failed: %v", err)
+ }
+ if len(nzb.Files[0].Groups) != 0 {
+ t.Errorf("expected 0 groups, got %d", len(nzb.Files[0].Groups))
+ }
+}
+
+func TestParse_MultipleGroups(t *testing.T) {
+ xml := `
+
+
+
+ alt.binaries.movies
+ alt.binaries.multimedia
+ alt.binaries.hdtv
+
+
+ seg@id
+
+
+`
+ nzb, err := Parse(strings.NewReader(xml))
+ if err != nil {
+ t.Fatalf("Parse failed: %v", err)
+ }
+ if len(nzb.Files[0].Groups) != 3 {
+ t.Errorf("expected 3 groups, got %d", len(nzb.Files[0].Groups))
+ }
+}
+
+// --- ContentFiles with sample variations ---
+
+func TestContentFiles_ExcludesSamples(t *testing.T) {
+ nzb := &NZB{
+ Files: []File{
+ {Subject: `"movie.mkv"`, Segments: []Segment{{Bytes: 1000, MessageID: "a"}}},
+ {Subject: `"movie.sample.mkv"`, Segments: []Segment{{Bytes: 100, MessageID: "b"}}},
+ {Subject: `"Sample/preview.mkv"`, Segments: []Segment{{Bytes: 100, MessageID: "c"}}},
+ },
+ }
+ content := nzb.ContentFiles()
+ if len(content) != 1 {
+ t.Errorf("ContentFiles should exclude samples: got %d, want 1", len(content))
+ }
+}
+
+// --- RarFiles with split naming ---
+
+func TestRarFiles_SplitRars(t *testing.T) {
+ nzb := &NZB{
+ Files: []File{
+ {Subject: `"movie.rar"`, Segments: []Segment{{MessageID: "a"}}},
+ {Subject: `"movie.r00"`, Segments: []Segment{{MessageID: "b"}}},
+ {Subject: `"movie.r01"`, Segments: []Segment{{MessageID: "c"}}},
+ {Subject: `"movie.001"`, Segments: []Segment{{MessageID: "d"}}},
+ {Subject: `"movie.002"`, Segments: []Segment{{MessageID: "e"}}},
+ {Subject: `"movie.par2"`, Segments: []Segment{{MessageID: "f"}}},
+ {Subject: `"movie.mkv"`, Segments: []Segment{{MessageID: "g"}}},
+ },
+ }
+ rars := nzb.RarFiles()
+ if len(rars) != 5 {
+ t.Errorf("RarFiles: got %d, want 5", len(rars))
+ }
+}
diff --git a/internal/usenet/postprocess/extract_test.go b/internal/usenet/postprocess/extract_test.go
new file mode 100644
index 0000000..5e61c75
--- /dev/null
+++ b/internal/usenet/postprocess/extract_test.go
@@ -0,0 +1,170 @@
+package postprocess
+
+import (
+ "os"
+ "path/filepath"
+ "testing"
+)
+
+func TestIsArchiveFile(t *testing.T) {
+ tests := []struct {
+ name string
+ want bool
+ }{
+ {"movie.rar", true},
+ {"movie.RAR", true},
+ {"movie.part01.rar", true},
+ {"movie.r00", true},
+ {"movie.r99", true},
+ {"movie.s00", true},
+ {"movie.001", true},
+ {"movie.099", true},
+ {"movie.mkv", false},
+ {"movie.mp4", false},
+ {"movie.par2", false},
+ {"movie.nfo", false},
+ {"movie.txt", false},
+ {"movie.r", false},
+ {"movie.abc", false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := isArchiveFile(tt.name)
+ if got != tt.want {
+ t.Errorf("isArchiveFile(%q) = %v, want %v", tt.name, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestIsCleanupTarget(t *testing.T) {
+ tests := []struct {
+ name string
+ want bool
+ }{
+ {"content.par2", true},
+ {"content.PAR2", true},
+ {"info.nfo", true},
+ {"checksum.sfv", true},
+ {"content.nzb", true},
+ {"content.srr", true},
+ {"content.srs", true},
+ {"cover.jpg", true},
+ {"cover.png", true},
+ {"readme.txt", true},
+ {"link.url", true},
+ {"movie.rar", true},
+ {"movie.r00", true},
+ {"movie.s01", true},
+ {"movie.001", true},
+ {"movie.mkv", false},
+ {"movie.mp4", false},
+ {"movie.avi", false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := isCleanupTarget(tt.name)
+ if got != tt.want {
+ t.Errorf("isCleanupTarget(%q) = %v, want %v", tt.name, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestIsNumeric(t *testing.T) {
+ tests := []struct {
+ input string
+ want bool
+ }{
+ {"", false},
+ {"0", true},
+ {"123", true},
+ {"00", true},
+ {"12a", false},
+ {"abc", false},
+ {" 1", false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.input, func(t *testing.T) {
+ got := isNumeric(tt.input)
+ if got != tt.want {
+ t.Errorf("isNumeric(%q) = %v, want %v", tt.input, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestListExtractedFiles(t *testing.T) {
+ dir := t.TempDir()
+
+ // Create some files
+ os.WriteFile(filepath.Join(dir, "movie.mkv"), []byte("video"), 0o644)
+ os.WriteFile(filepath.Join(dir, "subs.srt"), []byte("subs"), 0o644)
+ os.WriteFile(filepath.Join(dir, "movie.rar"), []byte("archive"), 0o644)
+ os.WriteFile(filepath.Join(dir, "movie.r00"), []byte("archive part"), 0o644)
+
+ archivePath := filepath.Join(dir, "movie.rar")
+ files, err := listExtractedFiles(dir, archivePath)
+ if err != nil {
+ t.Fatalf("listExtractedFiles: %v", err)
+ }
+
+ // Should exclude .rar and .r00 (archive files in same dir)
+ // Should include movie.mkv and subs.srt
+ if len(files) != 2 {
+ t.Errorf("expected 2 files, got %d: %v", len(files), files)
+ }
+
+ for _, f := range files {
+ base := filepath.Base(f)
+ if base != "movie.mkv" && base != "subs.srt" {
+ t.Errorf("unexpected file: %s", base)
+ }
+ }
+}
+
+func TestCleanup(t *testing.T) {
+ dir := t.TempDir()
+
+ // Files that should be removed
+ cleanupFiles := []string{"content.par2", "info.nfo", "checksum.sfv", "movie.rar", "movie.r00"}
+ for _, name := range cleanupFiles {
+ os.WriteFile(filepath.Join(dir, name), []byte("data"), 0o644)
+ }
+
+ // Files that should be kept
+ keepFiles := []string{"movie.mkv", "subs.srt"}
+ for _, name := range keepFiles {
+ os.WriteFile(filepath.Join(dir, name), []byte("data"), 0o644)
+ }
+
+ err := Cleanup(dir)
+ if err != nil {
+ t.Fatalf("Cleanup: %v", err)
+ }
+
+ // Verify cleanup files are gone
+ for _, name := range cleanupFiles {
+ if _, err := os.Stat(filepath.Join(dir, name)); !os.IsNotExist(err) {
+ t.Errorf("expected %s to be removed", name)
+ }
+ }
+
+ // Verify kept files still exist
+ for _, name := range keepFiles {
+ if _, err := os.Stat(filepath.Join(dir, name)); err != nil {
+ t.Errorf("expected %s to exist, got: %v", name, err)
+ }
+ }
+}
+
+func TestPasswordError(t *testing.T) {
+ err := &PasswordError{Archive: "/tmp/movie.rar"}
+ msg := err.Error()
+ if msg != "archive is password protected: /tmp/movie.rar" {
+ t.Errorf("PasswordError.Error() = %q", msg)
+ }
+}
diff --git a/internal/usenet/postprocess/pipeline_test.go b/internal/usenet/postprocess/pipeline_test.go
new file mode 100644
index 0000000..f1a6f26
--- /dev/null
+++ b/internal/usenet/postprocess/pipeline_test.go
@@ -0,0 +1,156 @@
+package postprocess
+
+import (
+ "os"
+ "path/filepath"
+ "testing"
+)
+
+func TestFindPar2File(t *testing.T) {
+ dir := t.TempDir()
+
+ // Create par2 files of different sizes
+ mainPar2 := filepath.Join(dir, "content.par2")
+ vol1 := filepath.Join(dir, "content.vol000+01.par2")
+ vol2 := filepath.Join(dir, "content.vol001+02.par2")
+
+ os.WriteFile(mainPar2, make([]byte, 100), 0o644) // smallest
+ os.WriteFile(vol1, make([]byte, 10000), 0o644)
+ os.WriteFile(vol2, make([]byte, 50000), 0o644)
+
+ files := map[string]string{
+ "content.par2": mainPar2,
+ "content.vol000+01.par2": vol1,
+ "content.vol001+02.par2": vol2,
+ }
+
+ result := findPar2File(files)
+ if result != mainPar2 {
+ t.Errorf("findPar2File() = %q, want %q (smallest par2)", result, mainPar2)
+ }
+}
+
+func TestFindPar2FileNone(t *testing.T) {
+ files := map[string]string{
+ "video.mkv": "/tmp/video.mkv",
+ "subs.srt": "/tmp/subs.srt",
+ }
+
+ result := findPar2File(files)
+ if result != "" {
+ t.Errorf("findPar2File() = %q, want empty", result)
+ }
+}
+
+func TestFindPar2FileEmpty(t *testing.T) {
+ result := findPar2File(map[string]string{})
+ if result != "" {
+ t.Errorf("findPar2File() = %q, want empty", result)
+ }
+}
+
+func TestFindFirstRarPart01(t *testing.T) {
+ files := map[string]string{
+ "movie.part01.rar": "/tmp/movie.part01.rar",
+ "movie.part02.rar": "/tmp/movie.part02.rar",
+ "movie.part03.rar": "/tmp/movie.part03.rar",
+ }
+
+ result := findFirstRar(files)
+ if result != "/tmp/movie.part01.rar" {
+ t.Errorf("findFirstRar() = %q, want part01.rar", result)
+ }
+}
+
+func TestFindFirstRarSingle(t *testing.T) {
+ files := map[string]string{
+ "movie.rar": "/tmp/movie.rar",
+ "movie.r00": "/tmp/movie.r00",
+ "movie.r01": "/tmp/movie.r01",
+ }
+
+ result := findFirstRar(files)
+ if result != "/tmp/movie.rar" {
+ t.Errorf("findFirstRar() = %q, want movie.rar (shortest)", result)
+ }
+}
+
+func TestFindFirstRarSplitFormat(t *testing.T) {
+ files := map[string]string{
+ "movie.001": "/tmp/movie.001",
+ "movie.002": "/tmp/movie.002",
+ }
+
+ result := findFirstRar(files)
+ if result != "/tmp/movie.001" {
+ t.Errorf("findFirstRar() = %q, want movie.001", result)
+ }
+}
+
+func TestFindFirstRarNone(t *testing.T) {
+ files := map[string]string{
+ "video.mkv": "/tmp/video.mkv",
+ "subs.srt": "/tmp/subs.srt",
+ }
+
+ result := findFirstRar(files)
+ if result != "" {
+ t.Errorf("findFirstRar() = %q, want empty", result)
+ }
+}
+
+func TestFindMainFile(t *testing.T) {
+ dir := t.TempDir()
+
+ // Create video files of different sizes
+ small := filepath.Join(dir, "small.mkv")
+ large := filepath.Join(dir, "large.mkv")
+ nonVideo := filepath.Join(dir, "readme.txt")
+
+ os.WriteFile(small, make([]byte, 1000), 0o644)
+ os.WriteFile(large, make([]byte, 5000), 0o644)
+ os.WriteFile(nonVideo, make([]byte, 9000), 0o644)
+
+ result := findMainFile(dir, []string{small, large, nonVideo})
+ if result != large {
+ t.Errorf("findMainFile() = %q, want %q (largest video)", result, large)
+ }
+}
+
+func TestFindMainFileFallbackToDir(t *testing.T) {
+ dir := t.TempDir()
+
+ video := filepath.Join(dir, "movie.mp4")
+ os.WriteFile(video, make([]byte, 5000), 0o644)
+
+ // Pass empty file list — should fallback to scanning dir
+ result := findMainFile(dir, nil)
+ if result != video {
+ t.Errorf("findMainFile() = %q, want %q (dir scan fallback)", result, video)
+ }
+}
+
+func TestFindMainFileEmpty(t *testing.T) {
+ dir := t.TempDir()
+ result := findMainFile(dir, nil)
+ if result != "" {
+ t.Errorf("findMainFile() = %q, want empty", result)
+ }
+}
+
+func TestFindMainFileMultipleFormats(t *testing.T) {
+ dir := t.TempDir()
+
+ mkv := filepath.Join(dir, "movie.mkv")
+ mp4 := filepath.Join(dir, "movie.mp4")
+ avi := filepath.Join(dir, "movie.avi")
+
+ os.WriteFile(mkv, make([]byte, 3000), 0o644)
+ os.WriteFile(mp4, make([]byte, 5000), 0o644) // largest
+ os.WriteFile(avi, make([]byte, 2000), 0o644)
+
+ result := findMainFile(dir, []string{mkv, mp4, avi})
+ if result != mp4 {
+ t.Errorf("findMainFile() = %q, want %q", result, mp4)
+ }
+}