From 3e0f3a5a64d5bfc0dc98b0246ecd33142d0faa73 Mon Sep 17 00:00:00 2001 From: Deivid Soto Date: Tue, 31 Mar 2026 22:05:43 +0200 Subject: [PATCH] feat(cli): upgrade command, rich status, and version cache - Replace `upgrade` stub with real command (alias for `self-update`) - Also register `update` as alias: `unarr update` works too - Rewrite `status` to show full config, disk usage, daemon state, and update availability with colored sections - Add version check cache (1h TTL) so `status` is instant on repeat runs - Guard against division by zero on empty filesystems - Guard against negative durations from clock skew - Guard against stale PID via heartbeat recency check (2 min) - Add comprehensive test coverage across agent, engine, upgrade, usenet, arr, library, mediaserver, and UI packages - Improve Makefile coverage target to exclude cmd/ glue code - Fix stream handler resource cleanup and ffprobe error handling --- Makefile | 9 +- internal/agent/client_test.go | 262 ++++ internal/agent/transport_test.go | 1145 +++++++++++++++++ internal/arr/client_test.go | 396 ++++++ internal/arr/discovery_test.go | 158 +++ internal/cmd/config_menu_test.go | 55 + internal/cmd/daemon_test.go | 55 + internal/cmd/helpers_test.go | 43 + internal/cmd/root.go | 3 +- internal/cmd/status.go | 166 ++- internal/cmd/stream_handler.go | 53 +- internal/cmd/upgrade.go | 30 + internal/engine/manager_test.go | 221 ++++ internal/engine/method_test.go | 50 + internal/engine/organize_expand_test.go | 181 +++ internal/engine/progress_test.go | 419 ++++++ internal/engine/stream_server.go | 72 +- internal/library/mediainfo/ffprobe.go | 6 + internal/library/mediainfo/ffprobe_test.go | 430 +++++++ internal/library/scanner_test.go | 93 ++ internal/library/sync_test.go | 108 ++ internal/mediaserver/detect_test.go | 92 ++ internal/sentry/sentry_test.go | 47 + internal/ui/format_test.go | 214 ++- internal/ui/table_test.go | 122 ++ internal/upgrade/cache.go | 75 ++ internal/upgrade/upgrade.go | 8 +- internal/upgrade/upgrade_test.go | 766 +++++++++++ .../usenet/download/progress_expand_test.go | 632 +++++++++ internal/usenet/nntp/client_test.go | 131 ++ internal/usenet/nzb/parser_test.go | 781 +++++++++++ internal/usenet/postprocess/extract_test.go | 170 +++ internal/usenet/postprocess/pipeline_test.go | 156 +++ 33 files changed, 7084 insertions(+), 65 deletions(-) create mode 100644 internal/arr/client_test.go create mode 100644 internal/cmd/config_menu_test.go create mode 100644 internal/cmd/daemon_test.go create mode 100644 internal/cmd/helpers_test.go create mode 100644 internal/cmd/upgrade.go create mode 100644 internal/engine/method_test.go create mode 100644 internal/engine/organize_expand_test.go create mode 100644 internal/engine/progress_test.go create mode 100644 internal/library/mediainfo/ffprobe_test.go create mode 100644 internal/library/scanner_test.go create mode 100644 internal/library/sync_test.go create mode 100644 internal/sentry/sentry_test.go create mode 100644 internal/ui/table_test.go create mode 100644 internal/upgrade/cache.go create mode 100644 internal/usenet/download/progress_expand_test.go create mode 100644 internal/usenet/nntp/client_test.go create mode 100644 internal/usenet/postprocess/extract_test.go create mode 100644 internal/usenet/postprocess/pipeline_test.go diff --git a/Makefile b/Makefile index 4a8245e..6207d50 100644 --- a/Makefile +++ b/Makefile @@ -19,10 +19,13 @@ test: lint: golangci-lint run ./... -## Run tests with coverage report +## Run tests with coverage report (excludes CLI layer — cmd/ is glue code) +COVER_PKGS = $(shell go list ./... | grep -v '/cmd') coverage: - go test -race -coverprofile=coverage.out -covermode=atomic ./... - go tool cover -func=coverage.out + go test -race -coverprofile=coverage.out -covermode=atomic $(COVER_PKGS) + @echo "──────────────────────────────────────" + @go tool cover -func=coverage.out | tail -1 + @echo "──────────────────────────────────────" go tool cover -html=coverage.out -o coverage.html ## Format code diff --git a/internal/agent/client_test.go b/internal/agent/client_test.go index 9266b74..c8ce68d 100644 --- a/internal/agent/client_test.go +++ b/internal/agent/client_test.go @@ -324,3 +324,265 @@ func TestHeartbeatWithoutUpgradeSignal(t *testing.T) { t.Errorf("expected no upgrade signal, got %+v", resp.Upgrade) } } + +func TestDeregister(t *testing.T) { + var received struct { + AgentID string `json:"agentId"` + } + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/internal/agent/deregister" { + t.Errorf("path = %s", r.URL.Path) + } + if r.Method != http.MethodPost { + t.Errorf("method = %s, want POST", r.Method) + } + json.NewDecoder(r.Body).Decode(&received) + json.NewEncoder(w).Encode(StatusResponse{Success: true}) + })) + defer srv.Close() + + c := NewClient(srv.URL, "test-key", "unarr-test") + err := c.Deregister(context.Background(), "agent-42") + if err != nil { + t.Fatalf("Deregister failed: %v", err) + } + if received.AgentID != "agent-42" { + t.Errorf("agentId = %q, want agent-42", received.AgentID) + } +} + +func TestBatchReportStatus(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/internal/agent/status" { + t.Errorf("path = %s", r.URL.Path) + } + var req BatchStatusRequest + json.NewDecoder(r.Body).Decode(&req) + if len(req.Updates) != 2 { + t.Errorf("expected 2 updates, got %d", len(req.Updates)) + } + json.NewEncoder(w).Encode(BatchStatusResponse{ + Results: []StatusResponse{ + {Success: true}, + {Success: true, Cancelled: true}, + }, + }) + })) + defer srv.Close() + + c := NewClient(srv.URL, "test-key", "unarr-test") + resp, err := c.BatchReportStatus(context.Background(), []StatusUpdate{ + {TaskID: "t1", Status: "downloading"}, + {TaskID: "t2", Status: "completed"}, + }) + if err != nil { + t.Fatalf("BatchReportStatus failed: %v", err) + } + if len(resp.Results) != 2 { + t.Fatalf("expected 2 results, got %d", len(resp.Results)) + } + if !resp.Results[1].Cancelled { + t.Error("expected result[1].Cancelled=true") + } +} + +func TestSearchNzbs(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/internal/agent/nzb-search" { + t.Errorf("path = %s", r.URL.Path) + } + json.NewEncoder(w).Encode(NzbSearchResponse{ + Results: []NzbSearchResult{ + {NzbID: "nzb-1", Title: "Movie.2023.1080p"}, + }, + }) + })) + defer srv.Close() + + c := NewClient(srv.URL, "test-key", "unarr-test") + resp, err := c.SearchNzbs(context.Background(), NzbSearchParams{Query: "Movie"}) + if err != nil { + t.Fatalf("SearchNzbs failed: %v", err) + } + if len(resp.Results) != 1 { + t.Fatalf("expected 1 result, got %d", len(resp.Results)) + } + if resp.Results[0].NzbID != "nzb-1" { + t.Errorf("nzb ID = %q, want nzb-1", resp.Results[0].NzbID) + } +} + +func TestDownloadNzb(t *testing.T) { + nzbContent := []byte(`test`) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/internal/agent/nzb-download" { + t.Errorf("path = %s", r.URL.Path) + } + if r.URL.Query().Get("nzbId") != "nzb-42" { + t.Errorf("nzbId = %q, want nzb-42", r.URL.Query().Get("nzbId")) + } + w.Write(nzbContent) + })) + defer srv.Close() + + c := NewClient(srv.URL, "test-key", "unarr-test") + data, err := c.DownloadNzb(context.Background(), "nzb-42") + if err != nil { + t.Fatalf("DownloadNzb failed: %v", err) + } + if string(data) != string(nzbContent) { + t.Errorf("nzb content mismatch") + } +} + +func TestDownloadNzbError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + w.Write([]byte("NZB not found")) + })) + defer srv.Close() + + c := NewClient(srv.URL, "test-key", "unarr-test") + _, err := c.DownloadNzb(context.Background(), "bad-id") + if err == nil { + t.Fatal("expected error for 404 response") + } +} + +func TestGetUsenetCredentials(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/internal/agent/usenet-credentials" { + t.Errorf("path = %s", r.URL.Path) + } + json.NewEncoder(w).Encode(UsenetCredentials{ + Host: "news.example.com", + Port: 563, + SSL: true, + Username: "user1", + Password: "pass1", + MaxConnections: 10, + }) + })) + defer srv.Close() + + c := NewClient(srv.URL, "test-key", "unarr-test") + creds, err := c.GetUsenetCredentials(context.Background()) + if err != nil { + t.Fatalf("GetUsenetCredentials failed: %v", err) + } + if creds.Host != "news.example.com" { + t.Errorf("host = %q, want news.example.com", creds.Host) + } + if creds.Username != "user1" { + t.Errorf("username = %q, want user1", creds.Username) + } + if creds.MaxConnections != 10 { + t.Errorf("maxConnections = %d, want 10", creds.MaxConnections) + } +} + +func TestGetUsenetUsage(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/internal/agent/usenet-usage" { + t.Errorf("path = %s", r.URL.Path) + } + json.NewEncoder(w).Encode(UsenetUsageResponse{ + UsedBytes: 5368709120, + QuotaBytes: 10737418240, + }) + })) + defer srv.Close() + + c := NewClient(srv.URL, "test-key", "unarr-test") + usage, err := c.GetUsenetUsage(context.Background()) + if err != nil { + t.Fatalf("GetUsenetUsage failed: %v", err) + } + if usage.UsedBytes != 5368709120 { + t.Errorf("usedBytes = %d", usage.UsedBytes) + } +} + +func TestConfigureDebrid(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/internal/agent/debrid-config" { + t.Errorf("path = %s", r.URL.Path) + } + json.NewEncoder(w).Encode(ConfigureDebridResponse{Success: true}) + })) + defer srv.Close() + + c := NewClient(srv.URL, "test-key", "unarr-test") + resp, err := c.ConfigureDebrid(context.Background(), ConfigureDebridRequest{ + Provider: "real-debrid", + Token: "rd-token-123", + }) + if err != nil { + t.Fatalf("ConfigureDebrid failed: %v", err) + } + if !resp.Success { + t.Error("expected success=true") + } +} + +func TestBatchDownload(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/internal/agent/batch-download" { + t.Errorf("path = %s", r.URL.Path) + } + json.NewEncoder(w).Encode(BatchDownloadResponse{ + Queued: 3, + NotFound: 1, + }) + })) + defer srv.Close() + + c := NewClient(srv.URL, "test-key", "unarr-test") + resp, err := c.BatchDownload(context.Background(), BatchDownloadRequest{}) + if err != nil { + t.Fatalf("BatchDownload failed: %v", err) + } + if resp.Queued != 3 { + t.Errorf("queued = %d, want 3", resp.Queued) + } +} + +func TestSyncLibrary(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/internal/agent/library-sync" { + t.Errorf("path = %s", r.URL.Path) + } + json.NewEncoder(w).Encode(LibrarySyncResponse{ + Matched: 10, + Synced: 15, + Removed: 2, + }) + })) + defer srv.Close() + + c := NewClient(srv.URL, "test-key", "unarr-test") + resp, err := c.SyncLibrary(context.Background(), LibrarySyncRequest{}) + if err != nil { + t.Fatalf("SyncLibrary failed: %v", err) + } + if resp.Matched != 10 { + t.Errorf("matched = %d, want 10", resp.Matched) + } + if resp.Synced != 15 { + t.Errorf("synced = %d, want 15", resp.Synced) + } +} + +func TestHTMLErrorResponse(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadGateway) + w.Write([]byte("502 Bad Gateway")) + })) + defer srv.Close() + + c := NewClient(srv.URL, "test-key", "unarr-test") + _, err := c.Register(context.Background(), RegisterRequest{AgentID: "x"}) + if err == nil { + t.Fatal("expected error for HTML error page") + } +} diff --git a/internal/agent/transport_test.go b/internal/agent/transport_test.go index a9c7e5d..be2f6c6 100644 --- a/internal/agent/transport_test.go +++ b/internal/agent/transport_test.go @@ -443,3 +443,1148 @@ func TestHybridTransportWSDisconnectSwitchesToHTTP(t *testing.T) { t.Errorf("expected http after disconnect, got %s", h.Mode()) } } + +// ── Additional HTTP Transport Tests ───────────────────────────────────────── + +func TestNewHTTPTransportConstructor(t *testing.T) { + tr := NewHTTPTransport("http://example.com", "my-key", "my-agent/1.0") + + if tr.client == nil { + t.Fatal("expected client to be non-nil") + } + if tr.events == nil { + t.Fatal("expected events channel to be non-nil") + } + // events channel should have capacity 10 + if cap(tr.events) != 10 { + t.Errorf("expected events capacity 10, got %d", cap(tr.events)) + } +} + +func TestHTTPTransportConnectAndCloseAreNoOps(t *testing.T) { + tr := NewHTTPTransport("http://localhost", "key", "ua") + + if err := tr.Connect(context.Background()); err != nil { + t.Errorf("Connect should be a no-op, got error: %v", err) + } + if err := tr.Close(); err != nil { + t.Errorf("Close should be a no-op, got error: %v", err) + } +} + +func TestHTTPTransportClientAccessor(t *testing.T) { + tr := NewHTTPTransport("http://localhost", "key", "ua") + c := tr.Client() + if c == nil { + t.Fatal("Client() should return the underlying client") + } + if c != tr.client { + t.Error("Client() should return the same instance stored internally") + } +} + +func TestHTTPTransportSendHeartbeat(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("expected POST, got %s", r.Method) + } + if !strings.Contains(r.URL.Path, "heartbeat") { + t.Errorf("expected heartbeat path, got %s", r.URL.Path) + } + json.NewEncoder(w).Encode(HeartbeatResponse{ + Success: true, + Watching: true, + Upgrade: &UpgradeSignal{Version: "9.9.9"}, + }) + })) + defer srv.Close() + + tr := NewHTTPTransport(srv.URL, "key", "ua") + resp, err := tr.SendHeartbeat(context.Background(), HeartbeatRequest{ + AgentID: "a1", + Name: "test", + Version: "1.0", + }) + if err != nil { + t.Fatalf("SendHeartbeat failed: %v", err) + } + if !resp.Success { + t.Error("expected success") + } + if !resp.Watching { + t.Error("expected watching=true") + } + if resp.Upgrade == nil || resp.Upgrade.Version != "9.9.9" { + t.Error("expected upgrade version 9.9.9") + } +} + +func TestHTTPTransportSendProgress(t *testing.T) { + var received StatusUpdate + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewDecoder(r.Body).Decode(&received) + json.NewEncoder(w).Encode(StatusResponse{ + Success: true, + Cancelled: true, + }) + })) + defer srv.Close() + + tr := NewHTTPTransport(srv.URL, "key", "ua") + resp, err := tr.SendProgress(context.Background(), StatusUpdate{ + TaskID: "task-1", + Status: "downloading", + Progress: 55, + SpeedBps: 1024000, + }) + if err != nil { + t.Fatalf("SendProgress failed: %v", err) + } + if !resp.Success { + t.Error("expected success") + } + if !resp.Cancelled { + t.Error("expected cancelled flag") + } + if received.TaskID != "task-1" { + t.Errorf("expected task-1, got %s", received.TaskID) + } + if received.Progress != 55 { + t.Errorf("expected progress 55, got %d", received.Progress) + } +} + +func TestHTTPTransportClaimTasks(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + t.Errorf("expected GET, got %s", r.Method) + } + agentID := r.URL.Query().Get("agentId") + if agentID != "agent-42" { + t.Errorf("expected agentId=agent-42, got %s", agentID) + } + json.NewEncoder(w).Encode(TasksResponse{ + Tasks: []Task{ + {ID: "t1", Title: "Movie 1", InfoHash: "abc"}, + {ID: "t2", Title: "Movie 2", InfoHash: "def"}, + }, + }) + })) + defer srv.Close() + + tr := NewHTTPTransport(srv.URL, "key", "ua") + resp, err := tr.ClaimTasks(context.Background(), "agent-42") + if err != nil { + t.Fatalf("ClaimTasks failed: %v", err) + } + if len(resp.Tasks) != 2 { + t.Fatalf("expected 2 tasks, got %d", len(resp.Tasks)) + } + if resp.Tasks[0].Title != "Movie 1" { + t.Errorf("expected Movie 1, got %s", resp.Tasks[0].Title) + } +} + +func TestHTTPTransportDeregister(t *testing.T) { + var called bool + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + json.NewEncoder(w).Encode(StatusResponse{Success: true}) + })) + defer srv.Close() + + tr := NewHTTPTransport(srv.URL, "key", "ua") + err := tr.Deregister(context.Background(), "agent-1") + if err != nil { + t.Fatalf("Deregister failed: %v", err) + } + if !called { + t.Error("expected server to be called") + } +} + +func TestHTTPTransportBatchReportStatus(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(BatchStatusResponse{ + Results: []StatusResponse{ + {Success: true}, + {Success: true, Cancelled: true}, + }, + Watching: true, + }) + })) + defer srv.Close() + + tr := NewHTTPTransport(srv.URL, "key", "ua") + resp, err := tr.BatchReportStatus(context.Background(), []StatusUpdate{ + {TaskID: "t1", Status: "downloading", Progress: 10}, + {TaskID: "t2", Status: "completed", Progress: 100}, + }) + if err != nil { + t.Fatalf("BatchReportStatus failed: %v", err) + } + if len(resp.Results) != 2 { + t.Fatalf("expected 2 results, got %d", len(resp.Results)) + } + if !resp.Watching { + t.Error("expected watching=true") + } + if !resp.Results[1].Cancelled { + t.Error("expected second result to be cancelled") + } +} + +func TestHTTPTransportAuthHeader(t *testing.T) { + var gotAuth string + var gotUA string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + gotUA = r.Header.Get("User-Agent") + json.NewEncoder(w).Encode(RegisterResponse{Success: true}) + })) + defer srv.Close() + + tr := NewHTTPTransport(srv.URL, "secret-key-123", "unarr/2.0") + tr.Register(context.Background(), RegisterRequest{AgentID: "a1"}) + + if gotAuth != "Bearer secret-key-123" { + t.Errorf("expected Bearer secret-key-123, got %s", gotAuth) + } + if gotUA != "unarr/2.0" { + t.Errorf("expected unarr/2.0, got %s", gotUA) + } +} + +// ── Additional WebSocket Transport Tests ──────────────────────────────────── + +func TestNewWSTransportConstructor(t *testing.T) { + tr := NewWSTransport("ws://example.com/ws", "api-key", "agent-1", "ua/1.0") + + if tr.Mode() != "ws" { + t.Errorf("expected ws mode, got %s", tr.Mode()) + } + if tr.wsURL != "ws://example.com/ws" { + t.Errorf("expected ws URL, got %s", tr.wsURL) + } + if tr.apiKey != "api-key" { + t.Errorf("expected api-key, got %s", tr.apiKey) + } + if tr.agentID != "agent-1" { + t.Errorf("expected agent-1, got %s", tr.agentID) + } + if tr.userAgent != "ua/1.0" { + t.Errorf("expected ua/1.0, got %s", tr.userAgent) + } + if cap(tr.events) != 50 { + t.Errorf("expected events capacity 50, got %d", cap(tr.events)) + } + if tr.authDone == nil { + t.Fatal("expected authDone channel to be non-nil") + } +} + +func TestWSTransportClaimTasksIsNoOp(t *testing.T) { + tr := NewWSTransport("ws://localhost", "key", "a1", "ua") + resp, err := tr.ClaimTasks(context.Background(), "a1") + if err != nil { + t.Fatalf("ClaimTasks should succeed (no-op): %v", err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if len(resp.Tasks) != 0 { + t.Errorf("expected 0 tasks, got %d", len(resp.Tasks)) + } +} + +func TestWSTransportCloseWhenNotConnected(t *testing.T) { + tr := NewWSTransport("ws://localhost", "key", "a1", "ua") + // Close without ever connecting should not panic or error + if err := tr.Close(); err != nil { + t.Errorf("Close on unconnected transport should return nil, got %v", err) + } +} + +func TestWSTransportSendWhenNotConnected(t *testing.T) { + tr := NewWSTransport("ws://localhost", "key", "a1", "ua") + // Attempting to send a heartbeat without connecting should fail + _, err := tr.SendHeartbeat(context.Background(), HeartbeatRequest{AgentID: "a1"}) + if err == nil { + t.Error("expected error when sending without connection") + } +} + +func TestWSTransportConnectBadURL(t *testing.T) { + tr := NewWSTransport("ws://127.0.0.1:1", "key", "a1", "ua") + err := tr.Connect(context.Background()) + if err == nil { + t.Error("expected error connecting to invalid address") + } +} + +func TestWSTransportSendHeartbeatWithDisk(t *testing.T) { + var receivedMsg map[string]interface{} + var mu sync.Mutex + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + + // Read auth + conn.ReadMessage() + conn.WriteJSON(wsRegisteredMessage{Type: "registered", User: UserInfo{}}) + + // Read heartbeat + _, msg, err := conn.ReadMessage() + if err != nil { + return + } + mu.Lock() + json.Unmarshal(msg, &receivedMsg) + mu.Unlock() + + time.Sleep(500 * time.Millisecond) + })) + defer srv.Close() + + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + tr := NewWSTransport(wsURL, "key", "a1", "ua") + + ctx := context.Background() + tr.Connect(ctx) + defer tr.Close() + tr.Register(ctx, RegisterRequest{AgentID: "a1"}) + + time.Sleep(50 * time.Millisecond) + resp, err := tr.SendHeartbeat(ctx, HeartbeatRequest{ + AgentID: "a1", + DiskFreeBytes: 500000000, + DiskTotalBytes: 1000000000, + }) + if err != nil { + t.Fatalf("SendHeartbeat failed: %v", err) + } + if !resp.Success { + t.Error("expected success") + } + + time.Sleep(100 * time.Millisecond) + mu.Lock() + defer mu.Unlock() + if receivedMsg["type"] != "heartbeat" { + t.Errorf("expected heartbeat, got %v", receivedMsg["type"]) + } + disk, ok := receivedMsg["disk"].(map[string]interface{}) + if !ok { + t.Fatal("expected disk field in heartbeat message") + } + if disk["free"].(float64) != 500000000 { + t.Errorf("expected free=500000000, got %v", disk["free"]) + } + if disk["total"].(float64) != 1000000000 { + t.Errorf("expected total=1000000000, got %v", disk["total"]) + } +} + +func TestWSTransportSendHeartbeatWithoutDisk(t *testing.T) { + var receivedMsg map[string]interface{} + var mu sync.Mutex + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + + conn.ReadMessage() + conn.WriteJSON(wsRegisteredMessage{Type: "registered", User: UserInfo{}}) + + _, msg, err := conn.ReadMessage() + if err != nil { + return + } + mu.Lock() + json.Unmarshal(msg, &receivedMsg) + mu.Unlock() + + time.Sleep(500 * time.Millisecond) + })) + defer srv.Close() + + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + tr := NewWSTransport(wsURL, "key", "a1", "ua") + + ctx := context.Background() + tr.Connect(ctx) + defer tr.Close() + tr.Register(ctx, RegisterRequest{AgentID: "a1"}) + + time.Sleep(50 * time.Millisecond) + resp, err := tr.SendHeartbeat(ctx, HeartbeatRequest{AgentID: "a1"}) + if err != nil { + t.Fatalf("SendHeartbeat failed: %v", err) + } + if !resp.Success { + t.Error("expected success") + } + + time.Sleep(100 * time.Millisecond) + mu.Lock() + defer mu.Unlock() + if receivedMsg["type"] != "heartbeat" { + t.Errorf("expected heartbeat, got %v", receivedMsg["type"]) + } + // disk field should be absent when no disk info provided + if _, exists := receivedMsg["disk"]; exists { + t.Error("expected no disk field when disk info is zero") + } +} + +func TestWSTransportDeregisterClosesConnection(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + conn.ReadMessage() + conn.WriteJSON(wsRegisteredMessage{Type: "registered", User: UserInfo{}}) + time.Sleep(500 * time.Millisecond) + })) + defer srv.Close() + + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + tr := NewWSTransport(wsURL, "key", "a1", "ua") + + ctx := context.Background() + tr.Connect(ctx) + tr.Register(ctx, RegisterRequest{AgentID: "a1"}) + + err := tr.Deregister(ctx, "a1") + if err != nil { + t.Fatalf("Deregister failed: %v", err) + } + + // After deregister, send should fail (connection closed) + _, err = tr.SendHeartbeat(ctx, HeartbeatRequest{AgentID: "a1"}) + if err == nil { + t.Error("expected error sending after deregister") + } +} + +func TestWSTransportReceiveStreamRequests(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + + conn.ReadMessage() + conn.WriteJSON(wsRegisteredMessage{Type: "registered", User: UserInfo{}}) + + time.Sleep(50 * time.Millisecond) + conn.WriteJSON(wsTasksMessage{ + Type: "tasks", + Tasks: []Task{}, + StreamRequests: []StreamRequest{ + {TaskID: "t1", FilePath: "/data/movie.mkv"}, + }, + }) + + time.Sleep(500 * time.Millisecond) + })) + defer srv.Close() + + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + tr := NewWSTransport(wsURL, "key", "a1", "ua") + + ctx := context.Background() + tr.Connect(ctx) + defer tr.Close() + tr.Register(ctx, RegisterRequest{AgentID: "a1"}) + + select { + case event := <-tr.Events(): + if event.Type != "tasks" { + t.Errorf("expected tasks, got %s", event.Type) + } + if len(event.Tasks.StreamRequests) != 1 { + t.Fatalf("expected 1 stream request, got %d", len(event.Tasks.StreamRequests)) + } + if event.Tasks.StreamRequests[0].FilePath != "/data/movie.mkv" { + t.Errorf("expected /data/movie.mkv, got %s", event.Tasks.StreamRequests[0].FilePath) + } + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for tasks event with stream requests") + } +} + +func TestWSTransportReceiveErrorMessage(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + + conn.ReadMessage() + conn.WriteJSON(wsRegisteredMessage{Type: "registered", User: UserInfo{}}) + + time.Sleep(50 * time.Millisecond) + // Send an error message (should be logged, not emitted as event) + conn.WriteJSON(map[string]string{ + "type": "error", + "message": "rate limited", + }) + + time.Sleep(200 * time.Millisecond) + })) + defer srv.Close() + + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + tr := NewWSTransport(wsURL, "key", "a1", "ua") + + ctx := context.Background() + tr.Connect(ctx) + defer tr.Close() + tr.Register(ctx, RegisterRequest{AgentID: "a1"}) + + // Error messages are logged but not emitted — events channel should be quiet + select { + case event := <-tr.Events(): + // If we get disconnected, that's acceptable (server closes after delay) + if event.Type != "disconnected" { + t.Errorf("unexpected event type: %s", event.Type) + } + case <-time.After(300 * time.Millisecond): + // Expected: no event emitted for error messages + } +} + +func TestWSTransportRegisterTimeout(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + conn.ReadMessage() + // Never send registered response — should timeout + time.Sleep(20 * time.Second) + })) + defer srv.Close() + + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + tr := NewWSTransport(wsURL, "key", "a1", "ua") + + ctx := context.Background() + tr.Connect(ctx) + defer tr.Close() + + // Use a context with short timeout to avoid waiting 15s + ctxShort, cancel := context.WithTimeout(ctx, 200*time.Millisecond) + defer cancel() + + _, err := tr.Register(ctxShort, RegisterRequest{AgentID: "a1"}) + if err == nil { + t.Error("expected timeout error from Register") + } +} + +// ── Additional Hybrid Transport Tests ─────────────────────────────────────── + +func TestNewHybridTransportConstructor(t *testing.T) { + wsT := NewWSTransport("ws://localhost", "key", "a1", "ua") + httpT := NewHTTPTransport("http://localhost", "key", "ua") + + h := NewHybridTransport(wsT, httpT) + + if h.Mode() != "http" { + t.Errorf("expected initial mode http, got %s", h.Mode()) + } + if cap(h.events) != 50 { + t.Errorf("expected events capacity 50, got %d", cap(h.events)) + } + if h.ws != wsT { + t.Error("expected ws transport to match") + } + if h.http != httpT { + t.Error("expected http transport to match") + } + if h.reconnectStop == nil { + t.Error("expected reconnectStop channel to be non-nil") + } +} + +func TestHybridTransportCloseIsIdempotent(t *testing.T) { + wsT := NewWSTransport("ws://127.0.0.1:1", "key", "a1", "ua") + httpT := NewHTTPTransport("http://localhost", "key", "ua") + + h := NewHybridTransport(wsT, httpT) + // Close twice should not panic + if err := h.Close(); err != nil { + t.Errorf("first Close failed: %v", err) + } + if err := h.Close(); err != nil { + t.Errorf("second Close failed: %v", err) + } +} + +func TestHybridTransportHTTPModeRegister(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(RegisterResponse{ + Success: true, + User: UserInfo{Name: "HTTPUser", Plan: "free"}, + }) + })) + defer srv.Close() + + wsT := NewWSTransport("ws://127.0.0.1:1", "key", "a1", "ua") + httpT := NewHTTPTransport(srv.URL, "key", "ua") + + h := NewHybridTransport(wsT, httpT) + // Force HTTP mode (default) + h.mode.Store("http") + + resp, err := h.Register(context.Background(), RegisterRequest{AgentID: "a1"}) + if err != nil { + t.Fatalf("Register failed: %v", err) + } + if resp.User.Name != "HTTPUser" { + t.Errorf("expected HTTPUser, got %s", resp.User.Name) + } +} + +func TestHybridTransportHTTPModeClaimTasks(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(TasksResponse{ + Tasks: []Task{{ID: "t1", Title: "Test"}}, + }) + })) + defer srv.Close() + + wsT := NewWSTransport("ws://127.0.0.1:1", "key", "a1", "ua") + httpT := NewHTTPTransport(srv.URL, "key", "ua") + + h := NewHybridTransport(wsT, httpT) + h.mode.Store("http") + + resp, err := h.ClaimTasks(context.Background(), "a1") + if err != nil { + t.Fatalf("ClaimTasks failed: %v", err) + } + if len(resp.Tasks) != 1 { + t.Errorf("expected 1 task, got %d", len(resp.Tasks)) + } +} + +func TestHybridTransportHTTPModeDeregister(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(StatusResponse{Success: true}) + })) + defer srv.Close() + + wsT := NewWSTransport("ws://127.0.0.1:1", "key", "a1", "ua") + httpT := NewHTTPTransport(srv.URL, "key", "ua") + + h := NewHybridTransport(wsT, httpT) + h.mode.Store("http") + + err := h.Deregister(context.Background(), "a1") + if err != nil { + t.Fatalf("Deregister failed: %v", err) + } +} + +func TestHybridTransportHTTPModeSendHeartbeat(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(HeartbeatResponse{Success: true, Watching: true}) + })) + defer srv.Close() + + wsT := NewWSTransport("ws://127.0.0.1:1", "key", "a1", "ua") + httpT := NewHTTPTransport(srv.URL, "key", "ua") + + h := NewHybridTransport(wsT, httpT) + h.mode.Store("http") + + resp, err := h.SendHeartbeat(context.Background(), HeartbeatRequest{AgentID: "a1"}) + if err != nil { + t.Fatalf("SendHeartbeat failed: %v", err) + } + if !resp.Success { + t.Error("expected success") + } + if !resp.Watching { + t.Error("expected watching=true") + } +} + +func TestHybridTransportHTTPModeSendProgress(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(StatusResponse{Success: true}) + })) + defer srv.Close() + + wsT := NewWSTransport("ws://127.0.0.1:1", "key", "a1", "ua") + httpT := NewHTTPTransport(srv.URL, "key", "ua") + + h := NewHybridTransport(wsT, httpT) + h.mode.Store("http") + + resp, err := h.SendProgress(context.Background(), StatusUpdate{ + TaskID: "t1", + Status: "completed", + Progress: 100, + }) + if err != nil { + t.Fatalf("SendProgress failed: %v", err) + } + if !resp.Success { + t.Error("expected success") + } +} + +func TestHybridTransportWSModeClaimTasksIsNoOp(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + time.Sleep(500 * time.Millisecond) + })) + defer srv.Close() + + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + wsT := NewWSTransport(wsURL, "key", "a1", "ua") + httpT := NewHTTPTransport("http://localhost", "key", "ua") + + h := NewHybridTransport(wsT, httpT) + h.Connect(context.Background()) + defer h.Close() + + // In WS mode, ClaimTasks delegates to WS which is a no-op + resp, err := h.ClaimTasks(context.Background(), "a1") + if err != nil { + t.Fatalf("ClaimTasks failed: %v", err) + } + if len(resp.Tasks) != 0 { + t.Errorf("expected 0 tasks in WS mode, got %d", len(resp.Tasks)) + } +} + +func TestHybridTransportEventsChannel(t *testing.T) { + wsT := NewWSTransport("ws://127.0.0.1:1", "key", "a1", "ua") + httpT := NewHTTPTransport("http://localhost", "key", "ua") + + h := NewHybridTransport(wsT, httpT) + ch := h.Events() + if ch == nil { + t.Fatal("Events() should return non-nil channel") + } + // Verify it is the correct channel + if cap(ch) != 50 { + t.Errorf("expected events capacity 50, got %d", cap(ch)) + } +} + +func TestHybridTransportSwitchToHTTPIdempotent(t *testing.T) { + wsT := NewWSTransport("ws://127.0.0.1:1", "key", "a1", "ua") + httpT := NewHTTPTransport("http://localhost", "key", "ua") + + h := NewHybridTransport(wsT, httpT) + // Already in HTTP mode, switchToHTTP should be a no-op + h.mode.Store("http") + h.switchToHTTP() // should not panic or start reconnect + + if h.Mode() != "http" { + t.Errorf("expected http, got %s", h.Mode()) + } +} + +// ── Daemon Constructor & Utility Tests ────────────────────────────────────── + +func TestNewDaemonDefaults(t *testing.T) { + tr := NewHTTPTransport("http://localhost", "key", "ua") + d := NewDaemon(DaemonConfig{ + AgentID: "a1", + AgentName: "test", + Version: "1.0", + DownloadDir: "/tmp", + }, tr) + + if d.cfg.PollInterval != 30*time.Second { + t.Errorf("expected default PollInterval 30s, got %v", d.cfg.PollInterval) + } + if d.cfg.HeartbeatInterval != 30*time.Second { + t.Errorf("expected default HeartbeatInterval 30s, got %v", d.cfg.HeartbeatInterval) + } + if d.Transport() != tr { + t.Error("Transport() should return the configured transport") + } + if d.pollNow == nil { + t.Error("pollNow channel should be initialized") + } +} + +func TestNewDaemonCustomIntervals(t *testing.T) { + tr := NewHTTPTransport("http://localhost", "key", "ua") + d := NewDaemon(DaemonConfig{ + AgentID: "a1", + AgentName: "test", + Version: "1.0", + DownloadDir: "/tmp", + PollInterval: 10 * time.Second, + HeartbeatInterval: 15 * time.Second, + }, tr) + + if d.cfg.PollInterval != 10*time.Second { + t.Errorf("expected PollInterval 10s, got %v", d.cfg.PollInterval) + } + if d.cfg.HeartbeatInterval != 15*time.Second { + t.Errorf("expected HeartbeatInterval 15s, got %v", d.cfg.HeartbeatInterval) + } +} + +func TestDaemonTriggerPoll(t *testing.T) { + tr := NewHTTPTransport("http://localhost", "key", "ua") + d := NewDaemon(DaemonConfig{ + AgentID: "a1", + AgentName: "test", + Version: "1.0", + DownloadDir: "/tmp", + }, tr) + + // First trigger should succeed + d.TriggerPoll() + + // Channel should have one signal + select { + case <-d.pollNow: + // good + default: + t.Error("expected signal on pollNow channel") + } + + // Second trigger when channel is empty should also succeed + d.TriggerPoll() + select { + case <-d.pollNow: + // good + default: + t.Error("expected signal on pollNow channel after second trigger") + } +} + +func TestDaemonTriggerPollNonBlocking(t *testing.T) { + tr := NewHTTPTransport("http://localhost", "key", "ua") + d := NewDaemon(DaemonConfig{ + AgentID: "a1", + AgentName: "test", + Version: "1.0", + DownloadDir: "/tmp", + }, tr) + + // Fill the channel (capacity 1) + d.TriggerPoll() + // Second call should not block even though channel is full + done := make(chan struct{}) + go func() { + d.TriggerPoll() + close(done) + }() + + select { + case <-done: + // good, did not block + case <-time.After(1 * time.Second): + t.Fatal("TriggerPoll blocked on full channel") + } +} + +func TestDaemonHandleEventTasks(t *testing.T) { + tr := NewHTTPTransport("http://localhost", "key", "ua") + d := NewDaemon(DaemonConfig{ + AgentID: "a1", + AgentName: "test", + Version: "1.0", + DownloadDir: "/tmp", + }, tr) + + var claimedTasks []Task + d.OnTasksClaimed = func(tasks []Task) { + claimedTasks = tasks + } + + d.handleEvent(ServerEvent{ + Type: "tasks", + Tasks: &TasksResponse{ + Tasks: []Task{ + {ID: "t1", Title: "Movie 1"}, + {ID: "t2", Title: "Movie 2"}, + }, + }, + }) + + if len(claimedTasks) != 2 { + t.Fatalf("expected 2 claimed tasks, got %d", len(claimedTasks)) + } + if claimedTasks[0].Title != "Movie 1" { + t.Errorf("expected Movie 1, got %s", claimedTasks[0].Title) + } +} + +func TestDaemonHandleEventTasksWithStreamRequests(t *testing.T) { + tr := NewHTTPTransport("http://localhost", "key", "ua") + d := NewDaemon(DaemonConfig{ + AgentID: "a1", + AgentName: "test", + Version: "1.0", + DownloadDir: "/tmp", + }, tr) + + var streamReqs []StreamRequest + d.OnStreamRequested = func(req StreamRequest) { + streamReqs = append(streamReqs, req) + } + + d.handleEvent(ServerEvent{ + Type: "tasks", + Tasks: &TasksResponse{ + Tasks: []Task{}, + StreamRequests: []StreamRequest{ + {TaskID: "t1", FilePath: "/data/movie.mkv"}, + {TaskID: "t2", FilePath: "/data/show.mkv"}, + }, + }, + }) + + if len(streamReqs) != 2 { + t.Fatalf("expected 2 stream requests, got %d", len(streamReqs)) + } + if streamReqs[0].FilePath != "/data/movie.mkv" { + t.Errorf("expected /data/movie.mkv, got %s", streamReqs[0].FilePath) + } +} + +func TestDaemonHandleEventUpgrade(t *testing.T) { + tr := NewHTTPTransport("http://localhost", "key", "ua") + d := NewDaemon(DaemonConfig{ + AgentID: "a1", + AgentName: "test", + Version: "1.0", + DownloadDir: "/tmp", + }, tr) + + d.handleEvent(ServerEvent{ + Type: "upgrade", + Upgrade: &UpgradeSignal{Version: "2.0.0"}, + }) + + if d.lastNotifiedVersion != "2.0.0" { + t.Errorf("expected lastNotifiedVersion 2.0.0, got %s", d.lastNotifiedVersion) + } + + // Same version again should not update (already notified) + d.lastNotifiedVersion = "2.0.0" + d.handleEvent(ServerEvent{ + Type: "upgrade", + Upgrade: &UpgradeSignal{Version: "2.0.0"}, + }) + // Still 2.0.0, no change + if d.lastNotifiedVersion != "2.0.0" { + t.Errorf("expected lastNotifiedVersion unchanged at 2.0.0, got %s", d.lastNotifiedVersion) + } +} + +func TestDaemonHandleEventControl(t *testing.T) { + tr := NewHTTPTransport("http://localhost", "key", "ua") + d := NewDaemon(DaemonConfig{ + AgentID: "a1", + AgentName: "test", + Version: "1.0", + DownloadDir: "/tmp", + }, tr) + + var gotAction, gotTaskID string + d.OnControlAction = func(action, taskID string) { + gotAction = action + gotTaskID = taskID + } + + d.handleEvent(ServerEvent{ + Type: "control", + Control: &ControlAction{Action: "cancel", TaskID: "task-99"}, + }) + + if gotAction != "cancel" { + t.Errorf("expected cancel, got %s", gotAction) + } + if gotTaskID != "task-99" { + t.Errorf("expected task-99, got %s", gotTaskID) + } +} + +func TestDaemonHandleEventControlWithNilCallback(t *testing.T) { + tr := NewHTTPTransport("http://localhost", "key", "ua") + d := NewDaemon(DaemonConfig{ + AgentID: "a1", + AgentName: "test", + Version: "1.0", + DownloadDir: "/tmp", + }, tr) + + // OnControlAction is nil — should not panic + d.handleEvent(ServerEvent{ + Type: "control", + Control: &ControlAction{Action: "pause", TaskID: "t1"}, + }) +} + +func TestDaemonHandleEventDisconnected(t *testing.T) { + tr := NewHTTPTransport("http://localhost", "key", "ua") + d := NewDaemon(DaemonConfig{ + AgentID: "a1", + AgentName: "test", + Version: "1.0", + DownloadDir: "/tmp", + }, tr) + + // disconnected event should not panic (just logs) + d.handleEvent(ServerEvent{Type: "disconnected"}) +} + +func TestDaemonHandleEventTasksNilCallback(t *testing.T) { + tr := NewHTTPTransport("http://localhost", "key", "ua") + d := NewDaemon(DaemonConfig{ + AgentID: "a1", + AgentName: "test", + Version: "1.0", + DownloadDir: "/tmp", + }, tr) + + // OnTasksClaimed is nil — should not panic + d.handleEvent(ServerEvent{ + Type: "tasks", + Tasks: &TasksResponse{ + Tasks: []Task{{ID: "t1", Title: "Test"}}, + }, + }) +} + +func TestDaemonHandleEventEmptyTasks(t *testing.T) { + tr := NewHTTPTransport("http://localhost", "key", "ua") + d := NewDaemon(DaemonConfig{ + AgentID: "a1", + AgentName: "test", + Version: "1.0", + DownloadDir: "/tmp", + }, tr) + + var called bool + d.OnTasksClaimed = func(tasks []Task) { + called = true + } + + // Empty tasks should not trigger callback + d.handleEvent(ServerEvent{ + Type: "tasks", + Tasks: &TasksResponse{Tasks: []Task{}}, + }) + + if called { + t.Error("OnTasksClaimed should not be called for empty task list") + } +} + +func TestDaemonHandleEventNilTasks(t *testing.T) { + tr := NewHTTPTransport("http://localhost", "key", "ua") + d := NewDaemon(DaemonConfig{ + AgentID: "a1", + AgentName: "test", + Version: "1.0", + DownloadDir: "/tmp", + }, tr) + + // Nil Tasks field should not panic + d.handleEvent(ServerEvent{ + Type: "tasks", + Tasks: nil, + }) +} + +func TestDaemonHandleEventUpgradeNilSignal(t *testing.T) { + tr := NewHTTPTransport("http://localhost", "key", "ua") + d := NewDaemon(DaemonConfig{ + AgentID: "a1", + AgentName: "test", + Version: "1.0", + DownloadDir: "/tmp", + }, tr) + + // Nil Upgrade should not panic + d.handleEvent(ServerEvent{ + Type: "upgrade", + Upgrade: nil, + }) + if d.lastNotifiedVersion != "" { + t.Errorf("expected empty lastNotifiedVersion, got %s", d.lastNotifiedVersion) + } +} + +func TestDaemonHandleEventUpgradeEmptyVersion(t *testing.T) { + tr := NewHTTPTransport("http://localhost", "key", "ua") + d := NewDaemon(DaemonConfig{ + AgentID: "a1", + AgentName: "test", + Version: "1.0", + DownloadDir: "/tmp", + }, tr) + + // Empty version should not update lastNotifiedVersion + d.handleEvent(ServerEvent{ + Type: "upgrade", + Upgrade: &UpgradeSignal{Version: ""}, + }) + if d.lastNotifiedVersion != "" { + t.Errorf("expected empty lastNotifiedVersion, got %s", d.lastNotifiedVersion) + } +} + +func TestDaemonWatchingFlag(t *testing.T) { + tr := NewHTTPTransport("http://localhost", "key", "ua") + d := NewDaemon(DaemonConfig{ + AgentID: "a1", + AgentName: "test", + Version: "1.0", + DownloadDir: "/tmp", + }, tr) + + if d.Watching.Load() { + t.Error("expected Watching to be false initially") + } + d.Watching.Store(true) + if !d.Watching.Load() { + t.Error("expected Watching to be true after Store(true)") + } +} + +// ── Transport Interface Compliance ────────────────────────────────────────── + +func TestHTTPTransportImplementsTransport(t *testing.T) { + var _ Transport = (*HTTPTransport)(nil) +} + +func TestWSTransportImplementsTransport(t *testing.T) { + var _ Transport = (*WSTransport)(nil) +} + +func TestHybridTransportImplementsTransport(t *testing.T) { + var _ Transport = (*HybridTransport)(nil) +} diff --git a/internal/arr/client_test.go b/internal/arr/client_test.go new file mode 100644 index 0000000..214dd12 --- /dev/null +++ b/internal/arr/client_test.go @@ -0,0 +1,396 @@ +package arr + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func newTestServer(t *testing.T, handlers map[string]any) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check API key header + if r.Header.Get("X-Api-Key") != "test-key" { + w.WriteHeader(http.StatusUnauthorized) + return + } + + handler, ok := handlers[r.URL.Path] + if !ok { + w.WriteHeader(http.StatusNotFound) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(handler) + })) +} + +func TestNewClient(t *testing.T) { + c := NewClient("http://localhost:8989/", "mykey") + if c.baseURL != "http://localhost:8989" { + t.Errorf("baseURL = %q, want trailing slash trimmed", c.baseURL) + } + if c.apiKey != "mykey" { + t.Errorf("apiKey = %q, want mykey", c.apiKey) + } +} + +func TestSystemStatus(t *testing.T) { + srv := newTestServer(t, map[string]any{ + "/api/v3/system/status": SystemStatus{AppName: "Radarr", Version: "4.0.0"}, + }) + defer srv.Close() + + c := NewClient(srv.URL, "test-key") + status, err := c.SystemStatus() + if err != nil { + t.Fatalf("SystemStatus: %v", err) + } + if status.AppName != "Radarr" { + t.Errorf("AppName = %q, want Radarr", status.AppName) + } + if status.Version != "4.0.0" { + t.Errorf("Version = %q, want 4.0.0", status.Version) + } +} + +func TestSystemStatusFallbackV1(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("X-Api-Key") != "test-key" { + w.WriteHeader(http.StatusUnauthorized) + return + } + switch r.URL.Path { + case "/api/v3/system/status": + w.WriteHeader(http.StatusNotFound) + case "/api/v1/system/status": + json.NewEncoder(w).Encode(SystemStatus{AppName: "Prowlarr", Version: "1.0.0"}) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + + c := NewClient(srv.URL, "test-key") + status, err := c.SystemStatus() + if err != nil { + t.Fatalf("SystemStatus v1 fallback: %v", err) + } + if status.AppName != "Prowlarr" { + t.Errorf("AppName = %q, want Prowlarr", status.AppName) + } +} + +func TestMovies(t *testing.T) { + srv := newTestServer(t, map[string]any{ + "/api/v3/movie": []Movie{ + {ID: 1, Title: "Inception", Year: 2010, TmdbID: 27205, Monitored: true}, + {ID: 2, Title: "Tenet", Year: 2020, TmdbID: 577922, HasFile: true}, + }, + }) + defer srv.Close() + + c := NewClient(srv.URL, "test-key") + movies, err := c.Movies() + if err != nil { + t.Fatalf("Movies: %v", err) + } + if len(movies) != 2 { + t.Fatalf("expected 2 movies, got %d", len(movies)) + } + if movies[0].Title != "Inception" { + t.Errorf("movies[0].Title = %q, want Inception", movies[0].Title) + } +} + +func TestSeries(t *testing.T) { + srv := newTestServer(t, map[string]any{ + "/api/v3/series": []Series{ + {ID: 1, Title: "Breaking Bad", Year: 2008, TvdbID: 81189}, + }, + }) + defer srv.Close() + + c := NewClient(srv.URL, "test-key") + series, err := c.Series() + if err != nil { + t.Fatalf("Series: %v", err) + } + if len(series) != 1 { + t.Fatalf("expected 1 series, got %d", len(series)) + } + if series[0].Title != "Breaking Bad" { + t.Errorf("series[0].Title = %q, want Breaking Bad", series[0].Title) + } +} + +func TestQualityProfiles(t *testing.T) { + srv := newTestServer(t, map[string]any{ + "/api/v3/qualityprofile": []QualityProfile{ + {ID: 1, Name: "HD-1080p"}, + {ID: 2, Name: "Ultra-HD"}, + }, + }) + defer srv.Close() + + c := NewClient(srv.URL, "test-key") + profiles, err := c.QualityProfiles() + if err != nil { + t.Fatalf("QualityProfiles: %v", err) + } + if len(profiles) != 2 { + t.Fatalf("expected 2 profiles, got %d", len(profiles)) + } +} + +func TestRootFolders(t *testing.T) { + srv := newTestServer(t, map[string]any{ + "/api/v3/rootfolder": []RootFolder{ + {ID: 1, Path: "/movies", FreeSpace: 500000000000}, + }, + }) + defer srv.Close() + + c := NewClient(srv.URL, "test-key") + folders, err := c.RootFolders() + if err != nil { + t.Fatalf("RootFolders: %v", err) + } + if len(folders) != 1 { + t.Fatalf("expected 1 folder, got %d", len(folders)) + } + if folders[0].Path != "/movies" { + t.Errorf("path = %q, want /movies", folders[0].Path) + } +} + +func TestDownloadClients(t *testing.T) { + srv := newTestServer(t, map[string]any{ + "/api/v3/downloadclient": []DownloadClient{ + {ID: 1, Name: "Transmission", Enable: true, Protocol: "torrent"}, + }, + }) + defer srv.Close() + + c := NewClient(srv.URL, "test-key") + clients, err := c.DownloadClients() + if err != nil { + t.Fatalf("DownloadClients: %v", err) + } + if len(clients) != 1 || clients[0].Name != "Transmission" { + t.Errorf("unexpected clients: %+v", clients) + } +} + +func TestDownloadClientDetails(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("X-Api-Key") != "test-key" { + w.WriteHeader(http.StatusUnauthorized) + return + } + if r.URL.Path == "/api/v3/downloadclient/5" { + json.NewEncoder(w).Encode(struct { + Fields []Field `json:"fields"` + }{ + Fields: []Field{ + {Name: "host", Value: "localhost"}, + {Name: "port", Value: 9091}, + }, + }) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer srv.Close() + + c := NewClient(srv.URL, "test-key") + fields, err := c.DownloadClientDetails(5) + if err != nil { + t.Fatalf("DownloadClientDetails: %v", err) + } + if len(fields) != 2 { + t.Fatalf("expected 2 fields, got %d", len(fields)) + } + if fields[0].Name != "host" { + t.Errorf("fields[0].Name = %q, want host", fields[0].Name) + } +} + +func TestTags(t *testing.T) { + srv := newTestServer(t, map[string]any{ + "/api/v3/tag": []Tag{ + {ID: 1, Label: "unarr"}, + {ID: 2, Label: "imported"}, + }, + }) + defer srv.Close() + + c := NewClient(srv.URL, "test-key") + tags, err := c.Tags() + if err != nil { + t.Fatalf("Tags: %v", err) + } + if len(tags) != 2 { + t.Fatalf("expected 2 tags, got %d", len(tags)) + } +} + +func TestHistory(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("X-Api-Key") != "test-key" { + w.WriteHeader(http.StatusUnauthorized) + return + } + if r.URL.Path == "/api/v3/history" { + json.NewEncoder(w).Encode(HistoryResponse{ + Records: []HistoryRecord{ + {ID: 1, EventType: "grabbed", SourceTitle: "Inception.2010.1080p"}, + }, + TotalRecords: 1, + }) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer srv.Close() + + c := NewClient(srv.URL, "test-key") + records, err := c.History(10) + if err != nil { + t.Fatalf("History: %v", err) + } + if len(records) != 1 { + t.Fatalf("expected 1 record, got %d", len(records)) + } + if records[0].SourceTitle != "Inception.2010.1080p" { + t.Errorf("sourceTitle = %q", records[0].SourceTitle) + } +} + +func TestHistoryDefaultPageSize(t *testing.T) { + var requestedPath string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("X-Api-Key") != "test-key" { + w.WriteHeader(http.StatusUnauthorized) + return + } + requestedPath = r.URL.String() + json.NewEncoder(w).Encode(HistoryResponse{}) + })) + defer srv.Close() + + c := NewClient(srv.URL, "test-key") + c.History(0) // should default to 250 + + if requestedPath == "" { + t.Fatal("no request made") + } + if !contains(requestedPath, "pageSize=250") { + t.Errorf("expected pageSize=250, got path: %s", requestedPath) + } +} + +func TestBlocklist(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("X-Api-Key") != "test-key" { + w.WriteHeader(http.StatusUnauthorized) + return + } + if r.URL.Path == "/api/v3/blocklist" { + json.NewEncoder(w).Encode(BlocklistResponse{ + Records: []BlocklistItem{ + {ID: 1, SourceTitle: "Bad.Release", Data: BlocklistData{InfoHash: "abc123"}}, + }, + }) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer srv.Close() + + c := NewClient(srv.URL, "test-key") + items, err := c.Blocklist(50) + if err != nil { + t.Fatalf("Blocklist: %v", err) + } + if len(items) != 1 || items[0].Data.InfoHash != "abc123" { + t.Errorf("unexpected blocklist: %+v", items) + } +} + +func TestIndexers(t *testing.T) { + srv := newTestServer(t, map[string]any{ + "/api/v1/indexer": []Indexer{ + {ID: 1, Name: "NZBGeek", Enable: true}, + {ID: 2, Name: "Torznab", Enable: false}, + }, + }) + defer srv.Close() + + c := NewClient(srv.URL, "test-key") + indexers, err := c.Indexers() + if err != nil { + t.Fatalf("Indexers: %v", err) + } + if len(indexers) != 2 { + t.Fatalf("expected 2 indexers, got %d", len(indexers)) + } +} + +func TestApplications(t *testing.T) { + srv := newTestServer(t, map[string]any{ + "/api/v1/applications": []Application{ + {ID: 1, Name: "Radarr"}, + }, + }) + defer srv.Close() + + c := NewClient(srv.URL, "test-key") + apps, err := c.Applications() + if err != nil { + t.Fatalf("Applications: %v", err) + } + if len(apps) != 1 || apps[0].Name != "Radarr" { + t.Errorf("unexpected apps: %+v", apps) + } +} + +func TestUnauthorized(t *testing.T) { + srv := newTestServer(t, map[string]any{}) + defer srv.Close() + + c := NewClient(srv.URL, "wrong-key") + _, err := c.SystemStatus() + if err == nil { + t.Error("expected error for unauthorized request") + } +} + +func TestHTTPError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("internal error")) + })) + defer srv.Close() + + c := NewClient(srv.URL, "test-key") + _, err := c.Movies() + if err == nil { + t.Error("expected error for 500 response") + } +} + +func contains(s, substr string) bool { + return len(s) >= len(substr) && searchStr(s, substr) +} + +func searchStr(s, sub string) bool { + for i := 0; i <= len(s)-len(sub); i++ { + if s[i:i+len(sub)] == sub { + return true + } + } + return false +} diff --git a/internal/arr/discovery_test.go b/internal/arr/discovery_test.go index 499877c..6dc78b8 100644 --- a/internal/arr/discovery_test.go +++ b/internal/arr/discovery_test.go @@ -1,6 +1,9 @@ package arr import ( + "encoding/json" + "net/http" + "net/http/httptest" "strings" "testing" ) @@ -82,3 +85,158 @@ func TestDetectApp(t *testing.T) { }) } } + +func TestConfigDirs(t *testing.T) { + dirs := configDirs() + if len(dirs) == 0 { + t.Error("configDirs() returned empty") + } +} + +func TestParseConfigXMLEmpty(t *testing.T) { + port, apiKey, urlBase := parseConfigXML(strings.NewReader("")) + if port != "" || apiKey != "" || urlBase != "" { + t.Error("empty input should return empty values") + } +} + +func TestParseConfigXMLNoPort(t *testing.T) { + xml := `key123` + port, apiKey, _ := parseConfigXML(strings.NewReader(xml)) + if port != "" { + t.Errorf("port = %q, want empty", port) + } + if apiKey != "key123" { + t.Errorf("apiKey = %q, want key123", apiKey) + } +} + +func TestExtractHostPortMultipleMappings(t *testing.T) { + tests := []struct { + name string + ports string + container string + want string + }{ + {"ipv6 only", ":::8989->8989/tcp", "8989", "8989"}, + {"different host port", "0.0.0.0:9999->8989/tcp", "8989", "9999"}, + {"port in string but no mapping", "something 8989 somewhere", "8989", "8989"}, + {"no match at all", "0.0.0.0:3000->3000/tcp", "9999", ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractHostPort(tt.ports, tt.container) + if got != tt.want { + t.Errorf("extractHostPort(%q, %q) = %q, want %q", tt.ports, tt.container, got, tt.want) + } + }) + } +} + +func TestDiscoverFromProwlarr(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/v1/applications": + json.NewEncoder(w).Encode([]Application{ + { + ID: 1, + Name: "Radarr", + Fields: []Field{ + {Name: "baseUrl", Value: "http://localhost:7878"}, + {Name: "apiKey", Value: "radarr-key-123"}, + }, + }, + { + ID: 2, + Name: "Sonarr", + Fields: []Field{ + {Name: "baseUrl", Value: "http://localhost:8989"}, + {Name: "apiKey", Value: "sonarr-key-456"}, + }, + }, + { + ID: 3, + Name: "Unknown App", + Fields: []Field{ + {Name: "baseUrl", Value: "http://localhost:9000"}, + {Name: "apiKey", Value: "unknown-key"}, + }, + }, + { + ID: 4, + Name: "Incomplete", + Fields: []Field{ + {Name: "baseUrl", Value: "http://localhost:5000"}, + // no apiKey → should be skipped + }, + }, + }) + case "/api/v3/system/status": + json.NewEncoder(w).Encode(SystemStatus{AppName: "Radarr", Version: "4.0.0"}) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + + // DiscoverFromProwlarr will try to verify each instance, which will fail + // for localhost URLs (not our test server), but that's OK — we test the parsing + instances := DiscoverFromProwlarr(srv.URL, "prowlarr-key") + + // Should find Radarr and Sonarr (Unknown and Incomplete skipped) + if len(instances) != 2 { + t.Fatalf("expected 2 instances, got %d: %+v", len(instances), instances) + } + + found := map[string]bool{} + for _, inst := range instances { + found[inst.App] = true + if inst.Source != "prowlarr" { + t.Errorf("source = %q, want prowlarr", inst.Source) + } + } + if !found["radarr"] { + t.Error("expected radarr instance") + } + if !found["sonarr"] { + t.Error("expected sonarr instance") + } +} + +func TestVerify(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("X-Api-Key") != "valid-key" { + w.WriteHeader(http.StatusUnauthorized) + return + } + json.NewEncoder(w).Encode(SystemStatus{AppName: "Radarr", Version: "5.0.0"}) + })) + defer srv.Close() + + t.Run("valid", func(t *testing.T) { + inst := &Instance{App: "radarr", URL: srv.URL, APIKey: "valid-key"} + err := Verify(inst) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if inst.Version != "5.0.0" { + t.Errorf("version = %q, want 5.0.0", inst.Version) + } + }) + + t.Run("no api key", func(t *testing.T) { + inst := &Instance{App: "radarr", URL: srv.URL} + err := Verify(inst) + if err == nil { + t.Error("expected error for no API key") + } + }) + + t.Run("invalid key", func(t *testing.T) { + inst := &Instance{App: "radarr", URL: srv.URL, APIKey: "wrong-key"} + err := Verify(inst) + if err == nil { + t.Error("expected error for invalid API key") + } + }) +} diff --git a/internal/cmd/config_menu_test.go b/internal/cmd/config_menu_test.go new file mode 100644 index 0000000..a63389b --- /dev/null +++ b/internal/cmd/config_menu_test.go @@ -0,0 +1,55 @@ +package cmd + +import "testing" + +func TestValidateSpeed(t *testing.T) { + tests := []struct { + input string + wantErr bool + }{ + {"", false}, + {"0", false}, + {" ", false}, + {"10MB", false}, + {"500KB", false}, + {"1GB", false}, + {"abc", true}, + {"10XB", true}, + {"-5MB", true}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + err := validateSpeed(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("validateSpeed(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr) + } + }) + } +} + +func TestValidateDuration(t *testing.T) { + tests := []struct { + input string + wantErr bool + }{ + {"", false}, + {"30s", false}, + {"1m", false}, + {"5m", false}, + {"1h", false}, + {"2h30m", false}, + {"abc", true}, + {"30", true}, + {"5x", true}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + err := validateDuration(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("validateDuration(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr) + } + }) + } +} diff --git a/internal/cmd/daemon_test.go b/internal/cmd/daemon_test.go new file mode 100644 index 0000000..11ad2a5 --- /dev/null +++ b/internal/cmd/daemon_test.go @@ -0,0 +1,55 @@ +package cmd + +import "testing" + +func TestDeriveWSURL(t *testing.T) { + tests := []struct { + apiURL string + agentID string + want string + }{ + {"https://torrentclaw.com", "agent-123", "wss://unarr.torrentclaw.com/ws/agent-123"}, + {"http://localhost:3000", "a1", ""}, // localhost skipped + {"http://127.0.0.1:3000", "a1", ""}, // 127.0.0.1 skipped + {"https://torrentclaw.com/", "a1", "wss://unarr.torrentclaw.com/ws/a1"}, + {"https://api.example.io", "x", "wss://unarr.api.example.io/ws/x"}, + {"", "agent-123", ""}, + {"https://torrentclaw.com", "", ""}, + {"", "", ""}, + } + + for _, tt := range tests { + t.Run(tt.apiURL+"_"+tt.agentID, func(t *testing.T) { + got := deriveWSURL(tt.apiURL, tt.agentID) + if got != tt.want { + t.Errorf("deriveWSURL(%q, %q) = %q, want %q", tt.apiURL, tt.agentID, got, tt.want) + } + }) + } +} + +func TestFormatSpeedLog(t *testing.T) { + tests := []struct { + bps int64 + want string + }{ + {0, "0 B/s"}, + {500, "500 B/s"}, + {1023, "1023 B/s"}, + {1024, "1 KB/s"}, + {10240, "10 KB/s"}, + {1048576, "1.0 MB/s"}, + {5242880, "5.0 MB/s"}, + {1073741824, "1.0 GB/s"}, + {2147483648, "2.0 GB/s"}, + } + + for _, tt := range tests { + t.Run(tt.want, func(t *testing.T) { + got := formatSpeedLog(tt.bps) + if got != tt.want { + t.Errorf("formatSpeedLog(%d) = %q, want %q", tt.bps, got, tt.want) + } + }) + } +} diff --git a/internal/cmd/helpers_test.go b/internal/cmd/helpers_test.go new file mode 100644 index 0000000..a5badc3 --- /dev/null +++ b/internal/cmd/helpers_test.go @@ -0,0 +1,43 @@ +package cmd + +import ( + "os" + "strings" + "testing" +) + +func TestExpandHome(t *testing.T) { + home, _ := os.UserHomeDir() + + tests := []struct { + input string + want string + }{ + {"~/Documents", home + "/Documents"}, + {"~/", home}, + {"/absolute/path", "/absolute/path"}, + {"relative/path", "relative/path"}, + {"", ""}, + {"~notexpanded", "~notexpanded"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := expandHome(tt.input) + if got != tt.want { + t.Errorf("expandHome(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestDefaultDownloadDir(t *testing.T) { + dir := defaultDownloadDir() + if dir == "" { + t.Error("defaultDownloadDir() returned empty string") + } + home, _ := os.UserHomeDir() + if !strings.HasPrefix(dir, home) { + t.Errorf("defaultDownloadDir() = %q, expected to start with home dir %q", dir, home) + } +} diff --git a/internal/cmd/root.go b/internal/cmd/root.go index bd8f734..bcf3473 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -143,8 +143,9 @@ Source: https://github.com/torrentclaw/unarr`, completionCmd, // Library scanCmd, + // Alias: upgrade → self-update + newUpgradeCmd(), // Stubs for future commands - newStubCmd("upgrade", "Find a better version of a torrent"), newStubCmd("moreseed", "Find same quality with more seeders"), newStubCmd("compare", "Compare two torrents side by side"), newStubCmd("add", "Search and add torrents to your client"), diff --git a/internal/cmd/status.go b/internal/cmd/status.go index 0f989f4..e90354c 100644 --- a/internal/cmd/status.go +++ b/internal/cmd/status.go @@ -1,20 +1,26 @@ package cmd import ( + "context" "fmt" + "runtime" + "strings" + "time" "github.com/fatih/color" "github.com/spf13/cobra" + "github.com/torrentclaw/unarr/internal/agent" + "github.com/torrentclaw/unarr/internal/upgrade" ) func newStatusCmd() *cobra.Command { return &cobra.Command{ Use: "status", - Short: "Show daemon status and active downloads", - Long: `Display the current state of the daemon, active downloads, and recent activity. + Short: "Show daemon status, configuration, and update availability", + Long: `Display the current state of unarr: version, configuration, daemon status, +disk usage, and whether an update is available. -Shows the configured agent name, download directory, and preferred method. -When the daemon is running, also displays active downloads and their progress.`, +When the daemon is running, also displays uptime, active downloads, and stats.`, Example: ` unarr status`, RunE: func(cmd *cobra.Command, args []string) error { return runStatus() @@ -25,27 +31,167 @@ When the daemon is running, also displays active downloads and their progress.`, func runStatus() error { bold := color.New(color.Bold) dim := color.New(color.FgHiBlack) + green := color.New(color.FgGreen) + yellow := color.New(color.FgYellow) + cyan := color.New(color.FgCyan) fmt.Println() bold.Printf(" unarr %s\n", Version) + dim.Printf(" %s/%s\n", runtime.GOOS, runtime.GOARCH) fmt.Println() cfg := loadConfig() + // ── Configuration ── if cfg.Auth.APIKey == "" { - dim.Println(" Not configured. Run 'unarr init' first.") + yellow.Println(" ⚠ Not configured. Run 'unarr init' first.") fmt.Println() return nil } - fmt.Printf(" Agent: %s (%s)\n", cfg.Agent.Name, cfg.Agent.ID[:8]+"...") - fmt.Printf(" Downloads: %s\n", cfg.Download.Dir) - fmt.Printf(" Method: %s\n", cfg.Download.PreferredMethod) + cyan.Println(" Configuration") + agentID := cfg.Agent.ID + if len(agentID) > 8 { + agentID = agentID[:8] + "..." + } + fmt.Printf(" Agent: %s (%s)\n", cfg.Agent.Name, agentID) + fmt.Printf(" Server: %s\n", cfg.Auth.APIURL) + fmt.Printf(" Downloads: %s\n", cfg.Download.Dir) + fmt.Printf(" Method: %s\n", cfg.Download.PreferredMethod) + if cfg.Download.PreferredQuality != "" { + fmt.Printf(" Quality: %s\n", cfg.Download.PreferredQuality) + } + fmt.Printf(" Concurrent: %d\n", cfg.Download.MaxConcurrent) + if cfg.Organize.Enabled { + fmt.Printf(" Organize: on") + if cfg.Organize.MoviesDir != "" { + fmt.Printf(" (movies: %s", cfg.Organize.MoviesDir) + if cfg.Organize.TVShowsDir != "" { + fmt.Printf(", tv: %s", cfg.Organize.TVShowsDir) + } + fmt.Print(")") + } + fmt.Println() + } fmt.Println() - dim.Println(" Daemon not running. Start with 'unarr start'") - dim.Println(" (Live status will be shown here when daemon is running)") + // ── Disk ── + if cfg.Download.Dir != "" { + if free, total, err := agent.DiskInfo(cfg.Download.Dir); err == nil && total > 0 { + usedPct := float64(total-free) / float64(total) * 100 + cyan.Println(" Disk") + fmt.Printf(" Free: %s / %s (%.0f%% used)\n", formatBytes(free), formatBytes(total), usedPct) + if usedPct > 90 { + yellow.Println(" ⚠ Low disk space!") + } + fmt.Println() + } + } + + // ── Daemon ── + cyan.Println(" Daemon") + state := agent.ReadState() + if state != nil && isDaemonAlive(state) { + green.Printf(" Status: running (PID %d)\n", state.PID) + fmt.Printf(" Uptime: %s\n", formatDuration(time.Since(state.StartedAt))) + fmt.Printf(" Last beat: %s ago\n", formatDuration(time.Since(state.LastHeartbeat))) + fmt.Printf(" Active: %d task(s)\n", state.ActiveTasks) + fmt.Printf(" Completed: %d\n", state.CompletedCount) + if state.FailedCount > 0 { + fmt.Printf(" Failed: %d\n", state.FailedCount) + } + if state.TotalDownloaded > 0 { + fmt.Printf(" Downloaded: %s\n", formatBytes(state.TotalDownloaded)) + } + if len(state.MethodStats) > 0 { + parts := make([]string, 0, len(state.MethodStats)) + for method, count := range state.MethodStats { + parts = append(parts, fmt.Sprintf("%s:%d", method, count)) + } + fmt.Printf(" Methods: %s\n", strings.Join(parts, ", ")) + } + } else { + dim.Println(" Status: stopped") + dim.Println(" Start with: unarr start") + } + fmt.Println() + + // ── Update check (cached: instant if <1h, otherwise async 3s) ── + type versionResult struct { + version string + fromCache bool + err error + } + versionCh := make(chan versionResult, 1) + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + v, cached, err := upgrade.CheckLatestCached(ctx) + versionCh <- versionResult{v, cached, err} + }() + + cyan.Println(" Update") + fmt.Print(" Checking... ") + vr := <-versionCh + if vr.err != nil { + dim.Println("could not check (offline?)") + } else { + currentClean := strings.TrimPrefix(Version, "v") + if currentClean == vr.version { + green.Printf("✓ up to date (v%s)\n", vr.version) + } else { + yellow.Printf("v%s available! ", vr.version) + fmt.Printf("Run: unarr upgrade\n") + } + } fmt.Println() return nil } + +// isDaemonAlive checks if the daemon process from the state file is still running. +// Guards against PID reuse by also checking heartbeat recency. +func isDaemonAlive(state *agent.DaemonState) bool { + if state.PID == 0 { + return false + } + // Reject stale state: if last heartbeat is older than 2 minutes, the daemon + // likely crashed and the PID may have been reused by another process. + if !state.LastHeartbeat.IsZero() && time.Since(state.LastHeartbeat) > 2*time.Minute { + return false + } + return agent.IsProcessAlive(state.PID) +} + +// formatBytes formats bytes into human-readable string. +func formatBytes(b int64) string { + const unit = 1024 + if b < unit { + return fmt.Sprintf("%d B", b) + } + div, exp := int64(unit), 0 + for n := b / unit; n >= unit; n /= unit { + div *= unit + exp++ + } + return fmt.Sprintf("%.1f %cB", float64(b)/float64(div), "KMGTPE"[exp]) +} + +// formatDuration formats a duration into a compact human-readable string. +func formatDuration(d time.Duration) string { + if d < 0 { + return "0s" + } + if d < time.Minute { + return fmt.Sprintf("%ds", int(d.Seconds())) + } + if d < time.Hour { + return fmt.Sprintf("%dm %ds", int(d.Minutes()), int(d.Seconds())%60) + } + if d < 24*time.Hour { + return fmt.Sprintf("%dh %dm", int(d.Hours()), int(d.Minutes())%60) + } + days := int(d.Hours()) / 24 + hours := int(d.Hours()) % 24 + return fmt.Sprintf("%dd %dh", days, hours) +} diff --git a/internal/cmd/stream_handler.go b/internal/cmd/stream_handler.go index 9bb9657..88f3111 100644 --- a/internal/cmd/stream_handler.go +++ b/internal/cmd/stream_handler.go @@ -24,6 +24,20 @@ var streamRegistry = struct { servers: make(map[string]*engine.StreamServer), } +// cancelAllStreams cancels all active stream tasks and servers (only 1 stream at a time). +func cancelAllStreams() { + streamRegistry.mu.Lock() + for taskID, cancel := range streamRegistry.cancels { + cancel() + delete(streamRegistry.cancels, taskID) + } + for taskID, srv := range streamRegistry.servers { + srv.Shutdown(context.Background()) + delete(streamRegistry.servers, taskID) + } + streamRegistry.mu.Unlock() +} + // cancelStreamTask cancels a running stream task and shuts down any stream server. func cancelStreamTask(taskID string) { streamRegistry.mu.Lock() @@ -94,7 +108,7 @@ func handleStreamTask(parentCtx context.Context, at agent.Task, reporter *engine } // 4. Start HTTP server - srv := engine.NewStreamServer(eng, 0) + srv := engine.NewStreamServer(eng, cfg.Download.StreamPort) streamURL, err := srv.Start(ctx) if err != nil { task.ErrorMessage = "start HTTP server: " + err.Error() @@ -107,17 +121,27 @@ func handleStreamTask(parentCtx context.Context, at agent.Task, reporter *engine task.StreamURL = streamURL log.Printf("[%s] stream ready: %s", at.ID[:8], streamURL) - // 6. Progress loop + // 6. Unified progress + idle timeout loop eng.StartProgressLoop(ctx) - ticker := time.NewTicker(3 * time.Second) - defer ticker.Stop() + progressTicker := time.NewTicker(3 * time.Second) + defer progressTicker.Stop() + idleCheck := time.NewTicker(60 * time.Second) + defer idleCheck.Stop() + completed := false for { select { case <-ctx.Done(): log.Printf("[%s] stream stopped", at.ID[:8]) return - case <-ticker.C: + + case <-idleCheck.C: + if srv.IdleSince() > 30*time.Minute { + log.Printf("[%s] stream idle timeout (30m no HTTP requests), shutting down", at.ID[:8]) + return + } + + case <-progressTicker.C: p := eng.Progress() task.UpdateProgress(engine.Progress{ DownloadedBytes: p.DownloadedBytes, @@ -129,7 +153,7 @@ func handleStreamTask(parentCtx context.Context, at agent.Task, reporter *engine }) // Terminal progress - if p.TotalBytes > 0 { + if !completed && p.TotalBytes > 0 { pct := int(float64(p.DownloadedBytes) / float64(p.TotalBytes) * 100) fmt.Fprintf(os.Stderr, "\r[%s] %d%% — %s/%s @ %s/s peers:%d seeds:%d", at.ID[:8], pct, @@ -137,20 +161,11 @@ func handleStreamTask(parentCtx context.Context, at agent.Task, reporter *engine p.Peers, p.Seeds) } - if p.DownloadedBytes >= p.TotalBytes && p.TotalBytes > 0 { - fmt.Fprint(os.Stderr, "\r\033[2K") // clear progress line + if !completed && p.DownloadedBytes >= p.TotalBytes && p.TotalBytes > 0 { + fmt.Fprint(os.Stderr, "\r\033[2K") task.Transition(engine.StatusCompleted) - log.Printf("[%s] stream download complete, server stays up for 30m or until cancelled", at.ID[:8]) - // Keep HTTP server running so the player can finish reading. - // Auto-shutdown after 30 minutes of idle to prevent resource leaks. - idleTimer := time.NewTimer(30 * time.Minute) - defer idleTimer.Stop() - select { - case <-ctx.Done(): - case <-idleTimer.C: - log.Printf("[%s] stream idle timeout (30m), shutting down", at.ID[:8]) - } - return + log.Printf("[%s] stream download complete, server stays up until idle (30m)", at.ID[:8]) + completed = true } } } diff --git a/internal/cmd/upgrade.go b/internal/cmd/upgrade.go new file mode 100644 index 0000000..c374603 --- /dev/null +++ b/internal/cmd/upgrade.go @@ -0,0 +1,30 @@ +package cmd + +import ( + "github.com/spf13/cobra" +) + +// newUpgradeCmd creates the `unarr upgrade` command as an alias for `self-update`. +func newUpgradeCmd() *cobra.Command { + var force bool + + cmd := &cobra.Command{ + Use: "upgrade", + Aliases: []string{"update"}, + Short: "Update unarr to the latest version", + Long: `Download and install the latest version of unarr. + +This is an alias for 'unarr self-update'. Checks GitHub for the latest +release, verifies the checksum, and replaces the current binary. +A backup is kept at .backup.`, + Example: ` unarr upgrade + unarr upgrade --force`, + RunE: func(cmd *cobra.Command, args []string) error { + return runSelfUpdate(force) + }, + } + + cmd.Flags().BoolVarP(&force, "force", "f", false, "reinstall even if already up to date") + + return cmd +} diff --git a/internal/engine/manager_test.go b/internal/engine/manager_test.go index 7c9893f..84bcc18 100644 --- a/internal/engine/manager_test.go +++ b/internal/engine/manager_test.go @@ -2,6 +2,7 @@ package engine import ( "context" + "os" "testing" "time" @@ -83,3 +84,223 @@ func TestManagerShutdown(t *testing.T) { mgr.Shutdown(ctx) // Should not hang } + +func TestManagerDefaultConcurrency(t *testing.T) { + reporter := NewProgressReporter( + agent.NewClient("http://localhost", "test", "test"), + 1*time.Second, + ) + mgr := NewManager(ManagerConfig{MaxConcurrent: 0}, reporter) + if cap(mgr.sem) != 3 { + t.Errorf("default MaxConcurrent should be 3, got %d", cap(mgr.sem)) + } +} + +func TestManagerGetTask(t *testing.T) { + reporter := NewProgressReporter( + agent.NewClient("http://localhost", "test", "test"), + 1*time.Second, + ) + mgr := NewManager(ManagerConfig{MaxConcurrent: 2}, reporter) + + // No task added + if task := mgr.GetTask("nonexistent"); task != nil { + t.Error("expected nil for nonexistent task") + } +} + +func TestManagerActiveTasks(t *testing.T) { + reporter := NewProgressReporter( + agent.NewClient("http://localhost", "test", "test"), + 1*time.Second, + ) + mgr := NewManager(ManagerConfig{MaxConcurrent: 2}, reporter) + + tasks := mgr.ActiveTasks() + if len(tasks) != 0 { + t.Errorf("expected 0 active tasks, got %d", len(tasks)) + } +} + +func TestManagerSubmitCompletesWithValidFile(t *testing.T) { + dir := t.TempDir() + // Create a file that verify() will accept + filePath := dir + "/movie.mkv" + os.WriteFile(filePath, make([]byte, 1024), 0o644) + + reporter := &mockStatusReporter{} + pr := &ProgressReporter{ + reporter: reporter, + interval: 100 * time.Millisecond, + latest: make(map[string]*Task), + lastReported: make(map[string]TaskStatus), + } + + dl := &resultMockDownloader{ + method: MethodTorrent, + result: &Result{ + FilePath: filePath, + FileName: "movie.mkv", + Method: MethodTorrent, + Size: 1024, + }, + } + + mgr := NewManager(ManagerConfig{ + MaxConcurrent: 2, + OutputDir: dir, + }, pr, dl) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + go pr.Run(ctx) + + mgr.Submit(ctx, agent.Task{ + ID: "task-complete-test1", + InfoHash: "abc123def456abc123def456abc123def456abc1", + Title: "Test Movie", + PreferredMethod: "torrent", + }) + + mgr.Wait() + cancel() + + // Task should have completed successfully + // (we can't check directly since it's removed from active map after processing) +} + +func TestManagerCancelTask(t *testing.T) { + reporter := NewProgressReporter( + agent.NewClient("http://localhost", "test", "test"), + 1*time.Second, + ) + + dl := &slowMockDownloader{method: MethodTorrent} + mgr := NewManager(ManagerConfig{ + MaxConcurrent: 2, + OutputDir: t.TempDir(), + }, reporter, dl) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + go reporter.Run(ctx) + + mgr.Submit(ctx, agent.Task{ + ID: "task-cancel-test12", + InfoHash: "abc123def456abc123def456abc123def456abc1", + Title: "Cancel Me", + PreferredMethod: "torrent", + }) + + // Give it time to start + time.Sleep(100 * time.Millisecond) + + mgr.CancelTask("task-cancel-test12") + mgr.Wait() +} + +func TestManagerPauseTask(t *testing.T) { + reporter := NewProgressReporter( + agent.NewClient("http://localhost", "test", "test"), + 1*time.Second, + ) + + dl := &slowMockDownloader{method: MethodTorrent} + mgr := NewManager(ManagerConfig{ + MaxConcurrent: 2, + OutputDir: t.TempDir(), + }, reporter, dl) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + go reporter.Run(ctx) + + mgr.Submit(ctx, agent.Task{ + ID: "task-pause-test123", + InfoHash: "abc123def456abc123def456abc123def456abc1", + Title: "Pause Me", + PreferredMethod: "torrent", + }) + + time.Sleep(100 * time.Millisecond) + mgr.PauseTask("task-pause-test123") + mgr.Wait() +} + +func TestManagerCancelAndDeleteFiles(t *testing.T) { + reporter := NewProgressReporter( + agent.NewClient("http://localhost", "test", "test"), + 1*time.Second, + ) + + dl := &slowMockDownloader{method: MethodTorrent} + mgr := NewManager(ManagerConfig{ + MaxConcurrent: 2, + OutputDir: t.TempDir(), + }, reporter, dl) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + go reporter.Run(ctx) + + mgr.Submit(ctx, agent.Task{ + ID: "task-delfile-test12", + InfoHash: "abc123def456abc123def456abc123def456abc1", + Title: "Delete Me", + PreferredMethod: "torrent", + }) + + time.Sleep(100 * time.Millisecond) + mgr.CancelAndDeleteFiles("task-delfile-test12") + mgr.Wait() +} + +func TestManagerCancelNonexistent(t *testing.T) { + reporter := NewProgressReporter( + agent.NewClient("http://localhost", "test", "test"), + 1*time.Second, + ) + mgr := NewManager(ManagerConfig{MaxConcurrent: 2}, reporter) + // Should not panic + mgr.CancelTask("nonexistent") + mgr.PauseTask("nonexistent") + mgr.CancelAndDeleteFiles("nonexistent") +} + +// resultMockDownloader returns a configurable result +type resultMockDownloader struct { + method DownloadMethod + result *Result +} + +func (m *resultMockDownloader) Method() DownloadMethod { return m.method } +func (m *resultMockDownloader) Available(_ context.Context, _ *Task) (bool, error) { + return true, nil +} +func (m *resultMockDownloader) Download(_ context.Context, _ *Task, _ string, _ chan<- Progress) (*Result, error) { + return m.result, nil +} +func (m *resultMockDownloader) Pause(_ string) error { return nil } +func (m *resultMockDownloader) Cancel(_ string) error { return nil } +func (m *resultMockDownloader) Shutdown(_ context.Context) error { return nil } + +// slowMockDownloader blocks until context is cancelled +type slowMockDownloader struct { + method DownloadMethod +} + +func (m *slowMockDownloader) Method() DownloadMethod { return m.method } +func (m *slowMockDownloader) Available(_ context.Context, _ *Task) (bool, error) { + return true, nil +} +func (m *slowMockDownloader) Download(ctx context.Context, _ *Task, _ string, _ chan<- Progress) (*Result, error) { + <-ctx.Done() + return nil, ctx.Err() +} +func (m *slowMockDownloader) Pause(_ string) error { return nil } +func (m *slowMockDownloader) Cancel(_ string) error { return nil } +func (m *slowMockDownloader) Shutdown(_ context.Context) error { return nil } diff --git a/internal/engine/method_test.go b/internal/engine/method_test.go new file mode 100644 index 0000000..e913d32 --- /dev/null +++ b/internal/engine/method_test.go @@ -0,0 +1,50 @@ +package engine + +import "testing" + +func TestDownloadMethodConstants(t *testing.T) { + if MethodTorrent != "torrent" { + t.Errorf("MethodTorrent = %q, want torrent", MethodTorrent) + } + if MethodDebrid != "debrid" { + t.Errorf("MethodDebrid = %q, want debrid", MethodDebrid) + } + if MethodUsenet != "usenet" { + t.Errorf("MethodUsenet = %q, want usenet", MethodUsenet) + } +} + +func TestProgressStruct(t *testing.T) { + p := Progress{ + DownloadedBytes: 1024, + TotalBytes: 2048, + SpeedBps: 512, + ETA: 10, + Peers: 5, + Seeds: 3, + FileName: "movie.mkv", + } + + if p.DownloadedBytes != 1024 { + t.Errorf("DownloadedBytes = %d, want 1024", p.DownloadedBytes) + } + if p.FileName != "movie.mkv" { + t.Errorf("FileName = %q, want movie.mkv", p.FileName) + } +} + +func TestResultStruct(t *testing.T) { + r := Result{ + FilePath: "/downloads/movie.mkv", + FileName: "movie.mkv", + Method: MethodTorrent, + Size: 1073741824, + } + + if r.Method != MethodTorrent { + t.Errorf("Method = %q, want torrent", r.Method) + } + if r.Size != 1073741824 { + t.Errorf("Size = %d, want 1073741824", r.Size) + } +} diff --git a/internal/engine/organize_expand_test.go b/internal/engine/organize_expand_test.go new file mode 100644 index 0000000..0a7d2f2 --- /dev/null +++ b/internal/engine/organize_expand_test.go @@ -0,0 +1,181 @@ +package engine + +import ( + "os" + "path/filepath" + "testing" +) + +func TestReplaceFile(t *testing.T) { + tmp := t.TempDir() + backupDir := filepath.Join(tmp, "backups") + + // Create "old" file + oldPath := filepath.Join(tmp, "movie.mkv") + os.WriteFile(oldPath, []byte("old content"), 0o644) + + // Create "new" file + newPath := filepath.Join(tmp, "movie-new.mkv") + os.WriteFile(newPath, []byte("new better content"), 0o644) + + err := replaceFile(oldPath, newPath, backupDir) + if err != nil { + t.Fatalf("replaceFile: %v", err) + } + + // Old path should now contain new content + data, err := os.ReadFile(oldPath) + if err != nil { + t.Fatalf("read old path: %v", err) + } + if string(data) != "new better content" { + t.Errorf("old path content = %q, want 'new better content'", string(data)) + } + + // Backup should exist + entries, _ := os.ReadDir(backupDir) + if len(entries) != 1 { + t.Errorf("expected 1 backup file, got %d", len(entries)) + } + + // New file should be gone + if _, err := os.Stat(newPath); !os.IsNotExist(err) { + t.Error("new file should have been moved/deleted") + } +} + +func TestReplaceFileOldNotFound(t *testing.T) { + tmp := t.TempDir() + err := replaceFile(filepath.Join(tmp, "nonexistent.mkv"), filepath.Join(tmp, "new.mkv"), "") + if err == nil { + t.Error("expected error when old file doesn't exist") + } +} + +func TestCopyFile(t *testing.T) { + tmp := t.TempDir() + + src := filepath.Join(tmp, "source.txt") + dst := filepath.Join(tmp, "dest.txt") + + content := []byte("hello world copy test") + os.WriteFile(src, content, 0o644) + + err := copyFile(src, dst) + if err != nil { + t.Fatalf("copyFile: %v", err) + } + + data, err := os.ReadFile(dst) + if err != nil { + t.Fatalf("read dest: %v", err) + } + if string(data) != string(content) { + t.Errorf("dest content = %q, want %q", string(data), string(content)) + } +} + +func TestCopyFileSrcNotFound(t *testing.T) { + tmp := t.TempDir() + err := copyFile(filepath.Join(tmp, "nope.txt"), filepath.Join(tmp, "out.txt")) + if err == nil { + t.Error("expected error when source doesn't exist") + } +} + +func TestOrganizeNoDirs(t *testing.T) { + r := &Result{FilePath: "/tmp/file.mkv", FileName: "file.mkv"} + task := &Task{Title: "Movie"} + + path, err := organize(r, task, OrganizeConfig{Enabled: true}) + if err != nil { + t.Fatal(err) + } + if path != "/tmp/file.mkv" { + t.Errorf("should return original path when no dirs configured, got %q", path) + } +} + +func TestOrganizeNilResult(t *testing.T) { + task := &Task{Title: "Movie"} + path, err := organize(&Result{}, task, OrganizeConfig{Enabled: true}) + if err != nil { + t.Fatal(err) + } + if path != "" { + t.Errorf("expected empty path for empty result, got %q", path) + } +} + +func TestOrganizeMovieDirectory(t *testing.T) { + tmp := t.TempDir() + srcDir := filepath.Join(tmp, "src", "MovieDir") + os.MkdirAll(srcDir, 0o755) + os.WriteFile(filepath.Join(srcDir, "movie.mkv"), []byte("data"), 0o644) + + moviesDir := filepath.Join(tmp, "Movies") + + r := &Result{FilePath: srcDir, FileName: "MovieDir"} + task := &Task{Title: "My Movie 2023"} + + path, err := organize(r, task, OrganizeConfig{ + Enabled: true, + MoviesDir: moviesDir, + }) + if err != nil { + t.Fatal(err) + } + + if path == srcDir { + t.Error("directory should have moved") + } + if _, err := os.Stat(path); err != nil { + t.Errorf("organized directory should exist at %s", path) + } +} + +func TestOrganizeSeasonOnly(t *testing.T) { + tmp := t.TempDir() + srcFile := filepath.Join(tmp, "Show.S01.Complete.mkv") + os.WriteFile(srcFile, []byte("data"), 0o644) + + tvDir := filepath.Join(tmp, "TV") + + r := &Result{FilePath: srcFile, FileName: "Show.S01.Complete.mkv"} + task := &Task{Title: "Show S01"} + + path, err := organize(r, task, OrganizeConfig{ + Enabled: true, + TVShowsDir: tvDir, + }) + if err != nil { + t.Fatal(err) + } + + dir := filepath.Dir(path) + if filepath.Base(dir) != "Season 01" { + t.Errorf("expected Season 01 directory, got %q", filepath.Base(dir)) + } +} + +func TestCleanTitleEdgeCases(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"", ""}, + {"Simple Title", "Simple Title"}, + {"Title (2023) 1080p BluRay", "Title"}, + {"Title 720p HDTV", "Title"}, + {"Title x264 HEVC", "Title"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := cleanTitle(tt.input) + if got != tt.want { + t.Errorf("cleanTitle(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} diff --git a/internal/engine/progress_test.go b/internal/engine/progress_test.go new file mode 100644 index 0000000..1bb36c6 --- /dev/null +++ b/internal/engine/progress_test.go @@ -0,0 +1,419 @@ +package engine + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/torrentclaw/unarr/internal/agent" +) + +// mockStatusReporter records calls to ReportStatus. +type mockStatusReporter struct { + mu sync.Mutex + calls []agent.StatusUpdate + resp *agent.StatusResponse + respErr error +} + +func (m *mockStatusReporter) ReportStatus(_ context.Context, update agent.StatusUpdate) (*agent.StatusResponse, error) { + m.mu.Lock() + defer m.mu.Unlock() + m.calls = append(m.calls, update) + if m.resp != nil { + return m.resp, m.respErr + } + return &agent.StatusResponse{}, m.respErr +} + +// mockBatchReporter records batch calls. +type mockBatchReporter struct { + mockStatusReporter + batchCalls [][]agent.StatusUpdate + batchResp *agent.BatchStatusResponse +} + +func (m *mockBatchReporter) BatchReportStatus(_ context.Context, updates []agent.StatusUpdate) (*agent.BatchStatusResponse, error) { + m.mu.Lock() + defer m.mu.Unlock() + m.batchCalls = append(m.batchCalls, updates) + if m.batchResp != nil { + return m.batchResp, nil + } + results := make([]agent.StatusResponse, len(updates)) + return &agent.BatchStatusResponse{Results: results}, nil +} + +func TestProgressReporter_TrackUntrack(t *testing.T) { + reporter := &mockStatusReporter{} + pr := &ProgressReporter{ + reporter: reporter, + interval: time.Second, + latest: make(map[string]*Task), + lastReported: make(map[string]TaskStatus), + } + + task := &Task{ID: "task-001", Status: StatusDownloading} + pr.Track(task) + + pr.mu.Lock() + if _, ok := pr.latest["task-001"]; !ok { + t.Error("task should be tracked") + } + pr.mu.Unlock() + + pr.Untrack("task-001") + + pr.mu.Lock() + if _, ok := pr.latest["task-001"]; ok { + t.Error("task should be untracked") + } + pr.mu.Unlock() +} + +func TestProgressReporter_FlushReportsFinalStates(t *testing.T) { + reporter := &mockStatusReporter{} + pr := &ProgressReporter{ + reporter: reporter, + interval: time.Second, + latest: make(map[string]*Task), + lastReported: make(map[string]TaskStatus), + } + + completed := &Task{ID: "task-completed-1234", Status: StatusCompleted} + pr.Track(completed) + + pr.flush(context.Background()) + + reporter.mu.Lock() + defer reporter.mu.Unlock() + if len(reporter.calls) != 1 { + t.Fatalf("expected 1 report, got %d", len(reporter.calls)) + } + if reporter.calls[0].TaskID != "task-completed-1234" { + t.Errorf("reported wrong task: %s", reporter.calls[0].TaskID) + } +} + +func TestProgressReporter_FlushSkipsWhenNotWatching(t *testing.T) { + reporter := &mockStatusReporter{} + pr := &ProgressReporter{ + reporter: reporter, + interval: time.Second, + isWatching: func() bool { return false }, + latest: make(map[string]*Task), + lastReported: make(map[string]TaskStatus), + lastCheckAt: time.Now(), // not due for control check + } + + // Active downloading task, already reported as downloading + task := &Task{ID: "task-active-12345678", Status: StatusDownloading} + pr.Track(task) + pr.mu.Lock() + pr.lastReported["task-active-12345678"] = StatusDownloading + pr.mu.Unlock() + + pr.flush(context.Background()) + + reporter.mu.Lock() + defer reporter.mu.Unlock() + if len(reporter.calls) != 0 { + t.Errorf("expected 0 reports when not watching (no transition), got %d", len(reporter.calls)) + } +} + +func TestProgressReporter_FlushReportsTransitions(t *testing.T) { + reporter := &mockStatusReporter{} + pr := &ProgressReporter{ + reporter: reporter, + interval: time.Second, + isWatching: func() bool { return false }, + latest: make(map[string]*Task), + lastReported: make(map[string]TaskStatus), + lastCheckAt: time.Now(), + } + + // Task transitioning from resolving to downloading + task := &Task{ID: "task-trans-12345678", Status: StatusDownloading} + pr.Track(task) + pr.mu.Lock() + pr.lastReported["task-trans-12345678"] = StatusResolving + pr.mu.Unlock() + + pr.flush(context.Background()) + + reporter.mu.Lock() + defer reporter.mu.Unlock() + if len(reporter.calls) != 1 { + t.Fatalf("expected 1 report for transition, got %d", len(reporter.calls)) + } +} + +func TestProgressReporter_FlushActiveWhenWatching(t *testing.T) { + reporter := &mockStatusReporter{} + pr := &ProgressReporter{ + reporter: reporter, + interval: time.Second, + isWatching: func() bool { return true }, + latest: make(map[string]*Task), + lastReported: make(map[string]TaskStatus), + } + + task := &Task{ID: "task-watch-12345678", Status: StatusDownloading} + pr.Track(task) + pr.mu.Lock() + pr.lastReported["task-watch-12345678"] = StatusDownloading + pr.mu.Unlock() + + pr.flush(context.Background()) + + reporter.mu.Lock() + defer reporter.mu.Unlock() + if len(reporter.calls) != 1 { + t.Fatalf("expected 1 report when watching active task, got %d", len(reporter.calls)) + } +} + +func TestProgressReporter_HandleResponseCancel(t *testing.T) { + reporter := &mockStatusReporter{ + resp: &agent.StatusResponse{Cancelled: true}, + } + + var cancelledID string + pr := &ProgressReporter{ + reporter: reporter, + interval: time.Second, + latest: make(map[string]*Task), + lastReported: make(map[string]TaskStatus), + onCancel: func(id string) { cancelledID = id }, + } + + task := &Task{ID: "task-cancel-1234567", Status: StatusCompleted} + pr.Track(task) + + pr.flush(context.Background()) + + if cancelledID != "task-cancel-1234567" { + t.Errorf("expected cancel handler called with task ID, got %q", cancelledID) + } +} + +func TestProgressReporter_HandleResponsePause(t *testing.T) { + reporter := &mockStatusReporter{ + resp: &agent.StatusResponse{Paused: true}, + } + + var pausedID string + pr := &ProgressReporter{ + reporter: reporter, + interval: time.Second, + latest: make(map[string]*Task), + lastReported: make(map[string]TaskStatus), + onPause: func(id string) { pausedID = id }, + } + + task := &Task{ID: "task-paused-1234567", Status: StatusCompleted} + pr.Track(task) + + pr.flush(context.Background()) + + if pausedID != "task-paused-1234567" { + t.Errorf("expected pause handler called, got %q", pausedID) + } +} + +func TestProgressReporter_HandleResponseDeleteFiles(t *testing.T) { + reporter := &mockStatusReporter{ + resp: &agent.StatusResponse{Cancelled: true, DeleteFiles: true}, + } + + var deletedID string + pr := &ProgressReporter{ + reporter: reporter, + interval: time.Second, + latest: make(map[string]*Task), + lastReported: make(map[string]TaskStatus), + onDeleteFiles: func(id string) { deletedID = id }, + } + + task := &Task{ID: "task-delete-1234567", Status: StatusCompleted} + pr.Track(task) + + pr.flush(context.Background()) + + if deletedID != "task-delete-1234567" { + t.Errorf("expected deleteFiles handler called, got %q", deletedID) + } +} + +func TestProgressReporter_HandleResponseStream(t *testing.T) { + reporter := &mockStatusReporter{ + resp: &agent.StatusResponse{StreamRequested: true}, + } + + var streamID string + pr := &ProgressReporter{ + reporter: reporter, + interval: time.Second, + latest: make(map[string]*Task), + lastReported: make(map[string]TaskStatus), + onStreamRequested: func(id string) { streamID = id }, + } + + // Task with no stream URL yet + task := &Task{ID: "task-stream-1234567", Status: StatusCompleted} + pr.Track(task) + + pr.flush(context.Background()) + + if streamID != "task-stream-1234567" { + t.Errorf("expected stream handler called, got %q", streamID) + } +} + +func TestProgressReporter_HandleResponseWatchingChanged(t *testing.T) { + reporter := &mockStatusReporter{ + resp: &agent.StatusResponse{Watching: true}, + } + + var watchingValue bool + pr := &ProgressReporter{ + reporter: reporter, + interval: time.Second, + latest: make(map[string]*Task), + lastReported: make(map[string]TaskStatus), + onWatchingChanged: func(w bool) { watchingValue = w }, + } + + task := &Task{ID: "task-watch2-1234567", Status: StatusCompleted} + pr.Track(task) + + pr.flush(context.Background()) + + if !watchingValue { + t.Error("expected watchingChanged called with true") + } +} + +func TestProgressReporter_BatchFlush(t *testing.T) { + batcher := &mockBatchReporter{ + batchResp: &agent.BatchStatusResponse{ + Results: []agent.StatusResponse{{}, {}}, + }, + } + + pr := &ProgressReporter{ + reporter: batcher, + interval: time.Second, + isWatching: func() bool { return true }, + latest: make(map[string]*Task), + lastReported: make(map[string]TaskStatus), + } + + pr.Track(&Task{ID: "task-batch1-1234567", Status: StatusDownloading}) + pr.Track(&Task{ID: "task-batch2-1234567", Status: StatusDownloading}) + + pr.flush(context.Background()) + + batcher.mu.Lock() + defer batcher.mu.Unlock() + + if len(batcher.batchCalls) != 1 { + t.Fatalf("expected 1 batch call, got %d", len(batcher.batchCalls)) + } + if len(batcher.batchCalls[0]) != 2 { + t.Errorf("expected 2 updates in batch, got %d", len(batcher.batchCalls[0])) + } +} + +func TestProgressReporter_RunStopsOnCancel(t *testing.T) { + reporter := &mockStatusReporter{} + pr := &ProgressReporter{ + reporter: reporter, + interval: 50 * time.Millisecond, + latest: make(map[string]*Task), + lastReported: make(map[string]TaskStatus), + } + + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + + err := pr.Run(ctx) + if err != nil { + t.Errorf("Run should return nil on context cancel, got: %v", err) + } +} + +func TestProgressReporter_ReportFinal(t *testing.T) { + reporter := &mockStatusReporter{} + pr := &ProgressReporter{ + reporter: reporter, + interval: time.Second, + latest: make(map[string]*Task), + lastReported: make(map[string]TaskStatus), + } + + task := &Task{ID: "task-final-12345678", Status: StatusCompleted} + pr.Track(task) + + pr.ReportFinal(context.Background(), task) + + reporter.mu.Lock() + defer reporter.mu.Unlock() + if len(reporter.calls) != 1 { + t.Fatalf("expected 1 final report, got %d", len(reporter.calls)) + } + + // Should be untracked after final report + pr.mu.Lock() + if _, ok := pr.latest["task-final-12345678"]; ok { + t.Error("task should be untracked after ReportFinal") + } + pr.mu.Unlock() +} + +func TestProgressReporter_SetHandlers(t *testing.T) { + pr := &ProgressReporter{ + latest: make(map[string]*Task), + lastReported: make(map[string]TaskStatus), + } + + pr.SetCancelHandler(func(id string) {}) + pr.SetPauseHandler(func(id string) {}) + pr.SetDeleteFilesHandler(func(id string) {}) + pr.SetStreamRequestedHandler(func(id string) {}) + pr.SetWatchingFunc(func() bool { return true }) + pr.SetWatchingChangedHandler(func(w bool) {}) + + if pr.onCancel == nil || pr.onPause == nil || pr.onDeleteFiles == nil || + pr.onStreamRequested == nil || pr.isWatching == nil || pr.onWatchingChanged == nil { + t.Error("expected all handlers to be set") + } +} + +func TestProgressReporter_ControlCheckDue(t *testing.T) { + reporter := &mockStatusReporter{} + pr := &ProgressReporter{ + reporter: reporter, + interval: time.Second, + isWatching: func() bool { return false }, + latest: make(map[string]*Task), + lastReported: make(map[string]TaskStatus), + lastCheckAt: time.Now().Add(-31 * time.Second), // 31s ago - due for control check + } + + task := &Task{ID: "task-ctrl-123456789", Status: StatusDownloading} + pr.Track(task) + pr.mu.Lock() + pr.lastReported["task-ctrl-123456789"] = StatusDownloading + pr.mu.Unlock() + + pr.flush(context.Background()) + + reporter.mu.Lock() + defer reporter.mu.Unlock() + if len(reporter.calls) != 1 { + t.Errorf("expected 1 report for control check, got %d", len(reporter.calls)) + } +} diff --git a/internal/engine/stream_server.go b/internal/engine/stream_server.go index 1635d69..e85cb13 100644 --- a/internal/engine/stream_server.go +++ b/internal/engine/stream_server.go @@ -11,6 +11,7 @@ import ( "os/exec" "path/filepath" "strings" + "sync/atomic" "time" "github.com/anacrolix/torrent" @@ -24,11 +25,12 @@ type fileProvider interface { // StreamServer serves a torrent file over HTTP with Range request support. type StreamServer struct { - provider fileProvider - server *http.Server - port int - url string - upnpMapping *UPnPMapping + provider fileProvider + server *http.Server + port int + url string + upnpMapping *UPnPMapping + lastActivity atomic.Int64 // UnixNano of last HTTP request } // NewStreamServer creates a new HTTP server for streaming via StreamEngine. @@ -93,11 +95,38 @@ func NewStreamServerFromDisk(filePath string, port int) *StreamServer { } } -// Start begins serving the file on all interfaces. Returns the best reachable URL: -// 1. UPnP public IP (accessible from anywhere on the internet) -// 2. Tailscale IP (accessible from any device in the tailnet) -// 3. LAN IP (accessible from local network) +// FindVideoFile scans a directory (recursively) for the largest video file. +// Returns empty string if no video file found. +func FindVideoFile(dir string) string { + var best string + var bestSize int64 + + filepath.WalkDir(dir, func(path string, d os.DirEntry, err error) error { + if err != nil || d.IsDir() { + return nil + } + ext := strings.ToLower(filepath.Ext(d.Name())) + if !VideoExts[ext] { + return nil + } + info, err := d.Info() + if err != nil { + return nil + } + if info.Size() > bestSize { + best = path + bestSize = info.Size() + } + return nil + }) + return best +} + +// Start begins serving the file on all interfaces. Returns the best reachable URL. +// The file is served as-is — the user's media player (VLC, mpv, etc.) handles decoding. func (ss *StreamServer) Start(ctx context.Context) (string, error) { + ss.lastActivity.Store(time.Now().UnixNano()) + mux := http.NewServeMux() mux.HandleFunc("/stream", ss.handler) @@ -107,19 +136,9 @@ func (ss *StreamServer) Start(ctx context.Context) (string, error) { return "", fmt.Errorf("listen on %s: %w", addr, err) } - // Extract actual port (important when port=0) ss.port = listener.Addr().(*net.TCPAddr).Port - - // Try UPnP for public internet access (like Plex Remote Access) - if mapping, upnpErr := setupUPnP(ss.port); upnpErr == nil { - ss.upnpMapping = mapping - ss.url = fmt.Sprintf("http://%s:%d/stream", mapping.ExternalIP, mapping.ExternalPort) - log.Printf("stream: UPnP mapped %s:%d -> local:%d", mapping.ExternalIP, mapping.ExternalPort, ss.port) - } else { - // Fallback: Tailscale IP > LAN IP > 127.0.0.1 - ss.url = fmt.Sprintf("http://%s:%d/stream", reachableIP(), ss.port) - log.Printf("stream: UPnP unavailable (%v), using %s", upnpErr, ss.url) - } + ss.url = fmt.Sprintf("http://%s:%d/stream", reachableIP(), ss.port) + log.Printf("stream: serving on %s", ss.url) ss.server = &http.Server{ Handler: mux, @@ -141,6 +160,15 @@ func (ss *StreamServer) URL() string { return ss.url } // Port returns the bound port. func (ss *StreamServer) Port() int { return ss.port } +// IdleSince returns how long since the last HTTP request was received. +func (ss *StreamServer) IdleSince() time.Duration { + last := ss.lastActivity.Load() + if last == 0 { + return 0 + } + return time.Since(time.Unix(0, last)) +} + // Shutdown gracefully stops the HTTP server and removes the UPnP port mapping. func (ss *StreamServer) Shutdown(ctx context.Context) error { ss.upnpMapping.Remove() @@ -151,6 +179,8 @@ func (ss *StreamServer) Shutdown(ctx context.Context) error { } func (ss *StreamServer) handler(w http.ResponseWriter, r *http.Request) { + ss.lastActivity.Store(time.Now().UnixNano()) + // CORS headers — only when browser sends Origin (HTTPS site → localhost) if origin := r.Header.Get("Origin"); origin != "" { w.Header().Set("Access-Control-Allow-Origin", "*") diff --git a/internal/library/mediainfo/ffprobe.go b/internal/library/mediainfo/ffprobe.go index f2c70fb..723ef6f 100644 --- a/internal/library/mediainfo/ffprobe.go +++ b/internal/library/mediainfo/ffprobe.go @@ -78,6 +78,12 @@ func ExtractMediaInfo(ctx context.Context, ffprobePath, filePath string) (*Media return nil, fmt.Errorf("ffprobe JSON parse failed: %w", err) } + return parseFFprobeOutput(data) +} + +// parseFFprobeOutput converts parsed ffprobe JSON into MediaInfo. +// Separated from ExtractMediaInfo so it can be tested without running ffprobe. +func parseFFprobeOutput(data ffprobeOutput) (*MediaInfo, error) { if len(data.Streams) == 0 { return nil, fmt.Errorf("ffprobe returned no streams") } diff --git a/internal/library/mediainfo/ffprobe_test.go b/internal/library/mediainfo/ffprobe_test.go new file mode 100644 index 0000000..e29eed1 --- /dev/null +++ b/internal/library/mediainfo/ffprobe_test.go @@ -0,0 +1,430 @@ +package mediainfo + +import ( + "testing" +) + +func TestParseDuration(t *testing.T) { + tests := []struct { + input string + want float64 + }{ + {"", 0}, + {"0", 0}, + {"-5", 0}, + {"invalid", 0}, + {"7423.500000", 7423.5}, + {"120.123456", 120.123}, + {"3600", 3600}, + {"0.001", 0.001}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := parseDuration(tt.input) + if got != tt.want { + t.Errorf("parseDuration(%q) = %v, want %v", tt.input, got, tt.want) + } + }) + } +} + +func TestTagValue(t *testing.T) { + tags := map[string]string{ + "language": "eng", + "title": "Main Audio", + "HANDLER": "VideoHandler", + } + + tests := []struct { + key string + want string + }{ + {"language", "eng"}, + {"title", "Main Audio"}, + {"handler", "VideoHandler"}, + {"missing", ""}, + } + + for _, tt := range tests { + t.Run(tt.key, func(t *testing.T) { + got := tagValue(tags, tt.key) + if got != tt.want { + t.Errorf("tagValue(tags, %q) = %q, want %q", tt.key, got, tt.want) + } + }) + } +} + +func TestTagValueNil(t *testing.T) { + got := tagValue(nil, "language") + if got != "" { + t.Errorf("tagValue(nil, language) = %q, want empty", got) + } +} + +func TestContainsAny(t *testing.T) { + tests := []struct { + s string + subs []string + want bool + }{ + {"yuv420p10le", []string{"10le", "10be", "p010"}, true}, + {"yuv420p12be", []string{"10le", "10be", "p010"}, false}, + {"yuv420p12be", []string{"12le", "12be"}, true}, + {"yuv420p", []string{"10le", "10be"}, false}, + {"", []string{"any"}, false}, + {"something", []string{}, false}, + } + + for _, tt := range tests { + got := containsAny(tt.s, tt.subs...) + if got != tt.want { + t.Errorf("containsAny(%q, %v) = %v, want %v", tt.s, tt.subs, got, tt.want) + } + } +} + +func TestParseFFprobeOutput_BasicH264(t *testing.T) { + data := ffprobeOutput{ + Format: ffprobeFormat{Duration: "7423.5"}, + Streams: []ffprobeStream{ + { + CodecType: "video", + CodecName: "h264", + Profile: "High", + Width: 1920, + Height: 1080, + RFrameRate: "24000/1001", + }, + { + CodecType: "audio", + CodecName: "aac", + Channels: 2, + Tags: map[string]string{"language": "eng"}, + Disposition: map[string]int{"default": 1}, + }, + }, + } + + mi, err := parseFFprobeOutput(data) + if err != nil { + t.Fatalf("parseFFprobeOutput: %v", err) + } + if mi.Video == nil { + t.Fatal("expected video info") + } + if mi.Video.Codec != "h264" { + t.Errorf("codec = %q, want h264", mi.Video.Codec) + } + if mi.Video.Width != 1920 || mi.Video.Height != 1080 { + t.Errorf("dimensions = %dx%d, want 1920x1080", mi.Video.Width, mi.Video.Height) + } + if mi.Video.Profile != "High" { + t.Errorf("profile = %q, want High", mi.Video.Profile) + } + if mi.Video.Duration != 7423.5 { + t.Errorf("duration = %v, want 7423.5", mi.Video.Duration) + } + if mi.Video.FrameRate < 23.975 || mi.Video.FrameRate > 23.977 { + t.Errorf("frameRate = %v, want ~23.976", mi.Video.FrameRate) + } + if len(mi.Audio) != 1 { + t.Fatalf("audio tracks = %d, want 1", len(mi.Audio)) + } + if mi.Audio[0].Lang != "en" { + t.Errorf("audio lang = %q, want en", mi.Audio[0].Lang) + } + if !mi.Audio[0].Default { + t.Error("expected default audio track") + } +} + +func TestParseFFprobeOutput_HEVC_HDR10(t *testing.T) { + data := ffprobeOutput{ + Format: ffprobeFormat{Duration: "3600"}, + Streams: []ffprobeStream{ + { + CodecType: "video", + CodecName: "hevc", + Width: 3840, + Height: 2160, + BitsPerRaw: "10", + ColorSpace: "bt2020nc", + ColorTransfer: "smpte2084", + RFrameRate: "24/1", + }, + }, + } + + mi, err := parseFFprobeOutput(data) + if err != nil { + t.Fatal(err) + } + if mi.Video.HDR != "HDR10" { + t.Errorf("hdr = %q, want HDR10", mi.Video.HDR) + } + if mi.Video.BitDepth != 10 { + t.Errorf("bitDepth = %d, want 10", mi.Video.BitDepth) + } +} + +func TestParseFFprobeOutput_DolbyVisionWithHDR10(t *testing.T) { + data := ffprobeOutput{ + Streams: []ffprobeStream{ + { + CodecType: "video", + CodecName: "hevc", + Width: 3840, + Height: 2160, + ColorSpace: "bt2020nc", + ColorTransfer: "smpte2084", + SideDataList: []sideData{{SideDataType: "DOVI configuration record"}}, + }, + }, + } + + mi, err := parseFFprobeOutput(data) + if err != nil { + t.Fatal(err) + } + if mi.Video.HDR != "DV+HDR10" { + t.Errorf("hdr = %q, want DV+HDR10", mi.Video.HDR) + } +} + +func TestParseFFprobeOutput_DolbyVisionOnly(t *testing.T) { + data := ffprobeOutput{ + Streams: []ffprobeStream{ + { + CodecType: "video", + CodecName: "hevc", + Width: 3840, + Height: 2160, + SideDataList: []sideData{{SideDataType: "DOVI configuration record"}}, + }, + }, + } + + mi, err := parseFFprobeOutput(data) + if err != nil { + t.Fatal(err) + } + if mi.Video.HDR != "DV" { + t.Errorf("hdr = %q, want DV", mi.Video.HDR) + } +} + +func TestParseFFprobeOutput_HLG(t *testing.T) { + data := ffprobeOutput{ + Streams: []ffprobeStream{ + { + CodecType: "video", + CodecName: "hevc", + Width: 3840, + Height: 2160, + ColorSpace: "bt2020nc", + ColorTransfer: "arib-std-b67", + }, + }, + } + + mi, err := parseFFprobeOutput(data) + if err != nil { + t.Fatal(err) + } + if mi.Video.HDR != "HLG" { + t.Errorf("hdr = %q, want HLG", mi.Video.HDR) + } +} + +func TestParseFFprobeOutput_MultiAudioAndSubtitles(t *testing.T) { + data := ffprobeOutput{ + Format: ffprobeFormat{Duration: "5400"}, + Streams: []ffprobeStream{ + {CodecType: "video", CodecName: "h264", Width: 1920, Height: 1080}, + { + CodecType: "audio", CodecName: "ac3", Channels: 6, + Tags: map[string]string{"language": "eng", "title": "English 5.1"}, + Disposition: map[string]int{"default": 1}, + }, + { + CodecType: "audio", CodecName: "aac", Channels: 2, + Tags: map[string]string{"language": "spa"}, + }, + { + CodecType: "subtitle", CodecName: "subrip", + Tags: map[string]string{"language": "eng"}, + }, + { + CodecType: "subtitle", CodecName: "ass", + Tags: map[string]string{"language": "spa"}, + Disposition: map[string]int{"forced": 1}, + }, + }, + } + + mi, err := parseFFprobeOutput(data) + if err != nil { + t.Fatal(err) + } + if len(mi.Audio) != 2 { + t.Fatalf("audio tracks = %d, want 2", len(mi.Audio)) + } + if mi.Audio[0].Title != "English 5.1" { + t.Errorf("audio[0].title = %q", mi.Audio[0].Title) + } + if len(mi.Subtitles) != 2 { + t.Fatalf("subtitle tracks = %d, want 2", len(mi.Subtitles)) + } + if !mi.Subtitles[1].Forced { + t.Error("expected subtitle[1] to be forced") + } + if len(mi.Languages) != 2 { + t.Errorf("languages = %v, want 2 entries", mi.Languages) + } +} + +func TestParseFFprobeOutput_BitDepthFromPixFmt(t *testing.T) { + data := ffprobeOutput{ + Streams: []ffprobeStream{ + {CodecType: "video", CodecName: "hevc", Width: 1920, Height: 1080, PixFmt: "yuv420p10le"}, + }, + } + + mi, err := parseFFprobeOutput(data) + if err != nil { + t.Fatal(err) + } + if mi.Video.BitDepth != 10 { + t.Errorf("bitDepth = %d, want 10", mi.Video.BitDepth) + } +} + +func TestParseFFprobeOutput_12BitFromPixFmt(t *testing.T) { + data := ffprobeOutput{ + Streams: []ffprobeStream{ + {CodecType: "video", CodecName: "hevc", Width: 1920, Height: 1080, PixFmt: "yuv420p12be"}, + }, + } + + mi, err := parseFFprobeOutput(data) + if err != nil { + t.Fatal(err) + } + if mi.Video.BitDepth != 12 { + t.Errorf("bitDepth = %d, want 12", mi.Video.BitDepth) + } +} + +func TestParseFFprobeOutput_DurationFromStreamFallback(t *testing.T) { + data := ffprobeOutput{ + Format: ffprobeFormat{Duration: ""}, + Streams: []ffprobeStream{ + {CodecType: "video", CodecName: "h264", Width: 1280, Height: 720, Duration: "1800.5"}, + }, + } + + mi, err := parseFFprobeOutput(data) + if err != nil { + t.Fatal(err) + } + if mi.Video.Duration != 1800.5 { + t.Errorf("duration = %v, want 1800.5", mi.Video.Duration) + } +} + +func TestParseFFprobeOutput_NoStreams(t *testing.T) { + data := ffprobeOutput{} + _, err := parseFFprobeOutput(data) + if err == nil { + t.Error("expected error for no streams") + } +} + +func TestParseFFprobeOutput_OnlyFirstVideoStream(t *testing.T) { + data := ffprobeOutput{ + Streams: []ffprobeStream{ + {CodecType: "video", CodecName: "h264", Width: 1920, Height: 1080}, + {CodecType: "video", CodecName: "mjpeg", Width: 320, Height: 240}, // cover art + }, + } + + mi, err := parseFFprobeOutput(data) + if err != nil { + t.Fatal(err) + } + if mi.Video.Codec != "h264" { + t.Errorf("should use first video stream, got codec %q", mi.Video.Codec) + } + if mi.Video.Width != 1920 { + t.Errorf("width = %d, should be from first video stream", mi.Video.Width) + } +} + +func TestParseFFprobeOutput_SMPTE2084_WithoutBT2020(t *testing.T) { + data := ffprobeOutput{ + Streams: []ffprobeStream{ + {CodecType: "video", CodecName: "hevc", Width: 3840, Height: 2160, ColorTransfer: "smpte2084"}, + }, + } + + mi, err := parseFFprobeOutput(data) + if err != nil { + t.Fatal(err) + } + if mi.Video.HDR != "HDR10" { + t.Errorf("hdr = %q, want HDR10", mi.Video.HDR) + } +} + +func TestParseFFprobeOutput_AribWithoutBT2020(t *testing.T) { + data := ffprobeOutput{ + Streams: []ffprobeStream{ + {CodecType: "video", CodecName: "hevc", Width: 3840, Height: 2160, ColorTransfer: "arib-std-b67", ColorSpace: "other"}, + }, + } + + mi, err := parseFFprobeOutput(data) + if err != nil { + t.Fatal(err) + } + if mi.Video.HDR != "HLG" { + t.Errorf("hdr = %q, want HLG", mi.Video.HDR) + } +} + +func TestParseFFprobeOutput_AudioOnly(t *testing.T) { + data := ffprobeOutput{ + Streams: []ffprobeStream{ + {CodecType: "audio", CodecName: "flac", Channels: 2, Tags: map[string]string{"language": "eng"}}, + }, + } + + mi, err := parseFFprobeOutput(data) + if err != nil { + t.Fatal(err) + } + if mi.Video != nil { + t.Error("expected no video info for audio-only") + } + if len(mi.Audio) != 1 { + t.Errorf("audio tracks = %d, want 1", len(mi.Audio)) + } +} + +func TestParseFFprobeOutput_FrameRateNoSlash(t *testing.T) { + data := ffprobeOutput{ + Streams: []ffprobeStream{ + {CodecType: "video", CodecName: "h264", Width: 1920, Height: 1080, RFrameRate: "30"}, + }, + } + + mi, err := parseFFprobeOutput(data) + if err != nil { + t.Fatal(err) + } + if mi.Video.FrameRate != 0 { + t.Errorf("frameRate = %v, want 0 (no slash)", mi.Video.FrameRate) + } +} diff --git a/internal/library/scanner_test.go b/internal/library/scanner_test.go new file mode 100644 index 0000000..43c84b2 --- /dev/null +++ b/internal/library/scanner_test.go @@ -0,0 +1,93 @@ +package library + +import ( + "os" + "path/filepath" + "testing" +) + +func TestDiscoverFiles(t *testing.T) { + dir := t.TempDir() + + // Create video files (need to be >= 100MB to pass size check) + largeContent := make([]byte, 101*1024*1024) + + videoFiles := []string{"movie.mkv", "show.mp4", "clip.avi"} + for _, name := range videoFiles { + path := filepath.Join(dir, name) + if err := os.WriteFile(path, largeContent, 0o644); err != nil { + t.Fatalf("write %s: %v", name, err) + } + } + + // Non-video files (should be excluded) + nonVideo := []string{"readme.txt", "cover.jpg", "subs.srt"} + for _, name := range nonVideo { + if err := os.WriteFile(filepath.Join(dir, name), largeContent, 0o644); err != nil { + t.Fatalf("write %s: %v", name, err) + } + } + + // Small video file (should be excluded, < 100MB) + if err := os.WriteFile(filepath.Join(dir, "small.mkv"), []byte("small"), 0o644); err != nil { + t.Fatal(err) + } + + // Excluded pattern (sample) + sampleDir := filepath.Join(dir, "sample") + os.MkdirAll(sampleDir, 0o755) + if err := os.WriteFile(filepath.Join(sampleDir, "sample.mkv"), largeContent, 0o644); err != nil { + t.Fatal(err) + } + + files, err := discoverFiles(dir) + if err != nil { + t.Fatalf("discoverFiles: %v", err) + } + + if len(files) != 3 { + t.Errorf("expected 3 files, got %d: %v", len(files), files) + } + + // Check that all returned files are video extensions + for _, f := range files { + ext := filepath.Ext(f) + if ext != ".mkv" && ext != ".mp4" && ext != ".avi" { + t.Errorf("unexpected extension: %s", ext) + } + } +} + +func TestDiscoverFilesEmptyDir(t *testing.T) { + dir := t.TempDir() + + files, err := discoverFiles(dir) + if err != nil { + t.Fatalf("discoverFiles: %v", err) + } + if len(files) != 0 { + t.Errorf("expected 0 files, got %d", len(files)) + } +} + +func TestDiscoverFilesExcludePatterns(t *testing.T) { + dir := t.TempDir() + largeContent := make([]byte, 101*1024*1024) + + excludeDirs := []string{"trailer", "featurette", "extras", "bonus"} + for _, name := range excludeDirs { + sub := filepath.Join(dir, name) + os.MkdirAll(sub, 0o755) + if err := os.WriteFile(filepath.Join(sub, "video.mkv"), largeContent, 0o644); err != nil { + t.Fatal(err) + } + } + + files, err := discoverFiles(dir) + if err != nil { + t.Fatal(err) + } + if len(files) != 0 { + t.Errorf("expected 0 files (all excluded), got %d: %v", len(files), files) + } +} diff --git a/internal/library/sync_test.go b/internal/library/sync_test.go new file mode 100644 index 0000000..fe7a113 --- /dev/null +++ b/internal/library/sync_test.go @@ -0,0 +1,108 @@ +package library + +import ( + "testing" + + "github.com/torrentclaw/unarr/internal/library/mediainfo" +) + +func TestBuildSyncItems(t *testing.T) { + cache := &LibraryCache{ + Items: []LibraryItem{ + { + FilePath: "/media/movies/Inception.mkv", + FileName: "Inception.2010.1080p.mkv", + FileSize: 5000000000, + Title: "Inception", + Year: "2010", + MediaInfo: &mediainfo.MediaInfo{ + Video: &mediainfo.VideoInfo{ + Codec: "hevc", + Width: 1920, + Height: 1080, + BitDepth: 10, + HDR: "HDR10", + }, + Audio: []mediainfo.AudioTrack{ + {Lang: "en", Codec: "ac3", Channels: 6, Default: true}, + {Lang: "es", Codec: "aac", Channels: 2}, + }, + Subtitles: []mediainfo.SubtitleTrack{ + {Lang: "en", Codec: "subrip"}, + {Lang: "es", Codec: "subrip"}, + }, + }, + }, + { + FilePath: "/media/shows/Breaking.Bad.S01E01.mkv", + FileName: "Breaking.Bad.S01E01.mkv", + FileSize: 1000000000, + Title: "Breaking Bad", + Season: 1, + Episode: 1, + }, + { + // Item with scan error — should be skipped + FilePath: "/media/bad.mkv", + FileName: "bad.mkv", + ScanError: "ffprobe failed", + }, + }, + } + + items := BuildSyncItems(cache) + + if len(items) != 2 { + t.Fatalf("expected 2 items (1 skipped), got %d", len(items)) + } + + // First item: movie with full media info + movie := items[0] + if movie.Title != "Inception" { + t.Errorf("title = %q, want Inception", movie.Title) + } + if movie.ContentType != "movie" { + t.Errorf("contentType = %q, want movie", movie.ContentType) + } + if movie.Resolution != "1080p" { + t.Errorf("resolution = %q, want 1080p", movie.Resolution) + } + if movie.VideoCodec != "hevc" { + t.Errorf("videoCodec = %q, want hevc", movie.VideoCodec) + } + if movie.HDR != "HDR10" { + t.Errorf("hdr = %q, want HDR10", movie.HDR) + } + if movie.AudioCodec != "ac3" { + t.Errorf("audioCodec = %q, want ac3", movie.AudioCodec) + } + if movie.AudioChannels != 6 { + t.Errorf("audioChannels = %d, want 6", movie.AudioChannels) + } + if len(movie.AudioLanguages) != 2 { + t.Errorf("audioLanguages count = %d, want 2", len(movie.AudioLanguages)) + } + if len(movie.SubtitleLanguages) != 2 { + t.Errorf("subtitleLanguages count = %d, want 2", len(movie.SubtitleLanguages)) + } + + // Second item: show without media info + show := items[1] + if show.ContentType != "show" { + t.Errorf("contentType = %q, want show", show.ContentType) + } + if show.Season != 1 || show.Episode != 1 { + t.Errorf("season/episode = %d/%d, want 1/1", show.Season, show.Episode) + } + if show.Resolution != "" { + t.Errorf("resolution should be empty, got %q", show.Resolution) + } +} + +func TestBuildSyncItemsEmpty(t *testing.T) { + cache := &LibraryCache{Items: nil} + items := BuildSyncItems(cache) + if len(items) != 0 { + t.Errorf("expected 0 items, got %d", len(items)) + } +} diff --git a/internal/mediaserver/detect_test.go b/internal/mediaserver/detect_test.go index fc5b00e..19ba53c 100644 --- a/internal/mediaserver/detect_test.go +++ b/internal/mediaserver/detect_test.go @@ -2,6 +2,8 @@ package mediaserver import ( "encoding/json" + "os" + "path/filepath" "testing" ) @@ -69,6 +71,96 @@ func TestJellyfinParsing(t *testing.T) { } } +func TestPlexTokenFromPrefs(t *testing.T) { + t.Run("valid prefs", func(t *testing.T) { + dir := t.TempDir() + prefsPath := filepath.Join(dir, "Preferences.xml") + xml := ` +` + os.WriteFile(prefsPath, []byte(xml), 0o644) + + token := plexTokenFromPrefs(prefsPath) + if token != "my-secret-token" { + t.Errorf("token = %q, want my-secret-token", token) + } + }) + + t.Run("no token attr", func(t *testing.T) { + dir := t.TempDir() + prefsPath := filepath.Join(dir, "Preferences.xml") + xml := `` + os.WriteFile(prefsPath, []byte(xml), 0o644) + + token := plexTokenFromPrefs(prefsPath) + if token != "" { + t.Errorf("token = %q, want empty", token) + } + }) + + t.Run("file not found", func(t *testing.T) { + token := plexTokenFromPrefs("/nonexistent/Preferences.xml") + if token != "" { + t.Errorf("token = %q, want empty", token) + } + }) + + t.Run("invalid xml", func(t *testing.T) { + dir := t.TempDir() + prefsPath := filepath.Join(dir, "Preferences.xml") + os.WriteFile(prefsPath, []byte("not xml at all"), 0o644) + + token := plexTokenFromPrefs(prefsPath) + if token != "" { + t.Errorf("token = %q, want empty", token) + } + }) +} + +func TestParsePlexSectionsMultipleLocations(t *testing.T) { + body := `{ + "MediaContainer": { + "Directory": [ + { + "title": "Movies", + "Location": [ + {"path": "/media/movies"}, + {"path": "/media/movies2"} + ] + } + ] + } + }` + + paths := parsePlexSections([]byte(body)) + if len(paths) != 2 { + t.Fatalf("expected 2 paths, got %d", len(paths)) + } +} + +func TestParsePlexSectionsEmptyPath(t *testing.T) { + body := `{ + "MediaContainer": { + "Directory": [ + { + "Location": [{"path": ""}, {"path": "/valid"}] + } + ] + } + }` + + paths := parsePlexSections([]byte(body)) + if len(paths) != 1 { + t.Fatalf("expected 1 path (empty filtered), got %d: %v", len(paths), paths) + } +} + +func TestCommonMediaDirs(t *testing.T) { + dirs := commonMediaDirs() + if len(dirs) == 0 { + t.Error("expected at least some common media dirs") + } +} + func TestParentDir(t *testing.T) { tests := []struct { name string diff --git a/internal/sentry/sentry_test.go b/internal/sentry/sentry_test.go new file mode 100644 index 0000000..671e641 --- /dev/null +++ b/internal/sentry/sentry_test.go @@ -0,0 +1,47 @@ +package sentry + +import "testing" + +func TestEnvironment(t *testing.T) { + tests := []struct { + version string + want string + }{ + {"", "development"}, + {"dev", "development"}, + {"0.1.0-dev", "development"}, + {"1.0.0", "production"}, + {"0.3.5", "production"}, + {"2.0.0-beta", "production"}, + } + + for _, tt := range tests { + t.Run(tt.version, func(t *testing.T) { + got := environment(tt.version) + if got != tt.want { + t.Errorf("environment(%q) = %q, want %q", tt.version, got, tt.want) + } + }) + } +} + +func TestInitNoOp(t *testing.T) { + // With empty dsn (default in tests), Init should be a no-op + Init("1.0.0") + // Should not panic +} + +func TestCloseNoOp(t *testing.T) { + // Close should be safe to call without Init + Close() +} + +func TestCaptureErrorNil(t *testing.T) { + // Should not panic with nil error + CaptureError(nil, "test") +} + +func TestSetUser(t *testing.T) { + // Should not panic without initialization + SetUser("agent-123") +} diff --git a/internal/ui/format_test.go b/internal/ui/format_test.go index 7f3d2c5..e5c9eda 100644 --- a/internal/ui/format_test.go +++ b/internal/ui/format_test.go @@ -1,7 +1,9 @@ package ui import ( + "fmt" "testing" + "time" ) func TestFormatSize(t *testing.T) { @@ -127,6 +129,10 @@ func TestQualityIndicator(t *testing.T) { {"medium", intPtr(60), "🟡"}, {"high", intPtr(80), "🟢"}, {"perfect", intPtr(100), "🟢"}, + {"boundary_40", intPtr(40), "🟡"}, + {"boundary_70", intPtr(70), "🟢"}, + {"boundary_39", intPtr(39), "🔴"}, + {"zero", intPtr(0), "🔴"}, } for _, tt := range tests { @@ -139,6 +145,52 @@ func TestQualityIndicator(t *testing.T) { } } +func TestSeedHealthIndicator(t *testing.T) { + tests := []struct { + seeds int + want string + }{ + {0, "🔴"}, + {5, "🔴"}, + {9, "🔴"}, + {10, "🟡"}, + {50, "🟡"}, + {100, "🟡"}, + {101, "🟢"}, + {1000, "🟢"}, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("seeds_%d", tt.seeds), func(t *testing.T) { + got := SeedHealthIndicator(tt.seeds) + if got != tt.want { + t.Errorf("SeedHealthIndicator(%d) = %q, want %q", tt.seeds, got, tt.want) + } + }) + } +} + +func TestFormatRating(t *testing.T) { + tests := []struct { + name string + input *string + want string + }{ + {"nil", nil, "-"}, + {"value", strPtr("8.5"), "8.5"}, + {"empty", strPtr(""), ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := FormatRating(tt.input) + if got != tt.want { + t.Errorf("FormatRating(%v) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + func TestStringOrDash(t *testing.T) { s := "hello" if got := StringOrDash(&s); got != "hello" { @@ -150,16 +202,160 @@ func TestStringOrDash(t *testing.T) { } func TestFormatContentType(t *testing.T) { - if got := FormatContentType("movie"); got != "Movie" { - t.Errorf("FormatContentType(movie) = %q, want Movie", got) + tests := []struct { + input string + want string + }{ + {"movie", "Movie"}, + {"Movie", "Movie"}, + {"MOVIE", "Movie"}, + {"show", "Show"}, + {"Show", "Show"}, + {"other", "other"}, + {"", ""}, } - if got := FormatContentType("show"); got != "Show" { - t.Errorf("FormatContentType(show) = %q, want Show", got) - } - if got := FormatContentType("other"); got != "other" { - t.Errorf("FormatContentType(other) = %q, want other", got) + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := FormatContentType(tt.input) + if got != tt.want { + t.Errorf("FormatContentType(%q) = %q, want %q", tt.input, got, tt.want) + } + }) } } -func ptr[T any](v T) *T { return &v } -func intPtr(v int) *int { return &v } +func TestFormatLanguages(t *testing.T) { + tests := []struct { + name string + input []string + want string + }{ + {"nil", nil, "-"}, + {"empty", []string{}, "-"}, + {"single", []string{"en"}, "en"}, + {"multiple", []string{"en", "es", "fr"}, "en, es, fr"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := FormatLanguages(tt.input) + if got != tt.want { + t.Errorf("FormatLanguages(%v) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestFormatSeedRatio(t *testing.T) { + tests := []struct { + seeders int + leechers int + want string + }{ + {0, 0, "0:0"}, + {10, 0, "10:0"}, + {100, 10, "10:1"}, + {50, 50, "1:1"}, + {1, 3, "0:1"}, + {150, 10, "15:1"}, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("%d_%d", tt.seeders, tt.leechers), func(t *testing.T) { + got := FormatSeedRatio(tt.seeders, tt.leechers) + if got != tt.want { + t.Errorf("FormatSeedRatio(%d, %d) = %q, want %q", tt.seeders, tt.leechers, got, tt.want) + } + }) + } +} + +func TestFormatTimeAgo(t *testing.T) { + now := time.Now() + + tests := []struct { + name string + input string + want string + }{ + {"invalid", "not-a-date", "not-a-date"}, + {"just_now", now.Add(-10 * time.Second).Format(time.RFC3339), "just now"}, + {"minutes", now.Add(-5 * time.Minute).Format(time.RFC3339), "5m ago"}, + {"hours", now.Add(-3 * time.Hour).Format(time.RFC3339), "3h ago"}, + {"days", now.Add(-7 * 24 * time.Hour).Format(time.RFC3339), "7d ago"}, + {"months", now.Add(-60 * 24 * time.Hour).Format(time.RFC3339), "2mo ago"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := FormatTimeAgo(tt.input) + if got != tt.want { + t.Errorf("FormatTimeAgo(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestFormatNumberExtended(t *testing.T) { + tests := []struct { + input int + want string + }{ + {-1000, "-1,000"}, + {-5, "-5"}, + {10000, "10,000"}, + {100000, "100,000"}, + {1000000, "1,000,000"}, + } + + for _, tt := range tests { + t.Run(tt.want, func(t *testing.T) { + got := FormatNumber(tt.input) + if got != tt.want { + t.Errorf("FormatNumber(%d) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestTruncateStringEdgeCases(t *testing.T) { + tests := []struct { + name string + input string + maxLen int + want string + }{ + {"maxLen_1", "hello", 1, "h"}, + {"maxLen_3", "hello", 3, "hel"}, + {"empty", "", 5, ""}, + {"unicode", "こんにちは世界", 5, "こん..."}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := TruncateString(tt.input, tt.maxLen) + if got != tt.want { + t.Errorf("TruncateString(%q, %d) = %q, want %q", tt.input, tt.maxLen, got, tt.want) + } + }) + } +} + +func TestPtr(t *testing.T) { + v := 42 + p := Ptr(v) + if *p != 42 { + t.Errorf("Ptr(42) = %d, want 42", *p) + } + + s := "hello" + sp := Ptr(s) + if *sp != "hello" { + t.Errorf("Ptr(hello) = %q, want hello", *sp) + } +} + +func ptr[T any](v T) *T { return &v } +func intPtr(v int) *int { return &v } +func strPtr(v string) *string { return &v } diff --git a/internal/ui/table_test.go b/internal/ui/table_test.go new file mode 100644 index 0000000..3b9abff --- /dev/null +++ b/internal/ui/table_test.go @@ -0,0 +1,122 @@ +package ui + +import ( + "bytes" + "strings" + "testing" + + tc "github.com/torrentclaw/go-client" +) + +func TestNewCleanTable(t *testing.T) { + var buf bytes.Buffer + tbl := newCleanTable(&buf) + tbl.Header([]string{"A", "B"}) + tbl.Append([]string{"1", "2"}) + tbl.Render() + + out := buf.String() + if !strings.Contains(out, "A") || !strings.Contains(out, "B") { + t.Errorf("expected headers in output, got: %s", out) + } + if !strings.Contains(out, "1") || !strings.Contains(out, "2") { + t.Errorf("expected row data in output, got: %s", out) + } +} + +func TestPrintSearchResultEntry(t *testing.T) { + var buf bytes.Buffer + year := 2010 + rating := "8.8" + quality := "1080p" + codec := "x265" + size := int64(4294967296) + score := 85 + + r := tc.SearchResult{ + Title: "Inception", + Year: &year, + ContentType: "movie", + RatingIMDb: &rating, + Genres: []string{"Sci-Fi", "Action"}, + Torrents: []tc.TorrentInfo{ + { + Quality: &quality, + SizeBytes: &size, + Seeders: 150, + Leechers: 10, + Source: "YTS", + Codec: &codec, + Languages: []string{"en", "es"}, + QualityScore: &score, + }, + }, + } + + printSearchResultEntry(&buf, r) + out := buf.String() + + if !strings.Contains(out, "Inception") { + t.Error("expected title in output") + } + if !strings.Contains(out, "2010") { + t.Error("expected year in output") + } + if !strings.Contains(out, "8.8") { + t.Error("expected rating in output") + } + if !strings.Contains(out, "1080p") { + t.Error("expected quality in output") + } + if !strings.Contains(out, "YTS") { + t.Error("expected source in output") + } +} + +func TestPrintSearchResultEntryNoTorrents(t *testing.T) { + var buf bytes.Buffer + year := 2020 + + r := tc.SearchResult{ + Title: "No Torrents Movie", + Year: &year, + ContentType: "movie", + Torrents: nil, + } + + printSearchResultEntry(&buf, r) + out := buf.String() + + if !strings.Contains(out, "No Torrents Movie") { + t.Error("expected title in output") + } + if !strings.Contains(out, "No torrents available") { + t.Error("expected no-torrents message") + } +} + +func TestPrintSearchResultEntryNilFields(t *testing.T) { + var buf bytes.Buffer + + r := tc.SearchResult{ + Title: "Minimal", + ContentType: "movie", + Torrents: []tc.TorrentInfo{ + { + Seeders: 5, + Leechers: 3, + Source: "RARBG", + }, + }, + } + + printSearchResultEntry(&buf, r) + out := buf.String() + + if !strings.Contains(out, "Minimal") { + t.Error("expected title in output") + } + if !strings.Contains(out, "-") { + t.Error("expected dash for nil fields") + } +} diff --git a/internal/upgrade/cache.go b/internal/upgrade/cache.go new file mode 100644 index 0000000..7bdcfb0 --- /dev/null +++ b/internal/upgrade/cache.go @@ -0,0 +1,75 @@ +package upgrade + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "time" + + "github.com/torrentclaw/unarr/internal/config" +) + +const cacheTTL = 1 * time.Hour + +// versionCache is the on-disk structure for cached version checks. +type versionCache struct { + Version string `json:"version"` + CheckedAt time.Time `json:"checkedAt"` +} + +// cacheFilePath returns the path to the version cache file. +func cacheFilePath() string { + return filepath.Join(config.DataDir(), "latest-version.json") +} + +// ReadCachedVersion returns the cached latest version if it's fresh (< cacheTTL). +// Returns empty string if cache is missing, stale, or corrupt. +func ReadCachedVersion() string { + data, err := os.ReadFile(cacheFilePath()) + if err != nil { + return "" + } + var c versionCache + if json.Unmarshal(data, &c) != nil { + return "" + } + if time.Since(c.CheckedAt) > cacheTTL { + return "" + } + return c.Version +} + +// writeCachedVersion writes the latest version to the cache file. +func writeCachedVersion(version string) { + c := versionCache{ + Version: version, + CheckedAt: time.Now(), + } + data, err := json.Marshal(c) + if err != nil { + return + } + path := cacheFilePath() + os.MkdirAll(filepath.Dir(path), 0o755) + // Best-effort write — ignore errors + tmp := path + ".tmp" + if err := os.WriteFile(tmp, data, 0o644); err != nil { + return + } + os.Rename(tmp, path) +} + +// CheckLatestCached returns the latest version, using cache when fresh. +// If cache is stale, fetches from GitHub and updates the cache. +func CheckLatestCached(ctx context.Context) (version string, fromCache bool, err error) { + if cached := ReadCachedVersion(); cached != "" { + return cached, true, nil + } + v, err := fetchLatestVersion(ctx) + if err != nil { + return "", false, err + } + writeCachedVersion(v) + return v, false, nil +} diff --git a/internal/upgrade/upgrade.go b/internal/upgrade/upgrade.go index b70dc7e..5d31308 100644 --- a/internal/upgrade/upgrade.go +++ b/internal/upgrade/upgrade.go @@ -152,9 +152,13 @@ func (u *Upgrader) fail(format string, args ...any) Result { } } -// CheckLatest fetches the latest version from GitHub API. +// CheckLatest fetches the latest version from GitHub API and updates the cache. func CheckLatest(ctx context.Context) (string, error) { - return fetchLatestVersion(ctx) + v, err := fetchLatestVersion(ctx) + if err == nil { + writeCachedVersion(v) + } + return v, err } // installBinary copies the new binary to the target path, preserving original permissions. diff --git a/internal/upgrade/upgrade_test.go b/internal/upgrade/upgrade_test.go index 2753005..b8805db 100644 --- a/internal/upgrade/upgrade_test.go +++ b/internal/upgrade/upgrade_test.go @@ -12,6 +12,7 @@ import ( "os" "path/filepath" "runtime" + "strings" "testing" ) @@ -305,3 +306,768 @@ func TestFetchLatestVersionMockServer(t *testing.T) { t.Errorf("status = %d, want 200", resp.StatusCode) } } + +// --- New tests below --- + +// swapHTTPClient replaces the package-level httpClient and returns a restore function. +func swapHTTPClient(c *http.Client) func() { + orig := httpClient + httpClient = c + return func() { httpClient = orig } +} + +// rewriteTransport redirects all requests to the given base URL, +// preserving path and query. +type rewriteTransport struct { + url string +} + +func (rt *rewriteTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req.URL.Scheme = "http" + u := strings.TrimPrefix(rt.url, "http://") + req.URL.Host = u + return http.DefaultTransport.RoundTrip(req) +} + +// createTarGz is a test helper that creates a tar.gz file with a single file entry. +func createTarGz(t *testing.T, archivePath, entryName string, content []byte) { + t.Helper() + f, err := os.Create(archivePath) + if err != nil { + t.Fatal(err) + } + gw := gzip.NewWriter(f) + tw := tar.NewWriter(gw) + + tw.WriteHeader(&tar.Header{ + Name: entryName, + Mode: 0o755, + Size: int64(len(content)), + }) + tw.Write(content) + + tw.Close() + gw.Close() + f.Close() +} + +func TestArchiveNameTableDriven(t *testing.T) { + // We can only run archiveName for the current GOOS/GOARCH, + // so we test several version strings and verify the pattern. + tests := []struct { + version string + }{ + {"0.1.0"}, + {"1.0.0-rc1"}, + {"2.5.10"}, + {"0.0.0"}, + } + for _, tt := range tests { + t.Run(tt.version, func(t *testing.T) { + got := archiveName(tt.version) + prefix := fmt.Sprintf("unarr_%s_%s_%s.", tt.version, runtime.GOOS, runtime.GOARCH) + if !strings.HasPrefix(got, prefix) { + t.Errorf("archiveName(%q) = %q, want prefix %q", tt.version, got, prefix) + } + if runtime.GOOS == "windows" { + if !strings.HasSuffix(got, ".zip") { + t.Errorf("archiveName on windows should end with .zip, got %q", got) + } + } else { + if !strings.HasSuffix(got, ".tar.gz") { + t.Errorf("archiveName on non-windows should end with .tar.gz, got %q", got) + } + } + }) + } +} + +func TestReleaseURLEdgeCases(t *testing.T) { + tests := []struct { + name string + version string + filename string + wantURL string + }{ + { + name: "pre-release version", + version: "2.0.0-beta.1", + filename: "unarr_2.0.0-beta.1_darwin_arm64.tar.gz", + wantURL: "https://github.com/torrentclaw/unarr/releases/download/v2.0.0-beta.1/unarr_2.0.0-beta.1_darwin_arm64.tar.gz", + }, + { + name: "checksums file", + version: "3.0.0", + filename: "checksums.txt", + wantURL: "https://github.com/torrentclaw/unarr/releases/download/v3.0.0/checksums.txt", + }, + { + name: "windows zip", + version: "1.2.3", + filename: "unarr_1.2.3_windows_amd64.zip", + wantURL: "https://github.com/torrentclaw/unarr/releases/download/v1.2.3/unarr_1.2.3_windows_amd64.zip", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := releaseURL(tt.version, tt.filename) + if got != tt.wantURL { + t.Errorf("releaseURL(%q, %q) = %q, want %q", tt.version, tt.filename, got, tt.wantURL) + } + }) + } +} + +func TestExtractBinaryDispatcher(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("extractBinary dispatcher test only on unix") + } + + dir := t.TempDir() + + // Create a valid tar.gz with the unarr binary + archivePath := filepath.Join(dir, "test.tar.gz") + binaryContent := []byte("#!/bin/sh\necho dispatcher test\n") + createTarGz(t, archivePath, "unarr", binaryContent) + + destDir := filepath.Join(dir, "out") + os.MkdirAll(destDir, 0o755) + + binPath, err := extractBinary(archivePath, destDir) + if err != nil { + t.Fatalf("extractBinary() error = %v", err) + } + if filepath.Base(binPath) != "unarr" { + t.Errorf("extractBinary() returned %q, want base name 'unarr'", binPath) + } + data, _ := os.ReadFile(binPath) + if string(data) != string(binaryContent) { + t.Error("extractBinary() content mismatch") + } +} + +func TestExtractBinaryInvalidArchive(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("tar.gz test only on unix") + } + + dir := t.TempDir() + archivePath := filepath.Join(dir, "garbage.tar.gz") + os.WriteFile(archivePath, []byte("this is not a tar.gz"), 0o644) + + destDir := filepath.Join(dir, "out") + os.MkdirAll(destDir, 0o755) + + _, err := extractBinary(archivePath, destDir) + if err == nil { + t.Error("extractBinary with garbage data should return error") + } +} + +func TestExtractBinaryNonExistentArchive(t *testing.T) { + dir := t.TempDir() + _, err := extractBinary("/nonexistent-archive-file", filepath.Join(dir, "out")) + if err == nil { + t.Error("extractBinary with nonexistent file should return error") + } +} + +func TestUpgraderFail(t *testing.T) { + u := &Upgrader{CurrentVersion: "1.0.0"} + + // Capture log messages + var logged []string + u.OnProgress = func(msg string) { logged = append(logged, msg) } + + result := u.fail("something went wrong: %d", 42) + + if result.Success { + t.Error("fail() should return Success=false") + } + if result.OldVersion != "1.0.0" { + t.Errorf("fail() OldVersion = %q, want 1.0.0", result.OldVersion) + } + if result.Error == nil { + t.Fatal("fail() should set Error") + } + if !strings.Contains(result.Error.Error(), "something went wrong: 42") { + t.Errorf("fail() Error = %q, want to contain 'something went wrong: 42'", result.Error) + } + if len(logged) == 0 { + t.Error("fail() should call OnProgress") + } + if len(logged) > 0 && !strings.Contains(logged[0], "FAILED") { + t.Errorf("fail() logged %q, want to contain 'FAILED'", logged[0]) + } +} + +func TestUpgraderFailNilOnProgress(t *testing.T) { + u := &Upgrader{CurrentVersion: "2.0.0"} + // OnProgress is nil — should not panic + result := u.fail("error without listener") + if result.Success { + t.Error("fail() should return Success=false") + } +} + +func TestFetchLatestVersionWithHTTPTest(t *testing.T) { + tests := []struct { + name string + body string + statusCode int + wantVer string + wantErr bool + }{ + { + name: "valid response", + body: `{"tag_name":"v3.1.4"}`, + statusCode: 200, + wantVer: "3.1.4", + }, + { + name: "valid response without v prefix", + body: `{"tag_name":"2.0.0"}`, + statusCode: 200, + wantVer: "2.0.0", + }, + { + name: "empty tag_name", + body: `{"tag_name":""}`, + statusCode: 200, + wantErr: true, + }, + { + name: "server error", + body: `Internal Server Error`, + statusCode: 500, + wantErr: true, + }, + { + name: "invalid json", + body: `{invalid`, + statusCode: 200, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tt.statusCode) + fmt.Fprint(w, tt.body) + })) + defer srv.Close() + + // Use a custom transport that rewrites requests to our test server + restore := swapHTTPClient(&http.Client{ + Transport: &rewriteTransport{url: srv.URL}, + }) + defer restore() + + ver, err := CheckLatest(context.Background()) + if tt.wantErr { + if err == nil { + t.Errorf("CheckLatest() error = nil, want error") + } + return + } + if err != nil { + t.Fatalf("CheckLatest() error = %v", err) + } + if ver != tt.wantVer { + t.Errorf("CheckLatest() = %q, want %q", ver, tt.wantVer) + } + }) + } +} + +func TestDownloadWithHTTPTest(t *testing.T) { + archiveBody := "fake-archive-bytes-for-download-test" + + t.Run("successful download", func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify user-agent header + if ua := r.Header.Get("User-Agent"); ua != "unarr-updater" { + t.Errorf("User-Agent = %q, want 'unarr-updater'", ua) + } + w.WriteHeader(200) + fmt.Fprint(w, archiveBody) + })) + defer srv.Close() + + restore := swapHTTPClient(&http.Client{ + Transport: &rewriteTransport{url: srv.URL}, + }) + defer restore() + + path, err := download(context.Background(), "1.0.0") + if err != nil { + t.Fatalf("download() error = %v", err) + } + defer os.Remove(path) + + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("read downloaded file: %v", err) + } + if string(data) != archiveBody { + t.Errorf("downloaded content = %q, want %q", data, archiveBody) + } + }) + + t.Run("server returns 404", func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + })) + defer srv.Close() + + restore := swapHTTPClient(&http.Client{ + Transport: &rewriteTransport{url: srv.URL}, + }) + defer restore() + + _, err := download(context.Background(), "99.99.99") + if err == nil { + t.Error("download() with 404 should return error") + } + if !strings.Contains(err.Error(), "HTTP 404") { + t.Errorf("download() error = %q, want to contain 'HTTP 404'", err) + } + }) + + t.Run("cancelled context", func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + fmt.Fprint(w, "data") + })) + defer srv.Close() + + restore := swapHTTPClient(&http.Client{ + Transport: &rewriteTransport{url: srv.URL}, + }) + defer restore() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel immediately + + _, err := download(ctx, "1.0.0") + if err == nil { + t.Error("download() with cancelled context should return error") + } + }) +} + +func TestVerifyChecksumWithHTTPTest(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("tar.gz test only on unix") + } + + // Create a fake archive file + dir := t.TempDir() + archiveContent := []byte("archive-content-for-checksum-test") + archivePath := filepath.Join(dir, "test-archive.tar.gz") + os.WriteFile(archivePath, archiveContent, 0o644) + + h := sha256.Sum256(archiveContent) + correctHash := hex.EncodeToString(h[:]) + + // The function builds the archive name using archiveName(), which uses runtime.GOOS/GOARCH. + expectedArchiveName := archiveName("1.0.0") + + t.Run("matching checksum", func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "0000000000000000000000000000000000000000000000000000000000000000 other_file.tar.gz\n") + fmt.Fprintf(w, "%s %s\n", correctHash, expectedArchiveName) + })) + defer srv.Close() + + restore := swapHTTPClient(&http.Client{ + Transport: &rewriteTransport{url: srv.URL}, + }) + defer restore() + + err := verifyChecksum(context.Background(), "1.0.0", archivePath) + if err != nil { + t.Errorf("verifyChecksum() = %v, want nil", err) + } + }) + + t.Run("mismatched checksum", func(t *testing.T) { + wrongHash := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "%s %s\n", wrongHash, expectedArchiveName) + })) + defer srv.Close() + + restore := swapHTTPClient(&http.Client{ + Transport: &rewriteTransport{url: srv.URL}, + }) + defer restore() + + err := verifyChecksum(context.Background(), "1.0.0", archivePath) + if err == nil { + t.Error("verifyChecksum() with wrong hash should return error") + } + if !strings.Contains(err.Error(), "SHA256 mismatch") { + t.Errorf("verifyChecksum() error = %q, want to contain 'SHA256 mismatch'", err) + } + }) + + t.Run("archive not in checksums", func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890 some_other_file.tar.gz\n") + })) + defer srv.Close() + + restore := swapHTTPClient(&http.Client{ + Transport: &rewriteTransport{url: srv.URL}, + }) + defer restore() + + err := verifyChecksum(context.Background(), "1.0.0", archivePath) + if err == nil { + t.Error("verifyChecksum() with missing entry should return error") + } + if !strings.Contains(err.Error(), "no checksum found") { + t.Errorf("verifyChecksum() error = %q, want to contain 'no checksum found'", err) + } + }) + + t.Run("checksums server error", func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(500) + })) + defer srv.Close() + + restore := swapHTTPClient(&http.Client{ + Transport: &rewriteTransport{url: srv.URL}, + }) + defer restore() + + err := verifyChecksum(context.Background(), "1.0.0", archivePath) + if err == nil { + t.Error("verifyChecksum() with server error should return error") + } + if !strings.Contains(err.Error(), "HTTP 500") { + t.Errorf("verifyChecksum() error = %q, want to contain 'HTTP 500'", err) + } + }) + + t.Run("nonexistent archive file", func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "%s %s\n", correctHash, expectedArchiveName) + })) + defer srv.Close() + + restore := swapHTTPClient(&http.Client{ + Transport: &rewriteTransport{url: srv.URL}, + }) + defer restore() + + err := verifyChecksum(context.Background(), "1.0.0", "/nonexistent-archive-path") + if err == nil { + t.Error("verifyChecksum() with nonexistent archive should return error") + } + }) +} + +func TestVerifyChecksumCaseInsensitive(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("tar.gz test only on unix") + } + + dir := t.TempDir() + archiveContent := []byte("case-insensitive-hash-test") + archivePath := filepath.Join(dir, "archive.tar.gz") + os.WriteFile(archivePath, archiveContent, 0o644) + + h := sha256.Sum256(archiveContent) + // Use uppercase hash in checksums.txt — verifyChecksum uses EqualFold + upperHash := strings.ToUpper(hex.EncodeToString(h[:])) + expectedArchiveName := archiveName("1.0.0") + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "%s %s\n", upperHash, expectedArchiveName) + })) + defer srv.Close() + + restore := swapHTTPClient(&http.Client{ + Transport: &rewriteTransport{url: srv.URL}, + }) + defer restore() + + err := verifyChecksum(context.Background(), "1.0.0", archivePath) + if err != nil { + t.Errorf("verifyChecksum() with uppercase hash = %v, want nil", err) + } +} + +func TestExtractTarGzNestedDirectory(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("tar.gz test only on unix") + } + + dir := t.TempDir() + archivePath := filepath.Join(dir, "nested.tar.gz") + + binaryContent := []byte("#!/bin/sh\necho nested\n") + + f, err := os.Create(archivePath) + if err != nil { + t.Fatal(err) + } + gw := gzip.NewWriter(f) + tw := tar.NewWriter(gw) + + // Write a directory entry first + tw.WriteHeader(&tar.Header{ + Name: "unarr_1.0.0_linux_amd64/", + Typeflag: tar.TypeDir, + Mode: 0o755, + }) + + // Write a README in the subdirectory + readmeContent := []byte("This is a README") + tw.WriteHeader(&tar.Header{ + Name: "unarr_1.0.0_linux_amd64/README.md", + Mode: 0o644, + Size: int64(len(readmeContent)), + }) + tw.Write(readmeContent) + + // Write the binary nested inside the directory + tw.WriteHeader(&tar.Header{ + Name: "unarr_1.0.0_linux_amd64/unarr", + Mode: 0o755, + Size: int64(len(binaryContent)), + }) + tw.Write(binaryContent) + + // Write another unrelated file after the binary + licenseContent := []byte("MIT License") + tw.WriteHeader(&tar.Header{ + Name: "unarr_1.0.0_linux_amd64/LICENSE", + Mode: 0o644, + Size: int64(len(licenseContent)), + }) + tw.Write(licenseContent) + + tw.Close() + gw.Close() + f.Close() + + destDir := filepath.Join(dir, "out") + os.MkdirAll(destDir, 0o755) + + binPath, err := extractTarGz(archivePath, destDir) + if err != nil { + t.Fatalf("extractTarGz() with nested dir = %v", err) + } + + if filepath.Base(binPath) != "unarr" { + t.Errorf("extracted name = %q, want 'unarr'", filepath.Base(binPath)) + } + + data, _ := os.ReadFile(binPath) + if string(data) != string(binaryContent) { + t.Error("extracted content does not match") + } + + // Verify that only the binary was extracted (README and LICENSE should NOT be in destDir) + entries, _ := os.ReadDir(destDir) + if len(entries) != 1 { + names := make([]string, len(entries)) + for i, e := range entries { + names[i] = e.Name() + } + t.Errorf("destDir should contain only 'unarr', got %v", names) + } +} + +func TestExtractTarGzMultipleFiles(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("tar.gz test only on unix") + } + + dir := t.TempDir() + archivePath := filepath.Join(dir, "multi.tar.gz") + + binaryContent := []byte("#!/bin/sh\necho multi\n") + + f, err := os.Create(archivePath) + if err != nil { + t.Fatal(err) + } + gw := gzip.NewWriter(f) + tw := tar.NewWriter(gw) + + // Several non-binary files before the actual binary + for _, name := range []string{"README.md", "LICENSE", "config.yaml", "completions.bash"} { + content := []byte("content of " + name) + tw.WriteHeader(&tar.Header{ + Name: name, + Mode: 0o644, + Size: int64(len(content)), + }) + tw.Write(content) + } + + // The actual binary + tw.WriteHeader(&tar.Header{ + Name: "unarr", + Mode: 0o755, + Size: int64(len(binaryContent)), + }) + tw.Write(binaryContent) + + tw.Close() + gw.Close() + f.Close() + + destDir := filepath.Join(dir, "out") + os.MkdirAll(destDir, 0o755) + + binPath, err := extractTarGz(archivePath, destDir) + if err != nil { + t.Fatalf("extractTarGz() = %v", err) + } + + data, _ := os.ReadFile(binPath) + if string(data) != string(binaryContent) { + t.Error("binary content mismatch among multiple files") + } +} + +func TestExtractTarGzSymlinkSkipped(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("tar.gz test only on unix") + } + + dir := t.TempDir() + archivePath := filepath.Join(dir, "symlink.tar.gz") + + f, err := os.Create(archivePath) + if err != nil { + t.Fatal(err) + } + gw := gzip.NewWriter(f) + tw := tar.NewWriter(gw) + + // Write a symlink entry named "unarr" — should be skipped because Typeflag != TypeReg + tw.WriteHeader(&tar.Header{ + Name: "unarr", + Typeflag: tar.TypeSymlink, + Linkname: "/etc/passwd", + Mode: 0o755, + }) + + tw.Close() + gw.Close() + f.Close() + + destDir := filepath.Join(dir, "out") + os.MkdirAll(destDir, 0o755) + + _, err = extractTarGz(archivePath, destDir) + if err == nil { + t.Error("extractTarGz() should return error when binary is a symlink (not TypeReg)") + } + if !strings.Contains(err.Error(), "not found in archive") { + t.Errorf("error = %q, want to contain 'not found in archive'", err) + } +} + +func TestInstallBinaryNonExistentSource(t *testing.T) { + dir := t.TempDir() + dst := filepath.Join(dir, "output") + + err := installBinary("/nonexistent-source-binary", dst) + if err == nil { + t.Error("installBinary with nonexistent source should return error") + } + if !strings.Contains(err.Error(), "read new binary") { + t.Errorf("error = %q, want to contain 'read new binary'", err) + } + + // Verify destination was not created + if _, statErr := os.Stat(dst); statErr == nil { + t.Error("destination file should not exist after failed install") + } +} + +func TestInstallBinaryUnwritableDestination(t *testing.T) { + dir := t.TempDir() + src := filepath.Join(dir, "source") + os.WriteFile(src, []byte("binary"), 0o755) + + // Try to write to a path inside a non-existent directory + dst := filepath.Join(dir, "nonexistent-subdir", "binary") + + err := installBinary(src, dst) + if err == nil { + t.Error("installBinary to non-writable destination should return error") + } + if !strings.Contains(err.Error(), "write binary") { + t.Errorf("error = %q, want to contain 'write binary'", err) + } +} + +func TestUpgraderLog(t *testing.T) { + var messages []string + u := &Upgrader{ + CurrentVersion: "1.0.0", + OnProgress: func(msg string) { messages = append(messages, msg) }, + } + + u.log("hello world") + if len(messages) != 1 || messages[0] != "hello world" { + t.Errorf("log() messages = %v, want [hello world]", messages) + } +} + +func TestUpgraderLogNilOnProgress(t *testing.T) { + u := &Upgrader{CurrentVersion: "1.0.0"} + // Should not panic + u.log("test message with nil OnProgress") +} + +func TestResultFields(t *testing.T) { + r := Result{ + Success: true, + OldVersion: "1.0.0", + NewVersion: "2.0.0", + BackupPath: "/tmp/backup", + } + if !r.Success || r.OldVersion != "1.0.0" || r.NewVersion != "2.0.0" || r.BackupPath != "/tmp/backup" { + t.Errorf("Result fields not set correctly: %+v", r) + } + + r2 := Result{Success: false, Error: fmt.Errorf("test error")} + if r2.Success || r2.Error == nil { + t.Errorf("Result error case not correct: %+v", r2) + } +} + +func TestDownloadSetsUserAgent(t *testing.T) { + var gotUA string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotUA = r.Header.Get("User-Agent") + w.WriteHeader(200) + fmt.Fprint(w, "data") + })) + defer srv.Close() + + restore := swapHTTPClient(&http.Client{ + Transport: &rewriteTransport{url: srv.URL}, + }) + defer restore() + + path, err := download(context.Background(), "1.0.0") + if err != nil { + t.Fatalf("download() = %v", err) + } + defer os.Remove(path) + + if gotUA != "unarr-updater" { + t.Errorf("User-Agent = %q, want 'unarr-updater'", gotUA) + } +} diff --git a/internal/usenet/download/progress_expand_test.go b/internal/usenet/download/progress_expand_test.go new file mode 100644 index 0000000..0ce1b4f --- /dev/null +++ b/internal/usenet/download/progress_expand_test.go @@ -0,0 +1,632 @@ +package download + +import ( + "encoding/binary" + "os" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/torrentclaw/unarr/internal/usenet/nzb" +) + +// --- Fingerprint --- + +func TestFingerprint_EmptyNZB(t *testing.T) { + n := &nzb.NZB{} + fp := Fingerprint(n) + // Empty NZB should still produce a deterministic hash (of zero message IDs). + fp2 := Fingerprint(n) + if fp != fp2 { + t.Fatal("fingerprint of empty NZB should be deterministic") + } +} + +func TestFingerprint_OrderIndependent(t *testing.T) { + // Fingerprint sorts IDs, so different file order should produce the same hash. + n1 := &nzb.NZB{ + Files: []nzb.File{ + {Segments: []nzb.Segment{{MessageID: "a@x"}, {MessageID: "b@x"}}}, + {Segments: []nzb.Segment{{MessageID: "c@x"}}}, + }, + } + n2 := &nzb.NZB{ + Files: []nzb.File{ + {Segments: []nzb.Segment{{MessageID: "c@x"}}}, + {Segments: []nzb.Segment{{MessageID: "b@x"}, {MessageID: "a@x"}}}, + }, + } + if Fingerprint(n1) != Fingerprint(n2) { + t.Fatal("fingerprint should be order-independent (sorted by message ID)") + } +} + +func TestFingerprint_SingleSegment(t *testing.T) { + n := &nzb.NZB{ + Files: []nzb.File{ + {Segments: []nzb.Segment{{MessageID: "only@one"}}}, + }, + } + fp := Fingerprint(n) + if fp == [32]byte{} { + t.Fatal("fingerprint should not be zero for a non-empty NZB") + } +} + +// --- ProgressTracker MarkDone idempotency --- + +func TestProgressTracker_MarkDoneIdempotent(t *testing.T) { + dir := t.TempDir() + n := makeTestNZB(1, 5) + tracker := NewProgressTracker("idem", n, dir) + + tracker.MarkDone(0, 2) + if tracker.CompletedSegments(0) != 1 { + t.Fatalf("expected 1, got %d", tracker.CompletedSegments(0)) + } + + // Mark the same segment again — count should not increase. + tracker.MarkDone(0, 2) + if tracker.CompletedSegments(0) != 1 { + t.Fatalf("idempotent mark: expected 1, got %d", tracker.CompletedSegments(0)) + } +} + +// --- ProgressTracker TotalCompleted --- + +func TestProgressTracker_TotalCompleted(t *testing.T) { + dir := t.TempDir() + n := makeTestNZB(3, 4) // 3 files, 4 segs each + tracker := NewProgressTracker("total", n, dir) + + tracker.MarkDone(0, 0) + tracker.MarkDone(0, 1) + tracker.MarkDone(1, 3) + tracker.MarkDone(2, 0) + tracker.MarkDone(2, 1) + tracker.MarkDone(2, 2) + + if got := tracker.TotalCompleted(); got != 6 { + t.Errorf("TotalCompleted: got %d, want 6", got) + } +} + +func TestProgressTracker_TotalCompleted_Empty(t *testing.T) { + dir := t.TempDir() + n := makeTestNZB(2, 3) + tracker := NewProgressTracker("empty-total", n, dir) + + if got := tracker.TotalCompleted(); got != 0 { + t.Errorf("TotalCompleted on fresh tracker: got %d, want 0", got) + } +} + +// --- CompletedBytes edge cases --- + +func TestProgressTracker_CompletedBytes_OutOfBounds(t *testing.T) { + dir := t.TempDir() + n := makeTestNZB(1, 3) + tracker := NewProgressTracker("cb-oob", n, dir) + + if got := tracker.CompletedBytes(-1, n.Files[0].Segments); got != 0 { + t.Errorf("CompletedBytes with file -1: got %d, want 0", got) + } + if got := tracker.CompletedBytes(5, n.Files[0].Segments); got != 0 { + t.Errorf("CompletedBytes with file 5: got %d, want 0", got) + } +} + +func TestProgressTracker_CompletedBytes_AllDone(t *testing.T) { + dir := t.TempDir() + n := makeTestNZB(1, 3) + tracker := NewProgressTracker("cb-all", n, dir) + + for i := 0; i < 3; i++ { + tracker.MarkDone(0, i) + } + + got := tracker.CompletedBytes(0, n.Files[0].Segments) + expected := int64(3 * 750 * 1024) + if got != expected { + t.Errorf("CompletedBytes all done: got %d, want %d", got, expected) + } +} + +// --- CompletedSegments out of bounds --- + +func TestProgressTracker_CompletedSegments_OutOfBounds(t *testing.T) { + dir := t.TempDir() + n := makeTestNZB(1, 3) + tracker := NewProgressTracker("cs-oob", n, dir) + + if got := tracker.CompletedSegments(-1); got != 0 { + t.Errorf("CompletedSegments(-1) = %d, want 0", got) + } + if got := tracker.CompletedSegments(99); got != 0 { + t.Errorf("CompletedSegments(99) = %d, want 0", got) + } +} + +// --- Load with corrupted / truncated data --- + +func TestProgressTracker_Load_TruncatedHeader(t *testing.T) { + dir := t.TempDir() + n := makeTestNZB(1, 3) + tracker := NewProgressTracker("trunc", n, dir) + + // Write too-short data + os.WriteFile(tracker.progressPath(), []byte("UNR"), 0o644) + + loaded, err := tracker.Load() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if loaded { + t.Error("truncated header should not load") + } +} + +func TestProgressTracker_Load_BadMagic(t *testing.T) { + dir := t.TempDir() + n := makeTestNZB(1, 3) + tracker := NewProgressTracker("badmagic", n, dir) + + // Write data with wrong magic bytes + data := make([]byte, headerSize+10) + copy(data[0:4], []byte("BAAD")) + os.WriteFile(tracker.progressPath(), data, 0o644) + + loaded, err := tracker.Load() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if loaded { + t.Error("bad magic should not load") + } +} + +func TestProgressTracker_Load_BadVersion(t *testing.T) { + dir := t.TempDir() + n := makeTestNZB(1, 3) + tracker := NewProgressTracker("badver", n, dir) + + data := make([]byte, headerSize+10) + copy(data[0:4], progressMagic[:]) + data[4] = 99 // unsupported version + os.WriteFile(tracker.progressPath(), data, 0o644) + + loaded, err := tracker.Load() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if loaded { + t.Error("bad version should not load") + } +} + +func TestProgressTracker_Load_WrongFileCount(t *testing.T) { + dir := t.TempDir() + n := makeTestNZB(2, 3) + tracker := NewProgressTracker("wrongfc", n, dir) + + data := make([]byte, headerSize+20) + copy(data[0:4], progressMagic[:]) + data[4] = progressVersion + binary.LittleEndian.PutUint16(data[6:8], 99) // wrong file count + copy(data[8:40], tracker.fingerprint[:]) + os.WriteFile(tracker.progressPath(), data, 0o644) + + loaded, err := tracker.Load() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if loaded { + t.Error("wrong file count should not load") + } +} + +func TestProgressTracker_Load_TruncatedBitset(t *testing.T) { + dir := t.TempDir() + n := makeTestNZB(1, 16) + tracker := NewProgressTracker("truncbit", n, dir) + + // Build a valid header but truncate the bitset data + data := make([]byte, headerSize+4) // header + segCount but no bitset + copy(data[0:4], progressMagic[:]) + data[4] = progressVersion + binary.LittleEndian.PutUint16(data[6:8], 1) // 1 file + copy(data[8:40], tracker.fingerprint[:]) + binary.LittleEndian.PutUint32(data[headerSize:headerSize+4], 16) // 16 segs + // No bitset data follows — truncated + os.WriteFile(tracker.progressPath(), data, 0o644) + + loaded, err := tracker.Load() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if loaded { + t.Error("truncated bitset should not load") + } +} + +func TestProgressTracker_Load_SegCountMismatch(t *testing.T) { + dir := t.TempDir() + n := makeTestNZB(1, 5) + tracker := NewProgressTracker("segmis", n, dir) + + // Build valid header with correct file count and fingerprint, but wrong segCount + bitsetLen := (999 + 7) / 8 + data := make([]byte, headerSize+4+bitsetLen) + copy(data[0:4], progressMagic[:]) + data[4] = progressVersion + binary.LittleEndian.PutUint16(data[6:8], 1) + copy(data[8:40], tracker.fingerprint[:]) + binary.LittleEndian.PutUint32(data[headerSize:headerSize+4], 999) // wrong seg count + os.WriteFile(tracker.progressPath(), data, 0o644) + + loaded, err := tracker.Load() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if loaded { + t.Error("segment count mismatch should not load") + } +} + +func TestProgressTracker_Load_NonexistentFile(t *testing.T) { + dir := t.TempDir() + n := makeTestNZB(1, 3) + tracker := NewProgressTracker("nofile", n, dir) + + loaded, err := tracker.Load() + if err != nil { + t.Fatalf("unexpected error for missing file: %v", err) + } + if loaded { + t.Error("nonexistent file should return false") + } +} + +// --- Flush and Load round-trip with multiple files --- + +func TestProgressTracker_MultiFileRoundTrip(t *testing.T) { + dir := t.TempDir() + n := makeTestNZB(5, 10) // 5 files, 10 segments each + tracker := NewProgressTracker("multi-rt", n, dir) + + // Mark various segments across files + tracker.MarkDone(0, 0) + tracker.MarkDone(0, 9) + tracker.MarkDone(1, 5) + tracker.MarkDone(2, 0) + tracker.MarkDone(2, 1) + tracker.MarkDone(2, 2) + tracker.MarkDone(2, 3) + tracker.MarkDone(2, 4) + tracker.MarkDone(2, 5) + tracker.MarkDone(2, 6) + tracker.MarkDone(2, 7) + tracker.MarkDone(2, 8) + tracker.MarkDone(2, 9) // file 2 fully done + tracker.MarkDone(4, 7) + + if err := tracker.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + + // Reload + tracker2 := NewProgressTracker("multi-rt", n, dir) + loaded, err := tracker2.Load() + if err != nil { + t.Fatalf("load: %v", err) + } + if !loaded { + t.Fatal("should load") + } + + if tracker2.CompletedSegments(0) != 2 { + t.Errorf("file 0: got %d, want 2", tracker2.CompletedSegments(0)) + } + if tracker2.CompletedSegments(1) != 1 { + t.Errorf("file 1: got %d, want 1", tracker2.CompletedSegments(1)) + } + if !tracker2.IsFileDone(2) { + t.Error("file 2 should be done") + } + if tracker2.CompletedSegments(3) != 0 { + t.Errorf("file 3: got %d, want 0", tracker2.CompletedSegments(3)) + } + if tracker2.CompletedSegments(4) != 1 { + t.Errorf("file 4: got %d, want 1", tracker2.CompletedSegments(4)) + } + if tracker2.TotalCompleted() != 14 { + t.Errorf("TotalCompleted: got %d, want 14", tracker2.TotalCompleted()) + } +} + +// --- Concurrent mark + IsDone reads --- + +func TestProgressTracker_ConcurrentMarkAndRead(t *testing.T) { + dir := t.TempDir() + segCount := 500 + n := makeTestNZB(2, segCount) + tracker := NewProgressTracker("conc-rw", n, dir) + + // Use separate WaitGroups for writers and readers + var writerWg sync.WaitGroup + stop := make(chan struct{}) + + // Writers + for file := 0; file < 2; file++ { + for seg := 0; seg < segCount; seg++ { + writerWg.Add(1) + go func(f, s int) { + defer writerWg.Done() + tracker.MarkDone(f, s) + }(file, seg) + } + } + + // Readers — continuously read while writes happen + var readerWg sync.WaitGroup + for i := 0; i < 4; i++ { + readerWg.Add(1) + go func() { + defer readerWg.Done() + for { + select { + case <-stop: + return + default: + // These should never panic + tracker.IsDone(0, 0) + tracker.IsDone(1, segCount-1) + tracker.IsFileDone(0) + tracker.CompletedSegments(1) + tracker.TotalCompleted() + } + } + }() + } + + // Wait for all writers to finish, then stop readers + writerWg.Wait() + close(stop) + readerWg.Wait() + + // After all goroutines complete, everything should be done + if !tracker.IsFileDone(0) { + t.Errorf("file 0 should be done, got %d/%d", tracker.CompletedSegments(0), segCount) + } + if !tracker.IsFileDone(1) { + t.Errorf("file 1 should be done, got %d/%d", tracker.CompletedSegments(1), segCount) + } +} + +// --- Concurrent flush safety --- + +func TestProgressTracker_ConcurrentFlush(t *testing.T) { + dir := t.TempDir() + n := makeTestNZB(1, 100) + tracker := NewProgressTracker("conc-flush", n, dir) + + // Mark some segments + for i := 0; i < 50; i++ { + tracker.MarkDone(0, i) + } + + // Multiple concurrent flushes should not panic or corrupt + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + tracker.Flush() + }() + } + wg.Wait() + + // Verify state is loadable + tracker2 := NewProgressTracker("conc-flush", n, dir) + loaded, err := tracker2.Load() + if err != nil { + t.Fatalf("load: %v", err) + } + if !loaded { + t.Fatal("should load after concurrent flushes") + } + if tracker2.CompletedSegments(0) != 50 { + t.Errorf("after concurrent flush: got %d, want 50", tracker2.CompletedSegments(0)) + } +} + +// --- Remove with .tmp file --- + +func TestProgressTracker_Remove_WithTmpFile(t *testing.T) { + dir := t.TempDir() + n := makeTestNZB(1, 3) + tracker := NewProgressTracker("rm-tmp", n, dir) + + // Create all three files that Remove should clean up + os.WriteFile(tracker.progressPath(), []byte("data"), 0o644) + os.WriteFile(tracker.nzbPath(), []byte(""), 0o644) + os.WriteFile(tracker.progressPath()+".tmp", []byte("tmp"), 0o644) + + tracker.Remove() + + for _, p := range []string{tracker.progressPath(), tracker.nzbPath(), tracker.progressPath() + ".tmp"} { + if _, err := os.Stat(p); !os.IsNotExist(err) { + t.Errorf("file should be removed: %s", p) + } + } +} + +// --- CleanStaleFiles edge cases --- + +func TestCleanStaleFiles_EmptyDir(t *testing.T) { + dir := t.TempDir() + if got := CleanStaleFiles(dir, time.Hour); got != 0 { + t.Errorf("empty dir: got %d removed, want 0", got) + } +} + +func TestCleanStaleFiles_NonexistentDir(t *testing.T) { + if got := CleanStaleFiles("/nonexistent/path/that/does/not/exist", time.Hour); got != 0 { + t.Errorf("nonexistent dir: got %d removed, want 0", got) + } +} + +func TestCleanStaleFiles_AllFresh(t *testing.T) { + dir := t.TempDir() + os.WriteFile(filepath.Join(dir, "a.progress"), []byte("a"), 0o644) + os.WriteFile(filepath.Join(dir, "b.progress"), []byte("b"), 0o644) + + if got := CleanStaleFiles(dir, 24*time.Hour); got != 0 { + t.Errorf("all fresh: got %d removed, want 0", got) + } +} + +func TestCleanStaleFiles_SkipsSubdirs(t *testing.T) { + dir := t.TempDir() + subDir := filepath.Join(dir, "subdir") + os.MkdirAll(subDir, 0o755) + + // Backdate the subdir (it should not be removed) + os.Chtimes(subDir, fixedPast, fixedPast) + + if got := CleanStaleFiles(dir, 24*time.Hour); got != 0 { + t.Errorf("should skip subdirs: got %d removed, want 0", got) + } + if _, err := os.Stat(subDir); err != nil { + t.Error("subdir should still exist") + } +} + +func TestCleanStaleFiles_MixedAges(t *testing.T) { + dir := t.TempDir() + + stale1 := filepath.Join(dir, "old1.progress") + stale2 := filepath.Join(dir, "old2.nzb") + fresh := filepath.Join(dir, "new.progress") + + os.WriteFile(stale1, []byte("x"), 0o644) + os.WriteFile(stale2, []byte("x"), 0o644) + os.WriteFile(fresh, []byte("x"), 0o644) + + os.Chtimes(stale1, fixedPast, fixedPast) + os.Chtimes(stale2, fixedPast, fixedPast) + + if got := CleanStaleFiles(dir, 7*24*time.Hour); got != 2 { + t.Errorf("mixed ages: got %d removed, want 2", got) + } + if _, err := os.Stat(fresh); err != nil { + t.Error("fresh file should still exist") + } +} + +// --- progressPath / nzbPath --- + +func TestProgressTracker_Paths(t *testing.T) { + dir := "/some/dir" + n := makeTestNZB(1, 1) + tracker := NewProgressTracker("my-task", n, dir) + + if got := tracker.progressPath(); got != filepath.Join(dir, "my-task.progress") { + t.Errorf("progressPath: got %q", got) + } + if got := tracker.nzbPath(); got != filepath.Join(dir, "my-task.nzb") { + t.Errorf("nzbPath: got %q", got) + } +} + +// --- formatBytes --- + +func TestFormatBytes(t *testing.T) { + tests := []struct { + input int64 + want string + }{ + {0, "0 B"}, + {500, "500 B"}, + {1023, "1023 B"}, + {1024, "1.0 KB"}, + {1536, "1.5 KB"}, + {1048576, "1.0 MB"}, + {1073741824, "1.0 GB"}, + {1099511627776, "1.0 TB"}, + } + for _, tt := range tests { + got := formatBytes(tt.input) + if got != tt.want { + t.Errorf("formatBytes(%d) = %q, want %q", tt.input, got, tt.want) + } + } +} + +// --- Single file with 1 segment (boundary) --- + +func TestProgressTracker_SingleSegment(t *testing.T) { + dir := t.TempDir() + n := makeTestNZB(1, 1) + tracker := NewProgressTracker("single-seg", n, dir) + + if tracker.IsFileDone(0) { + t.Error("should not be done initially") + } + + tracker.MarkDone(0, 0) + + if !tracker.IsFileDone(0) { + t.Error("should be done after marking the only segment") + } + + if err := tracker.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + + tracker2 := NewProgressTracker("single-seg", n, dir) + loaded, _ := tracker2.Load() + if !loaded { + t.Fatal("should load") + } + if !tracker2.IsFileDone(0) { + t.Error("should be done after reload") + } +} + +// --- Flush creates directory if missing --- + +func TestProgressTracker_FlushCreatesDir(t *testing.T) { + base := t.TempDir() + dir := filepath.Join(base, "nested", "resume") + n := makeTestNZB(1, 2) + tracker := NewProgressTracker("mkdir-test", n, dir) + + tracker.MarkDone(0, 0) + if err := tracker.Flush(); err != nil { + t.Fatalf("flush should create dir: %v", err) + } + + if _, err := os.Stat(tracker.progressPath()); err != nil { + t.Fatalf("progress file should exist: %v", err) + } +} + +// --- Double flush after no new marks --- + +func TestProgressTracker_DoubleFlush(t *testing.T) { + dir := t.TempDir() + n := makeTestNZB(1, 3) + tracker := NewProgressTracker("dbl-flush", n, dir) + + tracker.MarkDone(0, 0) + if err := tracker.Flush(); err != nil { + t.Fatalf("first flush: %v", err) + } + + // Second flush without new marks should be a no-op (dirty=false) + if err := tracker.Flush(); err != nil { + t.Fatalf("second flush: %v", err) + } +} diff --git a/internal/usenet/nntp/client_test.go b/internal/usenet/nntp/client_test.go new file mode 100644 index 0000000..41e61e3 --- /dev/null +++ b/internal/usenet/nntp/client_test.go @@ -0,0 +1,131 @@ +package nntp + +import ( + "bufio" + "bytes" + "testing" +) + +func TestNewClient(t *testing.T) { + c := NewClient(Config{Host: "news.example.com", Port: 563, SSL: true}) + if c.cfg.MaxConnections != 10 { + t.Errorf("default MaxConnections = %d, want 10", c.cfg.MaxConnections) + } + if c.cfg.Host != "news.example.com" { + t.Errorf("Host = %q", c.cfg.Host) + } +} + +func TestNewClientCustomConnections(t *testing.T) { + c := NewClient(Config{Host: "news.example.com", Port: 563, MaxConnections: 20}) + if c.cfg.MaxConnections != 20 { + t.Errorf("MaxConnections = %d, want 20", c.cfg.MaxConnections) + } +} + +func TestNewClientZeroConnections(t *testing.T) { + c := NewClient(Config{Host: "news.example.com", Port: 563, MaxConnections: 0}) + if c.cfg.MaxConnections != 10 { + t.Errorf("MaxConnections should default to 10, got %d", c.cfg.MaxConnections) + } +} + +func TestNewClientNegativeConnections(t *testing.T) { + c := NewClient(Config{MaxConnections: -5}) + if c.cfg.MaxConnections != 10 { + t.Errorf("MaxConnections should default to 10 for negative, got %d", c.cfg.MaxConnections) + } +} + +func TestActiveConnections(t *testing.T) { + c := NewClient(Config{Host: "localhost", Port: 119}) + if c.ActiveConnections() != 0 { + t.Errorf("ActiveConnections = %d, want 0", c.ActiveConnections()) + } +} + +func TestStatus(t *testing.T) { + c := NewClient(Config{Host: "news.example.com", Port: 563}) + s := c.Status() + if s != "0 connections (0 pooled) to news.example.com:563" { + t.Errorf("Status = %q", s) + } +} + +func TestCloseIdempotent(t *testing.T) { + c := NewClient(Config{Host: "localhost", Port: 119}) + // Close should be idempotent + if err := c.Close(); err != nil { + t.Errorf("first Close: %v", err) + } + if err := c.Close(); err != nil { + t.Errorf("second Close: %v", err) + } +} + +func TestArticleNotFoundError(t *testing.T) { + err := &ArticleNotFoundError{MessageID: "abc123@news.example.com"} + msg := err.Error() + if msg != "nntp: article not found: abc123@news.example.com" { + t.Errorf("Error() = %q", msg) + } +} + +func TestReadDotBody(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + { + "simple body", + "Hello World\r\n.\r\n", + "Hello World\n", + }, + { + "multiline", + "Line 1\r\nLine 2\r\nLine 3\r\n.\r\n", + "Line 1\nLine 2\nLine 3\n", + }, + { + "dot-stuffed line", + "..This starts with a dot\r\n.\r\n", + ".This starts with a dot\n", + }, + { + "empty body", + ".\r\n", + "", + }, + { + "binary-like data", + "=ybegin line=128 size=1024 name=test.bin\r\nsome encoded data\r\n=yend\r\n.\r\n", + "=ybegin line=128 size=1024 name=test.bin\nsome encoded data\n=yend\n", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := bufio.NewReader(bytes.NewBufferString(tt.input)) + got, err := readDotBody(r) + if err != nil { + t.Fatalf("readDotBody: %v", err) + } + if string(got) != tt.want { + t.Errorf("readDotBody = %q, want %q", string(got), tt.want) + } + }) + } +} + +func TestReadDotBodyEOF(t *testing.T) { + // No dot terminator — should read until EOF + r := bufio.NewReader(bytes.NewBufferString("partial data\r\n")) + got, err := readDotBody(r) + if err != nil { + t.Fatalf("readDotBody EOF: %v", err) + } + if string(got) != "partial data\n" { + t.Errorf("readDotBody EOF = %q", string(got)) + } +} diff --git a/internal/usenet/nzb/parser_test.go b/internal/usenet/nzb/parser_test.go index 8d0d686..6afeef0 100644 --- a/internal/usenet/nzb/parser_test.go +++ b/internal/usenet/nzb/parser_test.go @@ -267,3 +267,784 @@ func TestStripAngleBrackets(t *testing.T) { t.Errorf("MessageID not stripped: got %q", nzb.Files[0].Segments[0].MessageID) } } + +// --- Malformed / edge-case XML inputs --- + +func TestParse_CompletelyEmpty(t *testing.T) { + _, err := Parse(strings.NewReader("")) + if err == nil { + t.Error("expected error for completely empty input") + } +} + +func TestParse_OnlyWhitespace(t *testing.T) { + _, err := Parse(strings.NewReader(" \n\t ")) + if err == nil { + t.Error("expected error for whitespace-only input") + } +} + +func TestParse_ValidXMLButNotNZB(t *testing.T) { + _, err := Parse(strings.NewReader(`Hello`)) + if err == nil { + t.Error("expected error for non-NZB XML") + } +} + +func TestParse_NZBWithNoSegments(t *testing.T) { + xml := ` + + + alt.test + + +` + _, err := Parse(strings.NewReader(xml)) + if err == nil { + t.Error("expected error for file with no segments") + } +} + +func TestParse_SegmentWithEmptyMessageID(t *testing.T) { + xml := ` + + + alt.test + + + + +` + _, err := Parse(strings.NewReader(xml)) + if err == nil { + t.Error("expected error: segment with empty/whitespace message ID should be skipped, leaving no valid files") + } +} + +func TestParse_MixedValidAndEmptySegments(t *testing.T) { + xml := ` + + + alt.test + + valid@id + + also-valid@id + + +` + nzb, err := Parse(strings.NewReader(xml)) + if err != nil { + t.Fatalf("Parse failed: %v", err) + } + if len(nzb.Files[0].Segments) != 2 { + t.Errorf("expected 2 valid segments, got %d", len(nzb.Files[0].Segments)) + } +} + +// --- Metadata / Head parsing --- + +func TestParse_MetaPassword(t *testing.T) { + xml := ` + + + s3cr3t + My Movie + Movies + + + alt.test + + seg@id + + +` + nzb, err := Parse(strings.NewReader(xml)) + if err != nil { + t.Fatalf("Parse failed: %v", err) + } + if nzb.Password != "s3cr3t" { + t.Errorf("Password: got %q, want %q", nzb.Password, "s3cr3t") + } + if nzb.Meta["title"] != "My Movie" { + t.Errorf("Meta title: got %q", nzb.Meta["title"]) + } + if nzb.Meta["category"] != "Movies" { + t.Errorf("Meta category: got %q", nzb.Meta["category"]) + } +} + +func TestParse_MetaPasswordWithWhitespace(t *testing.T) { + xml := ` + + + padded + + + alt.test + + seg@id + + +` + nzb, err := Parse(strings.NewReader(xml)) + if err != nil { + t.Fatalf("Parse failed: %v", err) + } + if nzb.Password != "padded" { + t.Errorf("Password should be trimmed: got %q", nzb.Password) + } +} + +func TestParse_NoHead(t *testing.T) { + xml := ` + + + alt.test + + seg@id + + +` + nzb, err := Parse(strings.NewReader(xml)) + if err != nil { + t.Fatalf("Parse failed: %v", err) + } + if nzb.Password != "" { + t.Errorf("Password should be empty: got %q", nzb.Password) + } + if len(nzb.Meta) != 0 { + t.Errorf("Meta should be empty: got %v", nzb.Meta) + } +} + +func TestParse_MetaWithEmptyType(t *testing.T) { + xml := ` + + + ignored + kept + + + alt.test + + seg@id + + +` + nzb, err := Parse(strings.NewReader(xml)) + if err != nil { + t.Fatalf("Parse failed: %v", err) + } + if _, ok := nzb.Meta[""]; ok { + t.Error("empty-type meta should not be stored") + } + if nzb.Meta["name"] != "kept" { + t.Errorf("Meta name: got %q", nzb.Meta["name"]) + } +} + +// --- Multiple files --- + +func TestParse_MultipleFilesVariousTypes(t *testing.T) { + xml := ` + + + alt.binaries.movies + + mkv001@ex + mkv002@ex + + + + alt.binaries.movies + + nfo001@ex + + + + alt.binaries.movies + + par001@ex + + + + alt.binaries.movies + + parv001@ex + + + + alt.binaries.movies + + sample001@ex + + +` + nzb, err := Parse(strings.NewReader(xml)) + if err != nil { + t.Fatalf("Parse failed: %v", err) + } + + if len(nzb.Files) != 5 { + t.Fatalf("expected 5 files, got %d", len(nzb.Files)) + } + + // ContentFiles should exclude nfo, par2, par2 vol, and sample + content := nzb.ContentFiles() + if len(content) != 1 { + t.Errorf("ContentFiles: got %d, want 1", len(content)) + } + if len(content) > 0 && content[0].Filename() != "movie.mkv" { + t.Errorf("ContentFiles[0]: got %q, want movie.mkv", content[0].Filename()) + } + + // Par2Files + par2 := nzb.Par2Files() + if len(par2) != 2 { + t.Errorf("Par2Files: got %d, want 2", len(par2)) + } + + if !nzb.HasPar2() { + t.Error("HasPar2 should be true") + } + if nzb.HasRars() { + t.Error("HasRars should be false for this NZB") + } +} + +// --- Segment ordering / number parsing --- + +func TestParse_SegmentNumberParsing(t *testing.T) { + xml := ` + + + alt.test + + c@id + a@id + b@id + + +` + nzb, err := Parse(strings.NewReader(xml)) + if err != nil { + t.Fatalf("Parse failed: %v", err) + } + + segs := nzb.Files[0].Segments + if len(segs) != 3 { + t.Fatalf("expected 3 segments, got %d", len(segs)) + } + + // Parse preserves order from XML; sorting is done by the downloader + // Verify numbers are parsed correctly + numbers := make(map[int]bool) + for _, s := range segs { + numbers[s.Number] = true + } + for _, want := range []int{1, 2, 3} { + if !numbers[want] { + t.Errorf("missing segment number %d", want) + } + } +} + +func TestParse_SegmentBytesZero(t *testing.T) { + xml := ` + + + alt.test + + seg@id + + +` + nzb, err := Parse(strings.NewReader(xml)) + if err != nil { + t.Fatalf("Parse failed: %v", err) + } + if nzb.Files[0].Segments[0].Bytes != 0 { + t.Errorf("expected 0 bytes, got %d", nzb.Files[0].Segments[0].Bytes) + } +} + +func TestParse_SegmentBytesNonNumeric(t *testing.T) { + xml := ` + + + alt.test + + seg@id + + +` + nzb, err := Parse(strings.NewReader(xml)) + if err != nil { + t.Fatalf("Parse failed: %v", err) + } + // Non-numeric bytes should parse as 0 + if nzb.Files[0].Segments[0].Bytes != 0 { + t.Errorf("non-numeric bytes should be 0, got %d", nzb.Files[0].Segments[0].Bytes) + } +} + +// --- File helper methods --- + +func TestFileTotalBytes(t *testing.T) { + f := File{ + Segments: []Segment{ + {Bytes: 100}, {Bytes: 200}, {Bytes: 300}, + }, + } + if got := f.TotalBytes(); got != 600 { + t.Errorf("TotalBytes: got %d, want 600", got) + } +} + +func TestFileTotalBytes_Empty(t *testing.T) { + f := File{} + if got := f.TotalBytes(); got != 0 { + t.Errorf("TotalBytes of empty file: got %d, want 0", got) + } +} + +func TestFileExtension_Various(t *testing.T) { + tests := []struct { + subject string + want string + }{ + {`"file.MKV" yEnc`, ".mkv"}, + {`"file.RAR" yEnc`, ".rar"}, + {`"file.Par2" yEnc`, ".par2"}, + {`"noext" yEnc`, ""}, + {`"file.tar.gz" yEnc`, ".gz"}, + } + for _, tt := range tests { + f := File{Subject: tt.subject} + if got := f.Extension(); got != tt.want { + t.Errorf("Extension(%q) = %q, want %q", tt.subject, got, tt.want) + } + } +} + +// --- LargestFile edge cases --- + +func TestLargestFile_EmptyNZB(t *testing.T) { + nzb := &NZB{} + if nzb.LargestFile() != nil { + t.Error("LargestFile should return nil for empty NZB") + } +} + +func TestLargestFile_SingleFile(t *testing.T) { + nzb := &NZB{ + Files: []File{ + {Subject: `"only.bin"`, Segments: []Segment{{Bytes: 100}}}, + }, + } + largest := nzb.LargestFile() + if largest == nil { + t.Fatal("LargestFile should not be nil") + } + if largest.Filename() != "only.bin" { + t.Errorf("got %q", largest.Filename()) + } +} + +func TestLargestFile_MultipleSameSize(t *testing.T) { + nzb := &NZB{ + Files: []File{ + {Subject: `"a.bin"`, Segments: []Segment{{Bytes: 100}}}, + {Subject: `"b.bin"`, Segments: []Segment{{Bytes: 100}}}, + }, + } + largest := nzb.LargestFile() + if largest == nil { + t.Fatal("LargestFile should not be nil") + } + // Should return the first one (stable) + if largest.Filename() != "a.bin" { + t.Errorf("got %q, expected first file for equal sizes", largest.Filename()) + } +} + +// --- IsObfuscated --- + +func TestIsObfuscated_Normal(t *testing.T) { + nzb := &NZB{ + Files: []File{ + {Subject: `"Movie.2024.1080p.BluRay.x264-GROUP.mkv"`}, + }, + } + if nzb.IsObfuscated() { + t.Error("normal filename should not be obfuscated") + } +} + +func TestIsObfuscated_HexName(t *testing.T) { + nzb := &NZB{ + Files: []File{ + {Subject: `"a1b2c3d4e5f6a7b8c9d0e1f2.mkv"`}, + }, + } + if !nzb.IsObfuscated() { + t.Error("hex-like filename should be obfuscated") + } +} + +func TestIsObfuscated_EmptyFiles(t *testing.T) { + nzb := &NZB{} + if nzb.IsObfuscated() { + t.Error("empty NZB should not be obfuscated") + } +} + +func TestIsObfuscated_ShortHex(t *testing.T) { + // Short name (<=10 chars) should not trigger obfuscation + nzb := &NZB{ + Files: []File{ + {Subject: `"abcdef.mkv"`}, + }, + } + if nzb.IsObfuscated() { + t.Error("short hex-like name should not be obfuscated") + } +} + +// --- isMetadataFile --- + +func TestIsMetadataFile(t *testing.T) { + tests := []struct { + name string + want bool + }{ + {"file.par2", true}, + {"file.nfo", true}, + {"file.sfv", true}, + {"file.nzb", true}, + {"file.txt", true}, + {"file.jpg", true}, + {"file.png", true}, + {"file.url", true}, + {"file.mkv", false}, + {"file.rar", false}, + {"file.avi", false}, + {"FILE.PAR2", true}, + {"FILE.NFO", true}, + } + for _, tt := range tests { + if got := isMetadataFile(tt.name); got != tt.want { + t.Errorf("isMetadataFile(%q) = %v, want %v", tt.name, got, tt.want) + } + } +} + +// --- isSampleFile --- + +func TestIsSampleFile(t *testing.T) { + tests := []struct { + name string + want bool + }{ + {"movie.sample.mkv", true}, + {"Sample.mkv", true}, + {"SAMPLE.avi", true}, + {"movie-sample-video.mkv", true}, + {"movie_sample.mkv", true}, + {"sample.mkv", true}, + {"resampled.mkv", false}, // "sample" is part of "resampled" + {"movie.mkv", false}, + {"my.samples.zip", false}, // "sample" followed by 's' (alphanumeric) + } + for _, tt := range tests { + if got := isSampleFile(tt.name); got != tt.want { + t.Errorf("isSampleFile(%q) = %v, want %v", tt.name, got, tt.want) + } + } +} + +// --- isHexLike --- + +func TestIsHexLike(t *testing.T) { + tests := []struct { + input string + want bool + }{ + {"abcdef0123456789", true}, + {"ABCDEF", true}, + {"Movie2024", false}, + {"aabbccdd", true}, + {"xyz_not_hex", false}, + } + for _, tt := range tests { + if got := isHexLike(tt.input); got != tt.want { + t.Errorf("isHexLike(%q) = %v, want %v", tt.input, got, tt.want) + } + } +} + +// --- sanitizeFilename --- + +func TestSanitizeFilename(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"simple name", "simple name"}, + {"name (1/50)", "name"}, + {"file yEnc (01/99)", "file"}, + {`path/with\special:chars*?`, `path_with_special_chars__`}, + {`"quoted" text`, `_quoted_ text`}, + {" spaces ", "spaces"}, + } + for _, tt := range tests { + if got := sanitizeFilename(tt.input); got != tt.want { + t.Errorf("sanitizeFilename(%q) = %q, want %q", tt.input, got, tt.want) + } + } +} + +// --- Filename fallback --- + +func TestFilename_Fallback_NoQuotes(t *testing.T) { + f := File{Subject: "No quotes here yEnc (1/50)"} + got := f.Filename() + if got != "No quotes here" { + t.Errorf("Filename fallback: got %q, want %q", got, "No quotes here") + } +} + +func TestFilename_EmptySubject(t *testing.T) { + f := File{Subject: ""} + got := f.Filename() + if got != "" { + t.Errorf("Filename empty subject: got %q, want empty", got) + } +} + +// --- NZB aggregate methods on mixed content --- + +func TestNZB_HasRars_NoRars(t *testing.T) { + nzb := &NZB{ + Files: []File{ + {Subject: `"movie.mkv"`}, + {Subject: `"movie.par2"`}, + }, + } + if nzb.HasRars() { + t.Error("HasRars should be false") + } +} + +func TestNZB_HasPar2_NoPar2(t *testing.T) { + nzb := &NZB{ + Files: []File{ + {Subject: `"movie.mkv"`}, + {Subject: `"movie.rar"`}, + }, + } + if nzb.HasPar2() { + t.Error("HasPar2 should be false") + } +} + +func TestNZB_TotalSegments_MultiFile(t *testing.T) { + nzb := &NZB{ + Files: []File{ + {Segments: []Segment{{}, {}, {}}}, + {Segments: []Segment{{}, {}}}, + }, + } + if got := nzb.TotalSegments(); got != 5 { + t.Errorf("TotalSegments: got %d, want 5", got) + } +} + +func TestNZB_TotalBytes_MultiFile(t *testing.T) { + nzb := &NZB{ + Files: []File{ + {Segments: []Segment{{Bytes: 100}, {Bytes: 200}}}, + {Segments: []Segment{{Bytes: 300}}}, + }, + } + if got := nzb.TotalBytes(); got != 600 { + t.Errorf("TotalBytes: got %d, want 600", got) + } +} + +// --- isRarFile extended --- + +func TestIsRarFile_Extended(t *testing.T) { + tests := []struct { + name string + want bool + }{ + {"file.RAR", true}, // case insensitive + {"file.Rar", true}, + {"file.s01", true}, + {"file.s99", true}, + {"file.002", true}, + {"file.999", true}, + {"file.r0", false}, // too short extension + {"file.rXX", false}, // non-numeric + {"file", false}, // no extension + {"file.mp4", false}, + } + for _, tt := range tests { + if got := isRarFile(tt.name); got != tt.want { + t.Errorf("isRarFile(%q) = %v, want %v", tt.name, got, tt.want) + } + } +} + +// --- Parse with date edge cases --- + +func TestParse_DateNonNumeric(t *testing.T) { + xml := ` + + + alt.test + + seg@id + + +` + nzb, err := Parse(strings.NewReader(xml)) + if err != nil { + t.Fatalf("Parse failed: %v", err) + } + if nzb.Files[0].Date != 0 { + t.Errorf("non-numeric date should be 0, got %d", nzb.Files[0].Date) + } +} + +func TestParse_DateEmpty(t *testing.T) { + xml := ` + + + alt.test + + seg@id + + +` + nzb, err := Parse(strings.NewReader(xml)) + if err != nil { + t.Fatalf("Parse failed: %v", err) + } + if nzb.Files[0].Date != 0 { + t.Errorf("empty date should be 0, got %d", nzb.Files[0].Date) + } +} + +// --- Parse: file with all segments having empty IDs should be excluded --- + +func TestParse_AllEmptySegments(t *testing.T) { + xml := ` + + + alt.test + + + + + + + alt.test + + valid@id + + +` + nzb, err := Parse(strings.NewReader(xml)) + if err != nil { + t.Fatalf("Parse failed: %v", err) + } + if len(nzb.Files) != 1 { + t.Fatalf("expected 1 valid file, got %d", len(nzb.Files)) + } + if nzb.Files[0].Filename() != "good.bin" { + t.Errorf("expected good.bin, got %q", nzb.Files[0].Filename()) + } +} + +// --- Groups --- + +func TestParse_NoGroups(t *testing.T) { + xml := ` + + + + + seg@id + + +` + nzb, err := Parse(strings.NewReader(xml)) + if err != nil { + t.Fatalf("Parse failed: %v", err) + } + if len(nzb.Files[0].Groups) != 0 { + t.Errorf("expected 0 groups, got %d", len(nzb.Files[0].Groups)) + } +} + +func TestParse_MultipleGroups(t *testing.T) { + xml := ` + + + + alt.binaries.movies + alt.binaries.multimedia + alt.binaries.hdtv + + + seg@id + + +` + nzb, err := Parse(strings.NewReader(xml)) + if err != nil { + t.Fatalf("Parse failed: %v", err) + } + if len(nzb.Files[0].Groups) != 3 { + t.Errorf("expected 3 groups, got %d", len(nzb.Files[0].Groups)) + } +} + +// --- ContentFiles with sample variations --- + +func TestContentFiles_ExcludesSamples(t *testing.T) { + nzb := &NZB{ + Files: []File{ + {Subject: `"movie.mkv"`, Segments: []Segment{{Bytes: 1000, MessageID: "a"}}}, + {Subject: `"movie.sample.mkv"`, Segments: []Segment{{Bytes: 100, MessageID: "b"}}}, + {Subject: `"Sample/preview.mkv"`, Segments: []Segment{{Bytes: 100, MessageID: "c"}}}, + }, + } + content := nzb.ContentFiles() + if len(content) != 1 { + t.Errorf("ContentFiles should exclude samples: got %d, want 1", len(content)) + } +} + +// --- RarFiles with split naming --- + +func TestRarFiles_SplitRars(t *testing.T) { + nzb := &NZB{ + Files: []File{ + {Subject: `"movie.rar"`, Segments: []Segment{{MessageID: "a"}}}, + {Subject: `"movie.r00"`, Segments: []Segment{{MessageID: "b"}}}, + {Subject: `"movie.r01"`, Segments: []Segment{{MessageID: "c"}}}, + {Subject: `"movie.001"`, Segments: []Segment{{MessageID: "d"}}}, + {Subject: `"movie.002"`, Segments: []Segment{{MessageID: "e"}}}, + {Subject: `"movie.par2"`, Segments: []Segment{{MessageID: "f"}}}, + {Subject: `"movie.mkv"`, Segments: []Segment{{MessageID: "g"}}}, + }, + } + rars := nzb.RarFiles() + if len(rars) != 5 { + t.Errorf("RarFiles: got %d, want 5", len(rars)) + } +} diff --git a/internal/usenet/postprocess/extract_test.go b/internal/usenet/postprocess/extract_test.go new file mode 100644 index 0000000..5e61c75 --- /dev/null +++ b/internal/usenet/postprocess/extract_test.go @@ -0,0 +1,170 @@ +package postprocess + +import ( + "os" + "path/filepath" + "testing" +) + +func TestIsArchiveFile(t *testing.T) { + tests := []struct { + name string + want bool + }{ + {"movie.rar", true}, + {"movie.RAR", true}, + {"movie.part01.rar", true}, + {"movie.r00", true}, + {"movie.r99", true}, + {"movie.s00", true}, + {"movie.001", true}, + {"movie.099", true}, + {"movie.mkv", false}, + {"movie.mp4", false}, + {"movie.par2", false}, + {"movie.nfo", false}, + {"movie.txt", false}, + {"movie.r", false}, + {"movie.abc", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isArchiveFile(tt.name) + if got != tt.want { + t.Errorf("isArchiveFile(%q) = %v, want %v", tt.name, got, tt.want) + } + }) + } +} + +func TestIsCleanupTarget(t *testing.T) { + tests := []struct { + name string + want bool + }{ + {"content.par2", true}, + {"content.PAR2", true}, + {"info.nfo", true}, + {"checksum.sfv", true}, + {"content.nzb", true}, + {"content.srr", true}, + {"content.srs", true}, + {"cover.jpg", true}, + {"cover.png", true}, + {"readme.txt", true}, + {"link.url", true}, + {"movie.rar", true}, + {"movie.r00", true}, + {"movie.s01", true}, + {"movie.001", true}, + {"movie.mkv", false}, + {"movie.mp4", false}, + {"movie.avi", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isCleanupTarget(tt.name) + if got != tt.want { + t.Errorf("isCleanupTarget(%q) = %v, want %v", tt.name, got, tt.want) + } + }) + } +} + +func TestIsNumeric(t *testing.T) { + tests := []struct { + input string + want bool + }{ + {"", false}, + {"0", true}, + {"123", true}, + {"00", true}, + {"12a", false}, + {"abc", false}, + {" 1", false}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := isNumeric(tt.input) + if got != tt.want { + t.Errorf("isNumeric(%q) = %v, want %v", tt.input, got, tt.want) + } + }) + } +} + +func TestListExtractedFiles(t *testing.T) { + dir := t.TempDir() + + // Create some files + os.WriteFile(filepath.Join(dir, "movie.mkv"), []byte("video"), 0o644) + os.WriteFile(filepath.Join(dir, "subs.srt"), []byte("subs"), 0o644) + os.WriteFile(filepath.Join(dir, "movie.rar"), []byte("archive"), 0o644) + os.WriteFile(filepath.Join(dir, "movie.r00"), []byte("archive part"), 0o644) + + archivePath := filepath.Join(dir, "movie.rar") + files, err := listExtractedFiles(dir, archivePath) + if err != nil { + t.Fatalf("listExtractedFiles: %v", err) + } + + // Should exclude .rar and .r00 (archive files in same dir) + // Should include movie.mkv and subs.srt + if len(files) != 2 { + t.Errorf("expected 2 files, got %d: %v", len(files), files) + } + + for _, f := range files { + base := filepath.Base(f) + if base != "movie.mkv" && base != "subs.srt" { + t.Errorf("unexpected file: %s", base) + } + } +} + +func TestCleanup(t *testing.T) { + dir := t.TempDir() + + // Files that should be removed + cleanupFiles := []string{"content.par2", "info.nfo", "checksum.sfv", "movie.rar", "movie.r00"} + for _, name := range cleanupFiles { + os.WriteFile(filepath.Join(dir, name), []byte("data"), 0o644) + } + + // Files that should be kept + keepFiles := []string{"movie.mkv", "subs.srt"} + for _, name := range keepFiles { + os.WriteFile(filepath.Join(dir, name), []byte("data"), 0o644) + } + + err := Cleanup(dir) + if err != nil { + t.Fatalf("Cleanup: %v", err) + } + + // Verify cleanup files are gone + for _, name := range cleanupFiles { + if _, err := os.Stat(filepath.Join(dir, name)); !os.IsNotExist(err) { + t.Errorf("expected %s to be removed", name) + } + } + + // Verify kept files still exist + for _, name := range keepFiles { + if _, err := os.Stat(filepath.Join(dir, name)); err != nil { + t.Errorf("expected %s to exist, got: %v", name, err) + } + } +} + +func TestPasswordError(t *testing.T) { + err := &PasswordError{Archive: "/tmp/movie.rar"} + msg := err.Error() + if msg != "archive is password protected: /tmp/movie.rar" { + t.Errorf("PasswordError.Error() = %q", msg) + } +} diff --git a/internal/usenet/postprocess/pipeline_test.go b/internal/usenet/postprocess/pipeline_test.go new file mode 100644 index 0000000..f1a6f26 --- /dev/null +++ b/internal/usenet/postprocess/pipeline_test.go @@ -0,0 +1,156 @@ +package postprocess + +import ( + "os" + "path/filepath" + "testing" +) + +func TestFindPar2File(t *testing.T) { + dir := t.TempDir() + + // Create par2 files of different sizes + mainPar2 := filepath.Join(dir, "content.par2") + vol1 := filepath.Join(dir, "content.vol000+01.par2") + vol2 := filepath.Join(dir, "content.vol001+02.par2") + + os.WriteFile(mainPar2, make([]byte, 100), 0o644) // smallest + os.WriteFile(vol1, make([]byte, 10000), 0o644) + os.WriteFile(vol2, make([]byte, 50000), 0o644) + + files := map[string]string{ + "content.par2": mainPar2, + "content.vol000+01.par2": vol1, + "content.vol001+02.par2": vol2, + } + + result := findPar2File(files) + if result != mainPar2 { + t.Errorf("findPar2File() = %q, want %q (smallest par2)", result, mainPar2) + } +} + +func TestFindPar2FileNone(t *testing.T) { + files := map[string]string{ + "video.mkv": "/tmp/video.mkv", + "subs.srt": "/tmp/subs.srt", + } + + result := findPar2File(files) + if result != "" { + t.Errorf("findPar2File() = %q, want empty", result) + } +} + +func TestFindPar2FileEmpty(t *testing.T) { + result := findPar2File(map[string]string{}) + if result != "" { + t.Errorf("findPar2File() = %q, want empty", result) + } +} + +func TestFindFirstRarPart01(t *testing.T) { + files := map[string]string{ + "movie.part01.rar": "/tmp/movie.part01.rar", + "movie.part02.rar": "/tmp/movie.part02.rar", + "movie.part03.rar": "/tmp/movie.part03.rar", + } + + result := findFirstRar(files) + if result != "/tmp/movie.part01.rar" { + t.Errorf("findFirstRar() = %q, want part01.rar", result) + } +} + +func TestFindFirstRarSingle(t *testing.T) { + files := map[string]string{ + "movie.rar": "/tmp/movie.rar", + "movie.r00": "/tmp/movie.r00", + "movie.r01": "/tmp/movie.r01", + } + + result := findFirstRar(files) + if result != "/tmp/movie.rar" { + t.Errorf("findFirstRar() = %q, want movie.rar (shortest)", result) + } +} + +func TestFindFirstRarSplitFormat(t *testing.T) { + files := map[string]string{ + "movie.001": "/tmp/movie.001", + "movie.002": "/tmp/movie.002", + } + + result := findFirstRar(files) + if result != "/tmp/movie.001" { + t.Errorf("findFirstRar() = %q, want movie.001", result) + } +} + +func TestFindFirstRarNone(t *testing.T) { + files := map[string]string{ + "video.mkv": "/tmp/video.mkv", + "subs.srt": "/tmp/subs.srt", + } + + result := findFirstRar(files) + if result != "" { + t.Errorf("findFirstRar() = %q, want empty", result) + } +} + +func TestFindMainFile(t *testing.T) { + dir := t.TempDir() + + // Create video files of different sizes + small := filepath.Join(dir, "small.mkv") + large := filepath.Join(dir, "large.mkv") + nonVideo := filepath.Join(dir, "readme.txt") + + os.WriteFile(small, make([]byte, 1000), 0o644) + os.WriteFile(large, make([]byte, 5000), 0o644) + os.WriteFile(nonVideo, make([]byte, 9000), 0o644) + + result := findMainFile(dir, []string{small, large, nonVideo}) + if result != large { + t.Errorf("findMainFile() = %q, want %q (largest video)", result, large) + } +} + +func TestFindMainFileFallbackToDir(t *testing.T) { + dir := t.TempDir() + + video := filepath.Join(dir, "movie.mp4") + os.WriteFile(video, make([]byte, 5000), 0o644) + + // Pass empty file list — should fallback to scanning dir + result := findMainFile(dir, nil) + if result != video { + t.Errorf("findMainFile() = %q, want %q (dir scan fallback)", result, video) + } +} + +func TestFindMainFileEmpty(t *testing.T) { + dir := t.TempDir() + result := findMainFile(dir, nil) + if result != "" { + t.Errorf("findMainFile() = %q, want empty", result) + } +} + +func TestFindMainFileMultipleFormats(t *testing.T) { + dir := t.TempDir() + + mkv := filepath.Join(dir, "movie.mkv") + mp4 := filepath.Join(dir, "movie.mp4") + avi := filepath.Join(dir, "movie.avi") + + os.WriteFile(mkv, make([]byte, 3000), 0o644) + os.WriteFile(mp4, make([]byte, 5000), 0o644) // largest + os.WriteFile(avi, make([]byte, 2000), 0o644) + + result := findMainFile(dir, []string{mkv, mp4, avi}) + if result != mp4 { + t.Errorf("findMainFile() = %q, want %q", result, mp4) + } +}