feat(cli): upgrade command, rich status, and version cache
Some checks failed
Release / release (push) Failing after 0s
Release / docker (push) Has been skipped
Release / virustotal (push) Failing after 0s

- Replace `upgrade` stub with real command (alias for `self-update`)
- Also register `update` as alias: `unarr update` works too
- Rewrite `status` to show full config, disk usage, daemon state, and
  update availability with colored sections
- Add version check cache (1h TTL) so `status` is instant on repeat runs
- Guard against division by zero on empty filesystems
- Guard against negative durations from clock skew
- Guard against stale PID via heartbeat recency check (2 min)
- Add comprehensive test coverage across agent, engine, upgrade, usenet,
  arr, library, mediaserver, and UI packages
- Improve Makefile coverage target to exclude cmd/ glue code
- Fix stream handler resource cleanup and ffprobe error handling
This commit is contained in:
Deivid Soto 2026-03-31 22:05:43 +02:00
parent 01d62ffa13
commit 3e0f3a5a64
33 changed files with 7084 additions and 65 deletions

View file

@ -324,3 +324,265 @@ func TestHeartbeatWithoutUpgradeSignal(t *testing.T) {
t.Errorf("expected no upgrade signal, got %+v", resp.Upgrade)
}
}
func TestDeregister(t *testing.T) {
var received struct {
AgentID string `json:"agentId"`
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/internal/agent/deregister" {
t.Errorf("path = %s", r.URL.Path)
}
if r.Method != http.MethodPost {
t.Errorf("method = %s, want POST", r.Method)
}
json.NewDecoder(r.Body).Decode(&received)
json.NewEncoder(w).Encode(StatusResponse{Success: true})
}))
defer srv.Close()
c := NewClient(srv.URL, "test-key", "unarr-test")
err := c.Deregister(context.Background(), "agent-42")
if err != nil {
t.Fatalf("Deregister failed: %v", err)
}
if received.AgentID != "agent-42" {
t.Errorf("agentId = %q, want agent-42", received.AgentID)
}
}
func TestBatchReportStatus(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/internal/agent/status" {
t.Errorf("path = %s", r.URL.Path)
}
var req BatchStatusRequest
json.NewDecoder(r.Body).Decode(&req)
if len(req.Updates) != 2 {
t.Errorf("expected 2 updates, got %d", len(req.Updates))
}
json.NewEncoder(w).Encode(BatchStatusResponse{
Results: []StatusResponse{
{Success: true},
{Success: true, Cancelled: true},
},
})
}))
defer srv.Close()
c := NewClient(srv.URL, "test-key", "unarr-test")
resp, err := c.BatchReportStatus(context.Background(), []StatusUpdate{
{TaskID: "t1", Status: "downloading"},
{TaskID: "t2", Status: "completed"},
})
if err != nil {
t.Fatalf("BatchReportStatus failed: %v", err)
}
if len(resp.Results) != 2 {
t.Fatalf("expected 2 results, got %d", len(resp.Results))
}
if !resp.Results[1].Cancelled {
t.Error("expected result[1].Cancelled=true")
}
}
func TestSearchNzbs(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/internal/agent/nzb-search" {
t.Errorf("path = %s", r.URL.Path)
}
json.NewEncoder(w).Encode(NzbSearchResponse{
Results: []NzbSearchResult{
{NzbID: "nzb-1", Title: "Movie.2023.1080p"},
},
})
}))
defer srv.Close()
c := NewClient(srv.URL, "test-key", "unarr-test")
resp, err := c.SearchNzbs(context.Background(), NzbSearchParams{Query: "Movie"})
if err != nil {
t.Fatalf("SearchNzbs failed: %v", err)
}
if len(resp.Results) != 1 {
t.Fatalf("expected 1 result, got %d", len(resp.Results))
}
if resp.Results[0].NzbID != "nzb-1" {
t.Errorf("nzb ID = %q, want nzb-1", resp.Results[0].NzbID)
}
}
func TestDownloadNzb(t *testing.T) {
nzbContent := []byte(`<?xml version="1.0"?><nzb><file>test</file></nzb>`)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/internal/agent/nzb-download" {
t.Errorf("path = %s", r.URL.Path)
}
if r.URL.Query().Get("nzbId") != "nzb-42" {
t.Errorf("nzbId = %q, want nzb-42", r.URL.Query().Get("nzbId"))
}
w.Write(nzbContent)
}))
defer srv.Close()
c := NewClient(srv.URL, "test-key", "unarr-test")
data, err := c.DownloadNzb(context.Background(), "nzb-42")
if err != nil {
t.Fatalf("DownloadNzb failed: %v", err)
}
if string(data) != string(nzbContent) {
t.Errorf("nzb content mismatch")
}
}
func TestDownloadNzbError(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte("NZB not found"))
}))
defer srv.Close()
c := NewClient(srv.URL, "test-key", "unarr-test")
_, err := c.DownloadNzb(context.Background(), "bad-id")
if err == nil {
t.Fatal("expected error for 404 response")
}
}
func TestGetUsenetCredentials(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/internal/agent/usenet-credentials" {
t.Errorf("path = %s", r.URL.Path)
}
json.NewEncoder(w).Encode(UsenetCredentials{
Host: "news.example.com",
Port: 563,
SSL: true,
Username: "user1",
Password: "pass1",
MaxConnections: 10,
})
}))
defer srv.Close()
c := NewClient(srv.URL, "test-key", "unarr-test")
creds, err := c.GetUsenetCredentials(context.Background())
if err != nil {
t.Fatalf("GetUsenetCredentials failed: %v", err)
}
if creds.Host != "news.example.com" {
t.Errorf("host = %q, want news.example.com", creds.Host)
}
if creds.Username != "user1" {
t.Errorf("username = %q, want user1", creds.Username)
}
if creds.MaxConnections != 10 {
t.Errorf("maxConnections = %d, want 10", creds.MaxConnections)
}
}
func TestGetUsenetUsage(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/internal/agent/usenet-usage" {
t.Errorf("path = %s", r.URL.Path)
}
json.NewEncoder(w).Encode(UsenetUsageResponse{
UsedBytes: 5368709120,
QuotaBytes: 10737418240,
})
}))
defer srv.Close()
c := NewClient(srv.URL, "test-key", "unarr-test")
usage, err := c.GetUsenetUsage(context.Background())
if err != nil {
t.Fatalf("GetUsenetUsage failed: %v", err)
}
if usage.UsedBytes != 5368709120 {
t.Errorf("usedBytes = %d", usage.UsedBytes)
}
}
func TestConfigureDebrid(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/internal/agent/debrid-config" {
t.Errorf("path = %s", r.URL.Path)
}
json.NewEncoder(w).Encode(ConfigureDebridResponse{Success: true})
}))
defer srv.Close()
c := NewClient(srv.URL, "test-key", "unarr-test")
resp, err := c.ConfigureDebrid(context.Background(), ConfigureDebridRequest{
Provider: "real-debrid",
Token: "rd-token-123",
})
if err != nil {
t.Fatalf("ConfigureDebrid failed: %v", err)
}
if !resp.Success {
t.Error("expected success=true")
}
}
func TestBatchDownload(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/internal/agent/batch-download" {
t.Errorf("path = %s", r.URL.Path)
}
json.NewEncoder(w).Encode(BatchDownloadResponse{
Queued: 3,
NotFound: 1,
})
}))
defer srv.Close()
c := NewClient(srv.URL, "test-key", "unarr-test")
resp, err := c.BatchDownload(context.Background(), BatchDownloadRequest{})
if err != nil {
t.Fatalf("BatchDownload failed: %v", err)
}
if resp.Queued != 3 {
t.Errorf("queued = %d, want 3", resp.Queued)
}
}
func TestSyncLibrary(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/internal/agent/library-sync" {
t.Errorf("path = %s", r.URL.Path)
}
json.NewEncoder(w).Encode(LibrarySyncResponse{
Matched: 10,
Synced: 15,
Removed: 2,
})
}))
defer srv.Close()
c := NewClient(srv.URL, "test-key", "unarr-test")
resp, err := c.SyncLibrary(context.Background(), LibrarySyncRequest{})
if err != nil {
t.Fatalf("SyncLibrary failed: %v", err)
}
if resp.Matched != 10 {
t.Errorf("matched = %d, want 10", resp.Matched)
}
if resp.Synced != 15 {
t.Errorf("synced = %d, want 15", resp.Synced)
}
}
func TestHTMLErrorResponse(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadGateway)
w.Write([]byte("<html><body>502 Bad Gateway</body></html>"))
}))
defer srv.Close()
c := NewClient(srv.URL, "test-key", "unarr-test")
_, err := c.Register(context.Background(), RegisterRequest{AgentID: "x"})
if err == nil {
t.Fatal("expected error for HTML error page")
}
}

File diff suppressed because it is too large Load diff

396
internal/arr/client_test.go Normal file
View file

@ -0,0 +1,396 @@
package arr
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
)
func newTestServer(t *testing.T, handlers map[string]any) *httptest.Server {
t.Helper()
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Check API key header
if r.Header.Get("X-Api-Key") != "test-key" {
w.WriteHeader(http.StatusUnauthorized)
return
}
handler, ok := handlers[r.URL.Path]
if !ok {
w.WriteHeader(http.StatusNotFound)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(handler)
}))
}
func TestNewClient(t *testing.T) {
c := NewClient("http://localhost:8989/", "mykey")
if c.baseURL != "http://localhost:8989" {
t.Errorf("baseURL = %q, want trailing slash trimmed", c.baseURL)
}
if c.apiKey != "mykey" {
t.Errorf("apiKey = %q, want mykey", c.apiKey)
}
}
func TestSystemStatus(t *testing.T) {
srv := newTestServer(t, map[string]any{
"/api/v3/system/status": SystemStatus{AppName: "Radarr", Version: "4.0.0"},
})
defer srv.Close()
c := NewClient(srv.URL, "test-key")
status, err := c.SystemStatus()
if err != nil {
t.Fatalf("SystemStatus: %v", err)
}
if status.AppName != "Radarr" {
t.Errorf("AppName = %q, want Radarr", status.AppName)
}
if status.Version != "4.0.0" {
t.Errorf("Version = %q, want 4.0.0", status.Version)
}
}
func TestSystemStatusFallbackV1(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("X-Api-Key") != "test-key" {
w.WriteHeader(http.StatusUnauthorized)
return
}
switch r.URL.Path {
case "/api/v3/system/status":
w.WriteHeader(http.StatusNotFound)
case "/api/v1/system/status":
json.NewEncoder(w).Encode(SystemStatus{AppName: "Prowlarr", Version: "1.0.0"})
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
c := NewClient(srv.URL, "test-key")
status, err := c.SystemStatus()
if err != nil {
t.Fatalf("SystemStatus v1 fallback: %v", err)
}
if status.AppName != "Prowlarr" {
t.Errorf("AppName = %q, want Prowlarr", status.AppName)
}
}
func TestMovies(t *testing.T) {
srv := newTestServer(t, map[string]any{
"/api/v3/movie": []Movie{
{ID: 1, Title: "Inception", Year: 2010, TmdbID: 27205, Monitored: true},
{ID: 2, Title: "Tenet", Year: 2020, TmdbID: 577922, HasFile: true},
},
})
defer srv.Close()
c := NewClient(srv.URL, "test-key")
movies, err := c.Movies()
if err != nil {
t.Fatalf("Movies: %v", err)
}
if len(movies) != 2 {
t.Fatalf("expected 2 movies, got %d", len(movies))
}
if movies[0].Title != "Inception" {
t.Errorf("movies[0].Title = %q, want Inception", movies[0].Title)
}
}
func TestSeries(t *testing.T) {
srv := newTestServer(t, map[string]any{
"/api/v3/series": []Series{
{ID: 1, Title: "Breaking Bad", Year: 2008, TvdbID: 81189},
},
})
defer srv.Close()
c := NewClient(srv.URL, "test-key")
series, err := c.Series()
if err != nil {
t.Fatalf("Series: %v", err)
}
if len(series) != 1 {
t.Fatalf("expected 1 series, got %d", len(series))
}
if series[0].Title != "Breaking Bad" {
t.Errorf("series[0].Title = %q, want Breaking Bad", series[0].Title)
}
}
func TestQualityProfiles(t *testing.T) {
srv := newTestServer(t, map[string]any{
"/api/v3/qualityprofile": []QualityProfile{
{ID: 1, Name: "HD-1080p"},
{ID: 2, Name: "Ultra-HD"},
},
})
defer srv.Close()
c := NewClient(srv.URL, "test-key")
profiles, err := c.QualityProfiles()
if err != nil {
t.Fatalf("QualityProfiles: %v", err)
}
if len(profiles) != 2 {
t.Fatalf("expected 2 profiles, got %d", len(profiles))
}
}
func TestRootFolders(t *testing.T) {
srv := newTestServer(t, map[string]any{
"/api/v3/rootfolder": []RootFolder{
{ID: 1, Path: "/movies", FreeSpace: 500000000000},
},
})
defer srv.Close()
c := NewClient(srv.URL, "test-key")
folders, err := c.RootFolders()
if err != nil {
t.Fatalf("RootFolders: %v", err)
}
if len(folders) != 1 {
t.Fatalf("expected 1 folder, got %d", len(folders))
}
if folders[0].Path != "/movies" {
t.Errorf("path = %q, want /movies", folders[0].Path)
}
}
func TestDownloadClients(t *testing.T) {
srv := newTestServer(t, map[string]any{
"/api/v3/downloadclient": []DownloadClient{
{ID: 1, Name: "Transmission", Enable: true, Protocol: "torrent"},
},
})
defer srv.Close()
c := NewClient(srv.URL, "test-key")
clients, err := c.DownloadClients()
if err != nil {
t.Fatalf("DownloadClients: %v", err)
}
if len(clients) != 1 || clients[0].Name != "Transmission" {
t.Errorf("unexpected clients: %+v", clients)
}
}
func TestDownloadClientDetails(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("X-Api-Key") != "test-key" {
w.WriteHeader(http.StatusUnauthorized)
return
}
if r.URL.Path == "/api/v3/downloadclient/5" {
json.NewEncoder(w).Encode(struct {
Fields []Field `json:"fields"`
}{
Fields: []Field{
{Name: "host", Value: "localhost"},
{Name: "port", Value: 9091},
},
})
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
c := NewClient(srv.URL, "test-key")
fields, err := c.DownloadClientDetails(5)
if err != nil {
t.Fatalf("DownloadClientDetails: %v", err)
}
if len(fields) != 2 {
t.Fatalf("expected 2 fields, got %d", len(fields))
}
if fields[0].Name != "host" {
t.Errorf("fields[0].Name = %q, want host", fields[0].Name)
}
}
func TestTags(t *testing.T) {
srv := newTestServer(t, map[string]any{
"/api/v3/tag": []Tag{
{ID: 1, Label: "unarr"},
{ID: 2, Label: "imported"},
},
})
defer srv.Close()
c := NewClient(srv.URL, "test-key")
tags, err := c.Tags()
if err != nil {
t.Fatalf("Tags: %v", err)
}
if len(tags) != 2 {
t.Fatalf("expected 2 tags, got %d", len(tags))
}
}
func TestHistory(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("X-Api-Key") != "test-key" {
w.WriteHeader(http.StatusUnauthorized)
return
}
if r.URL.Path == "/api/v3/history" {
json.NewEncoder(w).Encode(HistoryResponse{
Records: []HistoryRecord{
{ID: 1, EventType: "grabbed", SourceTitle: "Inception.2010.1080p"},
},
TotalRecords: 1,
})
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
c := NewClient(srv.URL, "test-key")
records, err := c.History(10)
if err != nil {
t.Fatalf("History: %v", err)
}
if len(records) != 1 {
t.Fatalf("expected 1 record, got %d", len(records))
}
if records[0].SourceTitle != "Inception.2010.1080p" {
t.Errorf("sourceTitle = %q", records[0].SourceTitle)
}
}
func TestHistoryDefaultPageSize(t *testing.T) {
var requestedPath string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("X-Api-Key") != "test-key" {
w.WriteHeader(http.StatusUnauthorized)
return
}
requestedPath = r.URL.String()
json.NewEncoder(w).Encode(HistoryResponse{})
}))
defer srv.Close()
c := NewClient(srv.URL, "test-key")
c.History(0) // should default to 250
if requestedPath == "" {
t.Fatal("no request made")
}
if !contains(requestedPath, "pageSize=250") {
t.Errorf("expected pageSize=250, got path: %s", requestedPath)
}
}
func TestBlocklist(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("X-Api-Key") != "test-key" {
w.WriteHeader(http.StatusUnauthorized)
return
}
if r.URL.Path == "/api/v3/blocklist" {
json.NewEncoder(w).Encode(BlocklistResponse{
Records: []BlocklistItem{
{ID: 1, SourceTitle: "Bad.Release", Data: BlocklistData{InfoHash: "abc123"}},
},
})
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
c := NewClient(srv.URL, "test-key")
items, err := c.Blocklist(50)
if err != nil {
t.Fatalf("Blocklist: %v", err)
}
if len(items) != 1 || items[0].Data.InfoHash != "abc123" {
t.Errorf("unexpected blocklist: %+v", items)
}
}
func TestIndexers(t *testing.T) {
srv := newTestServer(t, map[string]any{
"/api/v1/indexer": []Indexer{
{ID: 1, Name: "NZBGeek", Enable: true},
{ID: 2, Name: "Torznab", Enable: false},
},
})
defer srv.Close()
c := NewClient(srv.URL, "test-key")
indexers, err := c.Indexers()
if err != nil {
t.Fatalf("Indexers: %v", err)
}
if len(indexers) != 2 {
t.Fatalf("expected 2 indexers, got %d", len(indexers))
}
}
func TestApplications(t *testing.T) {
srv := newTestServer(t, map[string]any{
"/api/v1/applications": []Application{
{ID: 1, Name: "Radarr"},
},
})
defer srv.Close()
c := NewClient(srv.URL, "test-key")
apps, err := c.Applications()
if err != nil {
t.Fatalf("Applications: %v", err)
}
if len(apps) != 1 || apps[0].Name != "Radarr" {
t.Errorf("unexpected apps: %+v", apps)
}
}
func TestUnauthorized(t *testing.T) {
srv := newTestServer(t, map[string]any{})
defer srv.Close()
c := NewClient(srv.URL, "wrong-key")
_, err := c.SystemStatus()
if err == nil {
t.Error("expected error for unauthorized request")
}
}
func TestHTTPError(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("internal error"))
}))
defer srv.Close()
c := NewClient(srv.URL, "test-key")
_, err := c.Movies()
if err == nil {
t.Error("expected error for 500 response")
}
}
func contains(s, substr string) bool {
return len(s) >= len(substr) && searchStr(s, substr)
}
func searchStr(s, sub string) bool {
for i := 0; i <= len(s)-len(sub); i++ {
if s[i:i+len(sub)] == sub {
return true
}
}
return false
}

View file

@ -1,6 +1,9 @@
package arr
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
@ -82,3 +85,158 @@ func TestDetectApp(t *testing.T) {
})
}
}
func TestConfigDirs(t *testing.T) {
dirs := configDirs()
if len(dirs) == 0 {
t.Error("configDirs() returned empty")
}
}
func TestParseConfigXMLEmpty(t *testing.T) {
port, apiKey, urlBase := parseConfigXML(strings.NewReader(""))
if port != "" || apiKey != "" || urlBase != "" {
t.Error("empty input should return empty values")
}
}
func TestParseConfigXMLNoPort(t *testing.T) {
xml := `<Config><ApiKey>key123</ApiKey></Config>`
port, apiKey, _ := parseConfigXML(strings.NewReader(xml))
if port != "" {
t.Errorf("port = %q, want empty", port)
}
if apiKey != "key123" {
t.Errorf("apiKey = %q, want key123", apiKey)
}
}
func TestExtractHostPortMultipleMappings(t *testing.T) {
tests := []struct {
name string
ports string
container string
want string
}{
{"ipv6 only", ":::8989->8989/tcp", "8989", "8989"},
{"different host port", "0.0.0.0:9999->8989/tcp", "8989", "9999"},
{"port in string but no mapping", "something 8989 somewhere", "8989", "8989"},
{"no match at all", "0.0.0.0:3000->3000/tcp", "9999", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractHostPort(tt.ports, tt.container)
if got != tt.want {
t.Errorf("extractHostPort(%q, %q) = %q, want %q", tt.ports, tt.container, got, tt.want)
}
})
}
}
func TestDiscoverFromProwlarr(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/v1/applications":
json.NewEncoder(w).Encode([]Application{
{
ID: 1,
Name: "Radarr",
Fields: []Field{
{Name: "baseUrl", Value: "http://localhost:7878"},
{Name: "apiKey", Value: "radarr-key-123"},
},
},
{
ID: 2,
Name: "Sonarr",
Fields: []Field{
{Name: "baseUrl", Value: "http://localhost:8989"},
{Name: "apiKey", Value: "sonarr-key-456"},
},
},
{
ID: 3,
Name: "Unknown App",
Fields: []Field{
{Name: "baseUrl", Value: "http://localhost:9000"},
{Name: "apiKey", Value: "unknown-key"},
},
},
{
ID: 4,
Name: "Incomplete",
Fields: []Field{
{Name: "baseUrl", Value: "http://localhost:5000"},
// no apiKey → should be skipped
},
},
})
case "/api/v3/system/status":
json.NewEncoder(w).Encode(SystemStatus{AppName: "Radarr", Version: "4.0.0"})
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
// DiscoverFromProwlarr will try to verify each instance, which will fail
// for localhost URLs (not our test server), but that's OK — we test the parsing
instances := DiscoverFromProwlarr(srv.URL, "prowlarr-key")
// Should find Radarr and Sonarr (Unknown and Incomplete skipped)
if len(instances) != 2 {
t.Fatalf("expected 2 instances, got %d: %+v", len(instances), instances)
}
found := map[string]bool{}
for _, inst := range instances {
found[inst.App] = true
if inst.Source != "prowlarr" {
t.Errorf("source = %q, want prowlarr", inst.Source)
}
}
if !found["radarr"] {
t.Error("expected radarr instance")
}
if !found["sonarr"] {
t.Error("expected sonarr instance")
}
}
func TestVerify(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("X-Api-Key") != "valid-key" {
w.WriteHeader(http.StatusUnauthorized)
return
}
json.NewEncoder(w).Encode(SystemStatus{AppName: "Radarr", Version: "5.0.0"})
}))
defer srv.Close()
t.Run("valid", func(t *testing.T) {
inst := &Instance{App: "radarr", URL: srv.URL, APIKey: "valid-key"}
err := Verify(inst)
if err != nil {
t.Fatalf("Verify: %v", err)
}
if inst.Version != "5.0.0" {
t.Errorf("version = %q, want 5.0.0", inst.Version)
}
})
t.Run("no api key", func(t *testing.T) {
inst := &Instance{App: "radarr", URL: srv.URL}
err := Verify(inst)
if err == nil {
t.Error("expected error for no API key")
}
})
t.Run("invalid key", func(t *testing.T) {
inst := &Instance{App: "radarr", URL: srv.URL, APIKey: "wrong-key"}
err := Verify(inst)
if err == nil {
t.Error("expected error for invalid API key")
}
})
}

View file

@ -0,0 +1,55 @@
package cmd
import "testing"
func TestValidateSpeed(t *testing.T) {
tests := []struct {
input string
wantErr bool
}{
{"", false},
{"0", false},
{" ", false},
{"10MB", false},
{"500KB", false},
{"1GB", false},
{"abc", true},
{"10XB", true},
{"-5MB", true},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
err := validateSpeed(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("validateSpeed(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr)
}
})
}
}
func TestValidateDuration(t *testing.T) {
tests := []struct {
input string
wantErr bool
}{
{"", false},
{"30s", false},
{"1m", false},
{"5m", false},
{"1h", false},
{"2h30m", false},
{"abc", true},
{"30", true},
{"5x", true},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
err := validateDuration(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("validateDuration(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr)
}
})
}
}

View file

@ -0,0 +1,55 @@
package cmd
import "testing"
func TestDeriveWSURL(t *testing.T) {
tests := []struct {
apiURL string
agentID string
want string
}{
{"https://torrentclaw.com", "agent-123", "wss://unarr.torrentclaw.com/ws/agent-123"},
{"http://localhost:3000", "a1", ""}, // localhost skipped
{"http://127.0.0.1:3000", "a1", ""}, // 127.0.0.1 skipped
{"https://torrentclaw.com/", "a1", "wss://unarr.torrentclaw.com/ws/a1"},
{"https://api.example.io", "x", "wss://unarr.api.example.io/ws/x"},
{"", "agent-123", ""},
{"https://torrentclaw.com", "", ""},
{"", "", ""},
}
for _, tt := range tests {
t.Run(tt.apiURL+"_"+tt.agentID, func(t *testing.T) {
got := deriveWSURL(tt.apiURL, tt.agentID)
if got != tt.want {
t.Errorf("deriveWSURL(%q, %q) = %q, want %q", tt.apiURL, tt.agentID, got, tt.want)
}
})
}
}
func TestFormatSpeedLog(t *testing.T) {
tests := []struct {
bps int64
want string
}{
{0, "0 B/s"},
{500, "500 B/s"},
{1023, "1023 B/s"},
{1024, "1 KB/s"},
{10240, "10 KB/s"},
{1048576, "1.0 MB/s"},
{5242880, "5.0 MB/s"},
{1073741824, "1.0 GB/s"},
{2147483648, "2.0 GB/s"},
}
for _, tt := range tests {
t.Run(tt.want, func(t *testing.T) {
got := formatSpeedLog(tt.bps)
if got != tt.want {
t.Errorf("formatSpeedLog(%d) = %q, want %q", tt.bps, got, tt.want)
}
})
}
}

View file

@ -0,0 +1,43 @@
package cmd
import (
"os"
"strings"
"testing"
)
func TestExpandHome(t *testing.T) {
home, _ := os.UserHomeDir()
tests := []struct {
input string
want string
}{
{"~/Documents", home + "/Documents"},
{"~/", home},
{"/absolute/path", "/absolute/path"},
{"relative/path", "relative/path"},
{"", ""},
{"~notexpanded", "~notexpanded"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
got := expandHome(tt.input)
if got != tt.want {
t.Errorf("expandHome(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestDefaultDownloadDir(t *testing.T) {
dir := defaultDownloadDir()
if dir == "" {
t.Error("defaultDownloadDir() returned empty string")
}
home, _ := os.UserHomeDir()
if !strings.HasPrefix(dir, home) {
t.Errorf("defaultDownloadDir() = %q, expected to start with home dir %q", dir, home)
}
}

View file

@ -143,8 +143,9 @@ Source: https://github.com/torrentclaw/unarr`,
completionCmd,
// Library
scanCmd,
// Alias: upgrade → self-update
newUpgradeCmd(),
// Stubs for future commands
newStubCmd("upgrade", "Find a better version of a torrent"),
newStubCmd("moreseed", "Find same quality with more seeders"),
newStubCmd("compare", "Compare two torrents side by side"),
newStubCmd("add", "Search and add torrents to your client"),

View file

@ -1,20 +1,26 @@
package cmd
import (
"context"
"fmt"
"runtime"
"strings"
"time"
"github.com/fatih/color"
"github.com/spf13/cobra"
"github.com/torrentclaw/unarr/internal/agent"
"github.com/torrentclaw/unarr/internal/upgrade"
)
func newStatusCmd() *cobra.Command {
return &cobra.Command{
Use: "status",
Short: "Show daemon status and active downloads",
Long: `Display the current state of the daemon, active downloads, and recent activity.
Short: "Show daemon status, configuration, and update availability",
Long: `Display the current state of unarr: version, configuration, daemon status,
disk usage, and whether an update is available.
Shows the configured agent name, download directory, and preferred method.
When the daemon is running, also displays active downloads and their progress.`,
When the daemon is running, also displays uptime, active downloads, and stats.`,
Example: ` unarr status`,
RunE: func(cmd *cobra.Command, args []string) error {
return runStatus()
@ -25,27 +31,167 @@ When the daemon is running, also displays active downloads and their progress.`,
func runStatus() error {
bold := color.New(color.Bold)
dim := color.New(color.FgHiBlack)
green := color.New(color.FgGreen)
yellow := color.New(color.FgYellow)
cyan := color.New(color.FgCyan)
fmt.Println()
bold.Printf(" unarr %s\n", Version)
dim.Printf(" %s/%s\n", runtime.GOOS, runtime.GOARCH)
fmt.Println()
cfg := loadConfig()
// ── Configuration ──
if cfg.Auth.APIKey == "" {
dim.Println(" Not configured. Run 'unarr init' first.")
yellow.Println(" ⚠ Not configured. Run 'unarr init' first.")
fmt.Println()
return nil
}
fmt.Printf(" Agent: %s (%s)\n", cfg.Agent.Name, cfg.Agent.ID[:8]+"...")
fmt.Printf(" Downloads: %s\n", cfg.Download.Dir)
fmt.Printf(" Method: %s\n", cfg.Download.PreferredMethod)
cyan.Println(" Configuration")
agentID := cfg.Agent.ID
if len(agentID) > 8 {
agentID = agentID[:8] + "..."
}
fmt.Printf(" Agent: %s (%s)\n", cfg.Agent.Name, agentID)
fmt.Printf(" Server: %s\n", cfg.Auth.APIURL)
fmt.Printf(" Downloads: %s\n", cfg.Download.Dir)
fmt.Printf(" Method: %s\n", cfg.Download.PreferredMethod)
if cfg.Download.PreferredQuality != "" {
fmt.Printf(" Quality: %s\n", cfg.Download.PreferredQuality)
}
fmt.Printf(" Concurrent: %d\n", cfg.Download.MaxConcurrent)
if cfg.Organize.Enabled {
fmt.Printf(" Organize: on")
if cfg.Organize.MoviesDir != "" {
fmt.Printf(" (movies: %s", cfg.Organize.MoviesDir)
if cfg.Organize.TVShowsDir != "" {
fmt.Printf(", tv: %s", cfg.Organize.TVShowsDir)
}
fmt.Print(")")
}
fmt.Println()
}
fmt.Println()
dim.Println(" Daemon not running. Start with 'unarr start'")
dim.Println(" (Live status will be shown here when daemon is running)")
// ── Disk ──
if cfg.Download.Dir != "" {
if free, total, err := agent.DiskInfo(cfg.Download.Dir); err == nil && total > 0 {
usedPct := float64(total-free) / float64(total) * 100
cyan.Println(" Disk")
fmt.Printf(" Free: %s / %s (%.0f%% used)\n", formatBytes(free), formatBytes(total), usedPct)
if usedPct > 90 {
yellow.Println(" ⚠ Low disk space!")
}
fmt.Println()
}
}
// ── Daemon ──
cyan.Println(" Daemon")
state := agent.ReadState()
if state != nil && isDaemonAlive(state) {
green.Printf(" Status: running (PID %d)\n", state.PID)
fmt.Printf(" Uptime: %s\n", formatDuration(time.Since(state.StartedAt)))
fmt.Printf(" Last beat: %s ago\n", formatDuration(time.Since(state.LastHeartbeat)))
fmt.Printf(" Active: %d task(s)\n", state.ActiveTasks)
fmt.Printf(" Completed: %d\n", state.CompletedCount)
if state.FailedCount > 0 {
fmt.Printf(" Failed: %d\n", state.FailedCount)
}
if state.TotalDownloaded > 0 {
fmt.Printf(" Downloaded: %s\n", formatBytes(state.TotalDownloaded))
}
if len(state.MethodStats) > 0 {
parts := make([]string, 0, len(state.MethodStats))
for method, count := range state.MethodStats {
parts = append(parts, fmt.Sprintf("%s:%d", method, count))
}
fmt.Printf(" Methods: %s\n", strings.Join(parts, ", "))
}
} else {
dim.Println(" Status: stopped")
dim.Println(" Start with: unarr start")
}
fmt.Println()
// ── Update check (cached: instant if <1h, otherwise async 3s) ──
type versionResult struct {
version string
fromCache bool
err error
}
versionCh := make(chan versionResult, 1)
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
v, cached, err := upgrade.CheckLatestCached(ctx)
versionCh <- versionResult{v, cached, err}
}()
cyan.Println(" Update")
fmt.Print(" Checking... ")
vr := <-versionCh
if vr.err != nil {
dim.Println("could not check (offline?)")
} else {
currentClean := strings.TrimPrefix(Version, "v")
if currentClean == vr.version {
green.Printf("✓ up to date (v%s)\n", vr.version)
} else {
yellow.Printf("v%s available! ", vr.version)
fmt.Printf("Run: unarr upgrade\n")
}
}
fmt.Println()
return nil
}
// isDaemonAlive checks if the daemon process from the state file is still running.
// Guards against PID reuse by also checking heartbeat recency.
func isDaemonAlive(state *agent.DaemonState) bool {
if state.PID == 0 {
return false
}
// Reject stale state: if last heartbeat is older than 2 minutes, the daemon
// likely crashed and the PID may have been reused by another process.
if !state.LastHeartbeat.IsZero() && time.Since(state.LastHeartbeat) > 2*time.Minute {
return false
}
return agent.IsProcessAlive(state.PID)
}
// formatBytes formats bytes into human-readable string.
func formatBytes(b int64) string {
const unit = 1024
if b < unit {
return fmt.Sprintf("%d B", b)
}
div, exp := int64(unit), 0
for n := b / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %cB", float64(b)/float64(div), "KMGTPE"[exp])
}
// formatDuration formats a duration into a compact human-readable string.
func formatDuration(d time.Duration) string {
if d < 0 {
return "0s"
}
if d < time.Minute {
return fmt.Sprintf("%ds", int(d.Seconds()))
}
if d < time.Hour {
return fmt.Sprintf("%dm %ds", int(d.Minutes()), int(d.Seconds())%60)
}
if d < 24*time.Hour {
return fmt.Sprintf("%dh %dm", int(d.Hours()), int(d.Minutes())%60)
}
days := int(d.Hours()) / 24
hours := int(d.Hours()) % 24
return fmt.Sprintf("%dd %dh", days, hours)
}

View file

@ -24,6 +24,20 @@ var streamRegistry = struct {
servers: make(map[string]*engine.StreamServer),
}
// cancelAllStreams cancels all active stream tasks and servers (only 1 stream at a time).
func cancelAllStreams() {
streamRegistry.mu.Lock()
for taskID, cancel := range streamRegistry.cancels {
cancel()
delete(streamRegistry.cancels, taskID)
}
for taskID, srv := range streamRegistry.servers {
srv.Shutdown(context.Background())
delete(streamRegistry.servers, taskID)
}
streamRegistry.mu.Unlock()
}
// cancelStreamTask cancels a running stream task and shuts down any stream server.
func cancelStreamTask(taskID string) {
streamRegistry.mu.Lock()
@ -94,7 +108,7 @@ func handleStreamTask(parentCtx context.Context, at agent.Task, reporter *engine
}
// 4. Start HTTP server
srv := engine.NewStreamServer(eng, 0)
srv := engine.NewStreamServer(eng, cfg.Download.StreamPort)
streamURL, err := srv.Start(ctx)
if err != nil {
task.ErrorMessage = "start HTTP server: " + err.Error()
@ -107,17 +121,27 @@ func handleStreamTask(parentCtx context.Context, at agent.Task, reporter *engine
task.StreamURL = streamURL
log.Printf("[%s] stream ready: %s", at.ID[:8], streamURL)
// 6. Progress loop
// 6. Unified progress + idle timeout loop
eng.StartProgressLoop(ctx)
ticker := time.NewTicker(3 * time.Second)
defer ticker.Stop()
progressTicker := time.NewTicker(3 * time.Second)
defer progressTicker.Stop()
idleCheck := time.NewTicker(60 * time.Second)
defer idleCheck.Stop()
completed := false
for {
select {
case <-ctx.Done():
log.Printf("[%s] stream stopped", at.ID[:8])
return
case <-ticker.C:
case <-idleCheck.C:
if srv.IdleSince() > 30*time.Minute {
log.Printf("[%s] stream idle timeout (30m no HTTP requests), shutting down", at.ID[:8])
return
}
case <-progressTicker.C:
p := eng.Progress()
task.UpdateProgress(engine.Progress{
DownloadedBytes: p.DownloadedBytes,
@ -129,7 +153,7 @@ func handleStreamTask(parentCtx context.Context, at agent.Task, reporter *engine
})
// Terminal progress
if p.TotalBytes > 0 {
if !completed && p.TotalBytes > 0 {
pct := int(float64(p.DownloadedBytes) / float64(p.TotalBytes) * 100)
fmt.Fprintf(os.Stderr, "\r[%s] %d%% — %s/%s @ %s/s peers:%d seeds:%d",
at.ID[:8], pct,
@ -137,20 +161,11 @@ func handleStreamTask(parentCtx context.Context, at agent.Task, reporter *engine
p.Peers, p.Seeds)
}
if p.DownloadedBytes >= p.TotalBytes && p.TotalBytes > 0 {
fmt.Fprint(os.Stderr, "\r\033[2K") // clear progress line
if !completed && p.DownloadedBytes >= p.TotalBytes && p.TotalBytes > 0 {
fmt.Fprint(os.Stderr, "\r\033[2K")
task.Transition(engine.StatusCompleted)
log.Printf("[%s] stream download complete, server stays up for 30m or until cancelled", at.ID[:8])
// Keep HTTP server running so the player can finish reading.
// Auto-shutdown after 30 minutes of idle to prevent resource leaks.
idleTimer := time.NewTimer(30 * time.Minute)
defer idleTimer.Stop()
select {
case <-ctx.Done():
case <-idleTimer.C:
log.Printf("[%s] stream idle timeout (30m), shutting down", at.ID[:8])
}
return
log.Printf("[%s] stream download complete, server stays up until idle (30m)", at.ID[:8])
completed = true
}
}
}

30
internal/cmd/upgrade.go Normal file
View file

@ -0,0 +1,30 @@
package cmd
import (
"github.com/spf13/cobra"
)
// newUpgradeCmd creates the `unarr upgrade` command as an alias for `self-update`.
func newUpgradeCmd() *cobra.Command {
var force bool
cmd := &cobra.Command{
Use: "upgrade",
Aliases: []string{"update"},
Short: "Update unarr to the latest version",
Long: `Download and install the latest version of unarr.
This is an alias for 'unarr self-update'. Checks GitHub for the latest
release, verifies the checksum, and replaces the current binary.
A backup is kept at <binary>.backup.`,
Example: ` unarr upgrade
unarr upgrade --force`,
RunE: func(cmd *cobra.Command, args []string) error {
return runSelfUpdate(force)
},
}
cmd.Flags().BoolVarP(&force, "force", "f", false, "reinstall even if already up to date")
return cmd
}

View file

@ -2,6 +2,7 @@ package engine
import (
"context"
"os"
"testing"
"time"
@ -83,3 +84,223 @@ func TestManagerShutdown(t *testing.T) {
mgr.Shutdown(ctx)
// Should not hang
}
func TestManagerDefaultConcurrency(t *testing.T) {
reporter := NewProgressReporter(
agent.NewClient("http://localhost", "test", "test"),
1*time.Second,
)
mgr := NewManager(ManagerConfig{MaxConcurrent: 0}, reporter)
if cap(mgr.sem) != 3 {
t.Errorf("default MaxConcurrent should be 3, got %d", cap(mgr.sem))
}
}
func TestManagerGetTask(t *testing.T) {
reporter := NewProgressReporter(
agent.NewClient("http://localhost", "test", "test"),
1*time.Second,
)
mgr := NewManager(ManagerConfig{MaxConcurrent: 2}, reporter)
// No task added
if task := mgr.GetTask("nonexistent"); task != nil {
t.Error("expected nil for nonexistent task")
}
}
func TestManagerActiveTasks(t *testing.T) {
reporter := NewProgressReporter(
agent.NewClient("http://localhost", "test", "test"),
1*time.Second,
)
mgr := NewManager(ManagerConfig{MaxConcurrent: 2}, reporter)
tasks := mgr.ActiveTasks()
if len(tasks) != 0 {
t.Errorf("expected 0 active tasks, got %d", len(tasks))
}
}
func TestManagerSubmitCompletesWithValidFile(t *testing.T) {
dir := t.TempDir()
// Create a file that verify() will accept
filePath := dir + "/movie.mkv"
os.WriteFile(filePath, make([]byte, 1024), 0o644)
reporter := &mockStatusReporter{}
pr := &ProgressReporter{
reporter: reporter,
interval: 100 * time.Millisecond,
latest: make(map[string]*Task),
lastReported: make(map[string]TaskStatus),
}
dl := &resultMockDownloader{
method: MethodTorrent,
result: &Result{
FilePath: filePath,
FileName: "movie.mkv",
Method: MethodTorrent,
Size: 1024,
},
}
mgr := NewManager(ManagerConfig{
MaxConcurrent: 2,
OutputDir: dir,
}, pr, dl)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
go pr.Run(ctx)
mgr.Submit(ctx, agent.Task{
ID: "task-complete-test1",
InfoHash: "abc123def456abc123def456abc123def456abc1",
Title: "Test Movie",
PreferredMethod: "torrent",
})
mgr.Wait()
cancel()
// Task should have completed successfully
// (we can't check directly since it's removed from active map after processing)
}
func TestManagerCancelTask(t *testing.T) {
reporter := NewProgressReporter(
agent.NewClient("http://localhost", "test", "test"),
1*time.Second,
)
dl := &slowMockDownloader{method: MethodTorrent}
mgr := NewManager(ManagerConfig{
MaxConcurrent: 2,
OutputDir: t.TempDir(),
}, reporter, dl)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
go reporter.Run(ctx)
mgr.Submit(ctx, agent.Task{
ID: "task-cancel-test12",
InfoHash: "abc123def456abc123def456abc123def456abc1",
Title: "Cancel Me",
PreferredMethod: "torrent",
})
// Give it time to start
time.Sleep(100 * time.Millisecond)
mgr.CancelTask("task-cancel-test12")
mgr.Wait()
}
func TestManagerPauseTask(t *testing.T) {
reporter := NewProgressReporter(
agent.NewClient("http://localhost", "test", "test"),
1*time.Second,
)
dl := &slowMockDownloader{method: MethodTorrent}
mgr := NewManager(ManagerConfig{
MaxConcurrent: 2,
OutputDir: t.TempDir(),
}, reporter, dl)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
go reporter.Run(ctx)
mgr.Submit(ctx, agent.Task{
ID: "task-pause-test123",
InfoHash: "abc123def456abc123def456abc123def456abc1",
Title: "Pause Me",
PreferredMethod: "torrent",
})
time.Sleep(100 * time.Millisecond)
mgr.PauseTask("task-pause-test123")
mgr.Wait()
}
func TestManagerCancelAndDeleteFiles(t *testing.T) {
reporter := NewProgressReporter(
agent.NewClient("http://localhost", "test", "test"),
1*time.Second,
)
dl := &slowMockDownloader{method: MethodTorrent}
mgr := NewManager(ManagerConfig{
MaxConcurrent: 2,
OutputDir: t.TempDir(),
}, reporter, dl)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
go reporter.Run(ctx)
mgr.Submit(ctx, agent.Task{
ID: "task-delfile-test12",
InfoHash: "abc123def456abc123def456abc123def456abc1",
Title: "Delete Me",
PreferredMethod: "torrent",
})
time.Sleep(100 * time.Millisecond)
mgr.CancelAndDeleteFiles("task-delfile-test12")
mgr.Wait()
}
func TestManagerCancelNonexistent(t *testing.T) {
reporter := NewProgressReporter(
agent.NewClient("http://localhost", "test", "test"),
1*time.Second,
)
mgr := NewManager(ManagerConfig{MaxConcurrent: 2}, reporter)
// Should not panic
mgr.CancelTask("nonexistent")
mgr.PauseTask("nonexistent")
mgr.CancelAndDeleteFiles("nonexistent")
}
// resultMockDownloader returns a configurable result
type resultMockDownloader struct {
method DownloadMethod
result *Result
}
func (m *resultMockDownloader) Method() DownloadMethod { return m.method }
func (m *resultMockDownloader) Available(_ context.Context, _ *Task) (bool, error) {
return true, nil
}
func (m *resultMockDownloader) Download(_ context.Context, _ *Task, _ string, _ chan<- Progress) (*Result, error) {
return m.result, nil
}
func (m *resultMockDownloader) Pause(_ string) error { return nil }
func (m *resultMockDownloader) Cancel(_ string) error { return nil }
func (m *resultMockDownloader) Shutdown(_ context.Context) error { return nil }
// slowMockDownloader blocks until context is cancelled
type slowMockDownloader struct {
method DownloadMethod
}
func (m *slowMockDownloader) Method() DownloadMethod { return m.method }
func (m *slowMockDownloader) Available(_ context.Context, _ *Task) (bool, error) {
return true, nil
}
func (m *slowMockDownloader) Download(ctx context.Context, _ *Task, _ string, _ chan<- Progress) (*Result, error) {
<-ctx.Done()
return nil, ctx.Err()
}
func (m *slowMockDownloader) Pause(_ string) error { return nil }
func (m *slowMockDownloader) Cancel(_ string) error { return nil }
func (m *slowMockDownloader) Shutdown(_ context.Context) error { return nil }

View file

@ -0,0 +1,50 @@
package engine
import "testing"
func TestDownloadMethodConstants(t *testing.T) {
if MethodTorrent != "torrent" {
t.Errorf("MethodTorrent = %q, want torrent", MethodTorrent)
}
if MethodDebrid != "debrid" {
t.Errorf("MethodDebrid = %q, want debrid", MethodDebrid)
}
if MethodUsenet != "usenet" {
t.Errorf("MethodUsenet = %q, want usenet", MethodUsenet)
}
}
func TestProgressStruct(t *testing.T) {
p := Progress{
DownloadedBytes: 1024,
TotalBytes: 2048,
SpeedBps: 512,
ETA: 10,
Peers: 5,
Seeds: 3,
FileName: "movie.mkv",
}
if p.DownloadedBytes != 1024 {
t.Errorf("DownloadedBytes = %d, want 1024", p.DownloadedBytes)
}
if p.FileName != "movie.mkv" {
t.Errorf("FileName = %q, want movie.mkv", p.FileName)
}
}
func TestResultStruct(t *testing.T) {
r := Result{
FilePath: "/downloads/movie.mkv",
FileName: "movie.mkv",
Method: MethodTorrent,
Size: 1073741824,
}
if r.Method != MethodTorrent {
t.Errorf("Method = %q, want torrent", r.Method)
}
if r.Size != 1073741824 {
t.Errorf("Size = %d, want 1073741824", r.Size)
}
}

View file

@ -0,0 +1,181 @@
package engine
import (
"os"
"path/filepath"
"testing"
)
func TestReplaceFile(t *testing.T) {
tmp := t.TempDir()
backupDir := filepath.Join(tmp, "backups")
// Create "old" file
oldPath := filepath.Join(tmp, "movie.mkv")
os.WriteFile(oldPath, []byte("old content"), 0o644)
// Create "new" file
newPath := filepath.Join(tmp, "movie-new.mkv")
os.WriteFile(newPath, []byte("new better content"), 0o644)
err := replaceFile(oldPath, newPath, backupDir)
if err != nil {
t.Fatalf("replaceFile: %v", err)
}
// Old path should now contain new content
data, err := os.ReadFile(oldPath)
if err != nil {
t.Fatalf("read old path: %v", err)
}
if string(data) != "new better content" {
t.Errorf("old path content = %q, want 'new better content'", string(data))
}
// Backup should exist
entries, _ := os.ReadDir(backupDir)
if len(entries) != 1 {
t.Errorf("expected 1 backup file, got %d", len(entries))
}
// New file should be gone
if _, err := os.Stat(newPath); !os.IsNotExist(err) {
t.Error("new file should have been moved/deleted")
}
}
func TestReplaceFileOldNotFound(t *testing.T) {
tmp := t.TempDir()
err := replaceFile(filepath.Join(tmp, "nonexistent.mkv"), filepath.Join(tmp, "new.mkv"), "")
if err == nil {
t.Error("expected error when old file doesn't exist")
}
}
func TestCopyFile(t *testing.T) {
tmp := t.TempDir()
src := filepath.Join(tmp, "source.txt")
dst := filepath.Join(tmp, "dest.txt")
content := []byte("hello world copy test")
os.WriteFile(src, content, 0o644)
err := copyFile(src, dst)
if err != nil {
t.Fatalf("copyFile: %v", err)
}
data, err := os.ReadFile(dst)
if err != nil {
t.Fatalf("read dest: %v", err)
}
if string(data) != string(content) {
t.Errorf("dest content = %q, want %q", string(data), string(content))
}
}
func TestCopyFileSrcNotFound(t *testing.T) {
tmp := t.TempDir()
err := copyFile(filepath.Join(tmp, "nope.txt"), filepath.Join(tmp, "out.txt"))
if err == nil {
t.Error("expected error when source doesn't exist")
}
}
func TestOrganizeNoDirs(t *testing.T) {
r := &Result{FilePath: "/tmp/file.mkv", FileName: "file.mkv"}
task := &Task{Title: "Movie"}
path, err := organize(r, task, OrganizeConfig{Enabled: true})
if err != nil {
t.Fatal(err)
}
if path != "/tmp/file.mkv" {
t.Errorf("should return original path when no dirs configured, got %q", path)
}
}
func TestOrganizeNilResult(t *testing.T) {
task := &Task{Title: "Movie"}
path, err := organize(&Result{}, task, OrganizeConfig{Enabled: true})
if err != nil {
t.Fatal(err)
}
if path != "" {
t.Errorf("expected empty path for empty result, got %q", path)
}
}
func TestOrganizeMovieDirectory(t *testing.T) {
tmp := t.TempDir()
srcDir := filepath.Join(tmp, "src", "MovieDir")
os.MkdirAll(srcDir, 0o755)
os.WriteFile(filepath.Join(srcDir, "movie.mkv"), []byte("data"), 0o644)
moviesDir := filepath.Join(tmp, "Movies")
r := &Result{FilePath: srcDir, FileName: "MovieDir"}
task := &Task{Title: "My Movie 2023"}
path, err := organize(r, task, OrganizeConfig{
Enabled: true,
MoviesDir: moviesDir,
})
if err != nil {
t.Fatal(err)
}
if path == srcDir {
t.Error("directory should have moved")
}
if _, err := os.Stat(path); err != nil {
t.Errorf("organized directory should exist at %s", path)
}
}
func TestOrganizeSeasonOnly(t *testing.T) {
tmp := t.TempDir()
srcFile := filepath.Join(tmp, "Show.S01.Complete.mkv")
os.WriteFile(srcFile, []byte("data"), 0o644)
tvDir := filepath.Join(tmp, "TV")
r := &Result{FilePath: srcFile, FileName: "Show.S01.Complete.mkv"}
task := &Task{Title: "Show S01"}
path, err := organize(r, task, OrganizeConfig{
Enabled: true,
TVShowsDir: tvDir,
})
if err != nil {
t.Fatal(err)
}
dir := filepath.Dir(path)
if filepath.Base(dir) != "Season 01" {
t.Errorf("expected Season 01 directory, got %q", filepath.Base(dir))
}
}
func TestCleanTitleEdgeCases(t *testing.T) {
tests := []struct {
input string
want string
}{
{"", ""},
{"Simple Title", "Simple Title"},
{"Title (2023) 1080p BluRay", "Title"},
{"Title 720p HDTV", "Title"},
{"Title x264 HEVC", "Title"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
got := cleanTitle(tt.input)
if got != tt.want {
t.Errorf("cleanTitle(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}

View file

@ -0,0 +1,419 @@
package engine
import (
"context"
"sync"
"testing"
"time"
"github.com/torrentclaw/unarr/internal/agent"
)
// mockStatusReporter records calls to ReportStatus.
type mockStatusReporter struct {
mu sync.Mutex
calls []agent.StatusUpdate
resp *agent.StatusResponse
respErr error
}
func (m *mockStatusReporter) ReportStatus(_ context.Context, update agent.StatusUpdate) (*agent.StatusResponse, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.calls = append(m.calls, update)
if m.resp != nil {
return m.resp, m.respErr
}
return &agent.StatusResponse{}, m.respErr
}
// mockBatchReporter records batch calls.
type mockBatchReporter struct {
mockStatusReporter
batchCalls [][]agent.StatusUpdate
batchResp *agent.BatchStatusResponse
}
func (m *mockBatchReporter) BatchReportStatus(_ context.Context, updates []agent.StatusUpdate) (*agent.BatchStatusResponse, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.batchCalls = append(m.batchCalls, updates)
if m.batchResp != nil {
return m.batchResp, nil
}
results := make([]agent.StatusResponse, len(updates))
return &agent.BatchStatusResponse{Results: results}, nil
}
func TestProgressReporter_TrackUntrack(t *testing.T) {
reporter := &mockStatusReporter{}
pr := &ProgressReporter{
reporter: reporter,
interval: time.Second,
latest: make(map[string]*Task),
lastReported: make(map[string]TaskStatus),
}
task := &Task{ID: "task-001", Status: StatusDownloading}
pr.Track(task)
pr.mu.Lock()
if _, ok := pr.latest["task-001"]; !ok {
t.Error("task should be tracked")
}
pr.mu.Unlock()
pr.Untrack("task-001")
pr.mu.Lock()
if _, ok := pr.latest["task-001"]; ok {
t.Error("task should be untracked")
}
pr.mu.Unlock()
}
func TestProgressReporter_FlushReportsFinalStates(t *testing.T) {
reporter := &mockStatusReporter{}
pr := &ProgressReporter{
reporter: reporter,
interval: time.Second,
latest: make(map[string]*Task),
lastReported: make(map[string]TaskStatus),
}
completed := &Task{ID: "task-completed-1234", Status: StatusCompleted}
pr.Track(completed)
pr.flush(context.Background())
reporter.mu.Lock()
defer reporter.mu.Unlock()
if len(reporter.calls) != 1 {
t.Fatalf("expected 1 report, got %d", len(reporter.calls))
}
if reporter.calls[0].TaskID != "task-completed-1234" {
t.Errorf("reported wrong task: %s", reporter.calls[0].TaskID)
}
}
func TestProgressReporter_FlushSkipsWhenNotWatching(t *testing.T) {
reporter := &mockStatusReporter{}
pr := &ProgressReporter{
reporter: reporter,
interval: time.Second,
isWatching: func() bool { return false },
latest: make(map[string]*Task),
lastReported: make(map[string]TaskStatus),
lastCheckAt: time.Now(), // not due for control check
}
// Active downloading task, already reported as downloading
task := &Task{ID: "task-active-12345678", Status: StatusDownloading}
pr.Track(task)
pr.mu.Lock()
pr.lastReported["task-active-12345678"] = StatusDownloading
pr.mu.Unlock()
pr.flush(context.Background())
reporter.mu.Lock()
defer reporter.mu.Unlock()
if len(reporter.calls) != 0 {
t.Errorf("expected 0 reports when not watching (no transition), got %d", len(reporter.calls))
}
}
func TestProgressReporter_FlushReportsTransitions(t *testing.T) {
reporter := &mockStatusReporter{}
pr := &ProgressReporter{
reporter: reporter,
interval: time.Second,
isWatching: func() bool { return false },
latest: make(map[string]*Task),
lastReported: make(map[string]TaskStatus),
lastCheckAt: time.Now(),
}
// Task transitioning from resolving to downloading
task := &Task{ID: "task-trans-12345678", Status: StatusDownloading}
pr.Track(task)
pr.mu.Lock()
pr.lastReported["task-trans-12345678"] = StatusResolving
pr.mu.Unlock()
pr.flush(context.Background())
reporter.mu.Lock()
defer reporter.mu.Unlock()
if len(reporter.calls) != 1 {
t.Fatalf("expected 1 report for transition, got %d", len(reporter.calls))
}
}
func TestProgressReporter_FlushActiveWhenWatching(t *testing.T) {
reporter := &mockStatusReporter{}
pr := &ProgressReporter{
reporter: reporter,
interval: time.Second,
isWatching: func() bool { return true },
latest: make(map[string]*Task),
lastReported: make(map[string]TaskStatus),
}
task := &Task{ID: "task-watch-12345678", Status: StatusDownloading}
pr.Track(task)
pr.mu.Lock()
pr.lastReported["task-watch-12345678"] = StatusDownloading
pr.mu.Unlock()
pr.flush(context.Background())
reporter.mu.Lock()
defer reporter.mu.Unlock()
if len(reporter.calls) != 1 {
t.Fatalf("expected 1 report when watching active task, got %d", len(reporter.calls))
}
}
func TestProgressReporter_HandleResponseCancel(t *testing.T) {
reporter := &mockStatusReporter{
resp: &agent.StatusResponse{Cancelled: true},
}
var cancelledID string
pr := &ProgressReporter{
reporter: reporter,
interval: time.Second,
latest: make(map[string]*Task),
lastReported: make(map[string]TaskStatus),
onCancel: func(id string) { cancelledID = id },
}
task := &Task{ID: "task-cancel-1234567", Status: StatusCompleted}
pr.Track(task)
pr.flush(context.Background())
if cancelledID != "task-cancel-1234567" {
t.Errorf("expected cancel handler called with task ID, got %q", cancelledID)
}
}
func TestProgressReporter_HandleResponsePause(t *testing.T) {
reporter := &mockStatusReporter{
resp: &agent.StatusResponse{Paused: true},
}
var pausedID string
pr := &ProgressReporter{
reporter: reporter,
interval: time.Second,
latest: make(map[string]*Task),
lastReported: make(map[string]TaskStatus),
onPause: func(id string) { pausedID = id },
}
task := &Task{ID: "task-paused-1234567", Status: StatusCompleted}
pr.Track(task)
pr.flush(context.Background())
if pausedID != "task-paused-1234567" {
t.Errorf("expected pause handler called, got %q", pausedID)
}
}
func TestProgressReporter_HandleResponseDeleteFiles(t *testing.T) {
reporter := &mockStatusReporter{
resp: &agent.StatusResponse{Cancelled: true, DeleteFiles: true},
}
var deletedID string
pr := &ProgressReporter{
reporter: reporter,
interval: time.Second,
latest: make(map[string]*Task),
lastReported: make(map[string]TaskStatus),
onDeleteFiles: func(id string) { deletedID = id },
}
task := &Task{ID: "task-delete-1234567", Status: StatusCompleted}
pr.Track(task)
pr.flush(context.Background())
if deletedID != "task-delete-1234567" {
t.Errorf("expected deleteFiles handler called, got %q", deletedID)
}
}
func TestProgressReporter_HandleResponseStream(t *testing.T) {
reporter := &mockStatusReporter{
resp: &agent.StatusResponse{StreamRequested: true},
}
var streamID string
pr := &ProgressReporter{
reporter: reporter,
interval: time.Second,
latest: make(map[string]*Task),
lastReported: make(map[string]TaskStatus),
onStreamRequested: func(id string) { streamID = id },
}
// Task with no stream URL yet
task := &Task{ID: "task-stream-1234567", Status: StatusCompleted}
pr.Track(task)
pr.flush(context.Background())
if streamID != "task-stream-1234567" {
t.Errorf("expected stream handler called, got %q", streamID)
}
}
func TestProgressReporter_HandleResponseWatchingChanged(t *testing.T) {
reporter := &mockStatusReporter{
resp: &agent.StatusResponse{Watching: true},
}
var watchingValue bool
pr := &ProgressReporter{
reporter: reporter,
interval: time.Second,
latest: make(map[string]*Task),
lastReported: make(map[string]TaskStatus),
onWatchingChanged: func(w bool) { watchingValue = w },
}
task := &Task{ID: "task-watch2-1234567", Status: StatusCompleted}
pr.Track(task)
pr.flush(context.Background())
if !watchingValue {
t.Error("expected watchingChanged called with true")
}
}
func TestProgressReporter_BatchFlush(t *testing.T) {
batcher := &mockBatchReporter{
batchResp: &agent.BatchStatusResponse{
Results: []agent.StatusResponse{{}, {}},
},
}
pr := &ProgressReporter{
reporter: batcher,
interval: time.Second,
isWatching: func() bool { return true },
latest: make(map[string]*Task),
lastReported: make(map[string]TaskStatus),
}
pr.Track(&Task{ID: "task-batch1-1234567", Status: StatusDownloading})
pr.Track(&Task{ID: "task-batch2-1234567", Status: StatusDownloading})
pr.flush(context.Background())
batcher.mu.Lock()
defer batcher.mu.Unlock()
if len(batcher.batchCalls) != 1 {
t.Fatalf("expected 1 batch call, got %d", len(batcher.batchCalls))
}
if len(batcher.batchCalls[0]) != 2 {
t.Errorf("expected 2 updates in batch, got %d", len(batcher.batchCalls[0]))
}
}
func TestProgressReporter_RunStopsOnCancel(t *testing.T) {
reporter := &mockStatusReporter{}
pr := &ProgressReporter{
reporter: reporter,
interval: 50 * time.Millisecond,
latest: make(map[string]*Task),
lastReported: make(map[string]TaskStatus),
}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
err := pr.Run(ctx)
if err != nil {
t.Errorf("Run should return nil on context cancel, got: %v", err)
}
}
func TestProgressReporter_ReportFinal(t *testing.T) {
reporter := &mockStatusReporter{}
pr := &ProgressReporter{
reporter: reporter,
interval: time.Second,
latest: make(map[string]*Task),
lastReported: make(map[string]TaskStatus),
}
task := &Task{ID: "task-final-12345678", Status: StatusCompleted}
pr.Track(task)
pr.ReportFinal(context.Background(), task)
reporter.mu.Lock()
defer reporter.mu.Unlock()
if len(reporter.calls) != 1 {
t.Fatalf("expected 1 final report, got %d", len(reporter.calls))
}
// Should be untracked after final report
pr.mu.Lock()
if _, ok := pr.latest["task-final-12345678"]; ok {
t.Error("task should be untracked after ReportFinal")
}
pr.mu.Unlock()
}
func TestProgressReporter_SetHandlers(t *testing.T) {
pr := &ProgressReporter{
latest: make(map[string]*Task),
lastReported: make(map[string]TaskStatus),
}
pr.SetCancelHandler(func(id string) {})
pr.SetPauseHandler(func(id string) {})
pr.SetDeleteFilesHandler(func(id string) {})
pr.SetStreamRequestedHandler(func(id string) {})
pr.SetWatchingFunc(func() bool { return true })
pr.SetWatchingChangedHandler(func(w bool) {})
if pr.onCancel == nil || pr.onPause == nil || pr.onDeleteFiles == nil ||
pr.onStreamRequested == nil || pr.isWatching == nil || pr.onWatchingChanged == nil {
t.Error("expected all handlers to be set")
}
}
func TestProgressReporter_ControlCheckDue(t *testing.T) {
reporter := &mockStatusReporter{}
pr := &ProgressReporter{
reporter: reporter,
interval: time.Second,
isWatching: func() bool { return false },
latest: make(map[string]*Task),
lastReported: make(map[string]TaskStatus),
lastCheckAt: time.Now().Add(-31 * time.Second), // 31s ago - due for control check
}
task := &Task{ID: "task-ctrl-123456789", Status: StatusDownloading}
pr.Track(task)
pr.mu.Lock()
pr.lastReported["task-ctrl-123456789"] = StatusDownloading
pr.mu.Unlock()
pr.flush(context.Background())
reporter.mu.Lock()
defer reporter.mu.Unlock()
if len(reporter.calls) != 1 {
t.Errorf("expected 1 report for control check, got %d", len(reporter.calls))
}
}

View file

@ -11,6 +11,7 @@ import (
"os/exec"
"path/filepath"
"strings"
"sync/atomic"
"time"
"github.com/anacrolix/torrent"
@ -24,11 +25,12 @@ type fileProvider interface {
// StreamServer serves a torrent file over HTTP with Range request support.
type StreamServer struct {
provider fileProvider
server *http.Server
port int
url string
upnpMapping *UPnPMapping
provider fileProvider
server *http.Server
port int
url string
upnpMapping *UPnPMapping
lastActivity atomic.Int64 // UnixNano of last HTTP request
}
// NewStreamServer creates a new HTTP server for streaming via StreamEngine.
@ -93,11 +95,38 @@ func NewStreamServerFromDisk(filePath string, port int) *StreamServer {
}
}
// Start begins serving the file on all interfaces. Returns the best reachable URL:
// 1. UPnP public IP (accessible from anywhere on the internet)
// 2. Tailscale IP (accessible from any device in the tailnet)
// 3. LAN IP (accessible from local network)
// FindVideoFile scans a directory (recursively) for the largest video file.
// Returns empty string if no video file found.
func FindVideoFile(dir string) string {
var best string
var bestSize int64
filepath.WalkDir(dir, func(path string, d os.DirEntry, err error) error {
if err != nil || d.IsDir() {
return nil
}
ext := strings.ToLower(filepath.Ext(d.Name()))
if !VideoExts[ext] {
return nil
}
info, err := d.Info()
if err != nil {
return nil
}
if info.Size() > bestSize {
best = path
bestSize = info.Size()
}
return nil
})
return best
}
// Start begins serving the file on all interfaces. Returns the best reachable URL.
// The file is served as-is — the user's media player (VLC, mpv, etc.) handles decoding.
func (ss *StreamServer) Start(ctx context.Context) (string, error) {
ss.lastActivity.Store(time.Now().UnixNano())
mux := http.NewServeMux()
mux.HandleFunc("/stream", ss.handler)
@ -107,19 +136,9 @@ func (ss *StreamServer) Start(ctx context.Context) (string, error) {
return "", fmt.Errorf("listen on %s: %w", addr, err)
}
// Extract actual port (important when port=0)
ss.port = listener.Addr().(*net.TCPAddr).Port
// Try UPnP for public internet access (like Plex Remote Access)
if mapping, upnpErr := setupUPnP(ss.port); upnpErr == nil {
ss.upnpMapping = mapping
ss.url = fmt.Sprintf("http://%s:%d/stream", mapping.ExternalIP, mapping.ExternalPort)
log.Printf("stream: UPnP mapped %s:%d -> local:%d", mapping.ExternalIP, mapping.ExternalPort, ss.port)
} else {
// Fallback: Tailscale IP > LAN IP > 127.0.0.1
ss.url = fmt.Sprintf("http://%s:%d/stream", reachableIP(), ss.port)
log.Printf("stream: UPnP unavailable (%v), using %s", upnpErr, ss.url)
}
ss.url = fmt.Sprintf("http://%s:%d/stream", reachableIP(), ss.port)
log.Printf("stream: serving on %s", ss.url)
ss.server = &http.Server{
Handler: mux,
@ -141,6 +160,15 @@ func (ss *StreamServer) URL() string { return ss.url }
// Port returns the bound port.
func (ss *StreamServer) Port() int { return ss.port }
// IdleSince returns how long since the last HTTP request was received.
func (ss *StreamServer) IdleSince() time.Duration {
last := ss.lastActivity.Load()
if last == 0 {
return 0
}
return time.Since(time.Unix(0, last))
}
// Shutdown gracefully stops the HTTP server and removes the UPnP port mapping.
func (ss *StreamServer) Shutdown(ctx context.Context) error {
ss.upnpMapping.Remove()
@ -151,6 +179,8 @@ func (ss *StreamServer) Shutdown(ctx context.Context) error {
}
func (ss *StreamServer) handler(w http.ResponseWriter, r *http.Request) {
ss.lastActivity.Store(time.Now().UnixNano())
// CORS headers — only when browser sends Origin (HTTPS site → localhost)
if origin := r.Header.Get("Origin"); origin != "" {
w.Header().Set("Access-Control-Allow-Origin", "*")

View file

@ -78,6 +78,12 @@ func ExtractMediaInfo(ctx context.Context, ffprobePath, filePath string) (*Media
return nil, fmt.Errorf("ffprobe JSON parse failed: %w", err)
}
return parseFFprobeOutput(data)
}
// parseFFprobeOutput converts parsed ffprobe JSON into MediaInfo.
// Separated from ExtractMediaInfo so it can be tested without running ffprobe.
func parseFFprobeOutput(data ffprobeOutput) (*MediaInfo, error) {
if len(data.Streams) == 0 {
return nil, fmt.Errorf("ffprobe returned no streams")
}

View file

@ -0,0 +1,430 @@
package mediainfo
import (
"testing"
)
func TestParseDuration(t *testing.T) {
tests := []struct {
input string
want float64
}{
{"", 0},
{"0", 0},
{"-5", 0},
{"invalid", 0},
{"7423.500000", 7423.5},
{"120.123456", 120.123},
{"3600", 3600},
{"0.001", 0.001},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
got := parseDuration(tt.input)
if got != tt.want {
t.Errorf("parseDuration(%q) = %v, want %v", tt.input, got, tt.want)
}
})
}
}
func TestTagValue(t *testing.T) {
tags := map[string]string{
"language": "eng",
"title": "Main Audio",
"HANDLER": "VideoHandler",
}
tests := []struct {
key string
want string
}{
{"language", "eng"},
{"title", "Main Audio"},
{"handler", "VideoHandler"},
{"missing", ""},
}
for _, tt := range tests {
t.Run(tt.key, func(t *testing.T) {
got := tagValue(tags, tt.key)
if got != tt.want {
t.Errorf("tagValue(tags, %q) = %q, want %q", tt.key, got, tt.want)
}
})
}
}
func TestTagValueNil(t *testing.T) {
got := tagValue(nil, "language")
if got != "" {
t.Errorf("tagValue(nil, language) = %q, want empty", got)
}
}
func TestContainsAny(t *testing.T) {
tests := []struct {
s string
subs []string
want bool
}{
{"yuv420p10le", []string{"10le", "10be", "p010"}, true},
{"yuv420p12be", []string{"10le", "10be", "p010"}, false},
{"yuv420p12be", []string{"12le", "12be"}, true},
{"yuv420p", []string{"10le", "10be"}, false},
{"", []string{"any"}, false},
{"something", []string{}, false},
}
for _, tt := range tests {
got := containsAny(tt.s, tt.subs...)
if got != tt.want {
t.Errorf("containsAny(%q, %v) = %v, want %v", tt.s, tt.subs, got, tt.want)
}
}
}
func TestParseFFprobeOutput_BasicH264(t *testing.T) {
data := ffprobeOutput{
Format: ffprobeFormat{Duration: "7423.5"},
Streams: []ffprobeStream{
{
CodecType: "video",
CodecName: "h264",
Profile: "High",
Width: 1920,
Height: 1080,
RFrameRate: "24000/1001",
},
{
CodecType: "audio",
CodecName: "aac",
Channels: 2,
Tags: map[string]string{"language": "eng"},
Disposition: map[string]int{"default": 1},
},
},
}
mi, err := parseFFprobeOutput(data)
if err != nil {
t.Fatalf("parseFFprobeOutput: %v", err)
}
if mi.Video == nil {
t.Fatal("expected video info")
}
if mi.Video.Codec != "h264" {
t.Errorf("codec = %q, want h264", mi.Video.Codec)
}
if mi.Video.Width != 1920 || mi.Video.Height != 1080 {
t.Errorf("dimensions = %dx%d, want 1920x1080", mi.Video.Width, mi.Video.Height)
}
if mi.Video.Profile != "High" {
t.Errorf("profile = %q, want High", mi.Video.Profile)
}
if mi.Video.Duration != 7423.5 {
t.Errorf("duration = %v, want 7423.5", mi.Video.Duration)
}
if mi.Video.FrameRate < 23.975 || mi.Video.FrameRate > 23.977 {
t.Errorf("frameRate = %v, want ~23.976", mi.Video.FrameRate)
}
if len(mi.Audio) != 1 {
t.Fatalf("audio tracks = %d, want 1", len(mi.Audio))
}
if mi.Audio[0].Lang != "en" {
t.Errorf("audio lang = %q, want en", mi.Audio[0].Lang)
}
if !mi.Audio[0].Default {
t.Error("expected default audio track")
}
}
func TestParseFFprobeOutput_HEVC_HDR10(t *testing.T) {
data := ffprobeOutput{
Format: ffprobeFormat{Duration: "3600"},
Streams: []ffprobeStream{
{
CodecType: "video",
CodecName: "hevc",
Width: 3840,
Height: 2160,
BitsPerRaw: "10",
ColorSpace: "bt2020nc",
ColorTransfer: "smpte2084",
RFrameRate: "24/1",
},
},
}
mi, err := parseFFprobeOutput(data)
if err != nil {
t.Fatal(err)
}
if mi.Video.HDR != "HDR10" {
t.Errorf("hdr = %q, want HDR10", mi.Video.HDR)
}
if mi.Video.BitDepth != 10 {
t.Errorf("bitDepth = %d, want 10", mi.Video.BitDepth)
}
}
func TestParseFFprobeOutput_DolbyVisionWithHDR10(t *testing.T) {
data := ffprobeOutput{
Streams: []ffprobeStream{
{
CodecType: "video",
CodecName: "hevc",
Width: 3840,
Height: 2160,
ColorSpace: "bt2020nc",
ColorTransfer: "smpte2084",
SideDataList: []sideData{{SideDataType: "DOVI configuration record"}},
},
},
}
mi, err := parseFFprobeOutput(data)
if err != nil {
t.Fatal(err)
}
if mi.Video.HDR != "DV+HDR10" {
t.Errorf("hdr = %q, want DV+HDR10", mi.Video.HDR)
}
}
func TestParseFFprobeOutput_DolbyVisionOnly(t *testing.T) {
data := ffprobeOutput{
Streams: []ffprobeStream{
{
CodecType: "video",
CodecName: "hevc",
Width: 3840,
Height: 2160,
SideDataList: []sideData{{SideDataType: "DOVI configuration record"}},
},
},
}
mi, err := parseFFprobeOutput(data)
if err != nil {
t.Fatal(err)
}
if mi.Video.HDR != "DV" {
t.Errorf("hdr = %q, want DV", mi.Video.HDR)
}
}
func TestParseFFprobeOutput_HLG(t *testing.T) {
data := ffprobeOutput{
Streams: []ffprobeStream{
{
CodecType: "video",
CodecName: "hevc",
Width: 3840,
Height: 2160,
ColorSpace: "bt2020nc",
ColorTransfer: "arib-std-b67",
},
},
}
mi, err := parseFFprobeOutput(data)
if err != nil {
t.Fatal(err)
}
if mi.Video.HDR != "HLG" {
t.Errorf("hdr = %q, want HLG", mi.Video.HDR)
}
}
func TestParseFFprobeOutput_MultiAudioAndSubtitles(t *testing.T) {
data := ffprobeOutput{
Format: ffprobeFormat{Duration: "5400"},
Streams: []ffprobeStream{
{CodecType: "video", CodecName: "h264", Width: 1920, Height: 1080},
{
CodecType: "audio", CodecName: "ac3", Channels: 6,
Tags: map[string]string{"language": "eng", "title": "English 5.1"},
Disposition: map[string]int{"default": 1},
},
{
CodecType: "audio", CodecName: "aac", Channels: 2,
Tags: map[string]string{"language": "spa"},
},
{
CodecType: "subtitle", CodecName: "subrip",
Tags: map[string]string{"language": "eng"},
},
{
CodecType: "subtitle", CodecName: "ass",
Tags: map[string]string{"language": "spa"},
Disposition: map[string]int{"forced": 1},
},
},
}
mi, err := parseFFprobeOutput(data)
if err != nil {
t.Fatal(err)
}
if len(mi.Audio) != 2 {
t.Fatalf("audio tracks = %d, want 2", len(mi.Audio))
}
if mi.Audio[0].Title != "English 5.1" {
t.Errorf("audio[0].title = %q", mi.Audio[0].Title)
}
if len(mi.Subtitles) != 2 {
t.Fatalf("subtitle tracks = %d, want 2", len(mi.Subtitles))
}
if !mi.Subtitles[1].Forced {
t.Error("expected subtitle[1] to be forced")
}
if len(mi.Languages) != 2 {
t.Errorf("languages = %v, want 2 entries", mi.Languages)
}
}
func TestParseFFprobeOutput_BitDepthFromPixFmt(t *testing.T) {
data := ffprobeOutput{
Streams: []ffprobeStream{
{CodecType: "video", CodecName: "hevc", Width: 1920, Height: 1080, PixFmt: "yuv420p10le"},
},
}
mi, err := parseFFprobeOutput(data)
if err != nil {
t.Fatal(err)
}
if mi.Video.BitDepth != 10 {
t.Errorf("bitDepth = %d, want 10", mi.Video.BitDepth)
}
}
func TestParseFFprobeOutput_12BitFromPixFmt(t *testing.T) {
data := ffprobeOutput{
Streams: []ffprobeStream{
{CodecType: "video", CodecName: "hevc", Width: 1920, Height: 1080, PixFmt: "yuv420p12be"},
},
}
mi, err := parseFFprobeOutput(data)
if err != nil {
t.Fatal(err)
}
if mi.Video.BitDepth != 12 {
t.Errorf("bitDepth = %d, want 12", mi.Video.BitDepth)
}
}
func TestParseFFprobeOutput_DurationFromStreamFallback(t *testing.T) {
data := ffprobeOutput{
Format: ffprobeFormat{Duration: ""},
Streams: []ffprobeStream{
{CodecType: "video", CodecName: "h264", Width: 1280, Height: 720, Duration: "1800.5"},
},
}
mi, err := parseFFprobeOutput(data)
if err != nil {
t.Fatal(err)
}
if mi.Video.Duration != 1800.5 {
t.Errorf("duration = %v, want 1800.5", mi.Video.Duration)
}
}
func TestParseFFprobeOutput_NoStreams(t *testing.T) {
data := ffprobeOutput{}
_, err := parseFFprobeOutput(data)
if err == nil {
t.Error("expected error for no streams")
}
}
func TestParseFFprobeOutput_OnlyFirstVideoStream(t *testing.T) {
data := ffprobeOutput{
Streams: []ffprobeStream{
{CodecType: "video", CodecName: "h264", Width: 1920, Height: 1080},
{CodecType: "video", CodecName: "mjpeg", Width: 320, Height: 240}, // cover art
},
}
mi, err := parseFFprobeOutput(data)
if err != nil {
t.Fatal(err)
}
if mi.Video.Codec != "h264" {
t.Errorf("should use first video stream, got codec %q", mi.Video.Codec)
}
if mi.Video.Width != 1920 {
t.Errorf("width = %d, should be from first video stream", mi.Video.Width)
}
}
func TestParseFFprobeOutput_SMPTE2084_WithoutBT2020(t *testing.T) {
data := ffprobeOutput{
Streams: []ffprobeStream{
{CodecType: "video", CodecName: "hevc", Width: 3840, Height: 2160, ColorTransfer: "smpte2084"},
},
}
mi, err := parseFFprobeOutput(data)
if err != nil {
t.Fatal(err)
}
if mi.Video.HDR != "HDR10" {
t.Errorf("hdr = %q, want HDR10", mi.Video.HDR)
}
}
func TestParseFFprobeOutput_AribWithoutBT2020(t *testing.T) {
data := ffprobeOutput{
Streams: []ffprobeStream{
{CodecType: "video", CodecName: "hevc", Width: 3840, Height: 2160, ColorTransfer: "arib-std-b67", ColorSpace: "other"},
},
}
mi, err := parseFFprobeOutput(data)
if err != nil {
t.Fatal(err)
}
if mi.Video.HDR != "HLG" {
t.Errorf("hdr = %q, want HLG", mi.Video.HDR)
}
}
func TestParseFFprobeOutput_AudioOnly(t *testing.T) {
data := ffprobeOutput{
Streams: []ffprobeStream{
{CodecType: "audio", CodecName: "flac", Channels: 2, Tags: map[string]string{"language": "eng"}},
},
}
mi, err := parseFFprobeOutput(data)
if err != nil {
t.Fatal(err)
}
if mi.Video != nil {
t.Error("expected no video info for audio-only")
}
if len(mi.Audio) != 1 {
t.Errorf("audio tracks = %d, want 1", len(mi.Audio))
}
}
func TestParseFFprobeOutput_FrameRateNoSlash(t *testing.T) {
data := ffprobeOutput{
Streams: []ffprobeStream{
{CodecType: "video", CodecName: "h264", Width: 1920, Height: 1080, RFrameRate: "30"},
},
}
mi, err := parseFFprobeOutput(data)
if err != nil {
t.Fatal(err)
}
if mi.Video.FrameRate != 0 {
t.Errorf("frameRate = %v, want 0 (no slash)", mi.Video.FrameRate)
}
}

View file

@ -0,0 +1,93 @@
package library
import (
"os"
"path/filepath"
"testing"
)
func TestDiscoverFiles(t *testing.T) {
dir := t.TempDir()
// Create video files (need to be >= 100MB to pass size check)
largeContent := make([]byte, 101*1024*1024)
videoFiles := []string{"movie.mkv", "show.mp4", "clip.avi"}
for _, name := range videoFiles {
path := filepath.Join(dir, name)
if err := os.WriteFile(path, largeContent, 0o644); err != nil {
t.Fatalf("write %s: %v", name, err)
}
}
// Non-video files (should be excluded)
nonVideo := []string{"readme.txt", "cover.jpg", "subs.srt"}
for _, name := range nonVideo {
if err := os.WriteFile(filepath.Join(dir, name), largeContent, 0o644); err != nil {
t.Fatalf("write %s: %v", name, err)
}
}
// Small video file (should be excluded, < 100MB)
if err := os.WriteFile(filepath.Join(dir, "small.mkv"), []byte("small"), 0o644); err != nil {
t.Fatal(err)
}
// Excluded pattern (sample)
sampleDir := filepath.Join(dir, "sample")
os.MkdirAll(sampleDir, 0o755)
if err := os.WriteFile(filepath.Join(sampleDir, "sample.mkv"), largeContent, 0o644); err != nil {
t.Fatal(err)
}
files, err := discoverFiles(dir)
if err != nil {
t.Fatalf("discoverFiles: %v", err)
}
if len(files) != 3 {
t.Errorf("expected 3 files, got %d: %v", len(files), files)
}
// Check that all returned files are video extensions
for _, f := range files {
ext := filepath.Ext(f)
if ext != ".mkv" && ext != ".mp4" && ext != ".avi" {
t.Errorf("unexpected extension: %s", ext)
}
}
}
func TestDiscoverFilesEmptyDir(t *testing.T) {
dir := t.TempDir()
files, err := discoverFiles(dir)
if err != nil {
t.Fatalf("discoverFiles: %v", err)
}
if len(files) != 0 {
t.Errorf("expected 0 files, got %d", len(files))
}
}
func TestDiscoverFilesExcludePatterns(t *testing.T) {
dir := t.TempDir()
largeContent := make([]byte, 101*1024*1024)
excludeDirs := []string{"trailer", "featurette", "extras", "bonus"}
for _, name := range excludeDirs {
sub := filepath.Join(dir, name)
os.MkdirAll(sub, 0o755)
if err := os.WriteFile(filepath.Join(sub, "video.mkv"), largeContent, 0o644); err != nil {
t.Fatal(err)
}
}
files, err := discoverFiles(dir)
if err != nil {
t.Fatal(err)
}
if len(files) != 0 {
t.Errorf("expected 0 files (all excluded), got %d: %v", len(files), files)
}
}

View file

@ -0,0 +1,108 @@
package library
import (
"testing"
"github.com/torrentclaw/unarr/internal/library/mediainfo"
)
func TestBuildSyncItems(t *testing.T) {
cache := &LibraryCache{
Items: []LibraryItem{
{
FilePath: "/media/movies/Inception.mkv",
FileName: "Inception.2010.1080p.mkv",
FileSize: 5000000000,
Title: "Inception",
Year: "2010",
MediaInfo: &mediainfo.MediaInfo{
Video: &mediainfo.VideoInfo{
Codec: "hevc",
Width: 1920,
Height: 1080,
BitDepth: 10,
HDR: "HDR10",
},
Audio: []mediainfo.AudioTrack{
{Lang: "en", Codec: "ac3", Channels: 6, Default: true},
{Lang: "es", Codec: "aac", Channels: 2},
},
Subtitles: []mediainfo.SubtitleTrack{
{Lang: "en", Codec: "subrip"},
{Lang: "es", Codec: "subrip"},
},
},
},
{
FilePath: "/media/shows/Breaking.Bad.S01E01.mkv",
FileName: "Breaking.Bad.S01E01.mkv",
FileSize: 1000000000,
Title: "Breaking Bad",
Season: 1,
Episode: 1,
},
{
// Item with scan error — should be skipped
FilePath: "/media/bad.mkv",
FileName: "bad.mkv",
ScanError: "ffprobe failed",
},
},
}
items := BuildSyncItems(cache)
if len(items) != 2 {
t.Fatalf("expected 2 items (1 skipped), got %d", len(items))
}
// First item: movie with full media info
movie := items[0]
if movie.Title != "Inception" {
t.Errorf("title = %q, want Inception", movie.Title)
}
if movie.ContentType != "movie" {
t.Errorf("contentType = %q, want movie", movie.ContentType)
}
if movie.Resolution != "1080p" {
t.Errorf("resolution = %q, want 1080p", movie.Resolution)
}
if movie.VideoCodec != "hevc" {
t.Errorf("videoCodec = %q, want hevc", movie.VideoCodec)
}
if movie.HDR != "HDR10" {
t.Errorf("hdr = %q, want HDR10", movie.HDR)
}
if movie.AudioCodec != "ac3" {
t.Errorf("audioCodec = %q, want ac3", movie.AudioCodec)
}
if movie.AudioChannels != 6 {
t.Errorf("audioChannels = %d, want 6", movie.AudioChannels)
}
if len(movie.AudioLanguages) != 2 {
t.Errorf("audioLanguages count = %d, want 2", len(movie.AudioLanguages))
}
if len(movie.SubtitleLanguages) != 2 {
t.Errorf("subtitleLanguages count = %d, want 2", len(movie.SubtitleLanguages))
}
// Second item: show without media info
show := items[1]
if show.ContentType != "show" {
t.Errorf("contentType = %q, want show", show.ContentType)
}
if show.Season != 1 || show.Episode != 1 {
t.Errorf("season/episode = %d/%d, want 1/1", show.Season, show.Episode)
}
if show.Resolution != "" {
t.Errorf("resolution should be empty, got %q", show.Resolution)
}
}
func TestBuildSyncItemsEmpty(t *testing.T) {
cache := &LibraryCache{Items: nil}
items := BuildSyncItems(cache)
if len(items) != 0 {
t.Errorf("expected 0 items, got %d", len(items))
}
}

View file

@ -2,6 +2,8 @@ package mediaserver
import (
"encoding/json"
"os"
"path/filepath"
"testing"
)
@ -69,6 +71,96 @@ func TestJellyfinParsing(t *testing.T) {
}
}
func TestPlexTokenFromPrefs(t *testing.T) {
t.Run("valid prefs", func(t *testing.T) {
dir := t.TempDir()
prefsPath := filepath.Join(dir, "Preferences.xml")
xml := `<?xml version="1.0" encoding="utf-8"?>
<Preferences PlexOnlineToken="my-secret-token" OldestPreviousVersion="1.0"/>`
os.WriteFile(prefsPath, []byte(xml), 0o644)
token := plexTokenFromPrefs(prefsPath)
if token != "my-secret-token" {
t.Errorf("token = %q, want my-secret-token", token)
}
})
t.Run("no token attr", func(t *testing.T) {
dir := t.TempDir()
prefsPath := filepath.Join(dir, "Preferences.xml")
xml := `<?xml version="1.0"?><Preferences/>`
os.WriteFile(prefsPath, []byte(xml), 0o644)
token := plexTokenFromPrefs(prefsPath)
if token != "" {
t.Errorf("token = %q, want empty", token)
}
})
t.Run("file not found", func(t *testing.T) {
token := plexTokenFromPrefs("/nonexistent/Preferences.xml")
if token != "" {
t.Errorf("token = %q, want empty", token)
}
})
t.Run("invalid xml", func(t *testing.T) {
dir := t.TempDir()
prefsPath := filepath.Join(dir, "Preferences.xml")
os.WriteFile(prefsPath, []byte("not xml at all"), 0o644)
token := plexTokenFromPrefs(prefsPath)
if token != "" {
t.Errorf("token = %q, want empty", token)
}
})
}
func TestParsePlexSectionsMultipleLocations(t *testing.T) {
body := `{
"MediaContainer": {
"Directory": [
{
"title": "Movies",
"Location": [
{"path": "/media/movies"},
{"path": "/media/movies2"}
]
}
]
}
}`
paths := parsePlexSections([]byte(body))
if len(paths) != 2 {
t.Fatalf("expected 2 paths, got %d", len(paths))
}
}
func TestParsePlexSectionsEmptyPath(t *testing.T) {
body := `{
"MediaContainer": {
"Directory": [
{
"Location": [{"path": ""}, {"path": "/valid"}]
}
]
}
}`
paths := parsePlexSections([]byte(body))
if len(paths) != 1 {
t.Fatalf("expected 1 path (empty filtered), got %d: %v", len(paths), paths)
}
}
func TestCommonMediaDirs(t *testing.T) {
dirs := commonMediaDirs()
if len(dirs) == 0 {
t.Error("expected at least some common media dirs")
}
}
func TestParentDir(t *testing.T) {
tests := []struct {
name string

View file

@ -0,0 +1,47 @@
package sentry
import "testing"
func TestEnvironment(t *testing.T) {
tests := []struct {
version string
want string
}{
{"", "development"},
{"dev", "development"},
{"0.1.0-dev", "development"},
{"1.0.0", "production"},
{"0.3.5", "production"},
{"2.0.0-beta", "production"},
}
for _, tt := range tests {
t.Run(tt.version, func(t *testing.T) {
got := environment(tt.version)
if got != tt.want {
t.Errorf("environment(%q) = %q, want %q", tt.version, got, tt.want)
}
})
}
}
func TestInitNoOp(t *testing.T) {
// With empty dsn (default in tests), Init should be a no-op
Init("1.0.0")
// Should not panic
}
func TestCloseNoOp(t *testing.T) {
// Close should be safe to call without Init
Close()
}
func TestCaptureErrorNil(t *testing.T) {
// Should not panic with nil error
CaptureError(nil, "test")
}
func TestSetUser(t *testing.T) {
// Should not panic without initialization
SetUser("agent-123")
}

View file

@ -1,7 +1,9 @@
package ui
import (
"fmt"
"testing"
"time"
)
func TestFormatSize(t *testing.T) {
@ -127,6 +129,10 @@ func TestQualityIndicator(t *testing.T) {
{"medium", intPtr(60), "🟡"},
{"high", intPtr(80), "🟢"},
{"perfect", intPtr(100), "🟢"},
{"boundary_40", intPtr(40), "🟡"},
{"boundary_70", intPtr(70), "🟢"},
{"boundary_39", intPtr(39), "🔴"},
{"zero", intPtr(0), "🔴"},
}
for _, tt := range tests {
@ -139,6 +145,52 @@ func TestQualityIndicator(t *testing.T) {
}
}
func TestSeedHealthIndicator(t *testing.T) {
tests := []struct {
seeds int
want string
}{
{0, "🔴"},
{5, "🔴"},
{9, "🔴"},
{10, "🟡"},
{50, "🟡"},
{100, "🟡"},
{101, "🟢"},
{1000, "🟢"},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("seeds_%d", tt.seeds), func(t *testing.T) {
got := SeedHealthIndicator(tt.seeds)
if got != tt.want {
t.Errorf("SeedHealthIndicator(%d) = %q, want %q", tt.seeds, got, tt.want)
}
})
}
}
func TestFormatRating(t *testing.T) {
tests := []struct {
name string
input *string
want string
}{
{"nil", nil, "-"},
{"value", strPtr("8.5"), "8.5"},
{"empty", strPtr(""), ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := FormatRating(tt.input)
if got != tt.want {
t.Errorf("FormatRating(%v) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestStringOrDash(t *testing.T) {
s := "hello"
if got := StringOrDash(&s); got != "hello" {
@ -150,16 +202,160 @@ func TestStringOrDash(t *testing.T) {
}
func TestFormatContentType(t *testing.T) {
if got := FormatContentType("movie"); got != "Movie" {
t.Errorf("FormatContentType(movie) = %q, want Movie", got)
tests := []struct {
input string
want string
}{
{"movie", "Movie"},
{"Movie", "Movie"},
{"MOVIE", "Movie"},
{"show", "Show"},
{"Show", "Show"},
{"other", "other"},
{"", ""},
}
if got := FormatContentType("show"); got != "Show" {
t.Errorf("FormatContentType(show) = %q, want Show", got)
}
if got := FormatContentType("other"); got != "other" {
t.Errorf("FormatContentType(other) = %q, want other", got)
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
got := FormatContentType(tt.input)
if got != tt.want {
t.Errorf("FormatContentType(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func ptr[T any](v T) *T { return &v }
func intPtr(v int) *int { return &v }
func TestFormatLanguages(t *testing.T) {
tests := []struct {
name string
input []string
want string
}{
{"nil", nil, "-"},
{"empty", []string{}, "-"},
{"single", []string{"en"}, "en"},
{"multiple", []string{"en", "es", "fr"}, "en, es, fr"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := FormatLanguages(tt.input)
if got != tt.want {
t.Errorf("FormatLanguages(%v) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestFormatSeedRatio(t *testing.T) {
tests := []struct {
seeders int
leechers int
want string
}{
{0, 0, "0:0"},
{10, 0, "10:0"},
{100, 10, "10:1"},
{50, 50, "1:1"},
{1, 3, "0:1"},
{150, 10, "15:1"},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("%d_%d", tt.seeders, tt.leechers), func(t *testing.T) {
got := FormatSeedRatio(tt.seeders, tt.leechers)
if got != tt.want {
t.Errorf("FormatSeedRatio(%d, %d) = %q, want %q", tt.seeders, tt.leechers, got, tt.want)
}
})
}
}
func TestFormatTimeAgo(t *testing.T) {
now := time.Now()
tests := []struct {
name string
input string
want string
}{
{"invalid", "not-a-date", "not-a-date"},
{"just_now", now.Add(-10 * time.Second).Format(time.RFC3339), "just now"},
{"minutes", now.Add(-5 * time.Minute).Format(time.RFC3339), "5m ago"},
{"hours", now.Add(-3 * time.Hour).Format(time.RFC3339), "3h ago"},
{"days", now.Add(-7 * 24 * time.Hour).Format(time.RFC3339), "7d ago"},
{"months", now.Add(-60 * 24 * time.Hour).Format(time.RFC3339), "2mo ago"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := FormatTimeAgo(tt.input)
if got != tt.want {
t.Errorf("FormatTimeAgo(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestFormatNumberExtended(t *testing.T) {
tests := []struct {
input int
want string
}{
{-1000, "-1,000"},
{-5, "-5"},
{10000, "10,000"},
{100000, "100,000"},
{1000000, "1,000,000"},
}
for _, tt := range tests {
t.Run(tt.want, func(t *testing.T) {
got := FormatNumber(tt.input)
if got != tt.want {
t.Errorf("FormatNumber(%d) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestTruncateStringEdgeCases(t *testing.T) {
tests := []struct {
name string
input string
maxLen int
want string
}{
{"maxLen_1", "hello", 1, "h"},
{"maxLen_3", "hello", 3, "hel"},
{"empty", "", 5, ""},
{"unicode", "こんにちは世界", 5, "こん..."},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := TruncateString(tt.input, tt.maxLen)
if got != tt.want {
t.Errorf("TruncateString(%q, %d) = %q, want %q", tt.input, tt.maxLen, got, tt.want)
}
})
}
}
func TestPtr(t *testing.T) {
v := 42
p := Ptr(v)
if *p != 42 {
t.Errorf("Ptr(42) = %d, want 42", *p)
}
s := "hello"
sp := Ptr(s)
if *sp != "hello" {
t.Errorf("Ptr(hello) = %q, want hello", *sp)
}
}
func ptr[T any](v T) *T { return &v }
func intPtr(v int) *int { return &v }
func strPtr(v string) *string { return &v }

122
internal/ui/table_test.go Normal file
View file

@ -0,0 +1,122 @@
package ui
import (
"bytes"
"strings"
"testing"
tc "github.com/torrentclaw/go-client"
)
func TestNewCleanTable(t *testing.T) {
var buf bytes.Buffer
tbl := newCleanTable(&buf)
tbl.Header([]string{"A", "B"})
tbl.Append([]string{"1", "2"})
tbl.Render()
out := buf.String()
if !strings.Contains(out, "A") || !strings.Contains(out, "B") {
t.Errorf("expected headers in output, got: %s", out)
}
if !strings.Contains(out, "1") || !strings.Contains(out, "2") {
t.Errorf("expected row data in output, got: %s", out)
}
}
func TestPrintSearchResultEntry(t *testing.T) {
var buf bytes.Buffer
year := 2010
rating := "8.8"
quality := "1080p"
codec := "x265"
size := int64(4294967296)
score := 85
r := tc.SearchResult{
Title: "Inception",
Year: &year,
ContentType: "movie",
RatingIMDb: &rating,
Genres: []string{"Sci-Fi", "Action"},
Torrents: []tc.TorrentInfo{
{
Quality: &quality,
SizeBytes: &size,
Seeders: 150,
Leechers: 10,
Source: "YTS",
Codec: &codec,
Languages: []string{"en", "es"},
QualityScore: &score,
},
},
}
printSearchResultEntry(&buf, r)
out := buf.String()
if !strings.Contains(out, "Inception") {
t.Error("expected title in output")
}
if !strings.Contains(out, "2010") {
t.Error("expected year in output")
}
if !strings.Contains(out, "8.8") {
t.Error("expected rating in output")
}
if !strings.Contains(out, "1080p") {
t.Error("expected quality in output")
}
if !strings.Contains(out, "YTS") {
t.Error("expected source in output")
}
}
func TestPrintSearchResultEntryNoTorrents(t *testing.T) {
var buf bytes.Buffer
year := 2020
r := tc.SearchResult{
Title: "No Torrents Movie",
Year: &year,
ContentType: "movie",
Torrents: nil,
}
printSearchResultEntry(&buf, r)
out := buf.String()
if !strings.Contains(out, "No Torrents Movie") {
t.Error("expected title in output")
}
if !strings.Contains(out, "No torrents available") {
t.Error("expected no-torrents message")
}
}
func TestPrintSearchResultEntryNilFields(t *testing.T) {
var buf bytes.Buffer
r := tc.SearchResult{
Title: "Minimal",
ContentType: "movie",
Torrents: []tc.TorrentInfo{
{
Seeders: 5,
Leechers: 3,
Source: "RARBG",
},
},
}
printSearchResultEntry(&buf, r)
out := buf.String()
if !strings.Contains(out, "Minimal") {
t.Error("expected title in output")
}
if !strings.Contains(out, "-") {
t.Error("expected dash for nil fields")
}
}

75
internal/upgrade/cache.go Normal file
View file

@ -0,0 +1,75 @@
package upgrade
import (
"context"
"encoding/json"
"os"
"path/filepath"
"time"
"github.com/torrentclaw/unarr/internal/config"
)
const cacheTTL = 1 * time.Hour
// versionCache is the on-disk structure for cached version checks.
type versionCache struct {
Version string `json:"version"`
CheckedAt time.Time `json:"checkedAt"`
}
// cacheFilePath returns the path to the version cache file.
func cacheFilePath() string {
return filepath.Join(config.DataDir(), "latest-version.json")
}
// ReadCachedVersion returns the cached latest version if it's fresh (< cacheTTL).
// Returns empty string if cache is missing, stale, or corrupt.
func ReadCachedVersion() string {
data, err := os.ReadFile(cacheFilePath())
if err != nil {
return ""
}
var c versionCache
if json.Unmarshal(data, &c) != nil {
return ""
}
if time.Since(c.CheckedAt) > cacheTTL {
return ""
}
return c.Version
}
// writeCachedVersion writes the latest version to the cache file.
func writeCachedVersion(version string) {
c := versionCache{
Version: version,
CheckedAt: time.Now(),
}
data, err := json.Marshal(c)
if err != nil {
return
}
path := cacheFilePath()
os.MkdirAll(filepath.Dir(path), 0o755)
// Best-effort write — ignore errors
tmp := path + ".tmp"
if err := os.WriteFile(tmp, data, 0o644); err != nil {
return
}
os.Rename(tmp, path)
}
// CheckLatestCached returns the latest version, using cache when fresh.
// If cache is stale, fetches from GitHub and updates the cache.
func CheckLatestCached(ctx context.Context) (version string, fromCache bool, err error) {
if cached := ReadCachedVersion(); cached != "" {
return cached, true, nil
}
v, err := fetchLatestVersion(ctx)
if err != nil {
return "", false, err
}
writeCachedVersion(v)
return v, false, nil
}

View file

@ -152,9 +152,13 @@ func (u *Upgrader) fail(format string, args ...any) Result {
}
}
// CheckLatest fetches the latest version from GitHub API.
// CheckLatest fetches the latest version from GitHub API and updates the cache.
func CheckLatest(ctx context.Context) (string, error) {
return fetchLatestVersion(ctx)
v, err := fetchLatestVersion(ctx)
if err == nil {
writeCachedVersion(v)
}
return v, err
}
// installBinary copies the new binary to the target path, preserving original permissions.

View file

@ -12,6 +12,7 @@ import (
"os"
"path/filepath"
"runtime"
"strings"
"testing"
)
@ -305,3 +306,768 @@ func TestFetchLatestVersionMockServer(t *testing.T) {
t.Errorf("status = %d, want 200", resp.StatusCode)
}
}
// --- New tests below ---
// swapHTTPClient replaces the package-level httpClient and returns a restore function.
func swapHTTPClient(c *http.Client) func() {
orig := httpClient
httpClient = c
return func() { httpClient = orig }
}
// rewriteTransport redirects all requests to the given base URL,
// preserving path and query.
type rewriteTransport struct {
url string
}
func (rt *rewriteTransport) RoundTrip(req *http.Request) (*http.Response, error) {
req.URL.Scheme = "http"
u := strings.TrimPrefix(rt.url, "http://")
req.URL.Host = u
return http.DefaultTransport.RoundTrip(req)
}
// createTarGz is a test helper that creates a tar.gz file with a single file entry.
func createTarGz(t *testing.T, archivePath, entryName string, content []byte) {
t.Helper()
f, err := os.Create(archivePath)
if err != nil {
t.Fatal(err)
}
gw := gzip.NewWriter(f)
tw := tar.NewWriter(gw)
tw.WriteHeader(&tar.Header{
Name: entryName,
Mode: 0o755,
Size: int64(len(content)),
})
tw.Write(content)
tw.Close()
gw.Close()
f.Close()
}
func TestArchiveNameTableDriven(t *testing.T) {
// We can only run archiveName for the current GOOS/GOARCH,
// so we test several version strings and verify the pattern.
tests := []struct {
version string
}{
{"0.1.0"},
{"1.0.0-rc1"},
{"2.5.10"},
{"0.0.0"},
}
for _, tt := range tests {
t.Run(tt.version, func(t *testing.T) {
got := archiveName(tt.version)
prefix := fmt.Sprintf("unarr_%s_%s_%s.", tt.version, runtime.GOOS, runtime.GOARCH)
if !strings.HasPrefix(got, prefix) {
t.Errorf("archiveName(%q) = %q, want prefix %q", tt.version, got, prefix)
}
if runtime.GOOS == "windows" {
if !strings.HasSuffix(got, ".zip") {
t.Errorf("archiveName on windows should end with .zip, got %q", got)
}
} else {
if !strings.HasSuffix(got, ".tar.gz") {
t.Errorf("archiveName on non-windows should end with .tar.gz, got %q", got)
}
}
})
}
}
func TestReleaseURLEdgeCases(t *testing.T) {
tests := []struct {
name string
version string
filename string
wantURL string
}{
{
name: "pre-release version",
version: "2.0.0-beta.1",
filename: "unarr_2.0.0-beta.1_darwin_arm64.tar.gz",
wantURL: "https://github.com/torrentclaw/unarr/releases/download/v2.0.0-beta.1/unarr_2.0.0-beta.1_darwin_arm64.tar.gz",
},
{
name: "checksums file",
version: "3.0.0",
filename: "checksums.txt",
wantURL: "https://github.com/torrentclaw/unarr/releases/download/v3.0.0/checksums.txt",
},
{
name: "windows zip",
version: "1.2.3",
filename: "unarr_1.2.3_windows_amd64.zip",
wantURL: "https://github.com/torrentclaw/unarr/releases/download/v1.2.3/unarr_1.2.3_windows_amd64.zip",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := releaseURL(tt.version, tt.filename)
if got != tt.wantURL {
t.Errorf("releaseURL(%q, %q) = %q, want %q", tt.version, tt.filename, got, tt.wantURL)
}
})
}
}
func TestExtractBinaryDispatcher(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("extractBinary dispatcher test only on unix")
}
dir := t.TempDir()
// Create a valid tar.gz with the unarr binary
archivePath := filepath.Join(dir, "test.tar.gz")
binaryContent := []byte("#!/bin/sh\necho dispatcher test\n")
createTarGz(t, archivePath, "unarr", binaryContent)
destDir := filepath.Join(dir, "out")
os.MkdirAll(destDir, 0o755)
binPath, err := extractBinary(archivePath, destDir)
if err != nil {
t.Fatalf("extractBinary() error = %v", err)
}
if filepath.Base(binPath) != "unarr" {
t.Errorf("extractBinary() returned %q, want base name 'unarr'", binPath)
}
data, _ := os.ReadFile(binPath)
if string(data) != string(binaryContent) {
t.Error("extractBinary() content mismatch")
}
}
func TestExtractBinaryInvalidArchive(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("tar.gz test only on unix")
}
dir := t.TempDir()
archivePath := filepath.Join(dir, "garbage.tar.gz")
os.WriteFile(archivePath, []byte("this is not a tar.gz"), 0o644)
destDir := filepath.Join(dir, "out")
os.MkdirAll(destDir, 0o755)
_, err := extractBinary(archivePath, destDir)
if err == nil {
t.Error("extractBinary with garbage data should return error")
}
}
func TestExtractBinaryNonExistentArchive(t *testing.T) {
dir := t.TempDir()
_, err := extractBinary("/nonexistent-archive-file", filepath.Join(dir, "out"))
if err == nil {
t.Error("extractBinary with nonexistent file should return error")
}
}
func TestUpgraderFail(t *testing.T) {
u := &Upgrader{CurrentVersion: "1.0.0"}
// Capture log messages
var logged []string
u.OnProgress = func(msg string) { logged = append(logged, msg) }
result := u.fail("something went wrong: %d", 42)
if result.Success {
t.Error("fail() should return Success=false")
}
if result.OldVersion != "1.0.0" {
t.Errorf("fail() OldVersion = %q, want 1.0.0", result.OldVersion)
}
if result.Error == nil {
t.Fatal("fail() should set Error")
}
if !strings.Contains(result.Error.Error(), "something went wrong: 42") {
t.Errorf("fail() Error = %q, want to contain 'something went wrong: 42'", result.Error)
}
if len(logged) == 0 {
t.Error("fail() should call OnProgress")
}
if len(logged) > 0 && !strings.Contains(logged[0], "FAILED") {
t.Errorf("fail() logged %q, want to contain 'FAILED'", logged[0])
}
}
func TestUpgraderFailNilOnProgress(t *testing.T) {
u := &Upgrader{CurrentVersion: "2.0.0"}
// OnProgress is nil — should not panic
result := u.fail("error without listener")
if result.Success {
t.Error("fail() should return Success=false")
}
}
func TestFetchLatestVersionWithHTTPTest(t *testing.T) {
tests := []struct {
name string
body string
statusCode int
wantVer string
wantErr bool
}{
{
name: "valid response",
body: `{"tag_name":"v3.1.4"}`,
statusCode: 200,
wantVer: "3.1.4",
},
{
name: "valid response without v prefix",
body: `{"tag_name":"2.0.0"}`,
statusCode: 200,
wantVer: "2.0.0",
},
{
name: "empty tag_name",
body: `{"tag_name":""}`,
statusCode: 200,
wantErr: true,
},
{
name: "server error",
body: `Internal Server Error`,
statusCode: 500,
wantErr: true,
},
{
name: "invalid json",
body: `{invalid`,
statusCode: 200,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(tt.statusCode)
fmt.Fprint(w, tt.body)
}))
defer srv.Close()
// Use a custom transport that rewrites requests to our test server
restore := swapHTTPClient(&http.Client{
Transport: &rewriteTransport{url: srv.URL},
})
defer restore()
ver, err := CheckLatest(context.Background())
if tt.wantErr {
if err == nil {
t.Errorf("CheckLatest() error = nil, want error")
}
return
}
if err != nil {
t.Fatalf("CheckLatest() error = %v", err)
}
if ver != tt.wantVer {
t.Errorf("CheckLatest() = %q, want %q", ver, tt.wantVer)
}
})
}
}
func TestDownloadWithHTTPTest(t *testing.T) {
archiveBody := "fake-archive-bytes-for-download-test"
t.Run("successful download", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify user-agent header
if ua := r.Header.Get("User-Agent"); ua != "unarr-updater" {
t.Errorf("User-Agent = %q, want 'unarr-updater'", ua)
}
w.WriteHeader(200)
fmt.Fprint(w, archiveBody)
}))
defer srv.Close()
restore := swapHTTPClient(&http.Client{
Transport: &rewriteTransport{url: srv.URL},
})
defer restore()
path, err := download(context.Background(), "1.0.0")
if err != nil {
t.Fatalf("download() error = %v", err)
}
defer os.Remove(path)
data, err := os.ReadFile(path)
if err != nil {
t.Fatalf("read downloaded file: %v", err)
}
if string(data) != archiveBody {
t.Errorf("downloaded content = %q, want %q", data, archiveBody)
}
})
t.Run("server returns 404", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
}))
defer srv.Close()
restore := swapHTTPClient(&http.Client{
Transport: &rewriteTransport{url: srv.URL},
})
defer restore()
_, err := download(context.Background(), "99.99.99")
if err == nil {
t.Error("download() with 404 should return error")
}
if !strings.Contains(err.Error(), "HTTP 404") {
t.Errorf("download() error = %q, want to contain 'HTTP 404'", err)
}
})
t.Run("cancelled context", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
fmt.Fprint(w, "data")
}))
defer srv.Close()
restore := swapHTTPClient(&http.Client{
Transport: &rewriteTransport{url: srv.URL},
})
defer restore()
ctx, cancel := context.WithCancel(context.Background())
cancel() // cancel immediately
_, err := download(ctx, "1.0.0")
if err == nil {
t.Error("download() with cancelled context should return error")
}
})
}
func TestVerifyChecksumWithHTTPTest(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("tar.gz test only on unix")
}
// Create a fake archive file
dir := t.TempDir()
archiveContent := []byte("archive-content-for-checksum-test")
archivePath := filepath.Join(dir, "test-archive.tar.gz")
os.WriteFile(archivePath, archiveContent, 0o644)
h := sha256.Sum256(archiveContent)
correctHash := hex.EncodeToString(h[:])
// The function builds the archive name using archiveName(), which uses runtime.GOOS/GOARCH.
expectedArchiveName := archiveName("1.0.0")
t.Run("matching checksum", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "0000000000000000000000000000000000000000000000000000000000000000 other_file.tar.gz\n")
fmt.Fprintf(w, "%s %s\n", correctHash, expectedArchiveName)
}))
defer srv.Close()
restore := swapHTTPClient(&http.Client{
Transport: &rewriteTransport{url: srv.URL},
})
defer restore()
err := verifyChecksum(context.Background(), "1.0.0", archivePath)
if err != nil {
t.Errorf("verifyChecksum() = %v, want nil", err)
}
})
t.Run("mismatched checksum", func(t *testing.T) {
wrongHash := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "%s %s\n", wrongHash, expectedArchiveName)
}))
defer srv.Close()
restore := swapHTTPClient(&http.Client{
Transport: &rewriteTransport{url: srv.URL},
})
defer restore()
err := verifyChecksum(context.Background(), "1.0.0", archivePath)
if err == nil {
t.Error("verifyChecksum() with wrong hash should return error")
}
if !strings.Contains(err.Error(), "SHA256 mismatch") {
t.Errorf("verifyChecksum() error = %q, want to contain 'SHA256 mismatch'", err)
}
})
t.Run("archive not in checksums", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890 some_other_file.tar.gz\n")
}))
defer srv.Close()
restore := swapHTTPClient(&http.Client{
Transport: &rewriteTransport{url: srv.URL},
})
defer restore()
err := verifyChecksum(context.Background(), "1.0.0", archivePath)
if err == nil {
t.Error("verifyChecksum() with missing entry should return error")
}
if !strings.Contains(err.Error(), "no checksum found") {
t.Errorf("verifyChecksum() error = %q, want to contain 'no checksum found'", err)
}
})
t.Run("checksums server error", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(500)
}))
defer srv.Close()
restore := swapHTTPClient(&http.Client{
Transport: &rewriteTransport{url: srv.URL},
})
defer restore()
err := verifyChecksum(context.Background(), "1.0.0", archivePath)
if err == nil {
t.Error("verifyChecksum() with server error should return error")
}
if !strings.Contains(err.Error(), "HTTP 500") {
t.Errorf("verifyChecksum() error = %q, want to contain 'HTTP 500'", err)
}
})
t.Run("nonexistent archive file", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "%s %s\n", correctHash, expectedArchiveName)
}))
defer srv.Close()
restore := swapHTTPClient(&http.Client{
Transport: &rewriteTransport{url: srv.URL},
})
defer restore()
err := verifyChecksum(context.Background(), "1.0.0", "/nonexistent-archive-path")
if err == nil {
t.Error("verifyChecksum() with nonexistent archive should return error")
}
})
}
func TestVerifyChecksumCaseInsensitive(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("tar.gz test only on unix")
}
dir := t.TempDir()
archiveContent := []byte("case-insensitive-hash-test")
archivePath := filepath.Join(dir, "archive.tar.gz")
os.WriteFile(archivePath, archiveContent, 0o644)
h := sha256.Sum256(archiveContent)
// Use uppercase hash in checksums.txt — verifyChecksum uses EqualFold
upperHash := strings.ToUpper(hex.EncodeToString(h[:]))
expectedArchiveName := archiveName("1.0.0")
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "%s %s\n", upperHash, expectedArchiveName)
}))
defer srv.Close()
restore := swapHTTPClient(&http.Client{
Transport: &rewriteTransport{url: srv.URL},
})
defer restore()
err := verifyChecksum(context.Background(), "1.0.0", archivePath)
if err != nil {
t.Errorf("verifyChecksum() with uppercase hash = %v, want nil", err)
}
}
func TestExtractTarGzNestedDirectory(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("tar.gz test only on unix")
}
dir := t.TempDir()
archivePath := filepath.Join(dir, "nested.tar.gz")
binaryContent := []byte("#!/bin/sh\necho nested\n")
f, err := os.Create(archivePath)
if err != nil {
t.Fatal(err)
}
gw := gzip.NewWriter(f)
tw := tar.NewWriter(gw)
// Write a directory entry first
tw.WriteHeader(&tar.Header{
Name: "unarr_1.0.0_linux_amd64/",
Typeflag: tar.TypeDir,
Mode: 0o755,
})
// Write a README in the subdirectory
readmeContent := []byte("This is a README")
tw.WriteHeader(&tar.Header{
Name: "unarr_1.0.0_linux_amd64/README.md",
Mode: 0o644,
Size: int64(len(readmeContent)),
})
tw.Write(readmeContent)
// Write the binary nested inside the directory
tw.WriteHeader(&tar.Header{
Name: "unarr_1.0.0_linux_amd64/unarr",
Mode: 0o755,
Size: int64(len(binaryContent)),
})
tw.Write(binaryContent)
// Write another unrelated file after the binary
licenseContent := []byte("MIT License")
tw.WriteHeader(&tar.Header{
Name: "unarr_1.0.0_linux_amd64/LICENSE",
Mode: 0o644,
Size: int64(len(licenseContent)),
})
tw.Write(licenseContent)
tw.Close()
gw.Close()
f.Close()
destDir := filepath.Join(dir, "out")
os.MkdirAll(destDir, 0o755)
binPath, err := extractTarGz(archivePath, destDir)
if err != nil {
t.Fatalf("extractTarGz() with nested dir = %v", err)
}
if filepath.Base(binPath) != "unarr" {
t.Errorf("extracted name = %q, want 'unarr'", filepath.Base(binPath))
}
data, _ := os.ReadFile(binPath)
if string(data) != string(binaryContent) {
t.Error("extracted content does not match")
}
// Verify that only the binary was extracted (README and LICENSE should NOT be in destDir)
entries, _ := os.ReadDir(destDir)
if len(entries) != 1 {
names := make([]string, len(entries))
for i, e := range entries {
names[i] = e.Name()
}
t.Errorf("destDir should contain only 'unarr', got %v", names)
}
}
func TestExtractTarGzMultipleFiles(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("tar.gz test only on unix")
}
dir := t.TempDir()
archivePath := filepath.Join(dir, "multi.tar.gz")
binaryContent := []byte("#!/bin/sh\necho multi\n")
f, err := os.Create(archivePath)
if err != nil {
t.Fatal(err)
}
gw := gzip.NewWriter(f)
tw := tar.NewWriter(gw)
// Several non-binary files before the actual binary
for _, name := range []string{"README.md", "LICENSE", "config.yaml", "completions.bash"} {
content := []byte("content of " + name)
tw.WriteHeader(&tar.Header{
Name: name,
Mode: 0o644,
Size: int64(len(content)),
})
tw.Write(content)
}
// The actual binary
tw.WriteHeader(&tar.Header{
Name: "unarr",
Mode: 0o755,
Size: int64(len(binaryContent)),
})
tw.Write(binaryContent)
tw.Close()
gw.Close()
f.Close()
destDir := filepath.Join(dir, "out")
os.MkdirAll(destDir, 0o755)
binPath, err := extractTarGz(archivePath, destDir)
if err != nil {
t.Fatalf("extractTarGz() = %v", err)
}
data, _ := os.ReadFile(binPath)
if string(data) != string(binaryContent) {
t.Error("binary content mismatch among multiple files")
}
}
func TestExtractTarGzSymlinkSkipped(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("tar.gz test only on unix")
}
dir := t.TempDir()
archivePath := filepath.Join(dir, "symlink.tar.gz")
f, err := os.Create(archivePath)
if err != nil {
t.Fatal(err)
}
gw := gzip.NewWriter(f)
tw := tar.NewWriter(gw)
// Write a symlink entry named "unarr" — should be skipped because Typeflag != TypeReg
tw.WriteHeader(&tar.Header{
Name: "unarr",
Typeflag: tar.TypeSymlink,
Linkname: "/etc/passwd",
Mode: 0o755,
})
tw.Close()
gw.Close()
f.Close()
destDir := filepath.Join(dir, "out")
os.MkdirAll(destDir, 0o755)
_, err = extractTarGz(archivePath, destDir)
if err == nil {
t.Error("extractTarGz() should return error when binary is a symlink (not TypeReg)")
}
if !strings.Contains(err.Error(), "not found in archive") {
t.Errorf("error = %q, want to contain 'not found in archive'", err)
}
}
func TestInstallBinaryNonExistentSource(t *testing.T) {
dir := t.TempDir()
dst := filepath.Join(dir, "output")
err := installBinary("/nonexistent-source-binary", dst)
if err == nil {
t.Error("installBinary with nonexistent source should return error")
}
if !strings.Contains(err.Error(), "read new binary") {
t.Errorf("error = %q, want to contain 'read new binary'", err)
}
// Verify destination was not created
if _, statErr := os.Stat(dst); statErr == nil {
t.Error("destination file should not exist after failed install")
}
}
func TestInstallBinaryUnwritableDestination(t *testing.T) {
dir := t.TempDir()
src := filepath.Join(dir, "source")
os.WriteFile(src, []byte("binary"), 0o755)
// Try to write to a path inside a non-existent directory
dst := filepath.Join(dir, "nonexistent-subdir", "binary")
err := installBinary(src, dst)
if err == nil {
t.Error("installBinary to non-writable destination should return error")
}
if !strings.Contains(err.Error(), "write binary") {
t.Errorf("error = %q, want to contain 'write binary'", err)
}
}
func TestUpgraderLog(t *testing.T) {
var messages []string
u := &Upgrader{
CurrentVersion: "1.0.0",
OnProgress: func(msg string) { messages = append(messages, msg) },
}
u.log("hello world")
if len(messages) != 1 || messages[0] != "hello world" {
t.Errorf("log() messages = %v, want [hello world]", messages)
}
}
func TestUpgraderLogNilOnProgress(t *testing.T) {
u := &Upgrader{CurrentVersion: "1.0.0"}
// Should not panic
u.log("test message with nil OnProgress")
}
func TestResultFields(t *testing.T) {
r := Result{
Success: true,
OldVersion: "1.0.0",
NewVersion: "2.0.0",
BackupPath: "/tmp/backup",
}
if !r.Success || r.OldVersion != "1.0.0" || r.NewVersion != "2.0.0" || r.BackupPath != "/tmp/backup" {
t.Errorf("Result fields not set correctly: %+v", r)
}
r2 := Result{Success: false, Error: fmt.Errorf("test error")}
if r2.Success || r2.Error == nil {
t.Errorf("Result error case not correct: %+v", r2)
}
}
func TestDownloadSetsUserAgent(t *testing.T) {
var gotUA string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotUA = r.Header.Get("User-Agent")
w.WriteHeader(200)
fmt.Fprint(w, "data")
}))
defer srv.Close()
restore := swapHTTPClient(&http.Client{
Transport: &rewriteTransport{url: srv.URL},
})
defer restore()
path, err := download(context.Background(), "1.0.0")
if err != nil {
t.Fatalf("download() = %v", err)
}
defer os.Remove(path)
if gotUA != "unarr-updater" {
t.Errorf("User-Agent = %q, want 'unarr-updater'", gotUA)
}
}

View file

@ -0,0 +1,632 @@
package download
import (
"encoding/binary"
"os"
"path/filepath"
"sync"
"testing"
"time"
"github.com/torrentclaw/unarr/internal/usenet/nzb"
)
// --- Fingerprint ---
func TestFingerprint_EmptyNZB(t *testing.T) {
n := &nzb.NZB{}
fp := Fingerprint(n)
// Empty NZB should still produce a deterministic hash (of zero message IDs).
fp2 := Fingerprint(n)
if fp != fp2 {
t.Fatal("fingerprint of empty NZB should be deterministic")
}
}
func TestFingerprint_OrderIndependent(t *testing.T) {
// Fingerprint sorts IDs, so different file order should produce the same hash.
n1 := &nzb.NZB{
Files: []nzb.File{
{Segments: []nzb.Segment{{MessageID: "a@x"}, {MessageID: "b@x"}}},
{Segments: []nzb.Segment{{MessageID: "c@x"}}},
},
}
n2 := &nzb.NZB{
Files: []nzb.File{
{Segments: []nzb.Segment{{MessageID: "c@x"}}},
{Segments: []nzb.Segment{{MessageID: "b@x"}, {MessageID: "a@x"}}},
},
}
if Fingerprint(n1) != Fingerprint(n2) {
t.Fatal("fingerprint should be order-independent (sorted by message ID)")
}
}
func TestFingerprint_SingleSegment(t *testing.T) {
n := &nzb.NZB{
Files: []nzb.File{
{Segments: []nzb.Segment{{MessageID: "only@one"}}},
},
}
fp := Fingerprint(n)
if fp == [32]byte{} {
t.Fatal("fingerprint should not be zero for a non-empty NZB")
}
}
// --- ProgressTracker MarkDone idempotency ---
func TestProgressTracker_MarkDoneIdempotent(t *testing.T) {
dir := t.TempDir()
n := makeTestNZB(1, 5)
tracker := NewProgressTracker("idem", n, dir)
tracker.MarkDone(0, 2)
if tracker.CompletedSegments(0) != 1 {
t.Fatalf("expected 1, got %d", tracker.CompletedSegments(0))
}
// Mark the same segment again — count should not increase.
tracker.MarkDone(0, 2)
if tracker.CompletedSegments(0) != 1 {
t.Fatalf("idempotent mark: expected 1, got %d", tracker.CompletedSegments(0))
}
}
// --- ProgressTracker TotalCompleted ---
func TestProgressTracker_TotalCompleted(t *testing.T) {
dir := t.TempDir()
n := makeTestNZB(3, 4) // 3 files, 4 segs each
tracker := NewProgressTracker("total", n, dir)
tracker.MarkDone(0, 0)
tracker.MarkDone(0, 1)
tracker.MarkDone(1, 3)
tracker.MarkDone(2, 0)
tracker.MarkDone(2, 1)
tracker.MarkDone(2, 2)
if got := tracker.TotalCompleted(); got != 6 {
t.Errorf("TotalCompleted: got %d, want 6", got)
}
}
func TestProgressTracker_TotalCompleted_Empty(t *testing.T) {
dir := t.TempDir()
n := makeTestNZB(2, 3)
tracker := NewProgressTracker("empty-total", n, dir)
if got := tracker.TotalCompleted(); got != 0 {
t.Errorf("TotalCompleted on fresh tracker: got %d, want 0", got)
}
}
// --- CompletedBytes edge cases ---
func TestProgressTracker_CompletedBytes_OutOfBounds(t *testing.T) {
dir := t.TempDir()
n := makeTestNZB(1, 3)
tracker := NewProgressTracker("cb-oob", n, dir)
if got := tracker.CompletedBytes(-1, n.Files[0].Segments); got != 0 {
t.Errorf("CompletedBytes with file -1: got %d, want 0", got)
}
if got := tracker.CompletedBytes(5, n.Files[0].Segments); got != 0 {
t.Errorf("CompletedBytes with file 5: got %d, want 0", got)
}
}
func TestProgressTracker_CompletedBytes_AllDone(t *testing.T) {
dir := t.TempDir()
n := makeTestNZB(1, 3)
tracker := NewProgressTracker("cb-all", n, dir)
for i := 0; i < 3; i++ {
tracker.MarkDone(0, i)
}
got := tracker.CompletedBytes(0, n.Files[0].Segments)
expected := int64(3 * 750 * 1024)
if got != expected {
t.Errorf("CompletedBytes all done: got %d, want %d", got, expected)
}
}
// --- CompletedSegments out of bounds ---
func TestProgressTracker_CompletedSegments_OutOfBounds(t *testing.T) {
dir := t.TempDir()
n := makeTestNZB(1, 3)
tracker := NewProgressTracker("cs-oob", n, dir)
if got := tracker.CompletedSegments(-1); got != 0 {
t.Errorf("CompletedSegments(-1) = %d, want 0", got)
}
if got := tracker.CompletedSegments(99); got != 0 {
t.Errorf("CompletedSegments(99) = %d, want 0", got)
}
}
// --- Load with corrupted / truncated data ---
func TestProgressTracker_Load_TruncatedHeader(t *testing.T) {
dir := t.TempDir()
n := makeTestNZB(1, 3)
tracker := NewProgressTracker("trunc", n, dir)
// Write too-short data
os.WriteFile(tracker.progressPath(), []byte("UNR"), 0o644)
loaded, err := tracker.Load()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if loaded {
t.Error("truncated header should not load")
}
}
func TestProgressTracker_Load_BadMagic(t *testing.T) {
dir := t.TempDir()
n := makeTestNZB(1, 3)
tracker := NewProgressTracker("badmagic", n, dir)
// Write data with wrong magic bytes
data := make([]byte, headerSize+10)
copy(data[0:4], []byte("BAAD"))
os.WriteFile(tracker.progressPath(), data, 0o644)
loaded, err := tracker.Load()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if loaded {
t.Error("bad magic should not load")
}
}
func TestProgressTracker_Load_BadVersion(t *testing.T) {
dir := t.TempDir()
n := makeTestNZB(1, 3)
tracker := NewProgressTracker("badver", n, dir)
data := make([]byte, headerSize+10)
copy(data[0:4], progressMagic[:])
data[4] = 99 // unsupported version
os.WriteFile(tracker.progressPath(), data, 0o644)
loaded, err := tracker.Load()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if loaded {
t.Error("bad version should not load")
}
}
func TestProgressTracker_Load_WrongFileCount(t *testing.T) {
dir := t.TempDir()
n := makeTestNZB(2, 3)
tracker := NewProgressTracker("wrongfc", n, dir)
data := make([]byte, headerSize+20)
copy(data[0:4], progressMagic[:])
data[4] = progressVersion
binary.LittleEndian.PutUint16(data[6:8], 99) // wrong file count
copy(data[8:40], tracker.fingerprint[:])
os.WriteFile(tracker.progressPath(), data, 0o644)
loaded, err := tracker.Load()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if loaded {
t.Error("wrong file count should not load")
}
}
func TestProgressTracker_Load_TruncatedBitset(t *testing.T) {
dir := t.TempDir()
n := makeTestNZB(1, 16)
tracker := NewProgressTracker("truncbit", n, dir)
// Build a valid header but truncate the bitset data
data := make([]byte, headerSize+4) // header + segCount but no bitset
copy(data[0:4], progressMagic[:])
data[4] = progressVersion
binary.LittleEndian.PutUint16(data[6:8], 1) // 1 file
copy(data[8:40], tracker.fingerprint[:])
binary.LittleEndian.PutUint32(data[headerSize:headerSize+4], 16) // 16 segs
// No bitset data follows — truncated
os.WriteFile(tracker.progressPath(), data, 0o644)
loaded, err := tracker.Load()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if loaded {
t.Error("truncated bitset should not load")
}
}
func TestProgressTracker_Load_SegCountMismatch(t *testing.T) {
dir := t.TempDir()
n := makeTestNZB(1, 5)
tracker := NewProgressTracker("segmis", n, dir)
// Build valid header with correct file count and fingerprint, but wrong segCount
bitsetLen := (999 + 7) / 8
data := make([]byte, headerSize+4+bitsetLen)
copy(data[0:4], progressMagic[:])
data[4] = progressVersion
binary.LittleEndian.PutUint16(data[6:8], 1)
copy(data[8:40], tracker.fingerprint[:])
binary.LittleEndian.PutUint32(data[headerSize:headerSize+4], 999) // wrong seg count
os.WriteFile(tracker.progressPath(), data, 0o644)
loaded, err := tracker.Load()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if loaded {
t.Error("segment count mismatch should not load")
}
}
func TestProgressTracker_Load_NonexistentFile(t *testing.T) {
dir := t.TempDir()
n := makeTestNZB(1, 3)
tracker := NewProgressTracker("nofile", n, dir)
loaded, err := tracker.Load()
if err != nil {
t.Fatalf("unexpected error for missing file: %v", err)
}
if loaded {
t.Error("nonexistent file should return false")
}
}
// --- Flush and Load round-trip with multiple files ---
func TestProgressTracker_MultiFileRoundTrip(t *testing.T) {
dir := t.TempDir()
n := makeTestNZB(5, 10) // 5 files, 10 segments each
tracker := NewProgressTracker("multi-rt", n, dir)
// Mark various segments across files
tracker.MarkDone(0, 0)
tracker.MarkDone(0, 9)
tracker.MarkDone(1, 5)
tracker.MarkDone(2, 0)
tracker.MarkDone(2, 1)
tracker.MarkDone(2, 2)
tracker.MarkDone(2, 3)
tracker.MarkDone(2, 4)
tracker.MarkDone(2, 5)
tracker.MarkDone(2, 6)
tracker.MarkDone(2, 7)
tracker.MarkDone(2, 8)
tracker.MarkDone(2, 9) // file 2 fully done
tracker.MarkDone(4, 7)
if err := tracker.Flush(); err != nil {
t.Fatalf("flush: %v", err)
}
// Reload
tracker2 := NewProgressTracker("multi-rt", n, dir)
loaded, err := tracker2.Load()
if err != nil {
t.Fatalf("load: %v", err)
}
if !loaded {
t.Fatal("should load")
}
if tracker2.CompletedSegments(0) != 2 {
t.Errorf("file 0: got %d, want 2", tracker2.CompletedSegments(0))
}
if tracker2.CompletedSegments(1) != 1 {
t.Errorf("file 1: got %d, want 1", tracker2.CompletedSegments(1))
}
if !tracker2.IsFileDone(2) {
t.Error("file 2 should be done")
}
if tracker2.CompletedSegments(3) != 0 {
t.Errorf("file 3: got %d, want 0", tracker2.CompletedSegments(3))
}
if tracker2.CompletedSegments(4) != 1 {
t.Errorf("file 4: got %d, want 1", tracker2.CompletedSegments(4))
}
if tracker2.TotalCompleted() != 14 {
t.Errorf("TotalCompleted: got %d, want 14", tracker2.TotalCompleted())
}
}
// --- Concurrent mark + IsDone reads ---
func TestProgressTracker_ConcurrentMarkAndRead(t *testing.T) {
dir := t.TempDir()
segCount := 500
n := makeTestNZB(2, segCount)
tracker := NewProgressTracker("conc-rw", n, dir)
// Use separate WaitGroups for writers and readers
var writerWg sync.WaitGroup
stop := make(chan struct{})
// Writers
for file := 0; file < 2; file++ {
for seg := 0; seg < segCount; seg++ {
writerWg.Add(1)
go func(f, s int) {
defer writerWg.Done()
tracker.MarkDone(f, s)
}(file, seg)
}
}
// Readers — continuously read while writes happen
var readerWg sync.WaitGroup
for i := 0; i < 4; i++ {
readerWg.Add(1)
go func() {
defer readerWg.Done()
for {
select {
case <-stop:
return
default:
// These should never panic
tracker.IsDone(0, 0)
tracker.IsDone(1, segCount-1)
tracker.IsFileDone(0)
tracker.CompletedSegments(1)
tracker.TotalCompleted()
}
}
}()
}
// Wait for all writers to finish, then stop readers
writerWg.Wait()
close(stop)
readerWg.Wait()
// After all goroutines complete, everything should be done
if !tracker.IsFileDone(0) {
t.Errorf("file 0 should be done, got %d/%d", tracker.CompletedSegments(0), segCount)
}
if !tracker.IsFileDone(1) {
t.Errorf("file 1 should be done, got %d/%d", tracker.CompletedSegments(1), segCount)
}
}
// --- Concurrent flush safety ---
func TestProgressTracker_ConcurrentFlush(t *testing.T) {
dir := t.TempDir()
n := makeTestNZB(1, 100)
tracker := NewProgressTracker("conc-flush", n, dir)
// Mark some segments
for i := 0; i < 50; i++ {
tracker.MarkDone(0, i)
}
// Multiple concurrent flushes should not panic or corrupt
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
tracker.Flush()
}()
}
wg.Wait()
// Verify state is loadable
tracker2 := NewProgressTracker("conc-flush", n, dir)
loaded, err := tracker2.Load()
if err != nil {
t.Fatalf("load: %v", err)
}
if !loaded {
t.Fatal("should load after concurrent flushes")
}
if tracker2.CompletedSegments(0) != 50 {
t.Errorf("after concurrent flush: got %d, want 50", tracker2.CompletedSegments(0))
}
}
// --- Remove with .tmp file ---
func TestProgressTracker_Remove_WithTmpFile(t *testing.T) {
dir := t.TempDir()
n := makeTestNZB(1, 3)
tracker := NewProgressTracker("rm-tmp", n, dir)
// Create all three files that Remove should clean up
os.WriteFile(tracker.progressPath(), []byte("data"), 0o644)
os.WriteFile(tracker.nzbPath(), []byte("<nzb/>"), 0o644)
os.WriteFile(tracker.progressPath()+".tmp", []byte("tmp"), 0o644)
tracker.Remove()
for _, p := range []string{tracker.progressPath(), tracker.nzbPath(), tracker.progressPath() + ".tmp"} {
if _, err := os.Stat(p); !os.IsNotExist(err) {
t.Errorf("file should be removed: %s", p)
}
}
}
// --- CleanStaleFiles edge cases ---
func TestCleanStaleFiles_EmptyDir(t *testing.T) {
dir := t.TempDir()
if got := CleanStaleFiles(dir, time.Hour); got != 0 {
t.Errorf("empty dir: got %d removed, want 0", got)
}
}
func TestCleanStaleFiles_NonexistentDir(t *testing.T) {
if got := CleanStaleFiles("/nonexistent/path/that/does/not/exist", time.Hour); got != 0 {
t.Errorf("nonexistent dir: got %d removed, want 0", got)
}
}
func TestCleanStaleFiles_AllFresh(t *testing.T) {
dir := t.TempDir()
os.WriteFile(filepath.Join(dir, "a.progress"), []byte("a"), 0o644)
os.WriteFile(filepath.Join(dir, "b.progress"), []byte("b"), 0o644)
if got := CleanStaleFiles(dir, 24*time.Hour); got != 0 {
t.Errorf("all fresh: got %d removed, want 0", got)
}
}
func TestCleanStaleFiles_SkipsSubdirs(t *testing.T) {
dir := t.TempDir()
subDir := filepath.Join(dir, "subdir")
os.MkdirAll(subDir, 0o755)
// Backdate the subdir (it should not be removed)
os.Chtimes(subDir, fixedPast, fixedPast)
if got := CleanStaleFiles(dir, 24*time.Hour); got != 0 {
t.Errorf("should skip subdirs: got %d removed, want 0", got)
}
if _, err := os.Stat(subDir); err != nil {
t.Error("subdir should still exist")
}
}
func TestCleanStaleFiles_MixedAges(t *testing.T) {
dir := t.TempDir()
stale1 := filepath.Join(dir, "old1.progress")
stale2 := filepath.Join(dir, "old2.nzb")
fresh := filepath.Join(dir, "new.progress")
os.WriteFile(stale1, []byte("x"), 0o644)
os.WriteFile(stale2, []byte("x"), 0o644)
os.WriteFile(fresh, []byte("x"), 0o644)
os.Chtimes(stale1, fixedPast, fixedPast)
os.Chtimes(stale2, fixedPast, fixedPast)
if got := CleanStaleFiles(dir, 7*24*time.Hour); got != 2 {
t.Errorf("mixed ages: got %d removed, want 2", got)
}
if _, err := os.Stat(fresh); err != nil {
t.Error("fresh file should still exist")
}
}
// --- progressPath / nzbPath ---
func TestProgressTracker_Paths(t *testing.T) {
dir := "/some/dir"
n := makeTestNZB(1, 1)
tracker := NewProgressTracker("my-task", n, dir)
if got := tracker.progressPath(); got != filepath.Join(dir, "my-task.progress") {
t.Errorf("progressPath: got %q", got)
}
if got := tracker.nzbPath(); got != filepath.Join(dir, "my-task.nzb") {
t.Errorf("nzbPath: got %q", got)
}
}
// --- formatBytes ---
func TestFormatBytes(t *testing.T) {
tests := []struct {
input int64
want string
}{
{0, "0 B"},
{500, "500 B"},
{1023, "1023 B"},
{1024, "1.0 KB"},
{1536, "1.5 KB"},
{1048576, "1.0 MB"},
{1073741824, "1.0 GB"},
{1099511627776, "1.0 TB"},
}
for _, tt := range tests {
got := formatBytes(tt.input)
if got != tt.want {
t.Errorf("formatBytes(%d) = %q, want %q", tt.input, got, tt.want)
}
}
}
// --- Single file with 1 segment (boundary) ---
func TestProgressTracker_SingleSegment(t *testing.T) {
dir := t.TempDir()
n := makeTestNZB(1, 1)
tracker := NewProgressTracker("single-seg", n, dir)
if tracker.IsFileDone(0) {
t.Error("should not be done initially")
}
tracker.MarkDone(0, 0)
if !tracker.IsFileDone(0) {
t.Error("should be done after marking the only segment")
}
if err := tracker.Flush(); err != nil {
t.Fatalf("flush: %v", err)
}
tracker2 := NewProgressTracker("single-seg", n, dir)
loaded, _ := tracker2.Load()
if !loaded {
t.Fatal("should load")
}
if !tracker2.IsFileDone(0) {
t.Error("should be done after reload")
}
}
// --- Flush creates directory if missing ---
func TestProgressTracker_FlushCreatesDir(t *testing.T) {
base := t.TempDir()
dir := filepath.Join(base, "nested", "resume")
n := makeTestNZB(1, 2)
tracker := NewProgressTracker("mkdir-test", n, dir)
tracker.MarkDone(0, 0)
if err := tracker.Flush(); err != nil {
t.Fatalf("flush should create dir: %v", err)
}
if _, err := os.Stat(tracker.progressPath()); err != nil {
t.Fatalf("progress file should exist: %v", err)
}
}
// --- Double flush after no new marks ---
func TestProgressTracker_DoubleFlush(t *testing.T) {
dir := t.TempDir()
n := makeTestNZB(1, 3)
tracker := NewProgressTracker("dbl-flush", n, dir)
tracker.MarkDone(0, 0)
if err := tracker.Flush(); err != nil {
t.Fatalf("first flush: %v", err)
}
// Second flush without new marks should be a no-op (dirty=false)
if err := tracker.Flush(); err != nil {
t.Fatalf("second flush: %v", err)
}
}

View file

@ -0,0 +1,131 @@
package nntp
import (
"bufio"
"bytes"
"testing"
)
func TestNewClient(t *testing.T) {
c := NewClient(Config{Host: "news.example.com", Port: 563, SSL: true})
if c.cfg.MaxConnections != 10 {
t.Errorf("default MaxConnections = %d, want 10", c.cfg.MaxConnections)
}
if c.cfg.Host != "news.example.com" {
t.Errorf("Host = %q", c.cfg.Host)
}
}
func TestNewClientCustomConnections(t *testing.T) {
c := NewClient(Config{Host: "news.example.com", Port: 563, MaxConnections: 20})
if c.cfg.MaxConnections != 20 {
t.Errorf("MaxConnections = %d, want 20", c.cfg.MaxConnections)
}
}
func TestNewClientZeroConnections(t *testing.T) {
c := NewClient(Config{Host: "news.example.com", Port: 563, MaxConnections: 0})
if c.cfg.MaxConnections != 10 {
t.Errorf("MaxConnections should default to 10, got %d", c.cfg.MaxConnections)
}
}
func TestNewClientNegativeConnections(t *testing.T) {
c := NewClient(Config{MaxConnections: -5})
if c.cfg.MaxConnections != 10 {
t.Errorf("MaxConnections should default to 10 for negative, got %d", c.cfg.MaxConnections)
}
}
func TestActiveConnections(t *testing.T) {
c := NewClient(Config{Host: "localhost", Port: 119})
if c.ActiveConnections() != 0 {
t.Errorf("ActiveConnections = %d, want 0", c.ActiveConnections())
}
}
func TestStatus(t *testing.T) {
c := NewClient(Config{Host: "news.example.com", Port: 563})
s := c.Status()
if s != "0 connections (0 pooled) to news.example.com:563" {
t.Errorf("Status = %q", s)
}
}
func TestCloseIdempotent(t *testing.T) {
c := NewClient(Config{Host: "localhost", Port: 119})
// Close should be idempotent
if err := c.Close(); err != nil {
t.Errorf("first Close: %v", err)
}
if err := c.Close(); err != nil {
t.Errorf("second Close: %v", err)
}
}
func TestArticleNotFoundError(t *testing.T) {
err := &ArticleNotFoundError{MessageID: "abc123@news.example.com"}
msg := err.Error()
if msg != "nntp: article not found: abc123@news.example.com" {
t.Errorf("Error() = %q", msg)
}
}
func TestReadDotBody(t *testing.T) {
tests := []struct {
name string
input string
want string
}{
{
"simple body",
"Hello World\r\n.\r\n",
"Hello World\n",
},
{
"multiline",
"Line 1\r\nLine 2\r\nLine 3\r\n.\r\n",
"Line 1\nLine 2\nLine 3\n",
},
{
"dot-stuffed line",
"..This starts with a dot\r\n.\r\n",
".This starts with a dot\n",
},
{
"empty body",
".\r\n",
"",
},
{
"binary-like data",
"=ybegin line=128 size=1024 name=test.bin\r\nsome encoded data\r\n=yend\r\n.\r\n",
"=ybegin line=128 size=1024 name=test.bin\nsome encoded data\n=yend\n",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := bufio.NewReader(bytes.NewBufferString(tt.input))
got, err := readDotBody(r)
if err != nil {
t.Fatalf("readDotBody: %v", err)
}
if string(got) != tt.want {
t.Errorf("readDotBody = %q, want %q", string(got), tt.want)
}
})
}
}
func TestReadDotBodyEOF(t *testing.T) {
// No dot terminator — should read until EOF
r := bufio.NewReader(bytes.NewBufferString("partial data\r\n"))
got, err := readDotBody(r)
if err != nil {
t.Fatalf("readDotBody EOF: %v", err)
}
if string(got) != "partial data\n" {
t.Errorf("readDotBody EOF = %q", string(got))
}
}

View file

@ -267,3 +267,784 @@ func TestStripAngleBrackets(t *testing.T) {
t.Errorf("MessageID not stripped: got %q", nzb.Files[0].Segments[0].MessageID)
}
}
// --- Malformed / edge-case XML inputs ---
func TestParse_CompletelyEmpty(t *testing.T) {
_, err := Parse(strings.NewReader(""))
if err == nil {
t.Error("expected error for completely empty input")
}
}
func TestParse_OnlyWhitespace(t *testing.T) {
_, err := Parse(strings.NewReader(" \n\t "))
if err == nil {
t.Error("expected error for whitespace-only input")
}
}
func TestParse_ValidXMLButNotNZB(t *testing.T) {
_, err := Parse(strings.NewReader(`<?xml version="1.0"?><html><body>Hello</body></html>`))
if err == nil {
t.Error("expected error for non-NZB XML")
}
}
func TestParse_NZBWithNoSegments(t *testing.T) {
xml := `<?xml version="1.0"?>
<nzb xmlns="http://www.newzbin.com/DTD/2003/nzb">
<file poster="test" date="0" subject="&quot;test.bin&quot;">
<groups><group>alt.test</group></groups>
<segments></segments>
</file>
</nzb>`
_, err := Parse(strings.NewReader(xml))
if err == nil {
t.Error("expected error for file with no segments")
}
}
func TestParse_SegmentWithEmptyMessageID(t *testing.T) {
xml := `<?xml version="1.0"?>
<nzb xmlns="http://www.newzbin.com/DTD/2003/nzb">
<file poster="test" date="0" subject="&quot;test.bin&quot;">
<groups><group>alt.test</group></groups>
<segments>
<segment bytes="100" number="1"> </segment>
</segments>
</file>
</nzb>`
_, err := Parse(strings.NewReader(xml))
if err == nil {
t.Error("expected error: segment with empty/whitespace message ID should be skipped, leaving no valid files")
}
}
func TestParse_MixedValidAndEmptySegments(t *testing.T) {
xml := `<?xml version="1.0"?>
<nzb xmlns="http://www.newzbin.com/DTD/2003/nzb">
<file poster="test" date="0" subject="&quot;test.bin&quot;">
<groups><group>alt.test</group></groups>
<segments>
<segment bytes="100" number="1">valid@id</segment>
<segment bytes="200" number="2"> </segment>
<segment bytes="300" number="3">also-valid@id</segment>
</segments>
</file>
</nzb>`
nzb, err := Parse(strings.NewReader(xml))
if err != nil {
t.Fatalf("Parse failed: %v", err)
}
if len(nzb.Files[0].Segments) != 2 {
t.Errorf("expected 2 valid segments, got %d", len(nzb.Files[0].Segments))
}
}
// --- Metadata / Head parsing ---
func TestParse_MetaPassword(t *testing.T) {
xml := `<?xml version="1.0"?>
<nzb xmlns="http://www.newzbin.com/DTD/2003/nzb">
<head>
<meta type="password">s3cr3t</meta>
<meta type="title">My Movie</meta>
<meta type="category">Movies</meta>
</head>
<file poster="test" date="0" subject="&quot;test.bin&quot;">
<groups><group>alt.test</group></groups>
<segments>
<segment bytes="100" number="1">seg@id</segment>
</segments>
</file>
</nzb>`
nzb, err := Parse(strings.NewReader(xml))
if err != nil {
t.Fatalf("Parse failed: %v", err)
}
if nzb.Password != "s3cr3t" {
t.Errorf("Password: got %q, want %q", nzb.Password, "s3cr3t")
}
if nzb.Meta["title"] != "My Movie" {
t.Errorf("Meta title: got %q", nzb.Meta["title"])
}
if nzb.Meta["category"] != "Movies" {
t.Errorf("Meta category: got %q", nzb.Meta["category"])
}
}
func TestParse_MetaPasswordWithWhitespace(t *testing.T) {
xml := `<?xml version="1.0"?>
<nzb xmlns="http://www.newzbin.com/DTD/2003/nzb">
<head>
<meta type="password"> padded </meta>
</head>
<file poster="test" date="0" subject="&quot;test.bin&quot;">
<groups><group>alt.test</group></groups>
<segments>
<segment bytes="100" number="1">seg@id</segment>
</segments>
</file>
</nzb>`
nzb, err := Parse(strings.NewReader(xml))
if err != nil {
t.Fatalf("Parse failed: %v", err)
}
if nzb.Password != "padded" {
t.Errorf("Password should be trimmed: got %q", nzb.Password)
}
}
func TestParse_NoHead(t *testing.T) {
xml := `<?xml version="1.0"?>
<nzb xmlns="http://www.newzbin.com/DTD/2003/nzb">
<file poster="test" date="0" subject="&quot;test.bin&quot;">
<groups><group>alt.test</group></groups>
<segments>
<segment bytes="100" number="1">seg@id</segment>
</segments>
</file>
</nzb>`
nzb, err := Parse(strings.NewReader(xml))
if err != nil {
t.Fatalf("Parse failed: %v", err)
}
if nzb.Password != "" {
t.Errorf("Password should be empty: got %q", nzb.Password)
}
if len(nzb.Meta) != 0 {
t.Errorf("Meta should be empty: got %v", nzb.Meta)
}
}
func TestParse_MetaWithEmptyType(t *testing.T) {
xml := `<?xml version="1.0"?>
<nzb xmlns="http://www.newzbin.com/DTD/2003/nzb">
<head>
<meta type="">ignored</meta>
<meta type="name">kept</meta>
</head>
<file poster="test" date="0" subject="&quot;test.bin&quot;">
<groups><group>alt.test</group></groups>
<segments>
<segment bytes="100" number="1">seg@id</segment>
</segments>
</file>
</nzb>`
nzb, err := Parse(strings.NewReader(xml))
if err != nil {
t.Fatalf("Parse failed: %v", err)
}
if _, ok := nzb.Meta[""]; ok {
t.Error("empty-type meta should not be stored")
}
if nzb.Meta["name"] != "kept" {
t.Errorf("Meta name: got %q", nzb.Meta["name"])
}
}
// --- Multiple files ---
func TestParse_MultipleFilesVariousTypes(t *testing.T) {
xml := `<?xml version="1.0"?>
<nzb xmlns="http://www.newzbin.com/DTD/2003/nzb">
<file poster="bot" date="1700000000" subject="&quot;movie.mkv&quot; yEnc (1/100)">
<groups><group>alt.binaries.movies</group></groups>
<segments>
<segment bytes="768000" number="1">mkv001@ex</segment>
<segment bytes="768000" number="2">mkv002@ex</segment>
</segments>
</file>
<file poster="bot" date="1700000000" subject="&quot;movie.nfo&quot; yEnc (1/1)">
<groups><group>alt.binaries.movies</group></groups>
<segments>
<segment bytes="4096" number="1">nfo001@ex</segment>
</segments>
</file>
<file poster="bot" date="1700000000" subject="&quot;movie.par2&quot; yEnc (1/1)">
<groups><group>alt.binaries.movies</group></groups>
<segments>
<segment bytes="32768" number="1">par001@ex</segment>
</segments>
</file>
<file poster="bot" date="1700000000" subject="&quot;movie.vol0+1.par2&quot; yEnc (1/1)">
<groups><group>alt.binaries.movies</group></groups>
<segments>
<segment bytes="65536" number="1">parv001@ex</segment>
</segments>
</file>
<file poster="bot" date="1700000000" subject="&quot;sample.mkv&quot; yEnc (1/1)">
<groups><group>alt.binaries.movies</group></groups>
<segments>
<segment bytes="10000" number="1">sample001@ex</segment>
</segments>
</file>
</nzb>`
nzb, err := Parse(strings.NewReader(xml))
if err != nil {
t.Fatalf("Parse failed: %v", err)
}
if len(nzb.Files) != 5 {
t.Fatalf("expected 5 files, got %d", len(nzb.Files))
}
// ContentFiles should exclude nfo, par2, par2 vol, and sample
content := nzb.ContentFiles()
if len(content) != 1 {
t.Errorf("ContentFiles: got %d, want 1", len(content))
}
if len(content) > 0 && content[0].Filename() != "movie.mkv" {
t.Errorf("ContentFiles[0]: got %q, want movie.mkv", content[0].Filename())
}
// Par2Files
par2 := nzb.Par2Files()
if len(par2) != 2 {
t.Errorf("Par2Files: got %d, want 2", len(par2))
}
if !nzb.HasPar2() {
t.Error("HasPar2 should be true")
}
if nzb.HasRars() {
t.Error("HasRars should be false for this NZB")
}
}
// --- Segment ordering / number parsing ---
func TestParse_SegmentNumberParsing(t *testing.T) {
xml := `<?xml version="1.0"?>
<nzb xmlns="http://www.newzbin.com/DTD/2003/nzb">
<file poster="test" date="0" subject="&quot;test.bin&quot;">
<groups><group>alt.test</group></groups>
<segments>
<segment bytes="100" number="3">c@id</segment>
<segment bytes="200" number="1">a@id</segment>
<segment bytes="300" number="2">b@id</segment>
</segments>
</file>
</nzb>`
nzb, err := Parse(strings.NewReader(xml))
if err != nil {
t.Fatalf("Parse failed: %v", err)
}
segs := nzb.Files[0].Segments
if len(segs) != 3 {
t.Fatalf("expected 3 segments, got %d", len(segs))
}
// Parse preserves order from XML; sorting is done by the downloader
// Verify numbers are parsed correctly
numbers := make(map[int]bool)
for _, s := range segs {
numbers[s.Number] = true
}
for _, want := range []int{1, 2, 3} {
if !numbers[want] {
t.Errorf("missing segment number %d", want)
}
}
}
func TestParse_SegmentBytesZero(t *testing.T) {
xml := `<?xml version="1.0"?>
<nzb xmlns="http://www.newzbin.com/DTD/2003/nzb">
<file poster="test" date="0" subject="&quot;test.bin&quot;">
<groups><group>alt.test</group></groups>
<segments>
<segment bytes="0" number="1">seg@id</segment>
</segments>
</file>
</nzb>`
nzb, err := Parse(strings.NewReader(xml))
if err != nil {
t.Fatalf("Parse failed: %v", err)
}
if nzb.Files[0].Segments[0].Bytes != 0 {
t.Errorf("expected 0 bytes, got %d", nzb.Files[0].Segments[0].Bytes)
}
}
func TestParse_SegmentBytesNonNumeric(t *testing.T) {
xml := `<?xml version="1.0"?>
<nzb xmlns="http://www.newzbin.com/DTD/2003/nzb">
<file poster="test" date="0" subject="&quot;test.bin&quot;">
<groups><group>alt.test</group></groups>
<segments>
<segment bytes="abc" number="1">seg@id</segment>
</segments>
</file>
</nzb>`
nzb, err := Parse(strings.NewReader(xml))
if err != nil {
t.Fatalf("Parse failed: %v", err)
}
// Non-numeric bytes should parse as 0
if nzb.Files[0].Segments[0].Bytes != 0 {
t.Errorf("non-numeric bytes should be 0, got %d", nzb.Files[0].Segments[0].Bytes)
}
}
// --- File helper methods ---
func TestFileTotalBytes(t *testing.T) {
f := File{
Segments: []Segment{
{Bytes: 100}, {Bytes: 200}, {Bytes: 300},
},
}
if got := f.TotalBytes(); got != 600 {
t.Errorf("TotalBytes: got %d, want 600", got)
}
}
func TestFileTotalBytes_Empty(t *testing.T) {
f := File{}
if got := f.TotalBytes(); got != 0 {
t.Errorf("TotalBytes of empty file: got %d, want 0", got)
}
}
func TestFileExtension_Various(t *testing.T) {
tests := []struct {
subject string
want string
}{
{`"file.MKV" yEnc`, ".mkv"},
{`"file.RAR" yEnc`, ".rar"},
{`"file.Par2" yEnc`, ".par2"},
{`"noext" yEnc`, ""},
{`"file.tar.gz" yEnc`, ".gz"},
}
for _, tt := range tests {
f := File{Subject: tt.subject}
if got := f.Extension(); got != tt.want {
t.Errorf("Extension(%q) = %q, want %q", tt.subject, got, tt.want)
}
}
}
// --- LargestFile edge cases ---
func TestLargestFile_EmptyNZB(t *testing.T) {
nzb := &NZB{}
if nzb.LargestFile() != nil {
t.Error("LargestFile should return nil for empty NZB")
}
}
func TestLargestFile_SingleFile(t *testing.T) {
nzb := &NZB{
Files: []File{
{Subject: `"only.bin"`, Segments: []Segment{{Bytes: 100}}},
},
}
largest := nzb.LargestFile()
if largest == nil {
t.Fatal("LargestFile should not be nil")
}
if largest.Filename() != "only.bin" {
t.Errorf("got %q", largest.Filename())
}
}
func TestLargestFile_MultipleSameSize(t *testing.T) {
nzb := &NZB{
Files: []File{
{Subject: `"a.bin"`, Segments: []Segment{{Bytes: 100}}},
{Subject: `"b.bin"`, Segments: []Segment{{Bytes: 100}}},
},
}
largest := nzb.LargestFile()
if largest == nil {
t.Fatal("LargestFile should not be nil")
}
// Should return the first one (stable)
if largest.Filename() != "a.bin" {
t.Errorf("got %q, expected first file for equal sizes", largest.Filename())
}
}
// --- IsObfuscated ---
func TestIsObfuscated_Normal(t *testing.T) {
nzb := &NZB{
Files: []File{
{Subject: `"Movie.2024.1080p.BluRay.x264-GROUP.mkv"`},
},
}
if nzb.IsObfuscated() {
t.Error("normal filename should not be obfuscated")
}
}
func TestIsObfuscated_HexName(t *testing.T) {
nzb := &NZB{
Files: []File{
{Subject: `"a1b2c3d4e5f6a7b8c9d0e1f2.mkv"`},
},
}
if !nzb.IsObfuscated() {
t.Error("hex-like filename should be obfuscated")
}
}
func TestIsObfuscated_EmptyFiles(t *testing.T) {
nzb := &NZB{}
if nzb.IsObfuscated() {
t.Error("empty NZB should not be obfuscated")
}
}
func TestIsObfuscated_ShortHex(t *testing.T) {
// Short name (<=10 chars) should not trigger obfuscation
nzb := &NZB{
Files: []File{
{Subject: `"abcdef.mkv"`},
},
}
if nzb.IsObfuscated() {
t.Error("short hex-like name should not be obfuscated")
}
}
// --- isMetadataFile ---
func TestIsMetadataFile(t *testing.T) {
tests := []struct {
name string
want bool
}{
{"file.par2", true},
{"file.nfo", true},
{"file.sfv", true},
{"file.nzb", true},
{"file.txt", true},
{"file.jpg", true},
{"file.png", true},
{"file.url", true},
{"file.mkv", false},
{"file.rar", false},
{"file.avi", false},
{"FILE.PAR2", true},
{"FILE.NFO", true},
}
for _, tt := range tests {
if got := isMetadataFile(tt.name); got != tt.want {
t.Errorf("isMetadataFile(%q) = %v, want %v", tt.name, got, tt.want)
}
}
}
// --- isSampleFile ---
func TestIsSampleFile(t *testing.T) {
tests := []struct {
name string
want bool
}{
{"movie.sample.mkv", true},
{"Sample.mkv", true},
{"SAMPLE.avi", true},
{"movie-sample-video.mkv", true},
{"movie_sample.mkv", true},
{"sample.mkv", true},
{"resampled.mkv", false}, // "sample" is part of "resampled"
{"movie.mkv", false},
{"my.samples.zip", false}, // "sample" followed by 's' (alphanumeric)
}
for _, tt := range tests {
if got := isSampleFile(tt.name); got != tt.want {
t.Errorf("isSampleFile(%q) = %v, want %v", tt.name, got, tt.want)
}
}
}
// --- isHexLike ---
func TestIsHexLike(t *testing.T) {
tests := []struct {
input string
want bool
}{
{"abcdef0123456789", true},
{"ABCDEF", true},
{"Movie2024", false},
{"aabbccdd", true},
{"xyz_not_hex", false},
}
for _, tt := range tests {
if got := isHexLike(tt.input); got != tt.want {
t.Errorf("isHexLike(%q) = %v, want %v", tt.input, got, tt.want)
}
}
}
// --- sanitizeFilename ---
func TestSanitizeFilename(t *testing.T) {
tests := []struct {
input string
want string
}{
{"simple name", "simple name"},
{"name (1/50)", "name"},
{"file yEnc (01/99)", "file"},
{`path/with\special:chars*?`, `path_with_special_chars__`},
{`"quoted" text`, `_quoted_ text`},
{" spaces ", "spaces"},
}
for _, tt := range tests {
if got := sanitizeFilename(tt.input); got != tt.want {
t.Errorf("sanitizeFilename(%q) = %q, want %q", tt.input, got, tt.want)
}
}
}
// --- Filename fallback ---
func TestFilename_Fallback_NoQuotes(t *testing.T) {
f := File{Subject: "No quotes here yEnc (1/50)"}
got := f.Filename()
if got != "No quotes here" {
t.Errorf("Filename fallback: got %q, want %q", got, "No quotes here")
}
}
func TestFilename_EmptySubject(t *testing.T) {
f := File{Subject: ""}
got := f.Filename()
if got != "" {
t.Errorf("Filename empty subject: got %q, want empty", got)
}
}
// --- NZB aggregate methods on mixed content ---
func TestNZB_HasRars_NoRars(t *testing.T) {
nzb := &NZB{
Files: []File{
{Subject: `"movie.mkv"`},
{Subject: `"movie.par2"`},
},
}
if nzb.HasRars() {
t.Error("HasRars should be false")
}
}
func TestNZB_HasPar2_NoPar2(t *testing.T) {
nzb := &NZB{
Files: []File{
{Subject: `"movie.mkv"`},
{Subject: `"movie.rar"`},
},
}
if nzb.HasPar2() {
t.Error("HasPar2 should be false")
}
}
func TestNZB_TotalSegments_MultiFile(t *testing.T) {
nzb := &NZB{
Files: []File{
{Segments: []Segment{{}, {}, {}}},
{Segments: []Segment{{}, {}}},
},
}
if got := nzb.TotalSegments(); got != 5 {
t.Errorf("TotalSegments: got %d, want 5", got)
}
}
func TestNZB_TotalBytes_MultiFile(t *testing.T) {
nzb := &NZB{
Files: []File{
{Segments: []Segment{{Bytes: 100}, {Bytes: 200}}},
{Segments: []Segment{{Bytes: 300}}},
},
}
if got := nzb.TotalBytes(); got != 600 {
t.Errorf("TotalBytes: got %d, want 600", got)
}
}
// --- isRarFile extended ---
func TestIsRarFile_Extended(t *testing.T) {
tests := []struct {
name string
want bool
}{
{"file.RAR", true}, // case insensitive
{"file.Rar", true},
{"file.s01", true},
{"file.s99", true},
{"file.002", true},
{"file.999", true},
{"file.r0", false}, // too short extension
{"file.rXX", false}, // non-numeric
{"file", false}, // no extension
{"file.mp4", false},
}
for _, tt := range tests {
if got := isRarFile(tt.name); got != tt.want {
t.Errorf("isRarFile(%q) = %v, want %v", tt.name, got, tt.want)
}
}
}
// --- Parse with date edge cases ---
func TestParse_DateNonNumeric(t *testing.T) {
xml := `<?xml version="1.0"?>
<nzb xmlns="http://www.newzbin.com/DTD/2003/nzb">
<file poster="test" date="not-a-number" subject="&quot;test.bin&quot;">
<groups><group>alt.test</group></groups>
<segments>
<segment bytes="100" number="1">seg@id</segment>
</segments>
</file>
</nzb>`
nzb, err := Parse(strings.NewReader(xml))
if err != nil {
t.Fatalf("Parse failed: %v", err)
}
if nzb.Files[0].Date != 0 {
t.Errorf("non-numeric date should be 0, got %d", nzb.Files[0].Date)
}
}
func TestParse_DateEmpty(t *testing.T) {
xml := `<?xml version="1.0"?>
<nzb xmlns="http://www.newzbin.com/DTD/2003/nzb">
<file poster="test" date="" subject="&quot;test.bin&quot;">
<groups><group>alt.test</group></groups>
<segments>
<segment bytes="100" number="1">seg@id</segment>
</segments>
</file>
</nzb>`
nzb, err := Parse(strings.NewReader(xml))
if err != nil {
t.Fatalf("Parse failed: %v", err)
}
if nzb.Files[0].Date != 0 {
t.Errorf("empty date should be 0, got %d", nzb.Files[0].Date)
}
}
// --- Parse: file with all segments having empty IDs should be excluded ---
func TestParse_AllEmptySegments(t *testing.T) {
xml := `<?xml version="1.0"?>
<nzb xmlns="http://www.newzbin.com/DTD/2003/nzb">
<file poster="test" date="0" subject="&quot;bad.bin&quot;">
<groups><group>alt.test</group></groups>
<segments>
<segment bytes="100" number="1"> </segment>
<segment bytes="200" number="2"></segment>
</segments>
</file>
<file poster="test" date="0" subject="&quot;good.bin&quot;">
<groups><group>alt.test</group></groups>
<segments>
<segment bytes="100" number="1">valid@id</segment>
</segments>
</file>
</nzb>`
nzb, err := Parse(strings.NewReader(xml))
if err != nil {
t.Fatalf("Parse failed: %v", err)
}
if len(nzb.Files) != 1 {
t.Fatalf("expected 1 valid file, got %d", len(nzb.Files))
}
if nzb.Files[0].Filename() != "good.bin" {
t.Errorf("expected good.bin, got %q", nzb.Files[0].Filename())
}
}
// --- Groups ---
func TestParse_NoGroups(t *testing.T) {
xml := `<?xml version="1.0"?>
<nzb xmlns="http://www.newzbin.com/DTD/2003/nzb">
<file poster="test" date="0" subject="&quot;test.bin&quot;">
<groups></groups>
<segments>
<segment bytes="100" number="1">seg@id</segment>
</segments>
</file>
</nzb>`
nzb, err := Parse(strings.NewReader(xml))
if err != nil {
t.Fatalf("Parse failed: %v", err)
}
if len(nzb.Files[0].Groups) != 0 {
t.Errorf("expected 0 groups, got %d", len(nzb.Files[0].Groups))
}
}
func TestParse_MultipleGroups(t *testing.T) {
xml := `<?xml version="1.0"?>
<nzb xmlns="http://www.newzbin.com/DTD/2003/nzb">
<file poster="test" date="0" subject="&quot;test.bin&quot;">
<groups>
<group>alt.binaries.movies</group>
<group>alt.binaries.multimedia</group>
<group>alt.binaries.hdtv</group>
</groups>
<segments>
<segment bytes="100" number="1">seg@id</segment>
</segments>
</file>
</nzb>`
nzb, err := Parse(strings.NewReader(xml))
if err != nil {
t.Fatalf("Parse failed: %v", err)
}
if len(nzb.Files[0].Groups) != 3 {
t.Errorf("expected 3 groups, got %d", len(nzb.Files[0].Groups))
}
}
// --- ContentFiles with sample variations ---
func TestContentFiles_ExcludesSamples(t *testing.T) {
nzb := &NZB{
Files: []File{
{Subject: `"movie.mkv"`, Segments: []Segment{{Bytes: 1000, MessageID: "a"}}},
{Subject: `"movie.sample.mkv"`, Segments: []Segment{{Bytes: 100, MessageID: "b"}}},
{Subject: `"Sample/preview.mkv"`, Segments: []Segment{{Bytes: 100, MessageID: "c"}}},
},
}
content := nzb.ContentFiles()
if len(content) != 1 {
t.Errorf("ContentFiles should exclude samples: got %d, want 1", len(content))
}
}
// --- RarFiles with split naming ---
func TestRarFiles_SplitRars(t *testing.T) {
nzb := &NZB{
Files: []File{
{Subject: `"movie.rar"`, Segments: []Segment{{MessageID: "a"}}},
{Subject: `"movie.r00"`, Segments: []Segment{{MessageID: "b"}}},
{Subject: `"movie.r01"`, Segments: []Segment{{MessageID: "c"}}},
{Subject: `"movie.001"`, Segments: []Segment{{MessageID: "d"}}},
{Subject: `"movie.002"`, Segments: []Segment{{MessageID: "e"}}},
{Subject: `"movie.par2"`, Segments: []Segment{{MessageID: "f"}}},
{Subject: `"movie.mkv"`, Segments: []Segment{{MessageID: "g"}}},
},
}
rars := nzb.RarFiles()
if len(rars) != 5 {
t.Errorf("RarFiles: got %d, want 5", len(rars))
}
}

View file

@ -0,0 +1,170 @@
package postprocess
import (
"os"
"path/filepath"
"testing"
)
func TestIsArchiveFile(t *testing.T) {
tests := []struct {
name string
want bool
}{
{"movie.rar", true},
{"movie.RAR", true},
{"movie.part01.rar", true},
{"movie.r00", true},
{"movie.r99", true},
{"movie.s00", true},
{"movie.001", true},
{"movie.099", true},
{"movie.mkv", false},
{"movie.mp4", false},
{"movie.par2", false},
{"movie.nfo", false},
{"movie.txt", false},
{"movie.r", false},
{"movie.abc", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := isArchiveFile(tt.name)
if got != tt.want {
t.Errorf("isArchiveFile(%q) = %v, want %v", tt.name, got, tt.want)
}
})
}
}
func TestIsCleanupTarget(t *testing.T) {
tests := []struct {
name string
want bool
}{
{"content.par2", true},
{"content.PAR2", true},
{"info.nfo", true},
{"checksum.sfv", true},
{"content.nzb", true},
{"content.srr", true},
{"content.srs", true},
{"cover.jpg", true},
{"cover.png", true},
{"readme.txt", true},
{"link.url", true},
{"movie.rar", true},
{"movie.r00", true},
{"movie.s01", true},
{"movie.001", true},
{"movie.mkv", false},
{"movie.mp4", false},
{"movie.avi", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := isCleanupTarget(tt.name)
if got != tt.want {
t.Errorf("isCleanupTarget(%q) = %v, want %v", tt.name, got, tt.want)
}
})
}
}
func TestIsNumeric(t *testing.T) {
tests := []struct {
input string
want bool
}{
{"", false},
{"0", true},
{"123", true},
{"00", true},
{"12a", false},
{"abc", false},
{" 1", false},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
got := isNumeric(tt.input)
if got != tt.want {
t.Errorf("isNumeric(%q) = %v, want %v", tt.input, got, tt.want)
}
})
}
}
func TestListExtractedFiles(t *testing.T) {
dir := t.TempDir()
// Create some files
os.WriteFile(filepath.Join(dir, "movie.mkv"), []byte("video"), 0o644)
os.WriteFile(filepath.Join(dir, "subs.srt"), []byte("subs"), 0o644)
os.WriteFile(filepath.Join(dir, "movie.rar"), []byte("archive"), 0o644)
os.WriteFile(filepath.Join(dir, "movie.r00"), []byte("archive part"), 0o644)
archivePath := filepath.Join(dir, "movie.rar")
files, err := listExtractedFiles(dir, archivePath)
if err != nil {
t.Fatalf("listExtractedFiles: %v", err)
}
// Should exclude .rar and .r00 (archive files in same dir)
// Should include movie.mkv and subs.srt
if len(files) != 2 {
t.Errorf("expected 2 files, got %d: %v", len(files), files)
}
for _, f := range files {
base := filepath.Base(f)
if base != "movie.mkv" && base != "subs.srt" {
t.Errorf("unexpected file: %s", base)
}
}
}
func TestCleanup(t *testing.T) {
dir := t.TempDir()
// Files that should be removed
cleanupFiles := []string{"content.par2", "info.nfo", "checksum.sfv", "movie.rar", "movie.r00"}
for _, name := range cleanupFiles {
os.WriteFile(filepath.Join(dir, name), []byte("data"), 0o644)
}
// Files that should be kept
keepFiles := []string{"movie.mkv", "subs.srt"}
for _, name := range keepFiles {
os.WriteFile(filepath.Join(dir, name), []byte("data"), 0o644)
}
err := Cleanup(dir)
if err != nil {
t.Fatalf("Cleanup: %v", err)
}
// Verify cleanup files are gone
for _, name := range cleanupFiles {
if _, err := os.Stat(filepath.Join(dir, name)); !os.IsNotExist(err) {
t.Errorf("expected %s to be removed", name)
}
}
// Verify kept files still exist
for _, name := range keepFiles {
if _, err := os.Stat(filepath.Join(dir, name)); err != nil {
t.Errorf("expected %s to exist, got: %v", name, err)
}
}
}
func TestPasswordError(t *testing.T) {
err := &PasswordError{Archive: "/tmp/movie.rar"}
msg := err.Error()
if msg != "archive is password protected: /tmp/movie.rar" {
t.Errorf("PasswordError.Error() = %q", msg)
}
}

View file

@ -0,0 +1,156 @@
package postprocess
import (
"os"
"path/filepath"
"testing"
)
func TestFindPar2File(t *testing.T) {
dir := t.TempDir()
// Create par2 files of different sizes
mainPar2 := filepath.Join(dir, "content.par2")
vol1 := filepath.Join(dir, "content.vol000+01.par2")
vol2 := filepath.Join(dir, "content.vol001+02.par2")
os.WriteFile(mainPar2, make([]byte, 100), 0o644) // smallest
os.WriteFile(vol1, make([]byte, 10000), 0o644)
os.WriteFile(vol2, make([]byte, 50000), 0o644)
files := map[string]string{
"content.par2": mainPar2,
"content.vol000+01.par2": vol1,
"content.vol001+02.par2": vol2,
}
result := findPar2File(files)
if result != mainPar2 {
t.Errorf("findPar2File() = %q, want %q (smallest par2)", result, mainPar2)
}
}
func TestFindPar2FileNone(t *testing.T) {
files := map[string]string{
"video.mkv": "/tmp/video.mkv",
"subs.srt": "/tmp/subs.srt",
}
result := findPar2File(files)
if result != "" {
t.Errorf("findPar2File() = %q, want empty", result)
}
}
func TestFindPar2FileEmpty(t *testing.T) {
result := findPar2File(map[string]string{})
if result != "" {
t.Errorf("findPar2File() = %q, want empty", result)
}
}
func TestFindFirstRarPart01(t *testing.T) {
files := map[string]string{
"movie.part01.rar": "/tmp/movie.part01.rar",
"movie.part02.rar": "/tmp/movie.part02.rar",
"movie.part03.rar": "/tmp/movie.part03.rar",
}
result := findFirstRar(files)
if result != "/tmp/movie.part01.rar" {
t.Errorf("findFirstRar() = %q, want part01.rar", result)
}
}
func TestFindFirstRarSingle(t *testing.T) {
files := map[string]string{
"movie.rar": "/tmp/movie.rar",
"movie.r00": "/tmp/movie.r00",
"movie.r01": "/tmp/movie.r01",
}
result := findFirstRar(files)
if result != "/tmp/movie.rar" {
t.Errorf("findFirstRar() = %q, want movie.rar (shortest)", result)
}
}
func TestFindFirstRarSplitFormat(t *testing.T) {
files := map[string]string{
"movie.001": "/tmp/movie.001",
"movie.002": "/tmp/movie.002",
}
result := findFirstRar(files)
if result != "/tmp/movie.001" {
t.Errorf("findFirstRar() = %q, want movie.001", result)
}
}
func TestFindFirstRarNone(t *testing.T) {
files := map[string]string{
"video.mkv": "/tmp/video.mkv",
"subs.srt": "/tmp/subs.srt",
}
result := findFirstRar(files)
if result != "" {
t.Errorf("findFirstRar() = %q, want empty", result)
}
}
func TestFindMainFile(t *testing.T) {
dir := t.TempDir()
// Create video files of different sizes
small := filepath.Join(dir, "small.mkv")
large := filepath.Join(dir, "large.mkv")
nonVideo := filepath.Join(dir, "readme.txt")
os.WriteFile(small, make([]byte, 1000), 0o644)
os.WriteFile(large, make([]byte, 5000), 0o644)
os.WriteFile(nonVideo, make([]byte, 9000), 0o644)
result := findMainFile(dir, []string{small, large, nonVideo})
if result != large {
t.Errorf("findMainFile() = %q, want %q (largest video)", result, large)
}
}
func TestFindMainFileFallbackToDir(t *testing.T) {
dir := t.TempDir()
video := filepath.Join(dir, "movie.mp4")
os.WriteFile(video, make([]byte, 5000), 0o644)
// Pass empty file list — should fallback to scanning dir
result := findMainFile(dir, nil)
if result != video {
t.Errorf("findMainFile() = %q, want %q (dir scan fallback)", result, video)
}
}
func TestFindMainFileEmpty(t *testing.T) {
dir := t.TempDir()
result := findMainFile(dir, nil)
if result != "" {
t.Errorf("findMainFile() = %q, want empty", result)
}
}
func TestFindMainFileMultipleFormats(t *testing.T) {
dir := t.TempDir()
mkv := filepath.Join(dir, "movie.mkv")
mp4 := filepath.Join(dir, "movie.mp4")
avi := filepath.Join(dir, "movie.avi")
os.WriteFile(mkv, make([]byte, 3000), 0o644)
os.WriteFile(mp4, make([]byte, 5000), 0o644) // largest
os.WriteFile(avi, make([]byte, 2000), 0o644)
result := findMainFile(dir, []string{mkv, mp4, avi})
if result != mp4 {
t.Errorf("findMainFile() = %q, want %q", result, mp4)
}
}