Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 150 additions & 7 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,13 +874,6 @@ class TreeClassifierParams(object):
def __init__(self):
super(TreeClassifierParams, self).__init__()

@since("1.6.0")
def setImpurity(self, value):
"""
Sets the value of :py:attr:`impurity`.
"""
return self._set(impurity=value)

@since("1.6.0")
def getImpurity(self):
"""
Expand Down Expand Up @@ -1003,6 +996,49 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
def _create_model(self, java_model):
return DecisionTreeClassificationModel(java_model)

def setMaxDepth(self, value):
"""
Sets the value of :py:attr:`maxDepth`.
"""
return self._set(maxDepth=value)

def setMaxBins(self, value):
"""
Sets the value of :py:attr:`maxBins`.
"""
return self._set(maxBins=value)

def setMinInstancesPerNode(self, value):
"""
Sets the value of :py:attr:`minInstancesPerNode`.
"""
return self._set(minInstancesPerNode=value)

def setMinInfoGain(self, value):
"""
Sets the value of :py:attr:`minInfoGain`.
"""
return self._set(minInfoGain=value)

def setMaxMemoryInMB(self, value):
"""
Sets the value of :py:attr:`maxMemoryInMB`.
"""
return self._set(maxMemoryInMB=value)

def setCacheNodeIds(self, value):
"""
Sets the value of :py:attr:`cacheNodeIds`.
"""
return self._set(cacheNodeIds=value)

@since("1.4.0")
def setImpurity(self, value):
"""
Sets the value of :py:attr:`impurity`.
"""
return self._set(impurity=value)


@inherit_doc
class DecisionTreeClassificationModel(DecisionTreeModel, JavaClassificationModel, JavaMLWritable,
Expand Down Expand Up @@ -1133,6 +1169,63 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
def _create_model(self, java_model):
return RandomForestClassificationModel(java_model)

def setMaxDepth(self, value):
"""
Sets the value of :py:attr:`maxDepth`.
"""
return self._set(maxDepth=value)

def setMaxBins(self, value):
"""
Sets the value of :py:attr:`maxBins`.
"""
return self._set(maxBins=value)

def setMinInstancesPerNode(self, value):
"""
Sets the value of :py:attr:`minInstancesPerNode`.
"""
return self._set(minInstancesPerNode=value)

def setMinInfoGain(self, value):
"""
Sets the value of :py:attr:`minInfoGain`.
"""
return self._set(minInfoGain=value)

def setMaxMemoryInMB(self, value):
"""
Sets the value of :py:attr:`maxMemoryInMB`.
"""
return self._set(maxMemoryInMB=value)

def setCacheNodeIds(self, value):
"""
Sets the value of :py:attr:`cacheNodeIds`.
"""
return self._set(cacheNodeIds=value)

@since("1.4.0")
def setImpurity(self, value):
"""
Sets the value of :py:attr:`impurity`.
"""
return self._set(impurity=value)

@since("1.4.0")
def setNumTrees(self, value):
"""
Sets the value of :py:attr:`numTrees`.
"""
return self._set(numTrees=value)

@since("1.4.0")
def setSubsamplingRate(self, value):
"""
Sets the value of :py:attr:`subsamplingRate`.
"""
return self._set(subsamplingRate=value)

@since("2.4.0")
def setFeatureSubsetStrategy(self, value):
"""
Expand Down Expand Up @@ -1317,13 +1410,63 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
def _create_model(self, java_model):
return GBTClassificationModel(java_model)

def setMaxDepth(self, value):
"""
Sets the value of :py:attr:`maxDepth`.
"""
return self._set(maxDepth=value)

def setMaxBins(self, value):
"""
Sets the value of :py:attr:`maxBins`.
"""
return self._set(maxBins=value)

def setMinInstancesPerNode(self, value):
"""
Sets the value of :py:attr:`minInstancesPerNode`.
"""
return self._set(minInstancesPerNode=value)

def setMinInfoGain(self, value):
"""
Sets the value of :py:attr:`minInfoGain`.
"""
return self._set(minInfoGain=value)

def setMaxMemoryInMB(self, value):
"""
Sets the value of :py:attr:`maxMemoryInMB`.
"""
return self._set(maxMemoryInMB=value)

def setCacheNodeIds(self, value):
"""
Sets the value of :py:attr:`cacheNodeIds`.
"""
return self._set(cacheNodeIds=value)

@since("1.4.0")
def setImpurity(self, value):
"""
Sets the value of :py:attr:`impurity`.
"""
return self._set(impurity=value)

@since("1.4.0")
def setLossType(self, value):
"""
Sets the value of :py:attr:`lossType`.
"""
return self._set(lossType=value)

@since("1.4.0")
def setSubsamplingRate(self, value):
"""
Sets the value of :py:attr:`subsamplingRate`.
"""
return self._set(subsamplingRate=value)

@since("2.4.0")
def setFeatureSubsetStrategy(self, value):
"""
Expand Down
36 changes: 0 additions & 36 deletions python/pyspark/ml/param/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,72 +765,36 @@ class DecisionTreeParams(Params):
def __init__(self):
super(DecisionTreeParams, self).__init__()

def setMaxDepth(self, value):
"""
Sets the value of :py:attr:`maxDepth`.
"""
return self._set(maxDepth=value)

def getMaxDepth(self):
"""
Gets the value of maxDepth or its default value.
"""
return self.getOrDefault(self.maxDepth)

def setMaxBins(self, value):
"""
Sets the value of :py:attr:`maxBins`.
"""
return self._set(maxBins=value)

def getMaxBins(self):
"""
Gets the value of maxBins or its default value.
"""
return self.getOrDefault(self.maxBins)

def setMinInstancesPerNode(self, value):
"""
Sets the value of :py:attr:`minInstancesPerNode`.
"""
return self._set(minInstancesPerNode=value)

def getMinInstancesPerNode(self):
"""
Gets the value of minInstancesPerNode or its default value.
"""
return self.getOrDefault(self.minInstancesPerNode)

def setMinInfoGain(self, value):
"""
Sets the value of :py:attr:`minInfoGain`.
"""
return self._set(minInfoGain=value)

def getMinInfoGain(self):
"""
Gets the value of minInfoGain or its default value.
"""
return self.getOrDefault(self.minInfoGain)

def setMaxMemoryInMB(self, value):
"""
Sets the value of :py:attr:`maxMemoryInMB`.
"""
return self._set(maxMemoryInMB=value)

def getMaxMemoryInMB(self):
"""
Gets the value of maxMemoryInMB or its default value.
"""
return self.getOrDefault(self.maxMemoryInMB)

def setCacheNodeIds(self, value):
"""
Sets the value of :py:attr:`cacheNodeIds`.
"""
return self._set(cacheNodeIds=value)

def getCacheNodeIds(self):
"""
Gets the value of cacheNodeIds or its default value.
Expand Down
Loading