Skip to content

Commit c6cd267

Browse files
committed
make depth default to 2
1 parent b04b96a commit c6cd267

File tree

4 files changed

+15
-16
lines changed

4 files changed

+15
-16
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,7 @@ class RowMatrix(
9292
s"Do not support vector operation from type ${rBrz.getClass.getName}.")
9393
}
9494
U
95-
},
96-
combOp = (U1, U2) => U1 += U2,
97-
depth = 2)
95+
}, combOp = (U1, U2) => U1 += U2)
9896
}
9997

10098
/**
@@ -109,7 +107,7 @@ class RowMatrix(
109107
seqOp = (U, v) => {
110108
RowMatrix.dspr(1.0, v, U.data)
111109
U
112-
}, combOp = (U1, U2) => U1 += U2, depth = 2)
110+
}, combOp = (U1, U2) => U1 += U2)
113111

114112
RowMatrix.triuToFull(n, GU.data)
115113
}
@@ -292,8 +290,8 @@ class RowMatrix(
292290
val (m, mean) = rows.treeAggregate[(Long, BDV[Double])]((0L, BDV.zeros[Double](n)))(
293291
seqOp = (s: (Long, BDV[Double]), v: Vector) => (s._1 + 1L, s._2 += v.toBreeze),
294292
combOp = (s1: (Long, BDV[Double]), s2: (Long, BDV[Double])) =>
295-
(s1._1 + s2._1, s1._2 += s2._2),
296-
depth = 2)
293+
(s1._1 + s2._1, s1._2 += s2._2)
294+
)
297295

298296
updateNumRows(m)
299297

@@ -355,8 +353,7 @@ class RowMatrix(
355353
def computeColumnSummaryStatistics(): MultivariateStatisticalSummary = {
356354
val summary = rows.treeAggregate(new MultivariateOnlineSummarizer)(
357355
(aggregator, data) => aggregator.add(data),
358-
(aggregator1, aggregator2) => aggregator1.merge(aggregator2),
359-
depth = 2)
356+
(aggregator1, aggregator2) => aggregator1.merge(aggregator2))
360357
updateNumRows(summary.count)
361358
summary
362359
}

mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ object GradientDescent extends Logging {
185185
},
186186
combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) =>
187187
(grad1 += grad2, loss1 + loss2)
188-
}, depth = 2)
188+
})
189189

190190
/**
191191
* NOTE(Xinghao): lossSum is computed using the weights from the previous iteration

mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ object LBFGS extends Logging {
208208
},
209209
combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) =>
210210
(grad1 += grad2, loss1 + loss2)
211-
}, depth = 2)
211+
})
212212

213213
/**
214214
* regVal is sum of weight squares if it's L2 updater;

mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,12 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) {
4949
}
5050

5151
/**
52-
* Reduces the elements of this RDD in a tree pattern.
53-
* @param depth suggested depth of the tree
52+
* Reduces the elements of this RDD in a multi-level tree pattern.
53+
*
54+
* @param depth suggested depth of the tree (default: 2)
5455
* @see [[org.apache.spark.rdd.RDD#reduce]]
5556
*/
56-
def treeReduce(f: (T, T) => T, depth: Int): T = {
57+
def treeReduce(f: (T, T) => T, depth: Int = 2): T = {
5758
require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")
5859
val cleanF = self.context.clean(f)
5960
val reducePartition: Iterator[T] => Option[T] = iter => {
@@ -80,14 +81,15 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) {
8081
}
8182

8283
/**
83-
* Aggregates the elements of this RDD in a tree pattern.
84-
* @param depth suggested depth of the tree
84+
* Aggregates the elements of this RDD in a multi-level tree pattern.
85+
*
86+
* @param depth suggested depth of the tree (default: 2)
8587
* @see [[org.apache.spark.rdd.RDD#aggregate]]
8688
*/
8789
def treeAggregate[U: ClassTag](zeroValue: U)(
8890
seqOp: (U, T) => U,
8991
combOp: (U, U) => U,
90-
depth: Int): U = {
92+
depth: Int = 2): U = {
9193
require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")
9294
if (self.partitions.size == 0) {
9395
return Utils.clone(zeroValue, self.context.env.closureSerializer.newInstance())

0 commit comments

Comments
 (0)