feat(auth): browser-based CLI authentication (like Claude Code)
- New browser auth flow: CLI opens localhost server, browser redirects token back via callback — zero copy/paste needed - Automatic fallback to manual API key entry if browser flow fails - Server-side state validation with TTL to prevent phishing - sync.Once guard on callback to prevent goroutine leaks - Localhost-only redirect validation (regex + url.Parse) - URL-escaped state parameter for safety
This commit is contained in:
parent
677a8fe083
commit
20d4d34dfc
3 changed files with 372 additions and 29 deletions
151
internal/cmd/auth_browser.go
Normal file
151
internal/cmd/auth_browser.go
Normal file
|
|
@ -0,0 +1,151 @@
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const browserAuthTimeout = 5 * time.Minute
|
||||||
|
|
||||||
|
// browserAuth opens a browser for the user to authorize the CLI.
|
||||||
|
// Returns the API key on success, or an error if the flow fails/times out.
|
||||||
|
//
|
||||||
|
// Flow:
|
||||||
|
// 1. Start a temporary HTTP server on a random localhost port
|
||||||
|
// 2. Open browser to {apiURL}/cli/auth?state={state}&port={port}
|
||||||
|
// 3. User logs in and clicks "Authorize" on the web page
|
||||||
|
// 4. Web redirects to localhost:{port}/callback?token=tc_...&state={state}
|
||||||
|
// 5. CLI validates state, extracts token, closes server
|
||||||
|
func browserAuth(apiURL string) (string, error) {
|
||||||
|
// Validate apiURL is a well-formed HTTP(S) URL
|
||||||
|
parsed, err := url.Parse(apiURL)
|
||||||
|
if err != nil || (parsed.Scheme != "http" && parsed.Scheme != "https") || parsed.Host == "" {
|
||||||
|
return "", fmt.Errorf("invalid API URL: %s", apiURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
state, err := generateState()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("generate state: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find a free port
|
||||||
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("listen: %w", err)
|
||||||
|
}
|
||||||
|
port := listener.Addr().(*net.TCPAddr).Port
|
||||||
|
|
||||||
|
// Channel to receive the token from the callback
|
||||||
|
tokenCh := make(chan string, 1)
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
|
||||||
|
var once sync.Once
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
handled := false
|
||||||
|
once.Do(func() {
|
||||||
|
handled = true
|
||||||
|
|
||||||
|
// Validate state to prevent CSRF
|
||||||
|
if r.URL.Query().Get("state") != state {
|
||||||
|
http.Error(w, "Invalid state parameter", http.StatusBadRequest)
|
||||||
|
errCh <- fmt.Errorf("state mismatch")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
token := r.URL.Query().Get("token")
|
||||||
|
if token == "" {
|
||||||
|
http.Error(w, "No token received", http.StatusBadRequest)
|
||||||
|
errCh <- fmt.Errorf("empty token in callback")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Respond with a success page
|
||||||
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
|
fmt.Fprint(w, callbackHTML)
|
||||||
|
|
||||||
|
tokenCh <- token
|
||||||
|
})
|
||||||
|
if !handled {
|
||||||
|
http.Error(w, "Already processed", http.StatusConflict)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
server := &http.Server{Handler: mux}
|
||||||
|
|
||||||
|
// Start server in background
|
||||||
|
go func() {
|
||||||
|
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
|
||||||
|
errCh <- err
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Open browser
|
||||||
|
authURL := fmt.Sprintf("%s/cli/auth?state=%s&port=%d", apiURL, url.QueryEscape(state), port)
|
||||||
|
openBrowser(authURL)
|
||||||
|
|
||||||
|
// Wait for callback, error, or timeout
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), browserAuthTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
var token string
|
||||||
|
select {
|
||||||
|
case token = <-tokenCh:
|
||||||
|
// Success
|
||||||
|
case err := <-errCh:
|
||||||
|
shutdownCtx2, cancel2 := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
_ = server.Shutdown(shutdownCtx2)
|
||||||
|
cancel2()
|
||||||
|
return "", err
|
||||||
|
case <-ctx.Done():
|
||||||
|
shutdownCtx2, cancel2 := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
_ = server.Shutdown(shutdownCtx2)
|
||||||
|
cancel2()
|
||||||
|
return "", fmt.Errorf("timed out waiting for browser authorization")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown the server
|
||||||
|
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer shutdownCancel()
|
||||||
|
_ = server.Shutdown(shutdownCtx)
|
||||||
|
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateState() (string, error) {
|
||||||
|
b := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// callbackHTML is the page shown in the browser after successful authorization.
|
||||||
|
const callbackHTML = `<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta charset="utf-8">
|
||||||
|
<title>unarr — Connected</title>
|
||||||
|
<style>
|
||||||
|
body { font-family: -apple-system, system-ui, sans-serif; display: flex; justify-content: center; align-items: center; min-height: 100vh; margin: 0; background: #0a0a0a; color: #fafafa; }
|
||||||
|
.card { text-align: center; padding: 3rem; }
|
||||||
|
.check { font-size: 4rem; margin-bottom: 1rem; }
|
||||||
|
h1 { font-size: 1.5rem; margin-bottom: 0.5rem; }
|
||||||
|
p { color: #888; font-size: 0.95rem; }
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="card">
|
||||||
|
<div class="check">✓</div>
|
||||||
|
<h1>Connected to torrentclaw</h1>
|
||||||
|
<p>You can close this tab and return to your terminal.</p>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>`
|
||||||
175
internal/cmd/auth_browser_test.go
Normal file
175
internal/cmd/auth_browser_test.go
Normal file
|
|
@ -0,0 +1,175 @@
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGenerateState(t *testing.T) {
|
||||||
|
state, err := generateState()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("generateState: %v", err)
|
||||||
|
}
|
||||||
|
if len(state) != 64 { // 32 bytes = 64 hex chars
|
||||||
|
t.Errorf("state length = %d, want 64", len(state))
|
||||||
|
}
|
||||||
|
for _, c := range state {
|
||||||
|
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
|
||||||
|
t.Errorf("state contains non-hex char: %c", c)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Two states should differ
|
||||||
|
state2, _ := generateState()
|
||||||
|
if state == state2 {
|
||||||
|
t.Error("consecutive states should differ")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCallbackHTML(t *testing.T) {
|
||||||
|
if !strings.Contains(callbackHTML, "Connected to torrentclaw") {
|
||||||
|
t.Error("missing success message")
|
||||||
|
}
|
||||||
|
if !strings.Contains(callbackHTML, "close this tab") {
|
||||||
|
t.Error("missing close instruction")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCallbackHandler_ValidState(t *testing.T) {
|
||||||
|
state := "abc123def456abc123def456abc123def456abc123def456abc123def456abcd"
|
||||||
|
tokenCh := make(chan string, 1)
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Query().Get("state") != state {
|
||||||
|
http.Error(w, "Invalid state", http.StatusBadRequest)
|
||||||
|
errCh <- fmt.Errorf("state mismatch")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
token := r.URL.Query().Get("token")
|
||||||
|
if token == "" {
|
||||||
|
http.Error(w, "No token", http.StatusBadRequest)
|
||||||
|
errCh <- fmt.Errorf("empty token")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "text/html")
|
||||||
|
fmt.Fprint(w, callbackHTML)
|
||||||
|
tokenCh <- token
|
||||||
|
})
|
||||||
|
|
||||||
|
server := httptest.NewServer(mux)
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
// Simulate browser redirect to callback
|
||||||
|
resp, err := http.Get(server.URL + "/callback?token=tc_test_key_123&state=" + state)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("callback request: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
t.Errorf("status = %d, want 200", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case token := <-tokenCh:
|
||||||
|
if token != "tc_test_key_123" {
|
||||||
|
t.Errorf("token = %q, want tc_test_key_123", token)
|
||||||
|
}
|
||||||
|
case err := <-errCh:
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("timeout waiting for token")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCallbackHandler_InvalidState(t *testing.T) {
|
||||||
|
tokenCh := make(chan string, 1)
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Query().Get("state") != "correct_state" {
|
||||||
|
http.Error(w, "Invalid state", http.StatusBadRequest)
|
||||||
|
errCh <- fmt.Errorf("state mismatch")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tokenCh <- r.URL.Query().Get("token")
|
||||||
|
})
|
||||||
|
|
||||||
|
server := httptest.NewServer(mux)
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
resp, err := http.Get(server.URL + "/callback?token=tc_test&state=wrong_state")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("callback request: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Errorf("status = %d, want 400", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-errCh:
|
||||||
|
// Expected — state mismatch
|
||||||
|
case <-tokenCh:
|
||||||
|
t.Fatal("should not have received token with wrong state")
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("timeout")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCallbackHandler_MissingToken(t *testing.T) {
|
||||||
|
state := "valid_state_0123456789abcdef0123456789abcdef"
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Query().Get("state") != state {
|
||||||
|
http.Error(w, "Invalid state", http.StatusBadRequest)
|
||||||
|
errCh <- fmt.Errorf("state mismatch")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
token := r.URL.Query().Get("token")
|
||||||
|
if token == "" {
|
||||||
|
http.Error(w, "No token", http.StatusBadRequest)
|
||||||
|
errCh <- fmt.Errorf("empty token")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
server := httptest.NewServer(mux)
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
resp, err := http.Get(server.URL + "/callback?state=" + state)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("callback request: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Errorf("status = %d, want 400", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBrowserAuth_ServerBinds(t *testing.T) {
|
||||||
|
// Verify browserAuth can bind to a free port
|
||||||
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("listen: %v", err)
|
||||||
|
}
|
||||||
|
port := listener.Addr().(*net.TCPAddr).Port
|
||||||
|
listener.Close()
|
||||||
|
|
||||||
|
if port < 1024 {
|
||||||
|
t.Errorf("port %d < 1024", port)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -73,39 +73,56 @@ func runInit(apiURLOverride string) error {
|
||||||
|
|
||||||
// ── Step 1/3: Connect account ───────────────────────────────────
|
// ── Step 1/3: Connect account ───────────────────────────────────
|
||||||
|
|
||||||
keysURL := apiURL + "/profile?tab=apikey"
|
|
||||||
fmt.Printf(" Opening %s ...\n", keysURL)
|
|
||||||
openBrowser(keysURL)
|
|
||||||
fmt.Println()
|
|
||||||
|
|
||||||
apiKey := cfg.Auth.APIKey
|
apiKey := cfg.Auth.APIKey
|
||||||
err := huh.NewForm(
|
|
||||||
huh.NewGroup(
|
if apiKey == "" {
|
||||||
huh.NewInput().
|
// Try browser-based auth first (like Claude Code / GitHub CLI)
|
||||||
Title("Step 1/3 — API Key").
|
fmt.Println(" Opening browser to connect your account...")
|
||||||
Description("Copy it from the page that just opened in your browser").
|
fmt.Println()
|
||||||
Placeholder("tc_...").
|
|
||||||
Value(&apiKey).
|
browserKey, browserErr := browserAuth(apiURL)
|
||||||
Validate(func(s string) error {
|
if browserErr == nil && strings.HasPrefix(browserKey, "tc_") {
|
||||||
s = strings.TrimSpace(s)
|
apiKey = browserKey
|
||||||
if s == "" {
|
green.Println(" ✓ Connected via browser")
|
||||||
return fmt.Errorf("API key is required")
|
fmt.Println()
|
||||||
}
|
} else {
|
||||||
if !strings.HasPrefix(s, "tc_") {
|
// Fallback to manual API key entry
|
||||||
return fmt.Errorf("API key should start with tc_")
|
if browserErr != nil {
|
||||||
}
|
dim.Printf(" Could not connect automatically: %s\n", browserErr)
|
||||||
|
}
|
||||||
|
fmt.Println(" Paste your API key instead:")
|
||||||
|
dim.Printf(" (get it from %s/profile?tab=apikey)\n", apiURL)
|
||||||
|
fmt.Println()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
err = huh.NewForm(
|
||||||
|
huh.NewGroup(
|
||||||
|
huh.NewInput().
|
||||||
|
Title("Step 1/3 — API Key").
|
||||||
|
Placeholder("tc_...").
|
||||||
|
Value(&apiKey).
|
||||||
|
Validate(func(s string) error {
|
||||||
|
s = strings.TrimSpace(s)
|
||||||
|
if s == "" {
|
||||||
|
return fmt.Errorf("API key is required")
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(s, "tc_") {
|
||||||
|
return fmt.Errorf("API key should start with tc_")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
).Run()
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, huh.ErrUserAborted) {
|
||||||
|
fmt.Println("\n Init cancelled.")
|
||||||
return nil
|
return nil
|
||||||
}),
|
}
|
||||||
),
|
return err
|
||||||
).Run()
|
}
|
||||||
if err != nil {
|
apiKey = strings.TrimSpace(apiKey)
|
||||||
if errors.Is(err, huh.ErrUserAborted) {
|
|
||||||
fmt.Println("\n Init cancelled.")
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
apiKey = strings.TrimSpace(apiKey)
|
|
||||||
|
|
||||||
// Validate API key by registering with the server
|
// Validate API key by registering with the server
|
||||||
fmt.Print(" Verifying API key... ")
|
fmt.Print(" Verifying API key... ")
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue