feat(cli): upgrade command, rich status, and version cache
- Replace `upgrade` stub with real command (alias for `self-update`) - Also register `update` as alias: `unarr update` works too - Rewrite `status` to show full config, disk usage, daemon state, and update availability with colored sections - Add version check cache (1h TTL) so `status` is instant on repeat runs - Guard against division by zero on empty filesystems - Guard against negative durations from clock skew - Guard against stale PID via heartbeat recency check (2 min) - Add comprehensive test coverage across agent, engine, upgrade, usenet, arr, library, mediaserver, and UI packages - Improve Makefile coverage target to exclude cmd/ glue code - Fix stream handler resource cleanup and ffprobe error handling
This commit is contained in:
parent
01d62ffa13
commit
3e0f3a5a64
33 changed files with 7084 additions and 65 deletions
|
|
@ -324,3 +324,265 @@ func TestHeartbeatWithoutUpgradeSignal(t *testing.T) {
|
|||
t.Errorf("expected no upgrade signal, got %+v", resp.Upgrade)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeregister(t *testing.T) {
|
||||
var received struct {
|
||||
AgentID string `json:"agentId"`
|
||||
}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/internal/agent/deregister" {
|
||||
t.Errorf("path = %s", r.URL.Path)
|
||||
}
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("method = %s, want POST", r.Method)
|
||||
}
|
||||
json.NewDecoder(r.Body).Decode(&received)
|
||||
json.NewEncoder(w).Encode(StatusResponse{Success: true})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key", "unarr-test")
|
||||
err := c.Deregister(context.Background(), "agent-42")
|
||||
if err != nil {
|
||||
t.Fatalf("Deregister failed: %v", err)
|
||||
}
|
||||
if received.AgentID != "agent-42" {
|
||||
t.Errorf("agentId = %q, want agent-42", received.AgentID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchReportStatus(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/internal/agent/status" {
|
||||
t.Errorf("path = %s", r.URL.Path)
|
||||
}
|
||||
var req BatchStatusRequest
|
||||
json.NewDecoder(r.Body).Decode(&req)
|
||||
if len(req.Updates) != 2 {
|
||||
t.Errorf("expected 2 updates, got %d", len(req.Updates))
|
||||
}
|
||||
json.NewEncoder(w).Encode(BatchStatusResponse{
|
||||
Results: []StatusResponse{
|
||||
{Success: true},
|
||||
{Success: true, Cancelled: true},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key", "unarr-test")
|
||||
resp, err := c.BatchReportStatus(context.Background(), []StatusUpdate{
|
||||
{TaskID: "t1", Status: "downloading"},
|
||||
{TaskID: "t2", Status: "completed"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("BatchReportStatus failed: %v", err)
|
||||
}
|
||||
if len(resp.Results) != 2 {
|
||||
t.Fatalf("expected 2 results, got %d", len(resp.Results))
|
||||
}
|
||||
if !resp.Results[1].Cancelled {
|
||||
t.Error("expected result[1].Cancelled=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchNzbs(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/internal/agent/nzb-search" {
|
||||
t.Errorf("path = %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(NzbSearchResponse{
|
||||
Results: []NzbSearchResult{
|
||||
{NzbID: "nzb-1", Title: "Movie.2023.1080p"},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key", "unarr-test")
|
||||
resp, err := c.SearchNzbs(context.Background(), NzbSearchParams{Query: "Movie"})
|
||||
if err != nil {
|
||||
t.Fatalf("SearchNzbs failed: %v", err)
|
||||
}
|
||||
if len(resp.Results) != 1 {
|
||||
t.Fatalf("expected 1 result, got %d", len(resp.Results))
|
||||
}
|
||||
if resp.Results[0].NzbID != "nzb-1" {
|
||||
t.Errorf("nzb ID = %q, want nzb-1", resp.Results[0].NzbID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadNzb(t *testing.T) {
|
||||
nzbContent := []byte(`<?xml version="1.0"?><nzb><file>test</file></nzb>`)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/internal/agent/nzb-download" {
|
||||
t.Errorf("path = %s", r.URL.Path)
|
||||
}
|
||||
if r.URL.Query().Get("nzbId") != "nzb-42" {
|
||||
t.Errorf("nzbId = %q, want nzb-42", r.URL.Query().Get("nzbId"))
|
||||
}
|
||||
w.Write(nzbContent)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key", "unarr-test")
|
||||
data, err := c.DownloadNzb(context.Background(), "nzb-42")
|
||||
if err != nil {
|
||||
t.Fatalf("DownloadNzb failed: %v", err)
|
||||
}
|
||||
if string(data) != string(nzbContent) {
|
||||
t.Errorf("nzb content mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadNzbError(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
w.Write([]byte("NZB not found"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key", "unarr-test")
|
||||
_, err := c.DownloadNzb(context.Background(), "bad-id")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 404 response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUsenetCredentials(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/internal/agent/usenet-credentials" {
|
||||
t.Errorf("path = %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(UsenetCredentials{
|
||||
Host: "news.example.com",
|
||||
Port: 563,
|
||||
SSL: true,
|
||||
Username: "user1",
|
||||
Password: "pass1",
|
||||
MaxConnections: 10,
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key", "unarr-test")
|
||||
creds, err := c.GetUsenetCredentials(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("GetUsenetCredentials failed: %v", err)
|
||||
}
|
||||
if creds.Host != "news.example.com" {
|
||||
t.Errorf("host = %q, want news.example.com", creds.Host)
|
||||
}
|
||||
if creds.Username != "user1" {
|
||||
t.Errorf("username = %q, want user1", creds.Username)
|
||||
}
|
||||
if creds.MaxConnections != 10 {
|
||||
t.Errorf("maxConnections = %d, want 10", creds.MaxConnections)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUsenetUsage(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/internal/agent/usenet-usage" {
|
||||
t.Errorf("path = %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(UsenetUsageResponse{
|
||||
UsedBytes: 5368709120,
|
||||
QuotaBytes: 10737418240,
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key", "unarr-test")
|
||||
usage, err := c.GetUsenetUsage(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("GetUsenetUsage failed: %v", err)
|
||||
}
|
||||
if usage.UsedBytes != 5368709120 {
|
||||
t.Errorf("usedBytes = %d", usage.UsedBytes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigureDebrid(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/internal/agent/debrid-config" {
|
||||
t.Errorf("path = %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(ConfigureDebridResponse{Success: true})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key", "unarr-test")
|
||||
resp, err := c.ConfigureDebrid(context.Background(), ConfigureDebridRequest{
|
||||
Provider: "real-debrid",
|
||||
Token: "rd-token-123",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ConfigureDebrid failed: %v", err)
|
||||
}
|
||||
if !resp.Success {
|
||||
t.Error("expected success=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchDownload(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/internal/agent/batch-download" {
|
||||
t.Errorf("path = %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(BatchDownloadResponse{
|
||||
Queued: 3,
|
||||
NotFound: 1,
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key", "unarr-test")
|
||||
resp, err := c.BatchDownload(context.Background(), BatchDownloadRequest{})
|
||||
if err != nil {
|
||||
t.Fatalf("BatchDownload failed: %v", err)
|
||||
}
|
||||
if resp.Queued != 3 {
|
||||
t.Errorf("queued = %d, want 3", resp.Queued)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncLibrary(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/internal/agent/library-sync" {
|
||||
t.Errorf("path = %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(LibrarySyncResponse{
|
||||
Matched: 10,
|
||||
Synced: 15,
|
||||
Removed: 2,
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key", "unarr-test")
|
||||
resp, err := c.SyncLibrary(context.Background(), LibrarySyncRequest{})
|
||||
if err != nil {
|
||||
t.Fatalf("SyncLibrary failed: %v", err)
|
||||
}
|
||||
if resp.Matched != 10 {
|
||||
t.Errorf("matched = %d, want 10", resp.Matched)
|
||||
}
|
||||
if resp.Synced != 15 {
|
||||
t.Errorf("synced = %d, want 15", resp.Synced)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTMLErrorResponse(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusBadGateway)
|
||||
w.Write([]byte("<html><body>502 Bad Gateway</body></html>"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key", "unarr-test")
|
||||
_, err := c.Register(context.Background(), RegisterRequest{AgentID: "x"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for HTML error page")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
396
internal/arr/client_test.go
Normal file
396
internal/arr/client_test.go
Normal file
|
|
@ -0,0 +1,396 @@
|
|||
package arr
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func newTestServer(t *testing.T, handlers map[string]any) *httptest.Server {
|
||||
t.Helper()
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Check API key header
|
||||
if r.Header.Get("X-Api-Key") != "test-key" {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
handler, ok := handlers[r.URL.Path]
|
||||
if !ok {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(handler)
|
||||
}))
|
||||
}
|
||||
|
||||
func TestNewClient(t *testing.T) {
|
||||
c := NewClient("http://localhost:8989/", "mykey")
|
||||
if c.baseURL != "http://localhost:8989" {
|
||||
t.Errorf("baseURL = %q, want trailing slash trimmed", c.baseURL)
|
||||
}
|
||||
if c.apiKey != "mykey" {
|
||||
t.Errorf("apiKey = %q, want mykey", c.apiKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemStatus(t *testing.T) {
|
||||
srv := newTestServer(t, map[string]any{
|
||||
"/api/v3/system/status": SystemStatus{AppName: "Radarr", Version: "4.0.0"},
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key")
|
||||
status, err := c.SystemStatus()
|
||||
if err != nil {
|
||||
t.Fatalf("SystemStatus: %v", err)
|
||||
}
|
||||
if status.AppName != "Radarr" {
|
||||
t.Errorf("AppName = %q, want Radarr", status.AppName)
|
||||
}
|
||||
if status.Version != "4.0.0" {
|
||||
t.Errorf("Version = %q, want 4.0.0", status.Version)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemStatusFallbackV1(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("X-Api-Key") != "test-key" {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
switch r.URL.Path {
|
||||
case "/api/v3/system/status":
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
case "/api/v1/system/status":
|
||||
json.NewEncoder(w).Encode(SystemStatus{AppName: "Prowlarr", Version: "1.0.0"})
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key")
|
||||
status, err := c.SystemStatus()
|
||||
if err != nil {
|
||||
t.Fatalf("SystemStatus v1 fallback: %v", err)
|
||||
}
|
||||
if status.AppName != "Prowlarr" {
|
||||
t.Errorf("AppName = %q, want Prowlarr", status.AppName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMovies(t *testing.T) {
|
||||
srv := newTestServer(t, map[string]any{
|
||||
"/api/v3/movie": []Movie{
|
||||
{ID: 1, Title: "Inception", Year: 2010, TmdbID: 27205, Monitored: true},
|
||||
{ID: 2, Title: "Tenet", Year: 2020, TmdbID: 577922, HasFile: true},
|
||||
},
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key")
|
||||
movies, err := c.Movies()
|
||||
if err != nil {
|
||||
t.Fatalf("Movies: %v", err)
|
||||
}
|
||||
if len(movies) != 2 {
|
||||
t.Fatalf("expected 2 movies, got %d", len(movies))
|
||||
}
|
||||
if movies[0].Title != "Inception" {
|
||||
t.Errorf("movies[0].Title = %q, want Inception", movies[0].Title)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSeries(t *testing.T) {
|
||||
srv := newTestServer(t, map[string]any{
|
||||
"/api/v3/series": []Series{
|
||||
{ID: 1, Title: "Breaking Bad", Year: 2008, TvdbID: 81189},
|
||||
},
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key")
|
||||
series, err := c.Series()
|
||||
if err != nil {
|
||||
t.Fatalf("Series: %v", err)
|
||||
}
|
||||
if len(series) != 1 {
|
||||
t.Fatalf("expected 1 series, got %d", len(series))
|
||||
}
|
||||
if series[0].Title != "Breaking Bad" {
|
||||
t.Errorf("series[0].Title = %q, want Breaking Bad", series[0].Title)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQualityProfiles(t *testing.T) {
|
||||
srv := newTestServer(t, map[string]any{
|
||||
"/api/v3/qualityprofile": []QualityProfile{
|
||||
{ID: 1, Name: "HD-1080p"},
|
||||
{ID: 2, Name: "Ultra-HD"},
|
||||
},
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key")
|
||||
profiles, err := c.QualityProfiles()
|
||||
if err != nil {
|
||||
t.Fatalf("QualityProfiles: %v", err)
|
||||
}
|
||||
if len(profiles) != 2 {
|
||||
t.Fatalf("expected 2 profiles, got %d", len(profiles))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRootFolders(t *testing.T) {
|
||||
srv := newTestServer(t, map[string]any{
|
||||
"/api/v3/rootfolder": []RootFolder{
|
||||
{ID: 1, Path: "/movies", FreeSpace: 500000000000},
|
||||
},
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key")
|
||||
folders, err := c.RootFolders()
|
||||
if err != nil {
|
||||
t.Fatalf("RootFolders: %v", err)
|
||||
}
|
||||
if len(folders) != 1 {
|
||||
t.Fatalf("expected 1 folder, got %d", len(folders))
|
||||
}
|
||||
if folders[0].Path != "/movies" {
|
||||
t.Errorf("path = %q, want /movies", folders[0].Path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadClients(t *testing.T) {
|
||||
srv := newTestServer(t, map[string]any{
|
||||
"/api/v3/downloadclient": []DownloadClient{
|
||||
{ID: 1, Name: "Transmission", Enable: true, Protocol: "torrent"},
|
||||
},
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key")
|
||||
clients, err := c.DownloadClients()
|
||||
if err != nil {
|
||||
t.Fatalf("DownloadClients: %v", err)
|
||||
}
|
||||
if len(clients) != 1 || clients[0].Name != "Transmission" {
|
||||
t.Errorf("unexpected clients: %+v", clients)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadClientDetails(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("X-Api-Key") != "test-key" {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
if r.URL.Path == "/api/v3/downloadclient/5" {
|
||||
json.NewEncoder(w).Encode(struct {
|
||||
Fields []Field `json:"fields"`
|
||||
}{
|
||||
Fields: []Field{
|
||||
{Name: "host", Value: "localhost"},
|
||||
{Name: "port", Value: 9091},
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key")
|
||||
fields, err := c.DownloadClientDetails(5)
|
||||
if err != nil {
|
||||
t.Fatalf("DownloadClientDetails: %v", err)
|
||||
}
|
||||
if len(fields) != 2 {
|
||||
t.Fatalf("expected 2 fields, got %d", len(fields))
|
||||
}
|
||||
if fields[0].Name != "host" {
|
||||
t.Errorf("fields[0].Name = %q, want host", fields[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTags(t *testing.T) {
|
||||
srv := newTestServer(t, map[string]any{
|
||||
"/api/v3/tag": []Tag{
|
||||
{ID: 1, Label: "unarr"},
|
||||
{ID: 2, Label: "imported"},
|
||||
},
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key")
|
||||
tags, err := c.Tags()
|
||||
if err != nil {
|
||||
t.Fatalf("Tags: %v", err)
|
||||
}
|
||||
if len(tags) != 2 {
|
||||
t.Fatalf("expected 2 tags, got %d", len(tags))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHistory(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("X-Api-Key") != "test-key" {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
if r.URL.Path == "/api/v3/history" {
|
||||
json.NewEncoder(w).Encode(HistoryResponse{
|
||||
Records: []HistoryRecord{
|
||||
{ID: 1, EventType: "grabbed", SourceTitle: "Inception.2010.1080p"},
|
||||
},
|
||||
TotalRecords: 1,
|
||||
})
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key")
|
||||
records, err := c.History(10)
|
||||
if err != nil {
|
||||
t.Fatalf("History: %v", err)
|
||||
}
|
||||
if len(records) != 1 {
|
||||
t.Fatalf("expected 1 record, got %d", len(records))
|
||||
}
|
||||
if records[0].SourceTitle != "Inception.2010.1080p" {
|
||||
t.Errorf("sourceTitle = %q", records[0].SourceTitle)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHistoryDefaultPageSize(t *testing.T) {
|
||||
var requestedPath string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("X-Api-Key") != "test-key" {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
requestedPath = r.URL.String()
|
||||
json.NewEncoder(w).Encode(HistoryResponse{})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key")
|
||||
c.History(0) // should default to 250
|
||||
|
||||
if requestedPath == "" {
|
||||
t.Fatal("no request made")
|
||||
}
|
||||
if !contains(requestedPath, "pageSize=250") {
|
||||
t.Errorf("expected pageSize=250, got path: %s", requestedPath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBlocklist(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("X-Api-Key") != "test-key" {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
if r.URL.Path == "/api/v3/blocklist" {
|
||||
json.NewEncoder(w).Encode(BlocklistResponse{
|
||||
Records: []BlocklistItem{
|
||||
{ID: 1, SourceTitle: "Bad.Release", Data: BlocklistData{InfoHash: "abc123"}},
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key")
|
||||
items, err := c.Blocklist(50)
|
||||
if err != nil {
|
||||
t.Fatalf("Blocklist: %v", err)
|
||||
}
|
||||
if len(items) != 1 || items[0].Data.InfoHash != "abc123" {
|
||||
t.Errorf("unexpected blocklist: %+v", items)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIndexers(t *testing.T) {
|
||||
srv := newTestServer(t, map[string]any{
|
||||
"/api/v1/indexer": []Indexer{
|
||||
{ID: 1, Name: "NZBGeek", Enable: true},
|
||||
{ID: 2, Name: "Torznab", Enable: false},
|
||||
},
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key")
|
||||
indexers, err := c.Indexers()
|
||||
if err != nil {
|
||||
t.Fatalf("Indexers: %v", err)
|
||||
}
|
||||
if len(indexers) != 2 {
|
||||
t.Fatalf("expected 2 indexers, got %d", len(indexers))
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplications(t *testing.T) {
|
||||
srv := newTestServer(t, map[string]any{
|
||||
"/api/v1/applications": []Application{
|
||||
{ID: 1, Name: "Radarr"},
|
||||
},
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key")
|
||||
apps, err := c.Applications()
|
||||
if err != nil {
|
||||
t.Fatalf("Applications: %v", err)
|
||||
}
|
||||
if len(apps) != 1 || apps[0].Name != "Radarr" {
|
||||
t.Errorf("unexpected apps: %+v", apps)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnauthorized(t *testing.T) {
|
||||
srv := newTestServer(t, map[string]any{})
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "wrong-key")
|
||||
_, err := c.SystemStatus()
|
||||
if err == nil {
|
||||
t.Error("expected error for unauthorized request")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPError(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte("internal error"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key")
|
||||
_, err := c.Movies()
|
||||
if err == nil {
|
||||
t.Error("expected error for 500 response")
|
||||
}
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && searchStr(s, substr)
|
||||
}
|
||||
|
||||
func searchStr(s, sub string) bool {
|
||||
for i := 0; i <= len(s)-len(sub); i++ {
|
||||
if s[i:i+len(sub)] == sub {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
@ -1,6 +1,9 @@
|
|||
package arr
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
|
@ -82,3 +85,158 @@ func TestDetectApp(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigDirs(t *testing.T) {
|
||||
dirs := configDirs()
|
||||
if len(dirs) == 0 {
|
||||
t.Error("configDirs() returned empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseConfigXMLEmpty(t *testing.T) {
|
||||
port, apiKey, urlBase := parseConfigXML(strings.NewReader(""))
|
||||
if port != "" || apiKey != "" || urlBase != "" {
|
||||
t.Error("empty input should return empty values")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseConfigXMLNoPort(t *testing.T) {
|
||||
xml := `<Config><ApiKey>key123</ApiKey></Config>`
|
||||
port, apiKey, _ := parseConfigXML(strings.NewReader(xml))
|
||||
if port != "" {
|
||||
t.Errorf("port = %q, want empty", port)
|
||||
}
|
||||
if apiKey != "key123" {
|
||||
t.Errorf("apiKey = %q, want key123", apiKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractHostPortMultipleMappings(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ports string
|
||||
container string
|
||||
want string
|
||||
}{
|
||||
{"ipv6 only", ":::8989->8989/tcp", "8989", "8989"},
|
||||
{"different host port", "0.0.0.0:9999->8989/tcp", "8989", "9999"},
|
||||
{"port in string but no mapping", "something 8989 somewhere", "8989", "8989"},
|
||||
{"no match at all", "0.0.0.0:3000->3000/tcp", "9999", ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := extractHostPort(tt.ports, tt.container)
|
||||
if got != tt.want {
|
||||
t.Errorf("extractHostPort(%q, %q) = %q, want %q", tt.ports, tt.container, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiscoverFromProwlarr(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/v1/applications":
|
||||
json.NewEncoder(w).Encode([]Application{
|
||||
{
|
||||
ID: 1,
|
||||
Name: "Radarr",
|
||||
Fields: []Field{
|
||||
{Name: "baseUrl", Value: "http://localhost:7878"},
|
||||
{Name: "apiKey", Value: "radarr-key-123"},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Name: "Sonarr",
|
||||
Fields: []Field{
|
||||
{Name: "baseUrl", Value: "http://localhost:8989"},
|
||||
{Name: "apiKey", Value: "sonarr-key-456"},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: 3,
|
||||
Name: "Unknown App",
|
||||
Fields: []Field{
|
||||
{Name: "baseUrl", Value: "http://localhost:9000"},
|
||||
{Name: "apiKey", Value: "unknown-key"},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: 4,
|
||||
Name: "Incomplete",
|
||||
Fields: []Field{
|
||||
{Name: "baseUrl", Value: "http://localhost:5000"},
|
||||
// no apiKey → should be skipped
|
||||
},
|
||||
},
|
||||
})
|
||||
case "/api/v3/system/status":
|
||||
json.NewEncoder(w).Encode(SystemStatus{AppName: "Radarr", Version: "4.0.0"})
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
// DiscoverFromProwlarr will try to verify each instance, which will fail
|
||||
// for localhost URLs (not our test server), but that's OK — we test the parsing
|
||||
instances := DiscoverFromProwlarr(srv.URL, "prowlarr-key")
|
||||
|
||||
// Should find Radarr and Sonarr (Unknown and Incomplete skipped)
|
||||
if len(instances) != 2 {
|
||||
t.Fatalf("expected 2 instances, got %d: %+v", len(instances), instances)
|
||||
}
|
||||
|
||||
found := map[string]bool{}
|
||||
for _, inst := range instances {
|
||||
found[inst.App] = true
|
||||
if inst.Source != "prowlarr" {
|
||||
t.Errorf("source = %q, want prowlarr", inst.Source)
|
||||
}
|
||||
}
|
||||
if !found["radarr"] {
|
||||
t.Error("expected radarr instance")
|
||||
}
|
||||
if !found["sonarr"] {
|
||||
t.Error("expected sonarr instance")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerify(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("X-Api-Key") != "valid-key" {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
json.NewEncoder(w).Encode(SystemStatus{AppName: "Radarr", Version: "5.0.0"})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
inst := &Instance{App: "radarr", URL: srv.URL, APIKey: "valid-key"}
|
||||
err := Verify(inst)
|
||||
if err != nil {
|
||||
t.Fatalf("Verify: %v", err)
|
||||
}
|
||||
if inst.Version != "5.0.0" {
|
||||
t.Errorf("version = %q, want 5.0.0", inst.Version)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no api key", func(t *testing.T) {
|
||||
inst := &Instance{App: "radarr", URL: srv.URL}
|
||||
err := Verify(inst)
|
||||
if err == nil {
|
||||
t.Error("expected error for no API key")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid key", func(t *testing.T) {
|
||||
inst := &Instance{App: "radarr", URL: srv.URL, APIKey: "wrong-key"}
|
||||
err := Verify(inst)
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid API key")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
55
internal/cmd/config_menu_test.go
Normal file
55
internal/cmd/config_menu_test.go
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
package cmd
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestValidateSpeed(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
wantErr bool
|
||||
}{
|
||||
{"", false},
|
||||
{"0", false},
|
||||
{" ", false},
|
||||
{"10MB", false},
|
||||
{"500KB", false},
|
||||
{"1GB", false},
|
||||
{"abc", true},
|
||||
{"10XB", true},
|
||||
{"-5MB", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
err := validateSpeed(tt.input)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("validateSpeed(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateDuration(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
wantErr bool
|
||||
}{
|
||||
{"", false},
|
||||
{"30s", false},
|
||||
{"1m", false},
|
||||
{"5m", false},
|
||||
{"1h", false},
|
||||
{"2h30m", false},
|
||||
{"abc", true},
|
||||
{"30", true},
|
||||
{"5x", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
err := validateDuration(tt.input)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("validateDuration(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
55
internal/cmd/daemon_test.go
Normal file
55
internal/cmd/daemon_test.go
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
package cmd
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDeriveWSURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
apiURL string
|
||||
agentID string
|
||||
want string
|
||||
}{
|
||||
{"https://torrentclaw.com", "agent-123", "wss://unarr.torrentclaw.com/ws/agent-123"},
|
||||
{"http://localhost:3000", "a1", ""}, // localhost skipped
|
||||
{"http://127.0.0.1:3000", "a1", ""}, // 127.0.0.1 skipped
|
||||
{"https://torrentclaw.com/", "a1", "wss://unarr.torrentclaw.com/ws/a1"},
|
||||
{"https://api.example.io", "x", "wss://unarr.api.example.io/ws/x"},
|
||||
{"", "agent-123", ""},
|
||||
{"https://torrentclaw.com", "", ""},
|
||||
{"", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.apiURL+"_"+tt.agentID, func(t *testing.T) {
|
||||
got := deriveWSURL(tt.apiURL, tt.agentID)
|
||||
if got != tt.want {
|
||||
t.Errorf("deriveWSURL(%q, %q) = %q, want %q", tt.apiURL, tt.agentID, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatSpeedLog(t *testing.T) {
|
||||
tests := []struct {
|
||||
bps int64
|
||||
want string
|
||||
}{
|
||||
{0, "0 B/s"},
|
||||
{500, "500 B/s"},
|
||||
{1023, "1023 B/s"},
|
||||
{1024, "1 KB/s"},
|
||||
{10240, "10 KB/s"},
|
||||
{1048576, "1.0 MB/s"},
|
||||
{5242880, "5.0 MB/s"},
|
||||
{1073741824, "1.0 GB/s"},
|
||||
{2147483648, "2.0 GB/s"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.want, func(t *testing.T) {
|
||||
got := formatSpeedLog(tt.bps)
|
||||
if got != tt.want {
|
||||
t.Errorf("formatSpeedLog(%d) = %q, want %q", tt.bps, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
43
internal/cmd/helpers_test.go
Normal file
43
internal/cmd/helpers_test.go
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExpandHome(t *testing.T) {
|
||||
home, _ := os.UserHomeDir()
|
||||
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"~/Documents", home + "/Documents"},
|
||||
{"~/", home},
|
||||
{"/absolute/path", "/absolute/path"},
|
||||
{"relative/path", "relative/path"},
|
||||
{"", ""},
|
||||
{"~notexpanded", "~notexpanded"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
got := expandHome(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("expandHome(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultDownloadDir(t *testing.T) {
|
||||
dir := defaultDownloadDir()
|
||||
if dir == "" {
|
||||
t.Error("defaultDownloadDir() returned empty string")
|
||||
}
|
||||
home, _ := os.UserHomeDir()
|
||||
if !strings.HasPrefix(dir, home) {
|
||||
t.Errorf("defaultDownloadDir() = %q, expected to start with home dir %q", dir, home)
|
||||
}
|
||||
}
|
||||
|
|
@ -143,8 +143,9 @@ Source: https://github.com/torrentclaw/unarr`,
|
|||
completionCmd,
|
||||
// Library
|
||||
scanCmd,
|
||||
// Alias: upgrade → self-update
|
||||
newUpgradeCmd(),
|
||||
// Stubs for future commands
|
||||
newStubCmd("upgrade", "Find a better version of a torrent"),
|
||||
newStubCmd("moreseed", "Find same quality with more seeders"),
|
||||
newStubCmd("compare", "Compare two torrents side by side"),
|
||||
newStubCmd("add", "Search and add torrents to your client"),
|
||||
|
|
|
|||
|
|
@ -1,20 +1,26 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fatih/color"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/torrentclaw/unarr/internal/agent"
|
||||
"github.com/torrentclaw/unarr/internal/upgrade"
|
||||
)
|
||||
|
||||
func newStatusCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "status",
|
||||
Short: "Show daemon status and active downloads",
|
||||
Long: `Display the current state of the daemon, active downloads, and recent activity.
|
||||
Short: "Show daemon status, configuration, and update availability",
|
||||
Long: `Display the current state of unarr: version, configuration, daemon status,
|
||||
disk usage, and whether an update is available.
|
||||
|
||||
Shows the configured agent name, download directory, and preferred method.
|
||||
When the daemon is running, also displays active downloads and their progress.`,
|
||||
When the daemon is running, also displays uptime, active downloads, and stats.`,
|
||||
Example: ` unarr status`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return runStatus()
|
||||
|
|
@ -25,27 +31,167 @@ When the daemon is running, also displays active downloads and their progress.`,
|
|||
func runStatus() error {
|
||||
bold := color.New(color.Bold)
|
||||
dim := color.New(color.FgHiBlack)
|
||||
green := color.New(color.FgGreen)
|
||||
yellow := color.New(color.FgYellow)
|
||||
cyan := color.New(color.FgCyan)
|
||||
|
||||
fmt.Println()
|
||||
bold.Printf(" unarr %s\n", Version)
|
||||
dim.Printf(" %s/%s\n", runtime.GOOS, runtime.GOARCH)
|
||||
fmt.Println()
|
||||
|
||||
cfg := loadConfig()
|
||||
|
||||
// ── Configuration ──
|
||||
if cfg.Auth.APIKey == "" {
|
||||
dim.Println(" Not configured. Run 'unarr init' first.")
|
||||
yellow.Println(" ⚠ Not configured. Run 'unarr init' first.")
|
||||
fmt.Println()
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Printf(" Agent: %s (%s)\n", cfg.Agent.Name, cfg.Agent.ID[:8]+"...")
|
||||
fmt.Printf(" Downloads: %s\n", cfg.Download.Dir)
|
||||
fmt.Printf(" Method: %s\n", cfg.Download.PreferredMethod)
|
||||
cyan.Println(" Configuration")
|
||||
agentID := cfg.Agent.ID
|
||||
if len(agentID) > 8 {
|
||||
agentID = agentID[:8] + "..."
|
||||
}
|
||||
fmt.Printf(" Agent: %s (%s)\n", cfg.Agent.Name, agentID)
|
||||
fmt.Printf(" Server: %s\n", cfg.Auth.APIURL)
|
||||
fmt.Printf(" Downloads: %s\n", cfg.Download.Dir)
|
||||
fmt.Printf(" Method: %s\n", cfg.Download.PreferredMethod)
|
||||
if cfg.Download.PreferredQuality != "" {
|
||||
fmt.Printf(" Quality: %s\n", cfg.Download.PreferredQuality)
|
||||
}
|
||||
fmt.Printf(" Concurrent: %d\n", cfg.Download.MaxConcurrent)
|
||||
if cfg.Organize.Enabled {
|
||||
fmt.Printf(" Organize: on")
|
||||
if cfg.Organize.MoviesDir != "" {
|
||||
fmt.Printf(" (movies: %s", cfg.Organize.MoviesDir)
|
||||
if cfg.Organize.TVShowsDir != "" {
|
||||
fmt.Printf(", tv: %s", cfg.Organize.TVShowsDir)
|
||||
}
|
||||
fmt.Print(")")
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
fmt.Println()
|
||||
|
||||
dim.Println(" Daemon not running. Start with 'unarr start'")
|
||||
dim.Println(" (Live status will be shown here when daemon is running)")
|
||||
// ── Disk ──
|
||||
if cfg.Download.Dir != "" {
|
||||
if free, total, err := agent.DiskInfo(cfg.Download.Dir); err == nil && total > 0 {
|
||||
usedPct := float64(total-free) / float64(total) * 100
|
||||
cyan.Println(" Disk")
|
||||
fmt.Printf(" Free: %s / %s (%.0f%% used)\n", formatBytes(free), formatBytes(total), usedPct)
|
||||
if usedPct > 90 {
|
||||
yellow.Println(" ⚠ Low disk space!")
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
}
|
||||
|
||||
// ── Daemon ──
|
||||
cyan.Println(" Daemon")
|
||||
state := agent.ReadState()
|
||||
if state != nil && isDaemonAlive(state) {
|
||||
green.Printf(" Status: running (PID %d)\n", state.PID)
|
||||
fmt.Printf(" Uptime: %s\n", formatDuration(time.Since(state.StartedAt)))
|
||||
fmt.Printf(" Last beat: %s ago\n", formatDuration(time.Since(state.LastHeartbeat)))
|
||||
fmt.Printf(" Active: %d task(s)\n", state.ActiveTasks)
|
||||
fmt.Printf(" Completed: %d\n", state.CompletedCount)
|
||||
if state.FailedCount > 0 {
|
||||
fmt.Printf(" Failed: %d\n", state.FailedCount)
|
||||
}
|
||||
if state.TotalDownloaded > 0 {
|
||||
fmt.Printf(" Downloaded: %s\n", formatBytes(state.TotalDownloaded))
|
||||
}
|
||||
if len(state.MethodStats) > 0 {
|
||||
parts := make([]string, 0, len(state.MethodStats))
|
||||
for method, count := range state.MethodStats {
|
||||
parts = append(parts, fmt.Sprintf("%s:%d", method, count))
|
||||
}
|
||||
fmt.Printf(" Methods: %s\n", strings.Join(parts, ", "))
|
||||
}
|
||||
} else {
|
||||
dim.Println(" Status: stopped")
|
||||
dim.Println(" Start with: unarr start")
|
||||
}
|
||||
fmt.Println()
|
||||
|
||||
// ── Update check (cached: instant if <1h, otherwise async 3s) ──
|
||||
type versionResult struct {
|
||||
version string
|
||||
fromCache bool
|
||||
err error
|
||||
}
|
||||
versionCh := make(chan versionResult, 1)
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
v, cached, err := upgrade.CheckLatestCached(ctx)
|
||||
versionCh <- versionResult{v, cached, err}
|
||||
}()
|
||||
|
||||
cyan.Println(" Update")
|
||||
fmt.Print(" Checking... ")
|
||||
vr := <-versionCh
|
||||
if vr.err != nil {
|
||||
dim.Println("could not check (offline?)")
|
||||
} else {
|
||||
currentClean := strings.TrimPrefix(Version, "v")
|
||||
if currentClean == vr.version {
|
||||
green.Printf("✓ up to date (v%s)\n", vr.version)
|
||||
} else {
|
||||
yellow.Printf("v%s available! ", vr.version)
|
||||
fmt.Printf("Run: unarr upgrade\n")
|
||||
}
|
||||
}
|
||||
fmt.Println()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isDaemonAlive checks if the daemon process from the state file is still running.
|
||||
// Guards against PID reuse by also checking heartbeat recency.
|
||||
func isDaemonAlive(state *agent.DaemonState) bool {
|
||||
if state.PID == 0 {
|
||||
return false
|
||||
}
|
||||
// Reject stale state: if last heartbeat is older than 2 minutes, the daemon
|
||||
// likely crashed and the PID may have been reused by another process.
|
||||
if !state.LastHeartbeat.IsZero() && time.Since(state.LastHeartbeat) > 2*time.Minute {
|
||||
return false
|
||||
}
|
||||
return agent.IsProcessAlive(state.PID)
|
||||
}
|
||||
|
||||
// formatBytes formats bytes into human-readable string.
|
||||
func formatBytes(b int64) string {
|
||||
const unit = 1024
|
||||
if b < unit {
|
||||
return fmt.Sprintf("%d B", b)
|
||||
}
|
||||
div, exp := int64(unit), 0
|
||||
for n := b / unit; n >= unit; n /= unit {
|
||||
div *= unit
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %cB", float64(b)/float64(div), "KMGTPE"[exp])
|
||||
}
|
||||
|
||||
// formatDuration formats a duration into a compact human-readable string.
|
||||
func formatDuration(d time.Duration) string {
|
||||
if d < 0 {
|
||||
return "0s"
|
||||
}
|
||||
if d < time.Minute {
|
||||
return fmt.Sprintf("%ds", int(d.Seconds()))
|
||||
}
|
||||
if d < time.Hour {
|
||||
return fmt.Sprintf("%dm %ds", int(d.Minutes()), int(d.Seconds())%60)
|
||||
}
|
||||
if d < 24*time.Hour {
|
||||
return fmt.Sprintf("%dh %dm", int(d.Hours()), int(d.Minutes())%60)
|
||||
}
|
||||
days := int(d.Hours()) / 24
|
||||
hours := int(d.Hours()) % 24
|
||||
return fmt.Sprintf("%dd %dh", days, hours)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,6 +24,20 @@ var streamRegistry = struct {
|
|||
servers: make(map[string]*engine.StreamServer),
|
||||
}
|
||||
|
||||
// cancelAllStreams cancels all active stream tasks and servers (only 1 stream at a time).
|
||||
func cancelAllStreams() {
|
||||
streamRegistry.mu.Lock()
|
||||
for taskID, cancel := range streamRegistry.cancels {
|
||||
cancel()
|
||||
delete(streamRegistry.cancels, taskID)
|
||||
}
|
||||
for taskID, srv := range streamRegistry.servers {
|
||||
srv.Shutdown(context.Background())
|
||||
delete(streamRegistry.servers, taskID)
|
||||
}
|
||||
streamRegistry.mu.Unlock()
|
||||
}
|
||||
|
||||
// cancelStreamTask cancels a running stream task and shuts down any stream server.
|
||||
func cancelStreamTask(taskID string) {
|
||||
streamRegistry.mu.Lock()
|
||||
|
|
@ -94,7 +108,7 @@ func handleStreamTask(parentCtx context.Context, at agent.Task, reporter *engine
|
|||
}
|
||||
|
||||
// 4. Start HTTP server
|
||||
srv := engine.NewStreamServer(eng, 0)
|
||||
srv := engine.NewStreamServer(eng, cfg.Download.StreamPort)
|
||||
streamURL, err := srv.Start(ctx)
|
||||
if err != nil {
|
||||
task.ErrorMessage = "start HTTP server: " + err.Error()
|
||||
|
|
@ -107,17 +121,27 @@ func handleStreamTask(parentCtx context.Context, at agent.Task, reporter *engine
|
|||
task.StreamURL = streamURL
|
||||
log.Printf("[%s] stream ready: %s", at.ID[:8], streamURL)
|
||||
|
||||
// 6. Progress loop
|
||||
// 6. Unified progress + idle timeout loop
|
||||
eng.StartProgressLoop(ctx)
|
||||
ticker := time.NewTicker(3 * time.Second)
|
||||
defer ticker.Stop()
|
||||
progressTicker := time.NewTicker(3 * time.Second)
|
||||
defer progressTicker.Stop()
|
||||
idleCheck := time.NewTicker(60 * time.Second)
|
||||
defer idleCheck.Stop()
|
||||
completed := false
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Printf("[%s] stream stopped", at.ID[:8])
|
||||
return
|
||||
case <-ticker.C:
|
||||
|
||||
case <-idleCheck.C:
|
||||
if srv.IdleSince() > 30*time.Minute {
|
||||
log.Printf("[%s] stream idle timeout (30m no HTTP requests), shutting down", at.ID[:8])
|
||||
return
|
||||
}
|
||||
|
||||
case <-progressTicker.C:
|
||||
p := eng.Progress()
|
||||
task.UpdateProgress(engine.Progress{
|
||||
DownloadedBytes: p.DownloadedBytes,
|
||||
|
|
@ -129,7 +153,7 @@ func handleStreamTask(parentCtx context.Context, at agent.Task, reporter *engine
|
|||
})
|
||||
|
||||
// Terminal progress
|
||||
if p.TotalBytes > 0 {
|
||||
if !completed && p.TotalBytes > 0 {
|
||||
pct := int(float64(p.DownloadedBytes) / float64(p.TotalBytes) * 100)
|
||||
fmt.Fprintf(os.Stderr, "\r[%s] %d%% — %s/%s @ %s/s peers:%d seeds:%d",
|
||||
at.ID[:8], pct,
|
||||
|
|
@ -137,20 +161,11 @@ func handleStreamTask(parentCtx context.Context, at agent.Task, reporter *engine
|
|||
p.Peers, p.Seeds)
|
||||
}
|
||||
|
||||
if p.DownloadedBytes >= p.TotalBytes && p.TotalBytes > 0 {
|
||||
fmt.Fprint(os.Stderr, "\r\033[2K") // clear progress line
|
||||
if !completed && p.DownloadedBytes >= p.TotalBytes && p.TotalBytes > 0 {
|
||||
fmt.Fprint(os.Stderr, "\r\033[2K")
|
||||
task.Transition(engine.StatusCompleted)
|
||||
log.Printf("[%s] stream download complete, server stays up for 30m or until cancelled", at.ID[:8])
|
||||
// Keep HTTP server running so the player can finish reading.
|
||||
// Auto-shutdown after 30 minutes of idle to prevent resource leaks.
|
||||
idleTimer := time.NewTimer(30 * time.Minute)
|
||||
defer idleTimer.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-idleTimer.C:
|
||||
log.Printf("[%s] stream idle timeout (30m), shutting down", at.ID[:8])
|
||||
}
|
||||
return
|
||||
log.Printf("[%s] stream download complete, server stays up until idle (30m)", at.ID[:8])
|
||||
completed = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
30
internal/cmd/upgrade.go
Normal file
30
internal/cmd/upgrade.go
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// newUpgradeCmd creates the `unarr upgrade` command as an alias for `self-update`.
|
||||
func newUpgradeCmd() *cobra.Command {
|
||||
var force bool
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "upgrade",
|
||||
Aliases: []string{"update"},
|
||||
Short: "Update unarr to the latest version",
|
||||
Long: `Download and install the latest version of unarr.
|
||||
|
||||
This is an alias for 'unarr self-update'. Checks GitHub for the latest
|
||||
release, verifies the checksum, and replaces the current binary.
|
||||
A backup is kept at <binary>.backup.`,
|
||||
Example: ` unarr upgrade
|
||||
unarr upgrade --force`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return runSelfUpdate(force)
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().BoolVarP(&force, "force", "f", false, "reinstall even if already up to date")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
|
@ -2,6 +2,7 @@ package engine
|
|||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
|
@ -83,3 +84,223 @@ func TestManagerShutdown(t *testing.T) {
|
|||
mgr.Shutdown(ctx)
|
||||
// Should not hang
|
||||
}
|
||||
|
||||
func TestManagerDefaultConcurrency(t *testing.T) {
|
||||
reporter := NewProgressReporter(
|
||||
agent.NewClient("http://localhost", "test", "test"),
|
||||
1*time.Second,
|
||||
)
|
||||
mgr := NewManager(ManagerConfig{MaxConcurrent: 0}, reporter)
|
||||
if cap(mgr.sem) != 3 {
|
||||
t.Errorf("default MaxConcurrent should be 3, got %d", cap(mgr.sem))
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerGetTask(t *testing.T) {
|
||||
reporter := NewProgressReporter(
|
||||
agent.NewClient("http://localhost", "test", "test"),
|
||||
1*time.Second,
|
||||
)
|
||||
mgr := NewManager(ManagerConfig{MaxConcurrent: 2}, reporter)
|
||||
|
||||
// No task added
|
||||
if task := mgr.GetTask("nonexistent"); task != nil {
|
||||
t.Error("expected nil for nonexistent task")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerActiveTasks(t *testing.T) {
|
||||
reporter := NewProgressReporter(
|
||||
agent.NewClient("http://localhost", "test", "test"),
|
||||
1*time.Second,
|
||||
)
|
||||
mgr := NewManager(ManagerConfig{MaxConcurrent: 2}, reporter)
|
||||
|
||||
tasks := mgr.ActiveTasks()
|
||||
if len(tasks) != 0 {
|
||||
t.Errorf("expected 0 active tasks, got %d", len(tasks))
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerSubmitCompletesWithValidFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
// Create a file that verify() will accept
|
||||
filePath := dir + "/movie.mkv"
|
||||
os.WriteFile(filePath, make([]byte, 1024), 0o644)
|
||||
|
||||
reporter := &mockStatusReporter{}
|
||||
pr := &ProgressReporter{
|
||||
reporter: reporter,
|
||||
interval: 100 * time.Millisecond,
|
||||
latest: make(map[string]*Task),
|
||||
lastReported: make(map[string]TaskStatus),
|
||||
}
|
||||
|
||||
dl := &resultMockDownloader{
|
||||
method: MethodTorrent,
|
||||
result: &Result{
|
||||
FilePath: filePath,
|
||||
FileName: "movie.mkv",
|
||||
Method: MethodTorrent,
|
||||
Size: 1024,
|
||||
},
|
||||
}
|
||||
|
||||
mgr := NewManager(ManagerConfig{
|
||||
MaxConcurrent: 2,
|
||||
OutputDir: dir,
|
||||
}, pr, dl)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go pr.Run(ctx)
|
||||
|
||||
mgr.Submit(ctx, agent.Task{
|
||||
ID: "task-complete-test1",
|
||||
InfoHash: "abc123def456abc123def456abc123def456abc1",
|
||||
Title: "Test Movie",
|
||||
PreferredMethod: "torrent",
|
||||
})
|
||||
|
||||
mgr.Wait()
|
||||
cancel()
|
||||
|
||||
// Task should have completed successfully
|
||||
// (we can't check directly since it's removed from active map after processing)
|
||||
}
|
||||
|
||||
func TestManagerCancelTask(t *testing.T) {
|
||||
reporter := NewProgressReporter(
|
||||
agent.NewClient("http://localhost", "test", "test"),
|
||||
1*time.Second,
|
||||
)
|
||||
|
||||
dl := &slowMockDownloader{method: MethodTorrent}
|
||||
mgr := NewManager(ManagerConfig{
|
||||
MaxConcurrent: 2,
|
||||
OutputDir: t.TempDir(),
|
||||
}, reporter, dl)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go reporter.Run(ctx)
|
||||
|
||||
mgr.Submit(ctx, agent.Task{
|
||||
ID: "task-cancel-test12",
|
||||
InfoHash: "abc123def456abc123def456abc123def456abc1",
|
||||
Title: "Cancel Me",
|
||||
PreferredMethod: "torrent",
|
||||
})
|
||||
|
||||
// Give it time to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
mgr.CancelTask("task-cancel-test12")
|
||||
mgr.Wait()
|
||||
}
|
||||
|
||||
func TestManagerPauseTask(t *testing.T) {
|
||||
reporter := NewProgressReporter(
|
||||
agent.NewClient("http://localhost", "test", "test"),
|
||||
1*time.Second,
|
||||
)
|
||||
|
||||
dl := &slowMockDownloader{method: MethodTorrent}
|
||||
mgr := NewManager(ManagerConfig{
|
||||
MaxConcurrent: 2,
|
||||
OutputDir: t.TempDir(),
|
||||
}, reporter, dl)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go reporter.Run(ctx)
|
||||
|
||||
mgr.Submit(ctx, agent.Task{
|
||||
ID: "task-pause-test123",
|
||||
InfoHash: "abc123def456abc123def456abc123def456abc1",
|
||||
Title: "Pause Me",
|
||||
PreferredMethod: "torrent",
|
||||
})
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
mgr.PauseTask("task-pause-test123")
|
||||
mgr.Wait()
|
||||
}
|
||||
|
||||
func TestManagerCancelAndDeleteFiles(t *testing.T) {
|
||||
reporter := NewProgressReporter(
|
||||
agent.NewClient("http://localhost", "test", "test"),
|
||||
1*time.Second,
|
||||
)
|
||||
|
||||
dl := &slowMockDownloader{method: MethodTorrent}
|
||||
mgr := NewManager(ManagerConfig{
|
||||
MaxConcurrent: 2,
|
||||
OutputDir: t.TempDir(),
|
||||
}, reporter, dl)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go reporter.Run(ctx)
|
||||
|
||||
mgr.Submit(ctx, agent.Task{
|
||||
ID: "task-delfile-test12",
|
||||
InfoHash: "abc123def456abc123def456abc123def456abc1",
|
||||
Title: "Delete Me",
|
||||
PreferredMethod: "torrent",
|
||||
})
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
mgr.CancelAndDeleteFiles("task-delfile-test12")
|
||||
mgr.Wait()
|
||||
}
|
||||
|
||||
func TestManagerCancelNonexistent(t *testing.T) {
|
||||
reporter := NewProgressReporter(
|
||||
agent.NewClient("http://localhost", "test", "test"),
|
||||
1*time.Second,
|
||||
)
|
||||
mgr := NewManager(ManagerConfig{MaxConcurrent: 2}, reporter)
|
||||
// Should not panic
|
||||
mgr.CancelTask("nonexistent")
|
||||
mgr.PauseTask("nonexistent")
|
||||
mgr.CancelAndDeleteFiles("nonexistent")
|
||||
}
|
||||
|
||||
// resultMockDownloader returns a configurable result
|
||||
type resultMockDownloader struct {
|
||||
method DownloadMethod
|
||||
result *Result
|
||||
}
|
||||
|
||||
func (m *resultMockDownloader) Method() DownloadMethod { return m.method }
|
||||
func (m *resultMockDownloader) Available(_ context.Context, _ *Task) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
func (m *resultMockDownloader) Download(_ context.Context, _ *Task, _ string, _ chan<- Progress) (*Result, error) {
|
||||
return m.result, nil
|
||||
}
|
||||
func (m *resultMockDownloader) Pause(_ string) error { return nil }
|
||||
func (m *resultMockDownloader) Cancel(_ string) error { return nil }
|
||||
func (m *resultMockDownloader) Shutdown(_ context.Context) error { return nil }
|
||||
|
||||
// slowMockDownloader blocks until context is cancelled
|
||||
type slowMockDownloader struct {
|
||||
method DownloadMethod
|
||||
}
|
||||
|
||||
func (m *slowMockDownloader) Method() DownloadMethod { return m.method }
|
||||
func (m *slowMockDownloader) Available(_ context.Context, _ *Task) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
func (m *slowMockDownloader) Download(ctx context.Context, _ *Task, _ string, _ chan<- Progress) (*Result, error) {
|
||||
<-ctx.Done()
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
func (m *slowMockDownloader) Pause(_ string) error { return nil }
|
||||
func (m *slowMockDownloader) Cancel(_ string) error { return nil }
|
||||
func (m *slowMockDownloader) Shutdown(_ context.Context) error { return nil }
|
||||
|
|
|
|||
50
internal/engine/method_test.go
Normal file
50
internal/engine/method_test.go
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
package engine
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDownloadMethodConstants(t *testing.T) {
|
||||
if MethodTorrent != "torrent" {
|
||||
t.Errorf("MethodTorrent = %q, want torrent", MethodTorrent)
|
||||
}
|
||||
if MethodDebrid != "debrid" {
|
||||
t.Errorf("MethodDebrid = %q, want debrid", MethodDebrid)
|
||||
}
|
||||
if MethodUsenet != "usenet" {
|
||||
t.Errorf("MethodUsenet = %q, want usenet", MethodUsenet)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressStruct(t *testing.T) {
|
||||
p := Progress{
|
||||
DownloadedBytes: 1024,
|
||||
TotalBytes: 2048,
|
||||
SpeedBps: 512,
|
||||
ETA: 10,
|
||||
Peers: 5,
|
||||
Seeds: 3,
|
||||
FileName: "movie.mkv",
|
||||
}
|
||||
|
||||
if p.DownloadedBytes != 1024 {
|
||||
t.Errorf("DownloadedBytes = %d, want 1024", p.DownloadedBytes)
|
||||
}
|
||||
if p.FileName != "movie.mkv" {
|
||||
t.Errorf("FileName = %q, want movie.mkv", p.FileName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResultStruct(t *testing.T) {
|
||||
r := Result{
|
||||
FilePath: "/downloads/movie.mkv",
|
||||
FileName: "movie.mkv",
|
||||
Method: MethodTorrent,
|
||||
Size: 1073741824,
|
||||
}
|
||||
|
||||
if r.Method != MethodTorrent {
|
||||
t.Errorf("Method = %q, want torrent", r.Method)
|
||||
}
|
||||
if r.Size != 1073741824 {
|
||||
t.Errorf("Size = %d, want 1073741824", r.Size)
|
||||
}
|
||||
}
|
||||
181
internal/engine/organize_expand_test.go
Normal file
181
internal/engine/organize_expand_test.go
Normal file
|
|
@ -0,0 +1,181 @@
|
|||
package engine
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestReplaceFile(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
backupDir := filepath.Join(tmp, "backups")
|
||||
|
||||
// Create "old" file
|
||||
oldPath := filepath.Join(tmp, "movie.mkv")
|
||||
os.WriteFile(oldPath, []byte("old content"), 0o644)
|
||||
|
||||
// Create "new" file
|
||||
newPath := filepath.Join(tmp, "movie-new.mkv")
|
||||
os.WriteFile(newPath, []byte("new better content"), 0o644)
|
||||
|
||||
err := replaceFile(oldPath, newPath, backupDir)
|
||||
if err != nil {
|
||||
t.Fatalf("replaceFile: %v", err)
|
||||
}
|
||||
|
||||
// Old path should now contain new content
|
||||
data, err := os.ReadFile(oldPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read old path: %v", err)
|
||||
}
|
||||
if string(data) != "new better content" {
|
||||
t.Errorf("old path content = %q, want 'new better content'", string(data))
|
||||
}
|
||||
|
||||
// Backup should exist
|
||||
entries, _ := os.ReadDir(backupDir)
|
||||
if len(entries) != 1 {
|
||||
t.Errorf("expected 1 backup file, got %d", len(entries))
|
||||
}
|
||||
|
||||
// New file should be gone
|
||||
if _, err := os.Stat(newPath); !os.IsNotExist(err) {
|
||||
t.Error("new file should have been moved/deleted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplaceFileOldNotFound(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
err := replaceFile(filepath.Join(tmp, "nonexistent.mkv"), filepath.Join(tmp, "new.mkv"), "")
|
||||
if err == nil {
|
||||
t.Error("expected error when old file doesn't exist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyFile(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
|
||||
src := filepath.Join(tmp, "source.txt")
|
||||
dst := filepath.Join(tmp, "dest.txt")
|
||||
|
||||
content := []byte("hello world copy test")
|
||||
os.WriteFile(src, content, 0o644)
|
||||
|
||||
err := copyFile(src, dst)
|
||||
if err != nil {
|
||||
t.Fatalf("copyFile: %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(dst)
|
||||
if err != nil {
|
||||
t.Fatalf("read dest: %v", err)
|
||||
}
|
||||
if string(data) != string(content) {
|
||||
t.Errorf("dest content = %q, want %q", string(data), string(content))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyFileSrcNotFound(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
err := copyFile(filepath.Join(tmp, "nope.txt"), filepath.Join(tmp, "out.txt"))
|
||||
if err == nil {
|
||||
t.Error("expected error when source doesn't exist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrganizeNoDirs(t *testing.T) {
|
||||
r := &Result{FilePath: "/tmp/file.mkv", FileName: "file.mkv"}
|
||||
task := &Task{Title: "Movie"}
|
||||
|
||||
path, err := organize(r, task, OrganizeConfig{Enabled: true})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if path != "/tmp/file.mkv" {
|
||||
t.Errorf("should return original path when no dirs configured, got %q", path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrganizeNilResult(t *testing.T) {
|
||||
task := &Task{Title: "Movie"}
|
||||
path, err := organize(&Result{}, task, OrganizeConfig{Enabled: true})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if path != "" {
|
||||
t.Errorf("expected empty path for empty result, got %q", path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrganizeMovieDirectory(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
srcDir := filepath.Join(tmp, "src", "MovieDir")
|
||||
os.MkdirAll(srcDir, 0o755)
|
||||
os.WriteFile(filepath.Join(srcDir, "movie.mkv"), []byte("data"), 0o644)
|
||||
|
||||
moviesDir := filepath.Join(tmp, "Movies")
|
||||
|
||||
r := &Result{FilePath: srcDir, FileName: "MovieDir"}
|
||||
task := &Task{Title: "My Movie 2023"}
|
||||
|
||||
path, err := organize(r, task, OrganizeConfig{
|
||||
Enabled: true,
|
||||
MoviesDir: moviesDir,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if path == srcDir {
|
||||
t.Error("directory should have moved")
|
||||
}
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
t.Errorf("organized directory should exist at %s", path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrganizeSeasonOnly(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
srcFile := filepath.Join(tmp, "Show.S01.Complete.mkv")
|
||||
os.WriteFile(srcFile, []byte("data"), 0o644)
|
||||
|
||||
tvDir := filepath.Join(tmp, "TV")
|
||||
|
||||
r := &Result{FilePath: srcFile, FileName: "Show.S01.Complete.mkv"}
|
||||
task := &Task{Title: "Show S01"}
|
||||
|
||||
path, err := organize(r, task, OrganizeConfig{
|
||||
Enabled: true,
|
||||
TVShowsDir: tvDir,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
dir := filepath.Dir(path)
|
||||
if filepath.Base(dir) != "Season 01" {
|
||||
t.Errorf("expected Season 01 directory, got %q", filepath.Base(dir))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanTitleEdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"", ""},
|
||||
{"Simple Title", "Simple Title"},
|
||||
{"Title (2023) 1080p BluRay", "Title"},
|
||||
{"Title 720p HDTV", "Title"},
|
||||
{"Title x264 HEVC", "Title"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
got := cleanTitle(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("cleanTitle(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
419
internal/engine/progress_test.go
Normal file
419
internal/engine/progress_test.go
Normal file
|
|
@ -0,0 +1,419 @@
|
|||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/torrentclaw/unarr/internal/agent"
|
||||
)
|
||||
|
||||
// mockStatusReporter records calls to ReportStatus.
|
||||
type mockStatusReporter struct {
|
||||
mu sync.Mutex
|
||||
calls []agent.StatusUpdate
|
||||
resp *agent.StatusResponse
|
||||
respErr error
|
||||
}
|
||||
|
||||
func (m *mockStatusReporter) ReportStatus(_ context.Context, update agent.StatusUpdate) (*agent.StatusResponse, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.calls = append(m.calls, update)
|
||||
if m.resp != nil {
|
||||
return m.resp, m.respErr
|
||||
}
|
||||
return &agent.StatusResponse{}, m.respErr
|
||||
}
|
||||
|
||||
// mockBatchReporter records batch calls.
|
||||
type mockBatchReporter struct {
|
||||
mockStatusReporter
|
||||
batchCalls [][]agent.StatusUpdate
|
||||
batchResp *agent.BatchStatusResponse
|
||||
}
|
||||
|
||||
func (m *mockBatchReporter) BatchReportStatus(_ context.Context, updates []agent.StatusUpdate) (*agent.BatchStatusResponse, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.batchCalls = append(m.batchCalls, updates)
|
||||
if m.batchResp != nil {
|
||||
return m.batchResp, nil
|
||||
}
|
||||
results := make([]agent.StatusResponse, len(updates))
|
||||
return &agent.BatchStatusResponse{Results: results}, nil
|
||||
}
|
||||
|
||||
func TestProgressReporter_TrackUntrack(t *testing.T) {
|
||||
reporter := &mockStatusReporter{}
|
||||
pr := &ProgressReporter{
|
||||
reporter: reporter,
|
||||
interval: time.Second,
|
||||
latest: make(map[string]*Task),
|
||||
lastReported: make(map[string]TaskStatus),
|
||||
}
|
||||
|
||||
task := &Task{ID: "task-001", Status: StatusDownloading}
|
||||
pr.Track(task)
|
||||
|
||||
pr.mu.Lock()
|
||||
if _, ok := pr.latest["task-001"]; !ok {
|
||||
t.Error("task should be tracked")
|
||||
}
|
||||
pr.mu.Unlock()
|
||||
|
||||
pr.Untrack("task-001")
|
||||
|
||||
pr.mu.Lock()
|
||||
if _, ok := pr.latest["task-001"]; ok {
|
||||
t.Error("task should be untracked")
|
||||
}
|
||||
pr.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestProgressReporter_FlushReportsFinalStates(t *testing.T) {
|
||||
reporter := &mockStatusReporter{}
|
||||
pr := &ProgressReporter{
|
||||
reporter: reporter,
|
||||
interval: time.Second,
|
||||
latest: make(map[string]*Task),
|
||||
lastReported: make(map[string]TaskStatus),
|
||||
}
|
||||
|
||||
completed := &Task{ID: "task-completed-1234", Status: StatusCompleted}
|
||||
pr.Track(completed)
|
||||
|
||||
pr.flush(context.Background())
|
||||
|
||||
reporter.mu.Lock()
|
||||
defer reporter.mu.Unlock()
|
||||
if len(reporter.calls) != 1 {
|
||||
t.Fatalf("expected 1 report, got %d", len(reporter.calls))
|
||||
}
|
||||
if reporter.calls[0].TaskID != "task-completed-1234" {
|
||||
t.Errorf("reported wrong task: %s", reporter.calls[0].TaskID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressReporter_FlushSkipsWhenNotWatching(t *testing.T) {
|
||||
reporter := &mockStatusReporter{}
|
||||
pr := &ProgressReporter{
|
||||
reporter: reporter,
|
||||
interval: time.Second,
|
||||
isWatching: func() bool { return false },
|
||||
latest: make(map[string]*Task),
|
||||
lastReported: make(map[string]TaskStatus),
|
||||
lastCheckAt: time.Now(), // not due for control check
|
||||
}
|
||||
|
||||
// Active downloading task, already reported as downloading
|
||||
task := &Task{ID: "task-active-12345678", Status: StatusDownloading}
|
||||
pr.Track(task)
|
||||
pr.mu.Lock()
|
||||
pr.lastReported["task-active-12345678"] = StatusDownloading
|
||||
pr.mu.Unlock()
|
||||
|
||||
pr.flush(context.Background())
|
||||
|
||||
reporter.mu.Lock()
|
||||
defer reporter.mu.Unlock()
|
||||
if len(reporter.calls) != 0 {
|
||||
t.Errorf("expected 0 reports when not watching (no transition), got %d", len(reporter.calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressReporter_FlushReportsTransitions(t *testing.T) {
|
||||
reporter := &mockStatusReporter{}
|
||||
pr := &ProgressReporter{
|
||||
reporter: reporter,
|
||||
interval: time.Second,
|
||||
isWatching: func() bool { return false },
|
||||
latest: make(map[string]*Task),
|
||||
lastReported: make(map[string]TaskStatus),
|
||||
lastCheckAt: time.Now(),
|
||||
}
|
||||
|
||||
// Task transitioning from resolving to downloading
|
||||
task := &Task{ID: "task-trans-12345678", Status: StatusDownloading}
|
||||
pr.Track(task)
|
||||
pr.mu.Lock()
|
||||
pr.lastReported["task-trans-12345678"] = StatusResolving
|
||||
pr.mu.Unlock()
|
||||
|
||||
pr.flush(context.Background())
|
||||
|
||||
reporter.mu.Lock()
|
||||
defer reporter.mu.Unlock()
|
||||
if len(reporter.calls) != 1 {
|
||||
t.Fatalf("expected 1 report for transition, got %d", len(reporter.calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressReporter_FlushActiveWhenWatching(t *testing.T) {
|
||||
reporter := &mockStatusReporter{}
|
||||
pr := &ProgressReporter{
|
||||
reporter: reporter,
|
||||
interval: time.Second,
|
||||
isWatching: func() bool { return true },
|
||||
latest: make(map[string]*Task),
|
||||
lastReported: make(map[string]TaskStatus),
|
||||
}
|
||||
|
||||
task := &Task{ID: "task-watch-12345678", Status: StatusDownloading}
|
||||
pr.Track(task)
|
||||
pr.mu.Lock()
|
||||
pr.lastReported["task-watch-12345678"] = StatusDownloading
|
||||
pr.mu.Unlock()
|
||||
|
||||
pr.flush(context.Background())
|
||||
|
||||
reporter.mu.Lock()
|
||||
defer reporter.mu.Unlock()
|
||||
if len(reporter.calls) != 1 {
|
||||
t.Fatalf("expected 1 report when watching active task, got %d", len(reporter.calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressReporter_HandleResponseCancel(t *testing.T) {
|
||||
reporter := &mockStatusReporter{
|
||||
resp: &agent.StatusResponse{Cancelled: true},
|
||||
}
|
||||
|
||||
var cancelledID string
|
||||
pr := &ProgressReporter{
|
||||
reporter: reporter,
|
||||
interval: time.Second,
|
||||
latest: make(map[string]*Task),
|
||||
lastReported: make(map[string]TaskStatus),
|
||||
onCancel: func(id string) { cancelledID = id },
|
||||
}
|
||||
|
||||
task := &Task{ID: "task-cancel-1234567", Status: StatusCompleted}
|
||||
pr.Track(task)
|
||||
|
||||
pr.flush(context.Background())
|
||||
|
||||
if cancelledID != "task-cancel-1234567" {
|
||||
t.Errorf("expected cancel handler called with task ID, got %q", cancelledID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressReporter_HandleResponsePause(t *testing.T) {
|
||||
reporter := &mockStatusReporter{
|
||||
resp: &agent.StatusResponse{Paused: true},
|
||||
}
|
||||
|
||||
var pausedID string
|
||||
pr := &ProgressReporter{
|
||||
reporter: reporter,
|
||||
interval: time.Second,
|
||||
latest: make(map[string]*Task),
|
||||
lastReported: make(map[string]TaskStatus),
|
||||
onPause: func(id string) { pausedID = id },
|
||||
}
|
||||
|
||||
task := &Task{ID: "task-paused-1234567", Status: StatusCompleted}
|
||||
pr.Track(task)
|
||||
|
||||
pr.flush(context.Background())
|
||||
|
||||
if pausedID != "task-paused-1234567" {
|
||||
t.Errorf("expected pause handler called, got %q", pausedID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressReporter_HandleResponseDeleteFiles(t *testing.T) {
|
||||
reporter := &mockStatusReporter{
|
||||
resp: &agent.StatusResponse{Cancelled: true, DeleteFiles: true},
|
||||
}
|
||||
|
||||
var deletedID string
|
||||
pr := &ProgressReporter{
|
||||
reporter: reporter,
|
||||
interval: time.Second,
|
||||
latest: make(map[string]*Task),
|
||||
lastReported: make(map[string]TaskStatus),
|
||||
onDeleteFiles: func(id string) { deletedID = id },
|
||||
}
|
||||
|
||||
task := &Task{ID: "task-delete-1234567", Status: StatusCompleted}
|
||||
pr.Track(task)
|
||||
|
||||
pr.flush(context.Background())
|
||||
|
||||
if deletedID != "task-delete-1234567" {
|
||||
t.Errorf("expected deleteFiles handler called, got %q", deletedID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressReporter_HandleResponseStream(t *testing.T) {
|
||||
reporter := &mockStatusReporter{
|
||||
resp: &agent.StatusResponse{StreamRequested: true},
|
||||
}
|
||||
|
||||
var streamID string
|
||||
pr := &ProgressReporter{
|
||||
reporter: reporter,
|
||||
interval: time.Second,
|
||||
latest: make(map[string]*Task),
|
||||
lastReported: make(map[string]TaskStatus),
|
||||
onStreamRequested: func(id string) { streamID = id },
|
||||
}
|
||||
|
||||
// Task with no stream URL yet
|
||||
task := &Task{ID: "task-stream-1234567", Status: StatusCompleted}
|
||||
pr.Track(task)
|
||||
|
||||
pr.flush(context.Background())
|
||||
|
||||
if streamID != "task-stream-1234567" {
|
||||
t.Errorf("expected stream handler called, got %q", streamID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressReporter_HandleResponseWatchingChanged(t *testing.T) {
|
||||
reporter := &mockStatusReporter{
|
||||
resp: &agent.StatusResponse{Watching: true},
|
||||
}
|
||||
|
||||
var watchingValue bool
|
||||
pr := &ProgressReporter{
|
||||
reporter: reporter,
|
||||
interval: time.Second,
|
||||
latest: make(map[string]*Task),
|
||||
lastReported: make(map[string]TaskStatus),
|
||||
onWatchingChanged: func(w bool) { watchingValue = w },
|
||||
}
|
||||
|
||||
task := &Task{ID: "task-watch2-1234567", Status: StatusCompleted}
|
||||
pr.Track(task)
|
||||
|
||||
pr.flush(context.Background())
|
||||
|
||||
if !watchingValue {
|
||||
t.Error("expected watchingChanged called with true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressReporter_BatchFlush(t *testing.T) {
|
||||
batcher := &mockBatchReporter{
|
||||
batchResp: &agent.BatchStatusResponse{
|
||||
Results: []agent.StatusResponse{{}, {}},
|
||||
},
|
||||
}
|
||||
|
||||
pr := &ProgressReporter{
|
||||
reporter: batcher,
|
||||
interval: time.Second,
|
||||
isWatching: func() bool { return true },
|
||||
latest: make(map[string]*Task),
|
||||
lastReported: make(map[string]TaskStatus),
|
||||
}
|
||||
|
||||
pr.Track(&Task{ID: "task-batch1-1234567", Status: StatusDownloading})
|
||||
pr.Track(&Task{ID: "task-batch2-1234567", Status: StatusDownloading})
|
||||
|
||||
pr.flush(context.Background())
|
||||
|
||||
batcher.mu.Lock()
|
||||
defer batcher.mu.Unlock()
|
||||
|
||||
if len(batcher.batchCalls) != 1 {
|
||||
t.Fatalf("expected 1 batch call, got %d", len(batcher.batchCalls))
|
||||
}
|
||||
if len(batcher.batchCalls[0]) != 2 {
|
||||
t.Errorf("expected 2 updates in batch, got %d", len(batcher.batchCalls[0]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressReporter_RunStopsOnCancel(t *testing.T) {
|
||||
reporter := &mockStatusReporter{}
|
||||
pr := &ProgressReporter{
|
||||
reporter: reporter,
|
||||
interval: 50 * time.Millisecond,
|
||||
latest: make(map[string]*Task),
|
||||
lastReported: make(map[string]TaskStatus),
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
err := pr.Run(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Run should return nil on context cancel, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressReporter_ReportFinal(t *testing.T) {
|
||||
reporter := &mockStatusReporter{}
|
||||
pr := &ProgressReporter{
|
||||
reporter: reporter,
|
||||
interval: time.Second,
|
||||
latest: make(map[string]*Task),
|
||||
lastReported: make(map[string]TaskStatus),
|
||||
}
|
||||
|
||||
task := &Task{ID: "task-final-12345678", Status: StatusCompleted}
|
||||
pr.Track(task)
|
||||
|
||||
pr.ReportFinal(context.Background(), task)
|
||||
|
||||
reporter.mu.Lock()
|
||||
defer reporter.mu.Unlock()
|
||||
if len(reporter.calls) != 1 {
|
||||
t.Fatalf("expected 1 final report, got %d", len(reporter.calls))
|
||||
}
|
||||
|
||||
// Should be untracked after final report
|
||||
pr.mu.Lock()
|
||||
if _, ok := pr.latest["task-final-12345678"]; ok {
|
||||
t.Error("task should be untracked after ReportFinal")
|
||||
}
|
||||
pr.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestProgressReporter_SetHandlers(t *testing.T) {
|
||||
pr := &ProgressReporter{
|
||||
latest: make(map[string]*Task),
|
||||
lastReported: make(map[string]TaskStatus),
|
||||
}
|
||||
|
||||
pr.SetCancelHandler(func(id string) {})
|
||||
pr.SetPauseHandler(func(id string) {})
|
||||
pr.SetDeleteFilesHandler(func(id string) {})
|
||||
pr.SetStreamRequestedHandler(func(id string) {})
|
||||
pr.SetWatchingFunc(func() bool { return true })
|
||||
pr.SetWatchingChangedHandler(func(w bool) {})
|
||||
|
||||
if pr.onCancel == nil || pr.onPause == nil || pr.onDeleteFiles == nil ||
|
||||
pr.onStreamRequested == nil || pr.isWatching == nil || pr.onWatchingChanged == nil {
|
||||
t.Error("expected all handlers to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressReporter_ControlCheckDue(t *testing.T) {
|
||||
reporter := &mockStatusReporter{}
|
||||
pr := &ProgressReporter{
|
||||
reporter: reporter,
|
||||
interval: time.Second,
|
||||
isWatching: func() bool { return false },
|
||||
latest: make(map[string]*Task),
|
||||
lastReported: make(map[string]TaskStatus),
|
||||
lastCheckAt: time.Now().Add(-31 * time.Second), // 31s ago - due for control check
|
||||
}
|
||||
|
||||
task := &Task{ID: "task-ctrl-123456789", Status: StatusDownloading}
|
||||
pr.Track(task)
|
||||
pr.mu.Lock()
|
||||
pr.lastReported["task-ctrl-123456789"] = StatusDownloading
|
||||
pr.mu.Unlock()
|
||||
|
||||
pr.flush(context.Background())
|
||||
|
||||
reporter.mu.Lock()
|
||||
defer reporter.mu.Unlock()
|
||||
if len(reporter.calls) != 1 {
|
||||
t.Errorf("expected 1 report for control check, got %d", len(reporter.calls))
|
||||
}
|
||||
}
|
||||
|
|
@ -11,6 +11,7 @@ import (
|
|||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/anacrolix/torrent"
|
||||
|
|
@ -24,11 +25,12 @@ type fileProvider interface {
|
|||
|
||||
// StreamServer serves a torrent file over HTTP with Range request support.
|
||||
type StreamServer struct {
|
||||
provider fileProvider
|
||||
server *http.Server
|
||||
port int
|
||||
url string
|
||||
upnpMapping *UPnPMapping
|
||||
provider fileProvider
|
||||
server *http.Server
|
||||
port int
|
||||
url string
|
||||
upnpMapping *UPnPMapping
|
||||
lastActivity atomic.Int64 // UnixNano of last HTTP request
|
||||
}
|
||||
|
||||
// NewStreamServer creates a new HTTP server for streaming via StreamEngine.
|
||||
|
|
@ -93,11 +95,38 @@ func NewStreamServerFromDisk(filePath string, port int) *StreamServer {
|
|||
}
|
||||
}
|
||||
|
||||
// Start begins serving the file on all interfaces. Returns the best reachable URL:
|
||||
// 1. UPnP public IP (accessible from anywhere on the internet)
|
||||
// 2. Tailscale IP (accessible from any device in the tailnet)
|
||||
// 3. LAN IP (accessible from local network)
|
||||
// FindVideoFile scans a directory (recursively) for the largest video file.
|
||||
// Returns empty string if no video file found.
|
||||
func FindVideoFile(dir string) string {
|
||||
var best string
|
||||
var bestSize int64
|
||||
|
||||
filepath.WalkDir(dir, func(path string, d os.DirEntry, err error) error {
|
||||
if err != nil || d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
ext := strings.ToLower(filepath.Ext(d.Name()))
|
||||
if !VideoExts[ext] {
|
||||
return nil
|
||||
}
|
||||
info, err := d.Info()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if info.Size() > bestSize {
|
||||
best = path
|
||||
bestSize = info.Size()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return best
|
||||
}
|
||||
|
||||
// Start begins serving the file on all interfaces. Returns the best reachable URL.
|
||||
// The file is served as-is — the user's media player (VLC, mpv, etc.) handles decoding.
|
||||
func (ss *StreamServer) Start(ctx context.Context) (string, error) {
|
||||
ss.lastActivity.Store(time.Now().UnixNano())
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/stream", ss.handler)
|
||||
|
||||
|
|
@ -107,19 +136,9 @@ func (ss *StreamServer) Start(ctx context.Context) (string, error) {
|
|||
return "", fmt.Errorf("listen on %s: %w", addr, err)
|
||||
}
|
||||
|
||||
// Extract actual port (important when port=0)
|
||||
ss.port = listener.Addr().(*net.TCPAddr).Port
|
||||
|
||||
// Try UPnP for public internet access (like Plex Remote Access)
|
||||
if mapping, upnpErr := setupUPnP(ss.port); upnpErr == nil {
|
||||
ss.upnpMapping = mapping
|
||||
ss.url = fmt.Sprintf("http://%s:%d/stream", mapping.ExternalIP, mapping.ExternalPort)
|
||||
log.Printf("stream: UPnP mapped %s:%d -> local:%d", mapping.ExternalIP, mapping.ExternalPort, ss.port)
|
||||
} else {
|
||||
// Fallback: Tailscale IP > LAN IP > 127.0.0.1
|
||||
ss.url = fmt.Sprintf("http://%s:%d/stream", reachableIP(), ss.port)
|
||||
log.Printf("stream: UPnP unavailable (%v), using %s", upnpErr, ss.url)
|
||||
}
|
||||
ss.url = fmt.Sprintf("http://%s:%d/stream", reachableIP(), ss.port)
|
||||
log.Printf("stream: serving on %s", ss.url)
|
||||
|
||||
ss.server = &http.Server{
|
||||
Handler: mux,
|
||||
|
|
@ -141,6 +160,15 @@ func (ss *StreamServer) URL() string { return ss.url }
|
|||
// Port returns the bound port.
|
||||
func (ss *StreamServer) Port() int { return ss.port }
|
||||
|
||||
// IdleSince returns how long since the last HTTP request was received.
|
||||
func (ss *StreamServer) IdleSince() time.Duration {
|
||||
last := ss.lastActivity.Load()
|
||||
if last == 0 {
|
||||
return 0
|
||||
}
|
||||
return time.Since(time.Unix(0, last))
|
||||
}
|
||||
|
||||
// Shutdown gracefully stops the HTTP server and removes the UPnP port mapping.
|
||||
func (ss *StreamServer) Shutdown(ctx context.Context) error {
|
||||
ss.upnpMapping.Remove()
|
||||
|
|
@ -151,6 +179,8 @@ func (ss *StreamServer) Shutdown(ctx context.Context) error {
|
|||
}
|
||||
|
||||
func (ss *StreamServer) handler(w http.ResponseWriter, r *http.Request) {
|
||||
ss.lastActivity.Store(time.Now().UnixNano())
|
||||
|
||||
// CORS headers — only when browser sends Origin (HTTPS site → localhost)
|
||||
if origin := r.Header.Get("Origin"); origin != "" {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
|
|
|
|||
|
|
@ -78,6 +78,12 @@ func ExtractMediaInfo(ctx context.Context, ffprobePath, filePath string) (*Media
|
|||
return nil, fmt.Errorf("ffprobe JSON parse failed: %w", err)
|
||||
}
|
||||
|
||||
return parseFFprobeOutput(data)
|
||||
}
|
||||
|
||||
// parseFFprobeOutput converts parsed ffprobe JSON into MediaInfo.
|
||||
// Separated from ExtractMediaInfo so it can be tested without running ffprobe.
|
||||
func parseFFprobeOutput(data ffprobeOutput) (*MediaInfo, error) {
|
||||
if len(data.Streams) == 0 {
|
||||
return nil, fmt.Errorf("ffprobe returned no streams")
|
||||
}
|
||||
|
|
|
|||
430
internal/library/mediainfo/ffprobe_test.go
Normal file
430
internal/library/mediainfo/ffprobe_test.go
Normal file
|
|
@ -0,0 +1,430 @@
|
|||
package mediainfo
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseDuration(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want float64
|
||||
}{
|
||||
{"", 0},
|
||||
{"0", 0},
|
||||
{"-5", 0},
|
||||
{"invalid", 0},
|
||||
{"7423.500000", 7423.5},
|
||||
{"120.123456", 120.123},
|
||||
{"3600", 3600},
|
||||
{"0.001", 0.001},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
got := parseDuration(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("parseDuration(%q) = %v, want %v", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTagValue(t *testing.T) {
|
||||
tags := map[string]string{
|
||||
"language": "eng",
|
||||
"title": "Main Audio",
|
||||
"HANDLER": "VideoHandler",
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
key string
|
||||
want string
|
||||
}{
|
||||
{"language", "eng"},
|
||||
{"title", "Main Audio"},
|
||||
{"handler", "VideoHandler"},
|
||||
{"missing", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.key, func(t *testing.T) {
|
||||
got := tagValue(tags, tt.key)
|
||||
if got != tt.want {
|
||||
t.Errorf("tagValue(tags, %q) = %q, want %q", tt.key, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTagValueNil(t *testing.T) {
|
||||
got := tagValue(nil, "language")
|
||||
if got != "" {
|
||||
t.Errorf("tagValue(nil, language) = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContainsAny(t *testing.T) {
|
||||
tests := []struct {
|
||||
s string
|
||||
subs []string
|
||||
want bool
|
||||
}{
|
||||
{"yuv420p10le", []string{"10le", "10be", "p010"}, true},
|
||||
{"yuv420p12be", []string{"10le", "10be", "p010"}, false},
|
||||
{"yuv420p12be", []string{"12le", "12be"}, true},
|
||||
{"yuv420p", []string{"10le", "10be"}, false},
|
||||
{"", []string{"any"}, false},
|
||||
{"something", []string{}, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := containsAny(tt.s, tt.subs...)
|
||||
if got != tt.want {
|
||||
t.Errorf("containsAny(%q, %v) = %v, want %v", tt.s, tt.subs, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFFprobeOutput_BasicH264(t *testing.T) {
|
||||
data := ffprobeOutput{
|
||||
Format: ffprobeFormat{Duration: "7423.5"},
|
||||
Streams: []ffprobeStream{
|
||||
{
|
||||
CodecType: "video",
|
||||
CodecName: "h264",
|
||||
Profile: "High",
|
||||
Width: 1920,
|
||||
Height: 1080,
|
||||
RFrameRate: "24000/1001",
|
||||
},
|
||||
{
|
||||
CodecType: "audio",
|
||||
CodecName: "aac",
|
||||
Channels: 2,
|
||||
Tags: map[string]string{"language": "eng"},
|
||||
Disposition: map[string]int{"default": 1},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mi, err := parseFFprobeOutput(data)
|
||||
if err != nil {
|
||||
t.Fatalf("parseFFprobeOutput: %v", err)
|
||||
}
|
||||
if mi.Video == nil {
|
||||
t.Fatal("expected video info")
|
||||
}
|
||||
if mi.Video.Codec != "h264" {
|
||||
t.Errorf("codec = %q, want h264", mi.Video.Codec)
|
||||
}
|
||||
if mi.Video.Width != 1920 || mi.Video.Height != 1080 {
|
||||
t.Errorf("dimensions = %dx%d, want 1920x1080", mi.Video.Width, mi.Video.Height)
|
||||
}
|
||||
if mi.Video.Profile != "High" {
|
||||
t.Errorf("profile = %q, want High", mi.Video.Profile)
|
||||
}
|
||||
if mi.Video.Duration != 7423.5 {
|
||||
t.Errorf("duration = %v, want 7423.5", mi.Video.Duration)
|
||||
}
|
||||
if mi.Video.FrameRate < 23.975 || mi.Video.FrameRate > 23.977 {
|
||||
t.Errorf("frameRate = %v, want ~23.976", mi.Video.FrameRate)
|
||||
}
|
||||
if len(mi.Audio) != 1 {
|
||||
t.Fatalf("audio tracks = %d, want 1", len(mi.Audio))
|
||||
}
|
||||
if mi.Audio[0].Lang != "en" {
|
||||
t.Errorf("audio lang = %q, want en", mi.Audio[0].Lang)
|
||||
}
|
||||
if !mi.Audio[0].Default {
|
||||
t.Error("expected default audio track")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFFprobeOutput_HEVC_HDR10(t *testing.T) {
|
||||
data := ffprobeOutput{
|
||||
Format: ffprobeFormat{Duration: "3600"},
|
||||
Streams: []ffprobeStream{
|
||||
{
|
||||
CodecType: "video",
|
||||
CodecName: "hevc",
|
||||
Width: 3840,
|
||||
Height: 2160,
|
||||
BitsPerRaw: "10",
|
||||
ColorSpace: "bt2020nc",
|
||||
ColorTransfer: "smpte2084",
|
||||
RFrameRate: "24/1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mi, err := parseFFprobeOutput(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if mi.Video.HDR != "HDR10" {
|
||||
t.Errorf("hdr = %q, want HDR10", mi.Video.HDR)
|
||||
}
|
||||
if mi.Video.BitDepth != 10 {
|
||||
t.Errorf("bitDepth = %d, want 10", mi.Video.BitDepth)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFFprobeOutput_DolbyVisionWithHDR10(t *testing.T) {
|
||||
data := ffprobeOutput{
|
||||
Streams: []ffprobeStream{
|
||||
{
|
||||
CodecType: "video",
|
||||
CodecName: "hevc",
|
||||
Width: 3840,
|
||||
Height: 2160,
|
||||
ColorSpace: "bt2020nc",
|
||||
ColorTransfer: "smpte2084",
|
||||
SideDataList: []sideData{{SideDataType: "DOVI configuration record"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mi, err := parseFFprobeOutput(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if mi.Video.HDR != "DV+HDR10" {
|
||||
t.Errorf("hdr = %q, want DV+HDR10", mi.Video.HDR)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFFprobeOutput_DolbyVisionOnly(t *testing.T) {
|
||||
data := ffprobeOutput{
|
||||
Streams: []ffprobeStream{
|
||||
{
|
||||
CodecType: "video",
|
||||
CodecName: "hevc",
|
||||
Width: 3840,
|
||||
Height: 2160,
|
||||
SideDataList: []sideData{{SideDataType: "DOVI configuration record"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mi, err := parseFFprobeOutput(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if mi.Video.HDR != "DV" {
|
||||
t.Errorf("hdr = %q, want DV", mi.Video.HDR)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFFprobeOutput_HLG(t *testing.T) {
|
||||
data := ffprobeOutput{
|
||||
Streams: []ffprobeStream{
|
||||
{
|
||||
CodecType: "video",
|
||||
CodecName: "hevc",
|
||||
Width: 3840,
|
||||
Height: 2160,
|
||||
ColorSpace: "bt2020nc",
|
||||
ColorTransfer: "arib-std-b67",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mi, err := parseFFprobeOutput(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if mi.Video.HDR != "HLG" {
|
||||
t.Errorf("hdr = %q, want HLG", mi.Video.HDR)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFFprobeOutput_MultiAudioAndSubtitles(t *testing.T) {
|
||||
data := ffprobeOutput{
|
||||
Format: ffprobeFormat{Duration: "5400"},
|
||||
Streams: []ffprobeStream{
|
||||
{CodecType: "video", CodecName: "h264", Width: 1920, Height: 1080},
|
||||
{
|
||||
CodecType: "audio", CodecName: "ac3", Channels: 6,
|
||||
Tags: map[string]string{"language": "eng", "title": "English 5.1"},
|
||||
Disposition: map[string]int{"default": 1},
|
||||
},
|
||||
{
|
||||
CodecType: "audio", CodecName: "aac", Channels: 2,
|
||||
Tags: map[string]string{"language": "spa"},
|
||||
},
|
||||
{
|
||||
CodecType: "subtitle", CodecName: "subrip",
|
||||
Tags: map[string]string{"language": "eng"},
|
||||
},
|
||||
{
|
||||
CodecType: "subtitle", CodecName: "ass",
|
||||
Tags: map[string]string{"language": "spa"},
|
||||
Disposition: map[string]int{"forced": 1},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mi, err := parseFFprobeOutput(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(mi.Audio) != 2 {
|
||||
t.Fatalf("audio tracks = %d, want 2", len(mi.Audio))
|
||||
}
|
||||
if mi.Audio[0].Title != "English 5.1" {
|
||||
t.Errorf("audio[0].title = %q", mi.Audio[0].Title)
|
||||
}
|
||||
if len(mi.Subtitles) != 2 {
|
||||
t.Fatalf("subtitle tracks = %d, want 2", len(mi.Subtitles))
|
||||
}
|
||||
if !mi.Subtitles[1].Forced {
|
||||
t.Error("expected subtitle[1] to be forced")
|
||||
}
|
||||
if len(mi.Languages) != 2 {
|
||||
t.Errorf("languages = %v, want 2 entries", mi.Languages)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFFprobeOutput_BitDepthFromPixFmt(t *testing.T) {
|
||||
data := ffprobeOutput{
|
||||
Streams: []ffprobeStream{
|
||||
{CodecType: "video", CodecName: "hevc", Width: 1920, Height: 1080, PixFmt: "yuv420p10le"},
|
||||
},
|
||||
}
|
||||
|
||||
mi, err := parseFFprobeOutput(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if mi.Video.BitDepth != 10 {
|
||||
t.Errorf("bitDepth = %d, want 10", mi.Video.BitDepth)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFFprobeOutput_12BitFromPixFmt(t *testing.T) {
|
||||
data := ffprobeOutput{
|
||||
Streams: []ffprobeStream{
|
||||
{CodecType: "video", CodecName: "hevc", Width: 1920, Height: 1080, PixFmt: "yuv420p12be"},
|
||||
},
|
||||
}
|
||||
|
||||
mi, err := parseFFprobeOutput(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if mi.Video.BitDepth != 12 {
|
||||
t.Errorf("bitDepth = %d, want 12", mi.Video.BitDepth)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFFprobeOutput_DurationFromStreamFallback(t *testing.T) {
|
||||
data := ffprobeOutput{
|
||||
Format: ffprobeFormat{Duration: ""},
|
||||
Streams: []ffprobeStream{
|
||||
{CodecType: "video", CodecName: "h264", Width: 1280, Height: 720, Duration: "1800.5"},
|
||||
},
|
||||
}
|
||||
|
||||
mi, err := parseFFprobeOutput(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if mi.Video.Duration != 1800.5 {
|
||||
t.Errorf("duration = %v, want 1800.5", mi.Video.Duration)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFFprobeOutput_NoStreams(t *testing.T) {
|
||||
data := ffprobeOutput{}
|
||||
_, err := parseFFprobeOutput(data)
|
||||
if err == nil {
|
||||
t.Error("expected error for no streams")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFFprobeOutput_OnlyFirstVideoStream(t *testing.T) {
|
||||
data := ffprobeOutput{
|
||||
Streams: []ffprobeStream{
|
||||
{CodecType: "video", CodecName: "h264", Width: 1920, Height: 1080},
|
||||
{CodecType: "video", CodecName: "mjpeg", Width: 320, Height: 240}, // cover art
|
||||
},
|
||||
}
|
||||
|
||||
mi, err := parseFFprobeOutput(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if mi.Video.Codec != "h264" {
|
||||
t.Errorf("should use first video stream, got codec %q", mi.Video.Codec)
|
||||
}
|
||||
if mi.Video.Width != 1920 {
|
||||
t.Errorf("width = %d, should be from first video stream", mi.Video.Width)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFFprobeOutput_SMPTE2084_WithoutBT2020(t *testing.T) {
|
||||
data := ffprobeOutput{
|
||||
Streams: []ffprobeStream{
|
||||
{CodecType: "video", CodecName: "hevc", Width: 3840, Height: 2160, ColorTransfer: "smpte2084"},
|
||||
},
|
||||
}
|
||||
|
||||
mi, err := parseFFprobeOutput(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if mi.Video.HDR != "HDR10" {
|
||||
t.Errorf("hdr = %q, want HDR10", mi.Video.HDR)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFFprobeOutput_AribWithoutBT2020(t *testing.T) {
|
||||
data := ffprobeOutput{
|
||||
Streams: []ffprobeStream{
|
||||
{CodecType: "video", CodecName: "hevc", Width: 3840, Height: 2160, ColorTransfer: "arib-std-b67", ColorSpace: "other"},
|
||||
},
|
||||
}
|
||||
|
||||
mi, err := parseFFprobeOutput(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if mi.Video.HDR != "HLG" {
|
||||
t.Errorf("hdr = %q, want HLG", mi.Video.HDR)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFFprobeOutput_AudioOnly(t *testing.T) {
|
||||
data := ffprobeOutput{
|
||||
Streams: []ffprobeStream{
|
||||
{CodecType: "audio", CodecName: "flac", Channels: 2, Tags: map[string]string{"language": "eng"}},
|
||||
},
|
||||
}
|
||||
|
||||
mi, err := parseFFprobeOutput(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if mi.Video != nil {
|
||||
t.Error("expected no video info for audio-only")
|
||||
}
|
||||
if len(mi.Audio) != 1 {
|
||||
t.Errorf("audio tracks = %d, want 1", len(mi.Audio))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFFprobeOutput_FrameRateNoSlash(t *testing.T) {
|
||||
data := ffprobeOutput{
|
||||
Streams: []ffprobeStream{
|
||||
{CodecType: "video", CodecName: "h264", Width: 1920, Height: 1080, RFrameRate: "30"},
|
||||
},
|
||||
}
|
||||
|
||||
mi, err := parseFFprobeOutput(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if mi.Video.FrameRate != 0 {
|
||||
t.Errorf("frameRate = %v, want 0 (no slash)", mi.Video.FrameRate)
|
||||
}
|
||||
}
|
||||
93
internal/library/scanner_test.go
Normal file
93
internal/library/scanner_test.go
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
package library
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDiscoverFiles(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create video files (need to be >= 100MB to pass size check)
|
||||
largeContent := make([]byte, 101*1024*1024)
|
||||
|
||||
videoFiles := []string{"movie.mkv", "show.mp4", "clip.avi"}
|
||||
for _, name := range videoFiles {
|
||||
path := filepath.Join(dir, name)
|
||||
if err := os.WriteFile(path, largeContent, 0o644); err != nil {
|
||||
t.Fatalf("write %s: %v", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Non-video files (should be excluded)
|
||||
nonVideo := []string{"readme.txt", "cover.jpg", "subs.srt"}
|
||||
for _, name := range nonVideo {
|
||||
if err := os.WriteFile(filepath.Join(dir, name), largeContent, 0o644); err != nil {
|
||||
t.Fatalf("write %s: %v", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Small video file (should be excluded, < 100MB)
|
||||
if err := os.WriteFile(filepath.Join(dir, "small.mkv"), []byte("small"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Excluded pattern (sample)
|
||||
sampleDir := filepath.Join(dir, "sample")
|
||||
os.MkdirAll(sampleDir, 0o755)
|
||||
if err := os.WriteFile(filepath.Join(sampleDir, "sample.mkv"), largeContent, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
files, err := discoverFiles(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("discoverFiles: %v", err)
|
||||
}
|
||||
|
||||
if len(files) != 3 {
|
||||
t.Errorf("expected 3 files, got %d: %v", len(files), files)
|
||||
}
|
||||
|
||||
// Check that all returned files are video extensions
|
||||
for _, f := range files {
|
||||
ext := filepath.Ext(f)
|
||||
if ext != ".mkv" && ext != ".mp4" && ext != ".avi" {
|
||||
t.Errorf("unexpected extension: %s", ext)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiscoverFilesEmptyDir(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
files, err := discoverFiles(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("discoverFiles: %v", err)
|
||||
}
|
||||
if len(files) != 0 {
|
||||
t.Errorf("expected 0 files, got %d", len(files))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiscoverFilesExcludePatterns(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
largeContent := make([]byte, 101*1024*1024)
|
||||
|
||||
excludeDirs := []string{"trailer", "featurette", "extras", "bonus"}
|
||||
for _, name := range excludeDirs {
|
||||
sub := filepath.Join(dir, name)
|
||||
os.MkdirAll(sub, 0o755)
|
||||
if err := os.WriteFile(filepath.Join(sub, "video.mkv"), largeContent, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
files, err := discoverFiles(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(files) != 0 {
|
||||
t.Errorf("expected 0 files (all excluded), got %d: %v", len(files), files)
|
||||
}
|
||||
}
|
||||
108
internal/library/sync_test.go
Normal file
108
internal/library/sync_test.go
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
package library
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/torrentclaw/unarr/internal/library/mediainfo"
|
||||
)
|
||||
|
||||
func TestBuildSyncItems(t *testing.T) {
|
||||
cache := &LibraryCache{
|
||||
Items: []LibraryItem{
|
||||
{
|
||||
FilePath: "/media/movies/Inception.mkv",
|
||||
FileName: "Inception.2010.1080p.mkv",
|
||||
FileSize: 5000000000,
|
||||
Title: "Inception",
|
||||
Year: "2010",
|
||||
MediaInfo: &mediainfo.MediaInfo{
|
||||
Video: &mediainfo.VideoInfo{
|
||||
Codec: "hevc",
|
||||
Width: 1920,
|
||||
Height: 1080,
|
||||
BitDepth: 10,
|
||||
HDR: "HDR10",
|
||||
},
|
||||
Audio: []mediainfo.AudioTrack{
|
||||
{Lang: "en", Codec: "ac3", Channels: 6, Default: true},
|
||||
{Lang: "es", Codec: "aac", Channels: 2},
|
||||
},
|
||||
Subtitles: []mediainfo.SubtitleTrack{
|
||||
{Lang: "en", Codec: "subrip"},
|
||||
{Lang: "es", Codec: "subrip"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
FilePath: "/media/shows/Breaking.Bad.S01E01.mkv",
|
||||
FileName: "Breaking.Bad.S01E01.mkv",
|
||||
FileSize: 1000000000,
|
||||
Title: "Breaking Bad",
|
||||
Season: 1,
|
||||
Episode: 1,
|
||||
},
|
||||
{
|
||||
// Item with scan error — should be skipped
|
||||
FilePath: "/media/bad.mkv",
|
||||
FileName: "bad.mkv",
|
||||
ScanError: "ffprobe failed",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
items := BuildSyncItems(cache)
|
||||
|
||||
if len(items) != 2 {
|
||||
t.Fatalf("expected 2 items (1 skipped), got %d", len(items))
|
||||
}
|
||||
|
||||
// First item: movie with full media info
|
||||
movie := items[0]
|
||||
if movie.Title != "Inception" {
|
||||
t.Errorf("title = %q, want Inception", movie.Title)
|
||||
}
|
||||
if movie.ContentType != "movie" {
|
||||
t.Errorf("contentType = %q, want movie", movie.ContentType)
|
||||
}
|
||||
if movie.Resolution != "1080p" {
|
||||
t.Errorf("resolution = %q, want 1080p", movie.Resolution)
|
||||
}
|
||||
if movie.VideoCodec != "hevc" {
|
||||
t.Errorf("videoCodec = %q, want hevc", movie.VideoCodec)
|
||||
}
|
||||
if movie.HDR != "HDR10" {
|
||||
t.Errorf("hdr = %q, want HDR10", movie.HDR)
|
||||
}
|
||||
if movie.AudioCodec != "ac3" {
|
||||
t.Errorf("audioCodec = %q, want ac3", movie.AudioCodec)
|
||||
}
|
||||
if movie.AudioChannels != 6 {
|
||||
t.Errorf("audioChannels = %d, want 6", movie.AudioChannels)
|
||||
}
|
||||
if len(movie.AudioLanguages) != 2 {
|
||||
t.Errorf("audioLanguages count = %d, want 2", len(movie.AudioLanguages))
|
||||
}
|
||||
if len(movie.SubtitleLanguages) != 2 {
|
||||
t.Errorf("subtitleLanguages count = %d, want 2", len(movie.SubtitleLanguages))
|
||||
}
|
||||
|
||||
// Second item: show without media info
|
||||
show := items[1]
|
||||
if show.ContentType != "show" {
|
||||
t.Errorf("contentType = %q, want show", show.ContentType)
|
||||
}
|
||||
if show.Season != 1 || show.Episode != 1 {
|
||||
t.Errorf("season/episode = %d/%d, want 1/1", show.Season, show.Episode)
|
||||
}
|
||||
if show.Resolution != "" {
|
||||
t.Errorf("resolution should be empty, got %q", show.Resolution)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildSyncItemsEmpty(t *testing.T) {
|
||||
cache := &LibraryCache{Items: nil}
|
||||
items := BuildSyncItems(cache)
|
||||
if len(items) != 0 {
|
||||
t.Errorf("expected 0 items, got %d", len(items))
|
||||
}
|
||||
}
|
||||
|
|
@ -2,6 +2,8 @@ package mediaserver
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
|
|
@ -69,6 +71,96 @@ func TestJellyfinParsing(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestPlexTokenFromPrefs(t *testing.T) {
|
||||
t.Run("valid prefs", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
prefsPath := filepath.Join(dir, "Preferences.xml")
|
||||
xml := `<?xml version="1.0" encoding="utf-8"?>
|
||||
<Preferences PlexOnlineToken="my-secret-token" OldestPreviousVersion="1.0"/>`
|
||||
os.WriteFile(prefsPath, []byte(xml), 0o644)
|
||||
|
||||
token := plexTokenFromPrefs(prefsPath)
|
||||
if token != "my-secret-token" {
|
||||
t.Errorf("token = %q, want my-secret-token", token)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no token attr", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
prefsPath := filepath.Join(dir, "Preferences.xml")
|
||||
xml := `<?xml version="1.0"?><Preferences/>`
|
||||
os.WriteFile(prefsPath, []byte(xml), 0o644)
|
||||
|
||||
token := plexTokenFromPrefs(prefsPath)
|
||||
if token != "" {
|
||||
t.Errorf("token = %q, want empty", token)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("file not found", func(t *testing.T) {
|
||||
token := plexTokenFromPrefs("/nonexistent/Preferences.xml")
|
||||
if token != "" {
|
||||
t.Errorf("token = %q, want empty", token)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid xml", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
prefsPath := filepath.Join(dir, "Preferences.xml")
|
||||
os.WriteFile(prefsPath, []byte("not xml at all"), 0o644)
|
||||
|
||||
token := plexTokenFromPrefs(prefsPath)
|
||||
if token != "" {
|
||||
t.Errorf("token = %q, want empty", token)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestParsePlexSectionsMultipleLocations(t *testing.T) {
|
||||
body := `{
|
||||
"MediaContainer": {
|
||||
"Directory": [
|
||||
{
|
||||
"title": "Movies",
|
||||
"Location": [
|
||||
{"path": "/media/movies"},
|
||||
{"path": "/media/movies2"}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}`
|
||||
|
||||
paths := parsePlexSections([]byte(body))
|
||||
if len(paths) != 2 {
|
||||
t.Fatalf("expected 2 paths, got %d", len(paths))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePlexSectionsEmptyPath(t *testing.T) {
|
||||
body := `{
|
||||
"MediaContainer": {
|
||||
"Directory": [
|
||||
{
|
||||
"Location": [{"path": ""}, {"path": "/valid"}]
|
||||
}
|
||||
]
|
||||
}
|
||||
}`
|
||||
|
||||
paths := parsePlexSections([]byte(body))
|
||||
if len(paths) != 1 {
|
||||
t.Fatalf("expected 1 path (empty filtered), got %d: %v", len(paths), paths)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommonMediaDirs(t *testing.T) {
|
||||
dirs := commonMediaDirs()
|
||||
if len(dirs) == 0 {
|
||||
t.Error("expected at least some common media dirs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParentDir(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
|
|
|||
47
internal/sentry/sentry_test.go
Normal file
47
internal/sentry/sentry_test.go
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
package sentry
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestEnvironment(t *testing.T) {
|
||||
tests := []struct {
|
||||
version string
|
||||
want string
|
||||
}{
|
||||
{"", "development"},
|
||||
{"dev", "development"},
|
||||
{"0.1.0-dev", "development"},
|
||||
{"1.0.0", "production"},
|
||||
{"0.3.5", "production"},
|
||||
{"2.0.0-beta", "production"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.version, func(t *testing.T) {
|
||||
got := environment(tt.version)
|
||||
if got != tt.want {
|
||||
t.Errorf("environment(%q) = %q, want %q", tt.version, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitNoOp(t *testing.T) {
|
||||
// With empty dsn (default in tests), Init should be a no-op
|
||||
Init("1.0.0")
|
||||
// Should not panic
|
||||
}
|
||||
|
||||
func TestCloseNoOp(t *testing.T) {
|
||||
// Close should be safe to call without Init
|
||||
Close()
|
||||
}
|
||||
|
||||
func TestCaptureErrorNil(t *testing.T) {
|
||||
// Should not panic with nil error
|
||||
CaptureError(nil, "test")
|
||||
}
|
||||
|
||||
func TestSetUser(t *testing.T) {
|
||||
// Should not panic without initialization
|
||||
SetUser("agent-123")
|
||||
}
|
||||
|
|
@ -1,7 +1,9 @@
|
|||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestFormatSize(t *testing.T) {
|
||||
|
|
@ -127,6 +129,10 @@ func TestQualityIndicator(t *testing.T) {
|
|||
{"medium", intPtr(60), "🟡"},
|
||||
{"high", intPtr(80), "🟢"},
|
||||
{"perfect", intPtr(100), "🟢"},
|
||||
{"boundary_40", intPtr(40), "🟡"},
|
||||
{"boundary_70", intPtr(70), "🟢"},
|
||||
{"boundary_39", intPtr(39), "🔴"},
|
||||
{"zero", intPtr(0), "🔴"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
|
@ -139,6 +145,52 @@ func TestQualityIndicator(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestSeedHealthIndicator(t *testing.T) {
|
||||
tests := []struct {
|
||||
seeds int
|
||||
want string
|
||||
}{
|
||||
{0, "🔴"},
|
||||
{5, "🔴"},
|
||||
{9, "🔴"},
|
||||
{10, "🟡"},
|
||||
{50, "🟡"},
|
||||
{100, "🟡"},
|
||||
{101, "🟢"},
|
||||
{1000, "🟢"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(fmt.Sprintf("seeds_%d", tt.seeds), func(t *testing.T) {
|
||||
got := SeedHealthIndicator(tt.seeds)
|
||||
if got != tt.want {
|
||||
t.Errorf("SeedHealthIndicator(%d) = %q, want %q", tt.seeds, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatRating(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input *string
|
||||
want string
|
||||
}{
|
||||
{"nil", nil, "-"},
|
||||
{"value", strPtr("8.5"), "8.5"},
|
||||
{"empty", strPtr(""), ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := FormatRating(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("FormatRating(%v) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringOrDash(t *testing.T) {
|
||||
s := "hello"
|
||||
if got := StringOrDash(&s); got != "hello" {
|
||||
|
|
@ -150,16 +202,160 @@ func TestStringOrDash(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestFormatContentType(t *testing.T) {
|
||||
if got := FormatContentType("movie"); got != "Movie" {
|
||||
t.Errorf("FormatContentType(movie) = %q, want Movie", got)
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"movie", "Movie"},
|
||||
{"Movie", "Movie"},
|
||||
{"MOVIE", "Movie"},
|
||||
{"show", "Show"},
|
||||
{"Show", "Show"},
|
||||
{"other", "other"},
|
||||
{"", ""},
|
||||
}
|
||||
if got := FormatContentType("show"); got != "Show" {
|
||||
t.Errorf("FormatContentType(show) = %q, want Show", got)
|
||||
}
|
||||
if got := FormatContentType("other"); got != "other" {
|
||||
t.Errorf("FormatContentType(other) = %q, want other", got)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
got := FormatContentType(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("FormatContentType(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func ptr[T any](v T) *T { return &v }
|
||||
func intPtr(v int) *int { return &v }
|
||||
func TestFormatLanguages(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []string
|
||||
want string
|
||||
}{
|
||||
{"nil", nil, "-"},
|
||||
{"empty", []string{}, "-"},
|
||||
{"single", []string{"en"}, "en"},
|
||||
{"multiple", []string{"en", "es", "fr"}, "en, es, fr"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := FormatLanguages(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("FormatLanguages(%v) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatSeedRatio(t *testing.T) {
|
||||
tests := []struct {
|
||||
seeders int
|
||||
leechers int
|
||||
want string
|
||||
}{
|
||||
{0, 0, "0:0"},
|
||||
{10, 0, "10:0"},
|
||||
{100, 10, "10:1"},
|
||||
{50, 50, "1:1"},
|
||||
{1, 3, "0:1"},
|
||||
{150, 10, "15:1"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(fmt.Sprintf("%d_%d", tt.seeders, tt.leechers), func(t *testing.T) {
|
||||
got := FormatSeedRatio(tt.seeders, tt.leechers)
|
||||
if got != tt.want {
|
||||
t.Errorf("FormatSeedRatio(%d, %d) = %q, want %q", tt.seeders, tt.leechers, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatTimeAgo(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"invalid", "not-a-date", "not-a-date"},
|
||||
{"just_now", now.Add(-10 * time.Second).Format(time.RFC3339), "just now"},
|
||||
{"minutes", now.Add(-5 * time.Minute).Format(time.RFC3339), "5m ago"},
|
||||
{"hours", now.Add(-3 * time.Hour).Format(time.RFC3339), "3h ago"},
|
||||
{"days", now.Add(-7 * 24 * time.Hour).Format(time.RFC3339), "7d ago"},
|
||||
{"months", now.Add(-60 * 24 * time.Hour).Format(time.RFC3339), "2mo ago"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := FormatTimeAgo(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("FormatTimeAgo(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatNumberExtended(t *testing.T) {
|
||||
tests := []struct {
|
||||
input int
|
||||
want string
|
||||
}{
|
||||
{-1000, "-1,000"},
|
||||
{-5, "-5"},
|
||||
{10000, "10,000"},
|
||||
{100000, "100,000"},
|
||||
{1000000, "1,000,000"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.want, func(t *testing.T) {
|
||||
got := FormatNumber(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("FormatNumber(%d) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateStringEdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
maxLen int
|
||||
want string
|
||||
}{
|
||||
{"maxLen_1", "hello", 1, "h"},
|
||||
{"maxLen_3", "hello", 3, "hel"},
|
||||
{"empty", "", 5, ""},
|
||||
{"unicode", "こんにちは世界", 5, "こん..."},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := TruncateString(tt.input, tt.maxLen)
|
||||
if got != tt.want {
|
||||
t.Errorf("TruncateString(%q, %d) = %q, want %q", tt.input, tt.maxLen, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPtr(t *testing.T) {
|
||||
v := 42
|
||||
p := Ptr(v)
|
||||
if *p != 42 {
|
||||
t.Errorf("Ptr(42) = %d, want 42", *p)
|
||||
}
|
||||
|
||||
s := "hello"
|
||||
sp := Ptr(s)
|
||||
if *sp != "hello" {
|
||||
t.Errorf("Ptr(hello) = %q, want hello", *sp)
|
||||
}
|
||||
}
|
||||
|
||||
func ptr[T any](v T) *T { return &v }
|
||||
func intPtr(v int) *int { return &v }
|
||||
func strPtr(v string) *string { return &v }
|
||||
|
|
|
|||
122
internal/ui/table_test.go
Normal file
122
internal/ui/table_test.go
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
package ui
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
tc "github.com/torrentclaw/go-client"
|
||||
)
|
||||
|
||||
func TestNewCleanTable(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
tbl := newCleanTable(&buf)
|
||||
tbl.Header([]string{"A", "B"})
|
||||
tbl.Append([]string{"1", "2"})
|
||||
tbl.Render()
|
||||
|
||||
out := buf.String()
|
||||
if !strings.Contains(out, "A") || !strings.Contains(out, "B") {
|
||||
t.Errorf("expected headers in output, got: %s", out)
|
||||
}
|
||||
if !strings.Contains(out, "1") || !strings.Contains(out, "2") {
|
||||
t.Errorf("expected row data in output, got: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrintSearchResultEntry(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
year := 2010
|
||||
rating := "8.8"
|
||||
quality := "1080p"
|
||||
codec := "x265"
|
||||
size := int64(4294967296)
|
||||
score := 85
|
||||
|
||||
r := tc.SearchResult{
|
||||
Title: "Inception",
|
||||
Year: &year,
|
||||
ContentType: "movie",
|
||||
RatingIMDb: &rating,
|
||||
Genres: []string{"Sci-Fi", "Action"},
|
||||
Torrents: []tc.TorrentInfo{
|
||||
{
|
||||
Quality: &quality,
|
||||
SizeBytes: &size,
|
||||
Seeders: 150,
|
||||
Leechers: 10,
|
||||
Source: "YTS",
|
||||
Codec: &codec,
|
||||
Languages: []string{"en", "es"},
|
||||
QualityScore: &score,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
printSearchResultEntry(&buf, r)
|
||||
out := buf.String()
|
||||
|
||||
if !strings.Contains(out, "Inception") {
|
||||
t.Error("expected title in output")
|
||||
}
|
||||
if !strings.Contains(out, "2010") {
|
||||
t.Error("expected year in output")
|
||||
}
|
||||
if !strings.Contains(out, "8.8") {
|
||||
t.Error("expected rating in output")
|
||||
}
|
||||
if !strings.Contains(out, "1080p") {
|
||||
t.Error("expected quality in output")
|
||||
}
|
||||
if !strings.Contains(out, "YTS") {
|
||||
t.Error("expected source in output")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrintSearchResultEntryNoTorrents(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
year := 2020
|
||||
|
||||
r := tc.SearchResult{
|
||||
Title: "No Torrents Movie",
|
||||
Year: &year,
|
||||
ContentType: "movie",
|
||||
Torrents: nil,
|
||||
}
|
||||
|
||||
printSearchResultEntry(&buf, r)
|
||||
out := buf.String()
|
||||
|
||||
if !strings.Contains(out, "No Torrents Movie") {
|
||||
t.Error("expected title in output")
|
||||
}
|
||||
if !strings.Contains(out, "No torrents available") {
|
||||
t.Error("expected no-torrents message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrintSearchResultEntryNilFields(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
r := tc.SearchResult{
|
||||
Title: "Minimal",
|
||||
ContentType: "movie",
|
||||
Torrents: []tc.TorrentInfo{
|
||||
{
|
||||
Seeders: 5,
|
||||
Leechers: 3,
|
||||
Source: "RARBG",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
printSearchResultEntry(&buf, r)
|
||||
out := buf.String()
|
||||
|
||||
if !strings.Contains(out, "Minimal") {
|
||||
t.Error("expected title in output")
|
||||
}
|
||||
if !strings.Contains(out, "-") {
|
||||
t.Error("expected dash for nil fields")
|
||||
}
|
||||
}
|
||||
75
internal/upgrade/cache.go
Normal file
75
internal/upgrade/cache.go
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
package upgrade
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/torrentclaw/unarr/internal/config"
|
||||
)
|
||||
|
||||
const cacheTTL = 1 * time.Hour
|
||||
|
||||
// versionCache is the on-disk structure for cached version checks.
|
||||
type versionCache struct {
|
||||
Version string `json:"version"`
|
||||
CheckedAt time.Time `json:"checkedAt"`
|
||||
}
|
||||
|
||||
// cacheFilePath returns the path to the version cache file.
|
||||
func cacheFilePath() string {
|
||||
return filepath.Join(config.DataDir(), "latest-version.json")
|
||||
}
|
||||
|
||||
// ReadCachedVersion returns the cached latest version if it's fresh (< cacheTTL).
|
||||
// Returns empty string if cache is missing, stale, or corrupt.
|
||||
func ReadCachedVersion() string {
|
||||
data, err := os.ReadFile(cacheFilePath())
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
var c versionCache
|
||||
if json.Unmarshal(data, &c) != nil {
|
||||
return ""
|
||||
}
|
||||
if time.Since(c.CheckedAt) > cacheTTL {
|
||||
return ""
|
||||
}
|
||||
return c.Version
|
||||
}
|
||||
|
||||
// writeCachedVersion writes the latest version to the cache file.
|
||||
func writeCachedVersion(version string) {
|
||||
c := versionCache{
|
||||
Version: version,
|
||||
CheckedAt: time.Now(),
|
||||
}
|
||||
data, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
path := cacheFilePath()
|
||||
os.MkdirAll(filepath.Dir(path), 0o755)
|
||||
// Best-effort write — ignore errors
|
||||
tmp := path + ".tmp"
|
||||
if err := os.WriteFile(tmp, data, 0o644); err != nil {
|
||||
return
|
||||
}
|
||||
os.Rename(tmp, path)
|
||||
}
|
||||
|
||||
// CheckLatestCached returns the latest version, using cache when fresh.
|
||||
// If cache is stale, fetches from GitHub and updates the cache.
|
||||
func CheckLatestCached(ctx context.Context) (version string, fromCache bool, err error) {
|
||||
if cached := ReadCachedVersion(); cached != "" {
|
||||
return cached, true, nil
|
||||
}
|
||||
v, err := fetchLatestVersion(ctx)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
writeCachedVersion(v)
|
||||
return v, false, nil
|
||||
}
|
||||
|
|
@ -152,9 +152,13 @@ func (u *Upgrader) fail(format string, args ...any) Result {
|
|||
}
|
||||
}
|
||||
|
||||
// CheckLatest fetches the latest version from GitHub API.
|
||||
// CheckLatest fetches the latest version from GitHub API and updates the cache.
|
||||
func CheckLatest(ctx context.Context) (string, error) {
|
||||
return fetchLatestVersion(ctx)
|
||||
v, err := fetchLatestVersion(ctx)
|
||||
if err == nil {
|
||||
writeCachedVersion(v)
|
||||
}
|
||||
return v, err
|
||||
}
|
||||
|
||||
// installBinary copies the new binary to the target path, preserving original permissions.
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import (
|
|||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
|
|
@ -305,3 +306,768 @@ func TestFetchLatestVersionMockServer(t *testing.T) {
|
|||
t.Errorf("status = %d, want 200", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
// --- New tests below ---
|
||||
|
||||
// swapHTTPClient replaces the package-level httpClient and returns a restore function.
|
||||
func swapHTTPClient(c *http.Client) func() {
|
||||
orig := httpClient
|
||||
httpClient = c
|
||||
return func() { httpClient = orig }
|
||||
}
|
||||
|
||||
// rewriteTransport redirects all requests to the given base URL,
|
||||
// preserving path and query.
|
||||
type rewriteTransport struct {
|
||||
url string
|
||||
}
|
||||
|
||||
func (rt *rewriteTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
req.URL.Scheme = "http"
|
||||
u := strings.TrimPrefix(rt.url, "http://")
|
||||
req.URL.Host = u
|
||||
return http.DefaultTransport.RoundTrip(req)
|
||||
}
|
||||
|
||||
// createTarGz is a test helper that creates a tar.gz file with a single file entry.
|
||||
func createTarGz(t *testing.T, archivePath, entryName string, content []byte) {
|
||||
t.Helper()
|
||||
f, err := os.Create(archivePath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
gw := gzip.NewWriter(f)
|
||||
tw := tar.NewWriter(gw)
|
||||
|
||||
tw.WriteHeader(&tar.Header{
|
||||
Name: entryName,
|
||||
Mode: 0o755,
|
||||
Size: int64(len(content)),
|
||||
})
|
||||
tw.Write(content)
|
||||
|
||||
tw.Close()
|
||||
gw.Close()
|
||||
f.Close()
|
||||
}
|
||||
|
||||
func TestArchiveNameTableDriven(t *testing.T) {
|
||||
// We can only run archiveName for the current GOOS/GOARCH,
|
||||
// so we test several version strings and verify the pattern.
|
||||
tests := []struct {
|
||||
version string
|
||||
}{
|
||||
{"0.1.0"},
|
||||
{"1.0.0-rc1"},
|
||||
{"2.5.10"},
|
||||
{"0.0.0"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.version, func(t *testing.T) {
|
||||
got := archiveName(tt.version)
|
||||
prefix := fmt.Sprintf("unarr_%s_%s_%s.", tt.version, runtime.GOOS, runtime.GOARCH)
|
||||
if !strings.HasPrefix(got, prefix) {
|
||||
t.Errorf("archiveName(%q) = %q, want prefix %q", tt.version, got, prefix)
|
||||
}
|
||||
if runtime.GOOS == "windows" {
|
||||
if !strings.HasSuffix(got, ".zip") {
|
||||
t.Errorf("archiveName on windows should end with .zip, got %q", got)
|
||||
}
|
||||
} else {
|
||||
if !strings.HasSuffix(got, ".tar.gz") {
|
||||
t.Errorf("archiveName on non-windows should end with .tar.gz, got %q", got)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReleaseURLEdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
version string
|
||||
filename string
|
||||
wantURL string
|
||||
}{
|
||||
{
|
||||
name: "pre-release version",
|
||||
version: "2.0.0-beta.1",
|
||||
filename: "unarr_2.0.0-beta.1_darwin_arm64.tar.gz",
|
||||
wantURL: "https://github.com/torrentclaw/unarr/releases/download/v2.0.0-beta.1/unarr_2.0.0-beta.1_darwin_arm64.tar.gz",
|
||||
},
|
||||
{
|
||||
name: "checksums file",
|
||||
version: "3.0.0",
|
||||
filename: "checksums.txt",
|
||||
wantURL: "https://github.com/torrentclaw/unarr/releases/download/v3.0.0/checksums.txt",
|
||||
},
|
||||
{
|
||||
name: "windows zip",
|
||||
version: "1.2.3",
|
||||
filename: "unarr_1.2.3_windows_amd64.zip",
|
||||
wantURL: "https://github.com/torrentclaw/unarr/releases/download/v1.2.3/unarr_1.2.3_windows_amd64.zip",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := releaseURL(tt.version, tt.filename)
|
||||
if got != tt.wantURL {
|
||||
t.Errorf("releaseURL(%q, %q) = %q, want %q", tt.version, tt.filename, got, tt.wantURL)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractBinaryDispatcher(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("extractBinary dispatcher test only on unix")
|
||||
}
|
||||
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create a valid tar.gz with the unarr binary
|
||||
archivePath := filepath.Join(dir, "test.tar.gz")
|
||||
binaryContent := []byte("#!/bin/sh\necho dispatcher test\n")
|
||||
createTarGz(t, archivePath, "unarr", binaryContent)
|
||||
|
||||
destDir := filepath.Join(dir, "out")
|
||||
os.MkdirAll(destDir, 0o755)
|
||||
|
||||
binPath, err := extractBinary(archivePath, destDir)
|
||||
if err != nil {
|
||||
t.Fatalf("extractBinary() error = %v", err)
|
||||
}
|
||||
if filepath.Base(binPath) != "unarr" {
|
||||
t.Errorf("extractBinary() returned %q, want base name 'unarr'", binPath)
|
||||
}
|
||||
data, _ := os.ReadFile(binPath)
|
||||
if string(data) != string(binaryContent) {
|
||||
t.Error("extractBinary() content mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractBinaryInvalidArchive(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("tar.gz test only on unix")
|
||||
}
|
||||
|
||||
dir := t.TempDir()
|
||||
archivePath := filepath.Join(dir, "garbage.tar.gz")
|
||||
os.WriteFile(archivePath, []byte("this is not a tar.gz"), 0o644)
|
||||
|
||||
destDir := filepath.Join(dir, "out")
|
||||
os.MkdirAll(destDir, 0o755)
|
||||
|
||||
_, err := extractBinary(archivePath, destDir)
|
||||
if err == nil {
|
||||
t.Error("extractBinary with garbage data should return error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractBinaryNonExistentArchive(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
_, err := extractBinary("/nonexistent-archive-file", filepath.Join(dir, "out"))
|
||||
if err == nil {
|
||||
t.Error("extractBinary with nonexistent file should return error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpgraderFail(t *testing.T) {
|
||||
u := &Upgrader{CurrentVersion: "1.0.0"}
|
||||
|
||||
// Capture log messages
|
||||
var logged []string
|
||||
u.OnProgress = func(msg string) { logged = append(logged, msg) }
|
||||
|
||||
result := u.fail("something went wrong: %d", 42)
|
||||
|
||||
if result.Success {
|
||||
t.Error("fail() should return Success=false")
|
||||
}
|
||||
if result.OldVersion != "1.0.0" {
|
||||
t.Errorf("fail() OldVersion = %q, want 1.0.0", result.OldVersion)
|
||||
}
|
||||
if result.Error == nil {
|
||||
t.Fatal("fail() should set Error")
|
||||
}
|
||||
if !strings.Contains(result.Error.Error(), "something went wrong: 42") {
|
||||
t.Errorf("fail() Error = %q, want to contain 'something went wrong: 42'", result.Error)
|
||||
}
|
||||
if len(logged) == 0 {
|
||||
t.Error("fail() should call OnProgress")
|
||||
}
|
||||
if len(logged) > 0 && !strings.Contains(logged[0], "FAILED") {
|
||||
t.Errorf("fail() logged %q, want to contain 'FAILED'", logged[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpgraderFailNilOnProgress(t *testing.T) {
|
||||
u := &Upgrader{CurrentVersion: "2.0.0"}
|
||||
// OnProgress is nil — should not panic
|
||||
result := u.fail("error without listener")
|
||||
if result.Success {
|
||||
t.Error("fail() should return Success=false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchLatestVersionWithHTTPTest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
statusCode int
|
||||
wantVer string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid response",
|
||||
body: `{"tag_name":"v3.1.4"}`,
|
||||
statusCode: 200,
|
||||
wantVer: "3.1.4",
|
||||
},
|
||||
{
|
||||
name: "valid response without v prefix",
|
||||
body: `{"tag_name":"2.0.0"}`,
|
||||
statusCode: 200,
|
||||
wantVer: "2.0.0",
|
||||
},
|
||||
{
|
||||
name: "empty tag_name",
|
||||
body: `{"tag_name":""}`,
|
||||
statusCode: 200,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "server error",
|
||||
body: `Internal Server Error`,
|
||||
statusCode: 500,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid json",
|
||||
body: `{invalid`,
|
||||
statusCode: 200,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(tt.statusCode)
|
||||
fmt.Fprint(w, tt.body)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
// Use a custom transport that rewrites requests to our test server
|
||||
restore := swapHTTPClient(&http.Client{
|
||||
Transport: &rewriteTransport{url: srv.URL},
|
||||
})
|
||||
defer restore()
|
||||
|
||||
ver, err := CheckLatest(context.Background())
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("CheckLatest() error = nil, want error")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("CheckLatest() error = %v", err)
|
||||
}
|
||||
if ver != tt.wantVer {
|
||||
t.Errorf("CheckLatest() = %q, want %q", ver, tt.wantVer)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadWithHTTPTest(t *testing.T) {
|
||||
archiveBody := "fake-archive-bytes-for-download-test"
|
||||
|
||||
t.Run("successful download", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Verify user-agent header
|
||||
if ua := r.Header.Get("User-Agent"); ua != "unarr-updater" {
|
||||
t.Errorf("User-Agent = %q, want 'unarr-updater'", ua)
|
||||
}
|
||||
w.WriteHeader(200)
|
||||
fmt.Fprint(w, archiveBody)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
restore := swapHTTPClient(&http.Client{
|
||||
Transport: &rewriteTransport{url: srv.URL},
|
||||
})
|
||||
defer restore()
|
||||
|
||||
path, err := download(context.Background(), "1.0.0")
|
||||
if err != nil {
|
||||
t.Fatalf("download() error = %v", err)
|
||||
}
|
||||
defer os.Remove(path)
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("read downloaded file: %v", err)
|
||||
}
|
||||
if string(data) != archiveBody {
|
||||
t.Errorf("downloaded content = %q, want %q", data, archiveBody)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("server returns 404", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(404)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
restore := swapHTTPClient(&http.Client{
|
||||
Transport: &rewriteTransport{url: srv.URL},
|
||||
})
|
||||
defer restore()
|
||||
|
||||
_, err := download(context.Background(), "99.99.99")
|
||||
if err == nil {
|
||||
t.Error("download() with 404 should return error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "HTTP 404") {
|
||||
t.Errorf("download() error = %q, want to contain 'HTTP 404'", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cancelled context", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
fmt.Fprint(w, "data")
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
restore := swapHTTPClient(&http.Client{
|
||||
Transport: &rewriteTransport{url: srv.URL},
|
||||
})
|
||||
defer restore()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // cancel immediately
|
||||
|
||||
_, err := download(ctx, "1.0.0")
|
||||
if err == nil {
|
||||
t.Error("download() with cancelled context should return error")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestVerifyChecksumWithHTTPTest(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("tar.gz test only on unix")
|
||||
}
|
||||
|
||||
// Create a fake archive file
|
||||
dir := t.TempDir()
|
||||
archiveContent := []byte("archive-content-for-checksum-test")
|
||||
archivePath := filepath.Join(dir, "test-archive.tar.gz")
|
||||
os.WriteFile(archivePath, archiveContent, 0o644)
|
||||
|
||||
h := sha256.Sum256(archiveContent)
|
||||
correctHash := hex.EncodeToString(h[:])
|
||||
|
||||
// The function builds the archive name using archiveName(), which uses runtime.GOOS/GOARCH.
|
||||
expectedArchiveName := archiveName("1.0.0")
|
||||
|
||||
t.Run("matching checksum", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintf(w, "0000000000000000000000000000000000000000000000000000000000000000 other_file.tar.gz\n")
|
||||
fmt.Fprintf(w, "%s %s\n", correctHash, expectedArchiveName)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
restore := swapHTTPClient(&http.Client{
|
||||
Transport: &rewriteTransport{url: srv.URL},
|
||||
})
|
||||
defer restore()
|
||||
|
||||
err := verifyChecksum(context.Background(), "1.0.0", archivePath)
|
||||
if err != nil {
|
||||
t.Errorf("verifyChecksum() = %v, want nil", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("mismatched checksum", func(t *testing.T) {
|
||||
wrongHash := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintf(w, "%s %s\n", wrongHash, expectedArchiveName)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
restore := swapHTTPClient(&http.Client{
|
||||
Transport: &rewriteTransport{url: srv.URL},
|
||||
})
|
||||
defer restore()
|
||||
|
||||
err := verifyChecksum(context.Background(), "1.0.0", archivePath)
|
||||
if err == nil {
|
||||
t.Error("verifyChecksum() with wrong hash should return error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "SHA256 mismatch") {
|
||||
t.Errorf("verifyChecksum() error = %q, want to contain 'SHA256 mismatch'", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("archive not in checksums", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintf(w, "abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890 some_other_file.tar.gz\n")
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
restore := swapHTTPClient(&http.Client{
|
||||
Transport: &rewriteTransport{url: srv.URL},
|
||||
})
|
||||
defer restore()
|
||||
|
||||
err := verifyChecksum(context.Background(), "1.0.0", archivePath)
|
||||
if err == nil {
|
||||
t.Error("verifyChecksum() with missing entry should return error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "no checksum found") {
|
||||
t.Errorf("verifyChecksum() error = %q, want to contain 'no checksum found'", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("checksums server error", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(500)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
restore := swapHTTPClient(&http.Client{
|
||||
Transport: &rewriteTransport{url: srv.URL},
|
||||
})
|
||||
defer restore()
|
||||
|
||||
err := verifyChecksum(context.Background(), "1.0.0", archivePath)
|
||||
if err == nil {
|
||||
t.Error("verifyChecksum() with server error should return error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "HTTP 500") {
|
||||
t.Errorf("verifyChecksum() error = %q, want to contain 'HTTP 500'", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nonexistent archive file", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintf(w, "%s %s\n", correctHash, expectedArchiveName)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
restore := swapHTTPClient(&http.Client{
|
||||
Transport: &rewriteTransport{url: srv.URL},
|
||||
})
|
||||
defer restore()
|
||||
|
||||
err := verifyChecksum(context.Background(), "1.0.0", "/nonexistent-archive-path")
|
||||
if err == nil {
|
||||
t.Error("verifyChecksum() with nonexistent archive should return error")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestVerifyChecksumCaseInsensitive(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("tar.gz test only on unix")
|
||||
}
|
||||
|
||||
dir := t.TempDir()
|
||||
archiveContent := []byte("case-insensitive-hash-test")
|
||||
archivePath := filepath.Join(dir, "archive.tar.gz")
|
||||
os.WriteFile(archivePath, archiveContent, 0o644)
|
||||
|
||||
h := sha256.Sum256(archiveContent)
|
||||
// Use uppercase hash in checksums.txt — verifyChecksum uses EqualFold
|
||||
upperHash := strings.ToUpper(hex.EncodeToString(h[:]))
|
||||
expectedArchiveName := archiveName("1.0.0")
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintf(w, "%s %s\n", upperHash, expectedArchiveName)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
restore := swapHTTPClient(&http.Client{
|
||||
Transport: &rewriteTransport{url: srv.URL},
|
||||
})
|
||||
defer restore()
|
||||
|
||||
err := verifyChecksum(context.Background(), "1.0.0", archivePath)
|
||||
if err != nil {
|
||||
t.Errorf("verifyChecksum() with uppercase hash = %v, want nil", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractTarGzNestedDirectory(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("tar.gz test only on unix")
|
||||
}
|
||||
|
||||
dir := t.TempDir()
|
||||
archivePath := filepath.Join(dir, "nested.tar.gz")
|
||||
|
||||
binaryContent := []byte("#!/bin/sh\necho nested\n")
|
||||
|
||||
f, err := os.Create(archivePath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
gw := gzip.NewWriter(f)
|
||||
tw := tar.NewWriter(gw)
|
||||
|
||||
// Write a directory entry first
|
||||
tw.WriteHeader(&tar.Header{
|
||||
Name: "unarr_1.0.0_linux_amd64/",
|
||||
Typeflag: tar.TypeDir,
|
||||
Mode: 0o755,
|
||||
})
|
||||
|
||||
// Write a README in the subdirectory
|
||||
readmeContent := []byte("This is a README")
|
||||
tw.WriteHeader(&tar.Header{
|
||||
Name: "unarr_1.0.0_linux_amd64/README.md",
|
||||
Mode: 0o644,
|
||||
Size: int64(len(readmeContent)),
|
||||
})
|
||||
tw.Write(readmeContent)
|
||||
|
||||
// Write the binary nested inside the directory
|
||||
tw.WriteHeader(&tar.Header{
|
||||
Name: "unarr_1.0.0_linux_amd64/unarr",
|
||||
Mode: 0o755,
|
||||
Size: int64(len(binaryContent)),
|
||||
})
|
||||
tw.Write(binaryContent)
|
||||
|
||||
// Write another unrelated file after the binary
|
||||
licenseContent := []byte("MIT License")
|
||||
tw.WriteHeader(&tar.Header{
|
||||
Name: "unarr_1.0.0_linux_amd64/LICENSE",
|
||||
Mode: 0o644,
|
||||
Size: int64(len(licenseContent)),
|
||||
})
|
||||
tw.Write(licenseContent)
|
||||
|
||||
tw.Close()
|
||||
gw.Close()
|
||||
f.Close()
|
||||
|
||||
destDir := filepath.Join(dir, "out")
|
||||
os.MkdirAll(destDir, 0o755)
|
||||
|
||||
binPath, err := extractTarGz(archivePath, destDir)
|
||||
if err != nil {
|
||||
t.Fatalf("extractTarGz() with nested dir = %v", err)
|
||||
}
|
||||
|
||||
if filepath.Base(binPath) != "unarr" {
|
||||
t.Errorf("extracted name = %q, want 'unarr'", filepath.Base(binPath))
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(binPath)
|
||||
if string(data) != string(binaryContent) {
|
||||
t.Error("extracted content does not match")
|
||||
}
|
||||
|
||||
// Verify that only the binary was extracted (README and LICENSE should NOT be in destDir)
|
||||
entries, _ := os.ReadDir(destDir)
|
||||
if len(entries) != 1 {
|
||||
names := make([]string, len(entries))
|
||||
for i, e := range entries {
|
||||
names[i] = e.Name()
|
||||
}
|
||||
t.Errorf("destDir should contain only 'unarr', got %v", names)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractTarGzMultipleFiles(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("tar.gz test only on unix")
|
||||
}
|
||||
|
||||
dir := t.TempDir()
|
||||
archivePath := filepath.Join(dir, "multi.tar.gz")
|
||||
|
||||
binaryContent := []byte("#!/bin/sh\necho multi\n")
|
||||
|
||||
f, err := os.Create(archivePath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
gw := gzip.NewWriter(f)
|
||||
tw := tar.NewWriter(gw)
|
||||
|
||||
// Several non-binary files before the actual binary
|
||||
for _, name := range []string{"README.md", "LICENSE", "config.yaml", "completions.bash"} {
|
||||
content := []byte("content of " + name)
|
||||
tw.WriteHeader(&tar.Header{
|
||||
Name: name,
|
||||
Mode: 0o644,
|
||||
Size: int64(len(content)),
|
||||
})
|
||||
tw.Write(content)
|
||||
}
|
||||
|
||||
// The actual binary
|
||||
tw.WriteHeader(&tar.Header{
|
||||
Name: "unarr",
|
||||
Mode: 0o755,
|
||||
Size: int64(len(binaryContent)),
|
||||
})
|
||||
tw.Write(binaryContent)
|
||||
|
||||
tw.Close()
|
||||
gw.Close()
|
||||
f.Close()
|
||||
|
||||
destDir := filepath.Join(dir, "out")
|
||||
os.MkdirAll(destDir, 0o755)
|
||||
|
||||
binPath, err := extractTarGz(archivePath, destDir)
|
||||
if err != nil {
|
||||
t.Fatalf("extractTarGz() = %v", err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(binPath)
|
||||
if string(data) != string(binaryContent) {
|
||||
t.Error("binary content mismatch among multiple files")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractTarGzSymlinkSkipped(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("tar.gz test only on unix")
|
||||
}
|
||||
|
||||
dir := t.TempDir()
|
||||
archivePath := filepath.Join(dir, "symlink.tar.gz")
|
||||
|
||||
f, err := os.Create(archivePath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
gw := gzip.NewWriter(f)
|
||||
tw := tar.NewWriter(gw)
|
||||
|
||||
// Write a symlink entry named "unarr" — should be skipped because Typeflag != TypeReg
|
||||
tw.WriteHeader(&tar.Header{
|
||||
Name: "unarr",
|
||||
Typeflag: tar.TypeSymlink,
|
||||
Linkname: "/etc/passwd",
|
||||
Mode: 0o755,
|
||||
})
|
||||
|
||||
tw.Close()
|
||||
gw.Close()
|
||||
f.Close()
|
||||
|
||||
destDir := filepath.Join(dir, "out")
|
||||
os.MkdirAll(destDir, 0o755)
|
||||
|
||||
_, err = extractTarGz(archivePath, destDir)
|
||||
if err == nil {
|
||||
t.Error("extractTarGz() should return error when binary is a symlink (not TypeReg)")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not found in archive") {
|
||||
t.Errorf("error = %q, want to contain 'not found in archive'", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstallBinaryNonExistentSource(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
dst := filepath.Join(dir, "output")
|
||||
|
||||
err := installBinary("/nonexistent-source-binary", dst)
|
||||
if err == nil {
|
||||
t.Error("installBinary with nonexistent source should return error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "read new binary") {
|
||||
t.Errorf("error = %q, want to contain 'read new binary'", err)
|
||||
}
|
||||
|
||||
// Verify destination was not created
|
||||
if _, statErr := os.Stat(dst); statErr == nil {
|
||||
t.Error("destination file should not exist after failed install")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstallBinaryUnwritableDestination(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
src := filepath.Join(dir, "source")
|
||||
os.WriteFile(src, []byte("binary"), 0o755)
|
||||
|
||||
// Try to write to a path inside a non-existent directory
|
||||
dst := filepath.Join(dir, "nonexistent-subdir", "binary")
|
||||
|
||||
err := installBinary(src, dst)
|
||||
if err == nil {
|
||||
t.Error("installBinary to non-writable destination should return error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "write binary") {
|
||||
t.Errorf("error = %q, want to contain 'write binary'", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpgraderLog(t *testing.T) {
|
||||
var messages []string
|
||||
u := &Upgrader{
|
||||
CurrentVersion: "1.0.0",
|
||||
OnProgress: func(msg string) { messages = append(messages, msg) },
|
||||
}
|
||||
|
||||
u.log("hello world")
|
||||
if len(messages) != 1 || messages[0] != "hello world" {
|
||||
t.Errorf("log() messages = %v, want [hello world]", messages)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpgraderLogNilOnProgress(t *testing.T) {
|
||||
u := &Upgrader{CurrentVersion: "1.0.0"}
|
||||
// Should not panic
|
||||
u.log("test message with nil OnProgress")
|
||||
}
|
||||
|
||||
func TestResultFields(t *testing.T) {
|
||||
r := Result{
|
||||
Success: true,
|
||||
OldVersion: "1.0.0",
|
||||
NewVersion: "2.0.0",
|
||||
BackupPath: "/tmp/backup",
|
||||
}
|
||||
if !r.Success || r.OldVersion != "1.0.0" || r.NewVersion != "2.0.0" || r.BackupPath != "/tmp/backup" {
|
||||
t.Errorf("Result fields not set correctly: %+v", r)
|
||||
}
|
||||
|
||||
r2 := Result{Success: false, Error: fmt.Errorf("test error")}
|
||||
if r2.Success || r2.Error == nil {
|
||||
t.Errorf("Result error case not correct: %+v", r2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadSetsUserAgent(t *testing.T) {
|
||||
var gotUA string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotUA = r.Header.Get("User-Agent")
|
||||
w.WriteHeader(200)
|
||||
fmt.Fprint(w, "data")
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
restore := swapHTTPClient(&http.Client{
|
||||
Transport: &rewriteTransport{url: srv.URL},
|
||||
})
|
||||
defer restore()
|
||||
|
||||
path, err := download(context.Background(), "1.0.0")
|
||||
if err != nil {
|
||||
t.Fatalf("download() = %v", err)
|
||||
}
|
||||
defer os.Remove(path)
|
||||
|
||||
if gotUA != "unarr-updater" {
|
||||
t.Errorf("User-Agent = %q, want 'unarr-updater'", gotUA)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
632
internal/usenet/download/progress_expand_test.go
Normal file
632
internal/usenet/download/progress_expand_test.go
Normal file
|
|
@ -0,0 +1,632 @@
|
|||
package download
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/torrentclaw/unarr/internal/usenet/nzb"
|
||||
)
|
||||
|
||||
// --- Fingerprint ---
|
||||
|
||||
func TestFingerprint_EmptyNZB(t *testing.T) {
|
||||
n := &nzb.NZB{}
|
||||
fp := Fingerprint(n)
|
||||
// Empty NZB should still produce a deterministic hash (of zero message IDs).
|
||||
fp2 := Fingerprint(n)
|
||||
if fp != fp2 {
|
||||
t.Fatal("fingerprint of empty NZB should be deterministic")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFingerprint_OrderIndependent(t *testing.T) {
|
||||
// Fingerprint sorts IDs, so different file order should produce the same hash.
|
||||
n1 := &nzb.NZB{
|
||||
Files: []nzb.File{
|
||||
{Segments: []nzb.Segment{{MessageID: "a@x"}, {MessageID: "b@x"}}},
|
||||
{Segments: []nzb.Segment{{MessageID: "c@x"}}},
|
||||
},
|
||||
}
|
||||
n2 := &nzb.NZB{
|
||||
Files: []nzb.File{
|
||||
{Segments: []nzb.Segment{{MessageID: "c@x"}}},
|
||||
{Segments: []nzb.Segment{{MessageID: "b@x"}, {MessageID: "a@x"}}},
|
||||
},
|
||||
}
|
||||
if Fingerprint(n1) != Fingerprint(n2) {
|
||||
t.Fatal("fingerprint should be order-independent (sorted by message ID)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFingerprint_SingleSegment(t *testing.T) {
|
||||
n := &nzb.NZB{
|
||||
Files: []nzb.File{
|
||||
{Segments: []nzb.Segment{{MessageID: "only@one"}}},
|
||||
},
|
||||
}
|
||||
fp := Fingerprint(n)
|
||||
if fp == [32]byte{} {
|
||||
t.Fatal("fingerprint should not be zero for a non-empty NZB")
|
||||
}
|
||||
}
|
||||
|
||||
// --- ProgressTracker MarkDone idempotency ---
|
||||
|
||||
func TestProgressTracker_MarkDoneIdempotent(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
n := makeTestNZB(1, 5)
|
||||
tracker := NewProgressTracker("idem", n, dir)
|
||||
|
||||
tracker.MarkDone(0, 2)
|
||||
if tracker.CompletedSegments(0) != 1 {
|
||||
t.Fatalf("expected 1, got %d", tracker.CompletedSegments(0))
|
||||
}
|
||||
|
||||
// Mark the same segment again — count should not increase.
|
||||
tracker.MarkDone(0, 2)
|
||||
if tracker.CompletedSegments(0) != 1 {
|
||||
t.Fatalf("idempotent mark: expected 1, got %d", tracker.CompletedSegments(0))
|
||||
}
|
||||
}
|
||||
|
||||
// --- ProgressTracker TotalCompleted ---
|
||||
|
||||
func TestProgressTracker_TotalCompleted(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
n := makeTestNZB(3, 4) // 3 files, 4 segs each
|
||||
tracker := NewProgressTracker("total", n, dir)
|
||||
|
||||
tracker.MarkDone(0, 0)
|
||||
tracker.MarkDone(0, 1)
|
||||
tracker.MarkDone(1, 3)
|
||||
tracker.MarkDone(2, 0)
|
||||
tracker.MarkDone(2, 1)
|
||||
tracker.MarkDone(2, 2)
|
||||
|
||||
if got := tracker.TotalCompleted(); got != 6 {
|
||||
t.Errorf("TotalCompleted: got %d, want 6", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressTracker_TotalCompleted_Empty(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
n := makeTestNZB(2, 3)
|
||||
tracker := NewProgressTracker("empty-total", n, dir)
|
||||
|
||||
if got := tracker.TotalCompleted(); got != 0 {
|
||||
t.Errorf("TotalCompleted on fresh tracker: got %d, want 0", got)
|
||||
}
|
||||
}
|
||||
|
||||
// --- CompletedBytes edge cases ---
|
||||
|
||||
func TestProgressTracker_CompletedBytes_OutOfBounds(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
n := makeTestNZB(1, 3)
|
||||
tracker := NewProgressTracker("cb-oob", n, dir)
|
||||
|
||||
if got := tracker.CompletedBytes(-1, n.Files[0].Segments); got != 0 {
|
||||
t.Errorf("CompletedBytes with file -1: got %d, want 0", got)
|
||||
}
|
||||
if got := tracker.CompletedBytes(5, n.Files[0].Segments); got != 0 {
|
||||
t.Errorf("CompletedBytes with file 5: got %d, want 0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressTracker_CompletedBytes_AllDone(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
n := makeTestNZB(1, 3)
|
||||
tracker := NewProgressTracker("cb-all", n, dir)
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
tracker.MarkDone(0, i)
|
||||
}
|
||||
|
||||
got := tracker.CompletedBytes(0, n.Files[0].Segments)
|
||||
expected := int64(3 * 750 * 1024)
|
||||
if got != expected {
|
||||
t.Errorf("CompletedBytes all done: got %d, want %d", got, expected)
|
||||
}
|
||||
}
|
||||
|
||||
// --- CompletedSegments out of bounds ---
|
||||
|
||||
func TestProgressTracker_CompletedSegments_OutOfBounds(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
n := makeTestNZB(1, 3)
|
||||
tracker := NewProgressTracker("cs-oob", n, dir)
|
||||
|
||||
if got := tracker.CompletedSegments(-1); got != 0 {
|
||||
t.Errorf("CompletedSegments(-1) = %d, want 0", got)
|
||||
}
|
||||
if got := tracker.CompletedSegments(99); got != 0 {
|
||||
t.Errorf("CompletedSegments(99) = %d, want 0", got)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Load with corrupted / truncated data ---
|
||||
|
||||
func TestProgressTracker_Load_TruncatedHeader(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
n := makeTestNZB(1, 3)
|
||||
tracker := NewProgressTracker("trunc", n, dir)
|
||||
|
||||
// Write too-short data
|
||||
os.WriteFile(tracker.progressPath(), []byte("UNR"), 0o644)
|
||||
|
||||
loaded, err := tracker.Load()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if loaded {
|
||||
t.Error("truncated header should not load")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressTracker_Load_BadMagic(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
n := makeTestNZB(1, 3)
|
||||
tracker := NewProgressTracker("badmagic", n, dir)
|
||||
|
||||
// Write data with wrong magic bytes
|
||||
data := make([]byte, headerSize+10)
|
||||
copy(data[0:4], []byte("BAAD"))
|
||||
os.WriteFile(tracker.progressPath(), data, 0o644)
|
||||
|
||||
loaded, err := tracker.Load()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if loaded {
|
||||
t.Error("bad magic should not load")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressTracker_Load_BadVersion(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
n := makeTestNZB(1, 3)
|
||||
tracker := NewProgressTracker("badver", n, dir)
|
||||
|
||||
data := make([]byte, headerSize+10)
|
||||
copy(data[0:4], progressMagic[:])
|
||||
data[4] = 99 // unsupported version
|
||||
os.WriteFile(tracker.progressPath(), data, 0o644)
|
||||
|
||||
loaded, err := tracker.Load()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if loaded {
|
||||
t.Error("bad version should not load")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressTracker_Load_WrongFileCount(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
n := makeTestNZB(2, 3)
|
||||
tracker := NewProgressTracker("wrongfc", n, dir)
|
||||
|
||||
data := make([]byte, headerSize+20)
|
||||
copy(data[0:4], progressMagic[:])
|
||||
data[4] = progressVersion
|
||||
binary.LittleEndian.PutUint16(data[6:8], 99) // wrong file count
|
||||
copy(data[8:40], tracker.fingerprint[:])
|
||||
os.WriteFile(tracker.progressPath(), data, 0o644)
|
||||
|
||||
loaded, err := tracker.Load()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if loaded {
|
||||
t.Error("wrong file count should not load")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressTracker_Load_TruncatedBitset(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
n := makeTestNZB(1, 16)
|
||||
tracker := NewProgressTracker("truncbit", n, dir)
|
||||
|
||||
// Build a valid header but truncate the bitset data
|
||||
data := make([]byte, headerSize+4) // header + segCount but no bitset
|
||||
copy(data[0:4], progressMagic[:])
|
||||
data[4] = progressVersion
|
||||
binary.LittleEndian.PutUint16(data[6:8], 1) // 1 file
|
||||
copy(data[8:40], tracker.fingerprint[:])
|
||||
binary.LittleEndian.PutUint32(data[headerSize:headerSize+4], 16) // 16 segs
|
||||
// No bitset data follows — truncated
|
||||
os.WriteFile(tracker.progressPath(), data, 0o644)
|
||||
|
||||
loaded, err := tracker.Load()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if loaded {
|
||||
t.Error("truncated bitset should not load")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressTracker_Load_SegCountMismatch(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
n := makeTestNZB(1, 5)
|
||||
tracker := NewProgressTracker("segmis", n, dir)
|
||||
|
||||
// Build valid header with correct file count and fingerprint, but wrong segCount
|
||||
bitsetLen := (999 + 7) / 8
|
||||
data := make([]byte, headerSize+4+bitsetLen)
|
||||
copy(data[0:4], progressMagic[:])
|
||||
data[4] = progressVersion
|
||||
binary.LittleEndian.PutUint16(data[6:8], 1)
|
||||
copy(data[8:40], tracker.fingerprint[:])
|
||||
binary.LittleEndian.PutUint32(data[headerSize:headerSize+4], 999) // wrong seg count
|
||||
os.WriteFile(tracker.progressPath(), data, 0o644)
|
||||
|
||||
loaded, err := tracker.Load()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if loaded {
|
||||
t.Error("segment count mismatch should not load")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressTracker_Load_NonexistentFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
n := makeTestNZB(1, 3)
|
||||
tracker := NewProgressTracker("nofile", n, dir)
|
||||
|
||||
loaded, err := tracker.Load()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error for missing file: %v", err)
|
||||
}
|
||||
if loaded {
|
||||
t.Error("nonexistent file should return false")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Flush and Load round-trip with multiple files ---
|
||||
|
||||
func TestProgressTracker_MultiFileRoundTrip(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
n := makeTestNZB(5, 10) // 5 files, 10 segments each
|
||||
tracker := NewProgressTracker("multi-rt", n, dir)
|
||||
|
||||
// Mark various segments across files
|
||||
tracker.MarkDone(0, 0)
|
||||
tracker.MarkDone(0, 9)
|
||||
tracker.MarkDone(1, 5)
|
||||
tracker.MarkDone(2, 0)
|
||||
tracker.MarkDone(2, 1)
|
||||
tracker.MarkDone(2, 2)
|
||||
tracker.MarkDone(2, 3)
|
||||
tracker.MarkDone(2, 4)
|
||||
tracker.MarkDone(2, 5)
|
||||
tracker.MarkDone(2, 6)
|
||||
tracker.MarkDone(2, 7)
|
||||
tracker.MarkDone(2, 8)
|
||||
tracker.MarkDone(2, 9) // file 2 fully done
|
||||
tracker.MarkDone(4, 7)
|
||||
|
||||
if err := tracker.Flush(); err != nil {
|
||||
t.Fatalf("flush: %v", err)
|
||||
}
|
||||
|
||||
// Reload
|
||||
tracker2 := NewProgressTracker("multi-rt", n, dir)
|
||||
loaded, err := tracker2.Load()
|
||||
if err != nil {
|
||||
t.Fatalf("load: %v", err)
|
||||
}
|
||||
if !loaded {
|
||||
t.Fatal("should load")
|
||||
}
|
||||
|
||||
if tracker2.CompletedSegments(0) != 2 {
|
||||
t.Errorf("file 0: got %d, want 2", tracker2.CompletedSegments(0))
|
||||
}
|
||||
if tracker2.CompletedSegments(1) != 1 {
|
||||
t.Errorf("file 1: got %d, want 1", tracker2.CompletedSegments(1))
|
||||
}
|
||||
if !tracker2.IsFileDone(2) {
|
||||
t.Error("file 2 should be done")
|
||||
}
|
||||
if tracker2.CompletedSegments(3) != 0 {
|
||||
t.Errorf("file 3: got %d, want 0", tracker2.CompletedSegments(3))
|
||||
}
|
||||
if tracker2.CompletedSegments(4) != 1 {
|
||||
t.Errorf("file 4: got %d, want 1", tracker2.CompletedSegments(4))
|
||||
}
|
||||
if tracker2.TotalCompleted() != 14 {
|
||||
t.Errorf("TotalCompleted: got %d, want 14", tracker2.TotalCompleted())
|
||||
}
|
||||
}
|
||||
|
||||
// --- Concurrent mark + IsDone reads ---
|
||||
|
||||
func TestProgressTracker_ConcurrentMarkAndRead(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
segCount := 500
|
||||
n := makeTestNZB(2, segCount)
|
||||
tracker := NewProgressTracker("conc-rw", n, dir)
|
||||
|
||||
// Use separate WaitGroups for writers and readers
|
||||
var writerWg sync.WaitGroup
|
||||
stop := make(chan struct{})
|
||||
|
||||
// Writers
|
||||
for file := 0; file < 2; file++ {
|
||||
for seg := 0; seg < segCount; seg++ {
|
||||
writerWg.Add(1)
|
||||
go func(f, s int) {
|
||||
defer writerWg.Done()
|
||||
tracker.MarkDone(f, s)
|
||||
}(file, seg)
|
||||
}
|
||||
}
|
||||
|
||||
// Readers — continuously read while writes happen
|
||||
var readerWg sync.WaitGroup
|
||||
for i := 0; i < 4; i++ {
|
||||
readerWg.Add(1)
|
||||
go func() {
|
||||
defer readerWg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
default:
|
||||
// These should never panic
|
||||
tracker.IsDone(0, 0)
|
||||
tracker.IsDone(1, segCount-1)
|
||||
tracker.IsFileDone(0)
|
||||
tracker.CompletedSegments(1)
|
||||
tracker.TotalCompleted()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for all writers to finish, then stop readers
|
||||
writerWg.Wait()
|
||||
close(stop)
|
||||
readerWg.Wait()
|
||||
|
||||
// After all goroutines complete, everything should be done
|
||||
if !tracker.IsFileDone(0) {
|
||||
t.Errorf("file 0 should be done, got %d/%d", tracker.CompletedSegments(0), segCount)
|
||||
}
|
||||
if !tracker.IsFileDone(1) {
|
||||
t.Errorf("file 1 should be done, got %d/%d", tracker.CompletedSegments(1), segCount)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Concurrent flush safety ---
|
||||
|
||||
func TestProgressTracker_ConcurrentFlush(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
n := makeTestNZB(1, 100)
|
||||
tracker := NewProgressTracker("conc-flush", n, dir)
|
||||
|
||||
// Mark some segments
|
||||
for i := 0; i < 50; i++ {
|
||||
tracker.MarkDone(0, i)
|
||||
}
|
||||
|
||||
// Multiple concurrent flushes should not panic or corrupt
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
tracker.Flush()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Verify state is loadable
|
||||
tracker2 := NewProgressTracker("conc-flush", n, dir)
|
||||
loaded, err := tracker2.Load()
|
||||
if err != nil {
|
||||
t.Fatalf("load: %v", err)
|
||||
}
|
||||
if !loaded {
|
||||
t.Fatal("should load after concurrent flushes")
|
||||
}
|
||||
if tracker2.CompletedSegments(0) != 50 {
|
||||
t.Errorf("after concurrent flush: got %d, want 50", tracker2.CompletedSegments(0))
|
||||
}
|
||||
}
|
||||
|
||||
// --- Remove with .tmp file ---
|
||||
|
||||
func TestProgressTracker_Remove_WithTmpFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
n := makeTestNZB(1, 3)
|
||||
tracker := NewProgressTracker("rm-tmp", n, dir)
|
||||
|
||||
// Create all three files that Remove should clean up
|
||||
os.WriteFile(tracker.progressPath(), []byte("data"), 0o644)
|
||||
os.WriteFile(tracker.nzbPath(), []byte("<nzb/>"), 0o644)
|
||||
os.WriteFile(tracker.progressPath()+".tmp", []byte("tmp"), 0o644)
|
||||
|
||||
tracker.Remove()
|
||||
|
||||
for _, p := range []string{tracker.progressPath(), tracker.nzbPath(), tracker.progressPath() + ".tmp"} {
|
||||
if _, err := os.Stat(p); !os.IsNotExist(err) {
|
||||
t.Errorf("file should be removed: %s", p)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- CleanStaleFiles edge cases ---
|
||||
|
||||
func TestCleanStaleFiles_EmptyDir(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if got := CleanStaleFiles(dir, time.Hour); got != 0 {
|
||||
t.Errorf("empty dir: got %d removed, want 0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanStaleFiles_NonexistentDir(t *testing.T) {
|
||||
if got := CleanStaleFiles("/nonexistent/path/that/does/not/exist", time.Hour); got != 0 {
|
||||
t.Errorf("nonexistent dir: got %d removed, want 0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanStaleFiles_AllFresh(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
os.WriteFile(filepath.Join(dir, "a.progress"), []byte("a"), 0o644)
|
||||
os.WriteFile(filepath.Join(dir, "b.progress"), []byte("b"), 0o644)
|
||||
|
||||
if got := CleanStaleFiles(dir, 24*time.Hour); got != 0 {
|
||||
t.Errorf("all fresh: got %d removed, want 0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanStaleFiles_SkipsSubdirs(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
subDir := filepath.Join(dir, "subdir")
|
||||
os.MkdirAll(subDir, 0o755)
|
||||
|
||||
// Backdate the subdir (it should not be removed)
|
||||
os.Chtimes(subDir, fixedPast, fixedPast)
|
||||
|
||||
if got := CleanStaleFiles(dir, 24*time.Hour); got != 0 {
|
||||
t.Errorf("should skip subdirs: got %d removed, want 0", got)
|
||||
}
|
||||
if _, err := os.Stat(subDir); err != nil {
|
||||
t.Error("subdir should still exist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanStaleFiles_MixedAges(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
stale1 := filepath.Join(dir, "old1.progress")
|
||||
stale2 := filepath.Join(dir, "old2.nzb")
|
||||
fresh := filepath.Join(dir, "new.progress")
|
||||
|
||||
os.WriteFile(stale1, []byte("x"), 0o644)
|
||||
os.WriteFile(stale2, []byte("x"), 0o644)
|
||||
os.WriteFile(fresh, []byte("x"), 0o644)
|
||||
|
||||
os.Chtimes(stale1, fixedPast, fixedPast)
|
||||
os.Chtimes(stale2, fixedPast, fixedPast)
|
||||
|
||||
if got := CleanStaleFiles(dir, 7*24*time.Hour); got != 2 {
|
||||
t.Errorf("mixed ages: got %d removed, want 2", got)
|
||||
}
|
||||
if _, err := os.Stat(fresh); err != nil {
|
||||
t.Error("fresh file should still exist")
|
||||
}
|
||||
}
|
||||
|
||||
// --- progressPath / nzbPath ---
|
||||
|
||||
func TestProgressTracker_Paths(t *testing.T) {
|
||||
dir := "/some/dir"
|
||||
n := makeTestNZB(1, 1)
|
||||
tracker := NewProgressTracker("my-task", n, dir)
|
||||
|
||||
if got := tracker.progressPath(); got != filepath.Join(dir, "my-task.progress") {
|
||||
t.Errorf("progressPath: got %q", got)
|
||||
}
|
||||
if got := tracker.nzbPath(); got != filepath.Join(dir, "my-task.nzb") {
|
||||
t.Errorf("nzbPath: got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// --- formatBytes ---
|
||||
|
||||
func TestFormatBytes(t *testing.T) {
|
||||
tests := []struct {
|
||||
input int64
|
||||
want string
|
||||
}{
|
||||
{0, "0 B"},
|
||||
{500, "500 B"},
|
||||
{1023, "1023 B"},
|
||||
{1024, "1.0 KB"},
|
||||
{1536, "1.5 KB"},
|
||||
{1048576, "1.0 MB"},
|
||||
{1073741824, "1.0 GB"},
|
||||
{1099511627776, "1.0 TB"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := formatBytes(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("formatBytes(%d) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Single file with 1 segment (boundary) ---
|
||||
|
||||
func TestProgressTracker_SingleSegment(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
n := makeTestNZB(1, 1)
|
||||
tracker := NewProgressTracker("single-seg", n, dir)
|
||||
|
||||
if tracker.IsFileDone(0) {
|
||||
t.Error("should not be done initially")
|
||||
}
|
||||
|
||||
tracker.MarkDone(0, 0)
|
||||
|
||||
if !tracker.IsFileDone(0) {
|
||||
t.Error("should be done after marking the only segment")
|
||||
}
|
||||
|
||||
if err := tracker.Flush(); err != nil {
|
||||
t.Fatalf("flush: %v", err)
|
||||
}
|
||||
|
||||
tracker2 := NewProgressTracker("single-seg", n, dir)
|
||||
loaded, _ := tracker2.Load()
|
||||
if !loaded {
|
||||
t.Fatal("should load")
|
||||
}
|
||||
if !tracker2.IsFileDone(0) {
|
||||
t.Error("should be done after reload")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Flush creates directory if missing ---
|
||||
|
||||
func TestProgressTracker_FlushCreatesDir(t *testing.T) {
|
||||
base := t.TempDir()
|
||||
dir := filepath.Join(base, "nested", "resume")
|
||||
n := makeTestNZB(1, 2)
|
||||
tracker := NewProgressTracker("mkdir-test", n, dir)
|
||||
|
||||
tracker.MarkDone(0, 0)
|
||||
if err := tracker.Flush(); err != nil {
|
||||
t.Fatalf("flush should create dir: %v", err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(tracker.progressPath()); err != nil {
|
||||
t.Fatalf("progress file should exist: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Double flush after no new marks ---
|
||||
|
||||
func TestProgressTracker_DoubleFlush(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
n := makeTestNZB(1, 3)
|
||||
tracker := NewProgressTracker("dbl-flush", n, dir)
|
||||
|
||||
tracker.MarkDone(0, 0)
|
||||
if err := tracker.Flush(); err != nil {
|
||||
t.Fatalf("first flush: %v", err)
|
||||
}
|
||||
|
||||
// Second flush without new marks should be a no-op (dirty=false)
|
||||
if err := tracker.Flush(); err != nil {
|
||||
t.Fatalf("second flush: %v", err)
|
||||
}
|
||||
}
|
||||
131
internal/usenet/nntp/client_test.go
Normal file
131
internal/usenet/nntp/client_test.go
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
package nntp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewClient(t *testing.T) {
|
||||
c := NewClient(Config{Host: "news.example.com", Port: 563, SSL: true})
|
||||
if c.cfg.MaxConnections != 10 {
|
||||
t.Errorf("default MaxConnections = %d, want 10", c.cfg.MaxConnections)
|
||||
}
|
||||
if c.cfg.Host != "news.example.com" {
|
||||
t.Errorf("Host = %q", c.cfg.Host)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewClientCustomConnections(t *testing.T) {
|
||||
c := NewClient(Config{Host: "news.example.com", Port: 563, MaxConnections: 20})
|
||||
if c.cfg.MaxConnections != 20 {
|
||||
t.Errorf("MaxConnections = %d, want 20", c.cfg.MaxConnections)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewClientZeroConnections(t *testing.T) {
|
||||
c := NewClient(Config{Host: "news.example.com", Port: 563, MaxConnections: 0})
|
||||
if c.cfg.MaxConnections != 10 {
|
||||
t.Errorf("MaxConnections should default to 10, got %d", c.cfg.MaxConnections)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewClientNegativeConnections(t *testing.T) {
|
||||
c := NewClient(Config{MaxConnections: -5})
|
||||
if c.cfg.MaxConnections != 10 {
|
||||
t.Errorf("MaxConnections should default to 10 for negative, got %d", c.cfg.MaxConnections)
|
||||
}
|
||||
}
|
||||
|
||||
func TestActiveConnections(t *testing.T) {
|
||||
c := NewClient(Config{Host: "localhost", Port: 119})
|
||||
if c.ActiveConnections() != 0 {
|
||||
t.Errorf("ActiveConnections = %d, want 0", c.ActiveConnections())
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatus(t *testing.T) {
|
||||
c := NewClient(Config{Host: "news.example.com", Port: 563})
|
||||
s := c.Status()
|
||||
if s != "0 connections (0 pooled) to news.example.com:563" {
|
||||
t.Errorf("Status = %q", s)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseIdempotent(t *testing.T) {
|
||||
c := NewClient(Config{Host: "localhost", Port: 119})
|
||||
// Close should be idempotent
|
||||
if err := c.Close(); err != nil {
|
||||
t.Errorf("first Close: %v", err)
|
||||
}
|
||||
if err := c.Close(); err != nil {
|
||||
t.Errorf("second Close: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestArticleNotFoundError(t *testing.T) {
|
||||
err := &ArticleNotFoundError{MessageID: "abc123@news.example.com"}
|
||||
msg := err.Error()
|
||||
if msg != "nntp: article not found: abc123@news.example.com" {
|
||||
t.Errorf("Error() = %q", msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadDotBody(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
"simple body",
|
||||
"Hello World\r\n.\r\n",
|
||||
"Hello World\n",
|
||||
},
|
||||
{
|
||||
"multiline",
|
||||
"Line 1\r\nLine 2\r\nLine 3\r\n.\r\n",
|
||||
"Line 1\nLine 2\nLine 3\n",
|
||||
},
|
||||
{
|
||||
"dot-stuffed line",
|
||||
"..This starts with a dot\r\n.\r\n",
|
||||
".This starts with a dot\n",
|
||||
},
|
||||
{
|
||||
"empty body",
|
||||
".\r\n",
|
||||
"",
|
||||
},
|
||||
{
|
||||
"binary-like data",
|
||||
"=ybegin line=128 size=1024 name=test.bin\r\nsome encoded data\r\n=yend\r\n.\r\n",
|
||||
"=ybegin line=128 size=1024 name=test.bin\nsome encoded data\n=yend\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := bufio.NewReader(bytes.NewBufferString(tt.input))
|
||||
got, err := readDotBody(r)
|
||||
if err != nil {
|
||||
t.Fatalf("readDotBody: %v", err)
|
||||
}
|
||||
if string(got) != tt.want {
|
||||
t.Errorf("readDotBody = %q, want %q", string(got), tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadDotBodyEOF(t *testing.T) {
|
||||
// No dot terminator — should read until EOF
|
||||
r := bufio.NewReader(bytes.NewBufferString("partial data\r\n"))
|
||||
got, err := readDotBody(r)
|
||||
if err != nil {
|
||||
t.Fatalf("readDotBody EOF: %v", err)
|
||||
}
|
||||
if string(got) != "partial data\n" {
|
||||
t.Errorf("readDotBody EOF = %q", string(got))
|
||||
}
|
||||
}
|
||||
|
|
@ -267,3 +267,784 @@ func TestStripAngleBrackets(t *testing.T) {
|
|||
t.Errorf("MessageID not stripped: got %q", nzb.Files[0].Segments[0].MessageID)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Malformed / edge-case XML inputs ---
|
||||
|
||||
func TestParse_CompletelyEmpty(t *testing.T) {
|
||||
_, err := Parse(strings.NewReader(""))
|
||||
if err == nil {
|
||||
t.Error("expected error for completely empty input")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_OnlyWhitespace(t *testing.T) {
|
||||
_, err := Parse(strings.NewReader(" \n\t "))
|
||||
if err == nil {
|
||||
t.Error("expected error for whitespace-only input")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_ValidXMLButNotNZB(t *testing.T) {
|
||||
_, err := Parse(strings.NewReader(`<?xml version="1.0"?><html><body>Hello</body></html>`))
|
||||
if err == nil {
|
||||
t.Error("expected error for non-NZB XML")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_NZBWithNoSegments(t *testing.T) {
|
||||
xml := `<?xml version="1.0"?>
|
||||
<nzb xmlns="http://www.newzbin.com/DTD/2003/nzb">
|
||||
<file poster="test" date="0" subject=""test.bin"">
|
||||
<groups><group>alt.test</group></groups>
|
||||
<segments></segments>
|
||||
</file>
|
||||
</nzb>`
|
||||
_, err := Parse(strings.NewReader(xml))
|
||||
if err == nil {
|
||||
t.Error("expected error for file with no segments")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_SegmentWithEmptyMessageID(t *testing.T) {
|
||||
xml := `<?xml version="1.0"?>
|
||||
<nzb xmlns="http://www.newzbin.com/DTD/2003/nzb">
|
||||
<file poster="test" date="0" subject=""test.bin"">
|
||||
<groups><group>alt.test</group></groups>
|
||||
<segments>
|
||||
<segment bytes="100" number="1"> </segment>
|
||||
</segments>
|
||||
</file>
|
||||
</nzb>`
|
||||
_, err := Parse(strings.NewReader(xml))
|
||||
if err == nil {
|
||||
t.Error("expected error: segment with empty/whitespace message ID should be skipped, leaving no valid files")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_MixedValidAndEmptySegments(t *testing.T) {
|
||||
xml := `<?xml version="1.0"?>
|
||||
<nzb xmlns="http://www.newzbin.com/DTD/2003/nzb">
|
||||
<file poster="test" date="0" subject=""test.bin"">
|
||||
<groups><group>alt.test</group></groups>
|
||||
<segments>
|
||||
<segment bytes="100" number="1">valid@id</segment>
|
||||
<segment bytes="200" number="2"> </segment>
|
||||
<segment bytes="300" number="3">also-valid@id</segment>
|
||||
</segments>
|
||||
</file>
|
||||
</nzb>`
|
||||
nzb, err := Parse(strings.NewReader(xml))
|
||||
if err != nil {
|
||||
t.Fatalf("Parse failed: %v", err)
|
||||
}
|
||||
if len(nzb.Files[0].Segments) != 2 {
|
||||
t.Errorf("expected 2 valid segments, got %d", len(nzb.Files[0].Segments))
|
||||
}
|
||||
}
|
||||
|
||||
// --- Metadata / Head parsing ---
|
||||
|
||||
func TestParse_MetaPassword(t *testing.T) {
|
||||
xml := `<?xml version="1.0"?>
|
||||
<nzb xmlns="http://www.newzbin.com/DTD/2003/nzb">
|
||||
<head>
|
||||
<meta type="password">s3cr3t</meta>
|
||||
<meta type="title">My Movie</meta>
|
||||
<meta type="category">Movies</meta>
|
||||
</head>
|
||||
<file poster="test" date="0" subject=""test.bin"">
|
||||
<groups><group>alt.test</group></groups>
|
||||
<segments>
|
||||
<segment bytes="100" number="1">seg@id</segment>
|
||||
</segments>
|
||||
</file>
|
||||
</nzb>`
|
||||
nzb, err := Parse(strings.NewReader(xml))
|
||||
if err != nil {
|
||||
t.Fatalf("Parse failed: %v", err)
|
||||
}
|
||||
if nzb.Password != "s3cr3t" {
|
||||
t.Errorf("Password: got %q, want %q", nzb.Password, "s3cr3t")
|
||||
}
|
||||
if nzb.Meta["title"] != "My Movie" {
|
||||
t.Errorf("Meta title: got %q", nzb.Meta["title"])
|
||||
}
|
||||
if nzb.Meta["category"] != "Movies" {
|
||||
t.Errorf("Meta category: got %q", nzb.Meta["category"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_MetaPasswordWithWhitespace(t *testing.T) {
|
||||
xml := `<?xml version="1.0"?>
|
||||
<nzb xmlns="http://www.newzbin.com/DTD/2003/nzb">
|
||||
<head>
|
||||
<meta type="password"> padded </meta>
|
||||
</head>
|
||||
<file poster="test" date="0" subject=""test.bin"">
|
||||
<groups><group>alt.test</group></groups>
|
||||
<segments>
|
||||
<segment bytes="100" number="1">seg@id</segment>
|
||||
</segments>
|
||||
</file>
|
||||
</nzb>`
|
||||
nzb, err := Parse(strings.NewReader(xml))
|
||||
if err != nil {
|
||||
t.Fatalf("Parse failed: %v", err)
|
||||
}
|
||||
if nzb.Password != "padded" {
|
||||
t.Errorf("Password should be trimmed: got %q", nzb.Password)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_NoHead(t *testing.T) {
|
||||
xml := `<?xml version="1.0"?>
|
||||
<nzb xmlns="http://www.newzbin.com/DTD/2003/nzb">
|
||||
<file poster="test" date="0" subject=""test.bin"">
|
||||
<groups><group>alt.test</group></groups>
|
||||
<segments>
|
||||
<segment bytes="100" number="1">seg@id</segment>
|
||||
</segments>
|
||||
</file>
|
||||
</nzb>`
|
||||
nzb, err := Parse(strings.NewReader(xml))
|
||||
if err != nil {
|
||||
t.Fatalf("Parse failed: %v", err)
|
||||
}
|
||||
if nzb.Password != "" {
|
||||
t.Errorf("Password should be empty: got %q", nzb.Password)
|
||||
}
|
||||
if len(nzb.Meta) != 0 {
|
||||
t.Errorf("Meta should be empty: got %v", nzb.Meta)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_MetaWithEmptyType(t *testing.T) {
|
||||
xml := `<?xml version="1.0"?>
|
||||
<nzb xmlns="http://www.newzbin.com/DTD/2003/nzb">
|
||||
<head>
|
||||
<meta type="">ignored</meta>
|
||||
<meta type="name">kept</meta>
|
||||
</head>
|
||||
<file poster="test" date="0" subject=""test.bin"">
|
||||
<groups><group>alt.test</group></groups>
|
||||
<segments>
|
||||
<segment bytes="100" number="1">seg@id</segment>
|
||||
</segments>
|
||||
</file>
|
||||
</nzb>`
|
||||
nzb, err := Parse(strings.NewReader(xml))
|
||||
if err != nil {
|
||||
t.Fatalf("Parse failed: %v", err)
|
||||
}
|
||||
if _, ok := nzb.Meta[""]; ok {
|
||||
t.Error("empty-type meta should not be stored")
|
||||
}
|
||||
if nzb.Meta["name"] != "kept" {
|
||||
t.Errorf("Meta name: got %q", nzb.Meta["name"])
|
||||
}
|
||||
}
|
||||
|
||||
// --- Multiple files ---
|
||||
|
||||
func TestParse_MultipleFilesVariousTypes(t *testing.T) {
|
||||
xml := `<?xml version="1.0"?>
|
||||
<nzb xmlns="http://www.newzbin.com/DTD/2003/nzb">
|
||||
<file poster="bot" date="1700000000" subject=""movie.mkv" yEnc (1/100)">
|
||||
<groups><group>alt.binaries.movies</group></groups>
|
||||
<segments>
|
||||
<segment bytes="768000" number="1">mkv001@ex</segment>
|
||||
<segment bytes="768000" number="2">mkv002@ex</segment>
|
||||
</segments>
|
||||
</file>
|
||||
<file poster="bot" date="1700000000" subject=""movie.nfo" yEnc (1/1)">
|
||||
<groups><group>alt.binaries.movies</group></groups>
|
||||
<segments>
|
||||
<segment bytes="4096" number="1">nfo001@ex</segment>
|
||||
</segments>
|
||||
</file>
|
||||
<file poster="bot" date="1700000000" subject=""movie.par2" yEnc (1/1)">
|
||||
<groups><group>alt.binaries.movies</group></groups>
|
||||
<segments>
|
||||
<segment bytes="32768" number="1">par001@ex</segment>
|
||||
</segments>
|
||||
</file>
|
||||
<file poster="bot" date="1700000000" subject=""movie.vol0+1.par2" yEnc (1/1)">
|
||||
<groups><group>alt.binaries.movies</group></groups>
|
||||
<segments>
|
||||
<segment bytes="65536" number="1">parv001@ex</segment>
|
||||
</segments>
|
||||
</file>
|
||||
<file poster="bot" date="1700000000" subject=""sample.mkv" yEnc (1/1)">
|
||||
<groups><group>alt.binaries.movies</group></groups>
|
||||
<segments>
|
||||
<segment bytes="10000" number="1">sample001@ex</segment>
|
||||
</segments>
|
||||
</file>
|
||||
</nzb>`
|
||||
nzb, err := Parse(strings.NewReader(xml))
|
||||
if err != nil {
|
||||
t.Fatalf("Parse failed: %v", err)
|
||||
}
|
||||
|
||||
if len(nzb.Files) != 5 {
|
||||
t.Fatalf("expected 5 files, got %d", len(nzb.Files))
|
||||
}
|
||||
|
||||
// ContentFiles should exclude nfo, par2, par2 vol, and sample
|
||||
content := nzb.ContentFiles()
|
||||
if len(content) != 1 {
|
||||
t.Errorf("ContentFiles: got %d, want 1", len(content))
|
||||
}
|
||||
if len(content) > 0 && content[0].Filename() != "movie.mkv" {
|
||||
t.Errorf("ContentFiles[0]: got %q, want movie.mkv", content[0].Filename())
|
||||
}
|
||||
|
||||
// Par2Files
|
||||
par2 := nzb.Par2Files()
|
||||
if len(par2) != 2 {
|
||||
t.Errorf("Par2Files: got %d, want 2", len(par2))
|
||||
}
|
||||
|
||||
if !nzb.HasPar2() {
|
||||
t.Error("HasPar2 should be true")
|
||||
}
|
||||
if nzb.HasRars() {
|
||||
t.Error("HasRars should be false for this NZB")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Segment ordering / number parsing ---
|
||||
|
||||
func TestParse_SegmentNumberParsing(t *testing.T) {
|
||||
xml := `<?xml version="1.0"?>
|
||||
<nzb xmlns="http://www.newzbin.com/DTD/2003/nzb">
|
||||
<file poster="test" date="0" subject=""test.bin"">
|
||||
<groups><group>alt.test</group></groups>
|
||||
<segments>
|
||||
<segment bytes="100" number="3">c@id</segment>
|
||||
<segment bytes="200" number="1">a@id</segment>
|
||||
<segment bytes="300" number="2">b@id</segment>
|
||||
</segments>
|
||||
</file>
|
||||
</nzb>`
|
||||
nzb, err := Parse(strings.NewReader(xml))
|
||||
if err != nil {
|
||||
t.Fatalf("Parse failed: %v", err)
|
||||
}
|
||||
|
||||
segs := nzb.Files[0].Segments
|
||||
if len(segs) != 3 {
|
||||
t.Fatalf("expected 3 segments, got %d", len(segs))
|
||||
}
|
||||
|
||||
// Parse preserves order from XML; sorting is done by the downloader
|
||||
// Verify numbers are parsed correctly
|
||||
numbers := make(map[int]bool)
|
||||
for _, s := range segs {
|
||||
numbers[s.Number] = true
|
||||
}
|
||||
for _, want := range []int{1, 2, 3} {
|
||||
if !numbers[want] {
|
||||
t.Errorf("missing segment number %d", want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_SegmentBytesZero(t *testing.T) {
|
||||
xml := `<?xml version="1.0"?>
|
||||
<nzb xmlns="http://www.newzbin.com/DTD/2003/nzb">
|
||||
<file poster="test" date="0" subject=""test.bin"">
|
||||
<groups><group>alt.test</group></groups>
|
||||
<segments>
|
||||
<segment bytes="0" number="1">seg@id</segment>
|
||||
</segments>
|
||||
</file>
|
||||
</nzb>`
|
||||
nzb, err := Parse(strings.NewReader(xml))
|
||||
if err != nil {
|
||||
t.Fatalf("Parse failed: %v", err)
|
||||
}
|
||||
if nzb.Files[0].Segments[0].Bytes != 0 {
|
||||
t.Errorf("expected 0 bytes, got %d", nzb.Files[0].Segments[0].Bytes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_SegmentBytesNonNumeric(t *testing.T) {
|
||||
xml := `<?xml version="1.0"?>
|
||||
<nzb xmlns="http://www.newzbin.com/DTD/2003/nzb">
|
||||
<file poster="test" date="0" subject=""test.bin"">
|
||||
<groups><group>alt.test</group></groups>
|
||||
<segments>
|
||||
<segment bytes="abc" number="1">seg@id</segment>
|
||||
</segments>
|
||||
</file>
|
||||
</nzb>`
|
||||
nzb, err := Parse(strings.NewReader(xml))
|
||||
if err != nil {
|
||||
t.Fatalf("Parse failed: %v", err)
|
||||
}
|
||||
// Non-numeric bytes should parse as 0
|
||||
if nzb.Files[0].Segments[0].Bytes != 0 {
|
||||
t.Errorf("non-numeric bytes should be 0, got %d", nzb.Files[0].Segments[0].Bytes)
|
||||
}
|
||||
}
|
||||
|
||||
// --- File helper methods ---
|
||||
|
||||
func TestFileTotalBytes(t *testing.T) {
|
||||
f := File{
|
||||
Segments: []Segment{
|
||||
{Bytes: 100}, {Bytes: 200}, {Bytes: 300},
|
||||
},
|
||||
}
|
||||
if got := f.TotalBytes(); got != 600 {
|
||||
t.Errorf("TotalBytes: got %d, want 600", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileTotalBytes_Empty(t *testing.T) {
|
||||
f := File{}
|
||||
if got := f.TotalBytes(); got != 0 {
|
||||
t.Errorf("TotalBytes of empty file: got %d, want 0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileExtension_Various(t *testing.T) {
|
||||
tests := []struct {
|
||||
subject string
|
||||
want string
|
||||
}{
|
||||
{`"file.MKV" yEnc`, ".mkv"},
|
||||
{`"file.RAR" yEnc`, ".rar"},
|
||||
{`"file.Par2" yEnc`, ".par2"},
|
||||
{`"noext" yEnc`, ""},
|
||||
{`"file.tar.gz" yEnc`, ".gz"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
f := File{Subject: tt.subject}
|
||||
if got := f.Extension(); got != tt.want {
|
||||
t.Errorf("Extension(%q) = %q, want %q", tt.subject, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- LargestFile edge cases ---
|
||||
|
||||
func TestLargestFile_EmptyNZB(t *testing.T) {
|
||||
nzb := &NZB{}
|
||||
if nzb.LargestFile() != nil {
|
||||
t.Error("LargestFile should return nil for empty NZB")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLargestFile_SingleFile(t *testing.T) {
|
||||
nzb := &NZB{
|
||||
Files: []File{
|
||||
{Subject: `"only.bin"`, Segments: []Segment{{Bytes: 100}}},
|
||||
},
|
||||
}
|
||||
largest := nzb.LargestFile()
|
||||
if largest == nil {
|
||||
t.Fatal("LargestFile should not be nil")
|
||||
}
|
||||
if largest.Filename() != "only.bin" {
|
||||
t.Errorf("got %q", largest.Filename())
|
||||
}
|
||||
}
|
||||
|
||||
func TestLargestFile_MultipleSameSize(t *testing.T) {
|
||||
nzb := &NZB{
|
||||
Files: []File{
|
||||
{Subject: `"a.bin"`, Segments: []Segment{{Bytes: 100}}},
|
||||
{Subject: `"b.bin"`, Segments: []Segment{{Bytes: 100}}},
|
||||
},
|
||||
}
|
||||
largest := nzb.LargestFile()
|
||||
if largest == nil {
|
||||
t.Fatal("LargestFile should not be nil")
|
||||
}
|
||||
// Should return the first one (stable)
|
||||
if largest.Filename() != "a.bin" {
|
||||
t.Errorf("got %q, expected first file for equal sizes", largest.Filename())
|
||||
}
|
||||
}
|
||||
|
||||
// --- IsObfuscated ---
|
||||
|
||||
func TestIsObfuscated_Normal(t *testing.T) {
|
||||
nzb := &NZB{
|
||||
Files: []File{
|
||||
{Subject: `"Movie.2024.1080p.BluRay.x264-GROUP.mkv"`},
|
||||
},
|
||||
}
|
||||
if nzb.IsObfuscated() {
|
||||
t.Error("normal filename should not be obfuscated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsObfuscated_HexName(t *testing.T) {
|
||||
nzb := &NZB{
|
||||
Files: []File{
|
||||
{Subject: `"a1b2c3d4e5f6a7b8c9d0e1f2.mkv"`},
|
||||
},
|
||||
}
|
||||
if !nzb.IsObfuscated() {
|
||||
t.Error("hex-like filename should be obfuscated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsObfuscated_EmptyFiles(t *testing.T) {
|
||||
nzb := &NZB{}
|
||||
if nzb.IsObfuscated() {
|
||||
t.Error("empty NZB should not be obfuscated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsObfuscated_ShortHex(t *testing.T) {
|
||||
// Short name (<=10 chars) should not trigger obfuscation
|
||||
nzb := &NZB{
|
||||
Files: []File{
|
||||
{Subject: `"abcdef.mkv"`},
|
||||
},
|
||||
}
|
||||
if nzb.IsObfuscated() {
|
||||
t.Error("short hex-like name should not be obfuscated")
|
||||
}
|
||||
}
|
||||
|
||||
// --- isMetadataFile ---
|
||||
|
||||
func TestIsMetadataFile(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
want bool
|
||||
}{
|
||||
{"file.par2", true},
|
||||
{"file.nfo", true},
|
||||
{"file.sfv", true},
|
||||
{"file.nzb", true},
|
||||
{"file.txt", true},
|
||||
{"file.jpg", true},
|
||||
{"file.png", true},
|
||||
{"file.url", true},
|
||||
{"file.mkv", false},
|
||||
{"file.rar", false},
|
||||
{"file.avi", false},
|
||||
{"FILE.PAR2", true},
|
||||
{"FILE.NFO", true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := isMetadataFile(tt.name); got != tt.want {
|
||||
t.Errorf("isMetadataFile(%q) = %v, want %v", tt.name, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- isSampleFile ---
|
||||
|
||||
func TestIsSampleFile(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
want bool
|
||||
}{
|
||||
{"movie.sample.mkv", true},
|
||||
{"Sample.mkv", true},
|
||||
{"SAMPLE.avi", true},
|
||||
{"movie-sample-video.mkv", true},
|
||||
{"movie_sample.mkv", true},
|
||||
{"sample.mkv", true},
|
||||
{"resampled.mkv", false}, // "sample" is part of "resampled"
|
||||
{"movie.mkv", false},
|
||||
{"my.samples.zip", false}, // "sample" followed by 's' (alphanumeric)
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := isSampleFile(tt.name); got != tt.want {
|
||||
t.Errorf("isSampleFile(%q) = %v, want %v", tt.name, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- isHexLike ---
|
||||
|
||||
func TestIsHexLike(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want bool
|
||||
}{
|
||||
{"abcdef0123456789", true},
|
||||
{"ABCDEF", true},
|
||||
{"Movie2024", false},
|
||||
{"aabbccdd", true},
|
||||
{"xyz_not_hex", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := isHexLike(tt.input); got != tt.want {
|
||||
t.Errorf("isHexLike(%q) = %v, want %v", tt.input, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- sanitizeFilename ---
|
||||
|
||||
func TestSanitizeFilename(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"simple name", "simple name"},
|
||||
{"name (1/50)", "name"},
|
||||
{"file yEnc (01/99)", "file"},
|
||||
{`path/with\special:chars*?`, `path_with_special_chars__`},
|
||||
{`"quoted" text`, `_quoted_ text`},
|
||||
{" spaces ", "spaces"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := sanitizeFilename(tt.input); got != tt.want {
|
||||
t.Errorf("sanitizeFilename(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Filename fallback ---
|
||||
|
||||
func TestFilename_Fallback_NoQuotes(t *testing.T) {
|
||||
f := File{Subject: "No quotes here yEnc (1/50)"}
|
||||
got := f.Filename()
|
||||
if got != "No quotes here" {
|
||||
t.Errorf("Filename fallback: got %q, want %q", got, "No quotes here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilename_EmptySubject(t *testing.T) {
|
||||
f := File{Subject: ""}
|
||||
got := f.Filename()
|
||||
if got != "" {
|
||||
t.Errorf("Filename empty subject: got %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
// --- NZB aggregate methods on mixed content ---
|
||||
|
||||
func TestNZB_HasRars_NoRars(t *testing.T) {
|
||||
nzb := &NZB{
|
||||
Files: []File{
|
||||
{Subject: `"movie.mkv"`},
|
||||
{Subject: `"movie.par2"`},
|
||||
},
|
||||
}
|
||||
if nzb.HasRars() {
|
||||
t.Error("HasRars should be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNZB_HasPar2_NoPar2(t *testing.T) {
|
||||
nzb := &NZB{
|
||||
Files: []File{
|
||||
{Subject: `"movie.mkv"`},
|
||||
{Subject: `"movie.rar"`},
|
||||
},
|
||||
}
|
||||
if nzb.HasPar2() {
|
||||
t.Error("HasPar2 should be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNZB_TotalSegments_MultiFile(t *testing.T) {
|
||||
nzb := &NZB{
|
||||
Files: []File{
|
||||
{Segments: []Segment{{}, {}, {}}},
|
||||
{Segments: []Segment{{}, {}}},
|
||||
},
|
||||
}
|
||||
if got := nzb.TotalSegments(); got != 5 {
|
||||
t.Errorf("TotalSegments: got %d, want 5", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNZB_TotalBytes_MultiFile(t *testing.T) {
|
||||
nzb := &NZB{
|
||||
Files: []File{
|
||||
{Segments: []Segment{{Bytes: 100}, {Bytes: 200}}},
|
||||
{Segments: []Segment{{Bytes: 300}}},
|
||||
},
|
||||
}
|
||||
if got := nzb.TotalBytes(); got != 600 {
|
||||
t.Errorf("TotalBytes: got %d, want 600", got)
|
||||
}
|
||||
}
|
||||
|
||||
// --- isRarFile extended ---
|
||||
|
||||
func TestIsRarFile_Extended(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
want bool
|
||||
}{
|
||||
{"file.RAR", true}, // case insensitive
|
||||
{"file.Rar", true},
|
||||
{"file.s01", true},
|
||||
{"file.s99", true},
|
||||
{"file.002", true},
|
||||
{"file.999", true},
|
||||
{"file.r0", false}, // too short extension
|
||||
{"file.rXX", false}, // non-numeric
|
||||
{"file", false}, // no extension
|
||||
{"file.mp4", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := isRarFile(tt.name); got != tt.want {
|
||||
t.Errorf("isRarFile(%q) = %v, want %v", tt.name, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Parse with date edge cases ---
|
||||
|
||||
func TestParse_DateNonNumeric(t *testing.T) {
|
||||
xml := `<?xml version="1.0"?>
|
||||
<nzb xmlns="http://www.newzbin.com/DTD/2003/nzb">
|
||||
<file poster="test" date="not-a-number" subject=""test.bin"">
|
||||
<groups><group>alt.test</group></groups>
|
||||
<segments>
|
||||
<segment bytes="100" number="1">seg@id</segment>
|
||||
</segments>
|
||||
</file>
|
||||
</nzb>`
|
||||
nzb, err := Parse(strings.NewReader(xml))
|
||||
if err != nil {
|
||||
t.Fatalf("Parse failed: %v", err)
|
||||
}
|
||||
if nzb.Files[0].Date != 0 {
|
||||
t.Errorf("non-numeric date should be 0, got %d", nzb.Files[0].Date)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_DateEmpty(t *testing.T) {
|
||||
xml := `<?xml version="1.0"?>
|
||||
<nzb xmlns="http://www.newzbin.com/DTD/2003/nzb">
|
||||
<file poster="test" date="" subject=""test.bin"">
|
||||
<groups><group>alt.test</group></groups>
|
||||
<segments>
|
||||
<segment bytes="100" number="1">seg@id</segment>
|
||||
</segments>
|
||||
</file>
|
||||
</nzb>`
|
||||
nzb, err := Parse(strings.NewReader(xml))
|
||||
if err != nil {
|
||||
t.Fatalf("Parse failed: %v", err)
|
||||
}
|
||||
if nzb.Files[0].Date != 0 {
|
||||
t.Errorf("empty date should be 0, got %d", nzb.Files[0].Date)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Parse: file with all segments having empty IDs should be excluded ---
|
||||
|
||||
func TestParse_AllEmptySegments(t *testing.T) {
|
||||
xml := `<?xml version="1.0"?>
|
||||
<nzb xmlns="http://www.newzbin.com/DTD/2003/nzb">
|
||||
<file poster="test" date="0" subject=""bad.bin"">
|
||||
<groups><group>alt.test</group></groups>
|
||||
<segments>
|
||||
<segment bytes="100" number="1"> </segment>
|
||||
<segment bytes="200" number="2"></segment>
|
||||
</segments>
|
||||
</file>
|
||||
<file poster="test" date="0" subject=""good.bin"">
|
||||
<groups><group>alt.test</group></groups>
|
||||
<segments>
|
||||
<segment bytes="100" number="1">valid@id</segment>
|
||||
</segments>
|
||||
</file>
|
||||
</nzb>`
|
||||
nzb, err := Parse(strings.NewReader(xml))
|
||||
if err != nil {
|
||||
t.Fatalf("Parse failed: %v", err)
|
||||
}
|
||||
if len(nzb.Files) != 1 {
|
||||
t.Fatalf("expected 1 valid file, got %d", len(nzb.Files))
|
||||
}
|
||||
if nzb.Files[0].Filename() != "good.bin" {
|
||||
t.Errorf("expected good.bin, got %q", nzb.Files[0].Filename())
|
||||
}
|
||||
}
|
||||
|
||||
// --- Groups ---
|
||||
|
||||
func TestParse_NoGroups(t *testing.T) {
|
||||
xml := `<?xml version="1.0"?>
|
||||
<nzb xmlns="http://www.newzbin.com/DTD/2003/nzb">
|
||||
<file poster="test" date="0" subject=""test.bin"">
|
||||
<groups></groups>
|
||||
<segments>
|
||||
<segment bytes="100" number="1">seg@id</segment>
|
||||
</segments>
|
||||
</file>
|
||||
</nzb>`
|
||||
nzb, err := Parse(strings.NewReader(xml))
|
||||
if err != nil {
|
||||
t.Fatalf("Parse failed: %v", err)
|
||||
}
|
||||
if len(nzb.Files[0].Groups) != 0 {
|
||||
t.Errorf("expected 0 groups, got %d", len(nzb.Files[0].Groups))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_MultipleGroups(t *testing.T) {
|
||||
xml := `<?xml version="1.0"?>
|
||||
<nzb xmlns="http://www.newzbin.com/DTD/2003/nzb">
|
||||
<file poster="test" date="0" subject=""test.bin"">
|
||||
<groups>
|
||||
<group>alt.binaries.movies</group>
|
||||
<group>alt.binaries.multimedia</group>
|
||||
<group>alt.binaries.hdtv</group>
|
||||
</groups>
|
||||
<segments>
|
||||
<segment bytes="100" number="1">seg@id</segment>
|
||||
</segments>
|
||||
</file>
|
||||
</nzb>`
|
||||
nzb, err := Parse(strings.NewReader(xml))
|
||||
if err != nil {
|
||||
t.Fatalf("Parse failed: %v", err)
|
||||
}
|
||||
if len(nzb.Files[0].Groups) != 3 {
|
||||
t.Errorf("expected 3 groups, got %d", len(nzb.Files[0].Groups))
|
||||
}
|
||||
}
|
||||
|
||||
// --- ContentFiles with sample variations ---
|
||||
|
||||
func TestContentFiles_ExcludesSamples(t *testing.T) {
|
||||
nzb := &NZB{
|
||||
Files: []File{
|
||||
{Subject: `"movie.mkv"`, Segments: []Segment{{Bytes: 1000, MessageID: "a"}}},
|
||||
{Subject: `"movie.sample.mkv"`, Segments: []Segment{{Bytes: 100, MessageID: "b"}}},
|
||||
{Subject: `"Sample/preview.mkv"`, Segments: []Segment{{Bytes: 100, MessageID: "c"}}},
|
||||
},
|
||||
}
|
||||
content := nzb.ContentFiles()
|
||||
if len(content) != 1 {
|
||||
t.Errorf("ContentFiles should exclude samples: got %d, want 1", len(content))
|
||||
}
|
||||
}
|
||||
|
||||
// --- RarFiles with split naming ---
|
||||
|
||||
func TestRarFiles_SplitRars(t *testing.T) {
|
||||
nzb := &NZB{
|
||||
Files: []File{
|
||||
{Subject: `"movie.rar"`, Segments: []Segment{{MessageID: "a"}}},
|
||||
{Subject: `"movie.r00"`, Segments: []Segment{{MessageID: "b"}}},
|
||||
{Subject: `"movie.r01"`, Segments: []Segment{{MessageID: "c"}}},
|
||||
{Subject: `"movie.001"`, Segments: []Segment{{MessageID: "d"}}},
|
||||
{Subject: `"movie.002"`, Segments: []Segment{{MessageID: "e"}}},
|
||||
{Subject: `"movie.par2"`, Segments: []Segment{{MessageID: "f"}}},
|
||||
{Subject: `"movie.mkv"`, Segments: []Segment{{MessageID: "g"}}},
|
||||
},
|
||||
}
|
||||
rars := nzb.RarFiles()
|
||||
if len(rars) != 5 {
|
||||
t.Errorf("RarFiles: got %d, want 5", len(rars))
|
||||
}
|
||||
}
|
||||
|
|
|
|||
170
internal/usenet/postprocess/extract_test.go
Normal file
170
internal/usenet/postprocess/extract_test.go
Normal file
|
|
@ -0,0 +1,170 @@
|
|||
package postprocess
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsArchiveFile(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
want bool
|
||||
}{
|
||||
{"movie.rar", true},
|
||||
{"movie.RAR", true},
|
||||
{"movie.part01.rar", true},
|
||||
{"movie.r00", true},
|
||||
{"movie.r99", true},
|
||||
{"movie.s00", true},
|
||||
{"movie.001", true},
|
||||
{"movie.099", true},
|
||||
{"movie.mkv", false},
|
||||
{"movie.mp4", false},
|
||||
{"movie.par2", false},
|
||||
{"movie.nfo", false},
|
||||
{"movie.txt", false},
|
||||
{"movie.r", false},
|
||||
{"movie.abc", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := isArchiveFile(tt.name)
|
||||
if got != tt.want {
|
||||
t.Errorf("isArchiveFile(%q) = %v, want %v", tt.name, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsCleanupTarget(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
want bool
|
||||
}{
|
||||
{"content.par2", true},
|
||||
{"content.PAR2", true},
|
||||
{"info.nfo", true},
|
||||
{"checksum.sfv", true},
|
||||
{"content.nzb", true},
|
||||
{"content.srr", true},
|
||||
{"content.srs", true},
|
||||
{"cover.jpg", true},
|
||||
{"cover.png", true},
|
||||
{"readme.txt", true},
|
||||
{"link.url", true},
|
||||
{"movie.rar", true},
|
||||
{"movie.r00", true},
|
||||
{"movie.s01", true},
|
||||
{"movie.001", true},
|
||||
{"movie.mkv", false},
|
||||
{"movie.mp4", false},
|
||||
{"movie.avi", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := isCleanupTarget(tt.name)
|
||||
if got != tt.want {
|
||||
t.Errorf("isCleanupTarget(%q) = %v, want %v", tt.name, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsNumeric(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want bool
|
||||
}{
|
||||
{"", false},
|
||||
{"0", true},
|
||||
{"123", true},
|
||||
{"00", true},
|
||||
{"12a", false},
|
||||
{"abc", false},
|
||||
{" 1", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
got := isNumeric(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("isNumeric(%q) = %v, want %v", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestListExtractedFiles(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create some files
|
||||
os.WriteFile(filepath.Join(dir, "movie.mkv"), []byte("video"), 0o644)
|
||||
os.WriteFile(filepath.Join(dir, "subs.srt"), []byte("subs"), 0o644)
|
||||
os.WriteFile(filepath.Join(dir, "movie.rar"), []byte("archive"), 0o644)
|
||||
os.WriteFile(filepath.Join(dir, "movie.r00"), []byte("archive part"), 0o644)
|
||||
|
||||
archivePath := filepath.Join(dir, "movie.rar")
|
||||
files, err := listExtractedFiles(dir, archivePath)
|
||||
if err != nil {
|
||||
t.Fatalf("listExtractedFiles: %v", err)
|
||||
}
|
||||
|
||||
// Should exclude .rar and .r00 (archive files in same dir)
|
||||
// Should include movie.mkv and subs.srt
|
||||
if len(files) != 2 {
|
||||
t.Errorf("expected 2 files, got %d: %v", len(files), files)
|
||||
}
|
||||
|
||||
for _, f := range files {
|
||||
base := filepath.Base(f)
|
||||
if base != "movie.mkv" && base != "subs.srt" {
|
||||
t.Errorf("unexpected file: %s", base)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanup(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Files that should be removed
|
||||
cleanupFiles := []string{"content.par2", "info.nfo", "checksum.sfv", "movie.rar", "movie.r00"}
|
||||
for _, name := range cleanupFiles {
|
||||
os.WriteFile(filepath.Join(dir, name), []byte("data"), 0o644)
|
||||
}
|
||||
|
||||
// Files that should be kept
|
||||
keepFiles := []string{"movie.mkv", "subs.srt"}
|
||||
for _, name := range keepFiles {
|
||||
os.WriteFile(filepath.Join(dir, name), []byte("data"), 0o644)
|
||||
}
|
||||
|
||||
err := Cleanup(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("Cleanup: %v", err)
|
||||
}
|
||||
|
||||
// Verify cleanup files are gone
|
||||
for _, name := range cleanupFiles {
|
||||
if _, err := os.Stat(filepath.Join(dir, name)); !os.IsNotExist(err) {
|
||||
t.Errorf("expected %s to be removed", name)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify kept files still exist
|
||||
for _, name := range keepFiles {
|
||||
if _, err := os.Stat(filepath.Join(dir, name)); err != nil {
|
||||
t.Errorf("expected %s to exist, got: %v", name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordError(t *testing.T) {
|
||||
err := &PasswordError{Archive: "/tmp/movie.rar"}
|
||||
msg := err.Error()
|
||||
if msg != "archive is password protected: /tmp/movie.rar" {
|
||||
t.Errorf("PasswordError.Error() = %q", msg)
|
||||
}
|
||||
}
|
||||
156
internal/usenet/postprocess/pipeline_test.go
Normal file
156
internal/usenet/postprocess/pipeline_test.go
Normal file
|
|
@ -0,0 +1,156 @@
|
|||
package postprocess
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFindPar2File(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create par2 files of different sizes
|
||||
mainPar2 := filepath.Join(dir, "content.par2")
|
||||
vol1 := filepath.Join(dir, "content.vol000+01.par2")
|
||||
vol2 := filepath.Join(dir, "content.vol001+02.par2")
|
||||
|
||||
os.WriteFile(mainPar2, make([]byte, 100), 0o644) // smallest
|
||||
os.WriteFile(vol1, make([]byte, 10000), 0o644)
|
||||
os.WriteFile(vol2, make([]byte, 50000), 0o644)
|
||||
|
||||
files := map[string]string{
|
||||
"content.par2": mainPar2,
|
||||
"content.vol000+01.par2": vol1,
|
||||
"content.vol001+02.par2": vol2,
|
||||
}
|
||||
|
||||
result := findPar2File(files)
|
||||
if result != mainPar2 {
|
||||
t.Errorf("findPar2File() = %q, want %q (smallest par2)", result, mainPar2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindPar2FileNone(t *testing.T) {
|
||||
files := map[string]string{
|
||||
"video.mkv": "/tmp/video.mkv",
|
||||
"subs.srt": "/tmp/subs.srt",
|
||||
}
|
||||
|
||||
result := findPar2File(files)
|
||||
if result != "" {
|
||||
t.Errorf("findPar2File() = %q, want empty", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindPar2FileEmpty(t *testing.T) {
|
||||
result := findPar2File(map[string]string{})
|
||||
if result != "" {
|
||||
t.Errorf("findPar2File() = %q, want empty", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindFirstRarPart01(t *testing.T) {
|
||||
files := map[string]string{
|
||||
"movie.part01.rar": "/tmp/movie.part01.rar",
|
||||
"movie.part02.rar": "/tmp/movie.part02.rar",
|
||||
"movie.part03.rar": "/tmp/movie.part03.rar",
|
||||
}
|
||||
|
||||
result := findFirstRar(files)
|
||||
if result != "/tmp/movie.part01.rar" {
|
||||
t.Errorf("findFirstRar() = %q, want part01.rar", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindFirstRarSingle(t *testing.T) {
|
||||
files := map[string]string{
|
||||
"movie.rar": "/tmp/movie.rar",
|
||||
"movie.r00": "/tmp/movie.r00",
|
||||
"movie.r01": "/tmp/movie.r01",
|
||||
}
|
||||
|
||||
result := findFirstRar(files)
|
||||
if result != "/tmp/movie.rar" {
|
||||
t.Errorf("findFirstRar() = %q, want movie.rar (shortest)", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindFirstRarSplitFormat(t *testing.T) {
|
||||
files := map[string]string{
|
||||
"movie.001": "/tmp/movie.001",
|
||||
"movie.002": "/tmp/movie.002",
|
||||
}
|
||||
|
||||
result := findFirstRar(files)
|
||||
if result != "/tmp/movie.001" {
|
||||
t.Errorf("findFirstRar() = %q, want movie.001", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindFirstRarNone(t *testing.T) {
|
||||
files := map[string]string{
|
||||
"video.mkv": "/tmp/video.mkv",
|
||||
"subs.srt": "/tmp/subs.srt",
|
||||
}
|
||||
|
||||
result := findFirstRar(files)
|
||||
if result != "" {
|
||||
t.Errorf("findFirstRar() = %q, want empty", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindMainFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create video files of different sizes
|
||||
small := filepath.Join(dir, "small.mkv")
|
||||
large := filepath.Join(dir, "large.mkv")
|
||||
nonVideo := filepath.Join(dir, "readme.txt")
|
||||
|
||||
os.WriteFile(small, make([]byte, 1000), 0o644)
|
||||
os.WriteFile(large, make([]byte, 5000), 0o644)
|
||||
os.WriteFile(nonVideo, make([]byte, 9000), 0o644)
|
||||
|
||||
result := findMainFile(dir, []string{small, large, nonVideo})
|
||||
if result != large {
|
||||
t.Errorf("findMainFile() = %q, want %q (largest video)", result, large)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindMainFileFallbackToDir(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
video := filepath.Join(dir, "movie.mp4")
|
||||
os.WriteFile(video, make([]byte, 5000), 0o644)
|
||||
|
||||
// Pass empty file list — should fallback to scanning dir
|
||||
result := findMainFile(dir, nil)
|
||||
if result != video {
|
||||
t.Errorf("findMainFile() = %q, want %q (dir scan fallback)", result, video)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindMainFileEmpty(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
result := findMainFile(dir, nil)
|
||||
if result != "" {
|
||||
t.Errorf("findMainFile() = %q, want empty", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindMainFileMultipleFormats(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
mkv := filepath.Join(dir, "movie.mkv")
|
||||
mp4 := filepath.Join(dir, "movie.mp4")
|
||||
avi := filepath.Join(dir, "movie.avi")
|
||||
|
||||
os.WriteFile(mkv, make([]byte, 3000), 0o644)
|
||||
os.WriteFile(mp4, make([]byte, 5000), 0o644) // largest
|
||||
os.WriteFile(avi, make([]byte, 2000), 0o644)
|
||||
|
||||
result := findMainFile(dir, []string{mkv, mp4, avi})
|
||||
if result != mp4 {
|
||||
t.Errorf("findMainFile() = %q, want %q", result, mp4)
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue