Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add tests and pytest github action #28

Merged
merged 27 commits into from
Sep 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/pypi_release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ jobs:
uses: JRubics/[email protected]
with:
pypi_token: ${{ secrets.PYPI_TOKEN }}
poetry_install_options: "--without dev"
33 changes: 33 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
name: Tests

on:
pull_request:
types: [opened, reopened, ready_for_review, review_requested]

jobs:
run-pytest:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4

- name: Set up python
uses: actions/setup-python@v5
with:
python-version: '3.12'

- name: Install Poetry
uses: snok/install-poetry@v1

- name: Install dependencies
run: poetry install --no-interaction --with dev

- name: Run tests
run: poetry run pytest --junitxml=pytest.xml --cov-report=term-missing:skip-covered --cov=tests/ > pytest-coverage.txt

- name: Pytest coverage comment
uses: MishaKav/pytest-coverage-comment@main
with:
pytest-coverage-path: ./pytest-coverage.txt
junitxml-path: ./pytest.xml
13 changes: 1 addition & 12 deletions mqboost/constraints.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,6 @@
import pandas as pd

from mqboost.base import (
FUNC_TYPE,
ModelName,
MQStr,
ParamsLike,
TypeName,
ValidationException,
)
from mqboost.base import FUNC_TYPE, ModelName, MQStr, ParamsLike, TypeName


def set_monotone_constraints(
Expand All @@ -27,10 +20,6 @@ def set_monotone_constraints(
ParamsLike
"""
constraints_fucs = FUNC_TYPE.get(model_name).get(TypeName.constraints_type)
if MQStr.obj.value in params:
raise ValidationException(
"The parameter named 'objective' must be excluded in params"
)
_params = params.copy()
if MQStr.mono.value in _params:
_monotone_constraints = list(_params[MQStr.mono.value])
Expand Down
5 changes: 0 additions & 5 deletions mqboost/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,6 @@ def predict_dtype(self) -> Callable:
"""Get the data type function for prediction data."""
return self._predict_dtype

@property
def model(self) -> ModelName:
"""Get the model type."""
return self._model

@property
def columns(self) -> pd.Index:
"""Get the column names of the input features."""
Expand Down
21 changes: 8 additions & 13 deletions mqboost/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np

from mqboost.base import DtrainLike, ModelName, ObjectiveName
from mqboost.utils import delta_validate
from mqboost.utils import delta_validate, epsilon_validate

CHECK_LOSS: str = "check_loss"

Expand All @@ -26,19 +26,13 @@ def _hess_rho(error: np.ndarray, alpha: float) -> np.ndarray:


# Huber loss
def _error_delta_compare(
error: np.ndarray, delta: float
) -> tuple[np.ndarray, np.ndarray]:
"""Rerutn two boolean arrays indicating where the errors are smaller or larger than delta."""
_abs_error = np.abs(error)
return (_abs_error <= delta).astype(int), (_abs_error > delta).astype(int)


def _grad_huber(error: np.ndarray, alpha: float, delta: float) -> np.ndarray:
"""Compute the gradient of the huber loss function."""
_smaller_delta, _bigger_delta = _error_delta_compare(error=error, delta=delta)
_grad = _grad_rho(error=error, alpha=alpha)
_abs_error = np.abs(error)
_smaller_delta = (_abs_error <= delta).astype(int)
_bigger_delta = (_abs_error > delta).astype(int)
_r = _rho(error=error, alpha=alpha)
_grad = _grad_rho(error=error, alpha=alpha)
return _r * _smaller_delta + _grad * _bigger_delta


Expand Down Expand Up @@ -174,9 +168,10 @@ def __init__(
if objective == ObjectiveName.check:
self._fobj = partial(check_loss_grad_hess, alphas=alphas)
elif objective == ObjectiveName.huber:
_delta = delta_validate(delta=delta)
self._fobj = partial(huber_loss_grad_hess, alphas=alphas, delta=_delta)
delta_validate(delta=delta)
self._fobj = partial(huber_loss_grad_hess, alphas=alphas, delta=delta)
elif objective == ObjectiveName.approx:
epsilon_validate(epsilon=epsilon)
self._fobj = partial(approx_loss_grad_hess, alphas=alphas, epsilon=epsilon)

if model == ModelName.lightgbm:
Expand Down
8 changes: 6 additions & 2 deletions mqboost/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from mqboost.constraints import set_monotone_constraints
from mqboost.dataset import MQDataset
from mqboost.objective import MQObjective
from mqboost.utils import delta_validate
from mqboost.utils import delta_validate, epsilon_validate, params_validate

__all__ = ["MQOptimizer"]

Expand Down Expand Up @@ -93,9 +93,12 @@ def __init__(
epsilon: float = 1e-5,
) -> None:
"""Initialize the MQOptimizer."""
delta_validate(delta=delta)
epsilon_validate(epsilon=epsilon)

self._model = ModelName.get(model)
self._objective = ObjectiveName.get(objective)
self._delta = delta_validate(delta)
self._delta = delta
self._epsilon = epsilon
self._get_params = _GET_PARAMS_FUNC.get(self._model)

Expand Down Expand Up @@ -179,6 +182,7 @@ def __optuna_objective(
) -> float:
"""Objective function for Optuna to minimize."""
params = get_params_func(trial=trial)
params_validate(params=params)
params = set_monotone_constraints(
params=params,
columns=self._dataset.columns,
Expand Down
2 changes: 2 additions & 0 deletions mqboost/regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from mqboost.constraints import set_monotone_constraints
from mqboost.dataset import MQDataset
from mqboost.objective import MQObjective
from mqboost.utils import params_validate

__all__ = ["MQRegressor"]

Expand Down Expand Up @@ -46,6 +47,7 @@ def __init__(
epsilon: float = 1e-5,
) -> None:
"""Initialize the MQRegressor."""
params_validate(params=params)
self._params = params
self._model = ModelName.get(model)
self._objective = ObjectiveName.get(objective)
Expand Down
29 changes: 25 additions & 4 deletions mqboost/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@
import numpy as np
import pandas as pd

from mqboost.base import AlphaLike, ValidationException, XdataLike, YdataLike
from mqboost.base import (
AlphaLike,
MQStr,
ParamsLike,
ValidationException,
XdataLike,
YdataLike,
)


def alpha_validate(
Expand Down Expand Up @@ -49,7 +56,7 @@ def prepare_x(
_alpha_repeat_count_list = [list(repeat(alpha, len(x))) for alpha in alphas]
_alpha_repeat_list = list(chain.from_iterable(_alpha_repeat_count_list))

_repeated_x = pd.concat([x] * len(alphas), axis=0)
_repeated_x = pd.concat([x] * len(alphas), axis=0).reset_index(drop=True)
_repeated_x = _repeated_x.assign(
_tau=_alpha_repeat_list,
)
Expand All @@ -64,7 +71,7 @@ def prepare_y(
return np.concatenate(list(repeat(y, len(alphas))))


def delta_validate(delta: float) -> float:
def delta_validate(delta: float) -> None:
"""Validates the delta parameter ensuring it is a positive float and less than or equal to 0.05."""
_delta_upper_bound: float = 0.05

Expand All @@ -77,4 +84,18 @@ def delta_validate(delta: float) -> float:
if delta > _delta_upper_bound:
warnings.warn("Delta should be 0.05 or less.")

return delta

def epsilon_validate(epsilon: float) -> None:
if not isinstance(epsilon, float):
raise ValidationException("Epsilon is not float type")

if epsilon <= 0:
raise ValidationException("Epsilon must be positive")


def params_validate(params: ParamsLike) -> None:
"""Validates the model parameter ensuring its key dosen't contain 'objective'."""
if MQStr.obj.value in params:
raise ValidationException(
"The parameter named 'objective' must be excluded in params"
)
Loading
Loading