Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions R/pkg/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,8 @@ exportMethods("%<=>%",
"getField",
"getItem",
"greatest",
"grouping_bit",
"grouping_id",
"hex",
"histogram",
"hour",
Expand Down
85 changes: 85 additions & 0 deletions R/pkg/R/functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -3871,6 +3871,7 @@ setMethod("posexplode_outer",
#' @rdname not
#' @name not
#' @aliases not,Column-method
#' @family normal_funcs
#' @export
#' @examples \dontrun{
#' df <- createDataFrame(data.frame(
Expand All @@ -3890,3 +3891,87 @@ setMethod("not",
jc <- callJStatic("org.apache.spark.sql.functions", "not", x@jc)
column(jc)
})

#' grouping_bit
#'
#' Indicates whether a specified column in a GROUP BY list is aggregated or not,
#' returns 1 for aggregated or 0 for not aggregated in the result set.
#'
#' Same as \code{GROUPING} in SQL and \code{grouping} function in Scala.
#'
#' @param x Column to compute on
#'
#' @rdname grouping_bit
#' @name grouping_bit
#' @family agg_funcs
#' @aliases grouping_bit,Column-method
#' @export
#' @examples \dontrun{
#' df <- createDataFrame(mtcars)
#'
#' # With cube
#' agg(
#' cube(df, "cyl", "gear", "am"),
#' mean(df$mpg),
#' grouping_bit(df$cyl), grouping_bit(df$gear), grouping_bit(df$am)
#' )
#'
#' # With rollup
#' agg(
#' rollup(df, "cyl", "gear", "am"),
#' mean(df$mpg),
#' grouping_bit(df$cyl), grouping_bit(df$gear), grouping_bit(df$am)
#' )
#' }
#' @note grouping_bit since 2.3.0
setMethod("grouping_bit",
signature(x = "Column"),
function(x) {
jc <- callJStatic("org.apache.spark.sql.functions", "grouping", x@jc)
column(jc)
})

#' grouping_id
#'
#' Returns the level of grouping.
#'
#' Equals to \code{
#' grouping_bit(c1) * 2^(n - 1) + grouping_bit(c2) * 2^(n - 2) + ... + grouping_bit(cn)
#' }
#'
#' @param x Column to compute on
#' @param ... additional Column(s) (optional).
#'
#' @rdname grouping_id
#' @name grouping_id
#' @family agg_funcs
#' @aliases grouping_id,Column-method
#' @export
#' @examples \dontrun{
#' df <- createDataFrame(mtcars)
#'
#' # With cube
#' agg(
#' cube(df, "cyl", "gear", "am"),
#' mean(df$mpg),
#' grouping_id(df$cyl, df$gear, df$am)
#' )
#'
#' # With rollup
#' agg(
#' rollup(df, "cyl", "gear", "am"),
#' mean(df$mpg),
#' grouping_id(df$cyl, df$gear, df$am)
#' )
#' }
#' @note grouping_id since 2.3.0
setMethod("grouping_id",
signature(x = "Column"),
function(x, ...) {
jcols <- lapply(list(x, ...), function (x) {
stopifnot(class(x) == "Column")
x@jc
})
jc <- callJStatic("org.apache.spark.sql.functions", "grouping_id", jcols)
column(jc)
})
8 changes: 8 additions & 0 deletions R/pkg/R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -1052,6 +1052,14 @@ setGeneric("from_unixtime", function(x, ...) { standardGeneric("from_unixtime")
#' @export
setGeneric("greatest", function(x, ...) { standardGeneric("greatest") })

#' @rdname grouping_bit
#' @export
setGeneric("grouping_bit", function(x) { standardGeneric("grouping_bit") })

#' @rdname grouping_id
#' @export
setGeneric("grouping_id", function(x, ...) { standardGeneric("grouping_id") })

#' @rdname hex
#' @export
setGeneric("hex", function(x) { standardGeneric("hex") })
Expand Down
56 changes: 54 additions & 2 deletions R/pkg/inst/tests/testthat/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -1848,7 +1848,11 @@ test_that("test multi-dimensional aggregations with cube and rollup", {
orderBy(
agg(
cube(df, "year", "department"),
expr("sum(salary) AS total_salary"), expr("avg(salary) AS average_salary")
expr("sum(salary) AS total_salary"),
expr("avg(salary) AS average_salary"),
alias(grouping_bit(df$year), "grouping_year"),
alias(grouping_bit(df$department), "grouping_department"),
alias(grouping_id(df$year, df$department), "grouping_id")
),
"year", "department"
)
Expand All @@ -1875,6 +1879,30 @@ test_that("test multi-dimensional aggregations with cube and rollup", {
mean(c(21000, 32000, 22000)), # 2017
22000, 32000, 21000 # 2017 each department
),
grouping_year = c(
1, # global
1, 1, 1, # by department
0, # 2016
0, 0, 0, # 2016 by department
0, # 2017
0, 0, 0 # 2017 by department
),
grouping_department = c(
1, # global
0, 0, 0, # by department
1, # 2016
0, 0, 0, # 2016 by department
1, # 2017
0, 0, 0 # 2017 by department
),
grouping_id = c(
3, # 11
2, 2, 2, # 10
1, # 01
0, 0, 0, # 00
1, # 01
0, 0, 0 # 00
),
stringsAsFactors = FALSE
)

Expand All @@ -1896,7 +1924,10 @@ test_that("test multi-dimensional aggregations with cube and rollup", {
orderBy(
agg(
rollup(df, "year", "department"),
expr("sum(salary) AS total_salary"), expr("avg(salary) AS average_salary")
expr("sum(salary) AS total_salary"), expr("avg(salary) AS average_salary"),
alias(grouping_bit(df$year), "grouping_year"),
alias(grouping_bit(df$department), "grouping_department"),
alias(grouping_id(df$year, df$department), "grouping_id")
),
"year", "department"
)
Expand All @@ -1920,6 +1951,27 @@ test_that("test multi-dimensional aggregations with cube and rollup", {
mean(c(21000, 32000, 22000)), # 2017
22000, 32000, 21000 # 2017 each department
),
grouping_year = c(
1, # global
0, # 2016
0, 0, 0, # 2016 each department
0, # 2017
0, 0, 0 # 2017 each department
),
grouping_department = c(
1, # global
1, # 2016
0, 0, 0, # 2016 each department
1, # 2017
0, 0, 0 # 2017 each department
),
grouping_id = c(
3, # 11
1, # 01
0, 0, 0, # 00
1, # 01
0, 0, 0 # 00
),
stringsAsFactors = FALSE
)

Expand Down
1 change: 1 addition & 0 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1734,6 +1734,7 @@ class SparkContext(config: SparkConf) extends Logging {
* Return information about blocks stored in all of the slaves
*/
@DeveloperApi
@deprecated("This method may change or be removed in a future release.", "2.2.0")
def getExecutorStorageStatus: Array[StorageStatus] = {
assertNotStopped()
env.blockManager.master.getStorageStatus
Expand Down
8 changes: 8 additions & 0 deletions docs/sparkr.md
Original file line number Diff line number Diff line change
Expand Up @@ -644,3 +644,11 @@ You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-ma
## Upgrading to SparkR 2.1.0

- `join` no longer performs Cartesian Product by default, use `crossJoin` instead.

## Upgrading to SparkR 2.2.0

- A `numPartitions` parameter has been added to `createDataFrame` and `as.DataFrame`. When splitting the data, the partition position calculation has been made to match the one in Scala.
- The method `createExternalTable` has been deprecated to be replaced by `createTable`. Either methods can be called to create external or managed table. Additional catalog methods have also been added.
- By default, derby.log is now saved to `tempdir()`. This will be created when instantiating the SparkSession with `enableHiveSupport` set to `TRUE`.
- `spark.lda` was not setting the optimizer correctly. It has been corrected.
- Several model summary outputs are updated to have `coefficients` as `matrix`. This includes `spark.logit`, `spark.kmeans`, `spark.glm`. Model summary outputs for `spark.gaussianMixture` have added log-likelihood as `loglik`.
30 changes: 30 additions & 0 deletions python/pyspark/ml/recommendation.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,14 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
Row(user=1, item=0, prediction=2.6258413791656494)
>>> predictions[2]
Row(user=2, item=0, prediction=-1.5018409490585327)
>>> user_recs = model.recommendForAllUsers(3)
>>> user_recs.where(user_recs.user == 0)\
.select("recommendations.item", "recommendations.rating").collect()
[Row(item=[0, 1, 2], rating=[3.910..., 1.992..., -0.138...])]
>>> item_recs = model.recommendForAllItems(3)
>>> item_recs.where(item_recs.item == 2)\
.select("recommendations.user", "recommendations.rating").collect()
[Row(user=[2, 1, 0], rating=[4.901..., 3.981..., -0.138...])]
>>> als_path = temp_path + "/als"
>>> als.save(als_path)
>>> als2 = ALS.load(als_path)
Expand Down Expand Up @@ -384,6 +392,28 @@ def itemFactors(self):
"""
return self._call_java("itemFactors")

@since("2.2.0")
def recommendForAllUsers(self, numItems):
"""
Returns top `numItems` items recommended for each user, for all users.

:param numItems: max number of recommendations for each user
:return: a DataFrame of (userCol, recommendations), where recommendations are
stored as an array of (itemCol, rating) Rows.
"""
return self._call_java("recommendForAllUsers", numItems)

@since("2.2.0")
def recommendForAllItems(self, numUsers):
"""
Returns top `numUsers` users recommended for each item, for all items.

:param numUsers: max number of recommendations for each item
:return: a DataFrame of (itemCol, recommendations), where recommendations are
stored as an array of (userCol, rating) Rows.
"""
return self._call_java("recommendForAllItems", numUsers)


if __name__ == "__main__":
import doctest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,8 @@ class ReplSuite extends SparkFunSuite {
assertDoesNotContain("Exception", output)
}

test("newProductSeqEncoder with REPL defined class") {
// TODO: [SPARK-20548] Fix and re-enable
ignore("newProductSeqEncoder with REPL defined class") {
val output = runInterpreterInPasteMode("local-cluster[1,4,4096]",
"""
|case class Click(id: Int)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@ predicate
| NOT? kind=IN '(' query ')'
| NOT? kind=(RLIKE | LIKE) pattern=valueExpression
| IS NOT? kind=NULL
| IS NOT? kind=DISTINCT FROM right=valueExpression
;

valueExpression
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,7 @@ case class GetJsonObject(json: Expression, path: Expression)
try {
/* We know the bytes are UTF-8 encoded. Pass a Reader to avoid having Jackson
detect character encoding which could fail for some malformed strings */
Utils.tryWithResource(jsonFactory.createParser(new InputStreamReader(
new ByteArrayInputStream(jsonStr.getBytes), "UTF-8"))) { parser =>
Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, jsonStr)) { parser =>
val output = new ByteArrayOutputStream()
val matched = Utils.tryWithResource(
jsonFactory.createGenerator(output, JsonEncoding.UTF8)) { generator =>
Expand Down Expand Up @@ -398,9 +397,8 @@ case class JsonTuple(children: Seq[Expression])
try {
/* We know the bytes are UTF-8 encoded. Pass a Reader to avoid having Jackson
detect character encoding which could fail for some malformed strings */
Utils.tryWithResource(jsonFactory.createParser(new InputStreamReader(
new ByteArrayInputStream(json.getBytes), "UTF-8"))) {
parser => parseRow(parser, input)
Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, json)) { parser =>
parseRow(parser, input)
}
} catch {
case _: JsonProcessingException =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.json

import java.io.InputStream
import java.io.{ByteArrayInputStream, InputStream, InputStreamReader}

import com.fasterxml.jackson.core.{JsonFactory, JsonParser}
import org.apache.hadoop.io.Text
Expand All @@ -33,7 +33,10 @@ private[sql] object CreateJacksonParser extends Serializable {
val bb = record.getByteBuffer
assert(bb.hasArray)

jsonFactory.createParser(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())
val bain = new ByteArrayInputStream(
bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())

jsonFactory.createParser(new InputStreamReader(bain, "UTF-8"))
}

def text(jsonFactory: JsonFactory, record: Text): JsonParser = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
* - (NOT) LIKE
* - (NOT) RLIKE
* - IS (NOT) NULL.
* - IS (NOT) DISTINCT FROM
*/
private def withPredicate(e: Expression, ctx: PredicateContext): Expression = withOrigin(ctx) {
// Invert a predicate if it has a valid NOT clause.
Expand Down Expand Up @@ -962,6 +963,10 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
IsNotNull(e)
case SqlBaseParser.NULL =>
IsNull(e)
case SqlBaseParser.DISTINCT if ctx.NOT != null =>
EqualNullSafe(e, expression(ctx.right))
case SqlBaseParser.DISTINCT =>
Not(EqualNullSafe(e, expression(ctx.right)))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,13 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
)
}

test("SPARK-20549: from_json bad UTF-8") {
val schema = StructType(StructField("a", IntegerType) :: Nil)
checkEvaluation(
JsonToStructs(schema, Map.empty, Literal(badJson), gmtId),
null)
}

test("from_json with timestamp") {
val schema = StructType(StructField("t", TimestampType) :: Nil)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,11 @@ class ExpressionParserSuite extends PlanTest {
assertEqual("a = b is not null", ('a === 'b).isNotNull)
}

test("is distinct expressions") {
assertEqual("a is distinct from b", !('a <=> 'b))
assertEqual("a is not distinct from b", 'a <=> 'b)
}

test("binary arithmetic expressions") {
// Simple operations
assertEqual("a * b", 'a * 'b)
Expand Down
Loading