unarr/internal/agent/daemon.go
Deivid Soto 64734cad1f feat(agent): send stream port and IPs in register request
Include StreamPort, LanIP, and TailscaleIP in RegisterRequest so the
server knows the agent's stream endpoints from the moment it registers,
not just after the first heartbeat. Align HeartbeatRequest field order
with RegisterRequest for consistency.
2026-04-07 23:28:41 +02:00

378 lines
10 KiB
Go

package agent
import (
"context"
"errors"
"fmt"
"log"
"os"
"runtime"
"strings"
"sync/atomic"
"time"
)
// DaemonConfig holds daemon runtime settings.
type DaemonConfig struct {
AgentID string
AgentName string
Version string
DownloadDir string
PollInterval time.Duration
HeartbeatInterval time.Duration
StreamPort int // port for the HTTP stream server (reported in heartbeat)
LanIP string // LAN IP (reported in heartbeat for stream URL resolution)
TailscaleIP string // Tailscale IP (reported in heartbeat for stream URL resolution)
}
// Daemon manages the main loop: register, heartbeat, poll tasks.
type Daemon struct {
cfg DaemonConfig
transport Transport
// Callbacks
OnTasksClaimed func(tasks []Task)
OnStreamRequested func(req StreamRequest)
OnControlAction func(action, taskID string)
// State
User UserInfo
Features FeatureFlags
Info AgentInfo
State DaemonState
heartbeatFailures int
lastNotifiedVersion string
// Callbacks for state tracking (set by cmd/daemon.go)
GetActiveCount func() int
GetCleanableBytes func() int64
// Watching tracks whether a user is viewing download progress in the web UI.
// When false, the progress reporter skips detailed updates (only sends final states).
// Accessed from heartbeat goroutine, flush goroutine, and WatchingFunc closure — must be atomic.
Watching atomic.Bool
// Exposed tickers for hot-reload
PollTicker *time.Ticker
HeartbeatTicker *time.Ticker
// pollNow triggers an immediate poll (e.g. on resume)
pollNow chan struct{}
// ScanNow triggers an immediate library scan (from heartbeat or WebSocket control event)
ScanNow chan struct{}
}
// NewDaemon creates a daemon with the given transport.
// Use NewHTTPTransport for HTTP-only, or NewHybridTransport for WS+HTTP.
func NewDaemon(cfg DaemonConfig, transport Transport) *Daemon {
if cfg.PollInterval == 0 {
cfg.PollInterval = 30 * time.Second
}
if cfg.HeartbeatInterval == 0 {
cfg.HeartbeatInterval = 30 * time.Second
}
return &Daemon{
cfg: cfg,
transport: transport,
pollNow: make(chan struct{}, 1),
ScanNow: make(chan struct{}, 1),
}
}
// Transport returns the configured transport.
func (d *Daemon) Transport() Transport { return d.transport }
// Register registers the agent and fetches user info + features.
// Retries with exponential backoff on transient errors (429, 5xx, network).
func (d *Daemon) Register(ctx context.Context) error {
req := RegisterRequest{
AgentID: d.cfg.AgentID,
Name: d.cfg.AgentName,
OS: runtime.GOOS,
Arch: runtime.GOARCH,
Version: d.cfg.Version,
DownloadDir: d.cfg.DownloadDir,
StreamPort: d.cfg.StreamPort,
LanIP: d.cfg.LanIP,
TailscaleIP: d.cfg.TailscaleIP,
}
if free, total, err := DiskInfo(d.cfg.DownloadDir); err == nil {
req.DiskFreeBytes = free
req.DiskTotalBytes = total
}
const maxRetries = 5
backoff := 5 * time.Second
var resp *RegisterResponse
var err error
for attempt := range maxRetries {
resp, err = d.transport.Register(ctx, req)
if err == nil {
break
}
// Only retry on transient errors (429, 5xx, network failures)
if !isTransientError(err) {
return fmt.Errorf("register: %w", err)
}
log.Printf("Register failed (attempt %d/%d): %v - retrying in %v", attempt+1, maxRetries, err, backoff)
timer := time.NewTimer(backoff)
select {
case <-ctx.Done():
timer.Stop()
return fmt.Errorf("register: %w", ctx.Err())
case <-timer.C:
}
backoff = min(backoff*2, 60*time.Second)
}
if err != nil {
return fmt.Errorf("register: %w (after %d retries)", err, maxRetries)
}
d.User = resp.User
d.Features = resp.Features
now := time.Now()
d.Info = AgentInfo{
ID: d.cfg.AgentID,
Name: d.cfg.AgentName,
User: resp.User,
Features: resp.Features,
StartedAt: now,
}
d.State = DaemonState{
AgentID: d.cfg.AgentID,
Status: "running",
Version: d.cfg.Version,
PID: os.Getpid(),
StartedAt: now,
MethodStats: make(map[string]int),
}
WriteState(&d.State)
return nil
}
// Run connects the transport, registers the agent, and starts the main loop.
// Blocks until ctx is cancelled. Callers must NOT call transport.Connect before Run.
func (d *Daemon) Run(ctx context.Context) error {
// Connect transport (establishes WebSocket if available, falls back to HTTP)
if err := d.transport.Connect(ctx); err != nil {
return fmt.Errorf("connect transport: %w", err)
}
// Register
if err := d.Register(ctx); err != nil {
return err
}
log.Printf("Agent registered: %s (%s) [%s]", d.User.Name, d.User.Email, d.User.Plan)
log.Printf("Features: torrent=%v debrid=%v usenet=%v", d.Features.Torrent, d.Features.Debrid, d.Features.Usenet)
log.Printf("Polling every %s, heartbeat every %s", d.cfg.PollInterval, d.cfg.HeartbeatInterval)
d.HeartbeatTicker = time.NewTicker(d.cfg.HeartbeatInterval)
defer d.HeartbeatTicker.Stop()
d.PollTicker = time.NewTicker(d.cfg.PollInterval)
defer d.PollTicker.Stop()
heartbeatTicker := d.HeartbeatTicker
pollTicker := d.PollTicker
// Initial poll immediately
d.poll(ctx)
eventsCh := d.transport.Events()
for {
select {
case <-ctx.Done():
log.Println("Daemon shutting down...")
d.deregister()
return nil
case event := <-eventsCh:
d.handleEvent(event)
case <-heartbeatTicker.C:
d.heartbeat(ctx)
case <-pollTicker.C:
// Only poll in HTTP mode — WS mode receives tasks via Events
if d.transport.Mode() == "http" {
d.poll(ctx)
}
case <-d.pollNow:
d.poll(ctx)
}
}
}
func (d *Daemon) heartbeat(ctx context.Context) {
req := HeartbeatRequest{
AgentID: d.cfg.AgentID,
Name: d.cfg.AgentName,
Version: d.cfg.Version,
OS: runtime.GOOS,
DownloadDir: d.cfg.DownloadDir,
StreamPort: d.cfg.StreamPort,
LanIP: d.cfg.LanIP,
TailscaleIP: d.cfg.TailscaleIP,
}
if free, total, err := DiskInfo(d.cfg.DownloadDir); err == nil {
req.DiskFreeBytes = free
req.DiskTotalBytes = total
}
resp, err := d.transport.SendHeartbeat(ctx, req)
if err != nil {
d.heartbeatFailures++
if d.heartbeatFailures >= 5 && d.heartbeatFailures%5 == 0 {
log.Printf("CRITICAL: %d consecutive heartbeat failures — server may be unreachable", d.heartbeatFailures)
} else {
log.Printf("Heartbeat failed: %v", err)
}
return
}
if d.heartbeatFailures > 0 {
log.Printf("Heartbeat recovered after %d failures", d.heartbeatFailures)
d.heartbeatFailures = 0
}
// Update watching flag and state file
d.Watching.Store(resp.Watching)
d.State.LastHeartbeat = time.Now()
if d.GetActiveCount != nil {
d.State.ActiveTasks = d.GetActiveCount()
}
WriteState(&d.State)
// Trigger library scan if requested
if resp.Scan {
log.Printf("Library scan requested by server")
select {
case d.ScanNow <- struct{}{}:
default: // scan already pending
}
}
// Log once per version when server suggests an upgrade
if resp.Upgrade != nil && resp.Upgrade.Version != "" && resp.Upgrade.Version != d.lastNotifiedVersion {
d.lastNotifiedVersion = resp.Upgrade.Version
log.Printf("New version available: %s (run `unarr self-update` to upgrade)", resp.Upgrade.Version)
}
}
// handleEvent processes a server-initiated event from the WebSocket transport.
func (d *Daemon) handleEvent(event ServerEvent) {
switch event.Type {
case "tasks":
if event.Tasks != nil && len(event.Tasks.Tasks) > 0 {
log.Printf("Received %d task(s) via WebSocket", len(event.Tasks.Tasks))
if d.OnTasksClaimed != nil {
d.OnTasksClaimed(event.Tasks.Tasks)
}
}
if event.Tasks != nil && d.OnStreamRequested != nil {
for _, sr := range event.Tasks.StreamRequests {
d.OnStreamRequested(sr)
}
}
case "upgrade":
if event.Upgrade != nil && event.Upgrade.Version != "" && event.Upgrade.Version != d.lastNotifiedVersion {
d.lastNotifiedVersion = event.Upgrade.Version
log.Printf("New version available: %s (run `unarr self-update` to upgrade)", event.Upgrade.Version)
}
case "control":
if event.Control != nil {
log.Printf("Control action via WebSocket: %s task %s", event.Control.Action, event.Control.TaskID)
if event.Control.Action == "scan" {
select {
case d.ScanNow <- struct{}{}:
default:
}
}
if d.OnControlAction != nil {
d.OnControlAction(event.Control.Action, event.Control.TaskID)
}
}
case "disconnected":
log.Println("WebSocket disconnected, switching to HTTP polling")
}
}
// UpdateStreamPort updates the stream port reported in heartbeats.
// Called after the persistent stream server binds (actual port may differ from configured).
func (d *Daemon) UpdateStreamPort(port int) {
d.cfg.StreamPort = port
}
// TriggerPoll requests an immediate task poll cycle.
// Used when a resume event is received to pick up re-pending tasks faster.
func (d *Daemon) TriggerPoll() {
select {
case d.pollNow <- struct{}{}:
default: // already pending
}
}
func (d *Daemon) deregister() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := d.transport.Deregister(ctx, d.cfg.AgentID)
if err != nil {
log.Printf("Deregister failed: %v", err)
} else {
log.Println("Agent deregistered")
}
RemoveState()
}
// isTransientError returns true for errors worth retrying (429, 5xx, network).
func isTransientError(err error) bool {
if err == nil {
return false
}
// Structured check: HTTPError carries the status code directly
var httpErr *HTTPError
if errors.As(err, &httpErr) {
return httpErr.StatusCode == 429 || httpErr.StatusCode >= 500
}
// Fallback: network-level errors (no HTTP response received)
lower := strings.ToLower(err.Error())
for _, keyword := range []string{"connection refused", "no such host", "timeout", "request failed"} {
if strings.Contains(lower, keyword) {
return true
}
}
return false
}
func (d *Daemon) poll(ctx context.Context) {
resp, err := d.transport.ClaimTasks(ctx, d.cfg.AgentID)
if err != nil {
log.Printf("Poll failed: %v", err)
return
}
d.Info.LastPollAt = time.Now()
if len(resp.Tasks) > 0 {
log.Printf("Claimed %d task(s)", len(resp.Tasks))
if d.OnTasksClaimed != nil {
d.OnTasksClaimed(resp.Tasks)
}
}
// Handle stream requests for completed downloads
if d.OnStreamRequested != nil {
for _, sr := range resp.StreamRequests {
d.OnStreamRequested(sr)
}
}
}