445 lines
11 KiB
Go
445 lines
11 KiB
Go
package agent
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
// ── HTTP Transport Tests ─────────────────────────────────────────────────────
|
|
|
|
func TestHTTPTransportMode(t *testing.T) {
|
|
tr := NewHTTPTransport("http://localhost", "key", "ua")
|
|
if tr.Mode() != "http" {
|
|
t.Errorf("expected http, got %s", tr.Mode())
|
|
}
|
|
}
|
|
|
|
func TestHTTPTransportEventsNeverEmit(t *testing.T) {
|
|
tr := NewHTTPTransport("http://localhost", "key", "ua")
|
|
select {
|
|
case <-tr.Events():
|
|
t.Error("events channel should never emit in HTTP mode")
|
|
case <-time.After(50 * time.Millisecond):
|
|
// expected
|
|
}
|
|
}
|
|
|
|
func TestHTTPTransportDelegates(t *testing.T) {
|
|
// Mock server for register
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
json.NewEncoder(w).Encode(RegisterResponse{
|
|
Success: true,
|
|
User: UserInfo{Name: "Test", Plan: "pro"},
|
|
})
|
|
}))
|
|
defer srv.Close()
|
|
|
|
tr := NewHTTPTransport(srv.URL, "test-key", "test-agent")
|
|
resp, err := tr.Register(context.Background(), RegisterRequest{AgentID: "a1"})
|
|
if err != nil {
|
|
t.Fatalf("Register failed: %v", err)
|
|
}
|
|
if !resp.Success {
|
|
t.Error("expected success")
|
|
}
|
|
if resp.User.Name != "Test" {
|
|
t.Errorf("expected Test, got %s", resp.User.Name)
|
|
}
|
|
}
|
|
|
|
// ── WebSocket Transport Tests ────────────────────────────────────────────────
|
|
|
|
var upgrader = websocket.Upgrader{
|
|
CheckOrigin: func(r *http.Request) bool { return true },
|
|
}
|
|
|
|
func TestWSTransportConnectAndAuth(t *testing.T) {
|
|
var received wsAuthMessage
|
|
var mu sync.Mutex
|
|
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
t.Fatalf("upgrade: %v", err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
// Read auth message
|
|
_, msg, err := conn.ReadMessage()
|
|
if err != nil {
|
|
return
|
|
}
|
|
mu.Lock()
|
|
json.Unmarshal(msg, &received)
|
|
mu.Unlock()
|
|
|
|
// Send registered response
|
|
conn.WriteJSON(wsRegisteredMessage{
|
|
Type: "registered",
|
|
User: UserInfo{Name: "WS User", Plan: "pro", IsPro: true},
|
|
Features: FeatureFlags{Torrent: true},
|
|
})
|
|
|
|
// Keep connection open
|
|
time.Sleep(500 * time.Millisecond)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
|
|
tr := NewWSTransport(wsURL, "my-api-key", "agent-123", "test/1.0")
|
|
|
|
ctx := context.Background()
|
|
if err := tr.Connect(ctx); err != nil {
|
|
t.Fatalf("Connect failed: %v", err)
|
|
}
|
|
defer tr.Close()
|
|
|
|
resp, err := tr.Register(ctx, RegisterRequest{
|
|
AgentID: "agent-123",
|
|
Name: "test-agent",
|
|
Version: "1.0.0",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Register failed: %v", err)
|
|
}
|
|
if !resp.Success {
|
|
t.Error("expected success")
|
|
}
|
|
if resp.User.Name != "WS User" {
|
|
t.Errorf("expected WS User, got %s", resp.User.Name)
|
|
}
|
|
|
|
mu.Lock()
|
|
if received.APIKey != "my-api-key" {
|
|
t.Errorf("expected my-api-key, got %s", received.APIKey)
|
|
}
|
|
if received.AgentID != "agent-123" {
|
|
t.Errorf("expected agent-123, got %s", received.AgentID)
|
|
}
|
|
mu.Unlock()
|
|
}
|
|
|
|
func TestWSTransportReceiveTasks(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
// Read auth
|
|
conn.ReadMessage()
|
|
conn.WriteJSON(wsRegisteredMessage{
|
|
Type: "registered",
|
|
User: UserInfo{Name: "Test"},
|
|
})
|
|
|
|
// Push tasks
|
|
time.Sleep(50 * time.Millisecond)
|
|
conn.WriteJSON(wsTasksMessage{
|
|
Type: "tasks",
|
|
Tasks: []Task{
|
|
{ID: "t1", InfoHash: "abc123", Title: "Test Movie"},
|
|
{ID: "t2", InfoHash: "def456", Title: "Test Show"},
|
|
},
|
|
})
|
|
|
|
time.Sleep(500 * time.Millisecond)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
|
|
tr := NewWSTransport(wsURL, "key", "agent1", "ua")
|
|
|
|
ctx := context.Background()
|
|
tr.Connect(ctx)
|
|
defer tr.Close()
|
|
|
|
tr.Register(ctx, RegisterRequest{AgentID: "agent1"})
|
|
|
|
// Wait for tasks event
|
|
select {
|
|
case event := <-tr.Events():
|
|
if event.Type != "tasks" {
|
|
t.Errorf("expected tasks, got %s", event.Type)
|
|
}
|
|
if len(event.Tasks.Tasks) != 2 {
|
|
t.Errorf("expected 2 tasks, got %d", len(event.Tasks.Tasks))
|
|
}
|
|
if event.Tasks.Tasks[0].Title != "Test Movie" {
|
|
t.Errorf("expected Test Movie, got %s", event.Tasks.Tasks[0].Title)
|
|
}
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timeout waiting for tasks event")
|
|
}
|
|
}
|
|
|
|
func TestWSTransportReceiveControl(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
conn.ReadMessage()
|
|
conn.WriteJSON(wsRegisteredMessage{Type: "registered", User: UserInfo{}})
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
conn.WriteJSON(map[string]string{
|
|
"type": "control",
|
|
"action": "cancel",
|
|
"taskId": "task-99",
|
|
})
|
|
|
|
time.Sleep(500 * time.Millisecond)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
|
|
tr := NewWSTransport(wsURL, "key", "a1", "ua")
|
|
|
|
ctx := context.Background()
|
|
tr.Connect(ctx)
|
|
defer tr.Close()
|
|
tr.Register(ctx, RegisterRequest{AgentID: "a1"})
|
|
|
|
select {
|
|
case event := <-tr.Events():
|
|
if event.Type != "control" {
|
|
t.Errorf("expected control, got %s", event.Type)
|
|
}
|
|
if event.Control.Action != "cancel" {
|
|
t.Errorf("expected cancel, got %s", event.Control.Action)
|
|
}
|
|
if event.Control.TaskID != "task-99" {
|
|
t.Errorf("expected task-99, got %s", event.Control.TaskID)
|
|
}
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timeout waiting for control event")
|
|
}
|
|
}
|
|
|
|
func TestWSTransportReceiveUpgrade(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
conn.ReadMessage()
|
|
conn.WriteJSON(wsRegisteredMessage{Type: "registered", User: UserInfo{}})
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
conn.WriteJSON(wsUpgradeMessage{Type: "upgrade", Version: "2.0.0"})
|
|
|
|
time.Sleep(500 * time.Millisecond)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
|
|
tr := NewWSTransport(wsURL, "key", "a1", "ua")
|
|
|
|
ctx := context.Background()
|
|
tr.Connect(ctx)
|
|
defer tr.Close()
|
|
tr.Register(ctx, RegisterRequest{AgentID: "a1"})
|
|
|
|
select {
|
|
case event := <-tr.Events():
|
|
if event.Type != "upgrade" {
|
|
t.Errorf("expected upgrade, got %s", event.Type)
|
|
}
|
|
if event.Upgrade.Version != "2.0.0" {
|
|
t.Errorf("expected 2.0.0, got %s", event.Upgrade.Version)
|
|
}
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timeout waiting for upgrade event")
|
|
}
|
|
}
|
|
|
|
func TestWSTransportDisconnect(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
conn.ReadMessage()
|
|
conn.WriteJSON(wsRegisteredMessage{Type: "registered", User: UserInfo{}})
|
|
|
|
// Close after a short delay to simulate disconnection
|
|
time.Sleep(100 * time.Millisecond)
|
|
conn.Close()
|
|
}))
|
|
defer srv.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
|
|
tr := NewWSTransport(wsURL, "key", "a1", "ua")
|
|
|
|
ctx := context.Background()
|
|
tr.Connect(ctx)
|
|
defer tr.Close()
|
|
tr.Register(ctx, RegisterRequest{AgentID: "a1"})
|
|
|
|
select {
|
|
case event := <-tr.Events():
|
|
if event.Type != "disconnected" {
|
|
t.Errorf("expected disconnected, got %s", event.Type)
|
|
}
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timeout waiting for disconnected event")
|
|
}
|
|
}
|
|
|
|
func TestWSTransportSendProgress(t *testing.T) {
|
|
var receivedMsg map[string]interface{}
|
|
var mu sync.Mutex
|
|
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
// Read auth
|
|
conn.ReadMessage()
|
|
conn.WriteJSON(wsRegisteredMessage{Type: "registered", User: UserInfo{}})
|
|
|
|
// Read progress
|
|
_, msg, err := conn.ReadMessage()
|
|
if err != nil {
|
|
return
|
|
}
|
|
mu.Lock()
|
|
json.Unmarshal(msg, &receivedMsg)
|
|
mu.Unlock()
|
|
|
|
time.Sleep(500 * time.Millisecond)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
|
|
tr := NewWSTransport(wsURL, "key", "a1", "ua")
|
|
|
|
ctx := context.Background()
|
|
tr.Connect(ctx)
|
|
defer tr.Close()
|
|
tr.Register(ctx, RegisterRequest{AgentID: "a1"})
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
resp, err := tr.SendProgress(ctx, StatusUpdate{
|
|
TaskID: "t1",
|
|
Status: "downloading",
|
|
Progress: 42,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("SendProgress failed: %v", err)
|
|
}
|
|
if !resp.Success {
|
|
t.Error("expected success response")
|
|
}
|
|
|
|
time.Sleep(100 * time.Millisecond)
|
|
mu.Lock()
|
|
if receivedMsg["type"] != "progress" {
|
|
t.Errorf("expected progress, got %v", receivedMsg["type"])
|
|
}
|
|
if receivedMsg["taskId"] != "t1" {
|
|
t.Errorf("expected t1, got %v", receivedMsg["taskId"])
|
|
}
|
|
mu.Unlock()
|
|
}
|
|
|
|
// ── Hybrid Transport Tests ───────────────────────────────────────────────────
|
|
|
|
func TestHybridTransportWSSuccess(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
time.Sleep(500 * time.Millisecond)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
|
|
wsT := NewWSTransport(wsURL, "key", "a1", "ua")
|
|
httpT := NewHTTPTransport("http://localhost", "key", "ua")
|
|
|
|
h := NewHybridTransport(wsT, httpT)
|
|
err := h.Connect(context.Background())
|
|
if err != nil {
|
|
t.Fatalf("Connect failed: %v", err)
|
|
}
|
|
defer h.Close()
|
|
|
|
if h.Mode() != "ws" {
|
|
t.Errorf("expected ws mode, got %s", h.Mode())
|
|
}
|
|
}
|
|
|
|
func TestHybridTransportWSFailFallbackHTTP(t *testing.T) {
|
|
// WS URL points to nowhere
|
|
wsT := NewWSTransport("ws://127.0.0.1:1", "key", "a1", "ua")
|
|
httpT := NewHTTPTransport("http://localhost", "key", "ua")
|
|
|
|
h := NewHybridTransport(wsT, httpT)
|
|
err := h.Connect(context.Background())
|
|
if err != nil {
|
|
t.Fatalf("Connect should succeed with HTTP fallback: %v", err)
|
|
}
|
|
defer h.Close()
|
|
|
|
if h.Mode() != "http" {
|
|
t.Errorf("expected http mode after WS failure, got %s", h.Mode())
|
|
}
|
|
}
|
|
|
|
func TestHybridTransportWSDisconnectSwitchesToHTTP(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
return
|
|
}
|
|
// Close immediately to trigger disconnect
|
|
time.Sleep(100 * time.Millisecond)
|
|
conn.Close()
|
|
}))
|
|
defer srv.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
|
|
wsT := NewWSTransport(wsURL, "key", "a1", "ua")
|
|
httpT := NewHTTPTransport("http://localhost", "key", "ua")
|
|
|
|
h := NewHybridTransport(wsT, httpT)
|
|
h.Connect(context.Background())
|
|
defer h.Close()
|
|
|
|
// Wait for disconnect event
|
|
select {
|
|
case event := <-h.Events():
|
|
if event.Type != "disconnected" {
|
|
t.Errorf("expected disconnected, got %s", event.Type)
|
|
}
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timeout waiting for disconnected event")
|
|
}
|
|
|
|
// Mode should be HTTP now
|
|
time.Sleep(100 * time.Millisecond)
|
|
if h.Mode() != "http" {
|
|
t.Errorf("expected http after disconnect, got %s", h.Mode())
|
|
}
|
|
}
|