Skip to content

Commit 810d6df

Browse files
committed
update tokenizer/parser implementation
1 parent 7aac03a commit 810d6df

File tree

5 files changed

+73
-54
lines changed

5 files changed

+73
-54
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import scala.collection.JavaConverters._
2626
import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV}
2727

2828
import org.apache.spark.mllib.util.NumericParser
29+
import org.apache.spark.SparkException
2930

3031
/**
3132
* Represents a numeric vector, whose index type is Int and value type is Double.
@@ -141,7 +142,7 @@ object Vectors {
141142
case Seq(size: Double, indices: Array[Double], values: Array[Double]) =>
142143
Vectors.sparse(size.toInt, indices.map(_.toInt), values)
143144
case other =>
144-
sys.error(s"Cannot parse $other.")
145+
throw new SparkException(s"Cannot parse $other.")
145146
}
146147
}
147148

mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.mllib.regression
1919

2020
import org.apache.spark.mllib.linalg.{Vectors, Vector}
2121
import org.apache.spark.mllib.util.NumericParser
22+
import org.apache.spark.SparkException
2223

2324
/**
2425
* Class that represents the features and labels of a data point.
@@ -43,7 +44,7 @@ object LabeledPoint {
4344
case Seq(label: Double, numeric: Any) =>
4445
LabeledPoint(label, Vectors.parseNumeric(numeric))
4546
case other =>
46-
sys.error(s"Cannot parse $other.")
47+
throw new SparkException(s"Cannot parse $other.")
4748
}
4849
} else { // dense format used before v1.0
4950
val parts = s.split(',')

mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala

Lines changed: 61 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ package org.apache.spark.mllib.util
1919

2020
import scala.collection.mutable.{ArrayBuffer, ListBuffer}
2121

22+
import org.apache.spark.SparkException
23+
2224
private[mllib] object NumericTokenizer {
2325
val NUMBER = -1
2426
val END = -2
@@ -61,39 +63,43 @@ private[mllib] class NumericTokenizer(s: String, start: Int, end: Int) {
6163
*/
6264
def next(): Int = {
6365
if (cur < end) {
64-
val c = s(cur)
65-
c match {
66-
case '(' | '[' =>
67-
allowComma = false
68-
cur += 1
69-
c
70-
case ')' | ']' =>
71-
allowComma = true
66+
val c = s.charAt(cur)
67+
if (c == '(' || c == '[') {
68+
allowComma = false
69+
cur += 1
70+
c
71+
} else if (c == ')' || c == ']') {
72+
allowComma = true
73+
cur += 1
74+
c
75+
} else if (c == ',') {
76+
if (allowComma) {
7277
cur += 1
73-
c
74-
case ',' =>
75-
if (allowComma) {
76-
cur += 1
77-
allowComma = false
78-
next()
78+
allowComma = false
79+
next()
80+
} else {
81+
throw new SparkException(s"Found a ',' at a wrong location: $cur.")
82+
}
83+
} else {
84+
// expecting a number
85+
var inNumber = true
86+
val beginAt = cur
87+
while (cur < end && inNumber) {
88+
val d = s.charAt(cur)
89+
if (d == ')' || d == ']' || d == ',') {
90+
inNumber = false
7991
} else {
80-
sys.error("Found a ',' at a wrong location.")
81-
}
82-
case other => // expecting a number
83-
var inNumber = true
84-
val sb = new StringBuilder()
85-
while (cur < end && inNumber) {
86-
val d = s(cur)
87-
if (d == ')' || d == ']' || d == ',') {
88-
inNumber = false
89-
} else {
90-
sb.append(d)
91-
cur += 1
92-
}
92+
cur += 1
9393
}
94-
_value = sb.toString().toDouble
95-
allowComma = true
96-
NUMBER
94+
}
95+
try {
96+
_value = java.lang.Double.parseDouble(s.substring(beginAt, cur))
97+
} catch {
98+
case e: Throwable =>
99+
throw new SparkException("Error parsing a number", e)
100+
}
101+
allowComma = true
102+
NUMBER
97103
}
98104
} else {
99105
END
@@ -110,15 +116,17 @@ private[mllib] object NumericParser {
110116
def parse(s: String): Any = parse(new NumericTokenizer(s))
111117

112118
private def parse(tokenizer: NumericTokenizer): Any = {
113-
tokenizer.next() match {
114-
case '(' =>
115-
parseTuple(tokenizer)
116-
case '[' =>
117-
parseArray(tokenizer)
118-
case NUMBER =>
119-
tokenizer.value
120-
case END =>
121-
null
119+
val token = tokenizer.next()
120+
if (token == NUMBER) {
121+
tokenizer.value
122+
} else if (token == '(') {
123+
parseTuple(tokenizer)
124+
} else if (token == '[') {
125+
parseArray(tokenizer)
126+
} else if (token == END) {
127+
null
128+
} else {
129+
throw new SparkException(s"Cannot recgonize token type: $token.")
122130
}
123131
}
124132

@@ -129,25 +137,30 @@ private[mllib] object NumericParser {
129137
values.append(tokenizer.value)
130138
token = tokenizer.next()
131139
}
132-
require(token == ']')
140+
if (token != ']') {
141+
throw new SparkException(s"An array must end with ] but got $token.")
142+
}
133143
values.toArray
134144
}
135145

136146
private def parseTuple(tokenizer: NumericTokenizer): Seq[_] = {
137147
val items = ListBuffer.empty[Any]
138148
var token = tokenizer.next()
139149
while (token != ')' && token != END) {
140-
token match {
141-
case '(' =>
142-
items.append(parseTuple(tokenizer))
143-
case '[' =>
144-
items.append(parseArray(tokenizer))
145-
case NUMBER =>
146-
items.append(tokenizer.value)
150+
if (token == NUMBER) {
151+
items.append(tokenizer.value)
152+
} else if (token == '(') {
153+
items.append(parseTuple(tokenizer))
154+
} else if (token == '[') {
155+
items.append(parseArray(tokenizer))
156+
} else {
157+
throw new SparkException(s"Cannot recognize token type: $token.")
147158
}
148159
token = tokenizer.next()
149160
}
150-
require(token == ')')
161+
if (token != ')') {
162+
throw new SparkException(s"A tuple must end with ) but got $token.")
163+
}
151164
items.toSeq
152165
}
153166
}

mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ package org.apache.spark.mllib.linalg
1919

2020
import org.scalatest.FunSuite
2121

22+
import org.apache.spark.SparkException
23+
2224
class VectorsSuite extends FunSuite {
2325

2426
val arr = Array(0.1, 0.0, 0.3, 0.4)
@@ -105,7 +107,7 @@ class VectorsSuite extends FunSuite {
105107
val vectors = Seq(
106108
Vectors.dense(Array.empty[Double]),
107109
Vectors.dense(1.0),
108-
Vectors.dense(1.0, 0.0, -2.0),
110+
Vectors.dense(1.0E6, 0.0, -2.0e-7),
109111
Vectors.sparse(0, Array.empty[Int], Array.empty[Double]),
110112
Vectors.sparse(1, Array(0), Array(1.0)),
111113
Vectors.sparse(3, Array(0, 2), Array(1.0, -2.0)))
@@ -115,9 +117,9 @@ class VectorsSuite extends FunSuite {
115117
assert(v === v1)
116118
}
117119

118-
val malformatted = Seq("1", "[1,,]", "[1,2", "(1,[1,2])", "(1,[1],[2.0,1.0])")
120+
val malformatted = Seq("1", "[1,,]", "[1,2b]", "(1,[1,2])", "([1],[2.0,1.0])")
119121
malformatted.foreach { s =>
120-
intercept[RuntimeException] {
122+
intercept[SparkException] {
121123
Vectors.parse(s)
122124
println(s"Didn't detect malformatted string $s.")
123125
}

mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ import scala.collection.mutable.ListBuffer
2121

2222
import org.scalatest.FunSuite
2323

24+
import org.apache.spark.SparkException
25+
2426
class NumericParserSuite extends FunSuite {
2527

2628
test("tokenizer") {
@@ -42,7 +44,7 @@ class NumericParserSuite extends FunSuite {
4244

4345
val malformatted = Seq("a", "[1,,]", "0.123.4", "1 2", "3+4")
4446
malformatted.foreach { s =>
45-
intercept[RuntimeException] {
47+
intercept[SparkException] {
4648
val tokenizer = new NumericTokenizer(s)
4749
while (tokenizer.next() != NumericTokenizer.END) {
4850
// do nothing

0 commit comments

Comments
 (0)