@@ -42,7 +42,8 @@ private[mllib] class GridPartitioner(
4242 override val numPartitions : Int ) extends Partitioner {
4343
4444 /**
45- * Returns the index of the partition the SubMatrix belongs to.
45+ * Returns the index of the partition the SubMatrix belongs to. Tries to achieve block wise
46+ * partitioning.
4647 *
4748 * @param key The key for the SubMatrix. Can be its position in the grid (its column major index)
4849 * or a tuple of three integers that are the final row index after the multiplication,
@@ -51,22 +52,25 @@ private[mllib] class GridPartitioner(
5152 * @return The index of the partition, which the SubMatrix belongs to.
5253 */
5354 override def getPartition (key : Any ): Int = {
55+ val sqrtPartition = math.round(math.sqrt(numPartitions)).toInt
56+ // numPartitions may not be the square of a number, it can even be a prime number
57+
5458 key match {
55- case (rowIndex : Int , colIndex : Int ) =>
56- Utils .nonNegativeMod(rowIndex + colIndex * numRowBlocks, numPartitions)
57- case (rowIndex : Int , innerIndex : Int , colIndex : Int ) =>
58- Utils .nonNegativeMod(rowIndex + colIndex * numRowBlocks, numPartitions)
59+ case (blockRowIndex : Int , blockColIndex : Int ) =>
60+ Utils .nonNegativeMod(blockRowIndex + blockColIndex * numRowBlocks, numPartitions)
61+ case (blockRowIndex : Int , innerIndex : Int , blockColIndex : Int ) =>
62+ Utils .nonNegativeMod(blockRowIndex + blockColIndex * numRowBlocks, numPartitions)
5963 case _ =>
60- throw new IllegalArgumentException (" Unrecognized key" )
64+ throw new IllegalArgumentException (s " Unrecognized key. key: $ key" )
6165 }
6266 }
6367
6468 /** Checks whether the partitioners have the same characteristics */
6569 override def equals (obj : Any ): Boolean = {
6670 obj match {
6771 case r : GridPartitioner =>
68- (this .numPartitions == r.numPartitions ) && (this .rowsPerBlock == r.rowsPerBlock) &&
69- (this .colsPerBlock == r.colsPerBlock)
72+ (this .numRowBlocks == r.numRowBlocks ) && (this .numColBlocks == r.numColBlocks)
73+ (this .rowsPerBlock == r.rowsPerBlock) && ( this . colsPerBlock == r.colsPerBlock)
7074 case _ =>
7175 false
7276 }
@@ -85,7 +89,7 @@ class BlockMatrix(
8589 val numColBlocks : Int ,
8690 val rdd : RDD [((Int , Int ), Matrix )]) extends DistributedMatrix with Logging {
8791
88- type SubMatrix = ((Int , Int ), Matrix ) // ((blockRowIndex, blockColIndex), matrix)
92+ private type SubMatrix = ((Int , Int ), Matrix ) // ((blockRowIndex, blockColIndex), matrix)
8993
9094 /**
9195 * Alternate constructor for BlockMatrix without the input of a partitioner. Will use a Grid
0 commit comments