Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions crates/goose-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,10 @@ sigstore-verification = { version = "0.1", default-features = false, features =
winapi = { version = "0.3", features = ["wincred"] }

[features]
default = ["code-mode"]
default = ["code-mode", "local-inference"]
code-mode = ["goose/code-mode", "goose-acp/code-mode"]
cuda = ["goose/cuda"]
local-inference = ["goose/local-inference"]
cuda = ["goose/cuda", "local-inference"]
# disables the update command
disable-update = []

Expand Down
5 changes: 5 additions & 0 deletions crates/goose-cli/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,7 @@ enum Command {
command: TermCommand,
},
/// Manage local inference models
#[cfg(feature = "local-inference")]
#[command(about = "Manage local inference models", visible_alias = "lm")]
LocalModels {
#[command(subcommand)]
Expand Down Expand Up @@ -892,6 +893,7 @@ enum Command {
},
}

#[cfg(feature = "local-inference")]
#[derive(Subcommand)]
enum LocalModelsCommand {
/// Search HuggingFace for GGUF models
Expand Down Expand Up @@ -1013,6 +1015,7 @@ fn get_command_name(command: &Option<Command>) -> &'static str {
Some(Command::Update { .. }) => "update",
Some(Command::Recipe { .. }) => "recipe",
Some(Command::Term { .. }) => "term",
#[cfg(feature = "local-inference")]
Some(Command::LocalModels { .. }) => "local-models",
Some(Command::Completion { .. }) => "completion",
Some(Command::ValidateExtensions { .. }) => "validate-extensions",
Expand Down Expand Up @@ -1473,6 +1476,7 @@ async fn handle_term_subcommand(command: TermCommand) -> Result<()> {
}
}

#[cfg(feature = "local-inference")]
async fn handle_local_models_command(command: LocalModelsCommand) -> Result<()> {
use goose::providers::local_inference::hf_models;
use goose::providers::local_inference::local_model_registry::{
Expand Down Expand Up @@ -1759,6 +1763,7 @@ pub async fn cli() -> anyhow::Result<()> {
}
Some(Command::Recipe { command }) => handle_recipe_subcommand(command),
Some(Command::Term { command }) => handle_term_subcommand(command).await,
#[cfg(feature = "local-inference")]
Some(Command::LocalModels { command }) => handle_local_models_command(command).await,
Some(Command::ValidateExtensions { file }) => {
use goose::agents::validate_extensions::validate_bundled_extensions;
Expand Down
5 changes: 3 additions & 2 deletions crates/goose-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ description.workspace = true
workspace = true

[features]
default = ["code-mode"]
default = ["code-mode", "local-inference"]
code-mode = ["goose/code-mode"]
cuda = ["goose/cuda"]
local-inference = ["goose/local-inference"]
cuda = ["goose/cuda", "local-inference"]

[dependencies]
goose = { path = "../goose", default-features = false }
Expand Down
1 change: 1 addition & 0 deletions crates/goose-server/src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub async fn check_token(
next: Next,
) -> Result<Response, StatusCode> {
if request.uri().path() == "/status"
|| request.uri().path() == "/features"
|| request.uri().path() == "/mcp-ui-proxy"
|| request.uri().path() == "/mcp-app-proxy"
|| request.uri().path() == "/mcp-app-guest"
Expand Down
53 changes: 35 additions & 18 deletions crates/goose-server/src/openapi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -479,20 +479,7 @@ derive_utoipa!(Icon as IconSchema);
super::routes::telemetry::send_telemetry_event,
super::routes::dictation::transcribe_dictation,
super::routes::dictation::get_dictation_config,
super::routes::dictation::list_models,
super::routes::dictation::download_model,
super::routes::dictation::get_download_progress,
super::routes::dictation::cancel_download,
super::routes::dictation::delete_model,
super::routes::local_inference::list_local_models,
super::routes::local_inference::search_hf_models,
super::routes::local_inference::get_repo_files,
super::routes::local_inference::download_hf_model,
super::routes::local_inference::get_local_model_download_progress,
super::routes::local_inference::cancel_local_model_download,
super::routes::local_inference::delete_local_model,
super::routes::local_inference::get_model_settings,
super::routes::local_inference::update_model_settings,
super::routes::features::get_features,
),
components(schemas(
super::routes::config_management::UpsertConfigQuery,
Expand Down Expand Up @@ -671,6 +658,33 @@ derive_utoipa!(Icon as IconSchema);
super::routes::dictation::TranscribeResponse,
goose::dictation::providers::DictationProvider,
super::routes::dictation::DictationProviderStatus,
super::routes::features::FeaturesResponse,
DownloadProgress,
DownloadStatus,
))
)]
pub struct ApiDoc;

#[cfg(feature = "local-inference")]
#[derive(OpenApi)]
#[openapi(
paths(
super::routes::dictation::list_models,
super::routes::dictation::download_model,
super::routes::dictation::get_download_progress,
super::routes::dictation::cancel_download,
super::routes::dictation::delete_model,
super::routes::local_inference::list_local_models,
super::routes::local_inference::search_hf_models,
super::routes::local_inference::get_repo_files,
super::routes::local_inference::download_hf_model,
super::routes::local_inference::get_local_model_download_progress,
super::routes::local_inference::cancel_local_model_download,
super::routes::local_inference::delete_local_model,
super::routes::local_inference::get_model_settings,
super::routes::local_inference::update_model_settings,
),
components(schemas(
super::routes::dictation::WhisperModelResponse,
super::routes::local_inference::LocalModelResponse,
super::routes::local_inference::ModelDownloadStatus,
Expand All @@ -681,14 +695,17 @@ derive_utoipa!(Icon as IconSchema);
super::routes::local_inference::RepoVariantsResponse,
goose::providers::local_inference::local_model_registry::ModelSettings,
goose::providers::local_inference::local_model_registry::SamplingConfig,
DownloadProgress,
DownloadStatus,
))
)]
pub struct ApiDoc;
pub struct LocalInferenceApiDoc;

#[allow(dead_code)] // Used by generate_schema binary
pub fn generate_schema() -> String {
let api_doc = ApiDoc::openapi();
#[allow(unused_mut)]
let mut api_doc = ApiDoc::openapi();

#[cfg(feature = "local-inference")]
api_doc.merge(LocalInferenceApiDoc::openapi());

serde_json::to_string_pretty(&api_doc).unwrap()
}
32 changes: 25 additions & 7 deletions crates/goose-server/src/routes/dictation.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
use crate::routes::errors::ErrorResponse;
use crate::state::AppState;
use axum::{
extract::{DefaultBodyLimit, Path},
extract::DefaultBodyLimit,
http::StatusCode,
routing::{delete, get, post},
routing::{get, post},
Json, Router,
};
#[cfg(feature = "local-inference")]
use axum::{extract::Path, routing::delete};
use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
#[cfg(feature = "local-inference")]
use goose::dictation::providers::transcribe_local;
use goose::dictation::providers::{
is_configured, transcribe_local, transcribe_with_provider, DictationProvider, PROVIDERS,
all_providers, is_configured, transcribe_with_provider, DictationProvider,
};
#[cfg(feature = "local-inference")]
use goose::dictation::whisper;
#[cfg(feature = "local-inference")]
use goose::download_manager::{get_download_manager, DownloadProgress};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
Expand All @@ -19,6 +25,7 @@ use utoipa::ToSchema;

const MAX_AUDIO_SIZE_BYTES: usize = 50 * 1024 * 1024;

#[cfg(feature = "local-inference")]
#[derive(Debug, Serialize, ToSchema)]
pub struct WhisperModelResponse {
#[serde(flatten)]
Expand Down Expand Up @@ -171,6 +178,7 @@ pub async fn transcribe_dictation(
)
.await
.map_err(convert_error)?,
#[cfg(feature = "local-inference")]
DictationProvider::Local => transcribe_local(audio_bytes).await.map_err(convert_error)?,
};

Expand All @@ -189,7 +197,7 @@ pub async fn get_dictation_config(
let config = goose::config::Config::global();
let mut providers = HashMap::new();

for def in PROVIDERS {
for def in all_providers() {
let provider = def.provider;
let configured = is_configured(provider);

Expand Down Expand Up @@ -222,6 +230,7 @@ pub async fn get_dictation_config(
Ok(Json(providers))
}

#[cfg(feature = "local-inference")]
#[utoipa::path(
get,
path = "/dictation/models",
Expand All @@ -243,6 +252,7 @@ pub async fn list_models() -> Result<Json<Vec<WhisperModelResponse>>, ErrorRespo
Ok(Json(models))
}

#[cfg(feature = "local-inference")]
#[utoipa::path(
post,
path = "/dictation/models/{model_id}/download",
Expand Down Expand Up @@ -274,6 +284,7 @@ pub async fn download_model(Path(model_id): Path<String>) -> Result<StatusCode,
Ok(StatusCode::ACCEPTED)
}

#[cfg(feature = "local-inference")]
#[utoipa::path(
get,
path = "/dictation/models/{model_id}/download",
Expand All @@ -293,6 +304,7 @@ pub async fn get_download_progress(
Ok(Json(progress))
}

#[cfg(feature = "local-inference")]
#[utoipa::path(
delete,
path = "/dictation/models/{model_id}/download",
Expand All @@ -307,6 +319,7 @@ pub async fn cancel_download(Path(model_id): Path<String>) -> Result<StatusCode,
Ok(StatusCode::OK)
}

#[cfg(feature = "local-inference")]
#[utoipa::path(
delete,
path = "/dictation/models/{model_id}",
Expand Down Expand Up @@ -334,9 +347,12 @@ pub async fn delete_model(Path(model_id): Path<String>) -> Result<StatusCode, Er
}

pub fn routes(state: Arc<AppState>) -> Router {
Router::new()
let router = Router::new()
.route("/dictation/transcribe", post(transcribe_dictation))
.route("/dictation/config", get(get_dictation_config))
.route("/dictation/config", get(get_dictation_config));

#[cfg(feature = "local-inference")]
let router = router
.route("/dictation/models", get(list_models))
.route(
"/dictation/models/{model_id}/download",
Expand All @@ -350,7 +366,9 @@ pub fn routes(state: Arc<AppState>) -> Router {
"/dictation/models/{model_id}/download",
delete(cancel_download),
)
.route("/dictation/models/{model_id}", delete(delete_model))
.route("/dictation/models/{model_id}", delete(delete_model));

router
.layer(DefaultBodyLimit::max(MAX_AUDIO_SIZE_BYTES))
.with_state(state)
}
33 changes: 33 additions & 0 deletions crates/goose-server/src/routes/features.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
use axum::{routing::get, Json, Router};
use serde::Serialize;
use std::collections::HashMap;
use utoipa::ToSchema;

#[derive(Serialize, ToSchema)]
pub struct FeaturesResponse {
/// Map of feature name to enabled status
pub features: HashMap<String, bool>,
}

#[utoipa::path(
get,
path = "/features",
responses(
(status = 200, description = "Compile-time feature flags", body = FeaturesResponse),
)
)]
pub async fn get_features() -> Json<FeaturesResponse> {
let mut features = HashMap::new();

features.insert(
"local-inference".to_string(),
cfg!(feature = "local-inference"),
);
features.insert("code-mode".to_string(), cfg!(feature = "code-mode"));

Json(FeaturesResponse { features })
}

pub fn routes() -> Router {
Router::new().route("/features", get(get_features))
}
15 changes: 11 additions & 4 deletions crates/goose-server/src/routes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ pub mod agent;
pub mod config_management;
pub mod dictation;
pub mod errors;
pub mod features;
pub mod gateway;
#[cfg(feature = "local-inference")]
pub mod local_inference;
pub mod mcp_app_proxy;
pub mod mcp_ui_proxy;
Expand All @@ -27,13 +29,11 @@ use axum::Router;

// Function to configure all routes
pub fn configure(state: Arc<crate::state::AppState>, secret_key: String) -> Router {
Router::new()
let router = Router::new()
.merge(status::routes(state.clone()))
.merge(reply::routes(state.clone()))
.merge(action_required::routes(state.clone()))
.merge(agent::routes(state.clone()))
.merge(dictation::routes(state.clone()))
.merge(local_inference::routes(state.clone()))
.merge(config_management::routes(state.clone()))
.merge(prompts::routes())
.merge(recipe::routes(state.clone()))
Expand All @@ -46,5 +46,12 @@ pub fn configure(state: Arc<crate::state::AppState>, secret_key: String) -> Rout
.merge(mcp_ui_proxy::routes(secret_key.clone()))
.merge(mcp_app_proxy::routes(secret_key))
.merge(session_events::routes(state.clone()))
.merge(sampling::routes(state))
.merge(sampling::routes(state.clone()))
.merge(dictation::routes(state.clone()))
.merge(features::routes());

#[cfg(feature = "local-inference")]
let router = router.merge(local_inference::routes(state));
Comment thread
jh-block marked this conversation as resolved.

router
}
3 changes: 3 additions & 0 deletions crates/goose-server/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::session_event_bus::SessionEventBus;
use crate::tunnel::TunnelManager;
use goose::agents::ExtensionLoadResult;
use goose::gateway::manager::GatewayManager;
#[cfg(feature = "local-inference")]
use goose::providers::local_inference::InferenceRuntime;

type ExtensionLoadingTasks =
Expand All @@ -26,6 +27,7 @@ pub struct AppState {
pub tunnel_manager: Arc<TunnelManager>,
pub gateway_manager: Arc<GatewayManager>,
pub extension_loading_tasks: ExtensionLoadingTasks,
#[cfg(feature = "local-inference")]
pub inference_runtime: Arc<InferenceRuntime>,
session_buses: Arc<Mutex<HashMap<String, Arc<SessionEventBus>>>>,
}
Expand All @@ -45,6 +47,7 @@ impl AppState {
tunnel_manager,
gateway_manager,
extension_loading_tasks: Arc::new(Mutex::new(HashMap::new())),
#[cfg(feature = "local-inference")]
inference_runtime: InferenceRuntime::get_or_init(),
session_buses: Arc::new(Mutex::new(HashMap::new())),
}))
Expand Down
Loading
Loading