Skip to content

Commit 713779f

Browse files
committed
feat: Add get_runner_type method to support getting the currently used Runner type
Signed-off-by: plotor <[email protected]>
1 parent 6ef8e52 commit 713779f

File tree

6 files changed

+81
-20
lines changed

6 files changed

+81
-20
lines changed

daft/context.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@ def __init__(self, ctx: PyDaftContext | None = None):
4747
else:
4848
self._ctx = PyDaftContext()
4949

50+
def get_runner_type(self) -> str:
51+
if self._ctx._runner is not None:
52+
return self._ctx._runner.name
53+
54+
return self._ctx.get_runner_type()
55+
5056
def get_or_create_runner(self) -> Runner[PartitionT]:
5157
return self._ctx.get_or_create_runner()
5258

daft/daft/__init__.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1955,6 +1955,7 @@ class PyDaftContext:
19551955
def __init__(self) -> None: ...
19561956
_runner: Runner[Any]
19571957
def get_or_create_runner(self) -> Runner[PartitionT]: ...
1958+
def get_runner_type(self) -> str: ...
19581959
_daft_execution_config: PyDaftExecutionConfig
19591960
_daft_planning_config: PyDaftPlanningConfig
19601961
@property

daft/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def column_inputs_to_expressions(columns: ManyColumnsInputType) -> list[Expressi
126126
return [col(c) if isinstance(c, str) else c for c in column_iter]
127127

128128

129-
def detect_ray_state() -> bool:
129+
def detect_ray_state() -> tuple[bool, bool]:
130130
ray_is_initialized = False
131131
ray_is_in_job = False
132132
in_ray_worker = False
@@ -145,4 +145,4 @@ def detect_ray_state() -> bool:
145145
except ImportError:
146146
pass
147147

148-
return not in_ray_worker and (ray_is_initialized or ray_is_in_job)
148+
return ray_is_initialized or ray_is_in_job, in_ray_worker

src/daft-context/src/lib.rs

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,14 @@ pub fn set_runner_ray(
244244
) -> DaftResult<DaftContext> {
245245
let ctx = get_context();
246246

247+
let runner_type = get_runner_type_from_env();
248+
if !runner_type.is_empty() && runner_type != RayRunner::NAME {
249+
log::warn!(
250+
"Ignore inconsistent $DAFT_RUNNER='{}' env when setting runner as ray",
251+
runner_type
252+
);
253+
}
254+
247255
let runner = Runner::Ray(RayRunner::try_new(
248256
address,
249257
max_task_backlog,
@@ -268,6 +276,14 @@ pub fn set_runner_ray(
268276
pub fn set_runner_native(num_threads: Option<usize>) -> DaftResult<DaftContext> {
269277
let ctx = get_context();
270278

279+
let runner_type = get_runner_type_from_env();
280+
if !runner_type.is_empty() && runner_type != NativeRunner::NAME {
281+
log::warn!(
282+
"Ignore inconsistent $DAFT_RUNNER='{}' env when setting runner as native",
283+
runner_type
284+
);
285+
}
286+
271287
let runner = Runner::Native(NativeRunner::try_new(num_threads)?);
272288
let runner = Arc::new(runner);
273289

@@ -322,30 +338,45 @@ fn get_ray_runner_config_from_env() -> RunnerConfig {
322338

323339
/// Helper function to automatically detect whether to use the ray runner.
324340
#[cfg(feature = "python")]
325-
fn detect_ray_state() -> bool {
341+
fn detect_ray_state() -> (bool, bool) {
326342
Python::with_gil(|py| {
327343
py.import(pyo3::intern!(py, "daft.utils"))
328344
.and_then(|m| m.getattr(pyo3::intern!(py, "detect_ray_state")))
329345
.and_then(|m| m.call0())
330346
.and_then(|m| m.extract())
331-
.unwrap_or(false)
347+
.unwrap_or((false, false))
332348
})
333349
}
334350

335351
#[cfg(feature = "python")]
336-
fn get_runner_config_from_env() -> DaftResult<RunnerConfig> {
352+
fn get_runner_type_from_env() -> String {
337353
const DAFT_RUNNER: &str = "DAFT_RUNNER";
338354

339-
let runner_from_envvar = std::env::var(DAFT_RUNNER)
355+
std::env::var(DAFT_RUNNER)
340356
.unwrap_or_default()
341-
.to_lowercase();
342-
343-
match runner_from_envvar.as_str() {
344-
"native" => Ok(RunnerConfig::Native { num_threads: None }),
345-
"ray" => Ok(get_ray_runner_config_from_env()),
346-
"py" => Err(DaftError::ValueError("The PyRunner was removed from Daft from v0.5.0 onwards. Please set the env to `DAFT_RUNNER=native` instead.".to_string())),
347-
"" => Ok(if detect_ray_state() { get_ray_runner_config_from_env() } else { RunnerConfig::Native { num_threads: None }}),
348-
other => Err(DaftError::ValueError(format!("Invalid runner type `DAFT_RUNNER={other}` specified through the env. Please use either `native` or `ray` instead.")))
357+
.to_lowercase()
358+
}
359+
360+
#[cfg(feature = "python")]
361+
fn get_runner_config_from_env() -> DaftResult<RunnerConfig> {
362+
match get_runner_type_from_env().as_str() {
363+
NativeRunner::NAME => Ok(RunnerConfig::Native { num_threads: None }),
364+
RayRunner::NAME => Ok(get_ray_runner_config_from_env()),
365+
"py" => Err(DaftError::ValueError(
366+
"The PyRunner was removed from Daft from v0.5.0 onwards. \
367+
Please set the env to `DAFT_RUNNER=native` instead."
368+
.to_string(),
369+
)),
370+
"" => Ok(if detect_ray_state() == (true, false) {
371+
// on ray but not in ray worker
372+
get_ray_runner_config_from_env()
373+
} else {
374+
RunnerConfig::Native { num_threads: None }
375+
}),
376+
other => Err(DaftError::ValueError(format!(
377+
"Invalid runner type `DAFT_RUNNER={other}` specified through the env. \
378+
Please use either `native` or `ray` instead."
379+
))),
349380
}
350381
}
351382

@@ -366,7 +397,7 @@ pub fn reset_runner() {
366397
}
367398

368399
#[cfg(feature = "python")]
369-
pub fn register_modules(parent: &Bound<PyModule>) -> pyo3::PyResult<()> {
400+
pub fn register_modules(parent: &Bound<PyModule>) -> PyResult<()> {
370401
parent.add_function(wrap_pyfunction!(
371402
python::get_runner_config_from_env,
372403
parent

src/daft-context/src/python.rs

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@ use std::sync::Arc;
22

33
use common_daft_config::{PyDaftExecutionConfig, PyDaftPlanningConfig};
44
use common_error::DaftError;
5-
use pyo3::prelude::*;
5+
use daft_py_runners::{NativeRunner, RayRunner};
6+
use pyo3::{prelude::*, IntoPyObjectExt};
67

7-
use crate::{DaftContext, Runner, RunnerConfig};
8+
use crate::{detect_ray_state, DaftContext, Runner, RunnerConfig};
89

910
#[pyclass]
1011
pub struct PyRunnerConfig {
@@ -13,7 +14,7 @@ pub struct PyRunnerConfig {
1314

1415
#[pyclass]
1516
pub struct PyDaftContext {
16-
inner: crate::DaftContext,
17+
inner: DaftContext,
1718
}
1819

1920
impl Default for PyDaftContext {
@@ -45,6 +46,24 @@ impl PyDaftContext {
4546
}
4647
}
4748
}
49+
50+
pub fn get_runner_type(&self, py: Python) -> PyResult<PyObject> {
51+
let runner_type = self.inner.runner().map_or_else(
52+
|| {
53+
if detect_ray_state().0 {
54+
RayRunner::NAME
55+
} else {
56+
NativeRunner::NAME
57+
}
58+
},
59+
|runner| match runner.as_ref() {
60+
Runner::Ray(_) => RayRunner::NAME,
61+
Runner::Native(_) => NativeRunner::NAME,
62+
},
63+
);
64+
runner_type.into_py_any(py)
65+
}
66+
4867
#[getter(_daft_execution_config)]
4968
pub fn get_daft_execution_config(&self, py: Python) -> PyResult<PyDaftExecutionConfig> {
5069
let config = py.allow_threads(|| self.inner.execution_config());

src/daft-py-runners/src/lib.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ pub struct RayRunner {
2222

2323
#[cfg(feature = "python")]
2424
impl RayRunner {
25+
pub const NAME: &'static str = "ray";
26+
2527
pub fn try_new(
2628
address: Option<String>,
2729
max_task_backlog: Option<usize>,
@@ -53,6 +55,8 @@ pub struct NativeRunner {
5355

5456
#[cfg(feature = "python")]
5557
impl NativeRunner {
58+
pub const NAME: &'static str = "native";
59+
5660
pub fn try_new(num_threads: Option<usize>) -> DaftResult<Self> {
5761
Python::with_gil(|py| {
5862
let native_runner_module = py.import(intern!(py, "daft.runners.native_runner"))?;
@@ -84,13 +88,13 @@ impl Runner {
8488
Python::with_gil(|py| {
8589
let name = obj.getattr(py, "name")?.extract::<String>(py)?;
8690
match name.as_ref() {
87-
"ray" => {
91+
RayRunner::NAME => {
8892
let ray_runner = RayRunner {
8993
pyobj: Arc::new(obj),
9094
};
9195
Ok(Self::Ray(ray_runner))
9296
}
93-
"native" => {
97+
NativeRunner::NAME => {
9498
let native_runner = NativeRunner {
9599
pyobj: Arc::new(obj),
96100
};

0 commit comments

Comments
 (0)