diff --git a/docs/FAQ.rst b/docs/FAQ.rst
index 14c7f7dd7265..aaaf4094ce3c 100644
--- a/docs/FAQ.rst
+++ b/docs/FAQ.rst
@@ -377,3 +377,42 @@ We strongly recommend installation from the ``conda-forge`` channel and not from
 For some specific examples, see `this comment <https://github.com/microsoft/LightGBM/issues/4948#issuecomment-1013766397>`__.
 
 In addition, as of ``lightgbm==4.4.0``, the ``conda-forge`` package automatically supports CUDA-based GPU acceleration.
+
+5. How do I subclass ``scikit-learn`` estimators?
+-------------------------------------------------
+
+For ``lightgbm <= 4.5.0``, copy all of the constructor arguments from the corresponding
+``lightgbm`` class into the constructor of your custom estimator.
+
+For later versions, just ensure that the constructor of your custom estimator calls ``super().__init__()``.
+
+Consider the example below, which implements a regressor that allows creation of truncated predictions.
+This pattern will work with ``lightgbm > 4.5.0``.
+
+.. code-block:: python
+
+    import numpy as np
+    from lightgbm import LGBMRegressor
+    from sklearn.datasets import make_regression
+
+    class TruncatedRegressor(LGBMRegressor):
+
+        def __init__(self, **kwargs):
+            super().__init__(**kwargs)
+
+        def predict(self, X, max_score: float = np.inf):
+            preds = super().predict(X)
+            np.clip(preds, a_min=None, a_max=max_score, out=preds)
+            return preds
+
+    X, y = make_regression(n_samples=1_000, n_features=4)
+
+    reg_trunc = TruncatedRegressor().fit(X, y)
+
+    preds = reg_trunc.predict(X)
+    print(f"mean: {preds.mean():.2f}, max: {preds.max():.2f}")
+    # mean: -6.81, max: 345.10
+
+    preds_trunc = reg_trunc.predict(X, max_score=preds.mean())
+    print(f"mean: {preds_trunc.mean():.2f}, max: {preds_trunc.max():.2f}")
+    # mean: -56.50, max: -6.81
diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py
index 8b939a8cb49d..12e778f37075 100644
--- a/python-package/lightgbm/dask.py
+++ b/python-package/lightgbm/dask.py
@@ -1115,6 +1115,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
 
     def __init__(
         self,
+        *,
         boosting_type: str = "gbdt",
         num_leaves: int = 31,
         max_depth: int = -1,
@@ -1318,6 +1319,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
 
     def __init__(
         self,
+        *,
         boosting_type: str = "gbdt",
         num_leaves: int = 31,
         max_depth: int = -1,
@@ -1485,6 +1487,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
 
     def __init__(
         self,
+        *,
         boosting_type: str = "gbdt",
         num_leaves: int = 31,
         max_depth: int = -1,
diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py
index 108ef1e14498..ab0686e216fa 100644
--- a/python-package/lightgbm/sklearn.py
+++ b/python-package/lightgbm/sklearn.py
@@ -488,6 +488,7 @@ class LGBMModel(_LGBMModelBase):
 
     def __init__(
         self,
+        *,
         boosting_type: str = "gbdt",
         num_leaves: int = 31,
         max_depth: int = -1,
@@ -745,7 +746,35 @@ def get_params(self, deep: bool = True) -> Dict[str, Any]:
         params : dict
             Parameter names mapped to their values.
         """
+        # Based on: https://github.com/dmlc/xgboost/blob/bd92b1c9c0db3e75ec3dfa513e1435d518bb535d/python-package/xgboost/sklearn.py#L941
+        # which was based on: https://stackoverflow.com/questions/59248211
+        #
+        # `get_params()` flows like this:
+        #
+        # 0. Get parameters in subclass (self.__class__) first, by using inspect.
+        # 1. Get parameters in all parent classes (especially `LGBMModel`).
+        # 2. Get whatever was passed via `**kwargs`.
+        # 3. Merge them.
+        #
+        # This needs to accommodate being called recursively in the following
+        # inheritance graphs (and similar for classification and ranking):
+        #
+        #   DaskLGBMRegressor -> LGBMRegressor     -> LGBMModel -> BaseEstimator
+        #   (custom subclass) -> LGBMRegressor     -> LGBMModel -> BaseEstimator
+        #                        LGBMRegressor     -> LGBMModel -> BaseEstimator
+        #                        (custom subclass) -> LGBMModel -> BaseEstimator
+        #                                             LGBMModel -> BaseEstimator
+        #
         params = super().get_params(deep=deep)
+        cp = copy.copy(self)
+        # If the immediate parent defines get_params(), use that.
+        if callable(getattr(cp.__class__.__bases__[0], "get_params", None)):
+            cp.__class__ = cp.__class__.__bases__[0]
+        # Otherwise, skip it and assume the next class will have it.
+        # This is here primarily for cases where the first class in MRO is a scikit-learn mixin.
+        else:
+            cp.__class__ = cp.__class__.__bases__[1]
+        params.update(cp.__class__.get_params(cp, deep))
         params.update(self._other_params)
         return params
 
@@ -1285,6 +1314,57 @@ def feature_names_in_(self) -> None:
 class LGBMRegressor(_LGBMRegressorBase, LGBMModel):
     """LightGBM regressor."""
 
+    # NOTE: all args from LGBMModel.__init__() are intentionally repeated here for
+    #       docs, help(), and tab completion.
+    def __init__(
+        self,
+        *,
+        boosting_type: str = "gbdt",
+        num_leaves: int = 31,
+        max_depth: int = -1,
+        learning_rate: float = 0.1,
+        n_estimators: int = 100,
+        subsample_for_bin: int = 200000,
+        objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None,
+        class_weight: Optional[Union[Dict, str]] = None,
+        min_split_gain: float = 0.0,
+        min_child_weight: float = 1e-3,
+        min_child_samples: int = 20,
+        subsample: float = 1.0,
+        subsample_freq: int = 0,
+        colsample_bytree: float = 1.0,
+        reg_alpha: float = 0.0,
+        reg_lambda: float = 0.0,
+        random_state: Optional[Union[int, np.random.RandomState, np.random.Generator]] = None,
+        n_jobs: Optional[int] = None,
+        importance_type: str = "split",
+        **kwargs: Any,
+    ) -> None:
+        super().__init__(
+            boosting_type=boosting_type,
+            num_leaves=num_leaves,
+            max_depth=max_depth,
+            learning_rate=learning_rate,
+            n_estimators=n_estimators,
+            subsample_for_bin=subsample_for_bin,
+            objective=objective,
+            class_weight=class_weight,
+            min_split_gain=min_split_gain,
+            min_child_weight=min_child_weight,
+            min_child_samples=min_child_samples,
+            subsample=subsample,
+            subsample_freq=subsample_freq,
+            colsample_bytree=colsample_bytree,
+            reg_alpha=reg_alpha,
+            reg_lambda=reg_lambda,
+            random_state=random_state,
+            n_jobs=n_jobs,
+            importance_type=importance_type,
+            **kwargs,
+        )
+
+    __init__.__doc__ = LGBMModel.__init__.__doc__
+
     def _more_tags(self) -> Dict[str, Any]:
         # handle the case where RegressorMixin possibly provides _more_tags()
         if callable(getattr(_LGBMRegressorBase, "_more_tags", None)):
@@ -1344,6 +1424,57 @@ def fit(  # type: ignore[override]
 class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
     """LightGBM classifier."""
 
+    # NOTE: all args from LGBMModel.__init__() are intentionally repeated here for
+    #       docs, help(), and tab completion.
+    def __init__(
+        self,
+        *,
+        boosting_type: str = "gbdt",
+        num_leaves: int = 31,
+        max_depth: int = -1,
+        learning_rate: float = 0.1,
+        n_estimators: int = 100,
+        subsample_for_bin: int = 200000,
+        objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None,
+        class_weight: Optional[Union[Dict, str]] = None,
+        min_split_gain: float = 0.0,
+        min_child_weight: float = 1e-3,
+        min_child_samples: int = 20,
+        subsample: float = 1.0,
+        subsample_freq: int = 0,
+        colsample_bytree: float = 1.0,
+        reg_alpha: float = 0.0,
+        reg_lambda: float = 0.0,
+        random_state: Optional[Union[int, np.random.RandomState, np.random.Generator]] = None,
+        n_jobs: Optional[int] = None,
+        importance_type: str = "split",
+        **kwargs: Any,
+    ) -> None:
+        super().__init__(
+            boosting_type=boosting_type,
+            num_leaves=num_leaves,
+            max_depth=max_depth,
+            learning_rate=learning_rate,
+            n_estimators=n_estimators,
+            subsample_for_bin=subsample_for_bin,
+            objective=objective,
+            class_weight=class_weight,
+            min_split_gain=min_split_gain,
+            min_child_weight=min_child_weight,
+            min_child_samples=min_child_samples,
+            subsample=subsample,
+            subsample_freq=subsample_freq,
+            colsample_bytree=colsample_bytree,
+            reg_alpha=reg_alpha,
+            reg_lambda=reg_lambda,
+            random_state=random_state,
+            n_jobs=n_jobs,
+            importance_type=importance_type,
+            **kwargs,
+        )
+
+    __init__.__doc__ = LGBMModel.__init__.__doc__
+
     def _more_tags(self) -> Dict[str, Any]:
         # handle the case where ClassifierMixin possibly provides _more_tags()
         if callable(getattr(_LGBMClassifierBase, "_more_tags", None)):
@@ -1554,6 +1685,57 @@ class LGBMRanker(LGBMModel):
         Please use this class mainly for training and applying ranking models in common sklearnish way.
     """
 
+    # NOTE: all args from LGBMModel.__init__() are intentionally repeated here for
+    #       docs, help(), and tab completion.
+    def __init__(
+        self,
+        *,
+        boosting_type: str = "gbdt",
+        num_leaves: int = 31,
+        max_depth: int = -1,
+        learning_rate: float = 0.1,
+        n_estimators: int = 100,
+        subsample_for_bin: int = 200000,
+        objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None,
+        class_weight: Optional[Union[Dict, str]] = None,
+        min_split_gain: float = 0.0,
+        min_child_weight: float = 1e-3,
+        min_child_samples: int = 20,
+        subsample: float = 1.0,
+        subsample_freq: int = 0,
+        colsample_bytree: float = 1.0,
+        reg_alpha: float = 0.0,
+        reg_lambda: float = 0.0,
+        random_state: Optional[Union[int, np.random.RandomState, np.random.Generator]] = None,
+        n_jobs: Optional[int] = None,
+        importance_type: str = "split",
+        **kwargs: Any,
+    ) -> None:
+        super().__init__(
+            boosting_type=boosting_type,
+            num_leaves=num_leaves,
+            max_depth=max_depth,
+            learning_rate=learning_rate,
+            n_estimators=n_estimators,
+            subsample_for_bin=subsample_for_bin,
+            objective=objective,
+            class_weight=class_weight,
+            min_split_gain=min_split_gain,
+            min_child_weight=min_child_weight,
+            min_child_samples=min_child_samples,
+            subsample=subsample,
+            subsample_freq=subsample_freq,
+            colsample_bytree=colsample_bytree,
+            reg_alpha=reg_alpha,
+            reg_lambda=reg_lambda,
+            random_state=random_state,
+            n_jobs=n_jobs,
+            importance_type=importance_type,
+            **kwargs,
+        )
+
+    __init__.__doc__ = LGBMModel.__init__.__doc__
+
     def fit(  # type: ignore[override]
         self,
         X: _LGBM_ScikitMatrixLike,
diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py
index b5e17991f63d..ad13734187d8 100644
--- a/tests/python_package_test/test_dask.py
+++ b/tests/python_package_test/test_dask.py
@@ -1373,26 +1373,42 @@ def test_machines_should_be_used_if_provided(task, cluster):
 
 
 @pytest.mark.parametrize(
-    "classes",
+    "dask_est,sklearn_est",
     [
         (lgb.DaskLGBMClassifier, lgb.LGBMClassifier),
         (lgb.DaskLGBMRegressor, lgb.LGBMRegressor),
         (lgb.DaskLGBMRanker, lgb.LGBMRanker),
     ],
 )
-def test_dask_classes_and_sklearn_equivalents_have_identical_constructors_except_client_arg(classes):
-    dask_spec = inspect.getfullargspec(classes[0])
-    sklearn_spec = inspect.getfullargspec(classes[1])
+def test_dask_classes_and_sklearn_equivalents_have_identical_constructors_except_client_arg(dask_est, sklearn_est):
+    dask_spec = inspect.getfullargspec(dask_est)
+    sklearn_spec = inspect.getfullargspec(sklearn_est)
+
+    # should not allow for any varargs
     assert dask_spec.varargs == sklearn_spec.varargs
+    assert dask_spec.varargs is None
+
+    # the only varkw should be **kwargs,
+    # for pass-through to parent classes' __init__()
     assert dask_spec.varkw == sklearn_spec.varkw
-    assert dask_spec.kwonlyargs == sklearn_spec.kwonlyargs
-    assert dask_spec.kwonlydefaults == sklearn_spec.kwonlydefaults
+    assert dask_spec.varkw == "kwargs"
 
     # "client" should be the only different, and the final argument
-    assert dask_spec.args[:-1] == sklearn_spec.args
-    assert dask_spec.defaults[:-1] == sklearn_spec.defaults
-    assert dask_spec.args[-1] == "client"
-    assert dask_spec.defaults[-1] is None
+    assert dask_spec.kwonlyargs == [*sklearn_spec.kwonlyargs, "client"]
+
+    # default values for all constructor arguments should be identical
+    #
+    # NOTE: if LGBMClassifier / LGBMRanker / LGBMRegressor ever override
+    #       any of LGBMModel's constructor arguments, this will need to be updated
+    assert dask_spec.kwonlydefaults == {**sklearn_spec.kwonlydefaults, "client": None}
+
+    # only positional argument should be 'self'
+    assert dask_spec.args == sklearn_spec.args
+    assert dask_spec.args == ["self"]
+    assert dask_spec.defaults is None
+
+    # get_params() should be identical, except for "client"
+    assert dask_est().get_params() == {**sklearn_est().get_params(), "client": None}
 
 
 @pytest.mark.parametrize(
diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py
index da6c94d41183..e26e14c24ec6 100644
--- a/tests/python_package_test/test_sklearn.py
+++ b/tests/python_package_test/test_sklearn.py
@@ -1,4 +1,5 @@
 # coding: utf-8
+import inspect
 import itertools
 import math
 import re
@@ -22,6 +23,7 @@
 
 import lightgbm as lgb
 from lightgbm.compat import (
+    DASK_INSTALLED,
     DATATABLE_INSTALLED,
     PANDAS_INSTALLED,
     _sklearn_version,
@@ -83,6 +85,30 @@ def __call__(self, env):
         env.model.attr_set_inside_callback = env.iteration * 10
 
 
+class ExtendedLGBMClassifier(lgb.LGBMClassifier):
+    """Class for testing that inheriting from LGBMClassifier works"""
+
+    def __init__(self, *, some_other_param: str = "lgbm-classifier", **kwargs):
+        self.some_other_param = some_other_param
+        super().__init__(**kwargs)
+
+
+class ExtendedLGBMRanker(lgb.LGBMRanker):
+    """Class for testing that inheriting from LGBMRanker works"""
+
+    def __init__(self, *, some_other_param: str = "lgbm-ranker", **kwargs):
+        self.some_other_param = some_other_param
+        super().__init__(**kwargs)
+
+
+class ExtendedLGBMRegressor(lgb.LGBMRegressor):
+    """Class for testing that inheriting from LGBMRegressor works"""
+
+    def __init__(self, *, some_other_param: str = "lgbm-regressor", **kwargs):
+        self.some_other_param = some_other_param
+        super().__init__(**kwargs)
+
+
 def custom_asymmetric_obj(y_true, y_pred):
     residual = (y_true - y_pred).astype(np.float64)
     grad = np.where(residual < 0, -2 * 10.0 * residual, -2 * residual)
@@ -475,6 +501,193 @@ def test_clone_and_property():
     assert isinstance(clf.feature_importances_, np.ndarray)
 
 
+@pytest.mark.parametrize("estimator", (lgb.LGBMClassifier, lgb.LGBMRegressor, lgb.LGBMRanker))
+def test_estimators_all_have_the_same_kwargs_and_defaults(estimator):
+    base_spec = inspect.getfullargspec(lgb.LGBMModel)
+    subclass_spec = inspect.getfullargspec(estimator)
+
+    # should not allow for any varargs
+    assert subclass_spec.varargs == base_spec.varargs
+    assert subclass_spec.varargs is None
+
+    # the only varkw should be **kwargs,
+    assert subclass_spec.varkw == base_spec.varkw
+    assert subclass_spec.varkw == "kwargs"
+
+    # default values for all constructor arguments should be identical
+    #
+    # NOTE: if LGBMClassifier / LGBMRanker / LGBMRegressor ever override
+    #       any of LGBMModel's constructor arguments, this will need to be updated
+    assert subclass_spec.kwonlydefaults == base_spec.kwonlydefaults
+
+    # only positional argument should be 'self'
+    assert subclass_spec.args == base_spec.args
+    assert subclass_spec.args == ["self"]
+    assert subclass_spec.defaults is None
+
+    # get_params() should be identical
+    assert estimator().get_params() == lgb.LGBMModel().get_params()
+
+
+def test_subclassing_get_params_works():
+    expected_params = {
+        "boosting_type": "gbdt",
+        "class_weight": None,
+        "colsample_bytree": 1.0,
+        "importance_type": "split",
+        "learning_rate": 0.1,
+        "max_depth": -1,
+        "min_child_samples": 20,
+        "min_child_weight": 0.001,
+        "min_split_gain": 0.0,
+        "n_estimators": 100,
+        "n_jobs": None,
+        "num_leaves": 31,
+        "objective": None,
+        "random_state": None,
+        "reg_alpha": 0.0,
+        "reg_lambda": 0.0,
+        "subsample": 1.0,
+        "subsample_for_bin": 200000,
+        "subsample_freq": 0,
+    }
+
+    # Overrides, used to test that passing through **kwargs works as expected.
+    #
+    # why these?
+    #
+    #  - 'n_estimators' directly matches a keyword arg for the scikit-learn estimators
+    #  - 'eta' is a parameter alias for 'learning_rate'
+    overrides = {"n_estimators": 13, "eta": 0.07}
+
+    # lightgbm-official classes
+    for est in [lgb.LGBMModel, lgb.LGBMClassifier, lgb.LGBMRanker, lgb.LGBMRegressor]:
+        assert est().get_params() == expected_params
+        assert est(**overrides).get_params() == {
+            **expected_params,
+            "eta": 0.07,
+            "n_estimators": 13,
+            "learning_rate": 0.1,
+        }
+
+    if DASK_INSTALLED:
+        for est in [lgb.DaskLGBMClassifier, lgb.DaskLGBMRanker, lgb.DaskLGBMRegressor]:
+            assert est().get_params() == {
+                **expected_params,
+                "client": None,
+            }
+            assert est(**overrides).get_params() == {
+                **expected_params,
+                "eta": 0.07,
+                "n_estimators": 13,
+                "learning_rate": 0.1,
+                "client": None,
+            }
+
+    # custom sub-classes
+    assert ExtendedLGBMClassifier().get_params() == {**expected_params, "some_other_param": "lgbm-classifier"}
+    assert ExtendedLGBMClassifier(**overrides).get_params() == {
+        **expected_params,
+        "eta": 0.07,
+        "n_estimators": 13,
+        "learning_rate": 0.1,
+        "some_other_param": "lgbm-classifier",
+    }
+    assert ExtendedLGBMRanker().get_params() == {
+        **expected_params,
+        "some_other_param": "lgbm-ranker",
+    }
+    assert ExtendedLGBMRanker(**overrides).get_params() == {
+        **expected_params,
+        "eta": 0.07,
+        "n_estimators": 13,
+        "learning_rate": 0.1,
+        "some_other_param": "lgbm-ranker",
+    }
+    assert ExtendedLGBMRegressor().get_params() == {
+        **expected_params,
+        "some_other_param": "lgbm-regressor",
+    }
+    assert ExtendedLGBMRegressor(**overrides).get_params() == {
+        **expected_params,
+        "eta": 0.07,
+        "n_estimators": 13,
+        "learning_rate": 0.1,
+        "some_other_param": "lgbm-regressor",
+    }
+
+
+@pytest.mark.parametrize("task", all_tasks)
+def test_subclassing_works(task):
+    # param values to make training deterministic and
+    # just train a small, cheap model
+    params = {
+        "deterministic": True,
+        "force_row_wise": True,
+        "n_jobs": 1,
+        "n_estimators": 5,
+        "num_leaves": 11,
+        "random_state": 708,
+    }
+
+    X, y, g = _create_data(task=task)
+    if task == "ranking":
+        est = lgb.LGBMRanker(**params).fit(X, y, group=g)
+        est_sub = ExtendedLGBMRanker(**params).fit(X, y, group=g)
+    elif task.endswith("classification"):
+        est = lgb.LGBMClassifier(**params).fit(X, y)
+        est_sub = ExtendedLGBMClassifier(**params).fit(X, y)
+    else:
+        est = lgb.LGBMRegressor(**params).fit(X, y)
+        est_sub = ExtendedLGBMRegressor(**params).fit(X, y)
+
+    np.testing.assert_allclose(est.predict(X), est_sub.predict(X))
+
+
+@pytest.mark.parametrize(
+    "estimator_to_task",
+    [
+        (lgb.LGBMClassifier, "binary-classification"),
+        (ExtendedLGBMClassifier, "binary-classification"),
+        (lgb.LGBMRanker, "ranking"),
+        (ExtendedLGBMRanker, "ranking"),
+        (lgb.LGBMRegressor, "regression"),
+        (ExtendedLGBMRegressor, "regression"),
+    ],
+)
+def test_parameter_aliases_are_handled_correctly(estimator_to_task):
+    estimator, task = estimator_to_task
+    # scikit-learn estimators should remember every parameter passed
+    # via keyword arguments in the estimator constructor, but then
+    # only pass the correct value down to LightGBM's C++ side
+    params = {
+        "eta": 0.08,
+        "num_iterations": 3,
+        "num_leaves": 5,
+    }
+    X, y, g = _create_data(task=task)
+    mod = estimator(**params)
+    if task == "ranking":
+        mod.fit(X, y, group=g)
+    else:
+        mod.fit(X, y)
+
+    # scikit-learn get_params()
+    p = mod.get_params()
+    assert p["eta"] == 0.08
+    assert p["learning_rate"] == 0.1
+
+    # lgb.Booster's 'params' attribute
+    p = mod.booster_.params
+    assert p["eta"] == 0.08
+    assert p["learning_rate"] == 0.1
+
+    # Config in the 'LightGBM::Booster' on the C++ side
+    p = mod.booster_._get_loaded_param()
+    assert p["learning_rate"] == 0.1
+    assert "eta" not in p
+
+
 def test_joblib(tmp_path):
     X, y = make_synthetic_regression()
     X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
@@ -1463,7 +1676,10 @@ def _get_expected_failed_tests(estimator):
     return estimator._more_tags()["_xfail_checks"]
 
 
-@parametrize_with_checks([lgb.LGBMClassifier(), lgb.LGBMRegressor()], expected_failed_checks=_get_expected_failed_tests)
+@parametrize_with_checks(
+    [ExtendedLGBMClassifier(), ExtendedLGBMRegressor(), lgb.LGBMClassifier(), lgb.LGBMRegressor()],
+    expected_failed_checks=_get_expected_failed_tests,
+)
 def test_sklearn_integration(estimator, check):
     estimator.set_params(min_child_samples=1, min_data_in_bin=1)
     check(estimator)