349 lines
9.3 KiB
Go
349 lines
9.3 KiB
Go
package agent
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
// WSTransport communicates with the server via WebSocket through a Cloudflare Durable Object.
|
|
type WSTransport struct {
|
|
wsURL string // wss://unarr.torrentclaw.com/ws/{agentId}
|
|
apiKey string
|
|
agentID string
|
|
userAgent string
|
|
|
|
conn *websocket.Conn
|
|
mu sync.Mutex
|
|
events chan ServerEvent
|
|
closed atomic.Bool
|
|
|
|
// Cached auth response from the DO
|
|
authResp *RegisterResponse
|
|
authMu sync.Mutex
|
|
authDone chan struct{}
|
|
authDoneOnce sync.Once
|
|
}
|
|
|
|
// NewWSTransport creates a WebSocket-based transport.
|
|
func NewWSTransport(wsURL, apiKey, agentID, userAgent string) *WSTransport {
|
|
return &WSTransport{
|
|
wsURL: wsURL,
|
|
apiKey: apiKey,
|
|
agentID: agentID,
|
|
userAgent: userAgent,
|
|
events: make(chan ServerEvent, 50),
|
|
authDone: make(chan struct{}),
|
|
}
|
|
}
|
|
|
|
func (t *WSTransport) Mode() string { return "ws" }
|
|
func (t *WSTransport) Events() <-chan ServerEvent { return t.events }
|
|
|
|
// Connect dials the WebSocket server and starts the read loop.
|
|
func (t *WSTransport) Connect(ctx context.Context) error {
|
|
dialer := websocket.Dialer{
|
|
HandshakeTimeout: 10 * time.Second,
|
|
}
|
|
|
|
header := http.Header{}
|
|
header.Set("User-Agent", t.userAgent)
|
|
|
|
// Append API key as query param for auth on WS upgrade
|
|
wsURLWithKey := t.wsURL
|
|
if t.apiKey != "" {
|
|
sep := "?"
|
|
if strings.Contains(wsURLWithKey, "?") {
|
|
sep = "&"
|
|
}
|
|
wsURLWithKey += sep + "key=" + t.apiKey
|
|
}
|
|
|
|
conn, wsResp, err := dialer.DialContext(ctx, wsURLWithKey, header)
|
|
if wsResp != nil && wsResp.Body != nil {
|
|
defer wsResp.Body.Close()
|
|
}
|
|
if err != nil {
|
|
return fmt.Errorf("ws dial: %w", err)
|
|
}
|
|
|
|
t.mu.Lock()
|
|
t.conn = conn
|
|
t.closed.Store(false)
|
|
t.authDone = make(chan struct{})
|
|
t.authDoneOnce = sync.Once{}
|
|
t.mu.Unlock()
|
|
|
|
go t.readLoop(conn)
|
|
return nil
|
|
}
|
|
|
|
// Close sends a close frame and shuts down the connection.
|
|
func (t *WSTransport) Close() error {
|
|
t.closed.Store(true)
|
|
t.mu.Lock()
|
|
defer t.mu.Unlock()
|
|
if t.conn != nil {
|
|
_ = t.conn.WriteMessage(
|
|
websocket.CloseMessage,
|
|
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""),
|
|
)
|
|
err := t.conn.Close()
|
|
t.conn = nil
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Register sends auth message and waits for the registered response.
|
|
func (t *WSTransport) Register(ctx context.Context, req RegisterRequest) (*RegisterResponse, error) {
|
|
msg := wsAuthMessage{
|
|
Type: "auth",
|
|
APIKey: t.apiKey,
|
|
AgentID: req.AgentID,
|
|
Name: req.Name,
|
|
OS: req.OS,
|
|
Arch: req.Arch,
|
|
Version: req.Version,
|
|
DownloadDir: req.DownloadDir,
|
|
DiskFreeBytes: req.DiskFreeBytes,
|
|
DiskTotalBytes: req.DiskTotalBytes,
|
|
}
|
|
|
|
if err := t.send(msg); err != nil {
|
|
return nil, fmt.Errorf("ws auth send: %w", err)
|
|
}
|
|
|
|
// Wait for the auth response or context cancellation
|
|
select {
|
|
case <-t.authDone:
|
|
t.authMu.Lock()
|
|
resp := t.authResp
|
|
t.authMu.Unlock()
|
|
if resp == nil {
|
|
return nil, fmt.Errorf("ws auth: no response received")
|
|
}
|
|
return resp, nil
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case <-time.After(15 * time.Second):
|
|
return nil, fmt.Errorf("ws auth: timeout waiting for registered response")
|
|
}
|
|
}
|
|
|
|
// SendHeartbeat sends a heartbeat message. No blocking response in WS mode.
|
|
func (t *WSTransport) SendHeartbeat(_ context.Context, req HeartbeatRequest) (*HeartbeatResponse, error) {
|
|
msg := struct {
|
|
Type string `json:"type"`
|
|
Disk *struct {
|
|
Free int64 `json:"free"`
|
|
Total int64 `json:"total"`
|
|
} `json:"disk,omitempty"`
|
|
}{Type: "heartbeat"}
|
|
|
|
if req.DiskFreeBytes > 0 || req.DiskTotalBytes > 0 {
|
|
msg.Disk = &struct {
|
|
Free int64 `json:"free"`
|
|
Total int64 `json:"total"`
|
|
}{Free: req.DiskFreeBytes, Total: req.DiskTotalBytes}
|
|
}
|
|
|
|
if err := t.send(msg); err != nil {
|
|
return nil, err
|
|
}
|
|
// WS mode: heartbeat is fire-and-forget. Upgrade signals arrive via Events().
|
|
return &HeartbeatResponse{Success: true}, nil
|
|
}
|
|
|
|
// SendProgress sends a progress update. Control signals arrive async via Events().
|
|
func (t *WSTransport) SendProgress(_ context.Context, update StatusUpdate) (*StatusResponse, error) {
|
|
msg := struct {
|
|
Type string `json:"type"`
|
|
TaskID string `json:"taskId"`
|
|
Status string `json:"status,omitempty"`
|
|
Progress int `json:"progress,omitempty"`
|
|
DownloadedBytes int64 `json:"downloadedBytes,omitempty"`
|
|
TotalBytes int64 `json:"totalBytes,omitempty"`
|
|
SpeedBps int64 `json:"speedBps,omitempty"`
|
|
ETA int `json:"eta,omitempty"`
|
|
ResolvedMethod string `json:"resolvedMethod,omitempty"`
|
|
FileName string `json:"fileName,omitempty"`
|
|
FilePath string `json:"filePath,omitempty"`
|
|
StreamURL string `json:"streamUrl,omitempty"`
|
|
ErrorMessage string `json:"errorMessage,omitempty"`
|
|
}{
|
|
Type: "progress",
|
|
TaskID: update.TaskID,
|
|
Status: update.Status,
|
|
Progress: update.Progress,
|
|
DownloadedBytes: update.DownloadedBytes,
|
|
TotalBytes: update.TotalBytes,
|
|
SpeedBps: update.SpeedBps,
|
|
ETA: update.ETA,
|
|
ResolvedMethod: update.ResolvedMethod,
|
|
FileName: update.FileName,
|
|
FilePath: update.FilePath,
|
|
StreamURL: update.StreamURL,
|
|
ErrorMessage: update.ErrorMessage,
|
|
}
|
|
|
|
if err := t.send(msg); err != nil {
|
|
return nil, err
|
|
}
|
|
// In WS mode, control signals come via Events(), not in the progress response.
|
|
return &StatusResponse{Success: true}, nil
|
|
}
|
|
|
|
// ClaimTasks is a no-op in WS mode — tasks arrive via Events().
|
|
func (t *WSTransport) ClaimTasks(_ context.Context, _ string) (*TasksResponse, error) {
|
|
return &TasksResponse{}, nil
|
|
}
|
|
|
|
// Deregister is handled by WebSocket close (DO detects disconnection).
|
|
func (t *WSTransport) Deregister(_ context.Context, _ string) error {
|
|
return t.Close()
|
|
}
|
|
|
|
// ── Internal ─────────────────────────────────────────────────────────────────
|
|
|
|
func (t *WSTransport) send(msg any) error {
|
|
t.mu.Lock()
|
|
defer t.mu.Unlock()
|
|
if t.conn == nil {
|
|
return fmt.Errorf("ws: not connected")
|
|
}
|
|
data, err := json.Marshal(msg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return t.conn.WriteMessage(websocket.TextMessage, data)
|
|
}
|
|
|
|
func (t *WSTransport) readLoop(conn *websocket.Conn) {
|
|
for {
|
|
_, msg, err := conn.ReadMessage()
|
|
if err != nil {
|
|
if !t.closed.Load() {
|
|
log.Printf("[ws] read error: %v", err)
|
|
// Signal disconnection to the daemon
|
|
select {
|
|
case t.events <- ServerEvent{Type: "disconnected"}:
|
|
default:
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
var envelope struct {
|
|
Type string `json:"type"`
|
|
}
|
|
if err := json.Unmarshal(msg, &envelope); err != nil {
|
|
log.Printf("[ws] invalid message: %v", err)
|
|
continue
|
|
}
|
|
|
|
switch envelope.Type {
|
|
case "registered":
|
|
var resp wsRegisteredMessage
|
|
if json.Unmarshal(msg, &resp) == nil {
|
|
t.authMu.Lock()
|
|
t.authResp = &RegisterResponse{
|
|
Success: true,
|
|
User: resp.User,
|
|
Features: resp.Features,
|
|
}
|
|
t.authMu.Unlock()
|
|
// Signal that auth is complete (sync.Once prevents double-close panic)
|
|
t.authDoneOnce.Do(func() { close(t.authDone) })
|
|
}
|
|
|
|
case "tasks":
|
|
var resp wsTasksMessage
|
|
if json.Unmarshal(msg, &resp) == nil {
|
|
select {
|
|
case t.events <- ServerEvent{
|
|
Type: "tasks",
|
|
Tasks: &TasksResponse{
|
|
Tasks: resp.Tasks,
|
|
StreamRequests: resp.StreamRequests,
|
|
},
|
|
}:
|
|
default:
|
|
log.Printf("[ws] events channel full, dropping tasks message")
|
|
}
|
|
}
|
|
|
|
case "upgrade":
|
|
var resp wsUpgradeMessage
|
|
if json.Unmarshal(msg, &resp) == nil {
|
|
select {
|
|
case t.events <- ServerEvent{
|
|
Type: "upgrade",
|
|
Upgrade: &UpgradeSignal{Version: resp.Version},
|
|
}:
|
|
default:
|
|
}
|
|
}
|
|
|
|
case "control":
|
|
var resp ControlAction
|
|
if json.Unmarshal(msg, &resp) == nil {
|
|
select {
|
|
case t.events <- ServerEvent{
|
|
Type: "control",
|
|
Control: &resp,
|
|
}:
|
|
default:
|
|
}
|
|
}
|
|
|
|
case "error":
|
|
var resp struct {
|
|
Message string `json:"message"`
|
|
}
|
|
if json.Unmarshal(msg, &resp) == nil {
|
|
log.Printf("[ws] server error: %s", resp.Message)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// ── WS message types ─────────────────────────────────────────────────────────
|
|
|
|
type wsAuthMessage struct {
|
|
Type string `json:"type"`
|
|
APIKey string `json:"apiKey"`
|
|
AgentID string `json:"agentId"`
|
|
Name string `json:"name,omitempty"`
|
|
OS string `json:"os,omitempty"`
|
|
Arch string `json:"arch,omitempty"`
|
|
Version string `json:"version,omitempty"`
|
|
DownloadDir string `json:"downloadDir,omitempty"`
|
|
DiskFreeBytes int64 `json:"diskFreeBytes,omitempty"`
|
|
DiskTotalBytes int64 `json:"diskTotalBytes,omitempty"`
|
|
}
|
|
|
|
type wsRegisteredMessage struct {
|
|
Type string `json:"type"`
|
|
User UserInfo `json:"user"`
|
|
Features FeatureFlags `json:"features"`
|
|
}
|
|
|
|
type wsTasksMessage struct {
|
|
Type string `json:"type"`
|
|
Tasks []Task `json:"tasks"`
|
|
StreamRequests []StreamRequest `json:"streamRequests,omitempty"`
|
|
}
|
|
|
|
type wsUpgradeMessage struct {
|
|
Type string `json:"type"`
|
|
Version string `json:"version"`
|
|
}
|