diff --git a/Cargo.lock b/Cargo.lock index aff2602dfa..e62ab0b95a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1832,6 +1832,14 @@ dependencies = [ "typetag", ] +[[package]] +name = "common-logging" +version = "0.3.0-dev0" +dependencies = [ + "arc-swap", + "log", +] + [[package]] name = "common-macros" version = "0.3.0-dev0" @@ -2289,6 +2297,7 @@ dependencies = [ "common-display", "common-file-formats", "common-hashable-float-wrapper", + "common-logging", "common-metrics", "common-partitioning", "common-resource-request", @@ -2866,12 +2875,14 @@ dependencies = [ name = "daft-local-execution" version = "0.3.0-dev0" dependencies = [ + "arc-swap", "async-trait", "capitalize", "common-daft-config", "common-display", "common-error", "common-file-formats", + "common-logging", "common-metrics", "common-py-serde", "common-resource-request", @@ -2879,6 +2890,7 @@ dependencies = [ "common-scan-info", "common-system-info", "common-tracing", + "console 0.16.0", "daft-core", "daft-csv", "daft-dsl", diff --git a/Cargo.toml b/Cargo.toml index 81e696caee..8b40f5b114 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,7 @@ common-daft-config = {path = "src/common/daft-config", default-features = false} common-display = {path = "src/common/display", default-features = false} common-file-formats = {path = "src/common/file-formats", default-features = false} common-hashable-float-wrapper = {path = "src/common/hashable-float-wrapper", default-features = false} +common-logging = {path = "src/common/logging", default-features = false} common-metrics = {path = "src/common/metrics", default-features = false} common-partitioning = {path = "src/common/partitioning", default-features = false} common-resource-request = {path = "src/common/resource-request", default-features = false} @@ -217,6 +218,7 @@ members = [ [workspace.dependencies] approx = "0.5.1" +arc-swap = "1.7.1" async-compat = "0.2.3" async-compression = {version = "0.4.12", features = [ "tokio", diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index 42cdb58194..e1f4495449 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1843,8 +1843,8 @@ class NativeExecutor: plan: LocalPhysicalPlan, psets: dict[str, list[PyMicroPartition]], daft_execution_config: PyDaftExecutionConfig, - results_buffer_size: int | None, - context: dict[str, str] | None, + results_buffer_size: int | None = None, + context: dict[str, str] | None = None, ) -> AsyncIterator[PyMicroPartition]: ... @staticmethod def repr_ascii(builder: LogicalPlanBuilder, daft_execution_config: PyDaftExecutionConfig, simple: bool) -> str: ... diff --git a/daft/execution/udf.py b/daft/execution/udf.py index 4cc7111924..a02fc2d084 100644 --- a/daft/execution/udf.py +++ b/daft/execution/udf.py @@ -9,7 +9,7 @@ import tempfile from multiprocessing import resource_tracker, shared_memory from multiprocessing.connection import Listener -from typing import TYPE_CHECKING +from typing import IO, TYPE_CHECKING, cast from daft.errors import UDFException from daft.expressions import Expression, ExpressionsProjection @@ -21,9 +21,11 @@ logger = logging.getLogger(__name__) _ENTER = "__ENTER__" +_READY = "ready" _SUCCESS = "success" _UDF_ERROR = "udf_error" _ERROR = "error" +_OUTPUT_DIVIDER = b"_DAFT_OUTPUT_DIVIDER_\n" _SENTINEL = ("__EXIT__", 0) @@ -66,7 +68,11 @@ def __init__(self, project_expr: PyExpr, passthrough_exprs: list[PyExpr]) -> Non "daft.execution.udf_worker", self.socket_path, secret.hex(), - ] + ], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + # Python auto-buffers stdout by default, so disable + env={"PYTHONUNBUFFERED": "1"}, ) # Initialize communication @@ -79,8 +85,22 @@ def __init__(self, project_expr: PyExpr, passthrough_exprs: list[PyExpr]) -> Non ) expr_projection_bytes = pickle.dumps(expr_projection) self.handle_conn.send((_ENTER, expr_projection_bytes)) - - def eval_input(self, input: PyMicroPartition) -> PyMicroPartition: + response = self.handle_conn.recv() + if response != _READY: + raise RuntimeError(f"Expected '{_READY}' but got {response}") + + def trace_output(self) -> list[str]: + lines = [] + while True: + line = cast("IO[bytes]", self.process.stdout).readline() + # UDF process is expected to return the divider + # after initialization and every iteration + if line == b"" or line == _OUTPUT_DIVIDER or self.process.poll() is not None: + break + lines.append(line.decode().rstrip()) + return lines + + def eval_input(self, input: PyMicroPartition) -> tuple[PyMicroPartition, list[str]]: if self.process.poll() is not None: raise RuntimeError("UDF process has terminated") @@ -89,6 +109,7 @@ def eval_input(self, input: PyMicroPartition) -> PyMicroPartition: self.handle_conn.send((shm_name, shm_size)) response = self.handle_conn.recv() + stdout = self.trace_output() if response[0] == _UDF_ERROR: base_exc: Exception = pickle.loads(response[3]) if sys.version_info >= (3, 11): @@ -100,7 +121,7 @@ def eval_input(self, input: PyMicroPartition) -> PyMicroPartition: out_name, out_size = response[1], response[2] output_bytes = self.transport.read_and_release(out_name, out_size) deserialized = MicroPartition.from_ipc_stream(output_bytes) - return deserialized._micropartition + return (deserialized._micropartition, stdout) else: raise RuntimeError(f"Unknown response from actor: {response}") diff --git a/daft/execution/udf_worker.py b/daft/execution/udf_worker.py index c2978066c0..fcff9daf71 100644 --- a/daft/execution/udf_worker.py +++ b/daft/execution/udf_worker.py @@ -6,7 +6,16 @@ from traceback import TracebackException from daft.errors import UDFException -from daft.execution.udf import _ENTER, _ERROR, _SENTINEL, _SUCCESS, _UDF_ERROR, SharedMemoryTransport +from daft.execution.udf import ( + _ENTER, + _ERROR, + _OUTPUT_DIVIDER, + _READY, + _SENTINEL, + _SUCCESS, + _UDF_ERROR, + SharedMemoryTransport, +) from daft.expressions.expressions import ExpressionsProjection from daft.recordbatch.micropartition import MicroPartition @@ -22,23 +31,33 @@ def udf_event_loop( name, expr_projection_bytes = conn.recv() if name != _ENTER: raise ValueError(f"Expected '{_ENTER}' but got {name}") - uninitialized_projection: ExpressionsProjection = pickle.loads(expr_projection_bytes) transport = SharedMemoryTransport() try: - initialized_projection = ExpressionsProjection([e._initialize_udfs() for e in uninitialized_projection]) + conn.send(_READY) + expression_projection = None while True: name, size = conn.recv() if (name, size) == _SENTINEL: break + # We initialize after ready to avoid blocking the main thread + if expression_projection is None: + uninitialized_projection: ExpressionsProjection = pickle.loads(expr_projection_bytes) + initialized_projection = ExpressionsProjection([e._initialize_udfs() for e in uninitialized_projection]) + expression_projection = initialized_projection + input_bytes = transport.read_and_release(name, size) input = MicroPartition.from_ipc_stream(input_bytes) - evaluated = input.eval_expression_list(initialized_projection) - output_bytes = evaluated.to_ipc_stream() + evaluated = input.eval_expression_list(expression_projection) + output_bytes = evaluated.to_ipc_stream() out_name, out_size = transport.write_and_close(output_bytes) + + print(_OUTPUT_DIVIDER.decode(), end="", file=sys.stderr, flush=True) + sys.stdout.flush() + sys.stderr.flush() conn.send((_SUCCESS, out_name, out_size)) except UDFException as e: exc = e.__cause__ diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index ed506a3b92..40e90de626 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -5130,6 +5130,17 @@ def resolve_schema(self, schema: Schema) -> Schema: def __repr__(self) -> str: return f"{self._output_name_to_exprs.values()}" + @classmethod + def _from_serialized(cls, _output_name_to_exprs: dict[str, Expression]) -> ExpressionsProjection: + obj = cls.__new__(cls) + obj._output_name_to_exprs = _output_name_to_exprs + return obj + + def __reduce__( + self, + ) -> tuple[Callable[[dict[str, Expression]], ExpressionsProjection], tuple[dict[str, Expression]]]: + return ExpressionsProjection._from_serialized, (self._output_name_to_exprs,) + class ExpressionImageNamespace(ExpressionNamespace): """Expression operations for image columns. The following methods are available under the `expr.image` attribute.""" diff --git a/src/common/logging/Cargo.toml b/src/common/logging/Cargo.toml new file mode 100644 index 0000000000..ba98ec4c6a --- /dev/null +++ b/src/common/logging/Cargo.toml @@ -0,0 +1,11 @@ +[dependencies] +arc-swap = {workspace = true} +log = {workspace = true} + +[lints] +workspace = true + +[package] +edition = {workspace = true} +name = "common-logging" +version = {workspace = true} diff --git a/src/common/logging/src/lib.rs b/src/common/logging/src/lib.rs new file mode 100644 index 0000000000..e4c8ba1f7c --- /dev/null +++ b/src/common/logging/src/lib.rs @@ -0,0 +1,85 @@ +use std::sync::{Arc, LazyLock}; + +use arc_swap::ArcSwap; +use log::Log; + +type BoxedLogger = Box; + +/// A logger that can be internally modified at runtime. +/// Usually, loggers can only be initialized once, but this container can +/// swap out the internal logger at runtime atomically. +pub struct SwappableLogger { + base: ArcSwap, + temp: ArcSwap>, +} + +impl SwappableLogger { + pub fn new(logger: BoxedLogger) -> Self { + Self { + base: ArcSwap::new(Arc::new(logger)), + temp: ArcSwap::new(Arc::new(None)), + } + } + + pub fn set_base_logger(&self, logger: BoxedLogger) { + self.base.store(Arc::new(logger)); + } + + pub fn get_base_logger(&self) -> Arc { + self.base.load().to_owned() + } + + pub fn set_temp_logger(&self, logger: BoxedLogger) { + self.temp.store(Arc::new(Some(logger))); + } + + pub fn reset_temp_logger(&self) { + self.temp.store(Arc::new(None)); + } +} + +impl Log for SwappableLogger { + fn enabled(&self, metadata: &log::Metadata) -> bool { + if let Some(temp) = self.temp.load().as_ref() { + temp.enabled(metadata) + } else { + self.base.load().enabled(metadata) + } + } + + fn log(&self, record: &log::Record) { + if let Some(temp) = self.temp.load().as_ref() { + temp.log(record); + } else { + self.base.load().log(record); + } + } + + fn flush(&self) { + if let Some(temp) = self.temp.load().as_ref() { + temp.flush(); + } else { + self.base.load().flush(); + } + } +} + +/// A Noop logger that does nothing. +/// Used for initialization purposes only, should never actually be used. +struct NoopLogger; + +impl Log for NoopLogger { + fn enabled(&self, _metadata: &log::Metadata) -> bool { + false + } + + fn log(&self, _record: &log::Record) {} + + fn flush(&self) {} +} + +/// The global logger that can be swapped out at runtime. +/// This is initialized to a NoopLogger to avoid any logging during initialization. +/// It can be swapped out with a real logger using `set_inner_logger`. +pub static GLOBAL_LOGGER: LazyLock> = + LazyLock::new(|| Arc::new(SwappableLogger::new(Box::new(NoopLogger)))); diff --git a/src/daft-dsl/src/functions/python/udf.rs b/src/daft-dsl/src/functions/python/udf.rs index 627a68b51f..d71737cb58 100644 --- a/src/daft-dsl/src/functions/python/udf.rs +++ b/src/daft-dsl/src/functions/python/udf.rs @@ -86,7 +86,6 @@ impl LegacyPythonUDF { MaybeInitializedUDF::Initialized(func) => func.clone().unwrap().clone_ref(py), MaybeInitializedUDF::Uninitialized { inner, init_args } => { // TODO(Kevin): warn user if initialization is taking too long and ask them to use actor pool UDFs - py_udf_initialize(py, inner.clone().unwrap(), init_args.clone().unwrap())? } }; diff --git a/src/daft-io/src/stats.rs b/src/daft-io/src/stats.rs index a4e70cf2ce..81410e7580 100644 --- a/src/daft-io/src/stats.rs +++ b/src/daft-io/src/stats.rs @@ -27,7 +27,7 @@ impl Drop for IOStatsContext { let num_puts = self.load_put_requests(); let mean_get_size = (bytes_read as f64) / (num_gets as f64); let mean_put_size = (bytes_uploaded as f64) / (num_puts as f64); - log::info!( + log::debug!( "IOStatsContext: {}, Gets: {}, Heads: {}, Lists: {}, BytesRead: {}, AvgGetSize: {}, BytesUploaded: {}, AvgPutSize: {}", self.name, num_gets, diff --git a/src/daft-local-execution/Cargo.toml b/src/daft-local-execution/Cargo.toml index 71df091ee8..729abed19d 100644 --- a/src/daft-local-execution/Cargo.toml +++ b/src/daft-local-execution/Cargo.toml @@ -1,10 +1,12 @@ [dependencies] +arc-swap = {workspace = true} async-trait = {workspace = true} capitalize = "*" common-daft-config = {path = "../common/daft-config", default-features = false} common-display = {path = "../common/display", default-features = false} common-error = {path = "../common/error", default-features = false} common-file-formats = {path = "../common/file-formats", default-features = false} +common-logging = {path = "../common/logging", default-features = false} common-metrics = {path = "../common/metrics", default-features = false} common-py-serde = {path = "../common/py-serde", default-features = false} common-resource-request = {path = "../common/resource-request", default-features = false} @@ -12,6 +14,7 @@ common-runtime = {path = "../common/runtime", default-features = false} common-scan-info = {path = "../common/scan-info", default-features = false} common-system-info = {path = "../common/system-info", default-features = false} common-tracing = {path = "../common/tracing", default-features = false} +console = "*" daft-core = {path = "../daft-core", default-features = false} daft-csv = {path = "../daft-csv", default-features = false} daft-dsl = {path = "../daft-dsl", default-features = false} diff --git a/src/daft-local-execution/src/intermediate_ops/udf.rs b/src/daft-local-execution/src/intermediate_ops/udf.rs index c3d47359be..62e4ea3934 100644 --- a/src/daft-local-execution/src/intermediate_ops/udf.rs +++ b/src/daft-local-execution/src/intermediate_ops/udf.rs @@ -1,4 +1,12 @@ -use std::{ops::RangeInclusive, sync::Arc, time::Duration, vec}; +use std::{ + ops::RangeInclusive, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + time::Duration, + vec, +}; use common_error::{DaftError, DaftResult}; use common_resource_request::ResourceRequest; @@ -31,17 +39,25 @@ use crate::{ops::NodeType, pipeline::NodeName, ExecutionRuntimeContext, Executio const NUM_TEST_ITERATIONS_RANGE: RangeInclusive = 10..=20; const GIL_CONTRIBUTION_THRESHOLD: f64 = 0.5; -pub(crate) struct UdfHandle { +/// Common parameters for UDF handle and operator +struct UdfParams { udf_expr: BoundExpr, passthrough_columns: Vec, + udf_name: String, output_schema: SchemaRef, +} + +pub(crate) struct UdfHandle { + params: Arc, + worker_idx: usize, // Optional PyObject handle to external UDF worker. // Required for ActorPoolUDFs // Optional for stateless UDFs // - Starts as None indicating that the UDF is run in-line with the thread // - If excessive GIL contention is detected, the UDF will be moved to an external worker + // Second bool indicates if the UDF was initialized #[cfg(feature = "python")] - inner: Option, + handle: Option, // Data used to track GIL contention check_gil_contention: bool, min_num_test_iterations: usize, @@ -52,18 +68,16 @@ pub(crate) struct UdfHandle { impl UdfHandle { fn no_handle( - udf_expr: &BoundExpr, - passthrough_columns: &[BoundExpr], - output_schema: SchemaRef, + params: Arc, + worker_idx: usize, check_gil_contention: bool, min_num_test_iterations: usize, ) -> Self { Self { - udf_expr: udf_expr.clone(), - passthrough_columns: passthrough_columns.to_vec(), - output_schema, + params, + worker_idx, #[cfg(feature = "python")] - inner: None, + handle: None, check_gil_contention, min_num_test_iterations, total_runtime: Duration::from_secs(0), @@ -75,13 +89,15 @@ impl UdfHandle { fn create_handle(&mut self) -> DaftResult<()> { #[cfg(feature = "python")] { - let py_expr = PyExpr::from(self.udf_expr.as_ref().clone()); + let py_expr = PyExpr::from(self.params.udf_expr.as_ref().clone()); let passthrough_exprs = self + .params .passthrough_columns .iter() .map(|expr| PyExpr::from(expr.as_ref().clone())) .collect::>(); - self.inner = Some(Python::with_gil(|py| { + + let handle = Python::with_gil(|py| { // create python object Ok::( py.import(pyo3::intern!(py, "daft.execution.udf"))? @@ -89,7 +105,9 @@ impl UdfHandle { .call1((py_expr, passthrough_exprs))? .unbind(), ) - })?); + })?; + + self.handle = Some(handle); } #[cfg(not(feature = "python"))] @@ -104,18 +122,25 @@ impl UdfHandle { fn eval_input_with_handle( &self, input: Arc, - inner: &PyObject, + handle: &PyObject, ) -> DaftResult> { - Python::with_gil(|py| { - Ok(inner + use crate::STDOUT; + + let (micropartition, outs) = Python::with_gil(|py| { + handle .bind(py) .call_method1( pyo3::intern!(py, "eval_input"), (PyMicroPartition::from(input),), )? - .extract::()? - .into()) - }) + .extract::<(PyMicroPartition, Vec)>() + })?; + + let label = format!("[`{}` Worker #{}]", self.params.udf_name, self.worker_idx); + for line in outs { + STDOUT.print(&label, &line); + } + Ok(micropartition.into()) } #[cfg(feature = "python")] @@ -126,7 +151,7 @@ impl UdfHandle { functions::{python::LegacyPythonUDF, scalar::ScalarFn}, python_udf::PyScalarFn, }; - let inner_expr = self.udf_expr.inner(); + let inner_expr = self.params.udf_expr.inner(); let (inner_expr, out_name) = inner_expr.unwrap_alias(); enum UdfImpl<'a> { @@ -153,7 +178,6 @@ impl UdfHandle { let input_batches = input.get_tables()?; let mut output_batches = Vec::with_capacity(input_batches.len()); - // Iterate over MicroPartition batches for batch in input_batches.as_ref() { use std::time::Instant; @@ -165,8 +189,7 @@ impl UdfHandle { UdfImpl::Legacy(f) => f.call_udf(func_input.columns())?, UdfImpl::PyScalarFn(f) => f.call(func_input.columns())?, }; - let end_time = Instant::now(); - let total_runtime = end_time - start_time; + let total_runtime = start_time.elapsed(); // Rename if necessary if let Some(out_name) = out_name.as_ref() { @@ -179,7 +202,7 @@ impl UdfHandle { self.num_batches += 1; let passthrough_input = - batch.eval_expression_list(self.passthrough_columns.as_slice())?; + batch.eval_expression_list(self.params.passthrough_columns.as_slice())?; let series = passthrough_input.append_column(result)?; output_batches.push(series); } @@ -194,7 +217,7 @@ impl UdfHandle { } Ok(Arc::new(MicroPartition::new_loaded( - self.output_schema.clone(), + self.params.output_schema.clone(), Arc::new(output_batches), None, ))) @@ -208,8 +231,8 @@ impl UdfHandle { #[cfg(feature = "python")] { - if let Some(inner) = &self.inner { - self.eval_input_with_handle(input, inner) + if let Some(handle) = &self.handle { + self.eval_input_with_handle(input, handle) } else { self.eval_input_inline(input) } @@ -219,12 +242,14 @@ impl UdfHandle { fn teardown(&self) -> DaftResult<()> { #[cfg(feature = "python")] { - let Some(inner) = &self.inner else { + let Some(handle) = &self.handle else { return Ok(()); }; Python::with_gil(|py| { - inner.bind(py).call_method0(pyo3::intern!(py, "teardown"))?; + handle + .bind(py) + .call_method0(pyo3::intern!(py, "teardown"))?; Ok(()) }) } @@ -256,9 +281,8 @@ pub(crate) struct UdfState { } pub(crate) struct UdfOperator { - project: BoundExpr, - passthrough_columns: Vec, - output_schema: SchemaRef, + params: Arc, + worker_count: AtomicUsize, udf_properties: UDFProperties, concurrency: usize, memory_request: u64, @@ -272,7 +296,6 @@ impl UdfOperator { ) -> DaftResult { let project_unbound = project.inner().clone(); - // count_udfs counts both actor pool and stateless udfs let num_udfs = count_udfs(&[project_unbound.clone()]); assert_eq!(num_udfs, 1, "Expected only one udf in an udf project"); let udf_properties = get_udf_properties(&project_unbound); @@ -291,9 +314,13 @@ impl UdfOperator { .unwrap_or(0); Ok(Self { - project, - passthrough_columns, - output_schema: output_schema.clone(), + params: Arc::new(UdfParams { + udf_expr: project, + passthrough_columns, + udf_name: udf_properties.name.clone(), + output_schema: output_schema.clone(), + }), + worker_count: AtomicUsize::new(0), udf_properties, concurrency, memory_request, @@ -315,8 +342,7 @@ impl UdfOperator { let requested_num_cpus = resource_request.num_cpus().unwrap(); if requested_num_cpus > num_cpus as f64 { Err(DaftError::ValueError(format!( - "{}: Requested {} CPUs but only found {} available", - full_name, requested_num_cpus, num_cpus + "`{full_name}` requested {requested_num_cpus} CPUs but found only {num_cpus} available" ))) } else { Ok((num_cpus as f64 / requested_num_cpus).clamp(1.0, num_cpus as f64) as usize) @@ -355,11 +381,10 @@ impl IntermediateOperator for UdfOperator { } fn name(&self) -> NodeName { - let full_name = &self.udf_properties.name; - let udf_name = if let Some((_, udf_name)) = full_name.rsplit_once('.') { + let udf_name = if let Some((_, udf_name)) = self.params.udf_name.rsplit_once('.') { udf_name } else { - full_name + &self.params.udf_name }; format!("UDF {}", udf_name).into() @@ -374,13 +399,13 @@ impl IntermediateOperator for UdfOperator { res.push("UDF Executor:".to_string()); res.push(format!( "UDF {} = {}", - self.udf_properties.name, self.project + self.params.udf_name, self.params.udf_expr )); res.push(format!( "Passthrough Columns = [{}]", - self.passthrough_columns.iter().join(", ") + self.params.passthrough_columns.iter().join(", ") )); - res.push(format!("Concurrency = {:?}", self.concurrency)); + res.push(format!("Concurrency = {}", self.concurrency)); if let Some(resource_request) = &self.udf_properties.resource_request { let multiline_display = resource_request.multiline_display(); res.push(format!( @@ -394,15 +419,16 @@ impl IntermediateOperator for UdfOperator { } fn make_state(&self) -> DaftResult { + let worker_count = self.worker_count.fetch_add(1, Ordering::SeqCst); let mut rng = rand::thread_rng(); let mut udf_handle = UdfHandle::no_handle( - &self.project, - &self.passthrough_columns, - self.output_schema.clone(), + self.params.clone(), + worker_count, matches!(self.udf_properties.use_process, Some(false)), rng.gen_range(NUM_TEST_ITERATIONS_RANGE), ); + if self.udf_properties.is_actor_pool_udf() || self.udf_properties.use_process.unwrap_or(false) { diff --git a/src/daft-local-execution/src/lib.rs b/src/daft-local-execution/src/lib.rs index cc837c1cec..65809ca581 100644 --- a/src/daft-local-execution/src/lib.rs +++ b/src/daft-local-execution/src/lib.rs @@ -17,12 +17,14 @@ mod streaming_sink; use std::{ future::Future, pin::Pin, - sync::Arc, + sync::{Arc, LazyLock}, task::{Context, Poll}, }; +use arc_swap::ArcSwap; use common_error::{DaftError, DaftResult}; use common_runtime::{RuntimeRef, RuntimeTask}; +use console::style; use resource_manager::MemoryManager; pub use run::{ExecutionEngineResult, NativeExecutor}; use runtime_stats::{RuntimeStats, RuntimeStatsManager, TimedFuture}; @@ -245,6 +247,50 @@ impl ExecutionTaskSpawner { } } +// ---------------------------- STDOUT / STDERR PIPING ---------------------------- // + +/// Target for printing to. +trait PythonPrintTarget: Send + Sync + 'static { + fn println(&self, message: &str); +} + +/// A static entity that redirects Python sys.stdout / sys.stderr to handle Rust side effects. +/// Tracks internal tags to reduce interweaving of user prints +/// Can also register callbacks, for example for suspending the progress bar before prints. +struct StdoutHandler { + target: ArcSwap>>, +} + +impl StdoutHandler { + pub fn new() -> Self { + Self { + target: ArcSwap::new(Arc::new(None)), + } + } + + fn set_target(&self, target: Box) { + self.target.store(Arc::new(Some(target))); + } + + fn reset_target(&self) { + self.target.store(Arc::new(None)); + } + + fn print(&self, prefix: &str, message: &str) { + let message = format!("{} {}", style(prefix).magenta(), message); + + if let Some(target) = self.target.load().as_ref() { + target.println(&message); + } else { + println!("{message}"); + } + } +} + +static STDOUT: LazyLock> = LazyLock::new(|| Arc::new(StdoutHandler::new())); + +// -------------------------------------------------------------------------------- // + #[cfg(feature = "python")] use pyo3::prelude::*; diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index 695c9c6889..ad36c3b912 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -423,7 +423,6 @@ pub fn physical_plan_to_pipeline( passthrough_columns, stats_state, schema, - .. }) => { let proj_op = UdfOperator::try_new(project.clone(), passthrough_columns.clone(), schema) diff --git a/src/daft-local-execution/src/runtime_stats/mod.rs b/src/daft-local-execution/src/runtime_stats/mod.rs index 3027512045..2a23babde7 100644 --- a/src/daft-local-execution/src/runtime_stats/mod.rs +++ b/src/daft-local-execution/src/runtime_stats/mod.rs @@ -141,6 +141,16 @@ impl RuntimeStatsManager { loop { tokio::select! { biased; + _ = &mut finish_rx => { + if !active_nodes.is_empty() { + log::debug!( + "RuntimeStatsManager finished with active nodes {{{}}}", + active_nodes.iter().map(|id: &usize| id.to_string()).join(", ") + ); + } + break; + } + Some((node_id, is_initialize)) = node_rx.recv() => { if is_initialize && active_nodes.insert(node_id) { for subscriber in &subscribers { @@ -163,16 +173,6 @@ impl RuntimeStatsManager { } } - _ = &mut finish_rx => { - if !active_nodes.is_empty() { - log::warn!( - "RuntimeStatsManager finished with active nodes {{{}}}", - active_nodes.iter().map(|id| id.to_string()).join(", ") - ); - } - break; - } - _ = interval.tick() => { if active_nodes.is_empty() { continue; diff --git a/src/daft-local-execution/src/runtime_stats/subscribers/progress_bar.rs b/src/daft-local-execution/src/runtime_stats/subscribers/progress_bar.rs index 774a173e64..83e43b774d 100644 --- a/src/daft-local-execution/src/runtime_stats/subscribers/progress_bar.rs +++ b/src/daft-local-execution/src/runtime_stats/subscribers/progress_bar.rs @@ -1,13 +1,16 @@ use std::{sync::Arc, time::Duration}; use common_error::DaftResult; +use common_logging::GLOBAL_LOGGER; use common_metrics::StatSnapshotSend; use indicatif::{ProgressDrawTarget, ProgressStyle}; use itertools::Itertools; +use log::Log; use crate::{ ops::{NodeCategory, NodeInfo}, runtime_stats::{subscribers::RuntimeStatsSubscriber, RuntimeStats, CPU_US_KEY}, + PythonPrintTarget, STDOUT, }; /// Convert statistics to a message for progress bars @@ -39,6 +42,49 @@ impl ProgressBarColor { const TICK_INTERVAL: Duration = Duration::from_millis(100); +struct IndicatifLogger { + pbar: indicatif::MultiProgress, + inner: L, +} + +impl IndicatifLogger { + fn new(pbar: indicatif::MultiProgress, inner: L) -> Self { + Self { pbar, inner } + } +} + +impl Log for IndicatifLogger { + fn enabled(&self, metadata: &log::Metadata) -> bool { + self.inner.enabled(metadata) + } + + fn log(&self, record: &log::Record) { + if self.inner.enabled(record.metadata()) { + self.pbar.suspend(|| self.inner.log(record)); + } + } + + fn flush(&self) { + self.inner.flush(); + } +} + +struct IndicatifPrintTarget { + pbar: indicatif::MultiProgress, +} + +impl IndicatifPrintTarget { + fn new(pbar: indicatif::MultiProgress) -> Self { + Self { pbar } + } +} + +impl PythonPrintTarget for IndicatifPrintTarget { + fn println(&self, message: &str) { + self.pbar.println(message).unwrap(); + } +} + #[derive(Debug)] struct IndicatifProgressBarManager { multi_progress: indicatif::MultiProgress, @@ -49,6 +95,17 @@ struct IndicatifProgressBarManager { impl IndicatifProgressBarManager { fn new(node_stats: &[(Arc, Arc)]) -> Self { let multi_progress = indicatif::MultiProgress::new(); + + if cfg!(feature = "python") { + // Register the IndicatifLogger to redirect Rust logs correctly + GLOBAL_LOGGER.set_temp_logger(Box::new(IndicatifLogger::new( + multi_progress.clone(), + GLOBAL_LOGGER.get_base_logger(), + ))); + + STDOUT.set_target(Box::new(IndicatifPrintTarget::new(multi_progress.clone()))); + } + multi_progress.set_move_cursor(true); multi_progress.set_draw_target(ProgressDrawTarget::stderr_with_hz(10)); @@ -99,6 +156,15 @@ impl IndicatifProgressBarManager { } } +impl Drop for IndicatifProgressBarManager { + fn drop(&mut self) { + if cfg!(feature = "python") { + GLOBAL_LOGGER.reset_temp_logger(); + STDOUT.reset_target(); + } + } +} + impl RuntimeStatsSubscriber for IndicatifProgressBarManager { #[cfg(test)] #[allow(dead_code)] @@ -126,7 +192,8 @@ impl RuntimeStatsSubscriber for IndicatifProgressBarManager { Ok(()) } - fn finish(self: Box) -> DaftResult<()> { + fn finish(mut self: Box) -> DaftResult<()> { + self.pbars.clear(); self.multi_progress.clear()?; Ok(()) } diff --git a/src/daft-physical-plan/src/physical_planner/planner.rs b/src/daft-physical-plan/src/physical_planner/planner.rs index b3e530ff8d..8c8c011bbb 100644 --- a/src/daft-physical-plan/src/physical_planner/planner.rs +++ b/src/daft-physical-plan/src/physical_planner/planner.rs @@ -545,11 +545,11 @@ impl AdaptivePlanner { self.stage_cache.insert_stage(&next_stage)?; match &next_stage { QueryStageOutput::Final { physical_plan } => { - log::info!("Emitting final plan:\n {}", physical_plan.repr_ascii(true)); + log::debug!("Emitting final plan:\n {}", physical_plan.repr_ascii(true)); self.status = AdaptivePlannerStatus::Done; } QueryStageOutput::Partial { physical_plan, .. } => { - log::info!( + log::debug!( "Emitting partial plan:\n {}", physical_plan.repr_ascii(true) ); diff --git a/src/lib.rs b/src/lib.rs index bdf528c36f..e2d3f29911 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -49,10 +49,18 @@ fn should_enable_chrome_trace() -> bool { pub mod pylib { use std::sync::LazyLock; + use common_logging::GLOBAL_LOGGER; use common_tracing::{init_opentelemetry_providers, init_tracing}; use pyo3::prelude::*; - static LOG_RESET_HANDLE: LazyLock = LazyLock::new(pyo3_log::init); + static LOG_RESET_HANDLE: LazyLock = LazyLock::new(|| { + let py_logger = Box::new(pyo3_log::Logger::default()); + let handle = py_logger.reset_handle(); + + GLOBAL_LOGGER.set_base_logger(py_logger); + log::set_boxed_logger(Box::new(GLOBAL_LOGGER.clone())).unwrap(); + handle + }); #[pyfunction] pub fn version() -> &'static str { diff --git a/tests/expressions/test_legacy_udf.py b/tests/expressions/test_legacy_udf.py index 2fafdeba79..7b04ca4e62 100644 --- a/tests/expressions/test_legacy_udf.py +++ b/tests/expressions/test_legacy_udf.py @@ -40,7 +40,7 @@ def repeat_n(data, n): def test_class_udf(batch_size, use_actor_pool): df = daft.from_pydict({"a": ["foo", "bar", "baz"]}) - @udf(return_dtype=DataType.string(), batch_size=batch_size) + @udf(return_dtype=DataType.string(), batch_size=batch_size, use_process=False) class RepeatN: def __init__(self): self.n = 2