Skip to content

Commit

Permalink
[BUG] Use python logging level (#2705)
Browse files Browse the repository at this point in the history
Example script:
```
import daft
print("Only warnings and errors should print")
daft.daft.test_logging()

print("\nSetting logging level to debug, all messages should print")
from daft.logging import setup_debug_logger
setup_debug_logger()
daft.daft.test_logging()
```

Output:
```
Only warnings and errors should print
WARN from rust
ERROR from rust

Setting logging level to debug, all messages should print
DEBUG:daft.pylib:DEBUG from rust
INFO:daft.pylib:INFO from rust
WARNING:daft.pylib:WARN from rust
ERROR:daft.pylib:ERROR from rust
```

---------

Co-authored-by: Colin Ho <[email protected]>
Co-authored-by: Colin Ho <[email protected]>
  • Loading branch information
3 people authored Aug 24, 2024
1 parent 3647b26 commit bf5c853
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 2 deletions.
1 change: 1 addition & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1794,6 +1794,7 @@ class PyDaftPlanningConfig:
def build_type() -> str: ...
def version() -> str: ...
def refresh_logger() -> None: ...
def get_max_log_level() -> str: ...
def __getattr__(name) -> Any: ...
def io_glob(
path: str,
Expand Down
5 changes: 5 additions & 0 deletions daft/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@

import logging

from daft import refresh_logger


def setup_debug_logger(
exclude_prefix: list[str] = [],
daft_only: bool = True,
):
logging.basicConfig(level="DEBUG")
root_logger = logging.getLogger()
root_logger.setLevel(logging.DEBUG)

if daft_only:
for handler in root_logger.handlers:
Expand All @@ -18,3 +21,5 @@ def setup_debug_logger(
for prefix in exclude_prefix:
for handler in root_logger.handlers:
handler.addFilter(lambda record: not record.name.startswith(prefix))

refresh_logger()
31 changes: 29 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ pub mod pylib {
lazy_static! {
static ref LOG_RESET_HANDLE: pyo3_log::ResetHandle = pyo3_log::init();
}

#[pyfunction]
pub fn version() -> &'static str {
daft_core::VERSION
Expand All @@ -63,13 +64,38 @@ pub mod pylib {
}

#[pyfunction]
pub fn refresh_logger() {
pub fn get_max_log_level() -> &'static str {
log::max_level().as_str()
}

#[pyfunction]
pub fn refresh_logger(py: Python) -> PyResult<()> {
use log::LevelFilter;
let logging = py.import("logging")?;
let python_log_level = logging
.getattr("getLogger")?
.call0()?
.getattr("level")?
.extract::<usize>()
.unwrap_or(0);

// https://docs.python.org/3/library/logging.html#logging-levels
let level_filter = match python_log_level {
0 => LevelFilter::Off,
1..=10 => LevelFilter::Debug,
11..=20 => LevelFilter::Info,
21..=30 => LevelFilter::Warn,
_ => LevelFilter::Error,
};

LOG_RESET_HANDLE.reset();
log::set_max_level(level_filter);
Ok(())
}

#[pymodule]
fn daft(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
refresh_logger();
refresh_logger(_py)?;
init_tracing(crate::should_enable_chrome_trace());

common_daft_config::register_modules(_py, m)?;
Expand All @@ -93,6 +119,7 @@ pub mod pylib {
m.add_wrapped(wrap_pyfunction!(version))?;
m.add_wrapped(wrap_pyfunction!(build_type))?;
m.add_wrapped(wrap_pyfunction!(refresh_logger))?;
m.add_wrapped(wrap_pyfunction!(get_max_log_level))?;
Ok(())
}
}
41 changes: 41 additions & 0 deletions tests/test_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import logging

import pytest


def test_logger_initialization():
import daft

rust_level = daft.daft.get_max_log_level()

assert rust_level == "WARN"


def test_debug_logger():
import daft
from daft.logging import setup_debug_logger

setup_debug_logger()
rust_level = daft.daft.get_max_log_level()
assert rust_level == "DEBUG"


@pytest.mark.parametrize(
"level, expected",
[
(logging.DEBUG, "DEBUG"),
(logging.INFO, "INFO"),
(logging.WARNING, "WARN"),
(logging.ERROR, "ERROR"),
],
)
def test_refresh_logger(level, expected):
import logging

import daft

logging.getLogger().setLevel(level)
daft.daft.refresh_logger()

rust_level = daft.daft.get_max_log_level()
assert rust_level == expected

0 comments on commit bf5c853

Please sign in to comment.