Skip to content

Commit 79721be

Browse files
committed
feat: Add get_or_infer_runner_type to support getting runner type from context
Signed-off-by: plotor <[email protected]>
1 parent 296a129 commit 79721be

File tree

7 files changed

+191
-20
lines changed

7 files changed

+191
-20
lines changed

daft/context.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,21 @@ def __init__(self, ctx: PyDaftContext | None = None):
4747
else:
4848
self._ctx = PyDaftContext()
4949

50+
def get_or_infer_runner_type(self) -> str:
51+
"""Get or infer the runner type.
52+
53+
This API will get or infer the currently used runner type according to the following strategies:
54+
1. If the `runner` has been set, return its type directly;
55+
2. Try to determine whether it's currently running on a ray cluster. If so, consider it to be a ray type;
56+
3. Try to determine based on `DAFT_RUNNER` env variable.
57+
58+
:return: runner type string ("native" or "ray")
59+
"""
60+
if self._ctx._runner is not None:
61+
return self._ctx._runner.name
62+
63+
return self._ctx.get_or_infer_runner_type()
64+
5065
def get_or_create_runner(self) -> Runner[PartitionT]:
5166
return self._ctx.get_or_create_runner()
5267

daft/daft/__init__.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1962,6 +1962,7 @@ class PyDaftContext:
19621962
def __init__(self) -> None: ...
19631963
_runner: Runner[Any]
19641964
def get_or_create_runner(self) -> Runner[PartitionT]: ...
1965+
def get_or_infer_runner_type(self) -> str: ...
19651966
_daft_execution_config: PyDaftExecutionConfig
19661967
_daft_planning_config: PyDaftPlanningConfig
19671968
@property

daft/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def column_inputs_to_expressions(columns: ManyColumnsInputType) -> list[Expressi
126126
return [col(c) if isinstance(c, str) else c for c in column_iter]
127127

128128

129-
def detect_ray_state() -> bool:
129+
def detect_ray_state() -> tuple[bool, bool]:
130130
ray_is_initialized = False
131131
ray_is_in_job = False
132132
in_ray_worker = False
@@ -145,4 +145,4 @@ def detect_ray_state() -> bool:
145145
except ImportError:
146146
pass
147147

148-
return not in_ray_worker and (ray_is_initialized or ray_is_in_job)
148+
return ray_is_initialized or ray_is_in_job, in_ray_worker

src/daft-context/src/lib.rs

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,14 @@ pub fn set_runner_ray(
244244
) -> DaftResult<DaftContext> {
245245
let ctx = get_context();
246246

247+
let runner_type = get_runner_type_from_env();
248+
if !runner_type.is_empty() && runner_type != RayRunner::NAME {
249+
log::warn!(
250+
"Ignore inconsistent $DAFT_RUNNER='{}' env when setting runner as ray",
251+
runner_type
252+
);
253+
}
254+
247255
let runner = Runner::Ray(RayRunner::try_new(
248256
address,
249257
max_task_backlog,
@@ -268,6 +276,14 @@ pub fn set_runner_ray(
268276
pub fn set_runner_native(num_threads: Option<usize>) -> DaftResult<DaftContext> {
269277
let ctx = get_context();
270278

279+
let runner_type = get_runner_type_from_env();
280+
if !runner_type.is_empty() && runner_type != NativeRunner::NAME {
281+
log::warn!(
282+
"Ignore inconsistent $DAFT_RUNNER='{}' env when setting runner as native",
283+
runner_type
284+
);
285+
}
286+
271287
let runner = Runner::Native(NativeRunner::try_new(num_threads)?);
272288
let runner = Arc::new(runner);
273289

@@ -322,30 +338,45 @@ fn get_ray_runner_config_from_env() -> RunnerConfig {
322338

323339
/// Helper function to automatically detect whether to use the ray runner.
324340
#[cfg(feature = "python")]
325-
fn detect_ray_state() -> bool {
341+
fn detect_ray_state() -> (bool, bool) {
326342
Python::with_gil(|py| {
327343
py.import(pyo3::intern!(py, "daft.utils"))
328344
.and_then(|m| m.getattr(pyo3::intern!(py, "detect_ray_state")))
329345
.and_then(|m| m.call0())
330346
.and_then(|m| m.extract())
331-
.unwrap_or(false)
347+
.unwrap_or((false, false))
332348
})
333349
}
334350

335351
#[cfg(feature = "python")]
336-
fn get_runner_config_from_env() -> DaftResult<RunnerConfig> {
352+
fn get_runner_type_from_env() -> String {
337353
const DAFT_RUNNER: &str = "DAFT_RUNNER";
338354

339-
let runner_from_envvar = std::env::var(DAFT_RUNNER)
355+
std::env::var(DAFT_RUNNER)
340356
.unwrap_or_default()
341-
.to_lowercase();
342-
343-
match runner_from_envvar.as_str() {
344-
"native" => Ok(RunnerConfig::Native { num_threads: None }),
345-
"ray" => Ok(get_ray_runner_config_from_env()),
346-
"py" => Err(DaftError::ValueError("The PyRunner was removed from Daft from v0.5.0 onwards. Please set the env to `DAFT_RUNNER=native` instead.".to_string())),
347-
"" => Ok(if detect_ray_state() { get_ray_runner_config_from_env() } else { RunnerConfig::Native { num_threads: None }}),
348-
other => Err(DaftError::ValueError(format!("Invalid runner type `DAFT_RUNNER={other}` specified through the env. Please use either `native` or `ray` instead.")))
357+
.to_lowercase()
358+
}
359+
360+
#[cfg(feature = "python")]
361+
fn get_runner_config_from_env() -> DaftResult<RunnerConfig> {
362+
match get_runner_type_from_env().as_str() {
363+
NativeRunner::NAME => Ok(RunnerConfig::Native { num_threads: None }),
364+
RayRunner::NAME => Ok(get_ray_runner_config_from_env()),
365+
"py" => Err(DaftError::ValueError(
366+
"The PyRunner was removed from Daft from v0.5.0 onwards. \
367+
Please set the env to `DAFT_RUNNER=native` instead."
368+
.to_string(),
369+
)),
370+
"" => Ok(if detect_ray_state() == (true, false) {
371+
// on ray but not in ray worker
372+
get_ray_runner_config_from_env()
373+
} else {
374+
RunnerConfig::Native { num_threads: None }
375+
}),
376+
other => Err(DaftError::ValueError(format!(
377+
"Invalid runner type `DAFT_RUNNER={other}` specified through the env. \
378+
Please use either `native` or `ray` instead."
379+
))),
349380
}
350381
}
351382

@@ -366,7 +397,7 @@ pub fn reset_runner() {
366397
}
367398

368399
#[cfg(feature = "python")]
369-
pub fn register_modules(parent: &Bound<PyModule>) -> pyo3::PyResult<()> {
400+
pub fn register_modules(parent: &Bound<PyModule>) -> PyResult<()> {
370401
parent.add_function(wrap_pyfunction!(
371402
python::get_runner_config_from_env,
372403
parent

src/daft-context/src/python.rs

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@ use std::sync::Arc;
22

33
use common_daft_config::{PyDaftExecutionConfig, PyDaftPlanningConfig};
44
use common_error::DaftError;
5-
use pyo3::prelude::*;
5+
use daft_py_runners::{NativeRunner, RayRunner};
6+
use pyo3::{prelude::*, IntoPyObjectExt};
67

7-
use crate::{DaftContext, Runner, RunnerConfig};
8+
use crate::{detect_ray_state, DaftContext, Runner, RunnerConfig};
89

910
#[pyclass]
1011
pub struct PyRunnerConfig {
@@ -13,7 +14,7 @@ pub struct PyRunnerConfig {
1314

1415
#[pyclass]
1516
pub struct PyDaftContext {
16-
inner: crate::DaftContext,
17+
inner: DaftContext,
1718
}
1819

1920
impl Default for PyDaftContext {
@@ -45,6 +46,27 @@ impl PyDaftContext {
4546
}
4647
}
4748
}
49+
50+
pub fn get_or_infer_runner_type(&self, py: Python) -> PyResult<PyObject> {
51+
match self.inner.runner() {
52+
Some(runner) => match runner.as_ref() {
53+
Runner::Ray(_) => RayRunner::NAME,
54+
Runner::Native(_) => NativeRunner::NAME,
55+
},
56+
None => {
57+
if let (true, _) = detect_ray_state() {
58+
RayRunner::NAME
59+
} else {
60+
match super::get_runner_config_from_env()? {
61+
RunnerConfig::Ray { .. } => RayRunner::NAME,
62+
RunnerConfig::Native { .. } => NativeRunner::NAME,
63+
}
64+
}
65+
}
66+
}
67+
.into_py_any(py)
68+
}
69+
4870
#[getter(_daft_execution_config)]
4971
pub fn get_daft_execution_config(&self, py: Python) -> PyResult<PyDaftExecutionConfig> {
5072
let config = py.allow_threads(|| self.inner.execution_config());

src/daft-py-runners/src/lib.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ pub struct RayRunner {
2222

2323
#[cfg(feature = "python")]
2424
impl RayRunner {
25+
pub const NAME: &'static str = "ray";
26+
2527
pub fn try_new(
2628
address: Option<String>,
2729
max_task_backlog: Option<usize>,
@@ -53,6 +55,8 @@ pub struct NativeRunner {
5355

5456
#[cfg(feature = "python")]
5557
impl NativeRunner {
58+
pub const NAME: &'static str = "native";
59+
5660
pub fn try_new(num_threads: Option<usize>) -> DaftResult<Self> {
5761
Python::with_gil(|py| {
5862
let native_runner_module = py.import(intern!(py, "daft.runners.native_runner"))?;
@@ -84,13 +88,13 @@ impl Runner {
8488
Python::with_gil(|py| {
8589
let name = obj.getattr(py, "name")?.extract::<String>(py)?;
8690
match name.as_ref() {
87-
"ray" => {
91+
RayRunner::NAME => {
8892
let ray_runner = RayRunner {
8993
pyobj: Arc::new(obj),
9094
};
9195
Ok(Self::Ray(ray_runner))
9296
}
93-
"native" => {
97+
NativeRunner::NAME => {
9498
let native_runner = NativeRunner {
9599
pyobj: Arc::new(obj),
96100
};

tests/test_context.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,3 +238,101 @@ def test_cannot_set_runner_ray_after_py():
238238
)
239239
assert result.stdout.decode().strip() in {"native"}
240240
assert "DaftError::InternalError Cannot set runner more than once" in result.stderr.decode().strip()
241+
242+
243+
@pytest.mark.parametrize("daft_runner_envvar", ["ray", "native"])
244+
def test_get_or_infer_runner_type_from_env(daft_runner_envvar):
245+
get_or_infer_runner_type_py_script = """
246+
import daft
247+
248+
print(daft.context.get_context().get_or_infer_runner_type())
249+
250+
251+
@daft.udf(return_dtype=daft.DataType.string())
252+
def my_udf(foo):
253+
runner_type = daft.context.get_context().get_or_infer_runner_type()
254+
return [f"{runner_type}_{f}" for f in foo]
255+
256+
257+
df = daft.from_pydict({"foo": [7]})
258+
pd = df.with_column(column_name="bar", expr=my_udf(df["foo"])).to_pydict()
259+
print(pd["bar"][0])
260+
"""
261+
262+
with with_null_env():
263+
result = subprocess.run(
264+
[sys.executable, "-c", get_or_infer_runner_type_py_script],
265+
capture_output=True,
266+
env={"DAFT_RUNNER": daft_runner_envvar},
267+
)
268+
269+
assert result.stdout.decode().strip() == f"{daft_runner_envvar}\n{daft_runner_envvar}_7"
270+
271+
272+
def test_get_or_infer_runner_type_with_set_runner_native():
273+
get_or_infer_runner_type_py_script = """
274+
import daft
275+
276+
daft.context.set_runner_native()
277+
278+
print(daft.context.get_context().get_or_infer_runner_type())
279+
280+
281+
@daft.udf(return_dtype=daft.DataType.string())
282+
def my_udf(foo):
283+
runner_type = daft.context.get_context().get_or_infer_runner_type()
284+
return [f"{runner_type}_{f}" for f in foo]
285+
286+
287+
df = daft.from_pydict({"foo": [7]})
288+
pd = df.with_column(column_name="bar", expr=my_udf(df["foo"])).to_pydict()
289+
print(pd["bar"][0])
290+
"""
291+
292+
with with_null_env():
293+
result = subprocess.run([sys.executable, "-c", get_or_infer_runner_type_py_script], capture_output=True)
294+
assert result.stdout.decode().strip() == "native\nnative_7"
295+
296+
297+
def test_get_or_infer_runner_type_with_set_runner_ray():
298+
get_or_infer_runner_type_py_script = """
299+
import daft
300+
301+
daft.context.set_runner_ray()
302+
303+
print(daft.context.get_context().get_or_infer_runner_type())
304+
305+
306+
@daft.udf(return_dtype=daft.DataType.string())
307+
def my_udf(foo):
308+
runner_type = daft.context.get_context().get_or_infer_runner_type()
309+
return [f"{runner_type}_{f}" for f in foo]
310+
311+
312+
df = daft.from_pydict({"foo": [7]})
313+
pd = df.with_column(column_name="bar", expr=my_udf(df["foo"])).to_pydict()
314+
print(pd["bar"][0])
315+
"""
316+
317+
with with_null_env():
318+
result = subprocess.run([sys.executable, "-c", get_or_infer_runner_type_py_script], capture_output=True)
319+
assert result.stdout.decode().strip() == "ray\nray_7"
320+
321+
322+
@pytest.mark.parametrize("daft_runner_envvar", ["ray", "native"])
323+
def test_get_or_infer_runner_type_with_inconsistent_settings(daft_runner_envvar):
324+
get_or_infer_runner_type_py_script = """
325+
import daft
326+
327+
print(daft.context.get_context().get_or_infer_runner_type())
328+
daft.context.set_runner_ray()
329+
print(daft.context.get_context().get_or_infer_runner_type())
330+
"""
331+
332+
with with_null_env():
333+
result = subprocess.run(
334+
[sys.executable, "-c", get_or_infer_runner_type_py_script],
335+
capture_output=True,
336+
env={"DAFT_RUNNER": daft_runner_envvar},
337+
)
338+
assert result.stdout.decode().strip() == f"{daft_runner_envvar}\nray"

0 commit comments

Comments
 (0)