214 lines
5.3 KiB
Go
214 lines
5.3 KiB
Go
package agent
|
|
|
|
import (
|
|
"context"
|
|
"log"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
)
|
|
|
|
// HybridTransport tries WebSocket first, falls back to HTTP if WS fails.
|
|
// Automatically reconnects WS in the background.
|
|
type HybridTransport struct {
|
|
ws *WSTransport
|
|
http *HTTPTransport
|
|
|
|
mode atomic.Value // "ws" or "http"
|
|
events chan ServerEvent
|
|
|
|
reconnectMu sync.Mutex
|
|
reconnectRunning bool
|
|
reconnectStop chan struct{}
|
|
closed atomic.Bool
|
|
}
|
|
|
|
// NewHybridTransport creates a transport that prefers WS with HTTP fallback.
|
|
func NewHybridTransport(ws *WSTransport, http *HTTPTransport) *HybridTransport {
|
|
h := &HybridTransport{
|
|
ws: ws,
|
|
http: http,
|
|
events: make(chan ServerEvent, 50),
|
|
reconnectStop: make(chan struct{}),
|
|
}
|
|
h.mode.Store("http") // start in HTTP, upgrade to WS on Connect
|
|
return h
|
|
}
|
|
|
|
func (h *HybridTransport) Mode() string { return h.mode.Load().(string) }
|
|
func (h *HybridTransport) Events() <-chan ServerEvent { return h.events }
|
|
|
|
// Connect tries WS first. If it fails, falls back to HTTP and starts reconnection loop.
|
|
func (h *HybridTransport) Connect(ctx context.Context) error {
|
|
// Try WebSocket first
|
|
if err := h.ws.Connect(ctx); err != nil {
|
|
log.Printf("[transport] WebSocket connect failed (%v), using HTTP fallback", err)
|
|
h.mode.Store("http")
|
|
h.startReconnectLoop()
|
|
return h.http.Connect(ctx)
|
|
}
|
|
|
|
h.mode.Store("ws")
|
|
log.Println("[transport] Connected via WebSocket")
|
|
|
|
// Forward WS events to unified channel + watch for disconnection
|
|
go h.forwardWSEvents()
|
|
|
|
return nil
|
|
}
|
|
|
|
// Close shuts down both transports and stops reconnection.
|
|
func (h *HybridTransport) Close() error {
|
|
h.closed.Store(true)
|
|
select {
|
|
case <-h.reconnectStop:
|
|
default:
|
|
close(h.reconnectStop)
|
|
}
|
|
_ = h.ws.Close()
|
|
return h.http.Close()
|
|
}
|
|
|
|
// Register delegates to the active transport.
|
|
func (h *HybridTransport) Register(ctx context.Context, req RegisterRequest) (*RegisterResponse, error) {
|
|
if h.mode.Load() == "ws" {
|
|
return h.ws.Register(ctx, req)
|
|
}
|
|
return h.http.Register(ctx, req)
|
|
}
|
|
|
|
// SendHeartbeat delegates to the active transport.
|
|
func (h *HybridTransport) SendHeartbeat(ctx context.Context, req HeartbeatRequest) (*HeartbeatResponse, error) {
|
|
if h.mode.Load() == "ws" {
|
|
resp, err := h.ws.SendHeartbeat(ctx, req)
|
|
if err != nil {
|
|
// WS write failed — switch to HTTP
|
|
h.switchToHTTP()
|
|
return h.http.SendHeartbeat(ctx, req)
|
|
}
|
|
return resp, nil
|
|
}
|
|
return h.http.SendHeartbeat(ctx, req)
|
|
}
|
|
|
|
// SendProgress delegates to the active transport.
|
|
func (h *HybridTransport) SendProgress(ctx context.Context, update StatusUpdate) (*StatusResponse, error) {
|
|
if h.mode.Load() == "ws" {
|
|
resp, err := h.ws.SendProgress(ctx, update)
|
|
if err != nil {
|
|
h.switchToHTTP()
|
|
return h.http.SendProgress(ctx, update)
|
|
}
|
|
return resp, nil
|
|
}
|
|
return h.http.SendProgress(ctx, update)
|
|
}
|
|
|
|
// ClaimTasks delegates to the active transport.
|
|
func (h *HybridTransport) ClaimTasks(ctx context.Context, agentID string) (*TasksResponse, error) {
|
|
if h.mode.Load() == "ws" {
|
|
return h.ws.ClaimTasks(ctx, agentID) // no-op in WS mode
|
|
}
|
|
return h.http.ClaimTasks(ctx, agentID)
|
|
}
|
|
|
|
// Deregister delegates to the active transport.
|
|
func (h *HybridTransport) Deregister(ctx context.Context, agentID string) error {
|
|
if h.mode.Load() == "ws" {
|
|
return h.ws.Deregister(ctx, agentID)
|
|
}
|
|
return h.http.Deregister(ctx, agentID)
|
|
}
|
|
|
|
// ── Internal ─────────────────────────────────────────────────────────────────
|
|
|
|
func (h *HybridTransport) switchToHTTP() {
|
|
if h.mode.Load() == "http" {
|
|
return
|
|
}
|
|
log.Println("[transport] Switching to HTTP fallback")
|
|
h.mode.Store("http")
|
|
_ = h.ws.Close()
|
|
h.startReconnectLoop()
|
|
}
|
|
|
|
func (h *HybridTransport) forwardWSEvents() {
|
|
for {
|
|
select {
|
|
case <-h.reconnectStop:
|
|
return
|
|
case event, ok := <-h.ws.Events():
|
|
if !ok {
|
|
return // channel closed
|
|
}
|
|
if event.Type == "disconnected" {
|
|
h.switchToHTTP()
|
|
select {
|
|
case h.events <- event:
|
|
default:
|
|
}
|
|
return
|
|
}
|
|
select {
|
|
case h.events <- event:
|
|
default:
|
|
log.Printf("[transport] events channel full, dropping %s event", event.Type)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (h *HybridTransport) startReconnectLoop() {
|
|
h.reconnectMu.Lock()
|
|
defer h.reconnectMu.Unlock()
|
|
if h.reconnectRunning {
|
|
return
|
|
}
|
|
h.reconnectRunning = true
|
|
go h.reconnectLoop()
|
|
}
|
|
|
|
func (h *HybridTransport) reconnectLoop() {
|
|
backoff := 5 * time.Second
|
|
maxBackoff := 60 * time.Second
|
|
|
|
for {
|
|
select {
|
|
case <-h.reconnectStop:
|
|
return
|
|
case <-time.After(backoff):
|
|
}
|
|
|
|
if h.closed.Load() {
|
|
return
|
|
}
|
|
|
|
// Already on WS? (someone else reconnected)
|
|
if h.mode.Load() == "ws" {
|
|
return
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
err := h.ws.Connect(ctx)
|
|
cancel()
|
|
|
|
if err != nil {
|
|
log.Printf("[transport] WS reconnect failed: %v (retry in %v)", err, backoff)
|
|
backoff = min(backoff*2, maxBackoff)
|
|
continue
|
|
}
|
|
|
|
// WS reconnected — switch back
|
|
log.Println("[transport] WebSocket reconnected")
|
|
h.mode.Store("ws")
|
|
|
|
// Reset reconnect flag so loop can start again if WS drops
|
|
h.reconnectMu.Lock()
|
|
h.reconnectRunning = false
|
|
h.reconnectMu.Unlock()
|
|
|
|
// Forward events from new WS connection
|
|
go h.forwardWSEvents()
|
|
return
|
|
}
|
|
}
|