diff --git a/internal/cmd/auth_browser.go b/internal/cmd/auth_browser.go new file mode 100644 index 0000000..a74090f --- /dev/null +++ b/internal/cmd/auth_browser.go @@ -0,0 +1,151 @@ +package cmd + +import ( + "context" + "crypto/rand" + "encoding/hex" + "fmt" + "net" + "net/http" + "net/url" + "sync" + "time" +) + +const browserAuthTimeout = 5 * time.Minute + +// browserAuth opens a browser for the user to authorize the CLI. +// Returns the API key on success, or an error if the flow fails/times out. +// +// Flow: +// 1. Start a temporary HTTP server on a random localhost port +// 2. Open browser to {apiURL}/cli/auth?state={state}&port={port} +// 3. User logs in and clicks "Authorize" on the web page +// 4. Web redirects to localhost:{port}/callback?token=tc_...&state={state} +// 5. CLI validates state, extracts token, closes server +func browserAuth(apiURL string) (string, error) { + // Validate apiURL is a well-formed HTTP(S) URL + parsed, err := url.Parse(apiURL) + if err != nil || (parsed.Scheme != "http" && parsed.Scheme != "https") || parsed.Host == "" { + return "", fmt.Errorf("invalid API URL: %s", apiURL) + } + + state, err := generateState() + if err != nil { + return "", fmt.Errorf("generate state: %w", err) + } + + // Find a free port + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return "", fmt.Errorf("listen: %w", err) + } + port := listener.Addr().(*net.TCPAddr).Port + + // Channel to receive the token from the callback + tokenCh := make(chan string, 1) + errCh := make(chan error, 1) + + var once sync.Once + mux := http.NewServeMux() + mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) { + handled := false + once.Do(func() { + handled = true + + // Validate state to prevent CSRF + if r.URL.Query().Get("state") != state { + http.Error(w, "Invalid state parameter", http.StatusBadRequest) + errCh <- fmt.Errorf("state mismatch") + return + } + + token := r.URL.Query().Get("token") + if token == "" { + http.Error(w, "No token received", http.StatusBadRequest) + errCh <- fmt.Errorf("empty token in callback") + return + } + + // Respond with a success page + w.Header().Set("Content-Type", "text/html; charset=utf-8") + fmt.Fprint(w, callbackHTML) + + tokenCh <- token + }) + if !handled { + http.Error(w, "Already processed", http.StatusConflict) + } + }) + + server := &http.Server{Handler: mux} + + // Start server in background + go func() { + if err := server.Serve(listener); err != nil && err != http.ErrServerClosed { + errCh <- err + } + }() + + // Open browser + authURL := fmt.Sprintf("%s/cli/auth?state=%s&port=%d", apiURL, url.QueryEscape(state), port) + openBrowser(authURL) + + // Wait for callback, error, or timeout + ctx, cancel := context.WithTimeout(context.Background(), browserAuthTimeout) + defer cancel() + + var token string + select { + case token = <-tokenCh: + // Success + case err := <-errCh: + shutdownCtx2, cancel2 := context.WithTimeout(context.Background(), 2*time.Second) + _ = server.Shutdown(shutdownCtx2) + cancel2() + return "", err + case <-ctx.Done(): + shutdownCtx2, cancel2 := context.WithTimeout(context.Background(), 2*time.Second) + _ = server.Shutdown(shutdownCtx2) + cancel2() + return "", fmt.Errorf("timed out waiting for browser authorization") + } + + // Shutdown the server + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 2*time.Second) + defer shutdownCancel() + _ = server.Shutdown(shutdownCtx) + + return token, nil +} + +func generateState() (string, error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", err + } + return hex.EncodeToString(b), nil +} + +// callbackHTML is the page shown in the browser after successful authorization. +const callbackHTML = ` + + + + unarr — Connected + + + +
+
+

Connected to torrentclaw

+

You can close this tab and return to your terminal.

+
+ +` diff --git a/internal/cmd/auth_browser_test.go b/internal/cmd/auth_browser_test.go new file mode 100644 index 0000000..d832266 --- /dev/null +++ b/internal/cmd/auth_browser_test.go @@ -0,0 +1,175 @@ +package cmd + +import ( + "fmt" + "net" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func TestGenerateState(t *testing.T) { + state, err := generateState() + if err != nil { + t.Fatalf("generateState: %v", err) + } + if len(state) != 64 { // 32 bytes = 64 hex chars + t.Errorf("state length = %d, want 64", len(state)) + } + for _, c := range state { + if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) { + t.Errorf("state contains non-hex char: %c", c) + break + } + } + + // Two states should differ + state2, _ := generateState() + if state == state2 { + t.Error("consecutive states should differ") + } +} + +func TestCallbackHTML(t *testing.T) { + if !strings.Contains(callbackHTML, "Connected to torrentclaw") { + t.Error("missing success message") + } + if !strings.Contains(callbackHTML, "close this tab") { + t.Error("missing close instruction") + } +} + +func TestCallbackHandler_ValidState(t *testing.T) { + state := "abc123def456abc123def456abc123def456abc123def456abc123def456abcd" + tokenCh := make(chan string, 1) + errCh := make(chan error, 1) + + mux := http.NewServeMux() + mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("state") != state { + http.Error(w, "Invalid state", http.StatusBadRequest) + errCh <- fmt.Errorf("state mismatch") + return + } + token := r.URL.Query().Get("token") + if token == "" { + http.Error(w, "No token", http.StatusBadRequest) + errCh <- fmt.Errorf("empty token") + return + } + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, callbackHTML) + tokenCh <- token + }) + + server := httptest.NewServer(mux) + defer server.Close() + + // Simulate browser redirect to callback + resp, err := http.Get(server.URL + "/callback?token=tc_test_key_123&state=" + state) + if err != nil { + t.Fatalf("callback request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + t.Errorf("status = %d, want 200", resp.StatusCode) + } + + select { + case token := <-tokenCh: + if token != "tc_test_key_123" { + t.Errorf("token = %q, want tc_test_key_123", token) + } + case err := <-errCh: + t.Fatalf("unexpected error: %v", err) + case <-time.After(time.Second): + t.Fatal("timeout waiting for token") + } +} + +func TestCallbackHandler_InvalidState(t *testing.T) { + tokenCh := make(chan string, 1) + errCh := make(chan error, 1) + + mux := http.NewServeMux() + mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("state") != "correct_state" { + http.Error(w, "Invalid state", http.StatusBadRequest) + errCh <- fmt.Errorf("state mismatch") + return + } + tokenCh <- r.URL.Query().Get("token") + }) + + server := httptest.NewServer(mux) + defer server.Close() + + resp, err := http.Get(server.URL + "/callback?token=tc_test&state=wrong_state") + if err != nil { + t.Fatalf("callback request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("status = %d, want 400", resp.StatusCode) + } + + select { + case <-errCh: + // Expected — state mismatch + case <-tokenCh: + t.Fatal("should not have received token with wrong state") + case <-time.After(time.Second): + t.Fatal("timeout") + } +} + +func TestCallbackHandler_MissingToken(t *testing.T) { + state := "valid_state_0123456789abcdef0123456789abcdef" + errCh := make(chan error, 1) + + mux := http.NewServeMux() + mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("state") != state { + http.Error(w, "Invalid state", http.StatusBadRequest) + errCh <- fmt.Errorf("state mismatch") + return + } + token := r.URL.Query().Get("token") + if token == "" { + http.Error(w, "No token", http.StatusBadRequest) + errCh <- fmt.Errorf("empty token") + return + } + }) + + server := httptest.NewServer(mux) + defer server.Close() + + resp, err := http.Get(server.URL + "/callback?state=" + state) + if err != nil { + t.Fatalf("callback request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("status = %d, want 400", resp.StatusCode) + } +} + +func TestBrowserAuth_ServerBinds(t *testing.T) { + // Verify browserAuth can bind to a free port + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + port := listener.Addr().(*net.TCPAddr).Port + listener.Close() + + if port < 1024 { + t.Errorf("port %d < 1024", port) + } +} diff --git a/internal/cmd/init.go b/internal/cmd/init.go index 3bca2c3..845d81f 100644 --- a/internal/cmd/init.go +++ b/internal/cmd/init.go @@ -73,39 +73,56 @@ func runInit(apiURLOverride string) error { // ── Step 1/3: Connect account ─────────────────────────────────── - keysURL := apiURL + "/profile?tab=apikey" - fmt.Printf(" Opening %s ...\n", keysURL) - openBrowser(keysURL) - fmt.Println() - apiKey := cfg.Auth.APIKey - err := huh.NewForm( - huh.NewGroup( - huh.NewInput(). - Title("Step 1/3 — API Key"). - Description("Copy it from the page that just opened in your browser"). - Placeholder("tc_..."). - Value(&apiKey). - Validate(func(s string) error { - s = strings.TrimSpace(s) - if s == "" { - return fmt.Errorf("API key is required") - } - if !strings.HasPrefix(s, "tc_") { - return fmt.Errorf("API key should start with tc_") - } + + if apiKey == "" { + // Try browser-based auth first (like Claude Code / GitHub CLI) + fmt.Println(" Opening browser to connect your account...") + fmt.Println() + + browserKey, browserErr := browserAuth(apiURL) + if browserErr == nil && strings.HasPrefix(browserKey, "tc_") { + apiKey = browserKey + green.Println(" ✓ Connected via browser") + fmt.Println() + } else { + // Fallback to manual API key entry + if browserErr != nil { + dim.Printf(" Could not connect automatically: %s\n", browserErr) + } + fmt.Println(" Paste your API key instead:") + dim.Printf(" (get it from %s/profile?tab=apikey)\n", apiURL) + fmt.Println() + + var err error + err = huh.NewForm( + huh.NewGroup( + huh.NewInput(). + Title("Step 1/3 — API Key"). + Placeholder("tc_..."). + Value(&apiKey). + Validate(func(s string) error { + s = strings.TrimSpace(s) + if s == "" { + return fmt.Errorf("API key is required") + } + if !strings.HasPrefix(s, "tc_") { + return fmt.Errorf("API key should start with tc_") + } + return nil + }), + ), + ).Run() + if err != nil { + if errors.Is(err, huh.ErrUserAborted) { + fmt.Println("\n Init cancelled.") return nil - }), - ), - ).Run() - if err != nil { - if errors.Is(err, huh.ErrUserAborted) { - fmt.Println("\n Init cancelled.") - return nil + } + return err + } + apiKey = strings.TrimSpace(apiKey) } - return err } - apiKey = strings.TrimSpace(apiKey) // Validate API key by registering with the server fmt.Print(" Verifying API key... ")