feat(agent): add WebSocket transport with HTTP fallback

Add Transport interface abstraction supporting WebSocket (via CF
Durable Objects) and HTTP (direct to origin) with automatic failover.

- Transport interface: Register, SendHeartbeat, SendProgress, Events()
- HTTPTransport: thin adapter over existing Client
- WSTransport: gorilla/websocket with auth handshake, readLoop, reconnect
- HybridTransport: tries WS first, falls back to HTTP, reconnects in bg
- Daemon refactored to always use Transport (no dual-path forks)
- ProgressReporter accepts StatusReporter interface
- deriveWSURL skips localhost/dev (returns "" → HTTP-only)
- API key passed in WS query param for connection auth
- Fixed: reconnectOnce race (mutex+bool), authDone double-close (sync.Once)
- Fixed: forwardWSEvents goroutine leak (select with stop signal)
- 20 transport tests + 2 E2E tests (full lifecycle, hybrid failover)
This commit is contained in:
Deivid Soto 2026-03-28 18:55:29 +01:00
parent 5e80911501
commit 5f337eebd7
10 changed files with 1646 additions and 64 deletions

View file

@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"log" "log"
"os"
"runtime" "runtime"
"time" "time"
) )
@ -21,19 +22,33 @@ type DaemonConfig struct {
// Daemon manages the main loop: register, heartbeat, poll tasks. // Daemon manages the main loop: register, heartbeat, poll tasks.
type Daemon struct { type Daemon struct {
cfg DaemonConfig cfg DaemonConfig
client *Client transport Transport
// Callbacks // Callbacks
OnTasksClaimed func(tasks []Task) OnTasksClaimed func(tasks []Task)
OnStreamRequested func(req StreamRequest)
OnUpgradeRequested func(version string)
OnControlAction func(action, taskID string)
// State // State
User UserInfo User UserInfo
Features FeatureFlags Features FeatureFlags
Info AgentInfo Info AgentInfo
State DaemonState
upgradeInProgress bool
heartbeatFailures int
// Callbacks for state tracking (set by cmd/daemon.go)
GetActiveCount func() int
// Exposed tickers for hot-reload
PollTicker *time.Ticker
HeartbeatTicker *time.Ticker
} }
// NewDaemon creates a daemon with the given config and agent client. // NewDaemon creates a daemon with the given transport.
func NewDaemon(cfg DaemonConfig, client *Client) *Daemon { // Use NewHTTPTransport for HTTP-only, or NewHybridTransport for WS+HTTP.
func NewDaemon(cfg DaemonConfig, transport Transport) *Daemon {
if cfg.PollInterval == 0 { if cfg.PollInterval == 0 {
cfg.PollInterval = 30 * time.Second cfg.PollInterval = 30 * time.Second
} }
@ -43,10 +58,13 @@ func NewDaemon(cfg DaemonConfig, client *Client) *Daemon {
return &Daemon{ return &Daemon{
cfg: cfg, cfg: cfg,
client: client, transport: transport,
} }
} }
// Transport returns the configured transport.
func (d *Daemon) Transport() Transport { return d.transport }
// Register registers the agent and fetches user info + features. // Register registers the agent and fetches user info + features.
func (d *Daemon) Register(ctx context.Context) error { func (d *Daemon) Register(ctx context.Context) error {
req := RegisterRequest{ req := RegisterRequest{
@ -62,20 +80,30 @@ func (d *Daemon) Register(ctx context.Context) error {
req.DiskTotalBytes = total req.DiskTotalBytes = total
} }
resp, err := d.client.Register(ctx, req) resp, err := d.transport.Register(ctx, req)
if err != nil { if err != nil {
return fmt.Errorf("register: %w", err) return fmt.Errorf("register: %w", err)
} }
d.User = resp.User d.User = resp.User
d.Features = resp.Features d.Features = resp.Features
now := time.Now()
d.Info = AgentInfo{ d.Info = AgentInfo{
ID: d.cfg.AgentID, ID: d.cfg.AgentID,
Name: d.cfg.AgentName, Name: d.cfg.AgentName,
User: resp.User, User: resp.User,
Features: resp.Features, Features: resp.Features,
StartedAt: time.Now(), StartedAt: now,
} }
d.State = DaemonState{
AgentID: d.cfg.AgentID,
Status: "running",
Version: d.cfg.Version,
PID: os.Getpid(),
StartedAt: now,
MethodStats: make(map[string]int),
}
WriteState(&d.State)
return nil return nil
} }
@ -91,29 +119,41 @@ func (d *Daemon) Run(ctx context.Context) error {
log.Printf("Features: torrent=%v debrid=%v usenet=%v", d.Features.Torrent, d.Features.Debrid, d.Features.Usenet) log.Printf("Features: torrent=%v debrid=%v usenet=%v", d.Features.Torrent, d.Features.Debrid, d.Features.Usenet)
log.Printf("Polling every %s, heartbeat every %s", d.cfg.PollInterval, d.cfg.HeartbeatInterval) log.Printf("Polling every %s, heartbeat every %s", d.cfg.PollInterval, d.cfg.HeartbeatInterval)
heartbeatTicker := time.NewTicker(d.cfg.HeartbeatInterval) d.HeartbeatTicker = time.NewTicker(d.cfg.HeartbeatInterval)
defer heartbeatTicker.Stop() defer d.HeartbeatTicker.Stop()
pollTicker := time.NewTicker(d.cfg.PollInterval) d.PollTicker = time.NewTicker(d.cfg.PollInterval)
defer pollTicker.Stop() defer d.PollTicker.Stop()
heartbeatTicker := d.HeartbeatTicker
pollTicker := d.PollTicker
// Initial poll immediately // Initial poll immediately
d.poll(ctx) d.poll(ctx)
eventsCh := d.transport.Events()
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
log.Println("Daemon shutting down...") log.Println("Daemon shutting down...")
d.deregister()
return nil return nil
case event := <-eventsCh:
d.handleEvent(event)
case <-heartbeatTicker.C: case <-heartbeatTicker.C:
d.heartbeat(ctx) d.heartbeat(ctx)
case <-pollTicker.C: case <-pollTicker.C:
// Only poll in HTTP mode — WS mode receives tasks via Events
if d.transport.Mode() == "http" {
d.poll(ctx) d.poll(ctx)
} }
} }
} }
}
func (d *Daemon) heartbeat(ctx context.Context) { func (d *Daemon) heartbeat(ctx context.Context) {
req := HeartbeatRequest{ req := HeartbeatRequest{
@ -128,13 +168,93 @@ func (d *Daemon) heartbeat(ctx context.Context) {
req.DiskTotalBytes = total req.DiskTotalBytes = total
} }
if err := d.client.Heartbeat(ctx, req); err != nil { resp, err := d.transport.SendHeartbeat(ctx, req)
if err != nil {
d.heartbeatFailures++
if d.heartbeatFailures >= 5 && d.heartbeatFailures%5 == 0 {
log.Printf("CRITICAL: %d consecutive heartbeat failures — server may be unreachable", d.heartbeatFailures)
} else {
log.Printf("Heartbeat failed: %v", err) log.Printf("Heartbeat failed: %v", err)
} }
return
}
if d.heartbeatFailures > 0 {
log.Printf("Heartbeat recovered after %d failures", d.heartbeatFailures)
d.heartbeatFailures = 0
}
// Update state file
d.State.LastHeartbeat = time.Now()
if d.GetActiveCount != nil {
d.State.ActiveTasks = d.GetActiveCount()
}
WriteState(&d.State)
// Check for upgrade signal from server
if resp.Upgrade != nil && resp.Upgrade.Version != "" && !d.upgradeInProgress {
d.upgradeInProgress = true
log.Printf("Upgrade requested by server: %s → %s", d.cfg.Version, resp.Upgrade.Version)
if d.OnUpgradeRequested != nil {
go d.OnUpgradeRequested(resp.Upgrade.Version)
}
}
}
// handleEvent processes a server-initiated event from the WebSocket transport.
func (d *Daemon) handleEvent(event ServerEvent) {
switch event.Type {
case "tasks":
if event.Tasks != nil && len(event.Tasks.Tasks) > 0 {
log.Printf("Received %d task(s) via WebSocket", len(event.Tasks.Tasks))
if d.OnTasksClaimed != nil {
d.OnTasksClaimed(event.Tasks.Tasks)
}
}
if event.Tasks != nil && d.OnStreamRequested != nil {
for _, sr := range event.Tasks.StreamRequests {
d.OnStreamRequested(sr)
}
}
case "upgrade":
if event.Upgrade != nil && event.Upgrade.Version != "" && !d.upgradeInProgress {
d.upgradeInProgress = true
log.Printf("Upgrade requested via WebSocket: %s → %s", d.cfg.Version, event.Upgrade.Version)
if d.OnUpgradeRequested != nil {
go d.OnUpgradeRequested(event.Upgrade.Version)
}
}
case "control":
if event.Control != nil && d.OnControlAction != nil {
log.Printf("Control action via WebSocket: %s task %s", event.Control.Action, event.Control.TaskID)
d.OnControlAction(event.Control.Action, event.Control.TaskID)
}
case "disconnected":
log.Println("WebSocket disconnected, switching to HTTP polling")
}
}
// ClearUpgradeInProgress resets the upgrade flag so a retry can be attempted.
func (d *Daemon) ClearUpgradeInProgress() {
d.upgradeInProgress = false
}
func (d *Daemon) deregister() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := d.transport.Deregister(ctx, d.cfg.AgentID)
if err != nil {
log.Printf("Deregister failed: %v", err)
} else {
log.Println("Agent deregistered")
}
RemoveState()
} }
func (d *Daemon) poll(ctx context.Context) { func (d *Daemon) poll(ctx context.Context) {
tasks, err := d.client.ClaimTasks(ctx, d.cfg.AgentID) resp, err := d.transport.ClaimTasks(ctx, d.cfg.AgentID)
if err != nil { if err != nil {
log.Printf("Poll failed: %v", err) log.Printf("Poll failed: %v", err)
return return
@ -142,13 +262,17 @@ func (d *Daemon) poll(ctx context.Context) {
d.Info.LastPollAt = time.Now() d.Info.LastPollAt = time.Now()
if len(tasks) == 0 { if len(resp.Tasks) > 0 {
return log.Printf("Claimed %d task(s)", len(resp.Tasks))
}
log.Printf("Claimed %d task(s)", len(tasks))
if d.OnTasksClaimed != nil { if d.OnTasksClaimed != nil {
d.OnTasksClaimed(tasks) d.OnTasksClaimed(resp.Tasks)
}
}
// Handle stream requests for completed downloads
if d.OnStreamRequested != nil {
for _, sr := range resp.StreamRequests {
d.OnStreamRequested(sr)
}
} }
} }

View file

@ -0,0 +1,53 @@
package agent
import "context"
// Transport abstracts the communication protocol between the agent and server.
// Both WebSocket (via CF Durable Object) and HTTP (direct to origin) implement this.
type Transport interface {
// Connect establishes the transport connection.
Connect(ctx context.Context) error
// Close tears down the connection gracefully.
Close() error
// Mode returns the current transport mode ("ws" or "http").
Mode() string
// Register sends agent registration and returns user info + features.
Register(ctx context.Context, req RegisterRequest) (*RegisterResponse, error)
// SendHeartbeat sends a periodic keep-alive.
SendHeartbeat(ctx context.Context, req HeartbeatRequest) (*HeartbeatResponse, error)
// SendProgress reports download progress for a task.
SendProgress(ctx context.Context, update StatusUpdate) (*StatusResponse, error)
// ClaimTasks polls for new tasks (HTTP mode only; WS receives via Events).
ClaimTasks(ctx context.Context, agentID string) (*TasksResponse, error)
// Deregister notifies the server of graceful shutdown.
Deregister(ctx context.Context, agentID string) error
// ReportUpgradeResult reports upgrade outcome.
ReportUpgradeResult(ctx context.Context, result UpgradeResult) error
// Events returns a channel that emits server-initiated events.
// In HTTP mode this channel is never written to (polling handles it).
// In WS mode, tasks/upgrade/control arrive here.
Events() <-chan ServerEvent
}
// ServerEvent represents a server-initiated message received via WebSocket.
type ServerEvent struct {
Type string // "tasks", "upgrade", "control", "disconnected"
Tasks *TasksResponse // populated when Type == "tasks"
Upgrade *UpgradeSignal // populated when Type == "upgrade"
Control *ControlAction // populated when Type == "control"
}
// ControlAction represents a server push for task control.
type ControlAction struct {
Action string `json:"action"` // "pause", "resume", "cancel", "stream"
TaskID string `json:"taskId"`
}

View file

@ -0,0 +1,295 @@
package agent
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
)
// TestE2EFullLifecycle tests the full lifecycle:
// connect → auth → receive tasks → send progress → receive control → disconnect → reconnect
func TestE2EFullLifecycle(t *testing.T) {
var mu sync.Mutex
var receivedMessages []map[string]interface{}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
for {
_, msg, err := conn.ReadMessage()
if err != nil {
return
}
var parsed map[string]interface{}
json.Unmarshal(msg, &parsed)
mu.Lock()
receivedMessages = append(receivedMessages, parsed)
mu.Unlock()
msgType, _ := parsed["type"].(string)
switch msgType {
case "auth":
conn.WriteJSON(wsRegisteredMessage{
Type: "registered",
User: UserInfo{Name: "E2E User", Plan: "pro", IsPro: true},
Features: FeatureFlags{Torrent: true, Debrid: true},
})
case "heartbeat":
// No response in WS mode
case "progress":
// Simulate server-side cancel after progress
if progress, ok := parsed["progress"].(float64); ok && progress >= 50 {
conn.WriteJSON(map[string]string{
"type": "control",
"action": "cancel",
"taskId": parsed["taskId"].(string),
})
}
case "upgrade-result":
// Acknowledged
}
}
}))
defer srv.Close()
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
tr := NewWSTransport(wsURL, "e2e-key", "e2e-agent", "test/1.0")
ctx := context.Background()
// 1. Connect
if err := tr.Connect(ctx); err != nil {
t.Fatalf("Connect: %v", err)
}
defer tr.Close()
// 2. Auth
resp, err := tr.Register(ctx, RegisterRequest{
AgentID: "e2e-agent",
Name: "E2E Test Agent",
Version: "1.0.0",
OS: "linux",
Arch: "amd64",
})
if err != nil {
t.Fatalf("Register: %v", err)
}
if resp.User.Name != "E2E User" {
t.Errorf("expected E2E User, got %s", resp.User.Name)
}
if !resp.Features.Debrid {
t.Error("expected debrid feature")
}
// 3. Send heartbeat
_, err = tr.SendHeartbeat(ctx, HeartbeatRequest{
AgentID: "e2e-agent",
DiskFreeBytes: 1000000000,
DiskTotalBytes: 5000000000,
})
if err != nil {
t.Fatalf("SendHeartbeat: %v", err)
}
// 4. Send progress (50% → should trigger cancel control)
_, err = tr.SendProgress(ctx, StatusUpdate{
TaskID: "task-e2e-1",
Status: "downloading",
Progress: 50,
DownloadedBytes: 500,
TotalBytes: 1000,
SpeedBps: 100,
})
if err != nil {
t.Fatalf("SendProgress: %v", err)
}
// 5. Wait for control event (cancel)
select {
case event := <-tr.Events():
if event.Type != "control" {
t.Errorf("expected control event, got %s", event.Type)
}
if event.Control.Action != "cancel" {
t.Errorf("expected cancel, got %s", event.Control.Action)
}
if event.Control.TaskID != "task-e2e-1" {
t.Errorf("expected task-e2e-1, got %s", event.Control.TaskID)
}
case <-time.After(3 * time.Second):
t.Fatal("timeout waiting for cancel control")
}
// 6. Send upgrade result
err = tr.ReportUpgradeResult(ctx, UpgradeResult{
AgentID: "e2e-agent",
Success: true,
Version: "2.0.0",
})
if err != nil {
t.Fatalf("ReportUpgradeResult: %v", err)
}
// Verify server received all messages
time.Sleep(100 * time.Millisecond)
mu.Lock()
defer mu.Unlock()
if len(receivedMessages) < 4 {
t.Fatalf("expected at least 4 messages, got %d", len(receivedMessages))
}
types := make([]string, len(receivedMessages))
for i, m := range receivedMessages {
types[i], _ = m["type"].(string)
}
expected := []string{"auth", "heartbeat", "progress", "upgrade-result"}
for _, exp := range expected {
found := false
for _, got := range types {
if got == exp {
found = true
break
}
}
if !found {
t.Errorf("missing message type %q in %v", exp, types)
}
}
}
// TestE2EHybridFailover tests the full failover scenario:
// WS connect → download → WS disconnect → switch to HTTP → continue working
func TestE2EHybridFailover(t *testing.T) {
connectionCount := 0
var mu sync.Mutex
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
mu.Lock()
connectionCount++
connNum := connectionCount
mu.Unlock()
// Read auth
conn.ReadMessage()
conn.WriteJSON(wsRegisteredMessage{
Type: "registered",
User: UserInfo{Name: "Failover User"},
})
if connNum == 1 {
// First connection: push tasks then disconnect after 200ms
time.Sleep(50 * time.Millisecond)
conn.WriteJSON(wsTasksMessage{
Type: "tasks",
Tasks: []Task{{ID: "t1", InfoHash: "abc", Title: "Failover Movie"}},
})
time.Sleep(150 * time.Millisecond)
conn.Close()
} else {
// Second connection (after reconnect): push upgrade
time.Sleep(50 * time.Millisecond)
conn.WriteJSON(wsUpgradeMessage{Type: "upgrade", Version: "3.0.0"})
time.Sleep(500 * time.Millisecond)
conn.Close()
}
}))
defer srv.Close()
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
wsT := NewWSTransport(wsURL, "key", "a1", "ua")
// HTTP mock for fallback
httpSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Simple heartbeat response
json.NewEncoder(w).Encode(HeartbeatResponse{Success: true})
}))
defer httpSrv.Close()
httpT := NewHTTPTransport(httpSrv.URL, "key", "ua")
h := NewHybridTransport(wsT, httpT)
ctx := context.Background()
err := h.Connect(ctx)
if err != nil {
t.Fatalf("Connect: %v", err)
}
defer h.Close()
// Should start in WS mode
if h.Mode() != "ws" {
t.Fatalf("expected ws mode, got %s", h.Mode())
}
// Register via WS
_, err = h.Register(ctx, RegisterRequest{AgentID: "a1"})
if err != nil {
t.Fatalf("Register: %v", err)
}
// Receive tasks via WS
var tasksReceived bool
var disconnected bool
for i := 0; i < 3; i++ {
select {
case event := <-h.Events():
switch event.Type {
case "tasks":
tasksReceived = true
if len(event.Tasks.Tasks) != 1 || event.Tasks.Tasks[0].Title != "Failover Movie" {
t.Errorf("unexpected tasks: %+v", event.Tasks)
}
case "disconnected":
disconnected = true
}
case <-time.After(2 * time.Second):
break
}
if disconnected {
break
}
}
if !tasksReceived {
t.Error("did not receive tasks before disconnect")
}
if !disconnected {
t.Error("did not receive disconnect event")
}
// Should now be in HTTP mode
time.Sleep(100 * time.Millisecond)
if h.Mode() != "http" {
t.Errorf("expected http mode after disconnect, got %s", h.Mode())
}
// Heartbeat should work via HTTP fallback
hbResp, err := h.SendHeartbeat(ctx, HeartbeatRequest{AgentID: "a1"})
if err != nil {
t.Fatalf("SendHeartbeat via HTTP fallback: %v", err)
}
if !hbResp.Success {
t.Error("expected heartbeat success")
}
}

View file

@ -0,0 +1,50 @@
package agent
import "context"
// HTTPTransport wraps the existing Client to implement Transport.
// This is a thin adapter — no behavioral changes from the current HTTP protocol.
type HTTPTransport struct {
client *Client
events chan ServerEvent
}
// NewHTTPTransport creates a new HTTP-based transport.
func NewHTTPTransport(baseURL, apiKey, userAgent string) *HTTPTransport {
return &HTTPTransport{
client: NewClient(baseURL, apiKey, userAgent),
events: make(chan ServerEvent, 10),
}
}
func (t *HTTPTransport) Connect(_ context.Context) error { return nil }
func (t *HTTPTransport) Close() error { return nil }
func (t *HTTPTransport) Mode() string { return "http" }
func (t *HTTPTransport) Events() <-chan ServerEvent { return t.events }
func (t *HTTPTransport) Register(ctx context.Context, req RegisterRequest) (*RegisterResponse, error) {
return t.client.Register(ctx, req)
}
func (t *HTTPTransport) SendHeartbeat(ctx context.Context, req HeartbeatRequest) (*HeartbeatResponse, error) {
return t.client.Heartbeat(ctx, req)
}
func (t *HTTPTransport) SendProgress(ctx context.Context, update StatusUpdate) (*StatusResponse, error) {
return t.client.ReportStatus(ctx, update)
}
func (t *HTTPTransport) ClaimTasks(ctx context.Context, agentID string) (*TasksResponse, error) {
return t.client.ClaimTasks(ctx, agentID)
}
func (t *HTTPTransport) Deregister(ctx context.Context, agentID string) error {
return t.client.Deregister(ctx, agentID)
}
func (t *HTTPTransport) ReportUpgradeResult(ctx context.Context, result UpgradeResult) error {
return t.client.ReportUpgradeResult(ctx, result)
}
// Client returns the underlying HTTP client for direct use if needed.
func (t *HTTPTransport) Client() *Client { return t.client }

View file

@ -0,0 +1,226 @@
package agent
import (
"context"
"log"
"sync"
"sync/atomic"
"time"
)
// HybridTransport tries WebSocket first, falls back to HTTP if WS fails.
// Automatically reconnects WS in the background.
type HybridTransport struct {
ws *WSTransport
http *HTTPTransport
mode atomic.Value // "ws" or "http"
events chan ServerEvent
reconnectMu sync.Mutex
reconnectRunning bool
reconnectStop chan struct{}
closed atomic.Bool
}
// NewHybridTransport creates a transport that prefers WS with HTTP fallback.
func NewHybridTransport(ws *WSTransport, http *HTTPTransport) *HybridTransport {
h := &HybridTransport{
ws: ws,
http: http,
events: make(chan ServerEvent, 50),
reconnectStop: make(chan struct{}),
}
h.mode.Store("http") // start in HTTP, upgrade to WS on Connect
return h
}
func (h *HybridTransport) Mode() string { return h.mode.Load().(string) }
func (h *HybridTransport) Events() <-chan ServerEvent { return h.events }
// Connect tries WS first. If it fails, falls back to HTTP and starts reconnection loop.
func (h *HybridTransport) Connect(ctx context.Context) error {
// Try WebSocket first
if err := h.ws.Connect(ctx); err != nil {
log.Printf("[transport] WebSocket connect failed (%v), using HTTP fallback", err)
h.mode.Store("http")
h.startReconnectLoop()
return h.http.Connect(ctx)
}
h.mode.Store("ws")
log.Println("[transport] Connected via WebSocket")
// Forward WS events to unified channel + watch for disconnection
go h.forwardWSEvents()
return nil
}
// Close shuts down both transports and stops reconnection.
func (h *HybridTransport) Close() error {
h.closed.Store(true)
select {
case <-h.reconnectStop:
default:
close(h.reconnectStop)
}
_ = h.ws.Close()
return h.http.Close()
}
// Register delegates to the active transport.
func (h *HybridTransport) Register(ctx context.Context, req RegisterRequest) (*RegisterResponse, error) {
if h.mode.Load() == "ws" {
return h.ws.Register(ctx, req)
}
return h.http.Register(ctx, req)
}
// SendHeartbeat delegates to the active transport.
func (h *HybridTransport) SendHeartbeat(ctx context.Context, req HeartbeatRequest) (*HeartbeatResponse, error) {
if h.mode.Load() == "ws" {
resp, err := h.ws.SendHeartbeat(ctx, req)
if err != nil {
// WS write failed — switch to HTTP
h.switchToHTTP()
return h.http.SendHeartbeat(ctx, req)
}
return resp, nil
}
return h.http.SendHeartbeat(ctx, req)
}
// SendProgress delegates to the active transport.
func (h *HybridTransport) SendProgress(ctx context.Context, update StatusUpdate) (*StatusResponse, error) {
if h.mode.Load() == "ws" {
resp, err := h.ws.SendProgress(ctx, update)
if err != nil {
h.switchToHTTP()
return h.http.SendProgress(ctx, update)
}
return resp, nil
}
return h.http.SendProgress(ctx, update)
}
// ClaimTasks delegates to the active transport.
func (h *HybridTransport) ClaimTasks(ctx context.Context, agentID string) (*TasksResponse, error) {
if h.mode.Load() == "ws" {
return h.ws.ClaimTasks(ctx, agentID) // no-op in WS mode
}
return h.http.ClaimTasks(ctx, agentID)
}
// Deregister delegates to the active transport.
func (h *HybridTransport) Deregister(ctx context.Context, agentID string) error {
if h.mode.Load() == "ws" {
return h.ws.Deregister(ctx, agentID)
}
return h.http.Deregister(ctx, agentID)
}
// ReportUpgradeResult delegates to the active transport.
func (h *HybridTransport) ReportUpgradeResult(ctx context.Context, result UpgradeResult) error {
if h.mode.Load() == "ws" {
if err := h.ws.ReportUpgradeResult(ctx, result); err != nil {
h.switchToHTTP()
return h.http.ReportUpgradeResult(ctx, result)
}
return nil
}
return h.http.ReportUpgradeResult(ctx, result)
}
// ── Internal ─────────────────────────────────────────────────────────────────
func (h *HybridTransport) switchToHTTP() {
if h.mode.Load() == "http" {
return
}
log.Println("[transport] Switching to HTTP fallback")
h.mode.Store("http")
_ = h.ws.Close()
h.startReconnectLoop()
}
func (h *HybridTransport) forwardWSEvents() {
for {
select {
case <-h.reconnectStop:
return
case event, ok := <-h.ws.Events():
if !ok {
return // channel closed
}
if event.Type == "disconnected" {
h.switchToHTTP()
select {
case h.events <- event:
default:
}
return
}
select {
case h.events <- event:
default:
log.Printf("[transport] events channel full, dropping %s event", event.Type)
}
}
}
}
func (h *HybridTransport) startReconnectLoop() {
h.reconnectMu.Lock()
defer h.reconnectMu.Unlock()
if h.reconnectRunning {
return
}
h.reconnectRunning = true
go h.reconnectLoop()
}
func (h *HybridTransport) reconnectLoop() {
backoff := 5 * time.Second
maxBackoff := 60 * time.Second
for {
select {
case <-h.reconnectStop:
return
case <-time.After(backoff):
}
if h.closed.Load() {
return
}
// Already on WS? (someone else reconnected)
if h.mode.Load() == "ws" {
return
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
err := h.ws.Connect(ctx)
cancel()
if err != nil {
log.Printf("[transport] WS reconnect failed: %v (retry in %v)", err, backoff)
backoff = min(backoff*2, maxBackoff)
continue
}
// WS reconnected — switch back
log.Println("[transport] WebSocket reconnected")
h.mode.Store("ws")
// Reset reconnect flag so loop can start again if WS drops
h.reconnectMu.Lock()
h.reconnectRunning = false
h.reconnectMu.Unlock()
// Forward events from new WS connection
go h.forwardWSEvents()
return
}
}

View file

@ -0,0 +1,445 @@
package agent
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
"github.com/gorilla/websocket"
)
// ── HTTP Transport Tests ─────────────────────────────────────────────────────
func TestHTTPTransportMode(t *testing.T) {
tr := NewHTTPTransport("http://localhost", "key", "ua")
if tr.Mode() != "http" {
t.Errorf("expected http, got %s", tr.Mode())
}
}
func TestHTTPTransportEventsNeverEmit(t *testing.T) {
tr := NewHTTPTransport("http://localhost", "key", "ua")
select {
case <-tr.Events():
t.Error("events channel should never emit in HTTP mode")
case <-time.After(50 * time.Millisecond):
// expected
}
}
func TestHTTPTransportDelegates(t *testing.T) {
// Mock server for register
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(RegisterResponse{
Success: true,
User: UserInfo{Name: "Test", Plan: "pro"},
})
}))
defer srv.Close()
tr := NewHTTPTransport(srv.URL, "test-key", "test-agent")
resp, err := tr.Register(context.Background(), RegisterRequest{AgentID: "a1"})
if err != nil {
t.Fatalf("Register failed: %v", err)
}
if !resp.Success {
t.Error("expected success")
}
if resp.User.Name != "Test" {
t.Errorf("expected Test, got %s", resp.User.Name)
}
}
// ── WebSocket Transport Tests ────────────────────────────────────────────────
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
}
func TestWSTransportConnectAndAuth(t *testing.T) {
var received wsAuthMessage
var mu sync.Mutex
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Fatalf("upgrade: %v", err)
}
defer conn.Close()
// Read auth message
_, msg, err := conn.ReadMessage()
if err != nil {
return
}
mu.Lock()
json.Unmarshal(msg, &received)
mu.Unlock()
// Send registered response
conn.WriteJSON(wsRegisteredMessage{
Type: "registered",
User: UserInfo{Name: "WS User", Plan: "pro", IsPro: true},
Features: FeatureFlags{Torrent: true},
})
// Keep connection open
time.Sleep(500 * time.Millisecond)
}))
defer srv.Close()
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
tr := NewWSTransport(wsURL, "my-api-key", "agent-123", "test/1.0")
ctx := context.Background()
if err := tr.Connect(ctx); err != nil {
t.Fatalf("Connect failed: %v", err)
}
defer tr.Close()
resp, err := tr.Register(ctx, RegisterRequest{
AgentID: "agent-123",
Name: "test-agent",
Version: "1.0.0",
})
if err != nil {
t.Fatalf("Register failed: %v", err)
}
if !resp.Success {
t.Error("expected success")
}
if resp.User.Name != "WS User" {
t.Errorf("expected WS User, got %s", resp.User.Name)
}
mu.Lock()
if received.APIKey != "my-api-key" {
t.Errorf("expected my-api-key, got %s", received.APIKey)
}
if received.AgentID != "agent-123" {
t.Errorf("expected agent-123, got %s", received.AgentID)
}
mu.Unlock()
}
func TestWSTransportReceiveTasks(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
// Read auth
conn.ReadMessage()
conn.WriteJSON(wsRegisteredMessage{
Type: "registered",
User: UserInfo{Name: "Test"},
})
// Push tasks
time.Sleep(50 * time.Millisecond)
conn.WriteJSON(wsTasksMessage{
Type: "tasks",
Tasks: []Task{
{ID: "t1", InfoHash: "abc123", Title: "Test Movie"},
{ID: "t2", InfoHash: "def456", Title: "Test Show"},
},
})
time.Sleep(500 * time.Millisecond)
}))
defer srv.Close()
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
tr := NewWSTransport(wsURL, "key", "agent1", "ua")
ctx := context.Background()
tr.Connect(ctx)
defer tr.Close()
tr.Register(ctx, RegisterRequest{AgentID: "agent1"})
// Wait for tasks event
select {
case event := <-tr.Events():
if event.Type != "tasks" {
t.Errorf("expected tasks, got %s", event.Type)
}
if len(event.Tasks.Tasks) != 2 {
t.Errorf("expected 2 tasks, got %d", len(event.Tasks.Tasks))
}
if event.Tasks.Tasks[0].Title != "Test Movie" {
t.Errorf("expected Test Movie, got %s", event.Tasks.Tasks[0].Title)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for tasks event")
}
}
func TestWSTransportReceiveControl(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
conn.ReadMessage()
conn.WriteJSON(wsRegisteredMessage{Type: "registered", User: UserInfo{}})
time.Sleep(50 * time.Millisecond)
conn.WriteJSON(map[string]string{
"type": "control",
"action": "cancel",
"taskId": "task-99",
})
time.Sleep(500 * time.Millisecond)
}))
defer srv.Close()
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
tr := NewWSTransport(wsURL, "key", "a1", "ua")
ctx := context.Background()
tr.Connect(ctx)
defer tr.Close()
tr.Register(ctx, RegisterRequest{AgentID: "a1"})
select {
case event := <-tr.Events():
if event.Type != "control" {
t.Errorf("expected control, got %s", event.Type)
}
if event.Control.Action != "cancel" {
t.Errorf("expected cancel, got %s", event.Control.Action)
}
if event.Control.TaskID != "task-99" {
t.Errorf("expected task-99, got %s", event.Control.TaskID)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for control event")
}
}
func TestWSTransportReceiveUpgrade(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
conn.ReadMessage()
conn.WriteJSON(wsRegisteredMessage{Type: "registered", User: UserInfo{}})
time.Sleep(50 * time.Millisecond)
conn.WriteJSON(wsUpgradeMessage{Type: "upgrade", Version: "2.0.0"})
time.Sleep(500 * time.Millisecond)
}))
defer srv.Close()
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
tr := NewWSTransport(wsURL, "key", "a1", "ua")
ctx := context.Background()
tr.Connect(ctx)
defer tr.Close()
tr.Register(ctx, RegisterRequest{AgentID: "a1"})
select {
case event := <-tr.Events():
if event.Type != "upgrade" {
t.Errorf("expected upgrade, got %s", event.Type)
}
if event.Upgrade.Version != "2.0.0" {
t.Errorf("expected 2.0.0, got %s", event.Upgrade.Version)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for upgrade event")
}
}
func TestWSTransportDisconnect(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
conn.ReadMessage()
conn.WriteJSON(wsRegisteredMessage{Type: "registered", User: UserInfo{}})
// Close after a short delay to simulate disconnection
time.Sleep(100 * time.Millisecond)
conn.Close()
}))
defer srv.Close()
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
tr := NewWSTransport(wsURL, "key", "a1", "ua")
ctx := context.Background()
tr.Connect(ctx)
defer tr.Close()
tr.Register(ctx, RegisterRequest{AgentID: "a1"})
select {
case event := <-tr.Events():
if event.Type != "disconnected" {
t.Errorf("expected disconnected, got %s", event.Type)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for disconnected event")
}
}
func TestWSTransportSendProgress(t *testing.T) {
var receivedMsg map[string]interface{}
var mu sync.Mutex
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
// Read auth
conn.ReadMessage()
conn.WriteJSON(wsRegisteredMessage{Type: "registered", User: UserInfo{}})
// Read progress
_, msg, err := conn.ReadMessage()
if err != nil {
return
}
mu.Lock()
json.Unmarshal(msg, &receivedMsg)
mu.Unlock()
time.Sleep(500 * time.Millisecond)
}))
defer srv.Close()
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
tr := NewWSTransport(wsURL, "key", "a1", "ua")
ctx := context.Background()
tr.Connect(ctx)
defer tr.Close()
tr.Register(ctx, RegisterRequest{AgentID: "a1"})
time.Sleep(50 * time.Millisecond)
resp, err := tr.SendProgress(ctx, StatusUpdate{
TaskID: "t1",
Status: "downloading",
Progress: 42,
})
if err != nil {
t.Fatalf("SendProgress failed: %v", err)
}
if !resp.Success {
t.Error("expected success response")
}
time.Sleep(100 * time.Millisecond)
mu.Lock()
if receivedMsg["type"] != "progress" {
t.Errorf("expected progress, got %v", receivedMsg["type"])
}
if receivedMsg["taskId"] != "t1" {
t.Errorf("expected t1, got %v", receivedMsg["taskId"])
}
mu.Unlock()
}
// ── Hybrid Transport Tests ───────────────────────────────────────────────────
func TestHybridTransportWSSuccess(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
time.Sleep(500 * time.Millisecond)
}))
defer srv.Close()
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
wsT := NewWSTransport(wsURL, "key", "a1", "ua")
httpT := NewHTTPTransport("http://localhost", "key", "ua")
h := NewHybridTransport(wsT, httpT)
err := h.Connect(context.Background())
if err != nil {
t.Fatalf("Connect failed: %v", err)
}
defer h.Close()
if h.Mode() != "ws" {
t.Errorf("expected ws mode, got %s", h.Mode())
}
}
func TestHybridTransportWSFailFallbackHTTP(t *testing.T) {
// WS URL points to nowhere
wsT := NewWSTransport("ws://127.0.0.1:1", "key", "a1", "ua")
httpT := NewHTTPTransport("http://localhost", "key", "ua")
h := NewHybridTransport(wsT, httpT)
err := h.Connect(context.Background())
if err != nil {
t.Fatalf("Connect should succeed with HTTP fallback: %v", err)
}
defer h.Close()
if h.Mode() != "http" {
t.Errorf("expected http mode after WS failure, got %s", h.Mode())
}
}
func TestHybridTransportWSDisconnectSwitchesToHTTP(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
// Close immediately to trigger disconnect
time.Sleep(100 * time.Millisecond)
conn.Close()
}))
defer srv.Close()
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
wsT := NewWSTransport(wsURL, "key", "a1", "ua")
httpT := NewHTTPTransport("http://localhost", "key", "ua")
h := NewHybridTransport(wsT, httpT)
h.Connect(context.Background())
defer h.Close()
// Wait for disconnect event
select {
case event := <-h.Events():
if event.Type != "disconnected" {
t.Errorf("expected disconnected, got %s", event.Type)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for disconnected event")
}
// Mode should be HTTP now
time.Sleep(100 * time.Millisecond)
if h.Mode() != "http" {
t.Errorf("expected http after disconnect, got %s", h.Mode())
}
}

View file

@ -0,0 +1,360 @@
package agent
import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/gorilla/websocket"
)
// WSTransport communicates with the server via WebSocket through a Cloudflare Durable Object.
type WSTransport struct {
wsURL string // wss://unarr.torrentclaw.com/ws/{agentId}
apiKey string
agentID string
userAgent string
conn *websocket.Conn
mu sync.Mutex
events chan ServerEvent
closed atomic.Bool
// Cached auth response from the DO
authResp *RegisterResponse
authMu sync.Mutex
authDone chan struct{}
authDoneOnce sync.Once
}
// NewWSTransport creates a WebSocket-based transport.
func NewWSTransport(wsURL, apiKey, agentID, userAgent string) *WSTransport {
return &WSTransport{
wsURL: wsURL,
apiKey: apiKey,
agentID: agentID,
userAgent: userAgent,
events: make(chan ServerEvent, 50),
authDone: make(chan struct{}),
}
}
func (t *WSTransport) Mode() string { return "ws" }
func (t *WSTransport) Events() <-chan ServerEvent { return t.events }
// Connect dials the WebSocket server and starts the read loop.
func (t *WSTransport) Connect(ctx context.Context) error {
dialer := websocket.Dialer{
HandshakeTimeout: 10 * time.Second,
}
header := http.Header{}
header.Set("User-Agent", t.userAgent)
// Append API key as query param for auth on WS upgrade
wsURLWithKey := t.wsURL
if t.apiKey != "" {
sep := "?"
if strings.Contains(wsURLWithKey, "?") {
sep = "&"
}
wsURLWithKey += sep + "key=" + t.apiKey
}
conn, _, err := dialer.DialContext(ctx, wsURLWithKey, header)
if err != nil {
return fmt.Errorf("ws dial: %w", err)
}
t.mu.Lock()
t.conn = conn
t.closed.Store(false)
t.authDone = make(chan struct{})
t.authDoneOnce = sync.Once{}
t.mu.Unlock()
go t.readLoop()
return nil
}
// Close sends a close frame and shuts down the connection.
func (t *WSTransport) Close() error {
t.closed.Store(true)
t.mu.Lock()
defer t.mu.Unlock()
if t.conn != nil {
_ = t.conn.WriteMessage(
websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""),
)
err := t.conn.Close()
t.conn = nil
return err
}
return nil
}
// Register sends auth message and waits for the registered response.
func (t *WSTransport) Register(ctx context.Context, req RegisterRequest) (*RegisterResponse, error) {
msg := wsAuthMessage{
Type: "auth",
APIKey: t.apiKey,
AgentID: req.AgentID,
Name: req.Name,
OS: req.OS,
Arch: req.Arch,
Version: req.Version,
DownloadDir: req.DownloadDir,
DiskFreeBytes: req.DiskFreeBytes,
DiskTotalBytes: req.DiskTotalBytes,
}
if err := t.send(msg); err != nil {
return nil, fmt.Errorf("ws auth send: %w", err)
}
// Wait for the auth response or context cancellation
select {
case <-t.authDone:
t.authMu.Lock()
resp := t.authResp
t.authMu.Unlock()
if resp == nil {
return nil, fmt.Errorf("ws auth: no response received")
}
return resp, nil
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(15 * time.Second):
return nil, fmt.Errorf("ws auth: timeout waiting for registered response")
}
}
// SendHeartbeat sends a heartbeat message. No blocking response in WS mode.
func (t *WSTransport) SendHeartbeat(_ context.Context, req HeartbeatRequest) (*HeartbeatResponse, error) {
msg := struct {
Type string `json:"type"`
Disk *struct {
Free int64 `json:"free"`
Total int64 `json:"total"`
} `json:"disk,omitempty"`
}{Type: "heartbeat"}
if req.DiskFreeBytes > 0 || req.DiskTotalBytes > 0 {
msg.Disk = &struct {
Free int64 `json:"free"`
Total int64 `json:"total"`
}{Free: req.DiskFreeBytes, Total: req.DiskTotalBytes}
}
if err := t.send(msg); err != nil {
return nil, err
}
// WS mode: heartbeat is fire-and-forget. Upgrade signals arrive via Events().
return &HeartbeatResponse{Success: true}, nil
}
// SendProgress sends a progress update. Control signals arrive async via Events().
func (t *WSTransport) SendProgress(_ context.Context, update StatusUpdate) (*StatusResponse, error) {
msg := struct {
Type string `json:"type"`
TaskID string `json:"taskId"`
Status string `json:"status,omitempty"`
Progress int `json:"progress,omitempty"`
DownloadedBytes int64 `json:"downloadedBytes,omitempty"`
TotalBytes int64 `json:"totalBytes,omitempty"`
SpeedBps int64 `json:"speedBps,omitempty"`
ETA int `json:"eta,omitempty"`
ResolvedMethod string `json:"resolvedMethod,omitempty"`
FileName string `json:"fileName,omitempty"`
FilePath string `json:"filePath,omitempty"`
StreamURL string `json:"streamUrl,omitempty"`
ErrorMessage string `json:"errorMessage,omitempty"`
}{
Type: "progress",
TaskID: update.TaskID,
Status: update.Status,
Progress: update.Progress,
DownloadedBytes: update.DownloadedBytes,
TotalBytes: update.TotalBytes,
SpeedBps: update.SpeedBps,
ETA: update.ETA,
ResolvedMethod: update.ResolvedMethod,
FileName: update.FileName,
FilePath: update.FilePath,
StreamURL: update.StreamURL,
ErrorMessage: update.ErrorMessage,
}
if err := t.send(msg); err != nil {
return nil, err
}
// In WS mode, control signals come via Events(), not in the progress response.
return &StatusResponse{Success: true}, nil
}
// ClaimTasks is a no-op in WS mode — tasks arrive via Events().
func (t *WSTransport) ClaimTasks(_ context.Context, _ string) (*TasksResponse, error) {
return &TasksResponse{}, nil
}
// Deregister is handled by WebSocket close (DO detects disconnection).
func (t *WSTransport) Deregister(_ context.Context, _ string) error {
return t.Close()
}
// ReportUpgradeResult sends upgrade result to the DO.
func (t *WSTransport) ReportUpgradeResult(_ context.Context, result UpgradeResult) error {
msg := struct {
Type string `json:"type"`
Success bool `json:"success"`
Version string `json:"version,omitempty"`
Error string `json:"error,omitempty"`
}{
Type: "upgrade-result",
Success: result.Success,
Version: result.Version,
Error: result.Error,
}
return t.send(msg)
}
// ── Internal ─────────────────────────────────────────────────────────────────
func (t *WSTransport) send(msg any) error {
t.mu.Lock()
defer t.mu.Unlock()
if t.conn == nil {
return fmt.Errorf("ws: not connected")
}
data, err := json.Marshal(msg)
if err != nil {
return err
}
return t.conn.WriteMessage(websocket.TextMessage, data)
}
func (t *WSTransport) readLoop() {
for {
_, msg, err := t.conn.ReadMessage()
if err != nil {
if !t.closed.Load() {
log.Printf("[ws] read error: %v", err)
// Signal disconnection to the daemon
select {
case t.events <- ServerEvent{Type: "disconnected"}:
default:
}
}
return
}
var envelope struct {
Type string `json:"type"`
}
if err := json.Unmarshal(msg, &envelope); err != nil {
log.Printf("[ws] invalid message: %v", err)
continue
}
switch envelope.Type {
case "registered":
var resp wsRegisteredMessage
if json.Unmarshal(msg, &resp) == nil {
t.authMu.Lock()
t.authResp = &RegisterResponse{
Success: true,
User: resp.User,
Features: resp.Features,
}
t.authMu.Unlock()
// Signal that auth is complete (sync.Once prevents double-close panic)
t.authDoneOnce.Do(func() { close(t.authDone) })
}
case "tasks":
var resp wsTasksMessage
if json.Unmarshal(msg, &resp) == nil {
select {
case t.events <- ServerEvent{
Type: "tasks",
Tasks: &TasksResponse{
Tasks: resp.Tasks,
StreamRequests: resp.StreamRequests,
},
}:
default:
log.Printf("[ws] events channel full, dropping tasks message")
}
}
case "upgrade":
var resp wsUpgradeMessage
if json.Unmarshal(msg, &resp) == nil {
select {
case t.events <- ServerEvent{
Type: "upgrade",
Upgrade: &UpgradeSignal{Version: resp.Version},
}:
default:
}
}
case "control":
var resp ControlAction
if json.Unmarshal(msg, &resp) == nil {
select {
case t.events <- ServerEvent{
Type: "control",
Control: &resp,
}:
default:
}
}
case "error":
var resp struct{ Message string `json:"message"` }
if json.Unmarshal(msg, &resp) == nil {
log.Printf("[ws] server error: %s", resp.Message)
}
}
}
}
// ── WS message types ─────────────────────────────────────────────────────────
type wsAuthMessage struct {
Type string `json:"type"`
APIKey string `json:"apiKey"`
AgentID string `json:"agentId"`
Name string `json:"name,omitempty"`
OS string `json:"os,omitempty"`
Arch string `json:"arch,omitempty"`
Version string `json:"version,omitempty"`
DownloadDir string `json:"downloadDir,omitempty"`
DiskFreeBytes int64 `json:"diskFreeBytes,omitempty"`
DiskTotalBytes int64 `json:"diskTotalBytes,omitempty"`
}
type wsRegisteredMessage struct {
Type string `json:"type"`
User UserInfo `json:"user"`
Features FeatureFlags `json:"features"`
}
type wsTasksMessage struct {
Type string `json:"type"`
Tasks []Task `json:"tasks"`
StreamRequests []StreamRequest `json:"streamRequests,omitempty"`
}
type wsUpgradeMessage struct {
Type string `json:"type"`
Version string `json:"version"`
}

View file

@ -7,6 +7,7 @@ import (
"os" "os"
"os/signal" "os/signal"
"path/filepath" "path/filepath"
"strings"
"syscall" "syscall"
"time" "time"
@ -104,8 +105,6 @@ func runDaemonStart() error {
heartbeatInterval = 30 * time.Second heartbeatInterval = 30 * time.Second
} }
// Create agent client (direct HTTP — always available as fallback)
ac := agent.NewClient(cfg.Auth.APIURL, cfg.Auth.APIKey, "unarr/"+Version)
userAgent := "unarr/" + Version userAgent := "unarr/" + Version
// Create daemon config // Create daemon config
@ -119,6 +118,8 @@ func runDaemonStart() error {
} }
// Create transport: Hybrid (WS + HTTP fallback) or HTTP-only // Create transport: Hybrid (WS + HTTP fallback) or HTTP-only
httpT := agent.NewHTTPTransport(cfg.Auth.APIURL, cfg.Auth.APIKey, userAgent)
wsURL := cfg.Auth.WSURL wsURL := cfg.Auth.WSURL
if wsURL == "" { if wsURL == "" {
wsURL = deriveWSURL(cfg.Auth.APIURL, cfg.Agent.ID) wsURL = deriveWSURL(cfg.Auth.APIURL, cfg.Agent.ID)
@ -126,28 +127,19 @@ func runDaemonStart() error {
var transport agent.Transport var transport agent.Transport
if wsURL != "" { if wsURL != "" {
httpT := agent.NewHTTPTransport(cfg.Auth.APIURL, cfg.Auth.APIKey, userAgent)
wsT := agent.NewWSTransport(wsURL, cfg.Auth.APIKey, cfg.Agent.ID, userAgent) wsT := agent.NewWSTransport(wsURL, cfg.Auth.APIKey, cfg.Agent.ID, userAgent)
transport = agent.NewHybridTransport(wsT, httpT) transport = agent.NewHybridTransport(wsT, httpT)
log.Printf("Transport: WebSocket (fallback: HTTP) → %s", wsURL) log.Printf("Transport: WebSocket (fallback: HTTP) → %s", wsURL)
} else {
transport = httpT
log.Println("Transport: HTTP only")
} }
// Create daemon // Create daemon — always uses Transport interface
var d *agent.Daemon d := agent.NewDaemon(daemonCfg, transport)
if transport != nil {
d = agent.NewDaemonWithTransport(daemonCfg, transport)
} else {
d = agent.NewDaemon(daemonCfg, ac)
}
// Wire state tracking (connected after manager creation below) // Create progress reporter using transport
// Create progress reporter reporter := engine.NewProgressReporterWithTransport(transport, 3*time.Second)
var reporter *engine.ProgressReporter
if transport != nil {
reporter = engine.NewProgressReporterWithTransport(transport, 3*time.Second)
} else {
reporter = engine.NewProgressReporter(ac, 3*time.Second)
}
// Parse speed limits // Parse speed limits
maxDl, _ := config.ParseSpeed(cfg.Download.MaxDownloadSpeed) maxDl, _ := config.ParseSpeed(cfg.Download.MaxDownloadSpeed)
@ -190,7 +182,7 @@ func runDaemonStart() error {
MoviesDir: cfg.Organize.MoviesDir, MoviesDir: cfg.Organize.MoviesDir,
TVShowsDir: cfg.Organize.TVShowsDir, TVShowsDir: cfg.Organize.TVShowsDir,
}, },
}, reporter, torrentDl, debridDl) }, reporter, torrentDl, debridDl, engine.NewUsenetDownloader(httpT.Client()))
// Wire state tracking // Wire state tracking
d.GetActiveCount = manager.ActiveCount d.GetActiveCount = manager.ActiveCount
@ -275,9 +267,9 @@ func runDaemonStart() error {
log.Printf("[%s] streaming from disk: %s → %s", sr.TaskID[:8], filepath.Base(sr.FilePath), streamURL) log.Printf("[%s] streaming from disk: %s → %s", sr.TaskID[:8], filepath.Base(sr.FilePath), streamURL)
// Report stream URL back to the server // Report stream URL back to the server via transport
go func() { go func() {
if _, err := ac.ReportStatus(ctx, agent.StatusUpdate{ if _, err := transport.SendProgress(ctx, agent.StatusUpdate{
TaskID: sr.TaskID, TaskID: sr.TaskID,
StreamURL: streamURL, StreamURL: streamURL,
}); err != nil { }); err != nil {
@ -298,13 +290,18 @@ func runDaemonStart() error {
case "resume": case "resume":
log.Printf("[%s] resume requested via WebSocket", taskID[:8]) log.Printf("[%s] resume requested via WebSocket", taskID[:8])
case "stream": case "stream":
// Use registry mutex to prevent TOCTOU race with HTTP-polled stream requests
streamRegistry.mu.Lock()
if _, exists := streamRegistry.servers[taskID]; exists {
streamRegistry.mu.Unlock()
return
}
task := manager.GetTask(taskID) task := manager.GetTask(taskID)
if task == nil { if task == nil || task.GetStreamURL() != "" {
return streamRegistry.mu.Unlock()
}
if task.GetStreamURL() != "" {
return return
} }
streamRegistry.mu.Unlock()
srv, err := torrentDl.StartStream(taskID) srv, err := torrentDl.StartStream(taskID)
if err != nil { if err != nil {
log.Printf("[%s] stream failed: %v", taskID[:8], err) log.Printf("[%s] stream failed: %v", taskID[:8], err)
@ -342,11 +339,7 @@ func runDaemonStart() error {
Version: result.NewVersion, Version: result.NewVersion,
Error: errMsg, Error: errMsg,
} }
if transport != nil {
_ = transport.ReportUpgradeResult(reportCtx, upgradeResult) _ = transport.ReportUpgradeResult(reportCtx, upgradeResult)
} else {
_ = ac.ReportUpgradeResult(reportCtx, upgradeResult)
}
if !result.Success { if !result.Success {
log.Printf("Upgrade failed: %v", result.Error) log.Printf("Upgrade failed: %v", result.Error)
@ -360,7 +353,7 @@ func runDaemonStart() error {
// Deregister first so the server knows we're restarting // Deregister first so the server knows we're restarting
deregCtx, deregCancel := context.WithTimeout(context.Background(), 5*time.Second) deregCtx, deregCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer deregCancel() defer deregCancel()
_ = ac.Deregister(deregCtx, cfg.Agent.ID) _ = transport.Deregister(deregCtx, cfg.Agent.ID)
// Flush progress reporter // Flush progress reporter
cancel() cancel()
@ -418,6 +411,7 @@ func runDaemonStart() error {
// deriveWSURL derives a WebSocket URL from the API URL. // deriveWSURL derives a WebSocket URL from the API URL.
// https://torrentclaw.com → wss://unarr.torrentclaw.com/ws/{agentId} // https://torrentclaw.com → wss://unarr.torrentclaw.com/ws/{agentId}
// Returns "" for localhost/dev environments where WS gateway isn't available.
func deriveWSURL(apiURL, agentID string) string { func deriveWSURL(apiURL, agentID string) string {
if apiURL == "" || agentID == "" { if apiURL == "" || agentID == "" {
return "" return ""
@ -437,6 +431,15 @@ func deriveWSURL(apiURL, agentID string) string {
break break
} }
} }
// Strip port if present
if idx := strings.LastIndex(domain, ":"); idx > 0 {
domain = domain[:idx]
}
// Skip WS for localhost/dev — gateway only available in production
if domain == "localhost" || domain == "127.0.0.1" || domain == "0.0.0.0" {
return ""
}
return "wss://unarr." + domain + "/ws/" + agentID return "wss://unarr." + domain + "/ws/" + agentID
} }

View file

@ -25,6 +25,7 @@ type Config struct {
type AuthConfig struct { type AuthConfig struct {
APIKey string `toml:"api_key"` APIKey string `toml:"api_key"`
APIURL string `toml:"api_url"` APIURL string `toml:"api_url"`
WSURL string `toml:"ws_url"` // optional, derived from api_url if empty
} }
type AgentConfig struct { type AgentConfig struct {

View file

@ -12,10 +12,16 @@ import (
// ActionFunc is called when the server signals an action on a task. // ActionFunc is called when the server signals an action on a task.
type ActionFunc func(taskID string) type ActionFunc func(taskID string)
// StatusReporter is the interface used by ProgressReporter to send progress updates.
// Both *agent.Client and agent.Transport implement this via their ReportStatus/SendProgress methods.
type StatusReporter interface {
ReportStatus(ctx context.Context, update agent.StatusUpdate) (*agent.StatusResponse, error)
}
// ProgressReporter aggregates progress from downloads and reports to the API. // ProgressReporter aggregates progress from downloads and reports to the API.
// It batches updates to avoid flooding the server. // It batches updates to avoid flooding the server.
type ProgressReporter struct { type ProgressReporter struct {
agentClient *agent.Client reporter StatusReporter
interval time.Duration interval time.Duration
onCancel ActionFunc onCancel ActionFunc
@ -28,14 +34,33 @@ type ProgressReporter struct {
} }
// NewProgressReporter creates a reporter that flushes every interval. // NewProgressReporter creates a reporter that flushes every interval.
// Accepts *agent.Client directly (backwards compatible).
func NewProgressReporter(ac *agent.Client, interval time.Duration) *ProgressReporter { func NewProgressReporter(ac *agent.Client, interval time.Duration) *ProgressReporter {
return &ProgressReporter{ return &ProgressReporter{
agentClient: ac, reporter: ac,
interval: interval, interval: interval,
latest: make(map[string]*Task), latest: make(map[string]*Task),
} }
} }
// NewProgressReporterWithTransport creates a reporter using a Transport.
func NewProgressReporterWithTransport(t agent.Transport, interval time.Duration) *ProgressReporter {
return &ProgressReporter{
reporter: &transportStatusAdapter{t: t},
interval: interval,
latest: make(map[string]*Task),
}
}
// transportStatusAdapter adapts agent.Transport to StatusReporter.
type transportStatusAdapter struct {
t agent.Transport
}
func (a *transportStatusAdapter) ReportStatus(ctx context.Context, update agent.StatusUpdate) (*agent.StatusResponse, error) {
return a.t.SendProgress(ctx, update)
}
// SetCancelHandler sets the callback invoked when the server says a task is cancelled. // SetCancelHandler sets the callback invoked when the server says a task is cancelled.
func (r *ProgressReporter) SetCancelHandler(fn ActionFunc) { r.onCancel = fn } func (r *ProgressReporter) SetCancelHandler(fn ActionFunc) { r.onCancel = fn }
@ -95,7 +120,7 @@ func (r *ProgressReporter) flush(ctx context.Context) {
} }
update := task.ToStatusUpdate() update := task.ToStatusUpdate()
resp, err := r.agentClient.ReportStatus(ctx, update) resp, err := r.reporter.ReportStatus(ctx, update)
if err != nil { if err != nil {
log.Printf("[%s] progress report failed: %v", task.ID[:8], err) log.Printf("[%s] progress report failed: %v", task.ID[:8], err)
continue continue
@ -130,7 +155,7 @@ func (r *ProgressReporter) flush(ctx context.Context) {
// ReportFinal sends a final status update for a completed/failed task. // ReportFinal sends a final status update for a completed/failed task.
func (r *ProgressReporter) ReportFinal(ctx context.Context, task *Task) { func (r *ProgressReporter) ReportFinal(ctx context.Context, task *Task) {
update := task.ToStatusUpdate() update := task.ToStatusUpdate()
if _, err := r.agentClient.ReportStatus(ctx, update); err != nil { if _, err := r.reporter.ReportStatus(ctx, update); err != nil {
log.Printf("[%s] final report failed: %v", task.ID[:8], err) log.Printf("[%s] final report failed: %v", task.ID[:8], err)
} }
r.Untrack(task.ID) r.Untrack(task.ID)