Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions crates/goose-server/src/openapi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,8 @@ derive_utoipa!(Icon as IconSchema);
super::routes::session::get_session_insights,
super::routes::session::update_session_description,
super::routes::session::delete_session,
super::routes::session::export_session,
super::routes::session::import_session,
super::routes::session::update_session_user_recipe_values,
super::routes::schedule::create_schedule,
super::routes::schedule::list_schedules,
Expand Down Expand Up @@ -389,6 +391,7 @@ derive_utoipa!(Icon as IconSchema);
super::routes::reply::ChatRequest,
super::routes::context::ContextManageRequest,
super::routes::context::ContextManageResponse,
super::routes::session::ImportSessionRequest,
super::routes::session::SessionListResponse,
super::routes::session::UpdateSessionDescriptionRequest,
super::routes::session::UpdateSessionUserRecipeValuesRequest,
Expand Down
59 changes: 59 additions & 0 deletions crates/goose-server/src/routes/session.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::state::AppState;
use axum::routing::post;
use axum::{
extract::Path,
http::StatusCode,
Expand Down Expand Up @@ -33,6 +34,12 @@ pub struct UpdateSessionUserRecipeValuesRequest {
user_recipe_values: HashMap<String, String>,
}

#[derive(Deserialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct ImportSessionRequest {
json: String,
}

const MAX_DESCRIPTION_LENGTH: usize = 200;

#[utoipa::path(
Expand Down Expand Up @@ -199,11 +206,63 @@ async fn delete_session(Path(session_id): Path<String>) -> Result<StatusCode, St
Ok(StatusCode::OK)
}

#[utoipa::path(
get,
path = "/sessions/{session_id}/export",
params(
("session_id" = String, Path, description = "Unique identifier for the session")
),
responses(
(status = 200, description = "Session exported successfully", body = String),
(status = 401, description = "Unauthorized - Invalid or missing API key"),
(status = 404, description = "Session not found"),
(status = 500, description = "Internal server error")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these just add to the openapi, you don't need to enumerate them, but you can if you want to. (on our other endpoints, I don't think they're always telling the truth)

),
security(
("api_key" = [])
),
tag = "Session Management"
)]
async fn export_session(Path(session_id): Path<String>) -> Result<Json<String>, StatusCode> {
let exported = SessionManager::export_session(&session_id)
.await
.map_err(|_| StatusCode::NOT_FOUND)?;

Ok(Json(exported))
}

#[utoipa::path(
post,
path = "/sessions/import",
request_body = ImportSessionRequest,
responses(
(status = 200, description = "Session imported successfully", body = Session),
(status = 401, description = "Unauthorized - Invalid or missing API key"),
(status = 400, description = "Bad request - Invalid JSON"),
(status = 500, description = "Internal server error")
),
security(
("api_key" = [])
),
tag = "Session Management"
)]
async fn import_session(
Json(request): Json<ImportSessionRequest>,
) -> Result<Json<Session>, StatusCode> {
let session = SessionManager::import_session(&request.json)
.await
.map_err(|_| StatusCode::BAD_REQUEST)?;

Ok(Json(session))
}

pub fn routes(state: Arc<AppState>) -> Router {
Router::new()
.route("/sessions", get(list_sessions))
.route("/sessions/{session_id}", get(get_session))
.route("/sessions/{session_id}", delete(delete_session))
.route("/sessions/{session_id}/export", get(export_session))
.route("/sessions/import", post(import_session))
.route("/sessions/insights", get(get_session_insights))
.route(
"/sessions/{session_id}/description",
Expand Down
119 changes: 119 additions & 0 deletions crates/goose/src/session/session_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,14 @@ impl SessionManager {
Self::instance().await?.get_insights().await
}

pub async fn export_session(id: &str) -> Result<String> {
Self::instance().await?.export_session(id).await
}

pub async fn import_session(json: &str) -> Result<Session> {
Self::instance().await?.import_session(json).await
}

pub async fn maybe_update_description(id: &str, provider: Arc<dyn Provider>) -> Result<()> {
let session = Self::get_session(id, true).await?;
let conversation = session
Expand Down Expand Up @@ -895,6 +903,41 @@ impl SessionStorage {
total_tokens: row.1.unwrap_or(0),
})
}

async fn export_session(&self, id: &str) -> Result<String> {
let session = self.get_session(id, true).await?;
serde_json::to_string_pretty(&session).map_err(Into::into)
}

async fn import_session(&self, json: &str) -> Result<Session> {
let import: Session = serde_json::from_str(json)?;

let session = self
.create_session(import.working_dir.clone(), import.description.clone())
.await?;

self.apply_update(
SessionUpdateBuilder::new(session.id.clone())
.extension_data(import.extension_data)
.total_tokens(import.total_tokens)
.input_tokens(import.input_tokens)
.output_tokens(import.output_tokens)
.accumulated_total_tokens(import.accumulated_total_tokens)
.accumulated_input_tokens(import.accumulated_input_tokens)
.accumulated_output_tokens(import.accumulated_output_tokens)
.schedule_id(import.schedule_id)
.recipe(import.recipe)
.user_recipe_values(import.user_recipe_values),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider constructing the struct itself without the builder so any added new fields are caught at compile time

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had that at first, but it becomes awkward since we still want to call create_session at this point so then we need to update only the other fields?

)
.await?;

if let Some(conversation) = import.conversation {
self.replace_conversation(&session.id, &conversation)
.await?;
}

self.get_session(&session.id, true).await
}
}

#[cfg(test)]
Expand Down Expand Up @@ -997,4 +1040,80 @@ mod tests {
let expected_tokens = 100 * NUM_CONCURRENT_SESSIONS * (NUM_CONCURRENT_SESSIONS - 1) / 2;
assert_eq!(insights.total_tokens, expected_tokens as i64);
}

#[tokio::test]
async fn test_export_import_roundtrip() {
const DESCRIPTION: &str = "Original session";
const TOTAL_TOKENS: i32 = 500;
const INPUT_TOKENS: i32 = 300;
const OUTPUT_TOKENS: i32 = 200;
const ACCUMULATED_TOKENS: i32 = 1000;
const USER_MESSAGE: &str = "test message";
const ASSISTANT_MESSAGE: &str = "test response";

let temp_dir = TempDir::new().unwrap();
let db_path = temp_dir.path().join("test_export.db");
let storage = Arc::new(SessionStorage::create(&db_path).await.unwrap());

let original = storage
.create_session(PathBuf::from("/tmp/test"), DESCRIPTION.to_string())
.await
.unwrap();

storage
.apply_update(
SessionUpdateBuilder::new(original.id.clone())
.total_tokens(Some(TOTAL_TOKENS))
.input_tokens(Some(INPUT_TOKENS))
.output_tokens(Some(OUTPUT_TOKENS))
.accumulated_total_tokens(Some(ACCUMULATED_TOKENS)),
)
.await
.unwrap();

storage
.add_message(
&original.id,
&Message {
id: None,
role: Role::User,
created: chrono::Utc::now().timestamp_millis(),
content: vec![MessageContent::text(USER_MESSAGE)],
metadata: Default::default(),
},
)
.await
.unwrap();

storage
.add_message(
&original.id,
&Message {
id: None,
role: Role::Assistant,
created: chrono::Utc::now().timestamp_millis(),
content: vec![MessageContent::text(ASSISTANT_MESSAGE)],
metadata: Default::default(),
},
)
.await
.unwrap();

let exported = storage.export_session(&original.id).await.unwrap();
let imported = storage.import_session(&exported).await.unwrap();

assert_ne!(imported.id, original.id);
assert_eq!(imported.description, DESCRIPTION);
assert_eq!(imported.working_dir, PathBuf::from("/tmp/test"));
assert_eq!(imported.total_tokens, Some(TOTAL_TOKENS));
assert_eq!(imported.input_tokens, Some(INPUT_TOKENS));
assert_eq!(imported.output_tokens, Some(OUTPUT_TOKENS));
assert_eq!(imported.accumulated_total_tokens, Some(ACCUMULATED_TOKENS));
assert_eq!(imported.message_count, 2);

let conversation = imported.conversation.unwrap();
assert_eq!(conversation.messages().len(), 2);
assert_eq!(conversation.messages()[0].role, Role::User);
assert_eq!(conversation.messages()[1].role, Role::Assistant);
}
}
100 changes: 100 additions & 0 deletions ui/desktop/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -1615,6 +1615,50 @@
]
}
},
"/sessions/import": {
"post": {
"tags": [
"Session Management"
],
"operationId": "import_session",
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ImportSessionRequest"
}
}
},
"required": true
},
"responses": {
"200": {
"description": "Session imported successfully",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/Session"
}
}
}
},
"400": {
"description": "Bad request - Invalid JSON"
},
"401": {
"description": "Unauthorized - Invalid or missing API key"
},
"500": {
"description": "Internal server error"
}
},
"security": [
{
"api_key": []
}
]
}
},
"/sessions/insights": {
"get": {
"tags": [
Expand Down Expand Up @@ -1778,6 +1822,51 @@
]
}
},
"/sessions/{session_id}/export": {
"get": {
"tags": [
"Session Management"
],
"operationId": "export_session",
"parameters": [
{
"name": "session_id",
"in": "path",
"description": "Unique identifier for the session",
"required": true,
"schema": {
"type": "string"
}
}
],
"responses": {
"200": {
"description": "Session exported successfully",
"content": {
"text/plain": {
"schema": {
"type": "string"
}
}
}
},
"401": {
"description": "Unauthorized - Invalid or missing API key"
},
"404": {
"description": "Session not found"
},
"500": {
"description": "Internal server error"
}
},
"security": [
{
"api_key": []
}
]
}
},
"/sessions/{session_id}/user_recipe_values": {
"put": {
"tags": [
Expand Down Expand Up @@ -2819,6 +2908,17 @@
}
}
},
"ImportSessionRequest": {
"type": "object",
"required": [
"json"
],
"properties": {
"json": {
"type": "string"
}
}
},
"InspectJobResponse": {
"type": "object",
"properties": {
Expand Down
Loading
Loading