diff --git a/.github/workflows/sarplus.yml b/.github/workflows/sarplus.yml new file mode 100644 index 0000000000..e2b3fc6103 --- /dev/null +++ b/.github/workflows/sarplus.yml @@ -0,0 +1,159 @@ +# This workflow will run tests and do packaging for contrib/sarplus. +# +# References: +# * GitHub Actions workflow templates +# + [python package](https://github.com/actions/starter-workflows/blob/main/ci/python-package.yml) +# + [scala](https://github.com/actions/starter-workflows/blob/main/ci/scala.yml) +# * [GitHub hosted runner - Ubuntu 20.04 LTS](https://github.com/actions/virtual-environments/blob/main/images/linux/Ubuntu2004-README.md) +# * [Azure Databricks runtime releases](https://docs.microsoft.com/en-us/azure/databricks/release-notes/runtime/releases) + + +name: sarplus test and package + +on: + push: + paths: + - contrib/sarplus/python/** + - contrib/sarplus/scala/** + - contrib/sarplus/VERSION + - .github/workflows/sarplus.yml + +env: + PYTHON_ROOT: ${{ github.workspace }}/contrib/sarplus/python + SCALA_ROOT: ${{ github.workspace }}/contrib/sarplus/scala + +jobs: + python: + # Test pysarplus with different versions of Python. + # Package pysarplus and upload as GitHub workflow artifact when merged into + # the main branch. + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.6", "3.7", "3.8", "3.9", "3.10"] + steps: + - uses: actions/checkout@v2 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install -U build pip twine + python -m pip install -U flake8 pytest pytest-cov scikit-learn + + - name: Lint with flake8 + run: | + cd "${PYTHON_ROOT}" + # See https://flake8.pycqa.org/en/latest/user/index.html + flake8 . + + - name: Package and check + run: | + cd "${PYTHON_ROOT}" + cp ../VERSION ./pysarplus/ + python -m build --sdist + python -m twine check dist/* + + - name: Test + run: | + cd "${PYTHON_ROOT}" + python -m pip install dist/*.gz + + cd "${SCALA_ROOT}" + export SPARK_VERSION=$(python -m pip show pyspark | grep -i version | cut -d ' ' -f 2) + SPARK_JAR_DIR=$(python -m pip show pyspark | grep -i location | cut -d ' ' -f2)/pyspark/jars + SCALA_JAR=$(ls ${SPARK_JAR_DIR}/scala-library*) + HADOOP_JAR=$(ls ${SPARK_JAR_DIR}/hadoop-client-api*) + SCALA_VERSION=${SCALA_JAR##*-} + export SCALA_VERSION=${SCALA_VERSION%.*} + HADOOP_VERSION=${HADOOP_JAR##*-} + export HADOOP_VERSION=${HADOOP_VERSION%.*} + sbt ++"${SCALA_VERSION}"! package + + cd "${PYTHON_ROOT}" + pytest ./tests + echo "sarplus_version=$(cat ../VERSION)" >> $GITHUB_ENV + + - name: Upload Python package as GitHub artifact + if: github.ref == 'refs/heads/main' && matrix.python-version == '3.10' + uses: actions/upload-artifact@v2 + with: + name: pysarplus-${{ env.sarplus_version }} + path: ${{ env.PYTHON_ROOT }}/dist/*.gz + + scala-test: + # Test sarplus with different versions of Databricks runtime, 2 LTSs and 1 + # latest. + runs-on: ubuntu-latest + strategy: + matrix: + include: + - scala-version: "2.12.10" + spark-version: "3.0.1" + hadoop-version: "2.7.4" + databricks-runtime: "ADB 7.3 LTS" + + - scala-version: "2.12.10" + spark-version: "3.1.2" + hadoop-version: "2.7.4" + databricks-runtime: "ADB 9.1 LTS" + + - scala-version: "2.12.14" + spark-version: "3.2.0" + hadoop-version: "3.3.1" + databricks-runtime: "ADB 10.0" + + steps: + - uses: actions/checkout@v2 + + - name: Test + run: | + cd "${SCALA_ROOT}" + export SPARK_VERSION="${{ matrix.spark-version }}" + export HADOOP_VERSION="${{ matrix.hadoop-version }}" + sbt ++${{ matrix.scala-version }}! test + + scala-package: + # Package sarplus and upload as GitHub workflow artifact when merged into + # the main branch. + needs: scala-test + if: github.ref == 'refs/heads/main' + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + + - name: Package + env: + GPG_KEY: ${{ secrets.SARPLUS_GPG_PRI_KEY_ASC }} + run: | + # generate artifacts + cd "${SCALA_ROOT}" + export SPARK_VERSION="3.1.2" + export HADOOP_VERSION="2.7.4" + export SCALA_VERSION="2.12.10" + sbt ++${SCALA_VERSION}! package + sbt ++${SCALA_VERSION}! packageDoc + sbt ++${SCALA_VERSION}! packageSrc + sbt ++${SCALA_VERSION}! makePom + export SPARK_VERSION="3.2.0" + export HADOOP_VERSION="3.3.1" + export SCALA_VERSION="2.12.14" + sbt ++${SCALA_VERSION}! package + + # sign with GPG + cd target/scala-2.12 + gpg --import <(cat <<< "${GPG_KEY}") + for file in {*.jar,*.pom}; do gpg -ab "${file}"; done + + # bundle + jar cvf sarplus-bundle_2.12-$(cat ../VERSION).jar *.jar *.pom *.asc + echo "sarplus_version=$(cat ../VERSION)" >> $GITHUB_ENV + + - name: Upload Scala bundle as GitHub artifact + uses: actions/upload-artifact@v2 + with: + name: sarplus-bundle_2.12-${{ env.sarplus_version }} + path: ${{ env.SCALA_ROOT }}/target/scala-2.12/sarplus-bundle_2.12-${{ env.sarplus_version }}.jar diff --git a/AUTHORS.md b/AUTHORS.md index 1b2e99d19d..18b792d5a0 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -103,6 +103,8 @@ To contributors: please add your name to the list when you submit a patch to the * Windows test pipelines * **[Satyadev Ntv](https://github.com/satyadevntv)** * GeoIMC algorithm +* **[Simon Zhao](https://github.com/simonzhaoms)** + * SARplus algorithm upgrade * **[Yan Zhang](https://github.com/YanZhangADS)** * Diversity metrics including coverage, novelty, diversity, and serendipity * Diversity metrics evaluation sample notebook diff --git a/contrib/sarplus/DEVELOPMENT.md b/contrib/sarplus/DEVELOPMENT.md index 8d1557cbba..d601be582f 100644 --- a/contrib/sarplus/DEVELOPMENT.md +++ b/contrib/sarplus/DEVELOPMENT.md @@ -1,41 +1,109 @@ # Packaging -For [databricks](https://databricks.com/) to properly install a [C++ extension](https://docs.python.org/3/extending/building.html), one must take a detour through [pypi](https://pypi.org/). -Use [twine](https://github.com/pypa/twine) to upload the package to [pypi](https://pypi.org/). +For [databricks](https://databricks.com/) to properly install a [C++ +extension](https://docs.python.org/3/extending/building.html), one +must take a detour through [pypi](https://pypi.org/). Use +[twine](https://github.com/pypa/twine) to upload the package to +[pypi](https://pypi.org/). ```bash -cd python - -python setup.py sdist +# build dependencies +python -m pip install -U build pip twine -twine upload dist/pysarplus-*.tar.gz +cd python +cp ../VERSION ./pysarplus/ # version file +python -m build --sdist +python -m twine upload dist/* ``` -On [Spark](https://spark.apache.org/) one can install all 3 components (C++, Python, Scala) in one pass by creating a [Spark Package](https://spark-packages.org/). Documentation is rather sparse. Steps to install +On [Spark](https://spark.apache.org/) one can install all 3 components +(C++, Python, Scala) in one pass by creating a [Spark +Package](https://spark-packages.org/). Steps to install 1. Package and publish the [pip package](python/setup.py) (see above) -2. Package the [Spark package](scala/build.sbt), which includes the [Scala formatter](scala/src/main/scala/microsoft/sarplus) and references the [pip package](scala/python/requirements.txt) (see below) -3. Upload the zipped Scala package to [Spark Package](https://spark-packages.org/) through a browser. [sbt spPublish](https://github.com/databricks/sbt-spark-package) has a few [issues](https://github.com/databricks/sbt-spark-package/issues/31) so it always fails for me. Don't use spPublishLocal as the packages are not created properly (names don't match up, [issue](https://github.com/databricks/sbt-spark-package/issues/17)) and furthermore fail to install if published to [Spark-Packages.org](https://spark-packages.org/). +2. Package the [Spark package](scala/build.sbt), which includes the + [Scala formatter](scala/src/main/scala/microsoft/sarplus) and + references the pip package (see below) +3. Upload the zipped Scala package bundle to [Nexus Repository + Manager](https://oss.sonatype.org/) through a browser (See [publish + manul](https://central.sonatype.org/publish/publish-manual/)). ```bash +export SPARK_VERSION="3.1.2" +export HADOOP_VERSION="2.7.4" +export SCALA_VERSION="2.12.10" +GPG_KEY="" + +# generate artifacts cd scala -sbt spPublish +sbt ++${SCALA_VERSION}! package +sbt ++${SCALA_VERSION}! packageDoc +sbt ++${SCALA_VERSION}! packageSrc +sbt ++${SCALA_VERSION}! makePom + +# generate the artifact (sarplus-*-spark32.jar) for Spark 3.2+ +export SPARK_VERSION="3.2.0" +export HADOOP_VERSION="3.3.1" +export SCALA_VERSION="2.12.14" +sbt ++${SCALA_VERSION}! package + +# sign with GPG +cd target/scala-${SCALA_VERSION%.*} +gpg --import <(cat <<< "${GPG_KEY}") +for file in {*.jar,*.pom}; do gpg -ab "${file}"; done + +# bundle +jar cvf sarplus-bundle_2.12-$(cat ../VERSION).jar *.jar *.pom *.asc ``` +where `SPARK_VERSION`, `HADOOP_VERSION`, `SCALA_VERSION` should be +customized as needed. + + ## Testing To test the python UDF + C++ backend ```bash -cd python -python setup.py install && pytest -s tests/ +# build dependencies +python -m pip install -U build pip twine + +# build +cd python +cp ../VERSION ./pysarplus/ # version file +python -m build --sdist + +# test +pytest ./tests ``` To test the Scala formatter ```bash +export SPARK_VERSION=3.2.0 +export HADOOP_VERSION=3.3.1 +export SCALA_VERSION=2.12.14 + cd scala -sbt test +sbt ++${SCALA_VERSION}! test ``` -(use ~test and it will automatically check for changes in source files, but not build.sbt) + +## Notes for Spark 3.x ## + +The code now has been modified to support Spark 3.x, and has been +tested under different versions of Databricks Runtime (including 6.4 +Extended Support, 7.3 LTS, 9.1 LTS, 10.0 and 10.1) on Azure Databricks +Service. However, there is a breaking change of +[org/apache.spark.sql.execution.datasources.OutputWriter](https://github.com/apache/spark/blob/dc0fa1eef74238d745dabfdc86705b59d95b07e1/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala#L74) +on **Spark 3.2**, which adds an extra function `path()`, so an +additional JAR file with the classifier `spark32` will be needed if +running on Spark 3.2 (See above for packaging). + +Also, extra configurations are also required when running on Spark +3.x: + +``` +spark.sql.sources.default parquet +spark.sql.legacy.createHiveTableByDefault true +``` diff --git a/contrib/sarplus/README.md b/contrib/sarplus/README.md index d898e2648b..bf27610b19 100644 --- a/contrib/sarplus/README.md +++ b/contrib/sarplus/README.md @@ -5,7 +5,13 @@ Pronounced surplus as it's simply better if not best! [![Build Status](https://dev.azure.com/best-practices/recommenders/_apis/build/status/contrib%20sarplus?branchName=master)](https://dev.azure.com/best-practices/recommenders/_build/latest?definitionId=107&branchName=master) [![PyPI version](https://badge.fury.io/py/pysarplus.svg)](https://badge.fury.io/py/pysarplus) -Simple Algorithm for Recommendation (SAR) is a neighborhood based algorithm for personalized recommendations based on user transaction history. SAR recommends items that are most **similar** to the ones that the user already has an existing **affinity** for. Two items are **similar** if the users that interacted with one item are also likely to have interacted with the other. A user has an **affinity** to an item if they have interacted with it in the past. +Simple Algorithm for Recommendation (SAR) is a neighborhood based +algorithm for personalized recommendations based on user transaction +history. SAR recommends items that are most **similar** to the ones +that the user already has an existing **affinity** for. Two items are +**similar** if the users that interacted with one item are also likely +to have interacted with the other. A user has an **affinity** to an +item if they have interacted with it in the past. SARplus is an efficient implementation of this algorithm for Spark. @@ -13,7 +19,8 @@ Features: * Scalable PySpark based [implementation](python/pysarplus/SARPlus.py) * Fast C++ based [predictions](python/src/pysarplus.cpp) -* Reduced memory consumption: similarity matrix cached in-memory once per worker, shared accross python executors +* Reduced memory consumption: similarity matrix cached in-memory once + per worker, shared accross python executors ## Benchmarks @@ -25,15 +32,23 @@ Features: There are a couple of key optimizations: -* map item ids (e.g. strings) to a continuous set of indexes to optmize storage and simplify access -* convert similarity matrix to exactly the representation the C++ component needs, thus enabling simple shared, memory mapping of the cache file and avoid parsing. This requires a customer formatter, written in Scala -* shared read-only memory mapping allows us to re-use the same memory from multiple python executors on the same worker node -* partition the input test users and past seen items by users, allowing for scale out +* map item ids (e.g. strings) to a continuous set of indexes to + optmize storage and simplify access +* convert similarity matrix to exactly the representation the C++ + component needs, thus enabling simple shared, memory mapping of the + cache file and avoid parsing. This requires a customer formatter, + written in Scala +* shared read-only memory mapping allows us to re-use the same memory + from multiple python executors on the same worker node +* partition the input test users and past seen items by users, + allowing for scale out * perform as much of the work as possible in PySpark (way simpler) * top-k computation -** reverse the join by summing reverse joining the users past seen items with any related items -** make sure to always just keep top-k items in-memory -** use standard join using binary search between users past seen items and the related items + + reverse the join by summing reverse joining the users past seen + items with any related items + + make sure to always just keep top-k items in-memory + + use standard join using binary search between users past seen + items and the related items ![Image of sarplus top-k recommendation optimization](https://recodatasets.z20.web.core.windows.net/images/sarplus_udf.svg) @@ -76,7 +91,7 @@ Insert this cell prior to the code above. ```python import os -SUBMIT_ARGS = "--packages eisber:sarplus:0.2.6 pyspark-shell" +SUBMIT_ARGS = "--packages com.microsoft.sarplus:sarplus:0.5.0 pyspark-shell" os.environ["PYSPARK_SUBMIT_ARGS"] = SUBMIT_ARGS from pyspark.sql import SparkSession @@ -96,21 +111,26 @@ spark = ( ```bash pip install pysarplus -pyspark --packages eisber:sarplus:0.2.6 --conf spark.sql.crossJoin.enabled=true +pyspark --packages com.microsoft.sarplus:sarplus:0.5.0 --conf spark.sql.crossJoin.enabled=true ``` ### Databricks -One must set the crossJoin property to enable calculation of the similarity matrix (Clusters / < Cluster > / Configuration / Spark Config) +One must set the crossJoin property to enable calculation of the +similarity matrix (Clusters / < Cluster > / Configuration / +Spark Config) ``` spark.sql.crossJoin.enabled true +spark.sql.sources.default parquet +spark.sql.legacy.createHiveTableByDefault true ``` 1. Navigate to your workspace 2. Create library 3. Under 'Source' select 'Maven Coordinate' -4. Enter 'eisber:sarplus:0.2.5' or 'eisber:sarplus:0.2.6' if you're on Spark 2.4.1 +4. Enter com.microsoft:sarplus:0.5.0' or + microsoft:sarplus:0.5.0:spark32' if you're on Spark 3.2+ 5. Hit 'Create Library' 6. Attach to your cluster 7. Create 2nd library @@ -130,10 +150,10 @@ You'll also have to mount shared storage 2. Generate new token: enter 'sarplus' 3. Use databricks shell (installation here) 4. databricks configure --token -4.1. Host: e.g. https://westus.azuredatabricks.net + 1. Host: e.g. https://westus.azuredatabricks.net 5. databricks secrets create-scope --scope all --initial-manage-principal users 6. databricks secrets put --scope all --key sarpluscache -6.1. enter Azure Storage Blob key of Azure Storage created before + 1. enter Azure Storage Blob key of Azure Storage created before 7. Run mount code @@ -153,4 +173,5 @@ logging.getLogger("py4j").setLevel(logging.ERROR) ## Development -See [DEVELOPMENT.md](DEVELOPMENT.md) for implementation details and development information. +See [DEVELOPMENT.md](DEVELOPMENT.md) for implementation details and +development information. diff --git a/contrib/sarplus/VERSION b/contrib/sarplus/VERSION new file mode 100644 index 0000000000..79a2734bbf --- /dev/null +++ b/contrib/sarplus/VERSION @@ -0,0 +1 @@ +0.5.0 \ No newline at end of file diff --git a/contrib/sarplus/azure-pipelines.yml b/contrib/sarplus/azure-pipelines.yml deleted file mode 100644 index ae56707d70..0000000000 --- a/contrib/sarplus/azure-pipelines.yml +++ /dev/null @@ -1,93 +0,0 @@ -# Python package -# Create and test a Python package on multiple Python versions. -# Add steps that analyze code, save the dist with the build record, publish to a PyPI-compatible index, and more: -# https://docs.microsoft.com/azure/devops/pipelines/languages/python - -pr: - branches: - include: - - staging - - master - paths: - include: - - contrib/sarplus/* - -# no CI trigger -trigger: none - -jobs: -- job: 'Test' - pool: - vmImage: 'Ubuntu 16.04' - strategy: - matrix: - Python35-Spark2.3: - python.version: '3.5' - spark.version: '2.3.0' - Python36-Spark2.3: - python.version: '3.6' - spark.version: '2.3.0' - Python35-Spark2.4.1: - python.version: '3.5' - spark.version: '2.4.1' - Python36-Spark2.4.1: - python.version: '3.6' - spark.version: '2.4.1' - Python36-Spark2.4.3: - python.version: '3.6' - spark.version: '2.4.3' - Python37-Spark2.4.3: - python.version: '3.7' - spark.version: '2.4.3' - maxParallel: 4 - - steps: - - task: ComponentGovernanceComponentDetection@0 - inputs: - scanType: 'Register' - verbosity: 'Verbose' - alertWarningLevel: 'High' - sourceScanPath: contrib/sarplus - - - task: UsePythonVersion@0 - inputs: - versionSpec: '$(python.version)' - architecture: 'x64' - - # pyarrow version: https://issues.apache.org/jira/projects/SPARK/issues/SPARK-29367 - - script: python -m pip install --upgrade pip && pip install pyspark==$(spark.version) pytest pandas pybind11 pyarrow==0.14.1 sklearn - displayName: 'Install dependencies' - - - script: | - cd contrib/sarplus/scala - sparkversion=$(spark.version) sbt package - cd ../python - python setup.py install - pytest tests --doctest-modules --junitxml=junit/test-results.xml - displayName: 'pytest' - - - script: | - cd contrib/sarplus/scala - sparkversion=$(spark.version) sbt test - displayName: 'scala test' - - - - task: PublishTestResults@2 - inputs: - testResultsFiles: '**/test-results.xml' - testRunTitle: 'Python $(python.version)' - condition: succeededOrFailed() - -- job: 'Publish' - dependsOn: 'Test' - pool: - vmImage: 'Ubuntu 16.04' - - steps: - - task: UsePythonVersion@0 - inputs: - versionSpec: '3.x' - architecture: 'x64' - - - script: cd contrib/sarplus/python && python setup.py sdist - displayName: 'Build sdist' diff --git a/contrib/sarplus/python/.flake8 b/contrib/sarplus/python/.flake8 new file mode 100644 index 0000000000..4c042c65b0 --- /dev/null +++ b/contrib/sarplus/python/.flake8 @@ -0,0 +1,4 @@ +[flake8] +max-line-length = 120 +ignore = W291 +per-file-ignores = pysarplus/SARPlus.py: E501, pysarplus/__init__.py: F401 \ No newline at end of file diff --git a/contrib/sarplus/python/README.md b/contrib/sarplus/python/README.md new file mode 100644 index 0000000000..a280bc7614 --- /dev/null +++ b/contrib/sarplus/python/README.md @@ -0,0 +1,13 @@ +# SARplus + +Simple Algorithm for Recommendation (SAR) is a neighborhood based +algorithm for personalized recommendations based on user transaction +history. SAR recommends items that are most **similar** to the ones +that the user already has an existing **affinity** for. Two items are +**similar** if the users that interacted with one item are also likely +to have interacted with the other. A user has an **affinity** to an +item if they have interacted with it in the past. + +SARplus is an efficient implementation of this algorithm for Spark. +More details can be found at +[sarplus@microsoft/recommenders](https://github.com/microsoft/recommenders/tree/main/contrib/sarplus). diff --git a/contrib/sarplus/python/pyproject.toml b/contrib/sarplus/python/pyproject.toml new file mode 100644 index 0000000000..415ac9499d --- /dev/null +++ b/contrib/sarplus/python/pyproject.toml @@ -0,0 +1,7 @@ +[build-system] +requires = [ + "pybind11", + "setuptools>=42", + "wheel", +] +build-backend = "setuptools.build_meta" diff --git a/contrib/sarplus/python/pysarplus/SARModel.py b/contrib/sarplus/python/pysarplus/SARModel.py index bd18c2c88f..afd90cd0f9 100644 --- a/contrib/sarplus/python/pysarplus/SARModel.py +++ b/contrib/sarplus/python/pysarplus/SARModel.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + import pysarplus_cpp import os diff --git a/contrib/sarplus/python/pysarplus/SARPlus.py b/contrib/sarplus/python/pysarplus/SARPlus.py index c372bf82bd..7bb619bf67 100644 --- a/contrib/sarplus/python/pysarplus/SARPlus.py +++ b/contrib/sarplus/python/pysarplus/SARPlus.py @@ -1,13 +1,10 @@ -""" -This is the one and only (to rule them all) implementation of SAR. -""" +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +"""This is the implementation of SAR.""" import logging -import pyspark.sql.functions as F import pandas as pd from pyspark.sql.types import ( - StringType, - DoubleType, StructType, StructField, IntegerType, @@ -16,6 +13,7 @@ from pyspark.sql.functions import pandas_udf, PandasUDFType from pysarplus import SARModel + SIM_COOCCUR = "cooccurrence" SIM_JACCARD = "jaccard" SIM_LIFT = "lift" @@ -25,7 +23,7 @@ class SARPlus: - """SAR implementation for PySpark""" + """SAR implementation for PySpark.""" def __init__( self, @@ -41,6 +39,21 @@ def __init__( timedecay_formula=False, threshold=1, ): + + """Initialize model parameters + Args: + spark (pyspark.sql.SparkSession): Spark session + col_user (str): user column name + col_item (str): item column name + col_rating (str): rating column name + col_timestamp (str): timestamp column name + table_prefix (str): name prefix of the generated tables + similarity_type (str): ['cooccurrence', 'jaccard', 'lift'] option for computing item-item similarity + time_decay_coefficient (float): number of days till ratings are decayed by 1/2 + time_now (int | None): current time for time decay calculation + timedecay_formula (bool): flag to apply time decay + threshold (int): item-item co-occurrences below this threshold will be removed + """ assert threshold > 0 self.spark = spark @@ -66,13 +79,15 @@ def f(self, str, **kwargs): # current time for time decay calculation # cooccurrence matrix threshold def fit(self, df): - """Main fit method for SAR. Expects the dataframes to have row_id, col_id columns which are indexes, + """Main fit method for SAR. + + Expects the dataframes to have row_id, col_id columns which are indexes, i.e. contain the sequential integer index of the original alphanumeric user and item IDs. Dataframe also contains rating and timestamp as floats; timestamp is in seconds since Epoch by default. Arguments: - df (pySpark.DataFrame): input dataframe which contains the index of users and items. """ - + df (pySpark.DataFrame): input dataframe which contains the index of users and items. + """ # threshold - items below this number get set to zero in coocurrence counts df.createOrReplaceTempView(self.f("{prefix}df_train_input")) @@ -93,12 +108,12 @@ def fit(self, df): query = self.f( """ SELECT - {col_user}, {col_item}, + {col_user}, {col_item}, SUM({col_rating} * EXP(-log(2) * (latest_timestamp - CAST({col_timestamp} AS long)) / ({time_decay_coefficient} * 3600 * 24))) as {col_rating} FROM {prefix}df_train_input, (SELECT CAST(MAX({col_timestamp}) AS long) latest_timestamp FROM {prefix}df_train_input) - GROUP BY {col_user}, {col_item} - CLUSTER BY {col_user} + GROUP BY {col_user}, {col_item} + CLUSTER BY {col_user} """ ) diff --git a/contrib/sarplus/python/pysarplus/__init__.py b/contrib/sarplus/python/pysarplus/__init__.py index 4e44ba7fe9..0d922d7df1 100644 --- a/contrib/sarplus/python/pysarplus/__init__.py +++ b/contrib/sarplus/python/pysarplus/__init__.py @@ -1,2 +1,20 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +from pathlib import Path + from .SARModel import SARModel from .SARPlus import SARPlus + +__title__ = "pysarplus" +__version__ = (Path(__file__).resolve().parent / "VERSION").read_text().strip() +__author__ = "RecoDev Team at Microsoft" +__license__ = "MIT" +__copyright__ = "Copyright 2018-present Microsoft Corporation" + +# Synonyms +TITLE = __title__ +VERSION = __version__ +AUTHOR = __author__ +LICENSE = __license__ +COPYRIGHT = __copyright__ diff --git a/contrib/sarplus/python/setup.py b/contrib/sarplus/python/setup.py index bb072fe324..fc1c866189 100644 --- a/contrib/sarplus/python/setup.py +++ b/contrib/sarplus/python/setup.py @@ -1,5 +1,9 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + import sysconfig +from pathlib import Path from setuptools import setup from setuptools.extension import Extension @@ -14,35 +18,47 @@ def __str__(self): return pybind11.get_include(self.user) +DEPENDENCIES = [ + "numpy", + "pandas", + "pyarrow>=1.0.0", + "pybind11>=2.2", + "pyspark>=3.0.0" +] + setup( name="pysarplus", - version="0.2.6", + version=(Path(__file__).resolve().parent / "pysarplus" / "VERSION").read_text().strip(), description="SAR prediction for use with PySpark", - url="https://github.com/Microsoft/Recommenders/contrib/sarplus", - author="Markus Cozowicz", - author_email="marcozo@microsoft.com", + long_description=(Path(__file__).resolve().parent / "README.md").read_text(), + long_description_content_type='text/markdown', + url="https://github.com/microsoft/recommenders/tree/main/contrib/sarplus", + author="RecoDev Team at Microsoft", + author_email="recodevteam@service.microsoft.com", license="MIT", classifiers=[ - "Development Status :: 4 - Beta", + "Development Status :: 5 - Production/Stable", "License :: OSI Approved :: MIT License", - "Programming Language :: Python :: 3.4", - "Programming Language :: Python :: 3.5", "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", "Intended Audience :: Developers", "Intended Audience :: Science/Research", "Topic :: Scientific/Engineering :: Mathematics", ], setup_requires=["pytest-runner"], - install_requires=["pybind11>=2.2"], + install_requires=DEPENDENCIES, tests_require=["pytest"], packages=["pysarplus"], + package_data={"": ["VERSION"]}, ext_modules=[ Extension( "pysarplus_cpp", ["src/pysarplus.cpp"], include_dirs=[get_pybind_include(), get_pybind_include(user=True)], - extra_compile_args=sysconfig.get_config_var("CFLAGS").split() - + ["-std=c++11", "-Wall", "-Wextra"], + extra_compile_args=sysconfig.get_config_var("CFLAGS").split() + ["-std=c++11", "-Wall", "-Wextra"], libraries=["stdc++"], language="c++11", ) diff --git a/contrib/sarplus/python/src/pysarplus.cpp b/contrib/sarplus/python/src/pysarplus.cpp index 7a5a2739f9..0b06912740 100644 --- a/contrib/sarplus/python/src/pysarplus.cpp +++ b/contrib/sarplus/python/src/pysarplus.cpp @@ -1,3 +1,8 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ + #include #include diff --git a/contrib/sarplus/python/tests/conftest.py b/contrib/sarplus/python/tests/conftest.py new file mode 100644 index 0000000000..44efbde4a7 --- /dev/null +++ b/contrib/sarplus/python/tests/conftest.py @@ -0,0 +1,83 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import calendar +import datetime +import pandas as pd +import pytest + +from sklearn.model_selection import train_test_split + + +@pytest.fixture(scope="module") +def demo_usage_data(header, sar_settings): + # load the data + data = pd.read_csv(sar_settings["FILE_DIR"] + "demoUsage.csv") + data["rating"] = pd.Series([1] * data.shape[0]) + data = data.rename( + columns={ + "userId": header["col_user"], + "productId": header["col_item"], + "rating": header["col_rating"], + "timestamp": header["col_timestamp"], + } + ) + + # convert timestamp + data[header["col_timestamp"]] = data[header["col_timestamp"]].apply( + lambda s: float( + calendar.timegm( + datetime.datetime.strptime(s, "%Y/%m/%dT%H:%M:%S").timetuple() + ) + ) + ) + + return data + + +@pytest.fixture(scope="module") +def header(): + header = { + "col_user": "UserId", + "col_item": "MovieId", + "col_rating": "Rating", + "col_timestamp": "Timestamp", + } + return header + + +@pytest.fixture(scope="module") +def pandas_dummy(header): + ratings_dict = { + header["col_user"]: [1, 1, 1, 1, 2, 2, 2, 2, 2, 2], + header["col_item"]: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + header["col_rating"]: [1, 2, 3, 4, 5, 1, 2, 3, 4, 5], + } + df = pd.DataFrame(ratings_dict) + return df + + +@pytest.fixture(scope="module") +def pandas_dummy_timestamp(pandas_dummy, header): + time = 1535133442 + time_series = [time + 20 * i for i in range(10)] + df = pandas_dummy + df[header["col_timestamp"]] = time_series + return df + + +@pytest.fixture(scope="module") +def sar_settings(): + return { + # absolute tolerance parameter for matrix equivalence in SAR tests + "ATOL": 1e-8, + # directory of the current file - used to link unit test data + "FILE_DIR": "https://recodatasets.z20.web.core.windows.net/sarunittest/", + # user ID used in the test files (they are designed for this user ID, this is part of the test) + "TEST_USER_ID": "0003000098E85347", + } + + +@pytest.fixture(scope="module") +def train_test_dummy_timestamp(pandas_dummy_timestamp): + return train_test_split(pandas_dummy_timestamp, test_size=0.2, random_state=0) diff --git a/contrib/sarplus/python/tests/test_pyspark_sar.py b/contrib/sarplus/python/tests/test_pyspark_sar.py index f2b85b5e2c..110e469f36 100644 --- a/contrib/sarplus/python/tests/test_pyspark_sar.py +++ b/contrib/sarplus/python/tests/test_pyspark_sar.py @@ -1,13 +1,13 @@ -import calendar -import datetime +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + import math +from pathlib import Path + import numpy as np import pandas as pd -import pytest -import os -from sklearn.model_selection import train_test_split - from pyspark.sql import SparkSession +import pytest from pysarplus import SARPlus, SARModel @@ -20,7 +20,7 @@ def assert_compare(expected_id, expected_score, actual_prediction): @pytest.fixture(scope="module") -def spark(app_name="Sample", url="local[*]", memory="1G"): +def spark(tmp_path_factory, app_name="Sample", url="local[*]", memory="1G"): """Start Spark if not started Args: app_name (str): sets name of the application @@ -28,19 +28,26 @@ def spark(app_name="Sample", url="local[*]", memory="1G"): memory (str): size of memory for spark driver """ + try: + sarplus_jar_path = next( + Path(__file__) + .parents[2] + .joinpath("scala", "target") + .glob("**/sarplus*.jar") + ).absolute() + except StopIteration: + raise Exception("Could not find Sarplus JAR file") + spark = ( SparkSession.builder.appName(app_name) .master(url) - .config( - "spark.jars", - os.path.dirname(__file__) - + "/../../scala/target/scala-2.11/sarplus_2.11-0.2.6.jar", - ) + .config("spark.jars", sarplus_jar_path) .config("spark.driver.memory", memory) .config("spark.sql.shuffle.partitions", "1") .config("spark.default.parallelism", "1") .config("spark.sql.crossJoin.enabled", True) .config("spark.ui.enabled", False) + .config("spark.sql.warehouse.dir", str(tmp_path_factory.mktemp("spark"))) # .config("spark.eventLog.enabled", True) # only for local debugging, breaks on build server .getOrCreate() ) @@ -59,17 +66,6 @@ def sample_cache(spark): return path -@pytest.fixture(scope="module") -def header(): - header = { - "col_user": "UserId", - "col_item": "MovieId", - "col_rating": "Rating", - "col_timestamp": "Timestamp", - } - return header - - @pytest.fixture(scope="module") def pandas_dummy_dataset(header): """Load sample dataset in pandas for testing; can be used to create a Spark dataframe @@ -177,78 +173,6 @@ def test_e2e(spark, pandas_dummy_dataset, header): assert np.allclose(r1.score.values, r2.score.values, 1e-3) -@pytest.fixture(scope="module") -def pandas_dummy(header): - ratings_dict = { - header["col_user"]: [1, 1, 1, 1, 2, 2, 2, 2, 2, 2], - header["col_item"]: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], - header["col_rating"]: [1, 2, 3, 4, 5, 1, 2, 3, 4, 5], - } - df = pd.DataFrame(ratings_dict) - return df - - -@pytest.fixture(scope="module") -def pandas_dummy_timestamp(pandas_dummy, header): - time = 1535133442 - time_series = [time + 20 * i for i in range(10)] - df = pandas_dummy - df[header["col_timestamp"]] = time_series - return df - - -@pytest.fixture(scope="module") -def train_test_dummy_timestamp(pandas_dummy_timestamp): - return train_test_split(pandas_dummy_timestamp, test_size=0.2, random_state=0) - - -@pytest.fixture(scope="module") -def demo_usage_data(header, sar_settings): - # load the data - data = pd.read_csv(sar_settings["FILE_DIR"] + "demoUsage.csv") - data["rating"] = pd.Series([1] * data.shape[0]) - data = data.rename( - columns={ - "userId": header["col_user"], - "productId": header["col_item"], - "rating": header["col_rating"], - "timestamp": header["col_timestamp"], - } - ) - - # convert timestamp - data[header["col_timestamp"]] = data[header["col_timestamp"]].apply( - lambda s: float( - calendar.timegm( - datetime.datetime.strptime(s, "%Y/%m/%dT%H:%M:%S").timetuple() - ) - ) - ) - - return data - - -@pytest.fixture(scope="module") -def demo_usage_data_spark(spark, demo_usage_data, header): - data_local = demo_usage_data[[x[1] for x in header.items()]] - # TODO: install pyArrow in DS VM - # spark.conf.set("spark.sql.execution.arrow.enabled", "true") - data = spark.createDataFrame(data_local) - return data - - -@pytest.fixture(scope="module") -def sar_settings(): - return { - # absolute tolerance parameter for matrix equivalence in SAR tests - "ATOL": 1e-8, - # directory of the current file - used to link unit test data - "FILE_DIR": "http://recodatasets.blob.core.windows.net/sarunittest/", - # user ID used in the test files (they are designed for this user ID, this is part of the test) - "TEST_USER_ID": "0003000098E85347", - } - - @pytest.mark.parametrize( "similarity_type, timedecay_formula", [("jaccard", False), ("lift", True)] ) @@ -259,7 +183,7 @@ def test_fit( spark, **header, timedecay_formula=timedecay_formula, - similarity_type=similarity_type + similarity_type=similarity_type, ) trainset, testset = train_test_dummy_timestamp @@ -276,6 +200,7 @@ def test_fit( Main SAR tests are below - load test files which are used for both Scala SAR and Python reference implementations """ + # Tests 1-6 @pytest.mark.parametrize( "threshold,similarity_type,file", @@ -289,7 +214,13 @@ def test_fit( ], ) def test_sar_item_similarity( - spark, threshold, similarity_type, file, demo_usage_data, sar_settings, header + spark, + threshold, + similarity_type, + file, + demo_usage_data, + sar_settings, + header, ): model = SARPlus( @@ -299,7 +230,7 @@ def test_sar_item_similarity( time_decay_coefficient=30, time_now=None, threshold=threshold, - similarity_type=similarity_type + similarity_type=similarity_type, ) df = spark.createDataFrame(demo_usage_data) @@ -339,7 +270,9 @@ def test_sar_item_similarity( ) assert np.allclose( - item_similarity.value.values, item_similarity_ref.value.values + item_similarity.value.values, + item_similarity_ref.value.values, + atol=sar_settings["ATOL"], ) @@ -353,7 +286,7 @@ def test_user_affinity(spark, demo_usage_data, sar_settings, header): timedecay_formula=True, time_decay_coefficient=30, time_now=time_now, - similarity_type="cooccurrence" + similarity_type="cooccurrence", ) df = spark.createDataFrame(demo_usage_data) @@ -393,7 +326,14 @@ def test_user_affinity(spark, demo_usage_data, sar_settings, header): [(3, "cooccurrence", "count"), (3, "jaccard", "jac"), (3, "lift", "lift")], ) def test_userpred( - spark, threshold, similarity_type, file, header, sar_settings, demo_usage_data + spark, + tmp_path, + threshold, + similarity_type, + file, + header, + sar_settings, + demo_usage_data, ): time_now = demo_usage_data[header["col_timestamp"]].max() @@ -407,19 +347,13 @@ def test_userpred( time_decay_coefficient=30, time_now=time_now, threshold=threshold, - similarity_type=similarity_type + similarity_type=similarity_type, ) df = spark.createDataFrame(demo_usage_data) model.fit(df) - url = ( - sar_settings["FILE_DIR"] - + "userpred_" - + file - + str(threshold) - + "_userid_only.csv" - ) + url = sar_settings["FILE_DIR"] + "userpred_" + file + str(threshold) + "_userid_only.csv" pred_ref = pd.read_csv(url) pred_ref = ( @@ -428,14 +362,14 @@ def test_userpred( .reset_index(drop=True) ) - # Note: it's important to have a separate cache_path for each run as they're interferring with each other + # Note: it's important to have a separate cache_path for each run as they're interfering with each other pred = model.recommend_k_items( spark.createDataFrame( demo_usage_data[ demo_usage_data[header["col_user"]] == sar_settings["TEST_USER_ID"] ] ), - cache_path="test_userpred-" + test_id, + cache_path=str(tmp_path.joinpath("test_userpred-" + test_id)), top_k=10, n_user_prediction_partitions=1, ) diff --git a/contrib/sarplus/scala/build.sbt b/contrib/sarplus/scala/build.sbt index f79cb49a43..7c573270d6 100644 --- a/contrib/sarplus/scala/build.sbt +++ b/contrib/sarplus/scala/build.sbt @@ -1,30 +1,93 @@ -scalaVersion := "2.11.8" +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ -sparkVersion := sys.env.get("sparkversion").getOrElse("2.3.0") - -spName := "microsoft/sarplus" - -organization := "microsoft" name := "sarplus" -version := "0.2.6" +// Denpendency configuration -sparkComponents ++= Seq("core", "sql", "mllib") +lazy val sparkVer = settingKey[String]("spark version") +lazy val hadoopVer = settingKey[String]("hadoop version") -libraryDependencies ++= Seq( - "commons-io" % "commons-io" % "2.6", - "com.google.guava" % "guava" % "25.0-jre", - "org.scalatest" %% "scalatest" % "3.0.5" % "test", - "org.scalamock" %% "scalamock" % "4.1.0" % "test" +lazy val commonSettings = Seq( + organization := "sarplus.microsoft", + version := IO.read(new File("../VERSION")), + resolvers ++= Seq( + Resolver.sonatypeRepo("snapshots"), + Resolver.sonatypeRepo("releases"), + ), + addCompilerPlugin("org.scalamacros" % "paradise" % "2.1.1" cross CrossVersion.full), + sparkVer := sys.env.getOrElse("SPARK_VERSION", "3.2.0"), + hadoopVer := sys.env.getOrElse("HADOOP_VERSION", "3.3.1"), + libraryDependencies ++= Seq( + "com.fasterxml.jackson.core" % "jackson-databind" % "2.12.2", + "commons-io" % "commons-io" % "2.8.0", + "org.apache.hadoop" % "hadoop-common" % hadoopVer.value, + "org.apache.hadoop" % "hadoop-hdfs" % hadoopVer.value, + "org.apache.spark" %% "spark-core" % sparkVer.value, + "org.apache.spark" %% "spark-mllib" % sparkVer.value, + "org.apache.spark" %% "spark-sql" % sparkVer.value, + "org.scala-lang" % "scala-reflect" % scalaVersion.value, + "com.google.guava" % "guava" % "15.0", + "org.scalamock" %% "scalamock" % "4.1.0" % "test", + "org.scalatest" %% "scalatest" % "3.0.8" % "test", + "xerces" % "xercesImpl" % "2.12.1", + ), + Compile / packageBin / artifact := { + val prev: Artifact = (Compile / packageBin / artifact).value + prev.withClassifier( + prev.classifier match { + case None => { + val splitVer = sparkVer.value.split('.') + val major = splitVer(0).toInt + val minor = splitVer(1).toInt + if (major >=3 && minor >= 2) Some("spark32") else None + } + case Some(s: String) => Some(s) + } + ) + }, ) -// All Spark Packages need a license -licenses := Seq("MIT" -> url("http://opensource.org/licenses/MIT")) +lazy val compat = project.settings(commonSettings) +lazy val root = (project in file(".")) + .dependsOn(compat) + .settings( + name := "sarplus", + commonSettings, + ) -// doesn't work anyway... -credentials += Credentials(Path.userHome / ".ivy2" / ".sbtcredentials") // A file containing credentials -spHomepage := "http://github.com/Microsoft/Recommenders/contrib/sarplus" +// POM metadata configuration. See https://www.scala-sbt.org/release/docs/Using-Sonatype.html -// If you published your package to Maven Central for this release (must be done prior to spPublish) -spIncludeMaven := true +organization := "com.microsoft.sarplus" +organizationName := "microsoft" +organizationHomepage := Some(url("https://microsoft.com")) + +scmInfo := Some( + ScmInfo( + url("https://github.com/microsoft/recommenders/tree/main/contrib/sarplus"), + "scm:git@github.com:microsoft/recommenders.git" + ) +) + +developers := List( + Developer( + id = "recodev", + name = "RecoDev Team at Microsoft", + email = "recodevteam@service.microsoft.com", + url = url("https://github.com/microsoft/recommenders/") + ) +) + +description := "sarplus" +licenses := Seq("MIT" -> url("http://opensource.org/licenses/MIT")) +homepage := Some(url("https://github.com/microsoft/recommenders/tree/main/contrib/sarplus")) +pomIncludeRepository := { _ => false } +publishTo := { + val nexus = "https://oss.sonatype.org/" + if (isSnapshot.value) Some("snapshots" at nexus + "content/repositories/snapshots") + else Some("releases" at nexus + "service/local/staging/deploy/maven2") +} +publishMavenStyle := true diff --git a/contrib/sarplus/scala/compat/src/main/scala/com/microsoft/sarplus/compat/spark/since3p2defvisible.scala b/contrib/sarplus/scala/compat/src/main/scala/com/microsoft/sarplus/compat/spark/since3p2defvisible.scala new file mode 100644 index 0000000000..780581d4e5 --- /dev/null +++ b/contrib/sarplus/scala/compat/src/main/scala/com/microsoft/sarplus/compat/spark/since3p2defvisible.scala @@ -0,0 +1,39 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ + +package com.microsoft.sarplus.spark + +import scala.annotation.{StaticAnnotation, compileTimeOnly} +import scala.language.experimental.macros +import scala.reflect.macros.Context + +import util.Properties.versionNumberString + +@compileTimeOnly("enable macro paradise to expand macro annotations") +class since3p2defvisible extends StaticAnnotation { + def macroTransform(annottees: Any*): Any = macro since3p2defvisibleMacro.impl +} + +object since3p2defvisibleMacro { + def impl(c: Context)(annottees: c.Tree*) = { + import c.universe._ + annottees match { + case q"$mods def $name[..$tparams](...$paramss): $tpt = $body" :: tail => + // NOTE: There seems no way to find out the Spark version. + val major = versionNumberString.split('.')(0).toInt + val minor = versionNumberString.split('.')(1).toInt + val patch = versionNumberString.split('.')(2).toInt + if (major >= 2 && minor >= 12 && patch >= 14) { + q""" + $mods def $name[..$tparams](...$paramss): $tpt = + $body + """ + } else { + q"" + } + case _ => throw new IllegalArgumentException("Please annotate a method") + } + } +} diff --git a/contrib/sarplus/scala/project/build.properties b/contrib/sarplus/scala/project/build.properties index 133a8f197e..10fd9eee04 100644 --- a/contrib/sarplus/scala/project/build.properties +++ b/contrib/sarplus/scala/project/build.properties @@ -1 +1 @@ -sbt.version=0.13.17 +sbt.version=1.5.5 diff --git a/contrib/sarplus/scala/project/plugins.sbt b/contrib/sarplus/scala/project/plugins.sbt index f1495154b8..0a8aeeaba2 100644 --- a/contrib/sarplus/scala/project/plugins.sbt +++ b/contrib/sarplus/scala/project/plugins.sbt @@ -1,4 +1,3 @@ -// You may use this file to add plugin dependencies for sbt. -resolvers += "Spark Package Main Repo" at "https://dl.bintray.com/spark-packages/maven" - -addSbtPlugin("org.spark-packages" %% "sbt-spark-package" % "0.2.6") +addSbtPlugin("no.arktekk.sbt" % "aether-deploy" % "0.27.0") +addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.9.1") +addSbtPlugin("com.github.sbt" % "sbt-pgp" % "2.1.2") diff --git a/contrib/sarplus/scala/python/pysarplus_dummy/__init__.py b/contrib/sarplus/scala/python/pysarplus_dummy/__init__.py index aa0be35f48..0720a92163 100644 --- a/contrib/sarplus/scala/python/pysarplus_dummy/__init__.py +++ b/contrib/sarplus/scala/python/pysarplus_dummy/__init__.py @@ -1 +1,4 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + installed = 1 diff --git a/contrib/sarplus/scala/python/setup.py b/contrib/sarplus/scala/python/setup.py index 49a0dc8b59..9821b294c9 100644 --- a/contrib/sarplus/scala/python/setup.py +++ b/contrib/sarplus/scala/python/setup.py @@ -1,11 +1,15 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + from distutils.core import setup +import os setup( name="pysarplus_dummy", - version="0.2", + version=(Path(__file__).resolve().parent.parent.parent / "VERSION").read_text().strip(), description="pysarplus dummy package to trigger spark packaging", - author="Markus Cozowicz", - author_email="marcozo@microsoft.com", + author="RecoDev Team at Microsoft", + author_email="recodevteam@service.microsoft.com", url="https://github.com/Microsoft/Recommenders/contrib/sarplus", packages=["pysarplus_dummy"], ) diff --git a/contrib/sarplus/scala/src/main/scala/com/microsoft/sarplus/DefaultSource.scala b/contrib/sarplus/scala/src/main/scala/com/microsoft/sarplus/DefaultSource.scala index c693e870ad..f7a1da5376 100644 --- a/contrib/sarplus/scala/src/main/scala/com/microsoft/sarplus/DefaultSource.scala +++ b/contrib/sarplus/scala/src/main/scala/com/microsoft/sarplus/DefaultSource.scala @@ -1,3 +1,8 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ + package com.microsoft.sarplus import org.apache.spark.sql.sources.DataSourceRegister diff --git a/contrib/sarplus/scala/src/main/scala/com/microsoft/sarplus/SARCacheOutputWriter.scala b/contrib/sarplus/scala/src/main/scala/com/microsoft/sarplus/SARCacheOutputWriter.scala index 7f8f8d446f..49c924c6c7 100644 --- a/contrib/sarplus/scala/src/main/scala/com/microsoft/sarplus/SARCacheOutputWriter.scala +++ b/contrib/sarplus/scala/src/main/scala/com/microsoft/sarplus/SARCacheOutputWriter.scala @@ -1,3 +1,8 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ + package com.microsoft.sarplus import java.io.{DataOutputStream, FileInputStream, FileOutputStream, BufferedOutputStream, OutputStream} @@ -11,8 +16,10 @@ import org.apache.spark.sql.types._ import org.apache.commons.io.IOUtils import com.google.common.io.LittleEndianDataOutputStream +import com.microsoft.sarplus.spark.since3p2defvisible + class SARCacheOutputWriter( - path: String, + filePath: String, outputStream: OutputStream, schema: StructType) extends OutputWriter { @@ -20,8 +27,8 @@ class SARCacheOutputWriter( if (schema.length < 3) throw new IllegalArgumentException("Schema must have at least 3 fields") - val pathOffset = path + ".offsets" - val pathRelated = path + ".related" + val pathOffset = filePath + ".offsets" + val pathRelated = filePath + ".related" // temporary output files val tempOutputOffset = new LittleEndianDataOutputStream(new BufferedOutputStream(new FileOutputStream(pathOffset), 8*1024)) @@ -44,7 +51,7 @@ class SARCacheOutputWriter( if(lastId != i1) { - tempOutputOffset.writeLong(rowNumber) + tempOutputOffset.writeLong(rowNumber) offsetCount += 1 lastId = i1 } @@ -64,7 +71,7 @@ class SARCacheOutputWriter( if(lastId != i1) { - tempOutputOffset.writeLong(rowNumber) + tempOutputOffset.writeLong(rowNumber) offsetCount += 1 lastId = i1 } @@ -75,7 +82,7 @@ class SARCacheOutputWriter( rowNumber += 1 } - override def close(): Unit = + override def close(): Unit = { tempOutputOffset.writeLong(rowNumber) offsetCount += 1 @@ -94,5 +101,8 @@ class SARCacheOutputWriter( input.close outputFinal.close - } + } + + @since3p2defvisible + override def path(): String = filePath } diff --git a/contrib/sarplus/scala/src/main/scala/com/microsoft/sarplus/SARCacheOutputWriterFactory.scala b/contrib/sarplus/scala/src/main/scala/com/microsoft/sarplus/SARCacheOutputWriterFactory.scala index 71d1e44f37..2e41effa3a 100644 --- a/contrib/sarplus/scala/src/main/scala/com/microsoft/sarplus/SARCacheOutputWriterFactory.scala +++ b/contrib/sarplus/scala/src/main/scala/com/microsoft/sarplus/SARCacheOutputWriterFactory.scala @@ -1,3 +1,8 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ + package com.microsoft.sarplus import org.apache.hadoop.mapreduce.TaskAttemptContext diff --git a/contrib/sarplus/scala/src/test/scala/com/microsoft/sarplus/SARCacheOutputWriterSpec.scala b/contrib/sarplus/scala/src/test/scala/com/microsoft/sarplus/SARCacheOutputWriterSpec.scala index 5eadfd8012..7565965e80 100644 --- a/contrib/sarplus/scala/src/test/scala/com/microsoft/sarplus/SARCacheOutputWriterSpec.scala +++ b/contrib/sarplus/scala/src/test/scala/com/microsoft/sarplus/SARCacheOutputWriterSpec.scala @@ -1,3 +1,8 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ + package com.microsoft.sarplus import org.scalatest._