220 lines
6.6 KiB
Go
220 lines
6.6 KiB
Go
package engine
|
|
|
|
import (
|
|
"context"
|
|
"log"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/torrentclaw/unarr/internal/agent"
|
|
)
|
|
|
|
// ActionFunc is called when the server signals an action on a task.
|
|
type ActionFunc func(taskID string)
|
|
|
|
// StatusReporter is the interface used by ProgressReporter to send progress updates.
|
|
// Both *agent.Client and agent.Transport implement this via their ReportStatus/SendProgress methods.
|
|
type StatusReporter interface {
|
|
ReportStatus(ctx context.Context, update agent.StatusUpdate) (*agent.StatusResponse, error)
|
|
}
|
|
|
|
// BatchStatusReporter extends StatusReporter with batch support.
|
|
// Transports that implement this send all updates in a single request.
|
|
type BatchStatusReporter interface {
|
|
StatusReporter
|
|
BatchReportStatus(ctx context.Context, updates []agent.StatusUpdate) (*agent.BatchStatusResponse, error)
|
|
}
|
|
|
|
// WatchingFunc returns whether a user is actively viewing download progress.
|
|
type WatchingFunc func() bool
|
|
|
|
// ProgressReporter aggregates progress from downloads and reports to the API.
|
|
// It batches updates to avoid flooding the server.
|
|
type ProgressReporter struct {
|
|
reporter StatusReporter
|
|
interval time.Duration
|
|
isWatching WatchingFunc // nil = always report (backwards compatible)
|
|
|
|
onCancel ActionFunc
|
|
onPause ActionFunc
|
|
onDeleteFiles ActionFunc
|
|
onStreamRequested ActionFunc
|
|
|
|
mu sync.Mutex
|
|
latest map[string]*Task // taskID -> task with latest progress
|
|
}
|
|
|
|
// NewProgressReporter creates a reporter that flushes every interval.
|
|
// Accepts *agent.Client directly (backwards compatible).
|
|
func NewProgressReporter(ac *agent.Client, interval time.Duration) *ProgressReporter {
|
|
return &ProgressReporter{
|
|
reporter: ac,
|
|
interval: interval,
|
|
latest: make(map[string]*Task),
|
|
}
|
|
}
|
|
|
|
// NewProgressReporterWithTransport creates a reporter using a Transport.
|
|
func NewProgressReporterWithTransport(t agent.Transport, interval time.Duration) *ProgressReporter {
|
|
return &ProgressReporter{
|
|
reporter: &transportStatusAdapter{t: t},
|
|
interval: interval,
|
|
latest: make(map[string]*Task),
|
|
}
|
|
}
|
|
|
|
// transportStatusAdapter adapts agent.Transport to StatusReporter.
|
|
type transportStatusAdapter struct {
|
|
t agent.Transport
|
|
}
|
|
|
|
func (a *transportStatusAdapter) ReportStatus(ctx context.Context, update agent.StatusUpdate) (*agent.StatusResponse, error) {
|
|
return a.t.SendProgress(ctx, update)
|
|
}
|
|
|
|
// SetCancelHandler sets the callback invoked when the server says a task is cancelled.
|
|
func (r *ProgressReporter) SetCancelHandler(fn ActionFunc) { r.onCancel = fn }
|
|
|
|
// SetPauseHandler sets the callback invoked when the server says a task is paused.
|
|
func (r *ProgressReporter) SetPauseHandler(fn ActionFunc) { r.onPause = fn }
|
|
|
|
// SetDeleteFilesHandler sets the callback for cancel+delete files.
|
|
func (r *ProgressReporter) SetDeleteFilesHandler(fn ActionFunc) { r.onDeleteFiles = fn }
|
|
|
|
// SetStreamRequestedHandler sets the callback for stream activation.
|
|
func (r *ProgressReporter) SetStreamRequestedHandler(fn ActionFunc) { r.onStreamRequested = fn }
|
|
|
|
// SetWatchingFunc sets the function that checks if someone is viewing downloads.
|
|
func (r *ProgressReporter) SetWatchingFunc(fn WatchingFunc) { r.isWatching = fn }
|
|
|
|
// Track registers a task for progress tracking.
|
|
func (r *ProgressReporter) Track(task *Task) {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
r.latest[task.ID] = task
|
|
}
|
|
|
|
// Untrack removes a task from progress tracking.
|
|
func (r *ProgressReporter) Untrack(taskID string) {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
delete(r.latest, taskID)
|
|
}
|
|
|
|
// Run starts the periodic flush loop. Blocks until ctx is cancelled.
|
|
func (r *ProgressReporter) Run(ctx context.Context) error {
|
|
ticker := time.NewTicker(r.interval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
r.flush(context.Background())
|
|
return nil
|
|
case <-ticker.C:
|
|
r.flush(ctx)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (r *ProgressReporter) flush(ctx context.Context) {
|
|
r.mu.Lock()
|
|
tasks := make([]*Task, 0, len(r.latest))
|
|
for _, t := range r.latest {
|
|
tasks = append(tasks, t)
|
|
}
|
|
r.mu.Unlock()
|
|
|
|
// When nobody is watching, only report final states (completed/failed).
|
|
// This saves ~99% of API requests when the user isn't on the downloads page.
|
|
watching := r.isWatching == nil || r.isWatching()
|
|
|
|
var reportable []*Task
|
|
for _, task := range tasks {
|
|
status := task.GetStatus()
|
|
isFinal := status == StatusCompleted || status == StatusFailed
|
|
isActive := status == StatusDownloading || status == StatusVerifying ||
|
|
status == StatusOrganizing || status == StatusSeeding
|
|
if isFinal || (watching && isActive) {
|
|
reportable = append(reportable, task)
|
|
}
|
|
}
|
|
|
|
if len(reportable) == 0 {
|
|
return
|
|
}
|
|
|
|
// Use batch when transport supports it
|
|
if batcher, ok := r.reporter.(BatchStatusReporter); ok {
|
|
r.flushBatch(ctx, batcher, reportable)
|
|
return
|
|
}
|
|
|
|
// Fallback: individual requests
|
|
for _, task := range reportable {
|
|
update := task.ToStatusUpdate()
|
|
resp, err := r.reporter.ReportStatus(ctx, update)
|
|
if err != nil {
|
|
log.Printf("[%s] progress report failed: %v", task.ID[:8], err)
|
|
continue
|
|
}
|
|
r.handleResponse(task, resp)
|
|
}
|
|
}
|
|
|
|
func (r *ProgressReporter) flushBatch(ctx context.Context, batcher BatchStatusReporter, tasks []*Task) {
|
|
updates := make([]agent.StatusUpdate, len(tasks))
|
|
for i, task := range tasks {
|
|
updates[i] = task.ToStatusUpdate()
|
|
}
|
|
|
|
resp, err := batcher.BatchReportStatus(ctx, updates)
|
|
if err != nil {
|
|
log.Printf("batch progress report failed: %v", err)
|
|
return
|
|
}
|
|
|
|
// Match results back to tasks by index (server returns in same order)
|
|
if len(resp.Results) != len(tasks) {
|
|
log.Printf("batch response mismatch: sent %d updates, got %d results", len(tasks), len(resp.Results))
|
|
}
|
|
for i, result := range resp.Results {
|
|
if i < len(tasks) {
|
|
r.handleResponse(tasks[i], &result)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (r *ProgressReporter) handleResponse(task *Task, resp *agent.StatusResponse) {
|
|
if resp.Cancelled {
|
|
log.Printf("[%s] cancelled by user (via web)", task.ID[:8])
|
|
r.Untrack(task.ID)
|
|
if resp.DeleteFiles && r.onDeleteFiles != nil {
|
|
r.onDeleteFiles(task.ID)
|
|
} else if r.onCancel != nil {
|
|
r.onCancel(task.ID)
|
|
}
|
|
} else if resp.Paused {
|
|
log.Printf("[%s] paused by user (via web)", task.ID[:8])
|
|
r.Untrack(task.ID)
|
|
if r.onPause != nil {
|
|
r.onPause(task.ID)
|
|
}
|
|
}
|
|
|
|
if resp.StreamRequested && task.GetStreamURL() == "" {
|
|
log.Printf("[%s] stream requested by user (via web)", task.ID[:8])
|
|
if r.onStreamRequested != nil {
|
|
r.onStreamRequested(task.ID)
|
|
}
|
|
}
|
|
}
|
|
|
|
// ReportFinal sends a final status update for a completed/failed task.
|
|
func (r *ProgressReporter) ReportFinal(ctx context.Context, task *Task) {
|
|
update := task.ToStatusUpdate()
|
|
if _, err := r.reporter.ReportStatus(ctx, update); err != nil {
|
|
log.Printf("[%s] final report failed: %v", task.ID[:8], err)
|
|
}
|
|
r.Untrack(task.ID)
|
|
}
|