diff --git a/unifiedbooster/gbdt.py b/unifiedbooster/gbdt.py index 68f1033..97a9c4e 100644 --- a/unifiedbooster/gbdt.py +++ b/unifiedbooster/gbdt.py @@ -128,7 +128,7 @@ def fit(self, X, y, **kwargs): self.classes_ ) # for compatibility with sklearn if getattr(self, "model_type") == "gradientboosting": - setattr(self, "model").max_features *= X.shape[1] + self.model.max_features = int(self.model.max_features*X.shape[1]) return getattr(self, "model").fit(X, y, **kwargs) def predict(self, X):