290 lines
7.6 KiB
Go
290 lines
7.6 KiB
Go
package agent
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"runtime"
|
|
"sync/atomic"
|
|
"time"
|
|
)
|
|
|
|
// DaemonConfig holds daemon runtime settings.
|
|
type DaemonConfig struct {
|
|
AgentID string
|
|
AgentName string
|
|
Version string
|
|
DownloadDir string
|
|
PollInterval time.Duration
|
|
HeartbeatInterval time.Duration
|
|
}
|
|
|
|
// Daemon manages the main loop: register, heartbeat, poll tasks.
|
|
type Daemon struct {
|
|
cfg DaemonConfig
|
|
transport Transport
|
|
|
|
// Callbacks
|
|
OnTasksClaimed func(tasks []Task)
|
|
OnStreamRequested func(req StreamRequest)
|
|
OnControlAction func(action, taskID string)
|
|
|
|
// State
|
|
User UserInfo
|
|
Features FeatureFlags
|
|
Info AgentInfo
|
|
State DaemonState
|
|
heartbeatFailures int
|
|
lastNotifiedVersion string
|
|
|
|
// Callbacks for state tracking (set by cmd/daemon.go)
|
|
GetActiveCount func() int
|
|
GetCleanableBytes func() int64
|
|
|
|
// Watching tracks whether a user is viewing download progress in the web UI.
|
|
// When false, the progress reporter skips detailed updates (only sends final states).
|
|
// Accessed from heartbeat goroutine, flush goroutine, and WatchingFunc closure — must be atomic.
|
|
Watching atomic.Bool
|
|
|
|
// Exposed tickers for hot-reload
|
|
PollTicker *time.Ticker
|
|
HeartbeatTicker *time.Ticker
|
|
|
|
// pollNow triggers an immediate poll (e.g. on resume)
|
|
pollNow chan struct{}
|
|
}
|
|
|
|
// NewDaemon creates a daemon with the given transport.
|
|
// Use NewHTTPTransport for HTTP-only, or NewHybridTransport for WS+HTTP.
|
|
func NewDaemon(cfg DaemonConfig, transport Transport) *Daemon {
|
|
if cfg.PollInterval == 0 {
|
|
cfg.PollInterval = 30 * time.Second
|
|
}
|
|
if cfg.HeartbeatInterval == 0 {
|
|
cfg.HeartbeatInterval = 30 * time.Second
|
|
}
|
|
|
|
return &Daemon{
|
|
cfg: cfg,
|
|
transport: transport,
|
|
pollNow: make(chan struct{}, 1),
|
|
}
|
|
}
|
|
|
|
// Transport returns the configured transport.
|
|
func (d *Daemon) Transport() Transport { return d.transport }
|
|
|
|
// Register registers the agent and fetches user info + features.
|
|
func (d *Daemon) Register(ctx context.Context) error {
|
|
req := RegisterRequest{
|
|
AgentID: d.cfg.AgentID,
|
|
Name: d.cfg.AgentName,
|
|
OS: runtime.GOOS,
|
|
Arch: runtime.GOARCH,
|
|
Version: d.cfg.Version,
|
|
DownloadDir: d.cfg.DownloadDir,
|
|
}
|
|
if free, total, err := DiskInfo(d.cfg.DownloadDir); err == nil {
|
|
req.DiskFreeBytes = free
|
|
req.DiskTotalBytes = total
|
|
}
|
|
|
|
resp, err := d.transport.Register(ctx, req)
|
|
if err != nil {
|
|
return fmt.Errorf("register: %w", err)
|
|
}
|
|
|
|
d.User = resp.User
|
|
d.Features = resp.Features
|
|
now := time.Now()
|
|
d.Info = AgentInfo{
|
|
ID: d.cfg.AgentID,
|
|
Name: d.cfg.AgentName,
|
|
User: resp.User,
|
|
Features: resp.Features,
|
|
StartedAt: now,
|
|
}
|
|
d.State = DaemonState{
|
|
AgentID: d.cfg.AgentID,
|
|
Status: "running",
|
|
Version: d.cfg.Version,
|
|
PID: os.Getpid(),
|
|
StartedAt: now,
|
|
MethodStats: make(map[string]int),
|
|
}
|
|
WriteState(&d.State)
|
|
|
|
return nil
|
|
}
|
|
|
|
// Run starts the main daemon loop. Blocks until ctx is cancelled.
|
|
func (d *Daemon) Run(ctx context.Context) error {
|
|
// Register
|
|
if err := d.Register(ctx); err != nil {
|
|
return err
|
|
}
|
|
|
|
log.Printf("Agent registered: %s (%s) [%s]", d.User.Name, d.User.Email, d.User.Plan)
|
|
log.Printf("Features: torrent=%v debrid=%v usenet=%v", d.Features.Torrent, d.Features.Debrid, d.Features.Usenet)
|
|
log.Printf("Polling every %s, heartbeat every %s", d.cfg.PollInterval, d.cfg.HeartbeatInterval)
|
|
|
|
d.HeartbeatTicker = time.NewTicker(d.cfg.HeartbeatInterval)
|
|
defer d.HeartbeatTicker.Stop()
|
|
|
|
d.PollTicker = time.NewTicker(d.cfg.PollInterval)
|
|
defer d.PollTicker.Stop()
|
|
|
|
heartbeatTicker := d.HeartbeatTicker
|
|
pollTicker := d.PollTicker
|
|
|
|
// Initial poll immediately
|
|
d.poll(ctx)
|
|
|
|
eventsCh := d.transport.Events()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
log.Println("Daemon shutting down...")
|
|
d.deregister()
|
|
return nil
|
|
|
|
case event := <-eventsCh:
|
|
d.handleEvent(event)
|
|
|
|
case <-heartbeatTicker.C:
|
|
d.heartbeat(ctx)
|
|
|
|
case <-pollTicker.C:
|
|
// Only poll in HTTP mode — WS mode receives tasks via Events
|
|
if d.transport.Mode() == "http" {
|
|
d.poll(ctx)
|
|
}
|
|
|
|
case <-d.pollNow:
|
|
d.poll(ctx)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (d *Daemon) heartbeat(ctx context.Context) {
|
|
req := HeartbeatRequest{
|
|
AgentID: d.cfg.AgentID,
|
|
Name: d.cfg.AgentName,
|
|
Version: d.cfg.Version,
|
|
OS: runtime.GOOS,
|
|
DownloadDir: d.cfg.DownloadDir,
|
|
}
|
|
if free, total, err := DiskInfo(d.cfg.DownloadDir); err == nil {
|
|
req.DiskFreeBytes = free
|
|
req.DiskTotalBytes = total
|
|
}
|
|
|
|
resp, err := d.transport.SendHeartbeat(ctx, req)
|
|
if err != nil {
|
|
d.heartbeatFailures++
|
|
if d.heartbeatFailures >= 5 && d.heartbeatFailures%5 == 0 {
|
|
log.Printf("CRITICAL: %d consecutive heartbeat failures — server may be unreachable", d.heartbeatFailures)
|
|
} else {
|
|
log.Printf("Heartbeat failed: %v", err)
|
|
}
|
|
return
|
|
}
|
|
if d.heartbeatFailures > 0 {
|
|
log.Printf("Heartbeat recovered after %d failures", d.heartbeatFailures)
|
|
d.heartbeatFailures = 0
|
|
}
|
|
|
|
// Update watching flag and state file
|
|
d.Watching.Store(resp.Watching)
|
|
d.State.LastHeartbeat = time.Now()
|
|
if d.GetActiveCount != nil {
|
|
d.State.ActiveTasks = d.GetActiveCount()
|
|
}
|
|
WriteState(&d.State)
|
|
|
|
// Log once per version when server suggests an upgrade
|
|
if resp.Upgrade != nil && resp.Upgrade.Version != "" && resp.Upgrade.Version != d.lastNotifiedVersion {
|
|
d.lastNotifiedVersion = resp.Upgrade.Version
|
|
log.Printf("New version available: %s (run `unarr self-update` to upgrade)", resp.Upgrade.Version)
|
|
}
|
|
}
|
|
|
|
// handleEvent processes a server-initiated event from the WebSocket transport.
|
|
func (d *Daemon) handleEvent(event ServerEvent) {
|
|
switch event.Type {
|
|
case "tasks":
|
|
if event.Tasks != nil && len(event.Tasks.Tasks) > 0 {
|
|
log.Printf("Received %d task(s) via WebSocket", len(event.Tasks.Tasks))
|
|
if d.OnTasksClaimed != nil {
|
|
d.OnTasksClaimed(event.Tasks.Tasks)
|
|
}
|
|
}
|
|
if event.Tasks != nil && d.OnStreamRequested != nil {
|
|
for _, sr := range event.Tasks.StreamRequests {
|
|
d.OnStreamRequested(sr)
|
|
}
|
|
}
|
|
|
|
case "upgrade":
|
|
if event.Upgrade != nil && event.Upgrade.Version != "" && event.Upgrade.Version != d.lastNotifiedVersion {
|
|
d.lastNotifiedVersion = event.Upgrade.Version
|
|
log.Printf("New version available: %s (run `unarr self-update` to upgrade)", event.Upgrade.Version)
|
|
}
|
|
|
|
case "control":
|
|
if event.Control != nil && d.OnControlAction != nil {
|
|
log.Printf("Control action via WebSocket: %s task %s", event.Control.Action, event.Control.TaskID)
|
|
d.OnControlAction(event.Control.Action, event.Control.TaskID)
|
|
}
|
|
|
|
case "disconnected":
|
|
log.Println("WebSocket disconnected, switching to HTTP polling")
|
|
}
|
|
}
|
|
|
|
// TriggerPoll requests an immediate task poll cycle.
|
|
// Used when a resume event is received to pick up re-pending tasks faster.
|
|
func (d *Daemon) TriggerPoll() {
|
|
select {
|
|
case d.pollNow <- struct{}{}:
|
|
default: // already pending
|
|
}
|
|
}
|
|
|
|
func (d *Daemon) deregister() {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
err := d.transport.Deregister(ctx, d.cfg.AgentID)
|
|
if err != nil {
|
|
log.Printf("Deregister failed: %v", err)
|
|
} else {
|
|
log.Println("Agent deregistered")
|
|
}
|
|
RemoveState()
|
|
}
|
|
|
|
func (d *Daemon) poll(ctx context.Context) {
|
|
resp, err := d.transport.ClaimTasks(ctx, d.cfg.AgentID)
|
|
if err != nil {
|
|
log.Printf("Poll failed: %v", err)
|
|
return
|
|
}
|
|
|
|
d.Info.LastPollAt = time.Now()
|
|
|
|
if len(resp.Tasks) > 0 {
|
|
log.Printf("Claimed %d task(s)", len(resp.Tasks))
|
|
if d.OnTasksClaimed != nil {
|
|
d.OnTasksClaimed(resp.Tasks)
|
|
}
|
|
}
|
|
|
|
// Handle stream requests for completed downloads
|
|
if d.OnStreamRequested != nil {
|
|
for _, sr := range resp.StreamRequests {
|
|
d.OnStreamRequested(sr)
|
|
}
|
|
}
|
|
}
|