diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 4ccef788ce8e..8595e7ec0e67 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -627,6 +627,7 @@ def __hash__(self): "pyspark.ml.tuning", # unittests "pyspark.ml.tests.test_algorithms", + "pyspark.ml.tests.test_als", "pyspark.ml.tests.test_base", "pyspark.ml.tests.test_evaluation", "pyspark.ml.tests.test_feature", diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 65c7d399a88b..1e6be16ef62b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -324,13 +324,22 @@ class ALSModel private[ml] ( // create a new column named map(predictionCol) by running the predict UDF. val validatedUsers = checkIntegers(dataset, $(userCol)) val validatedItems = checkIntegers(dataset, $(itemCol)) + + val validatedInputAlias = Identifiable.randomUID("__als_validated_input") + val itemFactorsAlias = Identifiable.randomUID("__als_item_factors") + val userFactorsAlias = Identifiable.randomUID("__als_user_factors") + val predictions = dataset - .join(userFactors, - validatedUsers === userFactors("id"), "left") - .join(itemFactors, - validatedItems === itemFactors("id"), "left") - .select(dataset("*"), - predict(userFactors("features"), itemFactors("features")).as($(predictionCol))) + .withColumns(Seq($(userCol), $(itemCol)), Seq(validatedUsers, validatedItems)) + .alias(validatedInputAlias) + .join(userFactors.alias(userFactorsAlias), + col(s"${validatedInputAlias}.${$(userCol)}") === col(s"${userFactorsAlias}.id"), "left") + .join(itemFactors.alias(itemFactorsAlias), + col(s"${validatedInputAlias}.${$(itemCol)}") === col(s"${itemFactorsAlias}.id"), "left") + .select(col(s"${validatedInputAlias}.*"), + predict(col(s"${userFactorsAlias}.features"), col(s"${itemFactorsAlias}.features")) + .alias($(predictionCol))) + getColdStartStrategy match { case ALSModel.Drop => predictions.na.drop("all", Seq($(predictionCol))) diff --git a/python/pyspark/ml/tests/test_als.py b/python/pyspark/ml/tests/test_als.py new file mode 100644 index 000000000000..8eec0d937768 --- /dev/null +++ b/python/pyspark/ml/tests/test_als.py @@ -0,0 +1,68 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import tempfile +import unittest + +import pyspark.sql.functions as sf +from pyspark.ml.recommendation import ALS, ALSModel +from pyspark.testing.sqlutils import ReusedSQLTestCase + + +class ALSTest(ReusedSQLTestCase): + def test_ambiguous_column(self): + data = self.spark.createDataFrame( + [[1, 15, 1], [1, 2, 2], [2, 3, 4], [2, 2, 5]], + ["user", "item", "rating"], + ) + model = ALS( + userCol="user", + itemCol="item", + ratingCol="rating", + numUserBlocks=10, + numItemBlocks=10, + maxIter=1, + seed=42, + ).fit(data) + + with tempfile.TemporaryDirectory() as d: + model.write().overwrite().save(d) + loaded_model = ALSModel().load(d) + + with self.sql_conf({"spark.sql.analyzer.failAmbiguousSelfJoin": False}): + users = loaded_model.userFactors.select(sf.col("id").alias("user")) + items = loaded_model.itemFactors.select(sf.col("id").alias("item")) + predictions = loaded_model.transform(users.crossJoin(items)) + self.assertTrue(predictions.count() > 0) + + with self.sql_conf({"spark.sql.analyzer.failAmbiguousSelfJoin": True}): + users = loaded_model.userFactors.select(sf.col("id").alias("user")) + items = loaded_model.itemFactors.select(sf.col("id").alias("item")) + predictions = loaded_model.transform(users.crossJoin(items)) + self.assertTrue(predictions.count() > 0) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_als import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2)