From 4314c06c5ce89923fcd6a50f3638a8dfbcf6cb8a Mon Sep 17 00:00:00 2001 From: Deivid Soto Date: Wed, 6 May 2026 23:12:38 +0200 Subject: [PATCH] feat(stream): pion-based WebRTC byte streamer for browser playback Replaces the broken anacrolix WebTorrent path with a custom WebRTC peer that the browser drives directly. Architecture matches plan/clever- weaving-dove.md (Fase 2 + 3 + 6 of the streaming pivot). - engine/wire: shared 12-byte binary frame format (Hello / RangeReq / RangeData / RangeEnd / Cancel / Ping / Pong / SeekHint). Roundtrip + oversized-frame rejection tests. - agent/signal_client: SSE consumer + POST sender for SDP/ICE relay through /api/internal/stream/signal/; auto-reconnects. - engine/webrtc_stream: pion v4 PeerConnection + DataChannel pump. Reads file via os.ReadAt, chunks RangeData at 16 KiB, honours app- level backpressure with SetBufferedAmountLowThreshold. - cmd/daemon dispatcher learns mode webrtc_stream + new webrtcSessionRegistry tracks per-session cancel funcs for clean shutdown. - engine/probe + hwaccel + transcoder: foundation for Fase 2.5 (codec detection, NVENC/QSV/VAAPI/VideoToolbox autodetection, ffmpeg pipe wrapper to fragmented MP4). Integration into webrtc_stream still pending. - pion/webrtc/v4 promoted from indirect to direct dep. End-to-end against unarr-dev confirms a 122 MB 1080p H.264 / AAC MP4 plays in Chrome with the new pipeline. --- go.mod | 2 +- internal/agent/daemon.go | 6 + internal/agent/signal_client.go | 233 +++++++++ internal/agent/signal_client_test.go | 153 ++++++ internal/agent/sync.go | 8 + internal/agent/types.go | 14 + internal/cmd/daemon.go | 61 +++ internal/cmd/webrtc_session_registry.go | 62 +++ internal/engine/hwaccel.go | 130 +++++ internal/engine/hwaccel_test.go | 34 ++ internal/engine/probe.go | 116 +++++ internal/engine/probe_test.go | 96 ++++ internal/engine/transcoder.go | 179 +++++++ internal/engine/transcoder_test.go | 151 ++++++ internal/engine/webrtc_stream.go | 617 ++++++++++++++++++++++++ internal/engine/wire/proto.go | 254 ++++++++++ internal/engine/wire/proto_test.go | 193 ++++++++ 17 files changed, 2308 insertions(+), 1 deletion(-) create mode 100644 internal/agent/signal_client.go create mode 100644 internal/agent/signal_client_test.go create mode 100644 internal/cmd/webrtc_session_registry.go create mode 100644 internal/engine/hwaccel.go create mode 100644 internal/engine/hwaccel_test.go create mode 100644 internal/engine/probe.go create mode 100644 internal/engine/probe_test.go create mode 100644 internal/engine/transcoder.go create mode 100644 internal/engine/transcoder_test.go create mode 100644 internal/engine/webrtc_stream.go create mode 100644 internal/engine/wire/proto.go create mode 100644 internal/engine/wire/proto_test.go diff --git a/go.mod b/go.mod index 6439955..30c116e 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/google/uuid v1.6.0 github.com/huin/goupnp v1.3.0 github.com/olekukonko/tablewriter v1.1.4 + github.com/pion/webrtc/v4 v4.2.11 github.com/spf13/cobra v1.10.2 github.com/torrentclaw/go-client v0.2.0 golang.org/x/term v0.41.0 @@ -105,7 +106,6 @@ require ( github.com/pion/stun/v3 v3.1.1 // indirect github.com/pion/transport/v4 v4.0.1 // indirect github.com/pion/turn/v4 v4.1.4 // indirect - github.com/pion/webrtc/v4 v4.2.11 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/protolambda/ctxlock v0.1.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect diff --git a/internal/agent/daemon.go b/internal/agent/daemon.go index 4e53c48..5977ecb 100644 --- a/internal/agent/daemon.go +++ b/internal/agent/daemon.go @@ -35,6 +35,7 @@ type Daemon struct { // Callbacks — set by cmd/daemon.go before calling Run. OnTasksClaimed func(tasks []Task) OnStreamRequested func(req StreamRequest) + OnWebRTCSession func(sess WebRTCSession) OnControlAction func(action, taskID string, deleteFiles bool) GetActiveCount func() int // returns number of active downloads (wired from manager) @@ -169,6 +170,11 @@ func (d *Daemon) Run(ctx context.Context) error { d.OnStreamRequested(req) } } + d.sync.OnWebRTCSession = func(sess WebRTCSession) { + if d.OnWebRTCSession != nil { + d.OnWebRTCSession(sess) + } + } d.sync.OnUpgrade = func(version string) { if version != d.lastNotifiedVersion { d.lastNotifiedVersion = version diff --git a/internal/agent/signal_client.go b/internal/agent/signal_client.go new file mode 100644 index 0000000..b5424f6 --- /dev/null +++ b/internal/agent/signal_client.go @@ -0,0 +1,233 @@ +package agent + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +// SignalRole identifies who produced a signalling message. The opposite role +// receives it. +type SignalRole string + +const ( + SignalRoleBrowser SignalRole = "browser" + SignalRoleAgent SignalRole = "agent" +) + +// SignalMessageType matches the server-side z.enum on +// /api/internal/stream/signal/[sessionId] route. +type SignalMessageType string + +const ( + SignalMsgOffer SignalMessageType = "offer" + SignalMsgAnswer SignalMessageType = "answer" + SignalMsgCandidate SignalMessageType = "candidate" + SignalMsgCandidateEnd SignalMessageType = "candidate-end" + SignalMsgBye SignalMessageType = "bye" +) + +// SignalMessage mirrors the bus envelope on the web side. +type SignalMessage struct { + From SignalRole `json:"from"` + Type SignalMessageType `json:"type"` + Payload string `json:"payload"` + TS int64 `json:"ts"` +} + +// PostSignal enqueues a signalling message produced by this agent. The +// browser receives it on its next SSE event push. +func (c *Client) PostSignal(ctx context.Context, sessionID string, msg SignalMessage) error { + body := map[string]any{ + "from": string(SignalRoleAgent), + "type": string(msg.Type), + "payload": msg.Payload, + } + path := fmt.Sprintf("/api/internal/stream/signal/%s", sessionID) + return c.doPost(ctx, path, body, &struct { + OK bool `json:"ok"` + }{}) +} + +// SignalEventStream wraps an open SSE connection. Read messages from Events() +// until the channel closes (server timeout or context cancel). Always defer +// Close() to release the underlying response body. +type SignalEventStream struct { + resp *http.Response + cancel context.CancelFunc + events chan SignalMessage + errs chan error + done chan struct{} +} + +// Events streams browser-produced messages addressed to the agent. +// The channel closes when the SSE connection ends; the caller should then +// call Close() and reopen if it wants to keep listening. +func (s *SignalEventStream) Events() <-chan SignalMessage { return s.events } + +// Err returns the terminating error (if any) once Events() has closed. +func (s *SignalEventStream) Err() error { + select { + case err := <-s.errs: + return err + default: + return nil + } +} + +// Close cancels the underlying HTTP request and waits for the reader goroutine +// to drain. Safe to call more than once. +func (s *SignalEventStream) Close() error { + if s.cancel != nil { + s.cancel() + } + if s.resp != nil { + s.resp.Body.Close() + } + <-s.done + return nil +} + +// OpenSignalStream opens a long-lived SSE connection to the signal events +// endpoint. Caller MUST cancel ctx (or call Close()) to free resources. +// +// The server caps each response at ~25 s; OpenSignalStream surfaces the +// disconnect by closing the events channel. Caller should reopen until the +// session ends. +func (c *Client) OpenSignalStream(ctx context.Context, sessionID string) (*SignalEventStream, error) { + streamCtx, cancel := context.WithCancel(ctx) + + url := fmt.Sprintf("%s/api/internal/stream/signal/%s/events", c.baseURL, sessionID) + req, err := http.NewRequestWithContext(streamCtx, http.MethodGet, url, nil) + if err != nil { + cancel() + return nil, fmt.Errorf("open signal stream: %w", err) + } + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Authorization", "Bearer "+c.apiKey) + req.Header.Set("User-Agent", c.userAgent) + req.Header.Set("Cache-Control", "no-cache") + + // Use a per-call client with no timeout (SSE connections are long). + sseClient := &http.Client{} + resp, err := sseClient.Do(req) + if err != nil { + cancel() + return nil, fmt.Errorf("open signal stream: %w", err) + } + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) + resp.Body.Close() + cancel() + return nil, fmt.Errorf("open signal stream: HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + + stream := &SignalEventStream{ + resp: resp, + cancel: cancel, + events: make(chan SignalMessage, 8), + errs: make(chan error, 1), + done: make(chan struct{}), + } + + go stream.read() + return stream, nil +} + +func (s *SignalEventStream) read() { + defer close(s.done) + defer close(s.events) + + reader := bufio.NewReaderSize(s.resp.Body, 16*1024) + var dataBuf bytes.Buffer + var eventName string + + for { + line, err := reader.ReadString('\n') + if err != nil { + if err != io.EOF { + select { + case s.errs <- err: + default: + } + } + return + } + line = strings.TrimRight(line, "\r\n") + if line == "" { + // End of an event — dispatch if we have data. + if dataBuf.Len() == 0 { + eventName = "" + continue + } + if eventName == "" || eventName == "signal" { + var msg SignalMessage + if err := json.Unmarshal(dataBuf.Bytes(), &msg); err == nil { + s.events <- msg + } + } + dataBuf.Reset() + eventName = "" + continue + } + if strings.HasPrefix(line, ":") { + // SSE comment (heartbeat); ignore. + continue + } + if strings.HasPrefix(line, "event:") { + eventName = strings.TrimSpace(line[len("event:"):]) + continue + } + if strings.HasPrefix(line, "data:") { + payload := strings.TrimSpace(line[len("data:"):]) + if dataBuf.Len() > 0 { + dataBuf.WriteByte('\n') + } + dataBuf.WriteString(payload) + continue + } + // id:, retry:, anything else — ignore for now. + } +} + +// SignalLoop runs an SSE consumer that reconnects automatically on disconnect. +// onMessage is called for every browser-produced message. Returns when ctx is +// cancelled. Reconnect backoff is fixed at 1 s — the server already paces +// reconnects with `retry: 1500` headers so churn is bounded. +func (c *Client) SignalLoop(ctx context.Context, sessionID string, onMessage func(SignalMessage)) error { + for ctx.Err() == nil { + stream, err := c.OpenSignalStream(ctx, sessionID) + if err != nil { + select { + case <-time.After(time.Second): + case <-ctx.Done(): + return ctx.Err() + } + continue + } + for msg := range stream.Events() { + onMessage(msg) + } + streamErr := stream.Err() + stream.Close() + if ctx.Err() != nil { + return ctx.Err() + } + // Server closes the SSE every ~25 s; reconnect immediately. + // Hard error → small backoff so we don't hammer. + if streamErr != nil { + select { + case <-time.After(time.Second): + case <-ctx.Done(): + return ctx.Err() + } + } + } + return ctx.Err() +} diff --git a/internal/agent/signal_client_test.go b/internal/agent/signal_client_test.go new file mode 100644 index 0000000..2527890 --- /dev/null +++ b/internal/agent/signal_client_test.go @@ -0,0 +1,153 @@ +package agent + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" +) + +// fakeSSEServer streams a fixed set of SSE events then closes the connection. +func fakeSSEServer(t *testing.T, msgs []SignalMessage, holdOpenAfter bool) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") != "Bearer test-key" { + http.Error(w, "auth", http.StatusUnauthorized) + return + } + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + flusher, ok := w.(http.Flusher) + if !ok { + t.Fatal("server: ResponseWriter is not a Flusher") + } + fmt.Fprint(w, "retry: 1500\n\n") + flusher.Flush() + for _, m := range msgs { + data, _ := json.Marshal(m) + fmt.Fprintf(w, "id: %d\nevent: signal\ndata: %s\n\n", m.TS, data) + flusher.Flush() + } + // Send a heartbeat comment to verify it's ignored. + fmt.Fprint(w, ": heartbeat\n\n") + flusher.Flush() + if holdOpenAfter { + // Hold the connection until the client disconnects so the test can + // exercise stream.Close(). + <-r.Context().Done() + } + })) +} + +func TestSignalStreamReadsMessages(t *testing.T) { + want := []SignalMessage{ + {From: SignalRoleBrowser, Type: SignalMsgOffer, Payload: "{sdp:1}", TS: 1}, + {From: SignalRoleBrowser, Type: SignalMsgCandidate, Payload: "{cand:1}", TS: 2}, + } + srv := fakeSSEServer(t, want, false) + defer srv.Close() + + c := NewClient(srv.URL, "test-key", "test-ua") + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + stream, err := c.OpenSignalStream(ctx, "session-1") + if err != nil { + t.Fatalf("open: %v", err) + } + defer stream.Close() + + var got []SignalMessage + for m := range stream.Events() { + got = append(got, m) + if len(got) == len(want) { + break + } + } + if len(got) != len(want) { + t.Fatalf("got %d messages, want %d", len(got), len(want)) + } + for i, m := range got { + if m.From != want[i].From || m.Type != want[i].Type || m.Payload != want[i].Payload { + t.Errorf("[%d] mismatch: %+v want %+v", i, m, want[i]) + } + } +} + +func TestSignalStreamPropagatesAuthError(t *testing.T) { + srv := fakeSSEServer(t, nil, false) + defer srv.Close() + + c := NewClient(srv.URL, "wrong-key", "test-ua") + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + _, err := c.OpenSignalStream(ctx, "session-1") + if err == nil { + t.Fatal("expected auth error, got nil") + } +} + +func TestSignalStreamCloseCancelsRead(t *testing.T) { + srv := fakeSSEServer(t, nil, true) + defer srv.Close() + + c := NewClient(srv.URL, "test-key", "test-ua") + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + stream, err := c.OpenSignalStream(ctx, "session-1") + if err != nil { + t.Fatalf("open: %v", err) + } + + // Close on a separate goroutine then make sure the events channel drains. + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + time.Sleep(50 * time.Millisecond) + stream.Close() + }() + + for range stream.Events() { + // drain + } + wg.Wait() +} + +func TestPostSignalSendsCorrectBody(t *testing.T) { + var bodySeen map[string]any + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") != "Bearer test-key" { + http.Error(w, "auth", http.StatusUnauthorized) + return + } + _ = json.NewDecoder(r.Body).Decode(&bodySeen) + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"ok":true}`) + })) + defer srv.Close() + + c := NewClient(srv.URL, "test-key", "test-ua") + err := c.PostSignal(context.Background(), "sess-x", SignalMessage{ + Type: SignalMsgAnswer, + Payload: "{sdp:answer}", + }) + if err != nil { + t.Fatalf("post: %v", err) + } + if bodySeen["from"] != string(SignalRoleAgent) { + t.Errorf("expected from=agent, got %v", bodySeen["from"]) + } + if bodySeen["type"] != string(SignalMsgAnswer) { + t.Errorf("expected type=answer, got %v", bodySeen["type"]) + } + if bodySeen["payload"] != "{sdp:answer}" { + t.Errorf("expected payload mismatch, got %v", bodySeen["payload"]) + } +} diff --git a/internal/agent/sync.go b/internal/agent/sync.go index 49f0e65..864de8a 100644 --- a/internal/agent/sync.go +++ b/internal/agent/sync.go @@ -29,6 +29,7 @@ type SyncClient struct { OnNewTasks func(tasks []Task) OnControl func(action, taskID string, deleteFiles bool) OnStreamRequest func(req StreamRequest) + OnWebRTCSession func(sess WebRTCSession) OnUpgrade func(version string) OnScan func() OnWatchingChange func(watching bool) @@ -191,6 +192,13 @@ func (sc *SyncClient) processResponse(resp *SyncResponse) { } } + // WebRTC streaming sessions + for _, ws := range resp.WebRTCSessions { + if sc.OnWebRTCSession != nil { + sc.OnWebRTCSession(ws) + } + } + // Upgrade if resp.Upgrade != nil && resp.Upgrade.Version != "" && sc.OnUpgrade != nil { sc.OnUpgrade(resp.Upgrade.Version) diff --git a/internal/agent/types.go b/internal/agent/types.go index eb88385..0a67a20 100644 --- a/internal/agent/types.go +++ b/internal/agent/types.go @@ -351,11 +351,25 @@ type LibraryDeleteRequest struct { FilePath string `json:"filePath"` } +// WebRTCSession is a request to open a custom WebRTC DataChannel byte-stream +// to a browser player. The CLI must POST an SDP answer to +// /api/internal/stream/signal/ and serve bytes from FilePath +// (or, when only InfoHash is set, from a download_task on disk). +type WebRTCSession struct { + SessionID string `json:"sessionId"` + FilePath string `json:"filePath,omitempty"` + InfoHash string `json:"infoHash,omitempty"` + TaskID string `json:"taskId,omitempty"` + FileName string `json:"fileName,omitempty"` + FileSize int64 `json:"fileSize,omitempty"` +} + // SyncResponse is returned by the server with all pending actions for the CLI. type SyncResponse struct { NewTasks []Task `json:"newTasks,omitempty"` Controls []ControlAction `json:"controls,omitempty"` StreamRequests []StreamRequest `json:"streamRequests,omitempty"` + WebRTCSessions []WebRTCSession `json:"webrtcSessions,omitempty"` Watching bool `json:"watching"` Upgrade *UpgradeSignal `json:"upgrade,omitempty"` Scan bool `json:"scan,omitempty"` diff --git a/internal/cmd/daemon.go b/internal/cmd/daemon.go index 9bdb714..b85c9c2 100644 --- a/internal/cmd/daemon.go +++ b/internal/cmd/daemon.go @@ -410,6 +410,65 @@ func runDaemonStart() error { }() } + // Wire: sync receives custom WebRTC streaming session requests. + // Each session is a one-shot browser↔daemon DataChannel. Validate the + // FilePath against allowed dirs to prevent path traversal abuse from a + // compromised server, then spawn the pion peer in its own goroutine. + d.OnWebRTCSession = func(sess agent.WebRTCSession) { + if webrtcRegistry.has(sess.SessionID) { + return // already running + } + if !cfg.Download.WebRTC.Enabled { + log.Printf("webrtc session %s rejected: webrtc disabled in config", agent.ShortID(sess.SessionID)) + return + } + filePath := sess.FilePath + if filePath == "" { + log.Printf("webrtc session %s rejected: empty file path", agent.ShortID(sess.SessionID)) + return + } + filePath = filepath.Clean(filePath) + if !isAllowedStreamPath(filePath, cfg.Download.Dir, cfg.Library.ScanPath, + cfg.Organize.MoviesDir, cfg.Organize.TVShowsDir) { + log.Printf("webrtc session %s rejected: path outside allowed dirs: %s", + agent.ShortID(sess.SessionID), filePath) + return + } + // Resolve directory → first video file (matches StreamRequest behavior). + if info, err := os.Stat(filePath); err == nil && info.IsDir() { + found := engine.FindVideoFile(filePath) + if found == "" { + log.Printf("webrtc session %s rejected: no video file in dir %s", + agent.ShortID(sess.SessionID), filePath) + return + } + filePath = found + } + sessCtx, sessCancel := context.WithCancel(ctx) //nolint:gosec // G118 cancel stored in registry + webrtcRegistry.add(sess.SessionID, sessCancel) + go func() { + defer func() { + webrtcRegistry.remove(sess.SessionID) + sessCancel() + }() + runCfg := engine.WebRTCStreamConfig{ + SessionID: sess.SessionID, + FilePath: filePath, + FileName: sess.FileName, + FileSize: sess.FileSize, + ICEServers: engine.BuildICEServers(cfg.Download.WebRTC), + Signal: agentClient, + Logger: stdLogger{}, + } + log.Printf("[wrtc %s] starting session: %s", agent.ShortID(sess.SessionID), filepath.Base(filePath)) + if err := engine.RunWebRTCStream(sessCtx, runCfg); err != nil { + if sessCtx.Err() == nil { + log.Printf("[wrtc %s] ended: %v", agent.ShortID(sess.SessionID), err) + } + } + }() + } + // Periodic DHT node persistence (every 5 min) go func() { ticker := time.NewTicker(5 * time.Minute) @@ -457,6 +516,7 @@ func runDaemonStart() error { case sig := <-sigCh: fmt.Printf("\n Received %s, shutting down...\n", sig) cancelStreamContexts() + cancelAllWebRTCSessions() streamSrv.Shutdown(context.Background()) cancel() @@ -471,6 +531,7 @@ func runDaemonStart() error { case err := <-errCh: cancelStreamContexts() + cancelAllWebRTCSessions() streamSrv.Shutdown(context.Background()) cancel() return err diff --git a/internal/cmd/webrtc_session_registry.go b/internal/cmd/webrtc_session_registry.go new file mode 100644 index 0000000..b0ec0b7 --- /dev/null +++ b/internal/cmd/webrtc_session_registry.go @@ -0,0 +1,62 @@ +package cmd + +import ( + "context" + "log" + "sync" +) + +// webrtcRegistry tracks per-session cancel funcs for active custom WebRTC +// streams (engine.RunWebRTCStream goroutines). Each session lives only as +// long as its DataChannel; the registry exists so duplicate sync responses +// don't double-spawn the same session and so daemon shutdown can drain. +var webrtcRegistry = &webrtcSessionRegistry{ + cancels: make(map[string]context.CancelFunc), +} + +type webrtcSessionRegistry struct { + mu sync.Mutex + cancels map[string]context.CancelFunc +} + +func (r *webrtcSessionRegistry) has(sessionID string) bool { + r.mu.Lock() + defer r.mu.Unlock() + _, ok := r.cancels[sessionID] + return ok +} + +func (r *webrtcSessionRegistry) add(sessionID string, cancel context.CancelFunc) { + r.mu.Lock() + defer r.mu.Unlock() + r.cancels[sessionID] = cancel +} + +func (r *webrtcSessionRegistry) remove(sessionID string) { + r.mu.Lock() + defer r.mu.Unlock() + delete(r.cancels, sessionID) +} + +// cancelAllWebRTCSessions cancels every running session. Called on daemon +// shutdown so pion peers and SSE consumers exit cleanly. +func cancelAllWebRTCSessions() { + webrtcRegistry.mu.Lock() + cancels := make([]context.CancelFunc, 0, len(webrtcRegistry.cancels)) + for _, c := range webrtcRegistry.cancels { + cancels = append(cancels, c) + } + webrtcRegistry.cancels = make(map[string]context.CancelFunc) + webrtcRegistry.mu.Unlock() + for _, c := range cancels { + c() + } +} + +// stdLogger is a tiny adapter so engine.RunWebRTCStream can log through the +// standard library logger without pulling in a logging dependency. +type stdLogger struct{} + +func (stdLogger) Infof(format string, args ...any) { log.Printf(format, args...) } +func (stdLogger) Warnf(format string, args ...any) { log.Printf("WARN: "+format, args...) } +func (stdLogger) Errorf(format string, args ...any) { log.Printf("ERROR: "+format, args...) } diff --git a/internal/engine/hwaccel.go b/internal/engine/hwaccel.go new file mode 100644 index 0000000..3d74c52 --- /dev/null +++ b/internal/engine/hwaccel.go @@ -0,0 +1,130 @@ +package engine + +import ( + "context" + "os" + "os/exec" + "runtime" + "strings" + "sync" +) + +// HWAccel identifies a hardware-accelerated ffmpeg encoder family. +type HWAccel string + +const ( + HWAccelNone HWAccel = "none" + HWAccelNVENC HWAccel = "nvenc" // NVIDIA — h264_nvenc / hevc_nvenc + HWAccelQSV HWAccel = "qsv" // Intel Quick Sync — h264_qsv / hevc_qsv + HWAccelVAAPI HWAccel = "vaapi" // Linux open-source — h264_vaapi / hevc_vaapi + HWAccelVideoToolbox HWAccel = "videotoolbox" // macOS — h264_videotoolbox +) + +var ( + hwOnce sync.Once + hwCache HWAccel +) + +// DetectHWAccel returns the most capable hardware encoder available on this +// host, or HWAccelNone if software-only. Cached after first call — adding / +// removing a GPU at runtime is rare and the cost of probing isn't free. +func DetectHWAccel(ctx context.Context, ffmpegPath string) HWAccel { + hwOnce.Do(func() { + hwCache = detectHWAccelFresh(ctx, ffmpegPath) + }) + return hwCache +} + +// ResetHWAccelCache clears the singleton — only used in tests. +func ResetHWAccelCache() { + hwOnce = sync.Once{} + hwCache = "" +} + +func detectHWAccelFresh(ctx context.Context, ffmpegPath string) HWAccel { + if ffmpegPath == "" { + return HWAccelNone + } + encoders := listFFmpegEncoders(ctx, ffmpegPath) + if encoders == "" { + return HWAccelNone + } + + // macOS — VideoToolbox is always available on Apple Silicon + recent Intel. + if runtime.GOOS == "darwin" && strings.Contains(encoders, "h264_videotoolbox") { + return HWAccelVideoToolbox + } + + // NVIDIA — encoder presence + a CUDA-capable device. We rely on the + // existence of the device file rather than running nvidia-smi to keep + // startup quick on hosts without nvidia tooling. + if strings.Contains(encoders, "h264_nvenc") && + (fileExists("/dev/nvidia0") || hasNvidiaDriver()) { + return HWAccelNVENC + } + + // Intel Quick Sync — needs /dev/dri (also used by VA-API). Distinguish by + // checking whether the QSV-specific encoder is built in. + if strings.Contains(encoders, "h264_qsv") && fileExists("/dev/dri/renderD128") { + return HWAccelQSV + } + + // Linux generic VA-API — works on Intel + AMD with mesa drivers. + if strings.Contains(encoders, "h264_vaapi") && fileExists("/dev/dri/renderD128") { + return HWAccelVAAPI + } + + return HWAccelNone +} + +func listFFmpegEncoders(ctx context.Context, ffmpegPath string) string { + cmd := exec.CommandContext(ctx, ffmpegPath, "-hide_banner", "-encoders") + out, err := cmd.CombinedOutput() + if err != nil { + return "" + } + return string(out) +} + +func fileExists(path string) bool { + _, err := os.Stat(path) + return err == nil +} + +func hasNvidiaDriver() bool { + // Cheap proxy — if the user has nvidia-smi on PATH they presumably also + // have a working driver / runtime libraries. + _, err := exec.LookPath("nvidia-smi") + return err == nil +} + +// FFmpegVideoCodec returns the encoder name to pass to `-c:v` for the +// requested HW accel + target (h264 or hevc). +func (h HWAccel) FFmpegVideoCodec(target string) string { + target = strings.ToLower(target) + switch h { + case HWAccelNVENC: + if target == "hevc" { + return "hevc_nvenc" + } + return "h264_nvenc" + case HWAccelQSV: + if target == "hevc" { + return "hevc_qsv" + } + return "h264_qsv" + case HWAccelVAAPI: + if target == "hevc" { + return "hevc_vaapi" + } + return "h264_vaapi" + case HWAccelVideoToolbox: + if target == "hevc" { + return "hevc_videotoolbox" + } + return "h264_videotoolbox" + default: + // Software fallback. libx264 ships with every ffmpeg build. + return "libx264" + } +} diff --git a/internal/engine/hwaccel_test.go b/internal/engine/hwaccel_test.go new file mode 100644 index 0000000..f022d29 --- /dev/null +++ b/internal/engine/hwaccel_test.go @@ -0,0 +1,34 @@ +package engine + +import "testing" + +func TestHWAccelFFmpegVideoCodec(t *testing.T) { + cases := []struct { + hw HWAccel + target string + want string + }{ + {HWAccelNone, "h264", "libx264"}, + {HWAccelNone, "hevc", "libx264"}, + {HWAccelNVENC, "h264", "h264_nvenc"}, + {HWAccelNVENC, "hevc", "hevc_nvenc"}, + {HWAccelQSV, "h264", "h264_qsv"}, + {HWAccelQSV, "hevc", "hevc_qsv"}, + {HWAccelVAAPI, "h264", "h264_vaapi"}, + {HWAccelVAAPI, "hevc", "hevc_vaapi"}, + {HWAccelVideoToolbox, "h264", "h264_videotoolbox"}, + {HWAccelVideoToolbox, "hevc", "hevc_videotoolbox"}, + } + for _, tc := range cases { + if got := tc.hw.FFmpegVideoCodec(tc.target); got != tc.want { + t.Errorf("%s.FFmpegVideoCodec(%q) = %q want %q", tc.hw, tc.target, got, tc.want) + } + } +} + +func TestDetectHWAccelEmptyPathReturnsNone(t *testing.T) { + ResetHWAccelCache() + if got := detectHWAccelFresh(t.Context(), ""); got != HWAccelNone { + t.Errorf("got %s, want %s", got, HWAccelNone) + } +} diff --git a/internal/engine/probe.go b/internal/engine/probe.go new file mode 100644 index 0000000..8e3e654 --- /dev/null +++ b/internal/engine/probe.go @@ -0,0 +1,116 @@ +package engine + +import ( + "context" + "fmt" + "strings" + + "github.com/torrentclaw/unarr/internal/library/mediainfo" +) + +// StreamProbe summarises the codec / container shape of a file as it relates +// to the WebRTC streaming pipeline. It tells the transcoder whether bytes can +// be streamed as-is, just remuxed to fragmented MP4, or fully transcoded. +type StreamProbe struct { + // VideoCodec lowercased — e.g. "h264", "hevc", "av1", "vp9", "mpeg4". + VideoCodec string + // AudioCodec lowercased — e.g. "aac", "ac3", "dts", "eac3", "opus". + AudioCodec string + // Width / Height of the primary video stream. + Width int + Height int + // BitDepth — 8, 10 or 12. 0 if unknown. + BitDepth int + // HDR signalling string ("HDR10" / "DV" / "HLG" / etc, or "" for SDR). + HDR string + // DurationSec is the file length, used to sanity-check seek targets. + DurationSec float64 + // Container is the file extension lowercased (".mp4", ".mkv", ".avi"). + Container string +} + +// TranscodeAction tells the streaming pipeline how to feed the file to +// the browser