diff --git a/go.mod b/go.mod index 9cbc7e5..842aff3 100644 --- a/go.mod +++ b/go.mod @@ -1,18 +1,18 @@ module github.com/torrentclaw/torrentclaw-cli -go 1.24.0 +go 1.25.0 require ( github.com/BurntSushi/toml v1.6.0 github.com/anacrolix/log v0.17.1-0.20251118025802-918f1157b7bb github.com/anacrolix/torrent v1.61.0 github.com/charmbracelet/huh v1.0.0 - github.com/fatih/color v1.18.0 + github.com/fatih/color v1.19.0 github.com/google/uuid v1.6.0 github.com/olekukonko/tablewriter v0.0.5 - github.com/spf13/cobra v1.8.1 + github.com/spf13/cobra v1.10.2 github.com/torrentclaw/go-client v0.2.0 - golang.org/x/time v0.14.0 + golang.org/x/time v0.15.0 ) require ( @@ -41,6 +41,7 @@ require ( github.com/bradfitz/iter v0.0.0-20191230175014-e8f45d346db8 // indirect github.com/catppuccin/go v0.3.0 // indirect github.com/cespare/xxhash v1.1.0 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/charmbracelet/bubbles v0.21.1-0.20250623103423-23b8fd6302d7 // indirect github.com/charmbracelet/bubbletea v1.3.6 // indirect github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect @@ -49,6 +50,8 @@ require ( github.com/charmbracelet/x/cellbuf v0.0.13 // indirect github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0 // indirect github.com/charmbracelet/x/term v0.2.1 // indirect + github.com/clipperhouse/displaywidth v0.11.0 // indirect + github.com/clipperhouse/uax29/v2 v2.7.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/edsrzf/mmap-go v1.1.0 // indirect @@ -57,6 +60,7 @@ require ( github.com/go-llsqlite/crawshaw v0.5.6-0.20250312230104-194977a03421 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect + github.com/goccy/go-json v0.10.6 // indirect github.com/google/btree v1.1.2 // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/gorilla/websocket v1.5.0 // indirect @@ -64,10 +68,10 @@ require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/klauspost/cpuid/v2 v2.2.3 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect - github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-localereader v0.0.1 // indirect - github.com/mattn/go-runewidth v0.0.16 // indirect + github.com/mattn/go-runewidth v0.0.21 // indirect github.com/minio/sha256-simd v1.0.0 // indirect github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect github.com/mr-tron/base58 v1.2.0 // indirect @@ -77,6 +81,9 @@ require ( github.com/muesli/termenv v0.16.0 // indirect github.com/multiformats/go-multihash v0.2.3 // indirect github.com/multiformats/go-varint v0.0.6 // indirect + github.com/olekukonko/cat v0.0.0-20250911104152-50322a0618f6 // indirect + github.com/olekukonko/errors v1.2.0 // indirect + github.com/olekukonko/ll v0.1.8 // indirect github.com/pion/datachannel v1.5.9 // indirect github.com/pion/dtls/v3 v3.0.3 // indirect github.com/pion/ice/v4 v4.0.2 // indirect @@ -99,7 +106,7 @@ require ( github.com/rivo/uniseg v0.4.7 // indirect github.com/rs/dnscache v0.0.0-20211102005908-e0241e321417 // indirect github.com/spaolacci/murmur3 v1.1.0 // indirect - github.com/spf13/pflag v1.0.5 // indirect + github.com/spf13/pflag v1.0.10 // indirect github.com/tidwall/btree v1.8.1 // indirect github.com/wlynxg/anet v0.0.3 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect @@ -112,7 +119,7 @@ require ( golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 // indirect golang.org/x/net v0.47.0 // indirect golang.org/x/sync v0.18.0 // indirect - golang.org/x/sys v0.38.0 // indirect + golang.org/x/sys v0.42.0 // indirect golang.org/x/text v0.31.0 // indirect lukechampine.com/blake3 v1.1.6 // indirect modernc.org/libc v1.22.3 // indirect diff --git a/go.sum b/go.sum index fa5f37b..d09e04e 100644 --- a/go.sum +++ b/go.sum @@ -112,6 +112,8 @@ github.com/catppuccin/go v0.3.0/go.mod h1:8IHJuMGaUUjQM82qBrGNBv7LFq6JI3NnQCF6MO github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko= github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/charmbracelet/bubbles v0.21.1-0.20250623103423-23b8fd6302d7 h1:JFgG/xnwFfbezlUnFMJy0nusZvytYysV4SCS2cYbvws= github.com/charmbracelet/bubbles v0.21.1-0.20250623103423-23b8fd6302d7/go.mod h1:ISC1gtLcVilLOf23wvTfoQuYbW2q0JevFxPfUzZ9Ybw= github.com/charmbracelet/bubbletea v1.3.6 h1:VkHIxPJQeDt0aFJIsVxw8BQdh/F/L2KKZGsK6et5taU= @@ -141,7 +143,12 @@ github.com/charmbracelet/x/termios v0.1.1/go.mod h1:rB7fnv1TgOPOyyKRJ9o+AsTU/vK5 github.com/charmbracelet/x/xpty v0.1.2 h1:Pqmu4TEJ8KeA9uSkISKMU3f+C1F6OGBn8ABuGlqCbtI= github.com/charmbracelet/x/xpty v0.1.2/go.mod h1:XK2Z0id5rtLWcpeNiMYBccNNBrP2IJnzHI0Lq13Xzq4= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/clipperhouse/displaywidth v0.11.0 h1:lBc6kY44VFw+TDx4I8opi/EtL9m20WSEFgwIwO+UVM8= +github.com/clipperhouse/displaywidth v0.11.0/go.mod h1:bkrFNkf81G8HyVqmKGxsPufD3JhNl3dSqnGhOoSD/o0= +github.com/clipperhouse/uax29/v2 v2.7.0 h1:+gs4oBZ2gPfVrKPthwbMzWZDaAFPGYK72F0NJv2v7Vk= +github.com/clipperhouse/uax29/v2 v2.7.0/go.mod h1:EFJ2TJMRUaplDxHKj1qAEhCtQPW2tJSwu5BF98AuoVM= github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s= github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE= @@ -162,6 +169,8 @@ github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6 github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= +github.com/fatih/color v1.19.0 h1:Zp3PiM21/9Ld6FzSKyL5c/BULoe/ONr9KlbYVOfG8+w= +github.com/fatih/color v1.19.0/go.mod h1:zNk67I0ZUT1bEGsSGyCZYZNrHuTkJJB+r6Q9VuMi0LE= github.com/frankban/quicktest v1.9.0/go.mod h1:ui7WezCLWMWxVWr1GETZY3smRy0G4KWq9vcPtJmFl7Y= github.com/frankban/quicktest v1.14.4/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= @@ -189,6 +198,8 @@ github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7eI= github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/goccy/go-json v0.10.6 h1:p8HrPJzOakx/mn/bQtjgNjdTcN+/S6FcG2CTtQOrHVU= +github.com/goccy/go-json v0.10.6/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= @@ -264,6 +275,8 @@ github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69 github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= @@ -272,6 +285,8 @@ github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+Ei github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mattn/go-runewidth v0.0.21 h1:jJKAZiQH+2mIinzCJIaIG9Be1+0NR+5sz/lYEEjdM8w= +github.com/mattn/go-runewidth v0.0.21/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/minio/sha256-simd v1.0.0 h1:v1ta+49hkWZyvaKwrQB8elexRqm6Y0aMLjCNsrYxo6g= github.com/minio/sha256-simd v1.0.0/go.mod h1:OuYzVNI5vcoYIAmbIvHPl3N3jUzVedXbKy5RFepssQM= @@ -297,8 +312,16 @@ github.com/multiformats/go-multihash v0.2.3/go.mod h1:dXgKXCXjBzdscBLk9JkjINiEsC github.com/multiformats/go-varint v0.0.6 h1:gk85QWKxh3TazbLxED/NlDVv8+q+ReFJk7Y2W/KhfNY= github.com/multiformats/go-varint v0.0.6/go.mod h1:3Ls8CIEsrijN6+B7PbrXRPxHRPuXSrVKRY101jdMZYE= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= +github.com/olekukonko/cat v0.0.0-20250911104152-50322a0618f6 h1:zrbMGy9YXpIeTnGj4EljqMiZsIcE09mmF8XsD5AYOJc= +github.com/olekukonko/cat v0.0.0-20250911104152-50322a0618f6/go.mod h1:rEKTHC9roVVicUIfZK7DYrdIoM0EOr8mK1Hj5s3JjH0= +github.com/olekukonko/errors v1.2.0 h1:10Zcn4GeV59t/EGqJc8fUjtFT/FuUh5bTMzZ1XwmCRo= +github.com/olekukonko/errors v1.2.0/go.mod h1:ppzxA5jBKcO1vIpCXQ9ZqgDh8iwODz6OXIGKU8r5m4Y= +github.com/olekukonko/ll v0.1.8 h1:ysHCJRGHYKzmBSdz9w5AySztx7lG8SQY+naTGYUbsz8= +github.com/olekukonko/ll v0.1.8/go.mod h1:RPRC6UcscfFZgjo1nulkfMH5IM0QAYim0LfnMvUuozw= github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= +github.com/olekukonko/tablewriter v1.1.4 h1:ORUMI3dXbMnRlRggJX3+q7OzQFDdvgbN9nVWj1drm6I= +github.com/olekukonko/tablewriter v1.1.4/go.mod h1:+kedxuyTtgoZLwif3P1Em4hARJs+mVnzKxmsCL/C5RY= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= @@ -390,8 +413,13 @@ github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0b github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= +github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= +github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= +github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= @@ -437,6 +465,7 @@ go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgf go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI= go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE= go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs= +go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -516,6 +545,8 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= +golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= @@ -533,6 +564,8 @@ golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= +golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= +golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/internal/agent/client_test.go b/internal/agent/client_test.go index 36315ea..03dcb57 100644 --- a/internal/agent/client_test.go +++ b/internal/agent/client_test.go @@ -82,15 +82,18 @@ func TestHeartbeat(t *testing.T) { if req.AgentID != "agent-123" { t.Errorf("agentId = %q, want agent-123", req.AgentID) } - json.NewEncoder(w).Encode(StatusResponse{Success: true}) + json.NewEncoder(w).Encode(HeartbeatResponse{Success: true}) })) defer srv.Close() c := NewClient(srv.URL, "test-key", "unarr-test") - err := c.Heartbeat(context.Background(), HeartbeatRequest{AgentID: "agent-123"}) + resp, err := c.Heartbeat(context.Background(), HeartbeatRequest{AgentID: "agent-123"}) if err != nil { t.Fatalf("Heartbeat failed: %v", err) } + if !resp.Success { + t.Error("expected success=true") + } } func TestClaimTasks(t *testing.T) { @@ -115,21 +118,21 @@ func TestClaimTasks(t *testing.T) { defer srv.Close() c := NewClient(srv.URL, "test-key", "unarr-test") - tasks, err := c.ClaimTasks(context.Background(), "agent-123") + resp, err := c.ClaimTasks(context.Background(), "agent-123") if err != nil { t.Fatalf("ClaimTasks failed: %v", err) } - if len(tasks) != 1 { - t.Fatalf("len(tasks) = %d, want 1", len(tasks)) + if len(resp.Tasks) != 1 { + t.Fatalf("len(tasks) = %d, want 1", len(resp.Tasks)) } - if tasks[0].ID != "task-uuid-1" { - t.Errorf("task.ID = %q, want task-uuid-1", tasks[0].ID) + if resp.Tasks[0].ID != "task-uuid-1" { + t.Errorf("task.ID = %q, want task-uuid-1", resp.Tasks[0].ID) } - if tasks[0].InfoHash != "abc123def456abc123def456abc123def456abc1" { - t.Errorf("task.InfoHash = %q", tasks[0].InfoHash) + if resp.Tasks[0].InfoHash != "abc123def456abc123def456abc123def456abc1" { + t.Errorf("task.InfoHash = %q", resp.Tasks[0].InfoHash) } - if tasks[0].PreferredMethod != "auto" { - t.Errorf("task.PreferredMethod = %q, want auto", tasks[0].PreferredMethod) + if resp.Tasks[0].PreferredMethod != "auto" { + t.Errorf("task.PreferredMethod = %q, want auto", resp.Tasks[0].PreferredMethod) } } @@ -177,12 +180,12 @@ func TestClaimTasksEmpty(t *testing.T) { defer srv.Close() c := NewClient(srv.URL, "test-key", "unarr-test") - tasks, err := c.ClaimTasks(context.Background(), "agent-123") + resp, err := c.ClaimTasks(context.Background(), "agent-123") if err != nil { t.Fatalf("ClaimTasks failed: %v", err) } - if len(tasks) != 0 { - t.Errorf("expected empty tasks, got %d", len(tasks)) + if len(resp.Tasks) != 0 { + t.Errorf("expected empty tasks, got %d", len(resp.Tasks)) } } @@ -276,10 +279,107 @@ func TestUserAgent(t *testing.T) { if r.Header.Get("User-Agent") != "unarr/0.2.0" { t.Errorf("User-Agent = %q, want unarr/0.2.0", r.Header.Get("User-Agent")) } - json.NewEncoder(w).Encode(StatusResponse{Success: true}) + json.NewEncoder(w).Encode(HeartbeatResponse{Success: true}) })) defer srv.Close() c := NewClient(srv.URL, "test-key", "unarr/0.2.0") c.Heartbeat(context.Background(), HeartbeatRequest{AgentID: "x"}) } + +func TestHeartbeatWithUpgradeSignal(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(HeartbeatResponse{ + Success: true, + Upgrade: &UpgradeSignal{Version: "2.0.0"}, + }) + })) + defer srv.Close() + + c := NewClient(srv.URL, "test-key", "unarr-test") + resp, err := c.Heartbeat(context.Background(), HeartbeatRequest{AgentID: "agent-1"}) + if err != nil { + t.Fatalf("Heartbeat failed: %v", err) + } + if resp.Upgrade == nil { + t.Fatal("expected upgrade signal, got nil") + } + if resp.Upgrade.Version != "2.0.0" { + t.Errorf("upgrade version = %q, want 2.0.0", resp.Upgrade.Version) + } +} + +func TestHeartbeatWithoutUpgradeSignal(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(HeartbeatResponse{Success: true}) + })) + defer srv.Close() + + c := NewClient(srv.URL, "test-key", "unarr-test") + resp, err := c.Heartbeat(context.Background(), HeartbeatRequest{AgentID: "agent-1"}) + if err != nil { + t.Fatalf("Heartbeat failed: %v", err) + } + if resp.Upgrade != nil { + t.Errorf("expected no upgrade signal, got %+v", resp.Upgrade) + } +} + +func TestReportUpgradeResult(t *testing.T) { + var received UpgradeResult + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/internal/agent/upgrade-result" { + t.Errorf("path = %s, want /api/internal/agent/upgrade-result", r.URL.Path) + } + if r.Method != http.MethodPost { + t.Errorf("method = %s, want POST", r.Method) + } + json.NewDecoder(r.Body).Decode(&received) + json.NewEncoder(w).Encode(struct{ Success bool }{Success: true}) + })) + defer srv.Close() + + c := NewClient(srv.URL, "test-key", "unarr-test") + err := c.ReportUpgradeResult(context.Background(), UpgradeResult{ + AgentID: "agent-1", + Success: true, + Version: "2.0.0", + }) + if err != nil { + t.Fatalf("ReportUpgradeResult failed: %v", err) + } + if received.AgentID != "agent-1" { + t.Errorf("agentId = %q, want agent-1", received.AgentID) + } + if !received.Success { + t.Error("expected success=true") + } + if received.Version != "2.0.0" { + t.Errorf("version = %q, want 2.0.0", received.Version) + } +} + +func TestReportUpgradeResultFailure(t *testing.T) { + var received UpgradeResult + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewDecoder(r.Body).Decode(&received) + json.NewEncoder(w).Encode(struct{ Success bool }{Success: true}) + })) + defer srv.Close() + + c := NewClient(srv.URL, "test-key", "unarr-test") + err := c.ReportUpgradeResult(context.Background(), UpgradeResult{ + AgentID: "agent-1", + Success: false, + Error: "checksum mismatch", + }) + if err != nil { + t.Fatalf("ReportUpgradeResult failed: %v", err) + } + if received.Success { + t.Error("expected success=false") + } + if received.Error != "checksum mismatch" { + t.Errorf("error = %q, want 'checksum mismatch'", received.Error) + } +} diff --git a/internal/agent/daemon.go b/internal/agent/daemon.go index 06e0c3e..fe57e85 100644 --- a/internal/agent/daemon.go +++ b/internal/agent/daemon.go @@ -44,6 +44,9 @@ type Daemon struct { // Exposed tickers for hot-reload PollTicker *time.Ticker HeartbeatTicker *time.Ticker + + // pollNow triggers an immediate poll (e.g. on resume) + pollNow chan struct{} } // NewDaemon creates a daemon with the given transport. @@ -59,6 +62,7 @@ func NewDaemon(cfg DaemonConfig, transport Transport) *Daemon { return &Daemon{ cfg: cfg, transport: transport, + pollNow: make(chan struct{}, 1), } } @@ -151,6 +155,9 @@ func (d *Daemon) Run(ctx context.Context) error { if d.transport.Mode() == "http" { d.poll(ctx) } + + case <-d.pollNow: + d.poll(ctx) } } } @@ -236,6 +243,15 @@ func (d *Daemon) handleEvent(event ServerEvent) { } } +// TriggerPoll requests an immediate task poll cycle. +// Used when a resume event is received to pick up re-pending tasks faster. +func (d *Daemon) TriggerPoll() { + select { + case d.pollNow <- struct{}{}: + default: // already pending + } +} + // ClearUpgradeInProgress resets the upgrade flag so a retry can be attempted. func (d *Daemon) ClearUpgradeInProgress() { d.upgradeInProgress = false diff --git a/internal/agent/state.go b/internal/agent/state.go new file mode 100644 index 0000000..7316116 --- /dev/null +++ b/internal/agent/state.go @@ -0,0 +1,72 @@ +package agent + +import ( + "encoding/json" + "os" + "path/filepath" + "time" + + "github.com/torrentclaw/torrentclaw-cli/internal/config" +) + +// DaemonState is written to disk every heartbeat for external tools to read. +type DaemonState struct { + AgentID string `json:"agentId"` + Status string `json:"status"` // running | upgrading | shutting_down + Version string `json:"version"` + PID int `json:"pid"` + StartedAt time.Time `json:"startedAt"` + LastHeartbeat time.Time `json:"lastHeartbeat"` + ActiveTasks int `json:"activeTasks"` + CompletedCount int `json:"completedCount"` + FailedCount int `json:"failedCount"` + TotalDownloaded int64 `json:"totalDownloaded"` + MethodStats map[string]int `json:"methodStats,omitempty"` +} + +// stateFilePathFn is overridable for testing. +var stateFilePathFn = func() string { + return filepath.Join(config.DataDir(), "daemon.state.json") +} + +// StateFilePath returns the path to the daemon state file. +func StateFilePath() string { + return stateFilePathFn() +} + +// WriteState writes the daemon state to disk (best-effort, never errors). +func WriteState(state *DaemonState) { + path := StateFilePath() + dir := filepath.Dir(path) + os.MkdirAll(dir, 0o755) + + data, err := json.MarshalIndent(state, "", " ") + if err != nil { + return + } + + // Write to temp file then rename for atomicity + tmp := path + ".tmp" + if err := os.WriteFile(tmp, data, 0o644); err != nil { + return + } + os.Rename(tmp, path) +} + +// ReadState reads the daemon state from disk. Returns nil if not found. +func ReadState() *DaemonState { + data, err := os.ReadFile(StateFilePath()) + if err != nil { + return nil + } + var state DaemonState + if json.Unmarshal(data, &state) != nil { + return nil + } + return &state +} + +// RemoveState deletes the state file (called on clean shutdown). +func RemoveState() { + os.Remove(StateFilePath()) +} diff --git a/internal/agent/state_test.go b/internal/agent/state_test.go new file mode 100644 index 0000000..6c9abdd --- /dev/null +++ b/internal/agent/state_test.go @@ -0,0 +1,106 @@ +package agent + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +func TestWriteAndReadState(t *testing.T) { + // Override the state file path for testing + tmpDir := t.TempDir() + origFn := stateFilePathFn + stateFilePathFn = func() string { return filepath.Join(tmpDir, "daemon.state.json") } + defer func() { stateFilePathFn = origFn }() + + state := &DaemonState{ + AgentID: "agent-123", + Status: "running", + Version: "1.0.0", + PID: 12345, + StartedAt: time.Now().Truncate(time.Second), + LastHeartbeat: time.Now().Truncate(time.Second), + ActiveTasks: 3, + CompletedCount: 10, + FailedCount: 2, + TotalDownloaded: 1024 * 1024 * 500, + MethodStats: map[string]int{"torrent": 8, "debrid": 2}, + } + + WriteState(state) + + read := ReadState() + if read == nil { + t.Fatal("ReadState() returned nil") + } + if read.AgentID != "agent-123" { + t.Errorf("AgentID = %q, want agent-123", read.AgentID) + } + if read.Status != "running" { + t.Errorf("Status = %q, want running", read.Status) + } + if read.Version != "1.0.0" { + t.Errorf("Version = %q, want 1.0.0", read.Version) + } + if read.PID != 12345 { + t.Errorf("PID = %d, want 12345", read.PID) + } + if read.ActiveTasks != 3 { + t.Errorf("ActiveTasks = %d, want 3", read.ActiveTasks) + } + if read.CompletedCount != 10 { + t.Errorf("CompletedCount = %d, want 10", read.CompletedCount) + } + if read.MethodStats["torrent"] != 8 { + t.Errorf("MethodStats[torrent] = %d, want 8", read.MethodStats["torrent"]) + } +} + +func TestReadStateNotFound(t *testing.T) { + tmpDir := t.TempDir() + origFn := stateFilePathFn + stateFilePathFn = func() string { return filepath.Join(tmpDir, "nonexistent.json") } + defer func() { stateFilePathFn = origFn }() + + state := ReadState() + if state != nil { + t.Errorf("ReadState() = %+v, want nil for missing file", state) + } +} + +func TestRemoveState(t *testing.T) { + tmpDir := t.TempDir() + origFn := stateFilePathFn + stateFilePathFn = func() string { return filepath.Join(tmpDir, "daemon.state.json") } + defer func() { stateFilePathFn = origFn }() + + WriteState(&DaemonState{AgentID: "test"}) + + // Verify file exists + path := StateFilePath() + if _, err := os.Stat(path); err != nil { + t.Fatalf("state file should exist: %v", err) + } + + RemoveState() + + if _, err := os.Stat(path); !os.IsNotExist(err) { + t.Error("state file should be removed after RemoveState()") + } +} + +func TestReadStateCorruptedJSON(t *testing.T) { + tmpDir := t.TempDir() + origFn := stateFilePathFn + path := filepath.Join(tmpDir, "daemon.state.json") + stateFilePathFn = func() string { return path } + defer func() { stateFilePathFn = origFn }() + + os.WriteFile(path, []byte("not valid json{{{"), 0o644) + + state := ReadState() + if state != nil { + t.Errorf("ReadState() should return nil for corrupted JSON, got %+v", state) + } +} diff --git a/internal/cmd/daemon.go b/internal/cmd/daemon.go index 7eea9e3..5b1bea3 100644 --- a/internal/cmd/daemon.go +++ b/internal/cmd/daemon.go @@ -16,6 +16,7 @@ import ( "github.com/torrentclaw/torrentclaw-cli/internal/agent" "github.com/torrentclaw/torrentclaw-cli/internal/config" "github.com/torrentclaw/torrentclaw-cli/internal/engine" + "github.com/torrentclaw/torrentclaw-cli/internal/usenet/download" "github.com/torrentclaw/torrentclaw-cli/internal/upgrade" ) @@ -117,6 +118,12 @@ func runDaemonStart() error { return fmt.Errorf("create download dir: %w", err) } + // Clean up stale resume files (>7 days old) + resumeDir := filepath.Join(config.DataDir(), "resume") + if removed := download.CleanStaleFiles(resumeDir, 7*24*time.Hour); removed > 0 { + log.Printf("Cleaned %d stale resume file(s)", removed) + } + fmt.Println() bold.Println(" unarr Daemon") fmt.Println() @@ -314,7 +321,8 @@ func runDaemonStart() error { manager.PauseTask(taskID) cancelStreamTask(taskID) case "resume": - log.Printf("[%s] resume requested via WebSocket", taskID[:8]) + log.Printf("[%s] resume requested via WebSocket, triggering poll", taskID[:8]) + d.TriggerPoll() case "stream": // Use registry mutex to prevent TOCTOU race with HTTP-polled stream requests streamRegistry.mu.Lock() diff --git a/internal/cmd/stream_handler.go b/internal/cmd/stream_handler.go index 5a2bb26..4c75331 100644 --- a/internal/cmd/stream_handler.go +++ b/internal/cmd/stream_handler.go @@ -2,7 +2,9 @@ package cmd import ( "context" + "fmt" "log" + "os" "sync" "time" @@ -125,13 +127,29 @@ func handleStreamTask(parentCtx context.Context, at agent.Task, reporter *engine Seeds: p.Seeds, FileName: p.FileName, }) + + // Terminal progress + if p.TotalBytes > 0 { + pct := int(float64(p.DownloadedBytes) / float64(p.TotalBytes) * 100) + fmt.Fprintf(os.Stderr, "\r[%s] %d%% — %s/%s @ %s/s peers:%d seeds:%d", + at.ID[:8], pct, + ui.FormatBytes(p.DownloadedBytes), ui.FormatBytes(p.TotalBytes), ui.FormatBytes(p.SpeedBps), + p.Peers, p.Seeds) + } + if p.DownloadedBytes >= p.TotalBytes && p.TotalBytes > 0 { + fmt.Fprint(os.Stderr, "\r\033[2K") // clear progress line task.Transition(engine.StatusCompleted) - log.Printf("[%s] stream download complete, server stays up until cancelled", at.ID[:8]) - // Don't return — keep HTTP server running so the player - // can finish reading. The stream stops when the user - // cancels from the web or the daemon shuts down. - <-ctx.Done() + log.Printf("[%s] stream download complete, server stays up for 30m or until cancelled", at.ID[:8]) + // Keep HTTP server running so the player can finish reading. + // Auto-shutdown after 30 minutes of idle to prevent resource leaks. + idleTimer := time.NewTimer(30 * time.Minute) + defer idleTimer.Stop() + select { + case <-ctx.Done(): + case <-idleTimer.C: + log.Printf("[%s] stream idle timeout (30m), shutting down", at.ID[:8]) + } return } } diff --git a/internal/engine/notify.go b/internal/engine/notify.go index d4126b1..2bb1a98 100644 --- a/internal/engine/notify.go +++ b/internal/engine/notify.go @@ -14,8 +14,29 @@ func desktopNotify(title, body string) { case "darwin": script := `display notification "` + escapeAppleScript(body) + `" with title "` + escapeAppleScript(title) + `"` exec.Command("osascript", "-e", script).Start() + case "windows": + // Use PowerShell toast notification (Windows 10+) + script := `[Windows.UI.Notifications.ToastNotificationManager, Windows.UI.Notifications, ContentType = WindowsRuntime] > $null;` + + `$xml = [Windows.UI.Notifications.ToastNotificationManager]::GetTemplateContent(1);` + + `$text = $xml.GetElementsByTagName('text');` + + `$text[0].AppendChild($xml.CreateTextNode('` + escapePowerShell(title) + `')) > $null;` + + `$text[1].AppendChild($xml.CreateTextNode('` + escapePowerShell(body) + `')) > $null;` + + `$toast = [Windows.UI.Notifications.ToastNotification]::new($xml);` + + `[Windows.UI.Notifications.ToastNotificationManager]::CreateToastNotifier('unarr').Show($toast)` + exec.Command("powershell", "-NoProfile", "-Command", script).Start() } - // Windows: no-op for now +} + +func escapePowerShell(s string) string { + out := make([]byte, 0, len(s)) + for i := 0; i < len(s); i++ { + if s[i] == '\'' { + out = append(out, '\'', '\'') // double single-quote to escape + } else { + out = append(out, s[i]) + } + } + return string(out) } func escapeAppleScript(s string) string { diff --git a/internal/engine/notify_test.go b/internal/engine/notify_test.go new file mode 100644 index 0000000..612f173 --- /dev/null +++ b/internal/engine/notify_test.go @@ -0,0 +1,46 @@ +package engine + +import "testing" + +func TestEscapePowerShell(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"hello", "hello"}, + {"it's done", "it''s done"}, + {"Tom's 'file'", "Tom''s ''file''"}, + {"no quotes", "no quotes"}, + {"", ""}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := escapePowerShell(tt.input) + if got != tt.want { + t.Errorf("escapePowerShell(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestEscapeAppleScript(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"hello", "hello"}, + {`say "hi"`, `say \"hi\"`}, + {`back\slash`, `back\\slash`}, + {"", ""}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := escapeAppleScript(tt.input) + if got != tt.want { + t.Errorf("escapeAppleScript(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} diff --git a/internal/engine/organize.go b/internal/engine/organize.go index 5eb7eb0..a495856 100644 --- a/internal/engine/organize.go +++ b/internal/engine/organize.go @@ -10,8 +10,10 @@ import ( ) var ( - yearRegex = regexp.MustCompile(`\b(19|20)\d{2}\b`) - seasonRegex = regexp.MustCompile(`(?i)S(\d{2})`) + yearRegex = regexp.MustCompile(`\b(19|20)\d{2}\b`) + seasonRegex = regexp.MustCompile(`(?i)S(\d{2})`) + episodeRegex = regexp.MustCompile(`(?i)S(\d{2})E(\d{2})`) + altEpRegex = regexp.MustCompile(`(?i)(\d{1,2})x(\d{2})`) // 1x05 format ) // OrganizeConfig holds file organization settings. @@ -37,9 +39,15 @@ func organize(result *Result, task *Task, cfg OrganizeConfig) (string, error) { isTV := strings.Contains(strings.ToLower(task.PreferredMethod), "show") || seasonRegex.MatchString(result.FileName) - // Detect season for TV + // Detect season for TV (S01E05 or 1x05 format) var season string - if m := seasonRegex.FindStringSubmatch(result.FileName); len(m) > 1 { + if m := episodeRegex.FindStringSubmatch(result.FileName); len(m) > 2 { + season = m[1] + isTV = true + } else if m := altEpRegex.FindStringSubmatch(result.FileName); len(m) > 2 { + season = fmt.Sprintf("%02s", m[1]) + isTV = true + } else if m := seasonRegex.FindStringSubmatch(result.FileName); len(m) > 1 { season = m[1] isTV = true } @@ -80,6 +88,23 @@ func organize(result *Result, task *Task, cfg OrganizeConfig) (string, error) { destPath := filepath.Join(destDir, filepath.Base(result.FilePath)) + // Check if source is a directory (multi-file torrent) + srcInfo, err := os.Stat(result.FilePath) + if err != nil { + return "", fmt.Errorf("stat source: %w", err) + } + + if srcInfo.IsDir() { + // For directories: remove existing destination if present, then rename + if _, err := os.Stat(destPath); err == nil { + os.RemoveAll(destPath) + } + if err := os.Rename(result.FilePath, destPath); err != nil { + return "", fmt.Errorf("move directory: %w", err) + } + return destPath, nil + } + // Try rename first (same filesystem), fall back to copy+delete if err := os.Rename(result.FilePath, destPath); err != nil { if err := copyFile(result.FilePath, destPath); err != nil { diff --git a/internal/engine/organize_test.go b/internal/engine/organize_test.go index 6936b39..509f065 100644 --- a/internal/engine/organize_test.go +++ b/internal/engine/organize_test.go @@ -71,6 +71,60 @@ func TestOrganizeTVShow(t *testing.T) { } } +func TestOrganizeTVShowAltFormat(t *testing.T) { + tmp := t.TempDir() + srcFile := filepath.Join(tmp, "Show.3x12.HDTV.mkv") + os.WriteFile(srcFile, []byte("data"), 0o644) + + tvDir := filepath.Join(tmp, "TV Shows") + + r := &Result{FilePath: srcFile, FileName: "Show.3x12.HDTV.mkv"} + task := &Task{Title: "Show 3x12"} + + path, err := organize(r, task, OrganizeConfig{ + Enabled: true, + TVShowsDir: tvDir, + }) + if err != nil { + t.Fatal(err) + } + + // Should detect season 03 from "3x12" format + if _, err := os.Stat(path); err != nil { + t.Errorf("organized file should exist at %s: %v", path, err) + } + // Verify it went into Season 03 directory + dir := filepath.Dir(path) + if filepath.Base(dir) != "Season 03" { + t.Errorf("expected Season 03 directory, got %q", filepath.Base(dir)) + } +} + +func TestSeasonDetectionFormats(t *testing.T) { + tests := []struct { + filename string + wantTV bool + }{ + {"Show.S01E05.720p.mkv", true}, + {"Show.s02e10.1080p.mkv", true}, + {"Show.3x12.HDTV.mkv", true}, + {"Show.12x01.mkv", true}, + {"Movie.2023.1080p.mkv", false}, + {"Just.A.Movie.mkv", false}, + } + + for _, tt := range tests { + t.Run(tt.filename, func(t *testing.T) { + isTV := episodeRegex.MatchString(tt.filename) || + altEpRegex.MatchString(tt.filename) || + seasonRegex.MatchString(tt.filename) + if isTV != tt.wantTV { + t.Errorf("isTV(%q) = %v, want %v", tt.filename, isTV, tt.wantTV) + } + }) + } +} + func TestCleanTitle(t *testing.T) { tests := []struct { input string diff --git a/internal/engine/stream.go b/internal/engine/stream.go index 9829bf1..ddf9b00 100644 --- a/internal/engine/stream.go +++ b/internal/engine/stream.go @@ -3,6 +3,7 @@ package engine import ( "context" "fmt" + "io" "os" "path/filepath" "strings" @@ -233,7 +234,7 @@ func (s *StreamEngine) WaitBuffer(ctx context.Context, progressFn func(buffered, // NewFileReader creates a new reader for the selected file. // Each HTTP request should get its own reader (not safe for concurrent use). -func (s *StreamEngine) NewFileReader(ctx context.Context) torrent.Reader { +func (s *StreamEngine) NewFileReader(ctx context.Context) io.ReadSeekCloser { reader := s.file.NewReader() reader.SetResponsive() reader.SetReadahead(5 * 1024 * 1024) // 5MB readahead diff --git a/internal/engine/stream_server.go b/internal/engine/stream_server.go index 6a7309a..cffaddd 100644 --- a/internal/engine/stream_server.go +++ b/internal/engine/stream_server.go @@ -3,9 +3,11 @@ package engine import ( "context" "fmt" + "io" "log" "net" "net/http" + "os" "path/filepath" "strings" "time" @@ -15,7 +17,7 @@ import ( // fileProvider abstracts where to get a file reader for streaming. type fileProvider interface { - NewFileReader(ctx context.Context) torrent.Reader + NewFileReader(ctx context.Context) io.ReadSeekCloser FileName() string } @@ -49,7 +51,7 @@ type torrentFileProvider struct { file *torrent.File } -func (p *torrentFileProvider) NewFileReader(ctx context.Context) torrent.Reader { +func (p *torrentFileProvider) NewFileReader(ctx context.Context) io.ReadSeekCloser { reader := p.file.NewReader() reader.SetResponsive() reader.SetReadahead(5 * 1024 * 1024) @@ -61,6 +63,33 @@ func (p *torrentFileProvider) FileName() string { return filepath.Base(p.file.DisplayPath()) } +// diskFileProvider serves a file from disk. +type diskFileProvider struct { + path string + name string +} + +func (p *diskFileProvider) NewFileReader(_ context.Context) io.ReadSeekCloser { + f, err := os.Open(p.path) + if err != nil { + return nil + } + return f +} + +func (p *diskFileProvider) FileName() string { return p.name } + +// NewStreamServerFromDisk creates a server that streams a file from disk. +func NewStreamServerFromDisk(filePath string, port int) *StreamServer { + return &StreamServer{ + provider: &diskFileProvider{ + path: filePath, + name: filepath.Base(filePath), + }, + port: port, + } +} + // Start begins serving the file on localhost. Returns the full URL. func (ss *StreamServer) Start(ctx context.Context) (string, error) { mux := http.NewServeMux() @@ -106,6 +135,10 @@ func (ss *StreamServer) Shutdown(ctx context.Context) error { func (ss *StreamServer) handler(w http.ResponseWriter, r *http.Request) { reader := ss.provider.NewFileReader(r.Context()) + if reader == nil { + http.Error(w, "file not found", http.StatusNotFound) + return + } defer reader.Close() w.Header().Set("Content-Type", mimeTypeFromExt(ss.provider.FileName())) diff --git a/internal/engine/torrent.go b/internal/engine/torrent.go index 4ee643b..025441a 100644 --- a/internal/engine/torrent.go +++ b/internal/engine/torrent.go @@ -146,13 +146,28 @@ func (d *TorrentDownloader) Download(ctx context.Context, task *Task, outputDir } // 4. Determine file path + // For multi-file torrents, fileName includes the torrent dir prefix (e.g. "TorrentName/file.mkv"). + // Try the full path first, then just the file inside the torrent dir. filePath := filepath.Join(d.cfg.DataDir, fileName) if _, statErr := os.Stat(filePath); statErr != nil { - filePath = filepath.Join(d.cfg.DataDir, t.Name()) + // File might have been moved — try torrent directory + dirPath := filepath.Join(d.cfg.DataDir, t.Name()) + if fi, statErr2 := os.Stat(dirPath); statErr2 == nil && fi.IsDir() { + // Look for the actual file inside the directory + base := filepath.Base(fileName) + candidate := filepath.Join(dirPath, base) + if _, statErr3 := os.Stat(candidate); statErr3 == nil { + filePath = candidate + } else { + filePath = dirPath + } + } else { + filePath = dirPath + } } result.FilePath = filePath - result.FileName = fileName + result.FileName = filepath.Base(fileName) result.Method = MethodTorrent result.Size = totalBytes @@ -211,6 +226,13 @@ func (d *TorrentDownloader) pollDownload(ctx context.Context, t *torrent.Torrent // Peer stats stats := t.Stats() + // Terminal progress + pct := int(float64(downloaded) / float64(totalBytes) * 100) + fmt.Fprintf(os.Stderr, "\r[%s] %d%% — %s/%s @ %s/s peers:%d seeds:%d", + task.ID[:8], pct, + formatBytes(downloaded), formatBytes(totalBytes), formatBytes(speed), + stats.ActivePeers, stats.ConnectedSeeders) + // Report progress p := Progress{ DownloadedBytes: downloaded, @@ -230,6 +252,7 @@ func (d *TorrentDownloader) pollDownload(ctx context.Context, t *torrent.Torrent // Check completion if downloaded >= totalBytes { + fmt.Fprint(os.Stderr, "\r\033[2K") // clear progress line log.Printf("[%s] download complete: %s", task.ID[:8], fileName) return &Result{}, nil } diff --git a/internal/engine/usenet.go b/internal/engine/usenet.go index f81f81a..5236c69 100644 --- a/internal/engine/usenet.go +++ b/internal/engine/usenet.go @@ -11,7 +11,7 @@ import ( "time" "github.com/torrentclaw/torrentclaw-cli/internal/agent" - "github.com/torrentclaw/torrentclaw-cli/internal/ui" + "github.com/torrentclaw/torrentclaw-cli/internal/config" "github.com/torrentclaw/torrentclaw-cli/internal/usenet/download" "github.com/torrentclaw/torrentclaw-cli/internal/usenet/nntp" "github.com/torrentclaw/torrentclaw-cli/internal/usenet/nzb" @@ -125,10 +125,22 @@ func (u *UsenetDownloader) Download(ctx context.Context, task *Task, outputDir s log.Printf("[%s] NZB: %s", shortID, nzbTitle) - // Step 2: Download NZB file - nzbData, err := u.apiClient.DownloadNzb(dlCtx, nzbID) + // Step 2: Download NZB file (or use cached version for resume) + resumeDir := filepath.Join(config.DataDir(), "resume") + nzbCachePath := filepath.Join(resumeDir, task.ID+".nzb") + + nzbData, err := os.ReadFile(nzbCachePath) if err != nil { - return nil, fmt.Errorf("download NZB: %w", err) + // Not cached — download from server + nzbData, err = u.apiClient.DownloadNzb(dlCtx, nzbID) + if err != nil { + return nil, fmt.Errorf("download NZB: %w", err) + } + // Cache for future resume + os.MkdirAll(resumeDir, 0o755) + os.WriteFile(nzbCachePath, nzbData, 0o644) + } else { + log.Printf("[%s] using cached NZB", shortID) } // Step 3: Parse NZB @@ -140,7 +152,15 @@ func (u *UsenetDownloader) Download(ctx context.Context, task *Task, outputDir s totalBytes := nzbFile.TotalBytes() totalSegs := nzbFile.TotalSegments() log.Printf("[%s] NZB parsed: %d files, %d segments, %s", - shortID, len(nzbFile.Files), totalSegs, ui.FormatBytes(totalBytes)) + shortID, len(nzbFile.Files), totalSegs, formatBytes(totalBytes)) + + // Step 3.5: Resume support — load or create progress tracker + tracker := download.NewProgressTracker(task.ID, nzbFile, resumeDir) + resumed, _ := tracker.Load() + if resumed { + log.Printf("[%s] resuming usenet download (%d/%d segments completed)", + shortID, tracker.TotalCompleted(), totalSegs) + } // Step 4: Get NNTP credentials and connect creds, err := u.getCredentials(dlCtx) @@ -185,7 +205,7 @@ func (u *UsenetDownloader) Download(ctx context.Context, task *Task, outputDir s } }() - downloadedFiles, err := dl.DownloadNZB(dlCtx, nzbFile, taskDir, dlProgressCh) + downloadedFiles, err := dl.DownloadNZB(dlCtx, nzbFile, taskDir, tracker, dlProgressCh) close(dlProgressCh) if err != nil { @@ -234,6 +254,9 @@ func (u *UsenetDownloader) Download(ctx context.Context, task *Task, outputDir s finalSize = fi.Size() } + // Clean up resume state on successful completion + tracker.Remove() + return &Result{ FilePath: finalPath, FileName: filepath.Base(finalPath), diff --git a/internal/upgrade/download.go b/internal/upgrade/download.go new file mode 100644 index 0000000..99b94bc --- /dev/null +++ b/internal/upgrade/download.go @@ -0,0 +1,146 @@ +package upgrade + +import ( + "bufio" + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "strings" + "time" +) + +var httpClient = &http.Client{Timeout: 120 * time.Second} + +// download fetches the release archive to a temporary file. +func download(ctx context.Context, version string) (string, error) { + url := releaseURL(version, archiveName(version)) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return "", err + } + req.Header.Set("User-Agent", "unarr-updater") + + resp, err := httpClient.Do(req) + if err != nil { + return "", fmt.Errorf("fetch %s: %w", url, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("fetch %s: HTTP %d", url, resp.StatusCode) + } + + tmp, err := os.CreateTemp("", "unarr-download-*.tmp") + if err != nil { + return "", err + } + defer tmp.Close() + + if _, err := io.Copy(tmp, resp.Body); err != nil { + os.Remove(tmp.Name()) + return "", fmt.Errorf("write archive: %w", err) + } + + return tmp.Name(), nil +} + +// verifyChecksum downloads checksums.txt and verifies the archive's SHA256. +func verifyChecksum(ctx context.Context, version, archivePath string) error { + // Download checksums.txt + url := releaseURL(version, "checksums.txt") + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return err + } + req.Header.Set("User-Agent", "unarr-updater") + + resp, err := httpClient.Do(req) + if err != nil { + return fmt.Errorf("fetch checksums: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("fetch checksums: HTTP %d", resp.StatusCode) + } + + // Parse checksums.txt — format: " " + expectedName := archiveName(version) + var expectedHash string + + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Text() + parts := strings.Fields(line) + if len(parts) >= 2 && parts[1] == expectedName { + expectedHash = parts[0] + break + } + } + if err := scanner.Err(); err != nil { + return fmt.Errorf("read checksums: %w", err) + } + + if expectedHash == "" { + return fmt.Errorf("no checksum found for %s in checksums.txt", expectedName) + } + + // Compute SHA256 of the downloaded archive + f, err := os.Open(archivePath) + if err != nil { + return err + } + defer f.Close() + + h := sha256.New() + if _, err := io.Copy(h, f); err != nil { + return fmt.Errorf("hash archive: %w", err) + } + + actualHash := hex.EncodeToString(h.Sum(nil)) + if !strings.EqualFold(actualHash, expectedHash) { + return fmt.Errorf("SHA256 mismatch: expected %s, got %s", expectedHash, actualHash) + } + + return nil +} + +// fetchLatestVersion queries GitHub API for the latest release tag. +func fetchLatestVersion(ctx context.Context) (string, error) { + url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", githubRepo) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return "", err + } + req.Header.Set("Accept", "application/vnd.github+json") + req.Header.Set("User-Agent", "unarr-updater") + + resp, err := httpClient.Do(req) + if err != nil { + return "", fmt.Errorf("fetch latest release: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("GitHub API: HTTP %d", resp.StatusCode) + } + + var release struct { + TagName string `json:"tag_name"` + } + if err := json.NewDecoder(resp.Body).Decode(&release); err != nil { + return "", fmt.Errorf("decode response: %w", err) + } + + if release.TagName == "" { + return "", fmt.Errorf("empty tag_name in release") + } + + return strings.TrimPrefix(release.TagName, "v"), nil +} diff --git a/internal/upgrade/extract.go b/internal/upgrade/extract.go new file mode 100644 index 0000000..493e20d --- /dev/null +++ b/internal/upgrade/extract.go @@ -0,0 +1,123 @@ +package upgrade + +import ( + "archive/tar" + "archive/zip" + "compress/gzip" + "fmt" + "io" + "os" + "path/filepath" + "runtime" + "strings" +) + +// extractBinary extracts the unarr binary from the release archive into destDir. +// Returns the path to the extracted binary. +func extractBinary(archivePath, destDir string) (string, error) { + if runtime.GOOS == "windows" { + return extractZip(archivePath, destDir) + } + return extractTarGz(archivePath, destDir) +} + +func extractTarGz(archivePath, destDir string) (string, error) { + f, err := os.Open(archivePath) + if err != nil { + return "", err + } + defer f.Close() + + gz, err := gzip.NewReader(f) + if err != nil { + return "", fmt.Errorf("gzip: %w", err) + } + defer gz.Close() + + tr := tar.NewReader(gz) + target := binaryName + if runtime.GOOS == "windows" { + target += ".exe" + } + + for { + hdr, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + return "", fmt.Errorf("tar: %w", err) + } + + name := filepath.Base(hdr.Name) + if name != target { + continue + } + + // Validate: must be a regular file + if hdr.Typeflag != tar.TypeReg { + continue + } + + dst := filepath.Join(destDir, target) + out, err := os.OpenFile(dst, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o755) + if err != nil { + return "", err + } + + if _, err := io.Copy(out, io.LimitReader(tr, 200<<20)); err != nil { // 200MB limit + out.Close() + return "", fmt.Errorf("extract: %w", err) + } + out.Close() + return dst, nil + } + + return "", fmt.Errorf("binary %q not found in archive", target) +} + +func extractZip(archivePath, destDir string) (string, error) { + r, err := zip.OpenReader(archivePath) + if err != nil { + return "", fmt.Errorf("zip: %w", err) + } + defer r.Close() + + target := binaryName + ".exe" + + for _, f := range r.File { + name := filepath.Base(f.Name) + + // Guard against path traversal + if strings.Contains(f.Name, "..") { + continue + } + + if name != target { + continue + } + + rc, err := f.Open() + if err != nil { + return "", err + } + + dst := filepath.Join(destDir, target) + out, err := os.OpenFile(dst, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o755) + if err != nil { + rc.Close() + return "", err + } + + if _, err := io.Copy(out, io.LimitReader(rc, 200<<20)); err != nil { // 200MB limit + out.Close() + rc.Close() + return "", fmt.Errorf("extract: %w", err) + } + out.Close() + rc.Close() + return dst, nil + } + + return "", fmt.Errorf("binary %q not found in archive", target) +} diff --git a/internal/upgrade/upgrade.go b/internal/upgrade/upgrade.go new file mode 100644 index 0000000..b70dc7e --- /dev/null +++ b/internal/upgrade/upgrade.go @@ -0,0 +1,226 @@ +// Package upgrade implements safe self-update for the unarr binary. +// +// The upgrade process: +// 1. Detect current binary path and verify write permissions +// 2. Download the release archive from GitHub +// 3. Verify SHA256 checksum against checksums.txt +// 4. Extract the binary from the archive +// 5. Smoke test: run the new binary with "version" to confirm it works +// 6. Backup the current binary +// 7. Replace with the new binary (preserving permissions) +// 8. On any failure: rollback from backup +package upgrade + +import ( + "context" + "fmt" + "log" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "time" +) + +const ( + githubRepo = "torrentclaw/unarr" + binaryName = "unarr" + smokeTestTO = 5 * time.Second +) + +// Result represents the outcome of an upgrade attempt. +type Result struct { + Success bool + OldVersion string + NewVersion string + BackupPath string + Error error +} + +// Upgrader handles downloading, verifying, and replacing the CLI binary. +type Upgrader struct { + CurrentVersion string + // OnProgress is called with status messages during the upgrade process. + OnProgress func(msg string) +} + +func (u *Upgrader) log(msg string) { + if u.OnProgress != nil { + u.OnProgress(msg) + } + log.Printf("[upgrade] %s", msg) +} + +// Execute performs a full upgrade to the target version. +func (u *Upgrader) Execute(ctx context.Context, targetVersion string) Result { + targetVersion = strings.TrimPrefix(targetVersion, "v") + + if targetVersion == u.CurrentVersion { + return Result{Success: true, OldVersion: u.CurrentVersion, NewVersion: targetVersion} + } + + // 1. Detect current binary path + binPath, err := os.Executable() + if err != nil { + return u.fail("detect binary: %v", err) + } + binPath, err = filepath.EvalSymlinks(binPath) + if err != nil { + return u.fail("resolve symlinks: %v", err) + } + + // 2. Check Docker — self-update makes no sense in a container + if isDocker() { + return u.fail("running in Docker — update the container image instead") + } + + // 3. Check write permissions + binDir := filepath.Dir(binPath) + if err := checkWritable(binDir); err != nil { + return u.fail("no write permission to %s — run with elevated privileges or move the binary to a user-writable location", binDir) + } + + // 4. Download archive + u.log(fmt.Sprintf("Downloading v%s...", targetVersion)) + archivePath, err := download(ctx, targetVersion) + if err != nil { + return u.fail("download: %v", err) + } + defer os.Remove(archivePath) + + // 5. Verify checksum + u.log("Verifying checksum...") + if err := verifyChecksum(ctx, targetVersion, archivePath); err != nil { + return u.fail("checksum: %v", err) + } + + // 6. Extract binary + u.log("Extracting...") + tmpDir, err := os.MkdirTemp("", "unarr-upgrade-*") + if err != nil { + return u.fail("create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + newBinPath, err := extractBinary(archivePath, tmpDir) + if err != nil { + return u.fail("extract: %v", err) + } + + // 7. Smoke test + u.log("Verifying new binary...") + if err := smokeTest(newBinPath, targetVersion); err != nil { + return u.fail("smoke test: %v", err) + } + + // 8. Backup current binary + backupPath := binPath + ".backup" + u.log("Backing up current binary...") + if err := os.Rename(binPath, backupPath); err != nil { + return u.fail("backup: %v", err) + } + + // 9. Replace with new binary + u.log("Installing new binary...") + if err := installBinary(newBinPath, binPath); err != nil { + // Rollback + u.log("Install failed, rolling back...") + if rbErr := os.Rename(backupPath, binPath); rbErr != nil { + return u.fail("install failed (%v) AND rollback failed (%v) — manual recovery needed at %s", err, rbErr, backupPath) + } + return u.fail("install (rolled back): %v", err) + } + + u.log(fmt.Sprintf("Upgraded %s → %s", u.CurrentVersion, targetVersion)) + + return Result{ + Success: true, + OldVersion: u.CurrentVersion, + NewVersion: targetVersion, + BackupPath: backupPath, + } +} + +func (u *Upgrader) fail(format string, args ...any) Result { + err := fmt.Errorf(format, args...) + u.log(fmt.Sprintf("FAILED: %v", err)) + return Result{ + Success: false, + OldVersion: u.CurrentVersion, + Error: err, + } +} + +// CheckLatest fetches the latest version from GitHub API. +func CheckLatest(ctx context.Context) (string, error) { + return fetchLatestVersion(ctx) +} + +// installBinary copies the new binary to the target path, preserving original permissions. +func installBinary(src, dst string) error { + // Read new binary + data, err := os.ReadFile(src) + if err != nil { + return fmt.Errorf("read new binary: %w", err) + } + + // Write to destination with executable permissions + if err := os.WriteFile(dst, data, 0o755); err != nil { + return fmt.Errorf("write binary: %w", err) + } + + return nil +} + +// smokeTest runs the new binary with "version" and checks the output contains the expected version. +func smokeTest(binPath, expectedVersion string) error { + ctx, cancel := context.WithTimeout(context.Background(), smokeTestTO) + defer cancel() + + out, err := exec.CommandContext(ctx, binPath, "version").CombinedOutput() + if err != nil { + return fmt.Errorf("failed to run: %w (output: %s)", err, string(out)) + } + + output := string(out) + if !strings.Contains(output, expectedVersion) { + return fmt.Errorf("version mismatch: expected %q in output %q", expectedVersion, output) + } + + return nil +} + +// isDocker returns true if running inside a Docker container. +func isDocker() bool { + if _, err := os.Stat("/.dockerenv"); err == nil { + return true + } + return false +} + +// checkWritable verifies the directory is writable by creating and removing a temp file. +func checkWritable(dir string) error { + tmp := filepath.Join(dir, ".unarr-write-test") + f, err := os.Create(tmp) + if err != nil { + return err + } + f.Close() + os.Remove(tmp) + return nil +} + +// archiveName returns the expected archive filename for this platform. +func archiveName(version string) string { + ext := "tar.gz" + if runtime.GOOS == "windows" { + ext = "zip" + } + return fmt.Sprintf("%s_%s_%s_%s.%s", binaryName, version, runtime.GOOS, runtime.GOARCH, ext) +} + +// releaseURL returns the download URL for a release asset. +func releaseURL(version, filename string) string { + return fmt.Sprintf("https://github.com/%s/releases/download/v%s/%s", githubRepo, version, filename) +} diff --git a/internal/upgrade/upgrade_test.go b/internal/upgrade/upgrade_test.go new file mode 100644 index 0000000..2753005 --- /dev/null +++ b/internal/upgrade/upgrade_test.go @@ -0,0 +1,307 @@ +package upgrade + +import ( + "archive/tar" + "compress/gzip" + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "runtime" + "testing" +) + +func TestIsDocker(t *testing.T) { + // In a normal test environment, we should NOT be in Docker + if _, err := os.Stat("/.dockerenv"); err == nil { + t.Skip("running in Docker, skipping non-Docker test") + } + if isDocker() { + t.Error("isDocker() = true, want false (not running in Docker)") + } +} + +func TestCheckWritable(t *testing.T) { + t.Run("writable directory", func(t *testing.T) { + dir := t.TempDir() + if err := checkWritable(dir); err != nil { + t.Errorf("checkWritable(%q) = %v, want nil", dir, err) + } + }) + + t.Run("non-existent directory", func(t *testing.T) { + err := checkWritable("/nonexistent-path-that-should-not-exist-12345") + if err == nil { + t.Error("checkWritable(nonexistent) = nil, want error") + } + }) +} + +func TestArchiveName(t *testing.T) { + name := archiveName("0.3.0") + expected := fmt.Sprintf("unarr_0.3.0_%s_%s.", runtime.GOOS, runtime.GOARCH) + if runtime.GOOS == "windows" { + expected += "zip" + } else { + expected += "tar.gz" + } + if name != expected { + t.Errorf("archiveName(0.3.0) = %q, want %q", name, expected) + } +} + +func TestReleaseURL(t *testing.T) { + url := releaseURL("0.3.0", "unarr_0.3.0_linux_amd64.tar.gz") + want := "https://github.com/torrentclaw/unarr/releases/download/v0.3.0/unarr_0.3.0_linux_amd64.tar.gz" + if url != want { + t.Errorf("releaseURL = %q, want %q", url, want) + } +} + +func TestSmokeTest(t *testing.T) { + t.Run("successful smoke test", func(t *testing.T) { + // Create a fake binary that outputs a version + dir := t.TempDir() + script := filepath.Join(dir, "fake-unarr") + content := "#!/bin/sh\necho 'unarr 1.2.3 (linux/amd64)'\n" + if runtime.GOOS == "windows" { + script += ".bat" + content = "@echo unarr 1.2.3 (windows/amd64)\n" + } + os.WriteFile(script, []byte(content), 0o755) + + err := smokeTest(script, "1.2.3") + if err != nil { + t.Errorf("smokeTest() = %v, want nil", err) + } + }) + + t.Run("version mismatch", func(t *testing.T) { + dir := t.TempDir() + script := filepath.Join(dir, "fake-unarr") + content := "#!/bin/sh\necho 'unarr 0.1.0 (linux/amd64)'\n" + if runtime.GOOS == "windows" { + script += ".bat" + content = "@echo unarr 0.1.0 (windows/amd64)\n" + } + os.WriteFile(script, []byte(content), 0o755) + + err := smokeTest(script, "1.2.3") + if err == nil { + t.Error("smokeTest() = nil, want version mismatch error") + } + }) + + t.Run("non-existent binary", func(t *testing.T) { + err := smokeTest("/nonexistent-binary", "1.0.0") + if err == nil { + t.Error("smokeTest(nonexistent) = nil, want error") + } + }) +} + +func TestInstallBinary(t *testing.T) { + dir := t.TempDir() + src := filepath.Join(dir, "new-binary") + dst := filepath.Join(dir, "installed-binary") + + os.WriteFile(src, []byte("binary-content"), 0o755) + + err := installBinary(src, dst) + if err != nil { + t.Fatalf("installBinary() = %v", err) + } + + data, err := os.ReadFile(dst) + if err != nil { + t.Fatalf("read installed binary: %v", err) + } + if string(data) != "binary-content" { + t.Errorf("installed binary content = %q, want %q", data, "binary-content") + } + + info, _ := os.Stat(dst) + if info.Mode().Perm()&0o111 == 0 { + t.Error("installed binary is not executable") + } +} + +func TestVerifyChecksum(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("tar.gz test only on unix") + } + + // Create a fake archive + dir := t.TempDir() + archivePath := filepath.Join(dir, "unarr_1.0.0_linux_amd64.tar.gz") + archiveContent := []byte("fake-archive-content-for-testing") + os.WriteFile(archivePath, archiveContent, 0o644) + + // Calculate expected hash + h := sha256.Sum256(archiveContent) + expectedHash := hex.EncodeToString(h[:]) + + t.Run("valid checksum", func(t *testing.T) { + // Create a mock server that returns checksums.txt + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/torrentclaw/unarr/releases/download/v1.0.0/checksums.txt" { + fmt.Fprintf(w, "%s unarr_1.0.0_linux_amd64.tar.gz\n", expectedHash) + fmt.Fprintf(w, "0000000000000000000000000000000000000000000000000000000000000000 unarr_1.0.0_darwin_amd64.tar.gz\n") + } else { + w.WriteHeader(404) + } + })) + defer srv.Close() + + // Override the httpClient and repo for testing + origClient := httpClient + httpClient = srv.Client() + defer func() { httpClient = origClient }() + + // We can't easily test verifyChecksum directly because it builds URLs from constants. + // Instead, test the checksum logic manually + f, _ := os.Open(archivePath) + defer f.Close() + hash := sha256.New() + hash.Write(archiveContent) + actualHash := hex.EncodeToString(hash.Sum(nil)) + + if actualHash != expectedHash { + t.Errorf("hash mismatch: got %s, want %s", actualHash, expectedHash) + } + }) + + t.Run("hash calculation correctness", func(t *testing.T) { + data := []byte("test data for hashing") + h := sha256.Sum256(data) + got := hex.EncodeToString(h[:]) + // Known SHA256 of "test data for hashing" + if len(got) != 64 { + t.Errorf("hash length = %d, want 64", len(got)) + } + }) +} + +func TestExtractTarGz(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("tar.gz test only on unix") + } + + dir := t.TempDir() + + // Create a tar.gz with a fake binary inside + archivePath := filepath.Join(dir, "test.tar.gz") + f, err := os.Create(archivePath) + if err != nil { + t.Fatal(err) + } + + gw := gzip.NewWriter(f) + tw := tar.NewWriter(gw) + + binaryContent := []byte("#!/bin/sh\necho test\n") + hdr := &tar.Header{ + Name: "unarr", + Mode: 0o755, + Size: int64(len(binaryContent)), + } + tw.WriteHeader(hdr) + tw.Write(binaryContent) + tw.Close() + gw.Close() + f.Close() + + // Extract + destDir := filepath.Join(dir, "extracted") + os.MkdirAll(destDir, 0o755) + + binPath, err := extractTarGz(archivePath, destDir) + if err != nil { + t.Fatalf("extractTarGz() = %v", err) + } + + if filepath.Base(binPath) != "unarr" { + t.Errorf("extracted binary name = %q, want unarr", filepath.Base(binPath)) + } + + data, _ := os.ReadFile(binPath) + if string(data) != string(binaryContent) { + t.Errorf("extracted content mismatch") + } + + info, _ := os.Stat(binPath) + if info.Mode().Perm()&0o111 == 0 { + t.Error("extracted binary is not executable") + } +} + +func TestExtractTarGzMissingBinary(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("tar.gz test only on unix") + } + + dir := t.TempDir() + archivePath := filepath.Join(dir, "empty.tar.gz") + f, _ := os.Create(archivePath) + gw := gzip.NewWriter(f) + tw := tar.NewWriter(gw) + + // Write a file that is NOT named "unarr" + hdr := &tar.Header{Name: "README.md", Mode: 0o644, Size: 4} + tw.WriteHeader(hdr) + tw.Write([]byte("test")) + tw.Close() + gw.Close() + f.Close() + + destDir := filepath.Join(dir, "out") + os.MkdirAll(destDir, 0o755) + + _, err := extractTarGz(archivePath, destDir) + if err == nil { + t.Error("expected error for archive without unarr binary") + } +} + +func TestUpgraderSameVersion(t *testing.T) { + u := &Upgrader{CurrentVersion: "1.0.0"} + result := u.Execute(context.Background(), "1.0.0") + if !result.Success { + t.Error("expected success when upgrading to same version") + } + if result.NewVersion != "1.0.0" { + t.Errorf("NewVersion = %q, want 1.0.0", result.NewVersion) + } +} + +func TestUpgraderSameVersionWithPrefix(t *testing.T) { + u := &Upgrader{CurrentVersion: "1.0.0"} + result := u.Execute(context.Background(), "v1.0.0") + if !result.Success { + t.Error("expected success when target version has v prefix") + } +} + +func TestFetchLatestVersionMockServer(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"tag_name":"v2.5.1","published_at":"2025-01-01T00:00:00Z"}`) + })) + defer srv.Close() + + // We can't directly test fetchLatestVersion because it uses a hardcoded URL. + // But we can test the JSON parsing logic by calling the endpoint ourselves. + resp, err := http.Get(srv.URL) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + t.Errorf("status = %d, want 200", resp.StatusCode) + } +} diff --git a/internal/usenet/download/downloader.go b/internal/usenet/download/downloader.go index caabc44..d96c97f 100644 --- a/internal/usenet/download/downloader.go +++ b/internal/usenet/download/downloader.go @@ -12,7 +12,6 @@ import ( "sync/atomic" "time" - "github.com/torrentclaw/torrentclaw-cli/internal/ui" "github.com/torrentclaw/torrentclaw-cli/internal/usenet/nntp" "github.com/torrentclaw/torrentclaw-cli/internal/usenet/nzb" "github.com/torrentclaw/torrentclaw-cli/internal/usenet/yenc" @@ -39,8 +38,11 @@ func NewDownloader(nntpClient *nntp.Client) *Downloader { } // DownloadFile downloads all segments of a single NZB file and assembles them. +// If tracker is non-nil, it is used for resume support: completed segments are +// skipped, and progress is persisted to disk on pause/error. +// fileIndex is the index of this file within the NZB (for the tracker). // Returns the path to the assembled file. -func (d *Downloader) DownloadFile(ctx context.Context, file nzb.File, outputDir string, progressCh chan<- Progress) (string, error) { +func (d *Downloader) DownloadFile(ctx context.Context, file nzb.File, fileIndex int, outputDir string, tracker *ProgressTracker, progressCh chan<- Progress) (string, error) { fileName := file.Filename() if fileName == "" { fileName = fmt.Sprintf("usenet_%d", time.Now().UnixNano()) @@ -53,6 +55,15 @@ func (d *Downloader) DownloadFile(ctx context.Context, file nzb.File, outputDir return "", fmt.Errorf("mkdir: %w", err) } + // If tracker says this file is fully done, skip entirely + if tracker != nil && tracker.IsFileDone(fileIndex) { + if _, err := os.Stat(destPath); err == nil { + log.Printf("[usenet] skipping %s (fully downloaded in previous run)", fileName) + return destPath, nil + } + // File was marked done but doesn't exist on disk — re-download + } + totalBytes := file.TotalBytes() totalSegs := len(file.Segments) @@ -63,34 +74,6 @@ func (d *Downloader) DownloadFile(ctx context.Context, file nzb.File, outputDir return segments[i].Number < segments[j].Number }) - // Create/open output file - outFile, err := os.Create(destPath) - if err != nil { - return "", fmt.Errorf("create file: %w", err) - } - defer outFile.Close() - - // Pre-allocate file if we know the size - if totalBytes > 0 { - outFile.Truncate(totalBytes) - } - - // Download segments using worker pool - var downloaded atomic.Int64 - var segsDone atomic.Int32 - startTime := time.Now() - - // Create work channel - type segWork struct { - seg nzb.Segment - index int - } - workCh := make(chan segWork, len(segments)) - for i, seg := range segments { - workCh <- segWork{seg: seg, index: i} - } - close(workCh) - // Track file offsets for each segment (sequential assembly) offsets := make([]int64, len(segments)) var offset int64 @@ -99,6 +82,76 @@ func (d *Downloader) DownloadFile(ctx context.Context, file nzb.File, outputDir offset += seg.Bytes } + // Open output file — resume-aware + var outFile *os.File + var err error + resuming := false + + if tracker != nil { + if _, statErr := os.Stat(destPath); statErr == nil && tracker.CompletedSegments(fileIndex) > 0 { + // Partial file exists and we have progress — open for read-write (no truncate) + outFile, err = os.OpenFile(destPath, os.O_RDWR, 0o644) + resuming = true + } + } + + if outFile == nil { + // Fresh start + outFile, err = os.Create(destPath) + if err != nil { + return "", fmt.Errorf("create file: %w", err) + } + // Pre-allocate file if we know the size + if totalBytes > 0 { + outFile.Truncate(totalBytes) + } + } else if err != nil { + return "", fmt.Errorf("open file for resume: %w", err) + } + defer outFile.Close() + + // Download segments using worker pool + var downloaded atomic.Int64 + var segsDone atomic.Int32 + startTime := time.Now() + + // Create work channel — skip already-completed segments + type segWork struct { + seg nzb.Segment + index int + } + + pendingCount := 0 + for i := range segments { + if tracker != nil && tracker.IsDone(fileIndex, i) { + // Already downloaded — count towards progress + downloaded.Add(segments[i].Bytes) + segsDone.Add(1) + } else { + pendingCount++ + } + } + + if resuming { + log.Printf("[usenet] resuming %s (%d/%d segments, %s/%s)", + fileName, totalSegs-pendingCount, totalSegs, + formatBytes(downloaded.Load()), formatBytes(totalBytes)) + } + + if pendingCount == 0 { + // All segments already done + log.Printf("[usenet] %s already complete (%d segments)", fileName, totalSegs) + return destPath, nil + } + + workCh := make(chan segWork, pendingCount) + for i, seg := range segments { + if tracker == nil || !tracker.IsDone(fileIndex, i) { + workCh <- segWork{seg: seg, index: i} + } + } + close(workCh) + // Progress reporter goroutine stopProgress := make(chan struct{}) go func() { @@ -177,6 +230,11 @@ func (d *Downloader) DownloadFile(ctx context.Context, file nzb.File, outputDir downloaded.Add(int64(len(data))) segsDone.Add(1) + + // Mark segment as completed in tracker + if tracker != nil { + tracker.MarkDone(fileIndex, work.index) + } } }() } @@ -187,17 +245,21 @@ func (d *Downloader) DownloadFile(ctx context.Context, file nzb.File, outputDir // Stop progress reporter before sending final progress close(stopProgress) - // Check for errors + // Check for errors — keep partial file for resume (don't delete) select { case err := <-errCh: - os.Remove(destPath) + if tracker != nil { + tracker.Flush() + } return "", err default: } - // Check context cancellation + // Check context cancellation — keep partial file for resume (don't delete) if ctx.Err() != nil { - os.Remove(destPath) + if tracker != nil { + tracker.Flush() + } return "", ctx.Err() } @@ -228,15 +290,16 @@ func (d *Downloader) DownloadFile(ctx context.Context, file nzb.File, outputDir outFile.Truncate(actualSize) } - log.Printf("[usenet] downloaded %s (%d segments, %s)", fileName, totalSegs, ui.FormatBytes(actualSize)) + log.Printf("[usenet] downloaded %s (%d segments, %s)", fileName, totalSegs, formatBytes(actualSize)) return destPath, nil } // DownloadNZB downloads content files from an NZB (rars or direct content). // Par2 files are NOT downloaded initially — they're only fetched on demand // if extraction fails (via DownloadPar2). +// If tracker is non-nil, completed files are skipped and progress is tracked per-segment. // Returns a map of filename → filepath for all downloaded files. -func (d *Downloader) DownloadNZB(ctx context.Context, n *nzb.NZB, outputDir string, progressCh chan<- Progress) (map[string]string, error) { +func (d *Downloader) DownloadNZB(ctx context.Context, n *nzb.NZB, outputDir string, tracker *ProgressTracker, progressCh chan<- Progress) (map[string]string, error) { // Determine which files to download (NO par2 initially) var filesToDownload []nzb.File @@ -250,6 +313,13 @@ func (d *Downloader) DownloadNZB(ctx context.Context, n *nzb.NZB, outputDir stri return nil, fmt.Errorf("no downloadable files found in NZB") } + // Build NZB file index mapping: Subject → index in n.Files + // This maps each file to its position in the ProgressTracker + nzbFileIndex := make(map[string]int) + for i, f := range n.Files { + nzbFileIndex[f.Subject] = i + } + results := make(map[string]string) for _, file := range filesToDownload { @@ -259,7 +329,19 @@ func (d *Downloader) DownloadNZB(ctx context.Context, n *nzb.NZB, outputDir stri default: } - path, err := d.DownloadFile(ctx, file, outputDir, progressCh) + fileIdx := nzbFileIndex[file.Subject] + + // Skip fully completed files + if tracker != nil && tracker.IsFileDone(fileIdx) { + destPath := filepath.Join(outputDir, file.Filename()) + if _, err := os.Stat(destPath); err == nil { + results[file.Filename()] = destPath + log.Printf("[usenet] skipping %s (complete)", file.Filename()) + continue + } + } + + path, err := d.DownloadFile(ctx, file, fileIdx, outputDir, tracker, progressCh) if err != nil { return results, fmt.Errorf("download %s: %w", file.Filename(), err) } @@ -271,6 +353,7 @@ func (d *Downloader) DownloadNZB(ctx context.Context, n *nzb.NZB, outputDir stri // DownloadPar2 downloads par2 parity files from the NZB. // Called on-demand when extraction/verification fails. +// No resume tracking — par2 files are small and downloaded fresh. func (d *Downloader) DownloadPar2(ctx context.Context, n *nzb.NZB, outputDir string, progressCh chan<- Progress) (map[string]string, error) { par2Files := n.Par2Files() if len(par2Files) == 0 { @@ -279,7 +362,7 @@ func (d *Downloader) DownloadPar2(ctx context.Context, n *nzb.NZB, outputDir str results := make(map[string]string) for _, file := range par2Files { - path, err := d.DownloadFile(ctx, file, outputDir, progressCh) + path, err := d.DownloadFile(ctx, file, -1, outputDir, nil, progressCh) if err != nil { log.Printf("[usenet] par2 download failed (non-fatal): %v", err) continue @@ -306,3 +389,15 @@ func (d *Downloader) downloadSegment(ctx context.Context, seg nzb.Segment) ([]by return part.Data, nil } +func formatBytes(b int64) string { + const unit = 1024 + if b < unit { + return fmt.Sprintf("%d B", b) + } + div, exp := int64(unit), 0 + for n := b / unit; n >= unit; n /= unit { + div *= unit + exp++ + } + return fmt.Sprintf("%.1f %cB", float64(b)/float64(div), "KMGTPE"[exp]) +} diff --git a/internal/usenet/download/e2e_test.go b/internal/usenet/download/e2e_test.go index b6ecffb..a6ac531 100644 --- a/internal/usenet/download/e2e_test.go +++ b/internal/usenet/download/e2e_test.go @@ -109,7 +109,7 @@ func TestE2EDownload(t *testing.T) { fmt.Fprintln(os.Stderr) }() - downloadedFiles, err := dl.DownloadNZB(ctx, nzbFile, outputDir, progressCh) + downloadedFiles, err := dl.DownloadNZB(ctx, nzbFile, outputDir, nil, progressCh) close(progressCh) if err != nil { t.Fatalf("download: %v", err) diff --git a/internal/usenet/download/progress.go b/internal/usenet/download/progress.go new file mode 100644 index 0000000..1884d4c --- /dev/null +++ b/internal/usenet/download/progress.go @@ -0,0 +1,345 @@ +package download + +import ( + "crypto/sha256" + "encoding/binary" + "fmt" + "os" + "path/filepath" + "sort" + "sync" + "sync/atomic" + "time" + + "github.com/torrentclaw/torrentclaw-cli/internal/usenet/nzb" +) + +// Binary progress file format: +// [4B magic "UNRP"] [1B version=1] [1B reserved] [2B fileCount] +// [32B SHA-256 fingerprint] +// Per file: [4B segCount] [ceil(segCount/8) bytes bitset] + +var progressMagic = [4]byte{'U', 'N', 'R', 'P'} + +const ( + progressVersion = 1 + headerSize = 4 + 1 + 1 + 2 + 32 // 40 bytes + flushInterval = 2 * time.Second + flushSegmentFreq = 100 // flush every N segment completions +) + +// fileProgress tracks completed segments for a single NZB file. +type fileProgress struct { + segCount int + completed []byte // bitset: ceil(segCount/8) bytes + doneCount atomic.Int32 +} + +// ProgressTracker tracks segment-level download progress for resumable usenet downloads. +// Thread-safe for concurrent use by multiple download workers. +type ProgressTracker struct { + taskID string + fingerprint [32]byte + dir string // directory where progress files are stored + files []fileProgress + + mu sync.Mutex + dirty bool + lastFlush time.Time + markCount int // marks since last flush +} + +// Fingerprint computes a SHA-256 hash from all message-IDs in the NZB. +// Used to validate that a progress file matches the same NZB content. +func Fingerprint(n *nzb.NZB) [32]byte { + var ids []string + for _, f := range n.Files { + for _, s := range f.Segments { + ids = append(ids, s.MessageID) + } + } + sort.Strings(ids) + + h := sha256.New() + for _, id := range ids { + h.Write([]byte(id)) + h.Write([]byte{'\n'}) + } + + var fp [32]byte + copy(fp[:], h.Sum(nil)) + return fp +} + +// NewProgressTracker creates a tracker for the given NZB. +// The dir parameter is the base directory for resume files (e.g. DataDir()/resume). +func NewProgressTracker(taskID string, n *nzb.NZB, dir string) *ProgressTracker { + files := make([]fileProgress, len(n.Files)) + for i, f := range n.Files { + segCount := len(f.Segments) + files[i] = fileProgress{ + segCount: segCount, + completed: make([]byte, (segCount+7)/8), + } + } + + return &ProgressTracker{ + taskID: taskID, + fingerprint: Fingerprint(n), + dir: dir, + files: files, + lastFlush: time.Now(), + } +} + +// progressPath returns the path to the binary progress file. +func (p *ProgressTracker) progressPath() string { + return filepath.Join(p.dir, p.taskID+".progress") +} + +// nzbPath returns the path to the cached NZB file. +func (p *ProgressTracker) nzbPath() string { + return filepath.Join(p.dir, p.taskID+".nzb") +} + +// Load reads a progress file from disk and validates its fingerprint. +// Returns true if the file was loaded successfully and matches the current NZB. +// Returns false if the file doesn't exist, is invalid, or has a different fingerprint. +func (p *ProgressTracker) Load() (bool, error) { + data, err := os.ReadFile(p.progressPath()) + if err != nil { + return false, nil // file doesn't exist = fresh start + } + + if len(data) < headerSize { + return false, nil + } + + // Validate magic + if data[0] != progressMagic[0] || data[1] != progressMagic[1] || + data[2] != progressMagic[2] || data[3] != progressMagic[3] { + return false, nil + } + + // Validate version + if data[4] != progressVersion { + return false, nil + } + + // Validate file count + fileCount := int(binary.LittleEndian.Uint16(data[6:8])) + if fileCount != len(p.files) { + return false, nil + } + + // Validate fingerprint + var storedFP [32]byte + copy(storedFP[:], data[8:40]) + if storedFP != p.fingerprint { + return false, nil + } + + // Read per-file bitsets + offset := headerSize + for i := range p.files { + if offset+4 > len(data) { + return false, nil + } + segCount := int(binary.LittleEndian.Uint32(data[offset : offset+4])) + offset += 4 + + if segCount != p.files[i].segCount { + return false, nil + } + + bitsetLen := (segCount + 7) / 8 + if offset+bitsetLen > len(data) { + return false, nil + } + + copy(p.files[i].completed, data[offset:offset+bitsetLen]) + offset += bitsetLen + + // Count completed segments + var count int32 + for seg := 0; seg < segCount; seg++ { + if p.files[i].completed[seg/8]&(1<= len(p.files) { + return + } + fp := &p.files[fileIndex] + if segIndex < 0 || segIndex >= fp.segCount { + return + } + + p.mu.Lock() + fp.completed[segIndex/8] |= 1 << uint(segIndex%8) + fp.doneCount.Add(1) + p.dirty = true + p.markCount++ + + shouldFlush := p.markCount >= flushSegmentFreq || time.Since(p.lastFlush) >= flushInterval + p.mu.Unlock() + + if shouldFlush { + p.Flush() + } +} + +// IsDone returns whether a specific segment has been completed. +func (p *ProgressTracker) IsDone(fileIndex, segIndex int) bool { + if fileIndex < 0 || fileIndex >= len(p.files) { + return false + } + fp := &p.files[fileIndex] + if segIndex < 0 || segIndex >= fp.segCount { + return false + } + // Read without lock — single-byte read is atomic on aligned data, + // and we only ever set bits (never clear), so a stale read just means + // we might re-download a segment (harmless, WriteAt is idempotent). + return fp.completed[segIndex/8]&(1<= len(p.files) { + return false + } + fp := &p.files[fileIndex] + return int(fp.doneCount.Load()) >= fp.segCount +} + +// CompletedSegments returns the number of completed segments for a file. +func (p *ProgressTracker) CompletedSegments(fileIndex int) int { + if fileIndex < 0 || fileIndex >= len(p.files) { + return 0 + } + return int(p.files[fileIndex].doneCount.Load()) +} + +// CompletedBytes returns the total bytes of completed segments for a file. +func (p *ProgressTracker) CompletedBytes(fileIndex int, segments []nzb.Segment) int64 { + if fileIndex < 0 || fileIndex >= len(p.files) { + return 0 + } + var total int64 + for i, seg := range segments { + if p.IsDone(fileIndex, i) { + total += seg.Bytes + } + } + return total +} + +// TotalCompleted returns total completed segments across all files. +func (p *ProgressTracker) TotalCompleted() int { + var total int + for i := range p.files { + total += int(p.files[i].doneCount.Load()) + } + return total +} + +// Flush writes the current progress state to disk atomically (tmp + rename). +func (p *ProgressTracker) Flush() error { + p.mu.Lock() + if !p.dirty { + p.mu.Unlock() + return nil + } + + // Calculate total size + size := headerSize + for i := range p.files { + size += 4 + (p.files[i].segCount+7)/8 + } + + buf := make([]byte, size) + + // Header + copy(buf[0:4], progressMagic[:]) + buf[4] = progressVersion + buf[5] = 0 // reserved + binary.LittleEndian.PutUint16(buf[6:8], uint16(len(p.files))) + copy(buf[8:40], p.fingerprint[:]) + + // Per-file bitsets + offset := headerSize + for i := range p.files { + fp := &p.files[i] + binary.LittleEndian.PutUint32(buf[offset:offset+4], uint32(fp.segCount)) + offset += 4 + bitsetLen := (fp.segCount + 7) / 8 + copy(buf[offset:offset+bitsetLen], fp.completed[:bitsetLen]) + offset += bitsetLen + } + + p.dirty = false + p.markCount = 0 + p.lastFlush = time.Now() + p.mu.Unlock() + + // Atomic write: tmp file + rename + if err := os.MkdirAll(p.dir, 0o755); err != nil { + return fmt.Errorf("create resume dir: %w", err) + } + + tmpPath := p.progressPath() + ".tmp" + if err := os.WriteFile(tmpPath, buf, 0o644); err != nil { + return fmt.Errorf("write progress tmp: %w", err) + } + + if err := os.Rename(tmpPath, p.progressPath()); err != nil { + os.Remove(tmpPath) + return fmt.Errorf("rename progress: %w", err) + } + + return nil +} + +// Remove deletes both the progress file and cached NZB file. +func (p *ProgressTracker) Remove() error { + os.Remove(p.progressPath()) + os.Remove(p.nzbPath()) + // Also remove tmp file if it exists + os.Remove(p.progressPath() + ".tmp") + return nil +} + +// CleanStaleFiles removes resume files older than maxAge from the given directory. +func CleanStaleFiles(dir string, maxAge time.Duration) int { + entries, err := os.ReadDir(dir) + if err != nil { + return 0 + } + + removed := 0 + for _, e := range entries { + if e.IsDir() { + continue + } + info, err := e.Info() + if err != nil { + continue + } + if time.Since(info.ModTime()) > maxAge { + if err := os.Remove(filepath.Join(dir, e.Name())); err == nil { + removed++ + } + } + } + return removed +} diff --git a/internal/usenet/download/progress_test.go b/internal/usenet/download/progress_test.go new file mode 100644 index 0000000..831eeb4 --- /dev/null +++ b/internal/usenet/download/progress_test.go @@ -0,0 +1,398 @@ +package download + +import ( + "os" + "path/filepath" + "sync" + "testing" + + "time" + + "github.com/torrentclaw/torrentclaw-cli/internal/usenet/nzb" +) + +var fixedPast = time.Now().Add(-30 * 24 * time.Hour) + +func makeTestNZB(fileCount, segsPerFile int) *nzb.NZB { + n := &nzb.NZB{ + Files: make([]nzb.File, fileCount), + } + for i := 0; i < fileCount; i++ { + segs := make([]nzb.Segment, segsPerFile) + for j := 0; j < segsPerFile; j++ { + segs[j] = nzb.Segment{ + Bytes: 750 * 1024, + Number: j + 1, + MessageID: segMsgID(i, j), + } + } + n.Files[i] = nzb.File{ + Subject: `"testfile_` + string(rune('a'+i)) + `.rar" yEnc (1/` + string(rune('0'+segsPerFile)) + `)`, + Segments: segs, + } + } + return n +} + +func segMsgID(file, seg int) string { + return "part" + itoa(seg) + ".file" + itoa(file) + "@example.com" +} + +func itoa(n int) string { + if n == 0 { + return "0" + } + s := "" + for n > 0 { + s = string(rune('0'+n%10)) + s + n /= 10 + } + return s +} + +func TestFingerprint_Deterministic(t *testing.T) { + n := makeTestNZB(3, 10) + fp1 := Fingerprint(n) + fp2 := Fingerprint(n) + if fp1 != fp2 { + t.Fatal("fingerprint should be deterministic") + } +} + +func TestFingerprint_DifferentNZB(t *testing.T) { + n1 := makeTestNZB(3, 10) + n2 := makeTestNZB(3, 11) + if Fingerprint(n1) == Fingerprint(n2) { + t.Fatal("different NZBs should have different fingerprints") + } +} + +func TestProgressTracker_NewAndFlush(t *testing.T) { + dir := t.TempDir() + n := makeTestNZB(2, 5) + tracker := NewProgressTracker("test-task-1", n, dir) + + // Mark some segments + tracker.MarkDone(0, 0) + tracker.MarkDone(0, 2) + tracker.MarkDone(1, 4) + + if err := tracker.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + + // Verify file exists + path := filepath.Join(dir, "test-task-1.progress") + if _, err := os.Stat(path); err != nil { + t.Fatalf("progress file should exist: %v", err) + } + + // Verify state + if !tracker.IsDone(0, 0) { + t.Error("segment 0,0 should be done") + } + if tracker.IsDone(0, 1) { + t.Error("segment 0,1 should NOT be done") + } + if !tracker.IsDone(0, 2) { + t.Error("segment 0,2 should be done") + } + if !tracker.IsDone(1, 4) { + t.Error("segment 1,4 should be done") + } + if tracker.CompletedSegments(0) != 2 { + t.Errorf("file 0: expected 2 completed, got %d", tracker.CompletedSegments(0)) + } + if tracker.CompletedSegments(1) != 1 { + t.Errorf("file 1: expected 1 completed, got %d", tracker.CompletedSegments(1)) + } +} + +func TestProgressTracker_LoadRoundTrip(t *testing.T) { + dir := t.TempDir() + n := makeTestNZB(2, 8) + + // Create and populate + tracker1 := NewProgressTracker("test-task-2", n, dir) + tracker1.MarkDone(0, 0) + tracker1.MarkDone(0, 3) + tracker1.MarkDone(0, 7) + tracker1.MarkDone(1, 1) + tracker1.MarkDone(1, 5) + if err := tracker1.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + + // Load into new tracker + tracker2 := NewProgressTracker("test-task-2", n, dir) + loaded, err := tracker2.Load() + if err != nil { + t.Fatalf("load: %v", err) + } + if !loaded { + t.Fatal("should have loaded successfully") + } + + // Verify all bits match + for _, tc := range []struct { + file, seg int + want bool + }{ + {0, 0, true}, {0, 1, false}, {0, 2, false}, {0, 3, true}, + {0, 4, false}, {0, 5, false}, {0, 6, false}, {0, 7, true}, + {1, 0, false}, {1, 1, true}, {1, 2, false}, {1, 3, false}, + {1, 4, false}, {1, 5, true}, {1, 6, false}, {1, 7, false}, + } { + got := tracker2.IsDone(tc.file, tc.seg) + if got != tc.want { + t.Errorf("file %d seg %d: got %v, want %v", tc.file, tc.seg, got, tc.want) + } + } + + if tracker2.CompletedSegments(0) != 3 { + t.Errorf("file 0: expected 3 completed, got %d", tracker2.CompletedSegments(0)) + } + if tracker2.CompletedSegments(1) != 2 { + t.Errorf("file 1: expected 2 completed, got %d", tracker2.CompletedSegments(1)) + } +} + +func TestProgressTracker_FingerprintMismatch(t *testing.T) { + dir := t.TempDir() + n1 := makeTestNZB(2, 5) + n2 := makeTestNZB(2, 6) // different segment count = different fingerprint + + // Write with n1 + tracker1 := NewProgressTracker("test-task-3", n1, dir) + tracker1.MarkDone(0, 0) + if err := tracker1.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + + // Try to load with n2 + tracker2 := NewProgressTracker("test-task-3", n2, dir) + loaded, err := tracker2.Load() + if err != nil { + t.Fatalf("load: %v", err) + } + if loaded { + t.Fatal("should NOT load — fingerprint mismatch") + } +} + +func TestProgressTracker_IsFileDone(t *testing.T) { + dir := t.TempDir() + n := makeTestNZB(1, 4) + tracker := NewProgressTracker("test-task-4", n, dir) + + if tracker.IsFileDone(0) { + t.Error("file should not be done yet") + } + + tracker.MarkDone(0, 0) + tracker.MarkDone(0, 1) + tracker.MarkDone(0, 2) + if tracker.IsFileDone(0) { + t.Error("file should not be done (3/4)") + } + + tracker.MarkDone(0, 3) + if !tracker.IsFileDone(0) { + t.Error("file should be done (4/4)") + } +} + +func TestProgressTracker_ConcurrentMark(t *testing.T) { + dir := t.TempDir() + segCount := 1000 + n := makeTestNZB(1, segCount) + tracker := NewProgressTracker("test-task-5", n, dir) + + var wg sync.WaitGroup + for i := 0; i < segCount; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + tracker.MarkDone(0, idx) + }(i) + } + wg.Wait() + + if !tracker.IsFileDone(0) { + t.Errorf("all segments should be done, got %d/%d", tracker.CompletedSegments(0), segCount) + } + + // Flush and reload + if err := tracker.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + + tracker2 := NewProgressTracker("test-task-5", n, dir) + loaded, _ := tracker2.Load() + if !loaded { + t.Fatal("should load") + } + if !tracker2.IsFileDone(0) { + t.Errorf("after reload: expected all done, got %d/%d", tracker2.CompletedSegments(0), segCount) + } +} + +func TestProgressTracker_Remove(t *testing.T) { + dir := t.TempDir() + n := makeTestNZB(1, 3) + tracker := NewProgressTracker("test-task-6", n, dir) + tracker.MarkDone(0, 0) + if err := tracker.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + + // Write a fake NZB cache file + nzbPath := filepath.Join(dir, "test-task-6.nzb") + os.WriteFile(nzbPath, []byte(""), 0o644) + + // Both should exist + if _, err := os.Stat(tracker.progressPath()); err != nil { + t.Fatal("progress file should exist") + } + if _, err := os.Stat(nzbPath); err != nil { + t.Fatal("nzb cache should exist") + } + + tracker.Remove() + + if _, err := os.Stat(tracker.progressPath()); !os.IsNotExist(err) { + t.Error("progress file should be removed") + } + if _, err := os.Stat(nzbPath); !os.IsNotExist(err) { + t.Error("nzb cache should be removed") + } +} + +func TestProgressTracker_LargeNZB(t *testing.T) { + dir := t.TempDir() + segCount := 30000 + n := makeTestNZB(1, segCount) + tracker := NewProgressTracker("test-task-7", n, dir) + + // Mark every other segment + for i := 0; i < segCount; i += 2 { + tracker.MarkDone(0, i) + } + + if err := tracker.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + + // Check file size is compact + info, err := os.Stat(tracker.progressPath()) + if err != nil { + t.Fatalf("stat: %v", err) + } + // Header (40) + file header (4) + bitset (30000/8 = 3750) = 3794 bytes + expectedMax := int64(4000) + if info.Size() > expectedMax { + t.Errorf("progress file too large: %d bytes (expected < %d)", info.Size(), expectedMax) + } + + // Reload and verify + tracker2 := NewProgressTracker("test-task-7", n, dir) + loaded, _ := tracker2.Load() + if !loaded { + t.Fatal("should load") + } + if tracker2.CompletedSegments(0) != segCount/2 { + t.Errorf("expected %d completed, got %d", segCount/2, tracker2.CompletedSegments(0)) + } + // Spot check + if !tracker2.IsDone(0, 0) { + t.Error("seg 0 should be done") + } + if tracker2.IsDone(0, 1) { + t.Error("seg 1 should NOT be done") + } + if !tracker2.IsDone(0, 100) { + t.Error("seg 100 should be done") + } +} + +func TestProgressTracker_CompletedBytes(t *testing.T) { + dir := t.TempDir() + n := makeTestNZB(1, 4) + tracker := NewProgressTracker("test-task-8", n, dir) + + tracker.MarkDone(0, 0) + tracker.MarkDone(0, 2) + + bytes := tracker.CompletedBytes(0, n.Files[0].Segments) + expected := int64(2 * 750 * 1024) // 2 segments * 750KB + if bytes != expected { + t.Errorf("expected %d bytes, got %d", expected, bytes) + } +} + +func TestProgressTracker_BoundsCheck(t *testing.T) { + dir := t.TempDir() + n := makeTestNZB(1, 3) + tracker := NewProgressTracker("test-task-9", n, dir) + + // Out-of-bounds should not panic + tracker.MarkDone(-1, 0) + tracker.MarkDone(0, -1) + tracker.MarkDone(5, 0) + tracker.MarkDone(0, 100) + + if tracker.IsDone(-1, 0) { + t.Error("out of bounds should return false") + } + if tracker.IsDone(5, 0) { + t.Error("out of bounds should return false") + } + if tracker.IsFileDone(-1) { + t.Error("out of bounds should return false") + } +} + +func TestCleanStaleFiles(t *testing.T) { + dir := t.TempDir() + + // Create a "stale" file + stalePath := filepath.Join(dir, "old-task.progress") + os.WriteFile(stalePath, []byte("data"), 0o644) + // Backdate modification time + staleTime := os.Chtimes(stalePath, fixedPast, fixedPast) + if staleTime != nil { + t.Fatalf("chtimes: %v", staleTime) + } + + // Create a "fresh" file + freshPath := filepath.Join(dir, "new-task.progress") + os.WriteFile(freshPath, []byte("data"), 0o644) + + removed := CleanStaleFiles(dir, 14*24*time.Hour) // 2 weeks — stale file is 30 days old + if removed != 1 { + t.Errorf("expected 1 removed, got %d", removed) + } + + if _, err := os.Stat(stalePath); !os.IsNotExist(err) { + t.Error("stale file should be removed") + } + if _, err := os.Stat(freshPath); err != nil { + t.Error("fresh file should still exist") + } +} + +func TestProgressTracker_FlushNoOp(t *testing.T) { + dir := t.TempDir() + n := makeTestNZB(1, 3) + tracker := NewProgressTracker("test-task-10", n, dir) + + // Flush without any marks should be no-op + if err := tracker.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + + // File should not be created + if _, err := os.Stat(tracker.progressPath()); !os.IsNotExist(err) { + t.Error("no file should be created for empty flush") + } +}