unarr/internal/usenet/nntp/client.go

385 lines
8.1 KiB
Go

package nntp
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/textproto"
"sync"
"time"
)
// Config holds NNTP server connection parameters.
type Config struct {
Host string
Port int
SSL bool
TLSServerName string // override for TLS cert validation (e.g., "xsnews.nl" when Host is a CNAME)
Username string
Password string
MaxConnections int // default 10
}
// Client manages a pool of authenticated NNTP connections.
type Client struct {
cfg Config
pool chan *conn
mu sync.Mutex
open int
done chan struct{} // closed on Close()
}
// conn is a single NNTP connection.
type conn struct {
tp *textproto.Conn
raw net.Conn
closed bool
}
// NewClient creates a new NNTP client (does not connect yet).
func NewClient(cfg Config) *Client {
if cfg.MaxConnections <= 0 {
cfg.MaxConnections = 10
}
return &Client{
cfg: cfg,
pool: make(chan *conn, cfg.MaxConnections),
done: make(chan struct{}),
}
}
// Connect opens and authenticates all connections in the pool.
// Safe to call again after a previous Connect failure.
func (c *Client) Connect(ctx context.Context) error {
// Reset done channel if previously closed (allows retry after failure)
select {
case <-c.done:
c.done = make(chan struct{})
default:
}
for i := 0; i < c.cfg.MaxConnections; i++ {
conn, err := c.dial(ctx)
if err != nil {
// Close any connections we already opened, but keep client reusable
c.drainPool()
return fmt.Errorf("nntp: connect %d/%d: %w", i+1, c.cfg.MaxConnections, err)
}
c.pool <- conn
c.mu.Lock()
c.open++
c.mu.Unlock()
}
return nil
}
// drainPool closes all connections in the pool without closing the done channel.
func (c *Client) drainPool() {
for {
select {
case cn := <-c.pool:
c.closeConn(cn)
default:
c.mu.Lock()
c.open = 0
c.mu.Unlock()
return
}
}
}
// Body downloads the body of an NNTP article by message-ID.
// Returns the raw body reader (typically yEnc encoded).
// The caller MUST call release() when done reading.
func (c *Client) Body(ctx context.Context, messageID string) ([]byte, error) {
cn, err := c.acquire(ctx)
if err != nil {
return nil, err
}
data, err := c.bodyOnConn(ctx, cn, messageID)
if err != nil {
// Connection might be broken, try to reconnect
cn2, reconErr := c.reconnect(ctx, cn)
if reconErr != nil {
c.discard(cn)
return nil, fmt.Errorf("nntp: body failed and reconnect failed: %w (original: %v)", reconErr, err)
}
// Retry once on the fresh connection
data, err = c.bodyOnConn(ctx, cn2, messageID)
if err != nil {
c.release(cn2)
return nil, err
}
c.release(cn2)
return data, nil
}
c.release(cn)
return data, nil
}
// ActiveConnections returns the number of open connections.
func (c *Client) ActiveConnections() int {
c.mu.Lock()
defer c.mu.Unlock()
return c.open
}
// Close shuts down all connections in the pool.
func (c *Client) Close() error {
select {
case <-c.done:
return nil // already closed
default:
close(c.done)
}
// Drain pool and close connections
for {
select {
case cn := <-c.pool:
c.closeConn(cn)
default:
return nil
}
}
}
// --- Internal ---
func (c *Client) dial(ctx context.Context) (*conn, error) {
addr := fmt.Sprintf("%s:%d", c.cfg.Host, c.cfg.Port)
dialer := &net.Dialer{Timeout: 30 * time.Second}
var rawConn net.Conn
var err error
if c.cfg.SSL {
// Use TLSServerName if set (e.g., cert is for xsnews.nl but host is reader.torrentclaw.com)
serverName := c.cfg.TLSServerName
if serverName == "" {
serverName = c.cfg.Host
}
tlsCfg := &tls.Config{
ServerName: serverName,
MinVersion: tls.VersionTLS12,
}
rawConn, err = tls.DialWithDialer(dialer, "tcp", addr, tlsCfg)
} else {
rawConn, err = dialer.DialContext(ctx, "tcp", addr)
}
if err != nil {
return nil, fmt.Errorf("dial %s: %w", addr, err)
}
tp := textproto.NewConn(rawConn)
cn := &conn{tp: tp, raw: rawConn}
// Read welcome banner (200 or 201)
code, msg, err := tp.ReadCodeLine(200)
if err != nil {
// Also accept 201 (posting not allowed)
if code != 201 {
rawConn.Close()
return nil, fmt.Errorf("welcome: %d %s: %w", code, msg, err)
}
}
// Authenticate
if c.cfg.Username != "" {
if err := c.auth(tp); err != nil {
rawConn.Close()
return nil, fmt.Errorf("auth: %w", err)
}
}
return cn, nil
}
func (c *Client) auth(tp *textproto.Conn) error {
id, err := tp.Cmd("AUTHINFO USER %s", c.cfg.Username)
if err != nil {
return err
}
tp.StartResponse(id)
code, msg, err := tp.ReadCodeLine(381)
tp.EndResponse(id)
if err != nil {
// 281 means no password required (unlikely but valid)
if code == 281 {
return nil
}
return fmt.Errorf("AUTHINFO USER: %d %s: %w", code, msg, err)
}
id, err = tp.Cmd("AUTHINFO PASS %s", c.cfg.Password)
if err != nil {
return err
}
tp.StartResponse(id)
code, msg, err = tp.ReadCodeLine(281)
tp.EndResponse(id)
if err != nil {
return fmt.Errorf("AUTHINFO PASS: %d %s: %w", code, msg, err)
}
return nil
}
func (c *Client) bodyOnConn(ctx context.Context, cn *conn, messageID string) ([]byte, error) {
// Set deadline from context
deadline, hasDeadline := ctx.Deadline()
if !hasDeadline {
deadline = time.Now().Add(60 * time.Second)
}
cn.raw.SetDeadline(deadline)
defer cn.raw.SetDeadline(time.Time{})
// Send BODY command
id, err := cn.tp.Cmd("BODY <%s>", messageID)
if err != nil {
return nil, fmt.Errorf("BODY cmd: %w", err)
}
cn.tp.StartResponse(id)
defer cn.tp.EndResponse(id)
// Read response code
code, msg, err := cn.tp.ReadCodeLine(222)
if err != nil {
if code == 430 {
return nil, &ArticleNotFoundError{MessageID: messageID}
}
return nil, fmt.Errorf("BODY response: %d %s: %w", code, msg, err)
}
// Read dot-terminated body
body, err := readDotBody(cn.tp.R)
if err != nil {
// Partial read leaves textproto in broken state — mark connection as dead
cn.closed = true
return nil, fmt.Errorf("read body: %w", err)
}
return body, nil
}
// readDotBody reads a dot-terminated text block from the NNTP server.
// Lines beginning with a dot have the dot removed (dot-stuffing).
// The final ".\r\n" line signals the end.
func readDotBody(r *bufio.Reader) ([]byte, error) {
var buf bytes.Buffer
for {
line, err := r.ReadBytes('\n')
if err != nil {
if err == io.EOF {
break
}
return nil, err
}
// Trim trailing \r\n
line = bytes.TrimRight(line, "\r\n")
// Check for terminator: single dot
if len(line) == 1 && line[0] == '.' {
break
}
// Dot-unstuffing: remove leading dot if present
if len(line) > 0 && line[0] == '.' {
line = line[1:]
}
buf.Write(line)
buf.WriteByte('\n')
}
return buf.Bytes(), nil
}
func (c *Client) acquire(ctx context.Context) (*conn, error) {
for {
select {
case cn := <-c.pool:
if cn.closed {
c.mu.Lock()
c.open--
c.mu.Unlock()
continue // discard and try next
}
return cn, nil
case <-ctx.Done():
return nil, ctx.Err()
case <-c.done:
return nil, fmt.Errorf("nntp: client closed")
}
}
}
func (c *Client) release(cn *conn) {
if cn == nil || cn.closed {
return
}
select {
case c.pool <- cn:
default:
// Pool full, close the connection
c.closeConn(cn)
}
}
func (c *Client) discard(cn *conn) {
c.closeConn(cn)
c.mu.Lock()
c.open--
c.mu.Unlock()
}
func (c *Client) reconnect(ctx context.Context, old *conn) (*conn, error) {
c.closeConn(old)
newConn, err := c.dial(ctx)
if err != nil {
c.mu.Lock()
c.open--
c.mu.Unlock()
return nil, err
}
return newConn, nil
}
func (c *Client) closeConn(cn *conn) {
if cn == nil || cn.closed {
return
}
cn.closed = true
// Best-effort QUIT
cn.tp.Cmd("QUIT")
cn.raw.Close()
}
// ArticleNotFoundError is returned when the server responds with 430.
type ArticleNotFoundError struct {
MessageID string
}
func (e *ArticleNotFoundError) Error() string {
return fmt.Sprintf("nntp: article not found: %s", e.MessageID)
}
// Status returns a human-readable status string.
func (c *Client) Status() string {
c.mu.Lock()
open := c.open
c.mu.Unlock()
pooled := len(c.pool)
return fmt.Sprintf("%d connections (%d pooled) to %s:%d", open, pooled, c.cfg.Host, c.cfg.Port)
}