diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 3ee66c95edb9..6f93365059d7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -248,22 +248,24 @@ class ALSSuite extends MLTest with DefaultReadWriteTest with Logging { } val msg = "ALS only supports non-Null values" - withClue("Invalid Long: out of range") { - val df = sc.parallelize(Seq( - (1231000000000L, 12L, 0.5), - (1112L, 21L, 1.0) - )).toDF("item", "user", "rating") - val e = intercept[Exception] { new ALS().setMaxIter(1).fit(df) } - assert(e.getMessage.contains(msg)) - } + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + withClue("Invalid Long: out of range") { + val df = sc.parallelize(Seq( + (1231000000000L, 12L, 0.5), + (1112L, 21L, 1.0) + )).toDF("item", "user", "rating") + val e = intercept[Exception] { new ALS().setMaxIter(1).fit(df) } + assert(e.getMessage.contains(msg)) + } - withClue("Invalid Double: out of range") { - val df = sc.parallelize(Seq( - (1231000000000.0, 12.0, 0.5), - (111.0, 21.0, 1.0) - )).toDF("item", "user", "rating") - val e = intercept[Exception] { new ALS().setMaxIter(1).fit(df) } - assert(e.getMessage.contains(msg)) + withClue("Invalid Double: out of range") { + val df = sc.parallelize(Seq( + (1231000000000.0, 12.0, 0.5), + (111.0, 21.0, 1.0) + )).toDF("item", "user", "rating") + val e = intercept[Exception] { new ALS().setMaxIter(1).fit(df) } + assert(e.getMessage.contains(msg)) + } } withClue("Invalid Double: fractional part") { @@ -275,18 +277,20 @@ class ALSSuite extends MLTest with DefaultReadWriteTest with Logging { assert(e.getMessage.contains(msg)) } - withClue("Invalid Decimal: out of range") { - val df = sc.parallelize(Seq( - (1231000000000.0, 12L, 0.5), - (1112.0, 21L, 1.0) - )).toDF("item", "user", "rating") - .select( - col("item").cast(DecimalType(15, 2)).as("item"), - col("user").cast(DecimalType(15, 2)).as("user"), - col("rating") - ) - val e = intercept[Exception] { new ALS().setMaxIter(1).fit(df) } - assert(e.getMessage.contains(msg)) + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + withClue("Invalid Decimal: out of range") { + val df = sc.parallelize(Seq( + (1231000000000.0, 12L, 0.5), + (1112.0, 21L, 1.0) + )).toDF("item", "user", "rating") + .select( + col("item").cast(DecimalType(15, 2)).as("item"), + col("user").cast(DecimalType(15, 2)).as("user"), + col("rating") + ) + val e = intercept[Exception] { new ALS().setMaxIter(1).fit(df) } + assert(e.getMessage.contains(msg)) + } } withClue("Invalid Decimal: fractional part") {