feat(stream): pion-based WebRTC byte streamer for browser playback
Replaces the broken anacrolix WebTorrent path with a custom WebRTC peer that the browser drives directly. Architecture matches plan/clever- weaving-dove.md (Fase 2 + 3 + 6 of the streaming pivot). - engine/wire: shared 12-byte binary frame format (Hello / RangeReq / RangeData / RangeEnd / Cancel / Ping / Pong / SeekHint). Roundtrip + oversized-frame rejection tests. - agent/signal_client: SSE consumer + POST sender for SDP/ICE relay through /api/internal/stream/signal/<id>; auto-reconnects. - engine/webrtc_stream: pion v4 PeerConnection + DataChannel pump. Reads file via os.ReadAt, chunks RangeData at 16 KiB, honours app- level backpressure with SetBufferedAmountLowThreshold. - cmd/daemon dispatcher learns mode webrtc_stream + new webrtcSessionRegistry tracks per-session cancel funcs for clean shutdown. - engine/probe + hwaccel + transcoder: foundation for Fase 2.5 (codec detection, NVENC/QSV/VAAPI/VideoToolbox autodetection, ffmpeg pipe wrapper to fragmented MP4). Integration into webrtc_stream still pending. - pion/webrtc/v4 promoted from indirect to direct dep. End-to-end against unarr-dev confirms a 122 MB 1080p H.264 / AAC MP4 plays in Chrome with the new pipeline.
This commit is contained in:
parent
4c52d9b039
commit
4314c06c5c
17 changed files with 2308 additions and 1 deletions
|
|
@ -35,6 +35,7 @@ type Daemon struct {
|
|||
// Callbacks — set by cmd/daemon.go before calling Run.
|
||||
OnTasksClaimed func(tasks []Task)
|
||||
OnStreamRequested func(req StreamRequest)
|
||||
OnWebRTCSession func(sess WebRTCSession)
|
||||
OnControlAction func(action, taskID string, deleteFiles bool)
|
||||
GetActiveCount func() int // returns number of active downloads (wired from manager)
|
||||
|
||||
|
|
@ -169,6 +170,11 @@ func (d *Daemon) Run(ctx context.Context) error {
|
|||
d.OnStreamRequested(req)
|
||||
}
|
||||
}
|
||||
d.sync.OnWebRTCSession = func(sess WebRTCSession) {
|
||||
if d.OnWebRTCSession != nil {
|
||||
d.OnWebRTCSession(sess)
|
||||
}
|
||||
}
|
||||
d.sync.OnUpgrade = func(version string) {
|
||||
if version != d.lastNotifiedVersion {
|
||||
d.lastNotifiedVersion = version
|
||||
|
|
|
|||
233
internal/agent/signal_client.go
Normal file
233
internal/agent/signal_client.go
Normal file
|
|
@ -0,0 +1,233 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SignalRole identifies who produced a signalling message. The opposite role
|
||||
// receives it.
|
||||
type SignalRole string
|
||||
|
||||
const (
|
||||
SignalRoleBrowser SignalRole = "browser"
|
||||
SignalRoleAgent SignalRole = "agent"
|
||||
)
|
||||
|
||||
// SignalMessageType matches the server-side z.enum on
|
||||
// /api/internal/stream/signal/[sessionId] route.
|
||||
type SignalMessageType string
|
||||
|
||||
const (
|
||||
SignalMsgOffer SignalMessageType = "offer"
|
||||
SignalMsgAnswer SignalMessageType = "answer"
|
||||
SignalMsgCandidate SignalMessageType = "candidate"
|
||||
SignalMsgCandidateEnd SignalMessageType = "candidate-end"
|
||||
SignalMsgBye SignalMessageType = "bye"
|
||||
)
|
||||
|
||||
// SignalMessage mirrors the bus envelope on the web side.
|
||||
type SignalMessage struct {
|
||||
From SignalRole `json:"from"`
|
||||
Type SignalMessageType `json:"type"`
|
||||
Payload string `json:"payload"`
|
||||
TS int64 `json:"ts"`
|
||||
}
|
||||
|
||||
// PostSignal enqueues a signalling message produced by this agent. The
|
||||
// browser receives it on its next SSE event push.
|
||||
func (c *Client) PostSignal(ctx context.Context, sessionID string, msg SignalMessage) error {
|
||||
body := map[string]any{
|
||||
"from": string(SignalRoleAgent),
|
||||
"type": string(msg.Type),
|
||||
"payload": msg.Payload,
|
||||
}
|
||||
path := fmt.Sprintf("/api/internal/stream/signal/%s", sessionID)
|
||||
return c.doPost(ctx, path, body, &struct {
|
||||
OK bool `json:"ok"`
|
||||
}{})
|
||||
}
|
||||
|
||||
// SignalEventStream wraps an open SSE connection. Read messages from Events()
|
||||
// until the channel closes (server timeout or context cancel). Always defer
|
||||
// Close() to release the underlying response body.
|
||||
type SignalEventStream struct {
|
||||
resp *http.Response
|
||||
cancel context.CancelFunc
|
||||
events chan SignalMessage
|
||||
errs chan error
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
// Events streams browser-produced messages addressed to the agent.
|
||||
// The channel closes when the SSE connection ends; the caller should then
|
||||
// call Close() and reopen if it wants to keep listening.
|
||||
func (s *SignalEventStream) Events() <-chan SignalMessage { return s.events }
|
||||
|
||||
// Err returns the terminating error (if any) once Events() has closed.
|
||||
func (s *SignalEventStream) Err() error {
|
||||
select {
|
||||
case err := <-s.errs:
|
||||
return err
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Close cancels the underlying HTTP request and waits for the reader goroutine
|
||||
// to drain. Safe to call more than once.
|
||||
func (s *SignalEventStream) Close() error {
|
||||
if s.cancel != nil {
|
||||
s.cancel()
|
||||
}
|
||||
if s.resp != nil {
|
||||
s.resp.Body.Close()
|
||||
}
|
||||
<-s.done
|
||||
return nil
|
||||
}
|
||||
|
||||
// OpenSignalStream opens a long-lived SSE connection to the signal events
|
||||
// endpoint. Caller MUST cancel ctx (or call Close()) to free resources.
|
||||
//
|
||||
// The server caps each response at ~25 s; OpenSignalStream surfaces the
|
||||
// disconnect by closing the events channel. Caller should reopen until the
|
||||
// session ends.
|
||||
func (c *Client) OpenSignalStream(ctx context.Context, sessionID string) (*SignalEventStream, error) {
|
||||
streamCtx, cancel := context.WithCancel(ctx)
|
||||
|
||||
url := fmt.Sprintf("%s/api/internal/stream/signal/%s/events", c.baseURL, sessionID)
|
||||
req, err := http.NewRequestWithContext(streamCtx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("open signal stream: %w", err)
|
||||
}
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
req.Header.Set("Authorization", "Bearer "+c.apiKey)
|
||||
req.Header.Set("User-Agent", c.userAgent)
|
||||
req.Header.Set("Cache-Control", "no-cache")
|
||||
|
||||
// Use a per-call client with no timeout (SSE connections are long).
|
||||
sseClient := &http.Client{}
|
||||
resp, err := sseClient.Do(req)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("open signal stream: %w", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||||
resp.Body.Close()
|
||||
cancel()
|
||||
return nil, fmt.Errorf("open signal stream: HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
stream := &SignalEventStream{
|
||||
resp: resp,
|
||||
cancel: cancel,
|
||||
events: make(chan SignalMessage, 8),
|
||||
errs: make(chan error, 1),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
go stream.read()
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
func (s *SignalEventStream) read() {
|
||||
defer close(s.done)
|
||||
defer close(s.events)
|
||||
|
||||
reader := bufio.NewReaderSize(s.resp.Body, 16*1024)
|
||||
var dataBuf bytes.Buffer
|
||||
var eventName string
|
||||
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
select {
|
||||
case s.errs <- err:
|
||||
default:
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
line = strings.TrimRight(line, "\r\n")
|
||||
if line == "" {
|
||||
// End of an event — dispatch if we have data.
|
||||
if dataBuf.Len() == 0 {
|
||||
eventName = ""
|
||||
continue
|
||||
}
|
||||
if eventName == "" || eventName == "signal" {
|
||||
var msg SignalMessage
|
||||
if err := json.Unmarshal(dataBuf.Bytes(), &msg); err == nil {
|
||||
s.events <- msg
|
||||
}
|
||||
}
|
||||
dataBuf.Reset()
|
||||
eventName = ""
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(line, ":") {
|
||||
// SSE comment (heartbeat); ignore.
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(line, "event:") {
|
||||
eventName = strings.TrimSpace(line[len("event:"):])
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(line, "data:") {
|
||||
payload := strings.TrimSpace(line[len("data:"):])
|
||||
if dataBuf.Len() > 0 {
|
||||
dataBuf.WriteByte('\n')
|
||||
}
|
||||
dataBuf.WriteString(payload)
|
||||
continue
|
||||
}
|
||||
// id:, retry:, anything else — ignore for now.
|
||||
}
|
||||
}
|
||||
|
||||
// SignalLoop runs an SSE consumer that reconnects automatically on disconnect.
|
||||
// onMessage is called for every browser-produced message. Returns when ctx is
|
||||
// cancelled. Reconnect backoff is fixed at 1 s — the server already paces
|
||||
// reconnects with `retry: 1500` headers so churn is bounded.
|
||||
func (c *Client) SignalLoop(ctx context.Context, sessionID string, onMessage func(SignalMessage)) error {
|
||||
for ctx.Err() == nil {
|
||||
stream, err := c.OpenSignalStream(ctx, sessionID)
|
||||
if err != nil {
|
||||
select {
|
||||
case <-time.After(time.Second):
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
continue
|
||||
}
|
||||
for msg := range stream.Events() {
|
||||
onMessage(msg)
|
||||
}
|
||||
streamErr := stream.Err()
|
||||
stream.Close()
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
// Server closes the SSE every ~25 s; reconnect immediately.
|
||||
// Hard error → small backoff so we don't hammer.
|
||||
if streamErr != nil {
|
||||
select {
|
||||
case <-time.After(time.Second):
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
return ctx.Err()
|
||||
}
|
||||
153
internal/agent/signal_client_test.go
Normal file
153
internal/agent/signal_client_test.go
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// fakeSSEServer streams a fixed set of SSE events then closes the connection.
|
||||
func fakeSSEServer(t *testing.T, msgs []SignalMessage, holdOpenAfter bool) *httptest.Server {
|
||||
t.Helper()
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("Authorization") != "Bearer test-key" {
|
||||
http.Error(w, "auth", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
t.Fatal("server: ResponseWriter is not a Flusher")
|
||||
}
|
||||
fmt.Fprint(w, "retry: 1500\n\n")
|
||||
flusher.Flush()
|
||||
for _, m := range msgs {
|
||||
data, _ := json.Marshal(m)
|
||||
fmt.Fprintf(w, "id: %d\nevent: signal\ndata: %s\n\n", m.TS, data)
|
||||
flusher.Flush()
|
||||
}
|
||||
// Send a heartbeat comment to verify it's ignored.
|
||||
fmt.Fprint(w, ": heartbeat\n\n")
|
||||
flusher.Flush()
|
||||
if holdOpenAfter {
|
||||
// Hold the connection until the client disconnects so the test can
|
||||
// exercise stream.Close().
|
||||
<-r.Context().Done()
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
func TestSignalStreamReadsMessages(t *testing.T) {
|
||||
want := []SignalMessage{
|
||||
{From: SignalRoleBrowser, Type: SignalMsgOffer, Payload: "{sdp:1}", TS: 1},
|
||||
{From: SignalRoleBrowser, Type: SignalMsgCandidate, Payload: "{cand:1}", TS: 2},
|
||||
}
|
||||
srv := fakeSSEServer(t, want, false)
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key", "test-ua")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
stream, err := c.OpenSignalStream(ctx, "session-1")
|
||||
if err != nil {
|
||||
t.Fatalf("open: %v", err)
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
var got []SignalMessage
|
||||
for m := range stream.Events() {
|
||||
got = append(got, m)
|
||||
if len(got) == len(want) {
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("got %d messages, want %d", len(got), len(want))
|
||||
}
|
||||
for i, m := range got {
|
||||
if m.From != want[i].From || m.Type != want[i].Type || m.Payload != want[i].Payload {
|
||||
t.Errorf("[%d] mismatch: %+v want %+v", i, m, want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignalStreamPropagatesAuthError(t *testing.T) {
|
||||
srv := fakeSSEServer(t, nil, false)
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "wrong-key", "test-ua")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_, err := c.OpenSignalStream(ctx, "session-1")
|
||||
if err == nil {
|
||||
t.Fatal("expected auth error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignalStreamCloseCancelsRead(t *testing.T) {
|
||||
srv := fakeSSEServer(t, nil, true)
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key", "test-ua")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
stream, err := c.OpenSignalStream(ctx, "session-1")
|
||||
if err != nil {
|
||||
t.Fatalf("open: %v", err)
|
||||
}
|
||||
|
||||
// Close on a separate goroutine then make sure the events channel drains.
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
stream.Close()
|
||||
}()
|
||||
|
||||
for range stream.Events() {
|
||||
// drain
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestPostSignalSendsCorrectBody(t *testing.T) {
|
||||
var bodySeen map[string]any
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("Authorization") != "Bearer test-key" {
|
||||
http.Error(w, "auth", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
_ = json.NewDecoder(r.Body).Decode(&bodySeen)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprint(w, `{"ok":true}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key", "test-ua")
|
||||
err := c.PostSignal(context.Background(), "sess-x", SignalMessage{
|
||||
Type: SignalMsgAnswer,
|
||||
Payload: "{sdp:answer}",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("post: %v", err)
|
||||
}
|
||||
if bodySeen["from"] != string(SignalRoleAgent) {
|
||||
t.Errorf("expected from=agent, got %v", bodySeen["from"])
|
||||
}
|
||||
if bodySeen["type"] != string(SignalMsgAnswer) {
|
||||
t.Errorf("expected type=answer, got %v", bodySeen["type"])
|
||||
}
|
||||
if bodySeen["payload"] != "{sdp:answer}" {
|
||||
t.Errorf("expected payload mismatch, got %v", bodySeen["payload"])
|
||||
}
|
||||
}
|
||||
|
|
@ -29,6 +29,7 @@ type SyncClient struct {
|
|||
OnNewTasks func(tasks []Task)
|
||||
OnControl func(action, taskID string, deleteFiles bool)
|
||||
OnStreamRequest func(req StreamRequest)
|
||||
OnWebRTCSession func(sess WebRTCSession)
|
||||
OnUpgrade func(version string)
|
||||
OnScan func()
|
||||
OnWatchingChange func(watching bool)
|
||||
|
|
@ -191,6 +192,13 @@ func (sc *SyncClient) processResponse(resp *SyncResponse) {
|
|||
}
|
||||
}
|
||||
|
||||
// WebRTC streaming sessions
|
||||
for _, ws := range resp.WebRTCSessions {
|
||||
if sc.OnWebRTCSession != nil {
|
||||
sc.OnWebRTCSession(ws)
|
||||
}
|
||||
}
|
||||
|
||||
// Upgrade
|
||||
if resp.Upgrade != nil && resp.Upgrade.Version != "" && sc.OnUpgrade != nil {
|
||||
sc.OnUpgrade(resp.Upgrade.Version)
|
||||
|
|
|
|||
|
|
@ -351,11 +351,25 @@ type LibraryDeleteRequest struct {
|
|||
FilePath string `json:"filePath"`
|
||||
}
|
||||
|
||||
// WebRTCSession is a request to open a custom WebRTC DataChannel byte-stream
|
||||
// to a browser player. The CLI must POST an SDP answer to
|
||||
// /api/internal/stream/signal/<sessionId> and serve bytes from FilePath
|
||||
// (or, when only InfoHash is set, from a download_task on disk).
|
||||
type WebRTCSession struct {
|
||||
SessionID string `json:"sessionId"`
|
||||
FilePath string `json:"filePath,omitempty"`
|
||||
InfoHash string `json:"infoHash,omitempty"`
|
||||
TaskID string `json:"taskId,omitempty"`
|
||||
FileName string `json:"fileName,omitempty"`
|
||||
FileSize int64 `json:"fileSize,omitempty"`
|
||||
}
|
||||
|
||||
// SyncResponse is returned by the server with all pending actions for the CLI.
|
||||
type SyncResponse struct {
|
||||
NewTasks []Task `json:"newTasks,omitempty"`
|
||||
Controls []ControlAction `json:"controls,omitempty"`
|
||||
StreamRequests []StreamRequest `json:"streamRequests,omitempty"`
|
||||
WebRTCSessions []WebRTCSession `json:"webrtcSessions,omitempty"`
|
||||
Watching bool `json:"watching"`
|
||||
Upgrade *UpgradeSignal `json:"upgrade,omitempty"`
|
||||
Scan bool `json:"scan,omitempty"`
|
||||
|
|
|
|||
|
|
@ -410,6 +410,65 @@ func runDaemonStart() error {
|
|||
}()
|
||||
}
|
||||
|
||||
// Wire: sync receives custom WebRTC streaming session requests.
|
||||
// Each session is a one-shot browser↔daemon DataChannel. Validate the
|
||||
// FilePath against allowed dirs to prevent path traversal abuse from a
|
||||
// compromised server, then spawn the pion peer in its own goroutine.
|
||||
d.OnWebRTCSession = func(sess agent.WebRTCSession) {
|
||||
if webrtcRegistry.has(sess.SessionID) {
|
||||
return // already running
|
||||
}
|
||||
if !cfg.Download.WebRTC.Enabled {
|
||||
log.Printf("webrtc session %s rejected: webrtc disabled in config", agent.ShortID(sess.SessionID))
|
||||
return
|
||||
}
|
||||
filePath := sess.FilePath
|
||||
if filePath == "" {
|
||||
log.Printf("webrtc session %s rejected: empty file path", agent.ShortID(sess.SessionID))
|
||||
return
|
||||
}
|
||||
filePath = filepath.Clean(filePath)
|
||||
if !isAllowedStreamPath(filePath, cfg.Download.Dir, cfg.Library.ScanPath,
|
||||
cfg.Organize.MoviesDir, cfg.Organize.TVShowsDir) {
|
||||
log.Printf("webrtc session %s rejected: path outside allowed dirs: %s",
|
||||
agent.ShortID(sess.SessionID), filePath)
|
||||
return
|
||||
}
|
||||
// Resolve directory → first video file (matches StreamRequest behavior).
|
||||
if info, err := os.Stat(filePath); err == nil && info.IsDir() {
|
||||
found := engine.FindVideoFile(filePath)
|
||||
if found == "" {
|
||||
log.Printf("webrtc session %s rejected: no video file in dir %s",
|
||||
agent.ShortID(sess.SessionID), filePath)
|
||||
return
|
||||
}
|
||||
filePath = found
|
||||
}
|
||||
sessCtx, sessCancel := context.WithCancel(ctx) //nolint:gosec // G118 cancel stored in registry
|
||||
webrtcRegistry.add(sess.SessionID, sessCancel)
|
||||
go func() {
|
||||
defer func() {
|
||||
webrtcRegistry.remove(sess.SessionID)
|
||||
sessCancel()
|
||||
}()
|
||||
runCfg := engine.WebRTCStreamConfig{
|
||||
SessionID: sess.SessionID,
|
||||
FilePath: filePath,
|
||||
FileName: sess.FileName,
|
||||
FileSize: sess.FileSize,
|
||||
ICEServers: engine.BuildICEServers(cfg.Download.WebRTC),
|
||||
Signal: agentClient,
|
||||
Logger: stdLogger{},
|
||||
}
|
||||
log.Printf("[wrtc %s] starting session: %s", agent.ShortID(sess.SessionID), filepath.Base(filePath))
|
||||
if err := engine.RunWebRTCStream(sessCtx, runCfg); err != nil {
|
||||
if sessCtx.Err() == nil {
|
||||
log.Printf("[wrtc %s] ended: %v", agent.ShortID(sess.SessionID), err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Periodic DHT node persistence (every 5 min)
|
||||
go func() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
|
|
@ -457,6 +516,7 @@ func runDaemonStart() error {
|
|||
case sig := <-sigCh:
|
||||
fmt.Printf("\n Received %s, shutting down...\n", sig)
|
||||
cancelStreamContexts()
|
||||
cancelAllWebRTCSessions()
|
||||
streamSrv.Shutdown(context.Background())
|
||||
cancel()
|
||||
|
||||
|
|
@ -471,6 +531,7 @@ func runDaemonStart() error {
|
|||
|
||||
case err := <-errCh:
|
||||
cancelStreamContexts()
|
||||
cancelAllWebRTCSessions()
|
||||
streamSrv.Shutdown(context.Background())
|
||||
cancel()
|
||||
return err
|
||||
|
|
|
|||
62
internal/cmd/webrtc_session_registry.go
Normal file
62
internal/cmd/webrtc_session_registry.go
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// webrtcRegistry tracks per-session cancel funcs for active custom WebRTC
|
||||
// streams (engine.RunWebRTCStream goroutines). Each session lives only as
|
||||
// long as its DataChannel; the registry exists so duplicate sync responses
|
||||
// don't double-spawn the same session and so daemon shutdown can drain.
|
||||
var webrtcRegistry = &webrtcSessionRegistry{
|
||||
cancels: make(map[string]context.CancelFunc),
|
||||
}
|
||||
|
||||
type webrtcSessionRegistry struct {
|
||||
mu sync.Mutex
|
||||
cancels map[string]context.CancelFunc
|
||||
}
|
||||
|
||||
func (r *webrtcSessionRegistry) has(sessionID string) bool {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
_, ok := r.cancels[sessionID]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (r *webrtcSessionRegistry) add(sessionID string, cancel context.CancelFunc) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.cancels[sessionID] = cancel
|
||||
}
|
||||
|
||||
func (r *webrtcSessionRegistry) remove(sessionID string) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
delete(r.cancels, sessionID)
|
||||
}
|
||||
|
||||
// cancelAllWebRTCSessions cancels every running session. Called on daemon
|
||||
// shutdown so pion peers and SSE consumers exit cleanly.
|
||||
func cancelAllWebRTCSessions() {
|
||||
webrtcRegistry.mu.Lock()
|
||||
cancels := make([]context.CancelFunc, 0, len(webrtcRegistry.cancels))
|
||||
for _, c := range webrtcRegistry.cancels {
|
||||
cancels = append(cancels, c)
|
||||
}
|
||||
webrtcRegistry.cancels = make(map[string]context.CancelFunc)
|
||||
webrtcRegistry.mu.Unlock()
|
||||
for _, c := range cancels {
|
||||
c()
|
||||
}
|
||||
}
|
||||
|
||||
// stdLogger is a tiny adapter so engine.RunWebRTCStream can log through the
|
||||
// standard library logger without pulling in a logging dependency.
|
||||
type stdLogger struct{}
|
||||
|
||||
func (stdLogger) Infof(format string, args ...any) { log.Printf(format, args...) }
|
||||
func (stdLogger) Warnf(format string, args ...any) { log.Printf("WARN: "+format, args...) }
|
||||
func (stdLogger) Errorf(format string, args ...any) { log.Printf("ERROR: "+format, args...) }
|
||||
130
internal/engine/hwaccel.go
Normal file
130
internal/engine/hwaccel.go
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// HWAccel identifies a hardware-accelerated ffmpeg encoder family.
|
||||
type HWAccel string
|
||||
|
||||
const (
|
||||
HWAccelNone HWAccel = "none"
|
||||
HWAccelNVENC HWAccel = "nvenc" // NVIDIA — h264_nvenc / hevc_nvenc
|
||||
HWAccelQSV HWAccel = "qsv" // Intel Quick Sync — h264_qsv / hevc_qsv
|
||||
HWAccelVAAPI HWAccel = "vaapi" // Linux open-source — h264_vaapi / hevc_vaapi
|
||||
HWAccelVideoToolbox HWAccel = "videotoolbox" // macOS — h264_videotoolbox
|
||||
)
|
||||
|
||||
var (
|
||||
hwOnce sync.Once
|
||||
hwCache HWAccel
|
||||
)
|
||||
|
||||
// DetectHWAccel returns the most capable hardware encoder available on this
|
||||
// host, or HWAccelNone if software-only. Cached after first call — adding /
|
||||
// removing a GPU at runtime is rare and the cost of probing isn't free.
|
||||
func DetectHWAccel(ctx context.Context, ffmpegPath string) HWAccel {
|
||||
hwOnce.Do(func() {
|
||||
hwCache = detectHWAccelFresh(ctx, ffmpegPath)
|
||||
})
|
||||
return hwCache
|
||||
}
|
||||
|
||||
// ResetHWAccelCache clears the singleton — only used in tests.
|
||||
func ResetHWAccelCache() {
|
||||
hwOnce = sync.Once{}
|
||||
hwCache = ""
|
||||
}
|
||||
|
||||
func detectHWAccelFresh(ctx context.Context, ffmpegPath string) HWAccel {
|
||||
if ffmpegPath == "" {
|
||||
return HWAccelNone
|
||||
}
|
||||
encoders := listFFmpegEncoders(ctx, ffmpegPath)
|
||||
if encoders == "" {
|
||||
return HWAccelNone
|
||||
}
|
||||
|
||||
// macOS — VideoToolbox is always available on Apple Silicon + recent Intel.
|
||||
if runtime.GOOS == "darwin" && strings.Contains(encoders, "h264_videotoolbox") {
|
||||
return HWAccelVideoToolbox
|
||||
}
|
||||
|
||||
// NVIDIA — encoder presence + a CUDA-capable device. We rely on the
|
||||
// existence of the device file rather than running nvidia-smi to keep
|
||||
// startup quick on hosts without nvidia tooling.
|
||||
if strings.Contains(encoders, "h264_nvenc") &&
|
||||
(fileExists("/dev/nvidia0") || hasNvidiaDriver()) {
|
||||
return HWAccelNVENC
|
||||
}
|
||||
|
||||
// Intel Quick Sync — needs /dev/dri (also used by VA-API). Distinguish by
|
||||
// checking whether the QSV-specific encoder is built in.
|
||||
if strings.Contains(encoders, "h264_qsv") && fileExists("/dev/dri/renderD128") {
|
||||
return HWAccelQSV
|
||||
}
|
||||
|
||||
// Linux generic VA-API — works on Intel + AMD with mesa drivers.
|
||||
if strings.Contains(encoders, "h264_vaapi") && fileExists("/dev/dri/renderD128") {
|
||||
return HWAccelVAAPI
|
||||
}
|
||||
|
||||
return HWAccelNone
|
||||
}
|
||||
|
||||
func listFFmpegEncoders(ctx context.Context, ffmpegPath string) string {
|
||||
cmd := exec.CommandContext(ctx, ffmpegPath, "-hide_banner", "-encoders")
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(out)
|
||||
}
|
||||
|
||||
func fileExists(path string) bool {
|
||||
_, err := os.Stat(path)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func hasNvidiaDriver() bool {
|
||||
// Cheap proxy — if the user has nvidia-smi on PATH they presumably also
|
||||
// have a working driver / runtime libraries.
|
||||
_, err := exec.LookPath("nvidia-smi")
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// FFmpegVideoCodec returns the encoder name to pass to `-c:v` for the
|
||||
// requested HW accel + target (h264 or hevc).
|
||||
func (h HWAccel) FFmpegVideoCodec(target string) string {
|
||||
target = strings.ToLower(target)
|
||||
switch h {
|
||||
case HWAccelNVENC:
|
||||
if target == "hevc" {
|
||||
return "hevc_nvenc"
|
||||
}
|
||||
return "h264_nvenc"
|
||||
case HWAccelQSV:
|
||||
if target == "hevc" {
|
||||
return "hevc_qsv"
|
||||
}
|
||||
return "h264_qsv"
|
||||
case HWAccelVAAPI:
|
||||
if target == "hevc" {
|
||||
return "hevc_vaapi"
|
||||
}
|
||||
return "h264_vaapi"
|
||||
case HWAccelVideoToolbox:
|
||||
if target == "hevc" {
|
||||
return "hevc_videotoolbox"
|
||||
}
|
||||
return "h264_videotoolbox"
|
||||
default:
|
||||
// Software fallback. libx264 ships with every ffmpeg build.
|
||||
return "libx264"
|
||||
}
|
||||
}
|
||||
34
internal/engine/hwaccel_test.go
Normal file
34
internal/engine/hwaccel_test.go
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
package engine
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestHWAccelFFmpegVideoCodec(t *testing.T) {
|
||||
cases := []struct {
|
||||
hw HWAccel
|
||||
target string
|
||||
want string
|
||||
}{
|
||||
{HWAccelNone, "h264", "libx264"},
|
||||
{HWAccelNone, "hevc", "libx264"},
|
||||
{HWAccelNVENC, "h264", "h264_nvenc"},
|
||||
{HWAccelNVENC, "hevc", "hevc_nvenc"},
|
||||
{HWAccelQSV, "h264", "h264_qsv"},
|
||||
{HWAccelQSV, "hevc", "hevc_qsv"},
|
||||
{HWAccelVAAPI, "h264", "h264_vaapi"},
|
||||
{HWAccelVAAPI, "hevc", "hevc_vaapi"},
|
||||
{HWAccelVideoToolbox, "h264", "h264_videotoolbox"},
|
||||
{HWAccelVideoToolbox, "hevc", "hevc_videotoolbox"},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
if got := tc.hw.FFmpegVideoCodec(tc.target); got != tc.want {
|
||||
t.Errorf("%s.FFmpegVideoCodec(%q) = %q want %q", tc.hw, tc.target, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectHWAccelEmptyPathReturnsNone(t *testing.T) {
|
||||
ResetHWAccelCache()
|
||||
if got := detectHWAccelFresh(t.Context(), ""); got != HWAccelNone {
|
||||
t.Errorf("got %s, want %s", got, HWAccelNone)
|
||||
}
|
||||
}
|
||||
116
internal/engine/probe.go
Normal file
116
internal/engine/probe.go
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/torrentclaw/unarr/internal/library/mediainfo"
|
||||
)
|
||||
|
||||
// StreamProbe summarises the codec / container shape of a file as it relates
|
||||
// to the WebRTC streaming pipeline. It tells the transcoder whether bytes can
|
||||
// be streamed as-is, just remuxed to fragmented MP4, or fully transcoded.
|
||||
type StreamProbe struct {
|
||||
// VideoCodec lowercased — e.g. "h264", "hevc", "av1", "vp9", "mpeg4".
|
||||
VideoCodec string
|
||||
// AudioCodec lowercased — e.g. "aac", "ac3", "dts", "eac3", "opus".
|
||||
AudioCodec string
|
||||
// Width / Height of the primary video stream.
|
||||
Width int
|
||||
Height int
|
||||
// BitDepth — 8, 10 or 12. 0 if unknown.
|
||||
BitDepth int
|
||||
// HDR signalling string ("HDR10" / "DV" / "HLG" / etc, or "" for SDR).
|
||||
HDR string
|
||||
// DurationSec is the file length, used to sanity-check seek targets.
|
||||
DurationSec float64
|
||||
// Container is the file extension lowercased (".mp4", ".mkv", ".avi").
|
||||
Container string
|
||||
}
|
||||
|
||||
// TranscodeAction tells the streaming pipeline how to feed the file to
|
||||
// the browser <video> element. The decision matrix is documented in the
|
||||
// project plan (Fase 2.5 — Transcoding on-the-fly).
|
||||
type TranscodeAction string
|
||||
|
||||
const (
|
||||
// ActionPassthrough — file is already browser-playable as-is. Stream the
|
||||
// raw bytes via ReadAt; no ffmpeg involved.
|
||||
ActionPassthrough TranscodeAction = "passthrough"
|
||||
// ActionRemux — codecs are browser-compatible but the container or moov
|
||||
// placement is not. Run ffmpeg with `-c copy -movflags frag_keyframe`.
|
||||
ActionRemux TranscodeAction = "remux"
|
||||
// ActionRemuxAudio — video is fine but audio needs a re-encode (AC3/DTS
|
||||
// → AAC). `-c:v copy -c:a aac`.
|
||||
ActionRemuxAudio TranscodeAction = "remux-audio"
|
||||
// ActionTranscodeVideo — full re-encode. Used for HEVC/AV1 and any
|
||||
// 10-bit content if the browser refuses the codec.
|
||||
ActionTranscodeVideo TranscodeAction = "transcode-video"
|
||||
)
|
||||
|
||||
// ProbeFile runs ffprobe and returns a StreamProbe view of the file.
|
||||
func ProbeFile(ctx context.Context, ffprobePath, filePath string) (*StreamProbe, error) {
|
||||
mi, err := mediainfo.ExtractMediaInfo(ctx, ffprobePath, filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("probe: %w", err)
|
||||
}
|
||||
probe := &StreamProbe{Container: lowerExt(filePath)}
|
||||
if mi.Video != nil {
|
||||
probe.VideoCodec = strings.ToLower(mi.Video.Codec)
|
||||
probe.Width = mi.Video.Width
|
||||
probe.Height = mi.Video.Height
|
||||
probe.BitDepth = mi.Video.BitDepth
|
||||
probe.HDR = mi.Video.HDR
|
||||
probe.DurationSec = mi.Video.Duration
|
||||
}
|
||||
if len(mi.Audio) > 0 {
|
||||
// Default to the first track marked "Default", else the first track.
|
||||
picked := mi.Audio[0]
|
||||
for _, a := range mi.Audio {
|
||||
if a.Default {
|
||||
picked = a
|
||||
break
|
||||
}
|
||||
}
|
||||
probe.AudioCodec = strings.ToLower(picked.Codec)
|
||||
}
|
||||
return probe, nil
|
||||
}
|
||||
|
||||
// DecideAction maps a probe to the transcoding action the streaming pipeline
|
||||
// should take. Browsers consume MP4/h264+AAC natively; everything else needs
|
||||
// some level of re-shaping.
|
||||
func DecideAction(p *StreamProbe) TranscodeAction {
|
||||
if p == nil {
|
||||
return ActionPassthrough
|
||||
}
|
||||
video := p.VideoCodec
|
||||
audio := p.AudioCodec
|
||||
container := p.Container
|
||||
|
||||
// 10-bit / HDR is a hard no for browser playback even if h264 — needs SW transcode.
|
||||
tenBitOrHDR := p.BitDepth >= 10 || p.HDR != ""
|
||||
|
||||
if !tenBitOrHDR && video == "h264" {
|
||||
if audio == "aac" {
|
||||
if container == ".mp4" {
|
||||
return ActionPassthrough
|
||||
}
|
||||
return ActionRemux
|
||||
}
|
||||
// Audio incompatible (AC3/DTS/TrueHD/EAC3) → remux video, transcode audio.
|
||||
return ActionRemuxAudio
|
||||
}
|
||||
|
||||
// HEVC / AV1 / VP9 / 10-bit / unknown → full re-encode video.
|
||||
return ActionTranscodeVideo
|
||||
}
|
||||
|
||||
func lowerExt(filePath string) string {
|
||||
dot := strings.LastIndex(filePath, ".")
|
||||
if dot < 0 {
|
||||
return ""
|
||||
}
|
||||
return strings.ToLower(filePath[dot:])
|
||||
}
|
||||
96
internal/engine/probe_test.go
Normal file
96
internal/engine/probe_test.go
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
package engine
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDecideAction(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
p StreamProbe
|
||||
want TranscodeAction
|
||||
}{
|
||||
{
|
||||
name: "MP4 + h264 + AAC = passthrough",
|
||||
p: StreamProbe{VideoCodec: "h264", AudioCodec: "aac", Container: ".mp4"},
|
||||
want: ActionPassthrough,
|
||||
},
|
||||
{
|
||||
name: "MKV + h264 + AAC = remux",
|
||||
p: StreamProbe{VideoCodec: "h264", AudioCodec: "aac", Container: ".mkv"},
|
||||
want: ActionRemux,
|
||||
},
|
||||
{
|
||||
name: "MKV + h264 + AC3 = remux audio",
|
||||
p: StreamProbe{VideoCodec: "h264", AudioCodec: "ac3", Container: ".mkv"},
|
||||
want: ActionRemuxAudio,
|
||||
},
|
||||
{
|
||||
name: "MP4 + h264 + EAC3 = remux audio",
|
||||
p: StreamProbe{VideoCodec: "h264", AudioCodec: "eac3", Container: ".mp4"},
|
||||
want: ActionRemuxAudio,
|
||||
},
|
||||
{
|
||||
name: "MKV + HEVC = transcode video",
|
||||
p: StreamProbe{VideoCodec: "hevc", AudioCodec: "aac", Container: ".mkv"},
|
||||
want: ActionTranscodeVideo,
|
||||
},
|
||||
{
|
||||
name: "MP4 + AV1 = transcode video",
|
||||
p: StreamProbe{VideoCodec: "av1", AudioCodec: "aac", Container: ".mp4"},
|
||||
want: ActionTranscodeVideo,
|
||||
},
|
||||
{
|
||||
name: "h264 10-bit = transcode video (browser refuses)",
|
||||
p: StreamProbe{VideoCodec: "h264", AudioCodec: "aac", BitDepth: 10, Container: ".mp4"},
|
||||
want: ActionTranscodeVideo,
|
||||
},
|
||||
{
|
||||
name: "h264 + HDR10 = transcode video",
|
||||
p: StreamProbe{VideoCodec: "h264", AudioCodec: "aac", HDR: "HDR10", Container: ".mp4"},
|
||||
want: ActionTranscodeVideo,
|
||||
},
|
||||
{
|
||||
name: "AVI + h264 + AAC = remux",
|
||||
p: StreamProbe{VideoCodec: "h264", AudioCodec: "aac", Container: ".avi"},
|
||||
want: ActionRemux,
|
||||
},
|
||||
{
|
||||
name: "Unknown codec = transcode video",
|
||||
p: StreamProbe{VideoCodec: "mpeg4", AudioCodec: "mp3", Container: ".avi"},
|
||||
want: ActionTranscodeVideo,
|
||||
},
|
||||
{
|
||||
name: "Empty probe falls through to transcode (unknown codec)",
|
||||
p: StreamProbe{},
|
||||
want: ActionTranscodeVideo,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := DecideAction(&tc.p)
|
||||
if got != tc.want {
|
||||
t.Errorf("got %s, want %s", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecideActionNil(t *testing.T) {
|
||||
if DecideAction(nil) != ActionPassthrough {
|
||||
t.Error("nil probe should default passthrough")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLowerExt(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"foo.MP4": ".mp4",
|
||||
"path/to/movie.MKV": ".mkv",
|
||||
"weird.name.with.dots": ".dots",
|
||||
"": "",
|
||||
"noext": "",
|
||||
}
|
||||
for in, want := range cases {
|
||||
if got := lowerExt(in); got != want {
|
||||
t.Errorf("lowerExt(%q) = %q want %q", in, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
179
internal/engine/transcoder.go
Normal file
179
internal/engine/transcoder.go
Normal file
|
|
@ -0,0 +1,179 @@
|
|||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TranscodeOpts steers how Transcoder builds its ffmpeg command line. Defaults
|
||||
// match the project's plan/clever-weaving-dove.md (Fase 2.5):
|
||||
//
|
||||
// - Output: fragmented MP4 readable by browser <video> via MSE-less Range.
|
||||
// - Audio: AAC stereo @ 192kbps unless source already AAC (then -c:a copy).
|
||||
// - Video: copy when h264 8-bit; otherwise transcode to h264 with HW encode
|
||||
// when available, software fallback at "veryfast" preset.
|
||||
type TranscodeOpts struct {
|
||||
Action TranscodeAction
|
||||
HWAccel HWAccel
|
||||
Preset string // "veryfast" / "fast" / "medium"
|
||||
VideoBitrate string // e.g. "5M"
|
||||
AudioBitrate string // e.g. "192k"
|
||||
MaxHeight int // optional downscale cap (e.g. 720)
|
||||
StartSeconds float64
|
||||
FFmpegPath string
|
||||
}
|
||||
|
||||
// Transcoder wraps a long-running ffmpeg child process whose stdout streams
|
||||
// fragmented MP4 bytes for the WebRTC pump to forward to the browser.
|
||||
//
|
||||
// One Transcoder == one playback position. A seek beyond the buffered window
|
||||
// requires Close()ing this transcoder and starting a new one with a higher
|
||||
// StartSeconds (handled in webrtc_stream.go).
|
||||
type Transcoder struct {
|
||||
cmd *exec.Cmd
|
||||
out io.ReadCloser
|
||||
|
||||
mu sync.Mutex
|
||||
closed bool
|
||||
stderr strings.Builder
|
||||
}
|
||||
|
||||
// NewTranscoder spawns ffmpeg and returns a Transcoder whose Read() yields
|
||||
// fragmented MP4 bytes from stdin. Callers MUST call Close() when done.
|
||||
func NewTranscoder(ctx context.Context, filePath string, opts TranscodeOpts) (*Transcoder, error) {
|
||||
if opts.FFmpegPath == "" {
|
||||
return nil, fmt.Errorf("transcoder: empty ffmpeg path")
|
||||
}
|
||||
args := buildFFmpegArgs(filePath, opts)
|
||||
cmd := exec.CommandContext(ctx, opts.FFmpegPath, args...)
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("transcoder: stdout pipe: %w", err)
|
||||
}
|
||||
t := &Transcoder{cmd: cmd, out: stdout}
|
||||
cmd.Stderr = &errWriter{t: t}
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("transcoder: start ffmpeg: %w", err)
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// Read implements io.Reader.
|
||||
func (t *Transcoder) Read(p []byte) (int, error) { return t.out.Read(p) }
|
||||
|
||||
// Close kills the child process if still running and waits up to 2s for exit.
|
||||
func (t *Transcoder) Close() error {
|
||||
t.mu.Lock()
|
||||
if t.closed {
|
||||
t.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
t.closed = true
|
||||
t.mu.Unlock()
|
||||
|
||||
_ = t.out.Close()
|
||||
if t.cmd.Process != nil {
|
||||
_ = t.cmd.Process.Kill()
|
||||
}
|
||||
done := make(chan error, 1)
|
||||
go func() { done <- t.cmd.Wait() }()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
// Process refused to die — leak it; the OS will clean up on exit.
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stderr returns the accumulated ffmpeg stderr so far. Useful for surfacing
|
||||
// failure reasons in logs after Close().
|
||||
func (t *Transcoder) Stderr() string {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
return t.stderr.String()
|
||||
}
|
||||
|
||||
// errWriter funnels ffmpeg stderr into the Transcoder buffer so it can be
|
||||
// inspected post-mortem. Capped so a misbehaving ffmpeg can't grow memory.
|
||||
type errWriter struct{ t *Transcoder }
|
||||
|
||||
func (w *errWriter) Write(p []byte) (int, error) {
|
||||
w.t.mu.Lock()
|
||||
defer w.t.mu.Unlock()
|
||||
const maxBuf = 64 * 1024
|
||||
if w.t.stderr.Len() < maxBuf {
|
||||
w.t.stderr.Write(p)
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// buildFFmpegArgs assembles the command line for the requested action.
|
||||
// Exposed package-level so tests can lock the flag matrix independently of
|
||||
// process spawning.
|
||||
func buildFFmpegArgs(filePath string, opts TranscodeOpts) []string {
|
||||
args := []string{"-hide_banner", "-loglevel", "warning"}
|
||||
|
||||
// Seek BEFORE input (-ss before -i) for fast keyframe-aligned start.
|
||||
if opts.StartSeconds > 0 {
|
||||
args = append(args, "-ss", strconv.FormatFloat(opts.StartSeconds, 'f', 3, 64))
|
||||
}
|
||||
|
||||
// HW accel hint on the demuxer side improves throughput for HEVC inputs
|
||||
// even when we end up encoding in software. Skip on macOS (videotoolbox
|
||||
// uses a different flag shape).
|
||||
switch opts.HWAccel {
|
||||
case HWAccelNVENC:
|
||||
args = append(args, "-hwaccel", "cuda")
|
||||
case HWAccelQSV:
|
||||
args = append(args, "-hwaccel", "qsv")
|
||||
case HWAccelVAAPI:
|
||||
args = append(args, "-hwaccel", "vaapi", "-hwaccel_output_format", "vaapi")
|
||||
case HWAccelNone, HWAccelVideoToolbox:
|
||||
// No demuxer-side hint: software decode (None) or per-encoder flags
|
||||
// already applied separately by FFmpegVideoCodec (VideoToolbox).
|
||||
}
|
||||
|
||||
args = append(args, "-i", filePath)
|
||||
|
||||
switch opts.Action {
|
||||
case ActionPassthrough, ActionRemux:
|
||||
args = append(args, "-c:v", "copy", "-c:a", "copy")
|
||||
case ActionRemuxAudio:
|
||||
args = append(args, "-c:v", "copy", "-c:a", "aac", "-b:a", coalesce(opts.AudioBitrate, "192k"))
|
||||
case ActionTranscodeVideo:
|
||||
videoCodec := opts.HWAccel.FFmpegVideoCodec("h264")
|
||||
args = append(args, "-c:v", videoCodec)
|
||||
if videoCodec == "libx264" {
|
||||
args = append(args, "-preset", coalesce(opts.Preset, "veryfast"))
|
||||
}
|
||||
args = append(args, "-b:v", coalesce(opts.VideoBitrate, "5M"))
|
||||
if opts.MaxHeight > 0 {
|
||||
args = append(args,
|
||||
"-vf",
|
||||
fmt.Sprintf("scale='min(iw,iw*%d/ih)':'min(ih,%d)'", opts.MaxHeight, opts.MaxHeight),
|
||||
)
|
||||
}
|
||||
args = append(args, "-c:a", "aac", "-b:a", coalesce(opts.AudioBitrate, "192k"))
|
||||
}
|
||||
|
||||
// Common output flags — fragmented MP4 to a single pipe.
|
||||
args = append(args,
|
||||
"-movflags", "frag_keyframe+empty_moov+default_base_moof+faststart",
|
||||
"-f", "mp4",
|
||||
"pipe:1",
|
||||
)
|
||||
return args
|
||||
}
|
||||
|
||||
func coalesce(s, fallback string) string {
|
||||
if s == "" {
|
||||
return fallback
|
||||
}
|
||||
return s
|
||||
}
|
||||
151
internal/engine/transcoder_test.go
Normal file
151
internal/engine/transcoder_test.go
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
package engine
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func sliceContains(args []string, want string) bool {
|
||||
for _, a := range args {
|
||||
if a == want {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func sliceContainsPair(args []string, key, val string) bool {
|
||||
for i := 0; i < len(args)-1; i++ {
|
||||
if args[i] == key && args[i+1] == val {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func TestBuildFFmpegArgsPassthroughCopy(t *testing.T) {
|
||||
args := buildFFmpegArgs("/tmp/movie.mp4", TranscodeOpts{
|
||||
Action: ActionPassthrough,
|
||||
HWAccel: HWAccelNone,
|
||||
FFmpegPath: "ffmpeg",
|
||||
})
|
||||
if !sliceContainsPair(args, "-c:v", "copy") {
|
||||
t.Errorf("passthrough should keep -c:v copy. args=%v", args)
|
||||
}
|
||||
if !sliceContainsPair(args, "-c:a", "copy") {
|
||||
t.Error("passthrough should keep -c:a copy")
|
||||
}
|
||||
if !sliceContainsPair(args, "-f", "mp4") {
|
||||
t.Error("output container must be mp4")
|
||||
}
|
||||
movflags := ""
|
||||
for i := 0; i < len(args)-1; i++ {
|
||||
if args[i] == "-movflags" {
|
||||
movflags = args[i+1]
|
||||
}
|
||||
}
|
||||
if !strings.Contains(movflags, "frag_keyframe") {
|
||||
t.Errorf("movflags must include frag_keyframe, got %q", movflags)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildFFmpegArgsRemuxAudio(t *testing.T) {
|
||||
args := buildFFmpegArgs("/tmp/movie.mkv", TranscodeOpts{
|
||||
Action: ActionRemuxAudio,
|
||||
AudioBitrate: "256k",
|
||||
FFmpegPath: "ffmpeg",
|
||||
})
|
||||
if !sliceContainsPair(args, "-c:v", "copy") {
|
||||
t.Error("remux-audio keeps video copy")
|
||||
}
|
||||
if !sliceContainsPair(args, "-c:a", "aac") {
|
||||
t.Error("remux-audio must transcode audio to aac")
|
||||
}
|
||||
if !sliceContainsPair(args, "-b:a", "256k") {
|
||||
t.Error("audio bitrate override not honored")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildFFmpegArgsTranscodeVideoSoftware(t *testing.T) {
|
||||
args := buildFFmpegArgs("/tmp/movie.mkv", TranscodeOpts{
|
||||
Action: ActionTranscodeVideo,
|
||||
HWAccel: HWAccelNone,
|
||||
Preset: "fast",
|
||||
VideoBitrate: "6M",
|
||||
FFmpegPath: "ffmpeg",
|
||||
})
|
||||
if !sliceContainsPair(args, "-c:v", "libx264") {
|
||||
t.Error("software fallback must use libx264")
|
||||
}
|
||||
if !sliceContainsPair(args, "-preset", "fast") {
|
||||
t.Error("custom preset not honored")
|
||||
}
|
||||
if !sliceContainsPair(args, "-b:v", "6M") {
|
||||
t.Error("video bitrate not honored")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildFFmpegArgsTranscodeVideoNVENC(t *testing.T) {
|
||||
args := buildFFmpegArgs("/tmp/movie.mkv", TranscodeOpts{
|
||||
Action: ActionTranscodeVideo,
|
||||
HWAccel: HWAccelNVENC,
|
||||
FFmpegPath: "ffmpeg",
|
||||
})
|
||||
if !sliceContainsPair(args, "-hwaccel", "cuda") {
|
||||
t.Error("NVENC must request -hwaccel cuda")
|
||||
}
|
||||
if !sliceContainsPair(args, "-c:v", "h264_nvenc") {
|
||||
t.Error("NVENC must use h264_nvenc encoder")
|
||||
}
|
||||
if sliceContains(args, "-preset") {
|
||||
// HW encoders ignore software preset; we should NOT pass it.
|
||||
t.Error("HW encoder path should not include -preset")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildFFmpegArgsAddsStartSeek(t *testing.T) {
|
||||
args := buildFFmpegArgs("/tmp/movie.mp4", TranscodeOpts{
|
||||
Action: ActionPassthrough,
|
||||
StartSeconds: 90.5,
|
||||
FFmpegPath: "ffmpeg",
|
||||
})
|
||||
idxSs, idxIn := -1, -1
|
||||
for i, a := range args {
|
||||
if a == "-ss" {
|
||||
idxSs = i
|
||||
}
|
||||
if a == "-i" {
|
||||
idxIn = i
|
||||
}
|
||||
}
|
||||
if idxSs < 0 {
|
||||
t.Fatal("missing -ss flag")
|
||||
}
|
||||
if idxIn < 0 {
|
||||
t.Fatal("missing -i flag")
|
||||
}
|
||||
if idxSs >= idxIn {
|
||||
t.Errorf("expected -ss BEFORE -i for fast seek; got -ss@%d -i@%d", idxSs, idxIn)
|
||||
}
|
||||
if args[idxSs+1] != "90.500" {
|
||||
t.Errorf("expected seek 90.500s, got %q", args[idxSs+1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildFFmpegArgsDownscale(t *testing.T) {
|
||||
args := buildFFmpegArgs("/tmp/movie.mkv", TranscodeOpts{
|
||||
Action: ActionTranscodeVideo,
|
||||
HWAccel: HWAccelNone,
|
||||
MaxHeight: 720,
|
||||
FFmpegPath: "ffmpeg",
|
||||
})
|
||||
hasVF := false
|
||||
for i := 0; i < len(args)-1; i++ {
|
||||
if args[i] == "-vf" && strings.Contains(args[i+1], "720") {
|
||||
hasVF = true
|
||||
}
|
||||
}
|
||||
if !hasVF {
|
||||
t.Errorf("expected -vf scale containing 720; args=%v", args)
|
||||
}
|
||||
}
|
||||
617
internal/engine/webrtc_stream.go
Normal file
617
internal/engine/webrtc_stream.go
Normal file
|
|
@ -0,0 +1,617 @@
|
|||
// Package engine — webrtc_stream.go implements the daemon side of the custom
|
||||
// WebRTC byte-streaming protocol. The browser opens an RTCDataChannel via
|
||||
// SDP exchange (signalled over the web's HTTP + SSE relay); this code:
|
||||
//
|
||||
// 1. Parses the browser's SDP offer.
|
||||
// 2. Creates a pion PeerConnection bound to the configured ICE servers.
|
||||
// 3. Answers + trickles its own ICE candidates back through the signal client.
|
||||
// 4. On DataChannel open, sends a HELLO frame describing the file.
|
||||
// 5. Services RangeReq frames by reading from disk and emitting RangeData
|
||||
// chunks (16 KiB each) followed by a RangeEnd.
|
||||
// 6. Honours app-level backpressure via SetBufferedAmountLowThreshold +
|
||||
// OnBufferedAmountLow — Chromium closes a DataChannel when bufferedAmount
|
||||
// exceeds 16 MiB, so we MUST pause the writer.
|
||||
//
|
||||
// No anacrolix, no torrent metadata. Just a peer-to-peer file server over
|
||||
// WebRTC. Pass-through path; transcoding lives in transcoder.go (Fase 2.5).
|
||||
|
||||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/pion/webrtc/v4"
|
||||
|
||||
"github.com/torrentclaw/unarr/internal/agent"
|
||||
"github.com/torrentclaw/unarr/internal/engine/wire"
|
||||
)
|
||||
|
||||
// Tunables — values match the protocol spec in plan/clever-weaving-dove.md.
|
||||
const (
|
||||
// dcChunkPayload is the per-frame application payload size. Must match
|
||||
// wire.MaxChunkPayload so RangeData frames fit one SCTP message.
|
||||
dcChunkPayload = wire.MaxChunkPayload
|
||||
// dcHighWatermark is the bufferedAmount cap above which the writer pauses.
|
||||
// Chromium closes DCs above 16 MiB; pause well below.
|
||||
dcHighWatermark = 8 << 20
|
||||
// dcLowWatermark triggers OnBufferedAmountLow → resume the writer.
|
||||
dcLowWatermark = 1 << 20
|
||||
// rangeReqConcurrency is the cap on in-flight range responses per session.
|
||||
rangeReqConcurrency = 4
|
||||
// helloDeadline is the max wait for the DataChannel to open after answer.
|
||||
helloDeadline = 30 * time.Second
|
||||
)
|
||||
|
||||
// WebRTCStreamConfig describes a single browser ↔ daemon stream session.
|
||||
type WebRTCStreamConfig struct {
|
||||
SessionID string
|
||||
FilePath string
|
||||
FileName string
|
||||
FileSize int64
|
||||
ICEServers []webrtc.ICEServer
|
||||
Signal *agent.Client
|
||||
// Logger receives diagnostic events; a nil logger swallows everything.
|
||||
Logger StreamLogger
|
||||
}
|
||||
|
||||
// StreamLogger is an injectable logger so tests can capture events.
|
||||
type StreamLogger interface {
|
||||
Infof(format string, args ...any)
|
||||
Warnf(format string, args ...any)
|
||||
Errorf(format string, args ...any)
|
||||
}
|
||||
|
||||
type nopLogger struct{}
|
||||
|
||||
func (nopLogger) Infof(string, ...any) {}
|
||||
func (nopLogger) Warnf(string, ...any) {}
|
||||
func (nopLogger) Errorf(string, ...any) {}
|
||||
|
||||
func logger(l StreamLogger) StreamLogger {
|
||||
if l == nil {
|
||||
return nopLogger{}
|
||||
}
|
||||
return l
|
||||
}
|
||||
|
||||
// RunWebRTCStream blocks until the session ends — either the DataChannel
|
||||
// closes, the peer connection drops, or ctx is cancelled. Always returns a
|
||||
// non-nil error explaining the termination reason.
|
||||
func RunWebRTCStream(ctx context.Context, cfg WebRTCStreamConfig) error {
|
||||
log := logger(cfg.Logger)
|
||||
|
||||
if cfg.SessionID == "" {
|
||||
return errors.New("webrtc_stream: empty SessionID")
|
||||
}
|
||||
if cfg.FilePath == "" {
|
||||
return errors.New("webrtc_stream: empty FilePath")
|
||||
}
|
||||
|
||||
abs, err := filepath.Abs(cfg.FilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("webrtc_stream: resolve path: %w", err)
|
||||
}
|
||||
file, err := os.Open(abs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("webrtc_stream: open file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
stat, err := file.Stat()
|
||||
if err != nil {
|
||||
return fmt.Errorf("webrtc_stream: stat: %w", err)
|
||||
}
|
||||
fileSize := stat.Size()
|
||||
if cfg.FileSize > 0 && cfg.FileSize != fileSize {
|
||||
log.Warnf("webrtc_stream: declared size %d != actual %d", cfg.FileSize, fileSize)
|
||||
}
|
||||
fileName := cfg.FileName
|
||||
if fileName == "" {
|
||||
fileName = filepath.Base(abs)
|
||||
}
|
||||
|
||||
// 1. Build PeerConnection.
|
||||
api := webrtc.NewAPI()
|
||||
pc, err := api.NewPeerConnection(webrtc.Configuration{
|
||||
ICEServers: cfg.ICEServers,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("webrtc_stream: new peer connection: %w", err)
|
||||
}
|
||||
defer pc.Close()
|
||||
|
||||
sessionCtx, cancelSession := context.WithCancel(ctx)
|
||||
defer cancelSession()
|
||||
|
||||
// Stop the session when ICE drops permanently. "Disconnected" is
|
||||
// transient per RFC 8445 (NAT rebind, brief packet loss) — wait for
|
||||
// "Failed" or "Closed" before tearing down.
|
||||
pc.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) {
|
||||
log.Infof("[wrtc %s] ice=%s", agent.ShortID(cfg.SessionID), state.String())
|
||||
switch state {
|
||||
case webrtc.ICEConnectionStateFailed,
|
||||
webrtc.ICEConnectionStateClosed:
|
||||
cancelSession()
|
||||
case webrtc.ICEConnectionStateUnknown,
|
||||
webrtc.ICEConnectionStateNew,
|
||||
webrtc.ICEConnectionStateChecking,
|
||||
webrtc.ICEConnectionStateConnected,
|
||||
webrtc.ICEConnectionStateCompleted,
|
||||
webrtc.ICEConnectionStateDisconnected:
|
||||
// Disconnected is transient (RFC 8445 — NAT rebind / packet loss);
|
||||
// the others are normal progress states. Don't tear the session down.
|
||||
}
|
||||
})
|
||||
|
||||
// Trickle our ICE candidates back to the browser.
|
||||
// PostSignal runs on its own goroutine so a slow signal server can't
|
||||
// stall pion's ICE-gathering thread.
|
||||
pc.OnICECandidate(func(c *webrtc.ICECandidate) {
|
||||
if c == nil {
|
||||
go func() {
|
||||
_ = cfg.Signal.PostSignal(sessionCtx, cfg.SessionID, agent.SignalMessage{
|
||||
Type: agent.SignalMsgCandidateEnd,
|
||||
Payload: "",
|
||||
})
|
||||
}()
|
||||
return
|
||||
}
|
||||
init := c.ToJSON()
|
||||
payload, _ := json.Marshal(init)
|
||||
go func() {
|
||||
_ = cfg.Signal.PostSignal(sessionCtx, cfg.SessionID, agent.SignalMessage{
|
||||
Type: agent.SignalMsgCandidate,
|
||||
Payload: string(payload),
|
||||
})
|
||||
}()
|
||||
})
|
||||
|
||||
// Browser is the offerer — we react to the DataChannel it creates.
|
||||
dcReady := make(chan *webrtc.DataChannel, 1)
|
||||
pc.OnDataChannel(func(dc *webrtc.DataChannel) {
|
||||
log.Infof("[wrtc %s] data channel '%s' open", agent.ShortID(cfg.SessionID), dc.Label())
|
||||
select {
|
||||
case dcReady <- dc:
|
||||
default:
|
||||
// Browser opened a second DC — ignore, we only serve one.
|
||||
log.Warnf("[wrtc %s] extra data channel ignored", agent.ShortID(cfg.SessionID))
|
||||
}
|
||||
})
|
||||
|
||||
// 2. Drive the SDP exchange.
|
||||
sdpDone := make(chan error, 1)
|
||||
go func() {
|
||||
sdpDone <- runSDPExchange(sessionCtx, pc, cfg)
|
||||
}()
|
||||
|
||||
// 3. Wait for either SDP error or DataChannel open.
|
||||
var dc *webrtc.DataChannel
|
||||
select {
|
||||
case err := <-sdpDone:
|
||||
if err != nil {
|
||||
return fmt.Errorf("sdp exchange: %w", err)
|
||||
}
|
||||
// SDP complete — wait for the DC.
|
||||
select {
|
||||
case dc = <-dcReady:
|
||||
case <-time.After(helloDeadline):
|
||||
return errors.New("webrtc_stream: data channel never opened")
|
||||
case <-sessionCtx.Done():
|
||||
return sessionCtx.Err()
|
||||
}
|
||||
case dc = <-dcReady:
|
||||
// DC opened before SDP loop reported done (typical: the loop keeps
|
||||
// running to ferry remote ICE candidates).
|
||||
case <-sessionCtx.Done():
|
||||
return sessionCtx.Err()
|
||||
}
|
||||
|
||||
// 4. Wire up the data channel pump.
|
||||
pump := newDataChannelPump(dc, file, fileSize, fileName, log, cancelSession)
|
||||
dc.OnOpen(pump.onOpen)
|
||||
dc.OnMessage(pump.onMessage)
|
||||
dc.OnClose(func() {
|
||||
log.Infof("[wrtc %s] data channel closed", agent.ShortID(cfg.SessionID))
|
||||
cancelSession()
|
||||
})
|
||||
|
||||
<-sessionCtx.Done()
|
||||
pump.shutdown()
|
||||
return sessionCtx.Err()
|
||||
}
|
||||
|
||||
// runSDPExchange consumes signal events from the browser and answers the SDP
|
||||
// offer. Keeps running for the lifetime of sessionCtx so trickle candidates
|
||||
// flow in both directions. Reopens the SSE stream on every clean close — the
|
||||
// server caps each response at ~25 s.
|
||||
func runSDPExchange(ctx context.Context, pc *webrtc.PeerConnection, cfg WebRTCStreamConfig) error {
|
||||
gotOffer := false
|
||||
for ctx.Err() == nil {
|
||||
stream, err := cfg.Signal.OpenSignalStream(ctx, cfg.SessionID)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
return fmt.Errorf("open signal stream: %w", err)
|
||||
}
|
||||
err = consumeSignalStream(ctx, pc, cfg, stream, &gotOffer)
|
||||
stream.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// consumeSignalStream drains a single SSE connection until it closes or
|
||||
// produces a hard error. Returns nil on a clean server-side disconnect so the
|
||||
// caller can reopen.
|
||||
func consumeSignalStream(
|
||||
ctx context.Context,
|
||||
pc *webrtc.PeerConnection,
|
||||
cfg WebRTCStreamConfig,
|
||||
stream *agent.SignalEventStream,
|
||||
gotOffer *bool,
|
||||
) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case msg, ok := <-stream.Events():
|
||||
if !ok {
|
||||
if err := stream.Err(); err != nil {
|
||||
return fmt.Errorf("signal stream: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if err := handleSignal(ctx, pc, cfg, msg, gotOffer); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func handleSignal(
|
||||
ctx context.Context,
|
||||
pc *webrtc.PeerConnection,
|
||||
cfg WebRTCStreamConfig,
|
||||
msg agent.SignalMessage,
|
||||
gotOffer *bool,
|
||||
) error {
|
||||
switch msg.Type {
|
||||
case agent.SignalMsgAnswer:
|
||||
// Browser is the offerer in our protocol — we never expect an answer
|
||||
// from the other side. Drop silently (also satisfies exhaustive lint).
|
||||
return nil
|
||||
case agent.SignalMsgOffer:
|
||||
if *gotOffer {
|
||||
return nil // ignore duplicates
|
||||
}
|
||||
var offer webrtc.SessionDescription
|
||||
if err := json.Unmarshal([]byte(msg.Payload), &offer); err != nil {
|
||||
return fmt.Errorf("decode offer: %w", err)
|
||||
}
|
||||
if err := pc.SetRemoteDescription(offer); err != nil {
|
||||
return fmt.Errorf("set remote description: %w", err)
|
||||
}
|
||||
answer, err := pc.CreateAnswer(nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create answer: %w", err)
|
||||
}
|
||||
if err := pc.SetLocalDescription(answer); err != nil {
|
||||
return fmt.Errorf("set local description: %w", err)
|
||||
}
|
||||
// Send back the local description *with* gathered candidates so far —
|
||||
// remaining candidates trickle separately via OnICECandidate.
|
||||
ld := pc.LocalDescription()
|
||||
payload, _ := json.Marshal(ld)
|
||||
if err := cfg.Signal.PostSignal(ctx, cfg.SessionID, agent.SignalMessage{
|
||||
Type: agent.SignalMsgAnswer,
|
||||
Payload: string(payload),
|
||||
}); err != nil {
|
||||
return fmt.Errorf("post answer: %w", err)
|
||||
}
|
||||
*gotOffer = true
|
||||
|
||||
case agent.SignalMsgCandidate:
|
||||
if !*gotOffer {
|
||||
// Browser may trickle candidates before we've seen the offer in
|
||||
// rare race conditions — drop. Browser will retransmit.
|
||||
return nil
|
||||
}
|
||||
var init webrtc.ICECandidateInit
|
||||
if err := json.Unmarshal([]byte(msg.Payload), &init); err != nil {
|
||||
return fmt.Errorf("decode candidate: %w", err)
|
||||
}
|
||||
if err := pc.AddICECandidate(init); err != nil {
|
||||
return fmt.Errorf("add ice candidate: %w", err)
|
||||
}
|
||||
|
||||
case agent.SignalMsgCandidateEnd:
|
||||
// No-op — pion gathers complete on its own.
|
||||
|
||||
case agent.SignalMsgBye:
|
||||
return errors.New("browser sent bye")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// dataChannelPump owns the DC + file handle and serves wire-protocol frames.
|
||||
type dataChannelPump struct {
|
||||
dc *webrtc.DataChannel
|
||||
file *os.File
|
||||
fileSize int64
|
||||
fileName string
|
||||
log StreamLogger
|
||||
cancel context.CancelFunc
|
||||
|
||||
// Flow control: writers wait on resumeCh when bufferedAmount goes high.
|
||||
paused atomic.Bool
|
||||
resumeCh chan struct{}
|
||||
|
||||
// Active range responses keyed by stream_id so CANCEL frames can stop them.
|
||||
activeMu sync.Mutex
|
||||
active map[uint32]context.CancelFunc
|
||||
|
||||
// Bound concurrent in-flight responses.
|
||||
sem chan struct{}
|
||||
|
||||
// closed once shutdown() has been called.
|
||||
closed atomic.Bool
|
||||
}
|
||||
|
||||
func newDataChannelPump(
|
||||
dc *webrtc.DataChannel,
|
||||
file *os.File,
|
||||
fileSize int64,
|
||||
fileName string,
|
||||
log StreamLogger,
|
||||
cancel context.CancelFunc,
|
||||
) *dataChannelPump {
|
||||
p := &dataChannelPump{
|
||||
dc: dc,
|
||||
file: file,
|
||||
fileSize: fileSize,
|
||||
fileName: fileName,
|
||||
log: log,
|
||||
cancel: cancel,
|
||||
resumeCh: make(chan struct{}, 1),
|
||||
active: make(map[uint32]context.CancelFunc),
|
||||
sem: make(chan struct{}, rangeReqConcurrency),
|
||||
}
|
||||
dc.SetBufferedAmountLowThreshold(dcLowWatermark)
|
||||
dc.OnBufferedAmountLow(p.onBufferedAmountLow)
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *dataChannelPump) onOpen() {
|
||||
hello := wire.HelloPayload{
|
||||
FileSize: uint64(p.fileSize),
|
||||
Transcoding: false,
|
||||
Seekable: true,
|
||||
FileName: p.fileName,
|
||||
}
|
||||
payload := wire.EncodeHello(hello)
|
||||
frame := wire.EncodeFrame(wire.Header{
|
||||
Type: wire.FrameHello,
|
||||
Flags: wire.HelloFlags(false, true),
|
||||
StreamID: 0,
|
||||
Length: uint32(len(payload)),
|
||||
}, payload)
|
||||
if err := p.dc.Send(frame); err != nil {
|
||||
p.log.Errorf("send hello: %v", err)
|
||||
p.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *dataChannelPump) onMessage(msg webrtc.DataChannelMessage) {
|
||||
if len(msg.Data) < wire.HeaderSize {
|
||||
p.log.Warnf("dc: short frame %d bytes", len(msg.Data))
|
||||
return
|
||||
}
|
||||
hdr, err := wire.DecodeHeader(msg.Data[:wire.HeaderSize])
|
||||
if err != nil {
|
||||
p.log.Warnf("dc: bad header: %v", err)
|
||||
return
|
||||
}
|
||||
payload := msg.Data[wire.HeaderSize:]
|
||||
if uint32(len(payload)) != hdr.Length {
|
||||
p.log.Warnf("dc: payload length mismatch: hdr=%d got=%d", hdr.Length, len(payload))
|
||||
return
|
||||
}
|
||||
|
||||
switch hdr.Type {
|
||||
case wire.FrameRangeReq:
|
||||
req, err := wire.DecodeRangeReq(payload)
|
||||
if err != nil {
|
||||
p.log.Warnf("dc: bad range_req: %v", err)
|
||||
return
|
||||
}
|
||||
go p.serveRange(hdr.StreamID, req)
|
||||
case wire.FrameCancel:
|
||||
p.cancelStream(hdr.StreamID)
|
||||
case wire.FramePing:
|
||||
p.sendSimpleFrame(wire.FramePong, hdr.StreamID, nil)
|
||||
case wire.FramePong:
|
||||
// no-op
|
||||
default:
|
||||
p.log.Warnf("dc: unknown frame type 0x%02x", hdr.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *dataChannelPump) cancelStream(streamID uint32) {
|
||||
p.activeMu.Lock()
|
||||
cancel, ok := p.active[streamID]
|
||||
delete(p.active, streamID)
|
||||
p.activeMu.Unlock()
|
||||
if ok {
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *dataChannelPump) sendSimpleFrame(t wire.FrameType, streamID uint32, payload []byte) {
|
||||
frame := wire.EncodeFrame(wire.Header{
|
||||
Type: t,
|
||||
StreamID: streamID,
|
||||
Length: uint32(len(payload)),
|
||||
}, payload)
|
||||
if err := p.dc.Send(frame); err != nil {
|
||||
p.log.Warnf("dc: send type=0x%02x: %v", t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *dataChannelPump) serveRange(streamID uint32, req wire.RangeReqPayload) {
|
||||
if p.closed.Load() {
|
||||
return
|
||||
}
|
||||
// Bound concurrency.
|
||||
select {
|
||||
case p.sem <- struct{}{}:
|
||||
case <-time.After(5 * time.Second):
|
||||
p.log.Warnf("dc: range_req sid=%d dropped (concurrency cap)", streamID)
|
||||
p.sendRangeEnd(streamID, 1)
|
||||
return
|
||||
}
|
||||
defer func() { <-p.sem }()
|
||||
|
||||
// Reject offsets above MaxInt64 — uint64→int64 narrowing would wrap to a
|
||||
// negative value and bypass the bounds check, then ReadAt would be called
|
||||
// with a negative offset.
|
||||
if req.Offset > math.MaxInt64 || int64(req.Offset) >= p.fileSize {
|
||||
p.sendRangeEnd(streamID, 2) // out of range
|
||||
return
|
||||
}
|
||||
|
||||
want := int64(req.Length)
|
||||
if req.Length > math.MaxInt64 {
|
||||
want = 0 // treat absurd length as "remainder of file"
|
||||
}
|
||||
remaining := p.fileSize - int64(req.Offset)
|
||||
if want <= 0 || want > remaining {
|
||||
want = remaining
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
p.activeMu.Lock()
|
||||
p.active[streamID] = cancel
|
||||
p.activeMu.Unlock()
|
||||
defer func() {
|
||||
p.activeMu.Lock()
|
||||
delete(p.active, streamID)
|
||||
p.activeMu.Unlock()
|
||||
cancel()
|
||||
}()
|
||||
|
||||
buf := make([]byte, dcChunkPayload)
|
||||
offset := int64(req.Offset)
|
||||
end := offset + want
|
||||
for offset < end {
|
||||
if ctx.Err() != nil || p.closed.Load() {
|
||||
return
|
||||
}
|
||||
// Wait if the DC is buffering too much.
|
||||
if err := p.waitForLowWater(ctx); err != nil {
|
||||
return
|
||||
}
|
||||
chunkLen := int64(len(buf))
|
||||
if end-offset < chunkLen {
|
||||
chunkLen = end - offset
|
||||
}
|
||||
n, rerr := p.file.ReadAt(buf[:chunkLen], offset)
|
||||
if n > 0 {
|
||||
// EOF on a short read means this is the final chunk — flag it so the
|
||||
// browser doesn't wait for more data before processing RangeEnd.
|
||||
isLast := offset+int64(n) >= end || rerr == io.EOF
|
||||
if err := p.sendRangeData(streamID, buf[:n], isLast); err != nil {
|
||||
p.log.Warnf("dc: send range_data sid=%d: %v", streamID, err)
|
||||
return
|
||||
}
|
||||
offset += int64(n)
|
||||
}
|
||||
if rerr != nil {
|
||||
if rerr == io.EOF {
|
||||
break
|
||||
}
|
||||
p.log.Errorf("dc: read sid=%d: %v", streamID, rerr)
|
||||
p.sendRangeEnd(streamID, 3)
|
||||
return
|
||||
}
|
||||
}
|
||||
p.sendRangeEnd(streamID, 0)
|
||||
}
|
||||
|
||||
func (p *dataChannelPump) sendRangeData(streamID uint32, data []byte, last bool) error {
|
||||
var flags uint8
|
||||
if last {
|
||||
flags |= wire.FlagLastChunk
|
||||
}
|
||||
frame := wire.EncodeFrame(wire.Header{
|
||||
Type: wire.FrameRangeData,
|
||||
Flags: flags,
|
||||
StreamID: streamID,
|
||||
Length: uint32(len(data)),
|
||||
}, data)
|
||||
return p.dc.Send(frame)
|
||||
}
|
||||
|
||||
func (p *dataChannelPump) sendRangeEnd(streamID uint32, status uint32) {
|
||||
payload := wire.EncodeRangeEnd(wire.RangeEndPayload{Status: status})
|
||||
p.sendSimpleFrame(wire.FrameRangeEnd, streamID, payload)
|
||||
}
|
||||
|
||||
func (p *dataChannelPump) waitForLowWater(ctx context.Context) error {
|
||||
if p.dc.BufferedAmount() < dcHighWatermark {
|
||||
return nil
|
||||
}
|
||||
p.paused.Store(true)
|
||||
for {
|
||||
// Drain any stale resume signal first.
|
||||
select {
|
||||
case <-p.resumeCh:
|
||||
default:
|
||||
}
|
||||
if p.dc.BufferedAmount() < dcHighWatermark {
|
||||
p.paused.Store(false)
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-p.resumeCh:
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
// Belt-and-braces poll in case OnBufferedAmountLow misses a fire.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *dataChannelPump) onBufferedAmountLow() {
|
||||
if !p.paused.Load() {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case p.resumeCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (p *dataChannelPump) shutdown() {
|
||||
if !p.closed.CompareAndSwap(false, true) {
|
||||
return
|
||||
}
|
||||
p.activeMu.Lock()
|
||||
for _, cancel := range p.active {
|
||||
cancel()
|
||||
}
|
||||
p.active = nil
|
||||
p.activeMu.Unlock()
|
||||
}
|
||||
254
internal/engine/wire/proto.go
Normal file
254
internal/engine/wire/proto.go
Normal file
|
|
@ -0,0 +1,254 @@
|
|||
// Package wire implements the binary frame format used over the WebRTC
|
||||
// DataChannel between the unarr daemon and the browser stream player.
|
||||
//
|
||||
// Header (12 bytes, big-endian):
|
||||
//
|
||||
// u8 Type
|
||||
// u8 Flags
|
||||
// u16 _reserved
|
||||
// u32 StreamID -- multiplex range requests
|
||||
// u32 Length -- payload bytes following the header
|
||||
//
|
||||
// Each side encodes one Frame at a time and writes it as a single SCTP
|
||||
// message (DataChannel send). Browsers cap message size at 64 KiB-ish, so
|
||||
// callers MUST split RANGE_DATA payloads into chunks <= MaxChunkPayload.
|
||||
package wire
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// FrameType identifies the wire message kind.
|
||||
type FrameType uint8
|
||||
|
||||
const (
|
||||
FrameHello FrameType = 0x00
|
||||
FrameRangeReq FrameType = 0x01
|
||||
FrameRangeData FrameType = 0x02
|
||||
FrameRangeEnd FrameType = 0x03
|
||||
FrameCancel FrameType = 0x04
|
||||
FramePing FrameType = 0x05
|
||||
FramePong FrameType = 0x06
|
||||
FrameSeekHint FrameType = 0x07
|
||||
)
|
||||
|
||||
// Flag bits — interpretation depends on FrameType.
|
||||
const (
|
||||
// FlagLastChunk on a RangeData frame marks the final chunk for a stream_id.
|
||||
FlagLastChunk uint8 = 1 << 0
|
||||
// FlagTranscoding on a Hello frame indicates the daemon will transcode.
|
||||
FlagTranscoding uint8 = 1 << 1
|
||||
// FlagSeekable on a Hello frame indicates random-access is supported.
|
||||
FlagSeekable uint8 = 1 << 2
|
||||
)
|
||||
|
||||
// HeaderSize is the fixed length of every frame header.
|
||||
const HeaderSize = 12
|
||||
|
||||
// MaxChunkPayload is the safe per-frame payload cap that works on every
|
||||
// browser implementation (Chromium fragments at 16 KiB internally above).
|
||||
// Callers MUST chunk RangeData payloads to <= this size.
|
||||
const MaxChunkPayload = 16 * 1024
|
||||
|
||||
// MaxFrameSize is the largest frame the parser will accept. Anything bigger
|
||||
// is treated as a corrupted stream — close the channel.
|
||||
const MaxFrameSize = HeaderSize + 64*1024
|
||||
|
||||
// Header is the parsed 12-byte frame header.
|
||||
type Header struct {
|
||||
Type FrameType
|
||||
Flags uint8
|
||||
StreamID uint32
|
||||
Length uint32
|
||||
}
|
||||
|
||||
// EncodeHeader writes h to dst (must be at least HeaderSize bytes).
|
||||
func EncodeHeader(dst []byte, h Header) {
|
||||
if len(dst) < HeaderSize {
|
||||
panic("wire: dst too small for header")
|
||||
}
|
||||
dst[0] = byte(h.Type)
|
||||
dst[1] = h.Flags
|
||||
dst[2] = 0
|
||||
dst[3] = 0
|
||||
binary.BigEndian.PutUint32(dst[4:8], h.StreamID)
|
||||
binary.BigEndian.PutUint32(dst[8:12], h.Length)
|
||||
}
|
||||
|
||||
// DecodeHeader parses src (must be at least HeaderSize bytes) into h.
|
||||
func DecodeHeader(src []byte) (Header, error) {
|
||||
if len(src) < HeaderSize {
|
||||
return Header{}, fmt.Errorf("wire: header needs %d bytes, got %d", HeaderSize, len(src))
|
||||
}
|
||||
h := Header{
|
||||
Type: FrameType(src[0]),
|
||||
Flags: src[1],
|
||||
StreamID: binary.BigEndian.Uint32(src[4:8]),
|
||||
Length: binary.BigEndian.Uint32(src[8:12]),
|
||||
}
|
||||
if h.Length > MaxFrameSize-HeaderSize {
|
||||
return Header{}, fmt.Errorf("wire: payload length %d exceeds max %d", h.Length, MaxFrameSize-HeaderSize)
|
||||
}
|
||||
return h, nil
|
||||
}
|
||||
|
||||
// EncodeFrame allocates and returns a complete frame (header + payload).
|
||||
// Use this for one-shot sends; for hot-path RangeData prefer EncodeHeader
|
||||
// into a pre-allocated buffer to avoid per-frame allocations.
|
||||
func EncodeFrame(h Header, payload []byte) []byte {
|
||||
if int(h.Length) != len(payload) {
|
||||
panic(fmt.Sprintf("wire: header length %d != payload len %d", h.Length, len(payload)))
|
||||
}
|
||||
buf := make([]byte, HeaderSize+len(payload))
|
||||
EncodeHeader(buf[:HeaderSize], h)
|
||||
copy(buf[HeaderSize:], payload)
|
||||
return buf
|
||||
}
|
||||
|
||||
// ReadFrame reads one full frame from r. Returns the parsed header and a
|
||||
// freshly allocated payload slice. On any size violation the connection
|
||||
// must be closed — the protocol has no resync.
|
||||
func ReadFrame(r io.Reader) (Header, []byte, error) {
|
||||
headerBuf := make([]byte, HeaderSize)
|
||||
if _, err := io.ReadFull(r, headerBuf); err != nil {
|
||||
return Header{}, nil, err
|
||||
}
|
||||
h, err := DecodeHeader(headerBuf)
|
||||
if err != nil {
|
||||
return Header{}, nil, err
|
||||
}
|
||||
if h.Length == 0 {
|
||||
return h, nil, nil
|
||||
}
|
||||
payload := make([]byte, h.Length)
|
||||
if _, err := io.ReadFull(r, payload); err != nil {
|
||||
return Header{}, nil, err
|
||||
}
|
||||
return h, payload, nil
|
||||
}
|
||||
|
||||
// HelloPayload describes the file the daemon is about to serve. It is the
|
||||
// first frame the daemon writes after the DataChannel opens.
|
||||
type HelloPayload struct {
|
||||
FileSize uint64
|
||||
Transcoding bool
|
||||
Seekable bool
|
||||
FileName string
|
||||
}
|
||||
|
||||
// EncodeHello marshals h into a payload byte slice.
|
||||
//
|
||||
// Layout: u64 file_size | u32 name_len | name_bytes
|
||||
func EncodeHello(h HelloPayload) []byte {
|
||||
nameBytes := []byte(h.FileName)
|
||||
buf := make([]byte, 8+4+len(nameBytes))
|
||||
binary.BigEndian.PutUint64(buf[0:8], h.FileSize)
|
||||
binary.BigEndian.PutUint32(buf[8:12], uint32(len(nameBytes)))
|
||||
copy(buf[12:], nameBytes)
|
||||
return buf
|
||||
}
|
||||
|
||||
// DecodeHello parses a Hello payload. The transcoding/seekable bits live in
|
||||
// the frame Flags byte, not the payload — pass them in.
|
||||
func DecodeHello(payload []byte, flags uint8) (HelloPayload, error) {
|
||||
if len(payload) < 12 {
|
||||
return HelloPayload{}, errors.New("wire: hello payload too short")
|
||||
}
|
||||
size := binary.BigEndian.Uint64(payload[0:8])
|
||||
nameLen := binary.BigEndian.Uint32(payload[8:12])
|
||||
if int(nameLen) > len(payload)-12 {
|
||||
return HelloPayload{}, fmt.Errorf("wire: hello name_len %d exceeds payload", nameLen)
|
||||
}
|
||||
return HelloPayload{
|
||||
FileSize: size,
|
||||
Transcoding: flags&FlagTranscoding != 0,
|
||||
Seekable: flags&FlagSeekable != 0,
|
||||
FileName: string(payload[12 : 12+nameLen]),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// HelloFlags returns the flag byte for a Hello frame given the booleans.
|
||||
func HelloFlags(transcoding, seekable bool) uint8 {
|
||||
var f uint8
|
||||
if transcoding {
|
||||
f |= FlagTranscoding
|
||||
}
|
||||
if seekable {
|
||||
f |= FlagSeekable
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
// RangeReqPayload is the browser → daemon request for bytes [Offset, Offset+Length).
|
||||
type RangeReqPayload struct {
|
||||
Offset uint64
|
||||
Length uint64
|
||||
}
|
||||
|
||||
// EncodeRangeReq marshals p. Layout: u64 offset | u64 length.
|
||||
func EncodeRangeReq(p RangeReqPayload) []byte {
|
||||
buf := make([]byte, 16)
|
||||
binary.BigEndian.PutUint64(buf[0:8], p.Offset)
|
||||
binary.BigEndian.PutUint64(buf[8:16], p.Length)
|
||||
return buf
|
||||
}
|
||||
|
||||
// DecodeRangeReq parses a 16-byte range request payload.
|
||||
func DecodeRangeReq(payload []byte) (RangeReqPayload, error) {
|
||||
if len(payload) != 16 {
|
||||
return RangeReqPayload{}, fmt.Errorf("wire: range_req payload must be 16 bytes, got %d", len(payload))
|
||||
}
|
||||
return RangeReqPayload{
|
||||
Offset: binary.BigEndian.Uint64(payload[0:8]),
|
||||
Length: binary.BigEndian.Uint64(payload[8:16]),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RangeEndPayload signals end-of-response for a stream_id with a status code.
|
||||
// Status 0 == OK; non-zero values are app-defined error codes.
|
||||
type RangeEndPayload struct {
|
||||
Status uint32
|
||||
}
|
||||
|
||||
// EncodeRangeEnd marshals p.
|
||||
func EncodeRangeEnd(p RangeEndPayload) []byte {
|
||||
buf := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(buf[0:4], p.Status)
|
||||
return buf
|
||||
}
|
||||
|
||||
// DecodeRangeEnd parses a 4-byte range_end payload.
|
||||
func DecodeRangeEnd(payload []byte) (RangeEndPayload, error) {
|
||||
if len(payload) != 4 {
|
||||
return RangeEndPayload{}, fmt.Errorf("wire: range_end payload must be 4 bytes, got %d", len(payload))
|
||||
}
|
||||
return RangeEndPayload{
|
||||
Status: binary.BigEndian.Uint32(payload[0:4]),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SeekHintPayload tells the daemon a seek to timestamp_ms is imminent so it
|
||||
// can pre-warm a transcoder pipeline before bytes are requested.
|
||||
type SeekHintPayload struct {
|
||||
TimestampMs uint64
|
||||
}
|
||||
|
||||
// EncodeSeekHint marshals p.
|
||||
func EncodeSeekHint(p SeekHintPayload) []byte {
|
||||
buf := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(buf[0:8], p.TimestampMs)
|
||||
return buf
|
||||
}
|
||||
|
||||
// DecodeSeekHint parses an 8-byte seek_hint payload.
|
||||
func DecodeSeekHint(payload []byte) (SeekHintPayload, error) {
|
||||
if len(payload) != 8 {
|
||||
return SeekHintPayload{}, fmt.Errorf("wire: seek_hint payload must be 8 bytes, got %d", len(payload))
|
||||
}
|
||||
return SeekHintPayload{
|
||||
TimestampMs: binary.BigEndian.Uint64(payload[0:8]),
|
||||
}, nil
|
||||
}
|
||||
193
internal/engine/wire/proto_test.go
Normal file
193
internal/engine/wire/proto_test.go
Normal file
|
|
@ -0,0 +1,193 @@
|
|||
package wire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHeaderRoundtrip(t *testing.T) {
|
||||
cases := []Header{
|
||||
{Type: FrameHello, Flags: FlagSeekable, StreamID: 0, Length: 32},
|
||||
{Type: FrameRangeReq, Flags: 0, StreamID: 7, Length: 16},
|
||||
{Type: FrameRangeData, Flags: FlagLastChunk, StreamID: 4242, Length: 16380},
|
||||
{Type: FrameRangeEnd, Flags: 0, StreamID: 1, Length: 4},
|
||||
{Type: FrameCancel, Flags: 0, StreamID: 9, Length: 0},
|
||||
{Type: FramePing, Flags: 0, StreamID: 0, Length: 0},
|
||||
}
|
||||
for _, want := range cases {
|
||||
buf := make([]byte, HeaderSize)
|
||||
EncodeHeader(buf, want)
|
||||
got, err := DecodeHeader(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("decode: %v (want %+v)", err, want)
|
||||
}
|
||||
if got != want {
|
||||
t.Errorf("roundtrip mismatch: got %+v want %+v", got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeHeaderShort(t *testing.T) {
|
||||
if _, err := DecodeHeader([]byte{0, 0, 0}); err == nil {
|
||||
t.Fatal("expected error on short header")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeHeaderRejectsHugeLength(t *testing.T) {
|
||||
// Synthesize a header with payload length above MaxFrameSize.
|
||||
buf := make([]byte, HeaderSize)
|
||||
buf[0] = byte(FrameRangeData)
|
||||
buf[8] = 0xff
|
||||
buf[9] = 0xff
|
||||
buf[10] = 0xff
|
||||
buf[11] = 0xff
|
||||
if _, err := DecodeHeader(buf); err == nil {
|
||||
t.Fatal("expected error on oversized payload length")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeFramePanicsOnLengthMismatch(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Fatal("expected panic on header length / payload mismatch")
|
||||
}
|
||||
}()
|
||||
EncodeFrame(Header{Type: FrameRangeData, Length: 5}, []byte{1, 2, 3})
|
||||
}
|
||||
|
||||
func TestReadFrameRoundtrip(t *testing.T) {
|
||||
want := Header{Type: FrameRangeData, Flags: FlagLastChunk, StreamID: 99, Length: 5}
|
||||
payload := []byte{0xde, 0xad, 0xbe, 0xef, 0x42}
|
||||
frame := EncodeFrame(want, payload)
|
||||
|
||||
r := bytes.NewReader(frame)
|
||||
got, gotPayload, err := ReadFrame(r)
|
||||
if err != nil {
|
||||
t.Fatalf("read: %v", err)
|
||||
}
|
||||
if got != want {
|
||||
t.Errorf("header mismatch: %+v want %+v", got, want)
|
||||
}
|
||||
if !bytes.Equal(gotPayload, payload) {
|
||||
t.Errorf("payload mismatch: %x want %x", gotPayload, payload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFrameZeroPayload(t *testing.T) {
|
||||
want := Header{Type: FrameCancel, StreamID: 7}
|
||||
frame := EncodeFrame(want, nil)
|
||||
got, payload, err := ReadFrame(bytes.NewReader(frame))
|
||||
if err != nil {
|
||||
t.Fatalf("read: %v", err)
|
||||
}
|
||||
if got != want {
|
||||
t.Errorf("header mismatch: %+v want %+v", got, want)
|
||||
}
|
||||
if len(payload) != 0 {
|
||||
t.Errorf("expected empty payload, got %d bytes", len(payload))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHelloRoundtrip(t *testing.T) {
|
||||
want := HelloPayload{
|
||||
FileSize: 1<<32 + 12345,
|
||||
Transcoding: false,
|
||||
Seekable: true,
|
||||
FileName: "Tangled.Ever.After.2025.1080p.WEB-DL.h264.mp4",
|
||||
}
|
||||
flags := HelloFlags(want.Transcoding, want.Seekable)
|
||||
payload := EncodeHello(want)
|
||||
got, err := DecodeHello(payload, flags)
|
||||
if err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
if got != want {
|
||||
t.Errorf("hello mismatch: %+v want %+v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHelloRejectsTruncatedPayload(t *testing.T) {
|
||||
if _, err := DecodeHello([]byte{1, 2, 3}, 0); err == nil {
|
||||
t.Fatal("expected error on truncated hello")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHelloRejectsNameLenOverrun(t *testing.T) {
|
||||
// file_size + name_len=999 but no name bytes → should fail.
|
||||
buf := make([]byte, 12)
|
||||
buf[8], buf[9], buf[10], buf[11] = 0, 0, 0x03, 0xe7 // 999
|
||||
if _, err := DecodeHello(buf, 0); err == nil {
|
||||
t.Fatal("expected error on name_len overrun")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRangeReqRoundtrip(t *testing.T) {
|
||||
want := RangeReqPayload{Offset: 1 << 30, Length: 1 << 20}
|
||||
got, err := DecodeRangeReq(EncodeRangeReq(want))
|
||||
if err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
if got != want {
|
||||
t.Errorf("range_req mismatch: %+v want %+v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRangeReqRejectsWrongLength(t *testing.T) {
|
||||
if _, err := DecodeRangeReq(make([]byte, 15)); err == nil {
|
||||
t.Fatal("expected error on 15-byte payload")
|
||||
}
|
||||
if _, err := DecodeRangeReq(make([]byte, 17)); err == nil {
|
||||
t.Fatal("expected error on 17-byte payload")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRangeEndRoundtrip(t *testing.T) {
|
||||
want := RangeEndPayload{Status: 42}
|
||||
got, err := DecodeRangeEnd(EncodeRangeEnd(want))
|
||||
if err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
if got != want {
|
||||
t.Errorf("range_end mismatch: %+v want %+v", got, want)
|
||||
}
|
||||
if _, err := DecodeRangeEnd(make([]byte, 3)); err == nil {
|
||||
t.Fatal("expected error on short range_end payload")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSeekHintRoundtrip(t *testing.T) {
|
||||
want := SeekHintPayload{TimestampMs: 123_456}
|
||||
got, err := DecodeSeekHint(EncodeSeekHint(want))
|
||||
if err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
if got != want {
|
||||
t.Errorf("seek_hint mismatch: %+v want %+v", got, want)
|
||||
}
|
||||
if _, err := DecodeSeekHint(make([]byte, 7)); err == nil {
|
||||
t.Fatal("expected error on short seek_hint payload")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHelloFlagsHelper(t *testing.T) {
|
||||
if HelloFlags(false, false) != 0 {
|
||||
t.Error("expected 0 for both false")
|
||||
}
|
||||
if HelloFlags(true, false) != FlagTranscoding {
|
||||
t.Error("expected FlagTranscoding only")
|
||||
}
|
||||
if HelloFlags(false, true) != FlagSeekable {
|
||||
t.Error("expected FlagSeekable only")
|
||||
}
|
||||
if HelloFlags(true, true) != (FlagTranscoding | FlagSeekable) {
|
||||
t.Error("expected both flags")
|
||||
}
|
||||
}
|
||||
|
||||
// Sanity check that MaxChunkPayload + HeaderSize fits inside MaxFrameSize so
|
||||
// callers can rely on the chunk cap without their own bookkeeping.
|
||||
func TestMaxChunkFitsInMaxFrame(t *testing.T) {
|
||||
if MaxChunkPayload+HeaderSize > MaxFrameSize {
|
||||
t.Fatalf("chunk %d + hdr %d > max frame %d", MaxChunkPayload, HeaderSize, MaxFrameSize)
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue