feat: initial commit — unarr CLI
Search, inspect, stream, and download torrents from the terminal. Replaces the entire *arr stack with a single binary.
This commit is contained in:
commit
29cf0a0126
85 changed files with 10178 additions and 0 deletions
41
internal/engine/debrid.go
Normal file
41
internal/engine/debrid.go
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
tc "github.com/torrentclaw/go-client"
|
||||
)
|
||||
|
||||
// DebridDownloader downloads via debrid services (Real-Debrid, AllDebrid, etc.).
|
||||
// Currently a stub — Available() works, Download() returns not-implemented.
|
||||
type DebridDownloader struct {
|
||||
apiClient *tc.Client
|
||||
}
|
||||
|
||||
// NewDebridDownloader creates a debrid downloader stub.
|
||||
func NewDebridDownloader(apiClient *tc.Client) *DebridDownloader {
|
||||
return &DebridDownloader{apiClient: apiClient}
|
||||
}
|
||||
|
||||
func (d *DebridDownloader) Method() DownloadMethod { return MethodDebrid }
|
||||
|
||||
func (d *DebridDownloader) Available(ctx context.Context, task *Task) (bool, error) {
|
||||
if d.apiClient == nil {
|
||||
return false, nil
|
||||
}
|
||||
resp, err := d.apiClient.DebridCheckCache(ctx, "", "", []string{task.InfoHash})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
cached, ok := resp.Cached[task.InfoHash]
|
||||
return ok && cached, nil
|
||||
}
|
||||
|
||||
func (d *DebridDownloader) Download(_ context.Context, _ *Task, _ string, _ chan<- Progress) (*Result, error) {
|
||||
return nil, fmt.Errorf("debrid download not implemented yet (coming in a future release)")
|
||||
}
|
||||
|
||||
func (d *DebridDownloader) Pause(_ string) error { return nil }
|
||||
func (d *DebridDownloader) Cancel(_ string) error { return nil }
|
||||
func (d *DebridDownloader) Shutdown(_ context.Context) error { return nil }
|
||||
362
internal/engine/manager.go
Normal file
362
internal/engine/manager.go
Normal file
|
|
@ -0,0 +1,362 @@
|
|||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"sync"
|
||||
|
||||
"github.com/torrentclaw/torrentclaw-cli/internal/agent"
|
||||
)
|
||||
|
||||
// ManagerConfig holds download manager settings.
|
||||
type ManagerConfig struct {
|
||||
MaxConcurrent int
|
||||
OutputDir string
|
||||
Organize OrganizeConfig
|
||||
Notifications bool // send desktop notifications on complete/fail
|
||||
}
|
||||
|
||||
// Manager orchestrates concurrent downloads with method resolution and fallback.
|
||||
type Manager struct {
|
||||
cfg ManagerConfig
|
||||
reporter *ProgressReporter
|
||||
downloaders map[DownloadMethod]Downloader
|
||||
|
||||
activeMu sync.RWMutex
|
||||
active map[string]*Task
|
||||
|
||||
sem chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewManager creates a download manager.
|
||||
func NewManager(cfg ManagerConfig, reporter *ProgressReporter, downloaders ...Downloader) *Manager {
|
||||
if cfg.MaxConcurrent <= 0 {
|
||||
cfg.MaxConcurrent = 3
|
||||
}
|
||||
|
||||
dlMap := make(map[DownloadMethod]Downloader)
|
||||
for _, d := range downloaders {
|
||||
dlMap[d.Method()] = d
|
||||
}
|
||||
|
||||
return &Manager{
|
||||
cfg: cfg,
|
||||
reporter: reporter,
|
||||
downloaders: dlMap,
|
||||
active: make(map[string]*Task),
|
||||
sem: make(chan struct{}, cfg.MaxConcurrent),
|
||||
}
|
||||
}
|
||||
|
||||
// Submit queues a task for download. Non-blocking if capacity available.
|
||||
func (m *Manager) Submit(ctx context.Context, at agent.Task) {
|
||||
task := NewTaskFromAgent(at)
|
||||
|
||||
m.activeMu.Lock()
|
||||
m.active[task.ID] = task
|
||||
m.activeMu.Unlock()
|
||||
|
||||
m.reporter.Track(task)
|
||||
|
||||
// Acquire semaphore slot
|
||||
select {
|
||||
case m.sem <- struct{}{}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
|
||||
m.wg.Add(1)
|
||||
go func() {
|
||||
defer m.wg.Done()
|
||||
defer func() { <-m.sem }()
|
||||
m.processTask(ctx, task)
|
||||
}()
|
||||
}
|
||||
|
||||
// HasCapacity returns true if there's room for more downloads.
|
||||
func (m *Manager) HasCapacity() bool {
|
||||
return len(m.sem) < cap(m.sem)
|
||||
}
|
||||
|
||||
// ActiveCount returns the number of in-progress downloads.
|
||||
func (m *Manager) ActiveCount() int {
|
||||
m.activeMu.RLock()
|
||||
defer m.activeMu.RUnlock()
|
||||
return len(m.active)
|
||||
}
|
||||
|
||||
// GetTask returns a single active task by ID, or nil.
|
||||
func (m *Manager) GetTask(taskID string) *Task {
|
||||
m.activeMu.RLock()
|
||||
defer m.activeMu.RUnlock()
|
||||
return m.active[taskID]
|
||||
}
|
||||
|
||||
// ActiveTasks returns a snapshot of all active tasks.
|
||||
func (m *Manager) ActiveTasks() []*Task {
|
||||
m.activeMu.RLock()
|
||||
defer m.activeMu.RUnlock()
|
||||
tasks := make([]*Task, 0, len(m.active))
|
||||
for _, t := range m.active {
|
||||
tasks = append(tasks, t)
|
||||
}
|
||||
return tasks
|
||||
}
|
||||
|
||||
// CancelTask cancels an active download by task ID (keeps partial files).
|
||||
func (m *Manager) CancelTask(taskID string) {
|
||||
m.activeMu.RLock()
|
||||
task, ok := m.active[taskID]
|
||||
m.activeMu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if dl, exists := m.downloaders[task.ResolvedMethod]; exists {
|
||||
dl.Pause(taskID) // stop download, keep files
|
||||
}
|
||||
|
||||
task.mu.Lock()
|
||||
task.ErrorMessage = "cancelled by user"
|
||||
task.mu.Unlock()
|
||||
task.Transition(StatusCancelled)
|
||||
|
||||
log.Printf("[%s] cancelled: %s", taskID[:8], task.Title)
|
||||
}
|
||||
|
||||
// PauseTask pauses an active download (keeps partial files for resume).
|
||||
func (m *Manager) PauseTask(taskID string) {
|
||||
m.activeMu.RLock()
|
||||
task, ok := m.active[taskID]
|
||||
m.activeMu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if dl, exists := m.downloaders[task.ResolvedMethod]; exists {
|
||||
dl.Pause(taskID) // stop download, keep files for resume
|
||||
}
|
||||
|
||||
task.Transition(StatusCancelled) // will be re-created as pending by server
|
||||
log.Printf("[%s] paused: %s", taskID[:8], task.Title)
|
||||
}
|
||||
|
||||
// CancelAndDeleteFiles cancels a download and removes its files from disk.
|
||||
func (m *Manager) CancelAndDeleteFiles(taskID string) {
|
||||
m.activeMu.RLock()
|
||||
task, ok := m.active[taskID]
|
||||
m.activeMu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if dl, exists := m.downloaders[task.ResolvedMethod]; exists {
|
||||
dl.Cancel(taskID) // stop download + delete files
|
||||
}
|
||||
|
||||
task.mu.Lock()
|
||||
task.ErrorMessage = "cancelled by user"
|
||||
task.mu.Unlock()
|
||||
task.Transition(StatusCancelled)
|
||||
|
||||
log.Printf("[%s] cancelled + files deleted: %s", taskID[:8], task.Title)
|
||||
}
|
||||
|
||||
// Wait blocks until all active downloads finish.
|
||||
func (m *Manager) Wait() {
|
||||
m.wg.Wait()
|
||||
}
|
||||
|
||||
// Shutdown stops accepting tasks and waits for active downloads to finish.
|
||||
func (m *Manager) Shutdown(ctx context.Context) {
|
||||
// Wait for goroutines with timeout
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
m.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
log.Println("shutdown timeout, cancelling active downloads")
|
||||
}
|
||||
|
||||
// Shutdown all downloaders
|
||||
for _, d := range m.downloaders {
|
||||
if err := d.Shutdown(ctx); err != nil {
|
||||
log.Printf("downloader shutdown: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Clean active map
|
||||
m.activeMu.Lock()
|
||||
m.active = make(map[string]*Task)
|
||||
m.activeMu.Unlock()
|
||||
}
|
||||
|
||||
func (m *Manager) processTask(ctx context.Context, task *Task) {
|
||||
defer func() {
|
||||
m.activeMu.Lock()
|
||||
delete(m.active, task.ID)
|
||||
m.activeMu.Unlock()
|
||||
}()
|
||||
|
||||
// 1. Resolve method
|
||||
if err := task.Transition(StatusResolving); err != nil {
|
||||
m.fail(ctx, task, "transition error: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
method, err := resolveMethod(ctx, task, m.downloaders)
|
||||
if err != nil {
|
||||
m.fail(ctx, task, "no method available: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
task.ResolvedMethod = method
|
||||
log.Printf("[%s] resolved method: %s", task.ID[:8], method)
|
||||
|
||||
// 2. Download
|
||||
if err := task.Transition(StatusDownloading); err != nil {
|
||||
m.fail(ctx, task, "transition error: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
progressCh := make(chan Progress, 16)
|
||||
|
||||
// Drain progress channel (just for logging; reporter reads directly from task)
|
||||
go func() {
|
||||
for range progressCh {
|
||||
// Progress already applied via task.UpdateProgress in the downloader
|
||||
}
|
||||
}()
|
||||
|
||||
dl := m.downloaders[method]
|
||||
result, err := dl.Download(ctx, task, m.cfg.OutputDir, progressCh)
|
||||
close(progressCh)
|
||||
|
||||
if err != nil {
|
||||
// Try fallback
|
||||
if tryFallback(task, m.downloaders) {
|
||||
log.Printf("[%s] %s failed, trying fallback: %v", task.ID[:8], method, err)
|
||||
if err := task.Transition(StatusResolving); err == nil {
|
||||
m.processTaskRetry(ctx, task)
|
||||
return
|
||||
}
|
||||
}
|
||||
m.fail(ctx, task, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 3. Verify
|
||||
if err := task.Transition(StatusVerifying); err != nil {
|
||||
m.fail(ctx, task, "transition error: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := verify(result); err != nil {
|
||||
m.fail(ctx, task, "verification failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 4. Organize
|
||||
if err := task.Transition(StatusOrganizing); err != nil {
|
||||
m.fail(ctx, task, "transition error: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
finalPath, err := organize(result, task, m.cfg.Organize)
|
||||
if err != nil {
|
||||
log.Printf("[%s] organize warning: %v (keeping in download dir)", task.ID[:8], err)
|
||||
finalPath = result.FilePath
|
||||
}
|
||||
|
||||
task.mu.Lock()
|
||||
task.FilePath = finalPath
|
||||
task.mu.Unlock()
|
||||
|
||||
// 5. Complete
|
||||
if method == MethodTorrent && m.cfg.Organize.Enabled {
|
||||
// Could add seeding here in the future
|
||||
}
|
||||
|
||||
if err := task.Transition(StatusCompleted); err != nil {
|
||||
m.fail(ctx, task, "transition error: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("[%s] completed: %s -> %s", task.ID[:8], task.Title, finalPath)
|
||||
if m.cfg.Notifications {
|
||||
desktopNotify("Download complete", task.Title)
|
||||
}
|
||||
m.reporter.ReportFinal(ctx, task)
|
||||
}
|
||||
|
||||
// processTaskRetry handles fallback after a method failure.
|
||||
func (m *Manager) processTaskRetry(ctx context.Context, task *Task) {
|
||||
method, err := resolveMethod(ctx, task, m.downloaders)
|
||||
if err != nil {
|
||||
m.fail(ctx, task, "fallback failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
task.ResolvedMethod = method
|
||||
log.Printf("[%s] fallback to: %s", task.ID[:8], method)
|
||||
|
||||
if err := task.Transition(StatusDownloading); err != nil {
|
||||
m.fail(ctx, task, "transition error: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
progressCh := make(chan Progress, 16)
|
||||
go func() {
|
||||
for range progressCh {
|
||||
}
|
||||
}()
|
||||
|
||||
dl := m.downloaders[method]
|
||||
result, err := dl.Download(ctx, task, m.cfg.OutputDir, progressCh)
|
||||
close(progressCh)
|
||||
|
||||
if err != nil {
|
||||
m.fail(ctx, task, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Verify + Organize + Complete (same as processTask)
|
||||
task.Transition(StatusVerifying)
|
||||
if err := verify(result); err != nil {
|
||||
m.fail(ctx, task, "verification failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
task.Transition(StatusOrganizing)
|
||||
finalPath, _ := organize(result, task, m.cfg.Organize)
|
||||
if finalPath == "" {
|
||||
finalPath = result.FilePath
|
||||
}
|
||||
task.mu.Lock()
|
||||
task.FilePath = finalPath
|
||||
task.mu.Unlock()
|
||||
|
||||
task.Transition(StatusCompleted)
|
||||
log.Printf("[%s] completed (fallback): %s -> %s", task.ID[:8], task.Title, finalPath)
|
||||
m.reporter.ReportFinal(ctx, task)
|
||||
}
|
||||
|
||||
func (m *Manager) fail(ctx context.Context, task *Task, msg string) {
|
||||
task.mu.Lock()
|
||||
task.ErrorMessage = msg
|
||||
task.mu.Unlock()
|
||||
task.Transition(StatusFailed)
|
||||
log.Printf("[%s] FAILED: %s — %s", task.ID[:8], task.Title, msg)
|
||||
if m.cfg.Notifications {
|
||||
desktopNotify("Download failed", task.Title+": "+msg)
|
||||
}
|
||||
m.reporter.ReportFinal(ctx, task)
|
||||
}
|
||||
85
internal/engine/manager_test.go
Normal file
85
internal/engine/manager_test.go
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/torrentclaw/torrentclaw-cli/internal/agent"
|
||||
)
|
||||
|
||||
func TestManagerSubmitAndWait(t *testing.T) {
|
||||
reporter := NewProgressReporter(
|
||||
agent.NewClient("http://localhost", "test", "test"),
|
||||
1*time.Second,
|
||||
)
|
||||
|
||||
dl := &mockDownloader{method: MethodTorrent, available: true}
|
||||
mgr := NewManager(ManagerConfig{
|
||||
MaxConcurrent: 2,
|
||||
OutputDir: t.TempDir(),
|
||||
}, reporter, dl)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go reporter.Run(ctx)
|
||||
|
||||
mgr.Submit(ctx, agent.Task{
|
||||
ID: "test-task-1",
|
||||
InfoHash: "abc123def456abc123def456abc123def456abc1",
|
||||
Title: "Test Movie",
|
||||
PreferredMethod: "torrent",
|
||||
})
|
||||
|
||||
mgr.Wait()
|
||||
|
||||
// Task should have been processed (completed or failed depending on verify)
|
||||
// Since mock returns a file that doesn't exist, it may fail at verify
|
||||
// This is expected — we're testing the pipeline works
|
||||
}
|
||||
|
||||
func TestManagerHasCapacity(t *testing.T) {
|
||||
reporter := NewProgressReporter(
|
||||
agent.NewClient("http://localhost", "test", "test"),
|
||||
1*time.Second,
|
||||
)
|
||||
|
||||
mgr := NewManager(ManagerConfig{MaxConcurrent: 2}, reporter)
|
||||
|
||||
if !mgr.HasCapacity() {
|
||||
t.Error("new manager should have capacity")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerActiveCount(t *testing.T) {
|
||||
reporter := NewProgressReporter(
|
||||
agent.NewClient("http://localhost", "test", "test"),
|
||||
1*time.Second,
|
||||
)
|
||||
|
||||
mgr := NewManager(ManagerConfig{MaxConcurrent: 3}, reporter)
|
||||
|
||||
if mgr.ActiveCount() != 0 {
|
||||
t.Errorf("ActiveCount = %d, want 0", mgr.ActiveCount())
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerShutdown(t *testing.T) {
|
||||
reporter := NewProgressReporter(
|
||||
agent.NewClient("http://localhost", "test", "test"),
|
||||
1*time.Second,
|
||||
)
|
||||
|
||||
dl := &mockDownloader{method: MethodTorrent, available: true}
|
||||
mgr := NewManager(ManagerConfig{
|
||||
MaxConcurrent: 1,
|
||||
OutputDir: t.TempDir(),
|
||||
}, reporter, dl)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
mgr.Shutdown(ctx)
|
||||
// Should not hang
|
||||
}
|
||||
58
internal/engine/method.go
Normal file
58
internal/engine/method.go
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
package engine
|
||||
|
||||
import "context"
|
||||
|
||||
// DownloadMethod identifies a download strategy.
|
||||
type DownloadMethod string
|
||||
|
||||
const (
|
||||
MethodTorrent DownloadMethod = "torrent"
|
||||
MethodDebrid DownloadMethod = "debrid"
|
||||
MethodUsenet DownloadMethod = "usenet"
|
||||
)
|
||||
|
||||
// Progress is emitted by downloaders during a download.
|
||||
type Progress struct {
|
||||
DownloadedBytes int64
|
||||
TotalBytes int64
|
||||
SpeedBps int64 // bytes per second
|
||||
ETA int // seconds remaining
|
||||
Peers int // connected peers (torrent only)
|
||||
Seeds int // connected seeds (torrent only)
|
||||
FileName string
|
||||
}
|
||||
|
||||
// Result is returned when a download completes successfully.
|
||||
type Result struct {
|
||||
FilePath string
|
||||
FileName string
|
||||
Method DownloadMethod
|
||||
Size int64
|
||||
}
|
||||
|
||||
// Downloader is the interface every download method must implement.
|
||||
type Downloader interface {
|
||||
// Method returns which method this downloader implements.
|
||||
Method() DownloadMethod
|
||||
|
||||
// Available reports whether this method can handle the given task.
|
||||
// For torrent: always true if infoHash is set.
|
||||
// For debrid: checks if cached on debrid service.
|
||||
// For usenet: checks if NZB is available.
|
||||
Available(ctx context.Context, task *Task) (bool, error)
|
||||
|
||||
// Download starts the download. It blocks until completion or error.
|
||||
// Progress is reported via progressCh at regular intervals.
|
||||
// outputDir is where files should be written.
|
||||
Download(ctx context.Context, task *Task, outputDir string, progressCh chan<- Progress) (*Result, error)
|
||||
|
||||
// Pause suspends an in-progress download but keeps partial files on disk
|
||||
// so the download can be resumed later.
|
||||
Pause(taskID string) error
|
||||
|
||||
// Cancel aborts an in-progress download and removes partial files.
|
||||
Cancel(taskID string) error
|
||||
|
||||
// Shutdown gracefully shuts down the downloader.
|
||||
Shutdown(ctx context.Context) error
|
||||
}
|
||||
30
internal/engine/notify.go
Normal file
30
internal/engine/notify.go
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
package engine
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// desktopNotify sends a best-effort desktop notification.
|
||||
// Silent failure — never blocks or errors.
|
||||
func desktopNotify(title, body string) {
|
||||
switch runtime.GOOS {
|
||||
case "linux":
|
||||
exec.Command("notify-send", title, body, "--icon=dialog-information", "--app-name=unarr").Start()
|
||||
case "darwin":
|
||||
script := `display notification "` + escapeAppleScript(body) + `" with title "` + escapeAppleScript(title) + `"`
|
||||
exec.Command("osascript", "-e", script).Start()
|
||||
}
|
||||
// Windows: no-op for now
|
||||
}
|
||||
|
||||
func escapeAppleScript(s string) string {
|
||||
out := make([]byte, 0, len(s))
|
||||
for i := 0; i < len(s); i++ {
|
||||
if s[i] == '"' || s[i] == '\\' {
|
||||
out = append(out, '\\')
|
||||
}
|
||||
out = append(out, s[i])
|
||||
}
|
||||
return string(out)
|
||||
}
|
||||
129
internal/engine/organize.go
Normal file
129
internal/engine/organize.go
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
package engine
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
yearRegex = regexp.MustCompile(`\b(19|20)\d{2}\b`)
|
||||
seasonRegex = regexp.MustCompile(`(?i)S(\d{2})`)
|
||||
)
|
||||
|
||||
// OrganizeConfig holds file organization settings.
|
||||
type OrganizeConfig struct {
|
||||
Enabled bool
|
||||
MoviesDir string
|
||||
TVShowsDir string
|
||||
}
|
||||
|
||||
// organize moves a downloaded file into the proper directory structure.
|
||||
// Movies: MoviesDir/Title (Year)/filename.ext
|
||||
// TV: TVShowsDir/Title/Season XX/filename.ext
|
||||
func organize(result *Result, task *Task, cfg OrganizeConfig) (string, error) {
|
||||
if !cfg.Enabled || result == nil || result.FilePath == "" {
|
||||
return result.FilePath, nil
|
||||
}
|
||||
|
||||
title := task.Title
|
||||
if title == "" {
|
||||
title = result.FileName
|
||||
}
|
||||
|
||||
isTV := strings.Contains(strings.ToLower(task.PreferredMethod), "show") ||
|
||||
seasonRegex.MatchString(result.FileName)
|
||||
|
||||
// Detect season for TV
|
||||
var season string
|
||||
if m := seasonRegex.FindStringSubmatch(result.FileName); len(m) > 1 {
|
||||
season = m[1]
|
||||
isTV = true
|
||||
}
|
||||
|
||||
var destDir string
|
||||
if isTV && cfg.TVShowsDir != "" {
|
||||
showName := cleanTitle(title)
|
||||
destDir = filepath.Join(cfg.TVShowsDir, showName)
|
||||
if season != "" {
|
||||
destDir = filepath.Join(destDir, fmt.Sprintf("Season %s", season))
|
||||
}
|
||||
} else if cfg.MoviesDir != "" {
|
||||
movieName := cleanTitle(title)
|
||||
year := yearRegex.FindString(title)
|
||||
if year != "" {
|
||||
destDir = filepath.Join(cfg.MoviesDir, fmt.Sprintf("%s (%s)", movieName, year))
|
||||
} else {
|
||||
destDir = filepath.Join(cfg.MoviesDir, movieName)
|
||||
}
|
||||
} else {
|
||||
return result.FilePath, nil // no organize dirs configured
|
||||
}
|
||||
|
||||
// Validate destination is within the expected base directory
|
||||
var baseDir string
|
||||
if isTV && cfg.TVShowsDir != "" {
|
||||
baseDir = cfg.TVShowsDir
|
||||
} else {
|
||||
baseDir = cfg.MoviesDir
|
||||
}
|
||||
if !isWithinDir(baseDir, destDir) {
|
||||
return "", fmt.Errorf("path traversal blocked: %q escapes %q", destDir, baseDir)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(destDir, 0o755); err != nil {
|
||||
return "", fmt.Errorf("create dir: %w", err)
|
||||
}
|
||||
|
||||
destPath := filepath.Join(destDir, filepath.Base(result.FilePath))
|
||||
|
||||
// Try rename first (same filesystem), fall back to copy+delete
|
||||
if err := os.Rename(result.FilePath, destPath); err != nil {
|
||||
if err := copyFile(result.FilePath, destPath); err != nil {
|
||||
return "", fmt.Errorf("move file: %w", err)
|
||||
}
|
||||
os.Remove(result.FilePath)
|
||||
}
|
||||
|
||||
return destPath, nil
|
||||
}
|
||||
|
||||
// cleanTitle extracts a clean title from a torrent title string.
|
||||
func cleanTitle(title string) string {
|
||||
// Remove year and everything after common separators
|
||||
t := title
|
||||
if idx := strings.Index(t, " ("); idx > 0 {
|
||||
t = t[:idx]
|
||||
}
|
||||
// Remove resolution and codec markers
|
||||
for _, pattern := range []string{"1080p", "720p", "2160p", "480p", "BluRay", "WEB-DL", "HDTV", "x264", "x265", "HEVC"} {
|
||||
if idx := strings.Index(strings.ToLower(t), strings.ToLower(pattern)); idx > 0 {
|
||||
t = t[:idx]
|
||||
}
|
||||
}
|
||||
t = strings.TrimRight(t, " .-_")
|
||||
if t == "" {
|
||||
return title
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
func copyFile(src, dst string) error {
|
||||
s, err := os.Open(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
d, err := os.Create(dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer d.Close()
|
||||
|
||||
_, err = io.Copy(d, s)
|
||||
return err
|
||||
}
|
||||
92
internal/engine/organize_test.go
Normal file
92
internal/engine/organize_test.go
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
package engine
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestOrganizeDisabled(t *testing.T) {
|
||||
r := &Result{FilePath: "/tmp/file.mkv", FileName: "file.mkv"}
|
||||
task := &Task{Title: "Movie"}
|
||||
path, err := organize(r, task, OrganizeConfig{Enabled: false})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if path != "/tmp/file.mkv" {
|
||||
t.Errorf("path = %q, want original path when disabled", path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrganizeMovie(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
srcDir := filepath.Join(tmp, "src")
|
||||
os.MkdirAll(srcDir, 0o755)
|
||||
srcFile := filepath.Join(srcDir, "Movie.2023.1080p.mkv")
|
||||
os.WriteFile(srcFile, []byte("data"), 0o644)
|
||||
|
||||
moviesDir := filepath.Join(tmp, "Movies")
|
||||
|
||||
r := &Result{FilePath: srcFile, FileName: "Movie.2023.1080p.mkv"}
|
||||
task := &Task{Title: "Movie 2023"}
|
||||
|
||||
path, err := organize(r, task, OrganizeConfig{
|
||||
Enabled: true,
|
||||
MoviesDir: moviesDir,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Should be in Movies/Movie (2023)/
|
||||
if path == srcFile {
|
||||
t.Error("file should have moved")
|
||||
}
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
t.Errorf("organized file should exist at %s: %v", path, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrganizeTVShow(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
srcFile := filepath.Join(tmp, "Show.S02E05.1080p.mkv")
|
||||
os.WriteFile(srcFile, []byte("data"), 0o644)
|
||||
|
||||
tvDir := filepath.Join(tmp, "TV Shows")
|
||||
|
||||
r := &Result{FilePath: srcFile, FileName: "Show.S02E05.1080p.mkv"}
|
||||
task := &Task{Title: "Show S02E05"}
|
||||
|
||||
path, err := organize(r, task, OrganizeConfig{
|
||||
Enabled: true,
|
||||
TVShowsDir: tvDir,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Should detect season from filename S02
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
t.Errorf("organized file should exist at %s: %v", path, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanTitle(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"The Matrix (1999)", "The Matrix"},
|
||||
{"Oppenheimer 2023 1080p BluRay", "Oppenheimer 2023"},
|
||||
{"Movie", "Movie"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
got := cleanTitle(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("cleanTitle(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
137
internal/engine/progress.go
Normal file
137
internal/engine/progress.go
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/torrentclaw/torrentclaw-cli/internal/agent"
|
||||
)
|
||||
|
||||
// ActionFunc is called when the server signals an action on a task.
|
||||
type ActionFunc func(taskID string)
|
||||
|
||||
// ProgressReporter aggregates progress from downloads and reports to the API.
|
||||
// It batches updates to avoid flooding the server.
|
||||
type ProgressReporter struct {
|
||||
agentClient *agent.Client
|
||||
interval time.Duration
|
||||
|
||||
onCancel ActionFunc
|
||||
onPause ActionFunc
|
||||
onDeleteFiles ActionFunc
|
||||
onStreamRequested ActionFunc
|
||||
|
||||
mu sync.Mutex
|
||||
latest map[string]*Task // taskID -> task with latest progress
|
||||
}
|
||||
|
||||
// NewProgressReporter creates a reporter that flushes every interval.
|
||||
func NewProgressReporter(ac *agent.Client, interval time.Duration) *ProgressReporter {
|
||||
return &ProgressReporter{
|
||||
agentClient: ac,
|
||||
interval: interval,
|
||||
latest: make(map[string]*Task),
|
||||
}
|
||||
}
|
||||
|
||||
// SetCancelHandler sets the callback invoked when the server says a task is cancelled.
|
||||
func (r *ProgressReporter) SetCancelHandler(fn ActionFunc) { r.onCancel = fn }
|
||||
|
||||
// SetPauseHandler sets the callback invoked when the server says a task is paused.
|
||||
func (r *ProgressReporter) SetPauseHandler(fn ActionFunc) { r.onPause = fn }
|
||||
|
||||
// SetDeleteFilesHandler sets the callback for cancel+delete files.
|
||||
func (r *ProgressReporter) SetDeleteFilesHandler(fn ActionFunc) { r.onDeleteFiles = fn }
|
||||
|
||||
// SetStreamRequestedHandler sets the callback for stream activation.
|
||||
func (r *ProgressReporter) SetStreamRequestedHandler(fn ActionFunc) { r.onStreamRequested = fn }
|
||||
|
||||
// Track registers a task for progress tracking.
|
||||
func (r *ProgressReporter) Track(task *Task) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.latest[task.ID] = task
|
||||
}
|
||||
|
||||
// Untrack removes a task from progress tracking.
|
||||
func (r *ProgressReporter) Untrack(taskID string) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
delete(r.latest, taskID)
|
||||
}
|
||||
|
||||
// Run starts the periodic flush loop. Blocks until ctx is cancelled.
|
||||
func (r *ProgressReporter) Run(ctx context.Context) error {
|
||||
ticker := time.NewTicker(r.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
r.flush(context.Background())
|
||||
return nil
|
||||
case <-ticker.C:
|
||||
r.flush(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ProgressReporter) flush(ctx context.Context) {
|
||||
r.mu.Lock()
|
||||
tasks := make([]*Task, 0, len(r.latest))
|
||||
for _, t := range r.latest {
|
||||
tasks = append(tasks, t)
|
||||
}
|
||||
r.mu.Unlock()
|
||||
|
||||
for _, task := range tasks {
|
||||
status := task.GetStatus()
|
||||
if status != StatusDownloading && status != StatusVerifying &&
|
||||
status != StatusOrganizing && status != StatusSeeding &&
|
||||
status != StatusCompleted && status != StatusFailed {
|
||||
continue
|
||||
}
|
||||
|
||||
update := task.ToStatusUpdate()
|
||||
resp, err := r.agentClient.ReportStatus(ctx, update)
|
||||
if err != nil {
|
||||
log.Printf("[%s] progress report failed: %v", task.ID[:8], err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle server-side signals
|
||||
if resp.Cancelled {
|
||||
log.Printf("[%s] cancelled by user (via web)", task.ID[:8])
|
||||
r.Untrack(task.ID)
|
||||
if resp.DeleteFiles && r.onDeleteFiles != nil {
|
||||
r.onDeleteFiles(task.ID)
|
||||
} else if r.onCancel != nil {
|
||||
r.onCancel(task.ID)
|
||||
}
|
||||
} else if resp.Paused {
|
||||
log.Printf("[%s] paused by user (via web)", task.ID[:8])
|
||||
r.Untrack(task.ID)
|
||||
if r.onPause != nil {
|
||||
r.onPause(task.ID)
|
||||
}
|
||||
}
|
||||
|
||||
if resp.StreamRequested && task.GetStreamURL() == "" {
|
||||
log.Printf("[%s] stream requested by user (via web)", task.ID[:8])
|
||||
if r.onStreamRequested != nil {
|
||||
r.onStreamRequested(task.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ReportFinal sends a final status update for a completed/failed task.
|
||||
func (r *ProgressReporter) ReportFinal(ctx context.Context, task *Task) {
|
||||
update := task.ToStatusUpdate()
|
||||
if _, err := r.agentClient.ReportStatus(ctx, update); err != nil {
|
||||
log.Printf("[%s] final report failed: %v", task.ID[:8], err)
|
||||
}
|
||||
r.Untrack(task.ID)
|
||||
}
|
||||
75
internal/engine/resolve.go
Normal file
75
internal/engine/resolve.go
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
)
|
||||
|
||||
// resolveMethod determines which download method to use for a task.
|
||||
// For "auto": tries available methods in priority order (torrent > debrid > usenet).
|
||||
// For specific method: uses only that method.
|
||||
func resolveMethod(ctx context.Context, task *Task, downloaders map[DownloadMethod]Downloader) (DownloadMethod, error) {
|
||||
var order []DownloadMethod
|
||||
switch task.PreferredMethod {
|
||||
case "torrent":
|
||||
order = []DownloadMethod{MethodTorrent}
|
||||
case "debrid":
|
||||
order = []DownloadMethod{MethodDebrid}
|
||||
case "usenet":
|
||||
order = []DownloadMethod{MethodUsenet}
|
||||
default: // "auto"
|
||||
order = []DownloadMethod{MethodTorrent, MethodDebrid, MethodUsenet}
|
||||
}
|
||||
|
||||
for _, method := range order {
|
||||
// Skip already-tried methods
|
||||
tried := false
|
||||
for _, tm := range task.TriedMethods {
|
||||
if tm == method {
|
||||
tried = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if tried {
|
||||
continue
|
||||
}
|
||||
|
||||
dl, ok := downloaders[method]
|
||||
if !ok {
|
||||
continue // downloader not registered
|
||||
}
|
||||
|
||||
available, err := dl.Available(ctx, task)
|
||||
if err != nil {
|
||||
taskID := task.ID
|
||||
if len(taskID) > 8 {
|
||||
taskID = taskID[:8]
|
||||
}
|
||||
log.Printf("[%s] %s availability check failed: %v", taskID, method, err)
|
||||
continue
|
||||
}
|
||||
if available {
|
||||
return method, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("no download method available (tried: %v)", task.TriedMethods)
|
||||
}
|
||||
|
||||
// tryFallback attempts to fall back to the next untried download method.
|
||||
// Returns true if fallback was initiated, false if no more methods.
|
||||
func tryFallback(task *Task, downloaders map[DownloadMethod]Downloader) bool {
|
||||
if task.PreferredMethod != "auto" {
|
||||
return false // specific method requested, no fallback
|
||||
}
|
||||
|
||||
task.TriedMethods = append(task.TriedMethods, task.ResolvedMethod)
|
||||
|
||||
available := make([]DownloadMethod, 0, len(downloaders))
|
||||
for m := range downloaders {
|
||||
available = append(available, m)
|
||||
}
|
||||
|
||||
return task.HasUntried(available)
|
||||
}
|
||||
141
internal/engine/resolve_test.go
Normal file
141
internal/engine/resolve_test.go
Normal file
|
|
@ -0,0 +1,141 @@
|
|||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// mockDownloader implements Downloader for testing.
|
||||
type mockDownloader struct {
|
||||
method DownloadMethod
|
||||
available bool
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockDownloader) Method() DownloadMethod { return m.method }
|
||||
func (m *mockDownloader) Available(_ context.Context, _ *Task) (bool, error) {
|
||||
return m.available, m.err
|
||||
}
|
||||
func (m *mockDownloader) Download(_ context.Context, _ *Task, _ string, _ chan<- Progress) (*Result, error) {
|
||||
return &Result{Method: m.method, FileName: "test.mkv", FilePath: "/tmp/test.mkv"}, nil
|
||||
}
|
||||
func (m *mockDownloader) Pause(_ string) error { return nil }
|
||||
func (m *mockDownloader) Cancel(_ string) error { return nil }
|
||||
func (m *mockDownloader) Shutdown(_ context.Context) error { return nil }
|
||||
|
||||
func TestResolveMethodAuto(t *testing.T) {
|
||||
downloaders := map[DownloadMethod]Downloader{
|
||||
MethodTorrent: &mockDownloader{method: MethodTorrent, available: true},
|
||||
MethodDebrid: &mockDownloader{method: MethodDebrid, available: true},
|
||||
}
|
||||
|
||||
task := &Task{PreferredMethod: "auto"}
|
||||
method, err := resolveMethod(context.Background(), task, downloaders)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Torrent is first in auto order
|
||||
if method != MethodTorrent {
|
||||
t.Errorf("method = %q, want torrent (first in auto order)", method)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveMethodSpecific(t *testing.T) {
|
||||
downloaders := map[DownloadMethod]Downloader{
|
||||
MethodTorrent: &mockDownloader{method: MethodTorrent, available: true},
|
||||
MethodDebrid: &mockDownloader{method: MethodDebrid, available: true},
|
||||
}
|
||||
|
||||
task := &Task{PreferredMethod: "debrid"}
|
||||
method, err := resolveMethod(context.Background(), task, downloaders)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if method != MethodDebrid {
|
||||
t.Errorf("method = %q, want debrid", method)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveMethodSkipsTried(t *testing.T) {
|
||||
downloaders := map[DownloadMethod]Downloader{
|
||||
MethodTorrent: &mockDownloader{method: MethodTorrent, available: true},
|
||||
MethodDebrid: &mockDownloader{method: MethodDebrid, available: true},
|
||||
}
|
||||
|
||||
task := &Task{
|
||||
PreferredMethod: "auto",
|
||||
TriedMethods: []DownloadMethod{MethodTorrent},
|
||||
}
|
||||
method, err := resolveMethod(context.Background(), task, downloaders)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if method != MethodDebrid {
|
||||
t.Errorf("method = %q, want debrid (torrent already tried)", method)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveMethodNoneAvailable(t *testing.T) {
|
||||
downloaders := map[DownloadMethod]Downloader{
|
||||
MethodTorrent: &mockDownloader{method: MethodTorrent, available: false},
|
||||
}
|
||||
|
||||
task := &Task{PreferredMethod: "auto"}
|
||||
_, err := resolveMethod(context.Background(), task, downloaders)
|
||||
if err == nil {
|
||||
t.Error("expected error when no method available")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveMethodAvailabilityError(t *testing.T) {
|
||||
downloaders := map[DownloadMethod]Downloader{
|
||||
MethodTorrent: &mockDownloader{method: MethodTorrent, available: false, err: fmt.Errorf("network error")},
|
||||
MethodDebrid: &mockDownloader{method: MethodDebrid, available: true},
|
||||
}
|
||||
|
||||
task := &Task{ID: "test-resolve-err", PreferredMethod: "auto"}
|
||||
method, err := resolveMethod(context.Background(), task, downloaders)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Should fallback to debrid when torrent has error
|
||||
if method != MethodDebrid {
|
||||
t.Errorf("method = %q, want debrid (torrent errored)", method)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTryFallbackAutoMode(t *testing.T) {
|
||||
downloaders := map[DownloadMethod]Downloader{
|
||||
MethodTorrent: &mockDownloader{method: MethodTorrent, available: true},
|
||||
MethodDebrid: &mockDownloader{method: MethodDebrid, available: true},
|
||||
}
|
||||
|
||||
task := &Task{
|
||||
PreferredMethod: "auto",
|
||||
ResolvedMethod: MethodTorrent,
|
||||
}
|
||||
|
||||
if !tryFallback(task, downloaders) {
|
||||
t.Error("should have fallback available")
|
||||
}
|
||||
if len(task.TriedMethods) != 1 || task.TriedMethods[0] != MethodTorrent {
|
||||
t.Error("torrent should be in tried methods")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTryFallbackSpecificMode(t *testing.T) {
|
||||
downloaders := map[DownloadMethod]Downloader{
|
||||
MethodTorrent: &mockDownloader{method: MethodTorrent, available: true},
|
||||
MethodDebrid: &mockDownloader{method: MethodDebrid, available: true},
|
||||
}
|
||||
|
||||
task := &Task{
|
||||
PreferredMethod: "torrent",
|
||||
ResolvedMethod: MethodTorrent,
|
||||
}
|
||||
|
||||
if tryFallback(task, downloaders) {
|
||||
t.Error("should not fallback in specific mode")
|
||||
}
|
||||
}
|
||||
37
internal/engine/safepath.go
Normal file
37
internal/engine/safepath.go
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
package engine
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// isWithinDir checks that resolved is a child of baseDir (prevents path traversal).
|
||||
// Both paths must be absolute and clean.
|
||||
func isWithinDir(baseDir, resolved string) bool {
|
||||
base := filepath.Clean(baseDir)
|
||||
target := filepath.Clean(resolved)
|
||||
return target == base || strings.HasPrefix(target, base+string(filepath.Separator))
|
||||
}
|
||||
|
||||
// safePath constructs a path under baseDir and validates it doesn't escape.
|
||||
// Returns an error if the resulting path is outside baseDir.
|
||||
// If the resulting path exists and is a symlink that resolves outside baseDir,
|
||||
// it is also rejected.
|
||||
func safePath(baseDir, untrusted string) (string, error) {
|
||||
resolved := filepath.Join(baseDir, untrusted) // Join already cleans
|
||||
|
||||
if !isWithinDir(baseDir, resolved) {
|
||||
return "", fmt.Errorf("path traversal blocked: %q escapes %q", untrusted, baseDir)
|
||||
}
|
||||
|
||||
// Resolve symlinks if the path already exists on disk
|
||||
if real, err := filepath.EvalSymlinks(resolved); err == nil {
|
||||
if !isWithinDir(baseDir, real) {
|
||||
return "", fmt.Errorf("path traversal blocked: %q resolves outside %q via symlink", untrusted, baseDir)
|
||||
}
|
||||
return real, nil
|
||||
}
|
||||
|
||||
return resolved, nil
|
||||
}
|
||||
47
internal/engine/safepath_test.go
Normal file
47
internal/engine/safepath_test.go
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
package engine
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestIsWithinDir(t *testing.T) {
|
||||
tests := []struct {
|
||||
base string
|
||||
target string
|
||||
want bool
|
||||
}{
|
||||
{"/data", "/data/file.txt", true},
|
||||
{"/data", "/data/sub/file.txt", true},
|
||||
{"/data", "/data", true},
|
||||
{"/data", "/data/../etc/passwd", false},
|
||||
{"/data", "/etc/passwd", false},
|
||||
{"/data", "/", false},
|
||||
{"/data", "/datafoo", false}, // not a child, just a prefix
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := isWithinDir(tt.base, tt.target)
|
||||
if got != tt.want {
|
||||
t.Errorf("isWithinDir(%q, %q) = %v, want %v", tt.base, tt.target, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafePath(t *testing.T) {
|
||||
tests := []struct {
|
||||
base string
|
||||
untrusted string
|
||||
wantErr bool
|
||||
}{
|
||||
{"/data", "movie.mkv", false},
|
||||
{"/data", "sub/file.mkv", false},
|
||||
{"/data", "../etc/passwd", true},
|
||||
{"/data", "../../root/.ssh", true},
|
||||
{"/data", "normal/../still-ok", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
_, err := safePath(tt.base, tt.untrusted)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("safePath(%q, %q) error = %v, wantErr %v", tt.base, tt.untrusted, err, tt.wantErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
316
internal/engine/stream.go
Normal file
316
internal/engine/stream.go
Normal file
|
|
@ -0,0 +1,316 @@
|
|||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
alog "github.com/anacrolix/log"
|
||||
"github.com/anacrolix/torrent"
|
||||
)
|
||||
|
||||
|
||||
|
||||
// StreamConfig holds settings for the streaming engine.
|
||||
type StreamConfig struct {
|
||||
DataDir string
|
||||
Port int
|
||||
BufferBytes int64
|
||||
MetaTimeout time.Duration
|
||||
NoOpen bool
|
||||
PlayerCmd string
|
||||
}
|
||||
|
||||
// StreamStatus represents the current state of the streaming session.
|
||||
type StreamStatus int
|
||||
|
||||
const (
|
||||
StreamStatusMetadata StreamStatus = iota
|
||||
StreamStatusBuffering
|
||||
StreamStatusReady
|
||||
StreamStatusError
|
||||
)
|
||||
|
||||
// StreamProgress is a snapshot of current streaming stats.
|
||||
type StreamProgress struct {
|
||||
Status StreamStatus
|
||||
DownloadedBytes int64
|
||||
TotalBytes int64
|
||||
SpeedBps int64
|
||||
Peers int
|
||||
Seeds int
|
||||
FileName string
|
||||
}
|
||||
|
||||
// StreamEngine manages a single streaming torrent session.
|
||||
type StreamEngine struct {
|
||||
client *torrent.Client
|
||||
cfg StreamConfig
|
||||
tor *torrent.Torrent
|
||||
file *torrent.File
|
||||
|
||||
bufferTarget int64
|
||||
totalBytes int64
|
||||
fileName string
|
||||
|
||||
mu sync.RWMutex
|
||||
status StreamStatus
|
||||
lastBytes int64
|
||||
lastTime time.Time
|
||||
speedBps int64
|
||||
}
|
||||
|
||||
// NewStreamEngine creates a streaming engine with its own torrent client.
|
||||
func NewStreamEngine(cfg StreamConfig) (*StreamEngine, error) {
|
||||
if cfg.MetaTimeout == 0 {
|
||||
cfg.MetaTimeout = 60 * time.Second
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(cfg.DataDir, 0o755); err != nil {
|
||||
return nil, fmt.Errorf("create data dir: %w", err)
|
||||
}
|
||||
|
||||
tcfg := torrent.NewDefaultClientConfig()
|
||||
tcfg.DataDir = cfg.DataDir
|
||||
tcfg.Seed = false
|
||||
tcfg.NoUpload = true
|
||||
tcfg.ListenPort = 0
|
||||
tcfg.Logger = alog.Default.FilterLevel(alog.Disabled)
|
||||
|
||||
client, err := torrent.NewClient(tcfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create torrent client: %w", err)
|
||||
}
|
||||
|
||||
return &StreamEngine{
|
||||
client: client,
|
||||
cfg: cfg,
|
||||
status: StreamStatusMetadata,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Start adds the torrent, waits for metadata, selects the video file,
|
||||
// and prepares for streaming.
|
||||
func (s *StreamEngine) Start(ctx context.Context, magnetOrHash string) error {
|
||||
magnet := magnetOrHash
|
||||
if !strings.HasPrefix(magnet, "magnet:") {
|
||||
magnet = buildMagnet(strings.TrimSpace(magnetOrHash))
|
||||
}
|
||||
|
||||
t, err := s.client.AddMagnet(magnet)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add magnet: %w", err)
|
||||
}
|
||||
s.tor = t
|
||||
|
||||
metaCtx, metaCancel := context.WithTimeout(ctx, s.cfg.MetaTimeout)
|
||||
defer metaCancel()
|
||||
|
||||
select {
|
||||
case <-t.GotInfo():
|
||||
case <-metaCtx.Done():
|
||||
return fmt.Errorf("metadata timeout after %s: no peers found", s.cfg.MetaTimeout)
|
||||
}
|
||||
|
||||
if err := s.selectFile(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.totalBytes = s.file.Length()
|
||||
s.fileName = filepath.Base(s.file.DisplayPath())
|
||||
s.bufferTarget = s.calculateBufferTarget()
|
||||
s.lastTime = time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
s.status = StreamStatusBuffering
|
||||
s.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// selectFile picks the best video file from the torrent.
|
||||
// Falls back to the largest file if no video is found.
|
||||
func (s *StreamEngine) selectFile() error {
|
||||
files := s.tor.Files()
|
||||
if len(files) == 0 {
|
||||
return fmt.Errorf("torrent has no files")
|
||||
}
|
||||
|
||||
var bestVideo *torrent.File
|
||||
var bestAny *torrent.File
|
||||
|
||||
for _, f := range files {
|
||||
ext := strings.ToLower(filepath.Ext(f.DisplayPath()))
|
||||
if VideoExts[ext] {
|
||||
if bestVideo == nil || f.Length() > bestVideo.Length() {
|
||||
bestVideo = f
|
||||
}
|
||||
}
|
||||
if bestAny == nil || f.Length() > bestAny.Length() {
|
||||
bestAny = f
|
||||
}
|
||||
}
|
||||
|
||||
if bestVideo != nil {
|
||||
s.file = bestVideo
|
||||
} else {
|
||||
s.file = bestAny
|
||||
}
|
||||
|
||||
// Cancel all other files, download only the selected one
|
||||
for _, f := range files {
|
||||
if f == s.file {
|
||||
f.Download()
|
||||
} else {
|
||||
f.SetPriority(torrent.PiecePriorityNone)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsVideoFile returns true if the selected file has a video extension.
|
||||
func (s *StreamEngine) IsVideoFile() bool {
|
||||
ext := strings.ToLower(filepath.Ext(s.fileName))
|
||||
return VideoExts[ext]
|
||||
}
|
||||
|
||||
func (s *StreamEngine) calculateBufferTarget() int64 {
|
||||
if s.cfg.BufferBytes > 0 {
|
||||
return s.cfg.BufferBytes
|
||||
}
|
||||
fivePercent := s.totalBytes / 20
|
||||
tenMB := int64(10 * 1024 * 1024)
|
||||
if fivePercent < tenMB {
|
||||
return fivePercent
|
||||
}
|
||||
return tenMB
|
||||
}
|
||||
|
||||
// contiguousBytes returns the number of bytes completed contiguously
|
||||
// from the start of the file.
|
||||
func (s *StreamEngine) contiguousBytes() int64 {
|
||||
states := s.file.State()
|
||||
var total int64
|
||||
for _, ps := range states {
|
||||
if ps.Complete {
|
||||
total += ps.Bytes
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
// WaitBuffer blocks until enough contiguous bytes from the file start
|
||||
// are downloaded, or the context is cancelled.
|
||||
func (s *StreamEngine) WaitBuffer(ctx context.Context, progressFn func(buffered, target int64)) error {
|
||||
ticker := time.NewTicker(500 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
buffered := s.contiguousBytes()
|
||||
if progressFn != nil {
|
||||
progressFn(buffered, s.bufferTarget)
|
||||
}
|
||||
if buffered >= s.bufferTarget {
|
||||
s.mu.Lock()
|
||||
s.status = StreamStatusReady
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NewFileReader creates a new reader for the selected file.
|
||||
// Each HTTP request should get its own reader (not safe for concurrent use).
|
||||
func (s *StreamEngine) NewFileReader(ctx context.Context) torrent.Reader {
|
||||
reader := s.file.NewReader()
|
||||
reader.SetResponsive()
|
||||
reader.SetReadahead(5 * 1024 * 1024) // 5MB readahead
|
||||
reader.SetContext(ctx)
|
||||
return reader
|
||||
}
|
||||
|
||||
// StartProgressLoop starts a goroutine that updates speed/peer stats every second.
|
||||
// It stops when the context is cancelled.
|
||||
func (s *StreamEngine) StartProgressLoop(ctx context.Context) {
|
||||
go func() {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
downloaded := s.file.BytesCompleted()
|
||||
|
||||
s.mu.Lock()
|
||||
elapsed := now.Sub(s.lastTime).Seconds()
|
||||
if elapsed > 0 {
|
||||
s.speedBps = int64(float64(downloaded-s.lastBytes) / elapsed)
|
||||
if s.speedBps < 0 {
|
||||
s.speedBps = 0
|
||||
}
|
||||
}
|
||||
s.lastBytes = downloaded
|
||||
s.lastTime = now
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Progress returns a snapshot of the current streaming stats.
|
||||
func (s *StreamEngine) Progress() StreamProgress {
|
||||
s.mu.RLock()
|
||||
status := s.status
|
||||
speed := s.speedBps
|
||||
s.mu.RUnlock()
|
||||
|
||||
stats := s.tor.Stats()
|
||||
|
||||
return StreamProgress{
|
||||
Status: status,
|
||||
DownloadedBytes: s.file.BytesCompleted(),
|
||||
TotalBytes: s.totalBytes,
|
||||
SpeedBps: speed,
|
||||
Peers: stats.ActivePeers,
|
||||
Seeds: stats.ConnectedSeeders,
|
||||
FileName: s.fileName,
|
||||
}
|
||||
}
|
||||
|
||||
// FileName returns the name of the selected file.
|
||||
func (s *StreamEngine) FileName() string { return s.fileName }
|
||||
|
||||
// FileLength returns the total size of the selected file in bytes.
|
||||
func (s *StreamEngine) FileLength() int64 { return s.totalBytes }
|
||||
|
||||
// BufferTarget returns the buffer threshold in bytes.
|
||||
func (s *StreamEngine) BufferTarget() int64 { return s.bufferTarget }
|
||||
|
||||
// Shutdown gracefully closes the torrent and client.
|
||||
func (s *StreamEngine) Shutdown(_ context.Context) error {
|
||||
if s.tor != nil {
|
||||
s.tor.Drop()
|
||||
}
|
||||
if s.client != nil {
|
||||
errs := s.client.Close()
|
||||
if len(errs) > 0 {
|
||||
return fmt.Errorf("close client: %v", errs[0])
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
74
internal/engine/stream_player.go
Normal file
74
internal/engine/stream_player.go
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
package engine
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// OpenPlayer attempts to open a media player with the given stream URL.
|
||||
// Returns the player name and the running command.
|
||||
// If override is set, it uses that command directly.
|
||||
func OpenPlayer(url, override string) (string, *exec.Cmd, error) {
|
||||
if override != "" {
|
||||
cmd := exec.Command(override, url)
|
||||
if err := cmd.Start(); err != nil {
|
||||
return override, nil, fmt.Errorf("start %s: %w", override, err)
|
||||
}
|
||||
return override, cmd, nil
|
||||
}
|
||||
|
||||
// Try mpv first (best streaming support)
|
||||
if path, err := exec.LookPath("mpv"); err == nil {
|
||||
cmd := exec.Command(path, "--no-terminal", url)
|
||||
if err := cmd.Start(); err == nil {
|
||||
return "mpv", cmd, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Try VLC
|
||||
if path, err := exec.LookPath("vlc"); err == nil {
|
||||
cmd := exec.Command(path, url)
|
||||
if err := cmd.Start(); err == nil {
|
||||
return "vlc", cmd, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Try cvlc (VLC headless)
|
||||
if path, err := exec.LookPath("cvlc"); err == nil {
|
||||
cmd := exec.Command(path, url)
|
||||
if err := cmd.Start(); err == nil {
|
||||
return "vlc (headless)", cmd, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Browser fallback
|
||||
name, cmd, err := openBrowser(url)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("no player found: install mpv or vlc, or open %s manually", url)
|
||||
}
|
||||
return name, cmd, nil
|
||||
}
|
||||
|
||||
func openBrowser(url string) (string, *exec.Cmd, error) {
|
||||
switch runtime.GOOS {
|
||||
case "linux":
|
||||
if path, err := exec.LookPath("xdg-open"); err == nil {
|
||||
cmd := exec.Command(path, url)
|
||||
if err := cmd.Start(); err == nil {
|
||||
return "browser", cmd, nil
|
||||
}
|
||||
}
|
||||
case "darwin":
|
||||
cmd := exec.Command("/usr/bin/open", url)
|
||||
if err := cmd.Start(); err == nil {
|
||||
return "browser", cmd, nil
|
||||
}
|
||||
case "windows":
|
||||
cmd := exec.Command("rundll32", "url.dll,FileProtocolHandler", url)
|
||||
if err := cmd.Start(); err == nil {
|
||||
return "browser", cmd, nil
|
||||
}
|
||||
}
|
||||
return "", nil, fmt.Errorf("no browser opener found")
|
||||
}
|
||||
142
internal/engine/stream_server.go
Normal file
142
internal/engine/stream_server.go
Normal file
|
|
@ -0,0 +1,142 @@
|
|||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/anacrolix/torrent"
|
||||
)
|
||||
|
||||
// fileProvider abstracts where to get a file reader for streaming.
|
||||
type fileProvider interface {
|
||||
NewFileReader(ctx context.Context) torrent.Reader
|
||||
FileName() string
|
||||
}
|
||||
|
||||
// StreamServer serves a torrent file over HTTP with Range request support.
|
||||
type StreamServer struct {
|
||||
provider fileProvider
|
||||
server *http.Server
|
||||
port int
|
||||
url string
|
||||
}
|
||||
|
||||
// NewStreamServer creates a new HTTP server for streaming via StreamEngine.
|
||||
func NewStreamServer(engine *StreamEngine, port int) *StreamServer {
|
||||
return &StreamServer{
|
||||
provider: engine,
|
||||
port: port,
|
||||
}
|
||||
}
|
||||
|
||||
// NewStreamServerFromFile creates a server that streams directly from a torrent.File.
|
||||
// Used for streaming an active download without a separate StreamEngine.
|
||||
func NewStreamServerFromFile(file *torrent.File, port int) *StreamServer {
|
||||
return &StreamServer{
|
||||
provider: &torrentFileProvider{file: file},
|
||||
port: port,
|
||||
}
|
||||
}
|
||||
|
||||
// torrentFileProvider wraps a torrent.File to implement fileProvider.
|
||||
type torrentFileProvider struct {
|
||||
file *torrent.File
|
||||
}
|
||||
|
||||
func (p *torrentFileProvider) NewFileReader(ctx context.Context) torrent.Reader {
|
||||
reader := p.file.NewReader()
|
||||
reader.SetResponsive()
|
||||
reader.SetReadahead(5 * 1024 * 1024)
|
||||
reader.SetContext(ctx)
|
||||
return reader
|
||||
}
|
||||
|
||||
func (p *torrentFileProvider) FileName() string {
|
||||
return filepath.Base(p.file.DisplayPath())
|
||||
}
|
||||
|
||||
// Start begins serving the file on localhost. Returns the full URL.
|
||||
func (ss *StreamServer) Start(ctx context.Context) (string, error) {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/stream", ss.handler)
|
||||
|
||||
addr := fmt.Sprintf("127.0.0.1:%d", ss.port)
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("listen on %s: %w", addr, err)
|
||||
}
|
||||
|
||||
// Extract actual port (important when port=0)
|
||||
ss.port = listener.Addr().(*net.TCPAddr).Port
|
||||
ss.url = fmt.Sprintf("http://127.0.0.1:%d/stream", ss.port)
|
||||
|
||||
ss.server = &http.Server{
|
||||
Handler: mux,
|
||||
ReadHeaderTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := ss.server.Serve(listener); err != nil && err != http.ErrServerClosed {
|
||||
log.Printf("stream server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return ss.url, nil
|
||||
}
|
||||
|
||||
// URL returns the full stream URL.
|
||||
func (ss *StreamServer) URL() string { return ss.url }
|
||||
|
||||
// Port returns the bound port.
|
||||
func (ss *StreamServer) Port() int { return ss.port }
|
||||
|
||||
// Shutdown gracefully stops the HTTP server.
|
||||
func (ss *StreamServer) Shutdown(ctx context.Context) error {
|
||||
if ss.server != nil {
|
||||
return ss.server.Shutdown(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ss *StreamServer) handler(w http.ResponseWriter, r *http.Request) {
|
||||
reader := ss.provider.NewFileReader(r.Context())
|
||||
defer reader.Close()
|
||||
|
||||
w.Header().Set("Content-Type", mimeTypeFromExt(ss.provider.FileName()))
|
||||
|
||||
http.ServeContent(w, r, ss.provider.FileName(), time.Time{}, reader)
|
||||
}
|
||||
|
||||
func mimeTypeFromExt(filename string) string {
|
||||
ext := strings.ToLower(filepath.Ext(filename))
|
||||
switch ext {
|
||||
case ".mp4", ".m4v":
|
||||
return "video/mp4"
|
||||
case ".mkv":
|
||||
return "video/x-matroska"
|
||||
case ".avi":
|
||||
return "video/x-msvideo"
|
||||
case ".webm":
|
||||
return "video/webm"
|
||||
case ".mov":
|
||||
return "video/quicktime"
|
||||
case ".ts":
|
||||
return "video/mp2t"
|
||||
case ".flv":
|
||||
return "video/x-flv"
|
||||
case ".mpg", ".mpeg":
|
||||
return "video/mpeg"
|
||||
case ".wmv":
|
||||
return "video/x-ms-wmv"
|
||||
case ".vob":
|
||||
return "video/x-ms-vob"
|
||||
default:
|
||||
return "application/octet-stream"
|
||||
}
|
||||
}
|
||||
370
internal/engine/stream_test.go
Normal file
370
internal/engine/stream_test.go
Normal file
|
|
@ -0,0 +1,370 @@
|
|||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/torrentclaw/torrentclaw-cli/internal/agent"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// StreamEngine unit tests (no network)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestStreamBuildMagnet(t *testing.T) {
|
||||
hash := "abc123def456abc123def456abc123def456abc1"
|
||||
magnet := buildMagnet(hash)
|
||||
|
||||
if !strings.HasPrefix(magnet, "magnet:?xt=urn:btih:"+hash) {
|
||||
t.Errorf("magnet should start with btih, got: %s", magnet[:60])
|
||||
}
|
||||
|
||||
// Should contain trackers
|
||||
for _, tracker := range defaultTrackers {
|
||||
if !strings.Contains(magnet, "tr=") {
|
||||
t.Errorf("magnet should contain tracker param for %s", tracker)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamBuildMagnetPassthrough(t *testing.T) {
|
||||
// If input already is a magnet, Start should use it directly
|
||||
// Here we test that buildMagnet produces a valid magnet from a hash
|
||||
hash := "0000000000000000000000000000000000000000"
|
||||
magnet := buildMagnet(hash)
|
||||
if !strings.Contains(magnet, hash) {
|
||||
t.Error("magnet should contain the info hash")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVideoExtensions(t *testing.T) {
|
||||
exts := []string{".mkv", ".mp4", ".avi", ".webm", ".mov", ".ts", ".flv", ".m4v", ".mpg", ".mpeg", ".vob", ".wmv"}
|
||||
for _, ext := range exts {
|
||||
if !VideoExts[ext] {
|
||||
t.Errorf("expected %s to be a video extension", ext)
|
||||
}
|
||||
}
|
||||
|
||||
nonVideo := []string{".txt", ".zip", ".nfo", ".srt", ".jpg", ".exe"}
|
||||
for _, ext := range nonVideo {
|
||||
if VideoExts[ext] {
|
||||
t.Errorf("expected %s to NOT be a video extension", ext)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateBufferTarget(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
totalBytes int64
|
||||
bufferBytes int64
|
||||
want int64
|
||||
}{
|
||||
{"small file (<200MB) uses 5%", 100 * 1024 * 1024, 0, 100 * 1024 * 1024 / 20},
|
||||
{"large file (10GB) caps at 10MB", 10 * 1024 * 1024 * 1024, 0, 10 * 1024 * 1024},
|
||||
{"medium file (500MB) caps at 10MB", 500 * 1024 * 1024, 0, 10 * 1024 * 1024}, // 5% of 500MB = 25MB > 10MB cap
|
||||
{"override takes precedence", 10 * 1024 * 1024 * 1024, 5 * 1024 * 1024, 5 * 1024 * 1024},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &StreamEngine{
|
||||
totalBytes: tt.totalBytes,
|
||||
cfg: StreamConfig{BufferBytes: tt.bufferBytes},
|
||||
}
|
||||
got := s.calculateBufferTarget()
|
||||
if got != tt.want {
|
||||
t.Errorf("calculateBufferTarget() = %d, want %d", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsVideoFile(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
fileName string
|
||||
want bool
|
||||
}{
|
||||
{"mp4", "movie.mp4", true},
|
||||
{"mkv", "movie.mkv", true},
|
||||
{"avi", "movie.avi", true},
|
||||
{"nfo", "movie.nfo", false},
|
||||
{"txt", "readme.txt", false},
|
||||
{"srt", "subtitles.srt", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &StreamEngine{fileName: tt.fileName}
|
||||
if got := s.IsVideoFile(); got != tt.want {
|
||||
t.Errorf("IsVideoFile(%q) = %v, want %v", tt.fileName, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamStatusConstants(t *testing.T) {
|
||||
// Verify status constants are distinct
|
||||
statuses := []StreamStatus{
|
||||
StreamStatusMetadata,
|
||||
StreamStatusBuffering,
|
||||
StreamStatusReady,
|
||||
StreamStatusError,
|
||||
}
|
||||
seen := map[StreamStatus]bool{}
|
||||
for _, s := range statuses {
|
||||
if seen[s] {
|
||||
t.Errorf("duplicate status value: %d", s)
|
||||
}
|
||||
seen[s] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamEngineGetters(t *testing.T) {
|
||||
s := &StreamEngine{
|
||||
fileName: "movie.mkv",
|
||||
totalBytes: 4 * 1024 * 1024 * 1024,
|
||||
bufferTarget: 10 * 1024 * 1024,
|
||||
}
|
||||
|
||||
if s.FileName() != "movie.mkv" {
|
||||
t.Errorf("FileName() = %q", s.FileName())
|
||||
}
|
||||
if s.FileLength() != 4*1024*1024*1024 {
|
||||
t.Errorf("FileLength() = %d", s.FileLength())
|
||||
}
|
||||
if s.BufferTarget() != 10*1024*1024 {
|
||||
t.Errorf("BufferTarget() = %d", s.BufferTarget())
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// StreamServer unit tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestMimeTypeFromExt(t *testing.T) {
|
||||
tests := []struct {
|
||||
filename string
|
||||
want string
|
||||
}{
|
||||
{"movie.mp4", "video/mp4"},
|
||||
{"movie.m4v", "video/mp4"},
|
||||
{"movie.mkv", "video/x-matroska"},
|
||||
{"movie.avi", "video/x-msvideo"},
|
||||
{"movie.webm", "video/webm"},
|
||||
{"movie.mov", "video/quicktime"},
|
||||
{"movie.ts", "video/mp2t"},
|
||||
{"movie.flv", "video/x-flv"},
|
||||
{"movie.mpg", "video/mpeg"},
|
||||
{"movie.mpeg", "video/mpeg"},
|
||||
{"movie.wmv", "video/x-ms-wmv"},
|
||||
{"movie.vob", "video/x-ms-vob"},
|
||||
{"unknown.xyz", "application/octet-stream"},
|
||||
{"file.MP4", "video/mp4"}, // case insensitive
|
||||
{"FILE.MKV", "video/x-matroska"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.filename, func(t *testing.T) {
|
||||
got := mimeTypeFromExt(tt.filename)
|
||||
if got != tt.want {
|
||||
t.Errorf("mimeTypeFromExt(%q) = %q, want %q", tt.filename, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamServerStartShutdown(t *testing.T) {
|
||||
// Test server lifecycle without a real StreamEngine
|
||||
// We can't test actual streaming, but we can test the HTTP server mechanics
|
||||
|
||||
// Create a minimal engine with just enough state for the server
|
||||
s := &StreamEngine{
|
||||
fileName: "test.mp4",
|
||||
totalBytes: 1024,
|
||||
}
|
||||
|
||||
srv := NewStreamServer(s, 0)
|
||||
if srv.Port() != 0 {
|
||||
t.Errorf("initial port should be 0, got %d", srv.Port())
|
||||
}
|
||||
|
||||
// We can't Start() because NewFileReader needs a real torrent File
|
||||
// But we can test that Shutdown on an un-started server doesn't panic
|
||||
if err := srv.Shutdown(context.Background()); err != nil {
|
||||
t.Errorf("shutdown of un-started server should not error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Task integration with stream fields
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestNewTaskFromAgentWithMode(t *testing.T) {
|
||||
at := agent.Task{
|
||||
ID: "stream-task-1",
|
||||
InfoHash: "abc123def456abc123def456abc123def456abc1",
|
||||
Title: "Movie (2024)",
|
||||
PreferredMethod: "auto",
|
||||
Mode: "stream",
|
||||
}
|
||||
task := NewTaskFromAgent(at)
|
||||
|
||||
if task.Mode != "stream" {
|
||||
t.Errorf("Mode = %q, want stream", task.Mode)
|
||||
}
|
||||
if task.Status != StatusClaimed {
|
||||
t.Errorf("Status = %q, want claimed", task.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTaskFromAgentDefaultMode(t *testing.T) {
|
||||
at := agent.Task{
|
||||
ID: "download-task-1",
|
||||
InfoHash: "abc123def456abc123def456abc123def456abc1",
|
||||
PreferredMethod: "auto",
|
||||
// Mode not set
|
||||
}
|
||||
task := NewTaskFromAgent(at)
|
||||
|
||||
if task.Mode != "download" {
|
||||
t.Errorf("Mode = %q, want download (default)", task.Mode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToStatusUpdateIncludesStreamURL(t *testing.T) {
|
||||
task := &Task{
|
||||
ID: "stream-task-2",
|
||||
Status: StatusDownloading,
|
||||
ResolvedMethod: MethodTorrent,
|
||||
Mode: "stream",
|
||||
StreamURL: "http://127.0.0.1:43210/stream",
|
||||
DownloadedBytes: 500,
|
||||
TotalBytes: 1000,
|
||||
SpeedBps: 100,
|
||||
FileName: "movie.mkv",
|
||||
}
|
||||
|
||||
update := task.ToStatusUpdate()
|
||||
if update.StreamURL != "http://127.0.0.1:43210/stream" {
|
||||
t.Errorf("StreamURL = %q, want http://127.0.0.1:43210/stream", update.StreamURL)
|
||||
}
|
||||
if update.Status != "downloading" {
|
||||
t.Errorf("Status = %q", update.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToStatusUpdateNoStreamURL(t *testing.T) {
|
||||
task := &Task{
|
||||
ID: "download-task-2",
|
||||
Status: StatusDownloading,
|
||||
ResolvedMethod: MethodTorrent,
|
||||
Mode: "download",
|
||||
}
|
||||
|
||||
update := task.ToStatusUpdate()
|
||||
if update.StreamURL != "" {
|
||||
t.Errorf("StreamURL should be empty for download tasks, got %q", update.StreamURL)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// StreamServer HTTP test (with mock ReadSeeker)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestStreamHTTPHandler(t *testing.T) {
|
||||
// We create an HTTP handler manually to test Range request support
|
||||
// This simulates what StreamServer.handler does, but with a string reader
|
||||
content := strings.Repeat("X", 1000) // 1KB of data
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
reader := strings.NewReader(content)
|
||||
w.Header().Set("Content-Type", "video/mp4")
|
||||
http.ServeContent(w, r, "test.mp4", time.Time{}, reader)
|
||||
})
|
||||
|
||||
// Test full content request
|
||||
t.Run("full request", func(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", "/stream", nil)
|
||||
rr := &responseRecorder{headers: http.Header{}, body: &strings.Builder{}}
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.statusCode != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", rr.statusCode)
|
||||
}
|
||||
if ct := rr.headers.Get("Content-Type"); ct != "video/mp4" {
|
||||
t.Errorf("Content-Type = %q, want video/mp4", ct)
|
||||
}
|
||||
if rr.body.Len() != 1000 {
|
||||
t.Errorf("body length = %d, want 1000", rr.body.Len())
|
||||
}
|
||||
})
|
||||
|
||||
// Test Range request
|
||||
t.Run("range request", func(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", "/stream", nil)
|
||||
req.Header.Set("Range", "bytes=0-99")
|
||||
rr := &responseRecorder{headers: http.Header{}, body: &strings.Builder{}}
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.statusCode != http.StatusPartialContent {
|
||||
t.Errorf("status = %d, want 206 Partial Content", rr.statusCode)
|
||||
}
|
||||
if rr.body.Len() != 100 {
|
||||
t.Errorf("body length = %d, want 100", rr.body.Len())
|
||||
}
|
||||
})
|
||||
|
||||
// Test Range request middle
|
||||
t.Run("range request middle", func(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", "/stream", nil)
|
||||
req.Header.Set("Range", "bytes=500-599")
|
||||
rr := &responseRecorder{headers: http.Header{}, body: &strings.Builder{}}
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.statusCode != http.StatusPartialContent {
|
||||
t.Errorf("status = %d, want 206", rr.statusCode)
|
||||
}
|
||||
if rr.body.Len() != 100 {
|
||||
t.Errorf("body length = %d, want 100", rr.body.Len())
|
||||
}
|
||||
})
|
||||
|
||||
// Test HEAD request
|
||||
t.Run("HEAD request", func(t *testing.T) {
|
||||
req, _ := http.NewRequest("HEAD", "/stream", nil)
|
||||
rr := &responseRecorder{headers: http.Header{}, body: &strings.Builder{}}
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.statusCode != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", rr.statusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// responseRecorder is a minimal http.ResponseWriter for testing
|
||||
type responseRecorder struct {
|
||||
statusCode int
|
||||
headers http.Header
|
||||
body *strings.Builder
|
||||
}
|
||||
|
||||
func (r *responseRecorder) Header() http.Header { return r.headers }
|
||||
func (r *responseRecorder) WriteHeader(code int) { r.statusCode = code }
|
||||
func (r *responseRecorder) Write(b []byte) (int, error) {
|
||||
if r.statusCode == 0 {
|
||||
r.statusCode = http.StatusOK
|
||||
}
|
||||
return r.body.Write(b)
|
||||
}
|
||||
|
||||
// Ensure responseRecorder implements ReadSeeker expectations
|
||||
func (r *responseRecorder) ReadFrom(src io.Reader) (int64, error) {
|
||||
n, err := io.Copy(r.body, src)
|
||||
return n, err
|
||||
}
|
||||
212
internal/engine/task.go
Normal file
212
internal/engine/task.go
Normal file
|
|
@ -0,0 +1,212 @@
|
|||
package engine
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/torrentclaw/torrentclaw-cli/internal/agent"
|
||||
)
|
||||
|
||||
// TaskStatus represents the current state of a download task.
|
||||
type TaskStatus string
|
||||
|
||||
const (
|
||||
StatusPending TaskStatus = "pending"
|
||||
StatusClaimed TaskStatus = "claimed"
|
||||
StatusResolving TaskStatus = "resolving"
|
||||
StatusDownloading TaskStatus = "downloading"
|
||||
StatusVerifying TaskStatus = "verifying"
|
||||
StatusOrganizing TaskStatus = "organizing"
|
||||
StatusSeeding TaskStatus = "seeding"
|
||||
StatusCompleted TaskStatus = "completed"
|
||||
StatusFailed TaskStatus = "failed"
|
||||
StatusCancelled TaskStatus = "cancelled"
|
||||
)
|
||||
|
||||
// validTransitions defines allowed state changes.
|
||||
var validTransitions = map[TaskStatus][]TaskStatus{
|
||||
StatusPending: {StatusClaimed},
|
||||
StatusClaimed: {StatusResolving, StatusCancelled},
|
||||
StatusResolving: {StatusDownloading, StatusFailed, StatusCancelled},
|
||||
StatusDownloading: {StatusVerifying, StatusFailed, StatusResolving, StatusCancelled},
|
||||
StatusVerifying: {StatusOrganizing, StatusFailed},
|
||||
StatusOrganizing: {StatusSeeding, StatusCompleted},
|
||||
StatusSeeding: {StatusCompleted},
|
||||
}
|
||||
|
||||
// Task represents a download task with its full lifecycle state.
|
||||
type Task struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// From server
|
||||
ID string
|
||||
InfoHash string
|
||||
Title string
|
||||
ContentID *int
|
||||
IMDbID string
|
||||
PreferredMethod string // auto | torrent | debrid | usenet
|
||||
|
||||
// Runtime state
|
||||
Status TaskStatus
|
||||
Mode string // download | stream
|
||||
ResolvedMethod DownloadMethod
|
||||
TriedMethods []DownloadMethod
|
||||
DownloadedBytes int64
|
||||
TotalBytes int64
|
||||
SpeedBps int64
|
||||
ETA int
|
||||
FileName string
|
||||
FilePath string
|
||||
StreamURL string
|
||||
ErrorMessage string
|
||||
|
||||
// Timestamps
|
||||
ClaimedAt time.Time
|
||||
StartedAt time.Time
|
||||
CompletedAt time.Time
|
||||
}
|
||||
|
||||
// NewTaskFromAgent creates a Task from a server-claimed agent.Task.
|
||||
func NewTaskFromAgent(at agent.Task) *Task {
|
||||
mode := at.Mode
|
||||
if mode == "" {
|
||||
mode = "download"
|
||||
}
|
||||
return &Task{
|
||||
ID: at.ID,
|
||||
InfoHash: at.InfoHash,
|
||||
Title: at.Title,
|
||||
ContentID: at.ContentID,
|
||||
IMDbID: at.IMDbID,
|
||||
PreferredMethod: at.PreferredMethod,
|
||||
Mode: mode,
|
||||
Status: StatusClaimed,
|
||||
ClaimedAt: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// Transition validates and performs a state transition.
|
||||
func (t *Task) Transition(to TaskStatus) error {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
allowed, ok := validTransitions[t.Status]
|
||||
if !ok {
|
||||
return fmt.Errorf("no transitions from %s", t.Status)
|
||||
}
|
||||
for _, a := range allowed {
|
||||
if a == to {
|
||||
t.Status = to
|
||||
if to == StatusDownloading {
|
||||
t.StartedAt = time.Now()
|
||||
}
|
||||
if to == StatusCompleted || to == StatusFailed {
|
||||
t.CompletedAt = time.Now()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("invalid transition: %s -> %s", t.Status, to)
|
||||
}
|
||||
|
||||
// GetStatus returns current status thread-safely.
|
||||
func (t *Task) GetStatus() TaskStatus {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
return t.Status
|
||||
}
|
||||
|
||||
// SetStreamURL sets the stream URL thread-safely.
|
||||
func (t *Task) SetStreamURL(url string) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
t.StreamURL = url
|
||||
}
|
||||
|
||||
// GetStreamURL returns the stream URL thread-safely.
|
||||
func (t *Task) GetStreamURL() string {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
return t.StreamURL
|
||||
}
|
||||
|
||||
// UpdateProgress updates download metrics thread-safely.
|
||||
func (t *Task) UpdateProgress(p Progress) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
t.DownloadedBytes = p.DownloadedBytes
|
||||
t.TotalBytes = p.TotalBytes
|
||||
t.SpeedBps = p.SpeedBps
|
||||
t.ETA = p.ETA
|
||||
if p.FileName != "" {
|
||||
t.FileName = p.FileName
|
||||
}
|
||||
}
|
||||
|
||||
// Percent returns download progress as 0-100.
|
||||
func (t *Task) Percent() int {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
if t.TotalBytes <= 0 {
|
||||
return 0
|
||||
}
|
||||
p := int(float64(t.DownloadedBytes) / float64(t.TotalBytes) * 100)
|
||||
if p > 100 {
|
||||
return 100
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// ToStatusUpdate converts task state to an API status update.
|
||||
func (t *Task) ToStatusUpdate() agent.StatusUpdate {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
|
||||
apiStatus := ""
|
||||
switch t.Status {
|
||||
case StatusResolving, StatusDownloading, StatusVerifying, StatusOrganizing, StatusSeeding:
|
||||
apiStatus = "downloading"
|
||||
case StatusCompleted:
|
||||
apiStatus = "completed"
|
||||
case StatusFailed:
|
||||
apiStatus = "failed"
|
||||
}
|
||||
|
||||
return agent.StatusUpdate{
|
||||
TaskID: t.ID,
|
||||
Status: apiStatus,
|
||||
Progress: t.Percent(),
|
||||
DownloadedBytes: t.DownloadedBytes,
|
||||
TotalBytes: t.TotalBytes,
|
||||
SpeedBps: t.SpeedBps,
|
||||
ETA: t.ETA,
|
||||
ResolvedMethod: string(t.ResolvedMethod),
|
||||
FileName: t.FileName,
|
||||
FilePath: t.FilePath,
|
||||
StreamURL: t.StreamURL,
|
||||
ErrorMessage: t.ErrorMessage,
|
||||
}
|
||||
}
|
||||
|
||||
// MagnetURI builds a magnet link from the info hash.
|
||||
func (t *Task) MagnetURI() string {
|
||||
return "magnet:?xt=urn:btih:" + t.InfoHash
|
||||
}
|
||||
|
||||
// HasUntried returns true if there are download methods not yet attempted.
|
||||
func (t *Task) HasUntried(available []DownloadMethod) bool {
|
||||
for _, m := range available {
|
||||
tried := false
|
||||
for _, tm := range t.TriedMethods {
|
||||
if tm == m {
|
||||
tried = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !tried {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
190
internal/engine/task_test.go
Normal file
190
internal/engine/task_test.go
Normal file
|
|
@ -0,0 +1,190 @@
|
|||
package engine
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/torrentclaw/torrentclaw-cli/internal/agent"
|
||||
)
|
||||
|
||||
func TestNewTaskFromAgent(t *testing.T) {
|
||||
at := agent.Task{
|
||||
ID: "uuid-123",
|
||||
InfoHash: "abc123def456abc123def456abc123def456abc1",
|
||||
Title: "The Matrix (1999)",
|
||||
PreferredMethod: "auto",
|
||||
}
|
||||
task := NewTaskFromAgent(at)
|
||||
|
||||
if task.ID != "uuid-123" {
|
||||
t.Errorf("ID = %q, want uuid-123", task.ID)
|
||||
}
|
||||
if task.Status != StatusClaimed {
|
||||
t.Errorf("Status = %q, want claimed", task.Status)
|
||||
}
|
||||
if task.ClaimedAt.IsZero() {
|
||||
t.Error("ClaimedAt should be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransitionValid(t *testing.T) {
|
||||
transitions := []struct {
|
||||
from TaskStatus
|
||||
to TaskStatus
|
||||
}{
|
||||
{StatusClaimed, StatusResolving},
|
||||
{StatusResolving, StatusDownloading},
|
||||
{StatusDownloading, StatusVerifying},
|
||||
{StatusVerifying, StatusOrganizing},
|
||||
{StatusOrganizing, StatusCompleted},
|
||||
}
|
||||
|
||||
for _, tt := range transitions {
|
||||
t.Run(string(tt.from)+"->"+string(tt.to), func(t *testing.T) {
|
||||
task := &Task{Status: tt.from}
|
||||
if err := task.Transition(tt.to); err != nil {
|
||||
t.Errorf("valid transition %s -> %s failed: %v", tt.from, tt.to, err)
|
||||
}
|
||||
if task.Status != tt.to {
|
||||
t.Errorf("Status = %q, want %q", task.Status, tt.to)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransitionInvalid(t *testing.T) {
|
||||
invalid := []struct {
|
||||
from TaskStatus
|
||||
to TaskStatus
|
||||
}{
|
||||
{StatusPending, StatusDownloading},
|
||||
{StatusClaimed, StatusCompleted},
|
||||
{StatusCompleted, StatusDownloading},
|
||||
{StatusFailed, StatusCompleted},
|
||||
{StatusVerifying, StatusResolving},
|
||||
}
|
||||
|
||||
for _, tt := range invalid {
|
||||
t.Run(string(tt.from)+"->"+string(tt.to), func(t *testing.T) {
|
||||
task := &Task{Status: tt.from}
|
||||
if err := task.Transition(tt.to); err == nil {
|
||||
t.Errorf("invalid transition %s -> %s should fail", tt.from, tt.to)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransitionDownloadingSetsStartedAt(t *testing.T) {
|
||||
task := &Task{Status: StatusResolving}
|
||||
task.Transition(StatusDownloading)
|
||||
if task.StartedAt.IsZero() {
|
||||
t.Error("StartedAt should be set on downloading transition")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransitionCompletedSetsCompletedAt(t *testing.T) {
|
||||
task := &Task{Status: StatusOrganizing}
|
||||
task.Transition(StatusCompleted)
|
||||
if task.CompletedAt.IsZero() {
|
||||
t.Error("CompletedAt should be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransitionFailedSetsCompletedAt(t *testing.T) {
|
||||
task := &Task{Status: StatusResolving}
|
||||
task.Transition(StatusFailed)
|
||||
if task.CompletedAt.IsZero() {
|
||||
t.Error("CompletedAt should be set on failure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallbackTransition(t *testing.T) {
|
||||
// downloading -> resolving (fallback)
|
||||
task := &Task{Status: StatusDownloading}
|
||||
if err := task.Transition(StatusResolving); err != nil {
|
||||
t.Errorf("fallback transition should work: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCancelFromMultipleStates(t *testing.T) {
|
||||
for _, from := range []TaskStatus{StatusClaimed, StatusResolving, StatusDownloading} {
|
||||
t.Run(string(from), func(t *testing.T) {
|
||||
task := &Task{Status: from}
|
||||
if err := task.Transition(StatusCancelled); err != nil {
|
||||
t.Errorf("cancel from %s should work: %v", from, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPercent(t *testing.T) {
|
||||
task := &Task{DownloadedBytes: 500, TotalBytes: 1000}
|
||||
if p := task.Percent(); p != 50 {
|
||||
t.Errorf("Percent = %d, want 50", p)
|
||||
}
|
||||
|
||||
task2 := &Task{DownloadedBytes: 0, TotalBytes: 0}
|
||||
if p := task2.Percent(); p != 0 {
|
||||
t.Errorf("Percent = %d, want 0 for zero total", p)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateProgress(t *testing.T) {
|
||||
task := &Task{}
|
||||
task.UpdateProgress(Progress{
|
||||
DownloadedBytes: 1024,
|
||||
TotalBytes: 2048,
|
||||
SpeedBps: 512,
|
||||
ETA: 2,
|
||||
FileName: "movie.mkv",
|
||||
})
|
||||
if task.DownloadedBytes != 1024 {
|
||||
t.Errorf("DownloadedBytes = %d", task.DownloadedBytes)
|
||||
}
|
||||
if task.FileName != "movie.mkv" {
|
||||
t.Errorf("FileName = %q", task.FileName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToStatusUpdate(t *testing.T) {
|
||||
task := &Task{
|
||||
ID: "task-123",
|
||||
Status: StatusDownloading,
|
||||
ResolvedMethod: MethodTorrent,
|
||||
DownloadedBytes: 500,
|
||||
TotalBytes: 1000,
|
||||
SpeedBps: 100,
|
||||
ETA: 5,
|
||||
FileName: "file.mkv",
|
||||
}
|
||||
update := task.ToStatusUpdate()
|
||||
if update.TaskID != "task-123" {
|
||||
t.Errorf("TaskID = %q", update.TaskID)
|
||||
}
|
||||
if update.Status != "downloading" {
|
||||
t.Errorf("Status = %q, want downloading", update.Status)
|
||||
}
|
||||
if update.Progress != 50 {
|
||||
t.Errorf("Progress = %d, want 50", update.Progress)
|
||||
}
|
||||
if update.ResolvedMethod != "torrent" {
|
||||
t.Errorf("ResolvedMethod = %q", update.ResolvedMethod)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMagnetURI(t *testing.T) {
|
||||
task := &Task{InfoHash: "abc123"}
|
||||
m := task.MagnetURI()
|
||||
if m != "magnet:?xt=urn:btih:abc123" {
|
||||
t.Errorf("MagnetURI = %q", m)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasUntried(t *testing.T) {
|
||||
task := &Task{TriedMethods: []DownloadMethod{MethodTorrent}}
|
||||
if !task.HasUntried([]DownloadMethod{MethodTorrent, MethodDebrid}) {
|
||||
t.Error("should have untried (debrid)")
|
||||
}
|
||||
if task.HasUntried([]DownloadMethod{MethodTorrent}) {
|
||||
t.Error("all methods tried")
|
||||
}
|
||||
}
|
||||
433
internal/engine/torrent.go
Normal file
433
internal/engine/torrent.go
Normal file
|
|
@ -0,0 +1,433 @@
|
|||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
alog "github.com/anacrolix/log"
|
||||
"github.com/anacrolix/torrent"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
var defaultTrackers = []string{
|
||||
"udp://tracker.opentrackr.org:1337/announce",
|
||||
"udp://open.stealth.si:80/announce",
|
||||
"udp://tracker.torrent.eu.org:451/announce",
|
||||
"udp://open.demonii.com:1337/announce",
|
||||
"udp://exodus.desync.com:6969/announce",
|
||||
}
|
||||
|
||||
// TorrentConfig holds settings for the BitTorrent downloader.
|
||||
type TorrentConfig struct {
|
||||
DataDir string
|
||||
StallTimeout time.Duration // no progress for this long = stall (default 90s)
|
||||
MaxTimeout time.Duration // absolute maximum per torrent (default 30m)
|
||||
MaxDownloadRate int64 // bytes/s, 0 = unlimited
|
||||
MaxUploadRate int64 // bytes/s, 0 = unlimited
|
||||
SeedEnabled bool
|
||||
SeedRatio float64 // target seed ratio (default 0, meaning seed until SeedTime)
|
||||
SeedTime time.Duration // min seed time after completion (default 0)
|
||||
}
|
||||
|
||||
// TorrentDownloader downloads torrents via BitTorrent P2P.
|
||||
type TorrentDownloader struct {
|
||||
client *torrent.Client
|
||||
cfg TorrentConfig
|
||||
|
||||
activeMu sync.Mutex
|
||||
active map[string]*torrent.Torrent // taskID -> torrent handle
|
||||
}
|
||||
|
||||
// NewTorrentDownloader creates a BitTorrent downloader with a long-lived client.
|
||||
func NewTorrentDownloader(cfg TorrentConfig) (*TorrentDownloader, error) {
|
||||
if cfg.StallTimeout == 0 {
|
||||
cfg.StallTimeout = 90 * time.Second
|
||||
}
|
||||
if cfg.MaxTimeout == 0 {
|
||||
cfg.MaxTimeout = 30 * time.Minute
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(cfg.DataDir, 0o755); err != nil {
|
||||
return nil, fmt.Errorf("create data dir: %w", err)
|
||||
}
|
||||
|
||||
tcfg := torrent.NewDefaultClientConfig()
|
||||
tcfg.DataDir = cfg.DataDir
|
||||
tcfg.Seed = cfg.SeedEnabled
|
||||
tcfg.NoUpload = !cfg.SeedEnabled
|
||||
tcfg.ListenPort = 0
|
||||
tcfg.Logger = alog.Default.FilterLevel(alog.Disabled)
|
||||
|
||||
if cfg.MaxDownloadRate > 0 {
|
||||
burst := int(cfg.MaxDownloadRate)
|
||||
if burst < 256*1024 {
|
||||
burst = 256 * 1024
|
||||
}
|
||||
tcfg.DownloadRateLimiter = rate.NewLimiter(rate.Limit(cfg.MaxDownloadRate), burst)
|
||||
}
|
||||
if cfg.MaxUploadRate > 0 {
|
||||
burst := int(cfg.MaxUploadRate)
|
||||
if burst < 256*1024 {
|
||||
burst = 256 * 1024
|
||||
}
|
||||
tcfg.UploadRateLimiter = rate.NewLimiter(rate.Limit(cfg.MaxUploadRate), burst)
|
||||
}
|
||||
|
||||
client, err := torrent.NewClient(tcfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create torrent client: %w", err)
|
||||
}
|
||||
|
||||
return &TorrentDownloader{
|
||||
client: client,
|
||||
cfg: cfg,
|
||||
active: make(map[string]*torrent.Torrent),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *TorrentDownloader) Method() DownloadMethod { return MethodTorrent }
|
||||
|
||||
func (d *TorrentDownloader) Available(_ context.Context, task *Task) (bool, error) {
|
||||
return task.InfoHash != "", nil
|
||||
}
|
||||
|
||||
func (d *TorrentDownloader) Download(ctx context.Context, task *Task, outputDir string, progressCh chan<- Progress) (*Result, error) {
|
||||
magnet := buildMagnet(task.InfoHash)
|
||||
|
||||
t, err := d.client.AddMagnet(magnet)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("add magnet: %w", err)
|
||||
}
|
||||
|
||||
// Track active torrent
|
||||
d.activeMu.Lock()
|
||||
d.active[task.ID] = t
|
||||
d.activeMu.Unlock()
|
||||
|
||||
cleanup := func() {
|
||||
d.activeMu.Lock()
|
||||
delete(d.active, task.ID)
|
||||
d.activeMu.Unlock()
|
||||
if !d.cfg.SeedEnabled {
|
||||
t.Drop()
|
||||
}
|
||||
}
|
||||
|
||||
// 1. Wait for metadata
|
||||
log.Printf("[%s] waiting for metadata...", task.ID[:8])
|
||||
metaCtx, metaCancel := context.WithTimeout(ctx, d.cfg.StallTimeout)
|
||||
defer metaCancel()
|
||||
|
||||
select {
|
||||
case <-t.GotInfo():
|
||||
log.Printf("[%s] metadata received: %s (%d files)", task.ID[:8], t.Name(), len(t.Files()))
|
||||
case <-metaCtx.Done():
|
||||
cleanup()
|
||||
return nil, fmt.Errorf("metadata timeout after %s", d.cfg.StallTimeout)
|
||||
}
|
||||
|
||||
// 2. Select files to download (prefer largest video + matching subs)
|
||||
totalBytes, fileName := d.selectFiles(t, task.ID)
|
||||
|
||||
log.Printf("[%s] downloading %s (%s)", task.ID[:8], fileName, formatBytes(totalBytes))
|
||||
|
||||
// 3. Poll progress with stall detection
|
||||
result, err := d.pollDownload(ctx, t, task, totalBytes, fileName, progressCh)
|
||||
if err != nil {
|
||||
cleanup()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 4. Determine file path
|
||||
filePath := filepath.Join(d.cfg.DataDir, fileName)
|
||||
if _, statErr := os.Stat(filePath); statErr != nil {
|
||||
filePath = filepath.Join(d.cfg.DataDir, t.Name())
|
||||
}
|
||||
|
||||
result.FilePath = filePath
|
||||
result.FileName = fileName
|
||||
result.Method = MethodTorrent
|
||||
result.Size = totalBytes
|
||||
|
||||
// If seeding enabled, keep alive (don't cleanup).
|
||||
// The manager handles seeding lifecycle.
|
||||
if !d.cfg.SeedEnabled {
|
||||
cleanup()
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (d *TorrentDownloader) pollDownload(ctx context.Context, t *torrent.Torrent, task *Task, totalBytes int64, fileName string, progressCh chan<- Progress) (*Result, error) {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
deadline := time.After(d.cfg.MaxTimeout)
|
||||
lastBytesAt := time.Now()
|
||||
lastBytes := int64(0)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("cancelled")
|
||||
|
||||
case <-deadline:
|
||||
return nil, fmt.Errorf("max timeout %s exceeded", d.cfg.MaxTimeout)
|
||||
|
||||
case <-ticker.C:
|
||||
downloaded := t.BytesCompleted()
|
||||
now := time.Now()
|
||||
|
||||
// Speed calculation
|
||||
speed := downloaded - lastBytes
|
||||
if speed < 0 {
|
||||
speed = 0
|
||||
}
|
||||
|
||||
// Stall detection (dual-level like TrueSpec)
|
||||
if downloaded > lastBytes {
|
||||
lastBytesAt = now
|
||||
lastBytes = downloaded
|
||||
} else if now.Sub(lastBytesAt) > d.cfg.StallTimeout {
|
||||
stats := t.Stats()
|
||||
return nil, fmt.Errorf("stalled: no progress for %s (peers: %d, seeds: %d)",
|
||||
d.cfg.StallTimeout, stats.ActivePeers, stats.ConnectedSeeders)
|
||||
}
|
||||
|
||||
// ETA
|
||||
var eta int
|
||||
if speed > 0 {
|
||||
remaining := totalBytes - downloaded
|
||||
eta = int(remaining / speed)
|
||||
}
|
||||
|
||||
// Peer stats
|
||||
stats := t.Stats()
|
||||
|
||||
// Report progress
|
||||
p := Progress{
|
||||
DownloadedBytes: downloaded,
|
||||
TotalBytes: totalBytes,
|
||||
SpeedBps: speed,
|
||||
ETA: eta,
|
||||
Peers: stats.ActivePeers,
|
||||
Seeds: stats.ConnectedSeeders,
|
||||
FileName: fileName,
|
||||
}
|
||||
task.UpdateProgress(p)
|
||||
|
||||
select {
|
||||
case progressCh <- p:
|
||||
default: // don't block if channel full
|
||||
}
|
||||
|
||||
// Check completion
|
||||
if downloaded >= totalBytes {
|
||||
log.Printf("[%s] download complete: %s", task.ID[:8], fileName)
|
||||
return &Result{}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pause drops the torrent handle but keeps partial files on disk for resume.
|
||||
func (d *TorrentDownloader) Pause(taskID string) error {
|
||||
d.activeMu.Lock()
|
||||
t, ok := d.active[taskID]
|
||||
delete(d.active, taskID)
|
||||
d.activeMu.Unlock()
|
||||
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
t.Drop()
|
||||
log.Printf("[%s] paused (files kept for resume)", taskID[:8])
|
||||
return nil
|
||||
}
|
||||
|
||||
// Cancel drops the torrent handle and removes partial files from disk.
|
||||
func (d *TorrentDownloader) Cancel(taskID string) error {
|
||||
d.activeMu.Lock()
|
||||
t, ok := d.active[taskID]
|
||||
delete(d.active, taskID)
|
||||
d.activeMu.Unlock()
|
||||
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
name := t.Name()
|
||||
t.Drop()
|
||||
|
||||
if name != "" {
|
||||
path, err := safePath(d.cfg.DataDir, name)
|
||||
if err != nil {
|
||||
log.Printf("[%s] cancel blocked: %v", taskID[:8], err)
|
||||
return nil
|
||||
}
|
||||
if fi, statErr := os.Stat(path); statErr == nil {
|
||||
if fi.IsDir() {
|
||||
os.RemoveAll(path)
|
||||
} else {
|
||||
os.Remove(path)
|
||||
}
|
||||
log.Printf("[%s] cleaned up partial download: %s", taskID[:8], name)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *TorrentDownloader) Shutdown(ctx context.Context) error {
|
||||
d.activeMu.Lock()
|
||||
for id, t := range d.active {
|
||||
t.Drop()
|
||||
delete(d.active, id)
|
||||
}
|
||||
d.activeMu.Unlock()
|
||||
|
||||
errs := d.client.Close()
|
||||
if len(errs) > 0 {
|
||||
return fmt.Errorf("close client: %v", errs[0])
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartStream starts an HTTP server for an active torrent download.
|
||||
// It selects the largest video file and serves it via HTTP Range requests.
|
||||
// Returns the running server (caller is responsible for shutdown).
|
||||
func (d *TorrentDownloader) StartStream(taskID string) (*StreamServer, error) {
|
||||
d.activeMu.Lock()
|
||||
t, ok := d.active[taskID]
|
||||
d.activeMu.Unlock()
|
||||
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no active torrent for task %s", taskID[:8])
|
||||
}
|
||||
|
||||
// Select largest video file
|
||||
files := t.Files()
|
||||
var video *torrent.File
|
||||
for _, f := range files {
|
||||
ext := strings.ToLower(filepath.Ext(f.DisplayPath()))
|
||||
if VideoExts[ext] && (video == nil || f.Length() > video.Length()) {
|
||||
video = f
|
||||
}
|
||||
}
|
||||
if video == nil {
|
||||
// No video — use largest file
|
||||
for _, f := range files {
|
||||
if video == nil || f.Length() > video.Length() {
|
||||
video = f
|
||||
}
|
||||
}
|
||||
}
|
||||
if video == nil {
|
||||
return nil, fmt.Errorf("torrent has no files")
|
||||
}
|
||||
|
||||
srv := NewStreamServerFromFile(video, 0)
|
||||
url, err := srv.Start(context.Background())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("start stream server: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[%s] stream started: %s → %s", taskID[:8], filepath.Base(video.DisplayPath()), url)
|
||||
return srv, nil
|
||||
}
|
||||
|
||||
// VideoExts is the canonical set of video file extensions used for file selection.
|
||||
var VideoExts = map[string]bool{
|
||||
".mkv": true, ".mp4": true, ".avi": true, ".m4v": true,
|
||||
".wmv": true, ".ts": true, ".webm": true, ".mov": true,
|
||||
".mpg": true, ".mpeg": true, ".vob": true, ".flv": true,
|
||||
}
|
||||
|
||||
var subExts = map[string]bool{
|
||||
".srt": true, ".ass": true, ".sub": true, ".ssa": true, ".vtt": true,
|
||||
}
|
||||
|
||||
// selectFiles picks the largest video file + matching subtitles.
|
||||
// Falls back to downloading everything if no video file is found.
|
||||
// Returns the total bytes to download and the primary file name.
|
||||
func (d *TorrentDownloader) selectFiles(t *torrent.Torrent, taskID string) (totalBytes int64, fileName string) {
|
||||
files := t.Files()
|
||||
|
||||
if len(files) <= 1 {
|
||||
t.DownloadAll()
|
||||
return t.Length(), t.Name()
|
||||
}
|
||||
|
||||
// Find largest video file
|
||||
var video *torrent.File
|
||||
for _, f := range files {
|
||||
ext := strings.ToLower(filepath.Ext(f.DisplayPath()))
|
||||
if VideoExts[ext] && (video == nil || f.Length() > video.Length()) {
|
||||
video = f
|
||||
}
|
||||
}
|
||||
|
||||
if video == nil {
|
||||
// No video (music, software, etc.) — download everything
|
||||
t.DownloadAll()
|
||||
return t.Length(), t.Name()
|
||||
}
|
||||
|
||||
// Download only the video
|
||||
video.Download()
|
||||
totalBytes = video.Length()
|
||||
fileName = video.DisplayPath()
|
||||
|
||||
// Also download matching subtitles
|
||||
videoBase := strings.TrimSuffix(video.DisplayPath(), filepath.Ext(video.DisplayPath()))
|
||||
var subCount int
|
||||
for _, f := range files {
|
||||
ext := strings.ToLower(filepath.Ext(f.DisplayPath()))
|
||||
if subExts[ext] {
|
||||
fBase := strings.TrimSuffix(f.DisplayPath(), filepath.Ext(f.DisplayPath()))
|
||||
// Match by prefix (handles Movie.en.srt, Movie.es.srt)
|
||||
if strings.HasPrefix(fBase, videoBase) || filepath.Dir(f.DisplayPath()) == filepath.Dir(video.DisplayPath()) {
|
||||
f.Download()
|
||||
totalBytes += f.Length()
|
||||
subCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
skipped := len(files) - 1 - subCount
|
||||
if skipped > 0 {
|
||||
log.Printf("[%s] selected: %s (%s) + %d subs, skipped %d files",
|
||||
taskID[:8], filepath.Base(fileName), formatBytes(video.Length()), subCount, skipped)
|
||||
}
|
||||
|
||||
return totalBytes, fileName
|
||||
}
|
||||
|
||||
func buildMagnet(infoHash string) string {
|
||||
params := []string{"xt=urn:btih:" + infoHash}
|
||||
for _, tracker := range defaultTrackers {
|
||||
params = append(params, "tr="+url.QueryEscape(tracker))
|
||||
}
|
||||
return "magnet:?" + strings.Join(params, "&")
|
||||
}
|
||||
|
||||
func formatBytes(b int64) string {
|
||||
const unit = 1024
|
||||
if b < unit {
|
||||
return fmt.Sprintf("%d B", b)
|
||||
}
|
||||
div, exp := int64(unit), 0
|
||||
for n := b / unit; n >= unit; n /= unit {
|
||||
div *= unit
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %s", float64(b)/float64(div), []string{"KB", "MB", "GB", "TB"}[exp])
|
||||
}
|
||||
26
internal/engine/usenet.go
Normal file
26
internal/engine/usenet.go
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// UsenetDownloader downloads via Usenet/NZB protocol.
|
||||
// Currently a stub — not implemented.
|
||||
type UsenetDownloader struct{}
|
||||
|
||||
func NewUsenetDownloader() *UsenetDownloader { return &UsenetDownloader{} }
|
||||
|
||||
func (u *UsenetDownloader) Method() DownloadMethod { return MethodUsenet }
|
||||
|
||||
func (u *UsenetDownloader) Available(_ context.Context, _ *Task) (bool, error) {
|
||||
return false, nil // always unavailable until implemented
|
||||
}
|
||||
|
||||
func (u *UsenetDownloader) Download(_ context.Context, _ *Task, _ string, _ chan<- Progress) (*Result, error) {
|
||||
return nil, fmt.Errorf("usenet download not implemented yet (coming in a future release)")
|
||||
}
|
||||
|
||||
func (u *UsenetDownloader) Pause(_ string) error { return nil }
|
||||
func (u *UsenetDownloader) Cancel(_ string) error { return nil }
|
||||
func (u *UsenetDownloader) Shutdown(_ context.Context) error { return nil }
|
||||
59
internal/engine/verify.go
Normal file
59
internal/engine/verify.go
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
package engine
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// verify checks that a downloaded file or directory is valid.
|
||||
func verify(result *Result) error {
|
||||
if result == nil || result.FilePath == "" {
|
||||
return fmt.Errorf("no file path in result")
|
||||
}
|
||||
|
||||
fi, err := os.Stat(result.FilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("file not found: %w", err)
|
||||
}
|
||||
|
||||
// Get actual size — handle both files and directories (multi-file torrents)
|
||||
var actualSize int64
|
||||
if fi.IsDir() {
|
||||
actualSize, err = dirSize(result.FilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not calculate dir size: %w", err)
|
||||
}
|
||||
} else {
|
||||
actualSize = fi.Size()
|
||||
}
|
||||
|
||||
if actualSize == 0 {
|
||||
return fmt.Errorf("download is empty: %s", result.FilePath)
|
||||
}
|
||||
|
||||
// If we know the expected size, check within 2% tolerance
|
||||
if result.Size > 0 {
|
||||
tolerance := int64(float64(result.Size) * 0.02)
|
||||
if actualSize < result.Size-tolerance {
|
||||
return fmt.Errorf("size mismatch: expected %d, got %d", result.Size, actualSize)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// dirSize returns total size of all files in a directory.
|
||||
func dirSize(path string) (int64, error) {
|
||||
var total int64
|
||||
err := filepath.Walk(path, func(_ string, fi os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !fi.IsDir() {
|
||||
total += fi.Size()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return total, err
|
||||
}
|
||||
71
internal/engine/verify_test.go
Normal file
71
internal/engine/verify_test.go
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
package engine
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestVerifyNilResult(t *testing.T) {
|
||||
if err := verify(nil); err == nil {
|
||||
t.Error("expected error for nil result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyEmptyPath(t *testing.T) {
|
||||
if err := verify(&Result{}); err == nil {
|
||||
t.Error("expected error for empty path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyMissingFile(t *testing.T) {
|
||||
err := verify(&Result{FilePath: "/nonexistent/file.mkv"})
|
||||
if err == nil {
|
||||
t.Error("expected error for missing file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyEmptyFile(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
path := filepath.Join(tmp, "empty.mkv")
|
||||
os.WriteFile(path, []byte{}, 0o644)
|
||||
|
||||
err := verify(&Result{FilePath: path})
|
||||
if err == nil {
|
||||
t.Error("expected error for empty file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyValidFile(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
path := filepath.Join(tmp, "movie.mkv")
|
||||
os.WriteFile(path, make([]byte, 1024), 0o644)
|
||||
|
||||
err := verify(&Result{FilePath: path, Size: 1024})
|
||||
if err != nil {
|
||||
t.Errorf("valid file should pass: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifySizeMismatch(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
path := filepath.Join(tmp, "movie.mkv")
|
||||
os.WriteFile(path, make([]byte, 500), 0o644)
|
||||
|
||||
err := verify(&Result{FilePath: path, Size: 1000})
|
||||
if err == nil {
|
||||
t.Error("expected error for size mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyNoExpectedSize(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
path := filepath.Join(tmp, "movie.mkv")
|
||||
os.WriteFile(path, make([]byte, 1024), 0o644)
|
||||
|
||||
// Size=0 means unknown, should pass
|
||||
err := verify(&Result{FilePath: path, Size: 0})
|
||||
if err != nil {
|
||||
t.Errorf("no expected size should pass: %v", err)
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue