Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 20 additions & 10 deletions crates/sail-cli/src/spark/mcp_server.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
use std::fmt;
use std::fmt::Formatter;
use std::net::Ipv4Addr;
use std::sync::Arc;

use clap::ValueEnum;
use log::info;
use pyo3::prelude::PyAnyMethods;
use pyo3::{PyResult, Python};
use sail_spark_connect::entrypoint::serve;
use sail_common::config::AppConfig;
use sail_common::runtime::RuntimeManager;
use sail_spark_connect::entrypoint::{serve, SessionManagerOptions};
use sail_telemetry::telemetry::init_telemetry;
use tokio::net::TcpListener;
use tokio::runtime::Runtime;

use crate::python::Modules;

Expand Down Expand Up @@ -42,31 +44,39 @@ pub struct McpSettings {
pub spark_remote: Option<String>,
}

fn run_spark_connect_server(runtime: &Runtime) -> Result<String, Box<dyn std::error::Error>> {
let (server_port, server_task) = runtime.block_on(async move {
fn run_spark_connect_server(
options: SessionManagerOptions,
) -> Result<String, Box<dyn std::error::Error>> {
let handle = options.runtime.primary().clone();
let (server_port, server_task) = handle.block_on(async move {
// Listen on only the loopback interface for security.
let listener = TcpListener::bind((Ipv4Addr::new(127, 0, 0, 1), 0)).await?;
let port = listener.local_addr()?.port();
let task = async move {
info!("Starting the Spark Connect server on port {port}...");
let _ = serve(listener, shutdown()).await;
let _ = serve(listener, shutdown(), options).await;
info!("The Spark Connect server has stopped.");
};
<Result<_, Box<dyn std::error::Error>>>::Ok((port, task))
})?;
runtime.spawn(server_task);
handle.spawn(server_task);
Ok(format!("sc://127.0.0.1:{server_port}"))
}

pub fn run_spark_mcp_server(settings: McpSettings) -> Result<(), Box<dyn std::error::Error>> {
init_telemetry()?;

let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()?;
let config = Arc::new(AppConfig::load()?);
let runtime = RuntimeManager::try_new(&config.runtime)?;

let spark_remote = match settings.spark_remote {
None => run_spark_connect_server(&runtime)?,
None => {
let options = SessionManagerOptions {
config: Arc::clone(&config),
runtime: runtime.handle(),
};
run_spark_connect_server(options)?
}
Some(x) => x,
};

Expand Down
21 changes: 12 additions & 9 deletions crates/sail-cli/src/spark/server.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use std::net::IpAddr;
use std::sync::Arc;

use log::info;
use sail_spark_connect::entrypoint::serve;
use sail_common::config::AppConfig;
use sail_common::runtime::RuntimeManager;
use sail_spark_connect::entrypoint::{serve, SessionManagerOptions};
use sail_telemetry::telemetry::init_telemetry;
use tokio::net::TcpListener;

const SERVER_STACK_SIZE: usize = 1024 * 1024 * 8;

/// Handles graceful shutdown by waiting for a `SIGINT` signal in [tokio].
///
/// The `SIGINT` signal is captured by Python if the `_signal` module is imported [1].
Expand All @@ -28,19 +29,21 @@ async fn shutdown() {
pub fn run_spark_connect_server(ip: IpAddr, port: u16) -> Result<(), Box<dyn std::error::Error>> {
init_telemetry()?;

let runtime = tokio::runtime::Builder::new_multi_thread()
.thread_stack_size(SERVER_STACK_SIZE)
.enable_all()
.build()?;
let config = Arc::new(AppConfig::load()?);
let runtime = RuntimeManager::try_new(&config.runtime)?;
let options = SessionManagerOptions {
config: Arc::clone(&config),
runtime: runtime.handle(),
};

runtime.block_on(async {
runtime.handle().primary().block_on(async {
// A secure connection can be handled by a gateway in production.
let listener = TcpListener::bind((ip, port)).await?;
info!(
"Starting the Spark Connect server on {}...",
listener.local_addr()?
);
serve(listener, shutdown()).await?;
serve(listener, shutdown(), options).await?;
info!("The Spark Connect server has stopped.");
<Result<(), Box<dyn std::error::Error>>>::Ok(())
})?;
Expand Down
21 changes: 14 additions & 7 deletions crates/sail-cli/src/spark/shell.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
use std::net::Ipv4Addr;
use std::sync::Arc;

use pyo3::prelude::PyAnyMethods;
use pyo3::{PyResult, Python};
use sail_spark_connect::entrypoint::serve;
use sail_common::config::AppConfig;
use sail_common::runtime::RuntimeManager;
use sail_spark_connect::entrypoint::{serve, SessionManagerOptions};
use tokio::net::TcpListener;
use tokio::sync::oneshot;

use crate::python::Modules;

pub fn run_pyspark_shell() -> Result<(), Box<dyn std::error::Error>> {
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()?;
let config = Arc::new(AppConfig::load()?);
let runtime = RuntimeManager::try_new(&config.runtime)?;
let options = SessionManagerOptions {
config,
runtime: runtime.handle(),
};
let (_tx, rx) = oneshot::channel::<()>();
let (server_port, server_task) = runtime.block_on(async move {
let handle = runtime.handle().primary().clone();
let (server_port, server_task) = handle.block_on(async move {
// Listen on only the loopback interface for security.
let listener = TcpListener::bind((Ipv4Addr::new(127, 0, 0, 1), 0)).await?;
let port = listener.local_addr()?.port();
Expand All @@ -25,11 +32,11 @@ pub fn run_pyspark_shell() -> Result<(), Box<dyn std::error::Error>> {
let _ = rx.await;
};
let task = async {
let _ = serve(listener, shutdown).await;
let _ = serve(listener, shutdown, options).await;
};
<Result<_, Box<dyn std::error::Error>>>::Ok((port, task))
})?;
runtime.spawn(server_task);
handle.spawn(server_task);
Python::with_gil(|py| -> PyResult<_> {
let shell = Modules::SPARK_SHELL.load(py)?;
shell
Expand Down
13 changes: 8 additions & 5 deletions crates/sail-cli/src/worker/entrypoint.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
use sail_common::config::AppConfig;
use sail_common::runtime::RuntimeManager;
use sail_telemetry::telemetry::init_telemetry;

pub fn run_worker() -> Result<(), Box<dyn std::error::Error>> {
init_telemetry()?;

let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()?;

runtime.block_on(sail_execution::run_worker())?;
let config = AppConfig::load()?;
let runtime = RuntimeManager::try_new(&config.runtime)?;
runtime
.handle()
.primary()
.block_on(sail_execution::run_worker(&config, runtime.handle()))?;

fastrace::flush();

Expand Down
1 change: 1 addition & 0 deletions crates/sail-common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ figment = { workspace = true }
half = { workspace = true }
log = { workspace = true }
iana-time-zone = { workspace = true }
tokio = { workspace = true }
7 changes: 7 additions & 0 deletions crates/sail-common/src/config/application.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ const DEFAULT_CONFIG: &str = include_str!("default.toml");
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AppConfig {
pub mode: ExecutionMode,
pub runtime: RuntimeConfig,
pub cluster: ClusterConfig,
pub execution: ExecutionConfig,
pub kubernetes: KubernetesConfig,
Expand Down Expand Up @@ -67,6 +68,12 @@ pub enum ExecutionMode {
KubernetesCluster,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RuntimeConfig {
pub stack_size: usize,
pub enable_secondary: bool,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClusterConfig {
pub enable_tls: bool,
Expand Down
4 changes: 4 additions & 0 deletions crates/sail-common/src/config/default.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
mode = "local"

[runtime]
stack_size = 8388608
enable_secondary = false

[cluster]
enable_tls = false
driver_listen_host = "127.0.0.1"
Expand Down
1 change: 1 addition & 0 deletions crates/sail-common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ pub mod datetime;
pub mod debug;
pub mod error;
pub mod object;
pub mod runtime;
pub mod spec;
pub mod string;
pub mod tests;
53 changes: 53 additions & 0 deletions crates/sail-common/src/runtime.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
use tokio::runtime::{Handle, Runtime};

use crate::config::RuntimeConfig;
use crate::error::{CommonError, CommonResult};

#[derive(Debug)]
pub struct RuntimeManager {
primary: Runtime,
secondary: Option<Runtime>,
}

impl RuntimeManager {
pub fn try_new(config: &RuntimeConfig) -> CommonResult<Self> {
let primary = Self::build_runtime(config.stack_size)?;
let secondary = if config.enable_secondary {
Some(Self::build_runtime(config.stack_size)?)
} else {
None
};

Ok(Self { primary, secondary })
}

pub fn handle(&self) -> RuntimeHandle {
let primary = self.primary.handle().clone();
let secondary = self.secondary.as_ref().map(|r| r.handle().clone());
RuntimeHandle { primary, secondary }
}

fn build_runtime(stack_size: usize) -> CommonResult<Runtime> {
tokio::runtime::Builder::new_multi_thread()
.thread_stack_size(stack_size)
.enable_all()
.build()
.map_err(|e| CommonError::internal(e.to_string()))
}
}

#[derive(Debug, Clone)]
pub struct RuntimeHandle {
primary: Handle,
secondary: Option<Handle>,
}

impl RuntimeHandle {
pub fn primary(&self) -> &Handle {
&self.primary
}

pub fn secondary(&self) -> Option<&Handle> {
self.secondary.as_ref()
}
}
4 changes: 3 additions & 1 deletion crates/sail-execution/src/driver/actor/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ impl Actor for DriverActor {

fn new(options: DriverOptions) -> Self {
let worker_manager: Arc<dyn WorkerManager> = match &options.worker_manager {
WorkerManagerOptions::Local => Arc::new(LocalWorkerManager::new()),
WorkerManagerOptions::Local => {
Arc::new(LocalWorkerManager::new(options.runtime.clone()))
}
WorkerManagerOptions::Kubernetes(options) => {
Arc::new(KubernetesWorkerManager::new(options.clone()))
}
Expand Down
3 changes: 2 additions & 1 deletion crates/sail-execution/src/driver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ pub(crate) use actor::DriverActor;
pub(crate) use client::DriverClient;
pub(crate) use event::DriverEvent;
pub(crate) use gen::driver_service_client::DriverServiceClient;
pub(crate) use options::{DriverOptions, WorkerManagerOptions};
pub use options::DriverOptions;
pub(crate) use options::WorkerManagerOptions;
8 changes: 5 additions & 3 deletions crates/sail-execution/src/driver/options.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::time::Duration;

use sail_common::config::{AppConfig, ExecutionMode};
use sail_common::runtime::RuntimeHandle;
use sail_server::RetryStrategy;

use crate::error::{ExecutionError, ExecutionResult};
Expand All @@ -25,6 +26,7 @@ pub struct DriverOptions {
pub job_output_buffer: usize,
pub rpc_retry_strategy: RetryStrategy,
pub worker_manager: WorkerManagerOptions,
pub runtime: RuntimeHandle,
}

#[derive(Debug)]
Expand All @@ -33,9 +35,8 @@ pub enum WorkerManagerOptions {
Kubernetes(KubernetesWorkerManagerOptions),
}

impl TryFrom<&AppConfig> for DriverOptions {
type Error = ExecutionError;
fn try_from(config: &AppConfig) -> ExecutionResult<Self> {
impl DriverOptions {
pub fn try_new(config: &AppConfig, runtime: RuntimeHandle) -> ExecutionResult<Self> {
let worker_manager = match config.mode {
ExecutionMode::Local => {
return Err(ExecutionError::InvalidArgument(
Expand Down Expand Up @@ -79,6 +80,7 @@ impl TryFrom<&AppConfig> for DriverOptions {
task_launch_timeout: Duration::from_secs(config.cluster.task_launch_timeout_secs),
job_output_buffer: config.cluster.job_output_buffer,
worker_manager,
runtime,
})
}
}
Loading