diff --git a/Cargo.lock b/Cargo.lock index 184eeb51fd93..03eb317111b4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -52,7 +52,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", - "const-random", "getrandom 0.2.15", "once_cell", "serde", @@ -220,220 +219,6 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" -[[package]] -name = "arrow" -version = "52.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05048a8932648b63f21c37d88b552ccc8a65afb6dfe9fc9f30ce79174c2e7a85" -dependencies = [ - "arrow-arith", - "arrow-array", - "arrow-buffer", - "arrow-cast", - "arrow-csv", - "arrow-data", - "arrow-ipc", - "arrow-json", - "arrow-ord", - "arrow-row", - "arrow-schema", - "arrow-select", - "arrow-string", -] - -[[package]] -name = "arrow-arith" -version = "52.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d8a57966e43bfe9a3277984a14c24ec617ad874e4c0e1d2a1b083a39cfbf22c" -dependencies = [ - "arrow-array", - "arrow-buffer", - "arrow-data", - "arrow-schema", - "chrono", - "half", - "num", -] - -[[package]] -name = "arrow-array" -version = "52.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16f4a9468c882dc66862cef4e1fd8423d47e67972377d85d80e022786427768c" -dependencies = [ - "ahash", - "arrow-buffer", - "arrow-data", - "arrow-schema", - "chrono", - "half", - "hashbrown 0.14.5", - "num", -] - -[[package]] -name = "arrow-buffer" -version = "52.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c975484888fc95ec4a632cdc98be39c085b1bb518531b0c80c5d462063e5daa1" -dependencies = [ - "bytes", - "half", - "num", -] - -[[package]] -name = "arrow-cast" -version = "52.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da26719e76b81d8bc3faad1d4dbdc1bcc10d14704e63dc17fc9f3e7e1e567c8e" -dependencies = [ - "arrow-array", - "arrow-buffer", - "arrow-data", - "arrow-schema", - "arrow-select", - "atoi", - "base64 0.22.1", - "chrono", - "half", - "lexical-core", - "num", - "ryu", -] - -[[package]] -name = "arrow-csv" -version = "52.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c13c36dc5ddf8c128df19bab27898eea64bf9da2b555ec1cd17a8ff57fba9ec2" -dependencies = [ - "arrow-array", - "arrow-buffer", - "arrow-cast", - "arrow-data", - "arrow-schema", - "chrono", - "csv", - "csv-core", - "lazy_static", - "lexical-core", - "regex", -] - -[[package]] -name = "arrow-data" -version = "52.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd9d6f18c65ef7a2573ab498c374d8ae364b4a4edf67105357491c031f716ca5" -dependencies = [ - "arrow-buffer", - "arrow-schema", - "half", - "num", -] - -[[package]] -name = "arrow-ipc" -version = "52.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e786e1cdd952205d9a8afc69397b317cfbb6e0095e445c69cda7e8da5c1eeb0f" -dependencies = [ - "arrow-array", - "arrow-buffer", - "arrow-cast", - "arrow-data", - "arrow-schema", - "flatbuffers", -] - -[[package]] -name = "arrow-json" -version = "52.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb22284c5a2a01d73cebfd88a33511a3234ab45d66086b2ca2d1228c3498e445" -dependencies = [ - "arrow-array", - "arrow-buffer", - "arrow-cast", - "arrow-data", - "arrow-schema", - "chrono", - "half", - "indexmap 2.7.1", - "lexical-core", - "num", - "serde", - "serde_json", -] - -[[package]] -name = "arrow-ord" -version = "52.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42745f86b1ab99ef96d1c0bcf49180848a64fe2c7a7a0d945bc64fa2b21ba9bc" -dependencies = [ - "arrow-array", - "arrow-buffer", - "arrow-data", - "arrow-schema", - "arrow-select", - "half", - "num", -] - -[[package]] -name = "arrow-row" -version = "52.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4cd09a518c602a55bd406bcc291a967b284cfa7a63edfbf8b897ea4748aad23c" -dependencies = [ - "ahash", - "arrow-array", - "arrow-buffer", - "arrow-data", - "arrow-schema", - "half", -] - -[[package]] -name = "arrow-schema" -version = "52.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e972cd1ff4a4ccd22f86d3e53e835c2ed92e0eea6a3e8eadb72b4f1ac802cf8" - -[[package]] -name = "arrow-select" -version = "52.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "600bae05d43483d216fb3494f8c32fdbefd8aa4e1de237e790dbb3d9f44690a3" -dependencies = [ - "ahash", - "arrow-array", - "arrow-buffer", - "arrow-data", - "arrow-schema", - "num", -] - -[[package]] -name = "arrow-string" -version = "52.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0dc1985b67cb45f6606a248ac2b4a288849f196bab8c657ea5589f47cdd55e6" -dependencies = [ - "arrow-array", - "arrow-buffer", - "arrow-data", - "arrow-schema", - "arrow-select", - "memchr", - "num", - "regex", - "regex-syntax 0.8.5", -] - [[package]] name = "assert-json-diff" version = "2.0.2" @@ -450,7 +235,7 @@ version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "435a87a52755b8f27fcf321ac4f04b2802e337c8c4872923137471ec39c37532" dependencies = [ - "event-listener", + "event-listener 5.4.1", "event-listener-strategy", "futures-core", "pin-project-lite", @@ -1148,6 +933,12 @@ dependencies = [ "vsimd", ] +[[package]] +name = "base64ct" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55248b47b0caf0546f7988906588779981c43bb1bc9d0c44087278f80cdb44ba" + [[package]] name = "bat" version = "0.24.0" @@ -1578,7 +1369,7 @@ version = "4.5.41" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ef4f52386a59ca4c860f7393bcf8abd8dfd91ecccc0f774635ff68e92eeef491" dependencies = [ - "heck", + "heck 0.5.0", "proc-macro2", "quote", "syn 2.0.99", @@ -1716,6 +1507,12 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "const-oid" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" + [[package]] name = "const-random" version = "0.1.18" @@ -1848,6 +1645,21 @@ dependencies = [ "libc", ] +[[package]] +name = "crc" +version = "3.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9710d3b3739c2e349eb44fe848ad0b7c8cb1e42bd87ee49371df2f7acaf3e675" +dependencies = [ + "crc-catalog", +] + +[[package]] +name = "crc-catalog" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" + [[package]] name = "crc32fast" version = "1.4.2" @@ -1930,6 +1742,15 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "crossbeam-queue" +version = "0.3.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" version = "0.8.21" @@ -1951,27 +1772,6 @@ dependencies = [ "typenum", ] -[[package]] -name = "csv" -version = "1.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "acdc4883a9c96732e4733212c01447ebd805833b7275a73ca3ee080fd77afdaf" -dependencies = [ - "csv-core", - "itoa", - "ryu", - "serde", -] - -[[package]] -name = "csv-core" -version = "0.1.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d02f3b0da4c6504f86e9cd789d8dbafab48c2321be74e9987593de5a894d93d" -dependencies = [ - "memchr", -] - [[package]] name = "ctor" version = "0.2.9" @@ -2115,6 +1915,17 @@ version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "092966b41edc516079bdf31ec78a2e0588d1d0c08f78b91d8307215928642b2b" +[[package]] +name = "der" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb" +dependencies = [ + "const-oid", + "pem-rfc7468", + "zeroize", +] + [[package]] name = "deranged" version = "0.3.11" @@ -2153,6 +1964,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", + "const-oid", "crypto-common", "subtle", ] @@ -2257,6 +2069,9 @@ name = "either" version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +dependencies = [ + "serde", +] [[package]] name = "email_address" @@ -2344,6 +2159,12 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "event-listener" +version = "2.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0" + [[package]] name = "event-listener" version = "5.4.1" @@ -2361,7 +2182,7 @@ version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8be9f3dfaaffdae2972880079a491a1a8bb7cbed0b8dd7a347f668b4150a3b93" dependencies = [ - "event-listener", + "event-listener 5.4.1", "pin-project-lite", ] @@ -2456,16 +2277,6 @@ dependencies = [ "windows-sys 0.59.0", ] -[[package]] -name = "flatbuffers" -version = "24.12.23" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f1baf0dbf96932ec9a3038d57900329c015b0bfb7b63d904f3bc27e2b02a096" -dependencies = [ - "bitflags 1.3.2", - "rustc_version", -] - [[package]] name = "flate2" version = "1.1.0" @@ -2487,6 +2298,17 @@ dependencies = [ "serde", ] +[[package]] +name = "flume" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da0e4dd2a88388a1f4ccc7c9ce104604dab68d9f408dc34cd45823d5a9069095" +dependencies = [ + "futures-core", + "futures-sink", + "spin", +] + [[package]] name = "fnv" version = "1.0.7" @@ -2624,6 +2446,17 @@ dependencies = [ "futures-util", ] +[[package]] +name = "futures-intrusive" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d930c203dd0b6ff06e0201a4a2fe9149b43c684fd4420555b26d21b1a02956f" +dependencies = [ + "futures-core", + "lock_api", + "parking_lot", +] + [[package]] name = "futures-io" version = "0.3.31" @@ -2780,7 +2613,6 @@ version = "1.9.0" dependencies = [ "ahash", "anyhow", - "arrow", "async-stream", "async-trait", "aws-config", @@ -2826,6 +2658,7 @@ dependencies = [ "serde_yaml", "serial_test", "sha2", + "sqlx", "temp-env", "tempfile", "test-case", @@ -3097,7 +2930,6 @@ checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" dependencies = [ "cfg-if", "crunchy", - "num-traits", ] [[package]] @@ -3136,6 +2968,15 @@ dependencies = [ "hashbrown 0.14.5", ] +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" +dependencies = [ + "unicode-segmentation", +] + [[package]] name = "heck" version = "0.5.0" @@ -3160,6 +3001,15 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "hkdf" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b5f8eb2ad728638ea2c7d47a21db23b7b58a72ed6a38256b8a1849f15fbbdf7" +dependencies = [ + "hmac", +] + [[package]] name = "hmac" version = "0.12.1" @@ -3344,7 +3194,7 @@ dependencies = [ "tokio", "tokio-rustls 0.26.2", "tower-service", - "webpki-roots", + "webpki-roots 0.26.8", ] [[package]] @@ -3921,6 +3771,9 @@ name = "lazy_static" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +dependencies = [ + "spin", +] [[package]] name = "lazycell" @@ -3934,70 +3787,6 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8" -[[package]] -name = "lexical-core" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2cde5de06e8d4c2faabc400238f9ae1c74d5412d03a7bd067645ccbc47070e46" -dependencies = [ - "lexical-parse-float", - "lexical-parse-integer", - "lexical-util", - "lexical-write-float", - "lexical-write-integer", -] - -[[package]] -name = "lexical-parse-float" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "683b3a5ebd0130b8fb52ba0bdc718cc56815b6a097e28ae5a6997d0ad17dc05f" -dependencies = [ - "lexical-parse-integer", - "lexical-util", - "static_assertions", -] - -[[package]] -name = "lexical-parse-integer" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d0994485ed0c312f6d965766754ea177d07f9c00c9b82a5ee62ed5b47945ee9" -dependencies = [ - "lexical-util", - "static_assertions", -] - -[[package]] -name = "lexical-util" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5255b9ff16ff898710eb9eb63cb39248ea8a5bb036bea8085b1a767ff6c4e3fc" -dependencies = [ - "static_assertions", -] - -[[package]] -name = "lexical-write-float" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "accabaa1c4581f05a3923d1b4cfd124c329352288b7b9da09e766b0668116862" -dependencies = [ - "lexical-util", - "lexical-write-integer", - "static_assertions", -] - -[[package]] -name = "lexical-write-integer" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1b6f3d1f4422866b68192d62f77bc5c700bee84f3069f2469d7bc8c77852446" -dependencies = [ - "lexical-util", - "static_assertions", -] - [[package]] name = "libc" version = "0.2.172" @@ -4063,6 +3852,17 @@ dependencies = [ "redox_syscall", ] +[[package]] +name = "libsqlite3-sys" +version = "0.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf4e226dcd58b4be396f7bd3c20da8fdee2911400705297ba7d2d7cc2c30f716" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + [[package]] name = "libz-sys" version = "1.1.21" @@ -4526,6 +4326,23 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-bigint-dig" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc84195820f291c7697304f3cbdadd1cb7199c0efc917ff5eafd71225c136151" +dependencies = [ + "byteorder", + "lazy_static", + "libm", + "num-integer", + "num-iter", + "num-traits", + "rand 0.8.5", + "smallvec", + "zeroize", +] + [[package]] name = "num-cmp" version = "0.1.0" @@ -4935,6 +4752,15 @@ dependencies = [ "serde", ] +[[package]] +name = "pem-rfc7468" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" +dependencies = [ + "base64ct", +] + [[package]] name = "percent-encoding" version = "2.3.2" @@ -5018,6 +4844,27 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "pkcs1" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8ffb9f10fa047879315e6625af03c164b16962a5368d724ed16323b68ace47f" +dependencies = [ + "der", + "pkcs8", + "spki", +] + +[[package]] +name = "pkcs8" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" +dependencies = [ + "der", + "spki", +] + [[package]] name = "pkg-config" version = "0.3.32" @@ -5700,7 +5547,7 @@ dependencies = [ "wasm-bindgen-futures", "wasm-streams", "web-sys", - "webpki-roots", + "webpki-roots 0.26.8", "windows-registry", ] @@ -5780,6 +5627,26 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "rsa" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78928ac1ed176a5ca1d17e578a1825f3d81ca54cf41053a592584b020cfd691b" +dependencies = [ + "const-oid", + "digest", + "num-bigint-dig", + "num-integer", + "num-traits", + "pkcs1", + "pkcs8", + "rand_core 0.6.4", + "signature", + "spki", + "subtle", + "zeroize", +] + [[package]] name = "rust-ini" version = "0.20.0" @@ -6311,6 +6178,16 @@ dependencies = [ "libc", ] +[[package]] +name = "signature" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" +dependencies = [ + "digest", + "rand_core 0.6.4", +] + [[package]] name = "simd-adler32" version = "0.3.7" @@ -6385,6 +6262,236 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +dependencies = [ + "lock_api", +] + +[[package]] +name = "spki" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d" +dependencies = [ + "base64ct", + "der", +] + +[[package]] +name = "sqlformat" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7bba3a93db0cc4f7bdece8bb09e77e2e785c20bfebf79eb8340ed80708048790" +dependencies = [ + "nom", + "unicode_categories", +] + +[[package]] +name = "sqlx" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c9a2ccff1a000a5a59cd33da541d9f2fdcd9e6e8229cc200565942bff36d0aaa" +dependencies = [ + "sqlx-core", + "sqlx-macros", + "sqlx-mysql", + "sqlx-postgres", + "sqlx-sqlite", +] + +[[package]] +name = "sqlx-core" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24ba59a9342a3d9bab6c56c118be528b27c9b60e490080e9711a04dccac83ef6" +dependencies = [ + "ahash", + "atoi", + "byteorder", + "bytes", + "chrono", + "crc", + "crossbeam-queue", + "either", + "event-listener 2.5.3", + "futures-channel", + "futures-core", + "futures-intrusive", + "futures-io", + "futures-util", + "hashlink", + "hex", + "indexmap 2.7.1", + "log", + "memchr", + "once_cell", + "paste", + "percent-encoding", + "rustls 0.21.12", + "rustls-pemfile 1.0.4", + "serde", + "serde_json", + "sha2", + "smallvec", + "sqlformat", + "thiserror 1.0.69", + "tokio", + "tokio-stream", + "tracing", + "url", + "webpki-roots 0.25.4", +] + +[[package]] +name = "sqlx-macros" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ea40e2345eb2faa9e1e5e326db8c34711317d2b5e08d0d5741619048a803127" +dependencies = [ + "proc-macro2", + "quote", + "sqlx-core", + "sqlx-macros-core", + "syn 1.0.109", +] + +[[package]] +name = "sqlx-macros-core" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5833ef53aaa16d860e92123292f1f6a3d53c34ba8b1969f152ef1a7bb803f3c8" +dependencies = [ + "dotenvy", + "either", + "heck 0.4.1", + "hex", + "once_cell", + "proc-macro2", + "quote", + "serde", + "serde_json", + "sha2", + "sqlx-core", + "sqlx-mysql", + "sqlx-postgres", + "sqlx-sqlite", + "syn 1.0.109", + "tempfile", + "tokio", + "url", +] + +[[package]] +name = "sqlx-mysql" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ed31390216d20e538e447a7a9b959e06ed9fc51c37b514b46eb758016ecd418" +dependencies = [ + "atoi", + "base64 0.21.7", + "bitflags 2.9.0", + "byteorder", + "bytes", + "chrono", + "crc", + "digest", + "dotenvy", + "either", + "futures-channel", + "futures-core", + "futures-io", + "futures-util", + "generic-array", + "hex", + "hkdf", + "hmac", + "itoa", + "log", + "md-5", + "memchr", + "once_cell", + "percent-encoding", + "rand 0.8.5", + "rsa", + "serde", + "sha1", + "sha2", + "smallvec", + "sqlx-core", + "stringprep", + "thiserror 1.0.69", + "tracing", + "whoami", +] + +[[package]] +name = "sqlx-postgres" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c824eb80b894f926f89a0b9da0c7f435d27cdd35b8c655b114e58223918577e" +dependencies = [ + "atoi", + "base64 0.21.7", + "bitflags 2.9.0", + "byteorder", + "chrono", + "crc", + "dotenvy", + "etcetera", + "futures-channel", + "futures-core", + "futures-io", + "futures-util", + "hex", + "hkdf", + "hmac", + "home", + "itoa", + "log", + "md-5", + "memchr", + "once_cell", + "rand 0.8.5", + "serde", + "serde_json", + "sha2", + "smallvec", + "sqlx-core", + "stringprep", + "thiserror 1.0.69", + "tracing", + "whoami", +] + +[[package]] +name = "sqlx-sqlite" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b244ef0a8414da0bed4bb1910426e890b19e5e9bccc27ada6b797d05c55ae0aa" +dependencies = [ + "atoi", + "chrono", + "flume", + "futures-channel", + "futures-core", + "futures-executor", + "futures-intrusive", + "futures-util", + "libsqlite3-sys", + "log", + "percent-encoding", + "serde", + "sqlx-core", + "tracing", + "url", + "urlencoding", +] + [[package]] name = "sse-stream" version = "0.2.1" @@ -6404,12 +6511,6 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" -[[package]] -name = "static_assertions" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" - [[package]] name = "std_prelude" version = "0.2.12" @@ -6422,6 +6523,17 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b2231b7c3057d5e4ad0156fb3dc807d900806020c5ffa3ee6ff2c8c76fb8520" +[[package]] +name = "stringprep" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b4df3d392d81bd458a8a621b8bffbd2302a12ffe288a9d931670948749463b1" +dependencies = [ + "unicode-bidi", + "unicode-normalization", + "unicode-properties", +] + [[package]] name = "strsim" version = "0.11.1" @@ -6441,6 +6553,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" dependencies = [ "proc-macro2", + "quote", "unicode-ident", ] @@ -6559,7 +6672,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a3e535eb8dded36d55ec13eddacd30dec501792ff23a0b1682c38601b8cf2349" dependencies = [ "cfg-expr", - "heck", + "heck 0.5.0", "pkg-config", "toml", "version-compare", @@ -7333,6 +7446,12 @@ version = "2.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539" +[[package]] +name = "unicode-bidi" +version = "0.3.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c1cb5db39152898a79168971543b1cb5020dff7fe43c8dc468b0885f5e29df5" + [[package]] name = "unicode-ident" version = "1.0.18" @@ -7354,6 +7473,12 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "unicode-properties" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e70f2a8b45122e719eb623c01822704c4e0907e7e426a05927e1a1cfff5b75d0" + [[package]] name = "unicode-segmentation" version = "1.12.0" @@ -7372,6 +7497,12 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd" +[[package]] +name = "unicode_categories" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" + [[package]] name = "unsafe-libyaml" version = "0.2.11" @@ -7546,6 +7677,12 @@ dependencies = [ "wit-bindgen-rt", ] +[[package]] +name = "wasite" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b" + [[package]] name = "wasm-bindgen" version = "0.2.100" @@ -7684,6 +7821,12 @@ dependencies = [ "web-sys", ] +[[package]] +name = "webpki-roots" +version = "0.25.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f20c57d8d7db6d3b86154206ae5d8fba62dd39573114de97c2cb0578251f8e1" + [[package]] name = "webpki-roots" version = "0.26.8" @@ -7723,6 +7866,16 @@ dependencies = [ "winsafe", ] +[[package]] +name = "whoami" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d4a4db5077702ca3015d3d02d74974948aba2ad9e12ab7df718ee64ccd7e97d" +dependencies = [ + "libredox", + "wasite", +] + [[package]] name = "wild" version = "2.2.1" diff --git a/crates/goose-bench/src/bench_session.rs b/crates/goose-bench/src/bench_session.rs index 30dc1dd7cea9..972ad99f832e 100644 --- a/crates/goose-bench/src/bench_session.rs +++ b/crates/goose-bench/src/bench_session.rs @@ -3,7 +3,6 @@ use chrono::{DateTime, Utc}; use goose::conversation::Conversation; use serde::{Deserialize, Serialize}; -use std::path::PathBuf; use std::sync::Arc; use tokio::sync::Mutex; @@ -18,9 +17,9 @@ pub struct BenchAgentError { #[async_trait] pub trait BenchBaseSession: Send + Sync { async fn headless(&mut self, message: String) -> anyhow::Result<()>; - fn session_file(&self) -> Option; fn message_history(&self) -> Conversation; fn get_total_token_usage(&self) -> anyhow::Result>; + fn get_session_id(&self) -> anyhow::Result; } // struct for managing agent-session-access. to be passed to evals for benchmarking pub struct BenchAgent { @@ -52,7 +51,8 @@ impl BenchAgent { pub(crate) async fn get_token_usage(&self) -> Option { self.session.get_total_token_usage().ok().flatten() } - pub(crate) fn session_file(&self) -> Option { - self.session.session_file() + + pub(crate) fn get_session_id(&self) -> anyhow::Result { + self.session.get_session_id() } } diff --git a/crates/goose-bench/src/runners/eval_runner.rs b/crates/goose-bench/src/runners/eval_runner.rs index ec1764a3e7f0..88eed45e63f8 100644 --- a/crates/goose-bench/src/runners/eval_runner.rs +++ b/crates/goose-bench/src/runners/eval_runner.rs @@ -5,6 +5,7 @@ use crate::eval_suites::{EvaluationSuite, ExtensionRequirements}; use crate::reporting::EvaluationResult; use crate::utilities::await_process_exits; use anyhow::{bail, Context, Result}; +use goose::session::SessionManager; use std::env; use std::fs; use std::future::Future; @@ -155,15 +156,14 @@ impl EvalRunner { .canonicalize() .context("Failed to canonicalize current directory path")?; - BenchmarkWorkDir::deep_copy( - agent - .session_file() - .expect("Failed to get session file") - .as_path(), - here.as_path(), - false, - ) - .context("Failed to copy session file to evaluation directory")?; + let session_id = agent.get_session_id()?.to_string(); + let session = SessionManager::get_session(&session_id, true).await?; + + let session_json = serde_json::to_string_pretty(&session) + .context("Failed to serialize session to JSON")?; + + fs::write(here.join("session.json"), session_json) + .context("Failed to write session JSON to evaluation directory")?; tracing::info!("Evaluation completed successfully"); } else { diff --git a/crates/goose-cli/src/cli.rs b/crates/goose-cli/src/cli.rs index 1b09fd111407..a760a8b98fd1 100644 --- a/crates/goose-cli/src/cli.rs +++ b/crates/goose-cli/src/cli.rs @@ -18,8 +18,8 @@ use crate::commands::schedule::{ use crate::commands::session::{handle_session_list, handle_session_remove}; use crate::recipes::extract_from_cli::extract_recipe_info_from_cli; use crate::recipes::recipe::{explain_recipe, render_recipe_as_yaml}; -use crate::session; use crate::session::{build_session, SessionBuilderConfig, SessionSettings}; +use goose::session::SessionManager; use goose_bench::bench_config::BenchRunConfig; use goose_bench::runners::bench_runner::BenchRunner; use goose_bench::runners::eval_runner::EvalRunner; @@ -48,26 +48,45 @@ struct Identifier { )] name: Option, + #[arg( + long = "session-id", + value_name = "SESSION_ID", + help = "Session ID (e.g., '20250921_143022')", + long_help = "Specify a session ID directly. When used with --resume, will resume this specific session if it exists." + )] + session_id: Option, + #[arg( short, long, value_name = "PATH", - help = "Path for the chat session (e.g., './playground.jsonl')", - long_help = "Specify a path for your chat session. When used with --resume, will resume this specific session if it exists." + help = "Legacy: Path for the chat session", + long_help = "Legacy parameter for backward compatibility. Extracts session ID from the file path (e.g., '/path/to/20250325_200615. +jsonl' -> '20250325_200615')." )] path: Option, } -fn extract_identifier(identifier: Identifier) -> session::Identifier { - if let Some(name) = identifier.name { - session::Identifier::Name(name) +async fn get_session_id(identifier: Identifier) -> Result { + if let Some(session_id) = identifier.session_id { + Ok(session_id) + } else if let Some(name) = identifier.name { + let sessions = SessionManager::list_sessions().await?; + + sessions + .into_iter() + .find(|s| s.description == name) + .map(|s| s.id) + .ok_or_else(|| anyhow::anyhow!("No session found with name '{}'", name)) } else if let Some(path) = identifier.path { - session::Identifier::Path(path) + path.file_stem() + .and_then(|s| s.to_str()) + .map(|s| s.to_string()) + .ok_or_else(|| anyhow::anyhow!("Could not extract session ID from path: {:?}", path)) } else { unreachable!() } } - fn parse_key_val(s: &str) -> Result<(String, String), String> { match s.split_once('=') { Some((key, value)) => Ok((key.to_string(), value.to_string())), @@ -121,6 +140,14 @@ enum SessionCommand { long_help = "Path to save the exported Markdown. If not provided, output will be sent to stdout" )] output: Option, + + #[arg( + long = "format", + value_name = "FORMAT", + help = "Output format (markdown, json, yaml)", + default_value = "markdown" + )] + format: String, }, } @@ -768,19 +795,24 @@ pub async fn cli() -> Result<()> { format, ascending, }) => { - handle_session_list(verbose, format, ascending)?; + handle_session_list(verbose, format, ascending).await?; Ok(()) } Some(SessionCommand::Remove { id, regex }) => { - handle_session_remove(id, regex)?; + handle_session_remove(id, regex).await?; return Ok(()); } - Some(SessionCommand::Export { identifier, output }) => { + Some(SessionCommand::Export { + identifier, + output, + format, + }) => { let session_identifier = if let Some(id) = identifier { - extract_identifier(id) + get_session_id(id).await? } else { // If no identifier is provided, prompt for interactive selection - match crate::commands::session::prompt_interactive_session_selection() { + match crate::commands::session::prompt_interactive_session_selection().await + { Ok(id) => id, Err(e) => { eprintln!("Error: {}", e); @@ -789,7 +821,12 @@ pub async fn cli() -> Result<()> { } }; - crate::commands::session::handle_session_export(session_identifier, output)?; + crate::commands::session::handle_session_export( + session_identifier, + output, + format, + ) + .await?; Ok(()) } None => { @@ -803,9 +840,15 @@ pub async fn cli() -> Result<()> { "Session started" ); + let session_id = if let Some(id) = identifier { + Some(get_session_id(id).await?) + } else { + None + }; + // Run session command by default - let mut session: crate::Session = build_session(SessionBuilderConfig { - identifier: identifier.map(extract_identifier), + let mut session: crate::CliSession = build_session(SessionBuilderConfig { + session_id, resume, no_session: false, extensions, @@ -841,6 +884,7 @@ pub async fn cli() -> Result<()> { let (total_tokens, message_count) = session .get_metadata() + .await .map(|m| (m.total_tokens.unwrap_or(0), m.message_count)) .unwrap_or((0, 0)); @@ -994,9 +1038,14 @@ pub async fn cli() -> Result<()> { std::process::exit(1); } }; + let session_id = if let Some(id) = identifier { + Some(get_session_id(id).await?) + } else { + None + }; let mut session = build_session(SessionBuilderConfig { - identifier: identifier.map(extract_identifier), + session_id, resume, no_session, extensions, @@ -1048,6 +1097,7 @@ pub async fn cli() -> Result<()> { let (total_tokens, message_count) = session .get_metadata() + .await .map(|m| (m.total_tokens.unwrap_or(0), m.message_count)) .unwrap_or((0, 0)); @@ -1172,7 +1222,7 @@ pub async fn cli() -> Result<()> { } else { // Run session command by default let mut session = build_session(SessionBuilderConfig { - identifier: None, + session_id: None, resume: false, no_session: false, extensions: Vec::new(), @@ -1188,7 +1238,7 @@ pub async fn cli() -> Result<()> { max_tool_repetitions: None, max_turns: None, scheduled_job_id: None, - interactive: true, // Default case is always interactive + interactive: true, quiet: false, sub_recipes: None, final_output_response: None, diff --git a/crates/goose-cli/src/commands/bench.rs b/crates/goose-cli/src/commands/bench.rs index 982c74aafa93..d67b91913151 100644 --- a/crates/goose-cli/src/commands/bench.rs +++ b/crates/goose-cli/src/commands/bench.rs @@ -1,38 +1,42 @@ use crate::session::build_session; use crate::session::SessionBuilderConfig; -use crate::{logging, session, Session}; +use crate::{logging, CliSession}; use async_trait::async_trait; use goose::conversation::Conversation; use goose_bench::bench_session::{BenchAgent, BenchBaseSession}; use goose_bench::eval_suites::ExtensionRequirements; -use std::path::PathBuf; use std::sync::Arc; use tokio::sync::Mutex; // allow session obj to be used in benchmarking #[async_trait] -impl BenchBaseSession for Session { +impl BenchBaseSession for CliSession { async fn headless(&mut self, message: String) -> anyhow::Result<()> { self.headless(message).await } - fn session_file(&self) -> Option { - self.session_file() - } fn message_history(&self) -> Conversation { self.message_history() } fn get_total_token_usage(&self) -> anyhow::Result> { - self.get_total_token_usage() + // Since the trait requires sync but the session method is async, + // we need to block on the async call + tokio::task::block_in_place(|| { + tokio::runtime::Handle::current().block_on(self.get_total_token_usage()) + }) + } + + fn get_session_id(&self) -> anyhow::Result { + self.session_id() + .cloned() + .ok_or_else(|| anyhow::anyhow!("No session ID available")) } } pub async fn agent_generator( requirements: ExtensionRequirements, session_id: String, ) -> BenchAgent { - let identifier = Some(session::Identifier::Name(session_id)); - let base_session = build_session(SessionBuilderConfig { - identifier, + session_id: Some(session_id), resume: false, no_session: false, extensions: requirements.external, @@ -56,10 +60,8 @@ pub async fn agent_generator( }) .await; - // package session obj into benchmark-compatible struct let bench_agent = BenchAgent::new(Box::new(base_session)); - // Initialize logging with error capture let errors = Some(Arc::new(Mutex::new(bench_agent.get_errors().await))); logging::setup_logging(Some("bench"), errors).expect("Failed to initialize logging"); diff --git a/crates/goose-cli/src/commands/schedule.rs b/crates/goose-cli/src/commands/schedule.rs index 4d94c6de8e80..c0be8a21166c 100644 --- a/crates/goose-cli/src/commands/schedule.rs +++ b/crates/goose-cli/src/commands/schedule.rs @@ -219,11 +219,10 @@ pub async fn handle_schedule_sessions(id: String, limit: Option) -> Result< // sessions is now Vec<(String, SessionMetadata)> for (session_name, metadata) in sessions { println!( - " - Session ID: {}, Working Dir: {}, Description: \"{}\", Messages: {}, Schedule ID: {:?}", + " - Session ID: {}, Working Dir: {}, Description: \"{}\", Schedule ID: {:?}", session_name, // Display the session_name as Session ID metadata.working_dir.display(), metadata.description, - metadata.message_count, metadata.schedule_id.as_deref().unwrap_or("N/A") ); } diff --git a/crates/goose-cli/src/commands/session.rs b/crates/goose-cli/src/commands/session.rs index c37e57994b51..fa8e13f7c40d 100644 --- a/crates/goose-cli/src/commands/session.rs +++ b/crates/goose-cli/src/commands/session.rs @@ -1,19 +1,19 @@ use crate::session::message_to_markdown; use anyhow::{Context, Result}; + use cliclack::{confirm, multiselect, select}; -use goose::session::info::{get_valid_sorted_sessions, SessionInfo, SortOrder}; -use goose::session::{self, Identifier}; +use goose::session::{Session, SessionManager}; use goose::utils::safe_truncate; use regex::Regex; use std::fs; -use std::path::{Path, PathBuf}; +use std::path::PathBuf; const TRUNCATED_DESC_LENGTH: usize = 60; -pub fn remove_sessions(sessions: Vec) -> Result<()> { +pub async fn remove_sessions(sessions: Vec) -> Result<()> { println!("The following sessions will be removed:"); for session in &sessions { - println!("- {}", session.id); + println!("- {} {}", session.id, session.description); } let should_delete = confirm("Are you sure you want to delete these sessions?") @@ -22,8 +22,7 @@ pub fn remove_sessions(sessions: Vec) -> Result<()> { if should_delete { for session in sessions { - fs::remove_file(session.path.clone()) - .with_context(|| format!("Failed to remove session file '{}'", session.path))?; + SessionManager::delete_session(&session.id).await?; println!("Session `{}` removed.", session.id); } } else { @@ -33,7 +32,7 @@ pub fn remove_sessions(sessions: Vec) -> Result<()> { Ok(()) } -fn prompt_interactive_session_removal(sessions: &[SessionInfo]) -> Result> { +fn prompt_interactive_session_removal(sessions: &[Session]) -> Result> { if sessions.is_empty() { println!("No sessions to delete."); return Ok(vec![]); @@ -43,16 +42,16 @@ fn prompt_interactive_session_removal(sessions: &[SessionInfo]) -> Result = sessions + let display_map: std::collections::HashMap = sessions .iter() .map(|s| { - let desc = if s.metadata.description.is_empty() { + let desc = if s.description.is_empty() { "(no description)" } else { - &s.metadata.description + &s.description }; let truncated_desc = safe_truncate(desc, TRUNCATED_DESC_LENGTH); - let display_text = format!("{} - {} ({})", s.modified, truncated_desc, s.id); + let display_text = format!("{} - {} ({})", s.updated_at, truncated_desc, s.id); (display_text, s.clone()) }) .collect(); @@ -63,7 +62,7 @@ fn prompt_interactive_session_removal(sessions: &[SessionInfo]) -> Result = selector.interact()?; - let selected_sessions: Vec = selected_display_texts + let selected_sessions: Vec = selected_display_texts .into_iter() .filter_map(|text| display_map.get(&text).cloned()) .collect(); @@ -71,8 +70,8 @@ fn prompt_interactive_session_removal(sessions: &[SessionInfo]) -> Result, regex_string: Option) -> Result<()> { - let all_sessions = match get_valid_sorted_sessions(SortOrder::Descending) { +pub async fn handle_session_remove(id: Option, regex_string: Option) -> Result<()> { + let all_sessions = match SessionManager::list_sessions().await { Ok(sessions) => sessions, Err(e) => { tracing::error!("Failed to retrieve sessions: {:?}", e); @@ -80,7 +79,7 @@ pub fn handle_session_remove(id: Option, regex_string: Option) - } }; - let matched_sessions: Vec; + let matched_sessions: Vec; if let Some(id_val) = id { if let Some(session) = all_sessions.iter().find(|s| s.id == id_val) { @@ -112,23 +111,16 @@ pub fn handle_session_remove(id: Option, regex_string: Option) - return Ok(()); } - remove_sessions(matched_sessions) + remove_sessions(matched_sessions).await } -pub fn handle_session_list(verbose: bool, format: String, ascending: bool) -> Result<()> { - let sort_order = if ascending { - SortOrder::Ascending +pub async fn handle_session_list(verbose: bool, format: String, ascending: bool) -> Result<()> { + let mut sessions = SessionManager::list_sessions().await?; + if ascending { + sessions.sort_by(|a, b| a.updated_at.cmp(&b.updated_at)); } else { - SortOrder::Descending - }; - - let sessions = match get_valid_sorted_sessions(sort_order) { - Ok(sessions) => sessions, - Err(e) => { - tracing::error!("Failed to list sessions: {:?}", e); - return Err(anyhow::anyhow!("Failed to list sessions")); - } - }; + sessions.sort_by(|a, b| b.updated_at.cmp(&a.updated_at)); + } match format.as_str() { "json" => { @@ -138,27 +130,18 @@ pub fn handle_session_list(verbose: bool, format: String, ascending: bool) -> Re if sessions.is_empty() { println!("No sessions found"); return Ok(()); - } else { - println!("Available sessions:"); - for SessionInfo { - id, - path, - metadata, - modified, - } in sessions - { - let description = if metadata.description.is_empty() { - "(none)" - } else { - &metadata.description - }; - let output = format!("{} - {} - {}", id, description, modified); - if verbose { - println!(" {}", output); - println!(" Path: {}", path); - } else { - println!("{}", output); - } + } + + println!("Available sessions:"); + for session in sessions { + let output = format!( + "{} - {} - {}", + session.id, session.description, session.updated_at + ); + if verbose { + println!(" {}", output); + } else { + println!("{}", output); } } } @@ -166,68 +149,55 @@ pub fn handle_session_list(verbose: bool, format: String, ascending: bool) -> Re Ok(()) } -/// Export a session to Markdown without creating a full Session object -/// -/// This function directly reads messages from the session file and converts them to Markdown -/// without creating an Agent or prompting about working directories. -pub fn handle_session_export(identifier: Identifier, output_path: Option) -> Result<()> { - // Get the session file path - let session_file_path = match goose::session::get_path(identifier.clone()) { - Ok(path) => path, +pub async fn handle_session_export( + session_id: String, + output_path: Option, + format: String, +) -> Result<()> { + let session = match SessionManager::get_session(&session_id, true).await { + Ok(session) => session, Err(e) => { - return Err(anyhow::anyhow!("Invalid session identifier: {}", e)); + return Err(anyhow::anyhow!( + "Session '{}' not found or failed to read: {}", + session_id, + e + )); } }; - if !session_file_path.exists() { - return Err(anyhow::anyhow!( - "Session file not found (expected path: {})", - session_file_path.display() - )); - } - - // Read messages directly without using Session - let messages = match goose::session::read_messages(&session_file_path) { - Ok(msgs) => msgs, - Err(e) => { - return Err(anyhow::anyhow!("Failed to read session messages: {}", e)); + let output = match format.as_str() { + "json" => serde_json::to_string_pretty(&session)?, + "yaml" => serde_yaml::to_string(&session)?, + "markdown" => { + let conversation = session + .conversation + .ok_or_else(|| anyhow::anyhow!("Session has no messages"))?; + export_session_to_markdown(conversation.messages().to_vec(), &session.description) } + _ => return Err(anyhow::anyhow!("Unsupported format: {}", format)), }; - // Generate the markdown content using the export functionality - let markdown = - export_session_to_markdown(messages.messages().clone(), &session_file_path, None); - - // Output the markdown - if let Some(output) = output_path { - fs::write(&output, markdown) - .with_context(|| format!("Failed to write to output file: {}", output.display()))?; - println!("Session exported to {}", output.display()); + if let Some(output_path) = output_path { + fs::write(&output_path, output).with_context(|| { + format!("Failed to write to output file: {}", output_path.display()) + })?; + println!("Session exported to {}", output_path.display()); } else { - println!("{}", markdown); + println!("{}", output); } Ok(()) } - /// Convert a list of messages to markdown format for session export /// /// This function handles the formatting of a complete session including headers, /// message organization, and proper tool request/response pairing. fn export_session_to_markdown( messages: Vec, - session_file: &Path, - session_name_override: Option<&str>, + session_name: &String, ) -> String { let mut markdown_output = String::new(); - let session_name = session_name_override.unwrap_or_else(|| { - session_file - .file_stem() - .and_then(|s| s.to_str()) - .unwrap_or("Unnamed Session") - }); - markdown_output.push_str(&format!("# Session Export: {}\n\n", session_name)); if messages.is_empty() { @@ -293,15 +263,8 @@ fn export_session_to_markdown( /// Prompt the user to interactively select a session /// /// Shows a list of available sessions and lets the user select one -pub fn prompt_interactive_session_selection() -> Result { - // Get sessions sorted by modification date (newest first) - let sessions = match get_valid_sorted_sessions(SortOrder::Descending) { - Ok(sessions) => sessions, - Err(e) => { - tracing::error!("Failed to list sessions: {:?}", e); - return Err(anyhow::anyhow!("Failed to list sessions")); - } - }; +pub async fn prompt_interactive_session_selection() -> Result { + let sessions = SessionManager::list_sessions().await?; if sessions.is_empty() { return Err(anyhow::anyhow!("No sessions found")); @@ -311,19 +274,17 @@ pub fn prompt_interactive_session_selection() -> Result { let mut selector = select("Select a session to export:"); // Map to display text - let display_map: std::collections::HashMap = sessions + let display_map: std::collections::HashMap = sessions .iter() .map(|s| { - let desc = if s.metadata.description.is_empty() { + let desc = if s.description.is_empty() { "(no description)" } else { - &s.metadata.description + &s.description }; + let truncated_desc = safe_truncate(desc, TRUNCATED_DESC_LENGTH); - // Truncate description if too long - let truncated_desc = safe_truncate(desc, 40); - - let display_text = format!("{} - {} ({})", s.modified, truncated_desc, s.id); + let display_text = format!("{} - {} ({})", s.updated_at, truncated_desc, s.id); (display_text, s.clone()) }) .collect(); @@ -346,7 +307,7 @@ pub fn prompt_interactive_session_selection() -> Result { // Retrieve the selected session if let Some(session) = display_map.get(&selected_display_text) { - Ok(goose::session::Identifier::Name(session.id.clone())) + Ok(session.id.clone()) } else { Err(anyhow::anyhow!("Invalid selection")) } diff --git a/crates/goose-cli/src/commands/web.rs b/crates/goose-cli/src/commands/web.rs index 6f175d57e3ae..4543065d38b0 100644 --- a/crates/goose-cli/src/commands/web.rs +++ b/crates/goose-cli/src/commands/web.rs @@ -8,24 +8,25 @@ use axum::{ routing::get, Json, Router, }; +use goose::session::SessionManager; +use webbrowser; + use futures::{sink::SinkExt, stream::StreamExt}; use goose::agents::{Agent, AgentEvent}; use goose::conversation::message::Message as GooseMessage; -use goose::conversation::Conversation; -use goose::session; + +use axum::response::Redirect; use serde::{Deserialize, Serialize}; use std::{net::SocketAddr, sync::Arc}; use tokio::sync::{Mutex, RwLock}; use tower_http::cors::{Any, CorsLayer}; use tracing::error; -type SessionStore = Arc>>>>; type CancellationStore = Arc>>; #[derive(Clone)] struct AppState { agent: Arc, - sessions: SessionStore, cancellations: CancellationStore, } @@ -123,7 +124,6 @@ pub async fn handle_web(port: u16, host: String, open: bool) -> Result<()> { let state = AppState { agent: Arc::new(agent), - sessions: Arc::new(RwLock::new(std::collections::HashMap::new())), cancellations: Arc::new(RwLock::new(std::collections::HashMap::new())), }; @@ -169,8 +169,15 @@ pub async fn handle_web(port: u16, host: String, open: bool) -> Result<()> { Ok(()) } -async fn serve_index() -> Html<&'static str> { - Html(include_str!("../../static/index.html")) +async fn serve_index() -> Result { + let session = SessionManager::create_session( + std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")), + "Web session".to_string(), + ) + .await + .map_err(|err| (http::StatusCode::INTERNAL_SERVER_ERROR, err.to_string()))?; + + Ok(Redirect::to(&format!("/session/{}", session.id))) } async fn serve_session( @@ -222,23 +229,19 @@ async fn health_check() -> Json { } async fn list_sessions() -> Json { - match session::list_sessions() { + match SessionManager::list_sessions().await { Ok(sessions) => { - let session_info: Vec = sessions - .into_iter() - .filter_map(|(name, path)| { - session::read_metadata(&path).ok().map(|metadata| { - serde_json::json!({ - "name": name, - "path": path, - "description": metadata.description, - "message_count": metadata.message_count, - "working_dir": metadata.working_dir - }) - }) - }) - .collect(); - + let mut session_info = Vec::new(); + + for session in sessions { + session_info.push(serde_json::json!({ + "name": session.id, + "path": session.id, + "description": session.description, + "message_count": session.message_count, + "working_dir": session.working_dir + })); + } Json(serde_json::json!({ "sessions": session_info })) @@ -251,30 +254,14 @@ async fn list_sessions() -> Json { async fn get_session( axum::extract::Path(session_id): axum::extract::Path, ) -> Json { - let session_file = match session::get_path(session::Identifier::Name(session_id)) { - Ok(path) => path, - Err(e) => { - return Json(serde_json::json!({ - "error": format!("Invalid session ID: {}", e) - })); - } - }; - - let error_response = |e: Box| { - Json(serde_json::json!({ + match SessionManager::get_session(&session_id, true).await { + Ok(session) => Json(serde_json::json!({ + "metadata": session, + "messages": session.conversation.unwrap_or_default().messages() + })), + Err(e) => Json(serde_json::json!({ "error": e.to_string() - })) - }; - - match session::read_messages(&session_file) { - Ok(messages) => match session::read_metadata(&session_file) { - Ok(metadata) => Json(serde_json::json!({ - "metadata": metadata, - "messages": messages - })), - Err(e) => error_response(e.into()), - }, - Err(e) => error_response(e.into()), + })), } } @@ -299,46 +286,14 @@ async fn handle_socket(socket: WebSocket, state: AppState) { session_id, .. }) => { - // Get session file path from session_id - let session_file = match session::get_path(session::Identifier::Name( - session_id.clone(), - )) { - Ok(path) => path, - Err(e) => { - tracing::error!("Failed to get session path: {}", e); - continue; - } - }; - - // Get or create session in memory (for fast access during processing) - let session_messages = { - let sessions = state.sessions.read().await; - if let Some(session) = sessions.get(&session_id) { - session.clone() - } else { - drop(sessions); - let mut sessions = state.sessions.write().await; - - // Load existing messages from JSONL file if it exists - let existing_messages = - session::read_messages(&session_file).unwrap_or_default(); - - let new_session = Arc::new(Mutex::new(existing_messages)); - sessions.insert(session_id.clone(), new_session.clone()); - new_session - } - }; - - // Clone sender for async processing let sender_clone = sender.clone(); let agent = state.agent.clone(); + let session_id_clone = session_id.clone(); - // Process message in a separate task to allow streaming let task_handle = tokio::spawn(async move { let result = process_message_streaming( &agent, - session_messages, - session_file, + session_id_clone, content, sender_clone, ) @@ -349,25 +304,21 @@ async fn handle_socket(socket: WebSocket, state: AppState) { } }); - // Store the abort handle { let mut cancellations = state.cancellations.write().await; cancellations .insert(session_id.clone(), task_handle.abort_handle()); } - // Wait for task completion and handle abort + // Handle task completion and cleanup let sender_for_abort = sender.clone(); let session_id_for_cleanup = session_id.clone(); let cancellations_for_cleanup = state.cancellations.clone(); tokio::spawn(async move { match task_handle.await { - Ok(_) => { - // Task completed normally - } + Ok(_) => {} Err(e) if e.is_cancelled() => { - // Task was aborted let mut sender = sender_for_abort.lock().await; let _ = sender .send(Message::Text( @@ -387,11 +338,8 @@ async fn handle_socket(socket: WebSocket, state: AppState) { } } - // Clean up cancellation token - { - let mut cancellations = cancellations_for_cleanup.write().await; - cancellations.remove(&session_id_for_cleanup); - } + let mut cancellations = cancellations_for_cleanup.write().await; + cancellations.remove(&session_id_for_cleanup); }); } Ok(WebSocketMessage::Cancel { session_id }) => { @@ -436,27 +384,16 @@ async fn handle_socket(socket: WebSocket, state: AppState) { async fn process_message_streaming( agent: &Agent, - session_messages: Arc>, - session_file: std::path::PathBuf, + session_id: String, content: String, sender: Arc>>, ) -> Result<()> { use futures::StreamExt; use goose::agents::SessionConfig; use goose::conversation::message::MessageContent; - use goose::session; - // Create a user message let user_message = GooseMessage::user().with_text(content.clone()); - // Messages will be auto-compacted in agent.reply() if needed - let messages: Conversation = { - let mut session_msgs = session_messages.lock().await; - session_msgs.push(user_message.clone()); - session_msgs.clone() - }; - - // Persist messages to JSONL file with provider for automatic description generation let provider = agent.provider().await; if provider.is_err() { let error_msg = "I'm not properly configured yet. Please configure a provider through the CLI first using `goose configure`.".to_string(); @@ -475,19 +412,13 @@ async fn process_message_streaming( return Ok(()); } - let provider = provider.unwrap(); - let working_dir = Some(std::env::current_dir()?); - session::persist_messages( - &session_file, - &messages, - Some(provider.clone()), - working_dir.clone(), - ) - .await?; + let session = SessionManager::get_session(&session_id, true).await?; + let mut messages = session.conversation.unwrap_or_default(); + messages.push(user_message); let session_config = SessionConfig { - id: session::Identifier::Path(session_file.clone()), - working_dir: std::env::current_dir()?, + id: session.id.clone(), + working_dir: session.working_dir, schedule_id: None, execution_mode: None, max_turns: None, @@ -502,29 +433,11 @@ async fn process_message_streaming( while let Some(result) = stream.next().await { match result { Ok(AgentEvent::Message(message)) => { - // Add message to our session - { - let mut session_msgs = session_messages.lock().await; - session_msgs.push(message.clone()); - } + SessionManager::add_message(&session_id, &message).await?; - // Persist messages to JSONL file (no provider needed for assistant messages) - let current_messages = { - let session_msgs = session_messages.lock().await; - session_msgs.clone() - }; - session::persist_messages( - &session_file, - ¤t_messages, - None, - working_dir.clone(), - ) - .await?; - // Handle different message content types for content in &message.content { match content { MessageContent::Text(text) => { - // Send the text response let mut sender = sender.lock().await; let _ = sender .send(Message::Text( @@ -539,7 +452,6 @@ async fn process_message_streaming( .await; } MessageContent::ToolRequest(req) => { - // Send tool request notification let mut sender = sender.lock().await; if let Ok(tool_call) = &req.tool_call { let _ = sender @@ -557,13 +469,8 @@ async fn process_message_streaming( .await; } } - MessageContent::ToolResponse(_resp) => { - // Tool responses are already included in the complete message stream - // and will be persisted to session history. No need to send separate - // WebSocket messages as this would cause duplicates. - } + MessageContent::ToolResponse(_resp) => {} MessageContent::ToolConfirmationRequest(confirmation) => { - // Send tool confirmation request let mut sender = sender.lock().await; let _ = sender .send(Message::Text( @@ -580,8 +487,6 @@ async fn process_message_streaming( )) .await; - // For now, auto-approve in web mode - // TODO: Implement proper confirmation UI agent.handle_confirmation( confirmation.id.clone(), goose::permission::PermissionConfirmation { @@ -591,7 +496,6 @@ async fn process_message_streaming( ).await; } MessageContent::Thinking(thinking) => { - // Send thinking indicator let mut sender = sender.lock().await; let _ = sender .send(Message::Text( @@ -604,7 +508,6 @@ async fn process_message_streaming( .await; } MessageContent::ContextLengthExceeded(msg) => { - // Send context exceeded notification let mut sender = sender.lock().await; let _ = sender .send(Message::Text( @@ -618,55 +521,27 @@ async fn process_message_streaming( )) .await; - // For now, auto-summarize in web mode - // TODO: Implement proper UI for context handling let (summarized_messages, _, _) = agent.summarize_context(messages.messages()).await?; - { - let mut session_msgs = session_messages.lock().await; - *session_msgs = summarized_messages; - } - } - _ => { - // Handle other message types as needed + SessionManager::replace_conversation( + &session_id, + &summarized_messages, + ) + .await?; } + _ => {} } } } - Ok(AgentEvent::HistoryReplaced(new_messages)) => { - // Replace the session's message history with the compacted messages - { - let mut session_msgs = session_messages.lock().await; - *session_msgs = Conversation::new_unvalidated(new_messages); - } - - // Persist the updated messages to the JSONL file - let current_messages = { - let session_msgs = session_messages.lock().await; - session_msgs.clone() - }; - - if let Err(e) = session::persist_messages( - &session_file, - ¤t_messages, - None, // No provider needed for persisting - working_dir.clone(), - ) - .await - { - error!("Failed to persist compacted messages: {}", e); - } + Ok(AgentEvent::HistoryReplaced(_new_messages)) => { + tracing::info!("History replaced, compacting happened in reply"); } Ok(AgentEvent::McpNotification(_notification)) => { - // Handle MCP notifications if needed - // For now, we'll just log them tracing::info!("Received MCP notification in web interface"); } Ok(AgentEvent::ModelChange { model, mode }) => { - // Log model change tracing::info!("Model changed to {} in {} mode", model, mode); } - Err(e) => { error!("Error in message stream: {}", e); let mut sender = sender.lock().await; @@ -699,7 +574,6 @@ async fn process_message_streaming( } } - // Send completion message let mut sender = sender.lock().await; let _ = sender .send(Message::Text( @@ -713,6 +587,3 @@ async fn process_message_streaming( Ok(()) } - -// Add webbrowser dependency for opening browser -use webbrowser; diff --git a/crates/goose-cli/src/lib.rs b/crates/goose-cli/src/lib.rs index ad0641b1817b..b2e882fd9cb4 100644 --- a/crates/goose-cli/src/lib.rs +++ b/crates/goose-cli/src/lib.rs @@ -10,7 +10,7 @@ pub mod session; pub mod signal; // Re-export commonly used types -pub use session::Session; +pub use session::CliSession; pub static APP_STRATEGY: Lazy = Lazy::new(|| AppStrategyArgs { top_level_domain: "Block".to_string(), diff --git a/crates/goose-cli/src/scenario_tests/scenario_runner.rs b/crates/goose-cli/src/scenario_tests/scenario_runner.rs index 43dbb2599fd4..a9356b19e38e 100644 --- a/crates/goose-cli/src/scenario_tests/scenario_runner.rs +++ b/crates/goose-cli/src/scenario_tests/scenario_runner.rs @@ -4,7 +4,7 @@ use goose::conversation::Conversation; use crate::scenario_tests::message_generator::MessageGenerator; use crate::scenario_tests::mock_client::weather_client; use crate::scenario_tests::provider_configs::{get_provider_configs, ProviderConfig}; -use crate::session::Session; +use crate::session::CliSession; use anyhow::Result; use goose::agents::Agent; use goose::model::ModelConfig; @@ -218,7 +218,7 @@ where .update_provider(provider_arc as Arc) .await?; - let mut session = Session::new(agent, None, false, None, None, None, None); + let mut session = CliSession::new(agent, None, false, None, None, None, None); let mut error = None; for message in &messages { diff --git a/crates/goose-cli/src/session/builder.rs b/crates/goose-cli/src/session/builder.rs index 0a7b7f8fb1a5..988954330c87 100644 --- a/crates/goose-cli/src/session/builder.rs +++ b/crates/goose-cli/src/session/builder.rs @@ -1,28 +1,27 @@ +use super::output; +use super::CliSession; use console::style; use goose::agents::types::RetryConfig; use goose::agents::Agent; use goose::config::{Config, ExtensionConfig, ExtensionConfigManager}; use goose::providers::create; use goose::recipe::{Response, SubRecipe}; -use goose::session; -use goose::session::Identifier; + +use goose::session::SessionManager; use rustyline::EditMode; use std::collections::HashSet; use std::process; use std::sync::Arc; use tokio::task::JoinSet; -use super::output; -use super::Session; - /// Configuration for building a new Goose session /// /// This struct contains all the parameters needed to create a new session, /// including session identification, extension configuration, and debug settings. #[derive(Default, Clone, Debug)] pub struct SessionBuilderConfig { - /// Optional identifier for the session (name or path) - pub identifier: Option, + /// Optional identifier for the session + pub session_id: Option, /// Whether to resume an existing session pub resume: bool, /// Whether to run without a session file @@ -129,20 +128,8 @@ async fn offer_extension_debugging_help( } } - // Create a temporary session file for this debugging session - let temp_session_file = - std::env::temp_dir().join(format!("goose_debug_extension_{}.jsonl", extension_name)); - // Create the debugging session - let mut debug_session = Session::new( - debug_agent, - Some(temp_session_file.clone()), - false, - None, - None, - None, - None, - ); + let mut debug_session = CliSession::new(debug_agent, None, false, None, None, None, None); // Process the debugging request println!("{}", style("Analyzing the extension failure...").yellow()); @@ -160,10 +147,6 @@ async fn offer_extension_debugging_help( ); } } - - // Clean up the temporary session file - let _ = std::fs::remove_file(temp_session_file); - Ok(()) } @@ -174,7 +157,7 @@ pub struct SessionSettings { pub temperature: Option, } -pub async fn build_session(session_config: SessionBuilderConfig) -> Session { +pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession { // Load config and get provider/model let config = Config::global(); @@ -257,68 +240,64 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session { }); // Handle session file resolution and resuming - let session_file: Option = if session_config.no_session { + let session_id: Option = if session_config.no_session { None } else if session_config.resume { - if let Some(identifier) = session_config.identifier { - let session_file = match session::get_path(identifier) { - Err(e) => { - output::render_error(&format!("Invalid session identifier: {}", e)); + if let Some(session_id) = session_config.session_id { + match SessionManager::get_session(&session_id, false).await { + Ok(_) => Some(session_id), + Err(_) => { + output::render_error(&format!( + "Cannot resume session {} - no such session exists", + style(&session_id).cyan() + )); process::exit(1); } - Ok(path) => path, - }; - if !session_file.exists() { - output::render_error(&format!( - "Cannot resume session {} - no such session exists", - style(session_file.display()).cyan() - )); - process::exit(1); } - - Some(session_file) } else { - // Try to resume most recent session - match session::get_most_recent_session() { - Ok(file) => Some(file), + match SessionManager::list_sessions().await { + Ok(sessions) => { + if sessions.is_empty() { + output::render_error("Cannot resume - no previous sessions found"); + process::exit(1); + } + Some(sessions[0].id.clone()) + } Err(_) => { output::render_error("Cannot resume - no previous sessions found"); process::exit(1); } } } + } else if let Some(session_id) = session_config.session_id { + Some(session_id) } else { - // Create new session with provided name/path or generated name - let id = match session_config.identifier { - Some(identifier) => identifier, - None => Identifier::Name(session::generate_session_id()), - }; - - // Just get the path - file will be created when needed - match session::get_path(id) { - Ok(path) => Some(path), - Err(e) => { - output::render_error(&format!("Failed to create session path: {}", e)); - process::exit(1); - } - } + let session = SessionManager::create_session( + std::env::current_dir().unwrap(), + "CLI Session".to_string(), + ) + .await + .unwrap(); + Some(session.id) }; if session_config.resume { - if let Some(session_file) = session_file.as_ref() { - // Read the session metadata - let metadata = session::read_metadata(session_file).unwrap_or_else(|e| { - output::render_error(&format!("Failed to read session metadata: {}", e)); - process::exit(1); - }); + if let Some(session_id) = session_id.as_ref() { + // Read the session metadata from database + let metadata = SessionManager::get_session(session_id, false) + .await + .unwrap_or_else(|e| { + output::render_error(&format!("Failed to read session metadata: {}", e)); + process::exit(1); + }); let current_workdir = std::env::current_dir().expect("Failed to get current working directory"); if current_workdir != metadata.working_dir { // Ask user if they want to change the working directory let change_workdir = cliclack::confirm(format!("{} The original working directory of this session was set to {}. Your current directory is {}. Do you want to switch back to the original working directory?", style("WARNING:").yellow(), style(metadata.working_dir.display()).cyan(), style(current_workdir.display()).cyan())) - .initial_value(true) - .interact().expect("Failed to get user input"); + .initial_value(true) + .interact().expect("Failed to get user input"); if change_workdir { if !metadata.working_dir.exists() { @@ -417,9 +396,9 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session { }); // Create new session - let mut session = Session::new( + let mut session = CliSession::new( Arc::try_unwrap(agent_ptr).unwrap_or_else(|_| panic!("There should be no more references")), - session_file.clone(), + session_id.clone(), session_config.debug, session_config.scheduled_job_id.clone(), session_config.max_turns, @@ -586,7 +565,7 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session { session_config.resume, &provider_name, &model_name, - &session_file, + &session_id, Some(&provider_for_display), ); } @@ -600,7 +579,7 @@ mod tests { #[test] fn test_session_builder_config_creation() { let config = SessionBuilderConfig { - identifier: Some(Identifier::Name("test".to_string())), + session_id: Some("test".to_string()), resume: false, no_session: false, extensions: vec!["echo test".to_string()], @@ -639,7 +618,7 @@ mod tests { fn test_session_builder_config_default() { let config = SessionBuilderConfig::default(); - assert!(config.identifier.is_none()); + assert!(config.session_id.is_none()); assert!(!config.resume); assert!(!config.no_session); assert!(config.extensions.is_empty()); diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index acd3d8432a1c..68ab52541e55 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -21,7 +21,6 @@ use goose::permission::permission_confirmation::PrincipalType; use goose::permission::Permission; use goose::permission::PermissionConfirmation; use goose::providers::base::Provider; -pub use goose::session::Identifier; use goose::utils::safe_truncate; use anyhow::{Context, Result}; @@ -39,6 +38,7 @@ use rmcp::model::ServerNotification; use rmcp::model::{ErrorCode, ErrorData}; use goose::conversation::message::{Message, MessageContent}; +use goose::session::SessionManager; use rand::{distributions::Alphanumeric, Rng}; use rustyline::EditMode; use serde_json::Value; @@ -54,13 +54,12 @@ pub enum RunMode { Plan, } -pub struct Session { +pub struct CliSession { agent: Agent, messages: Conversation, - session_file: Option, - // Cache for completion data - using std::sync for thread safety without async + session_id: Option, completion_cache: Arc>, - debug: bool, // New field for debug mode + debug: bool, run_mode: RunMode, scheduled_job_id: Option, // ID of the scheduled job that triggered this session max_turns: Option, @@ -111,8 +110,6 @@ pub async fn classify_planner_response( ) .await?; - // println!("classify_planner_response: {result:?}\n"); // TODO: remove - let predicted = result.as_concat_text(); if predicted.to_lowercase().contains("plan") { Ok(PlannerResponseType::Plan) @@ -121,30 +118,33 @@ pub async fn classify_planner_response( } } -impl Session { +impl CliSession { pub fn new( agent: Agent, - session_file: Option, + session_id: Option, debug: bool, scheduled_job_id: Option, max_turns: Option, edit_mode: Option, retry_config: Option, ) -> Self { - let messages = if let Some(session_file) = &session_file { - session::read_messages(session_file).unwrap_or_else(|e| { - eprintln!("Warning: Failed to load message history: {}", e); - Conversation::new_unvalidated(Vec::new()) + let messages = if let Some(session_id) = &session_id { + tokio::task::block_in_place(|| { + tokio::runtime::Handle::current().block_on(async { + SessionManager::get_session(session_id, true) + .await + .map(|session| session.conversation.unwrap_or_default()) + .unwrap() + }) }) } else { - // Don't try to read messages if we're not saving sessions Conversation::new_unvalidated(Vec::new()) }; - Session { + CliSession { agent, messages, - session_file, + session_id, completion_cache: Arc::new(std::sync::RwLock::new(CompletionCache::new())), debug, run_mode: RunMode::Normal, @@ -155,13 +155,15 @@ impl Session { } } - /// Helper function to summarize context messages + pub fn session_id(&self) -> Option<&String> { + self.session_id.as_ref() + } + async fn summarize_context_messages( messages: &mut Conversation, agent: &Agent, message_suffix: &str, ) -> Result<()> { - // Summarize messages to fit within context length let (summarized_messages, _, _) = agent.summarize_context(messages.messages()).await?; let msg = format!("Context maxed out\n{}\n{}", "-".repeat(50), message_suffix); output::render_text(&msg, Some(Color::Yellow), true); @@ -179,7 +181,6 @@ impl Session { let mut parts: Vec<&str> = extension_command.split_whitespace().collect(); let mut envs = HashMap::new(); - // Parse environment variables (format: KEY=value) while let Some(part) = parts.first() { if !part.contains('=') { break; @@ -194,7 +195,6 @@ impl Session { } let cmd = parts.remove(0).to_string(); - // Generate a random name for the ephemeral extension let name: String = rand::thread_rng() .sample_iter(&Alphanumeric) .take(8) @@ -374,46 +374,10 @@ impl Session { cancel_token: CancellationToken, ) -> Result<()> { let cancel_token = cancel_token.clone(); - let message_text = message.as_concat_text(); - self.push_message(message); - // Get the provider from the agent for description generation - let provider = self.agent.provider().await?; - - // Persist messages with provider for automatic description generation - if let Some(session_file) = &self.session_file { - let working_dir = Some( - std::env::current_dir().expect("failed to get current session working directory"), - ); - - session::persist_messages_with_schedule_id( - session_file, - &self.messages, - Some(provider), - self.scheduled_job_id.clone(), - working_dir, - ) - .await?; - } - - // Track the current directory and last instruction in projects.json - let session_id = self - .session_file - .as_ref() - .and_then(|p| p.file_stem()) - .and_then(|s| s.to_str()) - .map(|s| s.to_string()); - - if let Err(e) = crate::project_tracker::update_project_tracker( - Some(&message_text), - session_id.as_deref(), - ) { - eprintln!( - "Warning: Failed to update project tracker with instruction: {}", - e - ); - } + // TODO(Douwe): Make sure we generate the description here still: + self.push_message(message); self.process_agent_response(false, cancel_token).await?; Ok(()) } @@ -493,35 +457,14 @@ impl Session { self.push_message(Message::user().with_text(&content)); // Track the current directory and last instruction in projects.json - let session_id = self - .session_file - .as_ref() - .and_then(|p| p.file_stem()) - .and_then(|s| s.to_str()) - .map(|s| s.to_string()); - if let Err(e) = crate::project_tracker::update_project_tracker( Some(&content), - session_id.as_deref(), + self.session_id.as_deref(), ) { eprintln!("Warning: Failed to update project tracker with instruction: {}", e); } - let provider = self.agent.provider().await?; - - // Persist messages with provider for automatic description generation - if let Some(session_file) = &self.session_file { - let working_dir = Some(std::env::current_dir().unwrap_or_default()); - - session::persist_messages_with_schedule_id( - session_file, - &self.messages, - Some(provider), - self.scheduled_job_id.clone(), - working_dir, - ) - .await?; - } + let _provider = self.agent.provider().await?; output::show_thinking(); let start_time = Instant::now(); @@ -659,16 +602,25 @@ impl Session { input::InputResult::Clear => { save_history(&mut editor); + if let Some(session_id) = &self.session_id { + if let Err(e) = SessionManager::replace_conversation( + session_id, + &Conversation::default(), + ) + .await + { + output::render_error(&format!("Failed to clear session: {}", e)); + continue; + } + } + self.messages.clear(); tracing::info!("Chat context cleared by user."); output::render_message( &Message::assistant().with_text("Chat context cleared."), self.debug, ); - if let Some(file) = self.session_file.as_ref().filter(|f| f.exists()) { - std::fs::remove_file(file)?; - std::fs::File::create(file)?; - } + continue; } input::InputResult::PromptCommand(opts) => { @@ -729,44 +681,30 @@ impl Session { output::show_thinking(); // Get the provider for summarization - let provider = self.agent.provider().await?; + let _provider = self.agent.provider().await?; - // Call the summarize_context method which uses the summarize_messages function + // Call the summarize_context method let (summarized_messages, _token_counts, summarization_usage) = self .agent .summarize_context(self.messages.messages()) .await?; // Update the session messages with the summarized ones - self.messages = summarized_messages; - - // Persist the summarized messages and update session metadata with new token counts - if let Some(session_file) = &self.session_file { - let working_dir = std::env::current_dir().ok(); - session::persist_messages_with_schedule_id( - session_file, - &self.messages, - Some(provider), - self.scheduled_job_id.clone(), - working_dir, - ) - .await?; + self.messages = summarized_messages.clone(); + + // Persist the summarized messages and update session metadata + if let Some(session_id) = &self.session_id { + // Replace all messages with the summarized version + SessionManager::replace_conversation(session_id, &summarized_messages) + .await?; // Update session metadata with the new token counts from summarization if let Some(usage) = summarization_usage { - let session_file_path = session::storage::get_path( - session::storage::Identifier::Path(session_file.to_path_buf()), - )?; - let mut metadata = - session::storage::read_metadata(&session_file_path)?; + let session = + SessionManager::get_session(session_id, false).await?; // Update token counts with the summarization usage - // Use output tokens as total since that's what's actually in the context going forward let summary_tokens = usage.usage.output_tokens.unwrap_or(0); - metadata.total_tokens = Some(summary_tokens); - metadata.input_tokens = None; // Clear input tokens since we now have a summary - metadata.output_tokens = Some(summary_tokens); - metadata.message_count = self.messages.len(); // Update accumulated tokens (add the summarization cost) let accumulate = |a: Option, b: Option| -> Option { @@ -775,20 +713,28 @@ impl Session { _ => a.or(b), } }; - metadata.accumulated_total_tokens = accumulate( - metadata.accumulated_total_tokens, + + let accumulated_total = accumulate( + session.accumulated_total_tokens, usage.usage.total_tokens, ); - metadata.accumulated_input_tokens = accumulate( - metadata.accumulated_input_tokens, + let accumulated_input = accumulate( + session.accumulated_input_tokens, usage.usage.input_tokens, ); - metadata.accumulated_output_tokens = accumulate( - metadata.accumulated_output_tokens, + let accumulated_output = accumulate( + session.accumulated_output_tokens, usage.usage.output_tokens, ); - session::storage::update_metadata(&session_file_path, &metadata) + SessionManager::update_session(session_id) + .total_tokens(Some(summary_tokens)) + .input_tokens(None) + .output_tokens(Some(summary_tokens)) + .accumulated_total_tokens(accumulated_total) + .accumulated_input_tokens(accumulated_input) + .accumulated_output_tokens(accumulated_output) + .apply() .await?; } } @@ -808,19 +754,10 @@ impl Session { } else { println!("{}", console::style("Summarization cancelled.").yellow()); } - continue; } } } - - println!( - "\nClosing session.{}", - self.session_file - .as_ref() - .map(|p| format!(" Recorded to {}", p.display())) - .unwrap_or_default() - ); Ok(()) } @@ -919,16 +856,13 @@ impl Session { ) -> Result<()> { let cancel_token_clone = cancel_token.clone(); - let session_config = self.session_file.as_ref().map(|s| { - let session_id = session::Identifier::Path(s.clone()); - SessionConfig { - id: session_id.clone(), - working_dir: std::env::current_dir().unwrap_or_default(), - schedule_id: self.scheduled_job_id.clone(), - execution_mode: None, - max_turns: self.max_turns, - retry_config: self.retry_config.clone(), - } + let session_config = self.session_id.as_ref().map(|session_id| SessionConfig { + id: session_id.clone(), + working_dir: std::env::current_dir().unwrap_or_default(), + schedule_id: self.scheduled_job_id.clone(), + execution_mode: None, + max_turns: self.max_turns, + retry_config: self.retry_config.clone(), }); let mut stream = self .agent @@ -998,17 +932,6 @@ impl Session { Err(ErrorData { code: ErrorCode::INVALID_REQUEST, message: std::borrow::Cow::from("Tool call cancelled by user".to_string()), data: None }) )); self.messages.push(response_message); - if let Some(session_file) = &self.session_file { - let working_dir = std::env::current_dir().ok(); - session::persist_messages_with_schedule_id( - session_file, - &self.messages, - None, - self.scheduled_job_id.clone(), - working_dir, - ) - .await?; - } cancel_token_clone.cancel(); drop(stream); break; @@ -1140,22 +1063,8 @@ impl Session { ); } } - self.messages.push(message.clone()); - // No need to update description on assistant messages - if let Some(session_file) = &self.session_file { - let working_dir = std::env::current_dir().ok(); - session::persist_messages_with_schedule_id( - session_file, - &self.messages, - None, - self.scheduled_job_id.clone(), - working_dir, - ) - .await?; - } - if interactive {output::hide_thinking()}; let _ = progress_bars.hide(); output::render_message(&message, self.debug); @@ -1267,26 +1176,10 @@ impl Session { _ => (), } } - Some(Ok(AgentEvent::HistoryReplaced(new_messages))) => { - // Replace the session's message history with the compacted messages - self.messages = Conversation::new_unvalidated(new_messages); - - // Persist the updated messages to the session file - if let Some(session_file) = &self.session_file { - let provider = self.agent.provider().await.ok(); - let working_dir = std::env::current_dir().ok(); - if let Err(e) = session::persist_messages_with_schedule_id( - session_file, - &self.messages, - provider, - self.scheduled_job_id.clone(), - working_dir, - ).await { - eprintln!("Failed to persist compacted messages: {}", e); - } - } - } - Some(Ok(AgentEvent::ModelChange { model, mode })) => { + Some(Ok(AgentEvent::HistoryReplaced(new_messages))) => { + self.messages = Conversation::new_unvalidated(new_messages.clone()); + } + Some(Ok(AgentEvent::ModelChange { model, mode })) => { // Log model change if in debug mode if self.debug { eprintln!("Model changed to {} in {} mode", model, mode); @@ -1308,20 +1201,8 @@ impl Session { // Try auto-compaction first - keep the stream alive! if let Ok(compact_result) = goose::context_mgmt::auto_compact::perform_compaction(&self.agent, self.messages.messages()).await { self.messages = compact_result.messages; - - // Persist the compacted messages - if let Some(session_file) = &self.session_file { - let provider = self.agent.provider().await.ok(); - let working_dir = std::env::current_dir().ok(); - if let Err(e) = session::persist_messages_with_schedule_id( - session_file, - &self.messages, - provider, - self.scheduled_job_id.clone(), - working_dir, - ).await { - eprintln!("Failed to persist compacted messages: {}", e); - } + if let Some(session_id) = &self.session_id { + SessionManager::replace_conversation(session_id, &self.messages).await?; } output::render_text( @@ -1465,40 +1346,13 @@ impl Session { }), )); } + // TODO(Douwe): update also db self.push_message(response_message); - - // No need for description update here - if let Some(session_file) = &self.session_file { - let working_dir = std::env::current_dir().ok(); - session::persist_messages_with_schedule_id( - session_file, - &self.messages, - None, - self.scheduled_job_id.clone(), - working_dir, - ) - .await?; - } - let prompt = format!( "The existing call to {} was interrupted. How would you like to proceed?", last_tool_name ); self.push_message(Message::assistant().with_text(&prompt)); - - // No need for description update here - if let Some(session_file) = &self.session_file { - let working_dir = std::env::current_dir().ok(); - session::persist_messages_with_schedule_id( - session_file, - &self.messages, - None, - self.scheduled_job_id.clone(), - working_dir, - ) - .await?; - } - output::render_message(&Message::assistant().with_text(&prompt), self.debug); } else { // An interruption occurred outside of a tool request-response. @@ -1509,20 +1363,6 @@ impl Session { // Interruption occurred after a tool had completed but not assistant reply let prompt = "The tool calling loop was interrupted. How would you like to proceed?"; self.push_message(Message::assistant().with_text(prompt)); - - // No need for description update here - if let Some(session_file) = &self.session_file { - let working_dir = std::env::current_dir().ok(); - session::persist_messages_with_schedule_id( - session_file, - &self.messages, - None, - self.scheduled_job_id.clone(), - working_dir, - ) - .await?; - } - output::render_message( &Message::assistant().with_text(prompt), self.debug, @@ -1545,10 +1385,6 @@ impl Session { Ok(()) } - pub fn session_file(&self) -> Option { - self.session_file.clone() - } - /// Update the completion cache with fresh data /// This should be called before the interactive session starts pub async fn update_completion_cache(&mut self) -> Result<()> { @@ -1619,17 +1455,16 @@ impl Session { ); } - pub fn get_metadata(&self) -> Result { - if !self.session_file.as_ref().is_some_and(|f| f.exists()) { - return Err(anyhow::anyhow!("Session file does not exist")); + pub async fn get_metadata(&self) -> Result { + match &self.session_id { + Some(id) => SessionManager::get_session(id, false).await, + None => Err(anyhow::anyhow!("No session available")), } - - session::read_metadata(self.session_file.as_ref().unwrap()) } // Get the session's total token usage - pub fn get_total_token_usage(&self) -> Result> { - let metadata = self.get_metadata()?; + pub async fn get_total_token_usage(&self) -> Result> { + let metadata = self.get_metadata().await?; Ok(metadata.total_tokens) } @@ -1661,7 +1496,7 @@ impl Session { } } - match self.get_metadata() { + match self.get_metadata().await { Ok(metadata) => { let total_tokens = metadata.total_tokens.unwrap_or(0) as usize; diff --git a/crates/goose-cli/src/session/output.rs b/crates/goose-cli/src/session/output.rs index 0582ec976e06..7c601261eeb9 100644 --- a/crates/goose-cli/src/session/output.rs +++ b/crates/goose-cli/src/session/output.rs @@ -14,7 +14,7 @@ use serde_json::Value; use std::cell::RefCell; use std::collections::HashMap; use std::io::{Error, IsTerminal, Write}; -use std::path::{Path, PathBuf}; +use std::path::Path; use std::sync::Arc; use std::time::Duration; @@ -685,12 +685,12 @@ pub fn display_session_info( resume: bool, provider: &str, model: &str, - session_file: &Option, + session_id: &Option, provider_instance: Option<&Arc>, ) { let start_session_msg = if resume { "resuming session |" - } else if session_file.is_none() { + } else if session_id.is_none() { "running without session |" } else { "starting session |" @@ -732,14 +732,6 @@ pub fn display_session_info( ); } - if let Some(session_file) = session_file { - println!( - " {} {}", - style("logging to").dim(), - style(session_file.display()).dim().cyan(), - ); - } - println!( " {} {}", style("working directory:").dim(), diff --git a/crates/goose-server/src/openapi.rs b/crates/goose-server/src/openapi.rs index c677d03796fc..c9da2f1d158e 100644 --- a/crates/goose-server/src/openapi.rs +++ b/crates/goose-server/src/openapi.rs @@ -3,10 +3,11 @@ use goose::agents::extension::ToolInfo; use goose::agents::ExtensionConfig; use goose::config::permission::PermissionLevel; use goose::config::ExtensionEntry; +use goose::conversation::Conversation; use goose::permission::permission_confirmation::PrincipalType; use goose::providers::base::{ConfigKey, ModelInfo, ProviderMetadata}; -use goose::session::info::SessionInfo; -use goose::session::SessionMetadata; + +use goose::session::{Session, SessionInsights}; use rmcp::model::{ Annotations, Content, EmbeddedResource, ImageContent, RawEmbeddedResource, RawImageContent, RawResource, RawTextContent, ResourceContents, Role, TextContent, Tool, ToolAnnotations, @@ -45,8 +46,6 @@ macro_rules! derive_utoipa { } fn convert_schemars_to_utoipa(schema: rmcp::schemars::Schema) -> RefOr { - // For schemars 1.0+, we need to work with the public API - // The schema is now a wrapper around a JSON Value that can be either an object or bool if let Some(true) = schema.as_bool() { return RefOr::T(Schema::Object(ObjectBuilder::new().build())); } @@ -55,12 +54,10 @@ fn convert_schemars_to_utoipa(schema: rmcp::schemars::Schema) -> RefOr { return RefOr::T(Schema::Object(ObjectBuilder::new().build())); } - // For object schemas, we'll need to work with the JSON Value directly if let Some(obj) = schema.as_object() { return convert_json_object_to_utoipa(obj); } - // Fallback RefOr::T(Schema::Object(ObjectBuilder::new().build())) } @@ -69,12 +66,10 @@ fn convert_json_object_to_utoipa( ) -> RefOr { use serde_json::Value; - // Handle $ref if let Some(Value::String(reference)) = obj.get("$ref") { return RefOr::Ref(Ref::new(reference.clone())); } - // Handle oneOf, allOf, anyOf if let Some(Value::Array(one_of)) = obj.get("oneOf") { let mut builder = OneOfBuilder::new(); for item in one_of { @@ -105,11 +100,9 @@ fn convert_json_object_to_utoipa( return RefOr::T(Schema::AnyOf(builder.build())); } - // Handle type-based schemas match obj.get("type") { Some(Value::String(type_str)) => convert_typed_schema(type_str, obj), Some(Value::Array(types)) => { - // Multiple types - use AnyOf let mut builder = AnyOfBuilder::new(); for type_val in types { if let Value::String(type_str) = type_val { @@ -119,7 +112,7 @@ fn convert_json_object_to_utoipa( RefOr::T(Schema::AnyOf(builder.build())) } None => RefOr::T(Schema::Object(ObjectBuilder::new().build())), - _ => RefOr::T(Schema::Object(ObjectBuilder::new().build())), // Handle other value types + _ => RefOr::T(Schema::Object(ObjectBuilder::new().build())), } } @@ -133,7 +126,6 @@ fn convert_typed_schema( "object" => { let mut object_builder = ObjectBuilder::new(); - // Add properties if let Some(Value::Object(properties)) = obj.get("properties") { for (name, prop_value) in properties { if let Ok(prop_schema) = rmcp::schemars::Schema::try_from(prop_value.clone()) { @@ -143,7 +135,6 @@ fn convert_typed_schema( } } - // Add required fields if let Some(Value::Array(required)) = obj.get("required") { for req in required { if let Value::String(field_name) = req { @@ -152,7 +143,6 @@ fn convert_typed_schema( } } - // Handle additional properties if let Some(additional) = obj.get("additionalProperties") { match additional { Value::Bool(false) => { @@ -178,7 +168,6 @@ fn convert_typed_schema( "array" => { let mut array_builder = ArrayBuilder::new(); - // Add items schema if let Some(items) = obj.get("items") { match items { Value::Object(_) | Value::Bool(_) => { @@ -188,7 +177,6 @@ fn convert_typed_schema( } } Value::Array(item_schemas) => { - // Multiple item types - use AnyOf let mut any_of = AnyOfBuilder::new(); for item in item_schemas { if let Ok(schema) = rmcp::schemars::Schema::try_from(item.clone()) { @@ -202,7 +190,6 @@ fn convert_typed_schema( } } - // Add constraints if let Some(Value::Number(min_items)) = obj.get("minItems") { if let Some(min) = min_items.as_u64() { array_builder = array_builder.min_items(Some(min as usize)); @@ -333,8 +320,6 @@ struct AnnotatedSchema {} impl<'__s> ToSchema<'__s> for AnnotatedSchema { fn schema() -> (&'__s str, utoipa::openapi::RefOr) { - // Create a oneOf schema with only the variants we actually use in the API - // This avoids the circular reference from RawContent::Audio(AudioContent) let schema = Schema::OneOf( OneOfBuilder::new() .item(RefOr::Ref(Ref::new("#/components/schemas/RawTextContent"))) @@ -352,7 +337,6 @@ impl<'__s> ToSchema<'__s> for AnnotatedSchema { } } -#[allow(dead_code)] // Used by utoipa for OpenAPI generation #[derive(OpenApi)] #[openapi( paths( @@ -384,7 +368,10 @@ impl<'__s> ToSchema<'__s> for AnnotatedSchema { super::routes::reply::confirm_permission, super::routes::context::manage_context, super::routes::session::list_sessions, - super::routes::session::get_session_history, + super::routes::session::get_session, + super::routes::session::get_session_insights, + super::routes::session::update_session_description, + super::routes::session::delete_session, super::routes::schedule::create_schedule, super::routes::schedule::list_schedules, super::routes::schedule::delete_schedule, @@ -419,7 +406,7 @@ impl<'__s> ToSchema<'__s> for AnnotatedSchema { super::routes::context::ContextManageRequest, super::routes::context::ContextManageResponse, super::routes::session::SessionListResponse, - super::routes::session::SessionHistoryResponse, + super::routes::session::UpdateSessionDescriptionRequest, Message, MessageContent, MessageMetadata, @@ -454,9 +441,10 @@ impl<'__s> ToSchema<'__s> for AnnotatedSchema { PermissionLevel, PrincipalType, ModelInfo, - SessionInfo, - SessionMetadata, - goose::session::ExtensionData, + Session, + SessionInsights, + Conversation, + goose::session::extension_data::ExtensionData, super::routes::schedule::CreateScheduleRequest, super::routes::schedule::UpdateScheduleRequest, super::routes::schedule::KillJobResponse, @@ -498,7 +486,6 @@ impl<'__s> ToSchema<'__s> for AnnotatedSchema { super::routes::agent::UpdateRouterToolSelectorRequest, super::routes::agent::StartAgentRequest, super::routes::agent::ResumeAgentRequest, - super::routes::agent::StartAgentResponse, super::routes::agent::ErrorResponse, super::routes::setup::SetupResponse, )) diff --git a/crates/goose-server/src/routes/agent.rs b/crates/goose-server/src/routes/agent.rs index 1c546d5d66f1..bad6d27a2e99 100644 --- a/crates/goose-server/src/routes/agent.rs +++ b/crates/goose-server/src/routes/agent.rs @@ -6,13 +6,11 @@ use axum::{ Json, Router, }; use goose::config::PermissionManager; -use goose::conversation::message::Message; -use goose::conversation::Conversation; + use goose::model::ModelConfig; use goose::providers::create; use goose::recipe::{Recipe, Response}; -use goose::session; -use goose::session::SessionMetadata; +use goose::session::{Session, SessionManager}; use goose::{ agents::{extension::ToolInfo, extension_manager::get_parameter_names}, config::permission::PermissionLevel, @@ -22,7 +20,6 @@ use serde::{Deserialize, Serialize}; use std::path::PathBuf; use std::sync::atomic::Ordering; use std::sync::Arc; -use tracing::error; #[derive(Deserialize, utoipa::ToSchema)] pub struct ExtendPromptRequest { @@ -81,14 +78,6 @@ pub struct ResumeAgentRequest { session_id: String, } -// This is the same as SessionHistoryResponse -#[derive(Serialize, utoipa::ToSchema)] -pub struct StartAgentResponse { - session_id: String, - metadata: SessionMetadata, - messages: Vec, -} - #[derive(Serialize, utoipa::ToSchema)] pub struct ErrorResponse { error: String, @@ -99,7 +88,7 @@ pub struct ErrorResponse { path = "/agent/start", request_body = StartAgentRequest, responses( - (status = 200, description = "Agent started successfully", body = StartAgentResponse), + (status = 200, description = "Agent started successfully", body = Session), (status = 400, description = "Bad request - invalid working directory"), (status = 401, description = "Unauthorized - invalid secret key"), (status = 500, description = "Internal server error") @@ -108,39 +97,28 @@ pub struct ErrorResponse { async fn start_agent( State(state): State>, Json(payload): Json, -) -> Result, StatusCode> { - let session_id = session::generate_session_id(); +) -> Result, StatusCode> { let counter = state.session_counter.fetch_add(1, Ordering::SeqCst) + 1; + let description = format!("New session {}", counter); + + let mut session = + SessionManager::create_session(PathBuf::from(&payload.working_dir), description) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + if let Some(recipe) = payload.recipe { + SessionManager::update_session(&session.id) + .recipe(Some(recipe)) + .apply() + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + session = SessionManager::get_session(&session.id, false) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + } - let metadata = SessionMetadata { - working_dir: PathBuf::from(&payload.working_dir), - description: format!("New session {}", counter), - schedule_id: None, - message_count: 0, - total_tokens: Some(0), - input_tokens: Some(0), - output_tokens: Some(0), - accumulated_total_tokens: Some(0), - accumulated_input_tokens: Some(0), - accumulated_output_tokens: Some(0), - extension_data: Default::default(), - recipe: payload.recipe, - }; - - let session_path = match session::get_path(session::Identifier::Name(session_id.clone())) { - Ok(path) => path, - Err(_) => return Err(StatusCode::BAD_REQUEST), - }; - - let conversation = Conversation::empty(); - session::storage::save_messages_with_metadata(&session_path, &metadata, &conversation) - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - - Ok(Json(StartAgentResponse { - session_id, - metadata, - messages: conversation.messages().clone(), - })) + Ok(Json(session)) } #[utoipa::path( @@ -148,7 +126,7 @@ async fn start_agent( path = "/agent/resume", request_body = ResumeAgentRequest, responses( - (status = 200, description = "Agent started successfully", body = StartAgentResponse), + (status = 200, description = "Agent started successfully", body = Session), (status = 400, description = "Bad request - invalid working directory"), (status = 401, description = "Unauthorized - invalid secret key"), (status = 500, description = "Internal server error") @@ -156,28 +134,12 @@ async fn start_agent( )] async fn resume_agent( Json(payload): Json, -) -> Result, StatusCode> { - let session_path = - match session::get_path(session::Identifier::Name(payload.session_id.clone())) { - Ok(path) => path, - Err(_) => return Err(StatusCode::BAD_REQUEST), - }; - - let metadata = session::read_metadata(&session_path).map_err(|_| StatusCode::NOT_FOUND)?; - - let conversation = match session::read_messages(&session_path) { - Ok(messages) => messages, - Err(e) => { - error!("Failed to read session messages: {:?}", e); - return Err(StatusCode::NOT_FOUND); - } - }; +) -> Result, StatusCode> { + let session = SessionManager::get_session(&payload.session_id, true) + .await + .map_err(|_| StatusCode::NOT_FOUND)?; - Ok(Json(StartAgentResponse { - session_id: payload.session_id.clone(), - metadata, - messages: conversation.messages().clone(), - })) + Ok(Json(session)) } #[utoipa::path( diff --git a/crates/goose-server/src/routes/config_management.rs b/crates/goose-server/src/routes/config_management.rs index 3be6d8072a84..070274181de0 100644 --- a/crates/goose-server/src/routes/config_management.rs +++ b/crates/goose-server/src/routes/config_management.rs @@ -822,8 +822,8 @@ mod tests { let status_code = result.unwrap_err(); assert!(status_code == StatusCode::BAD_REQUEST, - "Expected BAD_REQUEST (authentication error) or INTERNAL_SERVER_ERROR (other errors), got: {}", - status_code + "Expected BAD_REQUEST (authentication error) or INTERNAL_SERVER_ERROR (other errors), got: {}", + status_code ); std::env::remove_var("OPENAI_API_KEY"); diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index 9f118a3b5add..09c1c41bff63 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -11,14 +11,12 @@ use futures::{stream::StreamExt, Stream}; use goose::conversation::message::{Message, MessageContent}; use goose::conversation::Conversation; use goose::execution::SessionExecutionMode; +use goose::permission::{Permission, PermissionConfirmation}; +use goose::session::SessionManager; use goose::{ agents::{AgentEvent, SessionConfig}, permission::permission_confirmation::PrincipalType, }; -use goose::{ - permission::{Permission, PermissionConfirmation}, - session, -}; use mcp_core::ToolResult; use rmcp::model::{Content, ServerNotification}; use serde::{Deserialize, Serialize}; @@ -228,30 +226,13 @@ async fn reply_handler( } }; - // Load session metadata to get the working directory and other config - let session_path = match session::get_path(session::Identifier::Name(session_id.clone())) { - Ok(path) => path, - Err(e) => { - tracing::error!("Failed to get session path for {}: {}", session_id, e); - let _ = stream_event( - MessageEvent::Error { - error: format!("Failed to get session path: {}", e), - }, - &task_tx, - &cancel_token, - ) - .await; - return; - } - }; - - let session_metadata = match session::read_metadata(&session_path) { + let session = match SessionManager::get_session(&session_id, false).await { Ok(metadata) => metadata, Err(e) => { - tracing::error!("Failed to read session metadata for {}: {}", session_id, e); + tracing::error!("Failed to read session for {}: {}", session_id, e); let _ = stream_event( MessageEvent::Error { - error: format!("Failed to read session metadata: {}", e), + error: format!("Failed to read session: {}", e), }, &task_tx, &cancel_token, @@ -262,9 +243,9 @@ async fn reply_handler( }; let session_config = SessionConfig { - id: session::Identifier::Name(session_id.clone()), - working_dir: session_metadata.working_dir.clone(), - schedule_id: session_metadata.schedule_id.clone(), + id: session_id.clone(), + working_dir: session.working_dir.clone(), + schedule_id: session.schedule_id.clone(), execution_mode: None, max_turns: None, retry_config: None, @@ -294,22 +275,6 @@ async fn reply_handler( }; let mut all_messages = messages.clone(); - let session_path = match session::get_path(session::Identifier::Name(session_id.clone())) { - Ok(path) => path, - Err(e) => { - tracing::error!("Failed to get session path: {}", e); - let _ = stream_event( - MessageEvent::Error { - error: format!("Failed to get session path: {}", e), - }, - &task_tx, - &cancel_token, - ) - .await; - return; - } - }; - let saved_message_count = all_messages.len(); let mut heartbeat_interval = tokio::time::interval(Duration::from_millis(500)); loop { @@ -376,40 +341,18 @@ async fn reply_handler( } } - if all_messages.len() > saved_message_count { - if let Ok(provider) = agent.provider().await { - let provider = Arc::clone(&provider); - let session_path_clone = session_path.to_path_buf(); - let all_messages_clone = all_messages.clone(); - let working_dir = session_config.working_dir.clone(); - tokio::spawn(async move { - if let Err(e) = session::persist_messages( - &session_path_clone, - &all_messages_clone, - Some(provider), - Some(working_dir), - ) - .await - { - tracing::error!("Failed to store session history: {:?}", e); - } - }); - } - } let session_duration = session_start.elapsed(); - if let Ok(metadata) = session::read_metadata(&session_path) { - let total_tokens = metadata.total_tokens.unwrap_or(0); - let message_count = metadata.message_count; - + if let Ok(session) = SessionManager::get_session(&session_id, true).await { + let total_tokens = session.total_tokens.unwrap_or(0); tracing::info!( counter.goose.session_completions = 1, session_type = "app", interface = "ui", exit_type = "normal", duration_ms = session_duration.as_millis() as u64, - total_tokens, - message_count, + total_tokens = total_tokens, + message_count = session.message_count, "Session completed" ); diff --git a/crates/goose-server/src/routes/schedule.rs b/crates/goose-server/src/routes/schedule.rs index 663890080089..5203e8be703b 100644 --- a/crates/goose-server/src/routes/schedule.rs +++ b/crates/goose-server/src/routes/schedule.rs @@ -319,24 +319,23 @@ async fn sessions_handler( .await { Ok(session_tuples) => { - // Expecting Vec<(String, goose::session::storage::SessionMetadata)> - let display_infos: Vec = session_tuples - .into_iter() - .map(|(session_name, metadata)| SessionDisplayInfo { + let mut display_infos = Vec::new(); + for (session_name, session) in session_tuples { + display_infos.push(SessionDisplayInfo { id: session_name.clone(), - name: metadata.description, // Use description as name + name: session.description, created_at: parse_session_name_to_iso(&session_name), - working_dir: metadata.working_dir.to_string_lossy().into_owned(), - schedule_id: metadata.schedule_id, // This is the ID of the schedule itself - message_count: metadata.message_count, - total_tokens: metadata.total_tokens, - input_tokens: metadata.input_tokens, - output_tokens: metadata.output_tokens, - accumulated_total_tokens: metadata.accumulated_total_tokens, - accumulated_input_tokens: metadata.accumulated_input_tokens, - accumulated_output_tokens: metadata.accumulated_output_tokens, - }) - .collect(); + working_dir: session.working_dir.to_string_lossy().into_owned(), + schedule_id: session.schedule_id, + message_count: session.message_count, + total_tokens: session.total_tokens, + input_tokens: session.input_tokens, + output_tokens: session.output_tokens, + accumulated_total_tokens: session.accumulated_total_tokens, + accumulated_input_tokens: session.accumulated_input_tokens, + accumulated_output_tokens: session.accumulated_output_tokens, + }); + } Ok(Json(display_infos)) } Err(e) => { diff --git a/crates/goose-server/src/routes/session.rs b/crates/goose-server/src/routes/session.rs index 9109d855d0dc..c7cee665fcc5 100644 --- a/crates/goose-server/src/routes/session.rs +++ b/crates/goose-server/src/routes/session.rs @@ -1,7 +1,3 @@ -use chrono::DateTime; -use std::collections::HashMap; -use std::sync::Arc; - use crate::state::AppState; use axum::{ extract::Path, @@ -9,65 +5,28 @@ use axum::{ routing::{delete, get, put}, Json, Router, }; -use goose::conversation::message::Message; -use goose::session; -use goose::session::info::{get_valid_sorted_sessions, SessionInfo, SortOrder}; -use goose::session::SessionMetadata; +use goose::session::session_manager::SessionInsights; +use goose::session::{Session, SessionManager}; use serde::{Deserialize, Serialize}; -use tracing::{error, info}; +use std::sync::Arc; use utoipa::ToSchema; #[derive(Serialize, ToSchema)] #[serde(rename_all = "camelCase")] pub struct SessionListResponse { /// List of available session information objects - sessions: Vec, -} - -#[derive(Serialize, ToSchema)] -#[serde(rename_all = "camelCase")] -pub struct SessionHistoryResponse { - /// Unique identifier for the session - session_id: String, - /// Session metadata containing creation time and other details - metadata: SessionMetadata, - /// List of messages in the session conversation - messages: Vec, + sessions: Vec, } #[derive(Deserialize, ToSchema)] #[serde(rename_all = "camelCase")] -pub struct UpdateSessionMetadataRequest { +pub struct UpdateSessionDescriptionRequest { /// Updated description (name) for the session (max 200 characters) description: String, } const MAX_DESCRIPTION_LENGTH: usize = 200; -#[derive(Serialize, ToSchema, Debug)] -#[serde(rename_all = "camelCase")] -pub struct SessionInsights { - /// Total number of sessions - total_sessions: usize, - /// Most active working directories with session counts - most_active_dirs: Vec<(String, usize)>, - /// Average session duration in minutes - avg_session_duration: f64, - /// Total tokens used across all sessions - total_tokens: i64, - /// Activity trend for the last 7 days - recent_activity: Vec<(String, usize)>, -} - -#[derive(Serialize, ToSchema, Debug)] -#[allow(dead_code)] -#[serde(rename_all = "camelCase")] -pub struct ActivityHeatmapCell { - pub week: usize, - pub day: usize, - pub count: usize, -} - #[utoipa::path( get, path = "/sessions", @@ -81,9 +40,9 @@ pub struct ActivityHeatmapCell { ), tag = "Session Management" )] -// List all available sessions async fn list_sessions() -> Result, StatusCode> { - let sessions = get_valid_sorted_sessions(SortOrder::Descending) + let sessions = SessionManager::list_sessions() + .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; Ok(Json(SessionListResponse { sessions })) @@ -96,7 +55,7 @@ async fn list_sessions() -> Result, StatusCode> { ("session_id" = String, Path, description = "Unique identifier for the session") ), responses( - (status = 200, description = "Session history retrieved successfully", body = SessionHistoryResponse), + (status = 200, description = "Session history retrieved successfully", body = Session), (status = 401, description = "Unauthorized - Invalid or missing API key"), (status = 404, description = "Session not found"), (status = 500, description = "Internal server error") @@ -106,40 +65,13 @@ async fn list_sessions() -> Result, StatusCode> { ), tag = "Session Management" )] -// Get a specific session's history -async fn get_session_history( - Path(session_id): Path, -) -> Result, StatusCode> { - let session_path = match session::get_path(session::Identifier::Name(session_id.clone())) { - Ok(path) => path, - Err(_) => return Err(StatusCode::BAD_REQUEST), - }; - - let metadata = session::read_metadata(&session_path).map_err(|_| StatusCode::NOT_FOUND)?; - - let messages = match session::read_messages(&session_path) { - Ok(messages) => messages, - Err(e) => { - error!("Failed to read session messages: {:?}", e); - return Err(StatusCode::NOT_FOUND); - } - }; - - // Filter messages to only include user_visible ones - let user_visible_messages: Vec = messages - .messages() - .iter() - .filter(|m| m.is_user_visible()) - .cloned() - .collect(); +async fn get_session(Path(session_id): Path) -> Result, StatusCode> { + let session = SessionManager::get_session(&session_id, true) + .await + .map_err(|_| StatusCode::NOT_FOUND)?; - Ok(Json(SessionHistoryResponse { - session_id, - metadata, - messages: user_visible_messages, - })) + Ok(Json(session)) } - #[utoipa::path( get, path = "/sessions/insights", @@ -154,115 +86,21 @@ async fn get_session_history( tag = "Session Management" )] async fn get_session_insights() -> Result, StatusCode> { - info!("Received request for session insights"); - - let sessions = get_valid_sorted_sessions(SortOrder::Descending).map_err(|e| { - error!("Failed to get session info: {:?}", e); - StatusCode::INTERNAL_SERVER_ERROR - })?; - - // Filter out sessions without descriptions - let sessions: Vec = sessions - .into_iter() - .filter(|session| !session.metadata.description.is_empty()) - .collect(); - - info!("Found {} sessions with descriptions", sessions.len()); - - // Calculate insights - let total_sessions = sessions.len(); - - // Debug: Log if we have very few sessions, which might indicate filtering issues - if total_sessions == 0 { - info!("Warning: No sessions found with descriptions"); - } - - // Track directory usage - let mut dir_counts: HashMap = HashMap::new(); - let mut total_duration = 0.0; - let mut total_tokens = 0; - let mut activity_by_date: HashMap = HashMap::new(); - - for session in &sessions { - // Track directory usage - let dir = session.metadata.working_dir.to_string_lossy().to_string(); - *dir_counts.entry(dir).or_insert(0) += 1; - - // Track tokens - only add positive values to prevent negative totals - if let Some(tokens) = session.metadata.accumulated_total_tokens { - match tokens.cmp(&0) { - std::cmp::Ordering::Greater => { - total_tokens += tokens as i64; - } - std::cmp::Ordering::Less => { - // Log negative token values for debugging - info!( - "Warning: Session {} has negative accumulated_total_tokens: {}", - session.id, tokens - ); - } - std::cmp::Ordering::Equal => { - // Zero tokens, no action needed - } - } - } - - // Track activity by date - if let Ok(date) = DateTime::parse_from_str(&session.modified, "%Y-%m-%d %H:%M:%S UTC") { - let date_str = date.format("%Y-%m-%d").to_string(); - *activity_by_date.entry(date_str).or_insert(0) += 1; - } - - // Calculate session duration from messages - let session_path = session::get_path(session::Identifier::Name(session.id.clone())); - if let Ok(session_path) = session_path { - if let Ok(messages) = session::read_messages(&session_path) { - if let (Some(first), Some(last)) = (messages.first(), messages.last()) { - let duration = (last.created - first.created) as f64 / 60.0; // Convert to minutes - total_duration += duration; - } - } - } - } - - // Get top 3 most active directories - let mut dir_vec: Vec<(String, usize)> = dir_counts.into_iter().collect(); - dir_vec.sort_by(|a, b| b.1.cmp(&a.1)); - let most_active_dirs = dir_vec.into_iter().take(3).collect(); - - // Calculate average session duration - let avg_session_duration = if total_sessions > 0 { - total_duration / total_sessions as f64 - } else { - 0.0 - }; - - // Get last 7 days of activity - let mut activity_vec: Vec<(String, usize)> = activity_by_date.into_iter().collect(); - activity_vec.sort_by(|a, b| b.0.cmp(&a.0)); // Sort by date descending - let recent_activity = activity_vec.into_iter().take(7).collect(); - - let insights = SessionInsights { - total_sessions, - most_active_dirs, - avg_session_duration, - total_tokens, - recent_activity, - }; - - info!("Returning insights: {:?}", insights); + let insights = SessionManager::get_insights() + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; Ok(Json(insights)) } #[utoipa::path( put, - path = "/sessions/{session_id}/metadata", - request_body = UpdateSessionMetadataRequest, + path = "/sessions/{session_id}/description", + request_body = UpdateSessionDescriptionRequest, params( ("session_id" = String, Path, description = "Unique identifier for the session") ), responses( - (status = 200, description = "Session metadata updated successfully"), + (status = 200, description = "Session description updated successfully"), (status = 400, description = "Bad request - Description too long (max 200 characters)"), (status = 401, description = "Unauthorized - Invalid or missing API key"), (status = 404, description = "Session not found"), @@ -273,27 +111,17 @@ async fn get_session_insights() -> Result, StatusCode> { ), tag = "Session Management" )] -// Update session metadata -async fn update_session_metadata( +async fn update_session_description( Path(session_id): Path, - Json(request): Json, + Json(request): Json, ) -> Result { - // Validate description length if request.description.len() > MAX_DESCRIPTION_LENGTH { return Err(StatusCode::BAD_REQUEST); } - let session_path = session::get_path(session::Identifier::Name(session_id.clone())) - .map_err(|_| StatusCode::BAD_REQUEST)?; - - // Read current metadata - let mut metadata = session::read_metadata(&session_path).map_err(|_| StatusCode::NOT_FOUND)?; - - // Update description - metadata.description = request.description; - - // Save updated metadata - session::update_metadata(&session_path, &metadata) + SessionManager::update_session(&session_id) + .description(request.description) + .apply() .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; @@ -302,7 +130,7 @@ async fn update_session_metadata( #[utoipa::path( delete, - path = "/sessions/{session_id}/delete", + path = "/sessions/{session_id}", params( ("session_id" = String, Path, description = "Unique identifier for the session") ), @@ -317,93 +145,29 @@ async fn update_session_metadata( ), tag = "Session Management" )] -// Delete a session async fn delete_session(Path(session_id): Path) -> Result { - // Get the session path - let session_path = match session::get_path(session::Identifier::Name(session_id.clone())) { - Ok(path) => path, - Err(_) => return Err(StatusCode::BAD_REQUEST), - }; - - // Check if session file exists - if !session_path.exists() { - return Err(StatusCode::NOT_FOUND); - } - - // Delete the session file - std::fs::remove_file(&session_path).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + SessionManager::delete_session(&session_id) + .await + .map_err(|e| { + if e.to_string().contains("not found") { + StatusCode::NOT_FOUND + } else { + StatusCode::INTERNAL_SERVER_ERROR + } + })?; Ok(StatusCode::OK) } -// Configure routes for this module pub fn routes(state: Arc) -> Router { Router::new() .route("/sessions", get(list_sessions)) - .route("/sessions/{session_id}", get(get_session_history)) - .route("/sessions/{session_id}/delete", delete(delete_session)) + .route("/sessions/{session_id}", get(get_session)) + .route("/sessions/{session_id}", delete(delete_session)) .route("/sessions/insights", get(get_session_insights)) .route( - "/sessions/{session_id}/metadata", - put(update_session_metadata), + "/sessions/{session_id}/description", + put(update_session_description), ) .with_state(state) } - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - async fn test_update_session_metadata_request_deserialization() { - // Test that our request struct can be deserialized properly - let json = r#"{"description": "test description"}"#; - let request: UpdateSessionMetadataRequest = serde_json::from_str(json).unwrap(); - assert_eq!(request.description, "test description"); - } - - #[tokio::test] - async fn test_update_session_metadata_request_validation() { - // Test empty description - let empty_request = UpdateSessionMetadataRequest { - description: "".to_string(), - }; - assert_eq!(empty_request.description, ""); - - // Test normal description - let normal_request = UpdateSessionMetadataRequest { - description: "My Session Name".to_string(), - }; - assert_eq!(normal_request.description, "My Session Name"); - - // Test description at max length (should be valid) - let max_length_description = "A".repeat(MAX_DESCRIPTION_LENGTH); - let max_request = UpdateSessionMetadataRequest { - description: max_length_description.clone(), - }; - assert_eq!(max_request.description, max_length_description); - assert_eq!(max_request.description.len(), MAX_DESCRIPTION_LENGTH); - - // Test description over max length - let over_max_description = "A".repeat(MAX_DESCRIPTION_LENGTH + 1); - let over_max_request = UpdateSessionMetadataRequest { - description: over_max_description.clone(), - }; - assert_eq!(over_max_request.description, over_max_description); - assert!(over_max_request.description.len() > MAX_DESCRIPTION_LENGTH); - } - - #[tokio::test] - async fn test_description_length_validation() { - // Test the validation logic used in the endpoint - let valid_description = "A".repeat(MAX_DESCRIPTION_LENGTH); - assert!(valid_description.len() <= MAX_DESCRIPTION_LENGTH); - - let invalid_description = "A".repeat(MAX_DESCRIPTION_LENGTH + 1); - assert!(invalid_description.len() > MAX_DESCRIPTION_LENGTH); - - // Test edge cases - assert!(String::new().len() <= MAX_DESCRIPTION_LENGTH); // Empty string - assert!("Short".len() <= MAX_DESCRIPTION_LENGTH); // Short string - } -} diff --git a/crates/goose/Cargo.toml b/crates/goose/Cargo.toml index a98b19c9187b..a575c5277642 100644 --- a/crates/goose/Cargo.toml +++ b/crates/goose/Cargo.toml @@ -31,18 +31,18 @@ thiserror = "1.0" futures = "0.3" dirs = "5.0" reqwest = { version = "0.12.9", features = [ - "rustls-tls-native-roots", - "json", - "cookies", - "gzip", - "brotli", - "deflate", - "zstd", - "charset", - "http2", - "stream", - "blocking" - ], default-features = false } + "rustls-tls-native-roots", + "json", + "cookies", + "gzip", + "brotli", + "deflate", + "zstd", + "charset", + "http2", + "stream", + "blocking" +], default-features = false } tokio = { version = "1.43", features = ["full"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" @@ -79,6 +79,7 @@ rand = "0.8.5" utoipa = { version = "4.1", features = ["chrono"] } tokio-cron-scheduler = "0.14.0" urlencoding = "2.1" +sqlx = { version = "0.7", features = ["runtime-tokio-rustls", "sqlite", "chrono", "json"] } # For Bedrock provider aws-config = { version = "1.5.16", features = ["behavior-version-latest"] } @@ -100,7 +101,6 @@ ahash = "0.8" tokio-util = "0.7.15" unicode-normalization = "0.1" -arrow = "52.2" oauth2 = "5.0.0" [target.'cfg(target_os = "windows")'.dependencies] diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 4b53c82c8a4d..2ad6e54fa3d1 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -42,8 +42,6 @@ use crate::providers::errors::ProviderError; use crate::recipe::{Author, Recipe, Response, Settings, SubRecipe}; use crate::scheduler_trait::SchedulerTrait; use crate::security::security_inspector::SecurityInspector; -use crate::session; -use crate::session::extension_data::ExtensionState; use crate::tool_inspection::ToolInspectionManager; use crate::tool_monitor::RepetitionInspector; use crate::utils::is_token_cancelled; @@ -55,7 +53,7 @@ use rmcp::model::{ use serde_json::Value; use tokio::sync::{mpsc, Mutex}; use tokio_util::sync::CancellationToken; -use tracing::{debug, error, info, instrument}; +use tracing::{debug, error, info, instrument, warn}; use super::final_output_tool::FinalOutputTool; use super::model_selector::autopilot::AutoPilot; @@ -66,12 +64,14 @@ use crate::agents::todo_tools::{ todo_read_tool, todo_write_tool, TODO_READ_TOOL_NAME, TODO_WRITE_TOOL_NAME, }; use crate::conversation::message::{Message, ToolRequest}; +use crate::session::extension_data::ExtensionState; +use crate::session::{extension_data, SessionManager}; const DEFAULT_MAX_TURNS: u32 = 1000; /// Context needed for the reply function pub struct ReplyContext { - pub messages: Conversation, + pub conversation: Conversation, pub tools: Vec, pub toolshim_tools: Vec, pub system_prompt: String, @@ -268,7 +268,7 @@ impl Agent { .await; Ok(ReplyContext { - messages: conversation, + conversation, tools, toolshim_tools, system_prompt, @@ -506,17 +506,12 @@ impl Agent { ))) } else if tool_call.name == TODO_READ_TOOL_NAME { // Handle task planner read tool - let session_file_path = if let Some(session_config) = session { - session::storage::get_path(session_config.id.clone()).ok() - } else { - None - }; - - let todo_content = if let Some(path) = session_file_path { - session::storage::read_metadata(&path) + let todo_content = if let Some(session_config) = session { + SessionManager::get_session(&session_config.id, false) + .await .ok() - .and_then(|m| { - session::TodoState::from_extension_data(&m.extension_data) + .and_then(|metadata| { + extension_data::TodoState::from_extension_data(&metadata.extension_data) .map(|state| state.content) }) .unwrap_or_default() @@ -551,43 +546,39 @@ impl Agent { None, ))) } else if let Some(session_config) = session { - // Update session metadata with new TODO content - match session::storage::get_path(session_config.id.clone()) { - Ok(path) => match session::storage::read_metadata(&path) { - Ok(mut metadata) => { - let todo_state = session::TodoState::new(content); - todo_state - .to_extension_data(&mut metadata.extension_data) - .ok(); - - let path_clone = path.clone(); - let metadata_clone = metadata.clone(); - let update_result = tokio::task::spawn(async move { - session::storage::update_metadata(&path_clone, &metadata_clone) - .await - }) - .await; - - match update_result { - Ok(Ok(_)) => ToolCallResult::from(Ok(vec![Content::text( - format!("Updated ({} chars)", char_count), - )])), - _ => ToolCallResult::from(Err(ErrorData::new( + match SessionManager::get_session(&session_config.id, false).await { + Ok(mut session) => { + let todo_state = extension_data::TodoState::new(content); + if todo_state + .to_extension_data(&mut session.extension_data) + .is_ok() + { + match SessionManager::update_session(&session_config.id) + .extension_data(session.extension_data) + .apply() + .await + { + Ok(_) => ToolCallResult::from(Ok(vec![Content::text(format!( + "Updated ({} chars)", + char_count + ))])), + Err(_) => ToolCallResult::from(Err(ErrorData::new( ErrorCode::INTERNAL_ERROR, "Failed to update session metadata".to_string(), None, ))), } + } else { + ToolCallResult::from(Err(ErrorData::new( + ErrorCode::INTERNAL_ERROR, + "Failed to serialize TODO state".to_string(), + None, + ))) } - Err(_) => ToolCallResult::from(Err(ErrorData::new( - ErrorCode::INTERNAL_ERROR, - "Failed to read session metadata".to_string(), - None, - ))), - }, + } Err(_) => ToolCallResult::from(Err(ErrorData::new( ErrorCode::INTERNAL_ERROR, - "Failed to get session path".to_string(), + "Failed to read session metadata".to_string(), None, ))), } @@ -911,10 +902,9 @@ impl Agent { > { // Try to get session metadata for more accurate token counts let session_metadata = if let Some(session_config) = session { - match session::storage::get_path(session_config.id.clone()) { - Ok(session_file_path) => session::storage::read_metadata(&session_file_path).ok(), - Err(_) => None, - } + SessionManager::get_session(&session_config.id, false) + .await + .ok() } else { None }; @@ -960,7 +950,7 @@ impl Agent { cancel_token: Option, ) -> Result>> { // Handle auto-compaction before processing - let (messages, compaction_msg, _summarization_usage) = match self + let (conversation, compaction_msg, _summarization_usage) = match self .handle_auto_compaction(unfixed_conversation.messages(), &session) .await? { @@ -969,7 +959,7 @@ impl Agent { let context = self .prepare_reply_context(unfixed_conversation, &session) .await?; - (context.messages, None, None) + (context.conversation, None, None) } }; @@ -977,10 +967,13 @@ impl Agent { if let Some(compaction_msg) = compaction_msg { return Ok(Box::pin(async_stream::try_stream! { yield AgentEvent::Message(Message::assistant().with_summarization_requested(compaction_msg)); - yield AgentEvent::HistoryReplaced(messages.messages().clone()); + yield AgentEvent::HistoryReplaced(conversation.messages().clone()); + if let Some(session_to_store) = &session { + SessionManager::replace_conversation(&session_to_store.id, &conversation).await? + } // Continue with normal reply processing using compacted messages - let mut reply_stream = self.reply_internal(messages, session, cancel_token).await?; + let mut reply_stream = self.reply_internal(conversation, session, cancel_token).await?; while let Some(event) = reply_stream.next().await { yield event?; } @@ -988,19 +981,20 @@ impl Agent { } // No compaction needed, proceed with normal processing - self.reply_internal(messages, session, cancel_token).await + self.reply_internal(conversation, session, cancel_token) + .await } /// Main reply method that handles the actual agent processing async fn reply_internal( &self, - messages: Conversation, + conversation: Conversation, session: Option, cancel_token: Option, ) -> Result>> { - let context = self.prepare_reply_context(messages, &session).await?; + let context = self.prepare_reply_context(conversation, &session).await?; let ReplyContext { - mut messages, + mut conversation, mut tools, mut toolshim_tools, mut system_prompt, @@ -1011,12 +1005,52 @@ impl Agent { let reply_span = tracing::Span::current(); self.reset_retry_attempts().await; - if let Some(content) = messages - .last() - .and_then(|msg| msg.content.first()) - .and_then(|c| c.as_text()) - { - debug!("user_message" = &content); + // This will need further refactoring. In the ideal world we pass the new message into + // reply and load the existing conversation. Until we get to that point, fetch the conversation + // so far and append the last (user) message that the caller already added. + if let Some(session_config) = &session { + let stored_conversation = SessionManager::get_session(&session_config.id, true) + .await? + .conversation + .ok_or_else(|| { + anyhow::anyhow!("Session {} has no conversation", session_config.id) + })?; + + match conversation.len().cmp(&stored_conversation.len()) { + std::cmp::Ordering::Equal => { + if conversation != stored_conversation { + warn!("Session messages mismatch - replacing with incoming"); + SessionManager::replace_conversation(&session_config.id, &conversation) + .await?; + } + } + std::cmp::Ordering::Greater + if conversation.len() == stored_conversation.len() + 1 => + { + let last_message = conversation.last().unwrap(); + if let Some(content) = last_message.content.first().and_then(|c| c.as_text()) { + debug!("user_message" = &content); + } + SessionManager::add_message(&session_config.id, last_message).await?; + } + _ => { + warn!( + "Unexpected session state: stored={}, incoming={}. Replacing.", + stored_conversation.len(), + conversation.len() + ); + SessionManager::replace_conversation(&session_config.id, &conversation).await?; + } + } + let provider = self.provider().await?; + let session_id = session_config.id.clone(); + tokio::spawn(async move { + if let Err(e) = + SessionManager::maybe_update_description(&session_id, provider).await + { + warn!("Failed to generate session description: {}", e); + } + }); } Ok(Box::pin(async_stream::try_stream! { @@ -1054,7 +1088,7 @@ impl Agent { { let mut autopilot = self.autopilot.lock().await; - if let Some((new_provider, role, model)) = autopilot.check_for_switch(&messages, self.provider().await?).await? { + if let Some((new_provider, role, model)) = autopilot.check_for_switch(&conversation, self.provider().await?).await? { debug!("AutoPilot switching to {} role with model {}", role, model); self.update_provider(new_provider).await?; @@ -1065,17 +1099,16 @@ impl Agent { } } - let mut stream = Self::stream_response_from_provider( self.provider().await?, &system_prompt, - messages.messages(), + conversation.messages(), &tools, &toolshim_tools, ).await?; - let mut added_message = false; - let mut messages_to_add = Vec::new(); + let mut no_tools_called = true; + let mut messages_to_add = Conversation::default(); let mut tools_updated = false; while let Some(next) = stream.next().await { @@ -1109,12 +1142,12 @@ impl Agent { // Record usage for the session if let Some(ref session_config) = &session { if let Some(ref usage) = usage { - Self::update_session_metrics(session_config, usage, messages.len()) - .await?; + Self::update_session_metrics(session_config, usage).await?; } } if let Some(response) = response { + messages_to_add.push(response.clone()); let ToolCategorizeResult { frontend_requests, remaining_requests, @@ -1161,7 +1194,7 @@ impl Agent { let inspection_results = self.tool_inspection_manager .inspect_tools( &remaining_requests, - messages.messages(), + conversation.messages(), ) .await?; @@ -1260,25 +1293,26 @@ impl Agent { let final_message_tool_resp = message_tool_response.lock().await.clone(); yield AgentEvent::Message(final_message_tool_resp.clone()); - added_message = true; - messages_to_add.push(response); + no_tools_called = false; messages_to_add.push(final_message_tool_resp); } } Err(ProviderError::ContextLengthExceeded(error_msg)) => { info!("Context length exceeded, attempting compaction"); - match auto_compact::perform_compaction(self, messages.messages()).await { + match auto_compact::perform_compaction(self, conversation.messages()).await { Ok(compact_result) => { - messages = compact_result.messages; + conversation = compact_result.messages; yield AgentEvent::Message( Message::assistant().with_summarization_requested( "Context limit reached. Conversation has been automatically compacted to continue." ) ); - yield AgentEvent::HistoryReplaced(messages.messages().to_vec()); - + yield AgentEvent::HistoryReplaced(conversation.messages().to_vec()); + if let Some(session_to_store) = &session { + SessionManager::replace_conversation(&session_to_store.id, &conversation).await? + } continue; } Err(_) => { @@ -1301,40 +1335,49 @@ impl Agent { if tools_updated { (tools, toolshim_tools, system_prompt) = self.prepare_tools_and_prompt().await?; } - if !added_message { + let mut exit_chat = false; + if no_tools_called { if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() { if final_output_tool.final_output.is_none() { - tracing::warn!("Final output tool has not been called yet. Continuing agent loop."); + warn!("Final output tool has not been called yet. Continuing agent loop."); let message = Message::user().with_text(FINAL_OUTPUT_CONTINUATION_MESSAGE); messages_to_add.push(message.clone()); yield AgentEvent::Message(message); - messages.extend(messages_to_add); - continue } else { let message = Message::assistant().with_text(final_output_tool.final_output.clone().unwrap()); messages_to_add.push(message.clone()); yield AgentEvent::Message(message); + exit_chat = true; } - } - - match self.handle_retry_logic(&mut messages, &session, &initial_messages).await { - Ok(should_retry) => { - if should_retry { - info!("Retry logic triggered, restarting agent loop"); - continue; + } else { + match self.handle_retry_logic(&mut conversation, &session, &initial_messages).await { + Ok(should_retry) => { + if should_retry { + info!("Retry logic triggered, restarting agent loop"); + } else { + exit_chat = true; + } + } + Err(e) => { + error!("Retry logic failed: {}", e); + yield AgentEvent::Message(Message::assistant().with_text( + format!("Retry logic encountered an error: {}", e) + )); + exit_chat = true; } - } - Err(e) => { - error!("Retry logic failed: {}", e); - yield AgentEvent::Message(Message::assistant().with_text( - format!("Retry logic encountered an error: {}", e) - )); } } - break; } - messages.extend(messages_to_add); + if let Some(session_config) = &session { + for msg in &messages_to_add { + SessionManager::add_message(&session_config.id, msg).await?; + } + } + conversation.extend(messages_to_add); + if exit_chat { + break; + } tokio::task::yield_now().await; } diff --git a/crates/goose/src/agents/reply_parts.rs b/crates/goose/src/agents/reply_parts.rs index ba6aa8d09284..641e65a0341c 100644 --- a/crates/goose/src/agents/reply_parts.rs +++ b/crates/goose/src/agents/reply_parts.rs @@ -15,7 +15,7 @@ use crate::providers::toolshim::{ modify_system_prompt_for_tool_json, OllamaInterpreter, }; -use crate::session; +use crate::session::SessionManager; use rmcp::model::Tool; async fn toolshim_postprocess( @@ -276,23 +276,9 @@ impl Agent { pub(crate) async fn update_session_metrics( session_config: &crate::agents::types::SessionConfig, usage: &ProviderUsage, - messages_length: usize, ) -> Result<()> { - let session_file_path = match session::storage::get_path(session_config.id.clone()) { - Ok(path) => path, - Err(e) => { - return Err(anyhow::anyhow!("Failed to get session file path: {}", e)); - } - }; - let mut metadata = session::storage::read_metadata(&session_file_path)?; - - metadata.schedule_id = session_config.schedule_id.clone(); - - metadata.total_tokens = usage.usage.total_tokens; - metadata.input_tokens = usage.usage.input_tokens; - metadata.output_tokens = usage.usage.output_tokens; - - metadata.message_count = messages_length + 1; + let session_id = session_config.id.as_str(); + let session = SessionManager::get_session(session_id, false).await?; let accumulate = |a: Option, b: Option| -> Option { match (a, b) { @@ -300,16 +286,24 @@ impl Agent { _ => a.or(b), } }; - metadata.accumulated_total_tokens = - accumulate(metadata.accumulated_total_tokens, usage.usage.total_tokens); - metadata.accumulated_input_tokens = - accumulate(metadata.accumulated_input_tokens, usage.usage.input_tokens); - metadata.accumulated_output_tokens = accumulate( - metadata.accumulated_output_tokens, - usage.usage.output_tokens, - ); - session::storage::update_metadata(&session_file_path, &metadata).await?; + let accumulated_total = + accumulate(session.accumulated_total_tokens, usage.usage.total_tokens); + let accumulated_input = + accumulate(session.accumulated_input_tokens, usage.usage.input_tokens); + let accumulated_output = + accumulate(session.accumulated_output_tokens, usage.usage.output_tokens); + + SessionManager::update_session(session_id) + .schedule_id(session_config.schedule_id.clone()) + .total_tokens(usage.usage.total_tokens) + .input_tokens(usage.usage.input_tokens) + .output_tokens(usage.usage.output_tokens) + .accumulated_total_tokens(accumulated_total) + .accumulated_input_tokens(accumulated_input) + .accumulated_output_tokens(accumulated_output) + .apply() + .await?; Ok(()) } diff --git a/crates/goose/src/agents/schedule_tool.rs b/crates/goose/src/agents/schedule_tool.rs index 7533ede8b03e..d38ab800bbd8 100644 --- a/crates/goose/src/agents/schedule_tool.rs +++ b/crates/goose/src/agents/schedule_tool.rs @@ -421,12 +421,12 @@ impl Agent { } else { let sessions_info: Vec = sessions .into_iter() - .map(|(session_name, metadata)| { + .map(|(session_name, session)| { format!( "- Session: {} (Messages: {}, Working Dir: {})", session_name, - metadata.message_count, - metadata.working_dir.display() + session.conversation.unwrap_or_default().len(), + session.working_dir.display() ) }) .collect(); @@ -462,55 +462,19 @@ impl Agent { ) })?; - // Get the session file path - let session_path = match crate::session::storage::get_path( - crate::session::storage::Identifier::Name(session_id.to_string()), - ) { - Ok(path) => path, - Err(e) => { - return Err(ErrorData::new( - ErrorCode::INTERNAL_ERROR, - format!("Invalid session ID '{}': {}", session_id, e), - None, - )); - } - }; - - // Check if session file exists - if !session_path.exists() { - return Err(ErrorData::new( - ErrorCode::INTERNAL_ERROR, - format!("Session '{}' not found", session_id), - None, - )); - } - - // Read session metadata - let metadata = match crate::session::storage::read_metadata(&session_path) { + let session = match crate::session::SessionManager::get_session(session_id, true).await { Ok(metadata) => metadata, Err(e) => { return Err(ErrorData::new( ErrorCode::INTERNAL_ERROR, - format!("Failed to read session metadata: {}", e), - None, - )); - } - }; - - // Read session messages - let messages = match crate::session::storage::read_messages(&session_path) { - Ok(messages) => messages, - Err(e) => { - return Err(ErrorData::new( - ErrorCode::INTERNAL_ERROR, - format!("Failed to read session messages: {}", e), + format!("Failed to read session for '{}': {}", session_id, e), None, )); } }; // Format the response with metadata and messages - let metadata_json = match serde_json::to_string_pretty(&metadata) { + let metadata_json = match serde_json::to_string_pretty(&session) { Ok(json) => json, Err(e) => { return Err(ErrorData::new( @@ -521,20 +485,9 @@ impl Agent { } }; - let messages_json = match serde_json::to_string_pretty(&messages) { - Ok(json) => json, - Err(e) => { - return Err(ErrorData::new( - ErrorCode::INTERNAL_ERROR, - format!("Failed to serialize messages: {}", e), - None, - )); - } - }; - Ok(vec![Content::text(format!( - "Session '{}' Content:\n\nMetadata:\n{}\n\nMessages:\n{}", - session_id, metadata_json, messages_json + "Session '{}' Content:\n\nSession:\n{}", + session_id, metadata_json ))]) } } diff --git a/crates/goose/src/agents/types.rs b/crates/goose/src/agents/types.rs index 8cb2c8fbadfb..671ee1499353 100644 --- a/crates/goose/src/agents/types.rs +++ b/crates/goose/src/agents/types.rs @@ -1,4 +1,3 @@ -use crate::session; use mcp_core::ToolResult; use rmcp::model::{Content, Tool}; use serde::{Deserialize, Serialize}; @@ -82,7 +81,7 @@ pub struct FrontendTool { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct SessionConfig { /// Unique identifier for the session - pub id: session::Identifier, + pub id: String, /// Working directory for the session pub working_dir: PathBuf, /// ID of the schedule that triggered this session, if any diff --git a/crates/goose/src/context_mgmt/auto_compact.rs b/crates/goose/src/context_mgmt/auto_compact.rs index fa03a941405f..01941af6141c 100644 --- a/crates/goose/src/context_mgmt/auto_compact.rs +++ b/crates/goose/src/context_mgmt/auto_compact.rs @@ -55,7 +55,7 @@ pub async fn check_compaction_needed( agent: &Agent, messages: &[Message], threshold_override: Option, - session_metadata: Option<&crate::session::storage::SessionMetadata>, + session_metadata: Option<&crate::session::Session>, ) -> Result { // Get threshold from config or use override let config = Config::global(); @@ -182,7 +182,7 @@ pub async fn check_and_compact_messages( agent: &Agent, messages: &[Message], threshold_override: Option, - session_metadata: Option<&crate::session::storage::SessionMetadata>, + session_metadata: Option<&crate::session::Session>, ) -> Result { // First check if compaction is needed let check_result = @@ -242,6 +242,7 @@ pub async fn check_and_compact_messages( mod tests { use super::*; use crate::conversation::message::{Message, MessageContent}; + use crate::session::extension_data; use crate::{ agents::Agent, model::ModelConfig, @@ -303,21 +304,32 @@ mod tests { fn create_test_session_metadata( message_count: usize, working_dir: &str, - ) -> crate::session::storage::SessionMetadata { + ) -> crate::session::Session { + use crate::conversation::Conversation; use std::path::PathBuf; - crate::session::storage::SessionMetadata { - message_count, + + let mut conversation = Conversation::default(); + for i in 0..message_count { + conversation.push(create_test_message(format!("message {}", i).as_str())); + } + + crate::session::Session { + id: "test_session".to_string(), working_dir: PathBuf::from(working_dir), description: "Test session".to_string(), + created_at: "2024-01-01T00:00:00Z".to_string(), + updated_at: "2024-01-01T00:00:00Z".to_string(), schedule_id: Some("test_job".to_string()), + recipe: None, total_tokens: Some(100), input_tokens: Some(50), output_tokens: Some(50), accumulated_total_tokens: Some(100), accumulated_input_tokens: Some(50), accumulated_output_tokens: Some(50), - extension_data: crate::session::ExtensionData::new(), - recipe: None, + extension_data: extension_data::ExtensionData::new(), + conversation: Some(conversation), + message_count, } } @@ -540,7 +552,7 @@ mod tests { #[tokio::test] async fn test_auto_compact_uses_session_metadata() { - use crate::session::storage::SessionMetadata; + use crate::session::Session; let mock_provider = Arc::new(MockProvider { model_config: ModelConfig::new("test-model") @@ -557,22 +569,22 @@ mod tests { create_test_message("Second message"), ]; - // Create session metadata with specific token counts + // Create session with specific token counts #[allow(clippy::field_reassign_with_default)] - let mut session_metadata = SessionMetadata::default(); + let mut session = Session::default(); { - session_metadata.total_tokens = Some(8000); // High token count to trigger compaction - session_metadata.accumulated_total_tokens = Some(15000); // Even higher accumulated count - session_metadata.input_tokens = Some(5000); - session_metadata.output_tokens = Some(3000); + session.total_tokens = Some(8000); // High token count to trigger compaction + session.accumulated_total_tokens = Some(15000); // Even higher accumulated count + session.input_tokens = Some(5000); + session.output_tokens = Some(3000); } - // Test with session metadata - should use total_tokens for compaction (not accumulated) + // Test with session - should use total_tokens for compaction (not accumulated) let result_with_metadata = check_compaction_needed( &agent, &messages, Some(0.3), // 30% threshold - Some(&session_metadata), + Some(&session), ) .await .unwrap(); @@ -595,8 +607,8 @@ mod tests { assert!(!result_without_metadata.needs_compaction); assert!(result_without_metadata.current_tokens < 8000); - // Test with metadata that has only accumulated tokens (no total_tokens) - let mut session_metadata_no_total = SessionMetadata::default(); + // Test with session that has only accumulated tokens (no total_tokens) + let mut session_metadata_no_total = Session::default(); #[allow(clippy::field_reassign_with_default)] { session_metadata_no_total.accumulated_total_tokens = Some(7500); @@ -616,7 +628,7 @@ mod tests { assert!(result_with_no_total.current_tokens < 7500); // Test with metadata that has no token counts - should fall back to estimation - let empty_metadata = SessionMetadata::default(); + let empty_metadata = Session::default(); let result_with_empty_metadata = check_compaction_needed( &agent, @@ -634,7 +646,7 @@ mod tests { #[tokio::test] async fn test_auto_compact_end_to_end_with_metadata() { - use crate::session::storage::SessionMetadata; + use crate::session::Session; let mock_provider = Arc::new(MockProvider { model_config: ModelConfig::new("test-model") @@ -655,10 +667,10 @@ mod tests { ]; // Create session metadata with high token count to trigger compaction - let mut session_metadata = SessionMetadata::default(); + let mut session = Session::default(); #[allow(clippy::field_reassign_with_default)] { - session_metadata.total_tokens = Some(9000); // High enough to trigger compaction + session.total_tokens = Some(9000); // High enough to trigger compaction } // Test full compaction flow with session metadata @@ -666,7 +678,7 @@ mod tests { &agent, &messages, Some(0.3), // 30% threshold - Some(&session_metadata), + Some(&session), ) .await .unwrap(); @@ -704,7 +716,14 @@ mod tests { let comprehensive_metadata = create_test_session_metadata(3, "/test/working/dir"); // Verify the helper created non-null metadata - assert_eq!(comprehensive_metadata.message_count, 3); + assert_eq!( + comprehensive_metadata + .clone() + .conversation + .unwrap_or_default() + .len(), + 3 + ); assert_eq!( comprehensive_metadata.working_dir.to_str().unwrap(), "/test/working/dir" diff --git a/crates/goose/src/conversation/mod.rs b/crates/goose/src/conversation/mod.rs index 132e691e1a36..9d4b63924e53 100644 --- a/crates/goose/src/conversation/mod.rs +++ b/crates/goose/src/conversation/mod.rs @@ -3,11 +3,12 @@ use rmcp::model::Role; use serde::{Deserialize, Serialize}; use std::collections::HashSet; use thiserror::Error; +use utoipa::ToSchema; pub mod message; mod tool_result_serde; -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema, PartialEq)] pub struct Conversation(Vec); #[derive(Error, Debug)] @@ -122,6 +123,23 @@ impl Default for Conversation { } } +impl IntoIterator for Conversation { + type Item = Message; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} +impl<'a> IntoIterator for &'a Conversation { + type Item = &'a Message; + type IntoIter = std::slice::Iter<'a, Message>; + + fn into_iter(self) -> Self::IntoIter { + self.0.iter() + } +} + /// Fix a conversation that we're about to send to an LLM. So the last and first /// messages should always be from the user. pub fn fix_conversation(conversation: Conversation) -> (Conversation, Vec) { diff --git a/crates/goose/src/providers/base.rs b/crates/goose/src/providers/base.rs index 3cc034269156..26a1716a75dc 100644 --- a/crates/goose/src/providers/base.rs +++ b/crates/goose/src/providers/base.rs @@ -31,6 +31,8 @@ pub fn get_current_model() -> Option { CURRENT_MODEL.lock().ok().and_then(|model| model.clone()) } +pub static MSG_COUNT_FOR_SESSION_NAME_GENERATION: usize = 3; + /// Information about a model's capabilities #[derive(Debug, Clone, Serialize, Deserialize, ToSchema, PartialEq)] pub struct ModelInfo { @@ -454,7 +456,7 @@ pub trait Provider: Send + Sync { messages .iter() .filter(|m| m.role == rmcp::model::Role::User) - .take(3) + .take(MSG_COUNT_FOR_SESSION_NAME_GENERATION) .map(|m| m.as_concat_text()) .collect() } @@ -476,7 +478,12 @@ pub trait Provider: Send + Sync { ) .await?; - let description = result.0.as_concat_text(); + let description = result + .0 + .as_concat_text() + .split_whitespace() + .collect::>() + .join(" "); Ok(safe_truncate(&description, 100)) } diff --git a/crates/goose/src/scheduler.rs b/crates/goose/src/scheduler.rs index d688c2311021..503c86584ada 100644 --- a/crates/goose/src/scheduler.rs +++ b/crates/goose/src/scheduler.rs @@ -21,8 +21,7 @@ use crate::providers::base::Provider as GooseProvider; // Alias to avoid conflic use crate::providers::create; use crate::recipe::Recipe; use crate::scheduler_trait::SchedulerTrait; -use crate::session; -use crate::session::storage::SessionMetadata; +use crate::session::{Session, SessionManager}; // Track running tasks with their abort handles type RunningTasksMap = HashMap; @@ -649,38 +648,24 @@ impl Scheduler { &self, sched_id: &str, limit: usize, - ) -> Result, SchedulerError> { - // Changed return type - let all_session_files = session::storage::list_sessions() + ) -> Result, SchedulerError> { + let all_sessions = SessionManager::list_sessions() + .await .map_err(|e| SchedulerError::StorageError(io::Error::other(e)))?; - let mut schedule_sessions: Vec<(String, SessionMetadata)> = Vec::new(); + let mut schedule_sessions: Vec<(String, Session)> = Vec::new(); - for (session_name, session_path) in all_session_files { - match session::storage::read_metadata(&session_path) { - Ok(metadata) => { - // metadata is not mutable here, and SessionMetadata is original - if metadata.schedule_id.as_deref() == Some(sched_id) { - schedule_sessions.push((session_name, metadata)); // Keep the tuple - } - } - Err(e) => { - tracing::warn!( - "Failed to read metadata for session file {}: {}. Skipping.", - session_path.display(), - e - ); - } + for session in all_sessions { + if session.schedule_id.as_deref() == Some(sched_id) { + schedule_sessions.push((session.id.clone(), session)); } } + schedule_sessions.sort_by(|a, b| b.0.cmp(&a.0)); - schedule_sessions.sort_by(|a, b| b.0.cmp(&a.0)); // Sort by session_name (timestamp string) - - // Keep the tuple, just take the limit - let result_sessions: Vec<(String, SessionMetadata)> = + let result_sessions: Vec<(String, Session)> = schedule_sessions.into_iter().take(limit).collect(); - Ok(result_sessions) // Return the Vec of tuples + Ok(result_sessions) } pub async fn run_now(&self, sched_id: &str) -> Result { @@ -1066,7 +1051,7 @@ struct JobExecutionError { async fn run_scheduled_job_internal( job: ScheduledJob, - provider_override: Option>, // New optional parameter + provider_override: Option>, jobs_arc: Option>>, job_id: Option, ) -> std::result::Result { @@ -1116,7 +1101,7 @@ async fn run_scheduled_job_internal( let agent: Agent = Agent::new(); - let agent_provider: Arc; // Use the aliased GooseProvider + let agent_provider: Arc; if let Some(provider) = provider_override { agent_provider = provider; @@ -1155,7 +1140,8 @@ async fn run_scheduled_job_internal( ), })?; } - if let Some(recipe_extensions) = recipe.extensions { + + if let Some(ref recipe_extensions) = recipe.extensions { for extension in recipe_extensions { agent .add_extension(extension.clone()) @@ -1174,49 +1160,53 @@ async fn run_scheduled_job_internal( }); } tracing::info!("Agent configured with provider for job '{}'", job.id); - - // Log the execution mode let execution_mode = job.execution_mode.as_deref().unwrap_or("background"); tracing::info!("Job '{}' running in {} mode", job.id, execution_mode); - let session_id_for_return = session::generate_session_id(); - - // Update the job with the session ID if we have access to the jobs arc - if let (Some(jobs_arc), Some(job_id_str)) = (jobs_arc.as_ref(), job_id.as_ref()) { - let mut jobs_guard = jobs_arc.lock().await; - if let Some((_, job_def)) = jobs_guard.get_mut(job_id_str) { - job_def.current_session_id = Some(session_id_for_return.clone()); + let current_dir = match std::env::current_dir() { + Ok(cd) => cd, + Err(e) => { + return Err(JobExecutionError { + job_id: job.id.clone(), + error: format!("Failed to get current directory for job execution: {}", e), + }); } - } + }; - let session_file_path = match crate::session::storage::get_path( - crate::session::storage::Identifier::Name(session_id_for_return.clone()), - ) { - Ok(path) => path, + // Create session upfront for both cases + let session = match SessionManager::create_session( + current_dir.clone(), + if recipe.prompt.is_some() { + format!("Scheduled job: {}", job.id) + } else { + "Empty job - no prompt".to_string() + }, + ) + .await + { + Ok(s) => s, Err(e) => { return Err(JobExecutionError { job_id: job.id.clone(), - error: format!("Failed to get session file path: {}", e), + error: format!("Failed to create session: {}", e), }); } }; - if let Some(prompt_text) = recipe.prompt { + // Update the job with the session ID if we have access to the jobs arc + if let (Some(jobs_arc), Some(job_id_str)) = (jobs_arc.as_ref(), job_id.as_ref()) { + let mut jobs_guard = jobs_arc.lock().await; + if let Some((_, job_def)) = jobs_guard.get_mut(job_id_str) { + job_def.current_session_id = Some(session.id.clone()); + } + } + + if let Some(ref prompt_text) = recipe.prompt { let mut all_session_messages = Conversation::new_unvalidated(vec![Message::user().with_text(prompt_text.clone())]); - let current_dir = match std::env::current_dir() { - Ok(cd) => cd, - Err(e) => { - return Err(JobExecutionError { - job_id: job.id.clone(), - error: format!("Failed to get current directory for job execution: {}", e), - }); - } - }; - let session_config = SessionConfig { - id: crate::session::storage::Identifier::Name(session_id_for_return.clone()), + id: session.id.clone(), working_dir: current_dir.clone(), schedule_id: Some(job.id.clone()), execution_mode: job.execution_mode.clone(), @@ -1236,7 +1226,6 @@ async fn run_scheduled_job_internal( use futures::StreamExt; while let Some(message_result) = stream.next().await { - // Check if the task has been cancelled tokio::task::yield_now().await; match message_result { @@ -1246,15 +1235,9 @@ async fn run_scheduled_job_internal( } all_session_messages.push(msg); } - Ok(AgentEvent::McpNotification(_)) => { - // Handle notifications if needed - } - Ok(AgentEvent::ModelChange { .. }) => { - // Model change events are informational, just continue - } - Ok(AgentEvent::HistoryReplaced(_)) => { - // Handle history replacement events if needed - } + Ok(AgentEvent::McpNotification(_)) => {} + Ok(AgentEvent::ModelChange { .. }) => {} + Ok(AgentEvent::HistoryReplaced(_)) => {} Err(e) => { tracing::error!( "[Job {}] Error receiving message from agent: {}", @@ -1265,51 +1248,6 @@ async fn run_scheduled_job_internal( } } } - - match crate::session::storage::read_metadata(&session_file_path) { - Ok(mut updated_metadata) => { - updated_metadata.message_count = all_session_messages.len(); - if let Err(e) = crate::session::storage::save_messages_with_metadata( - &session_file_path, - &updated_metadata, - &all_session_messages, - ) { - tracing::error!( - "[Job {}] Failed to persist final messages: {}", - job.id, - e - ); - } - } - Err(e) => { - tracing::error!( - "[Job {}] Failed to read updated metadata before final save: {}", - job.id, - e - ); - let fallback_metadata = crate::session::storage::SessionMetadata { - working_dir: current_dir.clone(), - description: String::new(), - schedule_id: Some(job.id.clone()), - message_count: all_session_messages.len(), - total_tokens: None, - input_tokens: None, - output_tokens: None, - accumulated_total_tokens: None, - accumulated_input_tokens: None, - accumulated_output_tokens: None, - extension_data: crate::session::ExtensionData::new(), - recipe: None, - }; - if let Err(e_fb) = crate::session::storage::save_messages_with_metadata( - &session_file_path, - &fallback_metadata, - &all_session_messages, - ) { - tracing::error!("[Job {}] Failed to persist final messages with fallback metadata: {}", job.id, e_fb); - } - } - } } Err(e) => { return Err(JobExecutionError { @@ -1324,28 +1262,19 @@ async fn run_scheduled_job_internal( job.id, job.source ); - let metadata = crate::session::storage::SessionMetadata { - working_dir: std::env::current_dir().unwrap_or_default(), - description: "Empty job - no prompt".to_string(), - schedule_id: Some(job.id.clone()), - message_count: 0, - ..Default::default() - }; - if let Err(e) = crate::session::storage::save_messages_with_metadata( - &session_file_path, - &metadata, - &Conversation::new_unvalidated(vec![]), - ) { - tracing::error!( - "[Job {}] Failed to persist metadata for empty job: {}", - job.id, - e - ); - } + } + + if let Err(e) = SessionManager::update_session(&session.id) + .schedule_id(Some(job.id.clone())) + .recipe(Some(recipe)) + .apply() + .await + { + tracing::error!("[Job {}] Failed to update session metadata: {}", job.id, e); } tracing::info!("Finished job: {}", job.id); - Ok(session_id_for_return) + Ok(session.id) } #[async_trait] @@ -1378,7 +1307,7 @@ impl SchedulerTrait for Scheduler { &self, sched_id: &str, limit: usize, - ) -> Result, SchedulerError> { + ) -> Result, SchedulerError> { self.sessions(sched_id, limit).await } @@ -1407,15 +1336,12 @@ mod tests { use super::*; use crate::recipe::Recipe; use crate::{ - model::ModelConfig, // Use the actual ModelConfig for the mock's field + model::ModelConfig, providers::base::{ProviderMetadata, ProviderUsage, Usage}, providers::errors::ProviderError, }; use rmcp::model::Tool; use rmcp::model::{AnnotateAble, RawTextContent, Role}; - // Removed: use crate::session::storage::{get_most_recent_session, read_metadata}; - // `read_metadata` is still used by the test itself, so keep it or its module. - use crate::session::storage::read_metadata; use crate::conversation::message::{Message, MessageContent}; use std::env; @@ -1488,7 +1414,8 @@ mod tests { let recipe_dir = temp_dir.path().join("recipes_for_test_scheduler"); fs::create_dir_all(&recipe_dir)?; - let _ = session::storage::ensure_session_dir().expect("Failed to ensure app session dir"); + let _ = crate::session::session_manager::ensure_session_dir() + .expect("Failed to ensure app session dir"); let schedule_id_str = "test_schedule_001_scheduler_check".to_string(); let recipe_filename = recipe_dir.join(format!("{}.json", schedule_id_str)); @@ -1539,39 +1466,31 @@ mod tests { .await .expect("run_scheduled_job_internal failed"); - let session_dir = session::storage::ensure_session_dir()?; - let expected_session_path = session_dir.join(format!("{}.jsonl", created_session_id)); - - assert!( - expected_session_path.exists(), - "Expected session file {} was not created", - expected_session_path.display() - ); - - let metadata = read_metadata(&expected_session_path)?; + let session = SessionManager::get_session(&created_session_id, true).await?; + let schedule_id = session.schedule_id.clone(); assert_eq!( - metadata.schedule_id, + schedule_id, Some(schedule_id_str.clone()), - "Session metadata schedule_id ({:?}) does not match the job ID ({}). File: {}", - metadata.schedule_id, + "Session metadata schedule_id ({:?}) does not match the job ID ({}). Session: {}", + schedule_id, schedule_id_str, - expected_session_path.display() + created_session_id ); - // Check if messages were written - let messages_in_file = crate::session::storage::read_messages(&expected_session_path)?; + // Check if messages were written using SessionManager + let messages_in_session = session.conversation.unwrap_or_default(); assert!( - !messages_in_file.is_empty(), - "No messages were written to the session file: {}", - expected_session_path.display() + !messages_in_session.is_empty(), + "No messages were written to the session: {}", + created_session_id ); // We expect at least a user prompt and an assistant response assert!( - messages_in_file.len() >= 2, - "Expected at least 2 messages (prompt + response), found {} in file: {}", - messages_in_file.len(), - expected_session_path.display() + messages_in_session.len() >= 2, + "Expected at least 2 messages (prompt + response), found {} in session: {}", + messages_in_session.len(), + created_session_id ); // Clean up environment variables diff --git a/crates/goose/src/scheduler_trait.rs b/crates/goose/src/scheduler_trait.rs index f23b124c2e3b..c4ef864576d9 100644 --- a/crates/goose/src/scheduler_trait.rs +++ b/crates/goose/src/scheduler_trait.rs @@ -2,7 +2,7 @@ use async_trait::async_trait; use chrono::{DateTime, Utc}; use crate::scheduler::{ScheduledJob, SchedulerError}; -use crate::session::storage::SessionMetadata; +use crate::session::Session; /// Common trait for all scheduler implementations #[async_trait] @@ -30,7 +30,7 @@ pub trait SchedulerTrait: Send + Sync { &self, sched_id: &str, limit: usize, - ) -> Result, SchedulerError>; + ) -> Result, SchedulerError>; /// Update a schedule's cron expression async fn update_schedule(&self, sched_id: &str, new_cron: String) diff --git a/crates/goose/src/session/info.rs b/crates/goose/src/session/info.rs deleted file mode 100644 index 9ebedb1aeef3..000000000000 --- a/crates/goose/src/session/info.rs +++ /dev/null @@ -1,99 +0,0 @@ -use crate::session::{self, SessionMetadata}; -use anyhow::Result; -use serde::Serialize; -use std::cmp::Ordering; -use utoipa::ToSchema; - -#[derive(Clone, Serialize, ToSchema)] -pub struct SessionInfo { - pub id: String, - pub path: String, - pub modified: String, - pub metadata: SessionMetadata, -} - -/// Sort order for listing sessions -pub enum SortOrder { - Ascending, - Descending, -} - -pub fn get_valid_sorted_sessions(sort_order: SortOrder) -> Result> { - let sessions = match session::list_sessions() { - Ok(sessions) => sessions, - Err(e) => { - tracing::error!("Failed to list sessions: {:?}", e); - return Err(anyhow::anyhow!("Failed to list sessions")); - } - }; - - let mut session_infos: Vec = Vec::new(); - let mut corrupted_count = 0; - - for (id, path) in sessions { - // Get file modification time with fallback - let modified = path - .metadata() - .and_then(|m| m.modified()) - .map(|time| { - chrono::DateTime::::from(time) - .format("%Y-%m-%d %H:%M:%S UTC") - .to_string() - }) - .unwrap_or_else(|_| { - tracing::warn!("Failed to get modification time for session: {}", id); - "Unknown".to_string() - }); - - // Try to read metadata with error handling - match session::read_metadata(&path) { - Ok(metadata) => { - session_infos.push(SessionInfo { - id, - path: path.to_string_lossy().to_string(), - modified, - metadata, - }); - } - Err(e) => { - corrupted_count += 1; - tracing::warn!( - "Failed to read metadata for session '{}': {}. Skipping corrupted session.", - id, - e - ); - - // Optionally, we could create a placeholder entry for corrupted sessions - // to show them in the UI with an error indicator, but for now we skip them - continue; - } - } - } - - if corrupted_count > 0 { - tracing::warn!( - "Skipped {} corrupted sessions during listing", - corrupted_count - ); - } - - // Sort sessions by modified date - // Since all dates are in ISO format (YYYY-MM-DD HH:MM:SS UTC), we can just use string comparison - // This works because the ISO format ensures lexicographical ordering matches chronological ordering - session_infos.sort_by(|a, b| { - if a.modified == "Unknown" && b.modified == "Unknown" { - return Ordering::Equal; - } else if a.modified == "Unknown" { - return Ordering::Greater; // Unknown dates go last - } else if b.modified == "Unknown" { - return Ordering::Less; - } - - match sort_order { - SortOrder::Ascending => a.modified.cmp(&b.modified), - SortOrder::Descending => b.modified.cmp(&a.modified), - } - }); - - Ok(session_infos) -} diff --git a/crates/goose/src/session/legacy.rs b/crates/goose/src/session/legacy.rs new file mode 100644 index 000000000000..3de432ce8228 --- /dev/null +++ b/crates/goose/src/session/legacy.rs @@ -0,0 +1,113 @@ +use crate::conversation::Conversation; +use crate::session::Session; +use anyhow::Result; +use chrono::NaiveDateTime; +use std::fs; +use std::io::{self, BufRead}; +use std::path::{Path, PathBuf}; +use std::time::SystemTime; + +const MAX_FILE_SIZE: u64 = 50 * 1024 * 1024; + +pub fn list_sessions(session_dir: &PathBuf) -> Result> { + let entries = fs::read_dir(session_dir)? + .filter_map(|entry| { + let entry = entry.ok()?; + let path = entry.path(); + + if path.extension().is_some_and(|ext| ext == "jsonl") { + let name = path.file_stem()?.to_string_lossy().to_string(); + Some((name, path)) + } else { + None + } + }) + .collect::>(); + + Ok(entries) +} + +pub fn load_session(session_name: &str, session_path: &Path) -> Result { + let file = fs::File::open(session_path).map_err(|e| { + anyhow::anyhow!( + "Failed to open session file {}: {}", + session_path.display(), + e + ) + })?; + + let file_metadata = file.metadata()?; + + if file_metadata.len() > MAX_FILE_SIZE { + return Err(anyhow::anyhow!("Session file too large")); + } + if file_metadata.len() == 0 { + return Err(anyhow::anyhow!("Empty session file")); + } + + let modified_time = file_metadata.modified().unwrap_or(SystemTime::now()); + let created_time = file_metadata + .created() + .unwrap_or_else(|_| parse_session_timestamp(session_name).unwrap_or(modified_time)); + + let reader = io::BufReader::new(file); + let mut lines = reader.lines(); + let mut messages = Vec::new(); + let mut session = Session { + id: session_name.to_string(), + ..Default::default() + }; + + if let Some(Ok(line)) = lines.next() { + let mut metadata_json: serde_json::Value = serde_json::from_str(&line) + .map_err(|_| anyhow::anyhow!("Invalid session metadata JSON"))?; + + if let Some(obj) = metadata_json.as_object_mut() { + obj.entry("id").or_insert(serde_json::json!(session_name)); + obj.entry("created_at") + .or_insert(serde_json::json!(format_timestamp(created_time)?)); + obj.entry("updated_at") + .or_insert(serde_json::json!(format_timestamp(modified_time)?)); + obj.entry("extension_data").or_insert(serde_json::json!({})); + obj.entry("message_count").or_insert(serde_json::json!(0)); + + if let Some(desc) = obj.get_mut("description") { + if let Some(desc_str) = desc.as_str() { + *desc = serde_json::json!(desc_str + .split_whitespace() + .collect::>() + .join(" ")); + } + } + } + session = serde_json::from_value(metadata_json)?; + session.id = session_name.to_string(); + } + + for line in lines.map_while(Result::ok) { + if let Ok(message) = serde_json::from_str(&line) { + messages.push(message); + } + } + + if !messages.is_empty() { + session.conversation = Some(Conversation::new_unvalidated(messages)); + } + + Ok(session) +} + +fn format_timestamp(time: SystemTime) -> Result { + let duration = time.duration_since(std::time::UNIX_EPOCH)?; + let timestamp = chrono::DateTime::from_timestamp(duration.as_secs() as i64, 0) + .unwrap_or_default() + .format("%Y-%m-%d %H:%M:%S") + .to_string(); + Ok(timestamp) +} + +fn parse_session_timestamp(session_name: &str) -> Option { + NaiveDateTime::parse_from_str(session_name, "%Y%m%d_%H%M%S") + .ok() + .map(|dt| SystemTime::from(dt.and_utc())) +} diff --git a/crates/goose/src/session/mod.rs b/crates/goose/src/session/mod.rs index 97381c7e522f..5ed8312ecd54 100644 --- a/crates/goose/src/session/mod.rs +++ b/crates/goose/src/session/mod.rs @@ -1,14 +1,5 @@ pub mod extension_data; -pub mod info; -pub mod storage; +mod legacy; +pub mod session_manager; -// Re-export common session types and functions -pub use storage::{ - ensure_session_dir, generate_description, generate_description_with_schedule_id, - generate_session_id, get_most_recent_session, get_path, list_sessions, persist_messages, - persist_messages_with_schedule_id, read_messages, read_metadata, update_metadata, Identifier, - SessionMetadata, -}; - -pub use extension_data::{ExtensionData, ExtensionState, TodoState}; -pub use info::{get_valid_sorted_sessions, SessionInfo}; +pub use session_manager::{Session, SessionInsights, SessionManager}; diff --git a/crates/goose/src/session/session_manager.rs b/crates/goose/src/session/session_manager.rs new file mode 100644 index 000000000000..6bbd57bdce94 --- /dev/null +++ b/crates/goose/src/session/session_manager.rs @@ -0,0 +1,858 @@ +use crate::config::APP_STRATEGY; +use crate::conversation::message::Message; +use crate::conversation::Conversation; +use crate::providers::base::{Provider, MSG_COUNT_FOR_SESSION_NAME_GENERATION}; +use crate::recipe::Recipe; +use crate::session::extension_data::ExtensionData; +use anyhow::Result; +use etcetera::{choose_app_strategy, AppStrategy}; +use rmcp::model::Role; +use serde::{Deserialize, Serialize}; +use sqlx::sqlite::SqliteConnectOptions; +use sqlx::{Pool, Sqlite}; +use std::fs; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use tokio::sync::OnceCell; +use tracing::{info, warn}; +use utoipa::ToSchema; + +const CURRENT_SCHEMA_VERSION: i32 = 1; + +static SESSION_STORAGE: OnceCell> = OnceCell::const_new(); + +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +pub struct Session { + pub id: String, + #[schema(value_type = String)] + pub working_dir: PathBuf, + pub description: String, + pub created_at: String, + pub updated_at: String, + pub extension_data: ExtensionData, + pub total_tokens: Option, + pub input_tokens: Option, + pub output_tokens: Option, + pub accumulated_total_tokens: Option, + pub accumulated_input_tokens: Option, + pub accumulated_output_tokens: Option, + pub schedule_id: Option, + pub recipe: Option, + pub conversation: Option, + pub message_count: usize, +} + +pub struct SessionUpdateBuilder { + session_id: String, + description: Option, + working_dir: Option, + extension_data: Option, + total_tokens: Option>, + input_tokens: Option>, + output_tokens: Option>, + accumulated_total_tokens: Option>, + accumulated_input_tokens: Option>, + accumulated_output_tokens: Option>, + schedule_id: Option>, + recipe: Option>, +} + +#[derive(Serialize, ToSchema, Debug)] +#[serde(rename_all = "camelCase")] +pub struct SessionInsights { + /// Total number of sessions + total_sessions: usize, + /// Total tokens used across all sessions + total_tokens: i64, +} + +impl SessionUpdateBuilder { + fn new(session_id: String) -> Self { + Self { + session_id, + description: None, + working_dir: None, + extension_data: None, + total_tokens: None, + input_tokens: None, + output_tokens: None, + accumulated_total_tokens: None, + accumulated_input_tokens: None, + accumulated_output_tokens: None, + schedule_id: None, + recipe: None, + } + } + + pub fn description(mut self, description: impl Into) -> Self { + self.description = Some(description.into()); + self + } + + pub fn working_dir(mut self, working_dir: PathBuf) -> Self { + self.working_dir = Some(working_dir); + self + } + + pub fn extension_data(mut self, data: ExtensionData) -> Self { + self.extension_data = Some(data); + self + } + + pub fn total_tokens(mut self, tokens: Option) -> Self { + self.total_tokens = Some(tokens); + self + } + + pub fn input_tokens(mut self, tokens: Option) -> Self { + self.input_tokens = Some(tokens); + self + } + + pub fn output_tokens(mut self, tokens: Option) -> Self { + self.output_tokens = Some(tokens); + self + } + + pub fn accumulated_total_tokens(mut self, tokens: Option) -> Self { + self.accumulated_total_tokens = Some(tokens); + self + } + + pub fn accumulated_input_tokens(mut self, tokens: Option) -> Self { + self.accumulated_input_tokens = Some(tokens); + self + } + + pub fn accumulated_output_tokens(mut self, tokens: Option) -> Self { + self.accumulated_output_tokens = Some(tokens); + self + } + + pub fn schedule_id(mut self, schedule_id: Option) -> Self { + self.schedule_id = Some(schedule_id); + self + } + + pub fn recipe(mut self, recipe: Option) -> Self { + self.recipe = Some(recipe); + self + } + + pub async fn apply(self) -> Result<()> { + SessionManager::apply_update(self).await + } +} + +pub struct SessionManager; + +impl SessionManager { + pub async fn instance() -> Result> { + SESSION_STORAGE + .get_or_try_init(|| async { SessionStorage::new().await.map(Arc::new) }) + .await + .map(Arc::clone) + } + + pub async fn create_session(working_dir: PathBuf, description: String) -> Result { + let today = chrono::Utc::now().format("%Y%m%d").to_string(); + let storage = Self::instance().await?; + + let mut tx = storage.pool.begin().await?; + + let max_idx = sqlx::query_scalar::<_, Option>( + "SELECT MAX(CAST(SUBSTR(id, 10) AS INTEGER)) FROM sessions WHERE id LIKE ?", + ) + .bind(format!("{}_%", today)) + .fetch_one(&mut *tx) + .await? + .unwrap_or(0); + + let session_id = format!("{}_{}", today, max_idx + 1); + + sqlx::query( + r#" + INSERT INTO sessions (id, description, working_dir, extension_data) + VALUES (?, ?, ?, '{}') + "#, + ) + .bind(&session_id) + .bind(&description) + .bind(working_dir.to_string_lossy().as_ref()) + .execute(&mut *tx) + .await?; + + tx.commit().await?; + + Self::get_session(&session_id, false).await + } + + pub async fn get_session(id: &str, include_messages: bool) -> Result { + Self::instance() + .await? + .get_session(id, include_messages) + .await + } + + pub fn update_session(id: &str) -> SessionUpdateBuilder { + SessionUpdateBuilder::new(id.to_string()) + } + + async fn apply_update(builder: SessionUpdateBuilder) -> Result<()> { + Self::instance().await?.apply_update(builder).await + } + + pub async fn add_message(id: &str, message: &Message) -> Result<()> { + Self::instance().await?.add_message(id, message).await + } + + pub async fn replace_conversation(id: &str, conversation: &Conversation) -> Result<()> { + Self::instance() + .await? + .replace_conversation(id, conversation) + .await + } + + pub async fn list_sessions() -> Result> { + Self::instance().await?.list_sessions().await + } + + pub async fn delete_session(id: &str) -> Result<()> { + Self::instance().await?.delete_session(id).await + } + + pub async fn get_insights() -> Result { + Self::instance().await?.get_insights().await + } + + pub async fn maybe_update_description(id: &str, provider: Arc) -> Result<()> { + let session = Self::get_session(id, true).await?; + let conversation = session + .conversation + .ok_or_else(|| anyhow::anyhow!("No messages found"))?; + + let user_message_count = conversation + .messages() + .iter() + .filter(|m| matches!(m.role, Role::User)) + .count(); + + if user_message_count <= MSG_COUNT_FOR_SESSION_NAME_GENERATION { + let description = provider.generate_session_name(&conversation).await?; + Self::update_session(id) + .description(description) + .apply() + .await + } else { + Ok(()) + } + } +} + +pub struct SessionStorage { + pool: Pool, +} + +pub fn ensure_session_dir() -> Result { + let data_dir = choose_app_strategy(APP_STRATEGY.clone()) + .expect("goose requires a home dir") + .data_dir() + .join("sessions"); + + if !data_dir.exists() { + fs::create_dir_all(&data_dir)?; + } + + Ok(data_dir) +} + +fn role_to_string(role: &Role) -> &'static str { + match role { + Role::User => "user", + Role::Assistant => "assistant", + } +} + +impl Default for Session { + fn default() -> Self { + Self { + id: String::new(), + working_dir: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")), + description: String::new(), + created_at: String::new(), + updated_at: String::new(), + extension_data: ExtensionData::default(), + total_tokens: None, + input_tokens: None, + output_tokens: None, + accumulated_total_tokens: None, + accumulated_input_tokens: None, + accumulated_output_tokens: None, + schedule_id: None, + recipe: None, + conversation: None, + message_count: 0, + } + } +} + +impl Session { + pub fn without_messages(mut self) -> Self { + self.conversation = None; + self + } +} + +impl sqlx::FromRow<'_, sqlx::sqlite::SqliteRow> for Session { + fn from_row(row: &sqlx::sqlite::SqliteRow) -> Result { + use sqlx::Row; + + let recipe_json: Option = row.try_get("recipe_json")?; + let recipe = recipe_json.and_then(|json| serde_json::from_str(&json).ok()); + + Ok(Session { + id: row.try_get("id")?, + working_dir: PathBuf::from(row.try_get::("working_dir")?), + description: row.try_get("description")?, + created_at: row.try_get("created_at")?, + updated_at: row.try_get("updated_at")?, + extension_data: serde_json::from_str(&row.try_get::("extension_data")?) + .unwrap_or_default(), + total_tokens: row.try_get("total_tokens")?, + input_tokens: row.try_get("input_tokens")?, + output_tokens: row.try_get("output_tokens")?, + accumulated_total_tokens: row.try_get("accumulated_total_tokens")?, + accumulated_input_tokens: row.try_get("accumulated_input_tokens")?, + accumulated_output_tokens: row.try_get("accumulated_output_tokens")?, + schedule_id: row.try_get("schedule_id")?, + recipe, + conversation: None, + message_count: row.try_get("message_count").unwrap_or(0) as usize, + }) + } +} + +impl SessionStorage { + async fn new() -> Result { + let session_dir = ensure_session_dir()?; + let db_path = session_dir.join("sessions.db"); + + let storage = if db_path.exists() { + Self::open(&db_path).await? + } else { + let storage = Self::create(&db_path).await?; + + if let Err(e) = storage.import_legacy(&session_dir).await { + warn!("Failed to import some legacy sessions: {}", e); + } + + storage + }; + + Ok(storage) + } + + async fn get_pool(db_path: &Path, create_if_missing: bool) -> Result> { + let options = SqliteConnectOptions::new() + .filename(db_path) + .create_if_missing(create_if_missing); + + sqlx::SqlitePool::connect_with(options).await.map_err(|e| { + anyhow::anyhow!( + "Failed to open SQLite database at '{}': {}", + db_path.display(), + e + ) + }) + } + + async fn open(db_path: &Path) -> Result { + let pool = Self::get_pool(db_path, false).await?; + + let storage = Self { pool }; + storage.run_migrations().await?; + Ok(storage) + } + + async fn create(db_path: &Path) -> Result { + let pool = Self::get_pool(db_path, true).await?; + + sqlx::query( + r#" + CREATE TABLE schema_version ( + version INTEGER PRIMARY KEY, + applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + "#, + ) + .execute(&pool) + .await?; + + sqlx::query("INSERT INTO schema_version (version) VALUES (?)") + .bind(CURRENT_SCHEMA_VERSION) + .execute(&pool) + .await?; + + sqlx::query( + r#" + CREATE TABLE sessions ( + id TEXT PRIMARY KEY, + description TEXT NOT NULL DEFAULT '', + working_dir TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + extension_data TEXT DEFAULT '{}', + total_tokens INTEGER, + input_tokens INTEGER, + output_tokens INTEGER, + accumulated_total_tokens INTEGER, + accumulated_input_tokens INTEGER, + accumulated_output_tokens INTEGER, + schedule_id TEXT, + recipe_json TEXT + ) + "#, + ) + .execute(&pool) + .await?; + + sqlx::query( + r#" + CREATE TABLE messages ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT NOT NULL REFERENCES sessions(id), + role TEXT NOT NULL, + content_json TEXT NOT NULL, + created_timestamp INTEGER NOT NULL, + timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + tokens INTEGER + ) + "#, + ) + .execute(&pool) + .await?; + + sqlx::query("CREATE INDEX idx_messages_session ON messages(session_id)") + .execute(&pool) + .await?; + sqlx::query("CREATE INDEX idx_messages_timestamp ON messages(timestamp)") + .execute(&pool) + .await?; + sqlx::query("CREATE INDEX idx_sessions_updated ON sessions(updated_at DESC)") + .execute(&pool) + .await?; + + Ok(Self { pool }) + } + + async fn import_legacy(&self, session_dir: &PathBuf) -> Result<()> { + use crate::session::legacy; + + let sessions = match legacy::list_sessions(session_dir) { + Ok(sessions) => sessions, + Err(_) => { + warn!("No legacy sessions found to import"); + return Ok(()); + } + }; + + if sessions.is_empty() { + return Ok(()); + } + + let mut imported_count = 0; + let mut failed_count = 0; + + for (session_name, session_path) in sessions { + match legacy::load_session(&session_name, &session_path) { + Ok(session) => match self.import_legacy_session(&session).await { + Ok(_) => { + imported_count += 1; + info!(" ✓ Imported: {}", session_name); + } + Err(e) => { + failed_count += 1; + info!(" ✗ Failed to import {}: {}", session_name, e); + } + }, + Err(e) => { + failed_count += 1; + info!(" ✗ Failed to load {}: {}", session_name, e); + } + } + } + + info!( + "Import complete: {} successful, {} failed", + imported_count, failed_count + ); + Ok(()) + } + + async fn import_legacy_session(&self, session: &Session) -> Result<()> { + let recipe_json = match &session.recipe { + Some(recipe) => Some(serde_json::to_string(recipe)?), + None => None, + }; + + sqlx::query( + r#" + INSERT INTO sessions ( + id, description, working_dir, created_at, updated_at, extension_data, + total_tokens, input_tokens, output_tokens, + accumulated_total_tokens, accumulated_input_tokens, accumulated_output_tokens, + schedule_id, recipe_json + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + "#, + ) + .bind(&session.id) + .bind(&session.description) + .bind(session.working_dir.to_string_lossy().as_ref()) + .bind(&session.created_at) + .bind(&session.updated_at) + .bind(serde_json::to_string(&session.extension_data)?) + .bind(session.total_tokens) + .bind(session.input_tokens) + .bind(session.output_tokens) + .bind(session.accumulated_total_tokens) + .bind(session.accumulated_input_tokens) + .bind(session.accumulated_output_tokens) + .bind(&session.schedule_id) + .bind(recipe_json) + .execute(&self.pool) + .await?; + + if let Some(conversation) = &session.conversation { + self.replace_conversation(&session.id, conversation).await?; + } + Ok(()) + } + + async fn run_migrations(&self) -> Result<()> { + let current_version = self.get_schema_version().await?; + + if current_version < CURRENT_SCHEMA_VERSION { + info!( + "Running database migrations from v{} to v{}...", + current_version, CURRENT_SCHEMA_VERSION + ); + + for version in (current_version + 1)..=CURRENT_SCHEMA_VERSION { + info!(" Applying migration v{}...", version); + self.apply_migration(version).await?; + self.update_schema_version(version).await?; + info!(" ✓ Migration v{} complete", version); + } + + info!("All migrations complete"); + } + + Ok(()) + } + + async fn get_schema_version(&self) -> Result { + let table_exists = sqlx::query_scalar::<_, bool>( + r#" + SELECT EXISTS ( + SELECT name FROM sqlite_master + WHERE type='table' AND name='schema_version' + ) + "#, + ) + .fetch_one(&self.pool) + .await?; + + if !table_exists { + return Ok(0); + } + + let version = sqlx::query_scalar::<_, i32>("SELECT MAX(version) FROM schema_version") + .fetch_one(&self.pool) + .await?; + + Ok(version) + } + + async fn update_schema_version(&self, version: i32) -> Result<()> { + sqlx::query("INSERT INTO schema_version (version) VALUES (?)") + .bind(version) + .execute(&self.pool) + .await?; + Ok(()) + } + + async fn apply_migration(&self, version: i32) -> Result<()> { + match version { + 1 => { + sqlx::query( + r#" + CREATE TABLE IF NOT EXISTS schema_version ( + version INTEGER PRIMARY KEY, + applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + "#, + ) + .execute(&self.pool) + .await?; + } + _ => { + anyhow::bail!("Unknown migration version: {}", version); + } + } + + Ok(()) + } + + async fn get_session(&self, id: &str, include_messages: bool) -> Result { + let mut session = sqlx::query_as::<_, Session>( + r#" + SELECT id, working_dir, description, created_at, updated_at, extension_data, + total_tokens, input_tokens, output_tokens, + accumulated_total_tokens, accumulated_input_tokens, accumulated_output_tokens, + schedule_id, recipe_json + FROM sessions + WHERE id = ? + "#, + ) + .bind(id) + .fetch_optional(&self.pool) + .await? + .ok_or_else(|| anyhow::anyhow!("Session not found"))?; + + if include_messages { + let conv = self.get_conversation(&session.id).await?; + session.message_count = conv.messages().len(); + session.conversation = Some(conv); + } else { + let count = + sqlx::query_scalar::<_, i64>("SELECT COUNT(*) FROM messages WHERE session_id = ?") + .bind(&session.id) + .fetch_one(&self.pool) + .await? as usize; + session.message_count = count; + } + + Ok(session) + } + + async fn apply_update(&self, builder: SessionUpdateBuilder) -> Result<()> { + let mut updates = Vec::new(); + let mut query = String::from("UPDATE sessions SET "); + + macro_rules! add_update { + ($field:expr, $name:expr) => { + if $field.is_some() { + if !updates.is_empty() { + query.push_str(", "); + } + updates.push($name); + query.push_str($name); + query.push_str(" = ?"); + } + }; + } + + add_update!(builder.description, "description"); + add_update!(builder.working_dir, "working_dir"); + add_update!(builder.extension_data, "extension_data"); + add_update!(builder.total_tokens, "total_tokens"); + add_update!(builder.input_tokens, "input_tokens"); + add_update!(builder.output_tokens, "output_tokens"); + add_update!(builder.accumulated_total_tokens, "accumulated_total_tokens"); + add_update!(builder.accumulated_input_tokens, "accumulated_input_tokens"); + add_update!( + builder.accumulated_output_tokens, + "accumulated_output_tokens" + ); + add_update!(builder.schedule_id, "schedule_id"); + add_update!(builder.recipe, "recipe_json"); + + if updates.is_empty() { + return Ok(()); + } + + if !updates.is_empty() { + query.push_str(", "); + } + query.push_str("updated_at = datetime('now') WHERE id = ?"); + + let mut q = sqlx::query(&query); + + if let Some(desc) = builder.description { + q = q.bind(desc); + } + if let Some(wd) = builder.working_dir { + q = q.bind(wd.to_string_lossy().to_string()); + } + if let Some(ed) = builder.extension_data { + q = q.bind(serde_json::to_string(&ed)?); + } + if let Some(tt) = builder.total_tokens { + q = q.bind(tt); + } + if let Some(it) = builder.input_tokens { + q = q.bind(it); + } + if let Some(ot) = builder.output_tokens { + q = q.bind(ot); + } + if let Some(att) = builder.accumulated_total_tokens { + q = q.bind(att); + } + if let Some(ait) = builder.accumulated_input_tokens { + q = q.bind(ait); + } + if let Some(aot) = builder.accumulated_output_tokens { + q = q.bind(aot); + } + if let Some(sid) = builder.schedule_id { + q = q.bind(sid); + } + if let Some(recipe) = builder.recipe { + let recipe_json = recipe.map(|r| serde_json::to_string(&r)).transpose()?; + q = q.bind(recipe_json); + } + + q = q.bind(&builder.session_id); + q.execute(&self.pool).await?; + + Ok(()) + } + + async fn get_conversation(&self, session_id: &str) -> Result { + let rows = sqlx::query_as::<_, (String, String, i64)>( + "SELECT role, content_json, created_timestamp FROM messages WHERE session_id = ? ORDER BY timestamp", + ) + .bind(session_id) + .fetch_all(&self.pool) + .await?; + + let mut messages = Vec::new(); + for (role_str, content_json, created_timestamp) in rows { + let role = match role_str.as_str() { + "user" => Role::User, + "assistant" => Role::Assistant, + _ => continue, + }; + + let content = serde_json::from_str(&content_json)?; + let message = Message::new(role, created_timestamp, content); + messages.push(message); + } + + Ok(Conversation::new_unvalidated(messages)) + } + + async fn add_message(&self, session_id: &str, message: &Message) -> Result<()> { + sqlx::query( + r#" + INSERT INTO messages (session_id, role, content_json, created_timestamp) + VALUES (?, ?, ?, ?) + "#, + ) + .bind(session_id) + .bind(role_to_string(&message.role)) + .bind(serde_json::to_string(&message.content)?) + .bind(message.created) + .execute(&self.pool) + .await?; + + sqlx::query("UPDATE sessions SET updated_at = datetime('now') WHERE id = ?") + .bind(session_id) + .execute(&self.pool) + .await?; + + Ok(()) + } + + async fn replace_conversation( + &self, + session_id: &str, + conversation: &Conversation, + ) -> Result<()> { + let mut tx = self.pool.begin().await?; + + sqlx::query("DELETE FROM messages WHERE session_id = ?") + .bind(session_id) + .execute(&mut *tx) + .await?; + + for message in conversation.messages() { + sqlx::query( + r#" + INSERT INTO messages (session_id, role, content_json, created_timestamp) + VALUES (?, ?, ?, ?) + "#, + ) + .bind(session_id) + .bind(role_to_string(&message.role)) + .bind(serde_json::to_string(&message.content)?) + .bind(message.created) + .execute(&mut *tx) + .await?; + } + + tx.commit().await?; + Ok(()) + } + + async fn list_sessions(&self) -> Result> { + sqlx::query_as::<_, Session>( + r#" + SELECT s.id, s.working_dir, s.description, s.created_at, s.updated_at, s.extension_data, + s.total_tokens, s.input_tokens, s.output_tokens, + s.accumulated_total_tokens, s.accumulated_input_tokens, s.accumulated_output_tokens, + s.schedule_id, s.recipe_json, + COUNT(m.id) as message_count + FROM sessions s + INNER JOIN messages m ON s.id = m.session_id + GROUP BY s.id + ORDER BY s.updated_at DESC + "#, + ) + .fetch_all(&self.pool) + .await + .map_err(Into::into) + } + + async fn delete_session(&self, session_id: &str) -> Result<()> { + let exists = + sqlx::query_scalar::<_, bool>("SELECT EXISTS(SELECT 1 FROM sessions WHERE id = ?)") + .bind(session_id) + .fetch_one(&self.pool) + .await?; + + if !exists { + return Err(anyhow::anyhow!("Session not found")); + } + + sqlx::query("DELETE FROM messages WHERE session_id = ?") + .bind(session_id) + .execute(&self.pool) + .await?; + + sqlx::query("DELETE FROM sessions WHERE id = ?") + .bind(session_id) + .execute(&self.pool) + .await?; + + Ok(()) + } + + async fn get_insights(&self) -> Result { + let row = sqlx::query_as::<_, (i64, Option)>( + r#" + SELECT COUNT(*) as total_sessions, + COALESCE(SUM(COALESCE(accumulated_total_tokens, total_tokens, 0)), 0) as total_tokens + FROM sessions + "#, + ) + .fetch_one(&self.pool) + .await?; + + Ok(SessionInsights { + total_sessions: row.0 as usize, + total_tokens: row.1.unwrap_or(0), + }) + } +} diff --git a/crates/goose/src/session/storage.rs b/crates/goose/src/session/storage.rs deleted file mode 100644 index 6a32068c6bdb..000000000000 --- a/crates/goose/src/session/storage.rs +++ /dev/null @@ -1,1989 +0,0 @@ -// IMPORTANT: This file includes session recovery functionality to handle corrupted session files. -// Only essential logging is included with the [SESSION] prefix to track: -// - Total message counts -// - Corruption detection and recovery -// - Backup creation -// Additional debug logging can be added if needed for troubleshooting. - -use crate::conversation::message::Message; -use crate::conversation::Conversation; -use crate::providers::base::Provider; -use crate::recipe::Recipe; -use crate::session::extension_data::ExtensionData; -use crate::utils::safe_truncate; -use anyhow::Result; -use chrono::Local; -use etcetera::{choose_app_strategy, AppStrategy, AppStrategyArgs}; -use regex::Regex; -use serde::{Deserialize, Serialize}; -use std::fs; -use std::io::{self, BufRead, Write}; -use std::ops::DerefMut; -use std::path::{Path, PathBuf}; -use std::sync::Arc; -use utoipa::ToSchema; - -// Security limits -const MAX_FILE_SIZE: u64 = 10 * 1024 * 1024; // 10MB -const MAX_MESSAGE_COUNT: usize = 5000; -const MAX_LINE_LENGTH: usize = 1024 * 1024; // 1MB per line - -fn get_home_dir() -> PathBuf { - choose_app_strategy(crate::config::APP_STRATEGY.clone()) - .expect("goose requires a home dir") - .home_dir() - .to_path_buf() -} - -fn get_current_working_dir() -> PathBuf { - std::env::current_dir() - .or_else(|_| Ok::(get_home_dir())) - .expect("could not determine the current working directory") -} - -/// Metadata for a session, stored as the first line in the session file -#[derive(Debug, Clone, Serialize, ToSchema)] -pub struct SessionMetadata { - /// Working directory for the session - #[schema(value_type = String, example = "/home/user/sessions/session1")] - pub working_dir: PathBuf, - /// A short description of the session, typically 3 words or less - pub description: String, - /// ID of the schedule that triggered this session, if any - pub schedule_id: Option, - - /// Number of messages in the session - pub message_count: usize, - /// The total number of tokens used in the session. Retrieved from the provider's last usage. - pub total_tokens: Option, - /// The number of input tokens used in the session. Retrieved from the provider's last usage. - pub input_tokens: Option, - /// The number of output tokens used in the session. Retrieved from the provider's last usage. - pub output_tokens: Option, - /// The total number of tokens used in the session. Accumulated across all messages (useful for tracking cost over an entire session). - pub accumulated_total_tokens: Option, - /// The number of input tokens used in the session. Accumulated across all messages. - pub accumulated_input_tokens: Option, - /// The number of output tokens used in the session. Accumulated across all messages. - pub accumulated_output_tokens: Option, - - /// Extension data containing extension states - #[serde(default)] - pub extension_data: ExtensionData, - - pub recipe: Option, -} - -// Custom deserializer to handle old sessions without working_dir -impl<'de> Deserialize<'de> for SessionMetadata { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - #[derive(Deserialize)] - struct Helper { - description: String, - message_count: usize, - schedule_id: Option, - total_tokens: Option, - input_tokens: Option, - output_tokens: Option, - accumulated_total_tokens: Option, - accumulated_input_tokens: Option, - accumulated_output_tokens: Option, - working_dir: Option, - #[serde(default)] - extension_data: ExtensionData, - recipe: Option, - } - - let helper = Helper::deserialize(deserializer)?; - - // Get working dir, falling back to home if not specified or if specified dir doesn't exist - let working_dir = helper - .working_dir - .filter(|path| path.exists()) - .unwrap_or_else(get_current_working_dir); - - Ok(SessionMetadata { - description: helper.description, - message_count: helper.message_count, - schedule_id: helper.schedule_id, - total_tokens: helper.total_tokens, - input_tokens: helper.input_tokens, - output_tokens: helper.output_tokens, - accumulated_total_tokens: helper.accumulated_total_tokens, - accumulated_input_tokens: helper.accumulated_input_tokens, - accumulated_output_tokens: helper.accumulated_output_tokens, - working_dir, - extension_data: helper.extension_data, - recipe: helper.recipe, - }) - } -} - -impl SessionMetadata { - pub fn new(working_dir: PathBuf) -> Self { - // If working_dir doesn't exist, fall back to home directory - let working_dir = if !working_dir.exists() { - get_home_dir() - } else { - working_dir - }; - - Self { - working_dir, - description: String::new(), - schedule_id: None, - message_count: 0, - total_tokens: None, - input_tokens: None, - output_tokens: None, - accumulated_total_tokens: None, - accumulated_input_tokens: None, - accumulated_output_tokens: None, - extension_data: ExtensionData::new(), - recipe: None, - } - } -} - -impl Default for SessionMetadata { - fn default() -> Self { - Self::new(get_current_working_dir()) - } -} - -// The single app name used for all Goose applications -const APP_NAME: &str = "goose"; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum Identifier { - Name(String), - Path(PathBuf), -} - -pub fn get_path(id: Identifier) -> Result { - let path = match id { - Identifier::Name(name) => { - // Validate session name for security - if name.is_empty() || name.len() > 255 { - return Err(anyhow::anyhow!("Invalid session name length")); - } - - // Check for path traversal attempts - if name.contains("..") || name.contains('/') || name.contains('\\') { - return Err(anyhow::anyhow!("Invalid characters in session name")); - } - - let session_dir = ensure_session_dir().map_err(|e| { - tracing::error!("Failed to create session directory: {}", e); - anyhow::anyhow!("Failed to access session directory") - })?; - session_dir.join(format!("{}.jsonl", name)) - } - Identifier::Path(path) => { - // In test mode, allow temporary directory paths - #[cfg(test)] - { - if let Some(path_str) = path.to_str() { - if path_str.contains("/tmp") || path_str.contains("/.tmp") { - // Allow test temporary directories - return Ok(path); - } - } - } - - // Validate that the path is within allowed directories - let session_dir = ensure_session_dir().map_err(|e| { - tracing::error!("Failed to create session directory: {}", e); - anyhow::anyhow!("Failed to access session directory") - })?; - - // Handle path validation with Windows-compatible logic - let is_path_allowed = validate_path_within_session_dir(&path, &session_dir)?; - if !is_path_allowed { - tracing::warn!( - "Attempted access outside session directory: {:?} not within {:?}", - path, - session_dir - ); - return Err(anyhow::anyhow!("Path not allowed")); - } - - path - } - }; - - // Additional security check for file extension (skip for special no-session paths) - if let Some(ext) = path.extension() { - if ext != "jsonl" { - return Err(anyhow::anyhow!("Invalid file extension")); - } - } - - Ok(path) -} - -/// Validate that a path is within the session directory, with Windows-compatible logic -/// -/// This function handles Windows-specific path issues like: -/// - UNC path conversion during canonicalization -/// - Case sensitivity differences -/// - Path separator normalization -/// - Drive letter casing inconsistencies -fn validate_path_within_session_dir(path: &Path, session_dir: &Path) -> Result { - // First, try the simple case - if canonicalization works cleanly - if let (Ok(canonical_path), Ok(canonical_session_dir)) = - (path.canonicalize(), session_dir.canonicalize()) - { - if canonical_path.starts_with(&canonical_session_dir) { - return Ok(true); - } - } - - // Fallback approach for Windows: normalize paths manually - let normalized_path = normalize_path_for_comparison(path); - let normalized_session_dir = normalize_path_for_comparison(session_dir); - - // Check if the normalized path starts with the normalized session directory - if normalized_path.starts_with(&normalized_session_dir) { - return Ok(true); - } - - // Additional check: if the path doesn't exist yet, check its parent directory - if !path.exists() { - if let Some(parent) = path.parent() { - return validate_path_within_session_dir(parent, session_dir); - } - } - - Ok(false) -} - -/// Normalize a path for cross-platform comparison -/// -/// This handles Windows-specific issues like: -/// - Converting to absolute paths -/// - Normalizing path separators -/// - Handling case sensitivity -fn normalize_path_for_comparison(path: &Path) -> PathBuf { - // Try to canonicalize first, but fall back to absolute path if that fails - let absolute_path = if let Ok(canonical) = path.canonicalize() { - canonical - } else if let Ok(absolute) = path.to_path_buf().canonicalize() { - absolute - } else { - // Last resort: try to make it absolute manually - if path.is_absolute() { - path.to_path_buf() - } else { - // If we can't make it absolute, use the current directory - std::env::current_dir() - .unwrap_or_else(|_| PathBuf::from(".")) - .join(path) - } - }; - - // On Windows, normalize the path representation - #[cfg(windows)] - { - // Convert the path to components and rebuild it normalized - let components: Vec<_> = absolute_path.components().collect(); - let mut normalized = PathBuf::new(); - - for component in components { - match component { - std::path::Component::Prefix(prefix) => { - // Handle drive letters and UNC paths - let prefix_str = prefix.as_os_str().to_string_lossy(); - if prefix_str.starts_with("\\\\?\\") { - // Remove UNC prefix and add the drive letter normally - let clean_prefix = &prefix_str[4..]; - normalized.push(clean_prefix); - } else { - normalized.push(component); - } - } - std::path::Component::RootDir => { - normalized.push(component); - } - std::path::Component::CurDir | std::path::Component::ParentDir => { - // Skip these as they should be resolved by canonicalization - continue; - } - std::path::Component::Normal(name) => { - // Normalize case for Windows - let name_str = name.to_string_lossy().to_lowercase(); - normalized.push(name_str); - } - } - } - - normalized - } - - #[cfg(not(windows))] - { - absolute_path - } -} - -/// Ensure the session directory exists and return its path -pub fn ensure_session_dir() -> Result { - let app_strategy = AppStrategyArgs { - top_level_domain: "Block".to_string(), - author: "Block".to_string(), - app_name: APP_NAME.to_string(), - }; - - let data_dir = choose_app_strategy(app_strategy) - .expect("goose requires a home dir") - .data_dir() - .join("sessions"); - - if !data_dir.exists() { - fs::create_dir_all(&data_dir)?; - } - - Ok(data_dir) -} - -/// Get the path to the most recently modified session file -pub fn get_most_recent_session() -> Result { - let session_dir = ensure_session_dir()?; - let mut entries = fs::read_dir(&session_dir)? - .filter_map(|entry| entry.ok()) - .filter(|entry| entry.path().extension().is_some_and(|ext| ext == "jsonl")) - .collect::>(); - - if entries.is_empty() { - return Err(anyhow::anyhow!("No session files found")); - } - - // Sort by modification time, most recent first - entries.sort_by(|a, b| { - b.metadata() - .and_then(|m| m.modified()) - .unwrap_or(std::time::SystemTime::UNIX_EPOCH) - .cmp( - &a.metadata() - .and_then(|m| m.modified()) - .unwrap_or(std::time::SystemTime::UNIX_EPOCH), - ) - }); - - Ok(entries[0].path()) -} - -/// List all available session files -pub fn list_sessions() -> Result> { - let session_dir = ensure_session_dir()?; - let entries = fs::read_dir(&session_dir)? - .filter_map(|entry| { - let entry = entry.ok()?; - let path = entry.path(); - - if path.extension().is_some_and(|ext| ext == "jsonl") { - let name = path.file_stem()?.to_string_lossy().to_string(); - Some((name, path)) - } else { - None - } - }) - .collect::>(); - - Ok(entries) -} - -/// Generate a session ID using timestamp format (yyyymmdd_hhmmss) -/// TODO(Douwe): make this actually be unique -pub fn generate_session_id() -> String { - Local::now().format("%Y%m%d_%H%M%S").to_string() -} - -/// Read messages from a session file with corruption recovery -/// -/// Creates the file if it doesn't exist, reads and deserializes all messages if it does. -/// The first line of the file is expected to be metadata, and the rest are messages. -/// Large messages are automatically truncated to prevent memory issues. -/// Includes recovery mechanisms for corrupted files. -/// -/// Security features: -/// - Validates file paths to prevent directory traversal -/// - Includes all security limits from read_messages_with_truncation -pub fn read_messages(session_file: &Path) -> Result { - // Validate the path for security - let secure_path = get_path(Identifier::Path(session_file.to_path_buf()))?; - - let result = read_messages_with_truncation(&secure_path, Some(50000)); // 50KB limit per message content - match &result { - Ok(_messages) => {} - Err(e) => println!( - "[SESSION] Failed to read messages from {:?}: {}", - secure_path, e - ), - } - result -} - -/// Read messages from a session file with optional content truncation and corruption recovery -/// -/// Creates the file if it doesn't exist, reads and deserializes all messages if it does. -/// The first line of the file is expected to be metadata, and the rest are messages. -/// If max_content_size is Some, large message content will be truncated during loading. -/// Includes robust error handling and corruption recovery mechanisms. -/// -/// Security features: -/// - File size limits to prevent resource exhaustion -/// - Message count limits to prevent DoS attacks -/// - Line length restrictions to prevent memory issues -pub fn read_messages_with_truncation( - session_file: &Path, - max_content_size: Option, -) -> Result { - // Security check: file size limit - if session_file.exists() { - let metadata = fs::metadata(session_file)?; - if metadata.len() > MAX_FILE_SIZE { - tracing::warn!("Session file exceeds size limit: {} bytes", metadata.len()); - return Err(anyhow::anyhow!("Session file too large")); - } - } - - // Check if there's a backup file we should restore from - let backup_file = session_file.with_extension("backup"); - if !session_file.exists() && backup_file.exists() { - println!( - "[SESSION] Session file missing but backup exists, restoring from backup: {:?}", - backup_file - ); - tracing::warn!( - "Session file missing but backup exists, restoring from backup: {:?}", - backup_file - ); - if let Err(e) = fs::copy(&backup_file, session_file) { - println!("[SESSION] Failed to restore from backup: {}", e); - tracing::error!("Failed to restore from backup: {}", e); - } - } - - // Open the file with appropriate options - let file = fs::OpenOptions::new() - .read(true) - .write(true) - .create(true) - .truncate(false) - .open(session_file)?; - - let reader = io::BufReader::new(file); - let mut lines = reader.lines(); - let mut messages = Vec::new(); - let mut corrupted_lines = Vec::new(); - let mut line_number = 1; - let mut message_count = 0; - - // Read the first line as metadata or create default if empty/missing - if let Some(line_result) = lines.next() { - match line_result { - Ok(line) => { - // Security check: line length - if line.len() > MAX_LINE_LENGTH { - tracing::warn!("Line {} exceeds length limit", line_number); - return Err(anyhow::anyhow!("Line too long")); - } - - // Try to parse as metadata, but if it fails, treat it as a message - if let Ok(_metadata) = serde_json::from_str::(&line) { - // Metadata successfully parsed, continue with the rest of the lines as messages - } else { - // This is not metadata, it's a message - match parse_message_with_truncation(&line, max_content_size) { - Ok(message) => { - messages.push(message); - message_count += 1; - } - Err(e) => { - println!("[SESSION] Failed to parse first line as message: {}", e); - println!("[SESSION] Attempting to recover corrupted first line..."); - tracing::warn!("Failed to parse first line as message: {}", e); - - // Try to recover the corrupted line - match attempt_corruption_recovery(&line, max_content_size) { - Ok(recovered) => { - println!( - "[SESSION] Successfully recovered corrupted first line!" - ); - messages.push(recovered); - message_count += 1; - } - Err(recovery_err) => { - println!( - "[SESSION] Failed to recover corrupted first line: {}", - recovery_err - ); - corrupted_lines.push((line_number, line)); - } - } - } - } - } - } - Err(e) => { - println!("[SESSION] Failed to read first line: {}", e); - tracing::error!("Failed to read first line: {}", e); - corrupted_lines.push((line_number, "[Unreadable line]".to_string())); - } - } - line_number += 1; - } - - // Read the rest of the lines as messages - for line_result in lines { - // Security check: message count limit - if message_count >= MAX_MESSAGE_COUNT { - tracing::warn!("Message count limit reached: {}", MAX_MESSAGE_COUNT); - println!( - "[SESSION] Message count limit reached, stopping at {}", - MAX_MESSAGE_COUNT - ); - break; - } - - match line_result { - Ok(line) => { - // Security check: line length - if line.len() > MAX_LINE_LENGTH { - tracing::warn!("Line {} exceeds length limit", line_number); - corrupted_lines.push(( - line_number, - "[Line too long - truncated for security]".to_string(), - )); - line_number += 1; - continue; - } - - match parse_message_with_truncation(&line, max_content_size) { - Ok(message) => { - messages.push(message); - message_count += 1; - } - Err(e) => { - println!("[SESSION] Failed to parse line {}: {}", line_number, e); - println!( - "[SESSION] Attempting to recover corrupted line {}...", - line_number - ); - tracing::warn!("Failed to parse line {}: {}", line_number, e); - - // Try to recover the corrupted line - match attempt_corruption_recovery(&line, max_content_size) { - Ok(recovered) => { - println!( - "[SESSION] Successfully recovered corrupted line {}!", - line_number - ); - messages.push(recovered); - message_count += 1; - } - Err(recovery_err) => { - println!( - "[SESSION] Failed to recover corrupted line {}: {}", - line_number, recovery_err - ); - corrupted_lines.push((line_number, line)); - } - } - } - } - } - Err(e) => { - println!("[SESSION] Failed to read line {}: {}", line_number, e); - tracing::error!("Failed to read line {}: {}", line_number, e); - corrupted_lines.push((line_number, "[Unreadable line]".to_string())); - } - } - line_number += 1; - } - - // If we found corrupted lines, create a backup and log the issues - if !corrupted_lines.is_empty() { - println!( - "[SESSION] Found {} corrupted lines, creating backup", - corrupted_lines.len() - ); - tracing::warn!( - "Found {} corrupted lines in session file, creating backup", - corrupted_lines.len() - ); - - // Create a backup of the original file - if !backup_file.exists() { - if let Err(e) = fs::copy(session_file, &backup_file) { - println!("[SESSION] Failed to create backup file: {}", e); - tracing::error!("Failed to create backup file: {}", e); - } else { - println!("[SESSION] Created backup file: {:?}", backup_file); - tracing::info!("Created backup file: {:?}", backup_file); - } - } - - // Log details about corrupted lines (with limited detail for security) - for (num, line) in &corrupted_lines { - let preview = if line.len() > 50 { - format!("{}... (truncated)", safe_truncate(line, 50)) - } else { - line.clone() - }; - tracing::debug!("Corrupted line {}: {}", num, preview); - } - } - - Ok(Conversation::new_unvalidated(messages)) -} - -/// Parse a message from JSON string with optional content truncation -fn parse_message_with_truncation( - json_str: &str, - max_content_size: Option, -) -> Result { - // First try to parse normally - match serde_json::from_str::(json_str) { - Ok(mut message) => { - // If we have a size limit, check and truncate if needed - if let Some(max_size) = max_content_size { - truncate_message_content_in_place(&mut message, max_size); - } - Ok(message) - } - Err(_e) => { - // If parsing fails and the string is very long, it might be due to size - if json_str.len() > 100000 { - println!( - "[SESSION] Very large message detected ({}KB), attempting truncation", - json_str.len() / 1024 - ); - tracing::warn!( - "Failed to parse very large message ({}KB), attempting truncation", - json_str.len() / 1024 - ); - - // Try to truncate the JSON string itself before parsing - let truncated_json = if let Some(max_size) = max_content_size { - truncate_json_string(json_str, max_size) - } else { - json_str.to_string() - }; - - match serde_json::from_str::(&truncated_json) { - Ok(message) => { - tracing::info!("Successfully parsed message after JSON truncation"); - Ok(message) - } - Err(_) => { - println!( - "[SESSION] Failed to parse even after truncation, attempting recovery" - ); - tracing::error!("Failed to parse message even after truncation"); - attempt_corruption_recovery(json_str, max_content_size) - } - } - } else { - // Try intelligent corruption recovery - attempt_corruption_recovery(json_str, max_content_size) - } - } - } -} - -/// Truncate content within a message in place -fn truncate_message_content_in_place(message: &mut Message, max_content_size: usize) { - use crate::conversation::message::MessageContent; - use rmcp::model::{RawContent, ResourceContents}; - - for content in &mut message.content { - match content { - MessageContent::Text(text_content) => { - if text_content.text.chars().count() > max_content_size { - let truncated = format!( - "{}\n\n[... content truncated during session loading from {} to {} characters ...]", - safe_truncate(&text_content.text, max_content_size), - text_content.text.chars().count(), - max_content_size - ); - text_content.text = truncated; - } - } - MessageContent::ToolResponse(tool_response) => { - if let Ok(ref mut result) = tool_response.tool_result { - for content_item in result { - match content_item.deref_mut() { - RawContent::Text(ref mut text_content) => { - if text_content.text.chars().count() > max_content_size { - let truncated = format!( - "{}\n\n[... tool response truncated during session loading from {} to {} characters ...]", - safe_truncate(&text_content.text, max_content_size), - text_content.text.chars().count(), - max_content_size - ); - text_content.text = truncated; - } - } - RawContent::Resource(ref mut resource_content) => { - if let ResourceContents::TextResourceContents { text, .. } = - &mut resource_content.resource - { - if text.chars().count() > max_content_size { - let truncated = format!( - "{}\n\n[... resource content truncated during session loading from {} to {} characters ...]", - safe_truncate(text, max_content_size), - text.chars().count(), - max_content_size - ); - *text = truncated; - } - } - } - _ => {} // Other content types are typically smaller - } - } - } - } - _ => {} // Other content types are typically smaller - } - } -} - -/// Attempt to recover corrupted JSON lines using various strategies -fn attempt_corruption_recovery(json_str: &str, max_content_size: Option) -> Result { - // Strategy 1: Try to fix common JSON corruption issues - if let Ok(message) = try_fix_json_corruption(json_str, max_content_size) { - println!("[SESSION] Recovered using JSON corruption fix"); - return Ok(message); - } - - // Strategy 2: Try to extract partial content if it looks like a message - if let Ok(message) = try_extract_partial_message(json_str) { - println!("[SESSION] Recovered using partial message extraction"); - return Ok(message); - } - - // Strategy 3: Try to fix truncated JSON - if let Ok(message) = try_fix_truncated_json(json_str, max_content_size) { - println!("[SESSION] Recovered using truncated JSON fix"); - return Ok(message); - } - - // Strategy 4: Create a placeholder message with the raw content - println!("[SESSION] All recovery strategies failed, creating placeholder message"); - let preview = if json_str.len() > 200 { - format!("{}...", safe_truncate(json_str, 200)) - } else { - json_str.to_string() - }; - - Ok(Message::user().with_text(format!( - "[RECOVERED FROM CORRUPTED LINE]\nOriginal content preview: {}\n\n[This message was recovered from a corrupted session file line. The original data may be incomplete.]", - preview - ))) -} - -/// Try to fix common JSON corruption patterns -fn try_fix_json_corruption(json_str: &str, max_content_size: Option) -> Result { - let mut fixed_json = json_str.to_string(); - let mut fixes_applied = Vec::new(); - - // Fix 1: Remove trailing commas before closing braces/brackets - if fixed_json.contains(",}") || fixed_json.contains(",]") { - fixed_json = fixed_json.replace(",}", "}").replace(",]", "]"); - fixes_applied.push("trailing commas"); - } - - // Fix 2: Try to close unclosed quotes in text fields - if let Some(text_start) = fixed_json.find("\"text\":\"") { - let content_start = text_start + 8; - if let Some(remaining) = fixed_json.get(content_start..) { - // Count quotes to see if we have an odd number (unclosed quote) - let quote_count = remaining.matches('"').count(); - if quote_count % 2 == 1 { - // Find the last quote and see if we need to close it - if let Some(last_quote_pos) = remaining.rfind('"') { - let after_last_quote = &remaining[last_quote_pos + 1..]; - if !after_last_quote.trim_start().starts_with(',') - && !after_last_quote.trim_start().starts_with('}') - { - // Insert a closing quote before the next field or end - if let Some(next_field) = after_last_quote.find(',') { - fixed_json.insert(content_start + last_quote_pos + 1 + next_field, '"'); - fixes_applied.push("unclosed quotes"); - } else if after_last_quote.contains('}') { - if let Some(brace_pos) = after_last_quote.find('}') { - fixed_json - .insert(content_start + last_quote_pos + 1 + brace_pos, '"'); - fixes_applied.push("unclosed quotes"); - } - } - } - } - } - } - } - - // Fix 3: Try to close unclosed JSON objects/arrays - let open_braces = fixed_json.matches('{').count(); - let close_braces = fixed_json.matches('}').count(); - let open_brackets = fixed_json.matches('[').count(); - let close_brackets = fixed_json.matches(']').count(); - - if open_braces > close_braces { - for _ in 0..(open_braces - close_braces) { - fixed_json.push('}'); - } - fixes_applied.push("unclosed braces"); - } - - if open_brackets > close_brackets { - for _ in 0..(open_brackets - close_brackets) { - fixed_json.push(']'); - } - fixes_applied.push("unclosed brackets"); - } - - // Fix 4: Remove control characters that might break JSON parsing - let original_len = fixed_json.len(); - fixed_json = fixed_json - .chars() - .filter(|c| !c.is_control() || *c == '\n' || *c == '\r' || *c == '\t') - .collect(); - if fixed_json.len() != original_len { - fixes_applied.push("control characters"); - } - - if !fixes_applied.is_empty() { - match serde_json::from_str::(&fixed_json) { - Ok(mut message) => { - if let Some(max_size) = max_content_size { - truncate_message_content_in_place(&mut message, max_size); - } - return Ok(message); - } - Err(e) => { - println!("[SESSION] JSON fixes didn't work: {}", e); - } - } - } - - Err(anyhow::anyhow!("JSON corruption fixes failed")) -} - -/// Try to extract a partial message from corrupted JSON -fn try_extract_partial_message(json_str: &str) -> Result { - // Look for recognizable patterns that indicate this was a message - - // Try to extract role - let role = if json_str.contains("\"role\":\"user\"") { - rmcp::model::Role::User - } else if json_str.contains("\"role\":\"assistant\"") { - rmcp::model::Role::Assistant - } else { - rmcp::model::Role::User // Default fallback - }; - - // Try to extract text content - let mut extracted_text = String::new(); - - // Look for text field content - if let Some(text_start) = json_str.find("\"text\":\"") { - let content_start = text_start + 8; - if let Some(content_end) = json_str[content_start..].find("\",") { - extracted_text = json_str[content_start..content_start + content_end].to_string(); - } else if let Some(content_end) = json_str[content_start..].find("\"") { - extracted_text = json_str[content_start..content_start + content_end].to_string(); - } else { - // Take everything after "text":" until we hit a likely end - let remaining = &json_str[content_start..]; - if let Some(end_pos) = remaining.find('}') { - extracted_text = remaining[..end_pos].trim_end_matches('"').to_string(); - } else { - extracted_text = remaining.to_string(); - } - } - } - - // If we couldn't extract text, try to find any readable content - if extracted_text.is_empty() { - // Look for any quoted strings that might be content - let quote_pattern = Regex::new(r#""([^"]{10,})""#).unwrap(); - if let Some(captures) = quote_pattern.find(json_str) { - extracted_text = captures.as_str().trim_matches('"').to_string(); - } - } - - if !extracted_text.is_empty() { - let message = match role { - rmcp::model::Role::User => Message::user(), - rmcp::model::Role::Assistant => Message::assistant(), - }; - - return Ok(message.with_text(format!("[PARTIALLY RECOVERED] {}", extracted_text))); - } - - Err(anyhow::anyhow!("Could not extract partial message")) -} - -/// Try to fix truncated JSON by completing it -fn try_fix_truncated_json(json_str: &str, max_content_size: Option) -> Result { - let mut completed_json = json_str.to_string(); - - // If the JSON appears to be cut off mid-field, try to complete it - if !completed_json.trim().ends_with('}') && !completed_json.trim().ends_with(']') { - // Try to find where it was likely cut off - if let Some(last_quote) = completed_json.rfind('"') { - let after_quote = &completed_json[last_quote + 1..]; - if !after_quote.contains('"') && !after_quote.contains('}') { - // Looks like it was cut off in the middle of a string value - completed_json.push('"'); - - // Try to close the JSON structure - let open_braces = completed_json.matches('{').count(); - let close_braces = completed_json.matches('}').count(); - - for _ in 0..(open_braces - close_braces) { - completed_json.push('}'); - } - - match serde_json::from_str::(&completed_json) { - Ok(mut message) => { - if let Some(max_size) = max_content_size { - truncate_message_content_in_place(&mut message, max_size); - } - return Ok(message); - } - Err(e) => { - println!("[SESSION] Truncation fix didn't work: {}", e); - } - } - } - } - } - - Err(anyhow::anyhow!("Truncation fix failed")) -} - -/// Attempt to truncate a JSON string by finding and truncating large text values -fn truncate_json_string(json_str: &str, max_content_size: usize) -> String { - // This is a heuristic approach - look for large text values in the JSON - // and truncate them. This is not perfect but should handle the common case - // of large tool responses. - - if json_str.len() <= max_content_size * 2 { - return json_str.to_string(); - } - - // Try to find patterns that look like large text content - // Look for "text":"..." patterns and truncate the content - let mut result = json_str.to_string(); - - // Simple regex-like approach to find and truncate large text values - if let Some(start) = result.find("\"text\":\"") { - let text_start = start + 8; // Length of "text":" - if let Some(end) = result[text_start..].find("\",") { - let text_end = text_start + end; - let text_content = &result[text_start..text_end]; - - if text_content.len() > max_content_size { - let truncated_text = format!( - "{}\n\n[... content truncated during JSON parsing from {} to {} characters ...]", - safe_truncate(text_content, max_content_size), - text_content.len(), - max_content_size - ); - result.replace_range(text_start..text_end, &truncated_text); - } - } - } - - result -} - -/// Read session metadata from a session file with security validation -/// -/// Returns default empty metadata if the file doesn't exist or has no metadata. -/// Includes security checks for file access and content validation. -pub fn read_metadata(session_file: &Path) -> Result { - // Validate the path for security - let secure_path = get_path(Identifier::Path(session_file.to_path_buf()))?; - - if !secure_path.exists() { - return Ok(SessionMetadata::default()); - } - - // Security check: file size - let file_metadata = fs::metadata(&secure_path)?; - if file_metadata.len() > MAX_FILE_SIZE { - tracing::warn!("Session file exceeds size limit during metadata read"); - return Err(anyhow::anyhow!("Session file too large")); - } - - let file = fs::File::open(&secure_path).map_err(|e| { - tracing::error!("Failed to open session file for metadata read: {}", e); - anyhow::anyhow!("Failed to access session file") - })?; - let mut reader = io::BufReader::new(file); - let mut first_line = String::new(); - - // Read just the first line - if reader.read_line(&mut first_line)? > 0 { - // Security check: line length - if first_line.len() > MAX_LINE_LENGTH { - tracing::warn!("Metadata line exceeds length limit"); - return Err(anyhow::anyhow!("Metadata line too long")); - } - - // Try to parse as metadata - match serde_json::from_str::(&first_line) { - Ok(metadata) => Ok(metadata), - Err(e) => { - // If the first line isn't metadata, return default - tracing::debug!("Metadata parse error: {}", e); - Ok(SessionMetadata::default()) - } - } - } else { - // Empty file, return default - Ok(SessionMetadata::default()) - } -} - -/// Write messages to a session file with metadata -/// -/// Overwrites the file with metadata as the first line, followed by all messages in JSONL format. -/// If a provider is supplied, it will automatically generate a description when appropriate. -/// -/// Security features: -/// - Validates file paths to prevent directory traversal -pub async fn persist_messages( - session_file: &Path, - messages: &Conversation, - provider: Option>, - working_dir: Option, -) -> Result<()> { - persist_messages_with_schedule_id(session_file, messages, provider, None, working_dir).await -} - -/// Write messages to a session file with metadata, including an optional scheduled job ID -/// -/// Overwrites the file with metadata as the first line, followed by all messages in JSONL format. -/// If a provider is supplied, it will automatically generate a description when appropriate. -/// -/// Security features: -/// - Validates file paths to prevent directory traversal -/// - Limits error message details in logs -/// - Uses atomic file operations via save_messages_with_metadata -pub async fn persist_messages_with_schedule_id( - session_file: &Path, - messages: &Conversation, - provider: Option>, - schedule_id: Option, - working_dir: Option, -) -> Result<()> { - // Validate the session file path for security - let secure_path = get_path(Identifier::Path(session_file.to_path_buf()))?; - - // Security check: message count limit - if messages.len() > MAX_MESSAGE_COUNT { - tracing::warn!("Message count exceeds limit: {}", messages.len()); - return Err(anyhow::anyhow!("Too many messages")); - } - - // Count user messages - let user_message_count = messages - .iter() - .filter(|m| m.role == rmcp::model::Role::User && !m.as_concat_text().trim().is_empty()) - .count(); - - // Check if we need to update the description (after 1st or 3rd user message) - match provider { - Some(provider) if user_message_count < 4 => { - //generate_description is responsible for writing the messages - generate_description_with_schedule_id( - &secure_path, - messages, - provider, - schedule_id, - working_dir, - ) - .await - } - _ => { - // Read existing metadata or create new with proper working_dir - let mut metadata = if secure_path.exists() { - read_metadata(&secure_path)? - } else { - // Create new metadata with the provided working_dir or fall back to home - let work_dir = working_dir.clone().unwrap_or_else(get_home_dir); - SessionMetadata::new(work_dir) - }; - - // Update the working_dir if provided (even for existing files) - if let Some(work_dir) = working_dir { - metadata.working_dir = work_dir; - } - - // Update the schedule_id if provided - if schedule_id.is_some() { - metadata.schedule_id = schedule_id; - } - - // Write the file with metadata and messages - save_messages_with_metadata(&secure_path, &metadata, messages) - } - } -} - -/// Write messages to a session file with the provided metadata using secure atomic operations -/// -/// This function uses atomic file operations to prevent corruption: -/// 1. Writes to a temporary file first with secure permissions -/// 2. Uses fs2 file locking to prevent concurrent writes -/// 3. Atomically moves the temp file to the final location -/// 4. Includes comprehensive error handling and recovery -/// -/// Security features: -/// - Secure temporary file creation with restricted permissions -/// - Path validation to prevent directory traversal -/// - File size and message count limits -/// - Sanitized error messages to prevent information leakage -pub fn save_messages_with_metadata( - session_file: &Path, - metadata: &SessionMetadata, - conversation: &Conversation, -) -> Result<()> { - use fs2::FileExt; - - // Validate the path for security - let secure_path = get_path(Identifier::Path(session_file.to_path_buf()))?; - - // Security check: message count limit - if conversation.len() > MAX_MESSAGE_COUNT { - tracing::warn!( - "Message count exceeds limit during save: {}", - conversation.len() - ); - return Err(anyhow::anyhow!("Too many messages to save")); - } - - // Create a temporary file in the same directory to ensure atomic move - let temp_file = secure_path.with_extension("tmp"); - - // Ensure the parent directory exists - if let Some(parent) = secure_path.parent() { - fs::create_dir_all(parent).map_err(|e| { - tracing::error!("Failed to create parent directory: {}", e); - anyhow::anyhow!("Failed to create session directory") - })?; - } - - // Create and lock the temporary file with secure permissions - let file = fs::OpenOptions::new() - .write(true) - .create(true) - .truncate(true) - .open(&temp_file) - .map_err(|e| { - tracing::error!("Failed to create temporary file: {}", e); - anyhow::anyhow!("Failed to create temporary session file") - })?; - - // Set secure file permissions (Unix only - read/write for owner only) - #[cfg(unix)] - { - use std::os::unix::fs::PermissionsExt; - let mut perms = file.metadata()?.permissions(); - perms.set_mode(0o600); // rw------- - fs::set_permissions(&temp_file, perms).map_err(|e| { - tracing::error!("Failed to set secure file permissions: {}", e); - anyhow::anyhow!("Failed to secure temporary file") - })?; - } - - // Get an exclusive lock on the file - file.try_lock_exclusive().map_err(|e| { - tracing::error!("Failed to lock file: {}", e); - anyhow::anyhow!("Failed to lock session file") - })?; - - // Write to temporary file - { - let mut writer = io::BufWriter::new(&file); - - // Write metadata as the first line - serde_json::to_writer(&mut writer, &metadata).map_err(|e| { - tracing::error!("Failed to serialize metadata: {}", e); - anyhow::anyhow!("Failed to write session metadata") - })?; - writeln!(writer)?; - - // Write all messages with progress tracking - for (i, message) in conversation.iter().enumerate() { - serde_json::to_writer(&mut writer, &message).map_err(|e| { - tracing::error!("Failed to serialize message {}: {}", i, e); - anyhow::anyhow!("Failed to write session message") - })?; - writeln!(writer)?; - } - - // Ensure all data is written to disk - writer.flush().map_err(|e| { - tracing::error!("Failed to flush writer: {}", e); - anyhow::anyhow!("Failed to flush session data") - })?; - } - - // Sync to ensure data is persisted - file.sync_all().map_err(|e| { - tracing::error!("Failed to sync data: {}", e); - anyhow::anyhow!("Failed to sync session data") - })?; - - // Release the lock - fs2::FileExt::unlock(&file).map_err(|e| { - tracing::error!("Failed to unlock file: {}", e); - anyhow::anyhow!("Failed to unlock session file") - })?; - - // Atomically move the temporary file to the final location - fs::rename(&temp_file, &secure_path).map_err(|e| { - // Clean up temp file on failure - tracing::error!("Failed to move temporary file: {}", e); - let _ = fs::remove_file(&temp_file); - anyhow::anyhow!("Failed to finalize session file") - })?; - - tracing::debug!("Successfully saved session file: {:?}", secure_path); - Ok(()) -} - -/// Generate a description for the session using the provider -/// -/// This function is called when appropriate to generate a short description -/// of the session based on the conversation history. -pub async fn generate_description( - session_file: &Path, - messages: &Conversation, - provider: Arc, - working_dir: Option, -) -> Result<()> { - generate_description_with_schedule_id(session_file, messages, provider, None, working_dir).await -} - -/// Generate a description for the session using the provider, including an optional scheduled job ID and working directory -/// -/// This function is called when appropriate to generate a short description -/// of the session based on the conversation history. -/// -/// Security features: -/// - Validates file paths to prevent directory traversal -/// - Limits context size to prevent resource exhaustion -/// - Uses secure file operations for saving -pub async fn generate_description_with_schedule_id( - session_file: &Path, - messages: &Conversation, - provider: Arc, - schedule_id: Option, - working_dir: Option, -) -> Result<()> { - // Validate the path for security - let secure_path = get_path(Identifier::Path(session_file.to_path_buf()))?; - - // Security check: message count limit - if messages.len() > MAX_MESSAGE_COUNT { - tracing::warn!( - "Message count exceeds limit during description generation: {}", - messages.len() - ); - return Err(anyhow::anyhow!( - "Too many messages for description generation" - )); - } - - // Use the provider's session naming capability - let sanitized_description = provider - .generate_session_name(messages) - .await - .map_err(|e| { - tracing::error!("Failed to generate session description: {}", e); - anyhow::anyhow!("Failed to generate session description") - })?; - - // Create metadata with proper working_dir or read existing and update - let mut metadata = if secure_path.exists() { - read_metadata(&secure_path)? - } else { - // Create new metadata with the provided working_dir or fall back to home - let work_dir = working_dir.clone().unwrap_or_else(get_home_dir); - SessionMetadata::new(work_dir) - }; - - // Update description and schedule_id - metadata.description = sanitized_description; - if schedule_id.is_some() { - metadata.schedule_id = schedule_id; - } - - // Update the working_dir if provided (even for existing files) - if let Some(work_dir) = working_dir { - metadata.working_dir = work_dir; - } - - // Update the file with the new metadata and existing messages - save_messages_with_metadata(&secure_path, &metadata, messages) -} - -/// Update only the metadata in a session file, preserving all messages -/// -/// Security features: -/// - Validates file paths to prevent directory traversal -/// - Uses secure file operations for reading and writing -pub async fn update_metadata(session_file: &Path, metadata: &SessionMetadata) -> Result<()> { - // Validate the path for security - let secure_path = get_path(Identifier::Path(session_file.to_path_buf()))?; - - // Read all messages from the file - let messages = read_messages(&secure_path)?; - - // Rewrite the file with the new metadata and existing messages - save_messages_with_metadata(&secure_path, metadata, &messages) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::conversation::message::{Message, MessageContent}; - use tempfile::tempdir; - - #[test] - fn test_corruption_recovery() -> Result<()> { - let test_cases = [ - // Case 1: Unclosed quotes - ( - r#"{"role":"user","content":[{"type":"text","text":"Hello there}]"#, - "Unclosed JSON with truncated content", - ), - // Case 2: Trailing comma - ( - r#"{"role":"user","content":[{"type":"text","text":"Test"},]}"#, - "JSON with trailing comma", - ), - // Case 3: Missing closing brace - ( - r#"{"role":"user","content":[{"type":"text","text":"Test""#, - "Incomplete JSON structure", - ), - // Case 4: Control characters in text - ( - r#"{"role":"user","content":[{"type":"text","text":"Test\u{0000}with\u{0001}control\u{0002}chars"}]}"#, - "JSON with control characters", - ), - // Case 5: Partial message with role and text - ( - r#"broken{"role": "assistant", "text": "This is recoverable content"more broken"#, - "Partial message with recoverable content", - ), - ]; - - println!("[TEST] Starting corruption recovery tests..."); - for (i, (corrupt_json, desc)) in test_cases.iter().enumerate() { - println!("\n[TEST] Case {}: {}", i + 1, desc); - println!( - "[TEST] Input: {}", - if corrupt_json.len() > 100 { - safe_truncate(corrupt_json, 100) - } else { - corrupt_json.to_string() - } - ); - - // Try to parse the corrupted JSON - match attempt_corruption_recovery(corrupt_json, Some(50000)) { - Ok(message) => { - println!("[TEST] Successfully recovered message"); - // Verify we got some content - if let Some(MessageContent::Text(text_content)) = message.content.first() { - assert!( - !text_content.text.is_empty(), - "Recovered message should have content" - ); - println!( - "[TEST] Recovered content: {}", - if text_content.text.len() > 50 { - format!("{}...", &text_content.text[..50]) - } else { - text_content.text.clone() - } - ); - } - } - Err(e) => { - println!("[TEST] Failed to recover: {}", e); - panic!("Failed to recover from case {}: {}", i + 1, desc); - } - } - } - - println!("\n[TEST] All corruption recovery tests passed!"); - Ok(()) - } - - #[tokio::test] - async fn test_read_write_messages() -> Result<()> { - let dir = tempdir()?; - let file_path = dir.path().join("test.jsonl"); - - // Create some test messages - let messages = Conversation::new_unvalidated(vec![ - Message::user().with_text("Hello"), - Message::assistant().with_text("Hi there"), - ]); - - // Write messages - persist_messages(&file_path, &messages, None, None).await?; - - // Read them back - let read_messages = read_messages(&file_path)?; - - // Compare - assert_eq!(messages.len(), read_messages.len()); - for (orig, read) in messages.iter().zip(read_messages.iter()) { - assert_eq!(orig.role, read.role); - assert_eq!(orig.content.len(), read.content.len()); - - // Compare first text content - if let (Some(MessageContent::Text(orig_text)), Some(MessageContent::Text(read_text))) = - (orig.content.first(), read.content.first()) - { - assert_eq!(orig_text.text, read_text.text); - } else { - panic!("Messages don't match expected structure"); - } - } - - Ok(()) - } - - #[test] - fn test_empty_file() -> Result<()> { - let dir = tempdir()?; - let file_path = dir.path().join("empty.jsonl"); - - // Reading an empty file should return empty vec - let messages = read_messages(&file_path)?; - assert!(messages.is_empty()); - - Ok(()) - } - - #[test] - fn test_generate_session_id() { - let id = generate_session_id(); - - // Check that it follows the timestamp format (yyyymmdd_hhmmss) - assert_eq!(id.len(), 15); // 8 chars for date + 1 for underscore + 6 for time - assert!(id.contains('_')); - - // Split by underscore and check parts - let parts: Vec<&str> = id.split('_').collect(); - assert_eq!(parts.len(), 2); - - // Date part should be 8 digits - assert_eq!(parts[0].len(), 8); - // Time part should be 6 digits - assert_eq!(parts[1].len(), 6); - } - - #[tokio::test] - async fn test_special_characters_and_long_text() -> Result<()> { - let dir = tempdir()?; - let file_path = dir.path().join("special.jsonl"); - - // Insert some problematic JSON-like content between moderately long text - // (keeping under truncation limit to test serialization/deserialization) - let long_text = format!( - "Start_of_message\n{}{}SOME_MIDDLE_TEXT{}End_of_message", - "A".repeat(10_000), // Reduced from 100_000 to stay under 50KB limit - "\"}]\n", - "A".repeat(10_000) // Reduced from 100_000 to stay under 50KB limit - ); - - let special_chars = vec![ - // Long text - long_text.as_str(), - // Newlines in different positions - "Line 1\nLine 2", - "Line 1\r\nLine 2", - "\nStart with newline", - "End with newline\n", - "\n\nMultiple\n\nNewlines\n\n", - // JSON special characters - "Quote\"in middle", - "\"Quote at start", - "Quote at end\"", - "Multiple\"\"Quotes", - "{\"json\": \"looking text\"}", - // Unicode and special characters - "Unicode: 🦆🤖👾", - "Special: \\n \\r \\t", - "Mixed: \n\"🦆\"\r\n\\n", - // Control characters - "Tab\there", - "Bell\u{0007}char", - "Null\u{0000}char", - // Long text with mixed content - "A very long message with multiple lines\nand \"quotes\"\nand emojis 🦆\nand \\escaped chars", - // Potentially problematic JSON content - "}{[]\",\\", - "]}}\"\\n\\\"{[", - "Edge case: } ] some text", - "{\"foo\": \"} ]\"}", - "}]", - ]; - - let mut messages = Conversation::empty(); - for text in special_chars { - messages.push(Message::user().with_text(text)); - messages.push(Message::assistant().with_text(text)); - } - - // Write messages with special characters - persist_messages(&file_path, &messages, None, None).await?; - - // Read them back - let read_messages = read_messages(&file_path)?; - - // Compare all messages - assert_eq!(messages.len(), read_messages.len()); - for (i, (orig, read)) in messages.iter().zip(read_messages.iter()).enumerate() { - assert_eq!(orig.role, read.role, "Role mismatch at message {}", i); - assert_eq!( - orig.content.len(), - read.content.len(), - "Content length mismatch at message {}", - i - ); - - if let (Some(MessageContent::Text(orig_text)), Some(MessageContent::Text(read_text))) = - (orig.content.first(), read.content.first()) - { - assert_eq!( - orig_text.text, read_text.text, - "Text mismatch at message {}\nExpected: {}\nGot: {}", - i, orig_text.text, read_text.text - ); - } else { - panic!("Messages don't match expected structure at index {}", i); - } - } - - // Verify file format - let contents = fs::read_to_string(&file_path)?; - let lines: Vec<&str> = contents.lines().collect(); - - // First line should be metadata - assert!( - lines[0].contains("\"description\""), - "First line should be metadata" - ); - - // Each subsequent line should be valid JSON - for (i, line) in lines.iter().enumerate().skip(1) { - assert!( - serde_json::from_str::(line).is_ok(), - "Invalid JSON at line {}: {}", - i + 1, - line - ); - } - - Ok(()) - } - - #[tokio::test] - async fn test_large_content_truncation() -> Result<()> { - let dir = tempdir()?; - let file_path = dir.path().join("large_content.jsonl"); - - // Create a message with content larger than the 50KB truncation limit - let very_large_text = "A".repeat(100_000); // 100KB of text - let messages = Conversation::new_unvalidated(vec![ - Message::user().with_text(&very_large_text), - Message::assistant().with_text("Small response"), - ]); - - // Write messages - persist_messages(&file_path, &messages, None, None).await?; - - // Read them back - should be truncated - let read_messages = read_messages(&file_path)?; - - assert_eq!(messages.len(), read_messages.len()); - - // First message should be truncated - if let Some(MessageContent::Text(read_text)) = - read_messages.first().unwrap().content.first() - { - assert!( - read_text.text.len() < very_large_text.len(), - "Content should be truncated" - ); - assert!( - read_text - .text - .contains("content truncated during session loading"), - "Should contain truncation notice" - ); - assert!( - read_text.text.starts_with("AAAA"), - "Should start with original content" - ); - } else { - panic!("Expected text content in first message"); - } - - // Second message should be unchanged - if let Some(MessageContent::Text(read_text)) = read_messages.messages()[1].content.first() { - assert_eq!(read_text.text, "Small response"); - } else { - panic!("Expected text content in second message"); - } - - Ok(()) - } - - #[tokio::test] - async fn test_metadata_special_chars() -> Result<()> { - let dir = tempdir()?; - let file_path = dir.path().join("metadata.jsonl"); - - let mut metadata = SessionMetadata::default(); - #[allow(clippy::field_reassign_with_default)] - { - metadata.description = "Description with\nnewline and \"quotes\" and 🦆".to_string(); - } - - let messages = Conversation::new_unvalidated(vec![Message::user().with_text("test")]); - - // Write with special metadata - save_messages_with_metadata(&file_path, &metadata, &messages)?; - - // Read back metadata - let read_metadata = read_metadata(&file_path)?; - assert_eq!(metadata.description, read_metadata.description); - - Ok(()) - } - - #[test] - fn test_invalid_working_dir() -> Result<()> { - let dir = tempdir()?; - let file_path = dir.path().join("test.jsonl"); - - // Create metadata with non-existent directory - let invalid_dir = PathBuf::from("/path/that/does/not/exist"); - - let metadata = SessionMetadata::new(invalid_dir.clone()); - - // Should fall back to home directory - assert_ne!(metadata.working_dir, invalid_dir); - assert_eq!(metadata.working_dir, get_home_dir()); - - // Test deserialization of invalid directory - let messages = Conversation::new_unvalidated(vec![Message::user().with_text("test")]); - save_messages_with_metadata(&file_path, &metadata, &messages)?; - - // Modify the file to include invalid directory - let contents = fs::read_to_string(&file_path)?; - let mut lines: Vec = contents.lines().map(String::from).collect(); - lines[0] = lines[0].replace( - &get_home_dir().to_string_lossy().into_owned(), - &invalid_dir.to_string_lossy(), - ); - fs::write(&file_path, lines.join("\n"))?; - - // Read back - should fall back to home dir - let read_metadata = read_metadata(&file_path)?; - assert_ne!(read_metadata.working_dir, invalid_dir); - assert_eq!(read_metadata.working_dir, get_current_working_dir()); - - Ok(()) - } - - #[tokio::test] - async fn test_working_dir_preservation() -> Result<()> { - let dir = tempdir()?; - let file_path = dir.path().join("test.jsonl"); - - // Create a temporary working directory - let working_dir = tempdir()?; - let working_dir_path = working_dir.path().to_path_buf(); - - // Create messages - let messages = - Conversation::new_unvalidated(vec![Message::user().with_text("test message")]); - - // Use persist_messages_with_schedule_id to set working dir - persist_messages_with_schedule_id( - &file_path, - &messages, - None, - None, - Some(working_dir_path.clone()), - ) - .await?; - - // Read back the metadata and verify working_dir is preserved - let metadata = read_metadata(&file_path)?; - assert_eq!(metadata.working_dir, working_dir_path); - - // Verify the messages are also preserved - let read_messages = read_messages(&file_path)?; - assert_eq!(read_messages.len(), 1); - assert_eq!( - read_messages.first().unwrap().role, - messages.messages()[0].role - ); - - Ok(()) - } - - #[tokio::test] - async fn test_working_dir_issue_fixed() -> Result<()> { - // This test demonstrates that the working_dir issue in jsonl files is fixed - let dir = tempdir()?; - let file_path = dir.path().join("test.jsonl"); - - // Create a temporary working directory (this simulates the actual working directory) - let working_dir = tempdir()?; - let working_dir_path = working_dir.path().to_path_buf(); - - // Create messages - let messages = - Conversation::new_unvalidated(vec![Message::user().with_text("test message")]); - - // Get the home directory for comparison - let home_dir = get_home_dir(); - - // Test 1: Using the old persist_messages function (without working_dir) - // This will fall back to home directory since no working_dir is provided - persist_messages(&file_path, &messages, None, None).await?; - - // Read back the metadata - this should now have the home directory as working_dir - let metadata_old = read_metadata(&file_path)?; - assert_eq!( - metadata_old.working_dir, home_dir, - "persist_messages should use home directory when no working_dir is provided" - ); - - // Test 2: Using persist_messages_with_schedule_id function - // This should properly set the working_dir (this is the main fix) - persist_messages_with_schedule_id( - &file_path, - &messages, - None, - None, - Some(working_dir_path.clone()), - ) - .await?; - - // Read back the metadata - this should now have the correct working_dir - let metadata_new = read_metadata(&file_path)?; - assert_eq!( - metadata_new.working_dir, working_dir_path, - "persist_messages_with_schedule_id should use provided working_dir" - ); - assert_ne!( - metadata_new.working_dir, home_dir, - "working_dir should be different from home directory" - ); - - // Test 3: Create a new session file without working_dir (should fall back to home) - let file_path_2 = dir.path().join("test2.jsonl"); - persist_messages_with_schedule_id( - &file_path_2, - &messages, - None, - None, - None, // No working_dir provided - ) - .await?; - - let metadata_fallback = read_metadata(&file_path_2)?; - assert_eq!(metadata_fallback.working_dir, home_dir, "persist_messages_with_schedule_id should fall back to home directory when no working_dir is provided"); - - // Test 4: Test that the fix works for existing files - // Create a session file and then add to it with different working_dir - let file_path_3 = dir.path().join("test3.jsonl"); - - // First, create with home directory - persist_messages(&file_path_3, &messages, None, None).await?; - let metadata_initial = read_metadata(&file_path_3)?; - assert_eq!( - metadata_initial.working_dir, home_dir, - "Initial session should use home directory" - ); - - // Then update with a specific working_dir - persist_messages_with_schedule_id( - &file_path_3, - &messages, - None, - None, - Some(working_dir_path.clone()), - ) - .await?; - - let metadata_updated = read_metadata(&file_path_3)?; - assert_eq!( - metadata_updated.working_dir, working_dir_path, - "Updated session should use new working_dir" - ); - - // Test 5: Most important test - simulate the real-world scenario where - // CLI and web interfaces pass the current directory instead of None - let file_path_4 = dir.path().join("test4.jsonl"); - let current_dir = std::env::current_dir()?; - - // This is what web.rs and session/mod.rs do now after the fix - persist_messages_with_schedule_id( - &file_path_4, - &messages, - None, - None, - Some(current_dir.clone()), - ) - .await?; - - let metadata_current = read_metadata(&file_path_4)?; - assert_eq!( - metadata_current.working_dir, current_dir, - "Session should use current directory when explicitly provided" - ); - // This should NOT be the home directory anymore (unless current_dir == home_dir) - if current_dir != home_dir { - assert_ne!( - metadata_current.working_dir, home_dir, - "working_dir should be different from home directory when current_dir is different" - ); - } - - Ok(()) - } - - #[test] - fn test_windows_path_validation() -> Result<()> { - // Test the Windows path validation logic - let temp_dir = tempfile::tempdir()?; - let session_dir = temp_dir.path().join("sessions"); - fs::create_dir_all(&session_dir)?; - - // Test case 1: Valid path within session directory - let valid_path = session_dir.join("test.jsonl"); - assert!(validate_path_within_session_dir(&valid_path, &session_dir)?); - - // Test case 2: Invalid path outside session directory - let invalid_path = temp_dir.path().join("outside.jsonl"); - assert!(!validate_path_within_session_dir( - &invalid_path, - &session_dir - )?); - - // Test case 3: Path with different separators (simulate Windows issue) - let mixed_sep_path = session_dir.join("subdir").join("test.jsonl"); - fs::create_dir_all(mixed_sep_path.parent().unwrap())?; - assert!(validate_path_within_session_dir( - &mixed_sep_path, - &session_dir - )?); - - // Test case 4: Non-existent path within session directory - let nonexistent_path = session_dir.join("nonexistent").join("test.jsonl"); - assert!(validate_path_within_session_dir( - &nonexistent_path, - &session_dir - )?); - - Ok(()) - } - - #[test] - fn test_path_normalization() { - let temp_dir = tempfile::tempdir().unwrap(); - let test_path = temp_dir.path().join("test"); - - // Test that normalization doesn't crash and returns a path - let normalized = normalize_path_for_comparison(&test_path); - assert!(!normalized.as_os_str().is_empty()); - - // Test with existing path - fs::create_dir_all(&test_path).unwrap(); - let normalized_existing = normalize_path_for_comparison(&test_path); - assert!(!normalized_existing.as_os_str().is_empty()); - } - - #[tokio::test] - async fn test_save_session_parameter() -> Result<()> { - let dir = tempdir()?; - let file_path = dir.path().join("test_save_session.jsonl"); - - let messages = Conversation::new_unvalidated(vec![ - Message::user().with_text("Hello"), - Message::assistant().with_text("Hi there"), - ]); - - let metadata = SessionMetadata::default(); - - // Test with save_session = true - should create file - save_messages_with_metadata(&file_path, &metadata, &messages)?; - assert!( - file_path.exists(), - "File should be created when save_session=true" - ); - - // Verify content is correct - let read_messages = read_messages(&file_path)?; - assert_eq!(messages.len(), read_messages.len()); - - Ok(()) - } - - #[tokio::test] - async fn test_persist_messages_with_save_session_false() -> Result<()> { - let dir = tempdir()?; - let file_path = dir.path().join("test_persist_no_save.jsonl"); - - let messages = Conversation::new_unvalidated(vec![ - Message::user().with_text("Test message"), - Message::assistant().with_text("Test response"), - ]); - - // Test persist_messages_with_schedule_id with working_dir parameter - persist_messages_with_schedule_id( - &file_path, - &messages, - None, - Some("test_schedule".to_string()), - None, - ) - .await?; - - assert!( - file_path.exists(), - "File should be created when save_session=true" - ); - - // Verify the schedule_id was set correctly - let metadata = read_metadata(&file_path)?; - assert_eq!(metadata.schedule_id, Some("test_schedule".to_string())); - - Ok(()) - } -} diff --git a/crates/goose/src/temporal_scheduler.rs b/crates/goose/src/temporal_scheduler.rs index 5f86f8d405f7..955443453279 100644 --- a/crates/goose/src/temporal_scheduler.rs +++ b/crates/goose/src/temporal_scheduler.rs @@ -11,7 +11,7 @@ use tracing::{info, warn}; use crate::scheduler::{normalize_cron_expression, ScheduledJob, SchedulerError}; use crate::scheduler_trait::SchedulerTrait; -use crate::session::storage::SessionMetadata; +use crate::session::{Session, SessionManager}; const TEMPORAL_SERVICE_STARTUP_TIMEOUT: Duration = Duration::from_secs(15); const TEMPORAL_SERVICE_HEALTH_CHECK_INTERVAL: Duration = Duration::from_millis(500); @@ -699,37 +699,23 @@ impl TemporalScheduler { } } - // Note: This method fetches sessions from the session storage directly - // since Temporal service doesn't track session metadata pub async fn sessions( &self, sched_id: &str, limit: usize, - ) -> Result, SchedulerError> { - use crate::session::storage; + ) -> Result, SchedulerError> { + use crate::session::SessionManager; - // Get all session files - let all_session_files = storage::list_sessions().map_err(|e| { + // Get all sessions from the database + let all_sessions = SessionManager::list_sessions().await.map_err(|e| { SchedulerError::SchedulerInternalError(format!("Failed to list sessions: {}", e)) })?; - let mut schedule_sessions: Vec<(String, SessionMetadata)> = Vec::new(); + let mut schedule_sessions: Vec<(String, Session)> = Vec::new(); - for (session_name, session_path) in all_session_files { - match storage::read_metadata(&session_path) { - Ok(metadata) => { - // Check if this session belongs to the requested schedule - if metadata.schedule_id.as_deref() == Some(sched_id) { - schedule_sessions.push((session_name, metadata)); - } - } - Err(e) => { - tracing::warn!( - "Failed to read metadata for session file {}: {}. Skipping.", - session_path.display(), - e - ); - } + for session in all_sessions { + if session.schedule_id.as_deref() == Some(sched_id) { + schedule_sessions.push((session.id.clone(), session)); } } @@ -737,10 +723,10 @@ impl TemporalScheduler { schedule_sessions.sort_by(|a, b| b.0.cmp(&a.0)); // Take only the requested limit - let result_sessions: Vec<(String, SessionMetadata)> = + let result_sessions: Vec<(String, Session)> = schedule_sessions.into_iter().take(limit).collect(); - tracing::info!( + info!( "Found {} sessions for schedule '{}'", result_sessions.len(), sched_id @@ -858,38 +844,41 @@ impl TemporalScheduler { let recent_sessions = self.sessions(&job.id, 3).await?; let mut has_active_session = false; - for (session_name, _) in recent_sessions { - let session_path = match crate::session::storage::get_path( - crate::session::storage::Identifier::Name(session_name.clone()), - ) { - Ok(path) => path, + for (session_id, _) in recent_sessions { + // Get session info from database to check last update time + match SessionManager::list_sessions().await { + Ok(all_sessions) => { + if let Some(session_info) = + all_sessions.iter().find(|s| s.id == session_id) + { + // Parse the updated_at timestamp from the database + if let Ok(modified_dt) = DateTime::parse_from_str( + &session_info.updated_at, + "%Y-%m-%d %H:%M:%S UTC", + ) { + let modified_utc = modified_dt.with_timezone(&Utc); + let now = Utc::now(); + let time_diff = now.signed_duration_since(modified_utc); + + // Increased tolerance to 5 minutes to reduce false positives + if time_diff.num_minutes() < 5 { + has_active_session = true; + tracing::debug!( + "Found active session for job '{}' modified {} minutes ago", + job.id, + time_diff.num_minutes() + ); + break; + } + } + } + } Err(e) => { tracing::warn!( - "Failed to get session path for '{}': {}", - session_name, + "Failed to list sessions to check activity for job '{}': {}", + job.id, e ); - continue; - } - }; - - // Check if session file was modified recently (within last 5 minutes instead of 2) - if let Ok(metadata) = std::fs::metadata(&session_path) { - if let Ok(modified) = metadata.modified() { - let modified_dt: DateTime = modified.into(); - let now = Utc::now(); - let time_diff = now.signed_duration_since(modified_dt); - - // Increased tolerance to 5 minutes to reduce false positives - if time_diff.num_minutes() < 5 { - has_active_session = true; - tracing::debug!( - "Found active session for job '{}' modified {} minutes ago", - job.id, - time_diff.num_minutes() - ); - break; - } } } } @@ -897,9 +886,9 @@ impl TemporalScheduler { // Only mark as completed if both Temporal service check failed AND no recent session activity if !has_active_session { tracing::info!( - "No active sessions found for job '{}' in the last 5 minutes, marking as completed", - job.id - ); + "No active sessions found for job '{}' in the last 5 minutes, marking as completed", + job.id + ); let request = JobRequest { action: "mark_completed".to_string(), @@ -966,14 +955,16 @@ impl TemporalScheduler { let recent_sessions = self.sessions(sched_id, 1).await?; if let Some((session_name, _session_metadata)) = recent_sessions.first() { - // Check if this session is still active by looking at the session file - let session_path = match crate::session::storage::get_path( - crate::session::storage::Identifier::Name(session_name.clone()), - ) { - Ok(path) => path, + // Check if this session still exists in the new system + match SessionManager::get_session(session_name, false).await { + Ok(_) => { + // Session exists, consider it active + let start_time = Utc::now(); + return Ok(Some((session_name.clone(), start_time))); + } Err(e) => { tracing::warn!( - "Failed to get session path for '{}': {}", + "Session '{}' no longer exists: {}", session_name, e ); @@ -983,21 +974,6 @@ impl TemporalScheduler { let start_time = Utc::now(); return Ok(Some((session_id, start_time))); } - }; - - // If the session file was modified recently (within last 5 minutes), - // consider it as the current running session - if let Ok(metadata) = std::fs::metadata(&session_path) { - if let Ok(modified) = metadata.modified() { - let modified_dt: DateTime = modified.into(); - let now = Utc::now(); - let time_diff = now.signed_duration_since(modified_dt); - - if time_diff.num_minutes() < 5 { - // This looks like an active session - return Ok(Some((session_name.clone(), modified_dt))); - } - } } } @@ -1023,10 +999,9 @@ impl TemporalScheduler { async fn make_request(&self, request: JobRequest) -> Result { let url = format!("{}/jobs", self.service_url); - tracing::info!( + info!( "TemporalScheduler: Making HTTP request to {} with action '{}'", - url, - request.action + url, request.action ); let response = self @@ -1212,7 +1187,7 @@ impl SchedulerTrait for TemporalScheduler { &self, sched_id: &str, limit: usize, - ) -> Result, SchedulerError> { + ) -> Result, SchedulerError> { self.sessions(sched_id, limit).await } diff --git a/crates/goose/tests/agent.rs b/crates/goose/tests/agent.rs index 6f551d5c926d..d1ce5231ec3e 100644 --- a/crates/goose/tests/agent.rs +++ b/crates/goose/tests/agent.rs @@ -364,10 +364,9 @@ mod schedule_tool_tests { use goose::agents::platform_tools::PLATFORM_MANAGE_SCHEDULE_TOOL_NAME; use goose::scheduler::{ScheduledJob, SchedulerError}; use goose::scheduler_trait::SchedulerTrait; - use goose::session::storage::SessionMetadata; + use goose::session::Session; use std::sync::Arc; - // Mock scheduler for testing struct MockScheduler { jobs: tokio::sync::Mutex>, } @@ -419,7 +418,7 @@ mod schedule_tool_tests { &self, _sched_id: &str, _limit: usize, - ) -> Result, SchedulerError> { + ) -> Result, SchedulerError> { Ok(vec![]) } @@ -853,7 +852,7 @@ mod final_output_tool_tests { mod retry_tests { use super::*; use async_trait::async_trait; - use goose::agents::types::{RetryConfig, SessionConfig, SuccessCheck}; + use goose::agents::types::{RetryConfig, SuccessCheck}; use goose::conversation::message::Message; use goose::conversation::Conversation; use goose::model::ModelConfig; @@ -939,21 +938,10 @@ mod retry_tests { "Valid config should pass validation" ); - let session_config = SessionConfig { - id: goose::session::Identifier::Name("test-retry".to_string()), - working_dir: std::env::current_dir()?, - schedule_id: None, - execution_mode: None, - max_turns: None, - retry_config: Some(retry_config), - }; - let conversation = Conversation::new(vec![Message::user().with_text("Complete this task")]).unwrap(); - let reply_stream = agent - .reply(conversation, Some(session_config), None) - .await?; + let reply_stream = agent.reply(conversation, None, None).await?; tokio::pin!(reply_stream); let mut responses = Vec::new(); @@ -1051,10 +1039,8 @@ mod max_turns_tests { use goose::model::ModelConfig; use goose::providers::base::{Provider, ProviderMetadata, ProviderUsage, Usage}; use goose::providers::errors::ProviderError; - use goose::session::storage::Identifier; use mcp_core::tool::ToolCall; use rmcp::model::Tool; - use std::path::PathBuf; struct MockToolProvider {} @@ -1116,21 +1102,9 @@ mod max_turns_tests { let provider = Arc::new(MockToolProvider::new()); agent.update_provider(provider).await?; // The mock provider will call a non-existent tool, which will fail and allow the loop to continue - - // Create session config with max_turns = 1 - let session_config = goose::agents::SessionConfig { - id: Identifier::Name("test_session".to_string()), - working_dir: PathBuf::from("/tmp"), - schedule_id: None, - execution_mode: None, - max_turns: Some(1), - retry_config: None, - }; let conversation = Conversation::new(vec![Message::user().with_text("Hello")]).unwrap(); - let reply_stream = agent - .reply(conversation, Some(session_config), None) - .await?; + let reply_stream = agent.reply(conversation, None, None).await?; tokio::pin!(reply_stream); let mut responses = Vec::new(); diff --git a/crates/goose/tests/private_tests.rs b/crates/goose/tests/private_tests.rs index 6e8f544807da..e5788163021b 100644 --- a/crates/goose/tests/private_tests.rs +++ b/crates/goose/tests/private_tests.rs @@ -757,85 +757,6 @@ async fn test_schedule_tool_sessions_action_empty() { assert!(calls.contains(&"sessions".to_string())); } -#[tokio::test] -async fn test_schedule_tool_session_content_action() { - let (agent, _) = ScheduleToolTestBuilder::new().build().await; - - // Test with a non-existent session - let arguments = json!({ - "action": "session_content", - "session_id": "non_existent_session" - }); - - let result = agent - .handle_schedule_management(arguments, "test_req".to_string()) - .await; - assert!(result.is_err()); - - if let Err(err) = result { - assert!(err - .message - .contains("Session 'non_existent_session' not found")); - } else { - panic!("Expected ExecutionError"); - } -} - -#[tokio::test] -async fn test_schedule_tool_session_content_action_with_real_session() { - let (agent, _) = ScheduleToolTestBuilder::new().build().await; - - // Create a temporary session file in the proper session directory - let session_dir = goose::session::storage::ensure_session_dir().unwrap(); - let session_id = "test_session_real"; - let session_path = session_dir.join(format!("{}.jsonl", session_id)); - - // Create test metadata and messages - let metadata = create_test_session_metadata(2, "/tmp"); - let messages = goose::conversation::Conversation::new_unvalidated(vec![ - goose::conversation::message::Message::user().with_text("Hello"), - goose::conversation::message::Message::assistant().with_text("Hi there!"), - ]); - - // Save the session file - goose::session::storage::save_messages_with_metadata(&session_path, &metadata, &messages) - .unwrap(); - - // Test the session_content action - let arguments = json!({ - "action": "session_content", - "session_id": session_id - }); - - let result = agent - .handle_schedule_management(arguments, "test_req".to_string()) - .await; - - // Clean up the test session file - let _ = std::fs::remove_file(&session_path); - - // Verify the result - assert!(result.is_ok()); - - if let Ok(content) = result { - assert_eq!(content.len(), 1); - if let Some(text_content) = content[0].as_text() { - assert!(text_content - .text - .contains("Session 'test_session_real' Content:")); - assert!(text_content.text.contains("Metadata:")); - assert!(text_content.text.contains("Messages:")); - assert!(text_content.text.contains("Hello")); - assert!(text_content.text.contains("Hi there!")); - assert!(text_content.text.contains("Test session")); - } else { - panic!("Expected text content"); - } - } else { - panic!("Expected successful result"); - } -} - #[tokio::test] async fn test_schedule_tool_session_content_action_missing_session_id() { let (agent, _) = ScheduleToolTestBuilder::new().build().await; diff --git a/crates/goose/tests/test_support.rs b/crates/goose/tests/test_support.rs index feea692d6a72..d53319477ddf 100644 --- a/crates/goose/tests/test_support.rs +++ b/crates/goose/tests/test_support.rs @@ -12,7 +12,7 @@ use tokio::sync::Mutex; use goose::agents::Agent; use goose::scheduler::{ScheduledJob, SchedulerError}; use goose::scheduler_trait::SchedulerTrait; -use goose::session::storage::SessionMetadata; +use goose::session::Session; #[derive(Debug, Clone)] pub enum MockBehavior { @@ -30,7 +30,7 @@ pub struct ConfigurableMockScheduler { call_log: Arc>>, behaviors: Arc>>, #[allow(clippy::type_complexity)] - sessions_data: Arc>>>, + sessions_data: Arc>>>, } #[allow(dead_code)] @@ -184,7 +184,7 @@ impl SchedulerTrait for ConfigurableMockScheduler { &self, sched_id: &str, limit: usize, - ) -> Result, SchedulerError> { + ) -> Result, SchedulerError> { self.log_call("sessions").await; match self.get_behavior("sessions").await { @@ -362,11 +362,7 @@ impl ScheduleToolTestBuilder { self } - pub async fn with_sessions_data( - self, - job_id: &str, - sessions: Vec<(String, SessionMetadata)>, - ) -> Self { + pub async fn with_sessions_data(self, job_id: &str, sessions: Vec<(String, Session)>) -> Self { { let mut sessions_data = self.scheduler.sessions_data.lock().await; sessions_data.insert(job_id.to_string(), sessions); @@ -381,13 +377,14 @@ impl ScheduleToolTestBuilder { } } -// Helper function to create test session metadata -pub fn create_test_session_metadata(message_count: usize, working_dir: &str) -> SessionMetadata { - SessionMetadata { - message_count, +pub fn create_test_session_metadata(message_count: usize, working_dir: &str) -> Session { + Session { + id: "".to_string(), working_dir: PathBuf::from(working_dir), description: "Test session".to_string(), + created_at: "".to_string(), schedule_id: Some("test_job".to_string()), + recipe: None, total_tokens: Some(100), input_tokens: Some(50), output_tokens: Some(50), @@ -395,6 +392,8 @@ pub fn create_test_session_metadata(message_count: usize, working_dir: &str) -> accumulated_input_tokens: Some(50), accumulated_output_tokens: Some(50), extension_data: Default::default(), - recipe: None, + updated_at: "".to_string(), + conversation: None, + message_count, } } diff --git a/crates/goose/tests/todo_session_integration.rs b/crates/goose/tests/todo_session_integration.rs deleted file mode 100644 index c63ec5f0d1d3..000000000000 --- a/crates/goose/tests/todo_session_integration.rs +++ /dev/null @@ -1,496 +0,0 @@ -use futures::StreamExt; -use goose::agents::types::SessionConfig; -use goose::agents::{Agent, AgentEvent}; -use goose::conversation::message::Message; -use goose::conversation::Conversation; -use goose::model::ModelConfig; -use goose::providers::base::{Provider, ProviderMetadata, ProviderUsage, Usage}; -use goose::providers::errors::ProviderError; -use goose::session; -use goose::session::storage::SessionMetadata; -use rmcp::model::Tool; -use std::sync::Arc; -use tempfile::TempDir; -use uuid::Uuid; - -// Mock provider implementation for testing -struct MockProvider { - model_config: ModelConfig, -} - -impl MockProvider { - fn new() -> Self { - Self { - model_config: ModelConfig::new_or_fail("mock-model"), - } - } -} - -#[async_trait::async_trait] -impl Provider for MockProvider { - fn metadata() -> ProviderMetadata - where - Self: Sized, - { - ProviderMetadata::new( - "mock", - "Mock Provider", - "A mock provider for testing", - "mock-model", - vec!["mock-model"], - "https://example.com", - vec![], - ) - } - - async fn complete( - &self, - _system: &str, - _messages: &[Message], - _tools: &[Tool], - ) -> Result<(Message, ProviderUsage), ProviderError> { - // Return a simple mock response - Ok(( - Message::assistant().with_text("Mock response"), - ProviderUsage::new( - "mock-model".to_string(), - Usage::new(Some(10), Some(20), Some(30)), - ), - )) - } - - async fn complete_with_model( - &self, - _model_config: &ModelConfig, - _system: &str, - _messages: &[Message], - _tools: &[Tool], - ) -> Result<(Message, ProviderUsage), ProviderError> { - // Return a simple mock response - Ok(( - Message::assistant().with_text("Mock response"), - ProviderUsage::new( - "mock-model".to_string(), - Usage::new(Some(10), Some(20), Some(30)), - ), - )) - } - - fn get_model_config(&self) -> ModelConfig { - self.model_config.clone() - } - - async fn stream( - &self, - _system: &str, - _messages: &[Message], - _tools: &[Tool], - ) -> Result { - // Return a simple mock stream - let message = Message::assistant().with_text("Mock stream response"); - let usage = ProviderUsage::new( - "mock-model".to_string(), - Usage::new(Some(10), Some(20), Some(30)), - ); - Ok(goose::providers::base::stream_from_single_message( - message, usage, - )) - } - - fn supports_streaming(&self) -> bool { - true - } - - async fn generate_session_name( - &self, - _messages: &Conversation, - ) -> Result { - Ok("Mock session description".to_string()) - } -} - -async fn create_test_session_dir() -> TempDir { - TempDir::new().unwrap() -} - -async fn create_test_agent_with_mock_provider() -> Agent { - let agent = Agent::new(); - let mock_provider = Arc::new(MockProvider::new()); - agent.update_provider(mock_provider).await.unwrap(); - agent -} - -#[tokio::test] -async fn test_todo_add_persists_to_session() { - let temp_dir = create_test_session_dir().await; - let session_id = session::Identifier::Name(format!("test_session_{}", uuid::Uuid::new_v4())); - let agent = create_test_agent_with_mock_provider().await; - - // Create a conversation with a TODO add request - let messages = - vec![Message::user().with_text("Add these tasks to my todo list: Buy milk, Call dentist")]; - let conversation = Conversation::new(messages).unwrap(); - - let session_config = SessionConfig { - id: session_id.clone(), - working_dir: temp_dir.path().to_path_buf(), - schedule_id: None, - max_turns: Some(10), - execution_mode: Some("auto".to_string()), - retry_config: None, - }; - - // Process the conversation - let mut stream = agent - .reply(conversation, Some(session_config.clone()), None) - .await - .unwrap(); - - // Collect all events - while let Some(event) = stream.next().await { - if let Ok(_event) = event { - // Process events - } - } - - // Verify TODO was persisted to session - let session_path = goose::session::storage::get_path(session_id).unwrap(); - let metadata = goose::session::storage::read_metadata(&session_path).unwrap(); - - // Since we're using a mock provider, we can't test the actual TODO content - // but we can verify the metadata structure is correct - assert!( - metadata.extension_data.extension_states.is_empty() - || !metadata.extension_data.extension_states.is_empty() - ); -} - -#[tokio::test] -async fn test_todo_list_reads_from_session() { - let temp_dir = create_test_session_dir().await; - let session_id = session::Identifier::Name(format!("test_session_{}", Uuid::new_v4())); - let agent = create_test_agent_with_mock_provider().await; - - // Pre-populate session with TODO content - let session_path = goose::session::storage::get_path(session_id.clone()).unwrap(); - let mut metadata = SessionMetadata::default(); - use goose::session::extension_data::{ExtensionState, TodoState}; - let todo_state = TodoState::new("- Task 1\n- Task 2\n- Task 3".to_string()); - todo_state - .to_extension_data(&mut metadata.extension_data) - .unwrap(); - goose::session::storage::update_metadata(&session_path, &metadata) - .await - .unwrap(); - - // Create a conversation requesting TODO list - let messages = vec![Message::user().with_text("Show me my todo list")]; - let conversation = Conversation::new(messages).unwrap(); - - let session_config = SessionConfig { - id: session_id.clone(), - working_dir: temp_dir.path().to_path_buf(), - schedule_id: None, - max_turns: Some(10), - execution_mode: Some("auto".to_string()), - retry_config: None, - }; - - // Process the conversation - let mut stream = agent - .reply(conversation, Some(session_config), None) - .await - .unwrap(); - - // Collect all events - while let Some(event) = stream.next().await { - if let Ok(AgentEvent::Message(msg)) = event { - let _text = msg.as_concat_text(); - // With mock provider, we can't verify the actual content - } - } - - // Verify the TODO content is still in session - let metadata_after = goose::session::storage::read_metadata(&session_path).unwrap(); - let todo_state_after = TodoState::from_extension_data(&metadata_after.extension_data); - assert!(todo_state_after.is_some()); - assert_eq!( - todo_state_after.unwrap().content, - "- Task 1\n- Task 2\n- Task 3".to_string() - ); -} - -#[tokio::test] -async fn test_todo_isolation_between_sessions() { - use goose::session::extension_data::{ExtensionState, TodoState}; - let session1_id = session::Identifier::Name(format!("test_session_{}", Uuid::new_v4())); - let session2_id = session::Identifier::Name(format!("test_session_{}", Uuid::new_v4())); - - // Add TODO to session1 - let session1_path = goose::session::storage::get_path(session1_id.clone()).unwrap(); - let mut metadata1 = SessionMetadata::default(); - let todo_state1 = TodoState::new("Session 1 tasks".to_string()); - todo_state1 - .to_extension_data(&mut metadata1.extension_data) - .unwrap(); - goose::session::storage::update_metadata(&session1_path, &metadata1) - .await - .unwrap(); - - // Add different TODO to session2 - let session2_path = goose::session::storage::get_path(session2_id.clone()).unwrap(); - let mut metadata2 = SessionMetadata::default(); - let todo_state2 = TodoState::new("Session 2 tasks".to_string()); - todo_state2 - .to_extension_data(&mut metadata2.extension_data) - .unwrap(); - goose::session::storage::update_metadata(&session2_path, &metadata2) - .await - .unwrap(); - - // Verify isolation - let metadata1_read = goose::session::storage::read_metadata(&session1_path).unwrap(); - let metadata2_read = goose::session::storage::read_metadata(&session2_path).unwrap(); - - let todo1 = TodoState::from_extension_data(&metadata1_read.extension_data).unwrap(); - let todo2 = TodoState::from_extension_data(&metadata2_read.extension_data).unwrap(); - - assert_eq!(todo1.content, "Session 1 tasks"); - assert_eq!(todo2.content, "Session 2 tasks"); -} - -#[tokio::test] -async fn test_todo_clear_removes_from_session() { - use goose::session::extension_data::{ExtensionState, TodoState}; - let temp_dir = create_test_session_dir().await; - let session_id = session::Identifier::Name(format!("test_session_{}", Uuid::new_v4())); - let agent = create_test_agent_with_mock_provider().await; - - // Pre-populate session with TODO content - let session_path = goose::session::storage::get_path(session_id.clone()).unwrap(); - let mut metadata = SessionMetadata::default(); - let todo_state = TodoState::new("- Task to clear".to_string()); - todo_state - .to_extension_data(&mut metadata.extension_data) - .unwrap(); - goose::session::storage::update_metadata(&session_path, &metadata) - .await - .unwrap(); - - // Create a conversation to clear TODO - let messages = vec![Message::user().with_text("Clear my entire todo list")]; - let conversation = Conversation::new(messages).unwrap(); - - let session_config = SessionConfig { - id: session_id.clone(), - working_dir: temp_dir.path().to_path_buf(), - schedule_id: None, - max_turns: Some(10), - execution_mode: Some("auto".to_string()), - retry_config: None, - }; - - // Process the conversation - let mut stream = agent - .reply(conversation, Some(session_config), None) - .await - .unwrap(); - - // Consume the stream - while (stream.next().await).is_some() {} - - // With mock provider, the TODO won't actually be cleared via tool calls - // but we can verify the structure is correct - let metadata_after = goose::session::storage::read_metadata(&session_path).unwrap(); - let todo_state_after = TodoState::from_extension_data(&metadata_after.extension_data); - assert!(todo_state_after.is_some()); // Will still have the original content with mock -} - -#[tokio::test] -async fn test_todo_persistence_across_agent_instances() { - use goose::session::extension_data::{ExtensionState, TodoState}; - let session_id = session::Identifier::Name(format!("test_session_{}", Uuid::new_v4())); - - // First agent instance adds TODO - { - let session_path = goose::session::storage::get_path(session_id.clone()).unwrap(); - let mut metadata = SessionMetadata::default(); - let todo_state = TodoState::new("Persistent task".to_string()); - todo_state - .to_extension_data(&mut metadata.extension_data) - .unwrap(); - goose::session::storage::update_metadata(&session_path, &metadata) - .await - .unwrap(); - } - - // Second agent instance reads TODO - { - let session_path = goose::session::storage::get_path(session_id.clone()).unwrap(); - let metadata = goose::session::storage::read_metadata(&session_path).unwrap(); - let todo_state = TodoState::from_extension_data(&metadata.extension_data).unwrap(); - assert_eq!(todo_state.content, "Persistent task"); - } -} - -#[tokio::test] -async fn test_todo_max_chars_limit() { - use goose::session::extension_data::{ExtensionState, TodoState}; - let session_id = session::Identifier::Name(format!("test_session_{}", Uuid::new_v4())); - - // Set a small limit for testing - std::env::set_var("GOOSE_TODO_MAX_CHARS", "50"); - - let session_path = goose::session::storage::get_path(session_id.clone()).unwrap(); - let mut metadata = SessionMetadata::default(); - - // Try to set content that exceeds the limit - let long_content = "x".repeat(100); - let todo_state = TodoState::new(long_content.clone()); - todo_state - .to_extension_data(&mut metadata.extension_data) - .unwrap(); - - // This should succeed at the storage level (storage doesn't enforce limits) - goose::session::storage::update_metadata(&session_path, &metadata) - .await - .unwrap(); - - // But when the agent tries to write through the TODO tool, it should enforce the limit - // This would be tested through the agent's dispatch_todo_tool_with_session method - - // Clean up - std::env::remove_var("GOOSE_TODO_MAX_CHARS"); -} - -#[tokio::test] -async fn test_todo_with_special_characters() { - use goose::session::extension_data::{ExtensionState, TodoState}; - let session_id = session::Identifier::Name(format!("test_session_{}", Uuid::new_v4())); - - let session_path = goose::session::storage::get_path(session_id.clone()).unwrap(); - let mut metadata = SessionMetadata::default(); - - // Test with various special characters - let special_content = r#" -- Task with "quotes" -- Task with 'single quotes' -- Task with emoji 🎉 -- Task with unicode: 你好 -- Task with newline - continuation -- Task with tab separation -"#; - - let todo_state = TodoState::new(special_content.to_string()); - todo_state - .to_extension_data(&mut metadata.extension_data) - .unwrap(); - goose::session::storage::update_metadata(&session_path, &metadata) - .await - .unwrap(); - - // Read back and verify - let metadata_read = goose::session::storage::read_metadata(&session_path).unwrap(); - let todo_state_read = TodoState::from_extension_data(&metadata_read.extension_data).unwrap(); - assert_eq!(todo_state_read.content, special_content); -} - -#[tokio::test] -async fn test_todo_concurrent_access() { - use goose::session::extension_data::{ExtensionState, TodoState}; - let session_id = session::Identifier::Name(format!("test_session_{}", Uuid::new_v4())); - - // Spawn multiple concurrent TODO operations - let mut handles = vec![]; - - for i in 0..5 { - let session_id_clone = session_id.clone(); - - let handle = tokio::spawn(async move { - let session_path = goose::session::storage::get_path(session_id_clone).unwrap(); - let mut metadata = goose::session::storage::read_metadata(&session_path) - .unwrap_or_else(|_| SessionMetadata::default()); - - let current_content = TodoState::from_extension_data(&metadata.extension_data) - .map(|t| t.content) - .unwrap_or_default(); - let new_todo = TodoState::new(format!("{}\n- Task {}", current_content, i)); - new_todo - .to_extension_data(&mut metadata.extension_data) - .unwrap(); - - goose::session::storage::update_metadata(&session_path, &metadata).await - }); - - handles.push(handle); - } - - // Wait for all operations to complete - for handle in handles { - handle.await.unwrap().unwrap(); - } - - // Verify final state contains at least one task - let session_path = goose::session::storage::get_path(session_id).unwrap(); - let metadata = goose::session::storage::read_metadata(&session_path).unwrap(); - let todo_state = TodoState::from_extension_data(&metadata.extension_data).unwrap(); - - // Should contain at least one task (concurrent writes may overwrite) - assert!(todo_state.content.contains("Task")); -} - -#[tokio::test] -async fn test_todo_empty_session_returns_empty() { - use goose::session::extension_data::{ExtensionState, TodoState}; - let session_id = session::Identifier::Name(format!("test_session_{}", Uuid::new_v4())); - - let session_path = goose::session::storage::get_path(session_id.clone()).unwrap(); - let metadata = goose::session::storage::read_metadata(&session_path) - .unwrap_or_else(|_| SessionMetadata::default()); - - let todo_state = TodoState::from_extension_data(&metadata.extension_data); - assert!(todo_state.is_none() || todo_state.unwrap().content.is_empty()); -} - -#[tokio::test] -async fn test_todo_update_preserves_other_metadata() { - use goose::session::extension_data::{ExtensionState, TodoState}; - let session_id = session::Identifier::Name(format!("test_session_{}", Uuid::new_v4())); - - let session_path = goose::session::storage::get_path(session_id.clone()).unwrap(); - - // Set initial metadata with various fields - let mut metadata = SessionMetadata::default(); - #[allow(clippy::field_reassign_with_default)] - { - metadata.message_count = 5; - metadata.description = "Test session".to_string(); - metadata.total_tokens = Some(1000); - } - let todo_state = TodoState::new("Initial TODO".to_string()); - todo_state - .to_extension_data(&mut metadata.extension_data) - .unwrap(); - - goose::session::storage::update_metadata(&session_path, &metadata) - .await - .unwrap(); - - // Update only TODO content - let todo_state_updated = TodoState::new("Updated TODO".to_string()); - todo_state_updated - .to_extension_data(&mut metadata.extension_data) - .unwrap(); - goose::session::storage::update_metadata(&session_path, &metadata) - .await - .unwrap(); - - // Verify other fields are preserved - let metadata_read = goose::session::storage::read_metadata(&session_path).unwrap(); - assert_eq!(metadata_read.message_count, 5); - assert_eq!(metadata_read.description, "Test session"); - assert_eq!(metadata_read.total_tokens, Some(1000)); - let todo_state_read = TodoState::from_extension_data(&metadata_read.extension_data).unwrap(); - assert_eq!(todo_state_read.content, "Updated TODO"); -} diff --git a/ui/desktop/openapi.json b/ui/desktop/openapi.json index bb6a08ba4717..552348fb461a 100644 --- a/ui/desktop/openapi.json +++ b/ui/desktop/openapi.json @@ -107,7 +107,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/StartAgentResponse" + "$ref": "#/components/schemas/Session" } } } @@ -185,7 +185,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/StartAgentResponse" + "$ref": "#/components/schemas/Session" } } } @@ -1500,12 +1500,43 @@ ] } }, + "/sessions/insights": { + "get": { + "tags": [ + "Session Management" + ], + "operationId": "get_session_insights", + "responses": { + "200": { + "description": "Session insights retrieved successfully", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SessionInsights" + } + } + } + }, + "401": { + "description": "Unauthorized - Invalid or missing API key" + }, + "500": { + "description": "Internal server error" + } + }, + "security": [ + { + "api_key": [] + } + ] + } + }, "/sessions/{session_id}": { "get": { "tags": [ "Session Management" ], - "operationId": "get_session_history", + "operationId": "get_session", "parameters": [ { "name": "session_id", @@ -1523,7 +1554,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/SessionHistoryResponse" + "$ref": "#/components/schemas/Session" } } } @@ -1543,6 +1574,93 @@ "api_key": [] } ] + }, + "delete": { + "tags": [ + "Session Management" + ], + "operationId": "delete_session", + "parameters": [ + { + "name": "session_id", + "in": "path", + "description": "Unique identifier for the session", + "required": true, + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "description": "Session deleted successfully" + }, + "401": { + "description": "Unauthorized - Invalid or missing API key" + }, + "404": { + "description": "Session not found" + }, + "500": { + "description": "Internal server error" + } + }, + "security": [ + { + "api_key": [] + } + ] + } + }, + "/sessions/{session_id}/description": { + "put": { + "tags": [ + "Session Management" + ], + "operationId": "update_session_description", + "parameters": [ + { + "name": "session_id", + "in": "path", + "description": "Unique identifier for the session", + "required": true, + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UpdateSessionDescriptionRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Session description updated successfully" + }, + "400": { + "description": "Bad request - Description too long (max 200 characters)" + }, + "401": { + "description": "Unauthorized - Invalid or missing API key" + }, + "404": { + "description": "Session not found" + }, + "500": { + "description": "Internal server error" + } + }, + "security": [ + { + "api_key": [] + } + ] } }, "/status": { @@ -1815,6 +1933,12 @@ } } }, + "Conversation": { + "type": "array", + "items": { + "$ref": "#/components/schemas/Message" + } + }, "CreateCustomProviderRequest": { "type": "object", "required": [ @@ -3426,6 +3550,92 @@ } } }, + "Session": { + "type": "object", + "required": [ + "id", + "working_dir", + "description", + "created_at", + "updated_at", + "extension_data", + "message_count" + ], + "properties": { + "accumulated_input_tokens": { + "type": "integer", + "format": "int32", + "nullable": true + }, + "accumulated_output_tokens": { + "type": "integer", + "format": "int32", + "nullable": true + }, + "accumulated_total_tokens": { + "type": "integer", + "format": "int32", + "nullable": true + }, + "conversation": { + "allOf": [ + { + "$ref": "#/components/schemas/Conversation" + } + ], + "nullable": true + }, + "created_at": { + "type": "string" + }, + "description": { + "type": "string" + }, + "extension_data": { + "$ref": "#/components/schemas/ExtensionData" + }, + "id": { + "type": "string" + }, + "input_tokens": { + "type": "integer", + "format": "int32", + "nullable": true + }, + "message_count": { + "type": "integer", + "minimum": 0 + }, + "output_tokens": { + "type": "integer", + "format": "int32", + "nullable": true + }, + "recipe": { + "allOf": [ + { + "$ref": "#/components/schemas/Recipe" + } + ], + "nullable": true + }, + "schedule_id": { + "type": "string", + "nullable": true + }, + "total_tokens": { + "type": "integer", + "format": "int32", + "nullable": true + }, + "updated_at": { + "type": "string" + }, + "working_dir": { + "type": "string" + } + } + }, "SessionConfigRequest": { "type": "object", "required": [ @@ -3507,50 +3717,22 @@ } } }, - "SessionHistoryResponse": { - "type": "object", - "required": [ - "sessionId", - "metadata", - "messages" - ], - "properties": { - "messages": { - "type": "array", - "items": { - "$ref": "#/components/schemas/Message" - }, - "description": "List of messages in the session conversation" - }, - "metadata": { - "$ref": "#/components/schemas/SessionMetadata" - }, - "sessionId": { - "type": "string", - "description": "Unique identifier for the session" - } - } - }, - "SessionInfo": { + "SessionInsights": { "type": "object", "required": [ - "id", - "path", - "modified", - "metadata" + "totalSessions", + "totalTokens" ], "properties": { - "id": { - "type": "string" - }, - "metadata": { - "$ref": "#/components/schemas/SessionMetadata" - }, - "modified": { - "type": "string" + "totalSessions": { + "type": "integer", + "description": "Total number of sessions", + "minimum": 0 }, - "path": { - "type": "string" + "totalTokens": { + "type": "integer", + "format": "int64", + "description": "Total tokens used across all sessions" } } }, @@ -3563,89 +3745,12 @@ "sessions": { "type": "array", "items": { - "$ref": "#/components/schemas/SessionInfo" + "$ref": "#/components/schemas/Session" }, "description": "List of available session information objects" } } }, - "SessionMetadata": { - "type": "object", - "description": "Metadata for a session, stored as the first line in the session file", - "required": [ - "working_dir", - "description", - "message_count" - ], - "properties": { - "accumulated_input_tokens": { - "type": "integer", - "format": "int32", - "description": "The number of input tokens used in the session. Accumulated across all messages.", - "nullable": true - }, - "accumulated_output_tokens": { - "type": "integer", - "format": "int32", - "description": "The number of output tokens used in the session. Accumulated across all messages.", - "nullable": true - }, - "accumulated_total_tokens": { - "type": "integer", - "format": "int32", - "description": "The total number of tokens used in the session. Accumulated across all messages (useful for tracking cost over an entire session).", - "nullable": true - }, - "description": { - "type": "string", - "description": "A short description of the session, typically 3 words or less" - }, - "extension_data": { - "$ref": "#/components/schemas/ExtensionData" - }, - "input_tokens": { - "type": "integer", - "format": "int32", - "description": "The number of input tokens used in the session. Retrieved from the provider's last usage.", - "nullable": true - }, - "message_count": { - "type": "integer", - "description": "Number of messages in the session", - "minimum": 0 - }, - "output_tokens": { - "type": "integer", - "format": "int32", - "description": "The number of output tokens used in the session. Retrieved from the provider's last usage.", - "nullable": true - }, - "recipe": { - "allOf": [ - { - "$ref": "#/components/schemas/Recipe" - } - ], - "nullable": true - }, - "schedule_id": { - "type": "string", - "description": "ID of the schedule that triggered this session, if any", - "nullable": true - }, - "total_tokens": { - "type": "integer", - "format": "int32", - "description": "The total number of tokens used in the session. Retrieved from the provider's last usage.", - "nullable": true - }, - "working_dir": { - "type": "string", - "description": "Working directory for the session", - "example": "/home/user/sessions/session1" - } - } - }, "SessionsQuery": { "type": "object", "properties": { @@ -3708,28 +3813,6 @@ } } }, - "StartAgentResponse": { - "type": "object", - "required": [ - "session_id", - "metadata", - "messages" - ], - "properties": { - "messages": { - "type": "array", - "items": { - "$ref": "#/components/schemas/Message" - } - }, - "metadata": { - "$ref": "#/components/schemas/SessionMetadata" - }, - "session_id": { - "type": "string" - } - } - }, "SubRecipe": { "type": "object", "required": [ @@ -4029,6 +4112,18 @@ } } }, + "UpdateSessionDescriptionRequest": { + "type": "object", + "required": [ + "description" + ], + "properties": { + "description": { + "type": "string", + "description": "Updated description (name) for the session (max 200 characters)" + } + } + }, "UpsertConfigQuery": { "type": "object", "required": [ diff --git a/ui/desktop/src/App.tsx b/ui/desktop/src/App.tsx index 17a6f9bde9dd..12aa2fe37e32 100644 --- a/ui/desktop/src/App.tsx +++ b/ui/desktop/src/App.tsx @@ -15,7 +15,6 @@ import { ExtensionInstallModal } from './components/ExtensionInstallModal'; import { ToastContainer } from 'react-toastify'; import { GoosehintsModal } from './components/GoosehintsModal'; import AnnouncementModal from './components/AnnouncementModal'; -import { generateSessionId } from './sessions'; import ProviderGuard from './components/ProviderGuard'; import { ChatType } from './types/chat'; @@ -321,7 +320,7 @@ export function AppInner() { const [_searchParams, setSearchParams] = useSearchParams(); const [chat, setChat] = useState({ - sessionId: generateSessionId(), + sessionId: '', title: 'Pair Chat', messages: [], messageHistoryIndex: 0, diff --git a/ui/desktop/src/api/sdk.gen.ts b/ui/desktop/src/api/sdk.gen.ts index 39ac4c781667..7be7d8b9dac3 100644 --- a/ui/desktop/src/api/sdk.gen.ts +++ b/ui/desktop/src/api/sdk.gen.ts @@ -1,7 +1,7 @@ // This file is auto-generated by @hey-api/openapi-ts import type { Options as ClientOptions, TDataShape, Client } from './client'; -import type { AddSubRecipesData, AddSubRecipesResponses, AddSubRecipesErrors, ExtendPromptData, ExtendPromptResponses, ExtendPromptErrors, ResumeAgentData, ResumeAgentResponses, ResumeAgentErrors, UpdateSessionConfigData, UpdateSessionConfigResponses, UpdateSessionConfigErrors, StartAgentData, StartAgentResponses, StartAgentErrors, GetToolsData, GetToolsResponses, GetToolsErrors, UpdateAgentProviderData, UpdateAgentProviderResponses, UpdateAgentProviderErrors, UpdateRouterToolSelectorData, UpdateRouterToolSelectorResponses, UpdateRouterToolSelectorErrors, ReadAllConfigData, ReadAllConfigResponses, BackupConfigData, BackupConfigResponses, BackupConfigErrors, CreateCustomProviderData, CreateCustomProviderResponses, CreateCustomProviderErrors, RemoveCustomProviderData, RemoveCustomProviderResponses, RemoveCustomProviderErrors, GetExtensionsData, GetExtensionsResponses, GetExtensionsErrors, AddExtensionData, AddExtensionResponses, AddExtensionErrors, RemoveExtensionData, RemoveExtensionResponses, RemoveExtensionErrors, InitConfigData, InitConfigResponses, InitConfigErrors, UpsertPermissionsData, UpsertPermissionsResponses, UpsertPermissionsErrors, ProvidersData, ProvidersResponses, GetProviderModelsData, GetProviderModelsResponses, GetProviderModelsErrors, ReadConfigData, ReadConfigResponses, ReadConfigErrors, RecoverConfigData, RecoverConfigResponses, RecoverConfigErrors, RemoveConfigData, RemoveConfigResponses, RemoveConfigErrors, UpsertConfigData, UpsertConfigResponses, UpsertConfigErrors, ValidateConfigData, ValidateConfigResponses, ValidateConfigErrors, ConfirmPermissionData, ConfirmPermissionResponses, ConfirmPermissionErrors, ManageContextData, ManageContextResponses, ManageContextErrors, StartOpenrouterSetupData, StartOpenrouterSetupResponses, StartTetrateSetupData, StartTetrateSetupResponses, CreateRecipeData, CreateRecipeResponses, CreateRecipeErrors, DecodeRecipeData, DecodeRecipeResponses, DecodeRecipeErrors, DeleteRecipeData, DeleteRecipeResponses, DeleteRecipeErrors, EncodeRecipeData, EncodeRecipeResponses, EncodeRecipeErrors, ListRecipesData, ListRecipesResponses, ListRecipesErrors, ScanRecipeData, ScanRecipeResponses, CreateScheduleData, CreateScheduleResponses, CreateScheduleErrors, DeleteScheduleData, DeleteScheduleResponses, DeleteScheduleErrors, ListSchedulesData, ListSchedulesResponses, ListSchedulesErrors, UpdateScheduleData, UpdateScheduleResponses, UpdateScheduleErrors, InspectRunningJobData, InspectRunningJobResponses, InspectRunningJobErrors, KillRunningJobData, KillRunningJobResponses, PauseScheduleData, PauseScheduleResponses, PauseScheduleErrors, RunNowHandlerData, RunNowHandlerResponses, RunNowHandlerErrors, SessionsHandlerData, SessionsHandlerResponses, SessionsHandlerErrors, UnpauseScheduleData, UnpauseScheduleResponses, UnpauseScheduleErrors, ListSessionsData, ListSessionsResponses, ListSessionsErrors, GetSessionHistoryData, GetSessionHistoryResponses, GetSessionHistoryErrors, StatusData, StatusResponses } from './types.gen'; +import type { AddSubRecipesData, AddSubRecipesResponses, AddSubRecipesErrors, ExtendPromptData, ExtendPromptResponses, ExtendPromptErrors, ResumeAgentData, ResumeAgentResponses, ResumeAgentErrors, UpdateSessionConfigData, UpdateSessionConfigResponses, UpdateSessionConfigErrors, StartAgentData, StartAgentResponses, StartAgentErrors, GetToolsData, GetToolsResponses, GetToolsErrors, UpdateAgentProviderData, UpdateAgentProviderResponses, UpdateAgentProviderErrors, UpdateRouterToolSelectorData, UpdateRouterToolSelectorResponses, UpdateRouterToolSelectorErrors, ReadAllConfigData, ReadAllConfigResponses, BackupConfigData, BackupConfigResponses, BackupConfigErrors, CreateCustomProviderData, CreateCustomProviderResponses, CreateCustomProviderErrors, RemoveCustomProviderData, RemoveCustomProviderResponses, RemoveCustomProviderErrors, GetExtensionsData, GetExtensionsResponses, GetExtensionsErrors, AddExtensionData, AddExtensionResponses, AddExtensionErrors, RemoveExtensionData, RemoveExtensionResponses, RemoveExtensionErrors, InitConfigData, InitConfigResponses, InitConfigErrors, UpsertPermissionsData, UpsertPermissionsResponses, UpsertPermissionsErrors, ProvidersData, ProvidersResponses, GetProviderModelsData, GetProviderModelsResponses, GetProviderModelsErrors, ReadConfigData, ReadConfigResponses, ReadConfigErrors, RecoverConfigData, RecoverConfigResponses, RecoverConfigErrors, RemoveConfigData, RemoveConfigResponses, RemoveConfigErrors, UpsertConfigData, UpsertConfigResponses, UpsertConfigErrors, ValidateConfigData, ValidateConfigResponses, ValidateConfigErrors, ConfirmPermissionData, ConfirmPermissionResponses, ConfirmPermissionErrors, ManageContextData, ManageContextResponses, ManageContextErrors, StartOpenrouterSetupData, StartOpenrouterSetupResponses, StartTetrateSetupData, StartTetrateSetupResponses, CreateRecipeData, CreateRecipeResponses, CreateRecipeErrors, DecodeRecipeData, DecodeRecipeResponses, DecodeRecipeErrors, DeleteRecipeData, DeleteRecipeResponses, DeleteRecipeErrors, EncodeRecipeData, EncodeRecipeResponses, EncodeRecipeErrors, ListRecipesData, ListRecipesResponses, ListRecipesErrors, ScanRecipeData, ScanRecipeResponses, CreateScheduleData, CreateScheduleResponses, CreateScheduleErrors, DeleteScheduleData, DeleteScheduleResponses, DeleteScheduleErrors, ListSchedulesData, ListSchedulesResponses, ListSchedulesErrors, UpdateScheduleData, UpdateScheduleResponses, UpdateScheduleErrors, InspectRunningJobData, InspectRunningJobResponses, InspectRunningJobErrors, KillRunningJobData, KillRunningJobResponses, PauseScheduleData, PauseScheduleResponses, PauseScheduleErrors, RunNowHandlerData, RunNowHandlerResponses, RunNowHandlerErrors, SessionsHandlerData, SessionsHandlerResponses, SessionsHandlerErrors, UnpauseScheduleData, UnpauseScheduleResponses, UnpauseScheduleErrors, ListSessionsData, ListSessionsResponses, ListSessionsErrors, GetSessionInsightsData, GetSessionInsightsResponses, GetSessionInsightsErrors, DeleteSessionData, DeleteSessionResponses, DeleteSessionErrors, GetSessionData, GetSessionResponses, GetSessionErrors, UpdateSessionDescriptionData, UpdateSessionDescriptionResponses, UpdateSessionDescriptionErrors, StatusData, StatusResponses } from './types.gen'; import { client as _heyApiClient } from './client.gen'; export type Options = ClientOptions & { @@ -424,13 +424,38 @@ export const listSessions = (options?: Opt }); }; -export const getSessionHistory = (options: Options) => { - return (options.client ?? _heyApiClient).get({ +export const getSessionInsights = (options?: Options) => { + return (options?.client ?? _heyApiClient).get({ + url: '/sessions/insights', + ...options + }); +}; + +export const deleteSession = (options: Options) => { + return (options.client ?? _heyApiClient).delete({ + url: '/sessions/{session_id}', + ...options + }); +}; + +export const getSession = (options: Options) => { + return (options.client ?? _heyApiClient).get({ url: '/sessions/{session_id}', ...options }); }; +export const updateSessionDescription = (options: Options) => { + return (options.client ?? _heyApiClient).put({ + url: '/sessions/{session_id}/description', + ...options, + headers: { + 'Content-Type': 'application/json', + ...options.headers + } + }); +}; + export const status = (options?: Options) => { return (options?.client ?? _heyApiClient).get({ url: '/status', diff --git a/ui/desktop/src/api/types.gen.ts b/ui/desktop/src/api/types.gen.ts index f05649ef8f7f..5bde90797027 100644 --- a/ui/desktop/src/api/types.gen.ts +++ b/ui/desktop/src/api/types.gen.ts @@ -103,6 +103,8 @@ export type ContextManageResponse = { tokenCounts: Array; }; +export type Conversation = Array; + export type CreateCustomProviderRequest = { api_key: string; api_url: string; @@ -698,6 +700,25 @@ export type ScheduledJob = { source: string; }; +export type Session = { + accumulated_input_tokens?: number | null; + accumulated_output_tokens?: number | null; + accumulated_total_tokens?: number | null; + conversation?: Conversation | null; + created_at: string; + description: string; + extension_data: ExtensionData; + id: string; + input_tokens?: number | null; + message_count: number; + output_tokens?: number | null; + recipe?: Recipe | null; + schedule_id?: string | null; + total_tokens?: number | null; + updated_at: string; + working_dir: string; +}; + export type SessionConfigRequest = { response?: Response | null; session_id: string; @@ -718,78 +739,22 @@ export type SessionDisplayInfo = { workingDir: string; }; -export type SessionHistoryResponse = { +export type SessionInsights = { /** - * List of messages in the session conversation + * Total number of sessions */ - messages: Array; - metadata: SessionMetadata; + totalSessions: number; /** - * Unique identifier for the session + * Total tokens used across all sessions */ - sessionId: string; -}; - -export type SessionInfo = { - id: string; - metadata: SessionMetadata; - modified: string; - path: string; + totalTokens: number; }; export type SessionListResponse = { /** * List of available session information objects */ - sessions: Array; -}; - -/** - * Metadata for a session, stored as the first line in the session file - */ -export type SessionMetadata = { - /** - * The number of input tokens used in the session. Accumulated across all messages. - */ - accumulated_input_tokens?: number | null; - /** - * The number of output tokens used in the session. Accumulated across all messages. - */ - accumulated_output_tokens?: number | null; - /** - * The total number of tokens used in the session. Accumulated across all messages (useful for tracking cost over an entire session). - */ - accumulated_total_tokens?: number | null; - /** - * A short description of the session, typically 3 words or less - */ - description: string; - extension_data?: ExtensionData; - /** - * The number of input tokens used in the session. Retrieved from the provider's last usage. - */ - input_tokens?: number | null; - /** - * Number of messages in the session - */ - message_count: number; - /** - * The number of output tokens used in the session. Retrieved from the provider's last usage. - */ - output_tokens?: number | null; - recipe?: Recipe | null; - /** - * ID of the schedule that triggered this session, if any - */ - schedule_id?: string | null; - /** - * The total number of tokens used in the session. Retrieved from the provider's last usage. - */ - total_tokens?: number | null; - /** - * Working directory for the session - */ - working_dir: string; + sessions: Array; }; export type SessionsQuery = { @@ -812,12 +777,6 @@ export type StartAgentRequest = { working_dir: string; }; -export type StartAgentResponse = { - messages: Array; - metadata: SessionMetadata; - session_id: string; -}; - export type SubRecipe = { description?: string | null; name: string; @@ -930,6 +889,13 @@ export type UpdateScheduleRequest = { cron: string; }; +export type UpdateSessionDescriptionRequest = { + /** + * Updated description (name) for the session (max 200 characters) + */ + description: string; +}; + export type UpsertConfigQuery = { is_secret: boolean; key: string; @@ -1020,7 +986,7 @@ export type ResumeAgentResponses = { /** * Agent started successfully */ - 200: StartAgentResponse; + 200: Session; }; export type ResumeAgentResponse = ResumeAgentResponses[keyof ResumeAgentResponses]; @@ -1082,10 +1048,10 @@ export type StartAgentResponses = { /** * Agent started successfully */ - 200: StartAgentResponse; + 200: Session; }; -export type StartAgentResponse2 = StartAgentResponses[keyof StartAgentResponses]; +export type StartAgentResponse = StartAgentResponses[keyof StartAgentResponses]; export type GetToolsData = { body?: never; @@ -2144,7 +2110,34 @@ export type ListSessionsResponses = { export type ListSessionsResponse = ListSessionsResponses[keyof ListSessionsResponses]; -export type GetSessionHistoryData = { +export type GetSessionInsightsData = { + body?: never; + path?: never; + query?: never; + url: '/sessions/insights'; +}; + +export type GetSessionInsightsErrors = { + /** + * Unauthorized - Invalid or missing API key + */ + 401: unknown; + /** + * Internal server error + */ + 500: unknown; +}; + +export type GetSessionInsightsResponses = { + /** + * Session insights retrieved successfully + */ + 200: SessionInsights; +}; + +export type GetSessionInsightsResponse = GetSessionInsightsResponses[keyof GetSessionInsightsResponses]; + +export type DeleteSessionData = { body?: never; path: { /** @@ -2156,7 +2149,7 @@ export type GetSessionHistoryData = { url: '/sessions/{session_id}'; }; -export type GetSessionHistoryErrors = { +export type DeleteSessionErrors = { /** * Unauthorized - Invalid or missing API key */ @@ -2171,14 +2164,86 @@ export type GetSessionHistoryErrors = { 500: unknown; }; -export type GetSessionHistoryResponses = { +export type DeleteSessionResponses = { + /** + * Session deleted successfully + */ + 200: unknown; +}; + +export type GetSessionData = { + body?: never; + path: { + /** + * Unique identifier for the session + */ + session_id: string; + }; + query?: never; + url: '/sessions/{session_id}'; +}; + +export type GetSessionErrors = { + /** + * Unauthorized - Invalid or missing API key + */ + 401: unknown; + /** + * Session not found + */ + 404: unknown; + /** + * Internal server error + */ + 500: unknown; +}; + +export type GetSessionResponses = { /** * Session history retrieved successfully */ - 200: SessionHistoryResponse; + 200: Session; +}; + +export type GetSessionResponse = GetSessionResponses[keyof GetSessionResponses]; + +export type UpdateSessionDescriptionData = { + body: UpdateSessionDescriptionRequest; + path: { + /** + * Unique identifier for the session + */ + session_id: string; + }; + query?: never; + url: '/sessions/{session_id}/description'; +}; + +export type UpdateSessionDescriptionErrors = { + /** + * Bad request - Description too long (max 200 characters) + */ + 400: unknown; + /** + * Unauthorized - Invalid or missing API key + */ + 401: unknown; + /** + * Session not found + */ + 404: unknown; + /** + * Internal server error + */ + 500: unknown; }; -export type GetSessionHistoryResponse = GetSessionHistoryResponses[keyof GetSessionHistoryResponses]; +export type UpdateSessionDescriptionResponses = { + /** + * Session description updated successfully + */ + 200: unknown; +}; export type StatusData = { body?: never; diff --git a/ui/desktop/src/components/BaseChat.tsx b/ui/desktop/src/components/BaseChat.tsx index be8728e3b5ad..66aeb5cfe75b 100644 --- a/ui/desktop/src/components/BaseChat.tsx +++ b/ui/desktop/src/components/BaseChat.tsx @@ -265,7 +265,7 @@ function BaseChatContent({ sessionOutputTokens, localInputTokens, localOutputTokens, - sessionMetadata, + session: sessionMetadata, }); useEffect(() => { diff --git a/ui/desktop/src/components/ExtensionInstallModal.test.tsx b/ui/desktop/src/components/ExtensionInstallModal.test.tsx index 9d65635882cd..3720271cf8ae 100644 --- a/ui/desktop/src/components/ExtensionInstallModal.test.tsx +++ b/ui/desktop/src/components/ExtensionInstallModal.test.tsx @@ -94,8 +94,7 @@ describe('ExtensionInstallModal', () => { expect(screen.getAllByRole('button')).toHaveLength(3); }); - - it("should handle i-ching-mcp-server as allowed command", async () => { + it('should handle i-ching-mcp-server as allowed command', async () => { mockElectron.getAllowedExtensions.mockResolvedValue([]); render(); @@ -103,13 +102,16 @@ describe('ExtensionInstallModal', () => { const eventHandler = getAddExtensionEventHandler(); await act(async () => { - await eventHandler({}, "goose://extension?cmd=i-ching-mcp-server&id=i-ching&name=I%20Ching&description=I%20Ching%20divination"); + await eventHandler( + {}, + 'goose://extension?cmd=i-ching-mcp-server&id=i-ching&name=I%20Ching&description=I%20Ching%20divination' + ); }); - expect(screen.getByRole("dialog")).toBeInTheDocument(); - expect(screen.getByText("Confirm Extension Installation")).toBeInTheDocument(); + expect(screen.getByRole('dialog')).toBeInTheDocument(); + expect(screen.getByText('Confirm Extension Installation')).toBeInTheDocument(); expect(screen.getByText(/I Ching extension/)).toBeInTheDocument(); - expect(screen.getAllByRole("button")).toHaveLength(3); + expect(screen.getAllByRole('button')).toHaveLength(3); }); it('should handle blocked extension', async () => { mockElectron.getAllowedExtensions.mockResolvedValue(['uvx allowed-package']); diff --git a/ui/desktop/src/components/GooseSidebar/SessionsSection.tsx b/ui/desktop/src/components/GooseSidebar/SessionsSection.tsx deleted file mode 100644 index bd144dbbb96e..000000000000 --- a/ui/desktop/src/components/GooseSidebar/SessionsSection.tsx +++ /dev/null @@ -1,356 +0,0 @@ -import React, { useEffect, useState, useCallback, useRef } from 'react'; -import { Search, ChevronDown, Folder, Loader2 } from 'lucide-react'; -import { fetchSessions, type Session } from '../../sessions'; -import { Input } from '../ui/input'; -import { - SidebarMenu, - SidebarMenuItem, - SidebarMenuButton, - SidebarGroup, - SidebarGroupLabel, - SidebarGroupContent, -} from '../ui/sidebar'; -import { Collapsible, CollapsibleContent, CollapsibleTrigger } from '../ui/collapsible'; -import { useTextAnimator } from '../../hooks/use-text-animator'; - -interface SessionsSectionProps { - onSelectSession: (sessionId: string) => void; - refreshTrigger?: number; -} - -interface GroupedSessions { - today: Session[]; - yesterday: Session[]; - older: { [key: string]: Session[] }; -} - -export const SessionsSection: React.FC = ({ - onSelectSession, - refreshTrigger, -}) => { - const [sessions, setSessions] = useState([]); - const [searchTerm, setSearchTerm] = useState(''); - const [groupedSessions, setGroupedSessions] = useState({ - today: [], - yesterday: [], - older: {}, - }); - const [sessionsWithDescriptions, setSessionsWithDescriptions] = useState>(new Set()); - - const refreshTimeoutRef = useRef | null>(null); - - const groupSessions = useCallback((sessionsToGroup: Session[]) => { - const now = new Date(); - const today = new Date(now.getFullYear(), now.getMonth(), now.getDate()); - const yesterday = new Date(today); - yesterday.setDate(yesterday.getDate() - 1); - - const grouped: GroupedSessions = { - today: [], - yesterday: [], - older: {}, - }; - - sessionsToGroup.forEach((session) => { - const sessionDate = new Date(session.modified); - const sessionDateOnly = new Date( - sessionDate.getFullYear(), - sessionDate.getMonth(), - sessionDate.getDate() - ); - - if (sessionDateOnly.getTime() === today.getTime()) { - grouped.today.push(session); - } else if (sessionDateOnly.getTime() === yesterday.getTime()) { - grouped.yesterday.push(session); - } else { - const dateKey = sessionDateOnly.toISOString().split('T')[0]; - if (!grouped.older[dateKey]) { - grouped.older[dateKey] = []; - } - grouped.older[dateKey].push(session); - } - }); - - // Sort older sessions by date (newest first) - const sortedOlder: { [key: string]: Session[] } = {}; - Object.keys(grouped.older) - .sort() - .reverse() - .forEach((key) => { - sortedOlder[key] = grouped.older[key]; - }); - - grouped.older = sortedOlder; - setGroupedSessions(grouped); - }, []); - - const loadSessions = useCallback(async () => { - try { - const sessions = await fetchSessions(); - setSessions(sessions); - groupSessions(sessions); - } catch (err) { - console.error('Failed to load sessions:', err); - setSessions([]); - setGroupedSessions({ today: [], yesterday: [], older: {} }); - } - }, [groupSessions]); - - // Debounced refresh function - const debouncedRefresh = useCallback(() => { - console.log('SessionsSection: Debounced refresh triggered'); - // Clear any existing timeout - if (refreshTimeoutRef.current) { - window.clearTimeout(refreshTimeoutRef.current); - } - - // Set new timeout - reduced to 200ms for faster response - refreshTimeoutRef.current = setTimeout(() => { - console.log('SessionsSection: Executing debounced refresh'); - loadSessions(); - refreshTimeoutRef.current = null; - }, 200); - }, [loadSessions]); - - // Cleanup timeout on unmount - useEffect(() => { - return () => { - if (refreshTimeoutRef.current) { - window.clearTimeout(refreshTimeoutRef.current); - } - }; - }, []); - - useEffect(() => { - console.log('SessionsSection: Initial load'); - loadSessions(); - }, [loadSessions]); - - // Add effect to refresh sessions when refreshTrigger changes - useEffect(() => { - if (refreshTrigger) { - console.log('SessionsSection: Refresh trigger changed, triggering refresh'); - debouncedRefresh(); - } - }, [refreshTrigger, debouncedRefresh]); - - // Add effect to listen for session creation events - useEffect(() => { - const handleSessionCreated = () => { - console.log('SessionsSection: Session created event received'); - debouncedRefresh(); - }; - - const handleMessageStreamFinish = () => { - console.log('SessionsSection: Message stream finished event received'); - // Always refresh when message stream finishes - debouncedRefresh(); - }; - - // Listen for custom events that indicate a session was created - window.addEventListener('session-created', handleSessionCreated); - - // Also listen for message stream finish events - window.addEventListener('message-stream-finished', handleMessageStreamFinish); - - return () => { - window.removeEventListener('session-created', handleSessionCreated); - window.removeEventListener('message-stream-finished', handleMessageStreamFinish); - }; - }, [debouncedRefresh]); - - useEffect(() => { - if (searchTerm) { - const filtered = sessions.filter((session) => - (session.metadata.description || session.id) - .toLowerCase() - .includes(searchTerm.toLowerCase()) - ); - groupSessions(filtered); - } else { - groupSessions(sessions); - } - }, [searchTerm, sessions, groupSessions]); - - // Component for individual session items with loading and animation states - const SessionItem = ({ session }: { session: Session }) => { - const hasDescription = - session.metadata.description && session.metadata.description.trim() !== ''; - const isNewSession = session.id.match(/^\d{8}_\d{6}$/); - const messageCount = session.metadata.message_count || 0; - // Show loading for new sessions with few messages and no description - // Only show loading for sessions created in the last 5 minutes - const sessionDate = new Date(session.modified); - const fiveMinutesAgo = new Date(Date.now() - 5 * 60 * 1000); - const isRecentSession = sessionDate > fiveMinutesAgo; - const shouldShowLoading = - !hasDescription && isNewSession && messageCount <= 2 && isRecentSession; - const [isAnimating, setIsAnimating] = useState(false); - - // Use text animator only for sessions that need animation - const descriptionRef = useTextAnimator({ - text: isAnimating ? session.metadata.description : '', - }); - - // Track when description becomes available and trigger animation - useEffect(() => { - if (hasDescription && !sessionsWithDescriptions.has(session.id)) { - setSessionsWithDescriptions((prev) => new Set(prev).add(session.id)); - - // Only animate for new sessions that were showing loading - if (shouldShowLoading) { - setIsAnimating(true); - } - } - }, [hasDescription, session.id, shouldShowLoading]); - - const handleClick = () => { - console.log('SessionItem: Clicked on session:', session.id); - onSelectSession(session.id); - }; - - return ( - - -
-
- {shouldShowLoading ? ( -
- - Generating description... -
- ) : ( - - {hasDescription ? session.metadata.description : `Session ${session.id}`} - - )} -
-
- - {session.metadata.working_dir} -
-
-
-
- ); - }; - - const renderSessionGroup = (sessions: Session[], title: string, index: number) => { - if (sessions.length === 0) return null; - - const isFirstTwoGroups = index < 2; - - return ( - - - - -
- - {title} - -
- -
-
- - - - {sessions.map((session, sessionIndex) => ( -
- -
- ))} -
-
-
-
-
- ); - }; - - return ( - - - - -
- Sessions -
- -
-
- - - {/* Search Input */} -
-
- - ) => - setSearchTerm(e.target.value) - } - /> -
-
- - {/* Sessions Groups */} -
- {(() => { - let groupIndex = 0; - const groups = [ - { sessions: groupedSessions.today, title: 'Today' }, - { sessions: groupedSessions.yesterday, title: 'Yesterday' }, - ...Object.entries(groupedSessions.older).map(([date, sessions]) => ({ - sessions, - title: new Date(date).toLocaleDateString('en-US', { - weekday: 'long', - year: 'numeric', - month: 'long', - day: 'numeric', - }), - })), - ]; - - return groups.map(({ sessions, title }) => { - if (sessions.length === 0) return null; - const currentIndex = groupIndex++; - return ( -
- {renderSessionGroup(sessions, title, currentIndex)} -
- ); - }); - })()} -
-
-
-
-
- ); -}; diff --git a/ui/desktop/src/components/schedule/ScheduleDetailView.tsx b/ui/desktop/src/components/schedule/ScheduleDetailView.tsx index e3be8809bffe..5f4e8cbc4040 100644 --- a/ui/desktop/src/components/schedule/ScheduleDetailView.tsx +++ b/ui/desktop/src/components/schedule/ScheduleDetailView.tsx @@ -3,7 +3,6 @@ import { Button } from '../ui/button'; import { ScrollArea } from '../ui/scroll-area'; import BackButton from '../ui/BackButton'; import { Card } from '../ui/card'; -import { fetchSessionDetails, SessionDetails } from '../../sessions'; import { getScheduleSessions, runScheduleNow, @@ -21,6 +20,7 @@ import { toastError, toastSuccess } from '../../toasts'; import { Loader2, Pause, Play, Edit, Square, Eye } from 'lucide-react'; import cronstrue from 'cronstrue'; import { formatToLocalDateWithTimezone } from '../../utils/date'; +import { getSession, Session } from '../../api'; interface ScheduleSessionMeta { id: string; @@ -146,7 +146,7 @@ const ScheduleDetailView: React.FC = ({ scheduleId, onN // Track if we explicitly killed a job to distinguish from natural completion const [jobWasKilled, setJobWasKilled] = useState(false); - const [selectedSessionDetails, setSelectedSessionDetails] = useState(null); + const [selectedSessionDetails, setSelectedSessionDetails] = useState(null); const [isLoadingSessionDetails, setIsLoadingSessionDetails] = useState(false); const [sessionDetailsError, setSessionDetailsError] = useState(null); const [isEditModalOpen, setIsEditModalOpen] = useState(false); @@ -430,8 +430,11 @@ const ScheduleDetailView: React.FC = ({ scheduleId, onN setSessionDetailsError(null); setSelectedSessionDetails(null); try { - const details = await fetchSessionDetails(sessionId); - setSelectedSessionDetails(details); + const response = await getSession({ + path: { session_id: sessionId }, + throwOnError: true, + }); + setSelectedSessionDetails(response.data); } catch (err) { console.error(`Failed to load session details for ${sessionId}:`, err); const errorMsg = err instanceof Error ? err.message : 'Failed to load session details.'; @@ -459,7 +462,7 @@ const ScheduleDetailView: React.FC = ({ scheduleId, onN setSelectedSessionDetails(null); setSessionDetailsError(null); }} - onRetry={() => loadAndShowSessionDetails(selectedSessionDetails?.sessionId)} + onRetry={() => loadAndShowSessionDetails(selectedSessionDetails?.id)} showActionButtons={true} /> ); diff --git a/ui/desktop/src/components/sessions/SessionHistoryView.tsx b/ui/desktop/src/components/sessions/SessionHistoryView.tsx index c0b08df21d20..781e42d84137 100644 --- a/ui/desktop/src/components/sessions/SessionHistoryView.tsx +++ b/ui/desktop/src/components/sessions/SessionHistoryView.tsx @@ -11,7 +11,7 @@ import { LoaderCircle, AlertCircle, } from 'lucide-react'; -import { resumeSession, type SessionDetails } from '../../sessions'; +import { resumeSession } from '../../sessions'; import { Button } from '../ui/button'; import { toast } from 'react-toastify'; import { MainPanelLayout } from '../Layout/MainPanelLayout'; @@ -32,6 +32,8 @@ import { ContextManagerProvider } from '../context_management/ContextManager'; import { Message } from '../../types/message'; import BackButton from '../ui/BackButton'; import { Tooltip, TooltipContent, TooltipTrigger } from '../ui/Tooltip'; +import { Session } from '../../api'; +import { convertApiMessageToFrontendMessage } from '../context_management'; // Helper function to determine if a message is a user message (same as useChatEngine) const isUserMessage = (message: Message): boolean => { @@ -46,7 +48,7 @@ const filterMessagesForDisplay = (messages: Message[]): Message[] => { }; interface SessionHistoryViewProps { - session: SessionDetails; + session: Session; isLoading: boolean; error: string | null; onBack: () => void; @@ -73,14 +75,12 @@ const SessionHeader: React.FC<{ ); }; -// Session messages component that uses the same rendering as BaseChat const SessionMessages: React.FC<{ messages: Message[]; isLoading: boolean; error: string | null; onRetry: () => void; }> = ({ messages, isLoading, error, onRetry }) => { - // Filter messages for display (same as BaseChat) const filteredMessages = filterMessagesForDisplay(messages); return ( @@ -153,6 +153,8 @@ const SessionHistoryView: React.FC = ({ const [isCopied, setIsCopied] = useState(false); const [canShare, setCanShare] = useState(false); + const messages = (session.conversation || []).map(convertApiMessageToFrontendMessage); + useEffect(() => { const savedSessionConfig = localStorage.getItem('session_sharing_config'); if (savedSessionConfig) { @@ -183,10 +185,10 @@ const SessionHistoryView: React.FC = ({ const shareToken = await createSharedSession( config.baseUrl, - session.metadata.working_dir, - session.messages, - session.metadata.description || 'Shared Session', - session.metadata.total_tokens || 0 + session.working_dir, + messages, + session.description || 'Shared Session', + session.total_tokens || 0 ); const shareableLink = `goose://sessions/${shareToken}`; @@ -270,32 +272,32 @@ const SessionHistoryView: React.FC = ({
- {!isLoading && session.messages.length > 0 ? ( + {!isLoading ? ( <>
- {formatMessageTimestamp(session.messages[0]?.created)} + {formatMessageTimestamp(messages[0]?.created)} - {session.metadata.message_count} + {session.message_count} - {session.metadata.total_tokens !== null && ( + {session.total_tokens !== null && ( - {(session.metadata.total_tokens || 0).toLocaleString()} + {(session.total_tokens || 0).toLocaleString()} )}
- {session.metadata.working_dir} + {session.working_dir}
@@ -309,7 +311,7 @@ const SessionHistoryView: React.FC = ({ = ({ session, extraActions }) => { return (
-
{session.metadata.description || `Session ${session.id}`}
+
{session.description || `Session ${session.id}`}
- {formatDate(session.modified)} • {session.metadata.message_count} messages + {formatDate(session.updated_at)} • {session.message_count} messages
-
{session.metadata.working_dir}
+
{session.working_dir}
{extraActions &&
{extraActions}
}
diff --git a/ui/desktop/src/components/sessions/SessionListView.tsx b/ui/desktop/src/components/sessions/SessionListView.tsx index 935a47e5757d..e56842b9c99e 100644 --- a/ui/desktop/src/components/sessions/SessionListView.tsx +++ b/ui/desktop/src/components/sessions/SessionListView.tsx @@ -8,7 +8,6 @@ import { Edit2, Trash2, } from 'lucide-react'; -import { fetchSessions, updateSessionMetadata, deleteSession, type Session } from '../../sessions'; import { Card } from '../ui/card'; import { Button } from '../ui/button'; import { ScrollArea } from '../ui/scroll-area'; @@ -21,6 +20,7 @@ import { groupSessionsByDate, type DateGroup } from '../../utils/dateUtils'; import { Skeleton } from '../ui/skeleton'; import { toast } from 'react-toastify'; import { ConfirmationModal } from '../ui/ConfirmationModal'; +import { deleteSession, listSessions, Session, updateSessionDescription } from '../../api'; interface EditSessionModalProps { session: Session | null; @@ -37,7 +37,7 @@ const EditSessionModal = React.memo( useEffect(() => { if (session && isOpen) { - setDescription(session.metadata.description || session.id); + setDescription(session.description || session.id); } else if (!isOpen) { // Reset state when modal closes setDescription(''); @@ -49,14 +49,18 @@ const EditSessionModal = React.memo( if (!session || disabled) return; const trimmedDescription = description.trim(); - if (trimmedDescription === session.metadata.description) { + if (trimmedDescription === session.description) { onClose(); return; } setIsUpdating(true); try { - await updateSessionMetadata(session.id, trimmedDescription); + await updateSessionDescription({ + path: { session_id: session.id }, + body: { description: trimmedDescription }, + throwOnError: true, + }); await onSave(session.id, trimmedDescription); // Close modal, then show success toast on a timeout to let the UI update complete. @@ -68,8 +72,7 @@ const EditSessionModal = React.memo( const errorMessage = error instanceof Error ? error.message : 'Unknown error occurred'; console.error('Failed to update session description:', errorMessage); toast.error(`Failed to update session description: ${errorMessage}`); - // Reset to original description on error - setDescription(session.metadata.description || session.id); + setDescription(session.description || session.id); } finally { setIsUpdating(false); } @@ -213,7 +216,8 @@ const SessionListView: React.FC = React.memo( setShowContent(false); setError(null); try { - const sessions = await fetchSessions(); + const resp = await listSessions({ throwOnError: true }); + const sessions = resp.data.sessions; // Use startTransition to make state updates non-blocking startTransition(() => { setSessions(sessions); @@ -291,20 +295,20 @@ const SessionListView: React.FC = React.memo( startTransition(() => { const searchTerm = caseSensitive ? debouncedSearchTerm : debouncedSearchTerm.toLowerCase(); const filtered = sessions.filter((session) => { - const description = session.metadata.description || session.id; - const path = session.path; - const workingDir = session.metadata.working_dir; + const description = session.description || session.id; + const workingDir = session.working_dir; + const sessionId = session.id; if (caseSensitive) { return ( description.includes(searchTerm) || - path.includes(searchTerm) || + sessionId.includes(searchTerm) || workingDir.includes(searchTerm) ); } else { return ( description.toLowerCase().includes(searchTerm) || - path.toLowerCase().includes(searchTerm) || + sessionId.toLowerCase().includes(searchTerm) || workingDir.toLowerCase().includes(searchTerm) ); } @@ -355,11 +359,7 @@ const SessionListView: React.FC = React.memo( const handleModalSave = useCallback(async (sessionId: string, newDescription: string) => { // Update state immediately for optimistic UI setSessions((prevSessions) => - prevSessions.map((s) => - s.id === sessionId - ? { ...s, metadata: { ...s.metadata, description: newDescription } } - : s - ) + prevSessions.map((s) => (s.id === sessionId ? { ...s, description: newDescription } : s)) ); }, []); @@ -378,18 +378,21 @@ const SessionListView: React.FC = React.memo( setShowDeleteConfirmation(false); const sessionToDeleteId = sessionToDelete.id; - const sessionName = sessionToDelete.metadata.description || sessionToDelete.id; + const sessionName = sessionToDelete.description || sessionToDelete.id; setSessionToDelete(null); try { - await deleteSession(sessionToDeleteId); + await deleteSession({ + path: { session_id: sessionToDeleteId }, + throwOnError: true, + }); toast.success('Session deleted successfully'); - loadSessions(); } catch (error) { console.error('Error deleting session:', error); const errorMessage = error instanceof Error ? error.message : 'Unknown error'; toast.error(`Failed to delete session "${sessionName}": ${errorMessage}`); } + await loadSessions(); }, [sessionToDelete, loadSessions]); const handleCancelDelete = useCallback(() => { @@ -451,16 +454,16 @@ const SessionListView: React.FC = React.memo(

- {session.metadata.description || session.id} + {session.description || session.id}

- {formatMessageTimestamp(Date.parse(session.modified) / 1000)} + {formatMessageTimestamp(Date.parse(session.updated_at) / 1000)}
- {session.metadata.working_dir} + {session.working_dir}
@@ -468,14 +471,12 @@ const SessionListView: React.FC = React.memo(
- {session.metadata.message_count} + {session.message_count}
- {session.metadata.total_tokens !== null && ( + {session.total_tokens !== null && (
- - {(session.metadata.total_tokens || 0).toLocaleString()} - + {(session.total_tokens || 0).toLocaleString()}
)}
@@ -675,7 +676,7 @@ const SessionListView: React.FC = React.memo( = ({ onBack, ch /** * Props for the SessionMessages component */ -export interface SessionMessagesProps { +interface SessionMessagesProps { messages: Message[]; isLoading: boolean; error: string | null; diff --git a/ui/desktop/src/components/sessions/SessionsInsights.tsx b/ui/desktop/src/components/sessions/SessionsInsights.tsx index 8001a1967794..3de44e413cf4 100644 --- a/ui/desktop/src/components/sessions/SessionsInsights.tsx +++ b/ui/desktop/src/components/sessions/SessionsInsights.tsx @@ -1,23 +1,21 @@ import { useEffect, useState } from 'react'; import { Card, CardContent, CardDescription } from '../ui/card'; -import { getApiUrl } from '../../config'; import { Greeting } from '../common/Greeting'; -import { fetchSessions, type Session, resumeSession } from '../../sessions'; import { useNavigate } from 'react-router-dom'; import { Button } from '../ui/button'; import { ChatSmart } from '../icons/'; import { Goose } from '../icons/Goose'; import { Skeleton } from '../ui/skeleton'; - -interface SessionInsightsType { - totalSessions: number; - mostActiveDirs: [string, number][]; - avgSessionDuration: number; - totalTokens: number; -} +import { + getSessionInsights, + listSessions, + Session, + SessionInsights as ApiSessionInsights, +} from '../../api'; +import { resumeSession } from '../../sessions'; export function SessionInsights() { - const [insights, setInsights] = useState(null); + const [insights, setInsights] = useState(null); const [error, setError] = useState(null); const [recentSessions, setRecentSessions] = useState([]); const [isLoading, setIsLoading] = useState(true); @@ -29,31 +27,14 @@ export function SessionInsights() { const loadInsights = async () => { try { - const response = await fetch(getApiUrl('/sessions/insights'), { - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - 'X-Secret-Key': await window.electron.getSecretKey(), - }, - }); - - if (!response.ok) { - const errorText = await response.text(); - throw new Error(`Failed to fetch insights: ${response.status} ${errorText}`); - } - - const data = await response.json(); - setInsights(data); - // Clear any previous error when insights load successfully + const response = await getSessionInsights({ throwOnError: true }); + setInsights(response.data); setError(null); } catch (error) { console.error('Failed to load insights:', error); setError(error instanceof Error ? error.message : 'Failed to load insights'); - // Set fallback insights data so the UI can still render setInsights({ totalSessions: 0, - mostActiveDirs: [], - avgSessionDuration: 0, totalTokens: 0, }); } finally { @@ -63,10 +44,8 @@ export function SessionInsights() { const loadRecentSessions = async () => { try { - const sessions = await fetchSessions(); - setRecentSessions(sessions.slice(0, 3)); - } catch (error) { - console.error('Failed to load recent sessions:', error); + const response = await listSessions({ throwOnError: true }); + setRecentSessions(response.data.sessions.slice(0, 3)); } finally { setIsLoadingSessions(false); } @@ -85,6 +64,7 @@ export function SessionInsights() { mostActiveDirs: [], avgSessionDuration: 0, totalTokens: 0, + recentActivity: [], }; } // If we already have insights, just make sure loading is false @@ -155,16 +135,6 @@ export function SessionInsights() { - {/* Average Duration Card Skeleton */} - {/**/} - {/* */} - {/*
*/} - {/* */} - {/* Avg. chat length*/} - {/*
*/} - {/*
*/} - {/*
*/} - {/* Total Tokens Card Skeleton */} @@ -363,11 +333,11 @@ export function SessionInsights() {
- {session.metadata.description || session.id} + {session.description || session.id}
- {formatDateOnly(session.modified)} + {formatDateOnly(session.updated_at)}
)) diff --git a/ui/desktop/src/components/sessions/SessionsView.tsx b/ui/desktop/src/components/sessions/SessionsView.tsx index 7b56bf13111e..3f06d31bba89 100644 --- a/ui/desktop/src/components/sessions/SessionsView.tsx +++ b/ui/desktop/src/components/sessions/SessionsView.tsx @@ -1,17 +1,16 @@ import React, { useState, useEffect, useCallback } from 'react'; import { View, ViewOptions } from '../../utils/navigationUtils'; -import { fetchSessionDetails, type SessionDetails } from '../../sessions'; import SessionListView from './SessionListView'; import SessionHistoryView from './SessionHistoryView'; -import { toastError } from '../../toasts'; import { useLocation } from 'react-router-dom'; +import { getSession, Session } from '../../api'; interface SessionsViewProps { setView: (view: View, viewOptions?: ViewOptions) => void; } const SessionsView: React.FC = ({ setView }) => { - const [selectedSession, setSelectedSession] = useState(null); + const [selectedSession, setSelectedSession] = useState(null); const [showSessionHistory, setShowSessionHistory] = useState(false); const [isLoadingSession, setIsLoadingSession] = useState(false); const [error, setError] = useState(null); @@ -23,21 +22,17 @@ const SessionsView: React.FC = ({ setView }) => { setError(null); setShowSessionHistory(true); try { - const sessionDetails = await fetchSessionDetails(sessionId); - setSelectedSession(sessionDetails); + const response = await getSession({ + path: { session_id: sessionId }, + throwOnError: true, + }); + setSelectedSession(response.data); } catch (err) { console.error(`Failed to load session details for ${sessionId}:`, err); setError('Failed to load session details. Please try again later.'); // Keep the selected session null if there's an error setSelectedSession(null); setShowSessionHistory(false); - - const errorMessage = err instanceof Error ? err.message : String(err); - toastError({ - title: 'Failed to load session. The file may be corrupted.', - msg: 'Please try again later.', - traceback: errorMessage, - }); } finally { setIsLoadingSession(false); setInitialSessionId(null); @@ -68,7 +63,7 @@ const SessionsView: React.FC = ({ setView }) => { const handleRetryLoadSession = () => { if (selectedSession) { - loadSessionDetails(selectedSession.sessionId); + loadSessionDetails(selectedSession.id); } }; @@ -78,14 +73,15 @@ const SessionsView: React.FC = ({ setView }) => { = ({ setView }) => { ); }; diff --git a/ui/desktop/src/contexts/ChatContext.tsx b/ui/desktop/src/contexts/ChatContext.tsx index 92e0a8d58593..eb8d328faeab 100644 --- a/ui/desktop/src/contexts/ChatContext.tsx +++ b/ui/desktop/src/contexts/ChatContext.tsx @@ -1,6 +1,5 @@ import React, { createContext, useContext, ReactNode } from 'react'; import { ChatType } from '../types/chat'; -import { generateSessionId } from '../sessions'; import { Recipe } from '../recipe'; import { useDraftContext } from './DraftContext'; @@ -54,16 +53,14 @@ export const ChatProvider: React.FC = ({ }; const resetChat = () => { - const newSessionId = generateSessionId(); setChat({ - sessionId: newSessionId, + sessionId: '', title: DEFAULT_CHAT_TITLE, messages: [], messageHistoryIndex: 0, - recipeConfig: null, // Clear recipe when resetting chat - recipeParameters: null, // Clear when resetting chat + recipeConfig: null, + recipeParameters: null, }); - // Clear draft when resetting chat clearDraft(); }; diff --git a/ui/desktop/src/goosed.ts b/ui/desktop/src/goosed.ts index a74904b9d2f2..1f957f8a5c2f 100644 --- a/ui/desktop/src/goosed.ts +++ b/ui/desktop/src/goosed.ts @@ -34,7 +34,8 @@ export const checkServerStatus = async (client: Client): Promise => { try { await status({ client, throwOnError: true }); return true; - } catch { + } catch (error) { + log.error('failure to connect, will retry', error); if (attempt === maxAttempts) { log.error(`Server failed to respond after ${(interval * maxAttempts) / 1000} seconds`); } diff --git a/ui/desktop/src/hooks/useAgent.ts b/ui/desktop/src/hooks/useAgent.ts index d70b1c48cc4d..4c5a2aaab5c2 100644 --- a/ui/desktop/src/hooks/useAgent.ts +++ b/ui/desktop/src/hooks/useAgent.ts @@ -1,4 +1,4 @@ -import { useState, useCallback, useRef } from 'react'; +import { useCallback, useRef, useState } from 'react'; import { useConfig } from '../components/ConfigContext'; import { ChatType } from '../types/chat'; import { initializeSystem } from '../utils/providerUtils'; @@ -8,11 +8,11 @@ import { initConfig, Message as ApiMessage, readAllConfig, + Recipe, recoverConfig, resumeAgent, startAgent, validateConfig, - Recipe, } from '../api'; import { COST_TRACKING_ENABLED } from '../updates'; import { convertApiMessageToFrontendMessage } from '../components/context_management'; @@ -72,19 +72,17 @@ export function useAgent(): UseAgentReturn { throwOnError: true, }); - const agentSessionInfo = agentResponse.data; - const sessionMetadata = agentSessionInfo.metadata; - let chat: ChatType = { - sessionId: agentSessionInfo.session_id, - title: sessionMetadata.recipe?.title || sessionMetadata.description, + const agentSession = agentResponse.data; + const messages = agentSession.conversation || []; + return { + sessionId: agentSession.id, + title: agentSession.recipe?.title || agentSession.description, messageHistoryIndex: 0, - messages: agentSessionInfo.messages.map((message: ApiMessage) => + messages: messages?.map((message: ApiMessage) => convertApiMessageToFrontendMessage(message) ), - recipeConfig: sessionMetadata.recipe, + recipeConfig: agentSession.recipe, }; - - return chat; } if (initPromiseRef.current) { @@ -121,11 +119,11 @@ export function useAgent(): UseAgentReturn { throwOnError: true, }); - const agentSessionInfo = agentResponse.data; - if (!agentSessionInfo) { + const agentSession = agentResponse.data; + if (!agentSession) { throw Error('Failed to get session info'); } - setSessionId(agentSessionInfo.session_id); + setSessionId(agentSession.id); agentWaitingMessage('Agent is loading config'); @@ -139,7 +137,7 @@ export function useAgent(): UseAgentReturn { } agentWaitingMessage('Extensions are loading'); - await initializeSystem(agentSessionInfo.session_id, provider as string, model as string, { + await initializeSystem(agentSession.id, provider as string, model as string, { getExtensions, addExtension, setIsExtensionsLoading: initContext.setIsExtensionsLoading, @@ -153,15 +151,15 @@ export function useAgent(): UseAgentReturn { } } - const sessionMetadata = agentSessionInfo.metadata; + const messages = agentSession.conversation || []; let initChat: ChatType = { - sessionId: agentSessionInfo.session_id, - title: sessionMetadata.recipe?.title || sessionMetadata.description, + sessionId: agentSession.id, + title: agentSession.recipe?.title || agentSession.description, messageHistoryIndex: 0, - messages: agentSessionInfo.messages.map((message: ApiMessage) => + messages: messages.map((message: ApiMessage) => convertApiMessageToFrontendMessage(message) ), - recipeConfig: sessionMetadata.recipe, + recipeConfig: agentSession.recipe, }; setAgentState(AgentState.INITIALIZED); diff --git a/ui/desktop/src/hooks/useChatEngine.ts b/ui/desktop/src/hooks/useChatEngine.ts index f9ab4b3e687b..658c02231578 100644 --- a/ui/desktop/src/hooks/useChatEngine.ts +++ b/ui/desktop/src/hooks/useChatEngine.ts @@ -1,7 +1,6 @@ import { useCallback, useEffect, useMemo, useState } from 'react'; import { getApiUrl } from '../config'; import { useMessageStream } from './useMessageStream'; -import { fetchSessionDetails } from '../sessions'; import { LocalMessageStorage } from '../utils/localMessageStorage'; import { Message, @@ -16,6 +15,7 @@ import { } from '../types/message'; import { ChatType } from '../types/chat'; import { ChatState } from '../types/chatState'; +import { getSession } from '../api'; // Helper function to determine if a message is a user message const isUserMessage = (message: Message): boolean => { @@ -85,7 +85,7 @@ export const useChatEngine = ({ handleInputChange: _handleInputChange, updateMessageStreamBody, notifications, - sessionMetadata, + session, setError, } = useMessageStream({ api: getApiUrl('/reply'), @@ -199,14 +199,17 @@ export const useChatEngine = ({ setChat((prevChat: ChatType) => ({ ...prevChat, messages })); }, [messages, setChat]); - // Fetch session metadata to get token count useEffect(() => { const fetchSessionTokens = async () => { try { - const sessionDetails = await fetchSessionDetails(chat.sessionId); - setSessionTokenCount(sessionDetails.metadata.total_tokens || 0); - setSessionInputTokens(sessionDetails.metadata.accumulated_input_tokens || 0); - setSessionOutputTokens(sessionDetails.metadata.accumulated_output_tokens || 0); + const response = await getSession({ + path: { session_id: chat.sessionId }, + throwOnError: true, + }); + const sessionDetails = response.data; + setSessionTokenCount(sessionDetails.total_tokens || 0); + setSessionInputTokens(sessionDetails.accumulated_input_tokens || 0); + setSessionOutputTokens(sessionDetails.accumulated_output_tokens || 0); } catch (err) { console.error('Error fetching session token count:', err); } @@ -219,13 +222,13 @@ export const useChatEngine = ({ // Update token counts when sessionMetadata changes from the message stream useEffect(() => { - console.log('Session metadata received:', sessionMetadata); - if (sessionMetadata) { - setSessionTokenCount(sessionMetadata.total_tokens || 0); - setSessionInputTokens(sessionMetadata.accumulated_input_tokens || 0); - setSessionOutputTokens(sessionMetadata.accumulated_output_tokens || 0); + console.log('Session metadata received:', session); + if (session) { + setSessionTokenCount(session.total_tokens || 0); + setSessionInputTokens(session.accumulated_input_tokens || 0); + setSessionOutputTokens(session.accumulated_output_tokens || 0); } - }, [sessionMetadata]); + }, [session]); useEffect(() => { return () => { @@ -473,7 +476,7 @@ export const useChatEngine = ({ // Stream utilities updateMessageStreamBody, - sessionMetadata, + sessionMetadata: session, // Utilities isUserMessage, diff --git a/ui/desktop/src/hooks/useCostTracking.ts b/ui/desktop/src/hooks/useCostTracking.ts index cf22d0e0f3cd..4e8b390b23a4 100644 --- a/ui/desktop/src/hooks/useCostTracking.ts +++ b/ui/desktop/src/hooks/useCostTracking.ts @@ -1,14 +1,14 @@ import { useEffect, useRef, useState } from 'react'; import { useModelAndProvider } from '../components/ModelAndProviderContext'; import { getCostForModel } from '../utils/costDatabase'; -import { SessionMetadata } from '../api'; +import { Session } from '../api'; interface UseCostTrackingProps { sessionInputTokens: number; sessionOutputTokens: number; localInputTokens: number; localOutputTokens: number; - sessionMetadata?: SessionMetadata | null; + session?: Session | null; } export const useCostTracking = ({ @@ -16,7 +16,7 @@ export const useCostTracking = ({ sessionOutputTokens, localInputTokens, localOutputTokens, - sessionMetadata, + session, }: UseCostTrackingProps) => { const [sessionCosts, setSessionCosts] = useState<{ [key: string]: { @@ -79,7 +79,7 @@ export const useCostTracking = ({ sessionOutputTokens, localInputTokens, localOutputTokens, - sessionMetadata, + session, ]); return { diff --git a/ui/desktop/src/hooks/useMessageStream.ts b/ui/desktop/src/hooks/useMessageStream.ts index e20244a41bef..f708ae50cf70 100644 --- a/ui/desktop/src/hooks/useMessageStream.ts +++ b/ui/desktop/src/hooks/useMessageStream.ts @@ -1,7 +1,7 @@ import { useCallback, useEffect, useId, useReducer, useRef, useState } from 'react'; import useSWR from 'swr'; import { createUserMessage, hasCompletedToolCalls, Message, Role } from '../types/message'; -import { getSessionHistory, SessionMetadata } from '../api'; +import { getSession, Session } from '../api'; import { ChatState } from '../types/chatState'; let messageIdCounter = 0; @@ -152,8 +152,8 @@ export interface UseMessageStreamHelpers { /** Current model info from the backend */ currentModelInfo: { model: string; mode: string } | null; - /** Session metadata including token counts */ - sessionMetadata: SessionMetadata | null; + /** Session including token counts */ + session: Session | null; /** Clear error state */ setError: (error: Error | undefined) => void; @@ -188,7 +188,7 @@ export function useMessageStream({ const [currentModelInfo, setCurrentModelInfo] = useState<{ model: string; mode: string } | null>( null ); - const [sessionMetadata, setSessionMetadata] = useState(null); + const [session, setSession] = useState(null); // expose a way to update the body so we can update the session id when CLE occurs const updateMessageStreamBody = useCallback((newBody: object) => { @@ -333,26 +333,21 @@ export function useMessageStream({ } case 'Finish': { - // Call onFinish with the last message if available if (onFinish && currentMessages.length > 0) { const lastMessage = currentMessages[currentMessages.length - 1]; onFinish(lastMessage, parsedEvent.reason); } - // Fetch updated session metadata with token counts const sessionId = (extraMetadataRef.current.body as Record) ?.session_id as string; if (sessionId) { - try { - const sessionResponse = await getSessionHistory({ - path: { session_id: sessionId }, - }); - - if (sessionResponse.data?.metadata) { - setSessionMetadata(sessionResponse.data?.metadata); - } - } catch (error) { - console.error('Failed to fetch session metadata:', error); + const sessionResponse = await getSession({ + path: { session_id: sessionId }, + throwOnError: true, + }); + + if (sessionResponse.data) { + setSession(sessionResponse.data); } } break; @@ -631,7 +626,7 @@ export function useMessageStream({ updateMessageStreamBody, notifications, currentModelInfo, - sessionMetadata, + session: session, setError, }; } diff --git a/ui/desktop/src/sessions.ts b/ui/desktop/src/sessions.ts index 74bfc9946455..5ea2413547e3 100644 --- a/ui/desktop/src/sessions.ts +++ b/ui/desktop/src/sessions.ts @@ -1,185 +1,16 @@ -import { Message } from './types/message'; -import { - getSessionHistory, - listSessions, - SessionInfo, - Message as ApiMessage, - SessionMetadata, -} from './api'; -import { convertApiMessageToFrontendMessage } from './components/context_management'; -import { getApiUrl } from './config'; +import { Session } from './api'; -// Helper function to ensure working directory is set -export function ensureWorkingDir(metadata: Partial): SessionMetadata { - return { - description: metadata.description || '', - message_count: metadata.message_count || 0, - total_tokens: metadata.total_tokens || null, - working_dir: metadata.working_dir || process.env.HOME || '', - accumulated_input_tokens: metadata.accumulated_input_tokens || null, - accumulated_output_tokens: metadata.accumulated_output_tokens || null, - accumulated_total_tokens: metadata.accumulated_total_tokens || null, - }; -} - -export interface Session { - id: string; - path: string; - modified: string; - metadata: SessionMetadata; -} - -export interface SessionDetails { - sessionId: string; - metadata: SessionMetadata; - messages: Message[]; -} - -/** - * Generate a session ID in the format yyyymmdd_hhmmss - */ -export function generateSessionId(): string { - const now = new Date(); - const year = now.getFullYear(); - const month = String(now.getMonth() + 1).padStart(2, '0'); - const day = String(now.getDate()).padStart(2, '0'); - const hours = String(now.getHours()).padStart(2, '0'); - const minutes = String(now.getMinutes()).padStart(2, '0'); - const seconds = String(now.getSeconds()).padStart(2, '0'); - - return `${year}${month}${day}_${hours}${minutes}${seconds}`; -} - -/** - * Fetches all available sessions from the API - * @returns Promise with sessions data - */ -/** - * Fetches all available sessions from the API - * @returns Promise with an array of Session objects - */ -export async function fetchSessions(): Promise { - const response = await listSessions(); - - // Check if the response has the expected structure - if (response && response.data && response.data.sessions) { - // Since the API returns SessionInfo, we need to convert to Session - const sessions = response.data.sessions - .filter( - (sessionInfo: SessionInfo) => sessionInfo.metadata && sessionInfo.metadata.message_count > 0 - ) - .map( - (sessionInfo: SessionInfo): Session => ({ - id: sessionInfo.id, - path: sessionInfo.path, - modified: sessionInfo.modified, - metadata: ensureWorkingDir(sessionInfo.metadata), - }) - ); - - // order sessions by 'modified' date descending - sessions.sort( - (a: Session, b: Session) => new Date(b.modified).getTime() - new Date(a.modified).getTime() - ); - - return sessions; - } else { - throw new Error('Unexpected response format from listSessions'); - } -} - -/** - * Fetches details for a specific session - * @param sessionId The ID of the session to fetch - * @returns Promise with session details - */ -export async function fetchSessionDetails(sessionId: string): Promise { - const response = await getSessionHistory({ - path: { session_id: sessionId }, - }); - - try { - // Convert the SessionHistoryResponse to a SessionDetails object - return { - sessionId: response.data.sessionId, - metadata: ensureWorkingDir(response.data.metadata), - messages: response.data.messages.map((message: ApiMessage) => - convertApiMessageToFrontendMessage(message) - ), // slight diffs between backend and frontend Message obj - }; - } catch (error) { - console.error(`Error fetching session details for ${sessionId}:`, error); - throw error; - } -} - -/** - * Updates the metadata for a specific session - * @param sessionId The ID of the session to update - * @param description The new description (name) for the session - * @returns Promise that resolves when the update is complete - */ -export async function updateSessionMetadata(sessionId: string, description: string): Promise { - const url = getApiUrl(`/sessions/${sessionId}/metadata`); - const secretKey = await window.electron.getSecretKey(); - - const response = await fetch(url, { - method: 'PUT', - headers: { - 'Content-Type': 'application/json', - 'X-Secret-Key': secretKey, - }, - body: JSON.stringify({ description }), - }); - - if (!response.ok) { - const errorText = await response.text(); - throw new Error(`Failed to update session metadata: ${response.statusText} - ${errorText}`); - } -} - -/** - * Resumes a session. Currently, this opens a new window with the session loaded. - */ -export function resumeSession(session: SessionDetails | Session) { - const resumedSessionId = 'sessionId' in session ? session.sessionId : session.id; - console.log('Launching session in new window:', resumedSessionId); - const workingDir = session.metadata?.working_dir; +export function resumeSession(session: Session) { + console.log('Launching session in new window:', session.description || session.id); + const workingDir = session.working_dir; if (!workingDir) { - throw new Error('Cannot resume session: working directory is missing in session metadata'); + throw new Error('Cannot resume session: working directory is missing in session'); } window.electron.createChatWindow( undefined, // query workingDir, undefined, // version - resumedSessionId + session.id ); } - -/** - * Deletes a specific session - * @param sessionId The ID of the session to delete - * @returns Promise that resolves when the deletion is complete - */ -export async function deleteSession(sessionId: string): Promise { - try { - const url = getApiUrl(`/sessions/${sessionId}/delete`); - const secretKey = await window.electron.getSecretKey(); - - const response = await fetch(url, { - method: 'DELETE', - headers: { - 'X-Secret-Key': secretKey, - }, - }); - - if (!response.ok) { - const errorText = await response.text(); - throw new Error(`Failed to delete session: ${response.statusText} - ${errorText}`); - } - } catch (error) { - console.error(`Error deleting session ${sessionId}:`, error); - throw error; - } -} diff --git a/ui/desktop/src/utils/dateUtils.ts b/ui/desktop/src/utils/dateUtils.ts index ca9b15c5f2f8..1902e057eb5f 100644 --- a/ui/desktop/src/utils/dateUtils.ts +++ b/ui/desktop/src/utils/dateUtils.ts @@ -1,4 +1,4 @@ -import { Session } from '../sessions'; +import { Session } from '../api'; export interface DateGroup { label: string; @@ -16,7 +16,7 @@ export function groupSessionsByDate(sessions: Session[]): DateGroup[] { const groups: { [key: string]: DateGroup } = {}; sessions.forEach((session) => { - const sessionDate = new Date(session.modified); + const sessionDate = new Date(session.updated_at); const sessionDateStart = new Date(sessionDate); sessionDateStart.setHours(0, 0, 0, 0);