feat(agent): add WebSocket transport with HTTP fallback
Add Transport interface abstraction supporting WebSocket (via CF Durable Objects) and HTTP (direct to origin) with automatic failover. - Transport interface: Register, SendHeartbeat, SendProgress, Events() - HTTPTransport: thin adapter over existing Client - WSTransport: gorilla/websocket with auth handshake, readLoop, reconnect - HybridTransport: tries WS first, falls back to HTTP, reconnects in bg - Daemon refactored to always use Transport (no dual-path forks) - ProgressReporter accepts StatusReporter interface - deriveWSURL skips localhost/dev (returns "" → HTTP-only) - API key passed in WS query param for connection auth - Fixed: reconnectOnce race (mutex+bool), authDone double-close (sync.Once) - Fixed: forwardWSEvents goroutine leak (select with stop signal) - 20 transport tests + 2 E2E tests (full lifecycle, hybrid failover)
This commit is contained in:
parent
5e80911501
commit
5f337eebd7
10 changed files with 1646 additions and 64 deletions
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
@ -21,19 +22,33 @@ type DaemonConfig struct {
|
||||||
// Daemon manages the main loop: register, heartbeat, poll tasks.
|
// Daemon manages the main loop: register, heartbeat, poll tasks.
|
||||||
type Daemon struct {
|
type Daemon struct {
|
||||||
cfg DaemonConfig
|
cfg DaemonConfig
|
||||||
client *Client
|
transport Transport
|
||||||
|
|
||||||
// Callbacks
|
// Callbacks
|
||||||
OnTasksClaimed func(tasks []Task)
|
OnTasksClaimed func(tasks []Task)
|
||||||
|
OnStreamRequested func(req StreamRequest)
|
||||||
|
OnUpgradeRequested func(version string)
|
||||||
|
OnControlAction func(action, taskID string)
|
||||||
|
|
||||||
// State
|
// State
|
||||||
User UserInfo
|
User UserInfo
|
||||||
Features FeatureFlags
|
Features FeatureFlags
|
||||||
Info AgentInfo
|
Info AgentInfo
|
||||||
|
State DaemonState
|
||||||
|
upgradeInProgress bool
|
||||||
|
heartbeatFailures int
|
||||||
|
|
||||||
|
// Callbacks for state tracking (set by cmd/daemon.go)
|
||||||
|
GetActiveCount func() int
|
||||||
|
|
||||||
|
// Exposed tickers for hot-reload
|
||||||
|
PollTicker *time.Ticker
|
||||||
|
HeartbeatTicker *time.Ticker
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDaemon creates a daemon with the given config and agent client.
|
// NewDaemon creates a daemon with the given transport.
|
||||||
func NewDaemon(cfg DaemonConfig, client *Client) *Daemon {
|
// Use NewHTTPTransport for HTTP-only, or NewHybridTransport for WS+HTTP.
|
||||||
|
func NewDaemon(cfg DaemonConfig, transport Transport) *Daemon {
|
||||||
if cfg.PollInterval == 0 {
|
if cfg.PollInterval == 0 {
|
||||||
cfg.PollInterval = 30 * time.Second
|
cfg.PollInterval = 30 * time.Second
|
||||||
}
|
}
|
||||||
|
|
@ -43,10 +58,13 @@ func NewDaemon(cfg DaemonConfig, client *Client) *Daemon {
|
||||||
|
|
||||||
return &Daemon{
|
return &Daemon{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
client: client,
|
transport: transport,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Transport returns the configured transport.
|
||||||
|
func (d *Daemon) Transport() Transport { return d.transport }
|
||||||
|
|
||||||
// Register registers the agent and fetches user info + features.
|
// Register registers the agent and fetches user info + features.
|
||||||
func (d *Daemon) Register(ctx context.Context) error {
|
func (d *Daemon) Register(ctx context.Context) error {
|
||||||
req := RegisterRequest{
|
req := RegisterRequest{
|
||||||
|
|
@ -62,20 +80,30 @@ func (d *Daemon) Register(ctx context.Context) error {
|
||||||
req.DiskTotalBytes = total
|
req.DiskTotalBytes = total
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := d.client.Register(ctx, req)
|
resp, err := d.transport.Register(ctx, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("register: %w", err)
|
return fmt.Errorf("register: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
d.User = resp.User
|
d.User = resp.User
|
||||||
d.Features = resp.Features
|
d.Features = resp.Features
|
||||||
|
now := time.Now()
|
||||||
d.Info = AgentInfo{
|
d.Info = AgentInfo{
|
||||||
ID: d.cfg.AgentID,
|
ID: d.cfg.AgentID,
|
||||||
Name: d.cfg.AgentName,
|
Name: d.cfg.AgentName,
|
||||||
User: resp.User,
|
User: resp.User,
|
||||||
Features: resp.Features,
|
Features: resp.Features,
|
||||||
StartedAt: time.Now(),
|
StartedAt: now,
|
||||||
}
|
}
|
||||||
|
d.State = DaemonState{
|
||||||
|
AgentID: d.cfg.AgentID,
|
||||||
|
Status: "running",
|
||||||
|
Version: d.cfg.Version,
|
||||||
|
PID: os.Getpid(),
|
||||||
|
StartedAt: now,
|
||||||
|
MethodStats: make(map[string]int),
|
||||||
|
}
|
||||||
|
WriteState(&d.State)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
@ -91,28 +119,40 @@ func (d *Daemon) Run(ctx context.Context) error {
|
||||||
log.Printf("Features: torrent=%v debrid=%v usenet=%v", d.Features.Torrent, d.Features.Debrid, d.Features.Usenet)
|
log.Printf("Features: torrent=%v debrid=%v usenet=%v", d.Features.Torrent, d.Features.Debrid, d.Features.Usenet)
|
||||||
log.Printf("Polling every %s, heartbeat every %s", d.cfg.PollInterval, d.cfg.HeartbeatInterval)
|
log.Printf("Polling every %s, heartbeat every %s", d.cfg.PollInterval, d.cfg.HeartbeatInterval)
|
||||||
|
|
||||||
heartbeatTicker := time.NewTicker(d.cfg.HeartbeatInterval)
|
d.HeartbeatTicker = time.NewTicker(d.cfg.HeartbeatInterval)
|
||||||
defer heartbeatTicker.Stop()
|
defer d.HeartbeatTicker.Stop()
|
||||||
|
|
||||||
pollTicker := time.NewTicker(d.cfg.PollInterval)
|
d.PollTicker = time.NewTicker(d.cfg.PollInterval)
|
||||||
defer pollTicker.Stop()
|
defer d.PollTicker.Stop()
|
||||||
|
|
||||||
|
heartbeatTicker := d.HeartbeatTicker
|
||||||
|
pollTicker := d.PollTicker
|
||||||
|
|
||||||
// Initial poll immediately
|
// Initial poll immediately
|
||||||
d.poll(ctx)
|
d.poll(ctx)
|
||||||
|
|
||||||
|
eventsCh := d.transport.Events()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
log.Println("Daemon shutting down...")
|
log.Println("Daemon shutting down...")
|
||||||
|
d.deregister()
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
|
case event := <-eventsCh:
|
||||||
|
d.handleEvent(event)
|
||||||
|
|
||||||
case <-heartbeatTicker.C:
|
case <-heartbeatTicker.C:
|
||||||
d.heartbeat(ctx)
|
d.heartbeat(ctx)
|
||||||
|
|
||||||
case <-pollTicker.C:
|
case <-pollTicker.C:
|
||||||
|
// Only poll in HTTP mode — WS mode receives tasks via Events
|
||||||
|
if d.transport.Mode() == "http" {
|
||||||
d.poll(ctx)
|
d.poll(ctx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Daemon) heartbeat(ctx context.Context) {
|
func (d *Daemon) heartbeat(ctx context.Context) {
|
||||||
|
|
@ -128,13 +168,93 @@ func (d *Daemon) heartbeat(ctx context.Context) {
|
||||||
req.DiskTotalBytes = total
|
req.DiskTotalBytes = total
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := d.client.Heartbeat(ctx, req); err != nil {
|
resp, err := d.transport.SendHeartbeat(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
d.heartbeatFailures++
|
||||||
|
if d.heartbeatFailures >= 5 && d.heartbeatFailures%5 == 0 {
|
||||||
|
log.Printf("CRITICAL: %d consecutive heartbeat failures — server may be unreachable", d.heartbeatFailures)
|
||||||
|
} else {
|
||||||
log.Printf("Heartbeat failed: %v", err)
|
log.Printf("Heartbeat failed: %v", err)
|
||||||
}
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if d.heartbeatFailures > 0 {
|
||||||
|
log.Printf("Heartbeat recovered after %d failures", d.heartbeatFailures)
|
||||||
|
d.heartbeatFailures = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update state file
|
||||||
|
d.State.LastHeartbeat = time.Now()
|
||||||
|
if d.GetActiveCount != nil {
|
||||||
|
d.State.ActiveTasks = d.GetActiveCount()
|
||||||
|
}
|
||||||
|
WriteState(&d.State)
|
||||||
|
|
||||||
|
// Check for upgrade signal from server
|
||||||
|
if resp.Upgrade != nil && resp.Upgrade.Version != "" && !d.upgradeInProgress {
|
||||||
|
d.upgradeInProgress = true
|
||||||
|
log.Printf("Upgrade requested by server: %s → %s", d.cfg.Version, resp.Upgrade.Version)
|
||||||
|
if d.OnUpgradeRequested != nil {
|
||||||
|
go d.OnUpgradeRequested(resp.Upgrade.Version)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleEvent processes a server-initiated event from the WebSocket transport.
|
||||||
|
func (d *Daemon) handleEvent(event ServerEvent) {
|
||||||
|
switch event.Type {
|
||||||
|
case "tasks":
|
||||||
|
if event.Tasks != nil && len(event.Tasks.Tasks) > 0 {
|
||||||
|
log.Printf("Received %d task(s) via WebSocket", len(event.Tasks.Tasks))
|
||||||
|
if d.OnTasksClaimed != nil {
|
||||||
|
d.OnTasksClaimed(event.Tasks.Tasks)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if event.Tasks != nil && d.OnStreamRequested != nil {
|
||||||
|
for _, sr := range event.Tasks.StreamRequests {
|
||||||
|
d.OnStreamRequested(sr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case "upgrade":
|
||||||
|
if event.Upgrade != nil && event.Upgrade.Version != "" && !d.upgradeInProgress {
|
||||||
|
d.upgradeInProgress = true
|
||||||
|
log.Printf("Upgrade requested via WebSocket: %s → %s", d.cfg.Version, event.Upgrade.Version)
|
||||||
|
if d.OnUpgradeRequested != nil {
|
||||||
|
go d.OnUpgradeRequested(event.Upgrade.Version)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case "control":
|
||||||
|
if event.Control != nil && d.OnControlAction != nil {
|
||||||
|
log.Printf("Control action via WebSocket: %s task %s", event.Control.Action, event.Control.TaskID)
|
||||||
|
d.OnControlAction(event.Control.Action, event.Control.TaskID)
|
||||||
|
}
|
||||||
|
|
||||||
|
case "disconnected":
|
||||||
|
log.Println("WebSocket disconnected, switching to HTTP polling")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearUpgradeInProgress resets the upgrade flag so a retry can be attempted.
|
||||||
|
func (d *Daemon) ClearUpgradeInProgress() {
|
||||||
|
d.upgradeInProgress = false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Daemon) deregister() {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
err := d.transport.Deregister(ctx, d.cfg.AgentID)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Deregister failed: %v", err)
|
||||||
|
} else {
|
||||||
|
log.Println("Agent deregistered")
|
||||||
|
}
|
||||||
|
RemoveState()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Daemon) poll(ctx context.Context) {
|
func (d *Daemon) poll(ctx context.Context) {
|
||||||
tasks, err := d.client.ClaimTasks(ctx, d.cfg.AgentID)
|
resp, err := d.transport.ClaimTasks(ctx, d.cfg.AgentID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Poll failed: %v", err)
|
log.Printf("Poll failed: %v", err)
|
||||||
return
|
return
|
||||||
|
|
@ -142,13 +262,17 @@ func (d *Daemon) poll(ctx context.Context) {
|
||||||
|
|
||||||
d.Info.LastPollAt = time.Now()
|
d.Info.LastPollAt = time.Now()
|
||||||
|
|
||||||
if len(tasks) == 0 {
|
if len(resp.Tasks) > 0 {
|
||||||
return
|
log.Printf("Claimed %d task(s)", len(resp.Tasks))
|
||||||
|
if d.OnTasksClaimed != nil {
|
||||||
|
d.OnTasksClaimed(resp.Tasks)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("Claimed %d task(s)", len(tasks))
|
// Handle stream requests for completed downloads
|
||||||
|
if d.OnStreamRequested != nil {
|
||||||
if d.OnTasksClaimed != nil {
|
for _, sr := range resp.StreamRequests {
|
||||||
d.OnTasksClaimed(tasks)
|
d.OnStreamRequested(sr)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
53
internal/agent/transport.go
Normal file
53
internal/agent/transport.go
Normal file
|
|
@ -0,0 +1,53 @@
|
||||||
|
package agent
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
// Transport abstracts the communication protocol between the agent and server.
|
||||||
|
// Both WebSocket (via CF Durable Object) and HTTP (direct to origin) implement this.
|
||||||
|
type Transport interface {
|
||||||
|
// Connect establishes the transport connection.
|
||||||
|
Connect(ctx context.Context) error
|
||||||
|
|
||||||
|
// Close tears down the connection gracefully.
|
||||||
|
Close() error
|
||||||
|
|
||||||
|
// Mode returns the current transport mode ("ws" or "http").
|
||||||
|
Mode() string
|
||||||
|
|
||||||
|
// Register sends agent registration and returns user info + features.
|
||||||
|
Register(ctx context.Context, req RegisterRequest) (*RegisterResponse, error)
|
||||||
|
|
||||||
|
// SendHeartbeat sends a periodic keep-alive.
|
||||||
|
SendHeartbeat(ctx context.Context, req HeartbeatRequest) (*HeartbeatResponse, error)
|
||||||
|
|
||||||
|
// SendProgress reports download progress for a task.
|
||||||
|
SendProgress(ctx context.Context, update StatusUpdate) (*StatusResponse, error)
|
||||||
|
|
||||||
|
// ClaimTasks polls for new tasks (HTTP mode only; WS receives via Events).
|
||||||
|
ClaimTasks(ctx context.Context, agentID string) (*TasksResponse, error)
|
||||||
|
|
||||||
|
// Deregister notifies the server of graceful shutdown.
|
||||||
|
Deregister(ctx context.Context, agentID string) error
|
||||||
|
|
||||||
|
// ReportUpgradeResult reports upgrade outcome.
|
||||||
|
ReportUpgradeResult(ctx context.Context, result UpgradeResult) error
|
||||||
|
|
||||||
|
// Events returns a channel that emits server-initiated events.
|
||||||
|
// In HTTP mode this channel is never written to (polling handles it).
|
||||||
|
// In WS mode, tasks/upgrade/control arrive here.
|
||||||
|
Events() <-chan ServerEvent
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServerEvent represents a server-initiated message received via WebSocket.
|
||||||
|
type ServerEvent struct {
|
||||||
|
Type string // "tasks", "upgrade", "control", "disconnected"
|
||||||
|
Tasks *TasksResponse // populated when Type == "tasks"
|
||||||
|
Upgrade *UpgradeSignal // populated when Type == "upgrade"
|
||||||
|
Control *ControlAction // populated when Type == "control"
|
||||||
|
}
|
||||||
|
|
||||||
|
// ControlAction represents a server push for task control.
|
||||||
|
type ControlAction struct {
|
||||||
|
Action string `json:"action"` // "pause", "resume", "cancel", "stream"
|
||||||
|
TaskID string `json:"taskId"`
|
||||||
|
}
|
||||||
295
internal/agent/transport_e2e_test.go
Normal file
295
internal/agent/transport_e2e_test.go
Normal file
|
|
@ -0,0 +1,295 @@
|
||||||
|
package agent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestE2EFullLifecycle tests the full lifecycle:
|
||||||
|
// connect → auth → receive tasks → send progress → receive control → disconnect → reconnect
|
||||||
|
func TestE2EFullLifecycle(t *testing.T) {
|
||||||
|
var mu sync.Mutex
|
||||||
|
var receivedMessages []map[string]interface{}
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
for {
|
||||||
|
_, msg, err := conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var parsed map[string]interface{}
|
||||||
|
json.Unmarshal(msg, &parsed)
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
receivedMessages = append(receivedMessages, parsed)
|
||||||
|
mu.Unlock()
|
||||||
|
|
||||||
|
msgType, _ := parsed["type"].(string)
|
||||||
|
switch msgType {
|
||||||
|
case "auth":
|
||||||
|
conn.WriteJSON(wsRegisteredMessage{
|
||||||
|
Type: "registered",
|
||||||
|
User: UserInfo{Name: "E2E User", Plan: "pro", IsPro: true},
|
||||||
|
Features: FeatureFlags{Torrent: true, Debrid: true},
|
||||||
|
})
|
||||||
|
|
||||||
|
case "heartbeat":
|
||||||
|
// No response in WS mode
|
||||||
|
|
||||||
|
case "progress":
|
||||||
|
// Simulate server-side cancel after progress
|
||||||
|
if progress, ok := parsed["progress"].(float64); ok && progress >= 50 {
|
||||||
|
conn.WriteJSON(map[string]string{
|
||||||
|
"type": "control",
|
||||||
|
"action": "cancel",
|
||||||
|
"taskId": parsed["taskId"].(string),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
case "upgrade-result":
|
||||||
|
// Acknowledged
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
|
||||||
|
tr := NewWSTransport(wsURL, "e2e-key", "e2e-agent", "test/1.0")
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// 1. Connect
|
||||||
|
if err := tr.Connect(ctx); err != nil {
|
||||||
|
t.Fatalf("Connect: %v", err)
|
||||||
|
}
|
||||||
|
defer tr.Close()
|
||||||
|
|
||||||
|
// 2. Auth
|
||||||
|
resp, err := tr.Register(ctx, RegisterRequest{
|
||||||
|
AgentID: "e2e-agent",
|
||||||
|
Name: "E2E Test Agent",
|
||||||
|
Version: "1.0.0",
|
||||||
|
OS: "linux",
|
||||||
|
Arch: "amd64",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Register: %v", err)
|
||||||
|
}
|
||||||
|
if resp.User.Name != "E2E User" {
|
||||||
|
t.Errorf("expected E2E User, got %s", resp.User.Name)
|
||||||
|
}
|
||||||
|
if !resp.Features.Debrid {
|
||||||
|
t.Error("expected debrid feature")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Send heartbeat
|
||||||
|
_, err = tr.SendHeartbeat(ctx, HeartbeatRequest{
|
||||||
|
AgentID: "e2e-agent",
|
||||||
|
DiskFreeBytes: 1000000000,
|
||||||
|
DiskTotalBytes: 5000000000,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SendHeartbeat: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Send progress (50% → should trigger cancel control)
|
||||||
|
_, err = tr.SendProgress(ctx, StatusUpdate{
|
||||||
|
TaskID: "task-e2e-1",
|
||||||
|
Status: "downloading",
|
||||||
|
Progress: 50,
|
||||||
|
DownloadedBytes: 500,
|
||||||
|
TotalBytes: 1000,
|
||||||
|
SpeedBps: 100,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SendProgress: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. Wait for control event (cancel)
|
||||||
|
select {
|
||||||
|
case event := <-tr.Events():
|
||||||
|
if event.Type != "control" {
|
||||||
|
t.Errorf("expected control event, got %s", event.Type)
|
||||||
|
}
|
||||||
|
if event.Control.Action != "cancel" {
|
||||||
|
t.Errorf("expected cancel, got %s", event.Control.Action)
|
||||||
|
}
|
||||||
|
if event.Control.TaskID != "task-e2e-1" {
|
||||||
|
t.Errorf("expected task-e2e-1, got %s", event.Control.TaskID)
|
||||||
|
}
|
||||||
|
case <-time.After(3 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for cancel control")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 6. Send upgrade result
|
||||||
|
err = tr.ReportUpgradeResult(ctx, UpgradeResult{
|
||||||
|
AgentID: "e2e-agent",
|
||||||
|
Success: true,
|
||||||
|
Version: "2.0.0",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReportUpgradeResult: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify server received all messages
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
|
||||||
|
if len(receivedMessages) < 4 {
|
||||||
|
t.Fatalf("expected at least 4 messages, got %d", len(receivedMessages))
|
||||||
|
}
|
||||||
|
|
||||||
|
types := make([]string, len(receivedMessages))
|
||||||
|
for i, m := range receivedMessages {
|
||||||
|
types[i], _ = m["type"].(string)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := []string{"auth", "heartbeat", "progress", "upgrade-result"}
|
||||||
|
for _, exp := range expected {
|
||||||
|
found := false
|
||||||
|
for _, got := range types {
|
||||||
|
if got == exp {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Errorf("missing message type %q in %v", exp, types)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestE2EHybridFailover tests the full failover scenario:
|
||||||
|
// WS connect → download → WS disconnect → switch to HTTP → continue working
|
||||||
|
func TestE2EHybridFailover(t *testing.T) {
|
||||||
|
connectionCount := 0
|
||||||
|
var mu sync.Mutex
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
connectionCount++
|
||||||
|
connNum := connectionCount
|
||||||
|
mu.Unlock()
|
||||||
|
|
||||||
|
// Read auth
|
||||||
|
conn.ReadMessage()
|
||||||
|
conn.WriteJSON(wsRegisteredMessage{
|
||||||
|
Type: "registered",
|
||||||
|
User: UserInfo{Name: "Failover User"},
|
||||||
|
})
|
||||||
|
|
||||||
|
if connNum == 1 {
|
||||||
|
// First connection: push tasks then disconnect after 200ms
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
conn.WriteJSON(wsTasksMessage{
|
||||||
|
Type: "tasks",
|
||||||
|
Tasks: []Task{{ID: "t1", InfoHash: "abc", Title: "Failover Movie"}},
|
||||||
|
})
|
||||||
|
time.Sleep(150 * time.Millisecond)
|
||||||
|
conn.Close()
|
||||||
|
} else {
|
||||||
|
// Second connection (after reconnect): push upgrade
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
conn.WriteJSON(wsUpgradeMessage{Type: "upgrade", Version: "3.0.0"})
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
|
||||||
|
wsT := NewWSTransport(wsURL, "key", "a1", "ua")
|
||||||
|
|
||||||
|
// HTTP mock for fallback
|
||||||
|
httpSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Simple heartbeat response
|
||||||
|
json.NewEncoder(w).Encode(HeartbeatResponse{Success: true})
|
||||||
|
}))
|
||||||
|
defer httpSrv.Close()
|
||||||
|
|
||||||
|
httpT := NewHTTPTransport(httpSrv.URL, "key", "ua")
|
||||||
|
h := NewHybridTransport(wsT, httpT)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
err := h.Connect(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Connect: %v", err)
|
||||||
|
}
|
||||||
|
defer h.Close()
|
||||||
|
|
||||||
|
// Should start in WS mode
|
||||||
|
if h.Mode() != "ws" {
|
||||||
|
t.Fatalf("expected ws mode, got %s", h.Mode())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register via WS
|
||||||
|
_, err = h.Register(ctx, RegisterRequest{AgentID: "a1"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Register: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receive tasks via WS
|
||||||
|
var tasksReceived bool
|
||||||
|
var disconnected bool
|
||||||
|
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
select {
|
||||||
|
case event := <-h.Events():
|
||||||
|
switch event.Type {
|
||||||
|
case "tasks":
|
||||||
|
tasksReceived = true
|
||||||
|
if len(event.Tasks.Tasks) != 1 || event.Tasks.Tasks[0].Title != "Failover Movie" {
|
||||||
|
t.Errorf("unexpected tasks: %+v", event.Tasks)
|
||||||
|
}
|
||||||
|
case "disconnected":
|
||||||
|
disconnected = true
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if disconnected {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tasksReceived {
|
||||||
|
t.Error("did not receive tasks before disconnect")
|
||||||
|
}
|
||||||
|
if !disconnected {
|
||||||
|
t.Error("did not receive disconnect event")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should now be in HTTP mode
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
if h.Mode() != "http" {
|
||||||
|
t.Errorf("expected http mode after disconnect, got %s", h.Mode())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Heartbeat should work via HTTP fallback
|
||||||
|
hbResp, err := h.SendHeartbeat(ctx, HeartbeatRequest{AgentID: "a1"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SendHeartbeat via HTTP fallback: %v", err)
|
||||||
|
}
|
||||||
|
if !hbResp.Success {
|
||||||
|
t.Error("expected heartbeat success")
|
||||||
|
}
|
||||||
|
}
|
||||||
50
internal/agent/transport_http.go
Normal file
50
internal/agent/transport_http.go
Normal file
|
|
@ -0,0 +1,50 @@
|
||||||
|
package agent
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
// HTTPTransport wraps the existing Client to implement Transport.
|
||||||
|
// This is a thin adapter — no behavioral changes from the current HTTP protocol.
|
||||||
|
type HTTPTransport struct {
|
||||||
|
client *Client
|
||||||
|
events chan ServerEvent
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHTTPTransport creates a new HTTP-based transport.
|
||||||
|
func NewHTTPTransport(baseURL, apiKey, userAgent string) *HTTPTransport {
|
||||||
|
return &HTTPTransport{
|
||||||
|
client: NewClient(baseURL, apiKey, userAgent),
|
||||||
|
events: make(chan ServerEvent, 10),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *HTTPTransport) Connect(_ context.Context) error { return nil }
|
||||||
|
func (t *HTTPTransport) Close() error { return nil }
|
||||||
|
func (t *HTTPTransport) Mode() string { return "http" }
|
||||||
|
func (t *HTTPTransport) Events() <-chan ServerEvent { return t.events }
|
||||||
|
|
||||||
|
func (t *HTTPTransport) Register(ctx context.Context, req RegisterRequest) (*RegisterResponse, error) {
|
||||||
|
return t.client.Register(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *HTTPTransport) SendHeartbeat(ctx context.Context, req HeartbeatRequest) (*HeartbeatResponse, error) {
|
||||||
|
return t.client.Heartbeat(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *HTTPTransport) SendProgress(ctx context.Context, update StatusUpdate) (*StatusResponse, error) {
|
||||||
|
return t.client.ReportStatus(ctx, update)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *HTTPTransport) ClaimTasks(ctx context.Context, agentID string) (*TasksResponse, error) {
|
||||||
|
return t.client.ClaimTasks(ctx, agentID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *HTTPTransport) Deregister(ctx context.Context, agentID string) error {
|
||||||
|
return t.client.Deregister(ctx, agentID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *HTTPTransport) ReportUpgradeResult(ctx context.Context, result UpgradeResult) error {
|
||||||
|
return t.client.ReportUpgradeResult(ctx, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client returns the underlying HTTP client for direct use if needed.
|
||||||
|
func (t *HTTPTransport) Client() *Client { return t.client }
|
||||||
226
internal/agent/transport_hybrid.go
Normal file
226
internal/agent/transport_hybrid.go
Normal file
|
|
@ -0,0 +1,226 @@
|
||||||
|
package agent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HybridTransport tries WebSocket first, falls back to HTTP if WS fails.
|
||||||
|
// Automatically reconnects WS in the background.
|
||||||
|
type HybridTransport struct {
|
||||||
|
ws *WSTransport
|
||||||
|
http *HTTPTransport
|
||||||
|
|
||||||
|
mode atomic.Value // "ws" or "http"
|
||||||
|
events chan ServerEvent
|
||||||
|
|
||||||
|
reconnectMu sync.Mutex
|
||||||
|
reconnectRunning bool
|
||||||
|
reconnectStop chan struct{}
|
||||||
|
closed atomic.Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHybridTransport creates a transport that prefers WS with HTTP fallback.
|
||||||
|
func NewHybridTransport(ws *WSTransport, http *HTTPTransport) *HybridTransport {
|
||||||
|
h := &HybridTransport{
|
||||||
|
ws: ws,
|
||||||
|
http: http,
|
||||||
|
events: make(chan ServerEvent, 50),
|
||||||
|
reconnectStop: make(chan struct{}),
|
||||||
|
}
|
||||||
|
h.mode.Store("http") // start in HTTP, upgrade to WS on Connect
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HybridTransport) Mode() string { return h.mode.Load().(string) }
|
||||||
|
func (h *HybridTransport) Events() <-chan ServerEvent { return h.events }
|
||||||
|
|
||||||
|
// Connect tries WS first. If it fails, falls back to HTTP and starts reconnection loop.
|
||||||
|
func (h *HybridTransport) Connect(ctx context.Context) error {
|
||||||
|
// Try WebSocket first
|
||||||
|
if err := h.ws.Connect(ctx); err != nil {
|
||||||
|
log.Printf("[transport] WebSocket connect failed (%v), using HTTP fallback", err)
|
||||||
|
h.mode.Store("http")
|
||||||
|
h.startReconnectLoop()
|
||||||
|
return h.http.Connect(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
h.mode.Store("ws")
|
||||||
|
log.Println("[transport] Connected via WebSocket")
|
||||||
|
|
||||||
|
// Forward WS events to unified channel + watch for disconnection
|
||||||
|
go h.forwardWSEvents()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close shuts down both transports and stops reconnection.
|
||||||
|
func (h *HybridTransport) Close() error {
|
||||||
|
h.closed.Store(true)
|
||||||
|
select {
|
||||||
|
case <-h.reconnectStop:
|
||||||
|
default:
|
||||||
|
close(h.reconnectStop)
|
||||||
|
}
|
||||||
|
_ = h.ws.Close()
|
||||||
|
return h.http.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register delegates to the active transport.
|
||||||
|
func (h *HybridTransport) Register(ctx context.Context, req RegisterRequest) (*RegisterResponse, error) {
|
||||||
|
if h.mode.Load() == "ws" {
|
||||||
|
return h.ws.Register(ctx, req)
|
||||||
|
}
|
||||||
|
return h.http.Register(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendHeartbeat delegates to the active transport.
|
||||||
|
func (h *HybridTransport) SendHeartbeat(ctx context.Context, req HeartbeatRequest) (*HeartbeatResponse, error) {
|
||||||
|
if h.mode.Load() == "ws" {
|
||||||
|
resp, err := h.ws.SendHeartbeat(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
// WS write failed — switch to HTTP
|
||||||
|
h.switchToHTTP()
|
||||||
|
return h.http.SendHeartbeat(ctx, req)
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
return h.http.SendHeartbeat(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendProgress delegates to the active transport.
|
||||||
|
func (h *HybridTransport) SendProgress(ctx context.Context, update StatusUpdate) (*StatusResponse, error) {
|
||||||
|
if h.mode.Load() == "ws" {
|
||||||
|
resp, err := h.ws.SendProgress(ctx, update)
|
||||||
|
if err != nil {
|
||||||
|
h.switchToHTTP()
|
||||||
|
return h.http.SendProgress(ctx, update)
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
return h.http.SendProgress(ctx, update)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClaimTasks delegates to the active transport.
|
||||||
|
func (h *HybridTransport) ClaimTasks(ctx context.Context, agentID string) (*TasksResponse, error) {
|
||||||
|
if h.mode.Load() == "ws" {
|
||||||
|
return h.ws.ClaimTasks(ctx, agentID) // no-op in WS mode
|
||||||
|
}
|
||||||
|
return h.http.ClaimTasks(ctx, agentID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deregister delegates to the active transport.
|
||||||
|
func (h *HybridTransport) Deregister(ctx context.Context, agentID string) error {
|
||||||
|
if h.mode.Load() == "ws" {
|
||||||
|
return h.ws.Deregister(ctx, agentID)
|
||||||
|
}
|
||||||
|
return h.http.Deregister(ctx, agentID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReportUpgradeResult delegates to the active transport.
|
||||||
|
func (h *HybridTransport) ReportUpgradeResult(ctx context.Context, result UpgradeResult) error {
|
||||||
|
if h.mode.Load() == "ws" {
|
||||||
|
if err := h.ws.ReportUpgradeResult(ctx, result); err != nil {
|
||||||
|
h.switchToHTTP()
|
||||||
|
return h.http.ReportUpgradeResult(ctx, result)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return h.http.ReportUpgradeResult(ctx, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Internal ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func (h *HybridTransport) switchToHTTP() {
|
||||||
|
if h.mode.Load() == "http" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Println("[transport] Switching to HTTP fallback")
|
||||||
|
h.mode.Store("http")
|
||||||
|
_ = h.ws.Close()
|
||||||
|
h.startReconnectLoop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HybridTransport) forwardWSEvents() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-h.reconnectStop:
|
||||||
|
return
|
||||||
|
case event, ok := <-h.ws.Events():
|
||||||
|
if !ok {
|
||||||
|
return // channel closed
|
||||||
|
}
|
||||||
|
if event.Type == "disconnected" {
|
||||||
|
h.switchToHTTP()
|
||||||
|
select {
|
||||||
|
case h.events <- event:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case h.events <- event:
|
||||||
|
default:
|
||||||
|
log.Printf("[transport] events channel full, dropping %s event", event.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HybridTransport) startReconnectLoop() {
|
||||||
|
h.reconnectMu.Lock()
|
||||||
|
defer h.reconnectMu.Unlock()
|
||||||
|
if h.reconnectRunning {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.reconnectRunning = true
|
||||||
|
go h.reconnectLoop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HybridTransport) reconnectLoop() {
|
||||||
|
backoff := 5 * time.Second
|
||||||
|
maxBackoff := 60 * time.Second
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-h.reconnectStop:
|
||||||
|
return
|
||||||
|
case <-time.After(backoff):
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.closed.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Already on WS? (someone else reconnected)
|
||||||
|
if h.mode.Load() == "ws" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
err := h.ws.Connect(ctx)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[transport] WS reconnect failed: %v (retry in %v)", err, backoff)
|
||||||
|
backoff = min(backoff*2, maxBackoff)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// WS reconnected — switch back
|
||||||
|
log.Println("[transport] WebSocket reconnected")
|
||||||
|
h.mode.Store("ws")
|
||||||
|
|
||||||
|
// Reset reconnect flag so loop can start again if WS drops
|
||||||
|
h.reconnectMu.Lock()
|
||||||
|
h.reconnectRunning = false
|
||||||
|
h.reconnectMu.Unlock()
|
||||||
|
|
||||||
|
// Forward events from new WS connection
|
||||||
|
go h.forwardWSEvents()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
445
internal/agent/transport_test.go
Normal file
445
internal/agent/transport_test.go
Normal file
|
|
@ -0,0 +1,445 @@
|
||||||
|
package agent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ── HTTP Transport Tests ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestHTTPTransportMode(t *testing.T) {
|
||||||
|
tr := NewHTTPTransport("http://localhost", "key", "ua")
|
||||||
|
if tr.Mode() != "http" {
|
||||||
|
t.Errorf("expected http, got %s", tr.Mode())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPTransportEventsNeverEmit(t *testing.T) {
|
||||||
|
tr := NewHTTPTransport("http://localhost", "key", "ua")
|
||||||
|
select {
|
||||||
|
case <-tr.Events():
|
||||||
|
t.Error("events channel should never emit in HTTP mode")
|
||||||
|
case <-time.After(50 * time.Millisecond):
|
||||||
|
// expected
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPTransportDelegates(t *testing.T) {
|
||||||
|
// Mock server for register
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
json.NewEncoder(w).Encode(RegisterResponse{
|
||||||
|
Success: true,
|
||||||
|
User: UserInfo{Name: "Test", Plan: "pro"},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
tr := NewHTTPTransport(srv.URL, "test-key", "test-agent")
|
||||||
|
resp, err := tr.Register(context.Background(), RegisterRequest{AgentID: "a1"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Register failed: %v", err)
|
||||||
|
}
|
||||||
|
if !resp.Success {
|
||||||
|
t.Error("expected success")
|
||||||
|
}
|
||||||
|
if resp.User.Name != "Test" {
|
||||||
|
t.Errorf("expected Test, got %s", resp.User.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── WebSocket Transport Tests ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
var upgrader = websocket.Upgrader{
|
||||||
|
CheckOrigin: func(r *http.Request) bool { return true },
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWSTransportConnectAndAuth(t *testing.T) {
|
||||||
|
var received wsAuthMessage
|
||||||
|
var mu sync.Mutex
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("upgrade: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
// Read auth message
|
||||||
|
_, msg, err := conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
mu.Lock()
|
||||||
|
json.Unmarshal(msg, &received)
|
||||||
|
mu.Unlock()
|
||||||
|
|
||||||
|
// Send registered response
|
||||||
|
conn.WriteJSON(wsRegisteredMessage{
|
||||||
|
Type: "registered",
|
||||||
|
User: UserInfo{Name: "WS User", Plan: "pro", IsPro: true},
|
||||||
|
Features: FeatureFlags{Torrent: true},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Keep connection open
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
|
||||||
|
tr := NewWSTransport(wsURL, "my-api-key", "agent-123", "test/1.0")
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
if err := tr.Connect(ctx); err != nil {
|
||||||
|
t.Fatalf("Connect failed: %v", err)
|
||||||
|
}
|
||||||
|
defer tr.Close()
|
||||||
|
|
||||||
|
resp, err := tr.Register(ctx, RegisterRequest{
|
||||||
|
AgentID: "agent-123",
|
||||||
|
Name: "test-agent",
|
||||||
|
Version: "1.0.0",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Register failed: %v", err)
|
||||||
|
}
|
||||||
|
if !resp.Success {
|
||||||
|
t.Error("expected success")
|
||||||
|
}
|
||||||
|
if resp.User.Name != "WS User" {
|
||||||
|
t.Errorf("expected WS User, got %s", resp.User.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
if received.APIKey != "my-api-key" {
|
||||||
|
t.Errorf("expected my-api-key, got %s", received.APIKey)
|
||||||
|
}
|
||||||
|
if received.AgentID != "agent-123" {
|
||||||
|
t.Errorf("expected agent-123, got %s", received.AgentID)
|
||||||
|
}
|
||||||
|
mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWSTransportReceiveTasks(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
// Read auth
|
||||||
|
conn.ReadMessage()
|
||||||
|
conn.WriteJSON(wsRegisteredMessage{
|
||||||
|
Type: "registered",
|
||||||
|
User: UserInfo{Name: "Test"},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Push tasks
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
conn.WriteJSON(wsTasksMessage{
|
||||||
|
Type: "tasks",
|
||||||
|
Tasks: []Task{
|
||||||
|
{ID: "t1", InfoHash: "abc123", Title: "Test Movie"},
|
||||||
|
{ID: "t2", InfoHash: "def456", Title: "Test Show"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
|
||||||
|
tr := NewWSTransport(wsURL, "key", "agent1", "ua")
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
tr.Connect(ctx)
|
||||||
|
defer tr.Close()
|
||||||
|
|
||||||
|
tr.Register(ctx, RegisterRequest{AgentID: "agent1"})
|
||||||
|
|
||||||
|
// Wait for tasks event
|
||||||
|
select {
|
||||||
|
case event := <-tr.Events():
|
||||||
|
if event.Type != "tasks" {
|
||||||
|
t.Errorf("expected tasks, got %s", event.Type)
|
||||||
|
}
|
||||||
|
if len(event.Tasks.Tasks) != 2 {
|
||||||
|
t.Errorf("expected 2 tasks, got %d", len(event.Tasks.Tasks))
|
||||||
|
}
|
||||||
|
if event.Tasks.Tasks[0].Title != "Test Movie" {
|
||||||
|
t.Errorf("expected Test Movie, got %s", event.Tasks.Tasks[0].Title)
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for tasks event")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWSTransportReceiveControl(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
conn.ReadMessage()
|
||||||
|
conn.WriteJSON(wsRegisteredMessage{Type: "registered", User: UserInfo{}})
|
||||||
|
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
conn.WriteJSON(map[string]string{
|
||||||
|
"type": "control",
|
||||||
|
"action": "cancel",
|
||||||
|
"taskId": "task-99",
|
||||||
|
})
|
||||||
|
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
|
||||||
|
tr := NewWSTransport(wsURL, "key", "a1", "ua")
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
tr.Connect(ctx)
|
||||||
|
defer tr.Close()
|
||||||
|
tr.Register(ctx, RegisterRequest{AgentID: "a1"})
|
||||||
|
|
||||||
|
select {
|
||||||
|
case event := <-tr.Events():
|
||||||
|
if event.Type != "control" {
|
||||||
|
t.Errorf("expected control, got %s", event.Type)
|
||||||
|
}
|
||||||
|
if event.Control.Action != "cancel" {
|
||||||
|
t.Errorf("expected cancel, got %s", event.Control.Action)
|
||||||
|
}
|
||||||
|
if event.Control.TaskID != "task-99" {
|
||||||
|
t.Errorf("expected task-99, got %s", event.Control.TaskID)
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for control event")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWSTransportReceiveUpgrade(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
conn.ReadMessage()
|
||||||
|
conn.WriteJSON(wsRegisteredMessage{Type: "registered", User: UserInfo{}})
|
||||||
|
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
conn.WriteJSON(wsUpgradeMessage{Type: "upgrade", Version: "2.0.0"})
|
||||||
|
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
|
||||||
|
tr := NewWSTransport(wsURL, "key", "a1", "ua")
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
tr.Connect(ctx)
|
||||||
|
defer tr.Close()
|
||||||
|
tr.Register(ctx, RegisterRequest{AgentID: "a1"})
|
||||||
|
|
||||||
|
select {
|
||||||
|
case event := <-tr.Events():
|
||||||
|
if event.Type != "upgrade" {
|
||||||
|
t.Errorf("expected upgrade, got %s", event.Type)
|
||||||
|
}
|
||||||
|
if event.Upgrade.Version != "2.0.0" {
|
||||||
|
t.Errorf("expected 2.0.0, got %s", event.Upgrade.Version)
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for upgrade event")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWSTransportDisconnect(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.ReadMessage()
|
||||||
|
conn.WriteJSON(wsRegisteredMessage{Type: "registered", User: UserInfo{}})
|
||||||
|
|
||||||
|
// Close after a short delay to simulate disconnection
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
conn.Close()
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
|
||||||
|
tr := NewWSTransport(wsURL, "key", "a1", "ua")
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
tr.Connect(ctx)
|
||||||
|
defer tr.Close()
|
||||||
|
tr.Register(ctx, RegisterRequest{AgentID: "a1"})
|
||||||
|
|
||||||
|
select {
|
||||||
|
case event := <-tr.Events():
|
||||||
|
if event.Type != "disconnected" {
|
||||||
|
t.Errorf("expected disconnected, got %s", event.Type)
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for disconnected event")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWSTransportSendProgress(t *testing.T) {
|
||||||
|
var receivedMsg map[string]interface{}
|
||||||
|
var mu sync.Mutex
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
// Read auth
|
||||||
|
conn.ReadMessage()
|
||||||
|
conn.WriteJSON(wsRegisteredMessage{Type: "registered", User: UserInfo{}})
|
||||||
|
|
||||||
|
// Read progress
|
||||||
|
_, msg, err := conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
mu.Lock()
|
||||||
|
json.Unmarshal(msg, &receivedMsg)
|
||||||
|
mu.Unlock()
|
||||||
|
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
|
||||||
|
tr := NewWSTransport(wsURL, "key", "a1", "ua")
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
tr.Connect(ctx)
|
||||||
|
defer tr.Close()
|
||||||
|
tr.Register(ctx, RegisterRequest{AgentID: "a1"})
|
||||||
|
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
resp, err := tr.SendProgress(ctx, StatusUpdate{
|
||||||
|
TaskID: "t1",
|
||||||
|
Status: "downloading",
|
||||||
|
Progress: 42,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SendProgress failed: %v", err)
|
||||||
|
}
|
||||||
|
if !resp.Success {
|
||||||
|
t.Error("expected success response")
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
mu.Lock()
|
||||||
|
if receivedMsg["type"] != "progress" {
|
||||||
|
t.Errorf("expected progress, got %v", receivedMsg["type"])
|
||||||
|
}
|
||||||
|
if receivedMsg["taskId"] != "t1" {
|
||||||
|
t.Errorf("expected t1, got %v", receivedMsg["taskId"])
|
||||||
|
}
|
||||||
|
mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Hybrid Transport Tests ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestHybridTransportWSSuccess(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
|
||||||
|
wsT := NewWSTransport(wsURL, "key", "a1", "ua")
|
||||||
|
httpT := NewHTTPTransport("http://localhost", "key", "ua")
|
||||||
|
|
||||||
|
h := NewHybridTransport(wsT, httpT)
|
||||||
|
err := h.Connect(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Connect failed: %v", err)
|
||||||
|
}
|
||||||
|
defer h.Close()
|
||||||
|
|
||||||
|
if h.Mode() != "ws" {
|
||||||
|
t.Errorf("expected ws mode, got %s", h.Mode())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHybridTransportWSFailFallbackHTTP(t *testing.T) {
|
||||||
|
// WS URL points to nowhere
|
||||||
|
wsT := NewWSTransport("ws://127.0.0.1:1", "key", "a1", "ua")
|
||||||
|
httpT := NewHTTPTransport("http://localhost", "key", "ua")
|
||||||
|
|
||||||
|
h := NewHybridTransport(wsT, httpT)
|
||||||
|
err := h.Connect(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Connect should succeed with HTTP fallback: %v", err)
|
||||||
|
}
|
||||||
|
defer h.Close()
|
||||||
|
|
||||||
|
if h.Mode() != "http" {
|
||||||
|
t.Errorf("expected http mode after WS failure, got %s", h.Mode())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHybridTransportWSDisconnectSwitchesToHTTP(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Close immediately to trigger disconnect
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
conn.Close()
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
|
||||||
|
wsT := NewWSTransport(wsURL, "key", "a1", "ua")
|
||||||
|
httpT := NewHTTPTransport("http://localhost", "key", "ua")
|
||||||
|
|
||||||
|
h := NewHybridTransport(wsT, httpT)
|
||||||
|
h.Connect(context.Background())
|
||||||
|
defer h.Close()
|
||||||
|
|
||||||
|
// Wait for disconnect event
|
||||||
|
select {
|
||||||
|
case event := <-h.Events():
|
||||||
|
if event.Type != "disconnected" {
|
||||||
|
t.Errorf("expected disconnected, got %s", event.Type)
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for disconnected event")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mode should be HTTP now
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
if h.Mode() != "http" {
|
||||||
|
t.Errorf("expected http after disconnect, got %s", h.Mode())
|
||||||
|
}
|
||||||
|
}
|
||||||
360
internal/agent/transport_ws.go
Normal file
360
internal/agent/transport_ws.go
Normal file
|
|
@ -0,0 +1,360 @@
|
||||||
|
package agent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WSTransport communicates with the server via WebSocket through a Cloudflare Durable Object.
|
||||||
|
type WSTransport struct {
|
||||||
|
wsURL string // wss://unarr.torrentclaw.com/ws/{agentId}
|
||||||
|
apiKey string
|
||||||
|
agentID string
|
||||||
|
userAgent string
|
||||||
|
|
||||||
|
conn *websocket.Conn
|
||||||
|
mu sync.Mutex
|
||||||
|
events chan ServerEvent
|
||||||
|
closed atomic.Bool
|
||||||
|
|
||||||
|
// Cached auth response from the DO
|
||||||
|
authResp *RegisterResponse
|
||||||
|
authMu sync.Mutex
|
||||||
|
authDone chan struct{}
|
||||||
|
authDoneOnce sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWSTransport creates a WebSocket-based transport.
|
||||||
|
func NewWSTransport(wsURL, apiKey, agentID, userAgent string) *WSTransport {
|
||||||
|
return &WSTransport{
|
||||||
|
wsURL: wsURL,
|
||||||
|
apiKey: apiKey,
|
||||||
|
agentID: agentID,
|
||||||
|
userAgent: userAgent,
|
||||||
|
events: make(chan ServerEvent, 50),
|
||||||
|
authDone: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *WSTransport) Mode() string { return "ws" }
|
||||||
|
func (t *WSTransport) Events() <-chan ServerEvent { return t.events }
|
||||||
|
|
||||||
|
// Connect dials the WebSocket server and starts the read loop.
|
||||||
|
func (t *WSTransport) Connect(ctx context.Context) error {
|
||||||
|
dialer := websocket.Dialer{
|
||||||
|
HandshakeTimeout: 10 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
header := http.Header{}
|
||||||
|
header.Set("User-Agent", t.userAgent)
|
||||||
|
|
||||||
|
// Append API key as query param for auth on WS upgrade
|
||||||
|
wsURLWithKey := t.wsURL
|
||||||
|
if t.apiKey != "" {
|
||||||
|
sep := "?"
|
||||||
|
if strings.Contains(wsURLWithKey, "?") {
|
||||||
|
sep = "&"
|
||||||
|
}
|
||||||
|
wsURLWithKey += sep + "key=" + t.apiKey
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, _, err := dialer.DialContext(ctx, wsURLWithKey, header)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("ws dial: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.mu.Lock()
|
||||||
|
t.conn = conn
|
||||||
|
t.closed.Store(false)
|
||||||
|
t.authDone = make(chan struct{})
|
||||||
|
t.authDoneOnce = sync.Once{}
|
||||||
|
t.mu.Unlock()
|
||||||
|
|
||||||
|
go t.readLoop()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close sends a close frame and shuts down the connection.
|
||||||
|
func (t *WSTransport) Close() error {
|
||||||
|
t.closed.Store(true)
|
||||||
|
t.mu.Lock()
|
||||||
|
defer t.mu.Unlock()
|
||||||
|
if t.conn != nil {
|
||||||
|
_ = t.conn.WriteMessage(
|
||||||
|
websocket.CloseMessage,
|
||||||
|
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""),
|
||||||
|
)
|
||||||
|
err := t.conn.Close()
|
||||||
|
t.conn = nil
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register sends auth message and waits for the registered response.
|
||||||
|
func (t *WSTransport) Register(ctx context.Context, req RegisterRequest) (*RegisterResponse, error) {
|
||||||
|
msg := wsAuthMessage{
|
||||||
|
Type: "auth",
|
||||||
|
APIKey: t.apiKey,
|
||||||
|
AgentID: req.AgentID,
|
||||||
|
Name: req.Name,
|
||||||
|
OS: req.OS,
|
||||||
|
Arch: req.Arch,
|
||||||
|
Version: req.Version,
|
||||||
|
DownloadDir: req.DownloadDir,
|
||||||
|
DiskFreeBytes: req.DiskFreeBytes,
|
||||||
|
DiskTotalBytes: req.DiskTotalBytes,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := t.send(msg); err != nil {
|
||||||
|
return nil, fmt.Errorf("ws auth send: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the auth response or context cancellation
|
||||||
|
select {
|
||||||
|
case <-t.authDone:
|
||||||
|
t.authMu.Lock()
|
||||||
|
resp := t.authResp
|
||||||
|
t.authMu.Unlock()
|
||||||
|
if resp == nil {
|
||||||
|
return nil, fmt.Errorf("ws auth: no response received")
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case <-time.After(15 * time.Second):
|
||||||
|
return nil, fmt.Errorf("ws auth: timeout waiting for registered response")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendHeartbeat sends a heartbeat message. No blocking response in WS mode.
|
||||||
|
func (t *WSTransport) SendHeartbeat(_ context.Context, req HeartbeatRequest) (*HeartbeatResponse, error) {
|
||||||
|
msg := struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Disk *struct {
|
||||||
|
Free int64 `json:"free"`
|
||||||
|
Total int64 `json:"total"`
|
||||||
|
} `json:"disk,omitempty"`
|
||||||
|
}{Type: "heartbeat"}
|
||||||
|
|
||||||
|
if req.DiskFreeBytes > 0 || req.DiskTotalBytes > 0 {
|
||||||
|
msg.Disk = &struct {
|
||||||
|
Free int64 `json:"free"`
|
||||||
|
Total int64 `json:"total"`
|
||||||
|
}{Free: req.DiskFreeBytes, Total: req.DiskTotalBytes}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := t.send(msg); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// WS mode: heartbeat is fire-and-forget. Upgrade signals arrive via Events().
|
||||||
|
return &HeartbeatResponse{Success: true}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendProgress sends a progress update. Control signals arrive async via Events().
|
||||||
|
func (t *WSTransport) SendProgress(_ context.Context, update StatusUpdate) (*StatusResponse, error) {
|
||||||
|
msg := struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
TaskID string `json:"taskId"`
|
||||||
|
Status string `json:"status,omitempty"`
|
||||||
|
Progress int `json:"progress,omitempty"`
|
||||||
|
DownloadedBytes int64 `json:"downloadedBytes,omitempty"`
|
||||||
|
TotalBytes int64 `json:"totalBytes,omitempty"`
|
||||||
|
SpeedBps int64 `json:"speedBps,omitempty"`
|
||||||
|
ETA int `json:"eta,omitempty"`
|
||||||
|
ResolvedMethod string `json:"resolvedMethod,omitempty"`
|
||||||
|
FileName string `json:"fileName,omitempty"`
|
||||||
|
FilePath string `json:"filePath,omitempty"`
|
||||||
|
StreamURL string `json:"streamUrl,omitempty"`
|
||||||
|
ErrorMessage string `json:"errorMessage,omitempty"`
|
||||||
|
}{
|
||||||
|
Type: "progress",
|
||||||
|
TaskID: update.TaskID,
|
||||||
|
Status: update.Status,
|
||||||
|
Progress: update.Progress,
|
||||||
|
DownloadedBytes: update.DownloadedBytes,
|
||||||
|
TotalBytes: update.TotalBytes,
|
||||||
|
SpeedBps: update.SpeedBps,
|
||||||
|
ETA: update.ETA,
|
||||||
|
ResolvedMethod: update.ResolvedMethod,
|
||||||
|
FileName: update.FileName,
|
||||||
|
FilePath: update.FilePath,
|
||||||
|
StreamURL: update.StreamURL,
|
||||||
|
ErrorMessage: update.ErrorMessage,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := t.send(msg); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// In WS mode, control signals come via Events(), not in the progress response.
|
||||||
|
return &StatusResponse{Success: true}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClaimTasks is a no-op in WS mode — tasks arrive via Events().
|
||||||
|
func (t *WSTransport) ClaimTasks(_ context.Context, _ string) (*TasksResponse, error) {
|
||||||
|
return &TasksResponse{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deregister is handled by WebSocket close (DO detects disconnection).
|
||||||
|
func (t *WSTransport) Deregister(_ context.Context, _ string) error {
|
||||||
|
return t.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReportUpgradeResult sends upgrade result to the DO.
|
||||||
|
func (t *WSTransport) ReportUpgradeResult(_ context.Context, result UpgradeResult) error {
|
||||||
|
msg := struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Success bool `json:"success"`
|
||||||
|
Version string `json:"version,omitempty"`
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
}{
|
||||||
|
Type: "upgrade-result",
|
||||||
|
Success: result.Success,
|
||||||
|
Version: result.Version,
|
||||||
|
Error: result.Error,
|
||||||
|
}
|
||||||
|
return t.send(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Internal ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func (t *WSTransport) send(msg any) error {
|
||||||
|
t.mu.Lock()
|
||||||
|
defer t.mu.Unlock()
|
||||||
|
if t.conn == nil {
|
||||||
|
return fmt.Errorf("ws: not connected")
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(msg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return t.conn.WriteMessage(websocket.TextMessage, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *WSTransport) readLoop() {
|
||||||
|
for {
|
||||||
|
_, msg, err := t.conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
if !t.closed.Load() {
|
||||||
|
log.Printf("[ws] read error: %v", err)
|
||||||
|
// Signal disconnection to the daemon
|
||||||
|
select {
|
||||||
|
case t.events <- ServerEvent{Type: "disconnected"}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var envelope struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(msg, &envelope); err != nil {
|
||||||
|
log.Printf("[ws] invalid message: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch envelope.Type {
|
||||||
|
case "registered":
|
||||||
|
var resp wsRegisteredMessage
|
||||||
|
if json.Unmarshal(msg, &resp) == nil {
|
||||||
|
t.authMu.Lock()
|
||||||
|
t.authResp = &RegisterResponse{
|
||||||
|
Success: true,
|
||||||
|
User: resp.User,
|
||||||
|
Features: resp.Features,
|
||||||
|
}
|
||||||
|
t.authMu.Unlock()
|
||||||
|
// Signal that auth is complete (sync.Once prevents double-close panic)
|
||||||
|
t.authDoneOnce.Do(func() { close(t.authDone) })
|
||||||
|
}
|
||||||
|
|
||||||
|
case "tasks":
|
||||||
|
var resp wsTasksMessage
|
||||||
|
if json.Unmarshal(msg, &resp) == nil {
|
||||||
|
select {
|
||||||
|
case t.events <- ServerEvent{
|
||||||
|
Type: "tasks",
|
||||||
|
Tasks: &TasksResponse{
|
||||||
|
Tasks: resp.Tasks,
|
||||||
|
StreamRequests: resp.StreamRequests,
|
||||||
|
},
|
||||||
|
}:
|
||||||
|
default:
|
||||||
|
log.Printf("[ws] events channel full, dropping tasks message")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case "upgrade":
|
||||||
|
var resp wsUpgradeMessage
|
||||||
|
if json.Unmarshal(msg, &resp) == nil {
|
||||||
|
select {
|
||||||
|
case t.events <- ServerEvent{
|
||||||
|
Type: "upgrade",
|
||||||
|
Upgrade: &UpgradeSignal{Version: resp.Version},
|
||||||
|
}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case "control":
|
||||||
|
var resp ControlAction
|
||||||
|
if json.Unmarshal(msg, &resp) == nil {
|
||||||
|
select {
|
||||||
|
case t.events <- ServerEvent{
|
||||||
|
Type: "control",
|
||||||
|
Control: &resp,
|
||||||
|
}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case "error":
|
||||||
|
var resp struct{ Message string `json:"message"` }
|
||||||
|
if json.Unmarshal(msg, &resp) == nil {
|
||||||
|
log.Printf("[ws] server error: %s", resp.Message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── WS message types ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
type wsAuthMessage struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
APIKey string `json:"apiKey"`
|
||||||
|
AgentID string `json:"agentId"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
OS string `json:"os,omitempty"`
|
||||||
|
Arch string `json:"arch,omitempty"`
|
||||||
|
Version string `json:"version,omitempty"`
|
||||||
|
DownloadDir string `json:"downloadDir,omitempty"`
|
||||||
|
DiskFreeBytes int64 `json:"diskFreeBytes,omitempty"`
|
||||||
|
DiskTotalBytes int64 `json:"diskTotalBytes,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type wsRegisteredMessage struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
User UserInfo `json:"user"`
|
||||||
|
Features FeatureFlags `json:"features"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type wsTasksMessage struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Tasks []Task `json:"tasks"`
|
||||||
|
StreamRequests []StreamRequest `json:"streamRequests,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type wsUpgradeMessage struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Version string `json:"version"`
|
||||||
|
}
|
||||||
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
|
@ -104,8 +105,6 @@ func runDaemonStart() error {
|
||||||
heartbeatInterval = 30 * time.Second
|
heartbeatInterval = 30 * time.Second
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create agent client (direct HTTP — always available as fallback)
|
|
||||||
ac := agent.NewClient(cfg.Auth.APIURL, cfg.Auth.APIKey, "unarr/"+Version)
|
|
||||||
userAgent := "unarr/" + Version
|
userAgent := "unarr/" + Version
|
||||||
|
|
||||||
// Create daemon config
|
// Create daemon config
|
||||||
|
|
@ -119,6 +118,8 @@ func runDaemonStart() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create transport: Hybrid (WS + HTTP fallback) or HTTP-only
|
// Create transport: Hybrid (WS + HTTP fallback) or HTTP-only
|
||||||
|
httpT := agent.NewHTTPTransport(cfg.Auth.APIURL, cfg.Auth.APIKey, userAgent)
|
||||||
|
|
||||||
wsURL := cfg.Auth.WSURL
|
wsURL := cfg.Auth.WSURL
|
||||||
if wsURL == "" {
|
if wsURL == "" {
|
||||||
wsURL = deriveWSURL(cfg.Auth.APIURL, cfg.Agent.ID)
|
wsURL = deriveWSURL(cfg.Auth.APIURL, cfg.Agent.ID)
|
||||||
|
|
@ -126,28 +127,19 @@ func runDaemonStart() error {
|
||||||
|
|
||||||
var transport agent.Transport
|
var transport agent.Transport
|
||||||
if wsURL != "" {
|
if wsURL != "" {
|
||||||
httpT := agent.NewHTTPTransport(cfg.Auth.APIURL, cfg.Auth.APIKey, userAgent)
|
|
||||||
wsT := agent.NewWSTransport(wsURL, cfg.Auth.APIKey, cfg.Agent.ID, userAgent)
|
wsT := agent.NewWSTransport(wsURL, cfg.Auth.APIKey, cfg.Agent.ID, userAgent)
|
||||||
transport = agent.NewHybridTransport(wsT, httpT)
|
transport = agent.NewHybridTransport(wsT, httpT)
|
||||||
log.Printf("Transport: WebSocket (fallback: HTTP) → %s", wsURL)
|
log.Printf("Transport: WebSocket (fallback: HTTP) → %s", wsURL)
|
||||||
|
} else {
|
||||||
|
transport = httpT
|
||||||
|
log.Println("Transport: HTTP only")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create daemon
|
// Create daemon — always uses Transport interface
|
||||||
var d *agent.Daemon
|
d := agent.NewDaemon(daemonCfg, transport)
|
||||||
if transport != nil {
|
|
||||||
d = agent.NewDaemonWithTransport(daemonCfg, transport)
|
|
||||||
} else {
|
|
||||||
d = agent.NewDaemon(daemonCfg, ac)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wire state tracking (connected after manager creation below)
|
// Create progress reporter using transport
|
||||||
// Create progress reporter
|
reporter := engine.NewProgressReporterWithTransport(transport, 3*time.Second)
|
||||||
var reporter *engine.ProgressReporter
|
|
||||||
if transport != nil {
|
|
||||||
reporter = engine.NewProgressReporterWithTransport(transport, 3*time.Second)
|
|
||||||
} else {
|
|
||||||
reporter = engine.NewProgressReporter(ac, 3*time.Second)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse speed limits
|
// Parse speed limits
|
||||||
maxDl, _ := config.ParseSpeed(cfg.Download.MaxDownloadSpeed)
|
maxDl, _ := config.ParseSpeed(cfg.Download.MaxDownloadSpeed)
|
||||||
|
|
@ -190,7 +182,7 @@ func runDaemonStart() error {
|
||||||
MoviesDir: cfg.Organize.MoviesDir,
|
MoviesDir: cfg.Organize.MoviesDir,
|
||||||
TVShowsDir: cfg.Organize.TVShowsDir,
|
TVShowsDir: cfg.Organize.TVShowsDir,
|
||||||
},
|
},
|
||||||
}, reporter, torrentDl, debridDl)
|
}, reporter, torrentDl, debridDl, engine.NewUsenetDownloader(httpT.Client()))
|
||||||
|
|
||||||
// Wire state tracking
|
// Wire state tracking
|
||||||
d.GetActiveCount = manager.ActiveCount
|
d.GetActiveCount = manager.ActiveCount
|
||||||
|
|
@ -275,9 +267,9 @@ func runDaemonStart() error {
|
||||||
|
|
||||||
log.Printf("[%s] streaming from disk: %s → %s", sr.TaskID[:8], filepath.Base(sr.FilePath), streamURL)
|
log.Printf("[%s] streaming from disk: %s → %s", sr.TaskID[:8], filepath.Base(sr.FilePath), streamURL)
|
||||||
|
|
||||||
// Report stream URL back to the server
|
// Report stream URL back to the server via transport
|
||||||
go func() {
|
go func() {
|
||||||
if _, err := ac.ReportStatus(ctx, agent.StatusUpdate{
|
if _, err := transport.SendProgress(ctx, agent.StatusUpdate{
|
||||||
TaskID: sr.TaskID,
|
TaskID: sr.TaskID,
|
||||||
StreamURL: streamURL,
|
StreamURL: streamURL,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
|
|
@ -298,13 +290,18 @@ func runDaemonStart() error {
|
||||||
case "resume":
|
case "resume":
|
||||||
log.Printf("[%s] resume requested via WebSocket", taskID[:8])
|
log.Printf("[%s] resume requested via WebSocket", taskID[:8])
|
||||||
case "stream":
|
case "stream":
|
||||||
|
// Use registry mutex to prevent TOCTOU race with HTTP-polled stream requests
|
||||||
|
streamRegistry.mu.Lock()
|
||||||
|
if _, exists := streamRegistry.servers[taskID]; exists {
|
||||||
|
streamRegistry.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
task := manager.GetTask(taskID)
|
task := manager.GetTask(taskID)
|
||||||
if task == nil {
|
if task == nil || task.GetStreamURL() != "" {
|
||||||
return
|
streamRegistry.mu.Unlock()
|
||||||
}
|
|
||||||
if task.GetStreamURL() != "" {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
streamRegistry.mu.Unlock()
|
||||||
srv, err := torrentDl.StartStream(taskID)
|
srv, err := torrentDl.StartStream(taskID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[%s] stream failed: %v", taskID[:8], err)
|
log.Printf("[%s] stream failed: %v", taskID[:8], err)
|
||||||
|
|
@ -342,11 +339,7 @@ func runDaemonStart() error {
|
||||||
Version: result.NewVersion,
|
Version: result.NewVersion,
|
||||||
Error: errMsg,
|
Error: errMsg,
|
||||||
}
|
}
|
||||||
if transport != nil {
|
|
||||||
_ = transport.ReportUpgradeResult(reportCtx, upgradeResult)
|
_ = transport.ReportUpgradeResult(reportCtx, upgradeResult)
|
||||||
} else {
|
|
||||||
_ = ac.ReportUpgradeResult(reportCtx, upgradeResult)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !result.Success {
|
if !result.Success {
|
||||||
log.Printf("Upgrade failed: %v", result.Error)
|
log.Printf("Upgrade failed: %v", result.Error)
|
||||||
|
|
@ -360,7 +353,7 @@ func runDaemonStart() error {
|
||||||
// Deregister first so the server knows we're restarting
|
// Deregister first so the server knows we're restarting
|
||||||
deregCtx, deregCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
deregCtx, deregCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer deregCancel()
|
defer deregCancel()
|
||||||
_ = ac.Deregister(deregCtx, cfg.Agent.ID)
|
_ = transport.Deregister(deregCtx, cfg.Agent.ID)
|
||||||
|
|
||||||
// Flush progress reporter
|
// Flush progress reporter
|
||||||
cancel()
|
cancel()
|
||||||
|
|
@ -418,6 +411,7 @@ func runDaemonStart() error {
|
||||||
|
|
||||||
// deriveWSURL derives a WebSocket URL from the API URL.
|
// deriveWSURL derives a WebSocket URL from the API URL.
|
||||||
// https://torrentclaw.com → wss://unarr.torrentclaw.com/ws/{agentId}
|
// https://torrentclaw.com → wss://unarr.torrentclaw.com/ws/{agentId}
|
||||||
|
// Returns "" for localhost/dev environments where WS gateway isn't available.
|
||||||
func deriveWSURL(apiURL, agentID string) string {
|
func deriveWSURL(apiURL, agentID string) string {
|
||||||
if apiURL == "" || agentID == "" {
|
if apiURL == "" || agentID == "" {
|
||||||
return ""
|
return ""
|
||||||
|
|
@ -437,6 +431,15 @@ func deriveWSURL(apiURL, agentID string) string {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Strip port if present
|
||||||
|
if idx := strings.LastIndex(domain, ":"); idx > 0 {
|
||||||
|
domain = domain[:idx]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip WS for localhost/dev — gateway only available in production
|
||||||
|
if domain == "localhost" || domain == "127.0.0.1" || domain == "0.0.0.0" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
return "wss://unarr." + domain + "/ws/" + agentID
|
return "wss://unarr." + domain + "/ws/" + agentID
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,7 @@ type Config struct {
|
||||||
type AuthConfig struct {
|
type AuthConfig struct {
|
||||||
APIKey string `toml:"api_key"`
|
APIKey string `toml:"api_key"`
|
||||||
APIURL string `toml:"api_url"`
|
APIURL string `toml:"api_url"`
|
||||||
|
WSURL string `toml:"ws_url"` // optional, derived from api_url if empty
|
||||||
}
|
}
|
||||||
|
|
||||||
type AgentConfig struct {
|
type AgentConfig struct {
|
||||||
|
|
|
||||||
|
|
@ -12,10 +12,16 @@ import (
|
||||||
// ActionFunc is called when the server signals an action on a task.
|
// ActionFunc is called when the server signals an action on a task.
|
||||||
type ActionFunc func(taskID string)
|
type ActionFunc func(taskID string)
|
||||||
|
|
||||||
|
// StatusReporter is the interface used by ProgressReporter to send progress updates.
|
||||||
|
// Both *agent.Client and agent.Transport implement this via their ReportStatus/SendProgress methods.
|
||||||
|
type StatusReporter interface {
|
||||||
|
ReportStatus(ctx context.Context, update agent.StatusUpdate) (*agent.StatusResponse, error)
|
||||||
|
}
|
||||||
|
|
||||||
// ProgressReporter aggregates progress from downloads and reports to the API.
|
// ProgressReporter aggregates progress from downloads and reports to the API.
|
||||||
// It batches updates to avoid flooding the server.
|
// It batches updates to avoid flooding the server.
|
||||||
type ProgressReporter struct {
|
type ProgressReporter struct {
|
||||||
agentClient *agent.Client
|
reporter StatusReporter
|
||||||
interval time.Duration
|
interval time.Duration
|
||||||
|
|
||||||
onCancel ActionFunc
|
onCancel ActionFunc
|
||||||
|
|
@ -28,14 +34,33 @@ type ProgressReporter struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewProgressReporter creates a reporter that flushes every interval.
|
// NewProgressReporter creates a reporter that flushes every interval.
|
||||||
|
// Accepts *agent.Client directly (backwards compatible).
|
||||||
func NewProgressReporter(ac *agent.Client, interval time.Duration) *ProgressReporter {
|
func NewProgressReporter(ac *agent.Client, interval time.Duration) *ProgressReporter {
|
||||||
return &ProgressReporter{
|
return &ProgressReporter{
|
||||||
agentClient: ac,
|
reporter: ac,
|
||||||
interval: interval,
|
interval: interval,
|
||||||
latest: make(map[string]*Task),
|
latest: make(map[string]*Task),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewProgressReporterWithTransport creates a reporter using a Transport.
|
||||||
|
func NewProgressReporterWithTransport(t agent.Transport, interval time.Duration) *ProgressReporter {
|
||||||
|
return &ProgressReporter{
|
||||||
|
reporter: &transportStatusAdapter{t: t},
|
||||||
|
interval: interval,
|
||||||
|
latest: make(map[string]*Task),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// transportStatusAdapter adapts agent.Transport to StatusReporter.
|
||||||
|
type transportStatusAdapter struct {
|
||||||
|
t agent.Transport
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *transportStatusAdapter) ReportStatus(ctx context.Context, update agent.StatusUpdate) (*agent.StatusResponse, error) {
|
||||||
|
return a.t.SendProgress(ctx, update)
|
||||||
|
}
|
||||||
|
|
||||||
// SetCancelHandler sets the callback invoked when the server says a task is cancelled.
|
// SetCancelHandler sets the callback invoked when the server says a task is cancelled.
|
||||||
func (r *ProgressReporter) SetCancelHandler(fn ActionFunc) { r.onCancel = fn }
|
func (r *ProgressReporter) SetCancelHandler(fn ActionFunc) { r.onCancel = fn }
|
||||||
|
|
||||||
|
|
@ -95,7 +120,7 @@ func (r *ProgressReporter) flush(ctx context.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
update := task.ToStatusUpdate()
|
update := task.ToStatusUpdate()
|
||||||
resp, err := r.agentClient.ReportStatus(ctx, update)
|
resp, err := r.reporter.ReportStatus(ctx, update)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[%s] progress report failed: %v", task.ID[:8], err)
|
log.Printf("[%s] progress report failed: %v", task.ID[:8], err)
|
||||||
continue
|
continue
|
||||||
|
|
@ -130,7 +155,7 @@ func (r *ProgressReporter) flush(ctx context.Context) {
|
||||||
// ReportFinal sends a final status update for a completed/failed task.
|
// ReportFinal sends a final status update for a completed/failed task.
|
||||||
func (r *ProgressReporter) ReportFinal(ctx context.Context, task *Task) {
|
func (r *ProgressReporter) ReportFinal(ctx context.Context, task *Task) {
|
||||||
update := task.ToStatusUpdate()
|
update := task.ToStatusUpdate()
|
||||||
if _, err := r.agentClient.ReportStatus(ctx, update); err != nil {
|
if _, err := r.reporter.ReportStatus(ctx, update); err != nil {
|
||||||
log.Printf("[%s] final report failed: %v", task.ID[:8], err)
|
log.Printf("[%s] final report failed: %v", task.ID[:8], err)
|
||||||
}
|
}
|
||||||
r.Untrack(task.ID)
|
r.Untrack(task.ID)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue