feat(stream): pion-based WebRTC byte streamer for browser playback

Replaces the broken anacrolix WebTorrent path with a custom WebRTC peer
that the browser drives directly. Architecture matches plan/clever-
weaving-dove.md (Fase 2 + 3 + 6 of the streaming pivot).

- engine/wire: shared 12-byte binary frame format (Hello / RangeReq /
  RangeData / RangeEnd / Cancel / Ping / Pong / SeekHint). Roundtrip +
  oversized-frame rejection tests.
- agent/signal_client: SSE consumer + POST sender for SDP/ICE relay
  through /api/internal/stream/signal/<id>; auto-reconnects.
- engine/webrtc_stream: pion v4 PeerConnection + DataChannel pump.
  Reads file via os.ReadAt, chunks RangeData at 16 KiB, honours app-
  level backpressure with SetBufferedAmountLowThreshold.
- cmd/daemon dispatcher learns mode webrtc_stream + new
  webrtcSessionRegistry tracks per-session cancel funcs for clean
  shutdown.
- engine/probe + hwaccel + transcoder: foundation for Fase 2.5
  (codec detection, NVENC/QSV/VAAPI/VideoToolbox autodetection,
  ffmpeg pipe wrapper to fragmented MP4). Integration into
  webrtc_stream still pending.
- pion/webrtc/v4 promoted from indirect to direct dep.

End-to-end against unarr-dev confirms a 122 MB 1080p H.264 / AAC MP4
plays in Chrome with the new pipeline.
This commit is contained in:
Deivid Soto 2026-05-06 23:12:38 +02:00
parent 4c52d9b039
commit 4314c06c5c
17 changed files with 2308 additions and 1 deletions

2
go.mod
View file

@ -13,6 +13,7 @@ require (
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
github.com/huin/goupnp v1.3.0 github.com/huin/goupnp v1.3.0
github.com/olekukonko/tablewriter v1.1.4 github.com/olekukonko/tablewriter v1.1.4
github.com/pion/webrtc/v4 v4.2.11
github.com/spf13/cobra v1.10.2 github.com/spf13/cobra v1.10.2
github.com/torrentclaw/go-client v0.2.0 github.com/torrentclaw/go-client v0.2.0
golang.org/x/term v0.41.0 golang.org/x/term v0.41.0
@ -105,7 +106,6 @@ require (
github.com/pion/stun/v3 v3.1.1 // indirect github.com/pion/stun/v3 v3.1.1 // indirect
github.com/pion/transport/v4 v4.0.1 // indirect github.com/pion/transport/v4 v4.0.1 // indirect
github.com/pion/turn/v4 v4.1.4 // indirect github.com/pion/turn/v4 v4.1.4 // indirect
github.com/pion/webrtc/v4 v4.2.11 // indirect
github.com/pkg/errors v0.9.1 // indirect github.com/pkg/errors v0.9.1 // indirect
github.com/protolambda/ctxlock v0.1.0 // indirect github.com/protolambda/ctxlock v0.1.0 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect

View file

@ -35,6 +35,7 @@ type Daemon struct {
// Callbacks — set by cmd/daemon.go before calling Run. // Callbacks — set by cmd/daemon.go before calling Run.
OnTasksClaimed func(tasks []Task) OnTasksClaimed func(tasks []Task)
OnStreamRequested func(req StreamRequest) OnStreamRequested func(req StreamRequest)
OnWebRTCSession func(sess WebRTCSession)
OnControlAction func(action, taskID string, deleteFiles bool) OnControlAction func(action, taskID string, deleteFiles bool)
GetActiveCount func() int // returns number of active downloads (wired from manager) GetActiveCount func() int // returns number of active downloads (wired from manager)
@ -169,6 +170,11 @@ func (d *Daemon) Run(ctx context.Context) error {
d.OnStreamRequested(req) d.OnStreamRequested(req)
} }
} }
d.sync.OnWebRTCSession = func(sess WebRTCSession) {
if d.OnWebRTCSession != nil {
d.OnWebRTCSession(sess)
}
}
d.sync.OnUpgrade = func(version string) { d.sync.OnUpgrade = func(version string) {
if version != d.lastNotifiedVersion { if version != d.lastNotifiedVersion {
d.lastNotifiedVersion = version d.lastNotifiedVersion = version

View file

@ -0,0 +1,233 @@
package agent
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
)
// SignalRole identifies who produced a signalling message. The opposite role
// receives it.
type SignalRole string
const (
SignalRoleBrowser SignalRole = "browser"
SignalRoleAgent SignalRole = "agent"
)
// SignalMessageType matches the server-side z.enum on
// /api/internal/stream/signal/[sessionId] route.
type SignalMessageType string
const (
SignalMsgOffer SignalMessageType = "offer"
SignalMsgAnswer SignalMessageType = "answer"
SignalMsgCandidate SignalMessageType = "candidate"
SignalMsgCandidateEnd SignalMessageType = "candidate-end"
SignalMsgBye SignalMessageType = "bye"
)
// SignalMessage mirrors the bus envelope on the web side.
type SignalMessage struct {
From SignalRole `json:"from"`
Type SignalMessageType `json:"type"`
Payload string `json:"payload"`
TS int64 `json:"ts"`
}
// PostSignal enqueues a signalling message produced by this agent. The
// browser receives it on its next SSE event push.
func (c *Client) PostSignal(ctx context.Context, sessionID string, msg SignalMessage) error {
body := map[string]any{
"from": string(SignalRoleAgent),
"type": string(msg.Type),
"payload": msg.Payload,
}
path := fmt.Sprintf("/api/internal/stream/signal/%s", sessionID)
return c.doPost(ctx, path, body, &struct {
OK bool `json:"ok"`
}{})
}
// SignalEventStream wraps an open SSE connection. Read messages from Events()
// until the channel closes (server timeout or context cancel). Always defer
// Close() to release the underlying response body.
type SignalEventStream struct {
resp *http.Response
cancel context.CancelFunc
events chan SignalMessage
errs chan error
done chan struct{}
}
// Events streams browser-produced messages addressed to the agent.
// The channel closes when the SSE connection ends; the caller should then
// call Close() and reopen if it wants to keep listening.
func (s *SignalEventStream) Events() <-chan SignalMessage { return s.events }
// Err returns the terminating error (if any) once Events() has closed.
func (s *SignalEventStream) Err() error {
select {
case err := <-s.errs:
return err
default:
return nil
}
}
// Close cancels the underlying HTTP request and waits for the reader goroutine
// to drain. Safe to call more than once.
func (s *SignalEventStream) Close() error {
if s.cancel != nil {
s.cancel()
}
if s.resp != nil {
s.resp.Body.Close()
}
<-s.done
return nil
}
// OpenSignalStream opens a long-lived SSE connection to the signal events
// endpoint. Caller MUST cancel ctx (or call Close()) to free resources.
//
// The server caps each response at ~25 s; OpenSignalStream surfaces the
// disconnect by closing the events channel. Caller should reopen until the
// session ends.
func (c *Client) OpenSignalStream(ctx context.Context, sessionID string) (*SignalEventStream, error) {
streamCtx, cancel := context.WithCancel(ctx)
url := fmt.Sprintf("%s/api/internal/stream/signal/%s/events", c.baseURL, sessionID)
req, err := http.NewRequestWithContext(streamCtx, http.MethodGet, url, nil)
if err != nil {
cancel()
return nil, fmt.Errorf("open signal stream: %w", err)
}
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Authorization", "Bearer "+c.apiKey)
req.Header.Set("User-Agent", c.userAgent)
req.Header.Set("Cache-Control", "no-cache")
// Use a per-call client with no timeout (SSE connections are long).
sseClient := &http.Client{}
resp, err := sseClient.Do(req)
if err != nil {
cancel()
return nil, fmt.Errorf("open signal stream: %w", err)
}
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
resp.Body.Close()
cancel()
return nil, fmt.Errorf("open signal stream: HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
stream := &SignalEventStream{
resp: resp,
cancel: cancel,
events: make(chan SignalMessage, 8),
errs: make(chan error, 1),
done: make(chan struct{}),
}
go stream.read()
return stream, nil
}
func (s *SignalEventStream) read() {
defer close(s.done)
defer close(s.events)
reader := bufio.NewReaderSize(s.resp.Body, 16*1024)
var dataBuf bytes.Buffer
var eventName string
for {
line, err := reader.ReadString('\n')
if err != nil {
if err != io.EOF {
select {
case s.errs <- err:
default:
}
}
return
}
line = strings.TrimRight(line, "\r\n")
if line == "" {
// End of an event — dispatch if we have data.
if dataBuf.Len() == 0 {
eventName = ""
continue
}
if eventName == "" || eventName == "signal" {
var msg SignalMessage
if err := json.Unmarshal(dataBuf.Bytes(), &msg); err == nil {
s.events <- msg
}
}
dataBuf.Reset()
eventName = ""
continue
}
if strings.HasPrefix(line, ":") {
// SSE comment (heartbeat); ignore.
continue
}
if strings.HasPrefix(line, "event:") {
eventName = strings.TrimSpace(line[len("event:"):])
continue
}
if strings.HasPrefix(line, "data:") {
payload := strings.TrimSpace(line[len("data:"):])
if dataBuf.Len() > 0 {
dataBuf.WriteByte('\n')
}
dataBuf.WriteString(payload)
continue
}
// id:, retry:, anything else — ignore for now.
}
}
// SignalLoop runs an SSE consumer that reconnects automatically on disconnect.
// onMessage is called for every browser-produced message. Returns when ctx is
// cancelled. Reconnect backoff is fixed at 1 s — the server already paces
// reconnects with `retry: 1500` headers so churn is bounded.
func (c *Client) SignalLoop(ctx context.Context, sessionID string, onMessage func(SignalMessage)) error {
for ctx.Err() == nil {
stream, err := c.OpenSignalStream(ctx, sessionID)
if err != nil {
select {
case <-time.After(time.Second):
case <-ctx.Done():
return ctx.Err()
}
continue
}
for msg := range stream.Events() {
onMessage(msg)
}
streamErr := stream.Err()
stream.Close()
if ctx.Err() != nil {
return ctx.Err()
}
// Server closes the SSE every ~25 s; reconnect immediately.
// Hard error → small backoff so we don't hammer.
if streamErr != nil {
select {
case <-time.After(time.Second):
case <-ctx.Done():
return ctx.Err()
}
}
}
return ctx.Err()
}

View file

@ -0,0 +1,153 @@
package agent
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
)
// fakeSSEServer streams a fixed set of SSE events then closes the connection.
func fakeSSEServer(t *testing.T, msgs []SignalMessage, holdOpenAfter bool) *httptest.Server {
t.Helper()
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Authorization") != "Bearer test-key" {
http.Error(w, "auth", http.StatusUnauthorized)
return
}
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
flusher, ok := w.(http.Flusher)
if !ok {
t.Fatal("server: ResponseWriter is not a Flusher")
}
fmt.Fprint(w, "retry: 1500\n\n")
flusher.Flush()
for _, m := range msgs {
data, _ := json.Marshal(m)
fmt.Fprintf(w, "id: %d\nevent: signal\ndata: %s\n\n", m.TS, data)
flusher.Flush()
}
// Send a heartbeat comment to verify it's ignored.
fmt.Fprint(w, ": heartbeat\n\n")
flusher.Flush()
if holdOpenAfter {
// Hold the connection until the client disconnects so the test can
// exercise stream.Close().
<-r.Context().Done()
}
}))
}
func TestSignalStreamReadsMessages(t *testing.T) {
want := []SignalMessage{
{From: SignalRoleBrowser, Type: SignalMsgOffer, Payload: "{sdp:1}", TS: 1},
{From: SignalRoleBrowser, Type: SignalMsgCandidate, Payload: "{cand:1}", TS: 2},
}
srv := fakeSSEServer(t, want, false)
defer srv.Close()
c := NewClient(srv.URL, "test-key", "test-ua")
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
stream, err := c.OpenSignalStream(ctx, "session-1")
if err != nil {
t.Fatalf("open: %v", err)
}
defer stream.Close()
var got []SignalMessage
for m := range stream.Events() {
got = append(got, m)
if len(got) == len(want) {
break
}
}
if len(got) != len(want) {
t.Fatalf("got %d messages, want %d", len(got), len(want))
}
for i, m := range got {
if m.From != want[i].From || m.Type != want[i].Type || m.Payload != want[i].Payload {
t.Errorf("[%d] mismatch: %+v want %+v", i, m, want[i])
}
}
}
func TestSignalStreamPropagatesAuthError(t *testing.T) {
srv := fakeSSEServer(t, nil, false)
defer srv.Close()
c := NewClient(srv.URL, "wrong-key", "test-ua")
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_, err := c.OpenSignalStream(ctx, "session-1")
if err == nil {
t.Fatal("expected auth error, got nil")
}
}
func TestSignalStreamCloseCancelsRead(t *testing.T) {
srv := fakeSSEServer(t, nil, true)
defer srv.Close()
c := NewClient(srv.URL, "test-key", "test-ua")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
stream, err := c.OpenSignalStream(ctx, "session-1")
if err != nil {
t.Fatalf("open: %v", err)
}
// Close on a separate goroutine then make sure the events channel drains.
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
time.Sleep(50 * time.Millisecond)
stream.Close()
}()
for range stream.Events() {
// drain
}
wg.Wait()
}
func TestPostSignalSendsCorrectBody(t *testing.T) {
var bodySeen map[string]any
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Authorization") != "Bearer test-key" {
http.Error(w, "auth", http.StatusUnauthorized)
return
}
_ = json.NewDecoder(r.Body).Decode(&bodySeen)
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, `{"ok":true}`)
}))
defer srv.Close()
c := NewClient(srv.URL, "test-key", "test-ua")
err := c.PostSignal(context.Background(), "sess-x", SignalMessage{
Type: SignalMsgAnswer,
Payload: "{sdp:answer}",
})
if err != nil {
t.Fatalf("post: %v", err)
}
if bodySeen["from"] != string(SignalRoleAgent) {
t.Errorf("expected from=agent, got %v", bodySeen["from"])
}
if bodySeen["type"] != string(SignalMsgAnswer) {
t.Errorf("expected type=answer, got %v", bodySeen["type"])
}
if bodySeen["payload"] != "{sdp:answer}" {
t.Errorf("expected payload mismatch, got %v", bodySeen["payload"])
}
}

View file

@ -29,6 +29,7 @@ type SyncClient struct {
OnNewTasks func(tasks []Task) OnNewTasks func(tasks []Task)
OnControl func(action, taskID string, deleteFiles bool) OnControl func(action, taskID string, deleteFiles bool)
OnStreamRequest func(req StreamRequest) OnStreamRequest func(req StreamRequest)
OnWebRTCSession func(sess WebRTCSession)
OnUpgrade func(version string) OnUpgrade func(version string)
OnScan func() OnScan func()
OnWatchingChange func(watching bool) OnWatchingChange func(watching bool)
@ -191,6 +192,13 @@ func (sc *SyncClient) processResponse(resp *SyncResponse) {
} }
} }
// WebRTC streaming sessions
for _, ws := range resp.WebRTCSessions {
if sc.OnWebRTCSession != nil {
sc.OnWebRTCSession(ws)
}
}
// Upgrade // Upgrade
if resp.Upgrade != nil && resp.Upgrade.Version != "" && sc.OnUpgrade != nil { if resp.Upgrade != nil && resp.Upgrade.Version != "" && sc.OnUpgrade != nil {
sc.OnUpgrade(resp.Upgrade.Version) sc.OnUpgrade(resp.Upgrade.Version)

View file

@ -351,11 +351,25 @@ type LibraryDeleteRequest struct {
FilePath string `json:"filePath"` FilePath string `json:"filePath"`
} }
// WebRTCSession is a request to open a custom WebRTC DataChannel byte-stream
// to a browser player. The CLI must POST an SDP answer to
// /api/internal/stream/signal/<sessionId> and serve bytes from FilePath
// (or, when only InfoHash is set, from a download_task on disk).
type WebRTCSession struct {
SessionID string `json:"sessionId"`
FilePath string `json:"filePath,omitempty"`
InfoHash string `json:"infoHash,omitempty"`
TaskID string `json:"taskId,omitempty"`
FileName string `json:"fileName,omitempty"`
FileSize int64 `json:"fileSize,omitempty"`
}
// SyncResponse is returned by the server with all pending actions for the CLI. // SyncResponse is returned by the server with all pending actions for the CLI.
type SyncResponse struct { type SyncResponse struct {
NewTasks []Task `json:"newTasks,omitempty"` NewTasks []Task `json:"newTasks,omitempty"`
Controls []ControlAction `json:"controls,omitempty"` Controls []ControlAction `json:"controls,omitempty"`
StreamRequests []StreamRequest `json:"streamRequests,omitempty"` StreamRequests []StreamRequest `json:"streamRequests,omitempty"`
WebRTCSessions []WebRTCSession `json:"webrtcSessions,omitempty"`
Watching bool `json:"watching"` Watching bool `json:"watching"`
Upgrade *UpgradeSignal `json:"upgrade,omitempty"` Upgrade *UpgradeSignal `json:"upgrade,omitempty"`
Scan bool `json:"scan,omitempty"` Scan bool `json:"scan,omitempty"`

View file

@ -410,6 +410,65 @@ func runDaemonStart() error {
}() }()
} }
// Wire: sync receives custom WebRTC streaming session requests.
// Each session is a one-shot browser↔daemon DataChannel. Validate the
// FilePath against allowed dirs to prevent path traversal abuse from a
// compromised server, then spawn the pion peer in its own goroutine.
d.OnWebRTCSession = func(sess agent.WebRTCSession) {
if webrtcRegistry.has(sess.SessionID) {
return // already running
}
if !cfg.Download.WebRTC.Enabled {
log.Printf("webrtc session %s rejected: webrtc disabled in config", agent.ShortID(sess.SessionID))
return
}
filePath := sess.FilePath
if filePath == "" {
log.Printf("webrtc session %s rejected: empty file path", agent.ShortID(sess.SessionID))
return
}
filePath = filepath.Clean(filePath)
if !isAllowedStreamPath(filePath, cfg.Download.Dir, cfg.Library.ScanPath,
cfg.Organize.MoviesDir, cfg.Organize.TVShowsDir) {
log.Printf("webrtc session %s rejected: path outside allowed dirs: %s",
agent.ShortID(sess.SessionID), filePath)
return
}
// Resolve directory → first video file (matches StreamRequest behavior).
if info, err := os.Stat(filePath); err == nil && info.IsDir() {
found := engine.FindVideoFile(filePath)
if found == "" {
log.Printf("webrtc session %s rejected: no video file in dir %s",
agent.ShortID(sess.SessionID), filePath)
return
}
filePath = found
}
sessCtx, sessCancel := context.WithCancel(ctx) //nolint:gosec // G118 cancel stored in registry
webrtcRegistry.add(sess.SessionID, sessCancel)
go func() {
defer func() {
webrtcRegistry.remove(sess.SessionID)
sessCancel()
}()
runCfg := engine.WebRTCStreamConfig{
SessionID: sess.SessionID,
FilePath: filePath,
FileName: sess.FileName,
FileSize: sess.FileSize,
ICEServers: engine.BuildICEServers(cfg.Download.WebRTC),
Signal: agentClient,
Logger: stdLogger{},
}
log.Printf("[wrtc %s] starting session: %s", agent.ShortID(sess.SessionID), filepath.Base(filePath))
if err := engine.RunWebRTCStream(sessCtx, runCfg); err != nil {
if sessCtx.Err() == nil {
log.Printf("[wrtc %s] ended: %v", agent.ShortID(sess.SessionID), err)
}
}
}()
}
// Periodic DHT node persistence (every 5 min) // Periodic DHT node persistence (every 5 min)
go func() { go func() {
ticker := time.NewTicker(5 * time.Minute) ticker := time.NewTicker(5 * time.Minute)
@ -457,6 +516,7 @@ func runDaemonStart() error {
case sig := <-sigCh: case sig := <-sigCh:
fmt.Printf("\n Received %s, shutting down...\n", sig) fmt.Printf("\n Received %s, shutting down...\n", sig)
cancelStreamContexts() cancelStreamContexts()
cancelAllWebRTCSessions()
streamSrv.Shutdown(context.Background()) streamSrv.Shutdown(context.Background())
cancel() cancel()
@ -471,6 +531,7 @@ func runDaemonStart() error {
case err := <-errCh: case err := <-errCh:
cancelStreamContexts() cancelStreamContexts()
cancelAllWebRTCSessions()
streamSrv.Shutdown(context.Background()) streamSrv.Shutdown(context.Background())
cancel() cancel()
return err return err

View file

@ -0,0 +1,62 @@
package cmd
import (
"context"
"log"
"sync"
)
// webrtcRegistry tracks per-session cancel funcs for active custom WebRTC
// streams (engine.RunWebRTCStream goroutines). Each session lives only as
// long as its DataChannel; the registry exists so duplicate sync responses
// don't double-spawn the same session and so daemon shutdown can drain.
var webrtcRegistry = &webrtcSessionRegistry{
cancels: make(map[string]context.CancelFunc),
}
type webrtcSessionRegistry struct {
mu sync.Mutex
cancels map[string]context.CancelFunc
}
func (r *webrtcSessionRegistry) has(sessionID string) bool {
r.mu.Lock()
defer r.mu.Unlock()
_, ok := r.cancels[sessionID]
return ok
}
func (r *webrtcSessionRegistry) add(sessionID string, cancel context.CancelFunc) {
r.mu.Lock()
defer r.mu.Unlock()
r.cancels[sessionID] = cancel
}
func (r *webrtcSessionRegistry) remove(sessionID string) {
r.mu.Lock()
defer r.mu.Unlock()
delete(r.cancels, sessionID)
}
// cancelAllWebRTCSessions cancels every running session. Called on daemon
// shutdown so pion peers and SSE consumers exit cleanly.
func cancelAllWebRTCSessions() {
webrtcRegistry.mu.Lock()
cancels := make([]context.CancelFunc, 0, len(webrtcRegistry.cancels))
for _, c := range webrtcRegistry.cancels {
cancels = append(cancels, c)
}
webrtcRegistry.cancels = make(map[string]context.CancelFunc)
webrtcRegistry.mu.Unlock()
for _, c := range cancels {
c()
}
}
// stdLogger is a tiny adapter so engine.RunWebRTCStream can log through the
// standard library logger without pulling in a logging dependency.
type stdLogger struct{}
func (stdLogger) Infof(format string, args ...any) { log.Printf(format, args...) }
func (stdLogger) Warnf(format string, args ...any) { log.Printf("WARN: "+format, args...) }
func (stdLogger) Errorf(format string, args ...any) { log.Printf("ERROR: "+format, args...) }

130
internal/engine/hwaccel.go Normal file
View file

@ -0,0 +1,130 @@
package engine
import (
"context"
"os"
"os/exec"
"runtime"
"strings"
"sync"
)
// HWAccel identifies a hardware-accelerated ffmpeg encoder family.
type HWAccel string
const (
HWAccelNone HWAccel = "none"
HWAccelNVENC HWAccel = "nvenc" // NVIDIA — h264_nvenc / hevc_nvenc
HWAccelQSV HWAccel = "qsv" // Intel Quick Sync — h264_qsv / hevc_qsv
HWAccelVAAPI HWAccel = "vaapi" // Linux open-source — h264_vaapi / hevc_vaapi
HWAccelVideoToolbox HWAccel = "videotoolbox" // macOS — h264_videotoolbox
)
var (
hwOnce sync.Once
hwCache HWAccel
)
// DetectHWAccel returns the most capable hardware encoder available on this
// host, or HWAccelNone if software-only. Cached after first call — adding /
// removing a GPU at runtime is rare and the cost of probing isn't free.
func DetectHWAccel(ctx context.Context, ffmpegPath string) HWAccel {
hwOnce.Do(func() {
hwCache = detectHWAccelFresh(ctx, ffmpegPath)
})
return hwCache
}
// ResetHWAccelCache clears the singleton — only used in tests.
func ResetHWAccelCache() {
hwOnce = sync.Once{}
hwCache = ""
}
func detectHWAccelFresh(ctx context.Context, ffmpegPath string) HWAccel {
if ffmpegPath == "" {
return HWAccelNone
}
encoders := listFFmpegEncoders(ctx, ffmpegPath)
if encoders == "" {
return HWAccelNone
}
// macOS — VideoToolbox is always available on Apple Silicon + recent Intel.
if runtime.GOOS == "darwin" && strings.Contains(encoders, "h264_videotoolbox") {
return HWAccelVideoToolbox
}
// NVIDIA — encoder presence + a CUDA-capable device. We rely on the
// existence of the device file rather than running nvidia-smi to keep
// startup quick on hosts without nvidia tooling.
if strings.Contains(encoders, "h264_nvenc") &&
(fileExists("/dev/nvidia0") || hasNvidiaDriver()) {
return HWAccelNVENC
}
// Intel Quick Sync — needs /dev/dri (also used by VA-API). Distinguish by
// checking whether the QSV-specific encoder is built in.
if strings.Contains(encoders, "h264_qsv") && fileExists("/dev/dri/renderD128") {
return HWAccelQSV
}
// Linux generic VA-API — works on Intel + AMD with mesa drivers.
if strings.Contains(encoders, "h264_vaapi") && fileExists("/dev/dri/renderD128") {
return HWAccelVAAPI
}
return HWAccelNone
}
func listFFmpegEncoders(ctx context.Context, ffmpegPath string) string {
cmd := exec.CommandContext(ctx, ffmpegPath, "-hide_banner", "-encoders")
out, err := cmd.CombinedOutput()
if err != nil {
return ""
}
return string(out)
}
func fileExists(path string) bool {
_, err := os.Stat(path)
return err == nil
}
func hasNvidiaDriver() bool {
// Cheap proxy — if the user has nvidia-smi on PATH they presumably also
// have a working driver / runtime libraries.
_, err := exec.LookPath("nvidia-smi")
return err == nil
}
// FFmpegVideoCodec returns the encoder name to pass to `-c:v` for the
// requested HW accel + target (h264 or hevc).
func (h HWAccel) FFmpegVideoCodec(target string) string {
target = strings.ToLower(target)
switch h {
case HWAccelNVENC:
if target == "hevc" {
return "hevc_nvenc"
}
return "h264_nvenc"
case HWAccelQSV:
if target == "hevc" {
return "hevc_qsv"
}
return "h264_qsv"
case HWAccelVAAPI:
if target == "hevc" {
return "hevc_vaapi"
}
return "h264_vaapi"
case HWAccelVideoToolbox:
if target == "hevc" {
return "hevc_videotoolbox"
}
return "h264_videotoolbox"
default:
// Software fallback. libx264 ships with every ffmpeg build.
return "libx264"
}
}

View file

@ -0,0 +1,34 @@
package engine
import "testing"
func TestHWAccelFFmpegVideoCodec(t *testing.T) {
cases := []struct {
hw HWAccel
target string
want string
}{
{HWAccelNone, "h264", "libx264"},
{HWAccelNone, "hevc", "libx264"},
{HWAccelNVENC, "h264", "h264_nvenc"},
{HWAccelNVENC, "hevc", "hevc_nvenc"},
{HWAccelQSV, "h264", "h264_qsv"},
{HWAccelQSV, "hevc", "hevc_qsv"},
{HWAccelVAAPI, "h264", "h264_vaapi"},
{HWAccelVAAPI, "hevc", "hevc_vaapi"},
{HWAccelVideoToolbox, "h264", "h264_videotoolbox"},
{HWAccelVideoToolbox, "hevc", "hevc_videotoolbox"},
}
for _, tc := range cases {
if got := tc.hw.FFmpegVideoCodec(tc.target); got != tc.want {
t.Errorf("%s.FFmpegVideoCodec(%q) = %q want %q", tc.hw, tc.target, got, tc.want)
}
}
}
func TestDetectHWAccelEmptyPathReturnsNone(t *testing.T) {
ResetHWAccelCache()
if got := detectHWAccelFresh(t.Context(), ""); got != HWAccelNone {
t.Errorf("got %s, want %s", got, HWAccelNone)
}
}

116
internal/engine/probe.go Normal file
View file

@ -0,0 +1,116 @@
package engine
import (
"context"
"fmt"
"strings"
"github.com/torrentclaw/unarr/internal/library/mediainfo"
)
// StreamProbe summarises the codec / container shape of a file as it relates
// to the WebRTC streaming pipeline. It tells the transcoder whether bytes can
// be streamed as-is, just remuxed to fragmented MP4, or fully transcoded.
type StreamProbe struct {
// VideoCodec lowercased — e.g. "h264", "hevc", "av1", "vp9", "mpeg4".
VideoCodec string
// AudioCodec lowercased — e.g. "aac", "ac3", "dts", "eac3", "opus".
AudioCodec string
// Width / Height of the primary video stream.
Width int
Height int
// BitDepth — 8, 10 or 12. 0 if unknown.
BitDepth int
// HDR signalling string ("HDR10" / "DV" / "HLG" / etc, or "" for SDR).
HDR string
// DurationSec is the file length, used to sanity-check seek targets.
DurationSec float64
// Container is the file extension lowercased (".mp4", ".mkv", ".avi").
Container string
}
// TranscodeAction tells the streaming pipeline how to feed the file to
// the browser <video> element. The decision matrix is documented in the
// project plan (Fase 2.5 — Transcoding on-the-fly).
type TranscodeAction string
const (
// ActionPassthrough — file is already browser-playable as-is. Stream the
// raw bytes via ReadAt; no ffmpeg involved.
ActionPassthrough TranscodeAction = "passthrough"
// ActionRemux — codecs are browser-compatible but the container or moov
// placement is not. Run ffmpeg with `-c copy -movflags frag_keyframe`.
ActionRemux TranscodeAction = "remux"
// ActionRemuxAudio — video is fine but audio needs a re-encode (AC3/DTS
// → AAC). `-c:v copy -c:a aac`.
ActionRemuxAudio TranscodeAction = "remux-audio"
// ActionTranscodeVideo — full re-encode. Used for HEVC/AV1 and any
// 10-bit content if the browser refuses the codec.
ActionTranscodeVideo TranscodeAction = "transcode-video"
)
// ProbeFile runs ffprobe and returns a StreamProbe view of the file.
func ProbeFile(ctx context.Context, ffprobePath, filePath string) (*StreamProbe, error) {
mi, err := mediainfo.ExtractMediaInfo(ctx, ffprobePath, filePath)
if err != nil {
return nil, fmt.Errorf("probe: %w", err)
}
probe := &StreamProbe{Container: lowerExt(filePath)}
if mi.Video != nil {
probe.VideoCodec = strings.ToLower(mi.Video.Codec)
probe.Width = mi.Video.Width
probe.Height = mi.Video.Height
probe.BitDepth = mi.Video.BitDepth
probe.HDR = mi.Video.HDR
probe.DurationSec = mi.Video.Duration
}
if len(mi.Audio) > 0 {
// Default to the first track marked "Default", else the first track.
picked := mi.Audio[0]
for _, a := range mi.Audio {
if a.Default {
picked = a
break
}
}
probe.AudioCodec = strings.ToLower(picked.Codec)
}
return probe, nil
}
// DecideAction maps a probe to the transcoding action the streaming pipeline
// should take. Browsers consume MP4/h264+AAC natively; everything else needs
// some level of re-shaping.
func DecideAction(p *StreamProbe) TranscodeAction {
if p == nil {
return ActionPassthrough
}
video := p.VideoCodec
audio := p.AudioCodec
container := p.Container
// 10-bit / HDR is a hard no for browser playback even if h264 — needs SW transcode.
tenBitOrHDR := p.BitDepth >= 10 || p.HDR != ""
if !tenBitOrHDR && video == "h264" {
if audio == "aac" {
if container == ".mp4" {
return ActionPassthrough
}
return ActionRemux
}
// Audio incompatible (AC3/DTS/TrueHD/EAC3) → remux video, transcode audio.
return ActionRemuxAudio
}
// HEVC / AV1 / VP9 / 10-bit / unknown → full re-encode video.
return ActionTranscodeVideo
}
func lowerExt(filePath string) string {
dot := strings.LastIndex(filePath, ".")
if dot < 0 {
return ""
}
return strings.ToLower(filePath[dot:])
}

View file

@ -0,0 +1,96 @@
package engine
import "testing"
func TestDecideAction(t *testing.T) {
cases := []struct {
name string
p StreamProbe
want TranscodeAction
}{
{
name: "MP4 + h264 + AAC = passthrough",
p: StreamProbe{VideoCodec: "h264", AudioCodec: "aac", Container: ".mp4"},
want: ActionPassthrough,
},
{
name: "MKV + h264 + AAC = remux",
p: StreamProbe{VideoCodec: "h264", AudioCodec: "aac", Container: ".mkv"},
want: ActionRemux,
},
{
name: "MKV + h264 + AC3 = remux audio",
p: StreamProbe{VideoCodec: "h264", AudioCodec: "ac3", Container: ".mkv"},
want: ActionRemuxAudio,
},
{
name: "MP4 + h264 + EAC3 = remux audio",
p: StreamProbe{VideoCodec: "h264", AudioCodec: "eac3", Container: ".mp4"},
want: ActionRemuxAudio,
},
{
name: "MKV + HEVC = transcode video",
p: StreamProbe{VideoCodec: "hevc", AudioCodec: "aac", Container: ".mkv"},
want: ActionTranscodeVideo,
},
{
name: "MP4 + AV1 = transcode video",
p: StreamProbe{VideoCodec: "av1", AudioCodec: "aac", Container: ".mp4"},
want: ActionTranscodeVideo,
},
{
name: "h264 10-bit = transcode video (browser refuses)",
p: StreamProbe{VideoCodec: "h264", AudioCodec: "aac", BitDepth: 10, Container: ".mp4"},
want: ActionTranscodeVideo,
},
{
name: "h264 + HDR10 = transcode video",
p: StreamProbe{VideoCodec: "h264", AudioCodec: "aac", HDR: "HDR10", Container: ".mp4"},
want: ActionTranscodeVideo,
},
{
name: "AVI + h264 + AAC = remux",
p: StreamProbe{VideoCodec: "h264", AudioCodec: "aac", Container: ".avi"},
want: ActionRemux,
},
{
name: "Unknown codec = transcode video",
p: StreamProbe{VideoCodec: "mpeg4", AudioCodec: "mp3", Container: ".avi"},
want: ActionTranscodeVideo,
},
{
name: "Empty probe falls through to transcode (unknown codec)",
p: StreamProbe{},
want: ActionTranscodeVideo,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got := DecideAction(&tc.p)
if got != tc.want {
t.Errorf("got %s, want %s", got, tc.want)
}
})
}
}
func TestDecideActionNil(t *testing.T) {
if DecideAction(nil) != ActionPassthrough {
t.Error("nil probe should default passthrough")
}
}
func TestLowerExt(t *testing.T) {
cases := map[string]string{
"foo.MP4": ".mp4",
"path/to/movie.MKV": ".mkv",
"weird.name.with.dots": ".dots",
"": "",
"noext": "",
}
for in, want := range cases {
if got := lowerExt(in); got != want {
t.Errorf("lowerExt(%q) = %q want %q", in, got, want)
}
}
}

View file

@ -0,0 +1,179 @@
package engine
import (
"context"
"fmt"
"io"
"os/exec"
"strconv"
"strings"
"sync"
"time"
)
// TranscodeOpts steers how Transcoder builds its ffmpeg command line. Defaults
// match the project's plan/clever-weaving-dove.md (Fase 2.5):
//
// - Output: fragmented MP4 readable by browser <video> via MSE-less Range.
// - Audio: AAC stereo @ 192kbps unless source already AAC (then -c:a copy).
// - Video: copy when h264 8-bit; otherwise transcode to h264 with HW encode
// when available, software fallback at "veryfast" preset.
type TranscodeOpts struct {
Action TranscodeAction
HWAccel HWAccel
Preset string // "veryfast" / "fast" / "medium"
VideoBitrate string // e.g. "5M"
AudioBitrate string // e.g. "192k"
MaxHeight int // optional downscale cap (e.g. 720)
StartSeconds float64
FFmpegPath string
}
// Transcoder wraps a long-running ffmpeg child process whose stdout streams
// fragmented MP4 bytes for the WebRTC pump to forward to the browser.
//
// One Transcoder == one playback position. A seek beyond the buffered window
// requires Close()ing this transcoder and starting a new one with a higher
// StartSeconds (handled in webrtc_stream.go).
type Transcoder struct {
cmd *exec.Cmd
out io.ReadCloser
mu sync.Mutex
closed bool
stderr strings.Builder
}
// NewTranscoder spawns ffmpeg and returns a Transcoder whose Read() yields
// fragmented MP4 bytes from stdin. Callers MUST call Close() when done.
func NewTranscoder(ctx context.Context, filePath string, opts TranscodeOpts) (*Transcoder, error) {
if opts.FFmpegPath == "" {
return nil, fmt.Errorf("transcoder: empty ffmpeg path")
}
args := buildFFmpegArgs(filePath, opts)
cmd := exec.CommandContext(ctx, opts.FFmpegPath, args...)
stdout, err := cmd.StdoutPipe()
if err != nil {
return nil, fmt.Errorf("transcoder: stdout pipe: %w", err)
}
t := &Transcoder{cmd: cmd, out: stdout}
cmd.Stderr = &errWriter{t: t}
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("transcoder: start ffmpeg: %w", err)
}
return t, nil
}
// Read implements io.Reader.
func (t *Transcoder) Read(p []byte) (int, error) { return t.out.Read(p) }
// Close kills the child process if still running and waits up to 2s for exit.
func (t *Transcoder) Close() error {
t.mu.Lock()
if t.closed {
t.mu.Unlock()
return nil
}
t.closed = true
t.mu.Unlock()
_ = t.out.Close()
if t.cmd.Process != nil {
_ = t.cmd.Process.Kill()
}
done := make(chan error, 1)
go func() { done <- t.cmd.Wait() }()
select {
case <-done:
case <-time.After(2 * time.Second):
// Process refused to die — leak it; the OS will clean up on exit.
}
return nil
}
// Stderr returns the accumulated ffmpeg stderr so far. Useful for surfacing
// failure reasons in logs after Close().
func (t *Transcoder) Stderr() string {
t.mu.Lock()
defer t.mu.Unlock()
return t.stderr.String()
}
// errWriter funnels ffmpeg stderr into the Transcoder buffer so it can be
// inspected post-mortem. Capped so a misbehaving ffmpeg can't grow memory.
type errWriter struct{ t *Transcoder }
func (w *errWriter) Write(p []byte) (int, error) {
w.t.mu.Lock()
defer w.t.mu.Unlock()
const maxBuf = 64 * 1024
if w.t.stderr.Len() < maxBuf {
w.t.stderr.Write(p)
}
return len(p), nil
}
// buildFFmpegArgs assembles the command line for the requested action.
// Exposed package-level so tests can lock the flag matrix independently of
// process spawning.
func buildFFmpegArgs(filePath string, opts TranscodeOpts) []string {
args := []string{"-hide_banner", "-loglevel", "warning"}
// Seek BEFORE input (-ss before -i) for fast keyframe-aligned start.
if opts.StartSeconds > 0 {
args = append(args, "-ss", strconv.FormatFloat(opts.StartSeconds, 'f', 3, 64))
}
// HW accel hint on the demuxer side improves throughput for HEVC inputs
// even when we end up encoding in software. Skip on macOS (videotoolbox
// uses a different flag shape).
switch opts.HWAccel {
case HWAccelNVENC:
args = append(args, "-hwaccel", "cuda")
case HWAccelQSV:
args = append(args, "-hwaccel", "qsv")
case HWAccelVAAPI:
args = append(args, "-hwaccel", "vaapi", "-hwaccel_output_format", "vaapi")
case HWAccelNone, HWAccelVideoToolbox:
// No demuxer-side hint: software decode (None) or per-encoder flags
// already applied separately by FFmpegVideoCodec (VideoToolbox).
}
args = append(args, "-i", filePath)
switch opts.Action {
case ActionPassthrough, ActionRemux:
args = append(args, "-c:v", "copy", "-c:a", "copy")
case ActionRemuxAudio:
args = append(args, "-c:v", "copy", "-c:a", "aac", "-b:a", coalesce(opts.AudioBitrate, "192k"))
case ActionTranscodeVideo:
videoCodec := opts.HWAccel.FFmpegVideoCodec("h264")
args = append(args, "-c:v", videoCodec)
if videoCodec == "libx264" {
args = append(args, "-preset", coalesce(opts.Preset, "veryfast"))
}
args = append(args, "-b:v", coalesce(opts.VideoBitrate, "5M"))
if opts.MaxHeight > 0 {
args = append(args,
"-vf",
fmt.Sprintf("scale='min(iw,iw*%d/ih)':'min(ih,%d)'", opts.MaxHeight, opts.MaxHeight),
)
}
args = append(args, "-c:a", "aac", "-b:a", coalesce(opts.AudioBitrate, "192k"))
}
// Common output flags — fragmented MP4 to a single pipe.
args = append(args,
"-movflags", "frag_keyframe+empty_moov+default_base_moof+faststart",
"-f", "mp4",
"pipe:1",
)
return args
}
func coalesce(s, fallback string) string {
if s == "" {
return fallback
}
return s
}

View file

@ -0,0 +1,151 @@
package engine
import (
"strings"
"testing"
)
func sliceContains(args []string, want string) bool {
for _, a := range args {
if a == want {
return true
}
}
return false
}
func sliceContainsPair(args []string, key, val string) bool {
for i := 0; i < len(args)-1; i++ {
if args[i] == key && args[i+1] == val {
return true
}
}
return false
}
func TestBuildFFmpegArgsPassthroughCopy(t *testing.T) {
args := buildFFmpegArgs("/tmp/movie.mp4", TranscodeOpts{
Action: ActionPassthrough,
HWAccel: HWAccelNone,
FFmpegPath: "ffmpeg",
})
if !sliceContainsPair(args, "-c:v", "copy") {
t.Errorf("passthrough should keep -c:v copy. args=%v", args)
}
if !sliceContainsPair(args, "-c:a", "copy") {
t.Error("passthrough should keep -c:a copy")
}
if !sliceContainsPair(args, "-f", "mp4") {
t.Error("output container must be mp4")
}
movflags := ""
for i := 0; i < len(args)-1; i++ {
if args[i] == "-movflags" {
movflags = args[i+1]
}
}
if !strings.Contains(movflags, "frag_keyframe") {
t.Errorf("movflags must include frag_keyframe, got %q", movflags)
}
}
func TestBuildFFmpegArgsRemuxAudio(t *testing.T) {
args := buildFFmpegArgs("/tmp/movie.mkv", TranscodeOpts{
Action: ActionRemuxAudio,
AudioBitrate: "256k",
FFmpegPath: "ffmpeg",
})
if !sliceContainsPair(args, "-c:v", "copy") {
t.Error("remux-audio keeps video copy")
}
if !sliceContainsPair(args, "-c:a", "aac") {
t.Error("remux-audio must transcode audio to aac")
}
if !sliceContainsPair(args, "-b:a", "256k") {
t.Error("audio bitrate override not honored")
}
}
func TestBuildFFmpegArgsTranscodeVideoSoftware(t *testing.T) {
args := buildFFmpegArgs("/tmp/movie.mkv", TranscodeOpts{
Action: ActionTranscodeVideo,
HWAccel: HWAccelNone,
Preset: "fast",
VideoBitrate: "6M",
FFmpegPath: "ffmpeg",
})
if !sliceContainsPair(args, "-c:v", "libx264") {
t.Error("software fallback must use libx264")
}
if !sliceContainsPair(args, "-preset", "fast") {
t.Error("custom preset not honored")
}
if !sliceContainsPair(args, "-b:v", "6M") {
t.Error("video bitrate not honored")
}
}
func TestBuildFFmpegArgsTranscodeVideoNVENC(t *testing.T) {
args := buildFFmpegArgs("/tmp/movie.mkv", TranscodeOpts{
Action: ActionTranscodeVideo,
HWAccel: HWAccelNVENC,
FFmpegPath: "ffmpeg",
})
if !sliceContainsPair(args, "-hwaccel", "cuda") {
t.Error("NVENC must request -hwaccel cuda")
}
if !sliceContainsPair(args, "-c:v", "h264_nvenc") {
t.Error("NVENC must use h264_nvenc encoder")
}
if sliceContains(args, "-preset") {
// HW encoders ignore software preset; we should NOT pass it.
t.Error("HW encoder path should not include -preset")
}
}
func TestBuildFFmpegArgsAddsStartSeek(t *testing.T) {
args := buildFFmpegArgs("/tmp/movie.mp4", TranscodeOpts{
Action: ActionPassthrough,
StartSeconds: 90.5,
FFmpegPath: "ffmpeg",
})
idxSs, idxIn := -1, -1
for i, a := range args {
if a == "-ss" {
idxSs = i
}
if a == "-i" {
idxIn = i
}
}
if idxSs < 0 {
t.Fatal("missing -ss flag")
}
if idxIn < 0 {
t.Fatal("missing -i flag")
}
if idxSs >= idxIn {
t.Errorf("expected -ss BEFORE -i for fast seek; got -ss@%d -i@%d", idxSs, idxIn)
}
if args[idxSs+1] != "90.500" {
t.Errorf("expected seek 90.500s, got %q", args[idxSs+1])
}
}
func TestBuildFFmpegArgsDownscale(t *testing.T) {
args := buildFFmpegArgs("/tmp/movie.mkv", TranscodeOpts{
Action: ActionTranscodeVideo,
HWAccel: HWAccelNone,
MaxHeight: 720,
FFmpegPath: "ffmpeg",
})
hasVF := false
for i := 0; i < len(args)-1; i++ {
if args[i] == "-vf" && strings.Contains(args[i+1], "720") {
hasVF = true
}
}
if !hasVF {
t.Errorf("expected -vf scale containing 720; args=%v", args)
}
}

View file

@ -0,0 +1,617 @@
// Package engine — webrtc_stream.go implements the daemon side of the custom
// WebRTC byte-streaming protocol. The browser opens an RTCDataChannel via
// SDP exchange (signalled over the web's HTTP + SSE relay); this code:
//
// 1. Parses the browser's SDP offer.
// 2. Creates a pion PeerConnection bound to the configured ICE servers.
// 3. Answers + trickles its own ICE candidates back through the signal client.
// 4. On DataChannel open, sends a HELLO frame describing the file.
// 5. Services RangeReq frames by reading from disk and emitting RangeData
// chunks (16 KiB each) followed by a RangeEnd.
// 6. Honours app-level backpressure via SetBufferedAmountLowThreshold +
// OnBufferedAmountLow — Chromium closes a DataChannel when bufferedAmount
// exceeds 16 MiB, so we MUST pause the writer.
//
// No anacrolix, no torrent metadata. Just a peer-to-peer file server over
// WebRTC. Pass-through path; transcoding lives in transcoder.go (Fase 2.5).
package engine
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"math"
"os"
"path/filepath"
"sync"
"sync/atomic"
"time"
"github.com/pion/webrtc/v4"
"github.com/torrentclaw/unarr/internal/agent"
"github.com/torrentclaw/unarr/internal/engine/wire"
)
// Tunables — values match the protocol spec in plan/clever-weaving-dove.md.
const (
// dcChunkPayload is the per-frame application payload size. Must match
// wire.MaxChunkPayload so RangeData frames fit one SCTP message.
dcChunkPayload = wire.MaxChunkPayload
// dcHighWatermark is the bufferedAmount cap above which the writer pauses.
// Chromium closes DCs above 16 MiB; pause well below.
dcHighWatermark = 8 << 20
// dcLowWatermark triggers OnBufferedAmountLow → resume the writer.
dcLowWatermark = 1 << 20
// rangeReqConcurrency is the cap on in-flight range responses per session.
rangeReqConcurrency = 4
// helloDeadline is the max wait for the DataChannel to open after answer.
helloDeadline = 30 * time.Second
)
// WebRTCStreamConfig describes a single browser ↔ daemon stream session.
type WebRTCStreamConfig struct {
SessionID string
FilePath string
FileName string
FileSize int64
ICEServers []webrtc.ICEServer
Signal *agent.Client
// Logger receives diagnostic events; a nil logger swallows everything.
Logger StreamLogger
}
// StreamLogger is an injectable logger so tests can capture events.
type StreamLogger interface {
Infof(format string, args ...any)
Warnf(format string, args ...any)
Errorf(format string, args ...any)
}
type nopLogger struct{}
func (nopLogger) Infof(string, ...any) {}
func (nopLogger) Warnf(string, ...any) {}
func (nopLogger) Errorf(string, ...any) {}
func logger(l StreamLogger) StreamLogger {
if l == nil {
return nopLogger{}
}
return l
}
// RunWebRTCStream blocks until the session ends — either the DataChannel
// closes, the peer connection drops, or ctx is cancelled. Always returns a
// non-nil error explaining the termination reason.
func RunWebRTCStream(ctx context.Context, cfg WebRTCStreamConfig) error {
log := logger(cfg.Logger)
if cfg.SessionID == "" {
return errors.New("webrtc_stream: empty SessionID")
}
if cfg.FilePath == "" {
return errors.New("webrtc_stream: empty FilePath")
}
abs, err := filepath.Abs(cfg.FilePath)
if err != nil {
return fmt.Errorf("webrtc_stream: resolve path: %w", err)
}
file, err := os.Open(abs)
if err != nil {
return fmt.Errorf("webrtc_stream: open file: %w", err)
}
defer file.Close()
stat, err := file.Stat()
if err != nil {
return fmt.Errorf("webrtc_stream: stat: %w", err)
}
fileSize := stat.Size()
if cfg.FileSize > 0 && cfg.FileSize != fileSize {
log.Warnf("webrtc_stream: declared size %d != actual %d", cfg.FileSize, fileSize)
}
fileName := cfg.FileName
if fileName == "" {
fileName = filepath.Base(abs)
}
// 1. Build PeerConnection.
api := webrtc.NewAPI()
pc, err := api.NewPeerConnection(webrtc.Configuration{
ICEServers: cfg.ICEServers,
})
if err != nil {
return fmt.Errorf("webrtc_stream: new peer connection: %w", err)
}
defer pc.Close()
sessionCtx, cancelSession := context.WithCancel(ctx)
defer cancelSession()
// Stop the session when ICE drops permanently. "Disconnected" is
// transient per RFC 8445 (NAT rebind, brief packet loss) — wait for
// "Failed" or "Closed" before tearing down.
pc.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) {
log.Infof("[wrtc %s] ice=%s", agent.ShortID(cfg.SessionID), state.String())
switch state {
case webrtc.ICEConnectionStateFailed,
webrtc.ICEConnectionStateClosed:
cancelSession()
case webrtc.ICEConnectionStateUnknown,
webrtc.ICEConnectionStateNew,
webrtc.ICEConnectionStateChecking,
webrtc.ICEConnectionStateConnected,
webrtc.ICEConnectionStateCompleted,
webrtc.ICEConnectionStateDisconnected:
// Disconnected is transient (RFC 8445 — NAT rebind / packet loss);
// the others are normal progress states. Don't tear the session down.
}
})
// Trickle our ICE candidates back to the browser.
// PostSignal runs on its own goroutine so a slow signal server can't
// stall pion's ICE-gathering thread.
pc.OnICECandidate(func(c *webrtc.ICECandidate) {
if c == nil {
go func() {
_ = cfg.Signal.PostSignal(sessionCtx, cfg.SessionID, agent.SignalMessage{
Type: agent.SignalMsgCandidateEnd,
Payload: "",
})
}()
return
}
init := c.ToJSON()
payload, _ := json.Marshal(init)
go func() {
_ = cfg.Signal.PostSignal(sessionCtx, cfg.SessionID, agent.SignalMessage{
Type: agent.SignalMsgCandidate,
Payload: string(payload),
})
}()
})
// Browser is the offerer — we react to the DataChannel it creates.
dcReady := make(chan *webrtc.DataChannel, 1)
pc.OnDataChannel(func(dc *webrtc.DataChannel) {
log.Infof("[wrtc %s] data channel '%s' open", agent.ShortID(cfg.SessionID), dc.Label())
select {
case dcReady <- dc:
default:
// Browser opened a second DC — ignore, we only serve one.
log.Warnf("[wrtc %s] extra data channel ignored", agent.ShortID(cfg.SessionID))
}
})
// 2. Drive the SDP exchange.
sdpDone := make(chan error, 1)
go func() {
sdpDone <- runSDPExchange(sessionCtx, pc, cfg)
}()
// 3. Wait for either SDP error or DataChannel open.
var dc *webrtc.DataChannel
select {
case err := <-sdpDone:
if err != nil {
return fmt.Errorf("sdp exchange: %w", err)
}
// SDP complete — wait for the DC.
select {
case dc = <-dcReady:
case <-time.After(helloDeadline):
return errors.New("webrtc_stream: data channel never opened")
case <-sessionCtx.Done():
return sessionCtx.Err()
}
case dc = <-dcReady:
// DC opened before SDP loop reported done (typical: the loop keeps
// running to ferry remote ICE candidates).
case <-sessionCtx.Done():
return sessionCtx.Err()
}
// 4. Wire up the data channel pump.
pump := newDataChannelPump(dc, file, fileSize, fileName, log, cancelSession)
dc.OnOpen(pump.onOpen)
dc.OnMessage(pump.onMessage)
dc.OnClose(func() {
log.Infof("[wrtc %s] data channel closed", agent.ShortID(cfg.SessionID))
cancelSession()
})
<-sessionCtx.Done()
pump.shutdown()
return sessionCtx.Err()
}
// runSDPExchange consumes signal events from the browser and answers the SDP
// offer. Keeps running for the lifetime of sessionCtx so trickle candidates
// flow in both directions. Reopens the SSE stream on every clean close — the
// server caps each response at ~25 s.
func runSDPExchange(ctx context.Context, pc *webrtc.PeerConnection, cfg WebRTCStreamConfig) error {
gotOffer := false
for ctx.Err() == nil {
stream, err := cfg.Signal.OpenSignalStream(ctx, cfg.SessionID)
if err != nil {
if ctx.Err() != nil {
return ctx.Err()
}
return fmt.Errorf("open signal stream: %w", err)
}
err = consumeSignalStream(ctx, pc, cfg, stream, &gotOffer)
stream.Close()
if err != nil {
return err
}
}
return ctx.Err()
}
// consumeSignalStream drains a single SSE connection until it closes or
// produces a hard error. Returns nil on a clean server-side disconnect so the
// caller can reopen.
func consumeSignalStream(
ctx context.Context,
pc *webrtc.PeerConnection,
cfg WebRTCStreamConfig,
stream *agent.SignalEventStream,
gotOffer *bool,
) error {
for {
select {
case <-ctx.Done():
return ctx.Err()
case msg, ok := <-stream.Events():
if !ok {
if err := stream.Err(); err != nil {
return fmt.Errorf("signal stream: %w", err)
}
return nil
}
if err := handleSignal(ctx, pc, cfg, msg, gotOffer); err != nil {
return err
}
}
}
}
func handleSignal(
ctx context.Context,
pc *webrtc.PeerConnection,
cfg WebRTCStreamConfig,
msg agent.SignalMessage,
gotOffer *bool,
) error {
switch msg.Type {
case agent.SignalMsgAnswer:
// Browser is the offerer in our protocol — we never expect an answer
// from the other side. Drop silently (also satisfies exhaustive lint).
return nil
case agent.SignalMsgOffer:
if *gotOffer {
return nil // ignore duplicates
}
var offer webrtc.SessionDescription
if err := json.Unmarshal([]byte(msg.Payload), &offer); err != nil {
return fmt.Errorf("decode offer: %w", err)
}
if err := pc.SetRemoteDescription(offer); err != nil {
return fmt.Errorf("set remote description: %w", err)
}
answer, err := pc.CreateAnswer(nil)
if err != nil {
return fmt.Errorf("create answer: %w", err)
}
if err := pc.SetLocalDescription(answer); err != nil {
return fmt.Errorf("set local description: %w", err)
}
// Send back the local description *with* gathered candidates so far —
// remaining candidates trickle separately via OnICECandidate.
ld := pc.LocalDescription()
payload, _ := json.Marshal(ld)
if err := cfg.Signal.PostSignal(ctx, cfg.SessionID, agent.SignalMessage{
Type: agent.SignalMsgAnswer,
Payload: string(payload),
}); err != nil {
return fmt.Errorf("post answer: %w", err)
}
*gotOffer = true
case agent.SignalMsgCandidate:
if !*gotOffer {
// Browser may trickle candidates before we've seen the offer in
// rare race conditions — drop. Browser will retransmit.
return nil
}
var init webrtc.ICECandidateInit
if err := json.Unmarshal([]byte(msg.Payload), &init); err != nil {
return fmt.Errorf("decode candidate: %w", err)
}
if err := pc.AddICECandidate(init); err != nil {
return fmt.Errorf("add ice candidate: %w", err)
}
case agent.SignalMsgCandidateEnd:
// No-op — pion gathers complete on its own.
case agent.SignalMsgBye:
return errors.New("browser sent bye")
}
return nil
}
// dataChannelPump owns the DC + file handle and serves wire-protocol frames.
type dataChannelPump struct {
dc *webrtc.DataChannel
file *os.File
fileSize int64
fileName string
log StreamLogger
cancel context.CancelFunc
// Flow control: writers wait on resumeCh when bufferedAmount goes high.
paused atomic.Bool
resumeCh chan struct{}
// Active range responses keyed by stream_id so CANCEL frames can stop them.
activeMu sync.Mutex
active map[uint32]context.CancelFunc
// Bound concurrent in-flight responses.
sem chan struct{}
// closed once shutdown() has been called.
closed atomic.Bool
}
func newDataChannelPump(
dc *webrtc.DataChannel,
file *os.File,
fileSize int64,
fileName string,
log StreamLogger,
cancel context.CancelFunc,
) *dataChannelPump {
p := &dataChannelPump{
dc: dc,
file: file,
fileSize: fileSize,
fileName: fileName,
log: log,
cancel: cancel,
resumeCh: make(chan struct{}, 1),
active: make(map[uint32]context.CancelFunc),
sem: make(chan struct{}, rangeReqConcurrency),
}
dc.SetBufferedAmountLowThreshold(dcLowWatermark)
dc.OnBufferedAmountLow(p.onBufferedAmountLow)
return p
}
func (p *dataChannelPump) onOpen() {
hello := wire.HelloPayload{
FileSize: uint64(p.fileSize),
Transcoding: false,
Seekable: true,
FileName: p.fileName,
}
payload := wire.EncodeHello(hello)
frame := wire.EncodeFrame(wire.Header{
Type: wire.FrameHello,
Flags: wire.HelloFlags(false, true),
StreamID: 0,
Length: uint32(len(payload)),
}, payload)
if err := p.dc.Send(frame); err != nil {
p.log.Errorf("send hello: %v", err)
p.cancel()
}
}
func (p *dataChannelPump) onMessage(msg webrtc.DataChannelMessage) {
if len(msg.Data) < wire.HeaderSize {
p.log.Warnf("dc: short frame %d bytes", len(msg.Data))
return
}
hdr, err := wire.DecodeHeader(msg.Data[:wire.HeaderSize])
if err != nil {
p.log.Warnf("dc: bad header: %v", err)
return
}
payload := msg.Data[wire.HeaderSize:]
if uint32(len(payload)) != hdr.Length {
p.log.Warnf("dc: payload length mismatch: hdr=%d got=%d", hdr.Length, len(payload))
return
}
switch hdr.Type {
case wire.FrameRangeReq:
req, err := wire.DecodeRangeReq(payload)
if err != nil {
p.log.Warnf("dc: bad range_req: %v", err)
return
}
go p.serveRange(hdr.StreamID, req)
case wire.FrameCancel:
p.cancelStream(hdr.StreamID)
case wire.FramePing:
p.sendSimpleFrame(wire.FramePong, hdr.StreamID, nil)
case wire.FramePong:
// no-op
default:
p.log.Warnf("dc: unknown frame type 0x%02x", hdr.Type)
}
}
func (p *dataChannelPump) cancelStream(streamID uint32) {
p.activeMu.Lock()
cancel, ok := p.active[streamID]
delete(p.active, streamID)
p.activeMu.Unlock()
if ok {
cancel()
}
}
func (p *dataChannelPump) sendSimpleFrame(t wire.FrameType, streamID uint32, payload []byte) {
frame := wire.EncodeFrame(wire.Header{
Type: t,
StreamID: streamID,
Length: uint32(len(payload)),
}, payload)
if err := p.dc.Send(frame); err != nil {
p.log.Warnf("dc: send type=0x%02x: %v", t, err)
}
}
func (p *dataChannelPump) serveRange(streamID uint32, req wire.RangeReqPayload) {
if p.closed.Load() {
return
}
// Bound concurrency.
select {
case p.sem <- struct{}{}:
case <-time.After(5 * time.Second):
p.log.Warnf("dc: range_req sid=%d dropped (concurrency cap)", streamID)
p.sendRangeEnd(streamID, 1)
return
}
defer func() { <-p.sem }()
// Reject offsets above MaxInt64 — uint64→int64 narrowing would wrap to a
// negative value and bypass the bounds check, then ReadAt would be called
// with a negative offset.
if req.Offset > math.MaxInt64 || int64(req.Offset) >= p.fileSize {
p.sendRangeEnd(streamID, 2) // out of range
return
}
want := int64(req.Length)
if req.Length > math.MaxInt64 {
want = 0 // treat absurd length as "remainder of file"
}
remaining := p.fileSize - int64(req.Offset)
if want <= 0 || want > remaining {
want = remaining
}
ctx, cancel := context.WithCancel(context.Background())
p.activeMu.Lock()
p.active[streamID] = cancel
p.activeMu.Unlock()
defer func() {
p.activeMu.Lock()
delete(p.active, streamID)
p.activeMu.Unlock()
cancel()
}()
buf := make([]byte, dcChunkPayload)
offset := int64(req.Offset)
end := offset + want
for offset < end {
if ctx.Err() != nil || p.closed.Load() {
return
}
// Wait if the DC is buffering too much.
if err := p.waitForLowWater(ctx); err != nil {
return
}
chunkLen := int64(len(buf))
if end-offset < chunkLen {
chunkLen = end - offset
}
n, rerr := p.file.ReadAt(buf[:chunkLen], offset)
if n > 0 {
// EOF on a short read means this is the final chunk — flag it so the
// browser doesn't wait for more data before processing RangeEnd.
isLast := offset+int64(n) >= end || rerr == io.EOF
if err := p.sendRangeData(streamID, buf[:n], isLast); err != nil {
p.log.Warnf("dc: send range_data sid=%d: %v", streamID, err)
return
}
offset += int64(n)
}
if rerr != nil {
if rerr == io.EOF {
break
}
p.log.Errorf("dc: read sid=%d: %v", streamID, rerr)
p.sendRangeEnd(streamID, 3)
return
}
}
p.sendRangeEnd(streamID, 0)
}
func (p *dataChannelPump) sendRangeData(streamID uint32, data []byte, last bool) error {
var flags uint8
if last {
flags |= wire.FlagLastChunk
}
frame := wire.EncodeFrame(wire.Header{
Type: wire.FrameRangeData,
Flags: flags,
StreamID: streamID,
Length: uint32(len(data)),
}, data)
return p.dc.Send(frame)
}
func (p *dataChannelPump) sendRangeEnd(streamID uint32, status uint32) {
payload := wire.EncodeRangeEnd(wire.RangeEndPayload{Status: status})
p.sendSimpleFrame(wire.FrameRangeEnd, streamID, payload)
}
func (p *dataChannelPump) waitForLowWater(ctx context.Context) error {
if p.dc.BufferedAmount() < dcHighWatermark {
return nil
}
p.paused.Store(true)
for {
// Drain any stale resume signal first.
select {
case <-p.resumeCh:
default:
}
if p.dc.BufferedAmount() < dcHighWatermark {
p.paused.Store(false)
return nil
}
select {
case <-ctx.Done():
return ctx.Err()
case <-p.resumeCh:
case <-time.After(500 * time.Millisecond):
// Belt-and-braces poll in case OnBufferedAmountLow misses a fire.
}
}
}
func (p *dataChannelPump) onBufferedAmountLow() {
if !p.paused.Load() {
return
}
select {
case p.resumeCh <- struct{}{}:
default:
}
}
func (p *dataChannelPump) shutdown() {
if !p.closed.CompareAndSwap(false, true) {
return
}
p.activeMu.Lock()
for _, cancel := range p.active {
cancel()
}
p.active = nil
p.activeMu.Unlock()
}

View file

@ -0,0 +1,254 @@
// Package wire implements the binary frame format used over the WebRTC
// DataChannel between the unarr daemon and the browser stream player.
//
// Header (12 bytes, big-endian):
//
// u8 Type
// u8 Flags
// u16 _reserved
// u32 StreamID -- multiplex range requests
// u32 Length -- payload bytes following the header
//
// Each side encodes one Frame at a time and writes it as a single SCTP
// message (DataChannel send). Browsers cap message size at 64 KiB-ish, so
// callers MUST split RANGE_DATA payloads into chunks <= MaxChunkPayload.
package wire
import (
"encoding/binary"
"errors"
"fmt"
"io"
)
// FrameType identifies the wire message kind.
type FrameType uint8
const (
FrameHello FrameType = 0x00
FrameRangeReq FrameType = 0x01
FrameRangeData FrameType = 0x02
FrameRangeEnd FrameType = 0x03
FrameCancel FrameType = 0x04
FramePing FrameType = 0x05
FramePong FrameType = 0x06
FrameSeekHint FrameType = 0x07
)
// Flag bits — interpretation depends on FrameType.
const (
// FlagLastChunk on a RangeData frame marks the final chunk for a stream_id.
FlagLastChunk uint8 = 1 << 0
// FlagTranscoding on a Hello frame indicates the daemon will transcode.
FlagTranscoding uint8 = 1 << 1
// FlagSeekable on a Hello frame indicates random-access is supported.
FlagSeekable uint8 = 1 << 2
)
// HeaderSize is the fixed length of every frame header.
const HeaderSize = 12
// MaxChunkPayload is the safe per-frame payload cap that works on every
// browser implementation (Chromium fragments at 16 KiB internally above).
// Callers MUST chunk RangeData payloads to <= this size.
const MaxChunkPayload = 16 * 1024
// MaxFrameSize is the largest frame the parser will accept. Anything bigger
// is treated as a corrupted stream — close the channel.
const MaxFrameSize = HeaderSize + 64*1024
// Header is the parsed 12-byte frame header.
type Header struct {
Type FrameType
Flags uint8
StreamID uint32
Length uint32
}
// EncodeHeader writes h to dst (must be at least HeaderSize bytes).
func EncodeHeader(dst []byte, h Header) {
if len(dst) < HeaderSize {
panic("wire: dst too small for header")
}
dst[0] = byte(h.Type)
dst[1] = h.Flags
dst[2] = 0
dst[3] = 0
binary.BigEndian.PutUint32(dst[4:8], h.StreamID)
binary.BigEndian.PutUint32(dst[8:12], h.Length)
}
// DecodeHeader parses src (must be at least HeaderSize bytes) into h.
func DecodeHeader(src []byte) (Header, error) {
if len(src) < HeaderSize {
return Header{}, fmt.Errorf("wire: header needs %d bytes, got %d", HeaderSize, len(src))
}
h := Header{
Type: FrameType(src[0]),
Flags: src[1],
StreamID: binary.BigEndian.Uint32(src[4:8]),
Length: binary.BigEndian.Uint32(src[8:12]),
}
if h.Length > MaxFrameSize-HeaderSize {
return Header{}, fmt.Errorf("wire: payload length %d exceeds max %d", h.Length, MaxFrameSize-HeaderSize)
}
return h, nil
}
// EncodeFrame allocates and returns a complete frame (header + payload).
// Use this for one-shot sends; for hot-path RangeData prefer EncodeHeader
// into a pre-allocated buffer to avoid per-frame allocations.
func EncodeFrame(h Header, payload []byte) []byte {
if int(h.Length) != len(payload) {
panic(fmt.Sprintf("wire: header length %d != payload len %d", h.Length, len(payload)))
}
buf := make([]byte, HeaderSize+len(payload))
EncodeHeader(buf[:HeaderSize], h)
copy(buf[HeaderSize:], payload)
return buf
}
// ReadFrame reads one full frame from r. Returns the parsed header and a
// freshly allocated payload slice. On any size violation the connection
// must be closed — the protocol has no resync.
func ReadFrame(r io.Reader) (Header, []byte, error) {
headerBuf := make([]byte, HeaderSize)
if _, err := io.ReadFull(r, headerBuf); err != nil {
return Header{}, nil, err
}
h, err := DecodeHeader(headerBuf)
if err != nil {
return Header{}, nil, err
}
if h.Length == 0 {
return h, nil, nil
}
payload := make([]byte, h.Length)
if _, err := io.ReadFull(r, payload); err != nil {
return Header{}, nil, err
}
return h, payload, nil
}
// HelloPayload describes the file the daemon is about to serve. It is the
// first frame the daemon writes after the DataChannel opens.
type HelloPayload struct {
FileSize uint64
Transcoding bool
Seekable bool
FileName string
}
// EncodeHello marshals h into a payload byte slice.
//
// Layout: u64 file_size | u32 name_len | name_bytes
func EncodeHello(h HelloPayload) []byte {
nameBytes := []byte(h.FileName)
buf := make([]byte, 8+4+len(nameBytes))
binary.BigEndian.PutUint64(buf[0:8], h.FileSize)
binary.BigEndian.PutUint32(buf[8:12], uint32(len(nameBytes)))
copy(buf[12:], nameBytes)
return buf
}
// DecodeHello parses a Hello payload. The transcoding/seekable bits live in
// the frame Flags byte, not the payload — pass them in.
func DecodeHello(payload []byte, flags uint8) (HelloPayload, error) {
if len(payload) < 12 {
return HelloPayload{}, errors.New("wire: hello payload too short")
}
size := binary.BigEndian.Uint64(payload[0:8])
nameLen := binary.BigEndian.Uint32(payload[8:12])
if int(nameLen) > len(payload)-12 {
return HelloPayload{}, fmt.Errorf("wire: hello name_len %d exceeds payload", nameLen)
}
return HelloPayload{
FileSize: size,
Transcoding: flags&FlagTranscoding != 0,
Seekable: flags&FlagSeekable != 0,
FileName: string(payload[12 : 12+nameLen]),
}, nil
}
// HelloFlags returns the flag byte for a Hello frame given the booleans.
func HelloFlags(transcoding, seekable bool) uint8 {
var f uint8
if transcoding {
f |= FlagTranscoding
}
if seekable {
f |= FlagSeekable
}
return f
}
// RangeReqPayload is the browser → daemon request for bytes [Offset, Offset+Length).
type RangeReqPayload struct {
Offset uint64
Length uint64
}
// EncodeRangeReq marshals p. Layout: u64 offset | u64 length.
func EncodeRangeReq(p RangeReqPayload) []byte {
buf := make([]byte, 16)
binary.BigEndian.PutUint64(buf[0:8], p.Offset)
binary.BigEndian.PutUint64(buf[8:16], p.Length)
return buf
}
// DecodeRangeReq parses a 16-byte range request payload.
func DecodeRangeReq(payload []byte) (RangeReqPayload, error) {
if len(payload) != 16 {
return RangeReqPayload{}, fmt.Errorf("wire: range_req payload must be 16 bytes, got %d", len(payload))
}
return RangeReqPayload{
Offset: binary.BigEndian.Uint64(payload[0:8]),
Length: binary.BigEndian.Uint64(payload[8:16]),
}, nil
}
// RangeEndPayload signals end-of-response for a stream_id with a status code.
// Status 0 == OK; non-zero values are app-defined error codes.
type RangeEndPayload struct {
Status uint32
}
// EncodeRangeEnd marshals p.
func EncodeRangeEnd(p RangeEndPayload) []byte {
buf := make([]byte, 4)
binary.BigEndian.PutUint32(buf[0:4], p.Status)
return buf
}
// DecodeRangeEnd parses a 4-byte range_end payload.
func DecodeRangeEnd(payload []byte) (RangeEndPayload, error) {
if len(payload) != 4 {
return RangeEndPayload{}, fmt.Errorf("wire: range_end payload must be 4 bytes, got %d", len(payload))
}
return RangeEndPayload{
Status: binary.BigEndian.Uint32(payload[0:4]),
}, nil
}
// SeekHintPayload tells the daemon a seek to timestamp_ms is imminent so it
// can pre-warm a transcoder pipeline before bytes are requested.
type SeekHintPayload struct {
TimestampMs uint64
}
// EncodeSeekHint marshals p.
func EncodeSeekHint(p SeekHintPayload) []byte {
buf := make([]byte, 8)
binary.BigEndian.PutUint64(buf[0:8], p.TimestampMs)
return buf
}
// DecodeSeekHint parses an 8-byte seek_hint payload.
func DecodeSeekHint(payload []byte) (SeekHintPayload, error) {
if len(payload) != 8 {
return SeekHintPayload{}, fmt.Errorf("wire: seek_hint payload must be 8 bytes, got %d", len(payload))
}
return SeekHintPayload{
TimestampMs: binary.BigEndian.Uint64(payload[0:8]),
}, nil
}

View file

@ -0,0 +1,193 @@
package wire
import (
"bytes"
"testing"
)
func TestHeaderRoundtrip(t *testing.T) {
cases := []Header{
{Type: FrameHello, Flags: FlagSeekable, StreamID: 0, Length: 32},
{Type: FrameRangeReq, Flags: 0, StreamID: 7, Length: 16},
{Type: FrameRangeData, Flags: FlagLastChunk, StreamID: 4242, Length: 16380},
{Type: FrameRangeEnd, Flags: 0, StreamID: 1, Length: 4},
{Type: FrameCancel, Flags: 0, StreamID: 9, Length: 0},
{Type: FramePing, Flags: 0, StreamID: 0, Length: 0},
}
for _, want := range cases {
buf := make([]byte, HeaderSize)
EncodeHeader(buf, want)
got, err := DecodeHeader(buf)
if err != nil {
t.Fatalf("decode: %v (want %+v)", err, want)
}
if got != want {
t.Errorf("roundtrip mismatch: got %+v want %+v", got, want)
}
}
}
func TestDecodeHeaderShort(t *testing.T) {
if _, err := DecodeHeader([]byte{0, 0, 0}); err == nil {
t.Fatal("expected error on short header")
}
}
func TestDecodeHeaderRejectsHugeLength(t *testing.T) {
// Synthesize a header with payload length above MaxFrameSize.
buf := make([]byte, HeaderSize)
buf[0] = byte(FrameRangeData)
buf[8] = 0xff
buf[9] = 0xff
buf[10] = 0xff
buf[11] = 0xff
if _, err := DecodeHeader(buf); err == nil {
t.Fatal("expected error on oversized payload length")
}
}
func TestEncodeFramePanicsOnLengthMismatch(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Fatal("expected panic on header length / payload mismatch")
}
}()
EncodeFrame(Header{Type: FrameRangeData, Length: 5}, []byte{1, 2, 3})
}
func TestReadFrameRoundtrip(t *testing.T) {
want := Header{Type: FrameRangeData, Flags: FlagLastChunk, StreamID: 99, Length: 5}
payload := []byte{0xde, 0xad, 0xbe, 0xef, 0x42}
frame := EncodeFrame(want, payload)
r := bytes.NewReader(frame)
got, gotPayload, err := ReadFrame(r)
if err != nil {
t.Fatalf("read: %v", err)
}
if got != want {
t.Errorf("header mismatch: %+v want %+v", got, want)
}
if !bytes.Equal(gotPayload, payload) {
t.Errorf("payload mismatch: %x want %x", gotPayload, payload)
}
}
func TestReadFrameZeroPayload(t *testing.T) {
want := Header{Type: FrameCancel, StreamID: 7}
frame := EncodeFrame(want, nil)
got, payload, err := ReadFrame(bytes.NewReader(frame))
if err != nil {
t.Fatalf("read: %v", err)
}
if got != want {
t.Errorf("header mismatch: %+v want %+v", got, want)
}
if len(payload) != 0 {
t.Errorf("expected empty payload, got %d bytes", len(payload))
}
}
func TestHelloRoundtrip(t *testing.T) {
want := HelloPayload{
FileSize: 1<<32 + 12345,
Transcoding: false,
Seekable: true,
FileName: "Tangled.Ever.After.2025.1080p.WEB-DL.h264.mp4",
}
flags := HelloFlags(want.Transcoding, want.Seekable)
payload := EncodeHello(want)
got, err := DecodeHello(payload, flags)
if err != nil {
t.Fatalf("decode: %v", err)
}
if got != want {
t.Errorf("hello mismatch: %+v want %+v", got, want)
}
}
func TestHelloRejectsTruncatedPayload(t *testing.T) {
if _, err := DecodeHello([]byte{1, 2, 3}, 0); err == nil {
t.Fatal("expected error on truncated hello")
}
}
func TestHelloRejectsNameLenOverrun(t *testing.T) {
// file_size + name_len=999 but no name bytes → should fail.
buf := make([]byte, 12)
buf[8], buf[9], buf[10], buf[11] = 0, 0, 0x03, 0xe7 // 999
if _, err := DecodeHello(buf, 0); err == nil {
t.Fatal("expected error on name_len overrun")
}
}
func TestRangeReqRoundtrip(t *testing.T) {
want := RangeReqPayload{Offset: 1 << 30, Length: 1 << 20}
got, err := DecodeRangeReq(EncodeRangeReq(want))
if err != nil {
t.Fatalf("decode: %v", err)
}
if got != want {
t.Errorf("range_req mismatch: %+v want %+v", got, want)
}
}
func TestRangeReqRejectsWrongLength(t *testing.T) {
if _, err := DecodeRangeReq(make([]byte, 15)); err == nil {
t.Fatal("expected error on 15-byte payload")
}
if _, err := DecodeRangeReq(make([]byte, 17)); err == nil {
t.Fatal("expected error on 17-byte payload")
}
}
func TestRangeEndRoundtrip(t *testing.T) {
want := RangeEndPayload{Status: 42}
got, err := DecodeRangeEnd(EncodeRangeEnd(want))
if err != nil {
t.Fatalf("decode: %v", err)
}
if got != want {
t.Errorf("range_end mismatch: %+v want %+v", got, want)
}
if _, err := DecodeRangeEnd(make([]byte, 3)); err == nil {
t.Fatal("expected error on short range_end payload")
}
}
func TestSeekHintRoundtrip(t *testing.T) {
want := SeekHintPayload{TimestampMs: 123_456}
got, err := DecodeSeekHint(EncodeSeekHint(want))
if err != nil {
t.Fatalf("decode: %v", err)
}
if got != want {
t.Errorf("seek_hint mismatch: %+v want %+v", got, want)
}
if _, err := DecodeSeekHint(make([]byte, 7)); err == nil {
t.Fatal("expected error on short seek_hint payload")
}
}
func TestHelloFlagsHelper(t *testing.T) {
if HelloFlags(false, false) != 0 {
t.Error("expected 0 for both false")
}
if HelloFlags(true, false) != FlagTranscoding {
t.Error("expected FlagTranscoding only")
}
if HelloFlags(false, true) != FlagSeekable {
t.Error("expected FlagSeekable only")
}
if HelloFlags(true, true) != (FlagTranscoding | FlagSeekable) {
t.Error("expected both flags")
}
}
// Sanity check that MaxChunkPayload + HeaderSize fits inside MaxFrameSize so
// callers can rely on the chunk cap without their own bookkeeping.
func TestMaxChunkFitsInMaxFrame(t *testing.T) {
if MaxChunkPayload+HeaderSize > MaxFrameSize {
t.Fatalf("chunk %d + hdr %d > max frame %d", MaxChunkPayload, HeaderSize, MaxFrameSize)
}
}