@@ -244,6 +244,14 @@ pub fn set_runner_ray(
244
244
) -> DaftResult < DaftContext > {
245
245
let ctx = get_context ( ) ;
246
246
247
+ let runner_type = get_runner_type_from_env ( ) ;
248
+ if !runner_type. is_empty ( ) && runner_type != RayRunner :: NAME {
249
+ log:: warn!(
250
+ "Ignore inconsistent $DAFT_RUNNER='{}' env when setting runner as ray" ,
251
+ runner_type
252
+ ) ;
253
+ }
254
+
247
255
let runner = Runner :: Ray ( RayRunner :: try_new (
248
256
address,
249
257
max_task_backlog,
@@ -268,6 +276,14 @@ pub fn set_runner_ray(
268
276
pub fn set_runner_native ( num_threads : Option < usize > ) -> DaftResult < DaftContext > {
269
277
let ctx = get_context ( ) ;
270
278
279
+ let runner_type = get_runner_type_from_env ( ) ;
280
+ if !runner_type. is_empty ( ) && runner_type != NativeRunner :: NAME {
281
+ log:: warn!(
282
+ "Ignore inconsistent $DAFT_RUNNER='{}' env when setting runner as native" ,
283
+ runner_type
284
+ ) ;
285
+ }
286
+
271
287
let runner = Runner :: Native ( NativeRunner :: try_new ( num_threads) ?) ;
272
288
let runner = Arc :: new ( runner) ;
273
289
@@ -322,30 +338,45 @@ fn get_ray_runner_config_from_env() -> RunnerConfig {
322
338
323
339
/// Helper function to automatically detect whether to use the ray runner.
324
340
#[ cfg( feature = "python" ) ]
325
- fn detect_ray_state ( ) -> bool {
341
+ fn detect_ray_state ( ) -> ( bool , bool ) {
326
342
Python :: with_gil ( |py| {
327
343
py. import ( pyo3:: intern!( py, "daft.utils" ) )
328
344
. and_then ( |m| m. getattr ( pyo3:: intern!( py, "detect_ray_state" ) ) )
329
345
. and_then ( |m| m. call0 ( ) )
330
346
. and_then ( |m| m. extract ( ) )
331
- . unwrap_or ( false )
347
+ . unwrap_or ( ( false , false ) )
332
348
} )
333
349
}
334
350
335
351
#[ cfg( feature = "python" ) ]
336
- fn get_runner_config_from_env ( ) -> DaftResult < RunnerConfig > {
352
+ fn get_runner_type_from_env ( ) -> String {
337
353
const DAFT_RUNNER : & str = "DAFT_RUNNER" ;
338
354
339
- let runner_from_envvar = std:: env:: var ( DAFT_RUNNER )
355
+ std:: env:: var ( DAFT_RUNNER )
340
356
. unwrap_or_default ( )
341
- . to_lowercase ( ) ;
342
-
343
- match runner_from_envvar. as_str ( ) {
344
- "native" => Ok ( RunnerConfig :: Native { num_threads : None } ) ,
345
- "ray" => Ok ( get_ray_runner_config_from_env ( ) ) ,
346
- "py" => Err ( DaftError :: ValueError ( "The PyRunner was removed from Daft from v0.5.0 onwards. Please set the env to `DAFT_RUNNER=native` instead." . to_string ( ) ) ) ,
347
- "" => Ok ( if detect_ray_state ( ) { get_ray_runner_config_from_env ( ) } else { RunnerConfig :: Native { num_threads : None } } ) ,
348
- other => Err ( DaftError :: ValueError ( format ! ( "Invalid runner type `DAFT_RUNNER={other}` specified through the env. Please use either `native` or `ray` instead." ) ) )
357
+ . to_lowercase ( )
358
+ }
359
+
360
+ #[ cfg( feature = "python" ) ]
361
+ fn get_runner_config_from_env ( ) -> DaftResult < RunnerConfig > {
362
+ match get_runner_type_from_env ( ) . as_str ( ) {
363
+ NativeRunner :: NAME => Ok ( RunnerConfig :: Native { num_threads : None } ) ,
364
+ RayRunner :: NAME => Ok ( get_ray_runner_config_from_env ( ) ) ,
365
+ "py" => Err ( DaftError :: ValueError (
366
+ "The PyRunner was removed from Daft from v0.5.0 onwards. \
367
+ Please set the env to `DAFT_RUNNER=native` instead."
368
+ . to_string ( ) ,
369
+ ) ) ,
370
+ "" => Ok ( if detect_ray_state ( ) == ( true , false ) {
371
+ // on ray but not in ray worker
372
+ get_ray_runner_config_from_env ( )
373
+ } else {
374
+ RunnerConfig :: Native { num_threads : None }
375
+ } ) ,
376
+ other => Err ( DaftError :: ValueError ( format ! (
377
+ "Invalid runner type `DAFT_RUNNER={other}` specified through the env. \
378
+ Please use either `native` or `ray` instead."
379
+ ) ) ) ,
349
380
}
350
381
}
351
382
@@ -366,7 +397,7 @@ pub fn reset_runner() {
366
397
}
367
398
368
399
#[ cfg( feature = "python" ) ]
369
- pub fn register_modules ( parent : & Bound < PyModule > ) -> pyo3 :: PyResult < ( ) > {
400
+ pub fn register_modules ( parent : & Bound < PyModule > ) -> PyResult < ( ) > {
370
401
parent. add_function ( wrap_pyfunction ! (
371
402
python:: get_runner_config_from_env,
372
403
parent
0 commit comments