feat(agent): add WebSocket transport with HTTP fallback
Add Transport interface abstraction supporting WebSocket (via CF Durable Objects) and HTTP (direct to origin) with automatic failover. - Transport interface: Register, SendHeartbeat, SendProgress, Events() - HTTPTransport: thin adapter over existing Client - WSTransport: gorilla/websocket with auth handshake, readLoop, reconnect - HybridTransport: tries WS first, falls back to HTTP, reconnects in bg - Daemon refactored to always use Transport (no dual-path forks) - ProgressReporter accepts StatusReporter interface - deriveWSURL skips localhost/dev (returns "" → HTTP-only) - API key passed in WS query param for connection auth - Fixed: reconnectOnce race (mutex+bool), authDone double-close (sync.Once) - Fixed: forwardWSEvents goroutine leak (select with stop signal) - 20 transport tests + 2 E2E tests (full lifecycle, hybrid failover)
This commit is contained in:
parent
5e80911501
commit
5f337eebd7
10 changed files with 1646 additions and 64 deletions
360
internal/agent/transport_ws.go
Normal file
360
internal/agent/transport_ws.go
Normal file
|
|
@ -0,0 +1,360 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// WSTransport communicates with the server via WebSocket through a Cloudflare Durable Object.
|
||||
type WSTransport struct {
|
||||
wsURL string // wss://unarr.torrentclaw.com/ws/{agentId}
|
||||
apiKey string
|
||||
agentID string
|
||||
userAgent string
|
||||
|
||||
conn *websocket.Conn
|
||||
mu sync.Mutex
|
||||
events chan ServerEvent
|
||||
closed atomic.Bool
|
||||
|
||||
// Cached auth response from the DO
|
||||
authResp *RegisterResponse
|
||||
authMu sync.Mutex
|
||||
authDone chan struct{}
|
||||
authDoneOnce sync.Once
|
||||
}
|
||||
|
||||
// NewWSTransport creates a WebSocket-based transport.
|
||||
func NewWSTransport(wsURL, apiKey, agentID, userAgent string) *WSTransport {
|
||||
return &WSTransport{
|
||||
wsURL: wsURL,
|
||||
apiKey: apiKey,
|
||||
agentID: agentID,
|
||||
userAgent: userAgent,
|
||||
events: make(chan ServerEvent, 50),
|
||||
authDone: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *WSTransport) Mode() string { return "ws" }
|
||||
func (t *WSTransport) Events() <-chan ServerEvent { return t.events }
|
||||
|
||||
// Connect dials the WebSocket server and starts the read loop.
|
||||
func (t *WSTransport) Connect(ctx context.Context) error {
|
||||
dialer := websocket.Dialer{
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
header := http.Header{}
|
||||
header.Set("User-Agent", t.userAgent)
|
||||
|
||||
// Append API key as query param for auth on WS upgrade
|
||||
wsURLWithKey := t.wsURL
|
||||
if t.apiKey != "" {
|
||||
sep := "?"
|
||||
if strings.Contains(wsURLWithKey, "?") {
|
||||
sep = "&"
|
||||
}
|
||||
wsURLWithKey += sep + "key=" + t.apiKey
|
||||
}
|
||||
|
||||
conn, _, err := dialer.DialContext(ctx, wsURLWithKey, header)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ws dial: %w", err)
|
||||
}
|
||||
|
||||
t.mu.Lock()
|
||||
t.conn = conn
|
||||
t.closed.Store(false)
|
||||
t.authDone = make(chan struct{})
|
||||
t.authDoneOnce = sync.Once{}
|
||||
t.mu.Unlock()
|
||||
|
||||
go t.readLoop()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close sends a close frame and shuts down the connection.
|
||||
func (t *WSTransport) Close() error {
|
||||
t.closed.Store(true)
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
if t.conn != nil {
|
||||
_ = t.conn.WriteMessage(
|
||||
websocket.CloseMessage,
|
||||
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""),
|
||||
)
|
||||
err := t.conn.Close()
|
||||
t.conn = nil
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Register sends auth message and waits for the registered response.
|
||||
func (t *WSTransport) Register(ctx context.Context, req RegisterRequest) (*RegisterResponse, error) {
|
||||
msg := wsAuthMessage{
|
||||
Type: "auth",
|
||||
APIKey: t.apiKey,
|
||||
AgentID: req.AgentID,
|
||||
Name: req.Name,
|
||||
OS: req.OS,
|
||||
Arch: req.Arch,
|
||||
Version: req.Version,
|
||||
DownloadDir: req.DownloadDir,
|
||||
DiskFreeBytes: req.DiskFreeBytes,
|
||||
DiskTotalBytes: req.DiskTotalBytes,
|
||||
}
|
||||
|
||||
if err := t.send(msg); err != nil {
|
||||
return nil, fmt.Errorf("ws auth send: %w", err)
|
||||
}
|
||||
|
||||
// Wait for the auth response or context cancellation
|
||||
select {
|
||||
case <-t.authDone:
|
||||
t.authMu.Lock()
|
||||
resp := t.authResp
|
||||
t.authMu.Unlock()
|
||||
if resp == nil {
|
||||
return nil, fmt.Errorf("ws auth: no response received")
|
||||
}
|
||||
return resp, nil
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(15 * time.Second):
|
||||
return nil, fmt.Errorf("ws auth: timeout waiting for registered response")
|
||||
}
|
||||
}
|
||||
|
||||
// SendHeartbeat sends a heartbeat message. No blocking response in WS mode.
|
||||
func (t *WSTransport) SendHeartbeat(_ context.Context, req HeartbeatRequest) (*HeartbeatResponse, error) {
|
||||
msg := struct {
|
||||
Type string `json:"type"`
|
||||
Disk *struct {
|
||||
Free int64 `json:"free"`
|
||||
Total int64 `json:"total"`
|
||||
} `json:"disk,omitempty"`
|
||||
}{Type: "heartbeat"}
|
||||
|
||||
if req.DiskFreeBytes > 0 || req.DiskTotalBytes > 0 {
|
||||
msg.Disk = &struct {
|
||||
Free int64 `json:"free"`
|
||||
Total int64 `json:"total"`
|
||||
}{Free: req.DiskFreeBytes, Total: req.DiskTotalBytes}
|
||||
}
|
||||
|
||||
if err := t.send(msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// WS mode: heartbeat is fire-and-forget. Upgrade signals arrive via Events().
|
||||
return &HeartbeatResponse{Success: true}, nil
|
||||
}
|
||||
|
||||
// SendProgress sends a progress update. Control signals arrive async via Events().
|
||||
func (t *WSTransport) SendProgress(_ context.Context, update StatusUpdate) (*StatusResponse, error) {
|
||||
msg := struct {
|
||||
Type string `json:"type"`
|
||||
TaskID string `json:"taskId"`
|
||||
Status string `json:"status,omitempty"`
|
||||
Progress int `json:"progress,omitempty"`
|
||||
DownloadedBytes int64 `json:"downloadedBytes,omitempty"`
|
||||
TotalBytes int64 `json:"totalBytes,omitempty"`
|
||||
SpeedBps int64 `json:"speedBps,omitempty"`
|
||||
ETA int `json:"eta,omitempty"`
|
||||
ResolvedMethod string `json:"resolvedMethod,omitempty"`
|
||||
FileName string `json:"fileName,omitempty"`
|
||||
FilePath string `json:"filePath,omitempty"`
|
||||
StreamURL string `json:"streamUrl,omitempty"`
|
||||
ErrorMessage string `json:"errorMessage,omitempty"`
|
||||
}{
|
||||
Type: "progress",
|
||||
TaskID: update.TaskID,
|
||||
Status: update.Status,
|
||||
Progress: update.Progress,
|
||||
DownloadedBytes: update.DownloadedBytes,
|
||||
TotalBytes: update.TotalBytes,
|
||||
SpeedBps: update.SpeedBps,
|
||||
ETA: update.ETA,
|
||||
ResolvedMethod: update.ResolvedMethod,
|
||||
FileName: update.FileName,
|
||||
FilePath: update.FilePath,
|
||||
StreamURL: update.StreamURL,
|
||||
ErrorMessage: update.ErrorMessage,
|
||||
}
|
||||
|
||||
if err := t.send(msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// In WS mode, control signals come via Events(), not in the progress response.
|
||||
return &StatusResponse{Success: true}, nil
|
||||
}
|
||||
|
||||
// ClaimTasks is a no-op in WS mode — tasks arrive via Events().
|
||||
func (t *WSTransport) ClaimTasks(_ context.Context, _ string) (*TasksResponse, error) {
|
||||
return &TasksResponse{}, nil
|
||||
}
|
||||
|
||||
// Deregister is handled by WebSocket close (DO detects disconnection).
|
||||
func (t *WSTransport) Deregister(_ context.Context, _ string) error {
|
||||
return t.Close()
|
||||
}
|
||||
|
||||
// ReportUpgradeResult sends upgrade result to the DO.
|
||||
func (t *WSTransport) ReportUpgradeResult(_ context.Context, result UpgradeResult) error {
|
||||
msg := struct {
|
||||
Type string `json:"type"`
|
||||
Success bool `json:"success"`
|
||||
Version string `json:"version,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}{
|
||||
Type: "upgrade-result",
|
||||
Success: result.Success,
|
||||
Version: result.Version,
|
||||
Error: result.Error,
|
||||
}
|
||||
return t.send(msg)
|
||||
}
|
||||
|
||||
// ── Internal ─────────────────────────────────────────────────────────────────
|
||||
|
||||
func (t *WSTransport) send(msg any) error {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
if t.conn == nil {
|
||||
return fmt.Errorf("ws: not connected")
|
||||
}
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return t.conn.WriteMessage(websocket.TextMessage, data)
|
||||
}
|
||||
|
||||
func (t *WSTransport) readLoop() {
|
||||
for {
|
||||
_, msg, err := t.conn.ReadMessage()
|
||||
if err != nil {
|
||||
if !t.closed.Load() {
|
||||
log.Printf("[ws] read error: %v", err)
|
||||
// Signal disconnection to the daemon
|
||||
select {
|
||||
case t.events <- ServerEvent{Type: "disconnected"}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var envelope struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
if err := json.Unmarshal(msg, &envelope); err != nil {
|
||||
log.Printf("[ws] invalid message: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
switch envelope.Type {
|
||||
case "registered":
|
||||
var resp wsRegisteredMessage
|
||||
if json.Unmarshal(msg, &resp) == nil {
|
||||
t.authMu.Lock()
|
||||
t.authResp = &RegisterResponse{
|
||||
Success: true,
|
||||
User: resp.User,
|
||||
Features: resp.Features,
|
||||
}
|
||||
t.authMu.Unlock()
|
||||
// Signal that auth is complete (sync.Once prevents double-close panic)
|
||||
t.authDoneOnce.Do(func() { close(t.authDone) })
|
||||
}
|
||||
|
||||
case "tasks":
|
||||
var resp wsTasksMessage
|
||||
if json.Unmarshal(msg, &resp) == nil {
|
||||
select {
|
||||
case t.events <- ServerEvent{
|
||||
Type: "tasks",
|
||||
Tasks: &TasksResponse{
|
||||
Tasks: resp.Tasks,
|
||||
StreamRequests: resp.StreamRequests,
|
||||
},
|
||||
}:
|
||||
default:
|
||||
log.Printf("[ws] events channel full, dropping tasks message")
|
||||
}
|
||||
}
|
||||
|
||||
case "upgrade":
|
||||
var resp wsUpgradeMessage
|
||||
if json.Unmarshal(msg, &resp) == nil {
|
||||
select {
|
||||
case t.events <- ServerEvent{
|
||||
Type: "upgrade",
|
||||
Upgrade: &UpgradeSignal{Version: resp.Version},
|
||||
}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
case "control":
|
||||
var resp ControlAction
|
||||
if json.Unmarshal(msg, &resp) == nil {
|
||||
select {
|
||||
case t.events <- ServerEvent{
|
||||
Type: "control",
|
||||
Control: &resp,
|
||||
}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
case "error":
|
||||
var resp struct{ Message string `json:"message"` }
|
||||
if json.Unmarshal(msg, &resp) == nil {
|
||||
log.Printf("[ws] server error: %s", resp.Message)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── WS message types ─────────────────────────────────────────────────────────
|
||||
|
||||
type wsAuthMessage struct {
|
||||
Type string `json:"type"`
|
||||
APIKey string `json:"apiKey"`
|
||||
AgentID string `json:"agentId"`
|
||||
Name string `json:"name,omitempty"`
|
||||
OS string `json:"os,omitempty"`
|
||||
Arch string `json:"arch,omitempty"`
|
||||
Version string `json:"version,omitempty"`
|
||||
DownloadDir string `json:"downloadDir,omitempty"`
|
||||
DiskFreeBytes int64 `json:"diskFreeBytes,omitempty"`
|
||||
DiskTotalBytes int64 `json:"diskTotalBytes,omitempty"`
|
||||
}
|
||||
|
||||
type wsRegisteredMessage struct {
|
||||
Type string `json:"type"`
|
||||
User UserInfo `json:"user"`
|
||||
Features FeatureFlags `json:"features"`
|
||||
}
|
||||
|
||||
type wsTasksMessage struct {
|
||||
Type string `json:"type"`
|
||||
Tasks []Task `json:"tasks"`
|
||||
StreamRequests []StreamRequest `json:"streamRequests,omitempty"`
|
||||
}
|
||||
|
||||
type wsUpgradeMessage struct {
|
||||
Type string `json:"type"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue