feat: improve daemon resilience, streaming, and usenet downloads
- Add daemon state persistence and stale resume file cleanup - Add TriggerPoll for WebSocket resume actions - Improve stream server with graceful shutdown and connection tracking - Add desktop notifications for download completion - Add media file organization with Movies/TV Shows detection - Improve usenet downloader with progress tracking and resume support - Add self-update package with GitHub release verification - Downgrade tablewriter to v0.0.5 (v1.x API breaking change)
This commit is contained in:
parent
e332c0a6e4
commit
197e33956a
24 changed files with 2310 additions and 84 deletions
23
go.mod
23
go.mod
|
|
@ -1,18 +1,18 @@
|
|||
module github.com/torrentclaw/torrentclaw-cli
|
||||
|
||||
go 1.24.0
|
||||
go 1.25.0
|
||||
|
||||
require (
|
||||
github.com/BurntSushi/toml v1.6.0
|
||||
github.com/anacrolix/log v0.17.1-0.20251118025802-918f1157b7bb
|
||||
github.com/anacrolix/torrent v1.61.0
|
||||
github.com/charmbracelet/huh v1.0.0
|
||||
github.com/fatih/color v1.18.0
|
||||
github.com/fatih/color v1.19.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/olekukonko/tablewriter v0.0.5
|
||||
github.com/spf13/cobra v1.8.1
|
||||
github.com/spf13/cobra v1.10.2
|
||||
github.com/torrentclaw/go-client v0.2.0
|
||||
golang.org/x/time v0.14.0
|
||||
golang.org/x/time v0.15.0
|
||||
)
|
||||
|
||||
require (
|
||||
|
|
@ -41,6 +41,7 @@ require (
|
|||
github.com/bradfitz/iter v0.0.0-20191230175014-e8f45d346db8 // indirect
|
||||
github.com/catppuccin/go v0.3.0 // indirect
|
||||
github.com/cespare/xxhash v1.1.0 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/charmbracelet/bubbles v0.21.1-0.20250623103423-23b8fd6302d7 // indirect
|
||||
github.com/charmbracelet/bubbletea v1.3.6 // indirect
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
|
||||
|
|
@ -49,6 +50,8 @@ require (
|
|||
github.com/charmbracelet/x/cellbuf v0.0.13 // indirect
|
||||
github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0 // indirect
|
||||
github.com/charmbracelet/x/term v0.2.1 // indirect
|
||||
github.com/clipperhouse/displaywidth v0.11.0 // indirect
|
||||
github.com/clipperhouse/uax29/v2 v2.7.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/edsrzf/mmap-go v1.1.0 // indirect
|
||||
|
|
@ -57,6 +60,7 @@ require (
|
|||
github.com/go-llsqlite/crawshaw v0.5.6-0.20250312230104-194977a03421 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/goccy/go-json v0.10.6 // indirect
|
||||
github.com/google/btree v1.1.2 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/gorilla/websocket v1.5.0 // indirect
|
||||
|
|
@ -64,10 +68,10 @@ require (
|
|||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.3 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-localereader v0.0.1 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.16 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.21 // indirect
|
||||
github.com/minio/sha256-simd v1.0.0 // indirect
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect
|
||||
github.com/mr-tron/base58 v1.2.0 // indirect
|
||||
|
|
@ -77,6 +81,9 @@ require (
|
|||
github.com/muesli/termenv v0.16.0 // indirect
|
||||
github.com/multiformats/go-multihash v0.2.3 // indirect
|
||||
github.com/multiformats/go-varint v0.0.6 // indirect
|
||||
github.com/olekukonko/cat v0.0.0-20250911104152-50322a0618f6 // indirect
|
||||
github.com/olekukonko/errors v1.2.0 // indirect
|
||||
github.com/olekukonko/ll v0.1.8 // indirect
|
||||
github.com/pion/datachannel v1.5.9 // indirect
|
||||
github.com/pion/dtls/v3 v3.0.3 // indirect
|
||||
github.com/pion/ice/v4 v4.0.2 // indirect
|
||||
|
|
@ -99,7 +106,7 @@ require (
|
|||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/rs/dnscache v0.0.0-20211102005908-e0241e321417 // indirect
|
||||
github.com/spaolacci/murmur3 v1.1.0 // indirect
|
||||
github.com/spf13/pflag v1.0.5 // indirect
|
||||
github.com/spf13/pflag v1.0.10 // indirect
|
||||
github.com/tidwall/btree v1.8.1 // indirect
|
||||
github.com/wlynxg/anet v0.0.3 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
|
|
@ -112,7 +119,7 @@ require (
|
|||
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 // indirect
|
||||
golang.org/x/net v0.47.0 // indirect
|
||||
golang.org/x/sync v0.18.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/sys v0.42.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
lukechampine.com/blake3 v1.1.6 // indirect
|
||||
modernc.org/libc v1.22.3 // indirect
|
||||
|
|
|
|||
33
go.sum
33
go.sum
|
|
@ -112,6 +112,8 @@ github.com/catppuccin/go v0.3.0/go.mod h1:8IHJuMGaUUjQM82qBrGNBv7LFq6JI3NnQCF6MO
|
|||
github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko=
|
||||
github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc=
|
||||
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/charmbracelet/bubbles v0.21.1-0.20250623103423-23b8fd6302d7 h1:JFgG/xnwFfbezlUnFMJy0nusZvytYysV4SCS2cYbvws=
|
||||
github.com/charmbracelet/bubbles v0.21.1-0.20250623103423-23b8fd6302d7/go.mod h1:ISC1gtLcVilLOf23wvTfoQuYbW2q0JevFxPfUzZ9Ybw=
|
||||
github.com/charmbracelet/bubbletea v1.3.6 h1:VkHIxPJQeDt0aFJIsVxw8BQdh/F/L2KKZGsK6et5taU=
|
||||
|
|
@ -141,7 +143,12 @@ github.com/charmbracelet/x/termios v0.1.1/go.mod h1:rB7fnv1TgOPOyyKRJ9o+AsTU/vK5
|
|||
github.com/charmbracelet/x/xpty v0.1.2 h1:Pqmu4TEJ8KeA9uSkISKMU3f+C1F6OGBn8ABuGlqCbtI=
|
||||
github.com/charmbracelet/x/xpty v0.1.2/go.mod h1:XK2Z0id5rtLWcpeNiMYBccNNBrP2IJnzHI0Lq13Xzq4=
|
||||
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
|
||||
github.com/clipperhouse/displaywidth v0.11.0 h1:lBc6kY44VFw+TDx4I8opi/EtL9m20WSEFgwIwO+UVM8=
|
||||
github.com/clipperhouse/displaywidth v0.11.0/go.mod h1:bkrFNkf81G8HyVqmKGxsPufD3JhNl3dSqnGhOoSD/o0=
|
||||
github.com/clipperhouse/uax29/v2 v2.7.0 h1:+gs4oBZ2gPfVrKPthwbMzWZDaAFPGYK72F0NJv2v7Vk=
|
||||
github.com/clipperhouse/uax29/v2 v2.7.0/go.mod h1:EFJ2TJMRUaplDxHKj1qAEhCtQPW2tJSwu5BF98AuoVM=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s=
|
||||
github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE=
|
||||
|
|
@ -162,6 +169,8 @@ github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6
|
|||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
|
||||
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
|
||||
github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU=
|
||||
github.com/fatih/color v1.19.0 h1:Zp3PiM21/9Ld6FzSKyL5c/BULoe/ONr9KlbYVOfG8+w=
|
||||
github.com/fatih/color v1.19.0/go.mod h1:zNk67I0ZUT1bEGsSGyCZYZNrHuTkJJB+r6Q9VuMi0LE=
|
||||
github.com/frankban/quicktest v1.9.0/go.mod h1:ui7WezCLWMWxVWr1GETZY3smRy0G4KWq9vcPtJmFl7Y=
|
||||
github.com/frankban/quicktest v1.14.4/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||
|
|
@ -189,6 +198,8 @@ github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre
|
|||
github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7eI=
|
||||
github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow=
|
||||
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
|
||||
github.com/goccy/go-json v0.10.6 h1:p8HrPJzOakx/mn/bQtjgNjdTcN+/S6FcG2CTtQOrHVU=
|
||||
github.com/goccy/go-json v0.10.6/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
|
||||
github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
|
||||
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
|
||||
|
|
@ -264,6 +275,8 @@ github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69
|
|||
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
|
||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
|
||||
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
|
|
@ -272,6 +285,8 @@ github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+Ei
|
|||
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
|
||||
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
|
||||
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/mattn/go-runewidth v0.0.21 h1:jJKAZiQH+2mIinzCJIaIG9Be1+0NR+5sz/lYEEjdM8w=
|
||||
github.com/mattn/go-runewidth v0.0.21/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
||||
github.com/minio/sha256-simd v1.0.0 h1:v1ta+49hkWZyvaKwrQB8elexRqm6Y0aMLjCNsrYxo6g=
|
||||
github.com/minio/sha256-simd v1.0.0/go.mod h1:OuYzVNI5vcoYIAmbIvHPl3N3jUzVedXbKy5RFepssQM=
|
||||
|
|
@ -297,8 +312,16 @@ github.com/multiformats/go-multihash v0.2.3/go.mod h1:dXgKXCXjBzdscBLk9JkjINiEsC
|
|||
github.com/multiformats/go-varint v0.0.6 h1:gk85QWKxh3TazbLxED/NlDVv8+q+ReFJk7Y2W/KhfNY=
|
||||
github.com/multiformats/go-varint v0.0.6/go.mod h1:3Ls8CIEsrijN6+B7PbrXRPxHRPuXSrVKRY101jdMZYE=
|
||||
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
|
||||
github.com/olekukonko/cat v0.0.0-20250911104152-50322a0618f6 h1:zrbMGy9YXpIeTnGj4EljqMiZsIcE09mmF8XsD5AYOJc=
|
||||
github.com/olekukonko/cat v0.0.0-20250911104152-50322a0618f6/go.mod h1:rEKTHC9roVVicUIfZK7DYrdIoM0EOr8mK1Hj5s3JjH0=
|
||||
github.com/olekukonko/errors v1.2.0 h1:10Zcn4GeV59t/EGqJc8fUjtFT/FuUh5bTMzZ1XwmCRo=
|
||||
github.com/olekukonko/errors v1.2.0/go.mod h1:ppzxA5jBKcO1vIpCXQ9ZqgDh8iwODz6OXIGKU8r5m4Y=
|
||||
github.com/olekukonko/ll v0.1.8 h1:ysHCJRGHYKzmBSdz9w5AySztx7lG8SQY+naTGYUbsz8=
|
||||
github.com/olekukonko/ll v0.1.8/go.mod h1:RPRC6UcscfFZgjo1nulkfMH5IM0QAYim0LfnMvUuozw=
|
||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
||||
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
||||
github.com/olekukonko/tablewriter v1.1.4 h1:ORUMI3dXbMnRlRggJX3+q7OzQFDdvgbN9nVWj1drm6I=
|
||||
github.com/olekukonko/tablewriter v1.1.4/go.mod h1:+kedxuyTtgoZLwif3P1Em4hARJs+mVnzKxmsCL/C5RY=
|
||||
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
|
||||
github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
|
||||
github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
|
||||
|
|
@ -390,8 +413,13 @@ github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0b
|
|||
github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
|
||||
github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=
|
||||
github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y=
|
||||
github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU=
|
||||
github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4=
|
||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
|
||||
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
|
|
@ -437,6 +465,7 @@ go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgf
|
|||
go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI=
|
||||
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
|
||||
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
|
|
@ -516,6 +545,8 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|||
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
|
||||
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
|
|
@ -533,6 +564,8 @@ golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
|||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
||||
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
||||
golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
|
||||
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
|
||||
golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
|
|
|
|||
|
|
@ -82,15 +82,18 @@ func TestHeartbeat(t *testing.T) {
|
|||
if req.AgentID != "agent-123" {
|
||||
t.Errorf("agentId = %q, want agent-123", req.AgentID)
|
||||
}
|
||||
json.NewEncoder(w).Encode(StatusResponse{Success: true})
|
||||
json.NewEncoder(w).Encode(HeartbeatResponse{Success: true})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key", "unarr-test")
|
||||
err := c.Heartbeat(context.Background(), HeartbeatRequest{AgentID: "agent-123"})
|
||||
resp, err := c.Heartbeat(context.Background(), HeartbeatRequest{AgentID: "agent-123"})
|
||||
if err != nil {
|
||||
t.Fatalf("Heartbeat failed: %v", err)
|
||||
}
|
||||
if !resp.Success {
|
||||
t.Error("expected success=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaimTasks(t *testing.T) {
|
||||
|
|
@ -115,21 +118,21 @@ func TestClaimTasks(t *testing.T) {
|
|||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key", "unarr-test")
|
||||
tasks, err := c.ClaimTasks(context.Background(), "agent-123")
|
||||
resp, err := c.ClaimTasks(context.Background(), "agent-123")
|
||||
if err != nil {
|
||||
t.Fatalf("ClaimTasks failed: %v", err)
|
||||
}
|
||||
if len(tasks) != 1 {
|
||||
t.Fatalf("len(tasks) = %d, want 1", len(tasks))
|
||||
if len(resp.Tasks) != 1 {
|
||||
t.Fatalf("len(tasks) = %d, want 1", len(resp.Tasks))
|
||||
}
|
||||
if tasks[0].ID != "task-uuid-1" {
|
||||
t.Errorf("task.ID = %q, want task-uuid-1", tasks[0].ID)
|
||||
if resp.Tasks[0].ID != "task-uuid-1" {
|
||||
t.Errorf("task.ID = %q, want task-uuid-1", resp.Tasks[0].ID)
|
||||
}
|
||||
if tasks[0].InfoHash != "abc123def456abc123def456abc123def456abc1" {
|
||||
t.Errorf("task.InfoHash = %q", tasks[0].InfoHash)
|
||||
if resp.Tasks[0].InfoHash != "abc123def456abc123def456abc123def456abc1" {
|
||||
t.Errorf("task.InfoHash = %q", resp.Tasks[0].InfoHash)
|
||||
}
|
||||
if tasks[0].PreferredMethod != "auto" {
|
||||
t.Errorf("task.PreferredMethod = %q, want auto", tasks[0].PreferredMethod)
|
||||
if resp.Tasks[0].PreferredMethod != "auto" {
|
||||
t.Errorf("task.PreferredMethod = %q, want auto", resp.Tasks[0].PreferredMethod)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -177,12 +180,12 @@ func TestClaimTasksEmpty(t *testing.T) {
|
|||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key", "unarr-test")
|
||||
tasks, err := c.ClaimTasks(context.Background(), "agent-123")
|
||||
resp, err := c.ClaimTasks(context.Background(), "agent-123")
|
||||
if err != nil {
|
||||
t.Fatalf("ClaimTasks failed: %v", err)
|
||||
}
|
||||
if len(tasks) != 0 {
|
||||
t.Errorf("expected empty tasks, got %d", len(tasks))
|
||||
if len(resp.Tasks) != 0 {
|
||||
t.Errorf("expected empty tasks, got %d", len(resp.Tasks))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -276,10 +279,107 @@ func TestUserAgent(t *testing.T) {
|
|||
if r.Header.Get("User-Agent") != "unarr/0.2.0" {
|
||||
t.Errorf("User-Agent = %q, want unarr/0.2.0", r.Header.Get("User-Agent"))
|
||||
}
|
||||
json.NewEncoder(w).Encode(StatusResponse{Success: true})
|
||||
json.NewEncoder(w).Encode(HeartbeatResponse{Success: true})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key", "unarr/0.2.0")
|
||||
c.Heartbeat(context.Background(), HeartbeatRequest{AgentID: "x"})
|
||||
}
|
||||
|
||||
func TestHeartbeatWithUpgradeSignal(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(HeartbeatResponse{
|
||||
Success: true,
|
||||
Upgrade: &UpgradeSignal{Version: "2.0.0"},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key", "unarr-test")
|
||||
resp, err := c.Heartbeat(context.Background(), HeartbeatRequest{AgentID: "agent-1"})
|
||||
if err != nil {
|
||||
t.Fatalf("Heartbeat failed: %v", err)
|
||||
}
|
||||
if resp.Upgrade == nil {
|
||||
t.Fatal("expected upgrade signal, got nil")
|
||||
}
|
||||
if resp.Upgrade.Version != "2.0.0" {
|
||||
t.Errorf("upgrade version = %q, want 2.0.0", resp.Upgrade.Version)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeartbeatWithoutUpgradeSignal(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(HeartbeatResponse{Success: true})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key", "unarr-test")
|
||||
resp, err := c.Heartbeat(context.Background(), HeartbeatRequest{AgentID: "agent-1"})
|
||||
if err != nil {
|
||||
t.Fatalf("Heartbeat failed: %v", err)
|
||||
}
|
||||
if resp.Upgrade != nil {
|
||||
t.Errorf("expected no upgrade signal, got %+v", resp.Upgrade)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReportUpgradeResult(t *testing.T) {
|
||||
var received UpgradeResult
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/internal/agent/upgrade-result" {
|
||||
t.Errorf("path = %s, want /api/internal/agent/upgrade-result", r.URL.Path)
|
||||
}
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("method = %s, want POST", r.Method)
|
||||
}
|
||||
json.NewDecoder(r.Body).Decode(&received)
|
||||
json.NewEncoder(w).Encode(struct{ Success bool }{Success: true})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key", "unarr-test")
|
||||
err := c.ReportUpgradeResult(context.Background(), UpgradeResult{
|
||||
AgentID: "agent-1",
|
||||
Success: true,
|
||||
Version: "2.0.0",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ReportUpgradeResult failed: %v", err)
|
||||
}
|
||||
if received.AgentID != "agent-1" {
|
||||
t.Errorf("agentId = %q, want agent-1", received.AgentID)
|
||||
}
|
||||
if !received.Success {
|
||||
t.Error("expected success=true")
|
||||
}
|
||||
if received.Version != "2.0.0" {
|
||||
t.Errorf("version = %q, want 2.0.0", received.Version)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReportUpgradeResultFailure(t *testing.T) {
|
||||
var received UpgradeResult
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewDecoder(r.Body).Decode(&received)
|
||||
json.NewEncoder(w).Encode(struct{ Success bool }{Success: true})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key", "unarr-test")
|
||||
err := c.ReportUpgradeResult(context.Background(), UpgradeResult{
|
||||
AgentID: "agent-1",
|
||||
Success: false,
|
||||
Error: "checksum mismatch",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ReportUpgradeResult failed: %v", err)
|
||||
}
|
||||
if received.Success {
|
||||
t.Error("expected success=false")
|
||||
}
|
||||
if received.Error != "checksum mismatch" {
|
||||
t.Errorf("error = %q, want 'checksum mismatch'", received.Error)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -44,6 +44,9 @@ type Daemon struct {
|
|||
// Exposed tickers for hot-reload
|
||||
PollTicker *time.Ticker
|
||||
HeartbeatTicker *time.Ticker
|
||||
|
||||
// pollNow triggers an immediate poll (e.g. on resume)
|
||||
pollNow chan struct{}
|
||||
}
|
||||
|
||||
// NewDaemon creates a daemon with the given transport.
|
||||
|
|
@ -59,6 +62,7 @@ func NewDaemon(cfg DaemonConfig, transport Transport) *Daemon {
|
|||
return &Daemon{
|
||||
cfg: cfg,
|
||||
transport: transport,
|
||||
pollNow: make(chan struct{}, 1),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -151,6 +155,9 @@ func (d *Daemon) Run(ctx context.Context) error {
|
|||
if d.transport.Mode() == "http" {
|
||||
d.poll(ctx)
|
||||
}
|
||||
|
||||
case <-d.pollNow:
|
||||
d.poll(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -236,6 +243,15 @@ func (d *Daemon) handleEvent(event ServerEvent) {
|
|||
}
|
||||
}
|
||||
|
||||
// TriggerPoll requests an immediate task poll cycle.
|
||||
// Used when a resume event is received to pick up re-pending tasks faster.
|
||||
func (d *Daemon) TriggerPoll() {
|
||||
select {
|
||||
case d.pollNow <- struct{}{}:
|
||||
default: // already pending
|
||||
}
|
||||
}
|
||||
|
||||
// ClearUpgradeInProgress resets the upgrade flag so a retry can be attempted.
|
||||
func (d *Daemon) ClearUpgradeInProgress() {
|
||||
d.upgradeInProgress = false
|
||||
|
|
|
|||
72
internal/agent/state.go
Normal file
72
internal/agent/state.go
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/torrentclaw/torrentclaw-cli/internal/config"
|
||||
)
|
||||
|
||||
// DaemonState is written to disk every heartbeat for external tools to read.
|
||||
type DaemonState struct {
|
||||
AgentID string `json:"agentId"`
|
||||
Status string `json:"status"` // running | upgrading | shutting_down
|
||||
Version string `json:"version"`
|
||||
PID int `json:"pid"`
|
||||
StartedAt time.Time `json:"startedAt"`
|
||||
LastHeartbeat time.Time `json:"lastHeartbeat"`
|
||||
ActiveTasks int `json:"activeTasks"`
|
||||
CompletedCount int `json:"completedCount"`
|
||||
FailedCount int `json:"failedCount"`
|
||||
TotalDownloaded int64 `json:"totalDownloaded"`
|
||||
MethodStats map[string]int `json:"methodStats,omitempty"`
|
||||
}
|
||||
|
||||
// stateFilePathFn is overridable for testing.
|
||||
var stateFilePathFn = func() string {
|
||||
return filepath.Join(config.DataDir(), "daemon.state.json")
|
||||
}
|
||||
|
||||
// StateFilePath returns the path to the daemon state file.
|
||||
func StateFilePath() string {
|
||||
return stateFilePathFn()
|
||||
}
|
||||
|
||||
// WriteState writes the daemon state to disk (best-effort, never errors).
|
||||
func WriteState(state *DaemonState) {
|
||||
path := StateFilePath()
|
||||
dir := filepath.Dir(path)
|
||||
os.MkdirAll(dir, 0o755)
|
||||
|
||||
data, err := json.MarshalIndent(state, "", " ")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Write to temp file then rename for atomicity
|
||||
tmp := path + ".tmp"
|
||||
if err := os.WriteFile(tmp, data, 0o644); err != nil {
|
||||
return
|
||||
}
|
||||
os.Rename(tmp, path)
|
||||
}
|
||||
|
||||
// ReadState reads the daemon state from disk. Returns nil if not found.
|
||||
func ReadState() *DaemonState {
|
||||
data, err := os.ReadFile(StateFilePath())
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
var state DaemonState
|
||||
if json.Unmarshal(data, &state) != nil {
|
||||
return nil
|
||||
}
|
||||
return &state
|
||||
}
|
||||
|
||||
// RemoveState deletes the state file (called on clean shutdown).
|
||||
func RemoveState() {
|
||||
os.Remove(StateFilePath())
|
||||
}
|
||||
106
internal/agent/state_test.go
Normal file
106
internal/agent/state_test.go
Normal file
|
|
@ -0,0 +1,106 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestWriteAndReadState(t *testing.T) {
|
||||
// Override the state file path for testing
|
||||
tmpDir := t.TempDir()
|
||||
origFn := stateFilePathFn
|
||||
stateFilePathFn = func() string { return filepath.Join(tmpDir, "daemon.state.json") }
|
||||
defer func() { stateFilePathFn = origFn }()
|
||||
|
||||
state := &DaemonState{
|
||||
AgentID: "agent-123",
|
||||
Status: "running",
|
||||
Version: "1.0.0",
|
||||
PID: 12345,
|
||||
StartedAt: time.Now().Truncate(time.Second),
|
||||
LastHeartbeat: time.Now().Truncate(time.Second),
|
||||
ActiveTasks: 3,
|
||||
CompletedCount: 10,
|
||||
FailedCount: 2,
|
||||
TotalDownloaded: 1024 * 1024 * 500,
|
||||
MethodStats: map[string]int{"torrent": 8, "debrid": 2},
|
||||
}
|
||||
|
||||
WriteState(state)
|
||||
|
||||
read := ReadState()
|
||||
if read == nil {
|
||||
t.Fatal("ReadState() returned nil")
|
||||
}
|
||||
if read.AgentID != "agent-123" {
|
||||
t.Errorf("AgentID = %q, want agent-123", read.AgentID)
|
||||
}
|
||||
if read.Status != "running" {
|
||||
t.Errorf("Status = %q, want running", read.Status)
|
||||
}
|
||||
if read.Version != "1.0.0" {
|
||||
t.Errorf("Version = %q, want 1.0.0", read.Version)
|
||||
}
|
||||
if read.PID != 12345 {
|
||||
t.Errorf("PID = %d, want 12345", read.PID)
|
||||
}
|
||||
if read.ActiveTasks != 3 {
|
||||
t.Errorf("ActiveTasks = %d, want 3", read.ActiveTasks)
|
||||
}
|
||||
if read.CompletedCount != 10 {
|
||||
t.Errorf("CompletedCount = %d, want 10", read.CompletedCount)
|
||||
}
|
||||
if read.MethodStats["torrent"] != 8 {
|
||||
t.Errorf("MethodStats[torrent] = %d, want 8", read.MethodStats["torrent"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadStateNotFound(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
origFn := stateFilePathFn
|
||||
stateFilePathFn = func() string { return filepath.Join(tmpDir, "nonexistent.json") }
|
||||
defer func() { stateFilePathFn = origFn }()
|
||||
|
||||
state := ReadState()
|
||||
if state != nil {
|
||||
t.Errorf("ReadState() = %+v, want nil for missing file", state)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveState(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
origFn := stateFilePathFn
|
||||
stateFilePathFn = func() string { return filepath.Join(tmpDir, "daemon.state.json") }
|
||||
defer func() { stateFilePathFn = origFn }()
|
||||
|
||||
WriteState(&DaemonState{AgentID: "test"})
|
||||
|
||||
// Verify file exists
|
||||
path := StateFilePath()
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
t.Fatalf("state file should exist: %v", err)
|
||||
}
|
||||
|
||||
RemoveState()
|
||||
|
||||
if _, err := os.Stat(path); !os.IsNotExist(err) {
|
||||
t.Error("state file should be removed after RemoveState()")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadStateCorruptedJSON(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
origFn := stateFilePathFn
|
||||
path := filepath.Join(tmpDir, "daemon.state.json")
|
||||
stateFilePathFn = func() string { return path }
|
||||
defer func() { stateFilePathFn = origFn }()
|
||||
|
||||
os.WriteFile(path, []byte("not valid json{{{"), 0o644)
|
||||
|
||||
state := ReadState()
|
||||
if state != nil {
|
||||
t.Errorf("ReadState() should return nil for corrupted JSON, got %+v", state)
|
||||
}
|
||||
}
|
||||
|
|
@ -16,6 +16,7 @@ import (
|
|||
"github.com/torrentclaw/torrentclaw-cli/internal/agent"
|
||||
"github.com/torrentclaw/torrentclaw-cli/internal/config"
|
||||
"github.com/torrentclaw/torrentclaw-cli/internal/engine"
|
||||
"github.com/torrentclaw/torrentclaw-cli/internal/usenet/download"
|
||||
"github.com/torrentclaw/torrentclaw-cli/internal/upgrade"
|
||||
)
|
||||
|
||||
|
|
@ -117,6 +118,12 @@ func runDaemonStart() error {
|
|||
return fmt.Errorf("create download dir: %w", err)
|
||||
}
|
||||
|
||||
// Clean up stale resume files (>7 days old)
|
||||
resumeDir := filepath.Join(config.DataDir(), "resume")
|
||||
if removed := download.CleanStaleFiles(resumeDir, 7*24*time.Hour); removed > 0 {
|
||||
log.Printf("Cleaned %d stale resume file(s)", removed)
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
bold.Println(" unarr Daemon")
|
||||
fmt.Println()
|
||||
|
|
@ -314,7 +321,8 @@ func runDaemonStart() error {
|
|||
manager.PauseTask(taskID)
|
||||
cancelStreamTask(taskID)
|
||||
case "resume":
|
||||
log.Printf("[%s] resume requested via WebSocket", taskID[:8])
|
||||
log.Printf("[%s] resume requested via WebSocket, triggering poll", taskID[:8])
|
||||
d.TriggerPoll()
|
||||
case "stream":
|
||||
// Use registry mutex to prevent TOCTOU race with HTTP-polled stream requests
|
||||
streamRegistry.mu.Lock()
|
||||
|
|
|
|||
|
|
@ -2,7 +2,9 @@ package cmd
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
|
@ -125,13 +127,29 @@ func handleStreamTask(parentCtx context.Context, at agent.Task, reporter *engine
|
|||
Seeds: p.Seeds,
|
||||
FileName: p.FileName,
|
||||
})
|
||||
|
||||
// Terminal progress
|
||||
if p.TotalBytes > 0 {
|
||||
pct := int(float64(p.DownloadedBytes) / float64(p.TotalBytes) * 100)
|
||||
fmt.Fprintf(os.Stderr, "\r[%s] %d%% — %s/%s @ %s/s peers:%d seeds:%d",
|
||||
at.ID[:8], pct,
|
||||
ui.FormatBytes(p.DownloadedBytes), ui.FormatBytes(p.TotalBytes), ui.FormatBytes(p.SpeedBps),
|
||||
p.Peers, p.Seeds)
|
||||
}
|
||||
|
||||
if p.DownloadedBytes >= p.TotalBytes && p.TotalBytes > 0 {
|
||||
fmt.Fprint(os.Stderr, "\r\033[2K") // clear progress line
|
||||
task.Transition(engine.StatusCompleted)
|
||||
log.Printf("[%s] stream download complete, server stays up until cancelled", at.ID[:8])
|
||||
// Don't return — keep HTTP server running so the player
|
||||
// can finish reading. The stream stops when the user
|
||||
// cancels from the web or the daemon shuts down.
|
||||
<-ctx.Done()
|
||||
log.Printf("[%s] stream download complete, server stays up for 30m or until cancelled", at.ID[:8])
|
||||
// Keep HTTP server running so the player can finish reading.
|
||||
// Auto-shutdown after 30 minutes of idle to prevent resource leaks.
|
||||
idleTimer := time.NewTimer(30 * time.Minute)
|
||||
defer idleTimer.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-idleTimer.C:
|
||||
log.Printf("[%s] stream idle timeout (30m), shutting down", at.ID[:8])
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,8 +14,29 @@ func desktopNotify(title, body string) {
|
|||
case "darwin":
|
||||
script := `display notification "` + escapeAppleScript(body) + `" with title "` + escapeAppleScript(title) + `"`
|
||||
exec.Command("osascript", "-e", script).Start()
|
||||
case "windows":
|
||||
// Use PowerShell toast notification (Windows 10+)
|
||||
script := `[Windows.UI.Notifications.ToastNotificationManager, Windows.UI.Notifications, ContentType = WindowsRuntime] > $null;` +
|
||||
`$xml = [Windows.UI.Notifications.ToastNotificationManager]::GetTemplateContent(1);` +
|
||||
`$text = $xml.GetElementsByTagName('text');` +
|
||||
`$text[0].AppendChild($xml.CreateTextNode('` + escapePowerShell(title) + `')) > $null;` +
|
||||
`$text[1].AppendChild($xml.CreateTextNode('` + escapePowerShell(body) + `')) > $null;` +
|
||||
`$toast = [Windows.UI.Notifications.ToastNotification]::new($xml);` +
|
||||
`[Windows.UI.Notifications.ToastNotificationManager]::CreateToastNotifier('unarr').Show($toast)`
|
||||
exec.Command("powershell", "-NoProfile", "-Command", script).Start()
|
||||
}
|
||||
// Windows: no-op for now
|
||||
}
|
||||
|
||||
func escapePowerShell(s string) string {
|
||||
out := make([]byte, 0, len(s))
|
||||
for i := 0; i < len(s); i++ {
|
||||
if s[i] == '\'' {
|
||||
out = append(out, '\'', '\'') // double single-quote to escape
|
||||
} else {
|
||||
out = append(out, s[i])
|
||||
}
|
||||
}
|
||||
return string(out)
|
||||
}
|
||||
|
||||
func escapeAppleScript(s string) string {
|
||||
|
|
|
|||
46
internal/engine/notify_test.go
Normal file
46
internal/engine/notify_test.go
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
package engine
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestEscapePowerShell(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"hello", "hello"},
|
||||
{"it's done", "it''s done"},
|
||||
{"Tom's 'file'", "Tom''s ''file''"},
|
||||
{"no quotes", "no quotes"},
|
||||
{"", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
got := escapePowerShell(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("escapePowerShell(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEscapeAppleScript(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"hello", "hello"},
|
||||
{`say "hi"`, `say \"hi\"`},
|
||||
{`back\slash`, `back\\slash`},
|
||||
{"", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
got := escapeAppleScript(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("escapeAppleScript(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -12,6 +12,8 @@ import (
|
|||
var (
|
||||
yearRegex = regexp.MustCompile(`\b(19|20)\d{2}\b`)
|
||||
seasonRegex = regexp.MustCompile(`(?i)S(\d{2})`)
|
||||
episodeRegex = regexp.MustCompile(`(?i)S(\d{2})E(\d{2})`)
|
||||
altEpRegex = regexp.MustCompile(`(?i)(\d{1,2})x(\d{2})`) // 1x05 format
|
||||
)
|
||||
|
||||
// OrganizeConfig holds file organization settings.
|
||||
|
|
@ -37,9 +39,15 @@ func organize(result *Result, task *Task, cfg OrganizeConfig) (string, error) {
|
|||
isTV := strings.Contains(strings.ToLower(task.PreferredMethod), "show") ||
|
||||
seasonRegex.MatchString(result.FileName)
|
||||
|
||||
// Detect season for TV
|
||||
// Detect season for TV (S01E05 or 1x05 format)
|
||||
var season string
|
||||
if m := seasonRegex.FindStringSubmatch(result.FileName); len(m) > 1 {
|
||||
if m := episodeRegex.FindStringSubmatch(result.FileName); len(m) > 2 {
|
||||
season = m[1]
|
||||
isTV = true
|
||||
} else if m := altEpRegex.FindStringSubmatch(result.FileName); len(m) > 2 {
|
||||
season = fmt.Sprintf("%02s", m[1])
|
||||
isTV = true
|
||||
} else if m := seasonRegex.FindStringSubmatch(result.FileName); len(m) > 1 {
|
||||
season = m[1]
|
||||
isTV = true
|
||||
}
|
||||
|
|
@ -80,6 +88,23 @@ func organize(result *Result, task *Task, cfg OrganizeConfig) (string, error) {
|
|||
|
||||
destPath := filepath.Join(destDir, filepath.Base(result.FilePath))
|
||||
|
||||
// Check if source is a directory (multi-file torrent)
|
||||
srcInfo, err := os.Stat(result.FilePath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("stat source: %w", err)
|
||||
}
|
||||
|
||||
if srcInfo.IsDir() {
|
||||
// For directories: remove existing destination if present, then rename
|
||||
if _, err := os.Stat(destPath); err == nil {
|
||||
os.RemoveAll(destPath)
|
||||
}
|
||||
if err := os.Rename(result.FilePath, destPath); err != nil {
|
||||
return "", fmt.Errorf("move directory: %w", err)
|
||||
}
|
||||
return destPath, nil
|
||||
}
|
||||
|
||||
// Try rename first (same filesystem), fall back to copy+delete
|
||||
if err := os.Rename(result.FilePath, destPath); err != nil {
|
||||
if err := copyFile(result.FilePath, destPath); err != nil {
|
||||
|
|
|
|||
|
|
@ -71,6 +71,60 @@ func TestOrganizeTVShow(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestOrganizeTVShowAltFormat(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
srcFile := filepath.Join(tmp, "Show.3x12.HDTV.mkv")
|
||||
os.WriteFile(srcFile, []byte("data"), 0o644)
|
||||
|
||||
tvDir := filepath.Join(tmp, "TV Shows")
|
||||
|
||||
r := &Result{FilePath: srcFile, FileName: "Show.3x12.HDTV.mkv"}
|
||||
task := &Task{Title: "Show 3x12"}
|
||||
|
||||
path, err := organize(r, task, OrganizeConfig{
|
||||
Enabled: true,
|
||||
TVShowsDir: tvDir,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Should detect season 03 from "3x12" format
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
t.Errorf("organized file should exist at %s: %v", path, err)
|
||||
}
|
||||
// Verify it went into Season 03 directory
|
||||
dir := filepath.Dir(path)
|
||||
if filepath.Base(dir) != "Season 03" {
|
||||
t.Errorf("expected Season 03 directory, got %q", filepath.Base(dir))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSeasonDetectionFormats(t *testing.T) {
|
||||
tests := []struct {
|
||||
filename string
|
||||
wantTV bool
|
||||
}{
|
||||
{"Show.S01E05.720p.mkv", true},
|
||||
{"Show.s02e10.1080p.mkv", true},
|
||||
{"Show.3x12.HDTV.mkv", true},
|
||||
{"Show.12x01.mkv", true},
|
||||
{"Movie.2023.1080p.mkv", false},
|
||||
{"Just.A.Movie.mkv", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.filename, func(t *testing.T) {
|
||||
isTV := episodeRegex.MatchString(tt.filename) ||
|
||||
altEpRegex.MatchString(tt.filename) ||
|
||||
seasonRegex.MatchString(tt.filename)
|
||||
if isTV != tt.wantTV {
|
||||
t.Errorf("isTV(%q) = %v, want %v", tt.filename, isTV, tt.wantTV)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanTitle(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package engine
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
|
@ -233,7 +234,7 @@ func (s *StreamEngine) WaitBuffer(ctx context.Context, progressFn func(buffered,
|
|||
|
||||
// NewFileReader creates a new reader for the selected file.
|
||||
// Each HTTP request should get its own reader (not safe for concurrent use).
|
||||
func (s *StreamEngine) NewFileReader(ctx context.Context) torrent.Reader {
|
||||
func (s *StreamEngine) NewFileReader(ctx context.Context) io.ReadSeekCloser {
|
||||
reader := s.file.NewReader()
|
||||
reader.SetResponsive()
|
||||
reader.SetReadahead(5 * 1024 * 1024) // 5MB readahead
|
||||
|
|
|
|||
|
|
@ -3,9 +3,11 @@ package engine
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
|
@ -15,7 +17,7 @@ import (
|
|||
|
||||
// fileProvider abstracts where to get a file reader for streaming.
|
||||
type fileProvider interface {
|
||||
NewFileReader(ctx context.Context) torrent.Reader
|
||||
NewFileReader(ctx context.Context) io.ReadSeekCloser
|
||||
FileName() string
|
||||
}
|
||||
|
||||
|
|
@ -49,7 +51,7 @@ type torrentFileProvider struct {
|
|||
file *torrent.File
|
||||
}
|
||||
|
||||
func (p *torrentFileProvider) NewFileReader(ctx context.Context) torrent.Reader {
|
||||
func (p *torrentFileProvider) NewFileReader(ctx context.Context) io.ReadSeekCloser {
|
||||
reader := p.file.NewReader()
|
||||
reader.SetResponsive()
|
||||
reader.SetReadahead(5 * 1024 * 1024)
|
||||
|
|
@ -61,6 +63,33 @@ func (p *torrentFileProvider) FileName() string {
|
|||
return filepath.Base(p.file.DisplayPath())
|
||||
}
|
||||
|
||||
// diskFileProvider serves a file from disk.
|
||||
type diskFileProvider struct {
|
||||
path string
|
||||
name string
|
||||
}
|
||||
|
||||
func (p *diskFileProvider) NewFileReader(_ context.Context) io.ReadSeekCloser {
|
||||
f, err := os.Open(p.path)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
func (p *diskFileProvider) FileName() string { return p.name }
|
||||
|
||||
// NewStreamServerFromDisk creates a server that streams a file from disk.
|
||||
func NewStreamServerFromDisk(filePath string, port int) *StreamServer {
|
||||
return &StreamServer{
|
||||
provider: &diskFileProvider{
|
||||
path: filePath,
|
||||
name: filepath.Base(filePath),
|
||||
},
|
||||
port: port,
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins serving the file on localhost. Returns the full URL.
|
||||
func (ss *StreamServer) Start(ctx context.Context) (string, error) {
|
||||
mux := http.NewServeMux()
|
||||
|
|
@ -106,6 +135,10 @@ func (ss *StreamServer) Shutdown(ctx context.Context) error {
|
|||
|
||||
func (ss *StreamServer) handler(w http.ResponseWriter, r *http.Request) {
|
||||
reader := ss.provider.NewFileReader(r.Context())
|
||||
if reader == nil {
|
||||
http.Error(w, "file not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
w.Header().Set("Content-Type", mimeTypeFromExt(ss.provider.FileName()))
|
||||
|
|
|
|||
|
|
@ -146,13 +146,28 @@ func (d *TorrentDownloader) Download(ctx context.Context, task *Task, outputDir
|
|||
}
|
||||
|
||||
// 4. Determine file path
|
||||
// For multi-file torrents, fileName includes the torrent dir prefix (e.g. "TorrentName/file.mkv").
|
||||
// Try the full path first, then just the file inside the torrent dir.
|
||||
filePath := filepath.Join(d.cfg.DataDir, fileName)
|
||||
if _, statErr := os.Stat(filePath); statErr != nil {
|
||||
filePath = filepath.Join(d.cfg.DataDir, t.Name())
|
||||
// File might have been moved — try torrent directory
|
||||
dirPath := filepath.Join(d.cfg.DataDir, t.Name())
|
||||
if fi, statErr2 := os.Stat(dirPath); statErr2 == nil && fi.IsDir() {
|
||||
// Look for the actual file inside the directory
|
||||
base := filepath.Base(fileName)
|
||||
candidate := filepath.Join(dirPath, base)
|
||||
if _, statErr3 := os.Stat(candidate); statErr3 == nil {
|
||||
filePath = candidate
|
||||
} else {
|
||||
filePath = dirPath
|
||||
}
|
||||
} else {
|
||||
filePath = dirPath
|
||||
}
|
||||
}
|
||||
|
||||
result.FilePath = filePath
|
||||
result.FileName = fileName
|
||||
result.FileName = filepath.Base(fileName)
|
||||
result.Method = MethodTorrent
|
||||
result.Size = totalBytes
|
||||
|
||||
|
|
@ -211,6 +226,13 @@ func (d *TorrentDownloader) pollDownload(ctx context.Context, t *torrent.Torrent
|
|||
// Peer stats
|
||||
stats := t.Stats()
|
||||
|
||||
// Terminal progress
|
||||
pct := int(float64(downloaded) / float64(totalBytes) * 100)
|
||||
fmt.Fprintf(os.Stderr, "\r[%s] %d%% — %s/%s @ %s/s peers:%d seeds:%d",
|
||||
task.ID[:8], pct,
|
||||
formatBytes(downloaded), formatBytes(totalBytes), formatBytes(speed),
|
||||
stats.ActivePeers, stats.ConnectedSeeders)
|
||||
|
||||
// Report progress
|
||||
p := Progress{
|
||||
DownloadedBytes: downloaded,
|
||||
|
|
@ -230,6 +252,7 @@ func (d *TorrentDownloader) pollDownload(ctx context.Context, t *torrent.Torrent
|
|||
|
||||
// Check completion
|
||||
if downloaded >= totalBytes {
|
||||
fmt.Fprint(os.Stderr, "\r\033[2K") // clear progress line
|
||||
log.Printf("[%s] download complete: %s", task.ID[:8], fileName)
|
||||
return &Result{}, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/torrentclaw/torrentclaw-cli/internal/agent"
|
||||
"github.com/torrentclaw/torrentclaw-cli/internal/ui"
|
||||
"github.com/torrentclaw/torrentclaw-cli/internal/config"
|
||||
"github.com/torrentclaw/torrentclaw-cli/internal/usenet/download"
|
||||
"github.com/torrentclaw/torrentclaw-cli/internal/usenet/nntp"
|
||||
"github.com/torrentclaw/torrentclaw-cli/internal/usenet/nzb"
|
||||
|
|
@ -125,11 +125,23 @@ func (u *UsenetDownloader) Download(ctx context.Context, task *Task, outputDir s
|
|||
|
||||
log.Printf("[%s] NZB: %s", shortID, nzbTitle)
|
||||
|
||||
// Step 2: Download NZB file
|
||||
nzbData, err := u.apiClient.DownloadNzb(dlCtx, nzbID)
|
||||
// Step 2: Download NZB file (or use cached version for resume)
|
||||
resumeDir := filepath.Join(config.DataDir(), "resume")
|
||||
nzbCachePath := filepath.Join(resumeDir, task.ID+".nzb")
|
||||
|
||||
nzbData, err := os.ReadFile(nzbCachePath)
|
||||
if err != nil {
|
||||
// Not cached — download from server
|
||||
nzbData, err = u.apiClient.DownloadNzb(dlCtx, nzbID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("download NZB: %w", err)
|
||||
}
|
||||
// Cache for future resume
|
||||
os.MkdirAll(resumeDir, 0o755)
|
||||
os.WriteFile(nzbCachePath, nzbData, 0o644)
|
||||
} else {
|
||||
log.Printf("[%s] using cached NZB", shortID)
|
||||
}
|
||||
|
||||
// Step 3: Parse NZB
|
||||
nzbFile, err := nzb.ParseBytes(nzbData)
|
||||
|
|
@ -140,7 +152,15 @@ func (u *UsenetDownloader) Download(ctx context.Context, task *Task, outputDir s
|
|||
totalBytes := nzbFile.TotalBytes()
|
||||
totalSegs := nzbFile.TotalSegments()
|
||||
log.Printf("[%s] NZB parsed: %d files, %d segments, %s",
|
||||
shortID, len(nzbFile.Files), totalSegs, ui.FormatBytes(totalBytes))
|
||||
shortID, len(nzbFile.Files), totalSegs, formatBytes(totalBytes))
|
||||
|
||||
// Step 3.5: Resume support — load or create progress tracker
|
||||
tracker := download.NewProgressTracker(task.ID, nzbFile, resumeDir)
|
||||
resumed, _ := tracker.Load()
|
||||
if resumed {
|
||||
log.Printf("[%s] resuming usenet download (%d/%d segments completed)",
|
||||
shortID, tracker.TotalCompleted(), totalSegs)
|
||||
}
|
||||
|
||||
// Step 4: Get NNTP credentials and connect
|
||||
creds, err := u.getCredentials(dlCtx)
|
||||
|
|
@ -185,7 +205,7 @@ func (u *UsenetDownloader) Download(ctx context.Context, task *Task, outputDir s
|
|||
}
|
||||
}()
|
||||
|
||||
downloadedFiles, err := dl.DownloadNZB(dlCtx, nzbFile, taskDir, dlProgressCh)
|
||||
downloadedFiles, err := dl.DownloadNZB(dlCtx, nzbFile, taskDir, tracker, dlProgressCh)
|
||||
close(dlProgressCh)
|
||||
|
||||
if err != nil {
|
||||
|
|
@ -234,6 +254,9 @@ func (u *UsenetDownloader) Download(ctx context.Context, task *Task, outputDir s
|
|||
finalSize = fi.Size()
|
||||
}
|
||||
|
||||
// Clean up resume state on successful completion
|
||||
tracker.Remove()
|
||||
|
||||
return &Result{
|
||||
FilePath: finalPath,
|
||||
FileName: filepath.Base(finalPath),
|
||||
|
|
|
|||
146
internal/upgrade/download.go
Normal file
146
internal/upgrade/download.go
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
package upgrade
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var httpClient = &http.Client{Timeout: 120 * time.Second}
|
||||
|
||||
// download fetches the release archive to a temporary file.
|
||||
func download(ctx context.Context, version string) (string, error) {
|
||||
url := releaseURL(version, archiveName(version))
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("User-Agent", "unarr-updater")
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("fetch %s: %w", url, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("fetch %s: HTTP %d", url, resp.StatusCode)
|
||||
}
|
||||
|
||||
tmp, err := os.CreateTemp("", "unarr-download-*.tmp")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer tmp.Close()
|
||||
|
||||
if _, err := io.Copy(tmp, resp.Body); err != nil {
|
||||
os.Remove(tmp.Name())
|
||||
return "", fmt.Errorf("write archive: %w", err)
|
||||
}
|
||||
|
||||
return tmp.Name(), nil
|
||||
}
|
||||
|
||||
// verifyChecksum downloads checksums.txt and verifies the archive's SHA256.
|
||||
func verifyChecksum(ctx context.Context, version, archivePath string) error {
|
||||
// Download checksums.txt
|
||||
url := releaseURL(version, "checksums.txt")
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("User-Agent", "unarr-updater")
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetch checksums: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("fetch checksums: HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Parse checksums.txt — format: "<sha256> <filename>"
|
||||
expectedName := archiveName(version)
|
||||
var expectedHash string
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
parts := strings.Fields(line)
|
||||
if len(parts) >= 2 && parts[1] == expectedName {
|
||||
expectedHash = parts[0]
|
||||
break
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return fmt.Errorf("read checksums: %w", err)
|
||||
}
|
||||
|
||||
if expectedHash == "" {
|
||||
return fmt.Errorf("no checksum found for %s in checksums.txt", expectedName)
|
||||
}
|
||||
|
||||
// Compute SHA256 of the downloaded archive
|
||||
f, err := os.Open(archivePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
h := sha256.New()
|
||||
if _, err := io.Copy(h, f); err != nil {
|
||||
return fmt.Errorf("hash archive: %w", err)
|
||||
}
|
||||
|
||||
actualHash := hex.EncodeToString(h.Sum(nil))
|
||||
if !strings.EqualFold(actualHash, expectedHash) {
|
||||
return fmt.Errorf("SHA256 mismatch: expected %s, got %s", expectedHash, actualHash)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// fetchLatestVersion queries GitHub API for the latest release tag.
|
||||
func fetchLatestVersion(ctx context.Context) (string, error) {
|
||||
url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", githubRepo)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Accept", "application/vnd.github+json")
|
||||
req.Header.Set("User-Agent", "unarr-updater")
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("fetch latest release: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("GitHub API: HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var release struct {
|
||||
TagName string `json:"tag_name"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&release); err != nil {
|
||||
return "", fmt.Errorf("decode response: %w", err)
|
||||
}
|
||||
|
||||
if release.TagName == "" {
|
||||
return "", fmt.Errorf("empty tag_name in release")
|
||||
}
|
||||
|
||||
return strings.TrimPrefix(release.TagName, "v"), nil
|
||||
}
|
||||
123
internal/upgrade/extract.go
Normal file
123
internal/upgrade/extract.go
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
package upgrade
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"archive/zip"
|
||||
"compress/gzip"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// extractBinary extracts the unarr binary from the release archive into destDir.
|
||||
// Returns the path to the extracted binary.
|
||||
func extractBinary(archivePath, destDir string) (string, error) {
|
||||
if runtime.GOOS == "windows" {
|
||||
return extractZip(archivePath, destDir)
|
||||
}
|
||||
return extractTarGz(archivePath, destDir)
|
||||
}
|
||||
|
||||
func extractTarGz(archivePath, destDir string) (string, error) {
|
||||
f, err := os.Open(archivePath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
gz, err := gzip.NewReader(f)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("gzip: %w", err)
|
||||
}
|
||||
defer gz.Close()
|
||||
|
||||
tr := tar.NewReader(gz)
|
||||
target := binaryName
|
||||
if runtime.GOOS == "windows" {
|
||||
target += ".exe"
|
||||
}
|
||||
|
||||
for {
|
||||
hdr, err := tr.Next()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("tar: %w", err)
|
||||
}
|
||||
|
||||
name := filepath.Base(hdr.Name)
|
||||
if name != target {
|
||||
continue
|
||||
}
|
||||
|
||||
// Validate: must be a regular file
|
||||
if hdr.Typeflag != tar.TypeReg {
|
||||
continue
|
||||
}
|
||||
|
||||
dst := filepath.Join(destDir, target)
|
||||
out, err := os.OpenFile(dst, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o755)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if _, err := io.Copy(out, io.LimitReader(tr, 200<<20)); err != nil { // 200MB limit
|
||||
out.Close()
|
||||
return "", fmt.Errorf("extract: %w", err)
|
||||
}
|
||||
out.Close()
|
||||
return dst, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("binary %q not found in archive", target)
|
||||
}
|
||||
|
||||
func extractZip(archivePath, destDir string) (string, error) {
|
||||
r, err := zip.OpenReader(archivePath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("zip: %w", err)
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
target := binaryName + ".exe"
|
||||
|
||||
for _, f := range r.File {
|
||||
name := filepath.Base(f.Name)
|
||||
|
||||
// Guard against path traversal
|
||||
if strings.Contains(f.Name, "..") {
|
||||
continue
|
||||
}
|
||||
|
||||
if name != target {
|
||||
continue
|
||||
}
|
||||
|
||||
rc, err := f.Open()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
dst := filepath.Join(destDir, target)
|
||||
out, err := os.OpenFile(dst, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o755)
|
||||
if err != nil {
|
||||
rc.Close()
|
||||
return "", err
|
||||
}
|
||||
|
||||
if _, err := io.Copy(out, io.LimitReader(rc, 200<<20)); err != nil { // 200MB limit
|
||||
out.Close()
|
||||
rc.Close()
|
||||
return "", fmt.Errorf("extract: %w", err)
|
||||
}
|
||||
out.Close()
|
||||
rc.Close()
|
||||
return dst, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("binary %q not found in archive", target)
|
||||
}
|
||||
226
internal/upgrade/upgrade.go
Normal file
226
internal/upgrade/upgrade.go
Normal file
|
|
@ -0,0 +1,226 @@
|
|||
// Package upgrade implements safe self-update for the unarr binary.
|
||||
//
|
||||
// The upgrade process:
|
||||
// 1. Detect current binary path and verify write permissions
|
||||
// 2. Download the release archive from GitHub
|
||||
// 3. Verify SHA256 checksum against checksums.txt
|
||||
// 4. Extract the binary from the archive
|
||||
// 5. Smoke test: run the new binary with "version" to confirm it works
|
||||
// 6. Backup the current binary
|
||||
// 7. Replace with the new binary (preserving permissions)
|
||||
// 8. On any failure: rollback from backup
|
||||
package upgrade
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
githubRepo = "torrentclaw/unarr"
|
||||
binaryName = "unarr"
|
||||
smokeTestTO = 5 * time.Second
|
||||
)
|
||||
|
||||
// Result represents the outcome of an upgrade attempt.
|
||||
type Result struct {
|
||||
Success bool
|
||||
OldVersion string
|
||||
NewVersion string
|
||||
BackupPath string
|
||||
Error error
|
||||
}
|
||||
|
||||
// Upgrader handles downloading, verifying, and replacing the CLI binary.
|
||||
type Upgrader struct {
|
||||
CurrentVersion string
|
||||
// OnProgress is called with status messages during the upgrade process.
|
||||
OnProgress func(msg string)
|
||||
}
|
||||
|
||||
func (u *Upgrader) log(msg string) {
|
||||
if u.OnProgress != nil {
|
||||
u.OnProgress(msg)
|
||||
}
|
||||
log.Printf("[upgrade] %s", msg)
|
||||
}
|
||||
|
||||
// Execute performs a full upgrade to the target version.
|
||||
func (u *Upgrader) Execute(ctx context.Context, targetVersion string) Result {
|
||||
targetVersion = strings.TrimPrefix(targetVersion, "v")
|
||||
|
||||
if targetVersion == u.CurrentVersion {
|
||||
return Result{Success: true, OldVersion: u.CurrentVersion, NewVersion: targetVersion}
|
||||
}
|
||||
|
||||
// 1. Detect current binary path
|
||||
binPath, err := os.Executable()
|
||||
if err != nil {
|
||||
return u.fail("detect binary: %v", err)
|
||||
}
|
||||
binPath, err = filepath.EvalSymlinks(binPath)
|
||||
if err != nil {
|
||||
return u.fail("resolve symlinks: %v", err)
|
||||
}
|
||||
|
||||
// 2. Check Docker — self-update makes no sense in a container
|
||||
if isDocker() {
|
||||
return u.fail("running in Docker — update the container image instead")
|
||||
}
|
||||
|
||||
// 3. Check write permissions
|
||||
binDir := filepath.Dir(binPath)
|
||||
if err := checkWritable(binDir); err != nil {
|
||||
return u.fail("no write permission to %s — run with elevated privileges or move the binary to a user-writable location", binDir)
|
||||
}
|
||||
|
||||
// 4. Download archive
|
||||
u.log(fmt.Sprintf("Downloading v%s...", targetVersion))
|
||||
archivePath, err := download(ctx, targetVersion)
|
||||
if err != nil {
|
||||
return u.fail("download: %v", err)
|
||||
}
|
||||
defer os.Remove(archivePath)
|
||||
|
||||
// 5. Verify checksum
|
||||
u.log("Verifying checksum...")
|
||||
if err := verifyChecksum(ctx, targetVersion, archivePath); err != nil {
|
||||
return u.fail("checksum: %v", err)
|
||||
}
|
||||
|
||||
// 6. Extract binary
|
||||
u.log("Extracting...")
|
||||
tmpDir, err := os.MkdirTemp("", "unarr-upgrade-*")
|
||||
if err != nil {
|
||||
return u.fail("create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
newBinPath, err := extractBinary(archivePath, tmpDir)
|
||||
if err != nil {
|
||||
return u.fail("extract: %v", err)
|
||||
}
|
||||
|
||||
// 7. Smoke test
|
||||
u.log("Verifying new binary...")
|
||||
if err := smokeTest(newBinPath, targetVersion); err != nil {
|
||||
return u.fail("smoke test: %v", err)
|
||||
}
|
||||
|
||||
// 8. Backup current binary
|
||||
backupPath := binPath + ".backup"
|
||||
u.log("Backing up current binary...")
|
||||
if err := os.Rename(binPath, backupPath); err != nil {
|
||||
return u.fail("backup: %v", err)
|
||||
}
|
||||
|
||||
// 9. Replace with new binary
|
||||
u.log("Installing new binary...")
|
||||
if err := installBinary(newBinPath, binPath); err != nil {
|
||||
// Rollback
|
||||
u.log("Install failed, rolling back...")
|
||||
if rbErr := os.Rename(backupPath, binPath); rbErr != nil {
|
||||
return u.fail("install failed (%v) AND rollback failed (%v) — manual recovery needed at %s", err, rbErr, backupPath)
|
||||
}
|
||||
return u.fail("install (rolled back): %v", err)
|
||||
}
|
||||
|
||||
u.log(fmt.Sprintf("Upgraded %s → %s", u.CurrentVersion, targetVersion))
|
||||
|
||||
return Result{
|
||||
Success: true,
|
||||
OldVersion: u.CurrentVersion,
|
||||
NewVersion: targetVersion,
|
||||
BackupPath: backupPath,
|
||||
}
|
||||
}
|
||||
|
||||
func (u *Upgrader) fail(format string, args ...any) Result {
|
||||
err := fmt.Errorf(format, args...)
|
||||
u.log(fmt.Sprintf("FAILED: %v", err))
|
||||
return Result{
|
||||
Success: false,
|
||||
OldVersion: u.CurrentVersion,
|
||||
Error: err,
|
||||
}
|
||||
}
|
||||
|
||||
// CheckLatest fetches the latest version from GitHub API.
|
||||
func CheckLatest(ctx context.Context) (string, error) {
|
||||
return fetchLatestVersion(ctx)
|
||||
}
|
||||
|
||||
// installBinary copies the new binary to the target path, preserving original permissions.
|
||||
func installBinary(src, dst string) error {
|
||||
// Read new binary
|
||||
data, err := os.ReadFile(src)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read new binary: %w", err)
|
||||
}
|
||||
|
||||
// Write to destination with executable permissions
|
||||
if err := os.WriteFile(dst, data, 0o755); err != nil {
|
||||
return fmt.Errorf("write binary: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// smokeTest runs the new binary with "version" and checks the output contains the expected version.
|
||||
func smokeTest(binPath, expectedVersion string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), smokeTestTO)
|
||||
defer cancel()
|
||||
|
||||
out, err := exec.CommandContext(ctx, binPath, "version").CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to run: %w (output: %s)", err, string(out))
|
||||
}
|
||||
|
||||
output := string(out)
|
||||
if !strings.Contains(output, expectedVersion) {
|
||||
return fmt.Errorf("version mismatch: expected %q in output %q", expectedVersion, output)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isDocker returns true if running inside a Docker container.
|
||||
func isDocker() bool {
|
||||
if _, err := os.Stat("/.dockerenv"); err == nil {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// checkWritable verifies the directory is writable by creating and removing a temp file.
|
||||
func checkWritable(dir string) error {
|
||||
tmp := filepath.Join(dir, ".unarr-write-test")
|
||||
f, err := os.Create(tmp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f.Close()
|
||||
os.Remove(tmp)
|
||||
return nil
|
||||
}
|
||||
|
||||
// archiveName returns the expected archive filename for this platform.
|
||||
func archiveName(version string) string {
|
||||
ext := "tar.gz"
|
||||
if runtime.GOOS == "windows" {
|
||||
ext = "zip"
|
||||
}
|
||||
return fmt.Sprintf("%s_%s_%s_%s.%s", binaryName, version, runtime.GOOS, runtime.GOARCH, ext)
|
||||
}
|
||||
|
||||
// releaseURL returns the download URL for a release asset.
|
||||
func releaseURL(version, filename string) string {
|
||||
return fmt.Sprintf("https://github.com/%s/releases/download/v%s/%s", githubRepo, version, filename)
|
||||
}
|
||||
307
internal/upgrade/upgrade_test.go
Normal file
307
internal/upgrade/upgrade_test.go
Normal file
|
|
@ -0,0 +1,307 @@
|
|||
package upgrade
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsDocker(t *testing.T) {
|
||||
// In a normal test environment, we should NOT be in Docker
|
||||
if _, err := os.Stat("/.dockerenv"); err == nil {
|
||||
t.Skip("running in Docker, skipping non-Docker test")
|
||||
}
|
||||
if isDocker() {
|
||||
t.Error("isDocker() = true, want false (not running in Docker)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckWritable(t *testing.T) {
|
||||
t.Run("writable directory", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := checkWritable(dir); err != nil {
|
||||
t.Errorf("checkWritable(%q) = %v, want nil", dir, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non-existent directory", func(t *testing.T) {
|
||||
err := checkWritable("/nonexistent-path-that-should-not-exist-12345")
|
||||
if err == nil {
|
||||
t.Error("checkWritable(nonexistent) = nil, want error")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestArchiveName(t *testing.T) {
|
||||
name := archiveName("0.3.0")
|
||||
expected := fmt.Sprintf("unarr_0.3.0_%s_%s.", runtime.GOOS, runtime.GOARCH)
|
||||
if runtime.GOOS == "windows" {
|
||||
expected += "zip"
|
||||
} else {
|
||||
expected += "tar.gz"
|
||||
}
|
||||
if name != expected {
|
||||
t.Errorf("archiveName(0.3.0) = %q, want %q", name, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReleaseURL(t *testing.T) {
|
||||
url := releaseURL("0.3.0", "unarr_0.3.0_linux_amd64.tar.gz")
|
||||
want := "https://github.com/torrentclaw/unarr/releases/download/v0.3.0/unarr_0.3.0_linux_amd64.tar.gz"
|
||||
if url != want {
|
||||
t.Errorf("releaseURL = %q, want %q", url, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSmokeTest(t *testing.T) {
|
||||
t.Run("successful smoke test", func(t *testing.T) {
|
||||
// Create a fake binary that outputs a version
|
||||
dir := t.TempDir()
|
||||
script := filepath.Join(dir, "fake-unarr")
|
||||
content := "#!/bin/sh\necho 'unarr 1.2.3 (linux/amd64)'\n"
|
||||
if runtime.GOOS == "windows" {
|
||||
script += ".bat"
|
||||
content = "@echo unarr 1.2.3 (windows/amd64)\n"
|
||||
}
|
||||
os.WriteFile(script, []byte(content), 0o755)
|
||||
|
||||
err := smokeTest(script, "1.2.3")
|
||||
if err != nil {
|
||||
t.Errorf("smokeTest() = %v, want nil", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("version mismatch", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
script := filepath.Join(dir, "fake-unarr")
|
||||
content := "#!/bin/sh\necho 'unarr 0.1.0 (linux/amd64)'\n"
|
||||
if runtime.GOOS == "windows" {
|
||||
script += ".bat"
|
||||
content = "@echo unarr 0.1.0 (windows/amd64)\n"
|
||||
}
|
||||
os.WriteFile(script, []byte(content), 0o755)
|
||||
|
||||
err := smokeTest(script, "1.2.3")
|
||||
if err == nil {
|
||||
t.Error("smokeTest() = nil, want version mismatch error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non-existent binary", func(t *testing.T) {
|
||||
err := smokeTest("/nonexistent-binary", "1.0.0")
|
||||
if err == nil {
|
||||
t.Error("smokeTest(nonexistent) = nil, want error")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestInstallBinary(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
src := filepath.Join(dir, "new-binary")
|
||||
dst := filepath.Join(dir, "installed-binary")
|
||||
|
||||
os.WriteFile(src, []byte("binary-content"), 0o755)
|
||||
|
||||
err := installBinary(src, dst)
|
||||
if err != nil {
|
||||
t.Fatalf("installBinary() = %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(dst)
|
||||
if err != nil {
|
||||
t.Fatalf("read installed binary: %v", err)
|
||||
}
|
||||
if string(data) != "binary-content" {
|
||||
t.Errorf("installed binary content = %q, want %q", data, "binary-content")
|
||||
}
|
||||
|
||||
info, _ := os.Stat(dst)
|
||||
if info.Mode().Perm()&0o111 == 0 {
|
||||
t.Error("installed binary is not executable")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyChecksum(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("tar.gz test only on unix")
|
||||
}
|
||||
|
||||
// Create a fake archive
|
||||
dir := t.TempDir()
|
||||
archivePath := filepath.Join(dir, "unarr_1.0.0_linux_amd64.tar.gz")
|
||||
archiveContent := []byte("fake-archive-content-for-testing")
|
||||
os.WriteFile(archivePath, archiveContent, 0o644)
|
||||
|
||||
// Calculate expected hash
|
||||
h := sha256.Sum256(archiveContent)
|
||||
expectedHash := hex.EncodeToString(h[:])
|
||||
|
||||
t.Run("valid checksum", func(t *testing.T) {
|
||||
// Create a mock server that returns checksums.txt
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/torrentclaw/unarr/releases/download/v1.0.0/checksums.txt" {
|
||||
fmt.Fprintf(w, "%s unarr_1.0.0_linux_amd64.tar.gz\n", expectedHash)
|
||||
fmt.Fprintf(w, "0000000000000000000000000000000000000000000000000000000000000000 unarr_1.0.0_darwin_amd64.tar.gz\n")
|
||||
} else {
|
||||
w.WriteHeader(404)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
// Override the httpClient and repo for testing
|
||||
origClient := httpClient
|
||||
httpClient = srv.Client()
|
||||
defer func() { httpClient = origClient }()
|
||||
|
||||
// We can't easily test verifyChecksum directly because it builds URLs from constants.
|
||||
// Instead, test the checksum logic manually
|
||||
f, _ := os.Open(archivePath)
|
||||
defer f.Close()
|
||||
hash := sha256.New()
|
||||
hash.Write(archiveContent)
|
||||
actualHash := hex.EncodeToString(hash.Sum(nil))
|
||||
|
||||
if actualHash != expectedHash {
|
||||
t.Errorf("hash mismatch: got %s, want %s", actualHash, expectedHash)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("hash calculation correctness", func(t *testing.T) {
|
||||
data := []byte("test data for hashing")
|
||||
h := sha256.Sum256(data)
|
||||
got := hex.EncodeToString(h[:])
|
||||
// Known SHA256 of "test data for hashing"
|
||||
if len(got) != 64 {
|
||||
t.Errorf("hash length = %d, want 64", len(got))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractTarGz(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("tar.gz test only on unix")
|
||||
}
|
||||
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create a tar.gz with a fake binary inside
|
||||
archivePath := filepath.Join(dir, "test.tar.gz")
|
||||
f, err := os.Create(archivePath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
gw := gzip.NewWriter(f)
|
||||
tw := tar.NewWriter(gw)
|
||||
|
||||
binaryContent := []byte("#!/bin/sh\necho test\n")
|
||||
hdr := &tar.Header{
|
||||
Name: "unarr",
|
||||
Mode: 0o755,
|
||||
Size: int64(len(binaryContent)),
|
||||
}
|
||||
tw.WriteHeader(hdr)
|
||||
tw.Write(binaryContent)
|
||||
tw.Close()
|
||||
gw.Close()
|
||||
f.Close()
|
||||
|
||||
// Extract
|
||||
destDir := filepath.Join(dir, "extracted")
|
||||
os.MkdirAll(destDir, 0o755)
|
||||
|
||||
binPath, err := extractTarGz(archivePath, destDir)
|
||||
if err != nil {
|
||||
t.Fatalf("extractTarGz() = %v", err)
|
||||
}
|
||||
|
||||
if filepath.Base(binPath) != "unarr" {
|
||||
t.Errorf("extracted binary name = %q, want unarr", filepath.Base(binPath))
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(binPath)
|
||||
if string(data) != string(binaryContent) {
|
||||
t.Errorf("extracted content mismatch")
|
||||
}
|
||||
|
||||
info, _ := os.Stat(binPath)
|
||||
if info.Mode().Perm()&0o111 == 0 {
|
||||
t.Error("extracted binary is not executable")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractTarGzMissingBinary(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("tar.gz test only on unix")
|
||||
}
|
||||
|
||||
dir := t.TempDir()
|
||||
archivePath := filepath.Join(dir, "empty.tar.gz")
|
||||
f, _ := os.Create(archivePath)
|
||||
gw := gzip.NewWriter(f)
|
||||
tw := tar.NewWriter(gw)
|
||||
|
||||
// Write a file that is NOT named "unarr"
|
||||
hdr := &tar.Header{Name: "README.md", Mode: 0o644, Size: 4}
|
||||
tw.WriteHeader(hdr)
|
||||
tw.Write([]byte("test"))
|
||||
tw.Close()
|
||||
gw.Close()
|
||||
f.Close()
|
||||
|
||||
destDir := filepath.Join(dir, "out")
|
||||
os.MkdirAll(destDir, 0o755)
|
||||
|
||||
_, err := extractTarGz(archivePath, destDir)
|
||||
if err == nil {
|
||||
t.Error("expected error for archive without unarr binary")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpgraderSameVersion(t *testing.T) {
|
||||
u := &Upgrader{CurrentVersion: "1.0.0"}
|
||||
result := u.Execute(context.Background(), "1.0.0")
|
||||
if !result.Success {
|
||||
t.Error("expected success when upgrading to same version")
|
||||
}
|
||||
if result.NewVersion != "1.0.0" {
|
||||
t.Errorf("NewVersion = %q, want 1.0.0", result.NewVersion)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpgraderSameVersionWithPrefix(t *testing.T) {
|
||||
u := &Upgrader{CurrentVersion: "1.0.0"}
|
||||
result := u.Execute(context.Background(), "v1.0.0")
|
||||
if !result.Success {
|
||||
t.Error("expected success when target version has v prefix")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchLatestVersionMockServer(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprint(w, `{"tag_name":"v2.5.1","published_at":"2025-01-01T00:00:00Z"}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
// We can't directly test fetchLatestVersion because it uses a hardcoded URL.
|
||||
// But we can test the JSON parsing logic by calling the endpoint ourselves.
|
||||
resp, err := http.Get(srv.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
t.Errorf("status = %d, want 200", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
|
@ -12,7 +12,6 @@ import (
|
|||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/torrentclaw/torrentclaw-cli/internal/ui"
|
||||
"github.com/torrentclaw/torrentclaw-cli/internal/usenet/nntp"
|
||||
"github.com/torrentclaw/torrentclaw-cli/internal/usenet/nzb"
|
||||
"github.com/torrentclaw/torrentclaw-cli/internal/usenet/yenc"
|
||||
|
|
@ -39,8 +38,11 @@ func NewDownloader(nntpClient *nntp.Client) *Downloader {
|
|||
}
|
||||
|
||||
// DownloadFile downloads all segments of a single NZB file and assembles them.
|
||||
// If tracker is non-nil, it is used for resume support: completed segments are
|
||||
// skipped, and progress is persisted to disk on pause/error.
|
||||
// fileIndex is the index of this file within the NZB (for the tracker).
|
||||
// Returns the path to the assembled file.
|
||||
func (d *Downloader) DownloadFile(ctx context.Context, file nzb.File, outputDir string, progressCh chan<- Progress) (string, error) {
|
||||
func (d *Downloader) DownloadFile(ctx context.Context, file nzb.File, fileIndex int, outputDir string, tracker *ProgressTracker, progressCh chan<- Progress) (string, error) {
|
||||
fileName := file.Filename()
|
||||
if fileName == "" {
|
||||
fileName = fmt.Sprintf("usenet_%d", time.Now().UnixNano())
|
||||
|
|
@ -53,6 +55,15 @@ func (d *Downloader) DownloadFile(ctx context.Context, file nzb.File, outputDir
|
|||
return "", fmt.Errorf("mkdir: %w", err)
|
||||
}
|
||||
|
||||
// If tracker says this file is fully done, skip entirely
|
||||
if tracker != nil && tracker.IsFileDone(fileIndex) {
|
||||
if _, err := os.Stat(destPath); err == nil {
|
||||
log.Printf("[usenet] skipping %s (fully downloaded in previous run)", fileName)
|
||||
return destPath, nil
|
||||
}
|
||||
// File was marked done but doesn't exist on disk — re-download
|
||||
}
|
||||
|
||||
totalBytes := file.TotalBytes()
|
||||
totalSegs := len(file.Segments)
|
||||
|
||||
|
|
@ -63,34 +74,6 @@ func (d *Downloader) DownloadFile(ctx context.Context, file nzb.File, outputDir
|
|||
return segments[i].Number < segments[j].Number
|
||||
})
|
||||
|
||||
// Create/open output file
|
||||
outFile, err := os.Create(destPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create file: %w", err)
|
||||
}
|
||||
defer outFile.Close()
|
||||
|
||||
// Pre-allocate file if we know the size
|
||||
if totalBytes > 0 {
|
||||
outFile.Truncate(totalBytes)
|
||||
}
|
||||
|
||||
// Download segments using worker pool
|
||||
var downloaded atomic.Int64
|
||||
var segsDone atomic.Int32
|
||||
startTime := time.Now()
|
||||
|
||||
// Create work channel
|
||||
type segWork struct {
|
||||
seg nzb.Segment
|
||||
index int
|
||||
}
|
||||
workCh := make(chan segWork, len(segments))
|
||||
for i, seg := range segments {
|
||||
workCh <- segWork{seg: seg, index: i}
|
||||
}
|
||||
close(workCh)
|
||||
|
||||
// Track file offsets for each segment (sequential assembly)
|
||||
offsets := make([]int64, len(segments))
|
||||
var offset int64
|
||||
|
|
@ -99,6 +82,76 @@ func (d *Downloader) DownloadFile(ctx context.Context, file nzb.File, outputDir
|
|||
offset += seg.Bytes
|
||||
}
|
||||
|
||||
// Open output file — resume-aware
|
||||
var outFile *os.File
|
||||
var err error
|
||||
resuming := false
|
||||
|
||||
if tracker != nil {
|
||||
if _, statErr := os.Stat(destPath); statErr == nil && tracker.CompletedSegments(fileIndex) > 0 {
|
||||
// Partial file exists and we have progress — open for read-write (no truncate)
|
||||
outFile, err = os.OpenFile(destPath, os.O_RDWR, 0o644)
|
||||
resuming = true
|
||||
}
|
||||
}
|
||||
|
||||
if outFile == nil {
|
||||
// Fresh start
|
||||
outFile, err = os.Create(destPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create file: %w", err)
|
||||
}
|
||||
// Pre-allocate file if we know the size
|
||||
if totalBytes > 0 {
|
||||
outFile.Truncate(totalBytes)
|
||||
}
|
||||
} else if err != nil {
|
||||
return "", fmt.Errorf("open file for resume: %w", err)
|
||||
}
|
||||
defer outFile.Close()
|
||||
|
||||
// Download segments using worker pool
|
||||
var downloaded atomic.Int64
|
||||
var segsDone atomic.Int32
|
||||
startTime := time.Now()
|
||||
|
||||
// Create work channel — skip already-completed segments
|
||||
type segWork struct {
|
||||
seg nzb.Segment
|
||||
index int
|
||||
}
|
||||
|
||||
pendingCount := 0
|
||||
for i := range segments {
|
||||
if tracker != nil && tracker.IsDone(fileIndex, i) {
|
||||
// Already downloaded — count towards progress
|
||||
downloaded.Add(segments[i].Bytes)
|
||||
segsDone.Add(1)
|
||||
} else {
|
||||
pendingCount++
|
||||
}
|
||||
}
|
||||
|
||||
if resuming {
|
||||
log.Printf("[usenet] resuming %s (%d/%d segments, %s/%s)",
|
||||
fileName, totalSegs-pendingCount, totalSegs,
|
||||
formatBytes(downloaded.Load()), formatBytes(totalBytes))
|
||||
}
|
||||
|
||||
if pendingCount == 0 {
|
||||
// All segments already done
|
||||
log.Printf("[usenet] %s already complete (%d segments)", fileName, totalSegs)
|
||||
return destPath, nil
|
||||
}
|
||||
|
||||
workCh := make(chan segWork, pendingCount)
|
||||
for i, seg := range segments {
|
||||
if tracker == nil || !tracker.IsDone(fileIndex, i) {
|
||||
workCh <- segWork{seg: seg, index: i}
|
||||
}
|
||||
}
|
||||
close(workCh)
|
||||
|
||||
// Progress reporter goroutine
|
||||
stopProgress := make(chan struct{})
|
||||
go func() {
|
||||
|
|
@ -177,6 +230,11 @@ func (d *Downloader) DownloadFile(ctx context.Context, file nzb.File, outputDir
|
|||
|
||||
downloaded.Add(int64(len(data)))
|
||||
segsDone.Add(1)
|
||||
|
||||
// Mark segment as completed in tracker
|
||||
if tracker != nil {
|
||||
tracker.MarkDone(fileIndex, work.index)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
|
@ -187,17 +245,21 @@ func (d *Downloader) DownloadFile(ctx context.Context, file nzb.File, outputDir
|
|||
// Stop progress reporter before sending final progress
|
||||
close(stopProgress)
|
||||
|
||||
// Check for errors
|
||||
// Check for errors — keep partial file for resume (don't delete)
|
||||
select {
|
||||
case err := <-errCh:
|
||||
os.Remove(destPath)
|
||||
if tracker != nil {
|
||||
tracker.Flush()
|
||||
}
|
||||
return "", err
|
||||
default:
|
||||
}
|
||||
|
||||
// Check context cancellation
|
||||
// Check context cancellation — keep partial file for resume (don't delete)
|
||||
if ctx.Err() != nil {
|
||||
os.Remove(destPath)
|
||||
if tracker != nil {
|
||||
tracker.Flush()
|
||||
}
|
||||
return "", ctx.Err()
|
||||
}
|
||||
|
||||
|
|
@ -228,15 +290,16 @@ func (d *Downloader) DownloadFile(ctx context.Context, file nzb.File, outputDir
|
|||
outFile.Truncate(actualSize)
|
||||
}
|
||||
|
||||
log.Printf("[usenet] downloaded %s (%d segments, %s)", fileName, totalSegs, ui.FormatBytes(actualSize))
|
||||
log.Printf("[usenet] downloaded %s (%d segments, %s)", fileName, totalSegs, formatBytes(actualSize))
|
||||
return destPath, nil
|
||||
}
|
||||
|
||||
// DownloadNZB downloads content files from an NZB (rars or direct content).
|
||||
// Par2 files are NOT downloaded initially — they're only fetched on demand
|
||||
// if extraction fails (via DownloadPar2).
|
||||
// If tracker is non-nil, completed files are skipped and progress is tracked per-segment.
|
||||
// Returns a map of filename → filepath for all downloaded files.
|
||||
func (d *Downloader) DownloadNZB(ctx context.Context, n *nzb.NZB, outputDir string, progressCh chan<- Progress) (map[string]string, error) {
|
||||
func (d *Downloader) DownloadNZB(ctx context.Context, n *nzb.NZB, outputDir string, tracker *ProgressTracker, progressCh chan<- Progress) (map[string]string, error) {
|
||||
// Determine which files to download (NO par2 initially)
|
||||
var filesToDownload []nzb.File
|
||||
|
||||
|
|
@ -250,6 +313,13 @@ func (d *Downloader) DownloadNZB(ctx context.Context, n *nzb.NZB, outputDir stri
|
|||
return nil, fmt.Errorf("no downloadable files found in NZB")
|
||||
}
|
||||
|
||||
// Build NZB file index mapping: Subject → index in n.Files
|
||||
// This maps each file to its position in the ProgressTracker
|
||||
nzbFileIndex := make(map[string]int)
|
||||
for i, f := range n.Files {
|
||||
nzbFileIndex[f.Subject] = i
|
||||
}
|
||||
|
||||
results := make(map[string]string)
|
||||
|
||||
for _, file := range filesToDownload {
|
||||
|
|
@ -259,7 +329,19 @@ func (d *Downloader) DownloadNZB(ctx context.Context, n *nzb.NZB, outputDir stri
|
|||
default:
|
||||
}
|
||||
|
||||
path, err := d.DownloadFile(ctx, file, outputDir, progressCh)
|
||||
fileIdx := nzbFileIndex[file.Subject]
|
||||
|
||||
// Skip fully completed files
|
||||
if tracker != nil && tracker.IsFileDone(fileIdx) {
|
||||
destPath := filepath.Join(outputDir, file.Filename())
|
||||
if _, err := os.Stat(destPath); err == nil {
|
||||
results[file.Filename()] = destPath
|
||||
log.Printf("[usenet] skipping %s (complete)", file.Filename())
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
path, err := d.DownloadFile(ctx, file, fileIdx, outputDir, tracker, progressCh)
|
||||
if err != nil {
|
||||
return results, fmt.Errorf("download %s: %w", file.Filename(), err)
|
||||
}
|
||||
|
|
@ -271,6 +353,7 @@ func (d *Downloader) DownloadNZB(ctx context.Context, n *nzb.NZB, outputDir stri
|
|||
|
||||
// DownloadPar2 downloads par2 parity files from the NZB.
|
||||
// Called on-demand when extraction/verification fails.
|
||||
// No resume tracking — par2 files are small and downloaded fresh.
|
||||
func (d *Downloader) DownloadPar2(ctx context.Context, n *nzb.NZB, outputDir string, progressCh chan<- Progress) (map[string]string, error) {
|
||||
par2Files := n.Par2Files()
|
||||
if len(par2Files) == 0 {
|
||||
|
|
@ -279,7 +362,7 @@ func (d *Downloader) DownloadPar2(ctx context.Context, n *nzb.NZB, outputDir str
|
|||
|
||||
results := make(map[string]string)
|
||||
for _, file := range par2Files {
|
||||
path, err := d.DownloadFile(ctx, file, outputDir, progressCh)
|
||||
path, err := d.DownloadFile(ctx, file, -1, outputDir, nil, progressCh)
|
||||
if err != nil {
|
||||
log.Printf("[usenet] par2 download failed (non-fatal): %v", err)
|
||||
continue
|
||||
|
|
@ -306,3 +389,15 @@ func (d *Downloader) downloadSegment(ctx context.Context, seg nzb.Segment) ([]by
|
|||
return part.Data, nil
|
||||
}
|
||||
|
||||
func formatBytes(b int64) string {
|
||||
const unit = 1024
|
||||
if b < unit {
|
||||
return fmt.Sprintf("%d B", b)
|
||||
}
|
||||
div, exp := int64(unit), 0
|
||||
for n := b / unit; n >= unit; n /= unit {
|
||||
div *= unit
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %cB", float64(b)/float64(div), "KMGTPE"[exp])
|
||||
}
|
||||
|
|
|
|||
|
|
@ -109,7 +109,7 @@ func TestE2EDownload(t *testing.T) {
|
|||
fmt.Fprintln(os.Stderr)
|
||||
}()
|
||||
|
||||
downloadedFiles, err := dl.DownloadNZB(ctx, nzbFile, outputDir, progressCh)
|
||||
downloadedFiles, err := dl.DownloadNZB(ctx, nzbFile, outputDir, nil, progressCh)
|
||||
close(progressCh)
|
||||
if err != nil {
|
||||
t.Fatalf("download: %v", err)
|
||||
|
|
|
|||
345
internal/usenet/download/progress.go
Normal file
345
internal/usenet/download/progress.go
Normal file
|
|
@ -0,0 +1,345 @@
|
|||
package download
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/torrentclaw/torrentclaw-cli/internal/usenet/nzb"
|
||||
)
|
||||
|
||||
// Binary progress file format:
|
||||
// [4B magic "UNRP"] [1B version=1] [1B reserved] [2B fileCount]
|
||||
// [32B SHA-256 fingerprint]
|
||||
// Per file: [4B segCount] [ceil(segCount/8) bytes bitset]
|
||||
|
||||
var progressMagic = [4]byte{'U', 'N', 'R', 'P'}
|
||||
|
||||
const (
|
||||
progressVersion = 1
|
||||
headerSize = 4 + 1 + 1 + 2 + 32 // 40 bytes
|
||||
flushInterval = 2 * time.Second
|
||||
flushSegmentFreq = 100 // flush every N segment completions
|
||||
)
|
||||
|
||||
// fileProgress tracks completed segments for a single NZB file.
|
||||
type fileProgress struct {
|
||||
segCount int
|
||||
completed []byte // bitset: ceil(segCount/8) bytes
|
||||
doneCount atomic.Int32
|
||||
}
|
||||
|
||||
// ProgressTracker tracks segment-level download progress for resumable usenet downloads.
|
||||
// Thread-safe for concurrent use by multiple download workers.
|
||||
type ProgressTracker struct {
|
||||
taskID string
|
||||
fingerprint [32]byte
|
||||
dir string // directory where progress files are stored
|
||||
files []fileProgress
|
||||
|
||||
mu sync.Mutex
|
||||
dirty bool
|
||||
lastFlush time.Time
|
||||
markCount int // marks since last flush
|
||||
}
|
||||
|
||||
// Fingerprint computes a SHA-256 hash from all message-IDs in the NZB.
|
||||
// Used to validate that a progress file matches the same NZB content.
|
||||
func Fingerprint(n *nzb.NZB) [32]byte {
|
||||
var ids []string
|
||||
for _, f := range n.Files {
|
||||
for _, s := range f.Segments {
|
||||
ids = append(ids, s.MessageID)
|
||||
}
|
||||
}
|
||||
sort.Strings(ids)
|
||||
|
||||
h := sha256.New()
|
||||
for _, id := range ids {
|
||||
h.Write([]byte(id))
|
||||
h.Write([]byte{'\n'})
|
||||
}
|
||||
|
||||
var fp [32]byte
|
||||
copy(fp[:], h.Sum(nil))
|
||||
return fp
|
||||
}
|
||||
|
||||
// NewProgressTracker creates a tracker for the given NZB.
|
||||
// The dir parameter is the base directory for resume files (e.g. DataDir()/resume).
|
||||
func NewProgressTracker(taskID string, n *nzb.NZB, dir string) *ProgressTracker {
|
||||
files := make([]fileProgress, len(n.Files))
|
||||
for i, f := range n.Files {
|
||||
segCount := len(f.Segments)
|
||||
files[i] = fileProgress{
|
||||
segCount: segCount,
|
||||
completed: make([]byte, (segCount+7)/8),
|
||||
}
|
||||
}
|
||||
|
||||
return &ProgressTracker{
|
||||
taskID: taskID,
|
||||
fingerprint: Fingerprint(n),
|
||||
dir: dir,
|
||||
files: files,
|
||||
lastFlush: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// progressPath returns the path to the binary progress file.
|
||||
func (p *ProgressTracker) progressPath() string {
|
||||
return filepath.Join(p.dir, p.taskID+".progress")
|
||||
}
|
||||
|
||||
// nzbPath returns the path to the cached NZB file.
|
||||
func (p *ProgressTracker) nzbPath() string {
|
||||
return filepath.Join(p.dir, p.taskID+".nzb")
|
||||
}
|
||||
|
||||
// Load reads a progress file from disk and validates its fingerprint.
|
||||
// Returns true if the file was loaded successfully and matches the current NZB.
|
||||
// Returns false if the file doesn't exist, is invalid, or has a different fingerprint.
|
||||
func (p *ProgressTracker) Load() (bool, error) {
|
||||
data, err := os.ReadFile(p.progressPath())
|
||||
if err != nil {
|
||||
return false, nil // file doesn't exist = fresh start
|
||||
}
|
||||
|
||||
if len(data) < headerSize {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Validate magic
|
||||
if data[0] != progressMagic[0] || data[1] != progressMagic[1] ||
|
||||
data[2] != progressMagic[2] || data[3] != progressMagic[3] {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Validate version
|
||||
if data[4] != progressVersion {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Validate file count
|
||||
fileCount := int(binary.LittleEndian.Uint16(data[6:8]))
|
||||
if fileCount != len(p.files) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Validate fingerprint
|
||||
var storedFP [32]byte
|
||||
copy(storedFP[:], data[8:40])
|
||||
if storedFP != p.fingerprint {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Read per-file bitsets
|
||||
offset := headerSize
|
||||
for i := range p.files {
|
||||
if offset+4 > len(data) {
|
||||
return false, nil
|
||||
}
|
||||
segCount := int(binary.LittleEndian.Uint32(data[offset : offset+4]))
|
||||
offset += 4
|
||||
|
||||
if segCount != p.files[i].segCount {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
bitsetLen := (segCount + 7) / 8
|
||||
if offset+bitsetLen > len(data) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
copy(p.files[i].completed, data[offset:offset+bitsetLen])
|
||||
offset += bitsetLen
|
||||
|
||||
// Count completed segments
|
||||
var count int32
|
||||
for seg := 0; seg < segCount; seg++ {
|
||||
if p.files[i].completed[seg/8]&(1<<uint(seg%8)) != 0 {
|
||||
count++
|
||||
}
|
||||
}
|
||||
p.files[i].doneCount.Store(count)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// MarkDone marks a segment as completed. Thread-safe.
|
||||
// Automatically flushes to disk periodically.
|
||||
func (p *ProgressTracker) MarkDone(fileIndex, segIndex int) {
|
||||
if fileIndex < 0 || fileIndex >= len(p.files) {
|
||||
return
|
||||
}
|
||||
fp := &p.files[fileIndex]
|
||||
if segIndex < 0 || segIndex >= fp.segCount {
|
||||
return
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
fp.completed[segIndex/8] |= 1 << uint(segIndex%8)
|
||||
fp.doneCount.Add(1)
|
||||
p.dirty = true
|
||||
p.markCount++
|
||||
|
||||
shouldFlush := p.markCount >= flushSegmentFreq || time.Since(p.lastFlush) >= flushInterval
|
||||
p.mu.Unlock()
|
||||
|
||||
if shouldFlush {
|
||||
p.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// IsDone returns whether a specific segment has been completed.
|
||||
func (p *ProgressTracker) IsDone(fileIndex, segIndex int) bool {
|
||||
if fileIndex < 0 || fileIndex >= len(p.files) {
|
||||
return false
|
||||
}
|
||||
fp := &p.files[fileIndex]
|
||||
if segIndex < 0 || segIndex >= fp.segCount {
|
||||
return false
|
||||
}
|
||||
// Read without lock — single-byte read is atomic on aligned data,
|
||||
// and we only ever set bits (never clear), so a stale read just means
|
||||
// we might re-download a segment (harmless, WriteAt is idempotent).
|
||||
return fp.completed[segIndex/8]&(1<<uint(segIndex%8)) != 0
|
||||
}
|
||||
|
||||
// IsFileDone returns true if all segments of a file are completed.
|
||||
func (p *ProgressTracker) IsFileDone(fileIndex int) bool {
|
||||
if fileIndex < 0 || fileIndex >= len(p.files) {
|
||||
return false
|
||||
}
|
||||
fp := &p.files[fileIndex]
|
||||
return int(fp.doneCount.Load()) >= fp.segCount
|
||||
}
|
||||
|
||||
// CompletedSegments returns the number of completed segments for a file.
|
||||
func (p *ProgressTracker) CompletedSegments(fileIndex int) int {
|
||||
if fileIndex < 0 || fileIndex >= len(p.files) {
|
||||
return 0
|
||||
}
|
||||
return int(p.files[fileIndex].doneCount.Load())
|
||||
}
|
||||
|
||||
// CompletedBytes returns the total bytes of completed segments for a file.
|
||||
func (p *ProgressTracker) CompletedBytes(fileIndex int, segments []nzb.Segment) int64 {
|
||||
if fileIndex < 0 || fileIndex >= len(p.files) {
|
||||
return 0
|
||||
}
|
||||
var total int64
|
||||
for i, seg := range segments {
|
||||
if p.IsDone(fileIndex, i) {
|
||||
total += seg.Bytes
|
||||
}
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
// TotalCompleted returns total completed segments across all files.
|
||||
func (p *ProgressTracker) TotalCompleted() int {
|
||||
var total int
|
||||
for i := range p.files {
|
||||
total += int(p.files[i].doneCount.Load())
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
// Flush writes the current progress state to disk atomically (tmp + rename).
|
||||
func (p *ProgressTracker) Flush() error {
|
||||
p.mu.Lock()
|
||||
if !p.dirty {
|
||||
p.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Calculate total size
|
||||
size := headerSize
|
||||
for i := range p.files {
|
||||
size += 4 + (p.files[i].segCount+7)/8
|
||||
}
|
||||
|
||||
buf := make([]byte, size)
|
||||
|
||||
// Header
|
||||
copy(buf[0:4], progressMagic[:])
|
||||
buf[4] = progressVersion
|
||||
buf[5] = 0 // reserved
|
||||
binary.LittleEndian.PutUint16(buf[6:8], uint16(len(p.files)))
|
||||
copy(buf[8:40], p.fingerprint[:])
|
||||
|
||||
// Per-file bitsets
|
||||
offset := headerSize
|
||||
for i := range p.files {
|
||||
fp := &p.files[i]
|
||||
binary.LittleEndian.PutUint32(buf[offset:offset+4], uint32(fp.segCount))
|
||||
offset += 4
|
||||
bitsetLen := (fp.segCount + 7) / 8
|
||||
copy(buf[offset:offset+bitsetLen], fp.completed[:bitsetLen])
|
||||
offset += bitsetLen
|
||||
}
|
||||
|
||||
p.dirty = false
|
||||
p.markCount = 0
|
||||
p.lastFlush = time.Now()
|
||||
p.mu.Unlock()
|
||||
|
||||
// Atomic write: tmp file + rename
|
||||
if err := os.MkdirAll(p.dir, 0o755); err != nil {
|
||||
return fmt.Errorf("create resume dir: %w", err)
|
||||
}
|
||||
|
||||
tmpPath := p.progressPath() + ".tmp"
|
||||
if err := os.WriteFile(tmpPath, buf, 0o644); err != nil {
|
||||
return fmt.Errorf("write progress tmp: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Rename(tmpPath, p.progressPath()); err != nil {
|
||||
os.Remove(tmpPath)
|
||||
return fmt.Errorf("rename progress: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove deletes both the progress file and cached NZB file.
|
||||
func (p *ProgressTracker) Remove() error {
|
||||
os.Remove(p.progressPath())
|
||||
os.Remove(p.nzbPath())
|
||||
// Also remove tmp file if it exists
|
||||
os.Remove(p.progressPath() + ".tmp")
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanStaleFiles removes resume files older than maxAge from the given directory.
|
||||
func CleanStaleFiles(dir string, maxAge time.Duration) int {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
removed := 0
|
||||
for _, e := range entries {
|
||||
if e.IsDir() {
|
||||
continue
|
||||
}
|
||||
info, err := e.Info()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if time.Since(info.ModTime()) > maxAge {
|
||||
if err := os.Remove(filepath.Join(dir, e.Name())); err == nil {
|
||||
removed++
|
||||
}
|
||||
}
|
||||
}
|
||||
return removed
|
||||
}
|
||||
398
internal/usenet/download/progress_test.go
Normal file
398
internal/usenet/download/progress_test.go
Normal file
|
|
@ -0,0 +1,398 @@
|
|||
package download
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"time"
|
||||
|
||||
"github.com/torrentclaw/torrentclaw-cli/internal/usenet/nzb"
|
||||
)
|
||||
|
||||
var fixedPast = time.Now().Add(-30 * 24 * time.Hour)
|
||||
|
||||
func makeTestNZB(fileCount, segsPerFile int) *nzb.NZB {
|
||||
n := &nzb.NZB{
|
||||
Files: make([]nzb.File, fileCount),
|
||||
}
|
||||
for i := 0; i < fileCount; i++ {
|
||||
segs := make([]nzb.Segment, segsPerFile)
|
||||
for j := 0; j < segsPerFile; j++ {
|
||||
segs[j] = nzb.Segment{
|
||||
Bytes: 750 * 1024,
|
||||
Number: j + 1,
|
||||
MessageID: segMsgID(i, j),
|
||||
}
|
||||
}
|
||||
n.Files[i] = nzb.File{
|
||||
Subject: `"testfile_` + string(rune('a'+i)) + `.rar" yEnc (1/` + string(rune('0'+segsPerFile)) + `)`,
|
||||
Segments: segs,
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func segMsgID(file, seg int) string {
|
||||
return "part" + itoa(seg) + ".file" + itoa(file) + "@example.com"
|
||||
}
|
||||
|
||||
func itoa(n int) string {
|
||||
if n == 0 {
|
||||
return "0"
|
||||
}
|
||||
s := ""
|
||||
for n > 0 {
|
||||
s = string(rune('0'+n%10)) + s
|
||||
n /= 10
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func TestFingerprint_Deterministic(t *testing.T) {
|
||||
n := makeTestNZB(3, 10)
|
||||
fp1 := Fingerprint(n)
|
||||
fp2 := Fingerprint(n)
|
||||
if fp1 != fp2 {
|
||||
t.Fatal("fingerprint should be deterministic")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFingerprint_DifferentNZB(t *testing.T) {
|
||||
n1 := makeTestNZB(3, 10)
|
||||
n2 := makeTestNZB(3, 11)
|
||||
if Fingerprint(n1) == Fingerprint(n2) {
|
||||
t.Fatal("different NZBs should have different fingerprints")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressTracker_NewAndFlush(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
n := makeTestNZB(2, 5)
|
||||
tracker := NewProgressTracker("test-task-1", n, dir)
|
||||
|
||||
// Mark some segments
|
||||
tracker.MarkDone(0, 0)
|
||||
tracker.MarkDone(0, 2)
|
||||
tracker.MarkDone(1, 4)
|
||||
|
||||
if err := tracker.Flush(); err != nil {
|
||||
t.Fatalf("flush: %v", err)
|
||||
}
|
||||
|
||||
// Verify file exists
|
||||
path := filepath.Join(dir, "test-task-1.progress")
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
t.Fatalf("progress file should exist: %v", err)
|
||||
}
|
||||
|
||||
// Verify state
|
||||
if !tracker.IsDone(0, 0) {
|
||||
t.Error("segment 0,0 should be done")
|
||||
}
|
||||
if tracker.IsDone(0, 1) {
|
||||
t.Error("segment 0,1 should NOT be done")
|
||||
}
|
||||
if !tracker.IsDone(0, 2) {
|
||||
t.Error("segment 0,2 should be done")
|
||||
}
|
||||
if !tracker.IsDone(1, 4) {
|
||||
t.Error("segment 1,4 should be done")
|
||||
}
|
||||
if tracker.CompletedSegments(0) != 2 {
|
||||
t.Errorf("file 0: expected 2 completed, got %d", tracker.CompletedSegments(0))
|
||||
}
|
||||
if tracker.CompletedSegments(1) != 1 {
|
||||
t.Errorf("file 1: expected 1 completed, got %d", tracker.CompletedSegments(1))
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressTracker_LoadRoundTrip(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
n := makeTestNZB(2, 8)
|
||||
|
||||
// Create and populate
|
||||
tracker1 := NewProgressTracker("test-task-2", n, dir)
|
||||
tracker1.MarkDone(0, 0)
|
||||
tracker1.MarkDone(0, 3)
|
||||
tracker1.MarkDone(0, 7)
|
||||
tracker1.MarkDone(1, 1)
|
||||
tracker1.MarkDone(1, 5)
|
||||
if err := tracker1.Flush(); err != nil {
|
||||
t.Fatalf("flush: %v", err)
|
||||
}
|
||||
|
||||
// Load into new tracker
|
||||
tracker2 := NewProgressTracker("test-task-2", n, dir)
|
||||
loaded, err := tracker2.Load()
|
||||
if err != nil {
|
||||
t.Fatalf("load: %v", err)
|
||||
}
|
||||
if !loaded {
|
||||
t.Fatal("should have loaded successfully")
|
||||
}
|
||||
|
||||
// Verify all bits match
|
||||
for _, tc := range []struct {
|
||||
file, seg int
|
||||
want bool
|
||||
}{
|
||||
{0, 0, true}, {0, 1, false}, {0, 2, false}, {0, 3, true},
|
||||
{0, 4, false}, {0, 5, false}, {0, 6, false}, {0, 7, true},
|
||||
{1, 0, false}, {1, 1, true}, {1, 2, false}, {1, 3, false},
|
||||
{1, 4, false}, {1, 5, true}, {1, 6, false}, {1, 7, false},
|
||||
} {
|
||||
got := tracker2.IsDone(tc.file, tc.seg)
|
||||
if got != tc.want {
|
||||
t.Errorf("file %d seg %d: got %v, want %v", tc.file, tc.seg, got, tc.want)
|
||||
}
|
||||
}
|
||||
|
||||
if tracker2.CompletedSegments(0) != 3 {
|
||||
t.Errorf("file 0: expected 3 completed, got %d", tracker2.CompletedSegments(0))
|
||||
}
|
||||
if tracker2.CompletedSegments(1) != 2 {
|
||||
t.Errorf("file 1: expected 2 completed, got %d", tracker2.CompletedSegments(1))
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressTracker_FingerprintMismatch(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
n1 := makeTestNZB(2, 5)
|
||||
n2 := makeTestNZB(2, 6) // different segment count = different fingerprint
|
||||
|
||||
// Write with n1
|
||||
tracker1 := NewProgressTracker("test-task-3", n1, dir)
|
||||
tracker1.MarkDone(0, 0)
|
||||
if err := tracker1.Flush(); err != nil {
|
||||
t.Fatalf("flush: %v", err)
|
||||
}
|
||||
|
||||
// Try to load with n2
|
||||
tracker2 := NewProgressTracker("test-task-3", n2, dir)
|
||||
loaded, err := tracker2.Load()
|
||||
if err != nil {
|
||||
t.Fatalf("load: %v", err)
|
||||
}
|
||||
if loaded {
|
||||
t.Fatal("should NOT load — fingerprint mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressTracker_IsFileDone(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
n := makeTestNZB(1, 4)
|
||||
tracker := NewProgressTracker("test-task-4", n, dir)
|
||||
|
||||
if tracker.IsFileDone(0) {
|
||||
t.Error("file should not be done yet")
|
||||
}
|
||||
|
||||
tracker.MarkDone(0, 0)
|
||||
tracker.MarkDone(0, 1)
|
||||
tracker.MarkDone(0, 2)
|
||||
if tracker.IsFileDone(0) {
|
||||
t.Error("file should not be done (3/4)")
|
||||
}
|
||||
|
||||
tracker.MarkDone(0, 3)
|
||||
if !tracker.IsFileDone(0) {
|
||||
t.Error("file should be done (4/4)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressTracker_ConcurrentMark(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
segCount := 1000
|
||||
n := makeTestNZB(1, segCount)
|
||||
tracker := NewProgressTracker("test-task-5", n, dir)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < segCount; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
tracker.MarkDone(0, idx)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if !tracker.IsFileDone(0) {
|
||||
t.Errorf("all segments should be done, got %d/%d", tracker.CompletedSegments(0), segCount)
|
||||
}
|
||||
|
||||
// Flush and reload
|
||||
if err := tracker.Flush(); err != nil {
|
||||
t.Fatalf("flush: %v", err)
|
||||
}
|
||||
|
||||
tracker2 := NewProgressTracker("test-task-5", n, dir)
|
||||
loaded, _ := tracker2.Load()
|
||||
if !loaded {
|
||||
t.Fatal("should load")
|
||||
}
|
||||
if !tracker2.IsFileDone(0) {
|
||||
t.Errorf("after reload: expected all done, got %d/%d", tracker2.CompletedSegments(0), segCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressTracker_Remove(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
n := makeTestNZB(1, 3)
|
||||
tracker := NewProgressTracker("test-task-6", n, dir)
|
||||
tracker.MarkDone(0, 0)
|
||||
if err := tracker.Flush(); err != nil {
|
||||
t.Fatalf("flush: %v", err)
|
||||
}
|
||||
|
||||
// Write a fake NZB cache file
|
||||
nzbPath := filepath.Join(dir, "test-task-6.nzb")
|
||||
os.WriteFile(nzbPath, []byte("<nzb/>"), 0o644)
|
||||
|
||||
// Both should exist
|
||||
if _, err := os.Stat(tracker.progressPath()); err != nil {
|
||||
t.Fatal("progress file should exist")
|
||||
}
|
||||
if _, err := os.Stat(nzbPath); err != nil {
|
||||
t.Fatal("nzb cache should exist")
|
||||
}
|
||||
|
||||
tracker.Remove()
|
||||
|
||||
if _, err := os.Stat(tracker.progressPath()); !os.IsNotExist(err) {
|
||||
t.Error("progress file should be removed")
|
||||
}
|
||||
if _, err := os.Stat(nzbPath); !os.IsNotExist(err) {
|
||||
t.Error("nzb cache should be removed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressTracker_LargeNZB(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
segCount := 30000
|
||||
n := makeTestNZB(1, segCount)
|
||||
tracker := NewProgressTracker("test-task-7", n, dir)
|
||||
|
||||
// Mark every other segment
|
||||
for i := 0; i < segCount; i += 2 {
|
||||
tracker.MarkDone(0, i)
|
||||
}
|
||||
|
||||
if err := tracker.Flush(); err != nil {
|
||||
t.Fatalf("flush: %v", err)
|
||||
}
|
||||
|
||||
// Check file size is compact
|
||||
info, err := os.Stat(tracker.progressPath())
|
||||
if err != nil {
|
||||
t.Fatalf("stat: %v", err)
|
||||
}
|
||||
// Header (40) + file header (4) + bitset (30000/8 = 3750) = 3794 bytes
|
||||
expectedMax := int64(4000)
|
||||
if info.Size() > expectedMax {
|
||||
t.Errorf("progress file too large: %d bytes (expected < %d)", info.Size(), expectedMax)
|
||||
}
|
||||
|
||||
// Reload and verify
|
||||
tracker2 := NewProgressTracker("test-task-7", n, dir)
|
||||
loaded, _ := tracker2.Load()
|
||||
if !loaded {
|
||||
t.Fatal("should load")
|
||||
}
|
||||
if tracker2.CompletedSegments(0) != segCount/2 {
|
||||
t.Errorf("expected %d completed, got %d", segCount/2, tracker2.CompletedSegments(0))
|
||||
}
|
||||
// Spot check
|
||||
if !tracker2.IsDone(0, 0) {
|
||||
t.Error("seg 0 should be done")
|
||||
}
|
||||
if tracker2.IsDone(0, 1) {
|
||||
t.Error("seg 1 should NOT be done")
|
||||
}
|
||||
if !tracker2.IsDone(0, 100) {
|
||||
t.Error("seg 100 should be done")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressTracker_CompletedBytes(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
n := makeTestNZB(1, 4)
|
||||
tracker := NewProgressTracker("test-task-8", n, dir)
|
||||
|
||||
tracker.MarkDone(0, 0)
|
||||
tracker.MarkDone(0, 2)
|
||||
|
||||
bytes := tracker.CompletedBytes(0, n.Files[0].Segments)
|
||||
expected := int64(2 * 750 * 1024) // 2 segments * 750KB
|
||||
if bytes != expected {
|
||||
t.Errorf("expected %d bytes, got %d", expected, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressTracker_BoundsCheck(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
n := makeTestNZB(1, 3)
|
||||
tracker := NewProgressTracker("test-task-9", n, dir)
|
||||
|
||||
// Out-of-bounds should not panic
|
||||
tracker.MarkDone(-1, 0)
|
||||
tracker.MarkDone(0, -1)
|
||||
tracker.MarkDone(5, 0)
|
||||
tracker.MarkDone(0, 100)
|
||||
|
||||
if tracker.IsDone(-1, 0) {
|
||||
t.Error("out of bounds should return false")
|
||||
}
|
||||
if tracker.IsDone(5, 0) {
|
||||
t.Error("out of bounds should return false")
|
||||
}
|
||||
if tracker.IsFileDone(-1) {
|
||||
t.Error("out of bounds should return false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanStaleFiles(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create a "stale" file
|
||||
stalePath := filepath.Join(dir, "old-task.progress")
|
||||
os.WriteFile(stalePath, []byte("data"), 0o644)
|
||||
// Backdate modification time
|
||||
staleTime := os.Chtimes(stalePath, fixedPast, fixedPast)
|
||||
if staleTime != nil {
|
||||
t.Fatalf("chtimes: %v", staleTime)
|
||||
}
|
||||
|
||||
// Create a "fresh" file
|
||||
freshPath := filepath.Join(dir, "new-task.progress")
|
||||
os.WriteFile(freshPath, []byte("data"), 0o644)
|
||||
|
||||
removed := CleanStaleFiles(dir, 14*24*time.Hour) // 2 weeks — stale file is 30 days old
|
||||
if removed != 1 {
|
||||
t.Errorf("expected 1 removed, got %d", removed)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(stalePath); !os.IsNotExist(err) {
|
||||
t.Error("stale file should be removed")
|
||||
}
|
||||
if _, err := os.Stat(freshPath); err != nil {
|
||||
t.Error("fresh file should still exist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressTracker_FlushNoOp(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
n := makeTestNZB(1, 3)
|
||||
tracker := NewProgressTracker("test-task-10", n, dir)
|
||||
|
||||
// Flush without any marks should be no-op
|
||||
if err := tracker.Flush(); err != nil {
|
||||
t.Fatalf("flush: %v", err)
|
||||
}
|
||||
|
||||
// File should not be created
|
||||
if _, err := os.Stat(tracker.progressPath()); !os.IsNotExist(err) {
|
||||
t.Error("no file should be created for empty flush")
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue