diff --git a/CMakeLists.txt b/CMakeLists.txt index b901f41c29f2..c6e0a4e19a33 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -91,6 +91,7 @@ cmake_dependent_option(ENABLE_TESTCOVERAGE "Enable compilation with test coverag option(BUILD_EXTENSION_PATH "Path to extension to build" "") option(BUILD_CYTHON_MODULES "Build cython modules." OFF) option(LOG_FATAL_THROW "Log exceptions but do not abort" ON) +option(BUILD_JAVA_NATIVE "Skip signal handler registration for Java Binding" OFF) cmake_dependent_option(USE_SPLIT_ARCH_DLL "Build a separate DLL for each Cuda arch (Windows only)." ON "MSVC" OFF) cmake_dependent_option(USE_CCACHE "Attempt using CCache to wrap the compilation" ON "UNIX" OFF) cmake_dependent_option(MXNET_FORCE_SHARED_CRT "Build with dynamic CRT on Windows (/MD)" ON "MXNET_BUILD_SHARED_LIBS" OFF) @@ -970,6 +971,10 @@ if(USE_CPP_PACKAGE) target_compile_definitions(mxnet PUBLIC MXNET_USE_CPP_PACKAGE=1) endif() +if(BUILD_JAVA_NATIVE) + add_definitions(-DSKIP_SIGNAL_HANDLER_REGISTRATION=1) +endif() + if(NOT CMAKE_BUILD_TYPE STREQUAL "Distribution") # Staticbuild applies linker version script to hide private symbols, breaking unit tests add_subdirectory(tests) diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh index 70934f64080c..1e75ab2d5e33 100755 --- a/ci/docker/runtime_functions.sh +++ b/ci/docker/runtime_functions.sh @@ -310,6 +310,23 @@ build_ubuntu_cpu() { build_ubuntu_cpu_openblas } +build_ubuntu_cpu_and_test_java() { + build_ubuntu_cpu_openblas_java + java_package_integration_test +} + +java_package_integration_test() { + # make sure you are using java 11 + # build java project + cd /work/mxnet/java-package + ./gradlew build -x javadoc + # generate native library + ./gradlew :native:buildLocalLibraryJarDefault + ./gradlew :native:mkl-linuxJar + # run integration + ./gradlew :integration:run +} + build_ubuntu_cpu_openblas() { set -ex cd /work/build @@ -327,6 +344,24 @@ build_ubuntu_cpu_openblas() { ninja } +build_ubuntu_cpu_openblas_java() { + set -ex + cd /work/build + CXXFLAGS="-Wno-error=strict-overflow" CC=gcc-7 CXX=g++-7 cmake \ + -DCMAKE_BUILD_TYPE="RelWithDebInfo" \ + -DENABLE_TESTCOVERAGE=ON \ + -DUSE_TVM_OP=ON \ + -DUSE_BLAS=Open \ + -DUSE_ONEDNN=OFF \ + -DUSE_CUDA=OFF \ + -DUSE_DIST_KVSTORE=ON \ + -DBUILD_CYTHON_MODULES=ON \ + -DBUILD_EXTENSION_PATH=/work/mxnet/example/extensions/lib_external_ops \ + -DBUILD_JAVA_NATIVE=ON \ + -G Ninja /work/mxnet + ninja +} + build_ubuntu_cpu_mkl() { set -ex cd /work/build diff --git a/java-package/Develop.md b/java-package/Develop.md new file mode 100644 index 000000000000..8bc0ffe90c6f --- /dev/null +++ b/java-package/Develop.md @@ -0,0 +1,109 @@ + + + + + + + + + + + + + + + + + +# Development Tips + +## Set up the Project +### Step 1. Obtain MXNet Library +The first step is to obtain the mxnet library. We recommend you build it from source. Also, you can download the library +from +#### Build from source +Refer to [Build From Source](https://mxnet.apache.org/get_started/build_from_source#building-mxnet) +For MacOS users: +- Prepare +```shell +# Install OS X Developer Tools +$ xcode-select --install + +# Install Homebrew +$ /usr/bin/ruby -e "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install)" + +# Install dependencies +$ brew install cmake ninja ccache opencv +``` +- Clone 3rd party projects +```shell +# Clone 3rd dependency for mxnet. It's necessary +$ git submodule update --init --recursive +``` +- Build MXNet +```shell +# select and copy cmake configure files for macos +$ cp config/darwin.cmake config.cmake + +# create build directory for prpject +$ mkdir build; cd build + +# cmake +$ cmake .. +$ cmake --build . +``` +Libraries will be generated under the directory _build/_. + +For Linux users: +Docker might help you build libraries on different platforms. You can get help from [README for CI](../ci/README.md). +For example, you can build mxnet on Ubuntu with by the following command. +```shell +$ python3 ci/build.py -p ubuntu_cpu +``` +##### Download Pre-built library +You can find the mxnet library from installed packages for mxnet, like python module. However, mxnet 2.0 is not released +yet, that's why we recommend you build it from source. +```shell +# download python module for mxnet (have to mention that mxnet 2.0 hasn't been released by now) +$ pip3 install mxnet==1.7.0.post2 +# find the location of the installed module +$ python +Python 3.6.8 |Anaconda, Inc.| (default, Dec 29 2018, 19:04:46) +>>> import mxnet +>>> mxnet + +>>> quit() +# you can locate the module under /Users/xxx/anaconda3/lib/python3.6/site-packages/mxnet/ +$ ls /Users/xxx/anaconda3/lib/python3.6/site-packages/mxnet/ | grep libmxnet +libmxnet.dylib +``` +The compiled library is the file with the name of _libmxnet.*_. For MacOS, you will receive the file with suffix +_.dylib_; For Linux, the lib file have the suffix ".so"; For Windows, the suffix is "." + +### Step 2. Build MXNet Native Lib for Java +The project uses gradle to manage dependencies. You can build the project using gradle. We have to encapsulate the mxnet +library into a jar file so that we can load it into JVM. +```shell +$ cd java-package +# Build the project +$ ./gradlew build +# Create gradle tasks to package mxnet library into jar +# The task name is in this form {$favor}-{$platform}Jar +# MacOS -> mkl-osxJar +# Linux -> mkl-linuxJar +# Windows -> mkl-winJar +$ ./gradlew :native:buildLocalLibraryJarDefault +# Build native lib for macos +$ ./gradlew mkl-osxJar +# Check the lib for osx +$ ls native/build/libs | grep osx +native-2.0.0-SNAPSHOT-osx-x86_64.jar +``` +The jar file _native-2.0.0-SNAPSHOT-osx-x86_64.jar_ is the output lib file. + +### Step 3. Run Integration Test +When we execute the task for integration test, the built mxnet native lib will be added into classpath automatically. +```shell +$ ./gradlew :integration:run + +``` \ No newline at end of file diff --git a/java-package/README.md b/java-package/README.md new file mode 100644 index 000000000000..eb5b901ea72d --- /dev/null +++ b/java-package/README.md @@ -0,0 +1,58 @@ + + + + + + + + + + + + + + + + + +# Java Package for MXNet 2.0 + +## Requirements + +## Install + +## Scripts +- customize mxnet library path +```bash +export MXNET_LIBRARY_PATH=//anaconda3/lib/python3.8/site-packages/mxnet/ +``` + + +## Tests +Test case for a rough inference run with MXNet model +```bash +./gradlew :integration:run +``` + +## Example + +```java +try (MxResource base = BaseMxResource.getSystemMxResource()) + { + Model model = Model.loadModel(Item.MLP); +// Model model = Model.loadModel("test", Paths.get("/Users/cspchen/mxnet.java_package/cache/repo/test-models/mlp.tar.gz/mlp/")); + Predictor predictor = model.newPredictor(); + NDArray input = NDArray.create(base, new Shape(1, 28, 28)).ones(); + NDList inputs = new NDList(); + inputs.add(input); + NDList result = predictor.predict(inputs); + NDArray expected = NDArray.create( + base, + new float[]{4.93476f, -0.76084447f, 0.37713608f, 0.6605506f, -1.3485785f, -0.8736369f + , 0.018061712f, -1.3274033f, 1.0609543f, 0.24042489f}, new Shape(1, 10)); + Assertions.assertAlmostEquals(result.get(0), expected); + + } catch (IOException e) { + logger.error(e.getMessage(), e); + } +``` \ No newline at end of file diff --git a/java-package/build.gradle b/java-package/build.gradle new file mode 100644 index 000000000000..909dbbb2e92d --- /dev/null +++ b/java-package/build.gradle @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +plugins { + id "com.github.spotbugs" version "4.2.0" apply true +} + +defaultTasks 'build' + +allprojects { + group 'org.apache.mxnet' + boolean isRelease = project.hasProperty("release") || project.hasProperty("staging") + version = "${java_package_version}" + (isRelease ? "" : "-SNAPSHOT") + + repositories { +// maven { +// url "https://mlrepo.djl.ai/maven/" +// } + mavenCentral() + maven { + url 'https://oss.sonatype.org/content/repositories/snapshots/' + } + } + + apply plugin: 'idea' + idea { + module { + outputDir = file('build/classes/java/main') + testOutputDir = file('build/classes/java/test') + // inheritOutputDirs = true + } + } +} + +def javaProjects() { + return subprojects.findAll { new File(it.projectDir, "src/main").exists() } +} + +configure(javaProjects()) { + apply plugin: 'java-library' + sourceCompatibility = 1.8 + targetCompatibility = 1.8 + compileJava.options.encoding = "UTF-8" + compileTestJava.options.encoding = "UTF-8" + if (JavaVersion.current() != JavaVersion.VERSION_1_8) { + compileJava.options.compilerArgs.addAll(["--release", "8"]) + } + + apply plugin: 'eclipse' + + eclipse { + jdt.file.withProperties { props -> + props.setProperty "org.eclipse.jdt.core.circularClasspath", "warning" + } + classpath { + sourceSets.test.java { + srcDirs = ["src/test/java"] + exclude "**/package-info.java" + } + } + } + + apply from: file("${rootProject.projectDir}/tools/gradle/java-formatter.gradle") + apply from: file("${rootProject.projectDir}/tools/gradle/check.gradle") + + test { + // tensorflow mobilenet and resnet require more cpu memory + maxHeapSize = "4096m" + doFirst { + if (JavaVersion.current() != JavaVersion.VERSION_1_8) { + jvmArgs = [ + '--add-opens', "java.base/jdk.internal.loader=ALL-UNNAMED" + ] + } + } + + useTestNG() { +// suiteXmlFiles << new File(rootDir, "testng.xml") //This is how to add custom testng.xml + } + + testLogging { + showStandardStreams = true + events "passed", "skipped", "failed", "standardOut", "standardError" + } + + doFirst { + systemProperties System.getProperties() + systemProperties.remove("user.dir") + systemProperty "org.apache.mxnet.logging.level", "debug" + systemProperty "org.slf4j.simpleLogger.defaultLogLevel", "debug" + systemProperty "org.slf4j.simpleLogger.log.org.mortbay.log", "warn" + systemProperty "disableProgressBar", "true" + systemProperty "nightly", System.getProperty("nightly", "false") +// systemProperty "java.library.path", "/Users/cspchen/Work/incubator-mxnet/build" + if (gradle.startParameter.offline) { + systemProperty "offline", "true" + } + } + } + + compileJava { + options.compilerArgs << "-proc:none" << "-Xlint:all,-options,-static" << "-Werror" + } + + compileTestJava { + options.compilerArgs << "-proc:none" << "-Xlint:all,-options,-static" << "-Werror" + } + + jar { + manifest { + attributes("Automatic-Module-Name": "org.apach.mxnet.${project.name.replace('-', '_')}") + } + } +} + +apply from: file("${rootProject.projectDir}/tools/gradle/jacoco.gradle") +apply from: file("${rootProject.projectDir}/tools/gradle/stats.gradle") diff --git a/java-package/example/build.gradle b/java-package/example/build.gradle new file mode 100644 index 000000000000..a6831eecfb29 --- /dev/null +++ b/java-package/example/build.gradle @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +plugins { + id 'java' +} + +group 'incubator-mxnet.java-package' +version '0.0.1-SNAPSHOT' + +repositories { + mavenCentral() +} + +dependencies { + testImplementation 'org.junit.jupiter:junit-jupiter-api:5.7.0' + testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.7.0' +} + +test { + useJUnitPlatform() +} \ No newline at end of file diff --git a/java-package/gradle.properties b/java-package/gradle.properties new file mode 100644 index 000000000000..9d32099c364b --- /dev/null +++ b/java-package/gradle.properties @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +org.gradle.daemon=true +org.gradle.jvmargs=-Xmx2048M + +systemProp.org.gradle.internal.http.socketTimeout=120000 +systemProp.org.gradle.internal.http.connectionTimeout=60000 + +# FIXME: Workaround gradle publish issue: https://github.com/gradle/gradle/issues/11308 +systemProp.org.gradle.internal.publish.checksums.insecure=true + +java_package_version=0.0.1 +mxnet_version=2.0.0 +api_version=0.0.1 +jnarator_version=0.0.1 + +antlr_version=4.7.2 +commons_cli_version=1.4 +commons_compress_version=1.20 +commons_csv_version=1.8 +gson_version=2.8.6 +jna_version=5.3.0 +netty_version=4.1.51.Final +slf4j_version=1.7.30 +log4j_slf4j_version=2.13.3 +testng_version=7.1.0 +powermock_version=2.0.7 + diff --git a/java-package/gradle/wrapper/gradle-wrapper.jar b/java-package/gradle/wrapper/gradle-wrapper.jar new file mode 100644 index 000000000000..e708b1c023ec Binary files /dev/null and b/java-package/gradle/wrapper/gradle-wrapper.jar differ diff --git a/java-package/gradle/wrapper/gradle-wrapper.properties b/java-package/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 000000000000..689c068d4cc6 --- /dev/null +++ b/java-package/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,22 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-7.0-bin.zip +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/java-package/gradlew b/java-package/gradlew new file mode 100755 index 000000000000..ebb6c09866e6 --- /dev/null +++ b/java-package/gradlew @@ -0,0 +1,186 @@ +#!/usr/bin/env sh + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +############################################################################## +## +## Gradle start up script for UN*X +## +############################################################################## + +# Attempt to set APP_HOME +# Resolve links: $0 may be a link +PRG="$0" +# Need this for relative symlinks. +while [ -h "$PRG" ] ; do + ls=`ls -ld "$PRG"` + link=`expr "$ls" : '.*-> \(.*\)$'` + if expr "$link" : '/.*' > /dev/null; then + PRG="$link" + else + PRG=`dirname "$PRG"`"/$link" + fi +done +SAVED="`pwd`" +cd "`dirname \"$PRG\"`/" >/dev/null +APP_HOME="`pwd -P`" +cd "$SAVED" >/dev/null + +APP_NAME="Gradle" +APP_BASE_NAME=`basename "$0"` + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD="maximum" + +warn () { + echo "$*" +} + +die () { + echo + echo "$*" + echo + exit 1 +} + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +nonstop=false +case "`uname`" in + CYGWIN* ) + cygwin=true + ;; + Darwin* ) + darwin=true + ;; + MINGW* ) + msys=true + ;; + NONSTOP* ) + nonstop=true + ;; +esac + +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD="$JAVA_HOME/jre/sh/java" + else + JAVACMD="$JAVA_HOME/bin/java" + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD="java" + which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." +fi + +# Increase the maximum file descriptors if we can. +if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then + MAX_FD_LIMIT=`ulimit -H -n` + if [ $? -eq 0 ] ; then + if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then + MAX_FD="$MAX_FD_LIMIT" + fi + ulimit -n $MAX_FD + if [ $? -ne 0 ] ; then + warn "Could not set maximum file descriptor limit: $MAX_FD" + fi + else + warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" + fi +fi + +# For Darwin, add options to specify how the application appears in the dock +if $darwin; then + GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" +fi + +# For Cygwin or MSYS, switch paths to Windows format before running java +if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then + APP_HOME=`cygpath --path --mixed "$APP_HOME"` + CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` + + JAVACMD=`cygpath --unix "$JAVACMD"` + + # We build the pattern for arguments to be converted via cygpath + ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` + SEP="" + for dir in $ROOTDIRSRAW ; do + ROOTDIRS="$ROOTDIRS$SEP$dir" + SEP="|" + done + OURCYGPATTERN="(^($ROOTDIRS))" + # Add a user-defined pattern to the cygpath arguments + if [ "$GRADLE_CYGPATTERN" != "" ] ; then + OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" + fi + # Now convert the arguments - kludge to limit ourselves to /bin/sh + i=0 + for arg in "$@" ; do + CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` + CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option + + if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition + eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` + else + eval `echo args$i`="\"$arg\"" + fi + i=`expr $i + 1` + done + case $i in + 0) set -- ;; + 1) set -- "$args0" ;; + 2) set -- "$args0" "$args1" ;; + 3) set -- "$args0" "$args1" "$args2" ;; + 4) set -- "$args0" "$args1" "$args2" "$args3" ;; + 5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; + 6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; + 7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; + 8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; + 9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; + esac +fi + +# Escape application args +save () { + for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done + echo " " +} +APP_ARGS=`save "$@"` + +# Collect all arguments for the java command, following the shell quoting and substitution rules +eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" + +exec "$JAVACMD" "$@" diff --git a/java-package/gradlew.bat b/java-package/gradlew.bat new file mode 100644 index 000000000000..bcf7fb7e32ce --- /dev/null +++ b/java-package/gradlew.bat @@ -0,0 +1,92 @@ +@REM +@REM Licensed to the Apache Software Foundation (ASF) under one +@REM or more contributor license agreements. See the NOTICE file +@REM distributed with this work for additional information +@REM regarding copyright ownership. The ASF licenses this file +@REM to you under the Apache License, Version 2.0 (the +@REM "License"); you may not use this file except in compliance +@REM with the License. You may obtain a copy of the License at +@REM +@REM http://www.apache.org/licenses/LICENSE-2.0 +@REM +@REM Unless required by applicable law or agreed to in writing, +@REM software distributed under the License is distributed on an +@REM "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +@REM KIND, either express or implied. See the License for the +@REM specific language governing permissions and limitations +@REM under the License. +@REM + +@if "%DEBUG%" == "" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +set DIRNAME=%~dp0 +if "%DIRNAME%" == "" set DIRNAME=. +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Resolve any "." and ".." in APP_HOME to make it shorter. +for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if "%ERRORLEVEL%" == "0" goto execute + +echo. +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto execute + +echo. +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:execute +@rem Setup the command line + +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* + +:end +@rem End local scope for the variables with windows NT shell +if "%ERRORLEVEL%"=="0" goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 +exit /b 1 + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/java-package/integration/build.gradle b/java-package/integration/build.gradle new file mode 100644 index 000000000000..fed8860c9a5a --- /dev/null +++ b/java-package/integration/build.gradle @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +plugins { + id 'application' + id 'jacoco' +} + +group 'org.apache.mxnet' +version '0.0.1-SNAPSHOT' + +repositories { + mavenCentral() +} + +application { + mainClassName = System.getProperty("main", "org.apache.mxnet.integration.IntegrationTest") +} + +dependencies { + api "commons-cli:commons-cli:${commons_cli_version}" + api "org.apache.logging.log4j:log4j-slf4j-impl:${log4j_slf4j_version}" + + api project(":mxnet-engine") + implementation "org.testng:testng:${testng_version}" +// testImplementation(":mxnet-engine") + testImplementation("org.testng:testng:${testng_version}") { + exclude group: "junit", module: "junit" + } + testImplementation "org.slf4j:slf4j-simple:${slf4j_version}" +} + +run { + systemProperties System.getProperties() + systemProperties.remove("user.dir") + systemProperty("file.encoding", "UTF-8") + jvmArgs "-Xverify:none" +} + +checkstyleMain { + // skip check style for this package + exclude 'org/apache/mxnet/integration/**' +} + +//test { +// +// useTestNG() +// filter { +// includeTestsMatching "org.apache.mxnet.integration.tests.engine.*" +// } +// +//} + diff --git a/java-package/integration/src/main/java/org/apache/mxnet/integration/IntegrationTest.java b/java-package/integration/src/main/java/org/apache/mxnet/integration/IntegrationTest.java new file mode 100644 index 000000000000..96d0ad3efced --- /dev/null +++ b/java-package/integration/src/main/java/org/apache/mxnet/integration/IntegrationTest.java @@ -0,0 +1,413 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.integration; + +import java.io.File; +import java.io.IOException; +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.net.URISyntaxException; +import java.net.URL; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Enumeration; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.jar.JarEntry; +import java.util.jar.JarFile; +import java.util.stream.Collectors; +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.DefaultParser; +import org.apache.commons.cli.HelpFormatter; +import org.apache.commons.cli.Options; +import org.apache.commons.cli.ParseException; +import org.apache.mxnet.integration.util.Arguments; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testng.SkipException; +import org.testng.annotations.AfterClass; +import org.testng.annotations.AfterTest; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.BeforeTest; +import org.testng.annotations.Test; + +public class IntegrationTest { + + private static final Logger logger = LoggerFactory.getLogger(IntegrationTest.class); + + private Class source; + + public IntegrationTest(Class source) { + this.source = source; + } + + public static void main(String[] args) { + new IntegrationTest(IntegrationTest.class).runTests(args); + // TODO: not elegant solution to native library crash + // System.exit(0); + } + + public boolean runTests(String[] args) { + Options options = Arguments.getOptions(); + try { + DefaultParser parser = new DefaultParser(); + CommandLine cmd = parser.parse(options, args, null, false); + Arguments arguments = new Arguments(cmd); + + Duration duration = Duration.ofMinutes(arguments.getDuration()); + List tests = listTests(arguments, source); + + boolean testsPassed = true; + while (!duration.isNegative()) { + long begin = System.currentTimeMillis(); + + testsPassed = testsPassed && runTests(tests); + + long delta = System.currentTimeMillis() - begin; + duration = duration.minus(Duration.ofMillis(delta)); + } + return testsPassed; + } catch (ParseException e) { + HelpFormatter formatter = new HelpFormatter(); + formatter.setLeftPadding(1); + formatter.setWidth(120); + formatter.printHelp(e.getMessage(), options); + return false; + } catch (Throwable t) { + logger.error("Unexpected error", t); + return false; + } + } + + private boolean runTests(List tests) { + Map totals = new ConcurrentHashMap<>(); + for (TestClass testClass : tests) { + logger.info("Running test {} ...", testClass.getName()); + int testCount = testClass.getTestCount(); + + try { + if (!testClass.beforeClass()) { + totals.merge(TestResult.FAILED, testCount, Integer::sum); + continue; + } + + for (int i = 0; i < testCount; ++i) { + TestResult result = testClass.runTest(i); + totals.merge(result, 1, Integer::sum); + } + } finally { + testClass.afterClass(); + } + } + + int totalFailed = totals.getOrDefault(TestResult.FAILED, 0); + int totalPassed = totals.getOrDefault(TestResult.SUCCESS, 0); + int totalSkipped = totals.getOrDefault(TestResult.SKIPPED, 0); + int totalUnsupported = totals.getOrDefault(TestResult.UNSUPPORTED, 0); + if (totalSkipped > 0) { + logger.info("Skipped: {} tests", totalSkipped); + } + if (totalUnsupported > 0) { + logger.info("Unsupported: {} tests", totalUnsupported); + } + if (totalFailed > 0) { + logger.error("Failed {} out of {} tests", totalFailed, totalFailed + totalPassed); + } else { + logger.info("Passed all {} tests", totalPassed); + } + return totalFailed == 0; + } + + private static List listTests(Arguments arguments, Class source) + throws IOException, ReflectiveOperationException, URISyntaxException { + String className = arguments.getClassName(); + String methodName = arguments.getMethodName(); + List tests = new ArrayList<>(); + try { + if (className != null) { + Class clazz; + if (className.startsWith(arguments.getPackageName())) { + clazz = Class.forName(className); + } else { + clazz = Class.forName(arguments.getPackageName() + className); + } + getTestsInClass(clazz, methodName).map(tests::add); + } else { + List> classes = listTestClasses(arguments, source); + for (Class clazz : classes) { + getTestsInClass(clazz, methodName).map(tests::add); + } + } + } catch (ReflectiveOperationException | IOException | URISyntaxException e) { + logger.error("Failed to resolve test class.", e); + throw e; + } + return tests; + } + + private static Optional getTestsInClass(Class clazz, String methodName) + throws ReflectiveOperationException { + if (clazz.getConstructors().length == 0) { + return Optional.empty(); + } + Constructor ctor = clazz.getConstructor(); + Object obj = ctor.newInstance(); + TestClass testClass = new TestClass(obj); + + for (Method method : clazz.getDeclaredMethods()) { + Test testMethod = method.getAnnotation(Test.class); + if (testMethod != null) { + if (testMethod.enabled() + && (methodName == null || methodName.equals(method.getName()))) { + testClass.addTestMethod(method); + } + continue; + } + BeforeClass beforeClass = method.getAnnotation(BeforeClass.class); + if (beforeClass != null) { + testClass.addBeforeClass(method); + continue; + } + AfterClass afterClass = method.getAnnotation(AfterClass.class); + if (afterClass != null) { + testClass.addAfterClass(method); + continue; + } + BeforeTest beforeTest = method.getAnnotation(BeforeTest.class); + if (beforeTest != null) { + testClass.addBeforeTest(method); + continue; + } + AfterTest afterTest = method.getAnnotation(AfterTest.class); + if (afterTest != null) { + testClass.addAfterTest(method); + } + } + + return Optional.of(testClass); + } + + private static List> listTestClasses(Arguments arguments, Class clazz) + throws IOException, ClassNotFoundException, URISyntaxException { + URL url = clazz.getProtectionDomain().getCodeSource().getLocation(); + String path = url.getPath(); + + if (!"file".equalsIgnoreCase(url.getProtocol())) { + return Collections.emptyList(); + } + + List> classList = new ArrayList<>(); + + Path classPath = Paths.get(url.toURI()); + if (Files.isDirectory(classPath)) { + Collection files = + Files.walk(classPath) + .filter(p -> Files.isRegularFile(p) && p.toString().endsWith(".class")) + .collect(Collectors.toList()); + for (Path file : files) { + Path p = classPath.relativize(file); + String className = p.toString(); + className = className.substring(0, className.lastIndexOf('.')); + className = className.replace(File.separatorChar, '.'); + if (className.startsWith(arguments.getPackageName()) && !className.contains("$")) { + try { + classList.add(Class.forName(className)); + } catch (ExceptionInInitializerError ignore) { + // ignore + } + } + } + } else if (path.toLowerCase().endsWith(".jar")) { + try (JarFile jarFile = new JarFile(classPath.toFile())) { + Enumeration en = jarFile.entries(); + while (en.hasMoreElements()) { + JarEntry entry = en.nextElement(); + String fileName = entry.getName(); + if (fileName.endsWith(".class")) { + fileName = fileName.substring(0, fileName.lastIndexOf('.')); + fileName = fileName.replace('/', '.'); + if (fileName.startsWith(arguments.getPackageName())) { + try { + classList.add(Class.forName(fileName)); + } catch (ExceptionInInitializerError ignore) { + // ignore + } + } + } + } + } + } + + return classList; + } + + private static final class TestClass { + + private Object object; + private List testMethods; + private List beforeClass; + private List afterClass; + private List beforeTest; + private List afterTest; + + public TestClass(Object object) { + this.object = object; + testMethods = new ArrayList<>(); + beforeClass = new ArrayList<>(); + afterClass = new ArrayList<>(); + beforeTest = new ArrayList<>(); + afterTest = new ArrayList<>(); + } + + public void addTestMethod(Method method) { + testMethods.add(method); + } + + public void addBeforeClass(Method method) { + beforeClass.add(method); + } + + public void addAfterClass(Method method) { + afterClass.add(method); + } + + public void addBeforeTest(Method method) { + beforeTest.add(method); + } + + public void addAfterTest(Method method) { + afterTest.add(method); + } + + public boolean beforeClass() { + try { + for (Method method : beforeClass) { + method.invoke(object); + } + return true; + } catch (InvocationTargetException | IllegalAccessException e) { + logger.error("", e.getCause()); + } + return false; + } + + public void afterClass() { + try { + for (Method method : afterClass) { + method.invoke(object); + } + } catch (InvocationTargetException | IllegalAccessException e) { + logger.error("", e.getCause()); + } + } + + public boolean beforeTest() { + try { + for (Method method : beforeTest) { + method.invoke(object); + } + return true; + } catch (InvocationTargetException | IllegalAccessException e) { + logger.error("", e.getCause()); + } + return false; + } + + public void afterTest() { + try { + for (Method method : afterTest) { + method.invoke(object); + } + } catch (InvocationTargetException | IllegalAccessException e) { + logger.error("", e.getCause()); + } + } + + public TestResult runTest(int index) { + if (!beforeTest()) { + return TestResult.FAILED; + } + + TestResult result; + Method method = testMethods.get(index); + try { + long begin = System.nanoTime(); + method.invoke(object); + String time = String.format("%.3f", (System.nanoTime() - begin) / 1000_0000f); + logger.info("Test {}.{} PASSED, duration: {}", getName(), method.getName(), time); + result = TestResult.SUCCESS; + } catch (IllegalAccessException | InvocationTargetException e) { + if (expectedException(method, e)) { + logger.info("Test {}.{} PASSED", getName(), method.getName()); + result = TestResult.SUCCESS; + } else if (e.getCause() instanceof SkipException) { + logger.info("Test {}.{} SKIPPED", getName(), method.getName()); + result = TestResult.SKIPPED; + } else if (e.getCause() instanceof UnsupportedOperationException) { + logger.info("Test {}.{} UNSUPPORTED", getName(), method.getName()); + logger.trace("", e.getCause()); + result = TestResult.UNSUPPORTED; + } else { + logger.error("Test {}.{} FAILED", getName(), method.getName()); + logger.error("", e.getCause()); + result = TestResult.FAILED; + } + } finally { + afterTest(); + } + return result; + } + + public int getTestCount() { + return testMethods.size(); + } + + public String getName() { + return object.getClass().getName(); + } + + private static boolean expectedException(Method method, Exception e) { + Test test = method.getAnnotation(Test.class); + Class[] exceptions = test.expectedExceptions(); + if (exceptions.length > 0) { + Throwable exception = e.getCause(); + for (Class c : exceptions) { + if (c.isInstance(exception)) { + return true; + } + } + } + return false; + } + } + + public enum TestResult { + SUCCESS, + FAILED, + SKIPPED, + UNSUPPORTED; + } +} diff --git a/java-package/integration/src/main/java/org/apache/mxnet/integration/package-info.java b/java-package/integration/src/main/java/org/apache/mxnet/integration/package-info.java new file mode 100644 index 000000000000..f97e77a8ce53 --- /dev/null +++ b/java-package/integration/src/main/java/org/apache/mxnet/integration/package-info.java @@ -0,0 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Contains integration tests that use the engine to test the actual behavior of the API. */ +package org.apache.mxnet.integration; diff --git a/java-package/integration/src/main/java/org/apache/mxnet/integration/tests/engine/ModelTest.java b/java-package/integration/src/main/java/org/apache/mxnet/integration/tests/engine/ModelTest.java new file mode 100644 index 000000000000..8beffb2f83fa --- /dev/null +++ b/java-package/integration/src/main/java/org/apache/mxnet/integration/tests/engine/ModelTest.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.integration.tests.engine; + +import java.io.IOException; +import org.apache.mxnet.engine.BaseMxResource; +import org.apache.mxnet.engine.Model; +import org.apache.mxnet.engine.MxResource; +import org.apache.mxnet.engine.Predictor; +import org.apache.mxnet.integration.tests.jna.JnaUtilTest; +import org.apache.mxnet.integration.util.Assertions; +import org.apache.mxnet.ndarray.NDArray; +import org.apache.mxnet.ndarray.NDList; +import org.apache.mxnet.ndarray.types.Shape; +import org.apache.mxnet.repository.Item; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testng.annotations.Test; + +public class ModelTest { + private static final Logger logger = LoggerFactory.getLogger(JnaUtilTest.class); + + @Test + public void modelLoadAndPredictTest() { + try (MxResource base = BaseMxResource.getSystemMxResource()) { + Model model = Model.loadModel(Item.MLP); + Predictor predictor = model.newPredictor(); + NDArray input = NDArray.create(base, new Shape(1, 28, 28)).ones(); + NDList inputs = new NDList(); + inputs.add(input); + NDList result = predictor.predict(inputs); + NDArray expected = + NDArray.create( + base, + new float[] { + 4.93476f, + -0.76084447f, + 0.37713608f, + 0.6605506f, + -1.3485785f, + -0.8736369f, + 0.018061712f, + -1.3274033f, + 1.0609543f, + 0.24042489f + }, + new Shape(1, 10)); + Assertions.assertAlmostEquals(result.get(0), expected); + } catch (IOException e) { + logger.error(e.getMessage(), e); + } + } +} diff --git a/java-package/integration/src/main/java/org/apache/mxnet/integration/tests/engine/package-info.java b/java-package/integration/src/main/java/org/apache/mxnet/integration/tests/engine/package-info.java new file mode 100644 index 000000000000..cc6a916609e2 --- /dev/null +++ b/java-package/integration/src/main/java/org/apache/mxnet/integration/tests/engine/package-info.java @@ -0,0 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Contains integration tests that use the engine to test the actual behavior of the API. */ +package org.apache.mxnet.integration.tests.engine; diff --git a/java-package/integration/src/main/java/org/apache/mxnet/integration/tests/jna/JnaUtilTest.java b/java-package/integration/src/main/java/org/apache/mxnet/integration/tests/jna/JnaUtilTest.java new file mode 100644 index 000000000000..e75114029972 --- /dev/null +++ b/java-package/integration/src/main/java/org/apache/mxnet/integration/tests/jna/JnaUtilTest.java @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.integration.tests.jna; + +import java.io.IOException; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import org.apache.mxnet.engine.BaseMxResource; +import org.apache.mxnet.engine.Device; +import org.apache.mxnet.engine.MxResource; +import org.apache.mxnet.engine.Symbol; +import org.apache.mxnet.jna.JnaUtils; +import org.apache.mxnet.ndarray.NDArray; +import org.apache.mxnet.ndarray.NDList; +import org.apache.mxnet.ndarray.types.Shape; +import org.apache.mxnet.nn.Parameter; +import org.apache.mxnet.nn.SymbolBlock; +import org.apache.mxnet.repository.Item; +import org.apache.mxnet.repository.Repository; +import org.apache.mxnet.util.PairList; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class JnaUtilTest { + + private static final Logger logger = LoggerFactory.getLogger(JnaUtilTest.class); + + @Test + public void doForwardTest() throws IOException { + try (MxResource base = BaseMxResource.getSystemMxResource()) { + Path modelPath = Repository.initRepository(Item.MLP); + Path symbolPath = modelPath.resolve("mlp-symbol.json"); + Path paramsPath = modelPath.resolve("mlp-0000.params"); + Symbol symbol = Symbol.loadSymbol(base, symbolPath); + SymbolBlock block = new SymbolBlock(base, symbol); + Device device = Device.defaultIfNull(); + NDList mxNDArray = JnaUtils.loadNdArray(base, paramsPath, Device.defaultIfNull(null)); + + // load parameters + List parameters = block.getAllParameters(); + Map map = new ConcurrentHashMap<>(); + parameters.forEach(p -> map.put(p.getName(), p)); + + for (NDArray nd : mxNDArray) { + String key = nd.getName(); + if (key == null) { + throw new IllegalArgumentException( + "Array names must be present in parameter file"); + } + + String paramName = key.split(":", 2)[1]; + Parameter parameter = map.remove(paramName); + parameter.setArray(nd); + } + block.setInputNames(new ArrayList<>(map.keySet())); + + NDArray arr = NDArray.create(base, new Shape(1, 28, 28), device).ones(); + block.forward(new NDList(arr), new PairList<>(), device); + logger.info( + "Number of MxResource managed by baseMxResource: {}", + BaseMxResource.getSystemMxResource().getSubResource().size()); + } catch (IOException e) { + logger.error(e.getMessage(), e); + throw e; + } + Assert.assertEquals(BaseMxResource.getSystemMxResource().getSubResource().size(), 0); + } + + @Test + public void createNdArray() { + try { + try (BaseMxResource base = BaseMxResource.getSystemMxResource()) { + int[] originIntegerArray = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + float[] originFloatArray = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + double[] originDoubleArray = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + long[] originLongArray = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + boolean[] originBooleanArray = { + true, false, false, true, true, true, true, false, false, true, true, true + }; + byte[] originByteArray = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + NDArray intArray = NDArray.create(base, originIntegerArray, new Shape(3, 4)); + NDArray floatArray = NDArray.create(base, originFloatArray, new Shape(3, 4)); + NDArray doubleArray = NDArray.create(base, originDoubleArray, new Shape(3, 4)); + NDArray longArray = NDArray.create(base, originLongArray, new Shape(3, 4)); + NDArray booleanArray = NDArray.create(base, originBooleanArray, new Shape(3, 4)); + NDArray byteArray = NDArray.create(base, originByteArray, new Shape(3, 4)); + NDArray intArray2 = NDArray.create(base, originIntegerArray); + NDArray floatArray2 = NDArray.create(base, originFloatArray); + NDArray doubleArray2 = NDArray.create(base, originDoubleArray); + NDArray longArray2 = NDArray.create(base, originLongArray); + NDArray booleanArray2 = NDArray.create(base, originBooleanArray); + NDArray byteArray2 = NDArray.create(base, originByteArray); + + int[] ndArrayInt = intArray.toIntArray(); + Assert.assertEquals(originIntegerArray, ndArrayInt); + // Float -> Double + float[] floats = floatArray.toFloatArray(); + Assert.assertEquals(originFloatArray, floats); + double[] ndArrayDouble = doubleArray.toDoubleArray(); + Assert.assertEquals(originDoubleArray, ndArrayDouble); + long[] ndArrayLong = longArray.toLongArray(); + Assert.assertEquals(originLongArray, ndArrayLong); + boolean[] ndArrayBoolean = booleanArray.toBooleanArray(); + Assert.assertEquals(originBooleanArray, ndArrayBoolean); + byte[] ndArrayByte = byteArray.toByteArray(); + Assert.assertEquals(originByteArray, ndArrayByte); + + int[] ndArrayInt2 = intArray2.toIntArray(); + Assert.assertEquals(originIntegerArray, ndArrayInt2); + + // Float -> Double + float[] floats2 = floatArray2.toFloatArray(); + Assert.assertEquals(originFloatArray, floats2); + double[] ndArrayDouble2 = doubleArray2.toDoubleArray(); + Assert.assertEquals(originDoubleArray, ndArrayDouble2); + long[] ndArrayLong2 = longArray2.toLongArray(); + Assert.assertEquals(originLongArray, ndArrayLong2); + boolean[] ndArrayBoolean2 = booleanArray2.toBooleanArray(); + Assert.assertEquals(originBooleanArray, ndArrayBoolean2); + byte[] ndArrayByte2 = byteArray2.toByteArray(); + Assert.assertEquals(originByteArray, ndArrayByte2); + } catch (ClassCastException e) { + logger.error(e.getMessage()); + throw e; + } + BaseMxResource base = BaseMxResource.getSystemMxResource(); + int countNotReleased = 0; + for (MxResource mxResource : base.getSubResource().values()) { + if (!mxResource.getClosed()) { + ++countNotReleased; + } + } + Assert.assertEquals(countNotReleased, 0); + } catch (ClassCastException e) { + logger.error(e.getMessage()); + throw e; + } + } + + @Test + public void loadNdArray() throws IOException { + try (BaseMxResource base = BaseMxResource.getSystemMxResource()) { + Path modelPath = Repository.initRepository(Item.MLP); + Path paramsPath = modelPath.resolve("mlp-0000.params"); + NDList mxNDArray = + JnaUtils.loadNdArray( + base, Paths.get(paramsPath.toUri()), Device.defaultIfNull(null)); + logger.info(mxNDArray.toString()); + logger.info( + String.format( + "The amount of sub resources managed by BaseMxResource: %s", + base.getSubResource().size())); + } catch (IOException e) { + logger.error(e.getMessage()); + throw e; + } + logger.info( + String.format( + "The amount of sub resources managed by BaseMxResource: %s", + BaseMxResource.getSystemMxResource().getSubResource().size())); + } +} diff --git a/java-package/integration/src/main/java/org/apache/mxnet/integration/util/Arguments.java b/java-package/integration/src/main/java/org/apache/mxnet/integration/util/Arguments.java new file mode 100644 index 000000000000..15771f3e2fa0 --- /dev/null +++ b/java-package/integration/src/main/java/org/apache/mxnet/integration/util/Arguments.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.integration.util; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Option; +import org.apache.commons.cli.Options; + +public class Arguments { + + private String methodName; + private String className; + private String packageName; + private int duration; + private int iteration = 1; + + public Arguments(CommandLine cmd) { + methodName = cmd.getOptionValue("method-name"); + className = cmd.getOptionValue("class-name"); + if (cmd.hasOption("package-name")) { + packageName = cmd.getOptionValue("package-name"); + } else { + packageName = "org.apache.mxnet.integration.tests."; + } + + if (cmd.hasOption("duration")) { + duration = Integer.parseInt(cmd.getOptionValue("duration")); + } + if (cmd.hasOption("iteration")) { + iteration = Integer.parseInt(cmd.getOptionValue("iteration")); + } + } + + public static Options getOptions() { + Options options = new Options(); + options.addOption( + Option.builder("d") + .longOpt("duration") + .hasArg() + .argName("DURATION") + .desc("Duration of the test.") + .build()); + options.addOption( + Option.builder("n") + .longOpt("iteration") + .hasArg() + .argName("ITERATION") + .desc("Number of iterations in each test.") + .build()); + options.addOption( + Option.builder("p") + .longOpt("package-name") + .hasArg() + .argName("PACKAGE-NAME") + .desc("Name of the package to run") + .build()); + options.addOption( + Option.builder("c") + .longOpt("class-name") + .hasArg() + .argName("CLASS-NAME") + .desc("Name of the class to run") + .build()); + options.addOption( + Option.builder("m") + .longOpt("method-name") + .hasArg() + .argName("METHOD-NAME") + .desc("Name of the method to run") + .build()); + return options; + } + + public int getDuration() { + return duration; + } + + public int getIteration() { + return iteration; + } + + public String getPackageName() { + return packageName; + } + + public String getClassName() { + return className; + } + + public String getMethodName() { + return methodName; + } +} diff --git a/java-package/integration/src/main/java/org/apache/mxnet/integration/util/Assertions.java b/java-package/integration/src/main/java/org/apache/mxnet/integration/util/Assertions.java new file mode 100644 index 000000000000..b342b451ba0b --- /dev/null +++ b/java-package/integration/src/main/java/org/apache/mxnet/integration/util/Assertions.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mxnet.integration.util; + +import org.apache.mxnet.ndarray.NDArray; +import org.apache.mxnet.ndarray.NDList; +import org.testng.Assert; + +public final class Assertions { + private static final double RTOL = 1e-5; + private static final double ATOL = 1e-3; + + private Assertions() {} + + private static String getDefaultErrorMessage(T actual, T expected) { + return getDefaultErrorMessage(actual, expected, null); + } + + private static String getDefaultErrorMessage(T actual, T expected, String errorMessage) { + StringBuilder sb = new StringBuilder(100); + if (errorMessage != null) { + sb.append(errorMessage); + } + sb.append(System.lineSeparator()) + .append("Expected: ") + .append(expected) + .append(System.lineSeparator()) + .append("Actual: ") + .append(actual); + return sb.toString(); + } + + public static void assertAlmostEquals(NDArray actual, NDArray expected) { + assertAlmostEquals(actual, expected, RTOL, ATOL); + } + + public static void assertAlmostEquals(NDList actual, NDList expected) { + assertAlmostEquals(actual, expected, RTOL, ATOL); + } + + public static void assertAlmostEquals(double actual, double expected) { + assertAlmostEquals(actual, expected, RTOL, ATOL); + } + + public static void assertAlmostEquals( + double actual, double expected, double rtol, double atol) { + if (Math.abs(actual - expected) > (atol + rtol * Math.abs(expected))) { + throw new AssertionError(getDefaultErrorMessage(actual, expected)); + } + } + + public static void assertAlmostEquals( + NDList actual, NDList expected, double rtol, double atol) { + Assert.assertEquals( + actual.size(), + expected.size(), + getDefaultErrorMessage( + actual.size(), expected.size(), "The NDLists have different sizes")); + int size = actual.size(); + for (int i = 0; i < size; i++) { + assertAlmostEquals(actual.get(i), expected.get(i), rtol, atol); + } + } + + public static void assertAlmostEquals( + NDArray actual, NDArray expected, double rtol, double atol) { + if (!actual.getShape().equals(expected.getShape())) { + throw new AssertionError( + getDefaultErrorMessage( + actual.getShape(), + expected.getShape(), + "The shape of two NDArray are different!")); + } + Number[] actualDoubleArray = actual.toArray(); + Number[] expectedDoubleArray = expected.toArray(); + for (int i = 0; i < actualDoubleArray.length; i++) { + double a = actualDoubleArray[i].doubleValue(); + double b = expectedDoubleArray[i].doubleValue(); + if (Math.abs(a - b) > (atol + rtol * Math.abs(b))) { + throw new AssertionError("Expected:" + b + " but got " + a); + } + } + } + + public static void assertInPlaceEquals(NDArray actual, NDArray expected, NDArray original) { + Assert.assertEquals( + actual, expected, getDefaultErrorMessage(actual, expected, "Assert Equal failed!")); + Assert.assertSame( + original, + actual, + getDefaultErrorMessage(original, expected, "Assert Inplace failed!")); + } + + public static void assertInPlaceAlmostEquals( + NDArray actual, NDArray expected, NDArray original) { + assertInPlaceAlmostEquals(actual, expected, original, RTOL, ATOL); + } + + public static void assertInPlaceAlmostEquals( + NDArray actual, NDArray expected, NDArray original, double rtol, double atol) { + assertAlmostEquals(actual, expected, rtol, atol); + Assert.assertSame( + original, + actual, + getDefaultErrorMessage(original, expected, "Assert Inplace failed!")); + } +} diff --git a/java-package/integration/src/main/java/org/apache/mxnet/integration/util/CoverageUtils.java b/java-package/integration/src/main/java/org/apache/mxnet/integration/util/CoverageUtils.java new file mode 100644 index 000000000000..e897c27e13ab --- /dev/null +++ b/java-package/integration/src/main/java/org/apache/mxnet/integration/util/CoverageUtils.java @@ -0,0 +1,260 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mxnet.integration.util; + +import java.io.File; +import java.io.IOException; +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.lang.reflect.Proxy; +import java.net.URISyntaxException; +import java.net.URL; +import java.net.URLClassLoader; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Enumeration; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.jar.JarEntry; +import java.util.jar.JarFile; +import java.util.stream.Collectors; + +public final class CoverageUtils { + + private CoverageUtils() {} + + public static void testGetterSetters(Class baseClass) + throws IOException, ReflectiveOperationException, URISyntaxException { + List> list = getClasses(baseClass); + for (Class clazz : list) { + Object obj = null; + if (clazz.isEnum()) { + obj = clazz.getEnumConstants()[0]; + } else { + Constructor[] constructors = clazz.getConstructors(); + for (Constructor con : constructors) { + try { + Class[] types = con.getParameterTypes(); + Object[] args = new Object[types.length]; + for (int i = 0; i < args.length; ++i) { + args[i] = getMockInstance(types[i], true); + } + con.setAccessible(true); + obj = con.newInstance(args); + } catch (ReflectiveOperationException ignore) { + // ignore + } + } + } + if (obj == null) { + continue; + } + + Method[] methods = clazz.getDeclaredMethods(); + for (Method method : methods) { + String methodName = method.getName(); + int parameterCount = method.getParameterCount(); + try { + if (parameterCount == 0 + && (methodName.startsWith("get") + || methodName.startsWith("is") + || "toString".equals(methodName) + || "hashCode".equals(methodName))) { + method.invoke(obj); + } else if (parameterCount == 1 + && (methodName.startsWith("set") || "fromValue".equals(methodName))) { + Class type = method.getParameterTypes()[0]; + method.invoke(obj, getMockInstance(type, true)); + } else if ("equals".equals(methodName)) { + method.invoke(obj, obj); + method.invoke(obj, (Object) null); + Class type = method.getParameterTypes()[0]; + method.invoke(obj, getMockInstance(type, true)); + } + } catch (ReflectiveOperationException ignore) { + // ignore + } + } + } + } + + private static List> getClasses(Class clazz) + throws IOException, ReflectiveOperationException, URISyntaxException { + ClassLoader appClassLoader = Thread.currentThread().getContextClassLoader(); + Field field = appClassLoader.getClass().getDeclaredField("ucp"); + field.setAccessible(true); + Object ucp = field.get(appClassLoader); + Method method = ucp.getClass().getDeclaredMethod("getURLs"); + URL[] urls = (URL[]) method.invoke(ucp); + ClassLoader cl = new TestClassLoader(urls, Thread.currentThread().getContextClassLoader()); + + URL url = clazz.getProtectionDomain().getCodeSource().getLocation(); + String path = url.getPath(); + + if (!"file".equalsIgnoreCase(url.getProtocol())) { + return Collections.emptyList(); + } + + List> classList = new ArrayList<>(); + + Path classPath = Paths.get(url.toURI()); + if (Files.isDirectory(classPath)) { + Collection files = + Files.walk(classPath) + .filter(p -> Files.isRegularFile(p) && p.toString().endsWith(".class")) + .collect(Collectors.toList()); + for (Path file : files) { + Path p = classPath.relativize(file); + String className = p.toString(); + className = className.substring(0, className.lastIndexOf('.')); + className = className.replace(File.separatorChar, '.'); + + try { + classList.add(Class.forName(className, true, cl)); + } catch (Error ignore) { + // ignore + } + } + } else if (path.toLowerCase().endsWith(".jar")) { + try (JarFile jarFile = new JarFile(classPath.toFile())) { + Enumeration en = jarFile.entries(); + while (en.hasMoreElements()) { + JarEntry entry = en.nextElement(); + String fileName = entry.getName(); + if (fileName.endsWith(".class")) { + fileName = fileName.substring(0, fileName.lastIndexOf('.')); + fileName = fileName.replace('/', '.'); + try { + classList.add(Class.forName(fileName, true, cl)); + } catch (Error ignore) { + // ignore + } + } + } + } + } + + return classList; + } + + private static Object getMockInstance(Class clazz, boolean useConstructor) { + if (clazz.isPrimitive()) { + if (clazz == Boolean.TYPE) { + return Boolean.TRUE; + } + if (clazz == Character.TYPE) { + return '0'; + } + if (clazz == Byte.TYPE) { + return (byte) 0; + } + if (clazz == Short.TYPE) { + return (short) 0; + } + if (clazz == Integer.TYPE) { + return 0; + } + if (clazz == Long.TYPE) { + return 0L; + } + if (clazz == Float.TYPE) { + return 0f; + } + if (clazz == Double.TYPE) { + return 0d; + } + } + + if (clazz.isAssignableFrom(String.class)) { + return ""; + } + + if (clazz.isAssignableFrom(List.class)) { + return new ArrayList<>(); + } + + if (clazz.isAssignableFrom(Set.class)) { + return new HashSet<>(); + } + + if (clazz.isAssignableFrom(Map.class)) { + return new HashMap<>(); + } + + if (clazz.isEnum()) { + return clazz.getEnumConstants()[0]; + } + + if (clazz.isInterface()) { + return newProxyInstance(clazz); + } + + if (useConstructor) { + Constructor[] constructors = clazz.getConstructors(); + for (Constructor con : constructors) { + try { + Class[] types = con.getParameterTypes(); + Object[] args = new Object[types.length]; + for (int i = 0; i < args.length; ++i) { + args[i] = getMockInstance(types[i], false); + } + con.setAccessible(true); + return con.newInstance(args); + } catch (ReflectiveOperationException ignore) { + // ignore + } + } + } + + return null; + } + + @SuppressWarnings({"rawtypes", "PMD.UseProperClassLoader"}) + private static Object newProxyInstance(Class clazz) { + ClassLoader cl = clazz.getClassLoader(); + return Proxy.newProxyInstance(cl, new Class[] {clazz}, (proxy, method, args) -> null); + } + + private static final class TestClassLoader extends URLClassLoader { + + public TestClassLoader(URL[] urls, ClassLoader parent) { + super(urls, parent); + } + + /** {@inheritDoc} */ + @Override + public Class loadClass(String name) throws ClassNotFoundException { + try { + return findClass(name); + } catch (ClassNotFoundException e) { + ClassLoader classLoader = getParent(); + if (classLoader == null) { + classLoader = getSystemClassLoader(); + } + return classLoader.loadClass(name); + } + } + } +} diff --git a/java-package/integration/src/main/resources/log4j2.xml b/java-package/integration/src/main/resources/log4j2.xml new file mode 100644 index 000000000000..ff05a0145057 --- /dev/null +++ b/java-package/integration/src/main/resources/log4j2.xml @@ -0,0 +1,34 @@ + + + + + + + + + + + + + + + + + + diff --git a/java-package/jnarator/build.gradle b/java-package/jnarator/build.gradle new file mode 100644 index 000000000000..abc7cc0a90d2 --- /dev/null +++ b/java-package/jnarator/build.gradle @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +plugins { + id 'antlr' +} + +dependencies { + antlr "org.antlr:antlr4:${antlr_version}" + + api "commons-cli:commons-cli:${commons_cli_version}" + api "org.antlr:antlr4-runtime:${antlr_version}" + api "org.apache.logging.log4j:log4j-slf4j-impl:${log4j_slf4j_version}" + + testImplementation("org.testng:testng:${testng_version}") { + exclude group: "junit", module: "junit" + } + + testImplementation "org.slf4j:slf4j-simple:${slf4j_version}" + testRuntimeOnly project(":mxnet-engine") +// testRuntimeOnly ":mxnet-native-auto:${mxnet_version}" +} + +checkstyleMain.source = 'src/main/java' + +checkstyleMain { + // skip check style for this package + exclude 'org/apache/mxnet/jnarator/**' +} + +pmdMain.source = 'src/main/java' +//pmdMain.ignoreFailures(true) +spotbugs.ignoreFailures = true + +jar { + manifest { + attributes ( + "Main-Class" : "org.apache.mxnet.jnarator.Main", + "Multi-Release" : true + ) + } + includeEmptyDirs = false + duplicatesStrategy = DuplicatesStrategy.INCLUDE + from configurations.runtimeClasspath.collect { it.isDirectory() ? it : zipTree(it) } +} diff --git a/java-package/jnarator/src/main/antlr/org/apache/mxnet/jnarator/parser/C.g4 b/java-package/jnarator/src/main/antlr/org/apache/mxnet/jnarator/parser/C.g4 new file mode 100644 index 000000000000..6206c1324c10 --- /dev/null +++ b/java-package/jnarator/src/main/antlr/org/apache/mxnet/jnarator/parser/C.g4 @@ -0,0 +1,923 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +grammar C; + +@parser::header { +package org.apache.mxnet.jnarator.parser; + +} + +@lexer::header { +package org.apache.mxnet.jnarator.parser; + +} + + +primaryExpression + : Identifier + | Constant + | StringLiteral+ + | '(' expression ')' + | genericSelection + | '__extension__'? '(' compoundStatement ')' // Blocks (GCC extension) + | '__builtin_va_arg' '(' unaryExpression ',' typeName ')' + | '__builtin_offsetof' '(' typeName ',' unaryExpression ')' + ; + +genericSelection + : '_Generic' '(' assignmentExpression ',' genericAssocList ')' + ; + +genericAssocList + : genericAssociation + | genericAssocList ',' genericAssociation + ; + +genericAssociation + : typeName ':' assignmentExpression + | 'default' ':' assignmentExpression + ; + +postfixExpression + : primaryExpression + | postfixExpression '[' expression ']' + | postfixExpression '(' argumentExpressionList? ')' + | postfixExpression '.' Identifier + | postfixExpression '->' Identifier + | postfixExpression '++' + | postfixExpression '--' + | '(' typeName ')' '{' initializerList '}' + | '(' typeName ')' '{' initializerList ',' '}' + | '__extension__' '(' typeName ')' '{' initializerList '}' + | '__extension__' '(' typeName ')' '{' initializerList ',' '}' + ; + +argumentExpressionList + : assignmentExpression + | argumentExpressionList ',' assignmentExpression + ; + +unaryExpression + : postfixExpression + | '++' unaryExpression + | '--' unaryExpression + | unaryOperator castExpression + | 'sizeof' unaryExpression + | 'sizeof' '(' typeName ')' + | '_Alignof' '(' typeName ')' + | '&&' Identifier // GCC extension address of label + ; + +unaryOperator + : '&' | '*' | '+' | '-' | '~' | '!' + ; + +castExpression + : '(' typeName ')' castExpression + | '__extension__' '(' typeName ')' castExpression + | unaryExpression + | DigitSequence // for + ; + +multiplicativeExpression + : castExpression + | multiplicativeExpression '*' castExpression + | multiplicativeExpression '/' castExpression + | multiplicativeExpression '%' castExpression + ; + +additiveExpression + : multiplicativeExpression + | additiveExpression '+' multiplicativeExpression + | additiveExpression '-' multiplicativeExpression + ; + +shiftExpression + : additiveExpression + | shiftExpression '<<' additiveExpression + | shiftExpression '>>' additiveExpression + ; + +relationalExpression + : shiftExpression + | relationalExpression '<' shiftExpression + | relationalExpression '>' shiftExpression + | relationalExpression '<=' shiftExpression + | relationalExpression '>=' shiftExpression + ; + +equalityExpression + : relationalExpression + | equalityExpression '==' relationalExpression + | equalityExpression '!=' relationalExpression + ; + +andExpression + : equalityExpression + | andExpression '&' equalityExpression + ; + +exclusiveOrExpression + : andExpression + | exclusiveOrExpression '^' andExpression + ; + +inclusiveOrExpression + : exclusiveOrExpression + | inclusiveOrExpression '|' exclusiveOrExpression + ; + +logicalAndExpression + : inclusiveOrExpression + | logicalAndExpression '&&' inclusiveOrExpression + ; + +logicalOrExpression + : logicalAndExpression + | logicalOrExpression '||' logicalAndExpression + ; + +conditionalExpression + : logicalOrExpression ('?' expression ':' conditionalExpression)? + ; + +assignmentExpression + : conditionalExpression + | unaryExpression assignmentOperator assignmentExpression + | DigitSequence // for + ; + +assignmentOperator + : '=' | '*=' | '/=' | '%=' | '+=' | '-=' | '<<=' | '>>=' | '&=' | '^=' | '|=' + ; + +expression + : assignmentExpression + | expression ',' assignmentExpression + ; + +constantExpression + : conditionalExpression + ; + +declaration + : declarationSpecifiers initDeclaratorList ';' + | declarationSpecifiers ';' + | staticAssertDeclaration + ; + +declarationSpecifiers + : declarationSpecifier+ + ; + +declarationSpecifiers2 + : declarationSpecifier+ + ; + +declarationSpecifier + : storageClassSpecifier + | typeSpecifier + | typeQualifier + | functionSpecifier + | alignmentSpecifier + ; + +initDeclaratorList + : initDeclarator + | initDeclaratorList ',' initDeclarator + ; + +initDeclarator + : declarator + | declarator '=' initializer + ; + +storageClassSpecifier + : 'typedef' + | 'extern' + | 'static' + | '_Thread_local' + | 'auto' + | 'register' + ; + +typeSpecifier + : ('void' + | 'char' + | 'short' + | 'int' + | 'long' + | 'float' + | 'double' + | 'signed' + | 'unsigned' + | '_Bool' + | '_Complex' + | '__m128' + | '__m128d' + | '__m128i') + | '__extension__' '(' ('__m128' | '__m128d' | '__m128i') ')' + | atomicTypeSpecifier + | structOrUnionSpecifier + | enumSpecifier + | typedefName + | '__typeof__' '(' constantExpression ')' // GCC extension + | typeSpecifier pointer + ; + +structOrUnionSpecifier + : structOrUnion Identifier? '{' structDeclarationList '}' + | structOrUnion Identifier + ; + +structOrUnion + : 'struct' + | 'union' + ; + +structDeclarationList + : structDeclaration + | structDeclarationList structDeclaration + ; + +structDeclaration + : specifierQualifierList structDeclaratorList? ';' + | staticAssertDeclaration + ; + +specifierQualifierList + : typeSpecifier specifierQualifierList? + | typeQualifier specifierQualifierList? + ; + +structDeclaratorList + : structDeclarator + | structDeclaratorList ',' structDeclarator + ; + +structDeclarator + : declarator + | declarator? ':' constantExpression + ; + +enumSpecifier + : 'enum' Identifier? '{' enumeratorList '}' + | 'enum' Identifier? '{' enumeratorList ',' '}' + | 'enum' Identifier + ; + +enumeratorList + : enumerator + | enumeratorList ',' enumerator + ; + +enumerator + : enumerationConstant + | enumerationConstant '=' constantExpression + ; + +enumerationConstant + : Identifier + ; + +atomicTypeSpecifier + : '_Atomic' '(' typeName ')' + ; + +typeQualifier + : 'const' + | 'restrict' + | 'volatile' + | '_Atomic' + ; + +functionSpecifier + : ('inline' + | '_Noreturn' + | '__inline__' // GCC extension + | '__stdcall') + | gccAttributeSpecifier + | '__declspec' '(' Identifier ')' + ; + +alignmentSpecifier + : '_Alignas' '(' typeName ')' + | '_Alignas' '(' constantExpression ')' + ; + +declarator + : pointer? directDeclarator gccDeclaratorExtension* + ; + +directDeclarator + : Identifier + | '(' declarator ')' + | directDeclarator '[' typeQualifierList? assignmentExpression? ']' + | directDeclarator '[' 'static' typeQualifierList? assignmentExpression ']' + | directDeclarator '[' typeQualifierList 'static' assignmentExpression ']' + | directDeclarator '[' typeQualifierList? '*' ']' + | directDeclarator '(' parameterTypeList ')' + | directDeclarator '(' identifierList? ')' + | Identifier ':' DigitSequence // bit field + | '(' typeSpecifier? pointer directDeclarator ')' // function pointer like: (__cdecl *f) + ; + +gccDeclaratorExtension + : '__asm' '(' StringLiteral+ ')' + | gccAttributeSpecifier + ; + +gccAttributeSpecifier + : '__attribute__' '(' '(' gccAttributeList ')' ')' + ; + +gccAttributeList + : gccAttribute (',' gccAttribute)* + | // empty + ; + +gccAttribute + : ~(',' | '(' | ')') // relaxed def for "identifier or reserved word" + ('(' argumentExpressionList? ')')? + | // empty + ; + +nestedParenthesesBlock + : ( ~('(' | ')') + | '(' nestedParenthesesBlock ')' + )* + ; + +pointer + : '*' typeQualifierList? + | '*' typeQualifierList? pointer + | '^' typeQualifierList? // Blocks language extension + | '^' typeQualifierList? pointer // Blocks language extension + ; + +typeQualifierList + : typeQualifier + | typeQualifierList typeQualifier + ; + +parameterTypeList + : parameterList + | parameterList ',' '...' + ; + +parameterList + : parameterDeclaration + | parameterList ',' parameterDeclaration + ; + +parameterDeclaration + : declarationSpecifiers declarator + | declarationSpecifiers2 abstractDeclarator? + ; + +identifierList + : Identifier + | identifierList ',' Identifier + ; + +typeName + : specifierQualifierList abstractDeclarator? + ; + +abstractDeclarator + : pointer + | pointer? directAbstractDeclarator gccDeclaratorExtension* + ; + +directAbstractDeclarator + : '(' abstractDeclarator ')' gccDeclaratorExtension* + | '[' typeQualifierList? assignmentExpression? ']' + | '[' 'static' typeQualifierList? assignmentExpression ']' + | '[' typeQualifierList 'static' assignmentExpression ']' + | '[' '*' ']' + | '(' parameterTypeList? ')' gccDeclaratorExtension* + | directAbstractDeclarator '[' typeQualifierList? assignmentExpression? ']' + | directAbstractDeclarator '[' 'static' typeQualifierList? assignmentExpression ']' + | directAbstractDeclarator '[' typeQualifierList 'static' assignmentExpression ']' + | directAbstractDeclarator '[' '*' ']' + | directAbstractDeclarator '(' parameterTypeList? ')' gccDeclaratorExtension* + ; + +typedefName + : Identifier + ; + +initializer + : assignmentExpression + | '{' initializerList '}' + | '{' initializerList ',' '}' + ; + +initializerList + : designation? initializer + | initializerList ',' designation? initializer + ; + +designation + : designatorList '=' + ; + +designatorList + : designator + | designatorList designator + ; + +designator + : '[' constantExpression ']' + | '.' Identifier + ; + +staticAssertDeclaration + : '_Static_assert' '(' constantExpression ',' StringLiteral+ ')' ';' + ; + +statement + : labeledStatement + | compoundStatement + | expressionStatement + | selectionStatement + | iterationStatement + | jumpStatement + | ('__asm' | '__asm__') ('volatile' | '__volatile__') '(' (logicalOrExpression (',' logicalOrExpression)*)? (':' (logicalOrExpression (',' logicalOrExpression)*)?)* ')' ';' + ; + +labeledStatement + : Identifier ':' statement + | 'case' constantExpression ':' statement + | 'default' ':' statement + ; + +compoundStatement + : '{' blockItemList? '}' + ; + +blockItemList + : blockItem + | blockItemList blockItem + ; + +blockItem + : statement + | declaration + ; + +expressionStatement + : expression? ';' + ; + +selectionStatement + : 'if' '(' expression ')' statement ('else' statement)? + | 'switch' '(' expression ')' statement + ; + +iterationStatement + : While '(' expression ')' statement + | Do statement While '(' expression ')' ';' + | For '(' forCondition ')' statement + ; + +// | 'for' '(' expression? ';' expression? ';' forUpdate? ')' statement +// | For '(' declaration expression? ';' expression? ')' statement + +forCondition + : forDeclaration ';' forExpression? ';' forExpression? + | expression? ';' forExpression? ';' forExpression? + ; + +forDeclaration + : declarationSpecifiers initDeclaratorList + | declarationSpecifiers + ; + +forExpression + : assignmentExpression + | forExpression ',' assignmentExpression + ; + +jumpStatement + : 'goto' Identifier ';' + | 'continue' ';' + | 'break' ';' + | 'return' expression? ';' + | 'goto' unaryExpression ';' // GCC extension + ; + +compilationUnit + : translationUnit? ( EOF | '}' ) + ; + +translationUnit + : externalDeclaration + | translationUnit externalDeclaration + ; + +externalDeclaration + : functionDefinition + | declaration + | ';' // stray ; + ; + +functionDefinition + : declarationSpecifiers? declarator declarationList? compoundStatement + ; + +declarationList + : declaration + | declarationList declaration + ; + +Auto : 'auto'; +Break : 'break'; +Case : 'case'; +Char : 'char'; +Const : 'const'; +Continue : 'continue'; +Default : 'default'; +Do : 'do'; +Double : 'double'; +Else : 'else'; +Enum : 'enum'; +Extern : 'extern'; +Float : 'float'; +For : 'for'; +Goto : 'goto'; +If : 'if'; +Inline : 'inline'; +Int : 'int'; +Long : 'long'; +Register : 'register'; +Restrict : 'restrict'; +Return : 'return'; +Short : 'short'; +Signed : 'signed'; +Sizeof : 'sizeof'; +Static : 'static'; +Struct : 'struct'; +Switch : 'switch'; +Typedef : 'typedef'; +Union : 'union'; +Unsigned : 'unsigned'; +Void : 'void'; +Volatile : 'volatile'; +While : 'while'; + +Alignas : '_Alignas'; +Alignof : '_Alignof'; +Atomic : '_Atomic'; +Bool : '_Bool'; +Complex : '_Complex'; +Generic : '_Generic'; +Imaginary : '_Imaginary'; +Noreturn : '_Noreturn'; +StaticAssert : '_Static_assert'; +ThreadLocal : '_Thread_local'; + +LeftParen : '('; +RightParen : ')'; +LeftBracket : '['; +RightBracket : ']'; +LeftBrace : '{'; +RightBrace : '}'; + +Less : '<'; +LessEqual : '<='; +Greater : '>'; +GreaterEqual : '>='; +LeftShift : '<<'; +RightShift : '>>'; + +Plus : '+'; +PlusPlus : '++'; +Minus : '-'; +MinusMinus : '--'; +Star : '*'; +Div : '/'; +Mod : '%'; + +And : '&'; +Or : '|'; +AndAnd : '&&'; +OrOr : '||'; +Caret : '^'; +Not : '!'; +Tilde : '~'; + +Question : '?'; +Colon : ':'; +Semi : ';'; +Comma : ','; + +Assign : '='; +// '*=' | '/=' | '%=' | '+=' | '-=' | '<<=' | '>>=' | '&=' | '^=' | '|=' +StarAssign : '*='; +DivAssign : '/='; +ModAssign : '%='; +PlusAssign : '+='; +MinusAssign : '-='; +LeftShiftAssign : '<<='; +RightShiftAssign : '>>='; +AndAssign : '&='; +XorAssign : '^='; +OrAssign : '|='; + +Equal : '=='; +NotEqual : '!='; + +Arrow : '->'; +Dot : '.'; +Ellipsis : '...'; + +Identifier + : IdentifierNondigit + ( IdentifierNondigit + | Digit + )* + ; + +fragment +IdentifierNondigit + : Nondigit + | UniversalCharacterName + //| // other implementation-defined characters... + ; + +fragment +Nondigit + : [a-zA-Z_] + ; + +fragment +Digit + : [0-9] + ; + +fragment +UniversalCharacterName + : '\\u' HexQuad + | '\\U' HexQuad HexQuad + ; + +fragment +HexQuad + : HexadecimalDigit HexadecimalDigit HexadecimalDigit HexadecimalDigit + ; + +Constant + : IntegerConstant + | FloatingConstant + //| EnumerationConstant + | CharacterConstant + ; + +fragment +IntegerConstant + : DecimalConstant IntegerSuffix? + | OctalConstant IntegerSuffix? + | HexadecimalConstant IntegerSuffix? + | BinaryConstant + ; + +fragment +BinaryConstant + : '0' [bB] [0-1]+ + ; + +fragment +DecimalConstant + : NonzeroDigit Digit* + ; + +fragment +OctalConstant + : '0' OctalDigit* + ; + +fragment +HexadecimalConstant + : HexadecimalPrefix HexadecimalDigit+ + ; + +fragment +HexadecimalPrefix + : '0' [xX] + ; + +fragment +NonzeroDigit + : [1-9] + ; + +fragment +OctalDigit + : [0-7] + ; + +fragment +HexadecimalDigit + : [0-9a-fA-F] + ; + +fragment +IntegerSuffix + : UnsignedSuffix LongSuffix? + | UnsignedSuffix LongLongSuffix + | LongSuffix UnsignedSuffix? + | LongLongSuffix UnsignedSuffix? + ; + +fragment +UnsignedSuffix + : [uU] + ; + +fragment +LongSuffix + : [lL] + ; + +fragment +LongLongSuffix + : 'll' | 'LL' + ; + +fragment +FloatingConstant + : DecimalFloatingConstant + | HexadecimalFloatingConstant + ; + +fragment +DecimalFloatingConstant + : FractionalConstant ExponentPart? FloatingSuffix? + | DigitSequence ExponentPart FloatingSuffix? + ; + +fragment +HexadecimalFloatingConstant + : HexadecimalPrefix HexadecimalFractionalConstant BinaryExponentPart FloatingSuffix? + | HexadecimalPrefix HexadecimalDigitSequence BinaryExponentPart FloatingSuffix? + ; + +fragment +FractionalConstant + : DigitSequence? '.' DigitSequence + | DigitSequence '.' + ; + +fragment +ExponentPart + : 'e' Sign? DigitSequence + | 'E' Sign? DigitSequence + ; + +fragment +Sign + : '+' | '-' + ; + +DigitSequence + : Digit+ + ; + +fragment +HexadecimalFractionalConstant + : HexadecimalDigitSequence? '.' HexadecimalDigitSequence + | HexadecimalDigitSequence '.' + ; + +fragment +BinaryExponentPart + : 'p' Sign? DigitSequence + | 'P' Sign? DigitSequence + ; + +fragment +HexadecimalDigitSequence + : HexadecimalDigit+ + ; + +fragment +FloatingSuffix + : 'f' | 'l' | 'F' | 'L' + ; + +fragment +CharacterConstant + : '\'' CCharSequence '\'' + | 'L\'' CCharSequence '\'' + | 'u\'' CCharSequence '\'' + | 'U\'' CCharSequence '\'' + ; + +fragment +CCharSequence + : CChar+ + ; + +fragment +CChar + : ~['\\\r\n] + | EscapeSequence + ; + +fragment +EscapeSequence + : SimpleEscapeSequence + | OctalEscapeSequence + | HexadecimalEscapeSequence + | UniversalCharacterName + ; + +fragment +SimpleEscapeSequence + : '\\' ['"?abfnrtv\\] + ; + +fragment +OctalEscapeSequence + : '\\' OctalDigit + | '\\' OctalDigit OctalDigit + | '\\' OctalDigit OctalDigit OctalDigit + ; + +fragment +HexadecimalEscapeSequence + : '\\x' HexadecimalDigit+ + ; + +StringLiteral + : EncodingPrefix? '"' SCharSequence? '"' + ; + +fragment +EncodingPrefix + : 'u8' + | 'u' + | 'U' + | 'L' + ; + +fragment +SCharSequence + : SChar+ + ; + +fragment +SChar + : ~["\\\r\n] + | EscapeSequence + | '\\\n' // Added line + | '\\\r\n' // Added line + ; + +LineAfterPreprocessing + : '#' ~[\r\n]* + -> skip + ; + +Whitespace + : ( [ \t] + | 'MXNET_DLL' + | 'NNVM_DLL' + | 'extern "C" {' + | 'DEFAULT' Whitespace? '(' .*? ')' + )+ + -> skip + ; + +Newline + : ( '\r' '\n'? + | '\n' + ) + -> skip + ; + +BlockComment + : '/*' .*? '*/' + -> skip + ; + +LineComment + : '//' ~[\r\n]* + -> skip + ; diff --git a/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/AntlrUtils.java b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/AntlrUtils.java new file mode 100644 index 000000000000..3a2b5e058bb2 --- /dev/null +++ b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/AntlrUtils.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mxnet.jnarator; + +import java.util.ArrayList; +import java.util.List; +import org.antlr.v4.runtime.tree.ParseTree; +import org.antlr.v4.runtime.tree.TerminalNode; +import org.apache.mxnet.jnarator.parser.CParser; + +public final class AntlrUtils { + + private AntlrUtils() {} + + public static boolean isTypeDef(CParser.DeclarationSpecifiersContext specs) { + if (specs.isEmpty()) { + return false; + } + + CParser.DeclarationSpecifierContext spec = + (CParser.DeclarationSpecifierContext) specs.getChild(0); + CParser.StorageClassSpecifierContext storage = spec.storageClassSpecifier(); + if (storage != null) { + return storage.Typedef() != null; + } + return false; + } + + public static String getTypeDefValue(CParser.DeclarationSpecifiersContext specs) { + List list = new ArrayList<>(); + for (int i = 1; i < specs.getChildCount(); ++i) { + list.add(specs.getChild(i).getText()); + } + return String.join(" ", list); + } + + public static boolean isEnum(CParser.DeclarationSpecifiersContext specs) { + if (specs.isEmpty()) { + return false; + } + + CParser.DeclarationSpecifierContext spec = + (CParser.DeclarationSpecifierContext) specs.getChild(0); + CParser.TypeSpecifierContext type = spec.typeSpecifier(); + if (type == null) { + return false; + } + return type.enumSpecifier() != null; + } + + public static boolean isStructOrUnion(CParser.DeclarationSpecifiersContext specs) { + if (specs.isEmpty()) { + return false; + } + + CParser.DeclarationSpecifierContext spec = + (CParser.DeclarationSpecifierContext) specs.getChild(0); + CParser.TypeSpecifierContext type = spec.typeSpecifier(); + if (type == null) { + return false; + } + return type.structOrUnionSpecifier() != null; + } + + public static String getText(ParseTree tree) { + StringBuilder sb = new StringBuilder(); + getText(sb, tree); + return sb.toString(); + } + + private static void getText(StringBuilder sb, ParseTree tree) { + if (tree instanceof TerminalNode) { + sb.append("\"v\" : \"").append(tree.getText()).append('"'); + return; + } + sb.append('"'); + sb.append(tree.getClass().getSimpleName()).append("\" : {"); + for (int i = 0; i < tree.getChildCount(); i++) { + getText(sb, tree.getChild(i)); + if (i < tree.getChildCount() - 1) { + sb.append(','); + } + } + sb.append('}'); + } + + public static String toCamelCase(String name) { + String[] tokens = name.split("_"); + for (int i = 0; i < tokens.length; ++i) { + char upper = Character.toUpperCase(tokens[i].charAt(0)); + tokens[i] = upper + tokens[i].substring(1); // NOPMD + } + return String.join("", tokens); + } +} diff --git a/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/DataType.java b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/DataType.java new file mode 100644 index 000000000000..3b1c1fe1e58c --- /dev/null +++ b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/DataType.java @@ -0,0 +1,266 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mxnet.jnarator; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import org.antlr.v4.runtime.tree.ParseTree; +import org.antlr.v4.runtime.tree.TerminalNode; +import org.apache.mxnet.jnarator.parser.CParser; + +public class DataType { + + private boolean isConst; + private boolean functionPointer; + private int pointerCount; + private StringBuilder type = new StringBuilder(); // NOPMD + + public boolean isConst() { + return isConst; + } + + public void setConst() { + isConst = true; + } + + public boolean isFunctionPointer() { + return functionPointer; + } + + public void setFunctionPointer(boolean functionPointer) { + this.functionPointer = functionPointer; + } + + public int getPointerCount() { + return pointerCount; + } + + public void setPointerCount(int pointerCount) { + this.pointerCount = pointerCount; + } + + public void increasePointerCount() { + ++pointerCount; + } + + public String getType() { + return type.toString(); + } + + public void setType(String typeName) { + type.setLength(0); + type.append(typeName); + } + + public void appendTypeName(String name) { + if (type.length() > 0) { + type.append(' '); + } + type.append(name); + } + + public String map(Map map, Set structs) { + String typeName = type.toString().trim(); + TypeDefine typeDefine = map.get(typeName); + boolean isStruct = structs.contains(typeName); + if (typeDefine != null && !typeDefine.isCallBack()) { + typeName = typeDefine.getValue(); + + String mapped = typeName.replaceAll("const ", "").replaceAll(" const", ""); + if (typeName.length() - mapped.length() > 0) { + isConst = true; + } + typeName = mapped; + mapped = typeName.replaceAll("\\*", ""); + int count = typeName.length() - mapped.length(); + pointerCount += count; + typeName = mapped; + setType(typeName); + } + + if (pointerCount > 2) { + return "PointerByReference"; + } + + typeName = baseTypeMapping(typeName); + + if (pointerCount == 2) { + if (isConst && "char".equals(typeName)) { + return "String[]"; + } + return "PointerByReference"; + } + + if (pointerCount == 1) { + switch (typeName) { + case "byte": + return "ByteBuffer"; + case "NativeSize": + return "NativeSizeByReference"; + case "int": + if (isConst) { + return "int[]"; + } + return "IntBuffer"; + case "long": + if (isConst) { + return "long[]"; + } + return "LongBuffer"; + case "char": + if (isConst) { + return "String"; + } + return "ByteBuffer"; + case "float": + return "FloatBuffer"; + case "void": + return "Pointer"; + default: + if (isStruct) { + return typeName + ".ByReference"; + } + return "Pointer"; + } + } + return typeName; + } + + /** {@inheritDoc} */ + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + if (isConst) { + sb.append("const "); + } + sb.append(type); + if (pointerCount > 0) { + sb.append(' '); + for (int i = 0; i < pointerCount; ++i) { + sb.append('*'); + } + } + return sb.toString(); + } + + /** {@inheritDoc} */ + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + DataType dataType = (DataType) o; + return type.toString().equals(dataType.type.toString()); + } + + /** {@inheritDoc} */ + @Override + public int hashCode() { + return Objects.hash(type); + } + + static DataType parse(ParseTree tree) { + DataType dataType = new DataType(); + parseTypeSpec(dataType, tree); + return dataType; + } + + static List parseDataTypes(List list) { + List ret = new ArrayList<>(); + DataType dataType = new DataType(); + for (CParser.DeclarationSpecifierContext spec : list) { + CParser.TypeQualifierContext qualifier = spec.typeQualifier(); + if (qualifier != null) { + String qualifierName = qualifier.getText(); + if ("const".equals(qualifierName)) { + dataType.setConst(); + } else { + dataType.appendTypeName(qualifierName); + } + continue; + } + + CParser.TypeSpecifierContext type = spec.typeSpecifier(); + parseTypeSpec(dataType, type); + ret.add(dataType); + dataType = new DataType(); + } + + return ret; + } + + private static void parseTypeSpec(DataType dataType, ParseTree tree) { + if (tree == null) { + return; + } + + if (tree instanceof CParser.StructOrUnionContext) { + return; + } + if (tree instanceof CParser.TypedefNameContext) { + if (dataType.getType().isEmpty()) { + dataType.appendTypeName(tree.getText()); + } + return; + } + + if (tree instanceof TerminalNode) { + String value = tree.getText(); + if ("const".equals(value)) { + dataType.setConst(); + } else if ("*".equals(value)) { + dataType.increasePointerCount(); + } else { + dataType.appendTypeName(value); + } + return; + } + + for (int i = 0; i < tree.getChildCount(); i++) { + parseTypeSpec(dataType, tree.getChild(i)); + } + } + + private static String baseTypeMapping(String type) { + switch (type) { + case "uint64_t": + case "int64_t": + case "long": + return "long"; + case "uint32_t": + case "unsigned int": + case "unsigned": + case "int": + return "int"; + case "bool": + return "byte"; + case "size_t": + return "NativeSize"; + case "char": + case "void": + case "float": + default: + return type; + } + } +} diff --git a/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/FuncInfo.java b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/FuncInfo.java new file mode 100644 index 000000000000..3cdc9c90bd6f --- /dev/null +++ b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/FuncInfo.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mxnet.jnarator; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import org.apache.mxnet.jnarator.parser.CParser; + +public class FuncInfo { + + private String name; + private DataType returnType; + private List parameters = new ArrayList<>(); + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public DataType getReturnType() { + return returnType; + } + + public void setReturnType(DataType returnType) { + this.returnType = returnType; + } + + public List getParameters() { + return parameters; + } + + public void setParameters(List parameters) { + this.parameters = parameters; + } + + public void addParameter(Parameter parameter) { + parameters.add(parameter); + } + + /** {@inheritDoc} */ + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(returnType).append(' ').append(name).append('('); + if (parameters != null) { + boolean first = true; + for (Parameter param : parameters) { + if (first) { + first = false; + } else { + sb.append(", "); + } + sb.append(param); + } + } + + sb.append(");"); + return sb.toString(); + } + + /** {@inheritDoc} */ + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + FuncInfo funcInfo = (FuncInfo) o; + return name.equals(funcInfo.name) && parameters.equals(funcInfo.parameters); + } + + /** {@inheritDoc} */ + @Override + public int hashCode() { + return Objects.hash(name); + } + + static FuncInfo parse(CParser.DeclarationContext ctx) { + FuncInfo info = new FuncInfo(); + + List specs = + ctx.declarationSpecifiers().declarationSpecifier(); + List dataTypes = DataType.parseDataTypes(specs); + info.setReturnType(dataTypes.get(0)); + if (dataTypes.size() > 1) { + info.setName(dataTypes.get(1).getType()); + } + + CParser.InitDeclaratorContext init = ctx.initDeclaratorList().initDeclarator(); + CParser.DirectDeclaratorContext declarator = init.declarator().directDeclarator(); + + CParser.DirectDeclaratorContext name = declarator.directDeclarator(); + if (info.getName() == null) { + info.setName(name.getText()); + CParser.ParameterTypeListContext paramListCtx = declarator.parameterTypeList(); + if (paramListCtx != null) { + Parameter.parseParams(info.getParameters(), paramListCtx); + } + } else { + DataType dataType = new DataType(); + CParser.TypeSpecifierContext type = declarator.typeSpecifier(); + dataType.appendTypeName(type.getText()); + if (declarator.pointer() != null) { + dataType.increasePointerCount(); + } + Parameter param = new Parameter(dataType, name.getText()); + info.addParameter(param); + } + + return info; + } +} diff --git a/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/JnaGenerator.java b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/JnaGenerator.java new file mode 100644 index 000000000000..7be963d3d0ff --- /dev/null +++ b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/JnaGenerator.java @@ -0,0 +1,371 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mxnet.jnarator; + +import java.io.BufferedWriter; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.Set; + +public class JnaGenerator { + + private Path dir; + private String packageName; + private String libName; + private String className; + private Map typedefMap; + private Set structs; + private Properties mapping; + + public JnaGenerator( + String libName, + String packageName, + Map typedefMap, + Set structs, + Properties mapping) { + this.libName = libName; + this.packageName = packageName; + this.typedefMap = typedefMap; + this.structs = structs; + this.mapping = mapping; + } + + public void init(String output) throws IOException { + String[] tokens = packageName.split("\\."); + dir = Paths.get(output, tokens); + Files.createDirectories(dir); + className = AntlrUtils.toCamelCase(libName) + "Library"; + } + + @SuppressWarnings("PMD.UseConcurrentHashMap") + public void writeStructure(Map> structMap) throws IOException { + for (Map.Entry> entry : structMap.entrySet()) { + String name = entry.getKey(); + Path path = dir.resolve(name + ".java"); + try (BufferedWriter writer = Files.newBufferedWriter(path)) { + writer.append("package ").append(packageName).append(";\n\n"); + + Set importSet = new HashSet<>(); + importSet.add("com.sun.jna.Pointer"); + importSet.add("com.sun.jna.Structure"); + importSet.add("java.util.List"); + + Map fieldNames = new LinkedHashMap<>(); + for (TypeDefine typeDefine : entry.getValue()) { + String fieldName = typeDefine.getValue(); + String typeName; + if (typeDefine.isCallBack()) { + typeName = AntlrUtils.toCamelCase(fieldName) + "Callback"; + importSet.add("com.sun.jna.Callback"); + for (Parameter param : typeDefine.getParameters()) { + String type = param.getType().map(typedefMap, structs); + addImports(importSet, type); + } + } else { + typeName = typeDefine.getDataType().map(typedefMap, structs); + addImports(importSet, typeName); + } + fieldNames.put(fieldName, typeName); + } + + int fieldCount = fieldNames.size(); + if (fieldCount < 2) { + importSet.add("java.util.Collections"); + } else { + importSet.add("java.util.Arrays"); + } + + List imports = new ArrayList<>(importSet.size()); + imports.addAll(importSet); + Collections.sort(imports); + for (String imp : imports) { + writer.append("import ").append(imp).append(";\n"); + } + + writer.append("\npublic class ").append(name).append(" extends Structure {\n"); + if (fieldCount > 0) { + writer.write("\n"); + } + for (Map.Entry field : fieldNames.entrySet()) { + writer.append(" public ").append(field.getValue()).append(' '); + writer.append(field.getKey()).append(";\n"); + } + + writer.append("\n public ").append(name).append("() {\n"); + writer.append(" }\n"); + writer.append("\n public ").append(name).append("(Pointer peer) {\n"); + writer.append(" super(peer);\n"); + writer.append(" }\n"); + + writer.append("\n @Override\n"); + writer.append(" protected List getFieldOrder() {\n"); + switch (fieldNames.size()) { + case 0: + writer.append(" return Collections.emptyList();\n"); + break; + case 1: + writer.append(" return Collections.singletonList("); + String firstField = fieldNames.keySet().iterator().next(); + writer.append('"').append(firstField).append("\");\n"); + break; + default: + writer.append(" return Arrays.asList("); + boolean first = true; + for (String fieldName : fieldNames.keySet()) { + if (first) { + first = false; + } else { + writer.write(", "); + } + writer.append('"').append(fieldName).append('"'); + } + writer.append(");\n"); + break; + } + writer.append(" }\n"); + + for (TypeDefine typeDefine : entry.getValue()) { + String fieldName = typeDefine.getValue(); + String typeName = fieldNames.get(fieldName); + String getterName; + if (!typeDefine.isCallBack()) { + getterName = AntlrUtils.toCamelCase(fieldName); + } else { + getterName = typeName; + } + + writer.append("\n public void set").append(getterName).append('('); + writer.append(typeName).append(' ').append(fieldName).append(") {\n"); + writer.append(" this.").append(fieldName).append(" = "); + writer.append(fieldName).append(";\n"); + writer.append(" }\n"); + writer.append("\n public ").append(typeName).append(" get"); + writer.append(getterName).append("() {\n"); + writer.append(" return ").append(fieldName).append(";\n"); + writer.append(" }\n"); + } + + writer.append("\n public static final class ByReference extends "); + writer.append(name).append(" implements Structure.ByReference {}\n"); + + writer.append("\n public static final class ByValue extends "); + writer.append(name).append(" implements Structure.ByValue {}\n"); + + for (TypeDefine typeDefine : entry.getValue()) { + if (typeDefine.isCallBack()) { + DataType dataType = typeDefine.getDataType(); + String fieldName = typeDefine.getValue(); + + String callbackName = fieldNames.get(fieldName); + String returnType = mapping.getProperty(callbackName); + if (returnType == null) { + returnType = dataType.map(typedefMap, structs); + } + + writer.append("\n public interface ").append(callbackName); + writer.append(" extends Callback {\n"); + writer.append(" ").append(returnType).append(" apply("); + writeParameters(writer, fieldName, typeDefine.getParameters()); + writer.append(");\n"); + writer.append(" }\n"); + } + } + + writer.append("}\n"); + } + } + } + + public void writeLibrary(Collection functions, Map> enumMap) + throws IOException { + try (BufferedWriter writer = Files.newBufferedWriter(dir.resolve(className + ".java"))) { + writer.append("package ").append(packageName).append(";\n\n"); + + writer.append("import com.sun.jna.Callback;\n"); + writer.append("import com.sun.jna.Library;\n"); + writer.append("import com.sun.jna.Pointer;\n"); + writer.append("import com.sun.jna.ptr.PointerByReference;\n"); + writer.append("import java.nio.ByteBuffer;\n"); + writer.append("import java.nio.FloatBuffer;\n"); + writer.append("import java.nio.IntBuffer;\n"); + writer.append("import java.nio.LongBuffer;\n"); + + writer.append("\npublic interface ").append(className).append(" extends Library {\n\n"); + + for (Map.Entry> entry : enumMap.entrySet()) { + String name = entry.getKey(); + writer.append("\n enum ").append(name).append(" {\n"); + List fields = entry.getValue(); + int count = 0; + for (String field : fields) { + writer.append(" ").append(field); + if (++count < fields.size()) { + writer.append(','); + } + writer.append('\n'); + } + writer.append(" }\n"); + } + + for (TypeDefine typeDefine : typedefMap.values()) { + if (typeDefine.isCallBack()) { + String callbackName = typeDefine.getDataType().getType(); + String returnType = mapping.getProperty(callbackName); + if (returnType == null) { + returnType = typeDefine.getValue(); + } + writer.append("\n interface ").append(callbackName); + writer.append(" extends Callback {\n"); + writer.append(" ").append(returnType).append(" apply("); + writeParameters(writer, callbackName, typeDefine.getParameters()); + writer.append(");\n"); + writer.append(" }\n"); + } + } + + for (FuncInfo info : functions) { + writeFunction(writer, info); + } + writer.append("}\n"); + } + } + + public void writeNativeSize() throws IOException { + try (BufferedWriter writer = Files.newBufferedWriter(dir.resolve("NativeSize.java"))) { + writer.append("package ").append(packageName).append(";\n\n"); + writer.append("import com.sun.jna.IntegerType;\n"); + writer.append("import com.sun.jna.Native;\n\n"); + + writer.append("public class NativeSize extends IntegerType {\n\n"); + writer.append(" private static final long serialVersionUID = 1L;\n\n"); + writer.append(" public static final int SIZE = Native.SIZE_T_SIZE;\n\n"); + writer.append(" public NativeSize() {\n"); + writer.append(" this(0);\n"); + writer.append(" }\n\n"); + writer.append(" public NativeSize(long value) {\n"); + writer.append(" super(SIZE, value);\n"); + writer.append(" }\n"); + writer.append("}\n"); + } + + Path path = dir.resolve("NativeSizeByReference.java"); + try (BufferedWriter writer = Files.newBufferedWriter(path)) { + writer.append("package ").append(packageName).append(";\n\n"); + writer.append("import com.sun.jna.ptr.ByReference;\n\n"); + writer.append("public class NativeSizeByReference extends ByReference {\n\n"); + writer.append(" public NativeSizeByReference() {\n"); + writer.append(" this(new NativeSize(0));\n"); + writer.append(" }\n\n"); + writer.append(" public NativeSizeByReference(NativeSize value) {\n"); + writer.append(" super(NativeSize.SIZE);\n"); + writer.append(" setValue(value);\n"); + writer.append(" }\n\n"); + writer.append(" public void setValue(NativeSize value) {\n"); + writer.append(" if (NativeSize.SIZE == 4) {\n"); + writer.append(" getPointer().setInt(0, value.intValue());\n"); + writer.append(" } else if (NativeSize.SIZE == 8) {\n"); + writer.append(" getPointer().setLong(0, value.longValue());\n"); + writer.append(" } else {\n"); + writer.append( + " throw new IllegalArgumentException(\"size_t has to be either 4 or 8 bytes.\");\n"); + writer.append(" }\n"); + writer.append(" }\n\n"); + writer.append(" public NativeSize getValue() {\n"); + writer.append(" if (NativeSize.SIZE == 4) {\n"); + writer.append(" return new NativeSize(getPointer().getInt(0));\n"); + writer.append(" } else if (NativeSize.SIZE == 8) {\n"); + writer.append(" return new NativeSize(getPointer().getLong(0));\n"); + writer.append(" } else {\n"); + writer.append( + " throw new IllegalArgumentException(\"size_t has to be either 4 or 8 bytes.\");\n"); + writer.append(" }\n"); + writer.append(" }\n"); + writer.append("}\n"); + } + } + + private void writeFunction(BufferedWriter writer, FuncInfo info) throws IOException { + String funcName = info.getName(); + String returnType = mapping.getProperty(funcName); + if (returnType == null) { + returnType = info.getReturnType().map(typedefMap, structs); + } + writer.append("\n ").append(returnType).append(' '); + writer.append(funcName).append('('); + writeParameters(writer, funcName, info.getParameters()); + writer.append(");\n"); + } + + private void writeParameters(BufferedWriter writer, String funcName, List parameters) + throws IOException { + if (parameters != null) { + boolean first = true; + for (Parameter param : parameters) { + if (first) { + first = false; + } else { + writer.append(", "); + } + String paramName = param.getName(); + String type = mapping.getProperty(funcName + '.' + paramName); + if (type == null) { + type = param.getType().map(typedefMap, structs); + } + if (!"void".equals(type)) { + writer.append(type).append(' '); + writer.append(paramName); + } + } + } + } + + private static void addImports(Set importSet, String typeName) { + switch (typeName) { + case "ByReference": + case "ByteByReference": + case "DoubleByReference": + case "FloatByReference": + case "IntByReference": + case "LongByReference": + case "NativeLongByReference": + case "PointerByReference": + case "ShortByReference": + importSet.add("com.sun.jna.ptr." + typeName); + break; + case "ByteBuffer": + case "DoubleBuffer": + case "FloatBuffer": + case "IntBuffer": + case "LongBuffer": + case "ShortBuffer": + importSet.add("java.nio." + typeName); + break; + default: + break; + } + } +} diff --git a/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/JnaParser.java b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/JnaParser.java new file mode 100644 index 000000000000..16a350e8d474 --- /dev/null +++ b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/JnaParser.java @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mxnet.jnarator; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.antlr.v4.runtime.CharStreams; +import org.antlr.v4.runtime.CommonTokenStream; +import org.antlr.v4.runtime.tree.ParseTree; +import org.antlr.v4.runtime.tree.ParseTreeWalker; +import org.apache.mxnet.jnarator.parser.CBaseListener; +import org.apache.mxnet.jnarator.parser.CLexer; +import org.apache.mxnet.jnarator.parser.CParser; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class JnaParser { + + static final Logger logger = LoggerFactory.getLogger(Main.class); + + Map> structMap; + Map> enumMap; + List functions; + Map typedefMap; + private Set functionNames; + + public JnaParser() { + structMap = new LinkedHashMap<>(); + enumMap = new LinkedHashMap<>(); + functions = new ArrayList<>(); + typedefMap = new LinkedHashMap<>(); + functionNames = new HashSet<>(); + } + + public void parse(String headerFile) { + try { + CLexer lexer = new CLexer(CharStreams.fromFileName(headerFile)); + CommonTokenStream tokens = new CommonTokenStream(lexer); + CParser parser = new CParser(tokens); + ParseTree tree = parser.compilationUnit(); + + ParseTreeWalker walker = new ParseTreeWalker(); + CBaseListener listener = + new CBaseListener() { + + /** {@inheritDoc} */ + @Override + public void enterDeclaration(CParser.DeclarationContext ctx) { + CParser.DeclarationSpecifiersContext specs = + ctx.declarationSpecifiers(); + CParser.InitDeclaratorListContext init = ctx.initDeclaratorList(); + + if (AntlrUtils.isTypeDef(specs)) { + TypeDefine value = TypeDefine.parse(init, specs); + typedefMap.put(value.getDataType().getType(), value); + } else if (AntlrUtils.isStructOrUnion(specs)) { + CParser.DeclarationSpecifierContext spec = + (CParser.DeclarationSpecifierContext) specs.getChild(0); + CParser.TypeSpecifierContext type = spec.typeSpecifier(); + CParser.StructOrUnionSpecifierContext struct = + type.structOrUnionSpecifier(); + String name = struct.Identifier().getText(); + List fields = new ArrayList<>(); + + CParser.StructDeclarationListContext list = + struct.structDeclarationList(); + parseStructFields(fields, list); + + structMap.put(name, fields); + } else if (AntlrUtils.isEnum(specs)) { + CParser.DeclarationSpecifierContext spec = + (CParser.DeclarationSpecifierContext) specs.getChild(0); + CParser.TypeSpecifierContext type = spec.typeSpecifier(); + CParser.EnumSpecifierContext enumSpecifierContext = + type.enumSpecifier(); + String name = enumSpecifierContext.Identifier().getText(); + List fields = new ArrayList<>(); + parseEnum(fields, ctx); + enumMap.put(name, fields); + } else { + FuncInfo info = FuncInfo.parse(ctx); + if (checkDuplicate(info)) { + logger.warn("Duplicate function: {}.", info.getName()); + } else { + functions.add(info); + } + } + } + }; + walker.walk(listener, tree); + } catch (IOException e) { + logger.error("", e); + } + } + + void parseStructFields(List fields, ParseTree tree) { + if (tree instanceof CParser.StructDeclarationContext) { + CParser.StructDeclarationContext ctx = (CParser.StructDeclarationContext) tree; + CParser.SpecifierQualifierListContext qualifierList = ctx.specifierQualifierList(); + DataType dataType = DataType.parse(qualifierList); + + TypeDefine typeDefine = new TypeDefine(); + fields.add(typeDefine); + + typeDefine.setDataType(dataType); + + CParser.StructDeclaratorListContext name = ctx.structDeclaratorList(); + if (name != null) { + typeDefine.setCallBack(true); + + CParser.DirectDeclaratorContext dd = + name.structDeclarator().declarator().directDeclarator(); + CParser.DirectDeclaratorContext nameCtx = + dd.directDeclarator().declarator().directDeclarator(); + String fieldName = nameCtx.getText(); + typeDefine.setValue(fieldName); + + CParser.ParameterTypeListContext paramListCtx = dd.parameterTypeList(); + if (paramListCtx != null) { + Parameter.parseParams(typeDefine.getParameters(), paramListCtx); + } + } else { + CParser.SpecifierQualifierListContext nameList = + qualifierList.specifierQualifierList(); + if (nameList.specifierQualifierList() != null) { + typeDefine.setValue(nameList.specifierQualifierList().getText()); + } else { + typeDefine.setValue(nameList.getText()); + } + } + return; + } + + for (int i = 0; i < tree.getChildCount(); i++) { + parseStructFields(fields, tree.getChild(i)); + } + } + + void parseEnum(List fields, ParseTree ctx) { + if (ctx instanceof CParser.EnumerationConstantContext) { + fields.add(ctx.getText()); + return; + } + + for (int i = 0; i < ctx.getChildCount(); i++) { + parseEnum(fields, ctx.getChild(i)); + } + } + + public Map> getStructMap() { + return structMap; + } + + public Map> getEnumMap() { + return enumMap; + } + + public List getFunctions() { + return functions; + } + + public Map getTypedefMap() { + return typedefMap; + } + + boolean checkDuplicate(FuncInfo function) { + if (!functionNames.add(function.getName())) { + for (FuncInfo info : functions) { + if (function.equals(info)) { + return true; + } + } + } + return false; + } +} diff --git a/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/Main.java b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/Main.java new file mode 100644 index 000000000000..39aa8fcdb924 --- /dev/null +++ b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/Main.java @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mxnet.jnarator; + +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.DefaultParser; +import org.apache.commons.cli.HelpFormatter; +import org.apache.commons.cli.Option; +import org.apache.commons.cli.Options; +import org.apache.commons.cli.ParseException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public final class Main { + + private static final Logger logger = LoggerFactory.getLogger(Main.class); + + private Main() {} + + public static void main(String[] args) { + Options options = Config.getOptions(); + try { + DefaultParser cmdParser = new DefaultParser(); + CommandLine cmd = cmdParser.parse(options, args, null, false); + Config config = new Config(cmd); + + String output = config.getOutput(); + String packageName = config.getPackageName(); + String library = config.getLibrary(); + String[] headerFiles = config.getHeaderFiles(); + String mappingFile = config.getMappingFile(); + + Path dir = Paths.get(output); + Files.createDirectories(dir); + + Properties mapping = new Properties(); + if (mappingFile != null) { + Path file = Paths.get(mappingFile); + if (Files.notExists(file)) { + logger.error("mapping file does not exists: {}", mappingFile); + System.exit(-1); // NOPMD + } + try (InputStream in = Files.newInputStream(file)) { + mapping.load(in); + } + } + + JnaParser jnaParser = new JnaParser(); + Map typedefMap = jnaParser.getTypedefMap(); + Map> structMap = jnaParser.getStructMap(); + JnaGenerator generator = + new JnaGenerator(library, packageName, typedefMap, structMap.keySet(), mapping); + generator.init(output); + + for (String headerFile : headerFiles) { + jnaParser.parse(headerFile); + } + + generator.writeNativeSize(); + generator.writeStructure(structMap); + generator.writeLibrary(jnaParser.getFunctions(), jnaParser.getEnumMap()); + } catch (ParseException e) { + HelpFormatter formatter = new HelpFormatter(); + formatter.setLeftPadding(1); + formatter.setWidth(120); + formatter.printHelp(e.getMessage(), options); + System.exit(-1); // NOPMD + } catch (Throwable t) { + logger.error("", t); + System.exit(-1); // NOPMD + } + } + + public static final class Config { + + private String library; + private String packageName; + private String output; + private String[] headerFiles; + private String mappingFile; + + public Config(CommandLine cmd) { + library = cmd.getOptionValue("library"); + packageName = cmd.getOptionValue("package"); + output = cmd.getOptionValue("output"); + headerFiles = cmd.getOptionValues("header"); + mappingFile = cmd.getOptionValue("mappingFile"); + } + + public static Options getOptions() { + Options options = new Options(); + options.addOption( + Option.builder("l") + .longOpt("library") + .hasArg() + .required() + .argName("LIBRARY") + .desc("library name") + .build()); + options.addOption( + Option.builder("p") + .longOpt("package") + .required() + .hasArg() + .argName("PACKAGE") + .desc("Java package name") + .build()); + options.addOption( + Option.builder("o") + .longOpt("output") + .required() + .hasArg() + .argName("OUTPUT") + .desc("output directory") + .build()); + options.addOption( + Option.builder("f") + .longOpt("header") + .required() + .hasArgs() + .argName("HEADER") + .desc("Header files") + .build()); + options.addOption( + Option.builder("m") + .longOpt("mappingFile") + .hasArg() + .argName("MAPPING_FILE") + .desc("Type mappingFile config file") + .build()); + return options; + } + + public String getLibrary() { + return library; + } + + public String getPackageName() { + return packageName; + } + + public String getOutput() { + return output; + } + + public String[] getHeaderFiles() { + return headerFiles; + } + + public String getMappingFile() { + return mappingFile; + } + } +} diff --git a/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/Parameter.java b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/Parameter.java new file mode 100644 index 000000000000..f46e5e7b7399 --- /dev/null +++ b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/Parameter.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mxnet.jnarator; + +import java.util.List; +import java.util.Objects; +import org.antlr.v4.runtime.tree.ParseTree; +import org.apache.mxnet.jnarator.parser.CParser; + +public class Parameter { + + private DataType type; + private String name; + + public Parameter(DataType type, String name) { + this.type = type; + this.name = name; + } + + public DataType getType() { + return type; + } + + public String getName() { + return name; + } + + /** {@inheritDoc} */ + @Override + public String toString() { + return type.toString() + ' ' + name; + } + + /** {@inheritDoc} */ + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Parameter parameter = (Parameter) o; + return type.equals(parameter.type); + } + + /** {@inheritDoc} */ + @Override + public int hashCode() { + return Objects.hash(type); + } + + static void parseParams(List params, ParseTree ctx) { + if (ctx instanceof CParser.ParameterDeclarationContext) { + CParser.ParameterDeclarationContext declarationContext = + (CParser.ParameterDeclarationContext) ctx; + CParser.DeclarationSpecifiersContext spec = declarationContext.declarationSpecifiers(); + DataType dataType; + if (spec == null) { + dataType = DataType.parse(declarationContext.declarationSpecifiers2()); + } else { + dataType = DataType.parse(spec); + } + + CParser.DeclaratorContext declarator = declarationContext.declarator(); + + String name; + if (declarator != null) { + CParser.PointerContext pointer = declarator.pointer(); + if (pointer != null) { + dataType.increasePointerCount(); + } + name = declarator.directDeclarator().getText(); + } else { + name = "arg" + (params.size() + 1); + } + + Parameter param = new Parameter(dataType, name); + params.add(param); + return; + } + for (int i = 0; i < ctx.getChildCount(); i++) { + parseParams(params, ctx.getChild(i)); + } + } +} diff --git a/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/TypeDefine.java b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/TypeDefine.java new file mode 100644 index 000000000000..1f30d21b28a2 --- /dev/null +++ b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/TypeDefine.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mxnet.jnarator; + +import java.util.ArrayList; +import java.util.List; +import org.apache.mxnet.jnarator.parser.CParser; + +public class TypeDefine { + + private DataType dataType; + private boolean callBack; + private String value; + private List parameters = new ArrayList<>(); + + public DataType getDataType() { + return dataType; + } + + public void setDataType(DataType dataType) { + this.dataType = dataType; + } + + public boolean isCallBack() { + return callBack; + } + + public void setCallBack(boolean callBack) { + this.callBack = callBack; + } + + public String getValue() { + return value; + } + + public void setValue(String value) { + this.value = value; + } + + public List getParameters() { + return parameters; + } + + static TypeDefine parse( + CParser.InitDeclaratorListContext init, CParser.DeclarationSpecifiersContext specs) { + TypeDefine typeDefine = new TypeDefine(); + DataType dataType = new DataType(); + typeDefine.setDataType(dataType); + + CParser.DirectDeclaratorContext ctx = init.initDeclarator().declarator().directDeclarator(); + CParser.DirectDeclaratorContext callback = ctx.directDeclarator(); + if (callback == null) { + dataType.setType(ctx.getText()); + } else { + typeDefine.setCallBack(true); + dataType.setType(callback.declarator().directDeclarator().getText()); + CParser.ParameterTypeListContext paramListCtx = ctx.parameterTypeList(); + List parameters = typeDefine.getParameters(); + Parameter.parseParams(parameters, paramListCtx); + } + + List list = new ArrayList<>(); + for (int i = 1; i < specs.getChildCount(); ++i) { + list.add(specs.getChild(i).getText()); + } + + typeDefine.setValue(String.join(" ", list)); + return typeDefine; + } +} diff --git a/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/package-info.java b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/package-info.java new file mode 100644 index 000000000000..a6cff702d34b --- /dev/null +++ b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/package-info.java @@ -0,0 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Contains classes to generate the Apache MXNet (incubating) native interface. */ +package org.apache.mxnet.jnarator; diff --git a/java-package/jnarator/src/main/resources/log4j2.xml b/java-package/jnarator/src/main/resources/log4j2.xml new file mode 100644 index 000000000000..4818a95eec9a --- /dev/null +++ b/java-package/jnarator/src/main/resources/log4j2.xml @@ -0,0 +1,30 @@ + + + + + + + + + + + + + + + diff --git a/java-package/mxnet-engine/build.gradle b/java-package/mxnet-engine/build.gradle new file mode 100644 index 000000000000..c874f723ffc0 --- /dev/null +++ b/java-package/mxnet-engine/build.gradle @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +plugins { + id 'java' +} + +group 'org.apache.mxnet' +version '0.0.1-SNAPSHOT' + +repositories { + mavenCentral() +} + +def getOsName() { + def os_name = System.properties['os.name'] + if (os_name.contains('windows')) { + return "win" + } else if (os_name.contains('Mac OS X')) { + return "osx" + } else if (os_name.contains('Linux')) { + return "linux" + } else { + return System.properties['os.name'] + } +} + +dependencies { + api "com.google.code.gson:gson:${gson_version}" + api "net.java.dev.jna:jna:${jna_version}" + api "org.apache.commons:commons-compress:${commons_compress_version}" + api "org.slf4j:slf4j-api:${slf4j_version}" + + testImplementation("org.testng:testng:${testng_version}") { + exclude group: "junit", module: "junit" + } + testImplementation "org.slf4j:slf4j-simple:${slf4j_version}" + // Solve the problem: Failed to load class "org.slf4j.impl.StaticLoggerBinder". + implementation "org.slf4j:slf4j-simple:${slf4j_version}" + def osName = getOsName() + implementation files("${project(':native').buildDir}/libs/native-${mxnet_version}-SNAPSHOT-${osName}-x86_64.jar") +// implementation fileTree(dir: "${project(':native').buildDir}/lib", includes: ["native-${mxnet_version}-SNAPSHOT-${osName}-x86_64.jar"]) +} + +sourceSets { + main { + java { + srcDirs = ['src/main/java', 'build/generated-src'] + } + } +} + +checkstyleMain.source = 'src/main/java' +pmdMain.source = 'src/main/java' + +task jnarator(dependsOn: ":jnarator:jar") { + outputs.dir "${project.buildDir}/generated-src" + doLast { + File jnaGenerator = project(":jnarator").jar.outputs.files.singleFile + javaexec { + main = "-jar" + args = [ + jnaGenerator.absolutePath, + "-l", + "mxnet", + "-p", + "org.apache.mxnet.jna", + "-o", + "${project.buildDir}/generated-src", + "-m", + "${project.projectDir}/src/main/jna/mapping.properties", + "-f", + "../../include/mxnet/c_api.h", + "../../include/nnvm/c_api.h" + ] + } + } +} + +test { + useTestNG() { + useDefaultListeners = true + } + environment "PATH", "src/test/bin:${environment.PATH}" +// environment "MXNET_LIBRARY_PATH", "${MXNET_LIBRARY_PATH}" + maxHeapSize = '6G' + testLogging.showStandardStreams = true + beforeTest { descriptor -> + logger.lifecycle("Running test: " + descriptor) + } + failFast = false + onOutput { descriptor, event -> + logger.lifecycle("Test: " + descriptor + " produced standard out/err: " + event.message ) + } +// debugOptions { +// enabled = true +// port = 4455 +// server = true +// suspend = true +// } +// filter { +// includeTestsMatching("*Test") +// } +} + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//import java.util.regex.Matcher +//import java.util.regex.Pattern + +//def checkForUpdate(String path, String url) { +// def expected = new URL(url).text +// def actual = new File("${project.projectDir}/src/main/include/${path}").text +// if (!actual.equals(expected)) { +// def fileName = path.replaceAll("[/\\\\]", '_') +// file("${project.projectDir}/build").mkdirs() +// (file("${project.projectDir}/build/${fileName}")).text = expected +// logger.warn("[\033[31mWARN\033[0m ] Header file has been changed in open source project: ${path}.") +// } +//} + +//task checkHeaderFile() { +// outputs.files "${project.buildDir}/mxnet_c_api.h", "${project.buildDir}/nnvm_c_api.h" +// doFirst { +// if (gradle.startParameter.offline) { +// logger.warn("[\033[31mWARN\033[0m ] Ignore header validation in offline mode.") +// return +// } +// +// def mxnetUrl = "https://raw.githubusercontent.com/apache/incubator-mxnet/v1.7.x/" +// checkForUpdate("mxnet/c_api.h", "${mxnetUrl}/include/mxnet/c_api.h") +// def content = new URL("https://github.com/apache/incubator-mxnet/tree/v1.7.x/3rdparty").text +// +// Pattern pattern = Pattern.compile("href=\"/apache/incubator-tvm/tree/([a-z0-9]+)\"") +// Matcher m = pattern.matcher(content); +// if (!m.find()) { +// throw new GradleException("Failed to retrieve submodule hash for tvm from github") +// } +// String hash = m.group(1); +// +// def nnvmUrl = "https://raw.githubusercontent.com/apache/incubator-tvm/${hash}" +// checkForUpdate("nnvm/c_api.h", "${nnvmUrl}/nnvm/include/nnvm/c_api.h") +// } +//} + +compileJava.dependsOn(jnarator) + +// TODO +//publishing { +// publications { +// maven(MavenPublication) { +// pom { +// name = "DJL Engine Adapter for Apache MXNet" +// description = "Deep Java Library (DJL) Engine Adapter for Apache MXNet" +// url = "http://www.djl.ai/mxnet/${project.name}" +// } +// } +// } +//} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/BaseMxResource.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/BaseMxResource.java new file mode 100644 index 000000000000..13ae88d21fad --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/BaseMxResource.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.engine; + +import org.apache.mxnet.jna.JnaUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * The top-level {@link MxResource} instance, with no parent Resource to manage. The {@link + * BaseMxResource} instance will be lazy loaded when the first time called, like when {@link Model} + * instance is loaded for the first time. + */ +public final class BaseMxResource extends MxResource { + + private static final Logger logger = LoggerFactory.getLogger(BaseMxResource.class); + + private static BaseMxResource systemMxResource; + + protected BaseMxResource() { + super(); + // Workaround MXNet engine lazy initialization issue + JnaUtils.getAllOpNames(); + + JnaUtils.setNumpyMode(JnaUtils.NumpyMode.GLOBAL_ON); + + // Workaround MXNet shutdown crash issue + Runtime.getRuntime().addShutdownHook(new Thread(JnaUtils::waitAll)); // NOPMD + } + + /** + * Getter method for the singleton {@code systemMxResource} instance. + * + * @return The top-leve {@link BaseMxResource} instance. + */ + public static synchronized BaseMxResource getSystemMxResource() { + if (systemMxResource == null) { + systemMxResource = new BaseMxResource(); + } + return systemMxResource; + } + + /** {@inheritDoc} */ + @Override + public void close() { + if (!getClosed()) { + logger.debug(String.format("Start to free BaseMxResource instance: %S", this.getUid())); + // only clean sub resources + JnaUtils.waitAll(); + super.freeSubResources(); + setClosed(true); + logger.debug( + String.format("Finish to free BaseMxResource instance: %S", this.getUid())); + } + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/CachedOp.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/CachedOp.java new file mode 100644 index 000000000000..eef058e576a3 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/CachedOp.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mxnet.engine; + +import com.sun.jna.Pointer; +import java.util.List; +import java.util.Map; +import org.apache.mxnet.jna.JnaUtils; +import org.apache.mxnet.ndarray.NDArray; +import org.apache.mxnet.ndarray.NDList; +import org.apache.mxnet.ndarray.types.Shape; +import org.apache.mxnet.nn.Parameter; +import org.apache.mxnet.nn.SymbolBlock; +import org.apache.mxnet.util.Pair; +import org.apache.mxnet.util.PairList; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * The {@code CachedOp} is an internal helper that provides the core functionality to execute a + * {@link SymbolBlock}. + * + *

We don't recommend users interact with this class directly. Users should use {@link Predictor} + * instead. CachedOp is an operator that simplifies calling and analyzing the input shape. It + * requires minimum input to do inference because most of the information can be obtained from the + * model itself. + */ +public class CachedOp extends MxResource { + + private static final Logger logger = LoggerFactory.getLogger(CachedOp.class); + + private List parameters; + private PairList dataIndices; + private Map dataIndicesMap; + private List paramIndices; + + /** + * Creates an instance of {@link CachedOp}. + * + *

It can be created by using {@link JnaUtils#createCachedOp(SymbolBlock, MxResource)} + * + * @param parent the MxResource object to manage this instance of CachedOp + * @param handle the C handle of the CachedOp + * @param parameters the parameter values + * @param paramIndices the parameters required by the model and their corresponding location + * @param dataIndices the input data names required by the model and their corresponding + * location + */ + public CachedOp( + MxResource parent, + Pointer handle, + List parameters, + List paramIndices, + PairList dataIndices) { + super(parent, handle); + this.parameters = parameters; + this.dataIndices = dataIndices; + this.paramIndices = paramIndices; + this.dataIndicesMap = dataIndices.toMap(); + } + + /** + * Assigns inputs to the empty locations of the input NDArray. + * + * @param data the input in {@link NDList} format + * @return an {@link NDList} + */ + public NDList forward(NDList data) { + // reset the input data index at the beginning + NDArray[] allInputsNDArray = new NDArray[parameters.size()]; + // check device of input + Device device = data.isEmpty() ? Device.defaultIfNull() : data.head().getDevice(); + // fill allInputsNDArray with parameter values on correct device + for (int index : paramIndices) { + Parameter parameter = parameters.get(index); + NDArray value = parameter.getArray(); + if (value == null) { + throw new NullPointerException("Failed to find parameter from parameterStore"); + } + value.setDevice(device); + allInputsNDArray[index] = value; + } + + // fill allInputsNDArray with data values + int index = 0; + for (NDArray array : data) { + // TODO: NDArray name doesn't match. To confirm the format of input name + // String inputName = array.getName().split(":")[1]; + String inputName = array.getName(); + // if inputName not provided, value will follow the default order + int idx = indexOf(inputName, index++); + allInputsNDArray[idx] = array; + } + + // check the input, set as Shape(batchSize) by default + for (Pair pair : dataIndices) { + if (allInputsNDArray[pair.getValue()] == null) { + // TODO: Do we need to set default to the input? + long batchSize = data.head().getShape().get(0); + String key = pair.getKey(); + if (!"prob_label".equals(key) && !"softmax_label".equals(key)) { + logger.warn( + "Input " + + key + + " not found, set NDArray to Shape(" + + batchSize + + ") by default"); + } + // TODO: consider how to manage MxNDArray generated during inference + allInputsNDArray[pair.getValue()] = + NDArray.create(this, new Shape(batchSize), device); + } + } + NDArray[] result = JnaUtils.cachedOpInvoke(getParent(), getHandle(), allInputsNDArray); + return new NDList(result); + } + + /** {@inheritDoc} */ + @Override + public void close() { + if (!getClosed()) { + logger.debug(String.format("Start to free CachedOp instance: %S", this.getUid())); + super.freeSubResources(); + Pointer pointer = handle.getAndSet(null); + if (pointer != null) { + JnaUtils.freeCachedOp(pointer); + } + setClosed(true); + logger.debug(String.format("Finish to free CachedOp instance: %S", this.getUid())); + } + } + + private int indexOf(String inputName, int position) { + if (inputName == null) { + return dataIndices.valueAt(position); + } + + Integer index = dataIndicesMap.get(inputName); + if (index == null) { + throw new IllegalArgumentException( + "Unknown input name: " + + inputName + + ", expected inputs: " + + dataIndicesMap.keySet().toString()); + } + return index; + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Device.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Device.java new file mode 100644 index 000000000000..7fb74c6c6f9d --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Device.java @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.engine; + +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; +import org.apache.mxnet.util.cuda.CudaUtils; + +/** + * The {@code Device} class provides the specified assignment for CPU/GPU processing on the {@link + * org.apache.mxnet.ndarray.NDArray}. + * + *

Users can use this to specify whether to load/compute the {@code NDArray} on CPU/GPU with + * deviceType and deviceId provided. + */ +public final class Device { + + private static final Map CACHE = new ConcurrentHashMap<>(); + + private static final Device CPU = new Device(Type.CPU, -1); + + private static final Device GPU = Device.of(Type.GPU, 0); + + private String deviceType; + + private int deviceId; + + private static final Device DEFAULT_DEVICE = CPU; + + /** + * Creates a {@code Device} with basic information. + * + * @param deviceType the device type, typically CPU or GPU + * @param deviceId the deviceId on the hardware. For example, if you have multiple GPUs, you can + * choose which GPU to process the NDArray + */ + private Device(String deviceType, int deviceId) { + this.deviceType = deviceType; + this.deviceId = deviceId; + } + + /** + * Returns a {@code Device} with device type and device id. + * + * @param deviceType the device type, typically CPU or GPU + * @param deviceId the deviceId on the hardware. + * @return a {@code Device} instance + */ + public static Device of(String deviceType, int deviceId) { + if (Type.CPU.equals(deviceType)) { + return CPU; + } + String key = deviceType + '-' + deviceId; + return CACHE.computeIfAbsent(key, k -> new Device(deviceType, deviceId)); + } + + /** + * Returns the device type of the Device. + * + * @return the device type of the Device + */ + public String getDeviceType() { + return deviceType; + } + + /** + * Returns the {@code deviceId} of the Device. + * + * @return the {@code deviceId} of the Device + */ + public int getDeviceId() { + if (Type.CPU.equals(deviceType)) { + throw new IllegalStateException("CPU doesn't have device id"); + } + return deviceId; + } + + /** {@inheritDoc} */ + @Override + public String toString() { + if (Type.CPU.equals(deviceType)) { + return deviceType + "()"; + } + return deviceType + '(' + deviceId + ')'; + } + + /** {@inheritDoc} */ + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Device device = (Device) o; + if (Type.CPU.equals(deviceType)) { + return Objects.equals(deviceType, device.getDeviceType()); + } + return deviceId == device.deviceId && Objects.equals(deviceType, device.deviceType); + } + + /** {@inheritDoc} */ + @Override + public int hashCode() { + return Objects.hash(deviceType, deviceId); + } + + /** + * Returns the default CPU Device. + * + * @return the default CPU Device + */ + public static Device cpu() { + return CPU; + } + + /** + * Returns the default GPU Device. + * + * @return the default GPU Device + */ + public static Device gpu() { + return GPU; + } + + /** + * Returns a new instance of GPU {@code Device} with the specified {@code deviceId}. + * + * @param deviceId the GPU device ID + * @return a new instance of GPU {@code Device} with specified {@code deviceId} + */ + public static Device gpu(int deviceId) { + return of(Type.GPU, deviceId); + } + + /** + * Returns an array of devices. + * + *

If GPUs are available, it will return an array of {@code Device} of size + * \(min(numAvailable, maxGpus)\). Else, it will return an array with a single CPU device. + * + * @return an array of devices + */ + public static Device[] getDevices() { + return getDevices(Integer.MAX_VALUE); + } + + /** + * Returns an array of devices given the maximum number of GPUs to use. + * + *

If GPUs are available, it will return an array of {@code Device} of size + * \(min(numAvailable, maxGpus)\). Else, it will return an array with a single CPU device. + * + * @param maxGpus the max number of GPUs to use. Use 0 for no GPUs. + * @return an array of devices + */ + public static Device[] getDevices(int maxGpus) { + int count = getGpuCount(); + if (maxGpus <= 0 || count <= 0) { + return new Device[] {CPU}; + } + + count = Math.min(maxGpus, count); + Device[] devices = new Device[count]; + for (int i = 0; i < devices.length; ++i) { + devices[i] = gpu(i); + } + return devices; + } + + /** + * Returns the number of GPUs available in the system. + * + * @return the number of GPUs available in the system + */ + public static int getGpuCount() { + return CudaUtils.getGpuCount(); + } + + /** + * Returns the default context used in Engine. + * + *

The default type is defined by whether the deep learning engine is recognizing GPUs + * available on your machine. If there is no GPU available, CPU will be used. + * + * @return a {@link Device} + */ + private static Device defaultDevice() { + return DEFAULT_DEVICE; + } + + /** + * Returns the given device or the default if it is null. + * + * @param device the device to try to return + * @return the given device or the default if it is null + */ + public static Device defaultIfNull(Device device) { + if (device != null) { + return device; + } + return defaultDevice(); + } + + /** + * Returns the default device. + * + * @return the default device + */ + public static Device defaultIfNull() { + return defaultIfNull(null); + } + + /** Contains device type string constants. */ + public interface Type { + String CPU = "cpu"; + String GPU = "gpu"; + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/DeviceType.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/DeviceType.java new file mode 100644 index 000000000000..02ad120e8179 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/DeviceType.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.engine; + +/** {@code DeviceType} is a class used to map the Device name to their corresponding type number. */ +public final class DeviceType { + + private static final String CPU_PINNED = "cpu_pinned"; + + private DeviceType() {} + + /** + * Converts a {@link Device} to the corresponding MXNet device number. + * + * @param device the java {@link Device} + * @return the MXNet device number + * @exception IllegalArgumentException the device is null or is not supported + */ + public static int toDeviceType(Device device) { + if (device == null) { + throw new IllegalArgumentException("Unsupported device: null"); + } + + String deviceType = device.getDeviceType(); + + if (Device.Type.CPU.equals(deviceType)) { + return 1; + } else if (Device.Type.GPU.equals(deviceType)) { + return 2; + } else if (CPU_PINNED.equals(deviceType)) { + return 3; + } else { + throw new IllegalArgumentException("Unsupported device: " + device.toString()); + } + } + + /** + * Converts from an MXNet device number to {@link Device}. + * + * @param deviceType the MXNet device number + * @return the corresponding {@link Device} + */ + public static String fromDeviceType(int deviceType) { + switch (deviceType) { + case 1: + case 3: + // hide the CPU_PINNED to frontend user + // but the advance user can still create CPU_PINNED + // to pass through engine + return Device.Type.CPU; + case 2: + return Device.Type.GPU; + default: + throw new IllegalArgumentException("Unsupported deviceType: " + deviceType); + } + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/GradReq.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/GradReq.java new file mode 100644 index 000000000000..c388de38297d --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/GradReq.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.engine; + +/** An enum that indicates whether gradient is required. */ +public enum GradReq { + NULL("null", 0), + WRITE("write", 1), + ADD("add", 3); + + private String type; + private int value; + + GradReq(String type, int value) { + this.type = type; + this.value = value; + } + + /** + * Gets the type of this {@code GradReq}. + * + * @return the type + */ + public String getType() { + return type; + } + + /** + * Gets the value of this {@code GradType}. + * + * @return the value + */ + public int getValue() { + return value; + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Model.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Model.java new file mode 100644 index 000000000000..90cceb4db0b9 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Model.java @@ -0,0 +1,442 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.engine; + +import java.io.FileNotFoundException; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import org.apache.mxnet.jna.JnaUtils; +import org.apache.mxnet.ndarray.NDArray; +import org.apache.mxnet.ndarray.NDList; +import org.apache.mxnet.ndarray.types.DataType; +import org.apache.mxnet.ndarray.types.Shape; +import org.apache.mxnet.nn.Parameter; +import org.apache.mxnet.nn.SymbolBlock; +import org.apache.mxnet.repository.Item; +import org.apache.mxnet.repository.Repository; +import org.apache.mxnet.translate.NoOpTranslator; +import org.apache.mxnet.translate.Translator; +import org.apache.mxnet.util.PairList; +import org.apache.mxnet.util.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A model is a collection of artifacts that is created by the training process. + * + *

Model contains methods to load and process a model object. In addition, it provides MXNet + * Specific functionality, such as getSymbol to obtain the Symbolic graph and getParameters to + * obtain the parameter NDArrays + */ +public class Model extends MxResource { + + private static final Logger logger = LoggerFactory.getLogger(Model.class); + protected Path modelDir; + protected SymbolBlock symbolBlock; + protected String modelName; + protected DataType dataType; + protected PairList inputData; + protected Map artifacts = new ConcurrentHashMap<>(); + protected Map properties = new ConcurrentHashMap<>(); + + Model(String name, Device device) { + this(BaseMxResource.getSystemMxResource(), name, device); + } + + private Model(MxResource parent, String name, Device device) { + super(parent); + setDevice(Device.defaultIfNull(device)); + setDataType(DataType.FLOAT32); + setModelName(name); + } + + /** + * Create a default {@link Predictor} instance, with {@link NoOpTranslator} as default + * translator , and do not copy parameters to parameter store. + * + * @return {@link Predictor} + */ + public Predictor newPredictor() { + Translator noOpTranslator = new NoOpTranslator(); + return newPredictor(noOpTranslator); + } + + /** + * Create {@link Predictor} instance, with specific {@link Translator} and {@code copy}. + * + * @param translator {@link Translator} used to convert inputs and outputs into {@link NDList} + * to get inferred + * @param the input type + * @param the output type + * @return {@link Predictor} + */ + public Predictor newPredictor(Translator translator) { + return new Predictor<>(this, translator); + } + + /** + * Create and initialize a MxModel from the model directory. + * + * @param modelPath {@code Path} model directory + * @return loaded {@code Model} instance + * @throws IOException when IO operation fails in loading a resource + */ + public static Model loadModel(Path modelPath) throws IOException { + return loadModel("model", modelPath); + } + + /** + * Create and initialize a MxModel from repository Item. + * + * @param modelItem {@link Item} model directory + * @return {@link Model} + * @throws IOException when IO operation fails in loading a resource + */ + public static Model loadModel(Item modelItem) throws IOException { + Model model = createModel(modelItem); + model.initial(); + return model; + } + + /** + * Create and initialize a MxModel with a model name from the model directory. + * + * @param modelName {@link String} model name + * @param modelPath {@link Path} model directory + * @return {@link Model} + * @throws IOException when IO operation fails in loading a resource + */ + public static Model loadModel(String modelName, Path modelPath) throws IOException { + Model model = createModel(modelName, modelPath); + model.initial(); + return model; + } + + /** + * Create a MxModel with specific model name and model directory. By default, the {@link Model} + * instance is managed by the top level {@link BaseMxResource}. + * + * @param modelName {@String} model name + * @param modelDir {@Path} local model path + * @return {@link Model} + * @throws IOException when IO operation fails in loading a resource + */ + static Model createModel(String modelName, Path modelDir) { + Model model = new Model(modelName, Device.defaultIfNull()); + model.setModelDir(modelDir); + return model; + } + + /** + * Create a sample MxModel Download or find the local path for the sample model. + * + * @param item {@link Item} sample model to be created + * @return created {@link Model} instance + * @throws IOException when IO operation fails in loading a resource + */ + static Model createModel(Item item) throws IOException { + Path modelDir = Repository.initRepository(item); + return createModel(item.getName(), modelDir); + } + + /** + * Initialize the model object Download or find the path for target model Load parameters and + * symbol from the path. + * + * @throws IOException when IO operation fails in loading a resource + * @throws FileNotFoundException if Model Directory is not assigned + */ + public void initial() throws IOException { + if (getModelDir() == null) { + throw new FileNotFoundException("Model path is not defined!"); + } + load(getModelDir()); + } + + /** + * Loads the model from the {@code modelPath}. + * + * @param modelPath the directory or file path of the model location + * @throws IOException when IO operation fails in loading a resource + */ + public void load(Path modelPath) throws IOException { + load(modelPath, null, null); + } + + /** + * Loads the MXNet model from a specified location. + * + *

MXNet Model looks for {MODEL_NAME}-symbol.json and {MODEL_NAME}-{EPOCH}.params files in + * the specified directory. By default, It will pick up the latest epoch of the parameter file. + * However, users can explicitly specify an epoch to be loaded: + * + *

+     * Map<String, String> options = new HashMap<>()
+     * options.put("epoch", "3");
+     * model.load(modelPath, "squeezenet", options);
+     * 
+ * + * @param modelPath the directory of the model + * @param prefix the model file name or path prefix + * @param options load model options, see documentation for the specific engine + * @throws IOException Exception for file loading + */ + public void load(Path modelPath, String prefix, Map options) throws IOException { + modelDir = modelPath.toAbsolutePath(); + if (prefix == null) { + prefix = modelName; + } + Path paramFile = paramPathResolver(prefix, options); + if (paramFile == null) { + prefix = modelDir.toFile().getName(); + paramFile = paramPathResolver(prefix, options); + if (paramFile == null) { + throw new FileNotFoundException( + "Parameter file with prefix: " + prefix + " not found in: " + modelDir); + } + } + + if (getSymbolBlock() == null) { + // load MxSymbolBlock + Path symbolFile = modelDir.resolve(prefix + "-symbol.json"); + if (Files.notExists(symbolFile)) { + throw new FileNotFoundException( + "Symbol file not found: " + + symbolFile + + ", please set block manually for imperative model."); + } + + // TODO: change default name "data" to model-specific one + setMxSymbolBlock(SymbolBlock.createMxSymbolBlock(this, symbolFile)); + } + loadParameters(paramFile); + // TODO: Check if Symbol has all names that params file have + if (options != null && options.containsKey("MxOptimizeFor")) { + String optimization = (String) options.get("MxOptimizeFor"); + getSymbolBlock().optimizeFor(optimization, getDevice()); + } + } + + protected Path paramPathResolver(String prefix, Map options) throws IOException { + try { + int epoch = getEpoch(prefix, options); + return getModelDir() + .resolve(String.format(Locale.ROOT, "%s-%04d.params", prefix, epoch)); + } catch (FileNotFoundException e) { + return null; + } + } + + private int getEpoch(String prefix, Map options) throws IOException { + if (options != null) { + Object epochOption = options.getOrDefault("epoch", null); + if (epochOption != null) { + return Integer.parseInt(epochOption.toString()); + } + } + return Utils.getCurrentEpoch(getModelDir(), prefix); + } + + @SuppressWarnings("PMD.UseConcurrentHashMap") + private void loadParameters(Path paramFile) { + + NDList paramNDlist = JnaUtils.loadNdArray(this, paramFile, getDevice()); + + List parameters = getSymbolBlock().getAllParameters(); + Map map = new LinkedHashMap<>(); + parameters.forEach(p -> map.put(p.getName(), p)); + + for (NDArray nd : paramNDlist) { + String key = nd.getName(); + if (key == null) { + throw new IllegalArgumentException("Array names must be present in parameter file"); + } + + String paramName = key.split(":", 2)[1]; + Parameter parameter = map.remove(paramName); + parameter.setArray(nd); + } + getSymbolBlock().setInputNames(new ArrayList<>(map.keySet())); + + // TODO: Find a better to infer model DataType from SymbolBlock. + dataType = paramNDlist.head().getDataType(); + logger.debug("MXNet Model {} ({}) loaded successfully.", paramFile, dataType); + } + + /** + * Get the modelDir from the Model. + * + * @return {@link Path} modelDir for the Model + */ + public Path getModelDir() { + return modelDir; + } + + /** + * Set the modelDir for the Model. + * + * @param modelDir {@link Path} + */ + public void setModelDir(Path modelDir) { + this.modelDir = modelDir; + } + + /** + * Get the symbolBlock of the Model. + * + * @return {@link SymbolBlock} + */ + public SymbolBlock getSymbolBlock() { + return symbolBlock; + } + + /** + * Set the symbolBlock for the Model. + * + * @param symbolBlock {@link SymbolBlock} + */ + public void setMxSymbolBlock(SymbolBlock symbolBlock) { + this.symbolBlock = symbolBlock; + } + + /** + * Get the name of the Model. + * + * @return modelName + */ + public String getModelName() { + return modelName; + } + + /** + * Set the model name for the Model. + * + * @param modelName for the Model + */ + public final void setModelName(String modelName) { + this.modelName = modelName; + } + + /** + * Get data type for the Model. + * + * @return {@link DataType} + */ + public DataType getDataType() { + return dataType; + } + + /** + * Set data type for the Model. + * + * @param dataType {@link DataType} + */ + public final void setDataType(DataType dataType) { + this.dataType = dataType; + } + + /** + * Get input data of the Model. + * + * @return {@link PairList} inputData + */ + public PairList getInputData() { + return inputData; + } + + /** + * Set input data for the Model. + * + * @param inputData {@link PairList} + */ + public void setInputData(PairList inputData) { + this.inputData = inputData; + } + + /** + * Get the Artifact Object from artifacts by key. + * + * @param key for the Artifact Object + * @return Artifact {@link Object} instance + */ + public Object getArtifact(String key) { + return artifacts.get(key); + } + + /** + * Set the Artifact Object for artifacts. + * + * @param key for the Artifact + * @param artifact {@link Object} + */ + public void setArtifact(String key, Object artifact) { + artifacts.put(key, artifact); + } + + /** + * Get the property from properties by key. + * + * @param key {@link String} + * @return {@link String} property + */ + public String getProperty(String key) { + return properties.get(key); + } + + /** + * Set the property for the Model. + * + * @param key for the property + * @param property value of the property + */ + public void setProperties(String key, String property) { + this.properties.put(key, property); + } + + /** {@inheritDoc} */ + @Override + public Device getDevice() { + if (device == null) { + return super.getDevice(); + } + return device; + } + + /** {@inheritDoc} */ + @Override + public void close() { + if (!getClosed()) { + logger.debug(String.format("Start to free Model instance: %S", this.getModelName())); + // release sub resources + super.freeSubResources(); + // release itself + this.symbolBlock = null; + this.artifacts = null; + this.properties = null; + setClosed(true); + logger.debug(String.format("Finish to free Model instance: %S", this.getModelName())); + } + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/MxDataType.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/MxDataType.java new file mode 100644 index 000000000000..b96941f89177 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/MxDataType.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.engine; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import org.apache.mxnet.ndarray.types.DataType; + +/** Helper to convert between {@link DataType} and the MXNet internal DataTypes. */ +public final class MxDataType { + + private static Map toMx = createMapToMx(); + private static Map fromMx = createMapFromMx(); + + private MxDataType() {} + + private static Map createMapToMx() { + Map map = new ConcurrentHashMap<>(); + map.put(DataType.FLOAT32, "float32"); + map.put(DataType.FLOAT64, "float64"); + map.put(DataType.INT32, "int32"); + map.put(DataType.INT64, "int64"); + map.put(DataType.UINT8, "uint8"); + return map; + } + + private static Map createMapFromMx() { + Map map = new ConcurrentHashMap<>(); + map.put("float32", DataType.FLOAT32); + map.put("float64", DataType.FLOAT64); + map.put("int32", DataType.INT32); + map.put("int64", DataType.INT64); + map.put("uint8", DataType.UINT8); + return map; + } + + /** + * Converts a MXNet type String into a {@link DataType}. + * + * @param mxType the type String to convert + * @return the {@link DataType} + */ + public static DataType fromMx(String mxType) { + return fromMx.get(mxType); + } + + /** + * Converts a {@link DataType} into the corresponding MXNet type String. + * + * @param jType the java {@link DataType} to convert + * @return the converted MXNet type string + */ + public static String toMx(DataType jType) { + return toMx.get(jType); + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/MxResource.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/MxResource.java new file mode 100644 index 000000000000..5740d9b5341a --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/MxResource.java @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.engine; + +import com.sun.jna.Pointer; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import org.apache.mxnet.util.NativeResource; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * An auto closable Resource object whose life circle can be managed by its parent {@link + * MxResource} instance. Meanwhile, it manages life circle of child {@link MxResource} instances. + */ +public class MxResource extends NativeResource { + + private static final Logger logger = LoggerFactory.getLogger(MxResource.class); + + private static boolean closed; + + protected Device device; + + private MxResource parent; + + private ConcurrentHashMap subResources; + + protected MxResource() { + super(); + setParent(null); + } + + protected MxResource(MxResource parent, String uid) { + super(uid); + setClosed(false); + setParent(parent); + getParent().addSubResource(this); + } + + protected MxResource(MxResource parent) { + this(parent, UUID.randomUUID().toString()); + } + + protected MxResource(MxResource parent, Pointer handle) { + super(handle); + setParent(parent); + if (parent != null) { + parent.addSubResource(this); + } else { + BaseMxResource.getSystemMxResource().addSubResource(this); + } + } + /** + * Add the sub {@link MxResource} under the current instance. + * + * @param subMxResource the instance to be added + */ + public void addSubResource(MxResource subMxResource) { + getSubResource().put(subMxResource.getUid(), subMxResource); + } + + /** Free all sub {@link MxResource} instances of the current instance. */ + public void freeSubResources() { + if (subResourceInitialized()) { + for (MxResource subResource : subResources.values()) { + try { + subResource.close(); + } catch (Exception e) { + logger.error("MxResource close failed.", e); + } + } + subResources = null; + } + } + + /** + * Check whether {@code subResource} has been initialized. + * + * @return boolean + */ + public boolean subResourceInitialized() { + return subResources != null; + } + + /** + * Get the {@code subResources} of the {@link MxResource}. + * + * @return subResources + */ + public ConcurrentHashMap getSubResource() { + if (!subResourceInitialized()) { + subResources = new ConcurrentHashMap<>(); + } + return subResources; + } + + protected final void setParent(MxResource parent) { + this.parent = parent; + } + + /** + * Get parent {@link MxResource} of the current instance. + * + * @return {@link MxResource} + */ + public MxResource getParent() { + return this.parent; + } + + /** + * Set the {@link Device} for the {@link MxResource}. + * + * @param device {@link Device} + */ + public void setDevice(Device device) { + this.device = device; + } + + /** + * Returns the {@link Device} of this {@code MxResource}. + * + *

{@link Device} class contains the information where this {@code NDArray} stored in memory, + * like CPU/GPU. + * + * @return the {@link Device} of this {@code MxResource} + */ + public Device getDevice() { + Device curDevice = getParent() == null ? null : getParent().getDevice(); + return Device.defaultIfNull(curDevice); + } + + /** + * Sets closed for MxResource instance. + * + * @param isClosed whether this {@link MxResource} get closed + */ + public final void setClosed(boolean isClosed) { + this.closed = isClosed; + } + + /** + * Get the attribute closed for the MxResource to check out whether it is closed. + * + * @return closed + */ + public boolean getClosed() { + return closed; + } + + /** {@inheritDoc} */ + @Override + public void close() { + freeSubResources(); + setClosed(true); + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/MxResourceList.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/MxResourceList.java new file mode 100644 index 000000000000..6869cf6f0738 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/MxResourceList.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.engine; + +import java.util.List; +import java.util.Map; +import org.apache.mxnet.util.Pair; +import org.apache.mxnet.util.PairList; + +/** + * An {@code MxResourceList} represents a sequence of {@link MxResource}s with names. + * + *

Each {@link MxResource} in this list can optionally have a name. You can use the name to look + * up an MxResource in the MxResourceList. + * + * @see MxResource + */ +public class MxResourceList extends PairList { + + /** Creates an empty {@code MxResourceList}. */ + public MxResourceList() {} + + /** + * Constructs an empty {@code MxResourceList} with the specified initial capacity. + * + * @param initialCapacity the initial capacity of the list + * @throws IllegalArgumentException if the specified initial capacity is negative + */ + public MxResourceList(int initialCapacity) { + super(initialCapacity); + } + + /** + * Constructs a {@code BlockList} containing the elements of the specified keys and values. + * + * @param keys the key list containing the elements to be placed into this {@code + * MxResourceList} + * @param values the value list containing the elements to be placed into this {@code + * MxResource} + * @throws IllegalArgumentException if the keys and values size are different + */ + public MxResourceList(List keys, List values) { + super(keys, values); + } + + /** + * Constructs a {@code BlockList} containing the elements of the specified list of Pairs. + * + * @param list the list containing the elements to be placed into this {@code MxResourceList} + */ + public MxResourceList(List> list) { + super(list); + } + + /** + * Constructs a {@code BlockList} containing the elements of the specified map. + * + * @param map the map containing keys and values + */ + public MxResourceList(Map map) { + super(map); + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/OpParams.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/OpParams.java new file mode 100644 index 000000000000..3362e03217d7 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/OpParams.java @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.engine; + +import org.apache.mxnet.ndarray.types.Shape; +import org.apache.mxnet.ndarray.types.SparseFormat; +import org.apache.mxnet.util.PairList; + +/** An internal helper for creating the MXNet operator parameters. */ +public class OpParams extends PairList { + // mxnet cpu take index + private static final String MXNET_CPU = "cpu(0)"; + /** + * Sets the Shape parameter. + * + * @param shape the shape to set + */ + public void setShape(Shape shape) { + addParam("shape", shape); + } + + /** + * Sets the device to use for the operation. + * + * @param device the device to use for the operation + */ + public void setDevice(Device device) { + setParam("ctx", ("cpu".equals(device.getDeviceType()) ? MXNET_CPU : device.toString())); + } + + /** + * Sets the dataType to use for the operation. + * + * @param dataType the dataType to use for the operation + */ + public void setDataType(org.apache.mxnet.ndarray.types.DataType dataType) { + if (dataType != null) { + setParam("dtype", MxDataType.toMx(dataType)); + } + } + + /** + * Sets the sparseFormat to use for the operation. + * + * @param sparseFormat the sparseFormat to use for the operation + */ + public void setSparseFormat(SparseFormat sparseFormat) { + if (sparseFormat != null) { + setParam("stype", String.valueOf(sparseFormat.getValue())); + } + } + + /** + * Sets a (potentially existing) parameter to a new value. + * + * @param paramName the parameter name to update + * @param value the value to set the parameter to + */ + public void setParam(String paramName, String value) { + remove(paramName); + add(paramName, value); + } + + /** + * Adds a parameter. + * + * @param paramName the name of the new parameter + * @param shape the value of the new parameter + */ + public void addParam(String paramName, Shape shape) { + if (shape != null) { + add(paramName, shape.toString()); + } + } + + /** + * Adds a parameter. + * + * @param paramName the name of the new parameter + * @param value the value of the new parameter + */ + public void addParam(String paramName, String value) { + add(paramName, value); + } + + /** + * Adds a parameter. + * + * @param paramName the name of the new parameter + * @param value the value of the new parameter + */ + public void addParam(String paramName, int value) { + add(paramName, String.valueOf(value)); + } + + /** + * Adds a parameter. + * + * @param paramName the name of the new parameter + * @param value the value of the new parameter + */ + public void addParam(String paramName, long value) { + add(paramName, String.valueOf(value)); + } + + /** + * Adds a parameter. + * + * @param paramName the name of the new parameter + * @param value the value of the new parameter + */ + public void addParam(String paramName, double value) { + add(paramName, String.valueOf(value)); + } + + /** + * Adds a parameter. + * + * @param paramName the name of the new parameter + * @param value the value of the new parameter + */ + public void addParam(String paramName, float value) { + add(paramName, String.valueOf(value)); + } + + /** + * Adds a parameter. + * + * @param paramName the name of the new parameter + * @param value the value of the new parameter + */ + public void addParam(String paramName, boolean value) { + add(paramName, value ? "True" : "False"); + } + + /** + * Adds a parameter. + * + * @param paramName the name of the new parameter + * @param value the value of the new parameter + */ + public void addParam(String paramName, Number value) { + add(paramName, String.valueOf(value)); + } + + /** + * Adds a parameter with tuple value. + * + * @param paramName the name of the new parameter + * @param tuple the values of the new parameter + */ + public void addTupleParam(String paramName, int... tuple) { + StringBuilder sb = new StringBuilder(); + sb.append('('); + for (int i = 0; i < tuple.length; ++i) { + if (i > 0) { + sb.append(", "); + } + sb.append(tuple[i]); + } + sb.append(')'); + add(paramName, sb.toString()); + } + + /** + * Adds a parameter with tuple value. + * + * @param paramName the name of the new parameter + * @param tuple the values of the new parameter + */ + public void addTupleParam(String paramName, long... tuple) { + StringBuilder sb = new StringBuilder(); + sb.append('('); + for (int i = 0; i < tuple.length; ++i) { + if (i > 0) { + sb.append(", "); + } + sb.append(tuple[i]); + } + sb.append(')'); + add(paramName, sb.toString()); + } + + /** + * Adds a parameter with tuple value. + * + * @param paramName the name of the new parameter + * @param tuple the values of the new parameter + */ + public void addTupleParam(String paramName, float... tuple) { + StringBuilder sb = new StringBuilder(); + sb.append('('); + for (int i = 0; i < tuple.length; ++i) { + if (i > 0) { + sb.append(", "); + } + sb.append(tuple[i]); + } + sb.append(')'); + add(paramName, sb.toString()); + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Predictor.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Predictor.java new file mode 100644 index 000000000000..739f6e0b65bf --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Predictor.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.engine; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import org.apache.mxnet.exception.TranslateException; +import org.apache.mxnet.ndarray.NDList; +import org.apache.mxnet.translate.Translator; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * The {@code Predictor} class provides a session for model inference. + * + *

You can use a {@code Predictor}, with a specified {@link Translator}, to perform inference on + * a {@link Model} + * + * @param the input type + * @param the output type + * @see Model + * @see Translator + * @see The guide on memory + * management + * @see The + * guide on running multi-threaded inference + * @see The + * guide on inference performance optimization + */ +public class Predictor extends MxResource { + + private static final Logger logger = LoggerFactory.getLogger(Predictor.class); + private Translator translator; + private Model model; + + /** + * Creates a new instance of {@code Predictor} with the given {@link Model} and {@link + * Translator}. + * + * @param model the model on which the predictions are based + * @param translator the translator to be used + */ + public Predictor(Model model, Translator translator) { + super(model); + this.model = model; + this.translator = translator; + } + + /** + * Predicts an item for inference. + * + * @param input the input + * @return the output object defined by the user + * @throws TranslateException if an error occurs during prediction + */ + @SuppressWarnings("PMD.AvoidRethrowingException") + public List predict(List input) { + NDList[] ndLists = processInputs(input); + for (int i = 0; i < ndLists.length; ++i) { + ndLists[i] = forward(ndLists[i]); + } + return processOutPut(ndLists); + } + + /** + * Predicts an Item for inference. + * + * @param input input data + * @return O the output object defined by the user + * @throws TranslateException if an error occurs during prediction + */ + public O predict(I input) { + return predict(Collections.singletonList(input)).get(0); + } + + private NDList forward(NDList ndList) { + logger.trace("Predictor input data: {}", ndList); + return model.getSymbolBlock().forward(ndList); + } + + // TODO: add batch predict + + private NDList[] processInputs(List inputs) { + int batchSize = inputs.size(); + NDList[] preprocessed = new NDList[batchSize]; + try { + for (int i = 0; i < batchSize; ++i) { + preprocessed[i] = translator.processInput(inputs.get(i)); + } + } catch (Exception e) { + logger.error("Error occurs when process input items.", e); + throw new TranslateException(e); + } + return preprocessed; + } + + private List processOutPut(NDList[] ndLists) { + List outputs = new ArrayList<>(); + try { + for (NDList mxNDList : ndLists) { + outputs.add(translator.processOutput(mxNDList)); + } + } catch (Exception e) { + logger.error("Error occurs when process output items.", e); + throw new TranslateException(e); + } + return outputs; + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Symbol.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Symbol.java new file mode 100644 index 000000000000..7777add3c5ba --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Symbol.java @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.engine; + +import com.sun.jna.Pointer; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.Collectors; +import org.apache.mxnet.jna.JnaUtils; +import org.apache.mxnet.ndarray.types.Shape; +import org.apache.mxnet.util.PairList; +import org.apache.mxnet.util.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * {@code Symbol} is an internal helper for symbolic model graphs used by the {@link + * org.apache.mxnet.nn.SymbolBlock}. + * + * @see org.apache.mxnet.nn.SymbolBlock + */ +public class Symbol extends MxResource { + + private static final Logger logger = LoggerFactory.getLogger(Symbol.class); + + private String[] outputs; + + protected Symbol(MxResource parent, Pointer handle) { + super(parent, handle); + } + + static Symbol loadFromFile(MxResource parent, String path) { + Pointer p = JnaUtils.createSymbolFromFile(path); + return new Symbol(parent, p); + } + + /** + * Load {@link Symbol} from the given {@link Path}. + * + * @param parent the parent {@link MxResource} + * @param path the {@link Path} to load the {@link Symbol} + * @return {@link Symbol} + */ + public static Symbol loadSymbol(MxResource parent, Path path) { + return loadFromFile(parent, path.toAbsolutePath().toString()); + } + + /** + * Loads a symbol from a json string. + * + * @param parent the parent {@link MxResource} + * @param json the json string of the symbol. + * @return the new symbol + */ + public static Symbol loadJson(MxResource parent, String json) { + Pointer pointer = JnaUtils.createSymbolFromString(json); + return new Symbol(parent, pointer); + } + + /** + * Returns the symbol outputs. + * + * @return the symbol outputs + */ + public String[] getOutputNames() { + if (this.outputs == null) { + this.outputs = JnaUtils.listSymbolOutputs(getHandle()); + } + return this.outputs; + } + + /** {@inheritDoc} */ + @Override + public void close() { + if (!getClosed()) { + logger.debug(String.format("Start to free Symbol instance: %S", this.toJsonString())); + super.freeSubResources(); + Pointer pointer = handle.getAndSet(null); + if (pointer != null) { + JnaUtils.freeSymbol(pointer); + } + setClosed(true); + logger.debug(String.format("Finish to free Symbol instance: %S", this.toJsonString())); + } + } + + /** + * Returns the output symbol by index. + * + * @param index the index of the output + * @return the symbol output as a new symbol + */ + public Symbol get(int index) { + Pointer pointer = JnaUtils.getSymbolOutput(getInternals().getHandle(), index); + return new Symbol(getParent(), pointer); + } + + /** + * Returns the output symbol with the given name. + * + * @param name the name of the symbol to return + * @return the output symbol + * @throws IllegalArgumentException Thrown if no output matches the name + */ + public Symbol get(String name) { + String[] out = getInternalOutputNames(); + int index = Utils.indexOf(out, name); + if (index < 0) { + throw new IllegalArgumentException("Cannot find output that matches name: " + name); + } + return get(index); + } + + /** + * Returns the symbol argument names. + * + * @return the symbol argument names + */ + public String[] getArgNames() { + return JnaUtils.listSymbolArguments(getHandle()); + } + + /** + * Returns the MXNet auxiliary states for the symbol. + * + * @return the MXNet auxiliary states for the symbol + */ + public String[] getAuxNames() { + return JnaUtils.listSymbolAuxiliaryStates(getHandle()); + } + + /** + * Returns the symbol names. + * + * @return the symbol names + */ + public String[] getAllNames() { + return JnaUtils.listSymbolNames(getHandle()); + } + + /** + * Returns the list of names for all internal outputs. + * + * @return a list of names + */ + public List getLayerNames() { + String[] outputNames = getInternalOutputNames(); + String[] allNames = getAllNames(); + Set allNamesSet = new LinkedHashSet<>(Arrays.asList(allNames)); + // Kill all params field and keep the output layer + return Arrays.stream(outputNames) + .filter(n -> !allNamesSet.contains(n)) + .collect(Collectors.toList()); + } + + private String[] getInternalOutputNames() { + return JnaUtils.listSymbolOutputs(getInternals().getHandle()); + } + + /** + * Returns the symbol internals. + * + * @return the symbol internals symbol + */ + public Symbol getInternals() { + Pointer pointer = JnaUtils.getSymbolInternals(getHandle()); + return new Symbol(getParent(), pointer); + } + + /** + * Infers the shapes for all parameters inside a symbol from the given input shapes. + * + * @param pairs the given input name and shape + * @return a map of arguments with names and shapes + */ + public Map inferShape(PairList pairs) { + List> shapes = JnaUtils.inferShape(this, pairs); + if (shapes == null) { + throw new IllegalArgumentException("Cannot infer shape based on the data provided!"); + } + List argShapes = shapes.get(0); + List outputShapes = shapes.get(1); + List auxShapes = shapes.get(2); + // TODO: add output to the map + String[] argNames = getArgNames(); + String[] auxNames = getAuxNames(); + String[] outputNames = getOutputNames(); + Map shapesMap = new ConcurrentHashMap<>(); + for (int i = 0; i < argNames.length; i++) { + shapesMap.put(argNames[i], argShapes.get(i)); + } + for (int i = 0; i < auxNames.length; i++) { + shapesMap.put(auxNames[i], auxShapes.get(i)); + } + for (int i = 0; i < outputNames.length; i++) { + shapesMap.put(outputNames[i], outputShapes.get(i)); + } + return shapesMap; + } + + /** + * [Experimental] Add customized optimization on the Symbol. + * + *

This method can be used with EIA or TensorRT for model acceleration + * + * @param backend backend name + * @param device the device assigned + * @return optimized Symbol + */ + public Symbol optimizeFor(String backend, Device device) { + return new Symbol(getParent(), JnaUtils.optimizeFor(this, backend, device)); + } + + /** + * Converts Symbol to json string for saving purpose. + * + * @return the json string + */ + public String toJsonString() { + return JnaUtils.getSymbolString(getHandle()); + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/package-info.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/package-info.java new file mode 100644 index 000000000000..e88f05d7b3ec --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/package-info.java @@ -0,0 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Contains the Java front-end implementation for Apache MXNet. */ +package org.apache.mxnet.engine; diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/BaseException.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/BaseException.java new file mode 100644 index 000000000000..0123a500dddf --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/BaseException.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.exception; + +/** Thrown to indicate that a native error is raised from the underlying. */ +public class BaseException extends RuntimeException { + + private static final long serialVersionUID = 1L; + + /** + * Constructs a new exception with the specified detail message. The cause is not initialized, + * and may subsequently be initialized by a call to {@link #initCause}. + * + * @param message the detail message. The detail message is saved for later retrieval by the + * {@link #getMessage()} method. + */ + public BaseException(String message) { + super(message); + } + + /** + * Constructs a new exception with the specified detail message and cause. + * + *

Note that the detail message associated with {@code cause} is not automatically + * incorporated in this exception's detail message. + * + * @param message the detail message (which is saved for later retrieval by the {@link + * #getMessage()} method) + * @param cause the cause (which is saved for later retrieval by the {@link #getCause()} + * method). (A {@code null} value is permitted, and indicates that the cause is nonexistent + * or unknown.) + */ + public BaseException(String message, Throwable cause) { + super(message, cause); + } + + /** + * Constructs a new exception with the specified cause and a detail message of {@code + * (cause==null ? null : cause.toString())} (which typically contains the class and detail + * message of {@code cause}). This constructor is useful for exceptions that are little more + * than wrappers for other throwables (for example, {@link + * java.security.PrivilegedActionException}). + * + * @param cause the cause (which is saved for later retrieval by the {@link #getCause()} + * method). (A {@code null} value is permitted, and indicates that the cause is nonexistent + * or unknown.) + */ + public BaseException(Throwable cause) { + super(cause); + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/JnaCallException.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/JnaCallException.java new file mode 100644 index 000000000000..f6f26fca9cf5 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/JnaCallException.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.exception; + +/** Thrown to indicate JNA functions are not called as expected. */ +public class JnaCallException extends BaseException { + + private static final long serialVersionUID = 1L; + + /** + * Constructs a new exception with the specified detail message. The cause is not initialized, + * and may subsequently be initialized by a call to {@link #initCause}. + * + * @param message the detail message. The detail message is saved for later retrieval by the + * {@link #getMessage()} method. + */ + public JnaCallException(String message) { + super(message); + } + + /** + * \ Constructs a new exception with the specified detail message and cause. + * + *

Note that the detail message associated with {@code cause} is not automatically + * incorporated in this exception's detail message. + * + * @param message the detail message that is saved for later retrieval by the {@link + * #getMessage()} method + * @param cause the cause that is saved for later retrieval by the {@link #getCause()} method. A + * {@code null} value is permitted, and indicates that the cause is nonexistent or unknown + */ + public JnaCallException(String message, Throwable cause) { + super(message, cause); + } + + /** + * Constructs a new exception with the specified cause and a detail message of {@code + * (cause==null ? null : cause.toString())} which typically contains the class and detail + * message of {@code cause}. This constructor is useful for exceptions that are little more than + * wrappers for other throwables. For example, {@link java.security.PrivilegedActionException}. + * + * @param cause the cause that is saved for later retrieval by the {@link #getCause()} method. A + * {@code null} value is permitted, and indicates that the cause is nonexistent or unknown + */ + public JnaCallException(Throwable cause) { + super(cause); + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/MalformedModelException.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/MalformedModelException.java new file mode 100644 index 000000000000..5455633b8c40 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/MalformedModelException.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.exception; + +/** Thrown to indicate Model parameters are not in expected format or are malformed. */ +public class MalformedModelException extends ModelException { + + private static final long serialVersionUID = 1L; + + /** + * Constructs a new exception with the specified detail message. The cause is not initialized, + * and may subsequently be initialized by a call to {@link #initCause}. + * + * @param message the detail message. The detail message is saved for later retrieval by the + * {@link #getMessage()} method. + */ + public MalformedModelException(String message) { + super(message); + } + + /** + * \ Constructs a new exception with the specified detail message and cause. + * + *

Note that the detail message associated with {@code cause} is not automatically + * incorporated in this exception's detail message. + * + * @param message the detail message that is saved for later retrieval by the {@link + * #getMessage()} method + * @param cause the cause that is saved for later retrieval by the {@link #getCause()} method. A + * {@code null} value is permitted, and indicates that the cause is nonexistent or unknown + */ + public MalformedModelException(String message, Throwable cause) { + super(message, cause); + } + + /** + * Constructs a new exception with the specified cause and a detail message of {@code + * (cause==null ? null : cause.toString())} which typically contains the class and detail + * message of {@code cause}. This constructor is useful for exceptions that are little more than + * wrappers for other throwables. For example, {@link java.security.PrivilegedActionException}. + * + * @param cause the cause that is saved for later retrieval by the {@link #getCause()} method. A + * {@code null} value is permitted, and indicates that the cause is nonexistent or unknown + */ + public MalformedModelException(Throwable cause) { + super(cause); + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/ModelException.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/ModelException.java new file mode 100644 index 000000000000..94a12e769174 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/ModelException.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.exception; + +/** Thrown to indicate . */ +public class ModelException extends BaseException { + + private static final long serialVersionUID = 1L; + + /** + * Constructs a new exception with the specified detail message. The cause is not initialized, + * and may subsequently be initialized by a call to {@link #initCause}. + * + * @param message the detail message that is saved for later retrieval by the {@link + * #getMessage()} method + */ + public ModelException(String message) { + super(message); + } + + /** + * Constructs a new exception with the specified detail message and cause. + * + *

Note that the detail message associated with {@code cause} is not automatically + * incorporated in this exception's detail message. + * + * @param message the detail message that is saved for later retrieval by the {@link + * #getMessage()} method + * @param cause the cause that is saved for later retrieval by the {@link #getCause()} method. A + * {@code null} value is permitted, and indicates that the cause is nonexistent or unknown + */ + public ModelException(String message, Throwable cause) { + super(message, cause); + } + + /** + * Constructs a new exception with the specified cause and a detail message of {@code + * (cause==null ? null : cause.toString())} which typically contains the class and detail + * message of {@code cause}. This constructor is useful for exceptions that are little more than + * wrappers for other throwables. For example, {@link java.security.PrivilegedActionException}. + * + * @param cause the cause that is saved for later retrieval by the {@link #getCause()} method. A + * {@code null} value is permitted, and indicates that the cause is nonexistent or unknown + */ + public ModelException(Throwable cause) { + super(cause); + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/TranslateException.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/TranslateException.java new file mode 100644 index 000000000000..b17758a5e325 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/TranslateException.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.exception; + +/** Thrown to indicate Translate pipeline doesn't work as expected. */ +public class TranslateException extends BaseException { + + private static final long serialVersionUID = 1L; + + /** + * Constructs a new exception with the specified detail message. The cause is not initialized, + * and may subsequently be initialized by a call to {@link #initCause}. + * + * @param message the detail message. The detail message is saved for later retrieval by the + * {@link #getMessage()} method. + */ + public TranslateException(String message) { + super(message); + } + + /** + * \ Constructs a new exception with the specified detail message and cause. + * + *

Note that the detail message associated with {@code cause} is not automatically + * incorporated in this exception's detail message. + * + * @param message the detail message that is saved for later retrieval by the {@link + * #getMessage()} method + * @param cause the cause that is saved for later retrieval by the {@link #getCause()} method. A + * {@code null} value is permitted, and indicates that the cause is nonexistent or unknown + */ + public TranslateException(String message, Throwable cause) { + super(message, cause); + } + + /** + * Constructs a new exception with the specified cause and a detail message of {@code + * (cause==null ? null : cause.toString())} which typically contains the class and detail + * message of {@code cause}. This constructor is useful for exceptions that are little more than + * wrappers for other throwables. For example, {@link java.security.PrivilegedActionException}. + * + * @param cause the cause that is saved for later retrieval by the {@link #getCause()} method. A + * {@code null} value is permitted, and indicates that the cause is nonexistent or unknown + */ + public TranslateException(Throwable cause) { + super(cause); + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/package-info.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/package-info.java new file mode 100644 index 000000000000..d464bfd4b7e0 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/package-info.java @@ -0,0 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Contains the Java front-end implementation for Apache MXNet. */ +package org.apache.mxnet.exception; diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/FunctionInfo.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/FunctionInfo.java new file mode 100644 index 000000000000..a2da11699d59 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/FunctionInfo.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.jna; + +import com.sun.jna.Pointer; +import java.util.List; +import org.apache.mxnet.engine.Device; +import org.apache.mxnet.engine.MxResource; +import org.apache.mxnet.ndarray.NDArray; +import org.apache.mxnet.ndarray.types.SparseFormat; +import org.apache.mxnet.util.PairList; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** A FunctionInfo represents an operator (ie function) within the MXNet Engine. */ +public class FunctionInfo { + + private Pointer handle; + private String name; + private PairList arguments; + + private static final Logger logger = LoggerFactory.getLogger(FunctionInfo.class); + + FunctionInfo(Pointer pointer, String functionName, PairList arguments) { + this.handle = pointer; + this.name = functionName; + this.arguments = arguments; + } + + /** + * Returns the name of the operator. + * + * @return the name of the operator + */ + public String getFunctionName() { + return name; + } + + /** + * Returns the names of the params to the operator. + * + * @return the names of the params to the operator + */ + public List getArgumentNames() { + return arguments.keys(); + } + + /** + * Returns the types of the operator arguments. + * + * @return the types of the operator arguments + */ + public List getArgumentTypes() { + return arguments.values(); + } + /** + * Calls an operator with the given arguments. + * + * @param src the input NDArray(s) to the operator + * @param dest the destination NDArray(s) to be overwritten with the result of the operator + * @param params the non-NDArray arguments to the operator. Should be a {@code PairList} + * @return the error code or zero for no errors + */ + public int invoke(NDArray[] src, NDArray[] dest, PairList params) { + checkDevices(src); + checkDevices(dest); + return JnaUtils.imperativeInvoke(handle, src, dest, params).size(); + } + + /** + * Calls an operator with the given arguments. + * + * @param parent {@link MxResource} for the current instance + * @param src the input NDArray(s) to the operator + * @param params the non-NDArray arguments to the operator. Should be a {@code PairList} + * @return the error code or zero for no errors + */ + public NDArray[] invoke(MxResource parent, NDArray[] src, PairList params) { + checkDevices(src); + PairList pairList = + JnaUtils.imperativeInvoke(handle, src, null, params); + return pairList.stream() + .map( + pair -> { + if (pair.getValue() != SparseFormat.DENSE) { + return NDArray.create(parent, pair.getKey(), pair.getValue()); + } + return NDArray.create(parent, pair.getKey()); + }) + .toArray(NDArray[]::new); + } + + private void checkDevices(NDArray[] src) { + // check if all the NDArrays are in the same device + if (logger.isDebugEnabled() && src.length > 1) { + Device device = src[0].getDevice(); + for (int i = 1; i < src.length; ++i) { + if (!device.equals(src[i].getDevice())) { + logger.warn( + "Please make sure all the NDArrays are in the same device. You can call toDevice() to move the NDArray to the desired Device."); + } + } + } + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/JnaUtils.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/JnaUtils.java new file mode 100644 index 000000000000..46cff306af65 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/JnaUtils.java @@ -0,0 +1,893 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.jna; + +import com.sun.jna.Memory; +import com.sun.jna.Native; +import com.sun.jna.Pointer; +import com.sun.jna.ptr.PointerByReference; +import java.nio.Buffer; +import java.nio.ByteBuffer; +import java.nio.IntBuffer; +import java.nio.charset.StandardCharsets; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import org.apache.mxnet.engine.CachedOp; +import org.apache.mxnet.engine.Device; +import org.apache.mxnet.engine.DeviceType; +import org.apache.mxnet.engine.MxResource; +import org.apache.mxnet.engine.Symbol; +import org.apache.mxnet.exception.JnaCallException; +import org.apache.mxnet.ndarray.NDArray; +import org.apache.mxnet.ndarray.NDList; +import org.apache.mxnet.ndarray.types.DataType; +import org.apache.mxnet.ndarray.types.Shape; +import org.apache.mxnet.ndarray.types.SparseFormat; +import org.apache.mxnet.nn.Parameter; +import org.apache.mxnet.nn.SymbolBlock; +import org.apache.mxnet.util.PairList; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A class containing utilities to interact with the MXNet Engine's Java Native Access (JNA) layer. + */ +@SuppressWarnings("MissingJavadocMethod") +public final class JnaUtils { + + private static final Logger logger = LoggerFactory.getLogger(JnaUtils.class); + + public static final MxnetLibrary LIB = LibUtils.loadLibrary(); + + public static final ObjectPool REFS = + new ObjectPool<>(PointerByReference::new, r -> r.setValue(null)); + + private static final String[] OP_NAME_PREFIX = { + "_contrib_", "_linalg_", "_sparse_", "_image_", "_random_" + }; + + private static final Map OPS = getNdArrayFunctions(); + // private static final Map OPS = null; + + private static final Set FEATURES = getFeaturesInternal(); + + public static final String[] EMPTY_ARRAY = new String[0]; + + private JnaUtils() { + // not called + } + + /** An enum that enumerates the statuses of numpy mode. */ + public enum NumpyMode { + OFF, + THREAD_LOCAL_ON, + GLOBAL_ON + } + + public static void waitAll() { + checkCall(LIB.MXNDArrayWaitAll()); + } + + public static void setNumpyMode(NumpyMode mode) { + IntBuffer ret = IntBuffer.allocate(1); + checkCall(LIB.MXSetIsNumpyShape(mode.ordinal(), ret)); + } + + ///////////////////////////////// + // Related to CacheOp + ///////////////////////////////// + public static CachedOp createCachedOp(SymbolBlock block, MxResource parent) { + Symbol symbol = block.getSymbol(); + + List parameters = block.getAllParameters(); + + // record data index in all inputs + PairList dataIndices = new PairList<>(); + // record parameter index in all inputs + List paramIndices = new ArrayList<>(); + int index = 0; + for (Parameter parameter : parameters) { + // We assume uninitialized parameters are data inputs + if (parameter.isInitialized()) { + paramIndices.add(index); + } else { + dataIndices.add(parameter.getName(), index); + } + ++index; + } + + // Creating CachedOp + Pointer symbolHandle = symbol.getHandle(); + PointerByReference ref = REFS.acquire(); + + // static_alloc and static_shape are enabled by default + String[] keys = {"data_indices", "param_indices", "static_alloc", "static_shape"}; + String[] values = {dataIndices.values().toString(), paramIndices.toString(), "1", "1"}; + + checkCall(LIB.MXCreateCachedOp(symbolHandle, keys.length, keys, values, ref, (byte) 0)); + Pointer pointer = ref.getValue(); + REFS.recycle(ref); + + return new CachedOp(parent, pointer, parameters, paramIndices, dataIndices); + } + + public static void freeCachedOp(Pointer handle) { + checkCall(LIB.MXFreeCachedOp(handle)); + } + + ///////////////////////////////// + // About Symbol + ///////////////////////////////// + public static Pointer createSymbolFromFile(String path) { + PointerByReference ref = REFS.acquire(); + checkCall(LIB.MXSymbolCreateFromFile(path, ref)); + Pointer pointer = ref.getValue(); + REFS.recycle(ref); + return pointer; + } + + public static Pointer createSymbolFromString(String json) { + PointerByReference ref = REFS.acquire(); + checkCall(LIB.MXSymbolCreateFromJSON(json, ref)); + Pointer pointer = ref.getValue(); + REFS.recycle(ref); + return pointer; + } + + public static String[] listSymbolOutputs(Pointer symbol) { + IntBuffer size = IntBuffer.allocate(1); + PointerByReference ref = REFS.acquire(); + + checkCall(LIB.MXSymbolListOutputs(symbol, size, ref)); + String[] ret = toStringArray(ref, size.get()); + REFS.recycle(ref); + return ret; + } + + public static String printSymbol(Pointer symbol) { + String[] outStr = new String[1]; + checkCall(LIB.NNSymbolPrint(symbol, outStr)); + return outStr[0]; + } + + public static void freeSymbol(Pointer symbol) { + checkCall(LIB.MXSymbolFree(symbol)); + } + + public static String[] listSymbolArguments(Pointer symbol) { + IntBuffer size = IntBuffer.allocate(1); + PointerByReference ref = REFS.acquire(); + + checkCall(LIB.MXSymbolListArguments(symbol, size, ref)); + + String[] ret = toStringArray(ref, size.get()); + REFS.recycle(ref); + return ret; + } + + public static String[] listSymbolAuxiliaryStates(Pointer symbol) { + IntBuffer size = IntBuffer.allocate(1); + PointerByReference ref = REFS.acquire(); + + checkCall(LIB.MXSymbolListAuxiliaryStates(symbol, size, ref)); + + String[] ret = toStringArray(ref, size.get()); + REFS.recycle(ref); + return ret; + } + + public static String[] listSymbolNames(Pointer symbol) { + IntBuffer size = IntBuffer.allocate(1); + PointerByReference ref = REFS.acquire(); + + checkCall(LIB.NNSymbolListInputNames(symbol, 0, size, ref)); + + String[] ret = toStringArray(ref, size.get()); + REFS.recycle(ref); + return ret; + } + + public static Pointer getSymbolInternals(Pointer symbol) { + PointerByReference ref = REFS.acquire(); + checkCall(LIB.MXSymbolGetInternals(symbol, ref)); + Pointer pointer = ref.getValue(); + REFS.recycle(ref); + return pointer; + } + + private static List recoverShape( + NativeSizeByReference size, PointerByReference nDim, PointerByReference data) { + int shapeLength = (int) size.getValue().longValue(); + if (shapeLength == 0) { + return new ArrayList<>(); + } + int[] dims = nDim.getValue().getIntArray(0, shapeLength); + int flattenedLength = 0; + for (int dim : dims) { + flattenedLength += dim; + } + long[] flattenedShapes = data.getValue().getPointer(0).getLongArray(0, flattenedLength); + int idx = 0; + List result = new ArrayList<>(); + for (int dim : dims) { + long[] shape = new long[dim]; + System.arraycopy(flattenedShapes, idx, shape, 0, dim); + idx += dim; + result.add(new Shape(shape)); + } + return result; + } + + public static List> inferShape(Symbol symbol, PairList args) { + Pointer handler = symbol.getHandle(); + int numArgs = args.size(); + String[] keys = args.keys().toArray(new String[0]); + // the following two is also the representation of + // CSR NDArray + long[] indPtr = new long[numArgs + 1]; + Shape flattened = new Shape(); + indPtr[0] = 0; + for (int i = 0; i < args.size(); i++) { + Shape shape = args.valueAt(i); + indPtr[i + 1] = shape.dimension(); + flattened = flattened.addAll(shape); + } + long[] flattenedShapeArray = flattened.getShape(); + + NativeSizeByReference inShapeSize = new NativeSizeByReference(); + PointerByReference inShapeNDim = REFS.acquire(); + PointerByReference inShapeData = REFS.acquire(); + NativeSizeByReference outShapeSize = new NativeSizeByReference(); + PointerByReference outShapeNDim = REFS.acquire(); + PointerByReference outShapeData = REFS.acquire(); + NativeSizeByReference auxShapeSize = new NativeSizeByReference(); + PointerByReference auxShapeNDim = REFS.acquire(); + PointerByReference auxShapeData = REFS.acquire(); + IntBuffer complete = IntBuffer.allocate(1); + checkCall( + LIB.MXSymbolInferShape64( + handler, + numArgs, + keys, + indPtr, + flattenedShapeArray, + inShapeSize, + inShapeNDim, + inShapeData, + outShapeSize, + outShapeNDim, + outShapeData, + auxShapeSize, + auxShapeNDim, + auxShapeData, + complete)); + if (complete.get() != 0) { + return Arrays.asList( + recoverShape(inShapeSize, inShapeNDim, inShapeData), + recoverShape(outShapeSize, outShapeNDim, outShapeData), + recoverShape(auxShapeSize, auxShapeNDim, auxShapeData)); + } + return null; + } + + public static Pointer optimizeFor(Symbol current, String backend, Device device) { + // TODO: Support partition on parameters + PointerByReference returnedSymbolHandle = REFS.acquire(); + // placeHolders + PointerByReference[] placeHolders = { + REFS.acquire(), + REFS.acquire(), + REFS.acquire(), + REFS.acquire(), + REFS.acquire(), + REFS.acquire() + }; + // there is no need to update parameters + // TODO : check 22th parameter type + checkCall( + LIB.MXOptimizeForBackend( + current.getHandle(), + backend, + DeviceType.toDeviceType(device), + returnedSymbolHandle, + 0, + placeHolders[0], + 0, + placeHolders[1], + 0, + new String[0], + new String[0], + 0, + new String[0], + new long[0], + new int[0], + 0, + new String[0], + new int[0], + 0, + new String[0], + new int[0], + (byte) 0, + IntBuffer.allocate(0), + placeHolders[2], + placeHolders[3], + IntBuffer.allocate(0), + placeHolders[4], + placeHolders[5])); + Pointer ptr = returnedSymbolHandle.getValue(); + REFS.recycle(returnedSymbolHandle); + Arrays.stream(placeHolders).forEach(REFS::recycle); + return ptr; + } + + public static String getSymbolString(Pointer symbol) { + String[] holder = new String[1]; + checkCall(LIB.MXSymbolSaveToJSON(symbol, holder)); + return holder[0]; + } + + public static Pointer getSymbolOutput(Pointer symbol, int index) { + PointerByReference ref = REFS.acquire(); + checkCall(LIB.MXSymbolGetOutput(symbol, index, ref)); + Pointer pointer = ref.getValue(); + REFS.recycle(ref); + return pointer; + } + + public static NDList loadNdArray(MxResource parent, Path path, Device device) { + IntBuffer handlesSize = IntBuffer.allocate(1); + PointerByReference handlesRef = REFS.acquire(); + PointerByReference namesRef = REFS.acquire(); + IntBuffer namesSize = IntBuffer.allocate(1); + checkCall(LIB.MXNDArrayLoad(path.toString(), handlesSize, handlesRef, namesSize, namesRef)); + int ndArrayCount = handlesSize.get(); + int nameCount = namesSize.get(); + if (nameCount > 0 && ndArrayCount != nameCount) { + throw new IllegalStateException( + "Mismatch between names and arrays in checkpoint file: " + path.toString()); + } + Pointer[] handles = handlesRef.getValue().getPointerArray(0, ndArrayCount); + NDList ndList = new NDList(); + if (nameCount == 0) { + for (Pointer handle : handles) { + ndList.add(NDArray.create(parent, handle)); + } + } else { + String[] names = namesRef.getValue().getStringArray(0, nameCount); + for (int i = 0; i < ndArrayCount; i++) { + NDArray array = NDArray.create(parent, handles[i]); + array.setName(names[i]); + ndList.add(array); + } + } + + REFS.recycle(namesRef); + REFS.recycle(handlesRef); + + // MXNet always load NDArray on CPU + if (Device.cpu().equals(device)) { + return ndList; + } + + NDList ret = ndList.toDevice(device, true); + ndList.close(); + return ret; + } + + public static PairList loadNdArrayFromFile(String path) { + IntBuffer handleSize = IntBuffer.allocate(1); + IntBuffer namesSize = IntBuffer.allocate(1); + PointerByReference handlesRef = REFS.acquire(); + PointerByReference namesRef = REFS.acquire(); + checkCall(LIB.MXNDArrayLoad(path, handleSize, handlesRef, namesSize, namesRef)); + // TODO : construct NDArray Objects + int handleCount = handleSize.get(); + int nameCount = namesSize.get(); + if (nameCount > 0 && nameCount != handleCount) { + throw new IllegalStateException( + "Mismatch between names and arrays in checkpoint file: " + path); + } + Pointer[] handles = handlesRef.getValue().getPointerArray(0, handleCount); + + PairList pairList = new PairList<>(); + + if (nameCount == 0) { + for (Pointer handle : handles) { + pairList.add(null, handle); + } + } else { + String[] names = namesRef.getValue().getStringArray(0, nameCount); + for (int i = 0; i < handleCount; i++) { + pairList.add(names[i], handles[i]); + } + } + REFS.recycle(namesRef); + REFS.recycle(handlesRef); + + return pairList; + } + + public static void freeNdArray(Pointer handle) { + checkCall(LIB.MXNDArrayFree(handle)); + } + + public static Pointer loadNdArrayFromByteArray(byte[] buf, int offset, int size) { + Memory memory = new Memory(size); + memory.write(0, buf, offset, size); + PointerByReference outRef = REFS.acquire(); + checkCall(LIB.MXNDArrayLoadFromRawBytes(memory, new NativeSize(size), outRef)); + Pointer p = outRef.getValue(); + // outRef.getValue().getPointerArray(0, size); + + REFS.recycle(outRef); + return p; + } + + public static Pointer loadNdArrayFromByteBuffer(ByteBuffer byteBuffer) { + // Pointer handle = new Pointer(byteBuffer.address); + // ((DirectByteBuffer) byteBuffer).address() + // TODO + byte[] bytes = new byte[byteBuffer.limit()]; + byteBuffer.get(bytes); + return loadNdArrayFromByteArray(bytes, 0, byteBuffer.limit()); + } + + public static ByteBuffer saveNdArrayAsByteBuffer(Pointer ndArray) { + NativeSizeByReference size = new NativeSizeByReference(); + PointerByReference ref = new PointerByReference(); + checkCall(LIB.MXNDArraySaveRawBytes(ndArray, size, ref)); + return ref.getValue().getByteBuffer(0, size.getValue().longValue()); + } + + public static byte[] saveNdArrayAsByteArray(Pointer ndArray) { + ByteBuffer buffer = saveNdArrayAsByteBuffer(ndArray); + byte[] bytes = new byte[buffer.limit()]; + buffer.get(bytes); + return bytes; + } + + public static void syncCopyToCPU(Pointer ndArray, Pointer data, int len) { + NativeSize size = new NativeSize(len); + checkNDArray(ndArray, "copy from"); + checkNDArray(data, "copy to"); + checkCall(LIB.MXNDArraySyncCopyToCPU(ndArray, data, size)); + } + + public static void syncCopyFromCPU(Pointer ndArray, Buffer data, int len) { + NativeSize size = new NativeSize(len); + Pointer pointer = Native.getDirectBufferPointer(data); + checkCall(LIB.MXNDArraySyncCopyFromCPU(ndArray, pointer, size)); + } + + public static void waitToRead(Pointer ndArray) { + checkNDArray(ndArray, "wait to read"); + checkCall(LIB.MXNDArrayWaitToRead(ndArray)); + } + + public static void waitToWrite(Pointer ndArray) { + checkNDArray(ndArray, "wait to write"); + checkCall(LIB.MXNDArrayWaitToWrite(ndArray)); + } + + public static Pointer detachGradient(Pointer handle) { + PointerByReference ref = REFS.acquire(); + checkCall(LIB.MXNDArrayDetach(handle, ref)); + Pointer pointer = ref.getValue(); + REFS.recycle(ref); + return pointer; + } + + public static Pointer getGradient(Pointer handle) { + PointerByReference ref = REFS.acquire(); + checkNDArray(handle, "get the gradient for"); + checkCall(LIB.MXNDArrayGetGrad(handle, ref)); + Pointer pointer = ref.getValue(); + REFS.recycle(ref); + return pointer; + } + + public static void autogradMarkVariables( + int numVar, Pointer varHandles, IntBuffer reqsArray, Pointer gradHandles) { + PointerByReference varRef = REFS.acquire(); + PointerByReference gradRef = REFS.acquire(); + varRef.setValue(varHandles); + gradRef.setValue(gradHandles); + checkCall(LIB.MXAutogradMarkVariables(numVar, varRef, reqsArray, gradRef)); + REFS.recycle(varRef); + REFS.recycle(gradRef); + } + + public static Map getNdArrayFunctions() { + Set opNames = JnaUtils.getAllOpNames(); + Map map = new ConcurrentHashMap<>(); + + PointerByReference ref = REFS.acquire(); + for (String opName : opNames) { + checkCall(LIB.NNGetOpHandle(opName, ref)); + String functionName = getOpNamePrefix(opName); + map.put(functionName, getFunctionByName(opName, functionName, ref.getValue())); + ref.setValue(null); + } + REFS.recycle(ref); + return map; + } + + public static PairList imperativeInvoke( + Pointer function, NDArray[] src, NDArray[] dest, PairList params) { + String[] keys; + String[] values; + if (params == null) { + keys = EMPTY_ARRAY; + values = EMPTY_ARRAY; + } else { + keys = params.keyArray(EMPTY_ARRAY); + values = params.values().stream().map(Object::toString).toArray(String[]::new); + } + // StringArray keyArray = StringArray.of(keys); + // StringArray valueArray = StringArray.of(values); + PointerArray srcArray = toPointerArray(src); + PointerArray destArray = toPointerArray(dest); + PointerByReference destRef = REFS.acquire(); + destRef.setValue(destArray); + PointerByReference destSType = REFS.acquire(); + IntBuffer numOutputs = IntBuffer.allocate(1); + numOutputs.put(0, 1); + + checkCall( + LIB.MXImperativeInvoke( + function, + src.length, + srcArray, + numOutputs, + destRef, + keys.length, + keys, + values, + destSType)); + int numOfOutputs = numOutputs.get(0); + Pointer[] ptrArray = destRef.getValue().getPointerArray(0, numOfOutputs); + int[] sTypes = destSType.getValue().getIntArray(0, numOfOutputs); + PairList pairList = new PairList<>(); + for (int i = 0; i < numOfOutputs; i++) { + pairList.add(ptrArray[i], SparseFormat.fromValue(sTypes[i])); + } + REFS.recycle(destRef); + REFS.recycle(destSType); + srcArray.recycle(); + // keyArray.recycle(); + // valueArray.recycle(); + + if (destArray != null) { + destArray.recycle(); + } + return pairList; + } + + private static PointerArray toPointerArray(NDArray[] vals) { + if (vals == null) { + return null; + } + Pointer[] valPointers = new Pointer[vals.length]; + for (int i = 0; i < vals.length; i++) { + valPointers[i] = vals[i].getHandle(); + } + return PointerArray.of(valPointers); + } + + public static FunctionInfo op(String opName) { + if (!OPS.containsKey(opName)) { + throw new IllegalArgumentException("Unknown operator: " + opName); + } + return OPS.get(opName); + } + + public static FunctionInfo getFunctionByName(String name, String functionName, Pointer handle) { + String[] nameRef = {name}; + String[] description = new String[1]; + IntBuffer numArgs = IntBuffer.allocate(1); + PointerByReference argNameRef = REFS.acquire(); + PointerByReference argTypeRef = REFS.acquire(); + PointerByReference argDescRef = REFS.acquire(); + String[] keyVarArgs = new String[1]; + String[] returnType = new String[1]; + + checkCall( + LIB.MXSymbolGetAtomicSymbolInfo( + handle, + nameRef, + description, + numArgs, + argNameRef, + argTypeRef, + argDescRef, + keyVarArgs, + returnType)); + + int count = numArgs.get(); + PairList arguments = new PairList<>(); + if (count != 0) { + String[] argNames = + argNameRef.getValue().getStringArray(0, count, StandardCharsets.UTF_8.name()); + String[] argTypes = + argTypeRef.getValue().getStringArray(0, count, StandardCharsets.UTF_8.name()); + for (int i = 0; i < argNames.length; i++) { + arguments.add(argNames[i], argTypes[i]); + } + } + + REFS.recycle(argNameRef); + REFS.recycle(argTypeRef); + REFS.recycle(argDescRef); + + return new FunctionInfo(handle, functionName, arguments); + } + + public static Set getAllOpNames() { + IntBuffer outSize = IntBuffer.allocate(1); + PointerByReference outArray = REFS.acquire(); + + checkCall(LIB.MXListAllOpNames(outSize, outArray)); + + int size = outSize.get(); + Pointer[] pointers = outArray.getValue().getPointerArray(0, size); + + Set set = new HashSet<>(); + for (Pointer p : pointers) { + set.add(p.getString(0, StandardCharsets.UTF_8.name())); + } + REFS.recycle(outArray); + return set; + } + + public static String getOpNamePrefix(String name) { + for (String prefix : OP_NAME_PREFIX) { + if (name.startsWith(prefix)) { + return name.substring(prefix.length()); + } + } + return name; + } + + public static DataType getDataTypeOfNdArray(Pointer handle) { + IntBuffer dataType = IntBuffer.allocate(1); + checkNDArray(handle, "get the data type of"); + checkCall(LIB.MXNDArrayGetDType(handle, dataType)); + return DataType.values()[dataType.get()]; + } + + public static Device getDeviceOfNdArray(Pointer handle) { + IntBuffer deviceType = IntBuffer.allocate(1); + IntBuffer deviceId = IntBuffer.allocate(1); + checkNDArray(handle, "get the device of"); + checkCall(LIB.MXNDArrayGetContext(handle, deviceType, deviceId)); + String deviceTypeStr = DeviceType.fromDeviceType(deviceType.get(0)); + // CPU is special case which don't have device id + return Device.of(deviceTypeStr, deviceId.get(0)); + } + + public static Shape getShapeOfNdArray(Pointer handle) { + IntBuffer dim = IntBuffer.allocate(1); + PointerByReference ref = REFS.acquire(); + checkNDArray(handle, "get the shape of"); + checkCall(LIB.MXNDArrayGetShape(handle, dim, ref)); + int nDim = dim.get(); + if (nDim == 0) { + REFS.recycle(ref); + return new Shape(); + } + int[] shape = ref.getValue().getIntArray(0, nDim); + REFS.recycle(ref); + return new Shape(Arrays.stream(shape).asLongStream().toArray()); + } + + public static Shape getShape64OfNdArray(Pointer handle) { + IntBuffer dim = IntBuffer.allocate(1); + PointerByReference ref = REFS.acquire(); + checkNDArray(handle, "get the shape64 of"); + checkCall(LIB.MXNDArrayGetShape64(handle, dim, ref)); + int nDim = dim.get(); + if (nDim == 0) { + REFS.recycle(ref); + return new Shape(); + } + int[] shape = ref.getValue().getIntArray(0, nDim); + REFS.recycle(ref); + return new Shape(Arrays.stream(shape).asLongStream().toArray()); + } + + public static SparseFormat getStorageType(Pointer handle) { + IntBuffer type = IntBuffer.allocate(1); + checkNDArray(handle, "get the storage type of"); + checkCall(LIB.MXNDArrayGetStorageType(handle, type)); + return SparseFormat.fromValue(type.get()); + } + + public static Pointer createNdArray( + Device device, Shape shape, DataType dataType, int size, boolean delayedAlloc) { + int deviceType = DeviceType.toDeviceType(device); + int deviceId = (deviceType != 1) ? device.getDeviceId() : -1; + int delay = delayedAlloc ? 1 : 0; + + PointerByReference ref = REFS.acquire(); + int[] shapeArray = Arrays.stream(shape.getShape()).mapToInt(Math::toIntExact).toArray(); + checkCall( + LIB.MXNDArrayCreate( + shapeArray, size, deviceType, deviceId, delay, dataType.ordinal(), ref)); + + Pointer pointer = ref.getValue(); + REFS.recycle(ref); + return pointer; + } + + public static Pointer createSparseNdArray( + SparseFormat fmt, + Device device, + Shape shape, + DataType dtype, + DataType[] auxDTypes, + Shape[] auxShapes, + boolean delayedAlloc) { + int[] shapeArray = Arrays.stream(shape.getShape()).mapToInt(Math::toIntExact).toArray(); + int deviceType = DeviceType.toDeviceType(device); + int deviceId = (deviceType != 1) ? device.getDeviceId() : -1; + int delay = delayedAlloc ? 1 : 0; + PointerByReference ref = REFS.acquire(); + IntBuffer auxDTypesInt = + IntBuffer.wrap(Arrays.stream(auxDTypes).mapToInt(DataType::ordinal).toArray()); + IntBuffer auxNDims = + IntBuffer.wrap(Arrays.stream(auxShapes).mapToInt(Shape::dimension).toArray()); + int[] auxShapesInt = Arrays.stream(auxShapes).mapToInt(ele -> (int) ele.head()).toArray(); + checkCall( + LIB.MXNDArrayCreateSparseEx( + fmt.getValue(), + shapeArray, + shapeArray.length, + deviceType, + deviceId, + delay, + dtype.ordinal(), + auxDTypes.length, + auxDTypesInt, + auxNDims, + auxShapesInt, + ref)); + Pointer pointer = ref.getValue(); + REFS.recycle(ref); + return pointer; + } + + public static void ndArraySyncCopyFromNdArray(NDArray dest, NDArray src, int location) { + checkCall(LIB.MXNDArraySyncCopyFromNDArray(dest.getHandle(), src.getHandle(), location)); + } + + public static int getVersion() { + IntBuffer version = IntBuffer.allocate(1); + checkCall(LIB.MXGetVersion(version)); + return version.get(); + } + + public static NDArray[] cachedOpInvoke( + MxResource parent, Pointer cachedOpHandle, NDArray[] inputs) { + IntBuffer buf = IntBuffer.allocate(1); + PointerArray array = toPointerArray(inputs); + PointerByReference ref = REFS.acquire(); + PointerByReference outSTypeRef = REFS.acquire(); + Device device = inputs[0].getDevice(); + // TODO: check the init value of default_dev_type and default_dev_id + checkCall( + LIB.MXInvokeCachedOp( + cachedOpHandle, + inputs.length, + array, + DeviceType.toDeviceType(device), + 0, + buf, + ref, + outSTypeRef)); + int numOutputs = buf.get(); + Pointer[] ptrArray = ref.getValue().getPointerArray(0, numOutputs); + int[] sTypes = outSTypeRef.getValue().getIntArray(0, numOutputs); + NDArray[] output = new NDArray[numOutputs]; + for (int i = 0; i < numOutputs; i++) { + if (sTypes[i] != 0) { + output[i] = NDArray.create(parent, ptrArray[i], SparseFormat.fromValue(sTypes[i])); + } else { + output[i] = NDArray.create(parent, ptrArray[i]); + } + } + REFS.recycle(ref); + REFS.recycle(outSTypeRef); + array.recycle(); + return output; + } + + private static void checkNDArray(Pointer pointer, String msg) { + if (pointer == null) { + throw new IllegalArgumentException( + "Tried to " + msg + " an MXNet NDArray that was already closed"); + } + } + + public static void checkCall(int ret) { + if (ret != 0) { + logger.error("MXNet engine call failed: " + getLastError()); + throw new JnaCallException("MXNet engine call failed: " + getLastError()); + } + } + + private static String getLastError() { + return LIB.MXGetLastError(); + } + + private static String[] toStringArray(PointerByReference ref, int size) { + if (size == 0) { + return new String[0]; + } + + Pointer[] pointers = ref.getValue().getPointerArray(0, size); + + String[] arr = new String[size]; + for (int i = 0; i < size; ++i) { + arr[i] = pointers[i].getString(0, StandardCharsets.UTF_8.name()); + } + + return arr; + } + + private static Set getFeaturesInternal() { + PointerByReference ref = REFS.acquire(); + NativeSizeByReference outSize = new NativeSizeByReference(); + checkCall(LIB.MXLibInfoFeatures(ref, outSize)); + int size = outSize.getValue().intValue(); + if (size == 0) { + REFS.recycle(ref); + return Collections.emptySet(); + } + + LibFeature pointer = new LibFeature(ref.getValue()); + pointer.read(); + + LibFeature[] features = (LibFeature[]) pointer.toArray(size); + + Set set = new HashSet<>(); + for (LibFeature feature : features) { + if (feature.getEnabled() == 1) { + set.add(feature.getName()); + } + } + REFS.recycle(ref); + return set; + } + + public static Set getFeatures() { + return FEATURES; + } + + public static boolean autogradIsTraining() { + ByteBuffer isTraining = ByteBuffer.allocate(1); + checkCall(LIB.MXAutogradIsTraining(isTraining)); + return isTraining.get(0) == 1; + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/LibUtils.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/LibUtils.java new file mode 100644 index 000000000000..d1103986dd48 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/LibUtils.java @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.jna; + +import com.sun.jna.Library; +import com.sun.jna.Native; +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.net.URL; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardCopyOption; +import java.util.Arrays; +import java.util.Collections; +import java.util.Enumeration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import org.apache.mxnet.util.Platform; +import org.apache.mxnet.util.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Utilities for finding the MXNet Engine binary on the System. + * + *

The Engine will be searched for in a variety of locations in the following order: + * + *

    + *
  1. In the path specified by the MXNET_LIBRARY_PATH environment variable + *
  2. In a jar file location in the classpath. These jars can be created with the mxnet-native + * module. + *
  3. In the python3 path. These can be installed using pip. + *
  4. In the python path. These can be installed using pip. + *
+ */ +@SuppressWarnings("MissingJavadocMethod") +public final class LibUtils { + + private static final Logger logger = LoggerFactory.getLogger(LibUtils.class); + + private static final String LIB_NAME = "mxnet"; + + private static final String MXNET_LIBRARY_PATH = "MXNET_LIBRARY_PATH"; + + private static final String MXNET_PROPERTIES_FILE_PATH = "native/lib/mxnet.properties"; + + private LibUtils() {} + + public static MxnetLibrary loadLibrary() { + + String libName = getLibName(); + logger.debug("Loading mxnet library from: {}", libName); + if (System.getProperty("os.name").startsWith("Linux")) { + logger.info("Loading on Linux platform"); + Map options = new ConcurrentHashMap<>(); + int rtld = 1; // Linux RTLD lazy + local + options.put(Library.OPTION_OPEN_FLAGS, rtld); + return Native.load(libName, MxnetLibrary.class, options); + } + return Native.load(libName, MxnetLibrary.class); + } + + public static String getLibName() { + String libName = findOverrideLibrary(); + if (libName == null) { + libName = LibUtils.findLibraryInClasspath(); + if (libName == null) { + libName = LIB_NAME; + } + } + + return libName; + } + + private static String findOverrideLibrary() { + // TODO: load from jar files + String libPath = System.getenv(MXNET_LIBRARY_PATH); + if (libPath != null) { + String libName = findLibraryInPath(libPath); + if (libName != null) { + return libName; + } + } + + libPath = System.getProperty("java.library.path"); + if (libPath != null) { + return findLibraryInPath(libPath); + } + return null; + } + + private static synchronized String findLibraryInClasspath() { + Enumeration urls = getUrls(); + // No native jars + if (!urls.hasMoreElements()) { + logger.debug("mxnet.properties not found in class path."); + return null; + } + + // Find the mxnet library version that matches local system platform + // throw exception if no one matches + Platform systemPlatform = Platform.fromSystem(); + try { + while (urls.hasMoreElements()) { + URL url = urls.nextElement(); + Platform platform = Platform.fromUrl(url); + if (!platform.isPlaceholder() && platform.matches(systemPlatform)) { + return loadLibraryFromClasspath(platform); + } + } + } catch (IOException e) { + throw new IllegalStateException( + "Failed to read MXNet native library jar properties", e); + } + + throw new IllegalStateException( + "Your MXNet native library jar does not match your operating system. Make sure that the Maven Dependency Classifier matches your system type."); + } + + private static Enumeration getUrls() { + try { + return Thread.currentThread() + .getContextClassLoader() + .getResources(MXNET_PROPERTIES_FILE_PATH); + } catch (IOException e) { + logger.warn( + String.format( + "IO Exception occurs when try to find the file %s", MXNET_LIBRARY_PATH), + e); + return null; + } + } + + private static String loadLibraryFromClasspath(Platform platform) { + Path tmp = null; + try { + String libName = System.mapLibraryName(LIB_NAME); + Path cacheFolder = Utils.getEngineCacheDir(LIB_NAME); + logger.debug("Using cache dir: {}", cacheFolder); + + Path dir = cacheFolder.resolve(platform.getVersion() + platform.getClassifier()); + Path path = dir.resolve(libName); + if (Files.exists(path)) { + return path.toAbsolutePath().toString(); + } + Files.createDirectories(cacheFolder); + tmp = Files.createTempDirectory(cacheFolder, "tmp"); + for (String file : platform.getLibraries()) { + String libPath = "/native/lib/" + file; + try (InputStream is = LibUtils.class.getResourceAsStream(libPath)) { + logger.info("Extracting {} to cache ...", file); + Files.copy(is, tmp.resolve(file), StandardCopyOption.REPLACE_EXISTING); + } + } + + Utils.moveQuietly(tmp, dir); + return path.toAbsolutePath().toString(); + } catch (IOException e) { + throw new IllegalStateException("Failed to extract MXNet native library", e); + } finally { + if (tmp != null) { + Utils.deleteQuietly(tmp); + } + } + } + + private static String findLibraryInPath(String libPath) { + String[] paths = libPath.split(File.pathSeparator); + List mappedLibNames; + if (com.sun.jna.Platform.isMac()) { + mappedLibNames = Arrays.asList("libmxnet.dylib", "libmxnet.jnilib", "libmxnet.so"); + } else { + mappedLibNames = Collections.singletonList(System.mapLibraryName(LIB_NAME)); + } + + for (String path : paths) { + File p = new File(path); + if (!p.exists()) { + continue; + } + for (String name : mappedLibNames) { + if (p.isFile() && p.getName().endsWith(name)) { + return p.getAbsolutePath(); + } + + File file = new File(path, name); + if (file.exists() && file.isFile()) { + return file.getAbsolutePath(); + } + } + } + return null; + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/NativeString.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/NativeString.java new file mode 100644 index 000000000000..d43e4758d78c --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/NativeString.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.jna; + +import com.sun.jna.Memory; +import com.sun.jna.Pointer; +import java.nio.charset.Charset; + +/** + * Provides a temporary allocation of an immutable C string (const char* or + * const wchar_t*) for use when converting a Java String into a native memory function + * argument. + */ +final class NativeString { + + private static final ObjectPool POOL = new ObjectPool<>(null, null); + + private Memory pointer; + + /** + * Create a native string (NUL-terminated array of char), using the requested + * encoding. + * + * @param data the bytes of the string + */ + private NativeString(byte[] data) { + pointer = new Memory(data.length + 1); + setData(data); + } + + private void setData(byte[] data) { + pointer.write(0, data, 0, data.length); + pointer.setByte(data.length, (byte) 0); + } + + /** + * Acquires a pooled {@code NativeString} object if available, otherwise a new instance is + * created. + * + * @param string the string value + * @param encoding the charset encoding + * @return a {@code NativeString} object + */ + public static NativeString of(String string, Charset encoding) { + byte[] data = string.getBytes(encoding); + NativeString array = POOL.acquire(); + if (array != null && array.pointer.size() > data.length) { + array.setData(data); + return array; + } + return new NativeString(data); + } + + /** Recycles this instance and return it back to the pool. */ + public void recycle() { + POOL.recycle(this); + } + + /** + * Returns the peer pointer. + * + * @return the peer pointer + */ + public Pointer getPointer() { + return pointer; + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/ObjectPool.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/ObjectPool.java new file mode 100644 index 000000000000..573acd4f439e --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/ObjectPool.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.jna; + +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.function.Consumer; +import java.util.function.Supplier; + +/** + * A generic object pool implementation. + * + * @param the type of object to put in the pool + */ +@SuppressWarnings("MissingJavadocMethod") +public class ObjectPool { + + private Queue queue; + private Supplier supplier; + private Consumer consumer; + + public ObjectPool(Supplier supplier, Consumer consumer) { + queue = new ConcurrentLinkedQueue<>(); + this.supplier = supplier; + this.consumer = consumer; + } + + public T acquire() { + T item = queue.poll(); + if (item == null) { + if (supplier != null) { + return supplier.get(); + } + } + return item; + } + + public void recycle(T item) { + if (consumer != null) { + consumer.accept(item); + } + queue.add(item); + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/PointerArray.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/PointerArray.java new file mode 100644 index 000000000000..a864b64abd35 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/PointerArray.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.jna; + +import com.sun.jna.Function; +import com.sun.jna.Memory; +import com.sun.jna.Native; +import com.sun.jna.Pointer; + +/** + * An abstraction for a native pointer array data type ({@code void**}). + * + * @see Pointer + * @see com.sun.jna.ptr.PointerByReference + * @see Function + */ +@SuppressWarnings("checkstyle:EqualsHashCode") +final class PointerArray extends Memory { + + private static final ObjectPool POOL = new ObjectPool<>(null, null); + + private int length; + + /** + * Constructs a {@link Memory} buffer PointerArray given the Pointers to include in it. + * + * @param arg the pointers to include in the array + */ + private PointerArray(Pointer... arg) { + super(Native.POINTER_SIZE * (arg.length + 1)); + length = arg.length; + setPointers(arg); + } + + /** + * Acquires a pooled {@code PointerArray} object if available, otherwise a new instance is + * created. + * + * @param arg the pointers to include in the array + * @return a {@code PointerArray} object + */ + public static PointerArray of(Pointer... arg) { + PointerArray array = POOL.acquire(); + if (array != null && array.length >= arg.length) { + array.setPointers(arg); + return array; + } + return new PointerArray(arg); + } + + /** Recycles this instance and return it back to the pool. */ + public void recycle() { + POOL.recycle(this); + } + + private void setPointers(Pointer[] pointers) { + for (int i = 0; i < pointers.length; i++) { + setPointer(i * Native.POINTER_SIZE, pointers[i]); + } + setPointer(Native.POINTER_SIZE * length, null); + } + + /** {@inheritDoc} */ + @Override + public boolean equals(Object o) { + return o == this; + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/StringArray.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/StringArray.java new file mode 100644 index 000000000000..00ffa8f053fe --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/StringArray.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.jna; + +import com.sun.jna.Memory; +import com.sun.jna.Native; +import com.sun.jna.Pointer; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.List; + +/** An abstraction for a native string array data type ({@code char**}). */ +@SuppressWarnings("checkstyle:EqualsHashCode") +final class StringArray extends Memory { + + private static final Charset ENCODING = Native.DEFAULT_CHARSET; + private static final ObjectPool POOL = new ObjectPool<>(null, null); + /** Hold all {@code NativeString}, avoid be GCed. */ + private List natives; // NOPMD + + private int length; + + /** + * Create a native array of strings. + * + * @param strings the strings + */ + private StringArray(String[] strings) { + super((strings.length + 1) * Native.POINTER_SIZE); + natives = new ArrayList<>(); + length = strings.length; + setPointers(strings); + } + + private void setPointers(String[] strings) { + for (NativeString ns : natives) { + ns.recycle(); + } + natives.clear(); + for (int i = 0; i < strings.length; ++i) { + Pointer p = null; + if (strings[i] != null) { + NativeString ns = NativeString.of(strings[i], ENCODING); + natives.add(ns); + p = ns.getPointer(); + } + setPointer(Native.POINTER_SIZE * i, p); + } + setPointer(Native.POINTER_SIZE * strings.length, null); + } + + /** + * Acquires a pooled {@code StringArray} object if available, otherwise a new instance is + * created. + * + * @param strings the pointers to include in the array + * @return a {@code StringArray} object + */ + public static StringArray of(String[] strings) { + StringArray array = POOL.acquire(); + if (array != null && array.length >= strings.length) { + array.setPointers(strings); + return array; + } + return new StringArray(strings); + } + + /** Recycles this instance and return it back to the poll. */ + public void recycle() { + POOL.recycle(this); + } + + /** {@inheritDoc} */ + @Override + public boolean equals(Object o) { + return this == o; + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/package-info.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/package-info.java new file mode 100644 index 000000000000..d524d5056e6d --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/package-info.java @@ -0,0 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Contains the Java front-end implementation for Apache MXNet. */ +package org.apache.mxnet.jna; diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArray.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArray.java new file mode 100644 index 000000000000..f775f443e7df --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArray.java @@ -0,0 +1,3455 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.ndarray; + +import com.sun.jna.Native; +import com.sun.jna.Pointer; +import java.io.ByteArrayInputStream; +import java.io.DataInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.Buffer; +import java.nio.ByteBuffer; +import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; +import java.nio.IntBuffer; +import java.nio.LongBuffer; +import java.util.Arrays; +import java.util.stream.IntStream; +import org.apache.mxnet.engine.BaseMxResource; +import org.apache.mxnet.engine.Device; +import org.apache.mxnet.engine.MxResource; +import org.apache.mxnet.engine.OpParams; +import org.apache.mxnet.jna.JnaUtils; +import org.apache.mxnet.ndarray.index.NDIndex; +import org.apache.mxnet.ndarray.types.DataType; +import org.apache.mxnet.ndarray.types.Shape; +import org.apache.mxnet.ndarray.types.SparseFormat; +import org.apache.mxnet.util.Float16Utils; +import org.apache.mxnet.util.PairList; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A Class representing an n-dimensional array. + * + *

NDArray is the core data structure for all mathematical computations. An NDArray represents a + * multidimensional, fixed-size homogeneous array. It has very similar behaviour to the Numpy python + * package with the addition of efficient computing. + */ +public class NDArray extends MxResource { + + private static final Logger logger = LoggerFactory.getLogger(NDArray.class); + + private static final int MAX_SIZE = 100; + private static final int MAX_DEPTH = 10; + private static final int MAX_ROWS = 10; + private static final int MAX_COLUMNS = 20; + private static final NDArray[] EMPTY = new NDArray[0]; + + private String name; + private Device device; + private SparseFormat sparseFormat; + private DataType dataType; + private Shape shape; + // use Boolean object to maintain three status: false, true + // and null which means the flag is not set by the native engine yet + private Boolean hasGradient; + private Integer version; + private NDArrayEx mxNDArrayEx; + + protected NDArray(Pointer handle) { + super(BaseMxResource.getSystemMxResource(), handle); + } + + /** + * Constructs an {@link NDArray} from a native handle and metadata (internal. Use {@method + * create} methods). + * + * @param parent the parent {@link MxResource} to manage the life circle of the {@link NDArray} + * @param handle the pointer to the native NDArray memory + * @param device the device the new array will be located on + * @param shape the shape of the new array + * @param dataType the dataType of the new array + * @param hasGradient the gradient status of the new array + */ + NDArray( + MxResource parent, + Pointer handle, + Device device, + Shape shape, + DataType dataType, + boolean hasGradient) { + this(parent, handle); + setParent(parent); + this.device = device; + // shape check + if (Arrays.stream(shape.getShape()).anyMatch(s -> s < 0)) { + throw new IllegalArgumentException("The shape must be >= 0"); + } + this.shape = shape; + this.dataType = dataType; + this.hasGradient = hasGradient; + if (parent != null) { + parent.addSubResource(this); + } + } + + /** + * Constructs an {@link NDArray} from a native handle and metadata (internal). + * + * @param parent the parent {@link MxResource} to manage the life circle of the {@link NDArray} + * @param handle the pointer to the native NDArray memory + */ + NDArray(MxResource parent, Pointer handle) { + super(parent, handle); + this.mxNDArrayEx = new NDArrayEx(this); + } + + /** + * Constructs an {@link NDArray} from a native handle and metadata (internal). + * + * @param parent the parent {@link MxResource} to manage the life circle of the {@link NDArray} + * @param handle the pointer to the native NDArray memory + * @param fmt the sparse format + */ + NDArray(MxResource parent, Pointer handle, SparseFormat fmt) { + this(parent, handle); + this.sparseFormat = fmt; + } + + /** + * Creates an NDArray with the given Native Memory Pointer and parent MxResource. + * + * @param parent the parent {@link MxResource} instance + * @param handle the array's native memory pointer + * @return the created array + */ + public static NDArray create(MxResource parent, Pointer handle) { + return new NDArray(parent, handle); + } + + /** + * Creates an NDArray with the given Native Memory Pointer and parent MxResource. + * + * @param parent the parent {@link MxResource} instance + * @param handle the array's native memory pointer + * @param fmt the sparse format + * @return the created array + */ + public static NDArray create(MxResource parent, Pointer handle, SparseFormat fmt) { + return new NDArray(parent, handle, fmt); + } + + /** + * Creates an uninitialized instance of {@link DataType#FLOAT32} {@link NDArray} with specified + * parent {@link MxResource}, {@link Shape}, {@link Device} and {@code hasGradient}. + * + * @param parent the parent {@link MxResource} + * @param shape the {@link Shape} of the {@link NDArray} + * @param device the {@link Device} of the {@link NDArray} + * @return a new instance of {@link NDArray} + */ + public static NDArray create(MxResource parent, Shape shape, Device device) { + return create(parent, shape, DataType.FLOAT32, device); + } + + /** + * Creates an uninitialized instance of {@link DataType#FLOAT32} {@link NDArray} with specified + * parent {@link MxResource} and {@link Shape}. + * + * @param parent the parent {@link MxResource} + * @param shape the {@link Shape} of the {@link NDArray} + * @return a new instance of {@link NDArray} + */ + public static NDArray create(MxResource parent, Shape shape) { + return create(parent, shape, DataType.FLOAT32, Device.defaultIfNull()); + } + + /** + * Creates an uninitialized instance of {@link NDArray} with specified parent {@link + * MxResource}, {@link Shape}, {@link DataType}, {@link Device} and {@code hasGradient}. + * + * @param parent the parent {@link MxResource} + * @param shape the {@link Shape} of the {@link NDArray} + * @param dataType the {@link DataType} of the {@link NDArray} + * @param device the {@link Device} of the {@link NDArray} + * @param hasGradient true if the gradient calculation is required for this {@link NDArray} + * @return a new instance of {@link NDArray} + */ + public static NDArray create( + MxResource parent, Shape shape, DataType dataType, Device device, boolean hasGradient) { + Pointer handle = + JnaUtils.createNdArray(device, shape, dataType, shape.dimension(), hasGradient); + return new NDArray(parent, handle, device, shape, dataType, hasGradient); + } + + /** + * Creates an uninitialized instance of {@link NDArray} with specified parent {@link + * MxResource}, {@link Shape}, {@link DataType}, {@link Device}. + * + * @param parent the parent {@link MxResource} + * @param shape the {@link Shape} of the {@link NDArray} + * @param dataType the {@link DataType} of the {@link NDArray} + * @param device the {@link Device} of the {@link NDArray} + * @return a new instance of {@link NDArray} + */ + public static NDArray create(MxResource parent, Shape shape, DataType dataType, Device device) { + Pointer handle = JnaUtils.createNdArray(device, shape, dataType, shape.dimension(), false); + return new NDArray(parent, handle, Device.defaultIfNull(device), shape, dataType, false); + } + + /** + * Creates and initializes a scalar {@link NDArray}. + * + * @param parent the parent {@link MxResource} instance + * @param data the {@link Number} that needs to be set + * @return a new instance of {@link NDArray} + * @throws IllegalArgumentException when the Type of data is not expected + */ + public static NDArray create(MxResource parent, Number data) { + if (data instanceof Integer) { + return create(parent, data.intValue()); + } else if (data instanceof Float) { + return create(parent, data.floatValue()); + } else if (data instanceof Double) { + return create(parent, data.doubleValue()); + } else if (data instanceof Long) { + return create(parent, data.longValue()); + } else if (data instanceof Byte) { + return create(parent, data.byteValue()); + } else { + throw new IllegalArgumentException("Short conversion not supported!"); + } + } + + /** + * Creates and initializes an instance of {@link NDArray} with specified {@link Shape} and float + * array. + * + * @param parent the parent {@link MxResource} to manage this instance + * @param data the float array that needs to be set + * @param shape the {@link Shape} of the {@link NDArray} + * @return a new instance of {@link NDArray} + */ + public static NDArray create(MxResource parent, float[] data, Shape shape) { + return create(parent, FloatBuffer.wrap(data), shape); + } + + /** + * Creates and initializes an instance of {@link NDArray} with specified {@link Shape} and int + * array. + * + * @param parent the parent {@link MxResource} to manage this instance + * @param data the float array that needs to be set + * @param shape the {@link Shape} of the {@link NDArray} + * @return a new instance of {@link NDArray} + */ + public static NDArray create(MxResource parent, int[] data, Shape shape) { + return create(parent, IntBuffer.wrap(data), shape); + } + + /** + * Creates and initializes an instance of {@link NDArray} with specified {@link Shape} and + * double array. + * + * @param parent the parent {@link MxResource} to manage this instance + * @param data the float array that needs to be set + * @param shape the {@link Shape} of the {@link NDArray} + * @return a new instance of {@link NDArray} + */ + public static NDArray create(MxResource parent, double[] data, Shape shape) { + return create(parent, DoubleBuffer.wrap(data), shape); + } + + /** + * Creates and initializes an instance of {@link NDArray} with specified {@link Shape} and long + * array. + * + * @param parent the parent {@link MxResource} to manage this instance + * @param data the float array that needs to be set + * @param shape the {@link Shape} of the {@link NDArray} + * @return a new instance of {@link NDArray} + */ + public static NDArray create(MxResource parent, long[] data, Shape shape) { + return create(parent, LongBuffer.wrap(data), shape); + } + + /** + * Creates and initializes an instance of {@link NDArray} with specified {@link Shape} and byte + * array. + * + * @param parent the parent {@link MxResource} to manage this instance + * @param data the float array that needs to be set + * @param shape the {@link Shape} of the {@link NDArray} + * @return a new instance of {@link NDArray} + */ + public static NDArray create(MxResource parent, byte[] data, Shape shape) { + return create(parent, ByteBuffer.wrap(data), shape); + } + + /** + * Creates and initializes an instance of {@link NDArray} with specified {@link Shape} and + * boolean array. + * + * @param parent the parent {@link MxResource} to manage this instance + * @param data the boolean array that needs to be set + * @param shape the {@link Shape} of the {@link NDArray} + * @return a new instance of {@link NDArray} + */ + public static NDArray create(MxResource parent, boolean[] data, Shape shape) { + byte[] byteData = new byte[data.length]; + for (int i = 0; i < data.length; i++) { + byteData[i] = (byte) (data[i] ? 1 : 0); + } + return create(parent, ByteBuffer.wrap(byteData), shape, DataType.BOOLEAN); + } + + /** + * Creates and initializes a scalar {@link NDArray}. + * + * @param parent the parent {@link MxResource} to manage this instance + * @param data the float that needs to be set + * @return a new instance of {@link NDArray} + */ + public static NDArray create(MxResource parent, float data) { + return create(parent, new float[] {data}, new Shape()); + } + + /** + * Creates and initializes a scalar {@link NDArray}. + * + * @param parent the parent {@link MxResource} to manage this instance + * @param data the float data that needs to be set + * @return a new instance of {@link NDArray} + */ + public static NDArray create(MxResource parent, int data) { + return create(parent, new int[] {data}, new Shape()); + } + + /** + * Creates and initializes a scalar {@link NDArray}. + * + * @param parent the parent {@link MxResource} to manage this instance + * @param data the double data that needs to be set + * @return a new instance of {@link NDArray} + */ + public static NDArray create(MxResource parent, double data) { + return create(parent, new double[] {data}, new Shape()); + } + + /** + * Creates and initializes a scalar {@link NDArray}. + * + * @param parent the parent {@link MxResource} to manage this instance + * @param data the long data that needs to be set + * @return a new instance of {@link NDArray} + */ + public static NDArray create(MxResource parent, long data) { + return create(parent, new long[] {data}, new Shape()); + } + + /** + * Creates and initializes a scalar {@link NDArray}. + * + * @param parent the parent {@link MxResource} to manage this instance + * @param data the byte data that needs to be set + * @return a new instance of {@link NDArray} + */ + public static NDArray create(MxResource parent, byte data) { + return create(parent, new byte[] {data}, new Shape()); + } + + /** + * Creates and initializes a scalar {@link NDArray}. + * + * @param parent the parent {@link MxResource} to manage this instance + * @param data the boolean data that needs to be set + * @return a new instance of {@link NDArray} + */ + public static NDArray create(MxResource parent, boolean data) { + + return create(parent, new boolean[] {data}, new Shape()); + } + + /** + * Creates and initializes a 1D {@link NDArray}. + * + * @param parent the parent {@link MxResource} to manage this instance + * @param data the float array that needs to be set + * @return a new instance of {@link NDArray} + */ + public static NDArray create(MxResource parent, float[] data) { + return create(parent, data, new Shape(data.length)); + } + + /** + * Creates and initializes a 1D {@link NDArray}. + * + * @param parent the parent {@link MxResource} to manage this instance + * @param data the float array that needs to be set + * @return a new instance of {@link NDArray} + */ + public static NDArray create(MxResource parent, int[] data) { + return create(parent, data, new Shape(data.length)); + } + + /** + * Creates and initializes a 1D {@link NDArray}. + * + * @param parent the parent {@link MxResource} to manage this instance + * @param data the float array that needs to be set + * @return a new instance of {@link NDArray} + */ + public static NDArray create(MxResource parent, double[] data) { + return create(parent, data, new Shape(data.length)); + } + + /** + * Creates and initializes a 1D {@link NDArray}. + * + * @param parent the parent {@link MxResource} to manage this instance + * @param data the float array that needs to be set + * @return a new instance of {@link NDArray} + */ + public static NDArray create(MxResource parent, long[] data) { + return create(parent, data, new Shape(data.length)); + } + + /** + * Creates and initializes a 1D {@link NDArray}. + * + * @param parent the parent {@link MxResource} to manage this instance + * @param data the float array that needs to be set + * @return a new instance of {@link NDArray} + */ + public static NDArray create(MxResource parent, byte[] data) { + return create(parent, data, new Shape(data.length)); + } + + /** + * Creates and initializes a 1D {@link NDArray}. + * + * @param parent the parent {@link MxResource} to manage this instance + * @param data the bool array that needs to be set + * @return a new instance of {@link NDArray} + */ + public static NDArray create(MxResource parent, boolean[] data) { + return create(parent, data, new Shape(data.length)); + } + + /** + * Creates and initializes a 2D {@link NDArray}. + * + * @param parent the parent {@link MxResource} to manage this instance + * @param data the float array that needs to be set + * @return a new instance of {@link NDArray} + */ + public static NDArray create(MxResource parent, float[][] data) { + FloatBuffer buffer = FloatBuffer.allocate(data.length * data[0].length); + for (float[] d : data) { + buffer.put(d); + } + buffer.rewind(); + return create(parent, buffer, new Shape(data.length, data[0].length)); + } + + /** + * Creates and initializes a 2D {@link NDArray}. + * + * @param parent the parent {@link MxResource} to manage this instance + * @param data the float array that needs to be set + * @return a new instance of {@link NDArray} + */ + public static NDArray create(MxResource parent, int[][] data) { + IntBuffer buffer = IntBuffer.allocate(data.length * data[0].length); + for (int[] d : data) { + buffer.put(d); + } + buffer.rewind(); + return create(parent, buffer, new Shape(data.length, data[0].length)); + } + + /** + * Creates and initializes a 2D {@link NDArray}. + * + * @param parent the parent {@link MxResource} to manage this instance + * @param data the float array that needs to be set + * @return a new instance of {@link NDArray} + */ + public static NDArray create(MxResource parent, double[][] data) { + DoubleBuffer buffer = DoubleBuffer.allocate(data.length * data[0].length); + for (double[] d : data) { + buffer.put(d); + } + buffer.rewind(); + return create(parent, buffer, new Shape(data.length, data[0].length)); + } + + /** + * Creates and initializes a 2-D {@link NDArray}. + * + * @param parent the parent {@link MxResource} to manage this instance + * @param data the float array that needs to be set + * @return a new instance of {@link NDArray} + */ + public static NDArray create(MxResource parent, long[][] data) { + LongBuffer buffer = LongBuffer.allocate(data.length * data[0].length); + for (long[] d : data) { + buffer.put(d); + } + buffer.rewind(); + return create(parent, buffer, new Shape(data.length, data[0].length)); + } + + /** + * Creates and initializes a 2-D {@link NDArray}. + * + * @param parent the parent {@link MxResource} instance + * @param data the float array that needs to be set + * @return a new instance of {@link NDArray} + */ + public static NDArray create(MxResource parent, byte[][] data) { + ByteBuffer buffer = ByteBuffer.allocate(data.length * data[0].length); + for (byte[] d : data) { + buffer.put(d); + } + buffer.rewind(); + return create(parent, buffer, new Shape(data.length, data[0].length)); + } + + /** + * Creates and initializes a 2-D {@link NDArray}. + * + * @param parent the parent {@link MxResource} instance + * @param data the boolean array that needs to be set + * @return a new instance of {@link NDArray} + */ + public static NDArray create(MxResource parent, boolean[][] data) { + ByteBuffer buffer = ByteBuffer.allocate(data.length * data[0].length); + for (boolean[] d : data) { + for (boolean b : d) { + buffer.put((byte) (b ? 1 : 0)); + } + } + buffer.rewind(); + return create(parent, buffer, new Shape(data.length, data[0].length), DataType.BOOLEAN); + } + + /** + * Creates and initializes a {@link NDArray} with specified {@link Shape}. + * + *

{@link DataType} of the MxNDArray will determined by type of Buffer. + * + * @param parent the parent {@link MxResource} instance + * @param data the data to initialize the {@code MxNDArray} + * @param shape the {@link Shape} of the {@link NDArray} + * @return a new instance of {@link NDArray} + */ + static NDArray create(MxResource parent, Buffer data, Shape shape) { + DataType dataType = DataType.fromBuffer(data); + return create(parent, data, shape, dataType); + } + + static NDArray create(MxResource parent, Buffer data, Shape shape, DataType dataType) { + NDArray array = create(parent, shape, dataType, Device.defaultIfNull()); + array.set(data); + return array; + } + + /** + * Returns the name of this {@code NDArray}. + * + * @return the name of this {@code NDArray} + */ + public String getName() { + return name; + } + + /** + * Sets the name of this {@code NDArray}. + * + * @param name of the {@code NDArray} + */ + public void setName(String name) { + this.name = name; + } + + /** + * Returns the {@link DataType} of this {@code NDArray}. + * + * @return the {@link DataType} of this {@code NDArray} + */ + public DataType getDataType() { + if (this.dataType == null) { + this.dataType = JnaUtils.getDataTypeOfNdArray(getHandle()); + } + return this.dataType; + } + + /** {@inheritDoc} */ + @Override + public Device getDevice() { + if (this.device == null) { + this.device = JnaUtils.getDeviceOfNdArray(getHandle()); + } + return this.device; + } + + /** + * Returns the {@link Shape} of this {@code NDArray}. + * + * @return the {@link Shape} of this {@code NDArray} + */ + public Shape getShape() { + if (this.shape == null) { + this.shape = JnaUtils.getShapeOfNdArray(getHandle()); + } + return this.shape; + } + + /** + * Returns the {@link SparseFormat} of this {@code NDArray}. + * + * @return the {@link SparseFormat} of this {@code NDArray} + */ + public SparseFormat getSparseFormat() { + if (this.sparseFormat == null) { + this.sparseFormat = JnaUtils.getStorageType(getHandle()); + } + return this.sparseFormat; + } + + /** + * Returns the version of this {@code NDArray}. + * + * @return the version of this {@code NDArray} + */ + public Integer getVersion() { + if (this.version == null) { + this.version = JnaUtils.getVersion(); + } + return this.version; + } + + private NDArray duplicate(Shape shape, DataType dataType, Device device, String name) { + // TODO get copy parameter + NDArray array = create(getParent(), shape, dataType, device); + array.setName(name); + copyTo(array); + return array; + } + + /** + * Returns a copy of this {@code NDArray}. + * + * @return a copy of this {@code NDArray} + */ + NDArray duplicate() { + NDArray array = create(getParent(), getShape(), getDataType(), getDevice()); + array.setName(getName()); + copyTo(array); + return array; + } + + /** + * Moves this {@code NDArray} to a different {@link Device}. + * + * @param device the {@link Device} to be set + * @param copy set {@code true} if you want to return a copy of the Existing {@code NDArray} + * @return the result {@code NDArray} with the new {@link Device} + */ + public NDArray toDevice(Device device, boolean copy) { + if (device.equals(getDevice()) && !copy) { + return this; + } + return duplicate(getShape(), getDataType(), device, getName()); + } + + /** + * Converts this {@code NDArray} to a different {@link DataType}. + * + * @param dataType the {@link DataType} to be set + * @param copy set {@code true} if you want to return a copy of the Existing {@code NDArray} + * @return the result {@code NDArray} with the new {@link DataType} + */ + public NDArray toType(DataType dataType, boolean copy) { + if (dataType.equals(getDataType()) && !copy) { + return this; + } + return duplicate(getShape(), dataType, getDevice(), getName()); + } + + /** + * Creates an instance of {@link NDArray} with specified {@link Shape} filled with zeros. + * + * @param shape the {@link Shape} of the {@link NDArray} + * @param dataType the {@link DataType} of the {@link NDArray} + * @return a new instance of {@link NDArray} + * @see #zeros(Shape, DataType, Device) + */ + public NDArray zeros(Shape shape, DataType dataType) { + return fill("_npi_zeros", shape, dataType); + } + + /** + * Creates an instance of {@link NDArray} with the same {@link Shape} and {@link DataType} + * filled with zeros. + * + * @return a new instance of {@link NDArray} + * @see #zeros(Shape, DataType, Device) + */ + public NDArray zeros() { + return zeros(getShape(), getDataType()); + } + + /** + * Creates an instance of {@link NDArray} with specified {@link Device}, {@link Shape}, and + * {@link DataType} filled with zeros. + * + * @param shape the {@link Shape} of the {@link NDArray} + * @param dataType the {@link DataType} of the {@link NDArray} + * @param device the {@link Device} of the {@link NDArray} + * @return a new instance of {@link NDArray} + */ + NDArray zeros(Shape shape, DataType dataType, Device device) { + if (device == null || device.equals(getDevice())) { + return zeros(shape, dataType); + } + return zeros(shape, dataType); + } + + private NDArray createGradient(SparseFormat format) { + try (NDArray zeros = this.zeros(getShape(), getDataType(), getDevice())) { + return zeros.toSparse(format); + } + } + + private NDArray fill(String opName, Shape shape, DataType dataType) { + OpParams params = new OpParams(); + if (shape == null) { + throw new IllegalArgumentException("Shape is required for " + opName.substring(1)); + } + params.addParam("shape", shape); + params.setDevice(device); + params.setDataType(dataType); + return invoke(getParent(), opName, params); + } + + /** + * Creates an instance of {@link NDArray} with specified {@link Shape} filled with ones. + * + * @param parent the parent {@link MxResource} + * @param shape the {@link Shape} of the {@link NDArray} + * @param dataType the {@link DataType} of the {@link NDArray} + * @param device the {@link Device} of the {@link NDArray} + * @return a new instance of {@link NDArray} + */ + public static NDArray ones(MxResource parent, Shape shape, DataType dataType, Device device) { + return create(parent, shape, dataType, device).ones(); + } + + /** + * Creates an instance of {@link NDArray} with specified {@link Shape} filled with ones. + * + * @param shape the {@link Shape} of the {@link NDArray} + * @param dataType the {@link DataType} of the {@link NDArray} + * @return a new instance of {@link NDArray} + */ + NDArray ones(Shape shape, DataType dataType) { + return fill("_npi_ones", shape, dataType); + } + + /** + * Creates an instance of {@link NDArray} with same {@link Shape} and {@link DataType} filled + * with ones. + * + * @return a new instance of {@link NDArray} + */ + public NDArray ones() { + return ones(getShape(), getDataType()); + } + /** + * Creates an instance of {@link NDArray} with specified {@link Shape} filled with ones. + * + * @param shape the {@link Shape} of the {@link NDArray} + * @return a new instance of {@link NDArray} + */ + NDArray ones(Shape shape) { + return ones(shape, DataType.FLOAT32); + } + + /** + * Creates an instance of {@link NDArray} with specified {@link Device}, {@link Shape}, and + * {@link DataType} filled with ones. + * + * @param shape the {@link Shape} of the {@link NDArray} + * @param dataType the {@link DataType} of the {@link NDArray} + * @param device the {@link Device} of the {@link NDArray} + * @return a new instance of {@link NDArray} + */ + NDArray ones(Shape shape, DataType dataType, Device device) { + if (device == null || device.equals(getDevice())) { + return ones(shape, dataType); + } + return create(getParent(), shape, dataType, device).ones(); + } + + /** + * Returns the gradient {@code NDArray} attached to this {@code NDArray}. + * + * @return the gradient {@code NDArray} + * @throws IllegalStateException when hasGradient is false + */ + public NDArray getGradient() { + if (!hasGradient()) { + throw new IllegalStateException( + "No gradient attached to this MxNDArray, please call array.requiredGradient()" + + "on your MxNDArray or block.setInitializer() on your Block"); + } + Pointer pointer = JnaUtils.getGradient(getHandle()); + return create(getParent(), pointer); + } + + /** + * Returns true if the gradient calculation is required for this {@code NDArray}. + * + * @return true if the gradient calculation is required for this {@code NDArray} else false + */ + public boolean hasGradient() { + if (hasGradient == null) { + Pointer pointer = JnaUtils.getGradient(getHandle()); + hasGradient = pointer != null; + } + return hasGradient; + } + + /** + * Returns an NDArray equal to this that stop gradient propagation through it. + * + * @return an NDArray equal to this that stops gradient propagation through it + */ + public NDArray stopGradient() { + Pointer pointer = JnaUtils.detachGradient(getHandle()); + return create(getParent(), pointer); + } + + /** + * Converts this {@code NDArray} to a String array. + * + *

This method is only applicable to the String typed NDArray and not for printing purpose + * + * @return Array of Strings + */ + public String[] toStringArray() { + throw new UnsupportedOperationException("String MxNDArray is not supported!"); + } + + /** + * Converts this {@code NDArray} to a ByteBuffer. + * + * @return a ByteBuffer + */ + public ByteBuffer toByteBuffer() { + if (getSparseFormat() != SparseFormat.DENSE) { + throw new IllegalStateException("Require Dense MxNDArray, actual " + getSparseFormat()); + } + Shape sh = getShape(); + DataType dType = getDataType(); + long product = sh.size(); + long len = dType.getNumOfBytes() * product; + ByteBuffer bb = NDSerializer.allocateDirect(Math.toIntExact(len)); + Pointer pointer = Native.getDirectBufferPointer(bb); + JnaUtils.syncCopyToCPU(getHandle(), pointer, Math.toIntExact(product)); + return bb; + } + + /** + * Returns the total number of elements in this {@code MxNDArray}. + * + * @return the number of elements in this {@code MxNDArray} + */ + long size() { + return getShape().size(); + } + + long size(int axis) { + return getShape().size(axis); + } + + /** + * Sets this {@code NDArray} value from {@link Buffer}. + * + * @param data the input buffered data + */ + public void set(Buffer data) { + int size = Math.toIntExact(size()); + if (data.remaining() < size) { + throw new IllegalArgumentException( + "The MxNDArray size is: " + size + ", but buffer size is: " + data.remaining()); + } + if (data.isDirect()) { + JnaUtils.syncCopyFromCPU(getHandle(), data, size); + return; + } + + data.limit(size); + // int8, uint8, boolean use ByteBuffer, so need to explicitly input DataType + DataType inputType = DataType.fromBuffer(data); + validate(inputType); + + int numOfBytes = inputType.getNumOfBytes(); + ByteBuffer buf = NDSerializer.allocateDirect(size * numOfBytes); + + switch (inputType) { + case FLOAT32: + buf.asFloatBuffer().put((FloatBuffer) data); + break; + case FLOAT64: + buf.asDoubleBuffer().put((DoubleBuffer) data); + break; + case UINT8: + case INT8: + case BOOLEAN: + buf.put((ByteBuffer) data); + break; + case INT32: + buf.asIntBuffer().put((IntBuffer) data); + break; + case INT64: + buf.asLongBuffer().put((LongBuffer) data); + break; + case FLOAT16: + default: + throw new UnsupportedOperationException("data type is not supported!"); + } + buf.rewind(); + JnaUtils.syncCopyFromCPU(getHandle(), buf, size); + } + + private void validate(DataType inputType) { + if (getDataType() != inputType + && ((dataType != DataType.UINT8 && dataType != DataType.BOOLEAN) + || inputType != DataType.INT8)) { + // Infer DataType from Buffer always return INT8, make this two special case that + // allows set UINT8 and BOOL array with regular ByteBuffer. + throw new IllegalStateException( + "DataType mismatch, required: " + dataType + ", actual: " + inputType); + } + } + + /** + * Returns {@code true} if this {@code MxNDArray} is a scalar {@code MxNDArray} with empty + * {@link Shape}. + * + * @return {@code true} if this {@code MxNDArray} is a scalar {@code MxNDArray} with empty + * {@link Shape} + */ + boolean isScalar() { + return getShape().isScalar(); + } + + /** + * Returns {@code true} if all elements within this {@code NDArray} are non-zero or {@code + * true}. + * + * @return {@code true} if all elements within this {@code NDArray} are non-zero or {@code true} + */ + NDArray all() { + // result of sum operation is int64 now + return toType(DataType.BOOLEAN, false).sum().eq(size()); + } + + /** + * Deep-copies the current {@code NDArray} to the one passed in. + * + * @param ndArray this {@code NDArray} prepared to be copied to + */ + public void copyTo(NDArray ndArray) { + + Shape inShape = getShape(); + Shape destShape = ndArray.getShape(); + if (!Arrays.equals(inShape.getShape(), destShape.getShape())) { + throw new IllegalArgumentException( + "shape are diff. Required: " + destShape + ", Actual " + inShape); + } + JnaUtils.op("_npi_copyto").invoke(new NDArray[] {this}, new NDArray[] {ndArray}, null); + } + + NDArray booleanMask(NDArray index) { + return booleanMask(index, 0); + } + + /** + * Returns portion of this {@code NDArray} given the index boolean {@code NDArray} along given + * axis. + * + * @param index boolean {@code NDArray} mask + * @param axis an integer that represents the axis of {@code NDArray} to mask from + * @return the result {@code NDArray} + */ + public NDArray booleanMask(NDArray index, int axis) { + if (isScalar() || index.isScalar()) { + throw new IllegalArgumentException("booleanMask didn't support scalar!"); + } + // TODO remove reshape when MXNet numpy support multi-dim index + // and boolean MxNDArray reshape + Shape remainingDims = getShape().slice(index.getShape().dimension()); + // create a reshape array {-1, remainingDims} + long[] reshape = new long[remainingDims.dimension() + 1]; + reshape[0] = -1; + System.arraycopy(remainingDims.getShape(), 0, reshape, 1, remainingDims.dimension()); + OpParams params = new OpParams(); + params.addParam("axis", axis); + try (NDArray reshaped = this.reshape(new Shape(reshape)); + NDArray reshapedIndex = index.toType(DataType.INT32, false).reshape(-1); + NDArray result = + invoke( + getParent(), + "_npi_boolean_mask", + new NDArray[] {reshaped, reshapedIndex}, + params)) { + return result.reshape(reshape); + } + } + + /** + * Sets all elements outside the sequence to a constant value. + * + *

This function takes an n-dimensional input array of the form [batch_size, + * max_sequence_length, ....] and returns an array of the same shape. Parameter {@code + * sequenceLength} is used to handle variable-length sequences. sequence_length should be an + * input array of positive ints of dimension [batch_size]. + * + * @param sequenceLength used to handle variable-length sequences + * @param value the constant value to be set + * @return the result {@code NDArray} + */ + public NDArray sequenceMask(NDArray sequenceLength, float value) { + if (getShape().dimension() < 2 || getShape().isScalar() || getShape().hasZeroDimension()) { + throw new IllegalArgumentException( + "sequenceMask is not supported for MxNDArray with less than 2 dimensions"); + } + Shape expectedSequenceLengthShape = new Shape(getShape().get(0)); + if (!sequenceLength.getShape().equals(expectedSequenceLengthShape)) { + throw new IllegalArgumentException("SequenceLength must be of shape [batchSize]"); + } + OpParams params = new OpParams(); + params.add("value", value); + params.add("use_sequence_length", true); + params.add("axis", 1); + return invoke(getParent(), "_npx_sequence_mask", new NDList(this, sequenceLength), params) + .head(); + } + + /** + * Sets all elements outside the sequence to 0. + * + *

This function takes an n-dimensional input array of the form [batch_size, + * max_sequence_length, ....] and returns an array of the same shape. Parameter {@code + * sequenceLength} is used to handle variable-length sequences. sequence_length should be an + * input array of positive ints of dimension [batch_size]. + * + * @param sequenceLength used to handle variable-length sequences + * @return the result {@code NDArray} + */ + public NDArray sequenceMask(NDArray sequenceLength) { + return sequenceMask(sequenceLength, 0); + } + + /** + * Returns an {@code NDArray} of zeros with the same {@link Shape}, {@link DataType} and {@link + * SparseFormat} as the input {@code NDArray}. + * + * @return a {@code NDArray} filled with zeros + */ + public NDArray zerosLike() { + OpParams params = new OpParams(); + params.addParam("fill_value", 0); + return invoke(getParent(), "_npi_full_like", this, params); + } + + /** + * Returns an {@code NDArray} of ones with the same {@link Shape}, {@link DataType} and {@link + * SparseFormat} as the input {@code NDArray}. + * + * @return a {@code NDArray} filled with ones + */ + public NDArray onesLike() { + OpParams params = new OpParams(); + params.addParam("fill_value", 1); + return invoke(getParent(), "_npi_full_like", this, params); + } + + NDArray get(NDIndex index) { + return getNDArrayInternal().getIndexer().get(this, index); + } + + NDArray get(long... indices) { + return get(new NDIndex(indices)); + } + + NDArray getScalar(long... indices) { + NDArray value = get(new NDIndex(indices)); + if (value.size() != 1) { + throw new IllegalArgumentException("The supplied Index does not produce a scalar"); + } + return value; + } + + boolean getBoolean(long... indices) { + return getScalar(indices).toBooleanArray()[0]; + } + + /** + * Returns {@code true} if all elements in this {@code NDArray} are equal to the {@link Number}. + * + * @param number the number to compare + * @return the boolean result + */ + public boolean contentEquals(Number number) { + if (number == null) { + return false; + } + try (NDArray result = eq(number)) { + return result.all().getBoolean(); + } + } + + /** + * Returns {@code true} if all elements in this {@code NDArray} are equal to the other {@link + * NDArray}. + * + * @param other the other {@code NDArray} to compare + * @return the boolean result + */ + public boolean contentEquals(NDArray other) { + if (other == null || (!shapeEquals(other))) { + return false; + } + if (getDataType() != other.getDataType()) { + return false; + } + try (NDArray result = eq(other).toType(DataType.INT32, false)) { + return result.all().getBoolean(); + } + } + + /** + * Returns the boolean {@code NDArray} for element-wise "Equals" comparison. + * + * @param n the number to compare + * @return the boolean {@code NDArray} for element-wise "Equals" comparison + */ + public NDArray eq(Number n) { + OpParams params = new OpParams(); + params.add("scalar", n.toString()); + return invoke(getParent(), "_npi_equal_scalar", this, params); + } + + /** + * Returns the boolean {@code NDArray} for element-wise "Equals" comparison. + * + * @param other the {@code NDArray} to compare + * @return the boolean {@code NDArray} for element-wise "Equals" comparison + */ + public NDArray eq(NDArray other) { + return invoke(getParent(), "_npi_equal", new NDArray[] {this, other}, null); + } + + /** + * Returns the boolean {@code NDArray} for element-wise "Not equals" comparison. + * + * @param n the number to compare + * @return the boolean {@code NDArray} for element-wise "Not equals" comparison + */ + public NDArray neq(Number n) { + OpParams params = new OpParams(); + params.add("scalar", n.toString()); + return invoke(getParent(), "_npi_not_equal_scalar", this, params); + } + + /** + * Returns the boolean {@code NDArray} for element-wise "Not equals" comparison. + * + * @param other the {@code NDArray} to compare + * @return the boolean {@code NDArray} for element-wise "Not equals" comparison + */ + public NDArray neq(NDArray other) { + return invoke(getParent(), "_npi_not_equal", new NDArray[] {this, other}, null); + } + + /** + * Returns the boolean {@code NDArray} for element-wise "Greater" comparison. + * + * @param other the number to compare + * @return the boolean {@code NDArray} for element-wise "Greater" comparison + */ + public NDArray gt(Number other) { + OpParams params = new OpParams(); + params.add("scalar", other.toString()); + return invoke(getParent(), "_npi_greater_scalar", this, params); + } + + /** + * Returns the boolean {@code NDArray} for element-wise "Greater Than" comparison. + * + * @param other the {@code NDArray} to compare + * @return the boolean {@code NDArray} for element-wis "Greater Than" comparison + */ + public NDArray gt(NDArray other) { + return invoke(getParent(), "_npi_greater", new NDArray[] {this, other}, null); + } + + /** + * Returns the boolean {@code NDArray} for element-wise "Greater or equals" comparison. + * + * @param other the number to compare + * @return the boolean {@code NDArray} for element-wise "Greater or equals" comparison + */ + public NDArray gte(Number other) { + OpParams params = new OpParams(); + params.add("scalar", other.toString()); + return invoke(getParent(), "_npi_greater_equal_scalar", this, params); + } + + /** + * Returns the boolean {@code NDArray} for element-wise "Greater or equals" comparison. + * + * @param other the number to compare + * @return the boolean {@code NDArray} for "Greater or equals" comparison + */ + public NDArray gte(NDArray other) { + return invoke(getParent(), "_npi_greater_equal", new NDArray[] {this, other}, null); + } + + /** + * Returns the boolean {@code NDArray} for element-wise "Less" comparison. + * + * @param other the number to compare + * @return the boolean {@code NDArray} for element-wise "Less" comparison + */ + public NDArray lt(Number other) { + OpParams params = new OpParams(); + params.add("scalar", other.toString()); + return invoke(getParent(), "_npi_less_scalar", this, params); + } + + /** + * Returns the boolean {@code NDArray} for element-wise "Less" comparison. + * + * @param other the {@code NDArray} to compare + * @return the boolean {@code NDArray} for element-wise "Less" comparison + */ + public NDArray lt(NDArray other) { + return invoke(getParent(), "_npi_less", new NDArray[] {this, other}, null); + } + + /** + * Returns the boolean {@code NDArray} for element-wise "Less or equals" comparison. + * + * @param other the number to compare + * @return the boolean {@code NDArray} for element-wise "Less or equals" comparison + */ + public NDArray lte(Number other) { + OpParams params = new OpParams(); + params.add("scalar", other.toString()); + return invoke(getParent(), "_npi_less_equal_scalar", this, params); + } + + /** + * Returns the boolean {@code NDArray} for element-wise "Less or equals" comparison. + * + * @param other the {@code NDArray} to compare + * @return the boolean {@code NDArray} for element-wise "Less or equals" comparison + */ + public NDArray lte(NDArray other) { + return invoke(getParent(), "_npi_less_equal", new NDArray[] {this, other}, null); + } + + /** + * Adds a number to this {@code NDArray} element-wise. + * + * @param n the number to add + * @return the result {@code NDArray} + */ + public NDArray add(Number n) { + OpParams params = new OpParams(); + params.add("scalar", n.toString()); + return invoke(getParent(), "_npi_add_scalar", this, params); + } + + /** + * Adds other {@code NDArray}s to this {@code NDArray} element-wise. + * + * @param other the other {@code NDArray}s to add + * @return the result {@code NDArray} + * @throws IllegalArgumentException others arrays must have at least one element + */ + public NDArray add(NDArray other) { + return invoke(getParent(), "_npi_add", new NDArray[] {this, other}, null); + } + + /** + * Subtracts a number from this {@code NDArray} element-wise. + * + * @param n the number to subtract from + * @return the result {@code NDArray} + */ + public NDArray sub(Number n) { + OpParams params = new OpParams(); + params.add("scalar", n.toString()); + return invoke(getParent(), "_npi_subtract_scalar", this, params); + } + + /** + * Subtracts the other {@code NDArray} from this {@code NDArray} element-wise. + * + * @param other the other {@code NDArray} to subtract from + * @return the result {@code NDArray} + */ + public NDArray sub(NDArray other) { + return invoke(getParent(), "_npi_subtract", new NDArray[] {this, other}, null); + } + + /** + * Multiplies this {@code NDArray} by a number element-wise. + * + * @param n the number to multiply by + * @return the result {@code NDArray} + */ + public NDArray mul(Number n) { + OpParams params = new OpParams(); + params.add("scalar", n.toString()); + return invoke(getParent(), "_npi_multiply_scalar", this, params); + } + + /** + * Multiplies this {@code NDArray} by other {@code NDArray}s element-wise. + * + * @param other the other {@code NDArray}s to multiply by + * @return the result {@code NDArray} + * @throws IllegalArgumentException others arrays must have at least one element + */ + public NDArray mul(NDArray other) { + return invoke(getParent(), "_npi_multiply", new NDArray[] {this, other}, null); + } + + /** + * Divides this {@code NDArray} by a number element-wise. + * + * @param n the number to divide by + * @return the result {@code NDArray} + */ + public NDArray div(Number n) { + OpParams params = new OpParams(); + params.add("scalar", n.toString()); + return invoke(getParent(), "_npi_true_divide_scalar", this, params); + } + + /** + * Divides this {@code NDArray} by the other {@code NDArray} element-wise. + * + * @param other the other {@code NDArray} to divide by + * @return the result {@code NDArray} + */ + public NDArray div(NDArray other) { + return invoke(getParent(), "_npi_true_divide", new NDArray[] {this, other}, null); + } + + /** + * Returns element-wise remainder of division. + * + * @param n the divisor number + * @return the result {@code NDArray} + */ + public NDArray mod(Number n) { + OpParams params = new OpParams(); + params.add("scalar", n.toString()); + return invoke(getParent(), "_npi_mod_scalar", this, params); + } + + /** + * Returns element-wise remainder of division. + * + * @param other the divisor {@code NDArray} + * @return the result {@code NDArray} + */ + public NDArray mod(NDArray other) { + return invoke(getParent(), "_npi_mod", new NDArray[] {this, other}, null); + } + + /** + * Takes the power of this {@code NDArray} with a number element-wise. + * + * @param n the number to take the power with + * @return the result {@code NDArray} + */ + public NDArray pow(Number n) { + OpParams params = new OpParams(); + params.add("scalar", n.toString()); + return invoke(getParent(), "_npi_power_scalar", this, params); + } + + /** + * Takes the power of this {@code NDArray} with the other {@code NDArray} element-wise. + * + * @param other the other {@code NDArray} to take the power with + * @return the result {@code NDArray} + */ + public NDArray pow(NDArray other) { + return invoke(getParent(), "_npi_power", new NDArray[] {this, other}, null); + } + + /** + * Adds a number to this {@code NDArray} element-wise in place. + * + * @param n the number to add + * @return the result {@code NDArray} + */ + public NDArray addi(Number n) { + OpParams params = new OpParams(); + params.add("scalar", n.toString()); + invoke("_npi_add_scalar", new NDArray[] {this}, new NDArray[] {this}, params); + return this; + } + + /** + * Adds other {@code NDArray}s to this {@code NDArray} element-wise in place. + * + * @param other the other {@code NDArray}s to add + * @return the result {@code NDArray} + */ + public NDArray addi(NDArray other) { + invoke("_npi_add", new NDArray[] {this, other}, new NDArray[] {this}, null); + return this; + } + + /** + * Subtracts a number from this {@code NDArray} element-wise in place. + * + * @param n the number to subtract + * @return the result {@code NDArray} + */ + public NDArray subi(Number n) { + OpParams params = new OpParams(); + params.add("scalar", n.toString()); + invoke("_npi_subtract_scalar", new NDArray[] {this}, new NDArray[] {this}, params); + return this; + } + + /** + * Subtracts the other {@code NDArray} from this {@code NDArray} element-wise in place. + * + * @param other the other {@code NDArray} to subtract from + * @return the result {@code NDArray} + */ + public NDArray subi(NDArray other) { + invoke("_npi_subtract", new NDArray[] {this, other}, new NDArray[] {this}, null); + return this; + } + + /** + * Multiplies this {@code NDArray} by a number element-wise in place. + * + * @param n the number to multiply by + * @return the result {@code NDArray} + */ + public NDArray muli(Number n) { + OpParams params = new OpParams(); + params.add("scalar", n.toString()); + invoke("_npi_multiply_scalar", new NDArray[] {this}, new NDArray[] {this}, params); + return this; + } + + /** + * Multiplies this {@code NDArray} by other {@code NDArray} element-wise in place. + * + * @param other the other NDArrays to multiply with + * @return the result {@code NDArray} + */ + public NDArray muli(NDArray other) { + invoke("_npi_multiply", new NDArray[] {this, other}, new NDArray[] {this}, null); + return this; + } + + /** + * Divides this {@code NDArray} by a number element-wise in place. + * + * @param n the number to divide values by + * @return the array after applying division operation + */ + public NDArray divi(Number n) { + OpParams params = new OpParams(); + params.add("scalar", n.toString()); + invoke("_npi_true_divide_scalar", new NDArray[] {this}, new NDArray[] {this}, params); + return this; + } + + /** + * Divides this {@code NDArray} by the other {@code NDArray} element-wise in place. + * + * @param other the other {@code NDArray} to divide by + * @return the result of the divide + */ + public NDArray divi(NDArray other) { + invoke("_npi_true_divide", new NDArray[] {this, other}, new NDArray[] {this}, null); + return this; + } + + /** + * Returns element-wise remainder of division in place. + * + * @param n the divisor number + * @return the result {@code NDArray} + */ + public NDArray modi(Number n) { + OpParams params = new OpParams(); + params.add("scalar", n.toString()); + invoke("_npi_mod_scalar", new NDArray[] {this}, new NDArray[] {this}, params); + return this; + } + + /** + * Returns in place element-wise remainder of division in place. + * + * @param other the divisor {@code NDArray} + * @return the result of the divide + */ + public NDArray modi(NDArray other) { + invoke("_npi_mod", new NDArray[] {this, other}, new NDArray[] {this}, null); + return this; + } + + /** + * Takes the power of this {@code NDArray} with a number element-wise in place. + * + * @param n the number to raise the power to + * @return the result {@code NDArray} + */ + public NDArray powi(Number n) { + OpParams params = new OpParams(); + params.add("scalar", n.toString()); + invoke("_npi_power_scalar", new NDArray[] {this}, new NDArray[] {this}, params); + return this; + } + + /** + * Takes the power of this {@code NDArray} with the other {@code NDArray} element-wise in place. + * + * @param other the other {@code NDArray} to take the power with + * @return the result {@code NDArray} + */ + public NDArray powi(NDArray other) { + invoke("_npi_power", new NDArray[] {this, other}, new NDArray[] {this}, null); + return this; + } + + /** + * Returns the element-wise sign. + * + * @return the result {@code NDArray} + */ + public NDArray sign() { + return invoke(getParent(), "_npi_sign", this, null); + } + + /** + * Returns the element-wise sign in-place. + * + * @return the result {@code NDArray} + */ + public NDArray signi() { + invoke("_npi_sign", new NDArray[] {this}, new NDArray[] {this}, null); + return this; + } + + /** + * Returns the maximum of this {@code NDArray} and a number element-wise. + * + * @param n the number to be compared + * @return the maximum of this {@code NDArray} and a number element-wise + */ + public NDArray maximum(Number n) { + OpParams params = new OpParams(); + params.add("scalar", n.toString()); + return invoke(getParent(), "_npi_maximum_scalar", this, params); + } + + /** + * Returns the maximum of this {@code NDArray} and the other {@code NDArray} element-wise. + * + * @param other the {@code NDArray} to be compared + * @return the maximum of this {@code NDArray} and the other {@code NDArray} element-wise + */ + public NDArray maximum(NDArray other) { + return invoke(getParent(), "_npi_maximum", new NDArray[] {this, other}, null); + } + + /** + * Returns the minimum of this {@code NDArray} and a number element-wise. + * + * @param n the number to be compared + * @return the minimum of this {@code NDArray} and a number element-wise + */ + public NDArray minimum(Number n) { + OpParams params = new OpParams(); + params.add("scalar", n.toString()); + return invoke(getParent(), "_npi_minimum_scalar", this, params); + } + + /** + * Returns the maximum of this {@code NDArray} and the other {@code NDArray} element-wise. + * + * @param other the {@code NDArray} to be compared + * @return the maximum of this {@code NDArray} and the other {@code NDArray} element-wise + */ + public NDArray minimum(NDArray other) { + return invoke(getParent(), "_npi_minimum", new NDArray[] {this, other}, null); + } + + /** + * Returns the numerical negative {@code NDArray} element-wise. + * + * @return the result {@code NDArray} + */ + public NDArray neg() { + return invoke(getParent(), "_npi_negative", this, null); + } + + /** + * Returns the numerical negative {@code NDArray} element-wise in place. + * + * @return the result {@code NDArray} + */ + public NDArray negi() { + invoke("_npi_negative", new NDArray[] {this}, new NDArray[] {this}, null); + return this; + } + + /** + * Returns the absolute value of this {@code NDArray} element-wise. + * + * @return the result {@code NDArray} + */ + public NDArray abs() { + return invoke(getParent(), "_npi_absolute", this, null); + } + + /** + * Returns the square of this {@code NDArray} element-wise. + * + * @return the result {@code NDArray} + */ + public NDArray square() { + return invoke(getParent(), "_npi_square", this, null); + } + + /** + * Returns the square root of this {@code NDArray} element-wise. + * + * @return the result {@code NDArray} + */ + public NDArray sqrt() { + return invoke(getParent(), "_npi_sqrt", this, null); + } + + /** + * Returns the cube-root of this {@code NDArray} element-wise. + * + * @return the result {@code NDArray} + */ + public NDArray cbrt() { + return invoke(getParent(), "_npi_cbrt", this, null); + } + + /** + * Returns the floor of this {@code NDArray} element-wise. + * + * @return the result {@code NDArray} + */ + public NDArray floor() { + return invoke(getParent(), "_npi_floor", this, null); + } + + /** + * Returns the ceiling of this {@code NDArray} element-wise. + * + * @return the result {@code NDArray} + */ + public NDArray ceil() { + return invoke(getParent(), "_npi_ceil", this, null); + } + + /** + * Returns the round of this {@code NDArray} element-wise. + * + * @return the result {@code NDArray} + */ + public NDArray round() { + return invoke(getParent(), "round", this, null); + } + + /** + * Returns the truncated value of this {@code NDArray} element-wise. + * + * @return the result {@code NDArray} + */ + public NDArray trunc() { + return invoke(getParent(), "_npi_trunc", this, null); + } + + /** + * Returns the exponential value of this {@code NDArray} element-wise. + * + * @return the result {@code NDArray} + */ + public NDArray exp() { + return invoke(getParent(), "_npi_exp", this, null); + } + + /** + * Returns the natural logarithmic value of this {@code NDArray} element-wise. + * + * @return the result {@code NDArray} + */ + public NDArray log() { + return invoke(getParent(), "_npi_log", this, null); + } + + /** + * Returns the base 10 logarithm of this {@code NDArray} element-wise. + * + * @return the result {@code NDArray} + */ + public NDArray log10() { + return invoke(getParent(), "_npi_log10", this, null); + } + + /** + * Returns the base 2 logarithm of this {@code NDArray} element-wise. + * + * @return the result {@code NDArray} + */ + public NDArray log2() { + return invoke(getParent(), "_npi_log2", this, null); + } + + /** + * Returns the trigonometric sine of this {@code NDArray} element-wise. + * + * @return the result {@code NDArray} + */ + public NDArray sin() { + return invoke(getParent(), "_npi_sin", this, null); + } + + /** + * Returns the trigonometric cosine of this {@code NDArray} element-wise. + * + * @return the result {@code NDArray} + */ + public NDArray cos() { + return invoke(getParent(), "_npi_cos", this, null); + } + + /** + * Returns the trigonometric tangent of this {@code NDArray} element-wise. + * + * @return the result {@code NDArray} + */ + public NDArray tan() { + return invoke(getParent(), "_npi_tan", this, null); + } + + /** + * Returns the inverse trigonometric sine of this {@code NDArray} element-wise. + * + * @return the result {@code NDArray} + */ + public NDArray asin() { + return invoke(getParent(), "_npi_arcsin", this, null); + } + + /** + * Returns the inverse trigonometric cosine of this {@code NDArray} element-wise. + * + * @return the result {@code NDArray} + */ + public NDArray acos() { + return invoke(getParent(), "_npi_arccos", this, null); + } + + /** + * Returns the inverse trigonometric tangent of this {@code NDArray} element-wise. + * + * @return the result {@code NDArray} + */ + public NDArray atan() { + return invoke(getParent(), "_npi_arctan", this, null); + } + + /** + * Returns the hyperbolic sine of this {@code NDArray} element-wise. + * + * @return the result {@code NDArray} + */ + public NDArray sinh() { + return invoke(getParent(), "_npi_sinh", this, null); + } + + /** + * Returns the hyperbolic cosine of this {@code NDArray} element-wise. + * + * @return the result {@code NDArray} + */ + public NDArray cosh() { + return invoke(getParent(), "_npi_cosh", this, null); + } + + /** + * Returns the hyperbolic tangent of this {@code NDArray} element-wise. + * + * @return the result {@code NDArray} + */ + public NDArray tanh() { + return invoke(getParent(), "_npi_tanh", this, null); + } + + /** + * Returns the inverse hyperbolic sine of this {@code NDArray} element-wise. + * + * @return the result {@code NDArray} + */ + public NDArray asinh() { + return invoke(getParent(), "_npi_arcsinh", this, null); + } + + /** + * Returns the inverse hyperbolic cosine of this {@code NDArray} element-wise. + * + * @return the result {@code NDArray} + */ + public NDArray acosh() { + return invoke(getParent(), "_npi_arccosh", this, null); + } + + /** + * Returns the inverse hyperbolic tangent of this {@code NDArray} element-wise. + * + * @return the result {@code NDArray} + */ + public NDArray atanh() { + return invoke(getParent(), "_npi_arctanh", this, null); + } + + /** + * Converts this {@code NDArray} from radians to degrees element-wise. + * + * @return the result {@code NDArray} + */ + public NDArray toDegrees() { + return invoke(getParent(), "_npi_degrees", this, null); + } + + /** + * Converts this {@code NDArray} from degrees to radians element-wise. + * + * @return the result {@code NDArray} + */ + public NDArray toRadians() { + return invoke(getParent(), "_npi_radians", this, null); + } + + /** + * Returns the maximum of this {@code NDArray}. + * + * @return the maximum of this {@code NDArray} + */ + public NDArray max() { + return invoke(getParent(), "_np_max", this, null); + } + + /** + * Returns the maximum of this {@code NDArray} along given axes. + * + * @param axes the axes along which to operate + * @return the maximum of this {@code NDArray} with the specified axes removed from the Shape + * containing the max + * @see NDArray#max(int[], boolean) + */ + public NDArray max(int[] axes) { + OpParams params = new OpParams(); + params.addTupleParam("axis", axes); + return invoke(getParent(), "_np_max", this, params); + } + + /** + * Returns the maximum of this {@code NDArray} along given axes. + * + * @param axes the axes along which to operate + * @param keepDims {@code true} to keep the specified axes as size 1 in the output array, {@code + * false} to squeeze the values out of the output array. + * @return the maximum of this {@code NDArray} + */ + public NDArray max(int[] axes, boolean keepDims) { + OpParams params = new OpParams(); + params.addTupleParam("axis", axes); + params.addParam("keepdims", keepDims); + return invoke(getParent(), "_np_max", this, params); + } + + /** + * Returns the minimum of this {@code NDArray}. + * + * @return the minimum of this {@code NDArray} + */ + public NDArray min() { + return invoke(getParent(), "_np_min", this, null); + } + + /** + * Returns the minimum of this {@code NDArray} along given axes. + * + * @param axes the axes along which to operate + * @return the minimum of this {@code NDArray} with the specified axes removed from the Shape + * containing the min + * @see NDArray#min(int[], boolean) + */ + public NDArray min(int[] axes) { + return min(axes, false); + } + + /** + * Returns the minimum of this {@code NDArray} along given axes. + * + * @param axes the axes along which to operate + * @param keepDims {@code true} to keep the specified axes as size 1 in the output array, {@code + * false} to squeeze the values out of the output array + * @return the minimum of this {@code NDArray} + */ + public NDArray min(int[] axes, boolean keepDims) { + OpParams params = new OpParams(); + params.addTupleParam("axis", axes); + params.addParam("keepdims", keepDims); + return invoke(getParent(), "_np_min", this, params); + } + + /** + * Returns the sum of this {@code NDArray}. + * + * @return the sum of this {@code NDArray} + */ + public NDArray sum() { + // TODO current windows doesn't support boolean MxNDArray + if (System.getProperty("os.name").toLowerCase().contains("win")) { + DataType target = getDataType(); + if (!target.isFloating()) { + try (NDArray thisArr = toType(DataType.FLOAT32, false)) { + if (target == DataType.BOOLEAN) { + target = DataType.INT64; + } + try (NDArray array = invoke(getParent(), "_np_sum", thisArr, null)) { + return array.toType(target, false); + } + } + } + } + return invoke(getParent(), "_np_sum", this, null); + } + + /** + * Returns the sum of this {@code NDArray} along given axes. + * + * @param axes the axes along which to operate + * @return the sum of this {@code NDArray} with the specified axes removed from the Shape + * containing the sum + * @see NDArray#sum(int[], boolean) + */ + public NDArray sum(int[] axes) { + return sum(axes, false); + } + + /** + * Returns the sum of this {@code NDArray} along given axes. + * + * @param axes the axes along which to operate + * @param keepDims {@code true} to keep the specified axes as size 1 in the output array, {@code + * false} to squeeze the values out of the output array + * @return the sum of this {@code NDArray} + */ + public NDArray sum(int[] axes, boolean keepDims) { + OpParams params = new OpParams(); + params.addTupleParam("axis", axes); + params.addParam("keepdims", keepDims); + return invoke(getParent(), "_np_sum", this, params); + } + + /** + * Returns the product of this {@code NDArray}. + * + * @return the product of this {@code NDArray} + */ + public NDArray prod() { + return invoke(getParent(), "_np_prod", this, null); + } + + /** + * Returns the product of this {@code NDArray} elements over the given axes. + * + * @param axes the axes along which to operate + * @return the product of this {@code NDArray} with the specified axes removed from the Shape + * containing the prod + * @see NDArray#prod(int[], boolean) + */ + NDArray prod(int[] axes) { + return prod(axes, false); + } + + /** + * Returns the product of this {@code NDArray} elements over the given axes. + * + * @param axes the axes along which to operate + * @param keepDims {@code true} to keep the specified axes as size 1 in the output array, {@code + * false} to squeeze the values out of the output array + * @return the product of this {@code NDArray} + */ + public NDArray prod(int[] axes, boolean keepDims) { + OpParams params = new OpParams(); + params.addTupleParam("axis", axes); + params.addParam("keepdims", keepDims); + return invoke(getParent(), "_np_prod", this, params); + } + + /** + * Returns the average of this {@code NDArray}. + * + * @return the average of this {@code NDArray} + */ + public NDArray mean() { + return invoke(getParent(), "_npi_mean", this, null); + } + + /** + * Returns the average of this {@code NDArray} along given axes. + * + * @param axes the axes along which to operate + * @return the average of this {@code NDArray} with the specified axes removed from the Shape + * containing the mean + * @see NDArray#mean(int[], boolean) + */ + public NDArray mean(int[] axes) { + return mean(axes, false); + } + + /** + * Returns the average of this {@code NDArray} along given axes. + * + * @param axes the axes along which to operate + * @param keepDims {@code true} to keep the specified axes as size 1 in the output array, {@code + * false} to squeeze the values out of the output array + * @return the average of this {@code NDArray} + */ + public NDArray mean(int[] axes, boolean keepDims) { + OpParams params = new OpParams(); + params.addTupleParam("axis", axes); + params.addParam("keepdims", keepDims); + return invoke(getParent(), "_npi_mean", this, params); + } + + /** + * Rotates an array by 90 degrees in the plane specified by axes. + * + * @param times Number of times the array is rotated by 90 degrees. + * @param axes The array is rotated in the plane defined by the axes. Axes must be different. + * @return the rotated NDArray + */ + public NDArray rotate90(int times, int[] axes) { + if (axes.length != 2) { + throw new IllegalArgumentException("Axes must be 2"); + } + OpParams params = new OpParams(); + params.addTupleParam("axes", axes); + params.addParam("k", times); + return invoke(getParent(), "_npi_rot90", this, params); + } + + /** + * Returns the sum along diagonals of this {@code NDArray}. + * + *

If this {@code NDArray} is 2-D, the sum along its diagonal with the given offset is + * returned, i.e., the sum of elements a[i,i+offset] for all i. If this {@code NDArray} has more + * than two dimensions, then the axes specified by axis1 and axis2 are used to determine the 2-D + * sub-arrays whose traces are returned. The {@link Shape} of the resulting array is the same as + * this {@code NDArray} with axis1 and axis2 removed. + * + * @param offset offset of the diagonal from the main diagonal. Can be both positive and + * negative. + * @param axis1 axes to be used as the first axis of the 2-D sub-arrays from which the diagonals + * should be taken + * @param axis2 axes to be used as the second axis of the 2-D sub-arrays from which the + * diagonals should be taken + * @return the sum along diagonals of this {@code NDArray} + */ + public NDArray trace(int offset, int axis1, int axis2) { + OpParams params = new OpParams(); + params.addParam("offset", offset); + params.addParam("axis1", axis1); + params.addParam("axis2", axis2); + return invoke(getParent(), "_np_trace", this, params); + } + + /** + * Returns the sum along diagonals of this {@code NDArray}. + * + *

If this {@code NDArray} is 2-D, the sum along its diagonal with the given offset is + * returned, i.e., the sum of elements a[i,i+offset] for all i. If this {@code NDArray} has more + * than two dimensions, then the axes specified by axis1 and axis2 are used to determine the 2-D + * sub-arrays whose traces are returned. The {@link Shape} of the resulting array is the same as + * this {@code NDArray} with axis1 and axis2 removed. + * + * @param offset offset of the diagonal from the main diagonal. Can be both positive and + * negative. + * @return the sum along diagonals of this {@code NDArray} + */ + public NDArray trace(int offset) { + return trace(offset, 0, 1); + } + + /** + * Splits this {@code NDArray} into multiple sub{@code NDArray}s given sections along the given + * axis. + * + * @param indices this {@code NDArray} will be divided into N (sections) equal arrays along axis + * @param axis the axis to split along + * @return an {@link NDList} with numOutputs {@code NDArray}s with {@link Shape} {@code + * (this.shape.axis /= axis) } + * @throws IllegalArgumentException thrown if the numOutputs does not equally divide the given + * axis + */ + public NDList split(long[] indices, int axis) { + if (indices.length == 0) { + return new NDList(this); + } + OpParams params = new OpParams(); + // follow the numpy behavior + if (indices[0] != 0) { + long[] tempIndices = new long[indices.length + 1]; + tempIndices[0] = 0; + System.arraycopy(indices, 0, tempIndices, 1, indices.length); + indices = tempIndices; + } + params.addTupleParam("indices", indices); + params.addParam("axis", axis); + params.addParam("squeeze_axis", false); + return invoke(getParent(), "_npi_split", new NDList(this), params); + } + + /** + * Flattens this {@code NDArray} into a 1-D {@code NDArray} in row-major order. + * + *

To flatten in column-major order, first transpose this {@code NDArray} + * + * @return a 1-D {@code NDArray} of equal size + */ + public NDArray flatten() { + return reshape(new Shape(Math.toIntExact(size()))); + } + + /** + * Reshapes this {@code NDArray} to the given {@link Shape}. + * + *

You can reshape it to match another NDArray by calling {@code a.reshape(b.getShape()) } + * + * @param shape the {@link Shape} to reshape into. Must have equal size to the current shape + * @return a reshaped {@code NDArray} + * @throws IllegalArgumentException thrown if the given {@link Shape} does not match the size of + * the current shape + */ + public NDArray reshape(Shape shape) { + OpParams params = new OpParams(); + params.addParam("newshape", shape); + return invoke(getParent(), "_np_reshape", this, params); + } + + /** + * Reshapes this {@code NDArray} to the given {@link Shape}. + * + * @param newShape the long array to reshape into. Must have equal size to the current shape + * @return a reshaped {@code NDArray} + * @throws IllegalArgumentException thrown if the given {@link Shape} does not match the size of + * the current shape + */ + public NDArray reshape(long... newShape) { + return reshape(new Shape(newShape)); + } + + /** + * Expands the {@link Shape} of a {@code NDArray}. + * + *

Inserts a new axis that will appear at the axis position in the expanded {@code NDArray} + * shape. + * + * @param axis the position in the expanded axes where the new axis is placed + * @return the result {@code NDArray}. The number of dimensions is one greater than that of the + * {@code NDArray} + */ + public NDArray expandDims(int axis) { + OpParams params = new OpParams(); + params.addParam("axis", axis); + return invoke(getParent(), "_npi_expand_dims", this, params); + } + + /** + * Removes all singleton dimensions from this {@code NDArray} {@link Shape}. + * + * @return a result {@code NDArray} of same size and data without singleton dimensions + */ + public NDArray squeeze() { + return invoke(getParent(), "_np_squeeze", this, null); + } + + /** + * Removes singleton dimensions at the given axes. + * + *

Examples + * + *

+     * jshell> NDArray array = manager.create(new float[] {0f, 1f, 2f}, new Shape(1, 3, 1));
+     * jshell> array;
+     * ND: (1, 3, 1) cpu() float32
+     * [[[0.],
+     *   [1.],
+     *   [2.],
+     *  ],
+     * ]
+     * jshell> array.squeeze(new int[] {0, 2});
+     * ND: (3) cpu() float32
+     * [0., 1., 2.]
+     * 
+ * + * @param axes the axes at which to remove the singleton dimensions + * @return a result {@code NDArray} of same size and data without the axes at part of the shape + * @throws IllegalArgumentException thrown if any of the given axes are not a singleton + * dimension + */ + public NDArray squeeze(int[] axes) { + OpParams params = new OpParams(); + params.addTupleParam("axis", axes); + return invoke(getParent(), "_np_squeeze", this, params); + } + + /** + * Returns the truth value of this {@code NDArray} AND the other {@code NDArray} element-wise. + * + *

The shapes of this {@code NDArray} and the other {@code NDArray} must be broadcastable. + * + * @param other the other {@code NDArray} to operate on + * @return the boolean {@code NDArray} of the logical AND operation applied to the elements of + * this {@code NDArray} and the other {@code NDArray} + */ + public NDArray logicalAnd(NDArray other) { + // TODO switch to numpy op, although current op support zero-dim, scalar + NDArray thisArr = + (getDataType() == DataType.BOOLEAN) ? toType(DataType.INT32, false) : this; + other = + (other.getDataType() == DataType.BOOLEAN) + ? other.toType(DataType.INT32, false) + : other; + return invoke(getParent(), "broadcast_logical_and", new NDArray[] {thisArr, other}, null) + .toType(DataType.BOOLEAN, false); + } + + /** + * Computes the truth value of this {@code NDArray} OR the other {@code NDArray} element-wise. + * + *

The shapes of this {@code NDArray} and the other {@code NDArray} must be broadcastable. + * + * @param other the other {@code NDArray} to operate on + * @return the boolean {@code NDArray} of the logical OR operation applied to the elements of + * this {@code NDArray} and the other {@code NDArray} + */ + public NDArray logicalOr(NDArray other) { + // TODO switch to numpy op, although current op support zero-dim, scalar + NDArray thisArr = + (getDataType() == DataType.BOOLEAN) ? toType(DataType.INT32, false) : this; + other = + (other.getDataType() == DataType.BOOLEAN) + ? other.toType(DataType.INT32, false) + : other; + return invoke(getParent(), "broadcast_logical_or", new NDArray[] {thisArr, other}, null) + .toType(DataType.BOOLEAN, false); + } + + /** + * Computes the truth value of this {@code NDArray} XOR the other {@code NDArray} element-wise. + * + *

The shapes of this {@code NDArray} and the other {@code NDArray} must be broadcastable. + * + * @param other the other {@code NDArray} to operate on + * @return the boolean {@code NDArray} of the logical XOR operation applied to the elements of + * this {@code NDArray} and the other {@code NDArray} + */ + public NDArray logicalXor(NDArray other) { + // TODO switch to numpy op, although current op support zero-dim, scalar + NDArray thisArr = + (getDataType() == DataType.BOOLEAN) ? toType(DataType.INT32, false) : this; + other = + (other.getDataType() == DataType.BOOLEAN) + ? other.toType(DataType.INT32, false) + : other; + return invoke(getParent(), "broadcast_logical_xor", new NDArray[] {thisArr, other}, null) + .toType(DataType.BOOLEAN, false); + } + + /** + * Computes the truth value of NOT this {@code NDArray} element-wise. + * + * @return the boolean {@code NDArray} + */ + public NDArray logicalNot() { + return invoke(getParent(), "_npi_logical_not", this, null); + } + + /** + * Returns the indices that would sort this {@code NDArray} given the axis. + * + *

Perform an indirect sort along the given axis. It returns a {@code NDArray} of indices of + * the same {@link Shape} as this {@code NDArray}. + * + * @param axis the axis to sort along + * @param ascending whether to sort ascending + * @return a {@code NDArray} of indices corresponding to elements in this {@code NDArray} on the + * axis, the output DataType is always {@link DataType#INT64} + */ + public NDArray argSort(int axis, boolean ascending) { + OpParams params = new OpParams(); + params.addParam("axis", axis); + // be careful that MXNet numpy argsort op didn't officially support this param + params.addParam("is_ascend", ascending); + params.setDataType(DataType.INT64); + return invoke(getParent(), "_npi_argsort", this, params); + } + + /** + * Sorts the flattened {@code NDArray}. + * + * @param axis the axis to sort along + * @return the sorted {@code NDArray} + */ + public NDArray sort(int axis) { + OpParams params = new OpParams(); + params.addParam("axis", axis); + return invoke(getParent(), "_npi_sort", this, params); + } + + /** + * Sorts the flattened {@code NDArray}. + * + * @return the sorted {@code NDArray} + */ + public NDArray sort() { + return invoke(getParent(), "_npi_sort", this, null); + } + + /** + * Applies the softmax function along the given axis. + * + * @param axis the axis along which to apply + * @return the result {@code NDArray} + * @see softmax + * @see NDArray#softmax(int) + */ + public NDArray softmax(int axis) { + // MXNet softmax op bug on GPU + if (isEmpty()) { + return create(getParent(), getShape(), DataType.FLOAT32, getDevice()); + } + OpParams params = new OpParams(); + params.addParam("axis", axis); + return invoke(getParent(), "_npx_softmax", this, params); + } + + /** + * Applies the softmax function followed by a logarithm. + * + *

Mathematically equivalent to calling softmax and then log. This single operator is faster + * than calling two operators and numerically more stable when computing gradients. + * + * @param axis the axis along which to apply + * @return the result {@code NDArray} + */ + public NDArray logSoftmax(int axis) { + // MXNet logsoftmax op bug on GPU + if (isEmpty()) { + return create(getParent(), getShape(), DataType.FLOAT32, getDevice()); + } + OpParams params = new OpParams(); + params.addParam("axis", axis); + return invoke(getParent(), "_npx_log_softmax", this, params); + } + + /** + * Returns the cumulative sum of the elements in the flattened {@code NDArray}. + * + * @return the cumulative sum of the elements in the flattened {@code NDArray} + */ + public NDArray cumSum() { + return invoke(getParent(), "_np_cumsum", this, null); + } + + /** + * Return the cumulative sum of the elements along a given axis. + * + * @param axis the axis along which the cumulative sum is computed + * @return the cumulative sum along the specified axis + */ + public NDArray cumSum(int axis) { + OpParams params = new OpParams(); + params.addParam("axis", axis); + return invoke(getParent(), "_np_cumsum", this, params); + } + + /** + * Replace the handle of the NDArray with the other. The NDArray used for replacement will be + * killed. + * + *

Please use with caution, this method will make the input argument unusable. + * + * @param replaced the handle provider that will be killed + */ + public void intern(NDArray replaced) { + NDArray arr = replaced; + Pointer oldHandle = handle.getAndSet(arr.handle.getAndSet(null)); + JnaUtils.waitToRead(oldHandle); + JnaUtils.freeNdArray(oldHandle); + // dereference old ndarray + arr.close(); + } + + /** + * Returns the boolean {@code NDArray} with value {@code true} where this {@code NDArray}'s + * entries are infinite, or {@code false} where they are not infinite. + * + * @return the boolean {@code NDArray} with value {@code true} if this {@code NDArray}'s entries + * are infinite + */ + public NDArray isInfinite() { + throw new UnsupportedOperationException("Not implemented yet."); + } + + /** + * Returns the boolean {@code NDArray} with value {@code true} where this {@code NDArray}'s + * entries are NaN, or {@code false} where they are not NaN. + * + * @return the boolean {@code NDArray} with value {@code true} if this {@code NDArray}'s {@link + * NDArray} are NaN + */ + public NDArray isNaN() { + return invoke(getParent(), "_npi_isnan", this, null); + } + + /** + * Returns a dense representation of the sparse {@code NDArray}. + * + * @return the result {@code NDArray} + */ + public NDArray toDense() { + if (!isSparse()) { + return duplicate(); + } + return castStorage(SparseFormat.DENSE); + } + + /** + * Returns a sparse representation of {@code NDArray}. + * + * @param fmt the {@link SparseFormat} of this {@code NDArray} + * @return the result {@code NDArray} + */ + public NDArray toSparse(SparseFormat fmt) { + if (fmt != SparseFormat.DENSE + && fmt != SparseFormat.CSR + && fmt != SparseFormat.ROW_SPARSE) { + throw new UnsupportedOperationException(fmt + " is not supported"); + } + if (fmt == getSparseFormat()) { + return duplicate(); + } + return castStorage(fmt); + } + + private NDArray castStorage(SparseFormat fmt) { + OpParams params = new OpParams(); + params.setParam("stype", fmt.getType()); + return invoke(getParent(), "cast_storage", this, params); + } + + /** + * Constructs a {@code NDArray} by repeating this {@code NDArray} the number of times given + * repeats. + * + * @param repeats the number of times to repeat for each dimension + * @return a NDArray that has been tiled + */ + public NDArray tile(long repeats) { + // zero-dim + if (isEmpty()) { + return duplicate(); + } + // scalar + int dim = (isScalar()) ? 1 : getShape().dimension(); + long[] repeatsArray = new long[dim]; + Arrays.fill(repeatsArray, repeats); + return tile(repeatsArray); + } + + /** + * Constructs a {@code NDArray} by repeating this {@code NDArray} the number of times given by + * repeats. + * + * @param repeats the number of times to repeat along each axis + * @return a {@code NDArray} that has been tiled + */ + public NDArray tile(long[] repeats) { + OpParams params = new OpParams(); + params.addTupleParam("reps", repeats); + return invoke(getParent(), "_npi_tile", this, params); + } + + /** + * Constructs a {@code NDArray} by repeating this {@code NDArray} the number of times given by + * repeats along given axis. + * + * @param axis the axis to repeat + * @param repeats the number of times to repeat for each axis + * @return a {@code NDArray} that has been tiled + * @throws IllegalArgumentException thrown for invalid axis + */ + public NDArray tile(int axis, long repeats) { + // scalar + if (isScalar()) { + throw new IllegalArgumentException("scalar didn't support specifying axis"); + } + long[] repeatsArray = new long[getShape().dimension()]; + Arrays.fill(repeatsArray, 1); + repeatsArray[withAxis(axis)] = repeats; + return tile(repeatsArray); + } + + /** + * Constructs a {@code NDArray} by repeating this {@code NDArray} the number of times to match + * the desired shape. + * + *

If the desired {@link Shape}has fewer dimensions than this {@code NDArray}, it will tile + * against the last axis. + * + * @param desiredShape the {@link Shape}that should be converted to + * @return a {@code NDArray} that has been tiled + */ + public NDArray tile(Shape desiredShape) { + return tile(repeatsToMatchShape(desiredShape)); + } + + private int withAxis(int axis) { + return Math.floorMod(axis, getShape().dimension()); + } + + private long[] repeatsToMatchShape(Shape desiredShape) { + Shape curShape = getShape(); + int dimension = curShape.dimension(); + if (desiredShape.dimension() > dimension) { + throw new IllegalArgumentException("The desired shape has too many dimensions"); + } + if (desiredShape.dimension() < dimension) { + int additionalDimensions = dimension - desiredShape.dimension(); + desiredShape = curShape.slice(0, additionalDimensions).addAll(desiredShape); + } + long[] repeats = new long[dimension]; + for (int i = 0; i < dimension; i++) { + if (curShape.get(i) == 0 || desiredShape.get(i) % curShape.get(i) != 0) { + throw new IllegalArgumentException( + "The desired shape is not a multiple of the original shape"); + } + repeats[i] = Math.round(Math.ceil((double) desiredShape.get(i) / curShape.get(i))); + } + return repeats; + } + + /** + * Repeats element of this {@code NDArray} the number of times given repeats. + * + * @param repeats the number of times to repeat for each axis + * @return an {@code NDArray} that has been repeated + */ + public NDArray repeat(long repeats) { + // zero-dim + if (isEmpty()) { + return duplicate(); + } + // scalar + int dim = (isScalar()) ? 1 : getShape().dimension(); + long[] repeatsArray = new long[dim]; + Arrays.fill(repeatsArray, repeats); + return repeat(repeatsArray); + } + + /** + * Repeats element of this {@code NDArray} the number of times given repeats along given axis. + * + * @param axis the axis to repeat + * @param repeats the number of times to repeat for each axis + * @return an {@code NDArray} that has been repeated + * @throws IllegalArgumentException thrown for invalid axis + */ + public NDArray repeat(int axis, long repeats) { + long[] repeatsArray = new long[getShape().dimension()]; + Arrays.fill(repeatsArray, 1); + repeatsArray[withAxis(axis)] = repeats; + return repeat(repeatsArray); + } + + /** + * Repeats element of this {@code NDArray} the number of times given repeats along each axis. + * + * @param repeats the number of times to repeat along each axis + * @return a {@code NDArray} that has been repeated + */ + public NDArray repeat(long[] repeats) { + // TODO get rid of for loop once bug in MXNet np.repeat is fixed + NDArray array = this; + int baseAxis = getShape().dimension() - repeats.length; + for (int i = 0; i < repeats.length; i++) { + if (repeats[i] > 1) { + NDArray previousArray = array; + OpParams params = new OpParams(); + params.addParam("repeats", repeats[i]); + params.addParam("axis", baseAxis + i); + array = invoke(getParent(), "_np_repeat", array, params); + if (previousArray != this) { + previousArray.close(); + } + } + } + return array; + } + + /** + * Repeats element of this {@code NDArray} to match the desired shape. + * + *

If the desired {@link Shape} has fewer dimensions that the array, it will repeat against + * the last axis. + * + * @param desiredShape the {@link Shape} that should be converted to + * @return an {@code NDArray} that has been repeated + */ + public NDArray repeat(Shape desiredShape) { + return repeat(repeatsToMatchShape(desiredShape)); + } + + /** + * Dot product of this {@code NDArray} and the other {@code NDArray}. + * + *

    + *
  • If both this {@code NDArray} and the other {@code NDArray} are 1-D {@code NDArray}s, it + * is inner product of vectors (without complex conjugation). + *
  • If both this {@code NDArray} and the other {@code NDArray} are 2-D {@code NDArray}s, it + * is matrix multiplication. + *
  • If either this {@code NDArray} or the other {@code NDArray} is 0-D {@code NDArray} + * (scalar), it is equivalent to mul. + *
  • If this {@code NDArray} is N-D {@code NDArray} and the other {@code NDArray} is 1-D + * {@code NDArray}, it is a sum product over the last axis of those. + *
  • If this {@code NDArray} is N-D {@code NDArray} and the other {@code NDArray} is M-D + * {@code NDArray}(where M>=2), it is a sum product over the last axis of this + * {@code NDArray} and the second-to-last axis of the other {@code NDArray} + *
+ * + * @param other the other {@code NDArray} to perform dot product with + * @return the result {@code NDArray} + */ + public NDArray dot(NDArray other) { + return invoke(getParent(), "_np_dot", new NDArray[] {this, other}, null); + } + + /** + * Product matrix of this {@code NDArray} and the other {@code NDArray}. + * + * @param other the other {@code NDArray} to perform matrix product with + * @return the result {@code NDArray} + */ + public NDArray matMul(NDArray other) { + if (isScalar() || other.isScalar()) { + throw new IllegalArgumentException("scalar is not allowed for matMul()"); + } + return invoke(getParent(), "_npi_matmul", new NDArray[] {this, other}, null); + } + + /** + * Clips (limit) the values in this {@code NDArray}. + * + *

Given an interval, values outside the interval are clipped to the interval edges. For + * example, if an interval of [0, 1] is specified, values smaller than 0 become 0, and values + * larger than 1 become 1. + * + * @param min the minimum value + * @param max the maximum value + * @return an {@code NDArray} with the elements of this {@code NDArray}, but where values < + * min are replaced with min, and those > max with max + */ + public NDArray clip(Number min, Number max) { + OpParams params = new OpParams(); + params.addParam("a_min", min); + params.addParam("a_max", max); + return invoke(getParent(), "_npi_clip", this, params); + } + + /** + * Interchanges two axes of this {@code NDArray}. + * + * @param axis1 the first axis + * @param axis2 the second axis + * @return the swapped axes {@code NDArray} + */ + public NDArray swapAxes(int axis1, int axis2) { + OpParams params = new OpParams(); + params.addParam("dim1", axis1); + params.addParam("dim2", axis2); + return invoke(getParent(), "_npi_swapaxes", this, params); + } + + /** + * Returns the reverse order of elements in an array along the given axis. + * + *

The shape of the array is preserved, but the elements are reordered. + * + * @param axes the axes to flip on + * @return the newly flipped array + */ + public NDArray flip(int... axes) { + OpParams params = new OpParams(); + params.addTupleParam("axis", axes); + return invoke(getParent(), "_npi_flip", this, params); + } + + /** + * Returns this {@code NDArray} with axes transposed. + * + * @return the newly permuted array + */ + public NDArray transpose() { + return invoke(getParent(), "_np_transpose", this, null); + } + + /** + * Returns this {@code NDArray} with given axes transposed. + * + * @param dimensions the axes to swap to + * @return the transposed {@code NDArray} + * @throws IllegalArgumentException thrown when passing a axis that is greater than the actual + * number of dimensions + */ + public NDArray transpose(int... dimensions) { + if (Arrays.stream(dimensions).anyMatch(d -> d < 0)) { + throw new UnsupportedOperationException( + "Passing -1 for broadcasting the dimension is not currently supported"); + } + if (!Arrays.equals( + Arrays.stream(dimensions).sorted().toArray(), + IntStream.range(0, getShape().dimension()).toArray())) { + throw new IllegalArgumentException( + "You must include each of the dimensions from 0 until " + + getShape().dimension()); + } + OpParams params = new OpParams(); + params.addTupleParam("axes", dimensions); + return invoke(getParent(), "_np_transpose", this, params); + } + + /** + * Broadcasts this {@code NDArray} to be the given shape. + * + * @param shape the new {@link Shape} of this {@code NDArray} + * @return the broadcasted {@code NDArray} + */ + public NDArray broadcast(Shape shape) { + OpParams params = new OpParams(); + params.setShape(shape); + return invoke(getParent(), "_npi_broadcast_to", this, params); + } + + /** + * Returns the indices of the maximum values into the flattened {@code NDArray}. + * + * @return a {@code NDArray} containing indices + */ + public NDArray argMax() { + if (isEmpty()) { + throw new IllegalArgumentException("attempt to get argMax of an empty MxNDArray"); + } + return invoke(getParent(), "_npi_argmax", this, null); + } + + /** + * Returns the indices of the maximum values along given axis. + * + * @param axis the axis along which to find maximum values + * @return a {@code NDArray} containing indices + */ + public NDArray argMax(int axis) { + OpParams params = new OpParams(); + params.addParam("axis", axis); + return invoke(getParent(), "_npi_argmax", this, params); + } + + /** + * Returns the indices of the minimum values into the flattened {@code NDArray}. + * + * @return a {@code NDArray} containing indices + */ + public NDArray argMin() { + if (isEmpty()) { + throw new IllegalArgumentException("attempt to get argMin of an empty MxNDArray"); + } + return invoke(getParent(), "_npi_argmin", this, null); + } + + /** + * Returns the indices of the minimum values along given axis. + * + * @param axis the axis along which to find minimum values + * @return a {@code NDArray} containing indices + */ + public NDArray argMin(int axis) { + OpParams params = new OpParams(); + params.addParam("axis", axis); + return invoke(getParent(), "_npi_argmin", this, params); + } + + /** + * Returns percentile for this {@code NDArray}. + * + * @param percentile the target percentile in range of 0..100 + * @return the result {@code NDArray} + */ + public NDArray percentile(Number percentile) { + throw new UnsupportedOperationException("Not implemented yet."); + } + + /** + * Returns median along given dimension(s). + * + * @param percentile the target percentile in range of 0..100 + * @param dimension the dimension to calculate percentile for + * @return the result {@code NDArray} NDArray + */ + public NDArray percentile(Number percentile, int[] dimension) { + throw new UnsupportedOperationException("Not implemented yet."); + } + + /** + * Returns median value for this {@code NDArray}. + * + * @return the median {@code NDArray} + */ + public NDArray median() { + throw new UnsupportedOperationException("Not implemented yet."); + } + + /** + * Returns median value along given axes. + * + * @param axes the axes along which to perform the median operation + * @return the median {@code NDArray} along the specified axes + */ + public NDArray median(int[] axes) { + throw new UnsupportedOperationException("Not implemented yet."); + } + + /** + * Returns the indices of elements that are non-zero. + * + *

Note that the behavior is slightly different from numpy.nonzero. Numpy returns a tuple of + * NDArray, one for each dimension of NDArray. DJL nonzero returns only one {@code NDArray} with + * last dimension containing all dimension of indices. + * + * @return the indices of the elements that are non-zero + */ + public NDArray nonzero() { + NDArray thisArr = + (getDataType() == DataType.BOOLEAN) ? toType(DataType.INT32, false) : this; + return invoke(getParent(), "_npx_nonzero", thisArr, null); + } + + /** + * Returns element-wise inverse gauss error function of the {@code NDArray}. + * + * @return The inverse of gauss error of the {@code NDArray}, element-wise + */ + public NDArray erfinv() { + return invoke(getParent(), "erfinv", this, null); + } + + /** + * Returns the norm of this {@code NDArray}. + * + * @param keepDims If this is set to True, the axes which are normed over are left in the result + * as dimensions with size one. With this option the result will broadcast correctly against + * the original x. + * @return the norm of this {@code NDArray} + */ + public NDArray norm(boolean keepDims) { + OpParams params = new OpParams(); + params.add("flag", -2); + params.addParam("keepdims", keepDims); + return invoke(getParent(), "_npi_norm", this, params); + } + + /** + * Returns the norm of this {@code NDArray}. + * + * @param ord Order of the norm. + * @param axes If axes contains an integer, it specifies the axis of x along which to compute + * the vector norms. If axis contains 2 integers, it specifies the axes that hold 2-D + * matrices, and the matrix norms of these matrices are computed. + * @param keepDims keepDims If this is set to True, the axes which are normed over are left in + * the result as dimensions with size one. With this option the result will broadcast + * correctly against the original x. + * @return the norm of this {@code NDArray} + */ + public NDArray norm(int ord, int[] axes, boolean keepDims) { + OpParams params = new OpParams(); + params.addParam("ord", (double) ord); + params.addTupleParam("axis", axes); + params.addParam("keepdims", keepDims); + return invoke(getParent(), "_npi_norm", this, params); + } + + // public MxNDArray oneHot(int depth) { + // return LazyNDArray.super.oneHot(depth); + // } + + /** + * Returns a one-hot {@code NDArray}. + * + *

    + *
  • The locations represented by indices take value onValue, while all other locations take + * value offValue. + *
  • If the input {@code NDArray} is rank N, the output will have rank N+1. The new axis is + * appended at the end. + *
  • If {@code NDArray} is a scalar the output shape will be a vector of length depth. + *
  • If {@code NDArray} is a vector of length features, the output shape will be features x + * depth. + *
  • If {@code NDArray} is a matrix with shape [batch, features], the output shape will be + * batch x features x depth. + *
+ * + * @param depth Depth of the one hot dimension. + * @param onValue The value assigned to the locations represented by indices. + * @param offValue The value assigned to the locations not represented by indices. + * @param dataType dataType of the output. + * @return one-hot encoding of this {@code NDArray} + */ + public NDArray oneHot(int depth, float onValue, float offValue, DataType dataType) { + OpParams params = new OpParams(); + params.add("depth", depth); + params.add("on_value", onValue); + params.add("off_value", offValue); + params.add("dtype", dataType); + return invoke(getParent(), "_npx_one_hot", this, params).toType(dataType, false); + } + + /** + * Batchwise product of this {@code NDArray} and the other {@code NDArray}. + * + *
    + *
  • batchDot is used to compute dot product of x and y when x and y are data in batch, + * namely N-D (N greater or equal to 3) arrays in shape of (B0, …, B_i, :, :). For + * example, given x with shape (B_0, …, B_i, N, M) and y with shape (B_0, …, B_i, M, K), + * the result array will have shape (B_0, …, B_i, N, K), which is computed by: + * batch_dot(x,y)[b_0, ..., b_i, :, :] = dot(x[b_0, ..., b_i, :, :], y[b_0, ..., b_i, :, + * :]) + *
+ * + * @param other the other {@code NDArray} to perform batch dot product with + * @return the result {@code NDArray} + */ + public NDArray batchDot(NDArray other) { + return invoke(getParent(), "_npx_batch_dot", new NDArray[] {this, other}, null); + } + + /** + * Returns an internal representative of Native {@code NDArray}. + * + *

This method should only be used by Engine provider + * + * @return an internal representative of Native {@code NDArray} + */ + public NDArrayEx getNDArrayInternal() { + return mxNDArrayEx; + } + + /** {@inheritDoc} */ + @Override + public void close() { + if (!getClosed()) { + logger.debug(String.format("Start to free NDArray instance: %S", this.getUid())); + super.freeSubResources(); + + if (this.getHandle() != null) { + JnaUtils.freeNdArray(this.getHandle()); + } + setClosed(true); + logger.debug(String.format("Finish to free NDArray instance: %S", this.getUid())); + } + } + + /** + * Returns {@code true} if this {@code NDArray} is special case: no-value {@code NDArray}. + * + * @return {@code true} if this NDArray is empty + */ + public boolean isEmpty() { + return getShape().size() == 0; + } + + boolean isSparse() { + return getSparseFormat() != SparseFormat.DENSE; + } + + boolean shapeEquals(NDArray other) { + return getShape().equals(other.getShape()); + } + + /** + * An engine specific generic invocation to native operation. + * + *

You should avoid using this function if possible. Since this function is engine specific, + * using this API may cause a portability issue. Native operation may not compatible between + * each version. + * + * @param parent the parent {@link MxResource} of the created {@link NDList} + * @param operation the native operation to perform + * @param src the {@link NDList} of source {@link NDArray} + * @param params the parameters to be passed to the native operation + * @return the output array of {@link NDArray} + * @throws IllegalArgumentException if operation is not supported by Engine + */ + public static NDList invoke( + MxResource parent, String operation, NDList src, PairList params) { + return new NDList(JnaUtils.op(operation).invoke(parent, src.toArray(EMPTY), params)); + } + + /** + * An engine specific generic invocation to native operator. + * + *

You should avoid using this function if possible. Since this function is engine specific, + * using this API may cause portability issues. A native operation may not be compatible between + * each version. + * + * @param operation the native operation to perform + * @param src the {@link NDList} of source {@link NDArray} + * @param dest the {@link NDList} to save output to + * @param params the parameters to be passed to the native operator + * @throws IllegalArgumentException if operation is not supported by Engine + */ + public static void invoke( + String operation, NDList src, NDList dest, PairList params) { + invoke(operation, src.toArray(EMPTY), dest.toArray(EMPTY), params); + } + + /** + * An engine specific generic invocation to native operator. + * + *

You should avoid using this function if possible. Since this function is engine specific, + * using this API may cause portability issues. A native operation may not be compatible between + * each version. + * + * @param parent the parent {@link MxResource} to manage this instance + * @param operation the native operation to perform + * @param src the array of source {@link NDArray} + * @param params the parameters to be passed to the native operator + * @return the output array of {@link NDArray} + */ + public static NDArray invoke( + MxResource parent, String operation, NDArray[] src, PairList params) { + return JnaUtils.op(operation).invoke(parent, src, params)[0]; + } + + /** + * An engine specific generic invocation to native operation. + * + *

You should avoid using this function if possible. Since this function is engine specific, + * using this API may cause a portability issue. Native operation may not be compatible between + * each version. + * + * @param operation the native operation to perform + * @param src the {@link NDList} of source {@link NDArray} + * @param dest the {@link NDList} to save output to + * @param params the parameters to be passed to the native operation + * @throws IllegalArgumentException if operation is not supported by Engine + */ + public static void invoke( + String operation, NDArray[] src, NDArray[] dest, PairList params) { + JnaUtils.op(operation).invoke(src, dest, params); + } + + /** + * An engine specific generic invocation to native operator. + * + *

You should avoid using this function if possible. Since this function is engine specific, + * using this API may cause portability issues. A native operation may not be compatible between + * each version. + * + * @param parent the parent {@link MxResource} to manage this instance + * @param operation the native operation to perform + * @param src the source {@link NDArray} + * @param params the parameters to be passed to the native operator + * @return the output array of {@link NDArray} + */ + public static NDArray invoke( + MxResource parent, String operation, NDArray src, PairList params) { + return invoke(parent, operation, new NDArray[] {src}, params); + } + + /** + * An engine specific generic invocation to native operator. + * + *

You should avoid using this function if possible. Since this function is engine specific, + * using this API may cause portability issues. A native operation may not be compatible between + * each version. + * + * @param parent the parent {@link MxResource} to manage this instance + * @param operation the native operation to perform + * @param params the parameters to be passed to the native operator + * @return the output array of {@link NDArray} + */ + public static NDArray invoke(MxResource parent, String operation, PairList params) { + return invoke(parent, operation, EMPTY, params); + } + + /** + * Encodes {@code MxNDArray} to byte array. + * + * @return byte array + */ + public byte[] encode() { + return NDSerializer.encode(this); + } + + /** + * Draws samples from a uniform distribution. + * + *

Samples are uniformly distributed over the half-open interval [low, high) (includes low, + * but excludes high). In other words, any value within the given interval is equally likely to + * be drawn by uniform. + * + * @param parent {@link MxResource} of this instance + * @param low the lower boundary of the output interval. All values generated will be greater + * than or equal to low. + * @param high the upper boundary of the output interval. All values generated will be less than + * high. + * @param shape the {@link Shape} of the {@link NDArray} + * @param dataType the {@link DataType} of the {@link NDArray} + * @param device the {@link Device} of the {@link NDArray} + * @return the drawn samples {@link NDArray} + */ + public static NDArray randomUniform( + MxResource parent, + float low, + float high, + Shape shape, + DataType dataType, + Device device) { + OpParams params = new OpParams(); + params.addParam("low", low); + params.addParam("high", high); + params.addParam("size", shape); + params.setDevice(device); + params.setDataType(dataType); + return invoke(parent, "_npi_uniform", params); + } + + /** + * Draws samples from a uniform distribution. + * + *

Samples are uniformly distributed over the half-open interval [low, high) (includes low, + * but excludes high). In other words, any value within the given interval is equally likely to + * be drawn by uniform. + * + * @param parent {@link MxResource} of this instance + * @param low the lower boundary of the output interval. All values generated will be greater + * than or equal to low. + * @param high the upper boundary of the output interval. All values generated will be less than + * high. + * @param shape the {@link Shape} of the {@link NDArray} + * @param dataType the {@link DataType} of the {@link NDArray} + * @return the drawn samples {@link NDArray} + */ + private static NDArray randomUniform( + MxResource parent, float low, float high, Shape shape, DataType dataType) { + return randomUniform(parent, low, high, shape, dataType, Device.defaultIfNull(null)); + } + + /** + * Draws random samples from a normal (Gaussian) distribution. + * + * @param parent {@link MxResource} of this instance + * @param loc the mean (centre) of the distribution + * @param scale the standard deviation (spread or "width") of the distribution + * @param shape the output {@link Shape} + * @param dataType the {@link DataType} of the {@link NDArray} + * @param device the {@link Device} of the {@link NDArray} + * @return the drawn samples {@link NDArray} + */ + public static NDArray randomNormal( + MxResource parent, + float loc, + float scale, + Shape shape, + DataType dataType, + Device device) { + if (device == null) { + return randomNormal(parent, loc, scale, shape, dataType); + } + return randomNormal(parent, loc, scale, shape, dataType); + } + + /** + * Draws random samples from a normal (Gaussian) distribution. + * + * @param parent {@link MxResource} of this instance + * @param loc the mean (centre) of the distribution + * @param scale the standard deviation (spread or "width") of the distribution + * @param shape the output {@link Shape} + * @param dataType the {@link DataType} of the {@link NDArray} + * @return the drawn samples {@link NDArray} + */ + public static NDArray randomNormal( + MxResource parent, float loc, float scale, Shape shape, DataType dataType) { + OpParams params = new OpParams(); + params.addParam("loc", loc); + params.addParam("scale", scale); + params.addParam("size", shape); + params.setDevice(Device.defaultIfNull(null)); + params.setDataType(dataType); + return invoke(parent, "_npi_normal", params); + } + + /** + * Decodes {@link NDArray} through byte array. + * + * @param parent the parent {@link MxResource} to create the {@link NDArray} + * @param bytes byte array to load from + * @return {@link NDArray} + */ + static NDArray decode(MxResource parent, byte[] bytes) { + try (DataInputStream dis = new DataInputStream(new ByteArrayInputStream(bytes))) { + return NDSerializer.decode(parent, dis); + } catch (IOException e) { + throw new IllegalArgumentException("NDArray decoding failed", e); + } + } + + /** + * Decodes {@link NDArray} through {@link DataInputStream}. + * + * @param parent the parent {@link MxResource} to create the {@link NDArray} + * @param is input stream data to load from + * @return {@link NDArray} + * @throws IOException data is not readable + */ + public static NDArray decode(MxResource parent, InputStream is) throws IOException { + return NDSerializer.decode(parent, is); + } + + /** + * Converts this {@code NDArray} to a Number array based on its {@link DataType}. + * + * @return a Number array + */ + public Number[] toArray() { + switch (getDataType()) { + case FLOAT16: + case FLOAT32: + float[] floatArray = toFloatArray(); + return IntStream.range(0, floatArray.length) + .mapToObj(i -> floatArray[i]) + .toArray(Number[]::new); + case FLOAT64: + return Arrays.stream(toDoubleArray()).boxed().toArray(Double[]::new); + case INT32: + return Arrays.stream(toIntArray()).boxed().toArray(Integer[]::new); + case INT64: + return Arrays.stream(toLongArray()).boxed().toArray(Long[]::new); + case BOOLEAN: + case INT8: + ByteBuffer bb = toByteBuffer(); + Byte[] ret = new Byte[bb.remaining()]; + for (int i = 0; i < ret.length; ++i) { + ret[i] = bb.get(); + } + return ret; + case UINT8: + return Arrays.stream(toUint8Array()).boxed().toArray(Integer[]::new); + default: + throw new IllegalStateException("Unsupported DataType: " + getDataType()); + } + } + + /** + * Converts this {@code NDArray} to a boolean array. + * + * @return a boolean array + * @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches + */ + public boolean[] toBooleanArray() { + if (getDataType() != DataType.BOOLEAN) { + throw new IllegalStateException( + "DataType mismatch, Required boolean" + " Actual " + getDataType()); + } + ByteBuffer bb = toByteBuffer(); + boolean[] ret = new boolean[bb.remaining()]; + for (int i = 0; i < ret.length; ++i) { + ret[i] = bb.get() != 0; + } + return ret; + } + + /** + * Converts this {@code NDArray} to a double array. + * + * @return a double array + * @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches + */ + public double[] toDoubleArray() { + if (getDataType() != DataType.FLOAT64) { + throw new IllegalStateException( + "DataType mismatch, Required double" + " Actual " + getDataType()); + } + DoubleBuffer db = toByteBuffer().asDoubleBuffer(); + double[] ret = new double[db.remaining()]; + db.get(ret); + return ret; + } + + /** + * Converts this {@code NDArray} to a float array. + * + * @return a float array + * @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches + */ + public float[] toFloatArray() { + if (getDataType() == DataType.FLOAT16) { + return Float16Utils.fromByteBuffer(toByteBuffer()); + } else if (getDataType() != DataType.FLOAT32) { + throw new IllegalStateException( + "DataType mismatch, Required float, Actual " + getDataType()); + } + FloatBuffer fb = toByteBuffer().asFloatBuffer(); + float[] ret = new float[fb.remaining()]; + fb.get(ret); + return ret; + } + + /** + * Converts this {@code NDArray} to an int array. + * + * @return an int array + * @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches + */ + public int[] toIntArray() { + if (getDataType() != DataType.INT32) { + throw new IllegalStateException( + "DataType mismatch, Required int" + " Actual " + getDataType()); + } + IntBuffer ib = toByteBuffer().asIntBuffer(); + int[] ret = new int[ib.remaining()]; + ib.get(ret); + return ret; + } + + /** + * Converts this {@code NDArray} to a long array. + * + * @return a long array + * @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches + */ + public long[] toLongArray() { + if (getDataType() != DataType.INT64) { + throw new IllegalStateException( + "DataType mismatch, Required long" + " Actual " + getDataType()); + } + LongBuffer lb = toByteBuffer().asLongBuffer(); + long[] ret = new long[lb.remaining()]; + lb.get(ret); + return ret; + } + + /** + * Converts this {@code NDArray} to a byte array. + * + * @return a byte array + * @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches + */ + public byte[] toByteArray() { + ByteBuffer bb = toByteBuffer(); + if (bb.hasArray()) { + return bb.array(); + } + byte[] buf = new byte[bb.remaining()]; + bb.get(buf); + return buf; + } + + /** + * Converts this {@code NDArray} to a uint8 array. + * + * @return a uint8 array + * @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches + */ + public int[] toUint8Array() { + ByteBuffer bb = toByteBuffer(); + int[] buf = new int[bb.remaining()]; + for (int i = 0; i < buf.length; ++i) { + buf[i] = bb.get() & 0xff; + } + return buf; + } + + /** {@inheritDoc} */ + @Override + public String toString() { + if (getClosed()) { + return "This array is already closed"; + } + return toDebugString(MAX_SIZE, MAX_DEPTH, MAX_ROWS, MAX_COLUMNS); + } + + /** + * Runs the debug string representation of this {@code NDArray}. + * + * @param maxSize the maximum elements to print out + * @param maxDepth the maximum depth to print out + * @param maxRows the maximum rows to print out + * @param maxColumns the maximum columns to print out + * @return the debug string representation of this {@code NDArray} + */ + String toDebugString(int maxSize, int maxDepth, int maxRows, int maxColumns) { + return NDFormat.format(this, maxSize, maxDepth, maxRows, maxColumns); + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArrayEx.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArrayEx.java new file mode 100644 index 000000000000..047ff046ab09 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArrayEx.java @@ -0,0 +1,1107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.ndarray; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Stream; +import org.apache.mxnet.engine.Device; +import org.apache.mxnet.engine.OpParams; +import org.apache.mxnet.jna.JnaUtils; +import org.apache.mxnet.ndarray.types.DataType; +import org.apache.mxnet.ndarray.types.Shape; +import org.apache.mxnet.ndarray.types.SparseFormat; + +/** An internal interface that encapsulates engine specific operations. */ +@SuppressWarnings("MissingJavadocMethod") +public class NDArrayEx { + + private static final NDArrayIndexer INDEXER = new NDArrayIndexer(); + + private NDArray array; + + /** + * Constructs an {@code MxNDArrayEx} given a {@link NDArray}. + * + * @param parent the {@link NDArray} to extend + */ + NDArrayEx(NDArray parent) { + this.array = parent; + } + + // TODO only used to calculate zero-dim numpy shape + // remove it once MXNet have all the np op that we support + private Shape deriveBroadcastedShape(Shape lhs, Shape rhs) { + long[] result = new long[Math.max(lhs.dimension(), rhs.dimension())]; + long lDiff = result.length - lhs.dimension(); + long rDiff = result.length - rhs.dimension(); + for (int i = 0; i < result.length; i++) { + long l = 1; + long r = 1; + if (i >= lDiff) { + l = lhs.get(Math.toIntExact(i - lDiff)); + } + if (i >= rDiff) { + r = rhs.get(Math.toIntExact(i - rDiff)); + } + if (l != r) { + if (l != 1 && r != 1) { + throw new IllegalArgumentException( + "operands could not be broadcast together with shapes " + + lhs + + " " + + rhs); + } + result[i] = (l == 1) ? r : l; + } else { + result[i] = l; + } + } + return new Shape(result); + } + + //////////////////////////////////////// + // MxNDArrays + //////////////////////////////////////// + /** + * Applies reverse division with a scalar - i.e., (n / thisArrayValues). + * + * @param n the Value to use for reverse division + * @return a copy of the array after applying reverse division + */ + public NDArray rdiv(Number n) { + OpParams params = new OpParams(); + params.add("scalar", n.toString()); + return NDArray.invoke(getArray().getParent(), "_rdiv_scalar", array, params); + } + + /** + * Applies reverse division with a scalar - i.e., (n / thisArrayValues). + * + * @param b the ndarray to use for reverse division + * @return a copy of the array after applying reverse division + */ + public NDArray rdiv(NDArray b) { + return b.div(array); + } + + /** + * Applies in place reverse division - i.e., (n / thisArrayValues). + * + * @param n the value to use for reverse division + * @return this array after applying reverse division + */ + public NDArray rdivi(Number n) { + OpParams params = new OpParams(); + params.add("scalar", n.toString()); + NDArray.invoke("_rdiv_scalar", new NDArray[] {array}, new NDArray[] {array}, params); + return array; + } + + /** + * Applies in place reverse division - i.e., (n / thisArrayValues). + * + * @param b the ndarray to use for reverse division + * @return this array after applying reverse division + */ + public NDArray rdivi(NDArray b) { + NDArray.invoke("elemwise_div", new NDArray[] {b, array}, new NDArray[] {array}, null); + return array; + } + + /** + * Applies reverse subtraction with duplicates - i.e., (n - thisArrayValues). + * + * @param n the value to use for reverse subtraction + * @return a copy of array after reverse subtraction + */ + public NDArray rsub(Number n) { + return array.sub(n).neg(); + } + + /** + * Applies reverse subtraction with duplicates - i.e., (n - thisArrayValues). + * + * @param b the ndarray to use for reverse subtraction + * @return a copy of the array after reverse subtraction + */ + public NDArray rsub(NDArray b) { + return array.sub(b).neg(); + } + + /** + * Applies reverse subtraction in place - i.e., (n - thisArrayValues). + * + * @param n the value to use for reverse subtraction + * @return this array after reverse subtraction + */ + public NDArray rsubi(Number n) { + return array.subi(n).negi(); + } + + /** + * Applies reverse subtraction in place - i.e., (n - thisArrayValues). + * + * @param b the ndarray to use for reverse subtraction + * @return this array after reverse subtraction + */ + public NDArray rsubi(NDArray b) { + return array.subi(b).negi(); + } + + public NDArray rmod(Number n) { + OpParams params = new OpParams(); + params.add("scalar", n.toString()); + return NDArray.invoke(getArray().getParent(), "_npi_rmod_scalar", array, params); + } + + /** + * Applies reverse remainder of division with a scalar. + * + * @param b the value to use for reverse division + * @return a copy of array after applying reverse division + */ + public NDArray rmod(NDArray b) { + return b.mod(array); + } + + /** + * Applies in place reverse remainder of division with a scalar. + * + * @param n the value to use for reverse division + * @return this array after applying reverse division + */ + public NDArray rmodi(Number n) { + OpParams params = new OpParams(); + params.add("scalar", n.toString()); + NDArray.invoke("_npi_rmod_scalar", new NDArray[] {array}, new NDArray[] {array}, params); + return array; + } + + /** + * Applies in place reverse remainder of division. + * + * @param b the ndarray to use for reverse division + * @return this array after applying reverse division + */ + public NDArray rmodi(NDArray b) { + NDArray.invoke("_npi_mod", new NDArray[] {b, array}, new NDArray[] {array}, null); + return array; + } + + /** + * Reverses the power of each element being raised in the {@code NDArray}. + * + * @param n the value to use for reverse power + * @return a copy of array after applying reverse power + */ + public NDArray rpow(Number n) { + OpParams params = new OpParams(); + params.add("scalar", n.toString()); + return NDArray.invoke(getArray().getParent(), "_npi_rpower_scalar", array, params); + } + + /** + * Reverses the power of each element being raised in the {@code NDArray} in place. + * + * @param n the value to use for reverse power + * @return a copy of array after applying reverse power + */ + public NDArray rpowi(Number n) { + OpParams params = new OpParams(); + params.add("scalar", n.toString()); + NDArray.invoke("_npi_rpower_scalar", new NDArray[] {array}, new NDArray[] {array}, params); + return array; + } + + //////////////////////////////////////// + // Activations + //////////////////////////////////////// + /** + * Computes rectified linear activation. + * + * @return a copy of array after applying relu + */ + public NDArray relu() { + OpParams params = new OpParams(); + params.addParam("act_type", "relu"); + return NDArray.invoke(getArray().getParent(), "_npx_activation", array, params); + } + + public NDArray sigmoid() { + OpParams params = new OpParams(); + params.addParam("act_type", "sigmoid"); + return NDArray.invoke(getArray().getParent(), "_npx_activation", array, params); + } + + public NDArray tanh() { + OpParams params = new OpParams(); + params.addParam("act_type", "tanh"); + return NDArray.invoke(getArray().getParent(), "_npx_activation", array, params); + } + + public NDArray softPlus() { + OpParams params = new OpParams(); + params.addParam("act_type", "softrelu"); + return NDArray.invoke(getArray().getParent(), "_npx_activation", array, params); + } + + public NDArray softSign() { + OpParams params = new OpParams(); + params.addParam("act_type", "softsign"); + return NDArray.invoke(getArray().getParent(), "_npx_activation", array, params); + } + + public NDArray leakyRelu(float alpha) { + OpParams params = new OpParams(); + params.addParam("act_type", "leaky"); + params.addParam("slope", alpha); + return NDArray.invoke(getArray().getParent(), "_npx_leaky_relu", array, params); + } + + public NDArray elu(float alpha) { + OpParams params = new OpParams(); + params.addParam("act_type", "elu"); + params.addParam("slope", alpha); + return NDArray.invoke(getArray().getParent(), "_npx_leaky_relu", array, params); + } + + public NDArray selu() { + OpParams params = new OpParams(); + params.addParam("act_type", "selu"); + return NDArray.invoke(getArray().getParent(), "_npx_leaky_relu", array, params); + } + + public NDArray gelu() { + OpParams params = new OpParams(); + params.addParam("act_type", "gelu"); + return NDArray.invoke(getArray().getParent(), "_npx_leaky_relu", array, params); + } + + //////////////////////////////////////// + // Pooling Operations + //////////////////////////////////////// + + public NDArray maxPool(Shape kernelShape, Shape stride, Shape padding, boolean ceilMode) { + OpParams params = new OpParams(); + params.addParam("kernel", kernelShape); + params.add("pool_type", "max"); + params.addParam("stride", stride); + params.addParam("pad", padding); + params.add("pooling_convention", ceilMode ? "full" : "valid"); + return NDArray.invoke(getArray().getParent(), "_npx_pooling", getArray(), params); + } + + public NDArray globalMaxPool() { + OpParams params = new OpParams(); + params.add("kernel", getGlobalPoolingShapes(1)); + params.add("pad", getGlobalPoolingShapes(0)); + params.add("pool_type", "max"); + params.addParam("global_pool", true); + try (NDArray temp = + NDArray.invoke(getArray().getParent(), "_npx_pooling", getArray(), params)) { + return temp.reshape(temp.getShape().size(0), temp.getShape().size(1)); + } + } + + public NDArray avgPool( + Shape kernelShape, + Shape stride, + Shape padding, + boolean ceilMode, + boolean countIncludePad) { + OpParams params = new OpParams(); + params.addParam("kernel", kernelShape); + params.add("pool_type", "avg"); + params.addParam("stride", stride); + params.addParam("pad", padding); + params.add("pooling_convention", ceilMode ? "full" : "valid"); + params.addParam("count_include_pad", countIncludePad); + return NDArray.invoke(getArray().getParent(), "_npx_pooling", getArray(), params); + } + + public NDArray globalAvgPool() { + OpParams params = new OpParams(); + params.add("kernel", getGlobalPoolingShapes(1)); + params.add("pad", getGlobalPoolingShapes(0)); + params.add("pool_type", "avg"); + params.addParam("global_pool", true); + try (NDArray temp = + NDArray.invoke(getArray().getParent(), "_npx_pooling", getArray(), params)) { + return temp.reshape(temp.getShape().size(0), temp.getShape().size(1)); + } + } + + public NDArray lpPool( + float normType, Shape kernelShape, Shape stride, Shape padding, boolean ceilMode) { + if (((int) normType) != normType) { + throw new IllegalArgumentException( + "float type of normType is not supported in MXNet engine, please use integer instead"); + } + OpParams params = new OpParams(); + params.addParam("p_value", (int) normType); + params.addParam("kernel", kernelShape); + params.add("pool_type", "lp"); + params.addParam("stride", stride); + params.addParam("pad", padding); + params.add("pooling_convention", ceilMode ? "full" : "valid"); + + return NDArray.invoke(getArray().getParent(), "_npx_pooling", getArray(), params); + } + + public NDArray globalLpPool(float normType) { + if (((int) normType) != normType) { + throw new IllegalArgumentException( + "float type of normType is not supported in MXNet engine, please use integer instead"); + } + OpParams params = new OpParams(); + params.add("pool_type", "lp"); + params.addParam("p_value", (int) normType); + params.addParam("global_pool", true); + try (NDArray temp = + NDArray.invoke(getArray().getParent(), "_npx_pooling", getArray(), params)) { + return temp.reshape(temp.getShape().size(0), temp.getShape().size(1)); + } + } + + //////////////////////////////////////// + // Optimizer + //////////////////////////////////////// + + // public void adadeltaUpdate( + // MxNDList inputs, + // MxNDList weights, + // float weightDecay, + // float rescaleGrad, + // float clipGrad, + // float rho, + // float epsilon) { + // MxNDArray weight = inputs.get(0); + // MxNDArray grad = inputs.get(1); + // MxNDArray s = inputs.get(2); + // MxNDArray delta = inputs.get(3); + // + // // create a baseManager to close all intermediate MxNDArrays + // try (NDManager subManager = NDManager.newBaseManager()) { + // subManager.tempAttachAll(inputs, weights); + // + // // Preprocess Gradient + // grad.muli(rescaleGrad); + // if (clipGrad > 0) { + // grad = grad.clip(-clipGrad, clipGrad); + // } + // grad.addi(weight.mul(weightDecay)); + // + // // Update s, g, and delta + // s.muli(rho).addi(grad.square().mul(1 - rho)); + // MxNDArray g = delta.add(epsilon).sqrt().div(s.add(epsilon).sqrt()).mul(grad); + // delta.muli(rho).addi(g.square().mul(1 - rho)); + // + // // Update weight + // weight.subi(g); + // } + // } + + public void adagradUpdate( + NDList inputs, + NDList weights, + float learningRate, + float weightDecay, + float rescaleGrad, + float clipGrad, + float epsilon) { + OpParams params = new OpParams(); + params.addParam("lr", learningRate); + params.addParam("wd", weightDecay); + params.addParam("rescale_grad", rescaleGrad); + params.addParam("clip_gradient", clipGrad); + + params.addParam("epsilon", epsilon); + + NDArray.invoke("adagrad_update", inputs, weights, params); + } + + public void adamUpdate( + NDList inputs, + NDList weights, + float learningRate, + float weightDecay, + float rescaleGrad, + float clipGrad, + float beta1, + float beta2, + float epsilon, + boolean lazyUpdate) { + OpParams params = new OpParams(); + params.addParam("lr", learningRate); + params.addParam("wd", weightDecay); + params.addParam("rescale_grad", rescaleGrad); + params.addParam("clip_gradient", clipGrad); + + params.addParam("beta1", beta1); + params.addParam("beta2", beta2); + params.addParam("epsilon", epsilon); + params.addParam("lazy_update", lazyUpdate); + + NDArray.invoke("adam_update", inputs, weights, params); + } + + public void rmspropUpdate( + NDList inputs, + NDList weights, + float learningRate, + float weightDecay, + float rescaleGrad, + float clipGrad, + float gamma1, + float gamma2, + float epsilon, + boolean centered) { + OpParams params = new OpParams(); + params.addParam("lr", learningRate); + params.addParam("wd", weightDecay); + params.addParam("rescale_grad", rescaleGrad); + params.addParam("clip_gradient", clipGrad); + + params.addParam("gamma1", gamma1); + params.addParam("epsilon", epsilon); + + if (!centered) { + NDArray.invoke("rmsprop_update", inputs, weights, params); + } else { + params.addParam("gamma2", gamma2); + + NDArray.invoke("rmspropalex_update", inputs, weights, params); + } + } + + public void nagUpdate( + NDList inputs, + NDList weights, + float learningRate, + float weightDecay, + float rescaleGrad, + float clipGrad, + float momentum) { + OpParams params = new OpParams(); + params.addParam("lr", learningRate); + params.addParam("wd", weightDecay); + params.addParam("rescale_grad", rescaleGrad); + params.addParam("clip_gradient", clipGrad); + params.addParam("momentum", momentum); + NDArray.invoke("nag_mom_update", inputs, weights, params); + } + + public void sgdUpdate( + NDList inputs, + NDList weights, + float learningRate, + float weightDecay, + float rescaleGrad, + float clipGrad, + float momentum, + boolean lazyUpdate) { + OpParams params = new OpParams(); + params.addParam("lr", learningRate); + params.addParam("wd", weightDecay); + params.addParam("rescale_grad", rescaleGrad); + params.addParam("clip_gradient", clipGrad); + params.addParam("lazy_update", lazyUpdate); + + if (momentum != 0) { + params.addParam("momentum", momentum); + NDArray.invoke("sgd_mom_update", inputs, weights, params); + } else { + NDArray.invoke("sgd_update", inputs, weights, params); + } + } + + //////////////////////////////////////// + // Neural network + //////////////////////////////////////// + + public NDList convolution( + NDArray input, + NDArray weight, + NDArray bias, + Shape stride, + Shape padding, + Shape dilation, + int groups) { + OpParams params = new OpParams(); + params.addParam("kernel", weight.getShape().slice(2)); + params.addParam("stride", stride); + params.addParam("pad", padding); + params.addParam("dilate", dilation); + params.addParam("num_group", groups); + params.addParam("num_filter", weight.getShape().get(0)); + + NDList inputs = new NDList(input, weight); + if (bias != null) { + params.add("no_bias", false); + inputs.add(bias); + } else { + params.add("no_bias", true); + } + + return NDArray.invoke(getArray().getParent(), "_npx_convolution", inputs, params); + } + + public NDList deconvolution( + NDArray input, + NDArray weight, + NDArray bias, + Shape stride, + Shape padding, + Shape outPadding, + Shape dilation, + int groups) { + OpParams params = new OpParams(); + params.addParam("kernel", weight.getShape().slice(2)); + params.addParam("stride", stride); + params.addParam("pad", padding); + params.addParam("adj", outPadding); + params.addParam("dilate", dilation); + params.addParam("num_group", groups); + params.addParam("num_filter", weight.getShape().get(0)); + + NDList inputs = new NDList(input, weight); + if (bias != null) { + params.add("no_bias", false); + inputs.add(bias); + } else { + params.add("no_bias", true); + } + + return NDArray.invoke(getArray().getParent(), "_npx_deconvolution", inputs, params); + } + + public NDList linear(NDArray input, NDArray weight, NDArray bias) { + OpParams params = new OpParams(); + params.addParam("num_hidden", weight.size(0)); + params.addParam("flatten", false); + params.addParam("no_bias", bias == null); + NDList inputs = new NDList(input, weight); + if (bias != null) { + inputs.add(bias); + } + + return NDArray.invoke(getArray().getParent(), "_npx_fully_connected", inputs, params); + } + + public NDList embedding(NDArray input, NDArray weight, SparseFormat sparse) { + if (!sparse.equals(SparseFormat.DENSE) && !sparse.equals(SparseFormat.ROW_SPARSE)) { + throw new IllegalArgumentException("MXNet only supports row sparse"); + } + OpParams params = new OpParams(); + long inputDim = weight.getShape().get(0); + long outputDim = weight.getShape().get(1); + params.addParam("input_dim", inputDim); + params.addParam("output_dim", outputDim); + params.addParam("sparse_grad", sparse.getValue()); + return NDArray.invoke( + getArray().getParent(), "_npx_embedding", new NDList(input, weight), params); + } + + public NDList prelu(NDArray input, NDArray alpha) { + OpParams params = new OpParams(); + params.addParam("act_type", "prelu"); + return NDArray.invoke( + getArray().getParent(), "_npx_leaky_relu", new NDList(input, alpha), params); + } + + public NDList dropout(NDArray input, float rate, boolean training) { + if (training != JnaUtils.autogradIsTraining()) { + throw new IllegalArgumentException( + "the mode of dropout in MXNet should align with the mode of GradientCollector"); + } + + OpParams params = new OpParams(); + params.addParam("p", rate); + + return NDArray.invoke(getArray().getParent(), "_npx_dropout", new NDList(input), params); + } + + public NDList batchNorm( + NDArray input, + NDArray runningMean, + NDArray runningVar, + NDArray gamma, + NDArray beta, + int axis, + float momentum, + float eps, + boolean training) { + OpParams params = new OpParams(); + params.addParam("axis", axis); + params.addParam("fix_gamma", gamma == null); + params.addParam("eps", eps); + params.addParam("momentum", momentum); + + if (training != JnaUtils.autogradIsTraining()) { + throw new IllegalArgumentException( + "the mode of batchNorm in MXNet should align with the mode of GradientCollector"); + } + + return NDArray.invoke( + getArray().getParent(), + "_npx_batch_norm", + new NDList(input, gamma, beta, runningMean, runningVar), + params); + } + + // public MxNDList rnn( + // MxNDArray input, + // MxNDArray state, + // MxNDList params, + // boolean hasBiases, + // int numLayers, + // RNN.Activation activation, + // double dropRate, + // boolean training, + // boolean bidirectional, + // boolean batchFirst) { + // int numParams = numLayers * ((hasBiases) ? 4 : 2) * ((bidirectional) ? 2 : 1); + // Preconditions.checkArgument( + // params.size() == numParams, + // "The size of Params is incorrect expect " + // + numParams + // + " parameters but got " + // + params.size()); + // + // if (training != JnaUtils.autogradIsTraining()) { + // throw new IllegalArgumentException( + // "the mode of rnn in MXNet should align with the mode of + // GradientCollector"); + // } + // + // if (batchFirst) { + // input = input.swapAxes(0, 1); + // } + // + // MxOpParams opParams = new MxOpParams(); + // opParams.addParam("p", dropRate); + // opParams.addParam("state_size", state.getShape().tail()); + // opParams.addParam("num_layers", numLayers); + // opParams.addParam("bidirectional", bidirectional); + // opParams.addParam("state_outputs", true); + // opParams.addParam("mode", activation == RNN.Activation.TANH ? "rnn_tanh" : + // "rnn_relu"); + // + // MxNDList inputs = new MxNDList(); + // inputs.add(input); + // + // try (MxNDList temp = new MxNDList()) { + // for (MxNDArray param : params) { + // temp.add(param.flatten()); + // } + // MxNDArray tempParam = MxNDArrays.concat(temp); + // tempParam.attach(input.getManager()); + // inputs.add(tempParam); + // } + // + // inputs.add(state); + // + // if (!batchFirst) { + // return getManager().invoke("_npx_rnn", inputs, opParams); + // } + // + // MxNDList result = getManager().invoke("_npx_rnn", inputs, opParams); + // try (MxNDArray temp = result.head()) { + // return new MxNDList(temp.swapAxes(0, 1), result.get(1)); + // } + // } + + // public MxNDList gru( + // MxNDArray input, + // MxNDArray state, + // MxNDList params, + // boolean hasBiases, + // int numLayers, + // double dropRate, + // boolean training, + // boolean bidirectional, + // boolean batchFirst) { + // int numParams = numLayers * ((hasBiases) ? 4 : 2) * ((bidirectional) ? 2 : 1); + // Preconditions.checkArgument( + // params.size() == numParams, + // "The size of Params is incorrect expect " + // + numParams + // + " parameters but got " + // + params.size()); + // + // if (training != JnaUtils.autogradIsTraining()) { + // throw new IllegalArgumentException( + // "the mode of gru in MXNet should align with the mode of + // GradientCollector"); + // } + // + // if (batchFirst) { + // input = input.swapAxes(0, 1); + // } + // + // MxOpParams opParams = new MxOpParams(); + // opParams.addParam("p", dropRate); + // opParams.addParam("state_size", state.getShape().tail()); + // opParams.addParam("num_layers", numLayers); + // opParams.addParam("bidirectional", bidirectional); + // opParams.addParam("state_outputs", true); + // opParams.addParam("mode", "gru"); + // + // MxNDList inputs = new MxNDList(); + // inputs.add(input); + // + // try (MxNDList temp = new MxNDList()) { + // for (MxNDArray param : params) { + // temp.add(param.flatten()); + // } + // MxNDArray tempParam = MxNDArrays.concat(temp); + // tempParam.attach(input.getManager()); + // inputs.add(tempParam); + // } + // + // inputs.add(state); + // + // if (!batchFirst) { + // return getManager().invoke("_npx_rnn", inputs, opParams); + // } + // + // MxNDList result = getManager().invoke("_npx_rnn", inputs, opParams); + // try (MxNDArray temp = result.head()) { + // return new MxNDList(temp.swapAxes(0, 1), result.get(1)); + // } + // } + // + // public MxNDList lstm( + // MxNDArray input, + // MxNDList states, + // MxNDList params, + // boolean hasBiases, + // int numLayers, + // double dropRate, + // boolean training, + // boolean bidirectional, + // boolean batchFirst) { + // int numParams = numLayers * ((hasBiases) ? 4 : 2) * ((bidirectional) ? 2 : 1); + // Preconditions.checkArgument( + // params.size() == numParams, + // "The size of Params is incorrect expect " + // + numParams + // + " parameters but got " + // + params.size()); + // + // if (training != JnaUtils.autogradIsTraining()) { + // throw new IllegalArgumentException( + // "the mode of lstm in MXNet should align with the mode of + // GradientCollector"); + // } + // + // if (batchFirst) { + // input = input.swapAxes(0, 1); + // } + // + // MxOpParams opParams = new MxOpParams(); + // opParams.addParam("mode", "lstm"); + // opParams.addParam("p", dropRate); + // opParams.addParam("state_size", states.head().getShape().tail()); + // opParams.addParam("state_outputs", true); + // opParams.addParam("num_layers", numLayers); + // opParams.addParam("bidirectional", bidirectional); + // opParams.addParam("lstm_state_clip_nan", true); + // + // MxNDList inputs = new MxNDList(); + // inputs.add(input); + // try (MxNDList temp = new MxNDList()) { + // for (MxNDArray param : params) { + // temp.add(param.flatten()); + // } + // MxNDArray tempParam = MxNDArrays.concat(temp); + // tempParam.attach(input.getManager()); + // inputs.add(tempParam); + // } + // inputs.addAll(states); + // + // if (!batchFirst) { + // return getManager().invoke("_npx_rnn", inputs, opParams); + // } + // + // MxNDList result = getManager().invoke("_npx_rnn", inputs, opParams); + // try (MxNDArray temp = result.head()) { + // return new MxNDList(temp.swapAxes(0, 1), result.get(1), result.get(2)); + // } + // } + + //////////////////////////////////////// + // Image and CV + //////////////////////////////////////// + + public NDArray normalize(float[] mean, float[] std) { + OpParams params = new OpParams(); + params.addTupleParam("mean", mean); + params.addTupleParam("std", std); + return NDArray.invoke(getArray().getParent(), "_npx__image_normalize", array, params); + } + + public NDArray toTensor() { + return NDArray.invoke(getArray().getParent(), "_npx__image_to_tensor", array, null); + } + + public NDArray resize(int width, int height, int interpolation) { + if (array.isEmpty()) { + throw new IllegalArgumentException("attempt to resize of an empty MxNDArray"); + } + OpParams params = new OpParams(); + params.addTupleParam("size", width, height); + params.addParam("interp", interpolation); + return NDArray.invoke(getArray().getParent(), "_npx__image_resize", array, params); + } + + public NDArray crop(int x, int y, int width, int height) { + OpParams params = new OpParams(); + params.add("x", x); + params.add("y", y); + params.add("width", width); + params.add("height", height); + return NDArray.invoke(getArray().getParent(), "_npx__image_crop", array, params); + } + + public NDArray randomFlipLeftRight() { + if (array.getDevice().getDeviceType().equals(Device.Type.GPU)) { + throw new UnsupportedOperationException("randomFlipLeftRight is not supported on GPU"); + } + return NDArray.invoke( + getArray().getParent(), "_npx__image_random_flip_left_right", array, null); + } + + public NDArray randomFlipTopBottom() { + if (array.getDevice().getDeviceType().equals(Device.Type.GPU)) { + throw new UnsupportedOperationException("randomFlipTopBottom is not supported on GPU"); + } + return NDArray.invoke( + getArray().getParent(), "_npx__image_random_flip_top_bottom", array, null); + } + + public NDArray randomBrightness(float brightness) { + if (array.getDevice().getDeviceType().equals(Device.Type.GPU)) { + throw new UnsupportedOperationException("randomBrightness is not supported on GPU"); + } + OpParams params = new OpParams(); + float min = Math.max(0, 1 - brightness); + float max = 1 + brightness; + params.addParam("min_factor", min); + params.addParam("max_factor", max); + return NDArray.invoke( + getArray().getParent(), "_npx__image_random_brightness", array, params); + } + + public NDArray randomHue(float hue) { + if (array.getDevice().getDeviceType().equals(Device.Type.GPU)) { + throw new UnsupportedOperationException("randomHue is not supported on GPU"); + } + OpParams params = new OpParams(); + float min = Math.max(0, 1 - hue); + float max = 1 + hue; + params.addParam("min_factor", min); + params.addParam("max_factor", max); + return NDArray.invoke(getArray().getParent(), "_npx__image_random_hue", array, params); + } + + public NDArray randomColorJitter( + float brightness, float contrast, float saturation, float hue) { + if (array.getDevice().getDeviceType().equals(Device.Type.GPU)) { + throw new UnsupportedOperationException("randomColorJitter is not supported on GPU"); + } + OpParams params = new OpParams(); + params.addParam("brightness", brightness); + params.addParam("contrast", contrast); + params.addParam("saturation", saturation); + params.addParam("hue", hue); + return NDArray.invoke( + getArray().getParent(), "_npx__image_random_color_jitter", array, params); + } + + public NDArrayIndexer getIndexer() { + return INDEXER; + } + + //////////////////////////////////////// + // Miscellaneous + //////////////////////////////////////// + + @SuppressWarnings("PMD.UseTryWithResources") + public NDArray where(NDArray condition, NDArray other) { + NDArray array1; + NDArray array2; + condition = + (condition.getDataType() == DataType.BOOLEAN) + ? condition.toType(DataType.INT32, false) + : condition; + if (array.getDataType() != other.getDataType()) { + throw new IllegalArgumentException( + "DataType mismatch, required " + + array.getDataType() + + " actual " + + other.getDataType()); + } + if (!array.shapeEquals(other)) { + Shape res = deriveBroadcastedShape(array.getShape(), other.getShape()); + array1 = (!res.equals(array.getShape())) ? array.broadcast(res) : array; + array2 = (!res.equals(other.getShape())) ? other.broadcast(res) : other; + } else { + array1 = array; + array2 = other; + } + try { + return NDArray.invoke( + getArray().getParent(), + "where", + new NDArray[] {condition, array1, array2}, + null); + } finally { + if (array1 != array) { + array1.close(); + } + if (array2 != other) { + array2.close(); + } + } + } + + public NDArray stack(NDList arrays, int axis) { + OpParams params = new OpParams(); + params.addParam("axis", axis); + NDArray[] srcArray = new NDArray[arrays.size() + 1]; + srcArray[0] = array; + System.arraycopy(arrays.toArray(new NDArray[0]), 0, srcArray, 1, arrays.size()); + return NDArray.invoke(getArray().getParent(), "_npi_stack", srcArray, params); + } + + /** + * Check two criteria of concat input: 1. no scalar 2. dimensions of all the array must be the + * same. + * + * @param list input {@link NDList} + */ + public static void checkConcatInput(NDList list) { + NDArray[] arrays = list.toArray(new NDArray[0]); + if (Stream.of(arrays).allMatch(array -> array.getShape().dimension() == 0)) { + throw new IllegalArgumentException( + "scalar(zero-dimensional) arrays cannot be concatenated"); + } + int dimension = arrays[0].getShape().dimension(); + for (int i = 1; i < arrays.length; i++) { + if (arrays[i].getShape().dimension() != dimension) { + throw new IllegalArgumentException( + "all the input arrays must have same number of dimensions, but the array at index 0 has " + + dimension + + " dimension(s) and the array at index " + + i + + " has " + + arrays[i].getShape().dimension() + + " dimension(s)"); + } + } + } + + public NDArray concat(NDList list, int axis) { + checkConcatInput(list); + + OpParams params = new OpParams(); + // MXNet backend use dim as argument name + params.addParam("axis", axis); + NDArray[] srcArray = new NDArray[list.size() + 1]; + srcArray[0] = array; + System.arraycopy(list.toArray(new NDArray[0]), 0, srcArray, 1, list.size()); + return NDArray.invoke(getArray().getParent(), "_npi_concatenate", srcArray, params); + } + + public NDList multiBoxTarget( + NDList inputs, + float iouThreshold, + float ignoreLabel, + float negativeMiningRatio, + float negativeMiningThreshold, + int minNegativeSamples) { + OpParams parameters = new OpParams(); + parameters.add("minimum_negative_samples", minNegativeSamples); + parameters.add("overlap_threshold", iouThreshold); + parameters.add("ignore_label", ignoreLabel); + parameters.add("negative_mining_ratio", negativeMiningRatio); + parameters.add("negative_mining_thresh", negativeMiningThreshold); + return NDArray.invoke(getArray().getParent(), "MultiBoxTarget", inputs, parameters); + } + + public NDList multiBoxPrior( + List sizes, + List ratios, + List steps, + List offsets, + boolean clip) { + OpParams parameters = new OpParams(); + parameters.add("sizes", sizes); + parameters.add("ratios", ratios); + parameters.add("steps", steps); + parameters.add("offsets", offsets); + parameters.add("clip", clip); + return NDArray.invoke( + getArray().getParent(), "MultiBoxPrior", new NDList(array), parameters); + } + + public NDList multiBoxDetection( + NDList inputs, + boolean clip, + float threshold, + int backgroundId, + float nmsThreashold, + boolean forceSuppress, + int nmsTopK) { + OpParams parameters = new OpParams(); + parameters.add("clip", clip); + parameters.add("threshold", threshold); + parameters.add("background_id", backgroundId); + parameters.add("nms_threshold", nmsThreashold); + parameters.add("force_suppress", forceSuppress); + parameters.add("nms_topk", nmsTopK); + return NDArray.invoke(getArray().getParent(), "MultiBoxDetection", inputs, parameters); + } + + public NDArray getArray() { + return array; + } + + private int getGlobalPoolingDim() { + int poolDim = getArray().getShape().dimension() - 2; + if (poolDim < 1 || poolDim > 3) { + throw new IllegalStateException( + "GlobalPooling only support" + + "1 to 3 Dimensions, " + + poolDim + + "D is not supported."); + } + return poolDim; + } + + private Shape getGlobalPoolingShapes(long fillValue) { + // determine pooling dimension according to input + // input dimension minus 2 (batch and channel dim) + int poolDim = getGlobalPoolingDim(); + long[] shape = new long[poolDim]; + Arrays.fill(shape, fillValue); + return new Shape(shape); + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArrayIndexer.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArrayIndexer.java new file mode 100644 index 000000000000..ee93fd7f9156 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArrayIndexer.java @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.ndarray; + +import java.util.List; +import java.util.Optional; +import java.util.Stack; +import org.apache.mxnet.engine.OpParams; +import org.apache.mxnet.ndarray.dim.NDIndexBooleans; +import org.apache.mxnet.ndarray.dim.NDIndexElement; +import org.apache.mxnet.ndarray.dim.full.NDIndexFullPick; +import org.apache.mxnet.ndarray.dim.full.NDIndexFullSlice; +import org.apache.mxnet.ndarray.index.NDIndex; +import org.apache.mxnet.ndarray.types.Shape; + +/** A helper class for {@link NDArray} implementations for operations with an {@link NDIndex}. */ +public class NDArrayIndexer { + + /** + * Returns a subarray by picking the elements. + * + * @param array the array to get from + * @param index the index to get + * @return the subArray + */ + public NDArray get(NDArray array, NDIndex index) { + if (index.getRank() == 0 && array.getShape().isScalar()) { + return array.duplicate(); + } + + // use booleanMask for NDIndexBooleans case + List indices = index.getIndices(); + if (!indices.isEmpty() && indices.get(0) instanceof NDIndexBooleans) { + if (indices.size() != 1) { + throw new IllegalArgumentException( + "get() currently didn't support more that one boolean NDArray"); + } + return array.booleanMask(((NDIndexBooleans) indices.get(0)).getIndex()); + } + + Optional fullPick = NDIndexFullPick.fromIndex(index, array.getShape()); + if (fullPick.isPresent()) { + return get(array, fullPick.get()); + } + + Optional fullSlice = NDIndexFullSlice.fromIndex(index, array.getShape()); + if (fullSlice.isPresent()) { + return get(array, fullSlice.get()); + } + throw new UnsupportedOperationException( + "get() currently supports all, fixed, and slices indices"); + } + + /** + * Returns a subarray by picking the elements. + * + * @param array the array to get from + * @param fullPick the elements to pick + * @return the subArray + */ + public NDArray get(NDArray array, NDIndexFullPick fullPick) { + OpParams params = new OpParams(); + params.addParam("axis", fullPick.getAxis()); + params.addParam("keepdims", true); + params.add("mode", "wrap"); + return NDArray.invoke( + array.getParent(), "pick", new NDList(array, fullPick.getIndices()), params) + .singletonOrThrow(); + } + + /** + * Returns a subarray at the slice. + * + * @param array the array to get from + * @param fullSlice the fullSlice index of the array + * @return the subArray + */ + public NDArray get(NDArray array, NDIndexFullSlice fullSlice) { + OpParams params = new OpParams(); + params.addTupleParam("begin", fullSlice.getMin()); + params.addTupleParam("end", fullSlice.getMax()); + params.addTupleParam("step", fullSlice.getStep()); + + NDArray result = NDArray.invoke(array.getParent(), "_npi_slice", array, params); + int[] toSqueeze = fullSlice.getToSqueeze(); + if (toSqueeze.length > 0) { + NDArray oldResult = result; + result = result.squeeze(toSqueeze); + oldResult.close(); + } + return result; + } + + /** + * Sets the values of the array at the fullSlice with an array. + * + * @param array the array to set + * @param fullSlice the fullSlice of the index to set in the array + * @param value the value to set with + */ + public void set(NDArray array, NDIndexFullSlice fullSlice, NDArray value) { + OpParams params = new OpParams(); + params.addTupleParam("begin", fullSlice.getMin()); + params.addTupleParam("end", fullSlice.getMax()); + params.addTupleParam("step", fullSlice.getStep()); + + Stack prepareValue = new Stack<>(); + prepareValue.add(value); + prepareValue.add(prepareValue.peek().toDevice(array.getDevice(), false)); + // prepareValue.add(prepareValue.peek().asType(getDataType(), false)); + // Deal with the case target: (1, 10, 1), original (10) + // try to find (10, 1) and reshape (10) to that + Shape targetShape = fullSlice.getShape(); + while (targetShape.size() > value.size()) { + targetShape = targetShape.slice(1); + } + prepareValue.add(prepareValue.peek().reshape(targetShape)); + prepareValue.add(prepareValue.peek().broadcast(fullSlice.getShape())); + + NDArray.invoke( + "_npi_slice_assign", + new NDArray[] {array, prepareValue.peek()}, + new NDArray[] {array}, + params); + for (NDArray toClean : prepareValue) { + if (toClean != value) { + toClean.close(); + } + } + } + + /** + * Sets the values of the array at the fullSlice with a number. + * + * @param array the array to set + * @param fullSlice the fullSlice of the index to set in the array + * @param value the value to set with + */ + public void set(NDArray array, NDIndexFullSlice fullSlice, Number value) { + OpParams params = new OpParams(); + params.addTupleParam("begin", fullSlice.getMin()); + params.addTupleParam("end", fullSlice.getMax()); + params.addTupleParam("step", fullSlice.getStep()); + params.addParam("scalar", value); + NDArray.invoke( + "_npi_slice_assign_scalar", new NDArray[] {array}, new NDArray[] {array}, params); + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArrays.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArrays.java new file mode 100644 index 000000000000..df0c6b84bc0a --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArrays.java @@ -0,0 +1,2008 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.ndarray; + +import java.util.Arrays; +import org.apache.mxnet.ndarray.types.Shape; + +/** This class contains various methods for manipulating MxNDArrays. */ +public final class NDArrays { + + private NDArrays() {} + + private static void checkInputs(NDArray[] arrays) { + if (arrays == null || arrays.length < 2) { + throw new IllegalArgumentException("Passed in arrays must have at least one element"); + } + if (arrays.length > 2 + && Arrays.stream(arrays).skip(1).anyMatch(array -> !arrays[0].shapeEquals(array))) { + throw new IllegalArgumentException("The shape of all inputs must be the same"); + } + } + + //////////////////////////////////////// + // Operations: Element Comparison + //////////////////////////////////////// + + /** + * Returns {@code true} if all elements in {@link NDArray} a are equal to {@link NDArray} b. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.ones(new Shape(3));
+     * jshell> MxNDArrays.contentEquals(array, 1); // return true instead of boolean MxNDArray
+     * true
+     * 
+ * + * @param a the {@link NDArray} to compare + * @param n the number to compare + * @return the boolean result + */ + public static boolean contentEquals(NDArray a, Number n) { + if (a == null) { + return false; + } + return a.contentEquals(n); + } + + /** + * Returns {@code true} if all elements in {@link NDArray} a are equal to {@link NDArray} b. + * + *

Examples + * + *

+     * jshell> MxNDArray array1 = manager.arange(6f).reshape(2, 3);
+     * jshell> MxNDArray array2 = manager.create(new float[] {0f, 1f, 2f, 3f, 4f, 5f}, new Shape(2, 3));
+     * jshell> MxNDArrays.contentEquals(array1, array2); // return true instead of boolean MxNDArray
+     * true
+     * 
+ * + * @param a the {@link NDArray} to compare + * @param b the {@link NDArray} to compare + * @return the boolean result + */ + public static boolean contentEquals(NDArray a, NDArray b) { + return a.contentEquals(b); + } + + /** + * Checks 2 {@link NDArray}s for equal shapes. + * + *

Shapes are considered equal if: + * + *

    + *
  • Both {@link NDArray}s have equal rank, and + *
  • size(0)...size(rank()-1) are equal for both {@link NDArray}s + *
+ * + *

Examples + * + *

+     * jshell> MxNDArray array1 = manager.ones(new Shape(1, 2, 3));
+     * jshell> MxNDArray array2 = manager.create(new Shape(1, 2, 3));
+     * jshell> MxNDArrays.shapeEquals(array1, array2); // return true instead of boolean MxNDArray
+     * true
+     * 
+ * + * @param a the {@link NDArray} to compare + * @param b the {@link NDArray} to compare + * @return {@code true} if the {@link Shape}s are the same + */ + public static boolean shapeEquals(NDArray a, NDArray b) { + return a.shapeEquals(b); + } + + /** + * Returns {@code true} if two {@link NDArray} are element-wise equal within a tolerance. + * + *

Examples + * + *

+     * jshell> MxNDArray array1 = manager.create(new double[] {1e10,1e-7});
+     * jshell> MxNDArray array2 = manager.create(new double[] {1.00001e10,1e-8});
+     * jshell> MxNDArrays.allClose(array1, array2); // return false instead of boolean MxNDArray
+     * false
+     * jshell> MxNDArray array1 = manager.create(new double[] {1e10,1e-8});
+     * jshell> MxNDArray array2 = manager.create(new double[] {1.00001e10,1e-9});
+     * jshell> MxNDArrays.allClose(array1, array2); // return true instead of boolean MxNDArray
+     * true
+     * 
+ * + * @param a the {@link NDArray} to compare with + * @param b the {@link NDArray} to compare with + * @return the boolean result + */ + // public static boolean allClose(MxNDArray a, MxNDArray b) { + // return a.allClose(b); + // } + + /** + * Returns {@code true} if two {@link NDArray} are element-wise equal within a tolerance. + * + *

Examples + * + *

+     * jshell> MxNDArray array1 = manager.create(new double[] {1e10, 1e-7});
+     * jshell> MxNDArray array2 = manager.create(new double[] {1.00001e10, 1e-8});
+     * jshell> MxNDArrays.allClose(array1, array2, 1e-05, 1e-08, false); // return false instead of boolean MxNDArray
+     * false
+     * jshell> MxNDArray array1 = manager.create(new double[] {1e10, 1e-8});
+     * jshell> MxNDArray array2 = manager.create(new double[] {1.00001e10, 1e-9});
+     * jshell> MxNDArrays.allClose(array1, array2, 1e-05, 1e-08, false); // return true instead of boolean MxNDArray
+     * true
+     * jshell> MxNDArray array1 = manager.create(new float[] {1f, Float.NaN});
+     * jshell> MxNDArray array2 = manager.create(new float[] {1f, Float.NaN});
+     * jshell> MxNDArrays.allClose(array1, array2, 1e-05, 1e-08, true); // return true instead of boolean MxNDArray
+     * true
+     * 
+ * + * @param a the {@link NDArray} to compare with + * @param b the {@link NDArray} to compare with + * @param rtol the relative tolerance parameter + * @param atol the absolute tolerance parameter + * @param equalNan whether to compare NaN’s as equal. If {@code true}, NaN’s in the {@link + * NDArray} will be considered equal to NaN’s in the other {@link NDArray} + * @return the boolean result + */ + // public static boolean allClose( + // MxNDArray a, MxNDArray b, double rtol, double atol, boolean equalNan) { + // return a.allClose(b, rtol, atol, equalNan); + // } + + /** + * Returns the boolean {@link NDArray} for element-wise "Equals" comparison. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.ones(new Shape(1));
+     * jshell> MxNDArrays.eq(array, 1);
+     * ND: (1) cpu() boolean
+     * [ true]
+     * 
+ * + * @param a the {@link NDArray} to compare + * @param n the number to compare + * @return the boolean {@link NDArray} for element-wise "Equals" comparison + */ + public static NDArray eq(NDArray a, Number n) { + return a.eq(n); + } + + /** + * Returns the boolean {@link NDArray} for element-wise "Equals" comparison. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.ones(new Shape(1));
+     * jshell> MxNDArrays.eq(1, array);
+     * ND: (1) cpu() boolean
+     * [ true]
+     * 
+ * + * @param n the number to compare + * @param a the {@link NDArray} to compare + * @return the boolean {@link NDArray} for element-wise "Equals" comparison + */ + public static NDArray eq(Number n, NDArray a) { + return a.eq(n); + } + + /** + * Returns the boolean {@link NDArray} for element-wise "Equals" comparison. + * + *

Examples + * + *

+     * jshell> MxNDArray array1 = manager.create(new float[] {0f, 1f, 3f});
+     * jshell> MxNDArray array2 = manager.arange(3f);
+     * jshell> MxNDArrays.eq(array1, array2);
+     * ND: (3) cpu() boolean
+     * [ true,  true, false]
+     * 
+ * + * @param a the {@link NDArray} to compare + * @param b the {@link NDArray} to compare + * @return the boolean {@link NDArray} for element-wise "Equals" comparison + */ + public static NDArray eq(NDArray a, NDArray b) { + return a.eq(b); + } + + /** + * Returns the boolean {@link NDArray} for element-wise "Not equals" comparison. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.arange(4f).reshape(2, 2);
+     * jshell> MxNDArrays.neq(array, 1);
+     * ND: (2, 2) cpu() boolean
+     * [[ true, false],
+     *  [ true,  true],
+     * ]
+     * 
+ * + * @param a the {@link NDArray} to compare + * @param n the number to compare + * @return the boolean {@link NDArray} for element-wise "Not equals" comparison + */ + public static NDArray neq(NDArray a, Number n) { + return a.neq(n); + } + + /** + * Returns the boolean {@link NDArray} for element-wise "Not equals" comparison. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.arange(f4).reshape(2, 2);
+     * jshell> MxNDArrays.neq(1, array);
+     * ND: (2, 2) cpu() boolean
+     * [[ true, false],
+     *  [ true,  true],
+     * ]
+     * 
+ * + * @param n the number to compare + * @param a the {@link NDArray} to compare + * @return the boolean {@link NDArray} for element-wise "Not equals" comparison + */ + public static NDArray neq(Number n, NDArray a) { + return a.neq(n); + } + + /** + * Returns the boolean {@link NDArray} for element-wise "Not equals" comparison. + * + *

Examples + * + *

+     * jshell> MxNDArray array1 = manager.create(new float[] {1f, 2f});
+     * jshell> MxNDArray array2 = manager.create(new float[] {1f, 3f});
+     * jshell> MxNDArrays.neq(array1, array2);
+     * ND: (2) cpu() boolean
+     * [false,  true]
+     * jshell> MxNDArray array1 = manager.create(new float[] {1f, 2f});
+     * jshell> MxNDArray array2 = manager.create(new float[] {1f, 3f, 1f, 4f}, new Shape(2, 2));
+     * jshell> MxNDArrays.neq(array1, array2); // broadcasting
+     * ND: (2, 2) cpu() boolean
+     * [[false,  true],
+     *  [false,  true],
+     * ]
+     * 
+ * + * @param a the {@link NDArray} to compare + * @param b the {@link NDArray} to compare + * @return the boolean {@link NDArray} for element-wise "Not equals" comparison + */ + public static NDArray neq(NDArray a, NDArray b) { + return a.neq(b); + } + + /** + * Returns the boolean {@link NDArray} for element-wise "Greater Than" comparison. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.create(new float[] {4f, 2f});
+     * jshell> MxNDArrays.gt(array, 2f);
+     * ND: (2) cpu() boolean
+     * [ true, false]
+     * 
+ * + * @param a the {@link NDArray} to compare + * @param n the number to be compared against + * @return the boolean {@link NDArray} for element-wise "Greater Than" comparison + */ + public static NDArray gt(NDArray a, Number n) { + return a.gt(n); + } + + /** + * Returns the boolean {@link NDArray} for element-wise "Greater Than" comparison. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.create(new float[] {4f, 2f});
+     * jshell> MxNDArrays.gt(2f, array);
+     * ND: (2) cpu() boolean
+     * [false, false]
+     * 
+ * + * @param n the number to be compared + * @param a the MxNDArray to be compared against + * @return the boolean {@link NDArray} for element-wise "Greater Than" comparison + */ + public static NDArray gt(Number n, NDArray a) { + return a.lt(n); + } + + /** + * Returns the boolean {@link NDArray} for element-wise "Greater Than" comparison. + * + *

Examples + * + *

+     * jshell> MxNDArray array1 = manager.create(new float[] {4f, 2f});
+     * jshell> MxNDArray array2 = manager.create(new float[] {2f, 2f});
+     * jshell> MxNDArrays.gt(array1, array2);
+     * ND: (2) cpu() boolean
+     * [ true, false]
+     * 
+ * + * @param a the {@link NDArray} to be compared + * @param b the {@link NDArray} to be compared against + * @return the boolean {@link NDArray} for element-wise "Greater Than" comparison + */ + public static NDArray gt(NDArray a, NDArray b) { + return a.gt(b); + } + + /** + * Returns the boolean {@link NDArray} for element-wise "Greater or equals" comparison. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.create(new float[] {4f, 2f});
+     * jshell> MxNDArrays.gte(array, 2);
+     * ND: (2) cpu() boolean
+     * [ true, true]
+     * 
+ * + * @param a the {@link NDArray} to be compared + * @param n the number to be compared against + * @return the boolean {@link NDArray} for element-wise "Greater or equals" comparison + */ + public static NDArray gte(NDArray a, Number n) { + return a.gte(n); + } + + /** + * Returns the boolean {@link NDArray} for element-wise "Greater or equals" comparison. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.create(new float[] {4f, 2f});
+     * jshell> MxNDArrays.gte(2, array);
+     * ND: (2) cpu() boolean
+     * [false,  true]
+     * 
+ * + * @param n the number to be compared + * @param a the {@link NDArray} to be compared against + * @return the boolean {@link NDArray} for element-wise "Greater or equals" comparison + */ + public static NDArray gte(Number n, NDArray a) { + return a.lte(n); + } + + /** + * Returns the boolean {@link NDArray} for element-wise "Greater or equals" comparison. + * + *

Examples + * + *

+     * jshell> MxNDArray array1 = manager.create(new float[] {4f, 2f});
+     * jshell> MxNDArray array2 = manager.create(new float[] {2f, 2f});
+     * jshell> MxNDArrays.gte(array1, array2);
+     * ND: (2) cpu() boolean
+     * [ true, true]
+     * 
+ * + * @param a the {@link NDArray} to be compared + * @param b the {@link NDArray} to be compared against + * @return the boolean {@link NDArray} for element-wise "Greater or equals" comparison + */ + public static NDArray gte(NDArray a, NDArray b) { + return a.gte(b); + } + + /** + * Returns the boolean {@link NDArray} for element-wise "Less" comparison. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+     * jshell> MxNDArrays.lt(array, 2f);
+     * ND: (2) cpu() boolean
+     * [ true, false]
+     * 
+ * + * @param a the {@link NDArray} to be compared + * @param n the number to be compared against + * @return the boolean {@link NDArray} for element-wise "Less" comparison + */ + public static NDArray lt(NDArray a, Number n) { + return a.lt(n); + } + + /** + * Returns the boolean {@link NDArray} for element-wise "Less" comparison. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+     * jshell> MxNDArrays.lt(2f, array);
+     * ND: (2) cpu() boolean
+     * [false, false]
+     * 
+ * + * @param n the number to be compared + * @param a the {@link NDArray} to be compared against + * @return the boolean {@link NDArray} for element-wise "Less" comparison + */ + public static NDArray lt(Number n, NDArray a) { + return a.gt(n); + } + + /** + * Returns the boolean {@link NDArray} for element-wise "Less" comparison. + * + *

Examples + * + *

+     * jshell> MxNDArray array1 = manager.create(new float[] {1f, 2f});
+     * jshell> MxNDArray array2 = manager.create(new float[] {2f, 2f});
+     * jshell> MxNDArrays.lt(array1, array2);
+     * ND: (2) cpu() boolean
+     * [ true, false]
+     * 
+ * + * @param a the {@link NDArray} to be compared + * @param b the {@link NDArray} to be compared against + * @return the boolean {@link NDArray} for element-wise "Less" comparison + */ + public static NDArray lt(NDArray a, NDArray b) { + return a.lt(b); + } + + /** + * Returns the boolean {@link NDArray} for element-wise "Less or equals" comparison. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+     * jshell> MxNDArrays.lte(array, 2f);
+     * ND: (2) cpu() boolean
+     * [ true, true]
+     * 
+ * + * @param a the {@link NDArray} to be compared + * @param n the number to be compared against + * @return the boolean {@link NDArray} for element-wise "Less or equals" comparison + */ + public static NDArray lte(NDArray a, Number n) { + return a.lte(n); + } + + /** + * Returns the boolean {@link NDArray} for element-wise "Less or equals" comparison. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+     * jshell> MxNDArrays.lte(2f, array);
+     * ND: (2) cpu() boolean
+     * [false,  true]
+     * 
+ * + * @param n the number to be compared + * @param a the {@link NDArray} to be compared against + * @return the boolean {@link NDArray} for element-wise "Less or equals" comparison + */ + public static NDArray lte(Number n, NDArray a) { + return a.gte(n); + } + + /** + * Returns the boolean {@link NDArray} for element-wise "Less or equals" comparison. + * + *

Examples + * + *

+     * jshell> MxNDArray array1 = manager.create(new float[] {1f, 2f});
+     * jshell> MxNDArray array2 = manager.create(new float[] {2f, 2f});
+     * jshell> MxNDArrays.lte(array1, array2)
+     * ND: (2) cpu() boolean
+     * [ true, true]
+     * 
+ * + * @param a the {@link NDArray} to be compared + * @param b the {@link NDArray} to be compared against + * @return the boolean {@link NDArray} for element-wise "Less or equals" comparison + */ + public static NDArray lte(NDArray a, NDArray b) { + return a.lte(b); + } + + /** + * Returns elements chosen from the {@link NDArray} or the other {@link NDArray} depending on + * condition. + * + *

Given three {@link NDArray}s, condition, a, and b, returns an {@link NDArray} with the + * elements from a or b, depending on whether the elements from condition {@link NDArray} are + * {@code true} or {@code false}. If condition has the same shape as a, each element in the + * output {@link NDArray} is from this if the corresponding element in the condition is {@code + * true}, and from other if {@code false}. + * + *

Note that all non-zero values are interpreted as {@code true} in condition {@link + * NDArray}. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.arange(10f);
+     * jshell> MxNDArrays.where(array.lt(5), array, array.mul(10));
+     * ND: (10) cpu() float32
+     * [ 0.,  1.,  2.,  3.,  4., 50., 60., 70., 80., 90.]
+     * jshell> MxNDArray array = manager.create(new float[]{0f, 1f, 2f, 0f, 2f, 4f, 0f, 3f, 6f}, new Shape(3, 3));
+     * jshell> MxNDArrays.where(array.lt(4), array, manager.create(-1f));
+     * ND: (3, 3) cpu() float32
+     * [[ 0.,  1.,  2.],
+     *  [ 0.,  2., -1.],
+     *  [ 0.,  3., -1.],
+     * ]
+     * 
+ * + * @param condition the condition {@code MxNDArray} + * @param a the first {@link NDArray} + * @param b the other {@link NDArray} + * @return the result {@link NDArray} + */ + public static NDArray where(NDArray condition, NDArray a, NDArray b) { + return a.getNDArrayInternal().where(condition, b); + } + + /** + * Returns the maximum of a {@link NDArray} and a number element-wise. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.create(new float[] {2f, 3f, 4f});
+     * jshell> MxNDArrays.maximum(array, 3f);
+     * ND: (3) cpu() float32
+     * [3., 3., 4.]
+     * 
+ * + * @param a the {@link NDArray} to be compared + * @param n the number to be compared + * @return the maximum of a {@link NDArray} and a number element-wise + */ + public static NDArray maximum(NDArray a, Number n) { + return a.maximum(n); + } + + /** + * Returns the maximum of a number and a {@link NDArray} element-wise. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.create(new float[] {2f, 3f, 4f});
+     * jshell> MxNDArrays.maximum(3f, array);
+     * ND: (3) cpu() float32
+     * [3., 3., 4.]
+     * 
+ * + * @param n the number to be compared + * @param a the {@link NDArray} to be compared + * @return the maximum of a number and a {@link NDArray} element-wise + */ + public static NDArray maximum(Number n, NDArray a) { + return maximum(a, n); + } + + /** + * Returns the maximum of {@link NDArray} a and {@link NDArray} b element-wise. + * + *

The shapes of {@link NDArray} a and {@link NDArray} b must be broadcastable. + * + *

Examples + * + *

+     * jshell> MxNDArray array1 = manager.create(new float[] {2f, 3f, 4f});
+     * jshell> MxNDArray array2 = manager.create(new float[] {1f, 5f, 2f});
+     * jshell> MxNDArrays.maximum(array1, array2);
+     * ND: (3) cpu() float32
+     * [2., 5., 4.]
+     * jshell> MxNDArray array1 = manager.eye(2);
+     * jshell> MxNDArray array2 = manager.create(new float[] {0.5f, 2f});
+     * jshell> MxNDArrays.maximum(array1, array2); // broadcasting
+     * ND: (2, 2) cpu() float32
+     * [[1. , 2. ],
+     *  [0.5, 2. ],
+     * ]
+     * 
+ * + * @param a the {@link NDArray} to be compared + * @param b the {@link NDArray} to be compared + * @return the maximum of {@link NDArray} a and {@link NDArray} b element-wise + */ + public static NDArray maximum(NDArray a, NDArray b) { + return a.maximum(b); + } + + /** + * Returns the minimum of a {@link NDArray} and a number element-wise. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.create(new float[] {2f, 3f, 4f});
+     * jshell> MxNDArrays.minimum(array, 3f);
+     * ND: (3) cpu() float32
+     * [2., 3., 3.]
+     * 
+ * + * @param a the {@link NDArray} to be compared + * @param n the number to be compared + * @return the minimum of a {@link NDArray} and a number element-wise + */ + public static NDArray minimum(NDArray a, Number n) { + return a.minimum(n); + } + + /** + * Returns the minimum of a number and a {@link NDArray} element-wise. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.create(new float[] {2f, 3f, 4f});
+     * jshell> MxNDArrays.minimum(3f, array);
+     * ND: (3) cpu() float32
+     * [2., 3., 3.]
+     * 
+ * + * @param n the number to be compared + * @param a the {@link NDArray} to be compared + * @return the minimum of a number and a {@link NDArray} element-wise + */ + public static NDArray minimum(Number n, NDArray a) { + return minimum(a, n); + } + + /** + * Returns the minimum of {@link NDArray} a and {@link NDArray} b element-wise. + * + *

The shapes of {@link NDArray} a and {@link NDArray} b must be broadcastable. + * + *

Examples + * + *

+     * jshell> MxNDArray array1 = manager.create(new float[] {2f, 3f, 4f});
+     * jshell> MxNDArray array2 = manager.create(new float[] {1f, 5f, 2f});
+     * jshell> MxNDArrays.minimum(array1, array2);
+     * ND: (3) cpu() float32
+     * [1., 3., 2.]
+     * jshell> MxNDArray array1 = manager.eye(2);
+     * jshell> MxNDArray array2 = manager.create(new float[] {0.5f, 2f});
+     * jshell> MxNDArrays.minimum(array1, array2); // broadcasting
+     * ND: (2, 2) cpu() float32
+     * [[0.5, 0. ],
+     *  [0. , 1. ],
+     * ]
+     * 
+ * + * @param a the {@link NDArray} to be compared + * @param b the {@link NDArray} to be compared + * @return the minimum of {@link NDArray} a and {@link NDArray} b element-wise + */ + public static NDArray minimum(NDArray a, NDArray b) { + return a.minimum(b); + } + + /** + * Returns portion of the {@link NDArray} given the index boolean {@link NDArray} along first + * axis. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.create(new float[] {1f, 2f, 3f, 4f, 5f, 6f}, new Shape(3, 2));
+     * jshell> MxNDArray mask = manager.create(new boolean[] {true, false, true});
+     * jshell> MxNDArrays.booleanMask(array, mask);
+     * ND: (2, 2) cpu() float32
+     * [[1., 2.],
+     *  [5., 6.],
+     * ]
+     * 
+ * + * @param data the {@link NDArray} to operate on + * @param index the boolean {@link NDArray} mask + * @return the result {@link NDArray} + */ + public static NDArray booleanMask(NDArray data, NDArray index) { + return booleanMask(data, index, 0); + } + + /** + * Returns portion of the {@link NDArray} given the index boolean {@link NDArray} along given + * axis. + * + * @param data the {@link NDArray} to operate on + * @param index the boolean {@link NDArray} mask + * @param axis an integer that represents the axis of {@link NDArray} to mask from + * @return the result {@link NDArray} + */ + public static NDArray booleanMask(NDArray data, NDArray index, int axis) { + return data.booleanMask(index, axis); + } + + /** + * Sets all elements of the given {@link NDArray} outside the sequence {@link NDArray} to a + * constant value. + * + *

This function takes an n-dimensional input array of the form [batch_size, + * max_sequence_length, ....] and returns an array of the same shape. Parameter {@code + * sequenceLength} is used to handle variable-length sequences. {@code sequenceLength} should be + * an input array of positive ints of dimension [batch_size]. + * + * @param data the {@link NDArray} to operate on + * @param sequenceLength used to handle variable-length sequences + * @param value the constant value to be set + * @return the result {@link NDArray} + */ + public static NDArray sequenceMask(NDArray data, NDArray sequenceLength, float value) { + return data.sequenceMask(sequenceLength, value); + } + + /** + * Sets all elements of the given {@link NDArray} outside the sequence {@link NDArray} to 0. + * + *

This function takes an n-dimensional input array of the form [batch_size, + * max_sequence_length, ....] and returns an array of the same shape. Parameter {@code + * sequenceLength} is used to handle variable-length sequences. {@code sequenceLength} should be + * an input array of positive ints of dimension [batch_size]. + * + * @param data the {@link NDArray} to operate on + * @param sequenceLength used to handle variable-length sequences + * @return the result {@link NDArray} + */ + public static NDArray sequenceMask(NDArray data, NDArray sequenceLength) { + return data.sequenceMask(sequenceLength); + } + + //////////////////////////////////////// + // Operations: Element Arithmetic + //////////////////////////////////////// + + /** + * Adds a number to the {@link NDArray} element-wise. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+     * jshell> MxNDArrays.add(array, 2f);
+     * ND: (2) cpu() float32
+     * [3., 4.]
+     * 
+ * + * @param a the {@link NDArray} to be added to + * @param n the number to add + * @return the result {@link NDArray} + */ + public static NDArray add(NDArray a, Number n) { + return a.add(n); + } + + /** + * Adds a {@link NDArray} to a number element-wise. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+     * jshell> MxNDArrays.add(2f, array);
+     * ND: (2) cpu() float32
+     * [3., 4.]
+     * 
+ * + * @param n the number to be added to + * @param a the {@link NDArray} to add + * @return the result {@link NDArray} + */ + public static NDArray add(Number n, NDArray a) { + return a.add(n); + } + + /** + * Adds a {@link NDArray} to a {@link NDArray} element-wise. + * + *

The shapes of all of the {@link NDArray}s must be the same. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+     * jshell> MxNDArrays.add(array, array, array);
+     * ND: (2) cpu() float32
+     * [3., 6.]
+     * 
+ * + * @param arrays the {@link NDArray}s to add together + * @return the result {@link NDArray} + * @throws IllegalArgumentException arrays must have at least two elements + * @throws IllegalArgumentException the shape of all inputs must be the same + */ + public static NDArray add(NDArray... arrays) { + checkInputs(arrays); + if (arrays.length == 2) { + return arrays[0].add(arrays[1]); + } + try (NDArray array = NDArrays.stack(new NDList(arrays))) { + return array.sum(new int[] {0}); + } + } + + /** + * Subtracts a number from the {@link NDArray} element-wise. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+     * jshell> array.sub(2f);
+     * ND: (2) cpu() float32
+     * [-1.,  0.]
+     * 
+ * + * @param a the {@link NDArray} to be subtracted + * @param n the number to subtract from + * @return the result {@link NDArray} + */ + public static NDArray sub(NDArray a, Number n) { + return a.sub(n); + } + + /** + * Subtracts a {@link NDArray} from a number element-wise. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+     * jshell> MxNDArrays.sub(3f, array);
+     * ND: (2) cpu() float32
+     * [2., 1.]
+     * 
+ * + * @param n the number to be subtracted + * @param a the {@link NDArray} to subtract from + * @return the result {@link NDArray} + */ + public static NDArray sub(Number n, NDArray a) { + return a.getNDArrayInternal().rsub(n); + } + + /** + * Subtracts a {@link NDArray} from a {@link NDArray} element-wise. + * + *

The shapes of {@link NDArray} a and {@link NDArray} b must be broadcastable. + * + *

Examples + * + *

+     * jshell> MxNDArray array1 = manager.arange(9f).reshape(3, 3);
+     * jshell> MxNDArray array2 = manager.arange(3f);
+     * jshell> MxNDArrays.sub(array1, array2); // broadcasting
+     * ND: (3, 3) cpu() float32
+     * [[0., 0., 0.],
+     *  [3., 3., 3.],
+     *  [6., 6., 6.],
+     * ]
+     * 
+ * + * @param a the {@link NDArray} to be subtracted + * @param b the {@link NDArray} to subtract from + * @return the result {@link NDArray} + */ + public static NDArray sub(NDArray a, NDArray b) { + return a.sub(b); + } + + /** + * Multiplies the {@link NDArray} by a number element-wise. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+     * jshell> MxNDArrays.mul(array, 3f);
+     * ND: (2) cpu() float32
+     * [3., 6.]
+     * 
+ * + * @param a the MxNDArray to be multiplied + * @param n the number to multiply by + * @return the result {@link NDArray} + */ + public static NDArray mul(NDArray a, Number n) { + return a.mul(n); + } + + /** + * Multiplies a number by a {@link NDArray} element-wise. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+     * jshell> MxNDArrays.mul(3f, array);
+     * ND: (2) cpu() float32
+     * [3., 6.]
+     * 
+ * + * @param n the number to be multiplied + * @param a the {@link NDArray} to multiply by + * @return the result {@link NDArray} + */ + public static NDArray mul(Number n, NDArray a) { + return a.mul(n); + } + + /** + * Multiplies all of the {@link NDArray}s together element-wise. + * + *

The shapes of all of the {@link NDArray}s must be the same. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+     * jshell> MxNDArrays.mul(array, array, array);
+     * ND: (2) cpu() float32
+     * [1., 8.]
+     * 
+ * + * @param arrays the {@link NDArray}s to multiply together + * @return the result {@link NDArray} + * @throws IllegalArgumentException arrays must have at least two elements + * @throws IllegalArgumentException the shape of all inputs must be the same + */ + public static NDArray mul(NDArray... arrays) { + checkInputs(arrays); + if (arrays.length == 2) { + return arrays[0].mul(arrays[1]); + } + try (NDArray array = NDArrays.stack(new NDList(arrays))) { + return array.prod(new int[] {0}); + } + } + + /** + * Divides the {@link NDArray} by a number element-wise. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.arange(5f);
+     * jshell> MxNDArrays.div(array, 4f);
+     * ND: (5) cpu() float32
+     * [0.  , 0.25, 0.5 , 0.75, 1.  ]
+     * 
+ * + * @param a the {@link NDArray} to be be divided + * @param n the number to divide by + * @return the result {@link NDArray} + */ + public static NDArray div(NDArray a, Number n) { + return a.div(n); + } + + /** + * Divides a number by a {@link NDArray} element-wise. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.arange(5f).add(1);
+     * jshell> MxNDArrays.div(4f, array);
+     * ND: (5) cpu() float32
+     * [4.    , 2.    , 1.3333, 1.    , 0.8   ]
+     * 
+ * + * @param n the number to be be divided + * @param a the {@link NDArray} to divide by + * @return the result {@link NDArray} + */ + public static NDArray div(Number n, NDArray a) { + return a.getNDArrayInternal().rdiv(n); + } + + /** + * Divides a {@link NDArray} by a {@link NDArray} element-wise. + * + *

The shapes of {@link NDArray} a and {@link NDArray} b must be broadcastable. + * + *

Examples + * + *

+     * jshell> MxNDArray array1 = manager.arange(9f).reshape(3, 3);
+     * jshell> MxNDArray array2 = manager.ones(new Shape(3)).mul(10);
+     * jshell> MxNDArrays.div(array1, array2); // broadcasting
+     * ND: (3, 3) cpu() float32
+     * [[0. , 0.1, 0.2],
+     *  [0.3, 0.4, 0.5],
+     *  [0.6, 0.7, 0.8],
+     * ]
+     * 
+ * + * @param a the {@link NDArray} to be be divided + * @param b the {@link NDArray} to divide by + * @return the result {@link NDArray} + */ + public static NDArray div(NDArray a, NDArray b) { + return a.div(b); + } + + /** + * Returns element-wise remainder of division. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.arange(7f);
+     * jshell> MxNDArrays.mod(array, 5f);
+     * ND: (7) cpu() float32
+     * [0., 1., 2., 3., 4., 0., 1.]
+     * 
+ * + * @param a the dividend {@link NDArray} + * @param n the divisor number + * @return the result {@link NDArray} + */ + public static NDArray mod(NDArray a, Number n) { + return a.mod(n); + } + + /** + * Returns element-wise remainder of division. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.arange(7f).add(1);
+     * jshell> MxNDArrays.mod(5f, array);
+     * ND: (7) cpu() float32
+     * [0., 1., 2., 1., 0., 5., 5.]
+     * 
+ * + * @param n the dividend number + * @param a the divisor {@link NDArray} + * @return the result {@link NDArray} + */ + public static NDArray mod(Number n, NDArray a) { + return a.getNDArrayInternal().rmod(n); + } + + /** + * Returns element-wise remainder of division. + * + *

Examples + * + *

+     * jshell> MxNDArray array1 = manager.create(new float[] {4f, 7f});
+     * jshell> MxNDArray array2 = manager.create(new float[] {2f, 3f});
+     * jshell> MxNDArrays.mod(array1, array2);
+     * ND: (2) cpu() float32
+     * [0., 1.]
+     * 
+ * + * @param a the dividend MxNDArray + * @param b the dividend MxNDArray + * @return the result {@link NDArray} + */ + public static NDArray mod(NDArray a, NDArray b) { + return a.mod(b); + } + + /** + * Takes the power of the {@link NDArray} with a number element-wise. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.arange(5f);
+     * jshell> MxNDArrays.pow(array, 4f);
+     * ND: (6) cpu() float32
+     * [  0.,   1.,   8.,  27.,  64., 125.]
+     * 
+ * + * @param a the {@link NDArray} to be taken the power with + * @param n the number to take the power with + * @return the result {@link NDArray} + */ + public static NDArray pow(NDArray a, Number n) { + return a.pow(n); + } + + /** + * Takes the power of a number with a {@link NDArray} element-wise. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.arange(5f);
+     * jshell> MxNDArrays.pow(4f, array);
+     * ND: (5) cpu() float32
+     * [  1.,   4.,  16.,  64., 256.]
+     * 
+ * + * @param n the number to be taken the power with + * @param a the {@link NDArray} to take the power with + * @return the result {@link NDArray} + */ + public static NDArray pow(Number n, NDArray a) { + return a.getNDArrayInternal().rpow(n); + } + + /** + * Takes the power of a {@link NDArray} with a {@link NDArray} element-wise. + * + *

The shapes of {@link NDArray} a and {@link NDArray} b must be broadcastable. + * + *

Examples + * + *

+     * jshell> MxNDArray array1 = manager.arange(6f).reshape(3, 2);
+     * jshell> MxNDArray array2 = manager.create(new float[] {2f, 3f});
+     * jshell> MxNDArrays.pow(array1, array2); // broadcasting
+     * ND: (3, 2) cpu() float32
+     * [[  0.,   1.],
+     *  [  4.,  27.],
+     *  [ 16., 125.],
+     * ]
+     * 
+ * + * @param a the {@link NDArray} to be taken the power with + * @param b the {@link NDArray} to take the power with + * @return the result {@link NDArray} + */ + public static NDArray pow(NDArray a, NDArray b) { + return a.pow(b); + } + + /** + * Adds a number to the {@link NDArray} element-wise in place. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+     * jshell> MxNDArrays.addi(array, 2f);
+     * ND: (2) cpu() float32
+     * [3., 4.]
+     * jshell> array;
+     * ND: (2) cpu() float32
+     * [3., 4.]
+     * 
+ * + * @param a the {@link NDArray} to be added to + * @param n the number to add + * @return the result {@link NDArray} + */ + public static NDArray addi(NDArray a, Number n) { + return a.addi(n); + } + + /** + * Adds a {@link NDArray} to a number element-wise in place. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+     * jshell> MxNDArrays.addi(2f, array);
+     * ND: (2) cpu() float32
+     * [3., 4.]
+     * jshell> array;
+     * ND: (2) cpu() float32
+     * [3., 4.]
+     * 
+ * + * @param a the number to be added to + * @param n the {@link NDArray} to add + * @return the result {@link NDArray} + */ + public static NDArray addi(Number n, NDArray a) { + return a.addi(n); + } + + /** + * Adds all of the {@link NDArray}s together element-wise in place. + * + *

The shapes of all of the {@link NDArray}s must be the same. + * + *

Examples + * + *

+     * jshell> MxNDArray array1 = manager.create(new float[] {1f, 2f});
+     * jshell> MxNDArray array2 = manager.create(new float[] {3f, 4f});
+     * jshell> MxNDArray array3 = manager.create(new float[] {5f, 6f});
+     * jshell> MxNDArrays.addi(array1, array2, array3);
+     * ND: (2) cpu() float32
+     * [9., 12.]
+     * jshell> array;
+     * ND: (2) cpu() float32
+     * [9., 12.]
+     * 
+ * + * @param arrays the {@link NDArray}s to add together + * @return the result {@link NDArray} + * @throws IllegalArgumentException arrays must have at least two elements + */ + public static NDArray addi(NDArray... arrays) { + checkInputs(arrays); + Arrays.stream(arrays).skip(1).forEachOrdered(array -> arrays[0].addi(array)); + return arrays[0]; + } + + /** + * Subtracts a number from the {@link NDArray} element-wise in place. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+     * jshell> MxNDArrays.subi(array, 2f);
+     * ND: (2) cpu() float32
+     * [-1.,  0.]
+     * jshell> array;
+     * ND: (2) cpu() float32
+     * [-1.,  0.]
+     * 
+ * + * @param a the {@link NDArray} to be subtracted + * @param n the number to subtract from + * @return the result {@link NDArray} + */ + public static NDArray subi(NDArray a, Number n) { + return a.subi(n); + } + + /** + * Subtracts a {@link NDArray} from a number element-wise in place. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+     * jshell> MxNDArrays.subi(3f, array);
+     * ND: (2) cpu() float32
+     * [2., 1.]
+     * jshell> array;
+     * ND: (2) cpu() float32
+     * [2., 1.]
+     * 
+ * + * @param n the number to be subtracted + * @param a the {@link NDArray} to subtract from + * @return the result {@link NDArray} + */ + public static NDArray subi(Number n, NDArray a) { + return a.getNDArrayInternal().rsubi(n); + } + + /** + * Subtracts a {@link NDArray} from a {@link NDArray} element-wise in place. + * + *

The shapes of {@link NDArray} a and {@link NDArray} b must be broadcastable. + * + *

Examples + * + *

+     * jshell> MxNDArray array1 = manager.arange(9f).reshape(3, 3);
+     * jshell> MxNDArray array2 = manager.arange(3f);
+     * jshell> MxNDArrays.subi(array1, array2); // broadcasting
+     * ND: (3, 3) cpu() float32
+     * [[0., 0., 0.],
+     *  [3., 3., 3.],
+     *  [6., 6., 6.],
+     * ]
+     * jshell> array1;
+     * [[0., 0., 0.],
+     *  [3., 3., 3.],
+     *  [6., 6., 6.],
+     * ]
+     * 
+ * + * @param a the {@link NDArray} to be subtracted + * @param b the {@link NDArray} to subtract from + * @return the result {@link NDArray} + */ + public static NDArray subi(NDArray a, NDArray b) { + return a.subi(b); + } + + /** + * Multiplies the {@link NDArray} by a number element-wise in place. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+     * jshell> MxNDArrays.muli(array, 3f);
+     * ND: (2) cpu() float32
+     * [3., 6.]
+     * jshell> array;
+     * ND: (2) cpu() float32
+     * [3., 6.]
+     * 
+ * + * @param a the MxNDArray to be multiplied + * @param n the number to multiply by + * @return the result {@link NDArray} + */ + public static NDArray muli(NDArray a, Number n) { + return a.muli(n); + } + + /** + * Multiplies a number by a {@link NDArray} element-wise. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+     * jshell> MxNDArrays.muli(3f, array);
+     * ND: (2) cpu() float32
+     * [3., 6.]
+     * jshell> array;
+     * ND: (2) cpu() float32
+     * [3., 6.]
+     * 
+ * + * @param n the number to multiply by + * @param a the {@link NDArray} to multiply by + * @return the result {@link NDArray} + */ + public static NDArray muli(Number n, NDArray a) { + return a.muli(n); + } + + /** + * Multiplies all of the {@link NDArray}s together element-wise in place. + * + *

The shapes of all of the {@link NDArray}s must be the same. + * + *

Examples + * + *

+     * jshell> MxNDArray array1 = manager.create(new float[] {1f, 2f});
+     * jshell> MxNDArray array2 = manager.create(new float[] {3f, 4f});
+     * jshell> MxNDArray array3 = manager.create(new float[] {5f, 6f});
+     * jshell> MxNDArrays.muli(array1, array2, array3);
+     * ND: (2) cpu() float32
+     * [15., 48.]
+     * jshell> array;
+     * ND: (2) cpu() float32
+     * [15., 48.]
+     * 
+ * + * @param arrays the {@link NDArray}s to multiply together + * @return the result {@link NDArray} + * @throws IllegalArgumentException arrays must have at least two elements + */ + public static NDArray muli(NDArray... arrays) { + checkInputs(arrays); + Arrays.stream(arrays).skip(1).forEachOrdered(array -> arrays[0].muli(array)); + return arrays[0]; + } + + /** + * Divides a number by a {@link NDArray} element-wise in place. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.arange(5f);
+     * jshell> MxNDArrays.divi(array, 4f);
+     * ND: (5) cpu() float32
+     * [0.  , 0.25, 0.5 , 0.75, 1.  ]
+     * jshell> array;
+     * ND: (5) cpu() float32
+     * [0.  , 0.25, 0.5 , 0.75, 1.  ]
+     * 
+ * + * @param a the {@link NDArray} to be be divided + * @param n the number to divide by + * @return the result {@link NDArray} + */ + public static NDArray divi(NDArray a, Number n) { + return a.divi(n); + } + + /** + * Divides a number by a {@link NDArray} element-wise. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.arange(5f).add(1);
+     * jshell> MxNDArrays.divi(4f, array);
+     * ND: (5) cpu() float32
+     * [4.    , 2.    , 1.3333, 1.    , 0.8   ]
+     * jshell> array;
+     * ND: (5) cpu() float32
+     * [4.    , 2.    , 1.3333, 1.    , 0.8   ]
+     * 
+ * + * @param n the number to be be divided + * @param a the {@link NDArray} to divide by + * @return the result {@link NDArray} + */ + public static NDArray divi(Number n, NDArray a) { + return a.getNDArrayInternal().rdivi(n); + } + + /** + * Divides a {@link NDArray} by a {@link NDArray} element-wise. + * + *

The shapes of {@link NDArray} a and {@link NDArray} b must be broadcastable. + * + *

Examples + * + *

+     * jshell> MxNDArray array1 = manager.arange(9f).reshape(3, 3);
+     * jshell> MxNDArray array2 = manager.ones(new Shape(3)).mul(10);
+     * jshell> MxNDArrays.divi(array1, array2); // broadcasting
+     * ND: (3, 3) cpu() float32
+     * [[0. , 0.1, 0.2],
+     *  [0.3, 0.4, 0.5],
+     *  [0.6, 0.7, 0.8],
+     * ]
+     * jshell> array1;
+     * [[0. , 0.1, 0.2],
+     *  [0.3, 0.4, 0.5],
+     *  [0.6, 0.7, 0.8],
+     * ]
+     * 
+ * + * @param a the {@link NDArray} to be be divided + * @param b the {@link NDArray} to divide by + * @return the result {@link NDArray} + */ + public static NDArray divi(NDArray a, NDArray b) { + return a.divi(b); + } + + /** + * Returns element-wise remainder of division in place. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.arange(7f);
+     * jshell> MxNDArrays.modi(array, 5f);
+     * ND: (7) cpu() float32
+     * [0., 1., 2., 3., 4., 0., 1.]
+     * jshell> array;
+     * ND: (7) cpu() float32
+     * [0., 1., 2., 3., 4., 0., 1.]
+     * 
+ * + * @param a the dividend {@link NDArray} + * @param n the divisor number + * @return the result {@link NDArray} + */ + public static NDArray modi(NDArray a, Number n) { + return a.modi(n); + } + + /** + * Returns element-wise remainder of division in place. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.arange(7f);
+     * jshell> MxNDArrays.modi(5f, array);
+     * ND: (7) cpu() float32
+     * [0., 0., 1., 2., 1., 0., 5.]
+     * jshell> array;
+     * ND: (7) cpu() float32
+     * [0., 0., 1., 2., 1., 0., 5.]
+     * 
+ * + * @param n the dividend number + * @param a the divisor {@link NDArray} + * @return the result {@link NDArray} + */ + public static NDArray modi(Number n, NDArray a) { + return a.getNDArrayInternal().rmodi(n); + } + + /** + * Returns element-wise remainder of division. + * + *

Examples + * + *

+     * jshell> MxNDArray array1 = manager.create(new float[] {4f, 7f});
+     * jshell> MxNDArray array2 = manager.create(new float[] {2f, 3f});
+     * jshell> MxNDArrays.modi(array1, array2);
+     * ND: (2) cpu() float32
+     * [0., 1.]
+     * jshell> array1;
+     * ND: (2) cpu() float32
+     * [0., 1.]
+     * 
+ * + * @param a the dividend MxNDArray + * @param b the dividend MxNDArray + * @return the result {@link NDArray} + */ + public static NDArray modi(NDArray a, NDArray b) { + return a.modi(b); + } + + /** + * Takes the power of the {@link NDArray} with a number element-wise in place. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.arange(5f);
+     * jshell> MxNDArrays.powi(array, 4f);
+     * ND: (6) cpu() float32
+     * [  0.,   1.,   8.,  27.,  64., 125.]
+     * jshell> array;
+     * ND: (6) cpu() float32
+     * [  0.,   1.,   8.,  27.,  64., 125.]
+     * 
+ * + * @param a the {@link NDArray} to be taken the power with + * @param n the number to take the power with + * @return the result {@link NDArray} + */ + public static NDArray powi(NDArray a, Number n) { + return a.powi(n); + } + + /** + * Takes the power of a number with a {@link NDArray} element-wise in place. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.arange(5f);
+     * jshell> MxNDArrays.powi(4f, array);
+     * ND: (5) cpu() float32
+     * [  1.,   4.,  16.,  64., 256.]
+     * jshell> array;
+     * ND: (5) cpu() float32
+     * [  1.,   4.,  16.,  64., 256.]
+     * 
+ * + * @param n the number to be taken the power with + * @param a the {@link NDArray} to take the power with + * @return the result {@link NDArray} + */ + public static NDArray powi(Number n, NDArray a) { + return a.getNDArrayInternal().rpowi(n); + } + + /** + * Takes the power of a {@link NDArray} with a {@link NDArray} element-wise. + * + *

The shapes of {@link NDArray} a and {@link NDArray} b must be broadcastable. + * + *

Examples + * + *

+     * jshell> MxNDArray array1 = manager.arange(6f).reshape(3, 2);
+     * jshell> MxNDArray array2 = manager.create(new float[] {2f, 3f});
+     * jshell> MxNDArrays.powi(array1, array2); // broadcasting
+     * ND: (3, 2) cpu() float32
+     * [[  0.,   1.],
+     *  [  4.,  27.],
+     *  [ 16., 125.],
+     * ]
+     * jshell> array1;
+     * ND: (3, 2) cpu() float32
+     * [[  0.,   1.],
+     *  [  4.,  27.],
+     *  [ 16., 125.],
+     * ]
+     * 
+ * + * @param a the {@link NDArray} to be taken the power with + * @param b the {@link NDArray} to take the power with + * @return the result {@link NDArray} + */ + public static NDArray powi(NDArray a, NDArray b) { + return a.powi(b); + } + + /** + * Dot product of {@link NDArray} a and {@link NDArray} b. + * + *
    + *
  • If both the {@link NDArray} and the other {@link NDArray} are 1-D {@link NDArray}s, it + * is inner product of vectors (without complex conjugation). + *
  • If both the {@link NDArray} and the other {@link NDArray} are 2-D {@link NDArray}s, it + * is matrix multiplication. + *
  • If either the {@link NDArray} or the other {@link NDArray} is 0-D {@link NDArray} + * (scalar), it is equivalent to mul. + *
  • If the {@link NDArray} is N-D {@link NDArray} and the other {@link NDArray} is 1-D + * {@link NDArray}, it is a sum product over the last axis of those. + *
  • If the {@link NDArray} is N-D {@link NDArray} and the other {@link NDArray} is M-D + * {@link NDArray}(where M>=2), it is a sum product over the last axis of this + * {@link NDArray} and the second-to-last axis of the other {@link NDArray} + *
+ * + *

Examples + * + *

+     * jshell> MxNDArray array1 = manager.create(new float[] {1f, 2f, 3f});
+     * jshell> MxNDArray array2 = manager.create(new float[] {4f, 5f, 6f});
+     * jshell> MxNDArrays.dot(array1, array2); // inner product
+     * ND: () cpu() float32
+     * 32.
+     * jshell> array1 = manager.create(new float[] {1f, 2f, 3f, 4f}, new Shape(2, 2));
+     * jshell> array2 = manager.create(new float[] {5f, 6f, 7f, 8f}, new Shape(2, 2));
+     * jshell> MxNDArrays.dot(array1, array2); // matrix multiplication
+     * ND: (2, 2) cpu() float32
+     * [[19., 22.],
+     *  [43., 50.],
+     * ]
+     * jshell> array1 = manager.create(new float[] {1f, 2f, 3f, 4f}, new Shape(2, 2));
+     * jshell> array2 = manager.create(5f);
+     * jshell> MxNDArrays.dot(array1, array2);
+     * ND: (2, 2) cpu() float32
+     * [[ 5., 10.],
+     *  [15., 20.],
+     * ]
+     * jshell> array1 = manager.create(new float[] {1f, 2f, 3f, 4f}, new Shape(2, 2));
+     * jshell> array2 = manager.create(new float[] {1f, 2f});
+     * jshell> MxNDArrays.dot(array1, array2);
+     * ND: (2) cpu() float32
+     * [ 5., 11.]
+     * jshell> array1 = manager.create(new float[] {1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f}, new Shape(2, 2, 2));
+     * jshell> array2 = manager.create(new float[] {1f, 2f, 3f ,4f}, new Shape(2, 2));
+     * jshell> MxNDArrays.dot(array1, array2);
+     * ND: (2, 2, 2) cpu() float32
+     * [[[ 7., 10.],
+     *   [15., 22.],
+     *  ],
+     *  [[23., 34.],
+     *   [31., 46.],
+     *  ],
+     * ]
+     * 
+ * + * @param a the {@link NDArray} to perform dot product with + * @param b the {@link NDArray} to perform dot product with + * @return the result {@link NDArray} + */ + public static NDArray dot(NDArray a, NDArray b) { + return a.dot(b); + } + + /** + * Product matrix of this {@code MxNDArray} and the other {@code MxNDArray}. + * + *

The behavior depends on the arguments in the following way. + * + *

    + *
  • If both this {@code MxNDArray} and the other {@code MxNDArray} are 2-D {@code + * MxNDArray}s, they are multiplied like conventional matrices + *
  • If either this {@code MxNDArray} or the other {@code MxNDArray} is N-D {@code + * MxNDArray}, N > 2 , it is treated as a stack of matrices residing in the last two + * indexes and broadcast accordingly. + *
  • If this {@code MxNDArray} is 1-D {@code MxNDArray}, it is promoted to a matrix by + * prepending a 1 to its dimensions. After matrix multiplication the prepended 1 is + * removed. + *
  • If other {@code MxNDArray} is 1-D {@code MxNDArray}, it is promoted to a matrix by + * appending a 1 to its dimensions. After matrix multiplication the appended 1 is removed. + *
+ * + *

Examples + * + *

+     * jshell> MxNDArray array1 = manager.create(new float[] {1f, 0f, 0f, 1f}, new Shape(2, 2));
+     * jshell> MxNDArray array2 = manager.create(new float[] {4f, 1f, 2f, 2f}, new Shape(2, 2));
+     * jshell> MxNDArrays.matMul(array1, array2); // for 2-D arrays, it is the matrix product
+     * ND: (2, 2) cpu() float32
+     * [[4., 1.],
+     *  [2., 2.],
+     * ]
+     * jshell> array1 = manager.create(new float[] {1f, 0f, 0f, 1f}, new Shape(2, 2));
+     * jshell> array2 = manager.create(new float[] {1f, 2f});
+     * jshell> MxNDArrays.matMul(array1, array2);
+     * ND: (2) cpu() float32
+     * [1., 2.]
+     * jshell> array1 = manager.create(new float[] {1f, 0f, 0f, 1f}, new Shape(2, 2));
+     * jshell> array2 = manager.create(new float[] {1f, 2f});
+     * jshell> MxNDArrays.matMul(array1, array2);
+     * ND: (2) cpu() float32
+     * [1., 2.]
+     * jshell> array1 = manager.arange(2f * 2f * 4f).reshape(2, 2, 4);
+     * jshell> array2 = manager.arange(2f * 2f * 4f).reshape(2, 4, 2);
+     * jshell> MxNDArrays.matMul(array1, array2);
+     * ND: () cpu() float32
+     * 98.
+     * 
+ * + * @param a the {@link NDArray} to perform matrix product with + * @param b the {@link NDArray} to perform matrix product with + * @return the result {@code MxNDArray} + */ + public static NDArray matMul(NDArray a, NDArray b) { + return a.matMul(b); + } + + /** + * Joins a sequence of {@link NDArray}s in {@link NDList} along the first axis. + * + *

Examples + * + *

+     * jshell> MxNDArray array1 = manager.create(new float[] {0f, 1f, 2f});
+     * jshell> MxNDArray array2 = manager.create(new float[] {3f, 4f, 5f});
+     * jshell> MxNDArray array3 = manager.create(new float[] {6f, 7f, 8f});
+     * jshell> MxNDArrays.stack(new MxNDList(array1, array2, array3));
+     * ND: (3, 3) cpu() float32
+     * [[0., 1., 2.],
+     *  [3., 4., 5.],
+     *  [6., 7., 8.],
+     * ]
+     * 
+ * + * @param arrays the input {@link NDList}. Each {@link NDArray} in the {@link NDList} must have + * the same shape as the {@link NDArray} + * @return the result {@link NDArray}. The stacked {@link NDArray} has one more dimension than + * the {@link NDArray}s in {@link NDList} + */ + public static NDArray stack(NDList arrays) { + return stack(arrays, 0); + } + + /** + * Joins a sequence of {@link NDArray}s in {@link NDList} along a new axis. + * + *

The axis parameter specifies the index of the new axis in the dimensions of the result. + * For example, if axis=0 it will be the first dimension and if axis=-1 it will be the last + * dimension. + * + *

Examples + * + *

+     * jshell> MxNDArray array1 = manager.create(new float[] {0f, 1f, 2f});
+     * jshell> MxNDArray array2 = manager.create(new float[] {3f, 4f, 5f});
+     * jshell> MxNDArrays.stack(new MxNDList(array1, array2), 0);
+     * ND: (2, 3) cpu() float32
+     * [[0., 1., 2.],
+     *  [3., 4., 5.],
+     * ]
+     * jshell> MxNDArray array1 = manager.create(new float[] {0f, 1f, 2f});
+     * jshell> MxNDArray array2 = manager.create(new float[] {3f, 4f, 5f});
+     * jshell> MxNDArrays.stack(new MxNDList(array1, array2), 1);
+     * ND: (3, 2) cpu() float32
+     * [[0., 3.],
+     *  [1., 4.],
+     *  [2., 5.],
+     * ]
+     * 
+ * + * @param arrays the input {@link NDList}. Each {@link NDArray} in the {@link NDList} must have + * the same shape as the {@link NDArray} + * @param axis the axis in the result {@link NDArray} along which the input {@link NDList} are + * stacked + * @return the result {@link NDArray}. The stacked {@link NDArray} has one more dimension than + * the the {@link NDArray} + */ + public static NDArray stack(NDList arrays, int axis) { + if (arrays.size() <= 0) { + throw new IllegalArgumentException("need at least one array to stack"); + } + NDArray array = arrays.head(); + return array.getNDArrayInternal().stack(arrays.subNDList(1), axis); + } + + /** + * Joins a {@link NDList} along the first axis. + * + *

Examples + * + *

+     * jshell> MxNDArray array1 = manager.create(new float[] {0f, 1f, 2f});
+     * jshell> MxNDArray array2 = manager.create(new float[] {3f, 4f, 5f});
+     * jshell> MxNDArray array3 = manager.create(new float[] {6f, 7f, 8f});
+     * jshell> MxNDArrays.concat(new MxNDList(array1, array2, array3));
+     * ND: (9) cpu() float32
+     * [0., 1., 2., 3., 4., 5., 6., 7., 8.]
+     * 
+ * + * @param arrays a {@link NDList} which have the same shape as the {@link NDArray}, except in + * the dimension corresponding to axis + * @return the concatenated {@link NDArray} + */ + public static NDArray concat(NDList arrays) { + return concat(arrays, 0); + } + + /** + * Joins a {@link NDList} along an existing axis. + * + *

Examples + * + *

+     * jshell> MxNDArray array1 = manager.create(new float[] {1f, 2f, 3f, 4f}, new Shape(2, 2));
+     * jshell> MxNDArray array2 = manager.create(new float[] {5f, 6f}, new Shape(1, 2));
+     * jshell> MxNDArrays.concat(new MxNDList(array1, array2), 0);
+     * ND: (3, 2) cpu() float32
+     * [[1., 2.],
+     *  [3., 4.],
+     *  [5., 6.],
+     * ]
+     * jshell> MxNDArrays.concat(new MxNDList(array1, array2.transpose()), 1);
+     * ND: (2, 3) cpu() float32
+     * [[1., 2., 5.],
+     *  [3., 4., 6.],
+     * ]
+     * 
+ * + * @param arrays a {@link NDList} which have the same shape as the {@link NDArray}, except in + * the dimension corresponding to axis + * @param axis the axis along which the {@link NDList} will be joined + * @return the concatenated {@link NDArray} + */ + public static NDArray concat(NDList arrays, int axis) { + + if (arrays.size() <= 0) { + throw new IllegalArgumentException("need at least one array to concatenate"); + } + + if (arrays.size() == 1) { + return arrays.singletonOrThrow().duplicate(); + } + NDArray array = arrays.head(); + return array.getNDArrayInternal().concat(arrays.subNDList(1), axis); + } + + /** + * Returns the truth value of {@link NDArray} a AND {@link NDArray} b element-wise. + * + *

The shapes of {@link NDArray} a and {@link NDArray} b must be broadcastable. + * + *

Examples + * + *

+     * jshell> MxNDArray array1 = manager.create(new boolean[] {true});
+     * jshell> MxNDArray array2 = manager.create(new boolean[] {false});
+     * jshell> MxNDArrays.logicalAnd(array1, array2);
+     * ND: (1) cpu() boolean
+     * [false]
+     * jshell> array1 = manager.create(new boolean[] {true, false});
+     * jshell> array2 = manager.create(new boolean[] {false, false});
+     * jshell> MxNDArrays.logicalAnd(array.gt(1), array.lt(4));
+     * ND: (2) cpu() boolean
+     * [false, false]
+     * 
+ * + * @param a the {@link NDArray} to operate on + * @param b the {@link NDArray} to operate on + * @return the boolean {@link NDArray} of the logical AND operation applied to the elements of + * the {@link NDArray} a and {@link NDArray} b + */ + public static NDArray logicalAnd(NDArray a, NDArray b) { + return a.logicalAnd(b); + } + + /** + * Computes the truth value of {@link NDArray} a AND {@link NDArray} b element-wise. + * + *

Examples + * + *

+     * jshell> MxNDArray array1 = manager.create(new boolean[] {true});
+     * jshell> MxNDArray array2 = manager.create(new boolean[] {false});
+     * jshell> MxNDArrays.logicalOr(array1, array2);
+     * ND: (1) cpu() boolean
+     * [ true]
+     * jshell> array1 = manager.create(new boolean[] {true, false});
+     * jshell> array2 = manager.create(new boolean[] {false, false});
+     * jshell> MxNDArrays.logicalOr(array1, array2);
+     * ND: (2) cpu() boolean
+     * [ true, false]
+     * 
+ * + *
+     * jshell> MxNDArray array = manager.arange(5f);
+     * jshell> MxNDArrays.logicalOr(array.lt(1), array.gt(3));
+     * ND: (5) cpu() boolean
+     * [ true, false, false, false,  true]
+     * 
+ * + * @param a the {@link NDArray} to operate on + * @param b the {@link NDArray} to operate on + * @return the boolean {@link NDArray} of the logical AND operation applied to the elements of + * the {@link NDArray} a and {@link NDArray} b + */ + public static NDArray logicalOr(NDArray a, NDArray b) { + return a.logicalOr(b); + } + + /** + * Computes the truth value of {@link NDArray} a AND {@link NDArray} b element-wise. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.create(new boolean[] {true});
+     * jshell> MxNDArrays.logicalXor(array1, array2);
+     * ND: (1) cpu() boolean
+     * [ true]
+     * jshell> array1 = manager.create(new boolean[] {true, false});
+     * jshell> array2 = manager.create(new boolean[] {false, false});
+     * jshell> MxNDArrays.logicalXor(array1, array2);
+     * ND: (2) cpu() boolean
+     * [ true, false]
+     * 
+ * + *
+     * jshell> MxNDArray array = manager.arange(5f);
+     * jshell> MxNDArrays.logicalXor(array.lt(1), array.gt(3));
+     * ND: (5) cpu() boolean
+     * [ true, false, false, false,  true]
+     * 
+ * + * @param a the {@link NDArray} to operate on + * @param b the {@link NDArray} to operate on + * @return the boolean {@link NDArray} of the logical XOR operation applied to the elements of + * the {@link NDArray} a and {@link NDArray} b + */ + public static NDArray logicalXor(NDArray a, NDArray b) { + return a.logicalXor(b); + } + + /** + * Returns element-wise inverse gauss error function of the input {@code MxNDArray}. + * + *

Examples + * + *

+     * jshell> MxNDArray array = manager.create(new float[] {0f, 0.5f, -1f});
+     * jshell> MxNDArrays.erfinv(array);
+     * ND: (3) cpu() float32
+     * [0., 0.4769, -inf]
+     * 
+ * + * @param input The input {@code MxNDArray} + * @return The inverse of gauss error of the input, element-wise + */ + public static NDArray erfinv(NDArray input) { + return input.erfinv(); + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDFormat.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDFormat.java new file mode 100644 index 000000000000..686719fdccad --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDFormat.java @@ -0,0 +1,345 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.ndarray; + +import java.util.Locale; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import org.apache.mxnet.ndarray.types.DataType; +import org.apache.mxnet.ndarray.types.Shape; +import org.apache.mxnet.util.Utils; + +/** A helper for printing an {@link NDArray}. */ +public abstract class NDFormat { + + private static final int PRECISION = 8; + private static final String LF = System.getProperty("line.separator"); + private static final Pattern PATTERN = Pattern.compile("\\s*\\d\\.(\\d*?)0*e[+-](\\d+)"); + + /** + * Formats the contents of an array as a pretty printable string. + * + * @param array the array to print + * @param maxSize the maximum elements to print out + * @param maxDepth the maximum depth to print out + * @param maxRows the maximum rows to print out + * @param maxColumns the maximum columns to print out + * @return the string representation of the array + */ + public static String format( + NDArray array, int maxSize, int maxDepth, int maxRows, int maxColumns) { + NDFormat format; + DataType dataType = array.getDataType(); + + if (dataType == DataType.UINT8) { + format = new HexFormat(); + } else if (dataType == DataType.BOOLEAN) { + format = new BooleanFormat(); + } else if (dataType.isInteger()) { + format = new IntFormat(array); + } else { + format = new FloatFormat(array); + } + return format.dump(array, maxSize, maxDepth, maxRows, maxColumns); + } + + protected abstract CharSequence format(Number value); + + private String dump(NDArray array, int maxSize, int maxDepth, int maxRows, int maxColumns) { + StringBuilder sb = new StringBuilder(1000); + String name = array.getName(); + if (name != null) { + sb.append(name).append(": "); + } else { + sb.append("ND: "); + } + sb.append(array.getShape()) + .append(' ') + .append(array.getDevice()) + .append(' ') + .append(array.getDataType()); + if (array.hasGradient()) { + sb.append(" hasGradient"); + } + sb.append(LF); + + long size = array.size(); + long dimension = array.getShape().dimension(); + if (size == 0) { + // corner case: 0 dimension + sb.append("[]").append(LF); + } else if (dimension == 0) { + // scalar case + sb.append(format(array.toArray()[0])).append(LF); + } else if (size > maxSize) { + sb.append("[ Exceed max print size ]"); + } else if (dimension > maxDepth) { + sb.append("[ Exceed max print dimension ]"); + } else { + dump(sb, array, 0, true, maxRows, maxColumns); + } + return sb.toString(); + } + + private void dump( + StringBuilder sb, + NDArray array, + int depth, + boolean first, + int maxRows, + int maxColumns) { + if (!first) { + Utils.pad(sb, ' ', depth); + } + sb.append('['); + Shape shape = array.getShape(); + if (shape.dimension() == 1) { + append(sb, array.toArray(), maxColumns); + } else { + long len = shape.head(); + long limit = Math.min(len, maxRows); + for (int i = 0; i < limit; ++i) { + try (NDArray nd = array.get(i)) { + dump(sb, nd, depth + 1, i == 0, maxRows, maxColumns); + } + } + long remaining = len - limit; + if (remaining > 0) { + Utils.pad(sb, ' ', depth + 1); + sb.append("... ").append(remaining).append(" more"); + } + Utils.pad(sb, ' ', depth); + } + // last "]" + if (depth == 0) { + sb.append(']').append(LF); + } else { + sb.append("],").append(LF); + } + } + + private void append(StringBuilder sb, Number[] values, int maxColumns) { + if (values.length == 0) { + return; + } + long limit = Math.min(values.length, maxColumns); + sb.append(format(values[0])); + for (int i = 1; i < limit; ++i) { + sb.append(", "); + sb.append(format(values[i])); + } + + long remaining = values.length - limit; + if (remaining > 0) { + sb.append(", ... ").append(remaining).append(" more"); + } + } + + private static final class FloatFormat extends NDFormat { + + private boolean exponential; + private int precision; + private int totalLength; + + public FloatFormat(NDArray array) { + Number[] values = array.toArray(); + int maxIntPartLen = 0; + int maxFractionLen = 0; + int expFractionLen = 0; + int maxExpSize = 2; + boolean sign = false; + + double max = 0; + double min = Double.MAX_VALUE; + for (Number n : values) { + double v = n.doubleValue(); + if (v < 0) { + sign = true; + } + + if (!Double.isFinite(v)) { + int intPartLen = v < 0 ? 4 : 3; + if (totalLength < intPartLen) { + totalLength = intPartLen; + } + continue; + } + double abs = Math.abs(v); + String str = String.format(Locale.ENGLISH, "%16e", abs); + Matcher m = PATTERN.matcher(str); + if (!m.matches()) { + throw new AssertionError("Invalid decimal value: " + str); + } + int fractionLen = m.group(1).length(); + if (expFractionLen < fractionLen) { + expFractionLen = fractionLen; + } + int expSize = m.group(2).length(); + if (expSize > maxExpSize) { + maxExpSize = expSize; + } + + if (abs >= 1) { + int intPartLen = (int) Math.log10(abs) + 1; + if (v < 0) { + ++intPartLen; + } + if (intPartLen > maxIntPartLen) { + maxIntPartLen = intPartLen; + } + int fullFractionLen = fractionLen + 1 - intPartLen; + if (maxFractionLen < fullFractionLen) { + maxFractionLen = fullFractionLen; + } + } else { + int intPartLen = v < 0 ? 2 : 1; + if (intPartLen > maxIntPartLen) { + maxIntPartLen = intPartLen; + } + + int fullFractionLen = fractionLen + Integer.parseInt(m.group(2)); + if (maxFractionLen < fullFractionLen) { + maxFractionLen = fullFractionLen; + } + } + + if (abs > max) { + max = abs; + } + if (abs < min && abs > 0) { + min = abs; + } + } + double ratio = max / min; + if (max > 1.e8 || min < 0.0001 || ratio > 1000.) { + exponential = true; + precision = Math.min(PRECISION, expFractionLen); + totalLength = precision + 4; + if (sign) { + ++totalLength; + } + } else { + precision = Math.min(4, maxFractionLen); + int len = maxIntPartLen + precision + 1; + if (totalLength < len) { + totalLength = len; + } + } + } + + /** {@inheritDoc} */ + @Override + public CharSequence format(Number value) { + double d = value.doubleValue(); + if (Double.isNaN(d)) { + return String.format(Locale.ENGLISH, "%" + totalLength + "s", "nan"); + } else if (Double.isInfinite(d)) { + if (d > 0) { + return String.format(Locale.ENGLISH, "%" + totalLength + "s", "inf"); + } else { + return String.format(Locale.ENGLISH, "%" + totalLength + "s", "-inf"); + } + } + if (exponential) { + precision = Math.max(PRECISION, precision); + return String.format(Locale.ENGLISH, "% ." + precision + "e", value.doubleValue()); + } + if (precision == 0) { + String fmt = "%" + (totalLength - 1) + '.' + precision + "f."; + return String.format(Locale.ENGLISH, fmt, value.doubleValue()); + } + + String fmt = "%" + totalLength + '.' + precision + 'f'; + String ret = String.format(Locale.ENGLISH, fmt, value.doubleValue()); + // Replace trailing zeros with space + char[] chars = ret.toCharArray(); + for (int i = chars.length - 1; i >= 0; --i) { + if (chars[i] == '0') { + chars[i] = ' '; + } else { + break; + } + } + return new String(chars); + } + } + + private static final class HexFormat extends NDFormat { + + /** {@inheritDoc} */ + @Override + public CharSequence format(Number value) { + return String.format(Locale.ENGLISH, "0x%02X", value.byteValue()); + } + } + + private static final class IntFormat extends NDFormat { + + private boolean exponential; + private int precision; + private int totalLength; + + public IntFormat(NDArray array) { + Number[] values = array.toArray(); + // scalar case + if (values.length == 1) { + totalLength = 1; + return; + } + long max = 0; + long negativeMax = 0; + for (Number n : values) { + long v = n.longValue(); + long abs = Math.abs(v); + if (v < 0 && abs > negativeMax) { + negativeMax = abs; + } + if (abs > max) { + max = abs; + } + } + + if (max >= 1.e8) { + exponential = true; + precision = Math.min(PRECISION, (int) Math.log10(max) + 1); + } else { + int size = (max != 0) ? (int) Math.log10(max) + 1 : 1; + int negativeSize = (negativeMax != 0) ? (int) Math.log10(negativeMax) + 2 : 2; + totalLength = Math.max(size, negativeSize); + } + } + + /** {@inheritDoc} */ + @Override + public CharSequence format(Number value) { + if (exponential) { + return String.format(Locale.ENGLISH, "% ." + precision + "e", value.floatValue()); + } + return String.format(Locale.ENGLISH, "%" + totalLength + "d", value.longValue()); + } + } + + private static final class BooleanFormat extends NDFormat { + + /** {@inheritDoc} */ + @Override + public CharSequence format(Number value) { + return value.byteValue() != 0 ? " true" : "false"; + } + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDList.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDList.java new file mode 100644 index 000000000000..b467a3a95c20 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDList.java @@ -0,0 +1,272 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.ndarray; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import org.apache.mxnet.engine.Device; +import org.apache.mxnet.engine.MxResource; +import org.apache.mxnet.ndarray.types.Shape; + +/** + * An {@code NDList} represents a sequence of {@link NDArray}s with names. + * + *

Each {@link NDArray} in this list can optionally have a name. You can use the name to look up + * an NDArray in the NDList. + * + * @see NDArray + */ +public class NDList extends ArrayList implements AutoCloseable { + private static final long serialVersionUID = 1L; + + /** Constructs an empty NDList. */ + public NDList() {} + + /** + * Constructs an empty NDList with the specified initial capacity. + * + * @param initialCapacity the initial capacity of the list + * @throws IllegalArgumentException if the specified initial capacity is negative + */ + public NDList(int initialCapacity) { + super(initialCapacity); + } + + /** + * Constructs and initiates an NDList with the specified {@link NDList}s. + * + * @param arrays the {@link NDList}s + */ + public NDList(NDArray... arrays) { + super(Arrays.asList(arrays)); + } + + /** + * Constructs and initiates an NDList with the specified {@link NDArray}s. + * + * @param other the {@link NDArray}s + */ + public NDList(Collection other) { + super(other); + } + + /** + * Decodes NDList from byte array. + * + * @param parent the parent {@link MxResource} to manage this instance + * @param byteArray byte array to load from + * @return {@code NDList} + */ + public static NDList decode(MxResource parent, byte[] byteArray) { + return decode(parent, new ByteArrayInputStream(byteArray)); + } + + /** + * Decodes NDList from {@link InputStream}. + * + * @param parent {@link MxResource} assigned to {@link NDArray} + * @param is input stream contains the ndlist information + * @return {@code NDList} + */ + public static NDList decode(MxResource parent, InputStream is) { + try (DataInputStream dis = new DataInputStream(is)) { + int size = dis.readInt(); + if (size < 0) { + throw new IllegalArgumentException("Invalid NDList size: " + size); + } + NDList list = new NDList(); + for (int i = 0; i < size; i++) { + list.add(i, NDArray.decode(parent, dis)); + } + return list; + } catch (IOException e) { + throw new IllegalArgumentException("Malformed data", e); + } + } + + /** + * Removes the first occurrence of the specified element from this NDList if it is present. + * + *

If this list does not contain the element, it is unchanged. More formally, removes the + * element with the lowest index {@code i} such that {@code + * (o==null ? get(i)==null : o.equals(get(i)))} (if such an element exists). + * + * @param name the name of the NDArray to be removed from this NDList, if present + * @return the element that was removed + */ + public NDArray remove(String name) { + int index = 0; + for (NDArray array : this) { + if (name.equals(array.getName())) { + remove(index); + return array; + } + ++index; + } + return null; + } + + /** + * Returns {@code true} if this NDList contains an NDArray with the specified name. + * + * @param name the name of the NDArray to be removed from this NDList, if present + * @return {@code true} if this list contains the specified element + */ + public boolean contains(String name) { + for (NDArray array : this) { + if (name.equals(array.getName())) { + return true; + } + } + return false; + } + + /** + * Returns the head index of the NDList. + * + * @return the head NDArray + * @throws IndexOutOfBoundsException if the index is out of range ({@code index < 0 || index + * >= size()}) + */ + public NDArray head() { + return get(0); + } + + /** + * Returns the only element if this is a singleton NDList or throws an exception if multiple + * elements. + * + * @return the head NDArray + * @throws IndexOutOfBoundsException if the list does not contain exactly one element + */ + public NDArray singletonOrThrow() { + if (size() != 1) { + throw new IndexOutOfBoundsException( + "Incorrect number of elements in NDList.singletonOrThrow: Expected 1 and was " + + size()); + } + return get(0); + } + + /** + * Appends all of the NDArrays in the specified NDList to the end of this NDList, in the order + * that they are returned by the specified NDList's iterator. + * + * @param other the NDList containing NDArray to be added to this list + * @return this NDList after the addition + */ + public NDList addAll(NDList other) { + for (NDArray array : other) { + add(array); + } + return this; + } + + /** + * Returns a view of the portion of this NDList between the specified fromIndex, inclusive, and + * to the end. + * + * @param fromIndex the start index (inclusive) + * @return a view of the portion of this NDList + */ + public NDList subNDList(int fromIndex) { + return new NDList(subList(fromIndex, size())); + } + + /** + * Converts all the {@code NDArray} in {@code NDList} to a different {@link Device}. + * + * @param device the {@link Device} to be set + * @param copy set {@code true} if you want to return a copy of the underlying NDArray + * @return a new {@code NDList} with the NDArrays on specified {@link Device} + */ + public NDList toDevice(Device device, boolean copy) { + if (!copy) { + // if all arrays in NDList are already on device, return itself + if (this.stream().allMatch(array -> array.getDevice() == device)) { + return this; + } + } + NDList newNDList = new NDList(size()); + forEach(a -> newNDList.add(a.toDevice(device, copy))); + return newNDList; + } + + /** + * Encodes the NDList to byte array. + * + * @return the byte array + */ + public byte[] encode() { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) { + DataOutputStream dos = new DataOutputStream(baos); + dos.writeInt(size()); + for (NDArray nd : this) { + dos.write(nd.encode()); + } + dos.flush(); + return baos.toByteArray(); + } catch (IOException e) { + throw new AssertionError("NDList is not writable", e); + } + } + + /** + * Gets all of shapes in the {@code NDList}. + * + * @return shapes in {@code NDList} + */ + public Shape[] getShapes() { + return stream().map(NDArray::getShape).toArray(Shape[]::new); + } + + /** {@inheritDoc} */ + @Override + public void close() { + forEach(NDArray::close); + clear(); + } + + /** {@inheritDoc} */ + @Override + public String toString() { + StringBuilder builder = new StringBuilder(200); + builder.append("NDList size: ").append(size()).append('\n'); + int index = 0; + for (NDArray array : this) { + String name = array.getName(); + builder.append(index++).append(' '); + if (name != null) { + builder.append(name); + } + builder.append(": ") + .append(array.getShape()) + .append(' ') + .append(array.getDataType()) + .append('\n'); + } + return builder.toString(); + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDSerializer.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDSerializer.java new file mode 100644 index 000000000000..34d37a50c86b --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDSerializer.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.ndarray; + +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import org.apache.mxnet.engine.MxResource; +import org.apache.mxnet.ndarray.types.DataType; +import org.apache.mxnet.ndarray.types.Shape; + +/** A interface contains encoding and decoding logic for NDArray. */ +public final class NDSerializer { + + static final int BUFFER_SIZE = 81920; + static final String MAGIC_NUMBER = "NDAR"; + static final int VERSION = 2; + + private NDSerializer() {} + + /** + * Allocates a new engine specific direct byte buffer. + * + * @param capacity the new buffer's capacity, in bytes + * @return the new byte buffer + */ + public static ByteBuffer allocateDirect(int capacity) { + return ByteBuffer.allocateDirect(capacity).order(ByteOrder.nativeOrder()); + } + + /** + * Encodes {@link NDArray} to byte array. + * + * @param array the input {@link NDArray} + * @return byte array + */ + static byte[] encode(NDArray array) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) { + DataOutputStream dos = new DataOutputStream(baos); + // magic string for version identification + dos.writeUTF(MAGIC_NUMBER); + dos.writeInt(VERSION); + String name = array.getName(); + if (name == null) { + dos.write(0); + } else { + dos.write(1); + dos.writeUTF(name); + } + dos.writeUTF(array.getSparseFormat().name()); + dos.writeUTF(array.getDataType().name()); + + Shape shape = array.getShape(); + dos.write(shape.getEncoded()); + + ByteBuffer bb = array.toByteBuffer(); + int length = bb.remaining(); + dos.writeInt(length); + + if (length > 0) { + if (length > BUFFER_SIZE) { + byte[] buf = new byte[BUFFER_SIZE]; + while (length > BUFFER_SIZE) { + bb.get(buf); + dos.write(buf); + length = bb.remaining(); + } + } + + byte[] buf = new byte[length]; + bb.get(buf); + dos.write(buf); + } + dos.flush(); + return baos.toByteArray(); + } catch (IOException e) { + throw new AssertionError("This should never happen", e); + } + } + + /** + * Decodes {@link NDArray} through {@link DataInputStream}. + * + * @param parent the parent MxResource object which create the returned object + * @param is input stream data to load from + * @return {@link NDArray} + * @throws IOException data is not readable + */ + public static NDArray decode(MxResource parent, InputStream is) throws IOException { + DataInputStream dis; + if (is instanceof DataInputStream) { + dis = (DataInputStream) is; + } else { + dis = new DataInputStream(is); + } + + if (!"NDAR".equals(dis.readUTF())) { + throw new IllegalArgumentException("Malformed NDArray data"); + } + + // NDArray encode version + int version = dis.readInt(); + if (version < 1 || version > VERSION) { + throw new IllegalArgumentException("Unexpected NDArray encode version " + version); + } + + String name = null; + if (version > 1) { + byte flag = dis.readByte(); + if (flag == 1) { + name = dis.readUTF(); + } + } + + dis.readUTF(); // ignore SparseFormat + + // DataType - 1 byte + DataType dataType = DataType.valueOf(dis.readUTF()); + + // Shape + Shape shape = Shape.decode(dis); + + // Data + int length = dis.readInt(); + ByteBuffer data = allocateDirect(length); + + if (length > 0) { + byte[] buf = new byte[BUFFER_SIZE]; + while (length > BUFFER_SIZE) { + dis.readFully(buf); + data.put(buf); + length -= BUFFER_SIZE; + } + + dis.readFully(buf, 0, length); + data.put(buf, 0, length); + data.rewind(); + } + NDArray array = NDArray.create(parent, dataType.asDataType(data), shape, dataType); + array.setName(name); + return array; + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/NDIndexAll.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/NDIndexAll.java new file mode 100644 index 000000000000..8d348e486cc3 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/NDIndexAll.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.ndarray.dim; + +/** An {@code NDIndexElement} to return all values in a particular dimension. */ +public class NDIndexAll implements NDIndexElement { + + /** {@inheritDoc} */ + @Override + public int getRank() { + return 1; + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/NDIndexBooleans.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/NDIndexBooleans.java new file mode 100644 index 000000000000..1468414f5270 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/NDIndexBooleans.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.ndarray.dim; + +import org.apache.mxnet.ndarray.NDArray; + +/** An {@code NDIndexElement} to return values based on a mask binary NDArray. */ +public class NDIndexBooleans implements NDIndexElement { + + private NDArray index; + + /** + * Constructs a {@code NDIndexBooleans} instance with specified mask binary NDArray. + * + * @param index the mask binary {@code NDArray} + */ + public NDIndexBooleans(NDArray index) { + this.index = index; + } + + /** + * Returns the mask binary {@code NDArray}. + * + * @return the mask binary {@code NDArray} + */ + public NDArray getIndex() { + return index; + } + + /** {@inheritDoc} */ + @Override + public int getRank() { + return index.getShape().dimension(); + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/NDIndexElement.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/NDIndexElement.java new file mode 100644 index 000000000000..4f89fbb5732d --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/NDIndexElement.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.ndarray.dim; + +/** An index for particular dimensions created by NDIndex. */ +public interface NDIndexElement { + + /** + * Returns the number of dimensions occupied by this index element. + * + * @return the number of dimensions occupied by this index element + */ + int getRank(); +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/NDIndexFixed.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/NDIndexFixed.java new file mode 100644 index 000000000000..e0713ac5790c --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/NDIndexFixed.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mxnet.ndarray.dim; + +/** An NDIndexElement that returns only a specific value in the corresponding dimension. */ +public class NDIndexFixed implements NDIndexElement { + + private long index; + + /** + * Constructs a {@code NDIndexFixed} instance with specified dimension. + * + * @param index the dimension of the NDArray + */ + public NDIndexFixed(long index) { + this.index = index; + } + + /** + * Returns the dimension of the index. + * + * @return the dimension of the index + */ + public long getIndex() { + return index; + } + + /** {@inheritDoc} */ + @Override + public int getRank() { + return 1; + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/NDIndexPick.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/NDIndexPick.java new file mode 100644 index 000000000000..f651f015144e --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/NDIndexPick.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mxnet.ndarray.dim; + +import org.apache.mxnet.ndarray.NDArray; + +/** An {@link NDIndexElement} that gets elements by index in the specified axis. */ +public class NDIndexPick implements NDIndexElement { + + private NDArray indices; + + /** + * Constructs a pick. + * + * @param indices the indices to pick + */ + public NDIndexPick(NDArray indices) { + this.indices = indices; + } + + @Override + /** {@inheritDoc} */ + public int getRank() { + return 1; + } + + /** + * Returns the indices to pick. + * + * @return the indices to pick + */ + public NDArray getIndices() { + return indices; + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/NDIndexSlice.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/NDIndexSlice.java new file mode 100644 index 000000000000..e87784c6b384 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/NDIndexSlice.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mxnet.ndarray.dim; + +/** An NDIndexElement that returns a range of values in the specified dimension. */ +public class NDIndexSlice implements NDIndexElement { + + private Long min; + private Long max; + private Long step; + + /** + * Constructs a {@code NDIndexSlice} instance with specified range and step. + * + * @param min the start of the range + * @param max the end of the range + * @param step the step between each slice + * @throws IllegalArgumentException Thrown if the step is zero + */ + public NDIndexSlice(Long min, Long max, Long step) { + this.min = min; + this.max = max; + this.step = step; + if (step != null && step == 0) { + throw new IllegalArgumentException("The step can not be zero"); + } + } + + /** + * Returns the start of the range. + * + * @return the start of the range + */ + public Long getMin() { + return min; + } + + /** + * Returns the end of the range. + * + * @return the end of the range + */ + public Long getMax() { + return max; + } + + /** + * Returns the step between each slice. + * + * @return the step between each slice + */ + public Long getStep() { + return step; + } + + /** {@inheritDoc} */ + @Override + public int getRank() { + return 1; + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/full/NDIndexFullPick.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/full/NDIndexFullPick.java new file mode 100644 index 000000000000..632badf4bfdb --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/full/NDIndexFullPick.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mxnet.ndarray.dim.full; + +import java.util.Optional; +import org.apache.mxnet.ndarray.NDArray; +import org.apache.mxnet.ndarray.dim.NDIndexAll; +import org.apache.mxnet.ndarray.dim.NDIndexElement; +import org.apache.mxnet.ndarray.dim.NDIndexPick; +import org.apache.mxnet.ndarray.index.NDIndex; +import org.apache.mxnet.ndarray.types.Shape; + +/** A simplified representation of a pick-based {@link NDArray}. */ +public final class NDIndexFullPick { + + private NDArray indices; + private int axis; + + /** + * Constructs a new {@link NDIndexFullPick}. + * + * @param indices the indices to pick + * @param axis the axis to pick at + */ + private NDIndexFullPick(NDArray indices, int axis) { + this.indices = indices; + this.axis = axis; + } + + /** + * Returns (if possible) the {@link NDIndexFullPick} representation of an {@link NDIndex}. + * + * @param index the index to represent + * @param target the shape of the array to index + * @return the full pick representation or nothing if it can't represent the index + */ + public static Optional fromIndex(NDIndex index, Shape target) { + int axis = 0; + NDIndexFullPick fullPick = null; + for (NDIndexElement el : index.getIndices()) { + if (el instanceof NDIndexAll) { + axis++; + } else if (el instanceof NDIndexPick) { + if (fullPick == null) { + fullPick = new NDIndexFullPick(((NDIndexPick) el).getIndices(), axis); + } else { + // Don't support multiple picks + throw new UnsupportedOperationException( + "Only one pick per get is currently supported"); + } + } else { + // Invalid dim for fullPick + return Optional.empty(); + } + } + return Optional.ofNullable(fullPick); + } + + /** + * Returns the indices to pick. + * + * @return the indices to pick + */ + public NDArray getIndices() { + return indices; + } + + /** + * Returns the axis to pick. + * + * @return the axis to pick + */ + public int getAxis() { + return axis; + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/full/NDIndexFullSlice.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/full/NDIndexFullSlice.java new file mode 100644 index 000000000000..9b8d9f0cf80a --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/full/NDIndexFullSlice.java @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mxnet.ndarray.dim.full; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import org.apache.mxnet.ndarray.dim.NDIndexAll; +import org.apache.mxnet.ndarray.dim.NDIndexElement; +import org.apache.mxnet.ndarray.dim.NDIndexFixed; +import org.apache.mxnet.ndarray.dim.NDIndexSlice; +import org.apache.mxnet.ndarray.index.NDIndex; +import org.apache.mxnet.ndarray.types.Shape; + +/** An index as a slice on all dimensions where some dimensions can be squeezed. */ +public final class NDIndexFullSlice { + private long[] min; + private long[] max; + private long[] step; + private int[] toSqueeze; + private Shape shape; + private Shape squeezedShape; + + /** + * Constructs a {@link NDIndexFullSlice}. + * + * @param min the min for each axis + * @param max the max for each axis + * @param step the step for each axis + * @param toSqueeze the axes to squeeze after slicing + * @param shape the result shape (without squeezing) + * @param squeezedShape the result shape (with squeezing) + */ + private NDIndexFullSlice( + long[] min, + long[] max, + long[] step, + int[] toSqueeze, + Shape shape, + Shape squeezedShape) { + this.min = min; + this.max = max; + this.step = step; + this.toSqueeze = toSqueeze; + this.shape = shape; + this.squeezedShape = squeezedShape; + } + + /** + * Returns (if possible) the {@link NDIndexFullSlice} representation of an {@link NDIndex}. + * + * @param index the index to represent + * @param target the shape of the array to index + * @return the full slice representation or nothing if it can't represent the index + */ + public static Optional fromIndex(NDIndex index, Shape target) { + if (!index.stream() + .allMatch( + ie -> + ie instanceof NDIndexAll + || ie instanceof NDIndexFixed + || ie instanceof NDIndexSlice)) { + return Optional.empty(); + } + int ellipsisIndex = index.getEllipsisIndex(); + int indDimensions = index.getRank(); + int targetDimensions = target.dimension(); + if (indDimensions > target.dimension()) { + throw new IllegalArgumentException( + "The index has too many dimensions - " + + indDimensions + + " dimensions for array with " + + targetDimensions + + " dimensions"); + } + long[] min = new long[targetDimensions]; + long[] max = new long[targetDimensions]; + long[] step = new long[targetDimensions]; + List toSqueeze = new ArrayList<>(targetDimensions); + long[] shape = new long[targetDimensions]; + List squeezedShape = new ArrayList<>(targetDimensions); + if (ellipsisIndex == -1 || ellipsisIndex == indDimensions) { + // ellipsis in the end and non ellipsis case + for (int i = 0; i < indDimensions; i++) { + NDIndexElement ie = index.get(i); + addSliceInfo(ie, i, target, min, max, step, toSqueeze, shape, squeezedShape); + } + for (int i = indDimensions; i < target.dimension(); i++) { + padIndexAll(i, target, min, max, step, shape, squeezedShape); + } + } else if (ellipsisIndex == 0) { + // ellipsis in the beginning + int paddingDim = targetDimensions - indDimensions; + int i; + for (i = 0; i < paddingDim; ++i) { + padIndexAll(i, target, min, max, step, shape, squeezedShape); + } + for (; i < targetDimensions; ++i) { + NDIndexElement ie = index.get(i - paddingDim); + addSliceInfo(ie, i, target, min, max, step, toSqueeze, shape, squeezedShape); + } + } else { + // ellipsis in the middle + int paddingDim = targetDimensions - indDimensions; + int i; + for (i = 0; i < ellipsisIndex; ++i) { + NDIndexElement ie = index.get(i); + addSliceInfo(ie, i, target, min, max, step, toSqueeze, shape, squeezedShape); + } + for (; i < paddingDim + ellipsisIndex; ++i) { + padIndexAll(i, target, min, max, step, shape, squeezedShape); + } + for (; i < targetDimensions; ++i) { + NDIndexElement ie = index.get(i - paddingDim); + addSliceInfo(ie, i, target, min, max, step, toSqueeze, shape, squeezedShape); + } + } + int[] squeeze = toSqueeze.stream().mapToInt(i -> i).toArray(); + NDIndexFullSlice fullSlice = + new NDIndexFullSlice( + min, max, step, squeeze, new Shape(shape), new Shape(squeezedShape)); + return Optional.of(fullSlice); + } + + private static void addSliceInfo( + NDIndexElement ie, + int i, + Shape target, + long[] min, + long[] max, + long[] step, + List toSqueeze, + long[] shape, + List squeezedShape) { + if (ie instanceof NDIndexFixed) { + NDIndexFixed fixed = ((NDIndexFixed) ie); + long rawIndex = fixed.getIndex(); + min[i] = rawIndex < 0 ? Math.floorMod(rawIndex, target.get(i)) : rawIndex; + max[i] = min[i] + 1; + step[i] = 1; + toSqueeze.add(i); + shape[i] = 1; + } else if (ie instanceof NDIndexSlice) { + NDIndexSlice slice = (NDIndexSlice) ie; + long rawMin = Optional.ofNullable(slice.getMin()).orElse(0L); + min[i] = rawMin < 0 ? Math.floorMod(rawMin, target.get(i)) : rawMin; + long rawMax = Optional.ofNullable(slice.getMax()).orElse(target.size(i)); + max[i] = rawMax < 0 ? Math.floorMod(rawMax, target.get(i)) : rawMax; + step[i] = Optional.ofNullable(slice.getStep()).orElse(1L); + shape[i] = (long) Math.ceil(((double) (max[i] - min[i])) / step[i]); + squeezedShape.add(shape[i]); + } else if (ie instanceof NDIndexAll) { + padIndexAll(i, target, min, max, step, shape, squeezedShape); + } + } + + private static void padIndexAll( + int i, + Shape target, + long[] min, + long[] max, + long[] step, + long[] shape, + List squeezedShape) { + min[i] = 0; + max[i] = target.size(i); + step[i] = 1; + shape[i] = target.size(i); + squeezedShape.add(target.size(i)); + } + + /** + * Returns the slice min for each axis. + * + * @return the slice min for each axis + */ + public long[] getMin() { + return min; + } + + /** + * Returns the slice max for each axis. + * + * @return the slice max for each axis + */ + public long[] getMax() { + return max; + } + + /** + * Returns the slice step for each axis. + * + * @return the slice step for each axis + */ + public long[] getStep() { + return step; + } + + /** + * Returns the squeeze array of axis. + * + * @return the squeeze array of axis + */ + public int[] getToSqueeze() { + return toSqueeze; + } + + /** + * Returns the slice shape without squeezing. + * + * @return the slice shape without squeezing + */ + public Shape getShape() { + return shape; + } + + /** + * Returns the slice shape with squeezing. + * + * @return the slice shape with squeezing + */ + public Shape getSqueezedShape() { + return squeezedShape; + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/full/package-info.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/full/package-info.java new file mode 100644 index 000000000000..e796f52f8d94 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/full/package-info.java @@ -0,0 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Contains the Java front-end implementation for Apache MXNet. */ +package org.apache.mxnet.ndarray.dim.full; diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/package-info.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/package-info.java new file mode 100644 index 000000000000..9a5c2f8e1c39 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/package-info.java @@ -0,0 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Contains the Java front-end implementation for Apache MXNet. */ +package org.apache.mxnet.ndarray.dim; diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/index/NDIndex.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/index/NDIndex.java new file mode 100644 index 000000000000..f08ee675df65 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/index/NDIndex.java @@ -0,0 +1,418 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mxnet.ndarray.index; + +import java.util.ArrayList; +import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Stream; +import org.apache.mxnet.ndarray.NDArray; +import org.apache.mxnet.ndarray.dim.NDIndexAll; +import org.apache.mxnet.ndarray.dim.NDIndexBooleans; +import org.apache.mxnet.ndarray.dim.NDIndexElement; +import org.apache.mxnet.ndarray.dim.NDIndexFixed; +import org.apache.mxnet.ndarray.dim.NDIndexPick; +import org.apache.mxnet.ndarray.dim.NDIndexSlice; +import org.apache.mxnet.ndarray.types.DataType; + +/** + * The {@code NDIndex} allows you to specify a subset of an NDArray that can be used for fetching or + * updating. + * + *

It accepts a different index option for each dimension, given in the order of the dimensions. + * Each dimension has options corresponding to: + * + *

    + *
  • Return all dimensions - Pass null to addIndices + *
  • A single value in the dimension - Pass the value to addIndices with a negative index -i + * corresponding to [dimensionLength - i] + *
  • A range of values - Use addSliceDim + *
+ * + *

We recommend creating the NDIndex using {@link #NDIndex(String, Object...)}. + * + * @see #NDIndex(String, Object...) + */ +public class NDIndex { + + /* Android regex requires escape } char as well */ + private static final Pattern ITEM_PATTERN = + Pattern.compile( + "(\\*)|((-?\\d+|\\{\\})?:(-?\\d+|\\{\\})?(:(-?\\d+|\\{\\}))?)|(-?\\d+|\\{\\})"); + + private int rank; + private List indices; + private int ellipsisIndex; + + /** Creates an empty {@link NDIndex} to append values to. */ + public NDIndex() { + rank = 0; + indices = new ArrayList<>(); + ellipsisIndex = -1; + } + + /** + * Creates a {@link NDIndex} given the index values. + * + *

Here are some examples of the indices format. + * + *

+     *     NDArray a = manager.ones(new Shape(5, 4, 3));
+     *
+     *     // Gets a subsection of the NDArray in the first axis.
+     *     assertEquals(a.get(new NDIndex("2")).getShape(), new Shape(4, 3));
+     *
+     *     // Gets a subsection of the NDArray indexing from the end (-i == length - i).
+     *     assertEquals(a.get(new NDIndex("-1")).getShape(), new Shape(4, 3));
+     *
+     *     // Gets everything in the first axis and a subsection in the second axis.
+     *     // You can use either : or * to represent everything
+     *     assertEquals(a.get(new NDIndex(":, 2")).getShape(), new Shape(5, 3));
+     *     assertEquals(a.get(new NDIndex("*, 2")).getShape(), new Shape(5, 3));
+     *
+     *     // Gets a range of values along the second axis that is inclusive on the bottom and exclusive on the top.
+     *     assertEquals(a.get(new NDIndex(":, 1:3")).getShape(), new Shape(5, 2, 3));
+     *
+     *     // Excludes either the min or the max of the range to go all the way to the beginning or end.
+     *     assertEquals(a.get(new NDIndex(":, :3")).getShape(), new Shape(5, 3, 3));
+     *     assertEquals(a.get(new NDIndex(":, 1:")).getShape(), new Shape(5, 4, 3));
+     *
+     *     // Uses the value after the second colon in a slicing range, the step, to get every other result.
+     *     assertEquals(a.get(new NDIndex(":, 1::2")).getShape(), new Shape(5, 2, 3));
+     *
+     *     // Uses a negative step to reverse along the dimension.
+     *     assertEquals(a.get(new NDIndex("-1")).getShape(), new Shape(5, 4, 3));
+     *
+     *     // Uses a variable argument to the index
+     *     // It can replace any number in any of these formats with {} and then the value of {}
+     *     // is specified in an argument following the indices string.
+     *     assertEquals(a.get(new NDIndex("{}, {}:{}", 0, 1, 3)).getShape(), new Shape(2, 3));
+     *
+     *     // Uses ellipsis to insert many full slices
+     *     assertEquals(a.get(new NDIndex("...")).getShape(), new Shape(5, 4, 3));
+     *
+     *     // Uses ellipsis to select all the dimensions except for last axis where we only get a subsection.
+     *     assertEquals(a.get(new NDIndex("..., 2")).getShape(), new Shape(5, 4));
+     * 
+ * + * @param indices a comma separated list of indices corresponding to either subsections, + * everything, or slices on a particular dimension + * @param args arguments to replace the variable "{}" in the indices string. Can be an integer, + * long, boolean {@link NDArray}, or integer {@link NDArray}. + * @see Numpy + * Indexing + */ + public NDIndex(String indices, Object... args) { + this(); + addIndices(indices, args); + } + + /** + * Creates an NDIndex with the given indices as specified values on the NDArray. + * + * @param indices the indices with each index corresponding to the dimensions and negative + * indices starting from the end + */ + public NDIndex(long... indices) { + this(); + addIndices(indices); + } + + /** + * Creates an {@link NDIndex} that just has one slice in the given axis. + * + * @param axis the axis to slice + * @param min the min of the slice + * @param max the max of the slice + * @return a new {@link NDIndex} with the given slice. + */ + public static NDIndex sliceAxis(int axis, long min, long max) { + NDIndex ind = new NDIndex(); + for (int i = 0; i < axis; i++) { + ind.addAllDim(); + } + ind.addSliceDim(min, max); + return ind; + } + + /** + * Returns the number of dimensions specified in the Index. + * + * @return the number of dimensions specified in the Index + */ + public int getRank() { + return rank; + } + + /** + * Returns the index of the ellipsis. + * + * @return the index of the ellipsis within this index or -1 for none. + */ + public int getEllipsisIndex() { + return ellipsisIndex; + } + + /** + * Returns the index affecting the given dimension. + * + * @param dimension the affected dimension + * @return the index affecting the given dimension + */ + public NDIndexElement get(int dimension) { + return indices.get(dimension); + } + + /** + * Returns the indices. + * + * @return the indices + */ + public List getIndices() { + return indices; + } + + /** + * Updates the NDIndex by appending indices to the array. + * + * @param indices the indices to add similar to {@link #NDIndex(String, Object...)} + * @param args arguments to replace the variable "{}" in the indices string. Can be an integer, + * long, boolean {@link NDArray}, or integer {@link NDArray}. + * @return the updated {@link NDIndex} + * @see #NDIndex(String, Object...) + */ + public final NDIndex addIndices(String indices, Object... args) { + String[] indexItems = indices.split(","); + rank += indexItems.length; + int argIndex = 0; + for (int i = 0; i < indexItems.length; ++i) { + if ("...".equals(indexItems[i].trim())) { + // make sure ellipsis appear only once + if (ellipsisIndex != -1) { + throw new IllegalArgumentException( + "an index can only have a single ellipsis (\"...\")"); + } + ellipsisIndex = i; + } else { + argIndex = addIndexItem(indexItems[i], argIndex, args); + } + } + if (ellipsisIndex != -1) { + rank--; + } + if (argIndex != args.length) { + throw new IllegalArgumentException("Incorrect number of index arguments"); + } + return this; + } + + /** + * Updates the NDIndex by appending indices as specified values on the NDArray. + * + * @param indices with each index corresponding to the dimensions and negative indices starting + * from the end + * @return the updated {@link NDIndex} + */ + public final NDIndex addIndices(long... indices) { + rank += indices.length; + for (long i : indices) { + this.indices.add(new NDIndexFixed(i)); + } + return this; + } + + /** + * Updates the NDIndex by appending a boolean NDArray. + * + *

The NDArray should have a matching shape to the dimensions being fetched and will return + * where the values in NDIndex do not equal zero. + * + * @param index a boolean NDArray where all nonzero elements correspond to elements to return + * @return the updated {@link NDIndex} + */ + public NDIndex addBooleanIndex(NDArray index) { + rank += index.getShape().dimension(); + indices.add(new NDIndexBooleans(index)); + return this; + } + + /** + * Appends a new index to get all values in the dimension. + * + * @return the updated {@link NDIndex} + */ + public NDIndex addAllDim() { + rank++; + indices.add(new NDIndexAll()); + return this; + } + + /** + * Appends multiple new index to get all values in the dimension. + * + * @param count how many axes of {@link NDIndexAll} to add. + * @return the updated {@link NDIndex} + * @throws IllegalArgumentException if count is negative + */ + public NDIndex addAllDim(int count) { + if (count < 0) { + throw new IllegalArgumentException( + "The number of index dimensions to add can't be negative"); + } + rank += count; + for (int i = 0; i < count; i++) { + indices.add(new NDIndexAll()); + } + return this; + } + + /** + * Appends a new index to slice the dimension and returns a range of values. + * + * @param min the minimum of the range + * @param max the maximum of the range + * @return the updated {@link NDIndex} + */ + public NDIndex addSliceDim(long min, long max) { + rank++; + indices.add(new NDIndexSlice(min, max, null)); + return this; + } + + /** + * Appends a new index to slice the dimension and returns a range of values. + * + * @param min the minimum of the range + * @param max the maximum of the range + * @param step the step of the slice + * @return the updated {@link NDIndex} + */ + public NDIndex addSliceDim(long min, long max, long step) { + rank++; + indices.add(new NDIndexSlice(min, max, step)); + return this; + } + + /** + * Appends a picking index that gets values by index in the axis. + * + * @param index the indices should be NDArray. For each element in the indices array, it acts + * like a fixed index returning an element of that shape. So, the final shape would be + * indices.getShape().addAll(target.getShape().slice(1)) (assuming it is the first index + * element). + * @return the updated {@link NDIndex} + */ + public NDIndex addPickDim(NDArray index) { + rank++; + indices.add(new NDIndexPick(index)); + return this; + } + + /** + * Returns a stream of the NDIndexElements. + * + * @return a stream of the NDIndexElements + */ + public Stream stream() { + return indices.stream(); + } + + private int addIndexItem(String indexItem, int argIndex, Object[] args) { + indexItem = indexItem.trim(); + Matcher m = ITEM_PATTERN.matcher(indexItem); + if (!m.matches()) { + throw new IllegalArgumentException("Invalid argument index: " + indexItem); + } + // "*" case + String star = m.group(1); + if (star != null) { + indices.add(new NDIndexAll()); + return argIndex; + } + // "number" number only case + String digit = m.group(7); + if (digit != null) { + if ("{}".equals(digit)) { + Object arg = args[argIndex]; + if (arg instanceof Integer) { + indices.add(new NDIndexFixed((Integer) arg)); + return argIndex + 1; + } else if (arg instanceof Long) { + indices.add(new NDIndexFixed((Long) arg)); + return argIndex + 1; + } else if (arg instanceof NDArray) { + NDArray array = (NDArray) arg; + if (array.getDataType() == DataType.BOOLEAN) { + indices.add(new NDIndexBooleans(array)); + return argIndex + 1; + } else if (array.getDataType().isInteger()) { + indices.add(new NDIndexPick(array)); + return argIndex + 1; + } + } + throw new IllegalArgumentException("Unknown argument: " + arg); + } else { + indices.add(new NDIndexFixed(Long.parseLong(digit))); + return argIndex; + } + } + + // Slice + Long min = null; + Long max = null; + Long step = null; + if (m.group(3) != null) { + min = parseSliceItem(m.group(3), argIndex, args); + if ("{}".equals(m.group(3))) { + argIndex++; + } + } + if (m.group(4) != null) { + max = parseSliceItem(m.group(4), argIndex, args); + if ("{}".equals(m.group(4))) { + argIndex++; + } + } + if (m.group(6) != null) { + step = parseSliceItem(m.group(6), argIndex, args); + if ("{}".equals(m.group(6))) { + argIndex++; + } + } + if (min == null && max == null && step == null) { + indices.add(new NDIndexAll()); + } else { + indices.add(new NDIndexSlice(min, max, step)); + } + return argIndex; + } + + private Long parseSliceItem(String sliceItem, int argIndex, Object... args) { + if ("{}".equals(sliceItem)) { + Object arg = args[argIndex]; + if (arg instanceof Integer) { + return ((Integer) arg).longValue(); + } else if (arg instanceof Long) { + return (Long) arg; + } + throw new IllegalArgumentException("Unknown slice argument: " + arg); + } else { + return Long.parseLong(sliceItem); + } + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/index/package-info.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/index/package-info.java new file mode 100644 index 000000000000..7cc862c4827c --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/index/package-info.java @@ -0,0 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Contains the Java front-end implementation for Apache MXNet. */ +package org.apache.mxnet.ndarray.index; diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/package-info.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/package-info.java new file mode 100644 index 000000000000..b161896895f7 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/package-info.java @@ -0,0 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Contains the Java front-end implementation for Apache MXNet. */ +package org.apache.mxnet.ndarray; diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/types/DataType.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/types/DataType.java new file mode 100644 index 000000000000..bda4181f9780 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/types/DataType.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.ndarray.types; + +import java.nio.Buffer; +import java.nio.ByteBuffer; +import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; +import java.nio.IntBuffer; +import java.nio.LongBuffer; +import org.apache.mxnet.ndarray.NDArray; + +/** An enum representing the underlying {@link NDArray}'s data type. */ +public enum DataType { + FLOAT32(Format.FLOATING, 4), + FLOAT64(Format.FLOATING, 8), + FLOAT16(Format.FLOATING, 2), + UINT8(Format.UINT, 1), + INT32(Format.INT, 4), + INT8(Format.INT, 1), + INT64(Format.INT, 8), + BOOLEAN(Format.BOOLEAN, 1), + UNKNOWN(Format.UNKNOWN, 0); + /** The general data type format categories. */ + public enum Format { + FLOATING, + UINT, + INT, + BOOLEAN, + UNKNOWN + } + + private Format format; + private int numOfBytes; + + DataType(Format format, int numOfBytes) { + this.format = format; + this.numOfBytes = numOfBytes; + } + + /** + * Returns the number of bytes for each element. + * + * @return the number of bytes for each element + */ + public int getNumOfBytes() { + return numOfBytes; + } + + /** + * Returns the format of the data type. + * + * @return the format of the data type + */ + public Format getFormat() { + return format; + } + + /** + * Checks whether it is a floating data type. + * + * @return whether it is a floating data type + */ + public boolean isFloating() { + return format == Format.FLOATING; + } + + /** + * Checks whether it is an integer data type. + * + * @return whether it is an integer type + */ + public boolean isInteger() { + return format == Format.UINT || format == Format.INT; + } + + /** + * Returns the data type to use for a data buffer. + * + * @param data the buffer to analyze + * @return the data type for the buffer + */ + public static DataType fromBuffer(Buffer data) { + if (data instanceof FloatBuffer) { + return DataType.FLOAT32; + } else if (data instanceof DoubleBuffer) { + return DataType.FLOAT64; + } else if (data instanceof IntBuffer) { + return DataType.INT32; + } else if (data instanceof LongBuffer) { + return DataType.INT64; + } else if (data instanceof ByteBuffer) { + return DataType.INT8; + } else { + throw new IllegalArgumentException( + "Unsupported buffer type: " + data.getClass().getSimpleName()); + } + } + + /** + * Converts a {@link ByteBuffer} to a buffer for this data type. + * + * @param data the buffer to convert + * @return the converted buffer + */ + public Buffer asDataType(ByteBuffer data) { + switch (this) { + case FLOAT32: + return data.asFloatBuffer(); + case FLOAT64: + return data.asDoubleBuffer(); + case INT32: + return data.asIntBuffer(); + case INT64: + return data.asLongBuffer(); + case UINT8: + case INT8: + case FLOAT16: + case UNKNOWN: + default: + return data; + } + } + + /** {@inheritDoc} */ + @Override + public String toString() { + return name().toLowerCase(); + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/types/LayoutType.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/types/LayoutType.java new file mode 100644 index 000000000000..9602146bc330 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/types/LayoutType.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.ndarray.types; + +import java.util.stream.IntStream; +import org.apache.mxnet.ndarray.NDArray; + +/** + * An enum to represent the meaning of a particular axis in an {@link NDArray}. + * + *

The options are: + * + *

    + *
  • {@link LayoutType#BATCH} - Different elements in a batch + *
  • {@link LayoutType#CHANNEL} - Each channel represents a different aspect of the data such as + * RGB showing different color channels. + *
  • {@link LayoutType#DEPTH} - The depth of a 3-D input + *
  • {@link LayoutType#HEIGHT} - The width of a multi-dimensional input, usually an image. + *
  • {@link LayoutType#WIDTH} - The height of a multi-dimensional input, usually an image. + *
  • {@link LayoutType#TIME} - The time within a sequence such as text or video. + *
  • {@link LayoutType#UNKNOWN} - A unknown or otherwise unrepresentable layout type. + *
+ */ +public enum LayoutType { + BATCH('N'), + CHANNEL('C'), + DEPTH('D'), + HEIGHT('H'), + WIDTH('W'), + TIME('T'), + UNKNOWN('?'); + + private char value; + + LayoutType(char value) { + this.value = value; + } + + /** + * Returns the character representation of the layout type. + * + * @return the character representation of the layout type + */ + public char getValue() { + return value; + } + + /** + * Converts the character to the matching layout type. + * + * @param value the character to convert + * @return the matching layout type + * @throws IllegalArgumentException thrown if the character does not match any layout type + */ + public static LayoutType fromValue(char value) { + for (LayoutType type : LayoutType.values()) { + if (value == type.value) { + return type; + } + } + throw new IllegalArgumentException( + "The value does not match any layoutTypes. Use '?' for Unknown"); + } + + /** + * Converts each character to the matching layout type. + * + * @param layout the character string to convert + * @return the list of layout types for each character in the string + * @throws IllegalArgumentException thrown if the character does not match any layout type + */ + public static LayoutType[] fromValue(String layout) { + return IntStream.range(0, layout.length()) + .mapToObj(i -> fromValue(layout.charAt(i))) + .toArray(LayoutType[]::new); + } + + /** + * Converts a layout type array to a string of the character representations. + * + * @param layouts the layout type to convert + * @return the string of the character representations + */ + public static String toString(LayoutType[] layouts) { + StringBuilder sb = new StringBuilder(layouts.length); + for (LayoutType layout : layouts) { + sb.append(layout.getValue()); + } + return sb.toString(); + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/types/Shape.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/types/Shape.java new file mode 100644 index 000000000000..4348e45c4578 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/types/Shape.java @@ -0,0 +1,484 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.ndarray.types; + +import java.io.DataInputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.List; +import java.util.function.Function; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.LongStream; +import java.util.stream.Stream; +import org.apache.mxnet.ndarray.NDArray; +import org.apache.mxnet.util.Pair; +import org.apache.mxnet.util.PairList; + +/** A class that presents the {@link NDArray}'s shape information. */ +public class Shape { + + private long[] shape; + private LayoutType[] layout; + + /** + * Constructs and initializes a {@code Shape} with specified dimension as {@code (long... + * shape)}. + * + * @param shape the dimensions of the shape + * @throws IllegalArgumentException Thrown if any element in Shape is invalid. It should not be + * less than -1. Also thrown if the shape and layout do not have equal sizes. + */ + public Shape(long... shape) { + this( + shape, + Arrays.stream(shape).mapToObj(x -> LayoutType.UNKNOWN).toArray(LayoutType[]::new)); + } + + /** + * Constructs and initializes a {@code Shape} with specified dimension. + * + * @param shape the dimensions of the shape + * @throws IllegalArgumentException Thrown if any element in Shape is invalid. It should not be + * less than -1. Also thrown if the shape and layout do not have equal sizes. + */ + public Shape(List shape) { + this( + shape.stream().mapToLong(l -> l).toArray(), + shape.stream().map(x -> LayoutType.UNKNOWN).toArray(LayoutType[]::new)); + } + + /** + * Constructs and initializes a {@code Shape} with specified shape and layout pairList. + * + * @param shape the dimensions and layout of the shape + * @throws IllegalArgumentException Thrown if any element in Shape is invalid. It should not be + * less than -1 .Also thrown if the shape and layout do not have equal sizes. + */ + public Shape(PairList shape) { + this( + shape.keys().stream().mapToLong(l -> l).toArray(), + shape.values().toArray(new LayoutType[shape.size()])); + } + + /** + * Constructs and initializes a {@code Shape} with specified dimension and layout. + * + * @param shape the size of each axis of the shape + * @param layout the {@link LayoutType} of each axis in the shape + * @throws IllegalArgumentException Thrown if any element in Shape is invalid. It should not be + * less than -1. Also thrown for an invalid layout. Also thrown if the shape and layout do + * not have equal sizes. + */ + public Shape(long[] shape, String layout) { + this(shape, LayoutType.fromValue(layout)); + } + + /** + * Constructs and initializes a {@code Shape} with specified dimension and layout. + * + * @param shape the size of each axis of the shape + * @param layout the {@link LayoutType} of each axis in the shape + * @throws IllegalArgumentException Thrown if any element in Shape is invalid. It should not be + * less than -1. Also thrown if the shape and layout do not have equal sizes. + */ + public Shape(long[] shape, LayoutType[] layout) { + if (Arrays.stream(shape).anyMatch(s -> s < -1)) { + throw new IllegalArgumentException("The shape must be >= -1"); + } + if (shape.length != layout.length) { + throw new IllegalArgumentException("The shape and layout must have the same length"); + } + this.shape = shape; + this.layout = layout; + } + + /** + * Returns a new shape altering the given dimension. + * + * @param shape the shape to update + * @param dimension the dimension to get the shape in + * @param value the value to set the dimension to + * @return a new shape with the update applied + */ + public static Shape update(Shape shape, int dimension, long value) { + long[] newShape = shape.shape.clone(); + newShape[dimension] = value; + return new Shape(newShape, shape.layout); + } + + /** + * Returns the dimensions of the {@code Shape}. + * + * @return the dimensions of the {@code Shape} + */ + public long[] getShape() { + return shape; + } + + /** + * Returns the shape in the given dimension. + * + * @param dimension the dimension to get the shape in + * @return the shape in the given dimension + */ + public long get(int dimension) { + return shape[dimension]; + } + + /** + * Returns the layout type in the given dimension. + * + * @param dimension the dimension to get the layout type in + * @return the layout type in the given dimension + */ + public LayoutType getLayoutType(int dimension) { + return layout[dimension]; + } + + /** + * Returns the size of a specific dimension or several specific dimensions. + * + * @param dimensions the dimension or dimensions to find the size of + * @return the size of specific dimension(s) or -1 for indeterminate size + * @throws IllegalArgumentException thrown if passed an invalid dimension + */ + public long size(int... dimensions) { + long total = 1; + for (long d : dimensions) { + if (d < 0 || d >= shape.length) { + throw new IllegalArgumentException("Invalid dimension " + d); + } + if (shape[Math.toIntExact(d)] == -1) { + return -1; + } + total *= shape[Math.toIntExact(d)]; + } + return total; + } + + /** + * Returns the total size. + * + * @return the total size or -1 for indeterminate size + */ + public long size() { + long total = 1; + for (long v : shape) { + if (v == -1) { + return -1; + } + total *= v; + } + return total; + } + + /** + * Returns the number of dimensions of this {@code Shape}. + * + * @return the number of dimensions of this {@code Shape} + */ + public int dimension() { + return shape.length; + } + + /** + * Return the count of unknown value in this {@code Shape}. + * + * @return the number of unknown value in this {@code Shape} + */ + public long getUnknownValueCount() { + return Arrays.stream(shape).filter(s -> s == -1).count(); + } + + /** + * Creates a new {@code Shape} whose content is a slice of this shape. + * + *

The sub shape begins at the specified {@code beginIndex} and extends to {@code endIndex - + * 1}. + * + * @param beginIndex the beginning index, inclusive + * @return a new {@code Shape} whose content is a slice of this shape + */ + public Shape slice(int beginIndex) { + return slice(beginIndex, shape.length); + } + + /** + * Creates a new {@code Shape} whose content is a slice of this shape. + * + *

The sub shape begins at the specified {@code beginIndex} and extends to {@code endIndex - + * 1}. + * + * @param beginIndex the beginning index, inclusive + * @param endIndex the ending index, exclusive + * @return a new {@code Shape} whose content is a slice of this shape + */ + public Shape slice(int beginIndex, int endIndex) { + int size = endIndex - beginIndex; + long[] out = new long[size]; + System.arraycopy(shape, beginIndex, out, 0, size); + return new Shape(out); + } + + /** + * Returns only the axes of the Shape whose layout types match the predicate. + * + * @param predicate the predicate to compare the axes of the Shape with + * @return a new filtered Shape + */ + public Shape filterByLayoutType(Predicate predicate) { + return new Shape( + new PairList<>( + this.stream() + .filter(pair -> predicate.test(pair.getValue())) + .collect(Collectors.toList()))); + } + + /** + * Returns a mapped shape. + * + * @param mapper the function to map each element of the Shape by + * @return a new mapped Shape + */ + public Shape map(Function, Pair> mapper) { + return new Shape(new PairList<>(stream().map(mapper).collect(Collectors.toList()))); + } + + /** + * Returns a stream of the Shape. + * + * @return the stream of the Shape + */ + public Stream> stream() { + return new PairList<>( + Arrays.stream(shape).boxed().collect(Collectors.toList()), + Arrays.asList(layout)) + .stream(); + } + + /** + * Joins this shape with axes. + * + * @param axes the axes to join + * @return the joined {@code Shape} + */ + public Shape add(long... axes) { + return this.addAll(new Shape(axes)); + } + + /** + * Joins this shape with specified {@code other} shape. + * + * @param other the shape to join + * @return the joined {@code Shape} + */ + public Shape addAll(Shape other) { + return new Shape( + LongStream.concat(Arrays.stream(shape), Arrays.stream(other.shape)).toArray()); + } + + /** + * Returns the head index of the shape. + * + * @return the head index of the shape + * @throws IndexOutOfBoundsException Thrown if the shape is empty + */ + public long head() { + // scalar case + if (shape.length == 0) { + throw new IndexOutOfBoundsException("can't get value from scalar shape."); + } + return shape[0]; + } + + /** + * Returns the tail index of the shape. + * + * @return the tail index of the shape + * @throws IndexOutOfBoundsException Thrown if the shape is empty + */ + public long tail() { + // scalar case + if (shape.length == 0) { + throw new IndexOutOfBoundsException("can't get value from scalar shape."); + } + return shape[shape.length - 1]; + } + + /** + * Returns the number of trailing ones in the array shape. + * + *

For example, a rank 3 array with shape [10, 1, 1] would return 2 for this method + * + * @return the number of trailing ones in the shape + */ + public int getTrailingOnes() { + for (int i = 0; i < shape.length; i++) { + if (shape[shape.length - i - 1] != 1) { + return i; + } + } + return 0; + } + + /** + * Returns the number of leading ones in the array shape. + * + *

For example, a rank 3 array with shape [1, 10, 1] would return value 1 for this method + * + * @return the number of leading ones in the shape + */ + public int getLeadingOnes() { + for (int i = 0; i < shape.length; i++) { + if (shape[i] != 1) { + return i; + } + } + return 0; + } + + /** + * Returns {@code true} if the NDArray is a scalar. + * + * @return whether the NDArray is a scalar + */ + public boolean isScalar() { + return dimension() == 0; + } + + /** + * Returns {@code true} if the NDArray contains zero dimensions. + * + * @return whether the NDArray contain zero dimensions + */ + public boolean hasZeroDimension() { + for (int i = 0; i < dimension(); i++) { + if (shape[i] == 0) { + return true; + } + } + return false; + } + + /** + * Returns {@code true} if a layout is set. + * + * @return whether a layout has been set + */ + public boolean isLayoutKnown() { + return !Arrays.stream(layout).allMatch(l -> l == LayoutType.UNKNOWN); + } + + /** + * Returns the layout type for each axis in this shape. + * + * @return the layout type for each axis in this shape + */ + public LayoutType[] getLayout() { + return layout; + } + + /** + * Returns the string layout type for each axis in this shape. + * + * @return the string layout type for each axis in this shape + */ + public String toLayoutString() { + return LayoutType.toString(layout); + } + + /** + * Gets the byte array representation of this {@code Shape} for serialization. + * + * @return a byte array representation of this {@code Shape} + */ + public byte[] getEncoded() { + int length = 8 + shape.length * 8 + layout.length * 2; + ByteBuffer bb = ByteBuffer.allocate(length); + bb.putInt(shape.length); + for (long l : shape) { + bb.putLong(l); + } + bb.putInt(layout.length); + for (LayoutType layoutType : layout) { + bb.putChar(layoutType.getValue()); + } + return bb.array(); + } + + /** {@inheritDoc} */ + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Shape shape1 = (Shape) o; + return Arrays.equals(shape, shape1.shape); + } + + /** {@inheritDoc} */ + @Override + public int hashCode() { + return Arrays.hashCode(shape); + } + + /** {@inheritDoc} */ + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append('('); + for (int i = 0; i < shape.length; ++i) { + if (i > 0) { + sb.append(", "); + } + sb.append(shape[i]); + } + sb.append(')'); + return sb.toString(); + } + + /** + * Decodes the data in the given {@link DataInputStream} and converts it into the corresponding + * {@link Shape} object. + * + * @param dis the inputstream to read from + * @return the corresponding {@link Shape} object + * @throws IOException when an I/O error occurs + */ + public static Shape decode(DataInputStream dis) throws IOException { + // Shape + int length = dis.readInt(); + long[] shapeValue = new long[length]; + for (int i = 0; i < length; ++i) { + shapeValue[i] = dis.readLong(); + } + + // Layout + length = dis.readInt(); + char[] layout = new char[length]; + for (int i = 0; i < length; ++i) { + layout[i] = dis.readChar(); + } + return new Shape(shapeValue, new String(layout)); + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/types/SparseFormat.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/types/SparseFormat.java new file mode 100644 index 000000000000..7c3d389b4f1b --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/types/SparseFormat.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.ndarray.types; + +/** + * An enum representing Sparse matrix storage formats. + * + *

    + *
  • DENSE: Stride format + *
  • ROW_SPARSE: Row Sparse + *
  • CSR: Compressed Sparse Row + *
+ * + * @see Sparse Matrix Storage Formats + */ +public enum SparseFormat { + // the dense format is accelerated by MKLDNN by default + DENSE("default", 0), + ROW_SPARSE("row_sparse", 1), + CSR("csr", 2); + + private String type; + private int value; + + SparseFormat(String type, int value) { + this.type = type; + this.value = value; + } + + /** + * Gets the {@code SparseFormat} from it's integer value. + * + * @param value the integer value of the {@code SparseFormat} + * @return a {@code SparseFormat} + */ + public static SparseFormat fromValue(int value) { + for (SparseFormat t : values()) { + if (value == t.getValue()) { + return t; + } + } + throw new IllegalArgumentException("Unknown Sparse type: " + value); + } + + /** + * Returns the {@code SparseFormat} name. + * + * @return the {@code SparseFormat} name + */ + public String getType() { + return type; + } + + /** + * Returns the integer value of this {@code SparseFormat}. + * + * @return the integer value of this {@code SparseFormat} + */ + public int getValue() { + return value; + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/types/package-info.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/types/package-info.java new file mode 100644 index 000000000000..c58f71d5299f --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/types/package-info.java @@ -0,0 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Contains the Java front-end implementation for Apache MXNet. */ +package org.apache.mxnet.ndarray.types; diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/nn/Parameter.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/nn/Parameter.java new file mode 100644 index 000000000000..9a268a3ed668 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/nn/Parameter.java @@ -0,0 +1,301 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.nn; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.Objects; +import java.util.UUID; +import org.apache.mxnet.engine.Device; +import org.apache.mxnet.engine.MxResource; +import org.apache.mxnet.exception.MalformedModelException; +import org.apache.mxnet.ndarray.NDArray; +import org.apache.mxnet.ndarray.NDSerializer; +import org.apache.mxnet.ndarray.types.DataType; +import org.apache.mxnet.ndarray.types.Shape; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * {@code Parameter} is a container class that holds a learnable parameter of a model. + * + *

Every {@code Parameter} is associated with a {@link SymbolBlock}. The output of the block's + * forward function depends on the values in the {@code Parameter}. During training, the values in + * the {@code Parameter} are updated to reflect the training data. This process forms the crux of + * learning. + * + * @see The D2L + * chapter on parameter management + */ +public class Parameter extends MxResource { + private static final Logger logger = LoggerFactory.getLogger(Parameter.class); + + private static final byte VERSION = 1; + + private String id; + private String name; + private Shape shape; + private Type type; + private NDArray array; + + Parameter(Builder builder) { + this.id = UUID.randomUUID().toString(); + this.name = builder.name; + this.shape = builder.shape; + this.type = builder.type; + this.array = builder.array; + } + + /** + * Gets the ID of this {@code Parameter}. + * + * @return the ID of this {@code Parameter} + */ + public String getId() { + return id; + } + + /** + * Gets the name of this {@code Parameter}. + * + * @return the name of this {@code Parameter} + */ + public String getName() { + return name == null ? "" : name; + } + + /** + * Gets the type of this {@code Parameter}. + * + * @return the type of this {@code Parameter} + */ + public Type getType() { + return type; + } + + /** + * Sets the values of this {@code Parameter}. + * + * @param array the {@link NDArray} that contains values of this {@code Parameter} + */ + public void setArray(NDArray array) { + if (shape != null) { + throw new IllegalStateException("array has been set! Use either setArray or setShape"); + } + this.array = array; + shape = array.getShape(); + array.setName(name); + } + + /** + * Sets the shape of this {@code Parameter}. + * + * @param shape the shape of this {@code Parameter} + */ + public void setShape(Shape shape) { + if (array != null) { + throw new IllegalStateException("array has been set! Use either setArray or setShape"); + } + this.shape = shape; + } + + /** + * Gets the values of this {@code Parameter} as an {@link NDArray}. + * + * @return an {@link NDArray} that contains values of this {@code Parameter} + */ + public NDArray getArray() { + if (!isInitialized()) { + throw new IllegalStateException("The array has not been initialized"); + } + return array; + } + + /** + * Checks if this {@code Parameter} is initialized. + * + * @return {@code true} if this {@code Parameter} is initialized + */ + public boolean isInitialized() { + return array != null; + } + + /** + * Initializes the parameter, with given {@link DataType} for the given expected input shapes. + * + * @param parent the parent {@link MxResource} to manage this instance + * @param dataType the datatype of the {@code Parameter} + * @param device the device of {@link NDArray} in the {@code Parameter} + */ + public void initialize(MxResource parent, DataType dataType, Device device) { + Objects.requireNonNull(shape, "No parameter shape has been set"); + } + + /** + * Writes the parameter NDArrays to the given output stream. + * + * @param dos the output stream to write to + * @throws IOException if the write operation fails + */ + public void save(DataOutputStream dos) throws IOException { + if (!isInitialized()) { + dos.writeChar('N'); + return; + } + + dos.writeChar('P'); + dos.writeByte(VERSION); + dos.writeUTF(getName()); + dos.write(array.encode()); + } + + /** + * Loads parameter NDArrays from InputStream. + * + *

Currently, we cannot deserialize into the exact subclass of NDArray. The SparseNDArray + * will be loaded as NDArray only. + * + * @param parent the parent {@link MxResource} to manage this instance + * @param dis the InputStream + * @throws IOException if failed to read (parameters). + */ + public void load(MxResource parent, DataInputStream dis) throws IOException { + char magic = dis.readChar(); + if (magic == 'N') { + return; + } else if (magic != 'P') { + throw new MalformedModelException("Invalid input data."); + } + + // Version + byte version = dis.readByte(); + if (version != VERSION) { + throw new MalformedModelException("Unsupported encoding version: " + version); + } + + String parameterName = dis.readUTF(); + if (!parameterName.equals(getName())) { + throw new MalformedModelException( + "Unexpected parameter name: " + parameterName + ", expected: " + name); + } + + array = NDSerializer.decode(parent, dis); + // set the shape of the parameter and prepare() can be skipped + shape = array.getShape(); + } + + /** {@inheritDoc} */ + @Override + public void close() { + if (!getClosed()) { + logger.debug(String.format("Start to free Symbol instance: %S", this.getUid())); + super.freeSubResources(); + if (array != null) { + array.close(); + array = null; + } + setClosed(true); + logger.debug(String.format("Start to free Symbol instance: %S", this.getUid())); + } + } + + /** + * Creates a builder to build a {@code Parameter}. + * + *

The methods start with {@code set} are required fields, and {@code opt} for optional + * fields. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + /** Enumerates the types of {@link Parameter}. */ + public enum Type { + WEIGHT, + BIAS, + GAMMA, + BETA, + RUNNING_MEAN, + RUNNING_VAR, + OTHER; + } + + /** A Builder to construct a {@code Parameter}. */ + public static final class Builder { + String name; + Shape shape; + Type type; + NDArray array; + + /** + * Sets the name of the {@code Parameter}. + * + * @param name the name of the {@code Parameter} + * @return this {@code Parameter} + */ + public Builder setName(String name) { + this.name = name; + return this; + } + + /** + * Sets the {@code Type} of the {@code Parameter}. + * + * @param type the {@code Type} of the {@code Parameter} + * @return this {@code Parameter} + */ + public Builder setType(Type type) { + this.type = type; + return this; + } + + /** + * Sets the shape of the {@code Parameter}. + * + * @param shape the shape of the {@code Parameter} + * @return this {@code Parameter} + */ + public Builder optShape(Shape shape) { + this.shape = shape; + return this; + } + + /** + * Sets the array of the {@code Parameter}. + * + * @param array the array of the {@code Parameter} + * @return this {@code Parameter} + */ + public Builder optArray(NDArray array) { + this.array = array; + return this; + } + + /** + * Builds a {@code Parameter} instance. + * + * @return the {@code Parameter} instance + */ + public Parameter build() { + return new Parameter(this); + } + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/nn/ParameterList.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/nn/ParameterList.java new file mode 100644 index 000000000000..575972a88d62 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/nn/ParameterList.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.nn; + +import java.util.List; +import java.util.Map; +import org.apache.mxnet.util.Pair; +import org.apache.mxnet.util.PairList; + +/** Represents a set of names and Parameters. */ +public class ParameterList extends PairList { + + /** Create an empty {@code ParameterList}. */ + public ParameterList() {} + + /** + * Constructs an empty {@code ParameterList} with the specified initial capacity. + * + * @param initialCapacity the initial capacity of the list + * @throws IllegalArgumentException if the specified initial capacity is negative + */ + public ParameterList(int initialCapacity) { + super(initialCapacity); + } + + /** + * Constructs a {@code ParameterList} containing the elements of the specified keys and values. + * + * @param keys the key list containing the elements to be placed into this {@code ParameterList} + * @param values the value list containing the elements to be placed into this {@code + * ParameterList} + * @throws IllegalArgumentException if the keys and values size are different + */ + public ParameterList(List keys, List values) { + super(keys, values); + } + + /** + * Constructs a {@code ParameterList} containing the elements of the specified list of Pairs. + * + * @param list the list containing the elements to be placed into this {@code ParameterList} + */ + public ParameterList(List> list) { + super(list); + } + + /** + * Constructs a {@code ParameterList} containing the elements of the specified map. + * + * @param map the map containing keys and values + */ + public ParameterList(Map map) { + super(map); + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/nn/SymbolBlock.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/nn/SymbolBlock.java new file mode 100644 index 000000000000..cb52f38090d4 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/nn/SymbolBlock.java @@ -0,0 +1,556 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.nn; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.apache.mxnet.engine.CachedOp; +import org.apache.mxnet.engine.Device; +import org.apache.mxnet.engine.MxResource; +import org.apache.mxnet.engine.MxResourceList; +import org.apache.mxnet.engine.Symbol; +import org.apache.mxnet.exception.MalformedModelException; +import org.apache.mxnet.jna.JnaUtils; +import org.apache.mxnet.ndarray.NDArray; +import org.apache.mxnet.ndarray.NDList; +import org.apache.mxnet.ndarray.types.DataType; +import org.apache.mxnet.ndarray.types.Shape; +import org.apache.mxnet.util.Pair; +import org.apache.mxnet.util.PairList; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * {@code SymbolBlock} is a {@link MxResource}. It is used to load models that were exported + * directly from the engine in its native format. + */ +public class SymbolBlock extends MxResource { + + private static final Logger logger = LoggerFactory.getLogger(SymbolBlock.class); + + /** The shape of the input for this block, set by the initialization process. */ + protected Shape[] inputShapes; + + /** List of names for the input, named inputs should be manually set in sub class. */ + protected List inputNames = Collections.emptyList(); + + /** + * The model version of this block, used for checking if parameters are still valid during + * parameter loading. + */ + protected byte version; + + /** + * All direct parameters of this Block. Keys are name of the parameters. + * + *

Use the {@link SymbolBlock#addParameter(Parameter)} method to add children. All parameters + * in this map are automatically loaded / saved. + */ + @SuppressWarnings("PMD.UseConcurrentHashMap") + protected Map parameters = new LinkedHashMap<>(); + + private static final byte VERSION = 3; + + private CachedOp op; + private Symbol symbol; + private List mxNetParams; // includes input data + private Map paramShapes; + private Shape[] outputShapes; + private PairList inputDescriptions; + private PairList outputDescriptions; + private boolean first; + + /** + * Constructs a {@code MxSymbolBlock} for a {@link Symbol}. + * + * @param parent the parent MxResource to use for the block + * @param symbol the symbol containing the block's symbolic graph + */ + public SymbolBlock(MxResource parent, Symbol symbol) { + super(); + setParent(parent); + this.symbol = symbol; + initBlock(); + } + + /** + * Constructs an empty {@code MxSymbolBlock}. + * + * @param parent the parent {@code MxSymbolBlock} instance to manage this MxSymbolBlock + */ + private SymbolBlock(MxResource parent) { + super(); + setParent(parent); + } + + /** + * Constructs an {@code MxSymbolBlock} and load the symbol according to {@code Path} The life + * circle of the {@code Symbol} instance is managed by parent {@code MxResource}. + * + * @param parent the parent MxResource Object to manage this MxSymbolBlock + * @param symbolPath the Path to load symbol + * @return created {@code SymbolBlock} instance + */ + public static SymbolBlock createMxSymbolBlock(MxResource parent, Path symbolPath) { + SymbolBlock symbolBlock = new SymbolBlock(parent); + symbolBlock.loadSymbol(symbolPath); + symbolBlock.initBlock(); + return symbolBlock; + } + + private void loadSymbol(Path symbolPath) { + this.symbol = Symbol.loadSymbol(this, symbolPath); + } + + /** + * Sets the names of the input data. + * + * @param inputNames the names of the input data + */ + public void setInputNames(List inputNames) { + this.inputNames = inputNames; + // now that we know which of the parameters are just input placeholders and which + // are trainable, add them properly so they are correctly handled + Set nameLookup = new HashSet<>(inputNames); + for (Parameter mxNetParameter : mxNetParams) { + if (!nameLookup.contains(mxNetParameter.getName())) { + addParameter(mxNetParameter); + } + } + } + + protected final Parameter addParameter(Parameter parameter) { + parameters.put(parameter.getName(), parameter); + return parameter; + } + + /** + * Returns the list of inputs and parameter NDArrays. + * + * @return the list of inputs and parameter NDArrays + */ + public List getAllParameters() { + return mxNetParams; + } + + /** + * Returns the layers' name. + * + * @return a List of String containing the layers' name + */ + public List getLayerNames() { + return symbol.getLayerNames(); + } + + /** + * Returns the Symbolic graph from the model. + * + * @return a {@link Symbol} object + */ + public Symbol getSymbol() { + return symbol; + } + + /** + * Applies Optimization algorithm for the model. + * + * @param optimization the name of the optimization + * @param device the device assigned + */ + public void optimizeFor(String optimization, Device device) { + Symbol newSymbol = symbol.optimizeFor(optimization, device); + symbol.close(); + symbol = newSymbol; + } + + /** + * Returns a {@link PairList} of input names, and shapes. + * + * @return the {@link PairList} of input names, and shapes + */ + public PairList describeInput() { + if (inputDescriptions == null) { + inputDescriptions = new PairList<>(); + for (String name : inputNames) { + // Add empty shapes as input shapes are not saved + // in MXNet models + logger.warn( + "Input shapes are unknown, please run predict or forward once" + + "and call describeInput again."); + inputDescriptions.add(name, new Shape()); + } + } + return inputDescriptions; + } + + /** + * Returns a {@link PairList} of output names and shapes stored in model file. + * + * @return the {@link PairList} of output names, and shapes + */ + public PairList describeOutput() { + if (outputDescriptions == null) { + logger.warn( + "Output shapes are unknown, please run predict or forward once" + + "and call describeOutput again."); + } + return outputDescriptions; + } + + /** + * Applies the operating function of the mxSymbolBlock once. This method should be called only + * on blocks that are initialized. + * + * @param inputs the input NDList + * @param params optional parameters + * @param device device to use + * @return the output of the forward pass + */ + public final NDList forward(NDList inputs, PairList params, Device device) { + + if (!isInitialized()) { + initialize(getParent(), DataType.FLOAT32, device, inputs.getShapes()); + } + return forwardInternal(inputs, params); + } + + /** + * Applies the operating function of the block once. This method should be called only on blocks + * that are initialized. + * + * @param inputs the input NDList + * @return the output of the forward pass + */ + public NDList forward(NDList inputs) { + return forward(inputs, null, getDevice()); + } + + /** + * A forward call using both training data and labels. + * + *

Within this forward call, it can be assumed that training is true. + * + * @param data the input data NDList + * @param labels the input labels NDList + * @param params optional parameters + * @param device the device assigned + * @return the output of the forward pass + * @see #forward(NDList, PairList, Device) + */ + public NDList forward( + NDList data, NDList labels, PairList params, Device device) { + if (!isInitialized()) { + initialize(getParent(), DataType.FLOAT32, device, data.getShapes()); + } + return forwardInternal(data, labels, params); + } + + /** + * A helper for {@link SymbolBlock#forward(NDList, NDList, PairList, Device)} after + * initialization. + * + * @param data the input data NDList + * @param labels the input labels NDList + * @param params optional parameters + * @return the output of the forward pass + * @see #forward(NDList, PairList, Device) + */ + protected NDList forwardInternal(NDList data, NDList labels, PairList params) { + return forwardInternal(data, params); + } + + protected NDList forwardInternal(NDList inputs, PairList params) { + if (first) { + synchronized (SymbolBlock.class) { + if (first) { + // create CachedOp is not thread-safe + // add synchronized block to avoid creating multiple CachedOps + op = JnaUtils.createCachedOp(this, getParent()); + inputDescriptions = new PairList<>(); + outputDescriptions = new PairList<>(); + for (NDArray array : inputs) { + inputDescriptions.add(array.getName(), array.getShape()); + } + NDList outputs = op.forward(inputs); + for (NDArray array : outputs) { + outputDescriptions.add(array.getName(), array.getShape()); + } + first = false; + return outputs; + } + } + } + return op.forward(inputs); + } + + /** + * Returns a boolean whether the {@link SymbolBlock} is initialized. + * + * @return whether the block is initialized + */ + public boolean isInitialized() { + for (Parameter param : getParameters().values()) { + if (!param.isInitialized()) { + return false; + } + } + return true; + } + + /** + * Initializes the parameters of the block. This method must be called before calling `forward`. + * + * @param parent the parent {@link MxResource} to manage this initialized + * @param dataType the datatype of the parameters + * @param device the device of the parameters + * @param inputShapes the shapes of the inputs to the block + */ + public void initialize( + MxResource parent, DataType dataType, Device device, Shape... inputShapes) { + beforeInitialize(inputShapes); + + // no need to initialize() for inference + + for (Parameter parameter : parameters.values()) { + parameter.initialize(parent, dataType, device); + } + initializeChildBlocks(); + } + + /** + * Initializes the Child blocks of this block. You need to override this method if your subclass + * has child blocks. Used to determine the correct input shapes for child blocks based on the + * requested input shape for this block. + */ + protected void initializeChildBlocks() { + if (!getSubResource().isEmpty()) { + throw new IllegalStateException( + getClass().getSimpleName() + + " has child blocks but initializeChildBlocks is not overwritten."); + } + } + + protected void beforeInitialize(Shape... inputShapes) { + if (inputNames.isEmpty()) { + // automatically assign input names + inputNames = new ArrayList<>(); + for (int i = 0; i < inputShapes.length; ++i) { + inputNames.add("data" + i); + } + } + this.inputShapes = inputShapes; + } + + /** + * Returns a list of all the parameters of the block, including the parameters of its children + * fetched recursively. + * + * @return the list of all parameters of the SymbolBlock + */ + public ParameterList getParameters() { + // we accumulate a list of all parameters by starting with a list of the direct parameters + ParameterList allParams = getDirectParameters(); + // then we add the parameters of child blocks + for (Pair childPair : getChildren()) { + if (SymbolBlock.class.equals(childPair.getValue().getClass())) { + SymbolBlock symbolBlock = (SymbolBlock) childPair.getValue(); + for (Pair paramPair : symbolBlock.getParameters()) { + // we prepend the name of the child block to the parameter name + allParams.add( + childPair.getKey() + "_" + paramPair.getKey(), paramPair.getValue()); + } + } + } + return allParams; + } + + /** + * Returns a list of all the children of the SymbolBlock. + * + * @return the list of child blocks + */ + public MxResourceList getChildren() { + MxResourceList defensiveCopy = new MxResourceList(getSubResource().size()); + for (Map.Entry entry : getSubResource().entrySet()) { + defensiveCopy.add(entry.getKey(), entry.getValue()); + } + return defensiveCopy; + } + + /** + * Returns a list of all the direct parameters of the SymbolBlock. + * + * @return the list of {@link Parameter} + */ + public ParameterList getDirectParameters() { + return new ParameterList(parameters); + } + + /** + * Returns the expected output shapes of the SymbolBlock for the specified input shapes. + * + * @param inputShapes the shapes of the inputs + * @return the expected output shapes of the block + */ + public Shape[] getOutputShapes(Shape[] inputShapes) { + if (outputShapes == null) { + String[] outputNames = symbol.getOutputNames(); + outputShapes = new Shape[outputNames.length]; + for (int i = 0; i < outputShapes.length; ++i) { + outputShapes[i] = getParameterShape(outputNames[i], inputShapes); + } + } + return outputShapes; + } + + /** Removes the last block in the symbolic graph. */ + public void removeLastBlock() { + List layerNames = getLayerNames(); + String layerName = layerNames.get(layerNames.size() - 2); + + Symbol sliced = symbol.get(layerName); + symbol.close(); + symbol = sliced; + + HashSet set = new HashSet<>(Arrays.asList(symbol.getAllNames())); + for (int i = mxNetParams.size() - 1; i >= 0; --i) { + Parameter parameter = mxNetParams.get(i); + if (!set.contains(parameter.getName())) { + mxNetParams.remove(i).close(); + parameters.remove(parameter.getName(), parameter); + } + } + } + + private Shape getParameterShape(String name, Shape[] inputShapes) { + if (paramShapes == null) { + PairList pairs = new PairList<>(); + for (int i = 0; i < inputNames.size(); i++) { + pairs.add(inputNames.get(i), inputShapes[i]); + } + paramShapes = symbol.inferShape(pairs); + } + if (paramShapes.containsKey(name)) { + return paramShapes.get(name); + } else { + throw new IllegalArgumentException("Name " + name + " not found"); + } + } + + /** + * Writes the parameters of the SymbolBlock to the given outputStream. + * + * @param os the outputstream to save the parameters to + * @throws IOException if an I/O error occurs + */ + public void saveParameters(DataOutputStream os) throws IOException { + os.writeByte(VERSION); + String json = symbol.toJsonString(); + // symbol size may go beyond os.writeUTF() size (65535) + byte[] bytes = json.getBytes(StandardCharsets.UTF_8); + os.writeInt(bytes.length); + os.write(bytes); + int size = inputNames.size(); + os.writeInt(size); + for (String name : inputNames) { + os.writeUTF(name); + } + for (Parameter parameter : mxNetParams) { + parameter.save(os); + } + } + + /** + * Loads the parameters from the given input stream. + * + * @param parent the parent {@link MxResource} to create the parameter arrays + * @param is the inputstream that stream the parameter values + * @throws IOException if an I/O error occurs + * @throws MalformedModelException if the model file is corrupted or unsupported + */ + public void loadParameters(MxResource parent, DataInputStream is) throws IOException { + Byte currentVersion = is.readByte(); + if (currentVersion > VERSION) { + throw new MalformedModelException("Unsupported encoding version: " + version); + } + if (currentVersion < VERSION && symbol == null) { + throw new IllegalStateException( + "Symbol is required for version 2, please use Model to load"); + } + if (currentVersion == VERSION) { + int len = is.readInt(); + byte[] bytes = new byte[len]; + if (is.read(bytes) == -1) { + throw new MalformedModelException("InputStream ends at symbol loading!"); + } + // init block only if it is not set + symbol = Symbol.loadJson(this, new String(bytes, StandardCharsets.UTF_8)); + initBlock(); + } + int size = is.readInt(); + for (int i = 0; i < size; ++i) { + inputNames.add(is.readUTF()); + } + + for (Parameter parameter : mxNetParams) { + parameter.load(parent, is); + } + setInputNames(inputNames); + } + + private void initBlock() { + inputNames = new ArrayList<>(); + + String[] allNames = symbol.getAllNames(); + mxNetParams = new ArrayList<>(allNames.length); + + for (String name : allNames) { + Parameter.Type type = inferType(name); + mxNetParams.add(Parameter.builder().setName(name).setType(type).build()); + } + first = true; + } + + private static Parameter.Type inferType(String name) { + if (name.endsWith("bias")) { + return Parameter.Type.BIAS; + } else if (name.endsWith("gamma")) { + return Parameter.Type.GAMMA; + } else if (name.endsWith("beta")) { + return Parameter.Type.BETA; + } else if (name.endsWith("moving_mean") || name.endsWith("running_mean")) { + return Parameter.Type.RUNNING_MEAN; + } else if (name.endsWith("moving_var") || name.endsWith("running_var")) { + return Parameter.Type.RUNNING_VAR; + } else if (name.endsWith("weight")) { + return Parameter.Type.WEIGHT; + } + return Parameter.Type.OTHER; + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/nn/package-info.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/nn/package-info.java new file mode 100644 index 000000000000..283e6256ff0e --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/nn/package-info.java @@ -0,0 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Contains the Java front-end implementation for Apache MXNet. */ +package org.apache.mxnet.nn; diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/repository/Item.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/repository/Item.java new file mode 100644 index 000000000000..0f085fcf4b6f --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/repository/Item.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.repository; + +/** + * {@link Item} is some listed repositories where we can download pre-trained models data. It is + * used by developers to download specific models data by initialize a {@link Repository}. + */ +public enum Item { + MLP("mlp", "https://resources.djl.ai/test-models/mlp.tar.gz"); + + private String name; + private String url; + + Item(String name, String url) { + this.name = name; + this.url = url; + } + + /** + * Gets the name of this {@code Item}. + * + * @return the name of this {@code Item} + */ + public String getName() { + return name; + } + + /** + * Gets the URL of this {@code Item} to download. + * + * @return the URL of this {@code Item} + */ + public String getUrl() { + return url; + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/repository/Repository.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/repository/Repository.java new file mode 100644 index 000000000000..2eacee536d18 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/repository/Repository.java @@ -0,0 +1,257 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.repository; + +import java.io.BufferedInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.URI; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardCopyOption; +import java.util.zip.GZIPInputStream; +import java.util.zip.ZipInputStream; +import org.apache.commons.compress.archivers.tar.TarArchiveEntry; +import org.apache.commons.compress.archivers.tar.TarArchiveInputStream; +import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream; +import org.apache.mxnet.util.FilenameUtils; +import org.apache.mxnet.util.Utils; +import org.apache.mxnet.util.ZipUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * {@code Repository} is a format for storing data {@link Item}s for various uses including deep + * learning models and datasets. + */ +public class Repository { + + private static final Logger logger = LoggerFactory.getLogger(Repository.class); + + private String name; + private URI uri; + private Path resourceDir; + + Repository(String name, String uri) { + setName(name); + setUri(URI.create(uri)); + } + + Repository(Item item) { + this(item.getName(), item.getUrl()); + } + + /** + * Initialize a {@link Repository} by a specific {@link Item}, which provides the name for the + * repository and the URL to achieve it. + * + * @param item {@link Item} to initialize the {@link Repository} + * @return {@link Path} of the initialized {@link Repository} + * @throws IOException when fail to prepare the {@link Repository} + */ + public static Path initRepository(Item item) throws IOException { + Repository repository = new Repository(item); + repository.prepare(); + return repository.getLocalDir(); + } + + private void setResourceDir(Path mResourceDir) { + this.resourceDir = mResourceDir; + } + + private Path getResourceDir() { + return resourceDir; + } + + /** + * Returns the local directory to store resources. + * + * @return {@link Path} of the local resource directory + */ + public Path getLocalDir() { + return getResourceDir().resolve(getName()); + } + + /** + * Sets the {@link URI} for the {@link Repository}. + * + * @param uri of the repository + */ + public final void setUri(URI uri) { + this.uri = uri; + } + + /** + * Returns {@link URI} for the {@link Repository}. + * + * @return {@link URI} of the {@link Repository} + */ + public URI getUri() { + return uri; + } + + /** + * Sets the name for the {@link Repository}. + * + * @param name for the {@link Repository} + */ + public final void setName(String name) { + this.name = name; + } + + /** + * Returns the name for the {@link Repository}. + * + * @return name for the {@link Repository} + */ + public String getName() { + return name; + } + + /** + * Prepares the repository for use. + * + * @throws IOException if it failed to prepare + */ + public void prepare() throws IOException { + String uriPath = getUri().getPath(); + if (uriPath != null && !"".equals(uriPath) && uriPath.charAt(0) == '/') { + uriPath = uriPath.substring(1); + } + setResourceDir(getCacheDirectory().resolve(uriPath)); + if (Files.exists(getResourceDir())) { + logger.debug("Files have been downloaded already: {}", getResourceDir()); + return; + } + Path parentDir = getResourceDir().toAbsolutePath().getParent(); + if (parentDir == null) { + throw new AssertionError( + String.format( + "Parent path should never be null: {}", getResourceDir().toString())); + } + + Files.createDirectories(parentDir); + Path tmp = Files.createTempDirectory(parentDir, getResourceDir().toFile().getName()); + + // dismiss Progress related + + try { + logger.debug("Repository to download: {}", getUri().toString()); + download(tmp); + Utils.moveQuietly(tmp, getResourceDir()); + } finally { + Utils.deleteQuietly(tmp); + } + } + + private void download(Path tmp) throws IOException { + logger.debug("Downloading artifact: {} at {}...", getName(), getUri()); + try (InputStream is = getUri().toURL().openStream()) { + String extension = FilenameUtils.getFileType(getUri().getPath()); + save(is, tmp, name, extension, isArchiveFile(extension)); + } + } + + private boolean isArchiveFile(String fileType) { + return "tgz".equals(fileType) || "zip".equals(fileType) || "tar".equals(fileType); + } + + protected void save( + InputStream is, Path tmp, String repoName, String extension, boolean archive) + throws IOException { + // ProgressInputStream pis = new ProgressInputStream(is); + + if (archive) { + Path diretory; + if (!repoName.isEmpty()) { + // honer the name set in metadata.json + diretory = tmp.resolve(repoName); + Files.createDirectories(diretory); + } else { + diretory = tmp; + } + if ("zip".equals(extension)) { + ZipUtils.unzip(is, diretory); + } else if ("tgz".equals(extension)) { + untar(is, diretory, true); + } else if ("tar".equals(extension)) { + untar(is, diretory, false); + } else { + throw new IOException("File type is not supported: " + extension); + } + } else { + Path file = tmp.resolve(repoName); + if ("zip".equals(extension)) { + ZipInputStream zis = new ZipInputStream(is); + zis.getNextEntry(); + Files.copy(zis, file, StandardCopyOption.REPLACE_EXISTING); + } else if ("gzip".equals(extension)) { + Files.copy(new GZIPInputStream(is), file, StandardCopyOption.REPLACE_EXISTING); + } else { + Files.copy(is, file, StandardCopyOption.REPLACE_EXISTING); + } + } + // pis.validateChecksum(item); + } + + private void untar(InputStream is, Path dir, boolean gzip) throws IOException { + InputStream bis; + if (gzip) { + bis = new GzipCompressorInputStream(new BufferedInputStream(is)); + } else { + bis = new BufferedInputStream(is); + } + try (TarArchiveInputStream tis = new TarArchiveInputStream(bis)) { + TarArchiveEntry entry; + while ((entry = tis.getNextTarEntry()) != null) { + String entryName = entry.getName(); + if (entryName.contains("..")) { + throw new IOException("Malicious zip entry: " + entryName); + } + Path file = dir.resolve(entryName).toAbsolutePath(); + if (entry.isDirectory()) { + Files.createDirectories(file); + } else { + Path parentFile = file.getParent(); + if (parentFile == null) { + throw new AssertionError( + "Parent path should never be null: " + file.toString()); + } + Files.createDirectories(parentFile); + Files.copy(tis, file, StandardCopyOption.REPLACE_EXISTING); + } + } + } + } + + /** + * Returns the cache directory for the repository. + * + * @return the cache directory path + * @throws IOException if it failed to ensure the creation of the cache directory + */ + public Path getCacheDirectory() throws IOException { + Path dir = Utils.getCacheDir().resolve("cache/repo"); + if (Files.notExists(dir)) { + Files.createDirectories(dir); + } else if (!Files.isDirectory(dir)) { + throw new IOException("Failed initialize cache directory: " + dir.toString()); + } + return dir; + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/repository/package-info.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/repository/package-info.java new file mode 100644 index 000000000000..6248f96a5c3e --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/repository/package-info.java @@ -0,0 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Contains the Java front-end implementation for Apache MXNet. */ +package org.apache.mxnet.repository; diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/translate/NoOpTranslator.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/translate/NoOpTranslator.java new file mode 100644 index 000000000000..0c5d69f94b27 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/translate/NoOpTranslator.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.translate; + +import org.apache.mxnet.ndarray.NDList; + +/** + * Default no operational implement for {@link Translator} to process input and output {@link + * org.apache.mxnet.ndarray.NDArray}. + */ +public class NoOpTranslator implements Translator { + + /** {@inheritDoc} */ + @Override + public Pipeline getPipeline() { + return Translator.super.getPipeline(); + } + + /** {@inheritDoc} */ + @Override + public NDList processInput(NDList input) { + return input; + } + + /** {@inheritDoc} */ + @Override + public NDList processOutput(NDList output) { + return output; + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/translate/Pipeline.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/translate/Pipeline.java new file mode 100644 index 000000000000..15facc4ff614 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/translate/Pipeline.java @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.translate; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import org.apache.mxnet.ndarray.NDArray; +import org.apache.mxnet.ndarray.NDList; +import org.apache.mxnet.util.Pair; +import org.apache.mxnet.util.PairList; + +/** {@code Pipeline} allows applying multiple transforms on an input {@link NDList}. */ +public class Pipeline { + + private PairList transforms; + + /** Creates a new instance of {@code Pipeline} that has no {@link Transform} defined yet. */ + public Pipeline() { + transforms = new PairList<>(); + } + + /** + * Creates a new instance of {@code Pipeline} that can apply the given transforms on its input. + * + *

Since no keys are provided for these transforms, they will be applied to the first element + * in the input {@link NDList} when the {@link #transform(NDList) transform} method is called on + * this object. + * + * @param transforms the transforms to be applied when the {@link #transform(NDList) transform} + * method is called on this object + */ + public Pipeline(Transform... transforms) { + this.transforms = new PairList<>(); + for (Transform transform : transforms) { + this.transforms.add(new IndexKey(0), transform); + } + } + + /** + * Adds the given {@link Transform} to the list of transforms to be applied on the input when + * the {@link #transform(NDList) transform} method is called on this object. + * + *

Since no keys are provided for this {@link Transform}, it will be applied to the first + * element in the input {@link NDList}. + * + * @param transform the {@link Transform} to be added + * @return this {@code Pipeline} + */ + public Pipeline add(Transform transform) { + transforms.add(new IndexKey(0), transform); + return this; + } + + /** + * Adds the given {@link Transform} to the list of transforms to be applied on the {@link + * NDArray} at the given index in the input {@link NDList}. + * + * @param index the index corresponding to the {@link NDArray} in the input {@link NDList} on + * which the given transform must be applied to + * @param transform the {@link Transform} to be added + * @return this {@code Pipeline} + */ + public Pipeline add(int index, Transform transform) { + transforms.add(new IndexKey(index), transform); + return this; + } + + /** + * Adds the given {@link Transform} to the list of transforms to be applied on the {@link + * NDArray} with the given key as name in the input {@link NDList}. + * + * @param name the key corresponding to the {@link NDArray} in the input {@link NDList} on which + * the given transform must be applied to + * @param transform the {@code Transform} to be applied when the {@link #transform(NDList) + * transform} method is called on this object + * @return this {@code Pipeline} + */ + public Pipeline add(String name, Transform transform) { + transforms.add(new IndexKey(name), transform); + return this; + } + + /** + * Inserts the given {@link Transform} to the list of transforms at the given position. + * + *

Since no keys or indices are provided for this {@link Transform}, it will be applied to + * the first element in the input {@link NDList} when the {@link #transform(NDList) transform} + * method is called on this object. + * + * @param position the position at which the {@link Transform} must be inserted + * @param transform the {@code Transform} to be inserted + * @return this {@code Pipeline} + */ + public Pipeline insert(int position, Transform transform) { + transforms.add(position, new IndexKey(0), transform); + return this; + } + + /** + * Inserts the given {@link Transform} to the list of transforms at the given position to be + * applied on the {@link NDArray} at the given index in the input {@link NDList}. + * + * @param position the position at which the {@link Transform} must be inserted + * @param index the index corresponding to the {@link NDArray} in the input {@link NDList} on + * which the given transform must be applied to + * @param transform the {@code Transform} to be inserted + * @return this {@code Pipeline} + */ + public Pipeline insert(int position, int index, Transform transform) { + transforms.add(position, new IndexKey(index), transform); + return this; + } + + /** + * Inserts the given {@link Transform} to the list of transforms at the given position to be + * applied on the {@link NDArray} with the given name in the input {@link NDList}. + * + * @param position the position at which the {@link Transform} must be inserted + * @param name the key corresponding to the {@link NDArray} in the input {@link NDList} on which + * the given transform must be applied to + * @param transform the {@code Transform} to be inserted + * @return this {@code Pipeline} + */ + public Pipeline insert(int position, String name, Transform transform) { + transforms.add(position, new IndexKey(name), transform); + return this; + } + + /** + * Applies the transforms configured in this object on the input {@link NDList}. + * + *

If a key is specified with the transform, those transforms will only be applied to the + * {@link NDArray} in the input {@link NDList}. If a key is not specified, it will be applied to + * the first element in the input {@link NDList}. + * + * @param input the input {@link NDList} on which the tranforms are to be applied + * @return the output {@link NDList} after applying the tranforms + */ + public NDList transform(NDList input) { + if (transforms.isEmpty() || input.isEmpty()) { + return input; + } + + NDArray[] arrays = input.toArray(new NDArray[0]); + + Map map = new ConcurrentHashMap<>(); + // create mapping + for (int i = 0; i < input.size(); i++) { + String key = input.get(i).getName(); + if (key != null) { + map.put(new IndexKey(key), i); + } + map.put(new IndexKey(i), i); + } + // apply transform + for (Pair transform : transforms) { + IndexKey key = transform.getKey(); + int index = map.get(key); + NDArray array = arrays[index]; + + arrays[index] = transform.getValue().transform(array); + arrays[index].setName(array.getName()); + } + + return new NDList(arrays); + } + + private static final class IndexKey { + private String key; + private int index; + + private IndexKey(String key) { + this.key = key; + } + + private IndexKey(int index) { + this.index = index; + } + + /** {@inheritDoc} */ + @Override + public int hashCode() { + if (key == null) { + return index; + } + return key.hashCode(); + } + + /** {@inheritDoc} */ + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof IndexKey)) { + return false; + } + IndexKey other = (IndexKey) obj; + if (key == null) { + return index == other.index; + } + return key.equals(other.key); + } + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/translate/Processor.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/translate/Processor.java new file mode 100644 index 000000000000..cee04d05ace5 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/translate/Processor.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.translate; + +import org.apache.mxnet.ndarray.NDList; + +/** + * An interface that provides pre-processing and post-processing functionality. + * + * @param the type of the input object + */ +public interface Processor { + + /** + * Gets the {@link Pipeline} applied to the input. + * + * @return the {@link Pipeline} + */ + default Pipeline getPipeline() { + throw new UnsupportedOperationException("Not implemented."); + } + + /** + * Processes the input and converts it to NDList. + * + * @param input the input object + * @return the {@link NDList} after pre-processing + * @throws Exception if an error occurs during processing input + */ + @SuppressWarnings("PMD.SignatureDeclareThrowsException") + NDList processInput(I input) throws Exception; + + /** + * Processes the input and converts it to NDList. + * + * @param output the input object + * @return the {@link NDList} after pre-processing + * @throws Exception if an error occurs during processing input + */ + @SuppressWarnings("PMD.SignatureDeclareThrowsException") + O processOutput(NDList output) throws Exception; +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/translate/Transform.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/translate/Transform.java new file mode 100644 index 000000000000..8e24304b6e57 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/translate/Transform.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.translate; + +import org.apache.mxnet.ndarray.NDArray; + +/** + * An interface to apply various transforms to the input. + * + *

A transform can be any function that modifies the input. Some examples of transform are crop + * and resize. + */ +// TODO : not used by now +public interface Transform { + /** + * Applies the {@code Transform} to the given {@link NDArray}. + * + * @param array the {@link NDArray} on which the {@link Transform} is applied + * @return the output of the {@code Transform} + */ + NDArray transform(NDArray array); +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/translate/Translator.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/translate/Translator.java new file mode 100644 index 000000000000..c476e243d3d3 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/translate/Translator.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.translate; + +import java.io.IOException; +import org.apache.mxnet.engine.Model; +import org.apache.mxnet.engine.Predictor; + +/** + * The {@code Translator} interface provides model pre-processing and postprocessing functionality. + * + *

Users can use this in {@link Predictor} with input and output objects specified. The following + * is an example of processing an image and creating classification output: + * + * @param the input type + * @param the output type + */ +public interface Translator extends Processor { + // TODO: implement getPipeline() and related methods + /** + * Prepares the translator with the manager and model to use. + * + * @param model the model to translate for + * @throws IOException if there is an error reading inputs for preparing the translator + */ + default void prepare(Model model) throws IOException {} +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/translate/package-info.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/translate/package-info.java new file mode 100644 index 000000000000..5aaeeb53c454 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/translate/package-info.java @@ -0,0 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Contains the Java front-end implementation for Apache MXNet. */ +package org.apache.mxnet.translate; diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/FilenameUtils.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/FilenameUtils.java new file mode 100644 index 000000000000..54938932bfc2 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/FilenameUtils.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mxnet.util; + +import java.util.Locale; + +/** A class containing utility methods. */ +public final class FilenameUtils { + + private FilenameUtils() {} + + /** + * Returns the type of the file. + * + * @param fileName the file name + * @return the type of the file + */ + public static String getFileType(String fileName) { + fileName = fileName.toLowerCase(Locale.ROOT); + if (fileName.endsWith(".zip")) { + return "zip"; + } else if (fileName.endsWith(".tgz") + || fileName.endsWith(".tar.gz") + || fileName.endsWith(".tar.z")) { + return "tgz"; + } else if (fileName.endsWith(".tar")) { + return "tar"; + } else if (fileName.endsWith(".gz") || fileName.endsWith(".z")) { + return "gzip"; + } else { + return ""; + } + } + + /** + * Returns if the the file is an archive file. + * + * @param fileName the file name + * @return the type of the file + */ + public static boolean isArchiveFile(String fileName) { + String fileType = getFileType(fileName); + return "tgz".equals(fileType) || "zip".equals(fileType) || "tar".equals(fileType); + } + + /** + * Returns the name of the file without file extension. + * + * @param name the file name + * @return the name of the file without file extension + */ + public static String getNamePart(String name) { + String lowerCase = name.toLowerCase(Locale.ROOT); + if (lowerCase.endsWith(".tar.gz")) { + return name.substring(0, name.length() - 7); + } else if (name.endsWith(".tar.z")) { + return name.substring(0, name.length() - 6); + } else if (name.endsWith(".tgz") || name.endsWith(".zip") || name.endsWith(".tar")) { + return name.substring(0, name.length() - 4); + } else if (name.endsWith(".gz")) { + return name.substring(0, name.length() - 3); + } else if (name.endsWith(".z")) { + return name.substring(0, name.length() - 2); + } + return name; + } + + /** + * Returns the file name extension of the file. + * + * @param fileName the file name + * @return the file name extension + */ + public static String getFileExtension(String fileName) { + int pos = fileName.lastIndexOf('.'); + if (pos > 0) { + return fileName.substring(pos + 1); + } + return ""; + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/Float16Utils.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/Float16Utils.java new file mode 100644 index 000000000000..5961272f6858 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/Float16Utils.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.util; + +import java.nio.ByteBuffer; +import java.nio.ShortBuffer; +import org.apache.mxnet.ndarray.NDSerializer; + +/** {@code Float16Utils} is a set of utilities for working with float16. */ +@SuppressWarnings("PMD.AvoidUsingShortType") +public final class Float16Utils { + + private Float16Utils() {} + + /** + * Converts a byte buffer of float16 values into a float32 array. + * + * @param buffer the buffer of float16 values as bytes. + * @return an array of float32 values. + */ + public static float[] fromByteBuffer(ByteBuffer buffer) { + return fromShortBuffer(buffer.asShortBuffer()); + } + + /** + * Converts a short buffer of float16 values into a float32 array. + * + * @param buffer the buffer of float16 values as shorts. + * @return an array of float32 values. + */ + public static float[] fromShortBuffer(ShortBuffer buffer) { + int index = 0; + float[] ret = new float[buffer.remaining()]; + while (buffer.hasRemaining()) { + short value = buffer.get(); + ret[index++] = halfToFloat(value); + } + return ret; + } + + /** + * Converts an array of float32 values into a byte buffer of float16 values. + * + * @param floats an array of float32 values. + * @return a byte buffer with float16 values represented as shorts (2 bytes each). + */ + public static ByteBuffer toByteBuffer(float[] floats) { + ByteBuffer buffer = NDSerializer.allocateDirect(floats.length * 2); + for (float f : floats) { + short value = floatToHalf(f); + buffer.putShort(value); + } + buffer.rewind(); + return buffer; + } + + /** + * Converts a float32 value into a float16 value. + * + * @param fVal a float32 value. + * @return a float16 value represented as a short. + */ + public static short floatToHalf(float fVal) { + int bits = Float.floatToIntBits(fVal); + int sign = bits >>> 16 & 0x8000; + int val = (bits & 0x7fffffff) + 0x1000; + if (val >= 0x47800000) { + if ((bits & 0x7fffffff) >= 0x47800000) { + if (val < 0x7f800000) { + return (short) (sign | 0x7c00); + } + return (short) (sign | 0x7c00 | (bits & 0x007fffff) >>> 13); + } + return (short) (sign | 0x7bff); + } + if (val >= 0x38800000) { + return (short) (sign | val - 0x38000000 >>> 13); + } + if (val < 0x33000000) { + return (short) sign; + } + val = (bits & 0x7fffffff) >>> 23; + return (short) + (sign | ((bits & 0x7fffff | 0x800000) + (0x800000 >>> val - 102) >>> 126 - val)); + } + + /** + * Converts a float16 value into a float32 value. + * + * @param half a float16 value represented as a short. + * @return a float32 value. + */ + public static float halfToFloat(short half) { + int mant = half & 0x03ff; + int exp = half & 0x7c00; + if (exp == 0x7c00) { + exp = 0x3fc00; + } else if (exp != 0) { + exp += 0x1c000; + if (mant == 0 && exp > 0x1c400) { + return Float.intBitsToFloat((half & 0x8000) << 16 | exp << 13); + } + } else if (mant != 0) { + exp = 0x1c400; + do { + mant <<= 1; + exp -= 0x400; + } while ((mant & 0x400) == 0); + mant &= 0x3ff; + } + return Float.intBitsToFloat((half & 0x8000) << 16 | (exp | mant) << 13); + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/NativeResource.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/NativeResource.java new file mode 100644 index 000000000000..e966ef4ef5e9 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/NativeResource.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.util; + +import com.sun.jna.Pointer; +import java.util.concurrent.atomic.AtomicReference; + +/** + * {@code NativeResource} is an internal class for {@link AutoCloseable} blocks of memory. + * + * @param the resource that could map to a native pointer or java object + */ +public abstract class NativeResource implements AutoCloseable { + + protected final AtomicReference handle; + private String uid; + + protected NativeResource(T handle) { + this.handle = new AtomicReference<>(handle); + this.uid = handle.toString(); + } + + protected NativeResource() { + this.handle = null; + this.uid = null; + } + + /** + * To initialize a NativeResource with handle = null. + * + * @param uid for the {@link NativeResource} + */ + protected NativeResource(String uid) { + this.handle = null; + this.uid = uid; + } + + /** + * Gets the boolean that indicates whether this resource has been released. + * + * @return whether this resource has been released + */ + public boolean isReleased() { + return handle.get() == null; + } + + /** + * Gets the {@link Pointer} to this resource. + * + * @return the {@link Pointer} to this resource + */ + public T getHandle() { + T reference = handle.get(); + if (reference == null) { + throw new IllegalStateException("Native resource has been release already."); + } + return reference; + } + + /** + * Gets the unique ID of this resource. + * + * @return the unique ID of this resource + */ + public final String getUid() { + return uid; + } + + /** {@inheritDoc} */ + @Override + public void close() { + throw new UnsupportedOperationException("Not implemented."); + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/Pair.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/Pair.java new file mode 100644 index 000000000000..8f1da3258c1f --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/Pair.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.util; + +import java.util.Objects; + +/** + * A class containing the key-value pair. + * + * @param the key type + * @param the value type + */ +public class Pair { + + private K key; + private V value; + + /** + * Constructs a {@code Pair} instance with key and value. + * + * @param key the key + * @param value the value + */ + public Pair(K key, V value) { + this.key = key; + this.value = value; + } + + /** + * Returns the key of this pair. + * + * @return the key + */ + public K getKey() { + return key; + } + + /** + * Returns the value of this pair. + * + * @return the value + */ + public V getValue() { + return value; + } + + /** {@inheritDoc} */ + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Pair pair = (Pair) o; + return Objects.equals(key, pair.key) && value.equals(pair.value); + } + + /** {@inheritDoc} */ + @Override + public int hashCode() { + return Objects.hash(key, value); + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/PairList.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/PairList.java new file mode 100644 index 000000000000..c803ceb021f2 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/PairList.java @@ -0,0 +1,380 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.util; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; + +/** + * The {@code PairList} class provides an efficient way to access a list of key-value pairs. + * + * @param the key type + * @param the value type + */ +public class PairList implements Iterable> { + + private List keys; + private List values; + + /** Constructs an empty {@code PairList}. */ + public PairList() { + keys = new ArrayList<>(); + values = new ArrayList<>(); + } + + /** + * Constructs an empty {@code PairList} with the specified initial capacity. + * + * @param initialCapacity the initial capacity of the list + * @throws IllegalArgumentException if the specified initial capacity is negative + */ + public PairList(int initialCapacity) { + keys = new ArrayList<>(initialCapacity); + values = new ArrayList<>(initialCapacity); + } + + /** + * Constructs a {@code PairList} containing the elements of the specified keys and values. + * + * @param keys the key list containing elements to be placed into this PairList + * @param values the value list containing elements to be placed into this PairList + * @throws IllegalArgumentException if the keys and values size are different + */ + public PairList(List keys, List values) { + if (keys.size() != values.size()) { + throw new IllegalArgumentException("key value size mismatch."); + } + this.keys = keys; + this.values = values; + } + + /** + * Constructs a {@code PairList} containing the elements of the specified list of Pairs. + * + * @param list the list containing elements to be placed into this PairList + */ + public PairList(List> list) { + this(list.size()); + for (Pair pair : list) { + keys.add(pair.getKey()); + values.add(pair.getValue()); + } + } + + /** + * Constructs a {@code PairList} containing the elements of the specified map. + * + * @param map the map contains keys and values + */ + public PairList(Map map) { + keys = new ArrayList<>(map.size()); + values = new ArrayList<>(map.size()); + for (Map.Entry entry : map.entrySet()) { + keys.add(entry.getKey()); + values.add(entry.getValue()); + } + } + + /** + * Inserts the specified element at the specified position in this list (optional operation), + * and shifts the element currently at that position (if any) and any subsequent elements to the + * right (adds one to their indices). + * + * @param index the index at which the specified element is to be inserted + * @param key the key + * @param value the value + */ + public void add(int index, K key, V value) { + keys.add(index, key); + values.add(index, value); + } + + /** + * Adds a key and value to the list. + * + * @param key the key + * @param value the value + */ + public void add(K key, V value) { + keys.add(key); + values.add(value); + } + + /** + * Appends all of the elements in the specified pair list to the end of this list. + * + * @param other the {@code PairList} containing elements to be added to this list + */ + public void addAll(PairList other) { + if (other != null) { + keys.addAll(other.keys); + values.addAll(other.values); + } + } + + /** + * Returns the size of the list. + * + * @return the size of the list + */ + public int size() { + return keys.size(); + } + + /** + * Checks whether the list is empty. + * + * @return whether the list is empty + */ + public boolean isEmpty() { + return size() == 0; + } + + /** + * Returns the key-value pair at the specified position in this list. + * + * @param index the index of the element to return + * @return the key-value pair at the specified position in this list + */ + public Pair get(int index) { + return new Pair<>(keys.get(index), values.get(index)); + } + + /** + * Returns the value for the first key found in the list. + * + * @param key the key of the element to get + * @return the value for the first key found in the list + */ + public V get(K key) { + int index = keys.indexOf(key); + if (index == -1) { + return null; + } + return values.get(index); + } + + /** + * Returns the key at the specified position in this list. + * + * @param index the index of the element to return + * @return the key at the specified position in this list + */ + public K keyAt(int index) { + return keys.get(index); + } + + /** + * Returns the value at the specified position in this list. + * + * @param index the index of the element to return + * @return the value at the specified position in this list + */ + public V valueAt(int index) { + return values.get(index); + } + + /** + * Returns all keys of the list. + * + * @return all keys of the list + */ + public List keys() { + return keys; + } + + /** + * Returns all values of the list. + * + * @return all values of the list + */ + public List values() { + return values; + } + + /** + * Returns an array containing all of the keys in this list in proper sequence (from first to + * last element); the runtime type of the returned array is that of the specified array. + * + *

If the list fits in the specified array, it is returned therein. Otherwise, a new array is + * allocated with the runtime type of the specified array and the size of this list. + * + * @param target the array into which the keys of this list are to be stored, if it is big + * enough; otherwise, a new array of the same runtime type is allocated for this purpose. + * @return an array containing the keys of this list + */ + public K[] keyArray(K[] target) { + return keys.toArray(target); + } + + /** + * Returns an array containing all of the values in this list in proper sequence (from first to + * last element); the runtime type of the returned array is that of the specified array. + * + *

If the list fits in the specified array, it is returned therein. Otherwise, a new array is + * allocated with the runtime type of the specified array and the size of this list. + * + * @param target the array into which the values of this list are to be stored, if it is big + * enough; otherwise, a new array of the same runtime type is allocated for this purpose. + * @return an array containing the values of this list + */ + public V[] valueArray(V[] target) { + return values.toArray(target); + } + + /** + * Removes the key-value pair for the first key found in the list. + * + * @param key the key of the element to be removed + * @return the value of the removed element, {@code null} if not found + */ + public V remove(K key) { + int index = keys.indexOf(key); + if (index == -1) { + return null; + } + return remove(index); + } + + /** + * Removes the key-value pair at an index. + * + * @param index the index of the element to remove + * @return the value of the removed element, {@code null} if not found + */ + public V remove(int index) { + keys.remove(index); + return values.remove(index); + } + + /** + * Returns a view of the portion of this PairList between the specified {@code fromIndex} + * inclusive, and to the end. + * + * @param fromIndex the start index (inclusive) + * @return a view of the portion of this PairList + */ + public PairList subList(int fromIndex) { + return subList(fromIndex, size()); + } + + /** + * Returns a view of the portion of this PairList between the specified {@code fromIndex} + * inclusive, and {@code toIndex}, exclusive. + * + * @param fromIndex the start index (inclusive) + * @param toIndex the end index (exclusive) + * @return a view of the portion of this PairList + */ + public PairList subList(int fromIndex, int toIndex) { + List subKeys = keys.subList(fromIndex, toIndex); + List subValues = values.subList(fromIndex, toIndex); + return new PairList<>(subKeys, subValues); + } + + /** + * Returns the {@link Stream} type of the PairList. + * + * @return a {@link Stream} of PairList + */ + public Stream> stream() { + return StreamSupport.stream(spliterator(), false); + } + + /** + * Returns {@code true} if this list contains the specified key. + * + * @param key the key whose presence will be tested + * @return {@code true} if this list contains the specified key + */ + public boolean contains(K key) { + return keys.contains(key); + } + + /** + * Removes all duplicate values from the list. + * + * @return a new {@code PairList} with the duplicate values removed, taking the latest value for + * each key + */ + public PairList unique() { + return new PairList<>(toMap(false)); + } + + /** + * Returns a {@code Map} that contains the key-value mappings of this list. + * + * @return a {@code Map} that contains the key-value mappings of this list + */ + public Map toMap() { + return toMap(true); + } + + /** + * Returns a {@code Map} that contains the key-value mappings of this list. + * + * @param checkDuplicate whether to check for duplicated keys in the list + * @return a {@code Map} that contains the key-value mappings of this list + */ + public Map toMap(boolean checkDuplicate) { + int size = keys.size(); + Map map = new ConcurrentHashMap<>(size * 3 / 2); + for (int i = 0; i < size; ++i) { + if (map.put(keys.get(i), values.get(i)) != null && checkDuplicate) { + throw new IllegalStateException("Duplicate keys: " + keys.get(i)); + } + } + return map; + } + + @Override + public Iterator> iterator() { + return new Itr(); + } + + /** Internal Iterator implementation. */ + private class Itr implements Iterator> { + + private int cursor; + private int size = size(); + + Itr() {} + + /** {@inheritDoc} */ + @Override + public boolean hasNext() { + return cursor < size; + } + + /** {@inheritDoc} */ + @Override + public Pair next() { + if (cursor >= size) { + throw new NoSuchElementException(); + } + + return get(cursor++); + } + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/Platform.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/Platform.java new file mode 100644 index 000000000000..e7e8c72b473c --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/Platform.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.util; + +import java.io.IOException; +import java.io.InputStream; +import java.net.URL; +import java.util.Properties; +import org.apache.mxnet.util.cuda.CudaUtils; + +/** + * The platform contains information regarding the version, os, and build flavor of the MXNet native + * code. + */ +public final class Platform { + + private String version; + private String osPrefix; + private String flavor; + private String cudaArch; + private String[] libraries; + private boolean placeholder; + + /** Constructor used only for {@link Platform#fromSystem()}. */ + private Platform() {} + + /** + * Returns the platform that parsed from "engine".properties file. + * + * @param url the url to the "engine".properties file + * @return the platform that parsed from mxnet.properties file + * @throws IOException if the file could not be read + */ + public static Platform fromUrl(URL url) throws IOException { + Platform platform = Platform.fromSystem(); + try (InputStream conf = url.openStream()) { + Properties prop = new Properties(); + prop.load(conf); + // 1.6.0 later should always has version property + platform.version = prop.getProperty("version"); + if (platform.version == null) { + throw new IllegalArgumentException( + "version key is required in .properties file."); + } + platform.placeholder = prop.getProperty("placeholder") != null; + String flavorPrefixedClassifier = prop.getProperty("classifier", ""); + String libraryList = prop.getProperty("libraries", ""); + if (libraryList.isEmpty()) { + platform.libraries = new String[0]; + } else { + platform.libraries = libraryList.split(","); + } + if (!flavorPrefixedClassifier.isEmpty()) { + platform.flavor = flavorPrefixedClassifier.split("-")[0]; + platform.osPrefix = flavorPrefixedClassifier.split("-")[1]; + } + } + return platform; + } + + /** + * Returns the platform for the current system. + * + * @return the platform for the current system + */ + public static Platform fromSystem() { + Platform platform = new Platform(); + String osName = System.getProperty("os.name"); + if (osName.startsWith("Win")) { + platform.osPrefix = "win"; + } else if (osName.startsWith("Mac")) { + platform.osPrefix = "osx"; + } else if (osName.startsWith("Linux")) { + platform.osPrefix = "linux"; + } else { + throw new AssertionError(String.format("Unsupported platform: %s", osName)); + } + if (CudaUtils.getGpuCount() > 0) { + platform.flavor = "cu" + CudaUtils.getCudaVersionString(); + platform.cudaArch = CudaUtils.getComputeCapability(0); + } else { + platform.flavor = ""; + } + return platform; + } + + /** + * Returns the Engine Version. + * + * @return the Engine version + */ + public String getVersion() { + return version; + } + + /** + * Returns the os (win, osx, or linux). + * + * @return the os (win, osx, or linux) + */ + public String getOsPrefix() { + return osPrefix; + } + + /** + * Returns the MXNet build flavor. + * + * @return the MXNet build flavor + */ + public String getFlavor() { + return flavor; + } + + /** + * Returns the classifier for the platform. + * + * @return the classifier for the platform + */ + public String getClassifier() { + return getOsPrefix() + "-x86_64"; + } + + /** + * Returns the cuda arch. + * + * @return the cuda arch + */ + public String getCudaArch() { + return cudaArch; + } + + /** + * Returns the libraries used in the platform. + * + * @return the libraries used in the platform + */ + public String[] getLibraries() { + return libraries; + } + + /** + * Returns true if the platform is a placeholder. + * + * @return true if the platform is a placeholder + */ + public boolean isPlaceholder() { + return placeholder; + } + + /** + * Returns true the platforms match (os and flavor). + * + * @param system the platform to compare it to + * @return true if the platforms match + */ + public boolean matches(Platform system) { + if (!osPrefix.equals(system.osPrefix)) { + return false; + } + // if system Machine is GPU + if (system.flavor.startsWith("cu")) { + // system flavor doesn't contain mkl, but MXNet has: cu110mkl + return "".equals(flavor) + || "cpu".equals(flavor) + || "mkl".equals(flavor) + || flavor.startsWith(system.flavor); + } + return "".equals(flavor) || "cpu".equals(flavor) || "mkl".equals(flavor); + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/Progress.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/Progress.java new file mode 100644 index 000000000000..e6d46dffc48e --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/Progress.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.util; + +/** An interface that allows tracking the progress of a task. */ +public interface Progress { + + /** + * Resets the progress tracking indicators, and sets the message and max to the given values. + * + * @param message the message to be shown + * @param max the max value that the progress tracking indicator can take + */ + default void reset(String message, long max) { + reset(message, max, null); + } + + /** + * Resets the progress tracking indicators, and sets the message and max to the given values. + * + * @param message the message to be shown + * @param max the max value that the progress tracking indicator can take + * @param trailingMessage the trailing message to be shown + */ + void reset(String message, long max, String trailingMessage); + + /** + * Starts tracking the progress of the progress tracking indicators at the given initial value. + * + * @param initialProgress the initial value of the progress + */ + void start(long initialProgress); + + /** Updates the tracking indicators to indicate that the task is complete. */ + void end(); + + /** + * Increments the progress tracking indicator by the given value. + * + * @param increment the value to increment the progress by + */ + void increment(long increment); + + /** + * Updates the progress tracking indicator to the given value. + * + * @param progress the value of the progress tracking indicator + */ + default void update(long progress) { + update(progress, null); + } + + /** + * Updates the progress tracking indicator to the given value, and displays the optional + * message. + * + * @param progress the value of the progress tracking indicator + * @param message the optional message + */ + void update(long progress, String message); +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/StandardCapabilities.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/StandardCapabilities.java new file mode 100644 index 000000000000..ce5cc80149ff --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/StandardCapabilities.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.util; + +/** Constant definitions for the standard capability. */ +public final class StandardCapabilities { + + public static final String CUDA = "CUDA"; + public static final String CUDNN = "CUDNN"; + public static final String MKL = "MKL"; + public static final String MKLDNN = "MKLDNN"; + public static final String OPENMP = "OPENMP"; + + private StandardCapabilities() {} +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/Utils.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/Utils.java new file mode 100644 index 000000000000..35a3cc5c7a82 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/Utils.java @@ -0,0 +1,353 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.util; + +import java.io.ByteArrayOutputStream; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.nio.file.StandardCopyOption; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Objects; +import java.util.Scanner; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +/** A class containing utility methods. */ +public final class Utils { + + private static final int BUFF_SIZE = 81920; + private static final String ENGINE_CACHE_DIR = "ENGINE_CACHE_DIR"; + private static final String MXNET_CACHE_DIR = "MXNET_CACHE_DIR"; + + private Utils() {} + + /** + * Returns the index of the first occurrence of the specified element in {@code array}, or -1 if + * this list does not contain the element. + * + * @param array the input array + * @param value the element to search for + * @param the array type + * @return the index of the first occurrence of the specified element in {@code array}, or -1 if + * this list does not contain the element + */ + public static int indexOf(T[] array, T value) { + if (array != null) { + if (value == null) { + for (int i = 0; i < array.length; ++i) { + if (array[i] == null) { + return i; + } + } + } else { + for (int i = 0; i < array.length; ++i) { + if (value.equals(array[i])) { + return i; + } + } + } + } + return -1; + } + + /** + * Returns {@code true} if the {@code array} contains the specified element. + * + * @param array the input array + * @param value the element whose presence in {@code array} is to be tested + * @param the array type + * @return {@code true} if this list contains the specified element + */ + public static boolean contains(T[] array, T value) { + return indexOf(array, value) >= 0; + } + + /** + * Adds padding chars to specified StringBuilder. + * + * @param sb the StringBuilder to append + * @param c the padding char + * @param count the number characters to be added + */ + public static void pad(StringBuilder sb, char c, int count) { + for (int i = 0; i < count; ++i) { + sb.append(c); + } + } + + /** + * Deletes an entire directory and ignore all errors. + * + * @param dir the directory to be removed + */ + public static void deleteQuietly(Path dir) { + try { + Files.walk(dir) + .sorted(Comparator.reverseOrder()) + .forEach( + path -> { + try { + Files.deleteIfExists(path); + } catch (IOException ignore) { + // ignore + } + }); + } catch (IOException ignore) { + // ignore + } + } + + /** + * Renames a file to a target file and ignore error if target already exists. + * + * @param source the path to the file to move + * @param target the path to the target file + * @throws IOException if move file failed + */ + public static void moveQuietly(Path source, Path target) throws IOException { + try { + Files.move(source, target, StandardCopyOption.ATOMIC_MOVE); + } catch (IOException e) { + if (!Files.exists(target)) { + throw e; + } + } + } + + /** + * Reads {@code is} as UTF-8 string. + * + * @param is the InputStream to be read + * @return a UTF-8 encoded string + * @throws IOException if IO error occurs + */ + public static String toString(InputStream is) throws IOException { + return null; + } + + /** + * Reads {@code is} as byte array. + * + * @param is the InputStream to be read + * @return a byte array + * @throws IOException if IO error occurs + */ + public static byte[] toByteArray(InputStream is) throws IOException { + + try (ByteArrayOutputStream bos = new ByteArrayOutputStream(BUFF_SIZE)) { + byte[] buf = new byte[BUFF_SIZE]; + int read; + while ((read = is.read(buf)) != -1) { + bos.write(buf, 0, read); + } + return bos.toByteArray(); + } + } + + /** + * Reads all lines from a file. + * + * @param file the file to be read + * @return all lines in the file + * @throws IOException if read file failed + */ + public static List readLines(Path file) throws IOException { + return readLines(file, false); + } + + /** + * Reads all lines from a file. + * + * @param file the file to be read + * @param trim true if you want to trim the line and exclude empty lines + * @return all lines in the file + * @throws IOException if read file failed + */ + public static List readLines(Path file, boolean trim) throws IOException { + if (Files.notExists(file)) { + return Collections.emptyList(); + } + try (InputStream is = Files.newInputStream(file)) { + return readLines(is, trim); + } + } + + /** + * Reads all lines from the specified InputStream. + * + * @param is the InputStream to read + * @return all lines from the input + */ + public static List readLines(InputStream is) { + return readLines(is, false); + } + + /** + * Reads all lines from the specified InputStream. + * + * @param is the InputStream to read + * @param trim true if you want to trim the line and exclude empty lines + * @return all lines from the input + */ + public static List readLines(InputStream is, boolean trim) { + List list = new ArrayList<>(); + try (Scanner scanner = + new Scanner(is, StandardCharsets.UTF_8.name()).useDelimiter("\\n|\\r\\n")) { + while (scanner.hasNext()) { + String line = scanner.next(); + if (trim) { + line = line.trim(); + if (line.isEmpty()) { + continue; + } + } + list.add(line); + } + } + return list; + } + + /** + * Converts a List of Number to float array. + * + * @param list the list to be converted + * @return a float array + */ + public static float[] toFloatArray(List list) { + float[] ret = new float[list.size()]; + int idx = 0; + for (Number n : list) { + ret[idx++] = n.floatValue(); + } + return ret; + } + + /** + * Gets the current epoch number. + * + * @param modelDir the path to the directory where the model files are stored + * @param modelName the name of the model + * @return the current epoch number, if no epoch number found, return null + * @throws IOException if an I/O error occurs + * @throws FileNotFoundException if no matched parameter file with epoch number is found + */ + public static int getCurrentEpoch(Path modelDir, String modelName) throws IOException { + final Pattern pattern = Pattern.compile(Pattern.quote(modelName) + "-(\\d{4}).params"); + List checkpoints = + Files.walk(modelDir, 1) + .map( + p -> { + Matcher m = pattern.matcher(p.toFile().getName()); + if (m.matches()) { + return Integer.parseInt(m.group(1)); + } + return null; + }) + .filter(Objects::nonNull) + .sorted() + .collect(Collectors.toList()); + if (checkpoints.isEmpty()) { + throw new FileNotFoundException( + String.format( + "No matched params file is found in directory: {} for model {}", + modelDir.toAbsolutePath(), + modelName)); + } + return checkpoints.get(checkpoints.size() - 1); + } + + /** + * Utility function to help debug nan values in parameters and their gradients. + * + * @param parameters the list of parameters to check + * @param checkGradient whether to check parameter value or its gradient value + * @param logger the logger to log the result + */ + // TODO + // public static void checkParameterValues( + // Pairlist parameters, boolean checkGradient, Logger logger) { + // + // } + + /** + * Utility function to help summarize the values in an {@link NDArray}. + * + * @param array the {@link NDArray} to be summarized + * @param logger the logger to log the result + * @param prefix the prefix or name to be displayed + */ + // TODO + // public static void checkNDArrayValues(NDArray array, Logger logger, String prefix) { + // + // } + + /** + * Utility function to get Engine specific cache directory. + * + * @param engine the engine name + * @return DJL engine cache directory + */ + public static Path getEngineCacheDir(String engine) { + return getEngineCacheDir().resolve(engine); + } + + /** + * Utility function to get Engine cache directory. + * + * @return DJL engine cache directory + */ + public static Path getEngineCacheDir() { + String cacheDir = System.getProperty(ENGINE_CACHE_DIR); + if (cacheDir == null || cacheDir.isEmpty()) { + cacheDir = System.getenv(ENGINE_CACHE_DIR); + if (cacheDir == null || cacheDir.isEmpty()) { + return getCacheDir(); + } + } + return Paths.get(cacheDir); + } + + /** + * Utility function to get DJL cache directory. + * + * @return DJL cache directory + */ + public static Path getCacheDir() { + String cacheDir = System.getProperty(MXNET_CACHE_DIR); + if (cacheDir == null || cacheDir.isEmpty()) { + cacheDir = System.getenv(MXNET_CACHE_DIR); + if (cacheDir == null || cacheDir.isEmpty()) { + Path dir = Paths.get(System.getProperty("user.home")); + if (!Files.isWritable(dir)) { + dir = Paths.get(System.getProperty("java.io.tmpdir")); + } + return dir.resolve("mxnet.java_package"); + } + } + return Paths.get(cacheDir); + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/ZipUtils.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/ZipUtils.java new file mode 100644 index 000000000000..44214a90acbf --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/ZipUtils.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.util; + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardCopyOption; +import java.util.zip.ZipEntry; +import java.util.zip.ZipInputStream; +import java.util.zip.ZipOutputStream; + +/** Utilities for working with zip files. */ +public final class ZipUtils { + + private ZipUtils() {} + + /** + * Unzips an input stream to a given path. + * + * @param is the input stream to unzip + * @param dest the path to store the unzipped files + * @throws IOException for failures to unzip the input stream and create files in the dest path + */ + public static void unzip(InputStream is, Path dest) throws IOException { + ZipInputStream zis = new ZipInputStream(is); + ZipEntry entry; + while ((entry = zis.getNextEntry()) != null) { + String name = entry.getName(); + if (name.contains("..")) { + throw new IOException("Malicious zip entry: " + name); + } + Path file = dest.resolve(name).toAbsolutePath(); + if (entry.isDirectory()) { + Files.createDirectories(file); + } else { + Path parentFile = file.getParent(); + if (parentFile == null) { + throw new AssertionError( + "Parent path should never be null: " + file.toString()); + } + Files.createDirectories(parentFile); + Files.copy(zis, file, StandardCopyOption.REPLACE_EXISTING); + } + } + } + + /** + * Zips an input directory to a given file. + * + * @param src the input directory to zip + * @param dest the path to store the zipped files + * @param includeFolderName if include the source directory name in the zip entry + * @throws IOException for failures to zip the input directory + */ + public static void zip(Path src, Path dest, boolean includeFolderName) throws IOException { + try (ZipOutputStream zos = new ZipOutputStream(Files.newOutputStream(dest))) { + Path root = includeFolderName ? src.getParent() : src; + if (root == null) { + throw new AssertionError("Parent folder should not be null."); + } + addToZip(root, src, zos); + } + } + + private static void addToZip(Path root, Path file, ZipOutputStream zos) throws IOException { + Path relative = root.relativize(file); + String name = relative.toString(); + if (Files.isDirectory(file)) { + if (!name.isEmpty()) { + ZipEntry entry = new ZipEntry(name + '/'); + zos.putNextEntry(entry); + } + File[] files = file.toFile().listFiles(); + if (files != null) { + for (File f : files) { + addToZip(root, f.toPath(), zos); + } + } + } else if (Files.isRegularFile(file)) { + if (name.isEmpty()) { + name = file.toFile().getName(); + } + ZipEntry entry = new ZipEntry(name); + zos.putNextEntry(entry); + Files.copy(file, zos); + } + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/cuda/CudaLibrary.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/cuda/CudaLibrary.java new file mode 100644 index 000000000000..abd5ba2834d4 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/cuda/CudaLibrary.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.util.cuda; + +import com.sun.jna.Library; + +/** + * {@code CudaLibrary} contains methods mapping to CUDA runtime API. + * + *

see: https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html + */ +public interface CudaLibrary extends Library { + + int INITIALIZATION_ERROR = 3; + int INSUFFICIENT_DRIVER = 35; + int ERROR_NO_DEVICE = 100; + int ERROR_NOT_PERMITTED = 800; + + /** + * Gets the number of devices with compute capability greater or equal to 1.0 that are available + * for execution. + * + * @param deviceCount the returned device count + * @return CUDA runtime API error code + */ + int cudaGetDeviceCount(int[] deviceCount); + + /** + * Returns the version number of the installed CUDA Runtime. + * + * @param runtimeVersion output buffer of runtime version number + * @return CUDA runtime API error code + */ + int cudaRuntimeGetVersion(int[] runtimeVersion); + + /** + * Gets the integer value of the attribute {@code attr} on device. + * + * @param pi the returned device attribute value + * @param attr the device attribute to query + * @param device the GPU device to retrieve + * @return CUDA runtime API error code + */ + int cudaDeviceGetAttribute(int[] pi, int attr, int device); + + /** + * Gets free and total device memory. + * + * @param free the returned free memory in bytes + * @param total the returned total memory in bytes + * @return CUDA runtime API error code + */ + int cudaMemGetInfo(long[] free, long[] total); + + /** + * Set device to be used for GPU executions. + * + * @param device the GPU device to retrieve + * @return CUDA runtime API error code + */ + int cudaSetDevice(int device); + + /** + * Gets which device is currently being used. + * + * @param device the returned current device + * @return CUDA runtime API error code + */ + int cudaGetDevice(int[] device); + + /** + * Returns the description string for an error code. + * + * @param code the CUDA error code to convert to string + * @return the description string for an error code + */ + String cudaGetErrorString(int code); +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/cuda/CudaUtils.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/cuda/CudaUtils.java new file mode 100644 index 000000000000..d79967d946ac --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/cuda/CudaUtils.java @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.util.cuda; + +import com.sun.jna.Native; +import java.io.File; +import java.lang.management.MemoryUsage; +import java.util.regex.Pattern; +import org.apache.mxnet.engine.Device; +import org.apache.mxnet.exception.JnaCallException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** A class containing CUDA utility methods. */ +public final class CudaUtils { + + private static final Logger logger = LoggerFactory.getLogger(CudaUtils.class); + + private static final CudaLibrary LIB = loadLibrary(); + + private static int gpuCount = -1; + + private CudaUtils() {} + + /** + * Gets whether CUDA runtime library is in the system. + * + * @return {@code true} if CUDA runtime library is in the system + */ + public static boolean hasCuda() { + return getGpuCount() > 0; + } + + /** + * Returns the number of GPUs available in the system. + * + * @return the number of GPUs available in the system + */ + public static int getGpuCount() { + + if (gpuCount != -1) { + return gpuCount; + } + + try { + validateLibrary(); + } catch (IllegalStateException e) { + return 0; + } + int[] count = new int[1]; + int result = LIB.cudaGetDeviceCount(count); + switch (result) { + case 0: + gpuCount = count[0]; + return gpuCount; + case CudaLibrary.ERROR_NO_DEVICE: + logger.debug( + "No GPU device found: {} ({})", LIB.cudaGetErrorString(result), result); + gpuCount = 0; + return gpuCount; + case CudaLibrary.INITIALIZATION_ERROR: + case CudaLibrary.INSUFFICIENT_DRIVER: + case CudaLibrary.ERROR_NOT_PERMITTED: + default: + logger.warn( + "Failed to detect GPU count: {} ({})", + LIB.cudaGetErrorString(result), + result); + gpuCount = 0; + return gpuCount; + } + } + + /** + * Returns the version of CUDA runtime. + * + * @return the version if CUDA runtime + */ + public static int getCudaVersion() { + validateLibrary(); + int[] version = new int[1]; + int result = LIB.cudaRuntimeGetVersion(version); + checkCall(result); + return version[0]; + } + + /** + * Returns the version string of CUDA runtime. + * + * @return the version string of CUDA runtime + */ + public static String getCudaVersionString() { + validateLibrary(); + int version = getCudaVersion(); + int major = version / 1000; + int minor = (version / 10) % 10; + return String.valueOf(major) + minor; + } + + /** + * Returns the CUDA compute capability. + * + * @param device the GPU {@link Device} to retrieve + * @return the CUDA compute capability + */ + public static String getComputeCapability(int device) { + validateLibrary(); + int attrComputeCapabilityMajor = 75; + int attrComputeCapabilityMinor = 76; + + int[] major = new int[1]; + int[] minor = new int[1]; + checkCall(LIB.cudaDeviceGetAttribute(major, attrComputeCapabilityMajor, device)); + checkCall(LIB.cudaDeviceGetAttribute(minor, attrComputeCapabilityMinor, device)); + + return String.valueOf(major[0] + minor[0]); + } + + /** + * Returns the {@link MemoryUsage} of the specified GPU device. + * + * @param device the GPU {@link Device} to retrieve + * @return the {@link MemoryUsage} of the specified GPU device + * @throws IllegalArgumentException if {@link Device} is not GPU device or does not exist + */ + public static MemoryUsage getGpuMemory(Device device) { + if (!Device.Type.GPU.equals(device.getDeviceType())) { + throw new IllegalArgumentException("Only GPU device is allowed."); + } + + validateLibrary("No GPU device detected."); + + int[] currentDevice = new int[1]; + checkCall(LIB.cudaGetDevice(currentDevice)); + checkCall(LIB.cudaSetDevice(device.getDeviceId())); + + long[] free = new long[1]; + long[] total = new long[1]; + + checkCall(LIB.cudaMemGetInfo(free, total)); + checkCall(LIB.cudaSetDevice(currentDevice[0])); + + long committed = total[0] - free[0]; + return new MemoryUsage(-1, committed, committed, total[0]); + } + + private static CudaLibrary loadLibrary() { + try { + if (System.getProperty("os.name").startsWith("Win")) { + String path = System.getenv("PATH"); + if (path == null) { + return null; + } + Pattern p = Pattern.compile("cudart64_\\d+\\.ddl"); + String cudaPath = System.getenv("CUDA_PATH"); + + String[] searchPath = getPathArray(path, cudaPath); + + for (String item : searchPath) { + File dir = new File(item); + File[] files = dir.listFiles(n -> p.matcher(n.getName()).matches()); + if (files != null && files.length > 0) { + String fileName = files[0].getName(); + String cudaRT = fileName.substring(0, fileName.length() - 4); + logger.debug("Found cudart: {}", files[0].getAbsolutePath()); + return Native.load(cudaRT, CudaLibrary.class); + } + } + logger.debug("No cudart library found in path."); + return null; + } + return Native.load("cudart", CudaLibrary.class); + } catch (UnsatisfiedLinkError e) { + logger.debug("cudart library not found."); + logger.trace("", e); + return null; + } + } + + private static String[] getPathArray(String path, String cudaPath) { + if (cudaPath == null) { + return path.split(";"); + } else { + return ";".split(String.format("%s\\bin\\;%s", cudaPath, path)); + } + } + + private static void checkCall(int ret) { + validateLibrary(); + if (ret != 0) { + throw new JnaCallException( + String.format( + "CUDA API call failed: %s (%d)", LIB.cudaGetErrorString(ret), ret)); + } + } + + private static void validateLibrary() { + if (LIB == null) { + throw new IllegalStateException("No cuda library is loaded."); + } + } + + private static void validateLibrary(String msg) { + if (msg == null) { + validateLibrary(); + } else if (LIB == null) { + throw new IllegalStateException(msg); + } + } +} diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/cuda/package-info.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/cuda/package-info.java new file mode 100644 index 000000000000..e34e298f0170 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/cuda/package-info.java @@ -0,0 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Contains the Java front-end implementation for Apache MXNet. */ +package org.apache.mxnet.util.cuda; diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/package-info.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/package-info.java new file mode 100644 index 000000000000..59260e2c4b78 --- /dev/null +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/package-info.java @@ -0,0 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Contains the Java front-end implementation for Apache MXNet. */ +package org.apache.mxnet.util; diff --git a/java-package/mxnet-engine/src/main/jna/mapping.properties b/java-package/mxnet-engine/src/main/jna/mapping.properties new file mode 100644 index 000000000000..8a770ccc2efb --- /dev/null +++ b/java-package/mxnet-engine/src/main/jna/mapping.properties @@ -0,0 +1,37 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +MXNDArraySaveRawBytes.out_buf = PointerByReference +MXNDArraySave.args = PointerArray +MXInvokeCachedOp.inputs = Pointer +MXInvokeCachedOpEx.inputs = Pointer +MXInvokeCachedOpEX.inputs = Pointer +MXImperativeInvoke.inputs = PointerArray +MXImperativeInvokeEx.inputs = PointerArray +MXImperativeInvokeEx.param_keys = StringArray +MXImperativeInvokeEx.param_vals = StringArray +MXKVStoreInit.vals = PointerArray +MXKVStoreInitEx.vals = PointerArray +MXKVStorePush.vals = PointerArray +MXKVStorePushEx.vals = PointerArray +MXKVStorePull.vals = PointerArray +MXKVStorePullEx.vals = PointerArray +MXKVStorePushPullEx.vals = PointerArray +MXKVStorePushPullEx.outs = PointerArray +MXAutogradBackwardEx.output_handles = PointerArray +MXAutogradBackwardEx.ograd_handles = PointerArray +MXAutogradBackward.output_handles = PointerArray diff --git a/java-package/native/build.gradle b/java-package/native/build.gradle new file mode 100644 index 000000000000..c3c91e070965 --- /dev/null +++ b/java-package/native/build.gradle @@ -0,0 +1,336 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +plugins { + id 'maven-publish' + id 'signing' +} + +group = "org.apache.mxnet" + +def VERSION = "2.0.0" +boolean isRelease = project.hasProperty("release") || project.hasProperty("staging") +version = VERSION + (isRelease ? "" : "-SNAPSHOT") + +task syncBuiltMxnetLib(type: Sync) { + from "${rootProject.projectDir.parent}/build" + into "${project.buildDir}/mxnet/native/lib" + include "libmxnet.*" +} + +// Create mxnet native library jar without classifier +jar { + def placeholder = "${project.buildDir}/placeholder" + // this line is to enforce gradle to build the jar + // otherwise it don't generate the placeholder jar at times + // when there is no java code inside src/main + outputs.dir file("build/libs") + doFirst { + def versionName = project.version + if (!isRelease) { + versionName += String.format("-%s", new Date().format('yyyyMMdd')) + } + def dir = file("${placeholder}/native/lib") + dir.mkdirs() + def propFile = file("${placeholder}/native/lib/mxnet.properties") + propFile.text = "placeholder=true\nversion=${versionName}\n" + } + + from placeholder +} + +java { + withJavadocJar() + withSourcesJar() +} + +project.tasks.withType(GenerateModuleMetadata) { + enabled = false +} + +signing { + required(project.hasProperty("staging") || project.hasProperty("snapshot")) + def signingKey = findProperty("signingKey") + def signingPassword = findProperty("signingPassword") + sign publishing.publications +} + +task buildLocalLibraryJarDefault() { + def flavor = "mkl" + def osName = getOsName() + buildLocalLibraryJar(flavor, osName) +} + +def buildLocalLibraryJar(flavorName, osName) { + def BINARY_ROOT = "${project.buildDir}" + tasks.create(name: "${flavorName}-${osName}Jar", type: Jar) { + doFirst { + copyMxnetNativeLib(flavorName, osName) + def propFile = file("${BINARY_ROOT}/${flavorName}/${osName}/native/lib/mxnet.properties") + propFile.delete() + def dsStore = file("${BINARY_ROOT}/${flavorName}/${osName}/native/lib/.DS_Store") + dsStore.delete() + + def versionName = String.format("${version}-%s", new Date().format('yyyyMMdd')) + def dir = file("${BINARY_ROOT}/${flavorName}/${osName}/native/lib") + def sb = new StringBuilder() + sb.append("version=${versionName}\nclassifier=${flavorName}-${osName}-x86_64\nlibraries=") + def first = true + for (String name : dir.list().sort()) { + if (first) { + first = false; + } else { + sb.append(',') + } + sb.append(name) + } + propFile.text = sb.toString() + def metaInf = new File("${BINARY_ROOT}/${flavorName}/${osName}/META-INF") + metaInf.mkdirs() + def licenseFile = new File(metaInf, "LICENSE") + licenseFile.text = new URL("https://raw.githubusercontent.com/apache/incubator-mxnet/master/LICENSE").text + + def binaryLicenseFile = new File(metaInf, "LICENSE.binary.dependencies") + binaryLicenseFile.text = new URL("https://raw.githubusercontent.com/apache/incubator-mxnet/master/tools/dependencies/LICENSE.binary.dependencies").text + + from file("src/main/resources") + } + from file("${BINARY_ROOT}/${flavorName}/${osName}") + archiveClassifier = "${osName}-x86_64" + + manifest { + attributes("Automatic-Module-Name": "org.apache.mxnet.mxnet_native_${flavorName}_${osName}") + } + } + return tasks["${flavorName}-${osName}Jar"] +} + +def getOsName() { + def os_name = System.properties['os.name'] + if (os_name.contains('windows')) { + return "win" + } else if (os_name.contains('Mac OS X')) { + return "osx" + } else if (os_name.contains('Linux')) { + return "linux" + } else { + return System.properties['os.name'] + } + +} + +//def BINARY_ROOT = "${project.buildDir}/download" +//def flavorNames = file(BINARY_ROOT).list() ?: [] +//flavorNames.each { flavor -> +// +// def platformNames = file("${BINARY_ROOT}/${flavor}").list() ?: [] +// +// def artifactsNames = [] +// +// platformNames.each { osName -> +// tasks.create(name: "${flavor}-${osName}Jar", type: Jar) { +// doFirst { +// def propFile = file("${BINARY_ROOT}/${flavor}/${osName}/native/lib/mxnet.properties") +// propFile.delete() +// def dsStore = file("${BINARY_ROOT}/${flavor}/${osName}/native/lib/.DS_Store") +// dsStore.delete() +// +// def versionName = String.format("${version}-%s", new Date().format('yyyyMMdd')) +// def dir = file("${BINARY_ROOT}/${flavor}/${osName}/native/lib") +// def sb = new StringBuilder() +// sb.append("version=${versionName}\nclassifier=${flavor}-${osName}-x86_64\nlibraries=") +// def first = true +// for (String name : dir.list().sort()) { +// if (first) { +// first = false; +// } else { +// sb.append(',') +// } +// sb.append(name) +// } +// propFile.text = sb.toString() +// def metaInf = new File("${BINARY_ROOT}/${flavor}/${osName}/META-INF") +// metaInf.mkdirs() +// def licenseFile = new File(metaInf, "LICENSE") +// licenseFile.text = new URL("https://raw.githubusercontent.com/apache/incubator-mxnet/master/LICENSE").text +// +// def binaryLicenseFile = new File(metaInf, "LICENSE.binary.dependencies") +// binaryLicenseFile.text = new URL("https://raw.githubusercontent.com/apache/incubator-mxnet/master/tools/dependencies/LICENSE.binary.dependencies").text +// +// from file("src/main/resources") +// } +// from file("${BINARY_ROOT}/${flavor}/${osName}") +// archiveClassifier = "${osName}-x86_64" +// +// manifest { +// attributes("Automatic-Module-Name": "org.apache.mxnet.mxnet_native_${flavor}_${osName}") +// } +// } +// artifactsNames.add(tasks["${flavor}-${osName}Jar"]) +// } + + // Only publish if the project directory equals the current directory + // This means that publishing from the main project does not publish the native jars + // and the native jars have to be published separately + // TODO publish info +// if (project.getProjectDir().toString() == System.getProperty("user.dir")) { +// publishing.publications.create("${flavor}", MavenPublication) { +// artifactId "mxnet-native-${flavor}" +// from components.java +// artifacts = artifactsNames +// artifact jar +// artifact javadocJar +// artifact sourcesJar +// pom { +// name = "DJL release for Apache MXNet native binaries" +// description = "Deep Java Library (DJL) provided Apache MXNet native library binary distribution" +// url = "http://www.djl.ai/mxnet/native" +// packaging = "jar" +// +// licenses { +// license { +// name = 'The Apache License, Version 2.0' +// url = 'https://www.apache.org/licenses/LICENSE-2.0' +// } +// } +// +// scm { +// connection = "scm:git:git@github.com:deepjavalibrary/djl.git" +// developerConnection = "scm:git:git@github.com:deepjavalibrary/djl.git" +// url = "https://github.com/deepjavalibrary/djl" +// tag = "HEAD" +// } +// +// developers { +// developer { +// name = "DJL.AI Team" +// email = "djl-dev@amazon.com" +// organization = "Amazon AI" +// organizationUrl = "https://amazon.com" +// } +// } +// } +// } +// } +//} + +//publishing.repositories { +// maven { +// if (project.hasProperty("snapshot")) { +// name = "snapshot" +// url = "https://oss.sonatype.org/content/repositories/snapshots/" +// credentials { +// username = findProperty("ossrhUsername") +// password = findProperty("ossrhPassword") +// } +// } else if (project.hasProperty("staging")) { +// name = "staging" +// url = "https://oss.sonatype.org/service/local/staging/deploy/maven2/" +// credentials { +// username = findProperty("ossrhUsername") +// password = findProperty("ossrhPassword") +// } +// } else { +// name = "local" +// url = "build/repo" +// } +// } +//} + +import java.util.zip.GZIPInputStream +task copyMxnetNativeLibDefault() { + copyMxnetNativeLib("mkl", getOsName()) +} + +def copyMxnetNativeLib(flavorName, osName) { + // TODO: only mkl considered here + copy { + from("${rootProject.projectDir.parent}/build") + into("${project.buildDir}/${flavorName}/${osName}/native/lib") + // TODO: load map (flavor-os -> lib name) from configure file + switch (osName + "-" + flavorName) { + case "osx-mkl": + include "libmxnet.dylib" + break + case "win-commen": + include "libgcc_s_seh-1.dll" + include "libgfortran-3.dll" + include "libopenblas.dll" + include "libquadmath-0.dll" + break + case "win-mkl": + include "mxnet.dll" + break + case "linux-common": + include "libgfortran.so.4" + include "libgomp.so.1" + include "libopenblas.so.0" + include "libquadmath.so.0" + break + case "linux-mkl": + case "linux-cu102mkl": + case "linux-cu110mkl": + include "libmxnet.so" + break + default: + include "" + } + } +} + +//task downloadMxnetNativeLib() { +// doLast { +// def url = "https://publish.djl.ai/mxnet-${VERSION}" +// def files = [ +//// "linux/common/libgfortran.so.4.gz": "mkl/linux/native/lib/libgfortran.so.4", +// "linux/common/libgomp.so.1.gz" : "mkl/linux/native/lib/libgomp.so.1", +// "linux/common/libopenblas.so.0.gz": "mkl/linux/native/lib/libopenblas.so.0", +// "linux/common/libquadmath.so.0.gz": "mkl/linux/native/lib/libquadmath.so.0", +// "linux/mkl/libmxnet.so.gz" : "mkl/linux/native/lib/libmxnet.so", +// "linux/cu102mkl/libmxnet.so.gz" : "cu102mkl/linux/native/lib/libmxnet.so", +// "linux/cu110mkl/libmxnet.so.gz" : "cu110mkl/linux/native/lib/libmxnet.so", +// "osx/mkl/libmxnet.dylib.gz" : "mkl/osx/native/lib/libmxnet.dylib", +// "win/common/libgcc_s_seh-1.dll.gz": "mkl/win/native/lib/libgcc_s_seh-1.dll", +// "win/common/libgfortran-3.dll.gz" : "mkl/win/native/lib/libgfortran-3.dll", +// "win/common/libopenblas.dll.gz" : "mkl/win/native/lib/libopenblas.dll", +// "win/common/libquadmath-0.dll.gz" : "mkl/win/native/lib/libquadmath-0.dll", +// "win/mkl/libmxnet.dll.gz" : "mkl/win/native/lib/mxnet.dll" +// ] +// +// files.each { entry -> +// project.logger.lifecycle("Downloading ${url}/${entry.key}") +// def file = new File("${BINARY_ROOT}/${entry.value}") +// file.getParentFile().mkdirs() +// new URL("${url}/${entry.key}").withInputStream { i -> file.withOutputStream { it << new GZIPInputStream(i) } } +// } +// +// copy { +// from("${BINARY_ROOT}/mkl/linux/native/lib") { +// exclude '**/libmxnet.so' +// } +// into("${BINARY_ROOT}/cu102mkl/linux/native/lib") +// } +// copy { +// from("${BINARY_ROOT}/mkl/linux/native/lib") { +// exclude '**/libmxnet.so' +// } +// into("${BINARY_ROOT}/cu110mkl/linux/native/lib") +// } +// +// new File("${BINARY_ROOT}/auto").mkdirs() +// } +//} diff --git a/java-package/native/src/main/resources/META-INF/.gitkeep b/java-package/native/src/main/resources/META-INF/.gitkeep new file mode 100644 index 000000000000..d216be4ddc94 --- /dev/null +++ b/java-package/native/src/main/resources/META-INF/.gitkeep @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. \ No newline at end of file diff --git a/java-package/scripts/ci/start_integration_test.sh b/java-package/scripts/ci/start_integration_test.sh new file mode 100644 index 000000000000..d4c55f0c91e8 --- /dev/null +++ b/java-package/scripts/ci/start_integration_test.sh @@ -0,0 +1,18 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +cd /work/mxnet +python3 build.py -p ubuntu_cpu /work/mxnet/ci/docker/runtime_functions.sh java_package_integration_test \ No newline at end of file diff --git a/java-package/settings.gradle b/java-package/settings.gradle new file mode 100644 index 000000000000..54c778f418a4 --- /dev/null +++ b/java-package/settings.gradle @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +rootProject.name = 'org.apache.mxnet' + +include 'mxnet-engine' +include 'native' +include 'jnarator' +include 'integration' +include 'example' + diff --git a/java-package/tools/conf/checkstyle.xml b/java-package/tools/conf/checkstyle.xml new file mode 100644 index 000000000000..e4e4fe4a721d --- /dev/null +++ b/java-package/tools/conf/checkstyle.xml @@ -0,0 +1,521 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/java-package/tools/conf/findbugs-exclude.xml b/java-package/tools/conf/findbugs-exclude.xml new file mode 100644 index 000000000000..649726547dff --- /dev/null +++ b/java-package/tools/conf/findbugs-exclude.xml @@ -0,0 +1,50 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/java-package/tools/conf/licenseHeader.java b/java-package/tools/conf/licenseHeader.java new file mode 100644 index 000000000000..3e7c6c26f557 --- /dev/null +++ b/java-package/tools/conf/licenseHeader.java @@ -0,0 +1,16 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ diff --git a/java-package/tools/conf/pmd.xml b/java-package/tools/conf/pmd.xml new file mode 100644 index 000000000000..d8c0e86a66cc --- /dev/null +++ b/java-package/tools/conf/pmd.xml @@ -0,0 +1,466 @@ + + + + + Java Rule in PMD + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/java-package/tools/conf/suppressions.xml b/java-package/tools/conf/suppressions.xml new file mode 100644 index 000000000000..263bc2e31749 --- /dev/null +++ b/java-package/tools/conf/suppressions.xml @@ -0,0 +1,34 @@ + + + + + + + + + + + + + + + + + + diff --git a/java-package/tools/gradle/check.gradle b/java-package/tools/gradle/check.gradle new file mode 100644 index 000000000000..e38eb34eac33 --- /dev/null +++ b/java-package/tools/gradle/check.gradle @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +if (JavaVersion.current() < JavaVersion.VERSION_11) { + apply plugin: "com.github.spotbugs" + spotbugs { + excludeFilter = file("${rootProject.projectDir}/tools/conf/findbugs-exclude.xml") + ignoreFailures = false + spotbugsTest.enabled = true + } + spotbugsMain { + reports { + xml.enabled false + html.enabled true + } + } + spotbugsTest { + reports { + xml.enabled false + html.enabled true + } + } +} + +apply plugin: "pmd" +pmd { + ignoreFailures = false + pmdTest.enabled = false + ruleSets = [] // workaround pmd gradle plugin bug + ruleSetFiles = files("${rootProject.projectDir}/tools/conf/pmd.xml") +} +tasks.withType(Pmd){ + reports{ + xml.enabled=true + html.enabled=true + } +} + +apply plugin: "checkstyle" +checkstyle { + toolVersion = "8.26" + ignoreFailures = false + checkstyleTest.enabled = true + configProperties = [ + "checkstyle.suppressions.file" : file("${rootProject.projectDir}/tools/conf/suppressions.xml"), + "checkstyle.licenseHeader.file" : file("${rootProject.projectDir}/tools/conf/licenseHeader.java") + ] + configFile = file("${rootProject.projectDir}/tools/conf/checkstyle.xml") +} +checkstyleMain { + classpath += configurations.compileClasspath +} +tasks.withType(Checkstyle) { + reports { + xml.enabled false + html.enabled true + } +} + +apply plugin: 'jacoco' +jacoco { + toolVersion = "0.8.5" +} +jacocoTestReport { + reports { + xml.enabled true + csv.enabled false + } +} + +test.finalizedBy jacocoTestReport +build.dependsOn javadoc diff --git a/java-package/tools/gradle/jacoco.gradle b/java-package/tools/gradle/jacoco.gradle new file mode 100644 index 000000000000..010ae474e960 --- /dev/null +++ b/java-package/tools/gradle/jacoco.gradle @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +apply plugin: 'jacoco' + +def jacocoProjects = subprojects.findAll { + if ([":jnarator"].contains(it.getPath())) { + return false + } + return new File(it.projectDir, "src/test/java").exists() +} + +task jacocoMergeTestData(type: JacocoMerge) { + jacocoProjects.each { p -> + dependsOn(p.test, p.jacocoTestReport) + executionData p.tasks.withType(Test) + } +} + +def exclusions = [":examples", ":integration"] + +task jacocoRootReport(type: JacocoReport) { + dependsOn jacocoMergeTestData + description = 'Generates an aggregate report from all subprojects' + + jacocoProjects.each { p -> + if (!exclusions.contains(p.getPath())) { + additionalSourceDirs.from files((Set) p.sourceSets.main.allJava.srcDirs) + sourceDirectories.from files((Set) p.sourceSets.main.allSource.srcDirs) + classDirectories.from files((FileCollection) p.sourceSets.main.output) + additionalClassDirs((FileCollection) p.sourceSets.main.output) + } + } + executionData.from = files(jacocoProjects.jacocoTestReport.executionData).filter { f -> f.exists() } + + reports { + xml.enabled = true + html.enabled = true + } +} + +task jacocoRootVerification(type: JacocoCoverageVerification) { + dependsOn jacocoMergeTestData + + jacocoProjects.each { p -> + if (!exclusions.contains(p.getPath())) { + additionalSourceDirs.from files((Set) p.sourceSets.main.allJava.srcDirs) + sourceDirectories.from files((Set) p.sourceSets.main.allSource.srcDirs) + classDirectories.from files((FileCollection) p.sourceSets.main.output) + additionalClassDirs((FileCollection) p.sourceSets.main.output) + } + } + executionData.from = files(jacocoProjects.jacocoTestReport.executionData).filter { f -> f.exists() } + + violationRules { + rule { + limit { + if (Boolean.getBoolean("nightly")) { + minimum = 0.70 + } else { + minimum = 0.65 + } + } + } + } +} diff --git a/java-package/tools/gradle/java-formatter.gradle b/java-package/tools/gradle/java-formatter.gradle new file mode 100644 index 000000000000..2cbe96d5ee3c --- /dev/null +++ b/java-package/tools/gradle/java-formatter.gradle @@ -0,0 +1,85 @@ +buildscript { + repositories { + maven { + url "https://plugins.gradle.org/m2/" + } + } + dependencies { + classpath 'com.google.googlejavaformat:google-java-format:1.6' + } +} + +apply plugin: JavaFormatterPlugin + +check.dependsOn verifyJava + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import com.google.googlejavaformat.java.Formatter +import com.google.googlejavaformat.java.ImportOrderer +import com.google.googlejavaformat.java.JavaFormatterOptions +import com.google.googlejavaformat.java.Main +import com.google.googlejavaformat.java.RemoveUnusedImports + +class JavaFormatterPlugin implements Plugin { + void apply(Project project) { + project.task('formatJava') { + doLast { + Main formatter = new Main(new PrintWriter(System.out, true), new PrintWriter(System.err, true), System.in) + Project rootProject = project.getRootProject() + for (item in project.sourceSets) { + for (File file : item.getAllSource()) { + if (!file.getName().endsWith(".java") || file.getAbsolutePath().contains("generated-src")) { + continue + } + if (formatter.format("-a", "-i", file.getAbsolutePath()) != 0) { + throw new GradleException("Format java failed: " + file.getAbsolutePath()) + } + } + } + } + } + + project.task('verifyJava') { + doLast { + def options = JavaFormatterOptions.builder().style(JavaFormatterOptions.Style.AOSP).build() + Formatter formatter = new Formatter(options) + Project rootProject = project.getRootProject() + for (item in project.sourceSets) { + for (File file : item.getAllSource()) { + if (!file.getName().endsWith(".java") || file.getAbsolutePath().contains("generated-src")) { + continue + } + + String src = new String(file.bytes, "UTF-8") + String formatted = formatter.formatSource(src) + formatted = RemoveUnusedImports.removeUnusedImports(formatted, RemoveUnusedImports.JavadocOnlyImports.KEEP) + formatted = ImportOrderer.reorderImports(formatted); + if (!src.equals(formatted)) { + throw new GradleException("File not formatted: " + file.getAbsolutePath() + + System.lineSeparator() + + "In order to reformat your code, run './gradlew formatJava' (or './gradlew fJ' for short)" + + System.lineSeparator() + + "See https://github.com/deepjavalibrary/djl/blob/master/docs/development/development_guideline.md#coding-conventions for more details") + } + } + } + } + } + } +} diff --git a/java-package/tools/gradle/stats.gradle b/java-package/tools/gradle/stats.gradle new file mode 100644 index 000000000000..ae827f5f64b2 --- /dev/null +++ b/java-package/tools/gradle/stats.gradle @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +def testsResults = new TreeMap<>(Comparator.reverseOrder()) +gradle.taskGraph.beforeTask { Task task -> + task.ext.setProperty("startTime", Instant.now()) +} + +gradle.taskGraph.afterTask { Task task, TaskState state -> + if (task.name.equals("test") && state.didWork) { + long duration = Duration.between(task.ext.startTime, Instant.now()).toSeconds() + testsResults.put(duration, task.project.name); + } +} + +gradle.buildFinished { + if (gradle.startParameter.taskNames.contains("build") && !testsResults.isEmpty()) { + int count = 0; + println "========== Test duration ==========" + for (Map.Entry entry : testsResults.entrySet()) { + if (count++ > 5) { + break; + } + println "\t${entry.value}:\t${entry.key}s" + } + } +} diff --git a/src/initialize.cc b/src/initialize.cc index 9ef51219609f..1319cfc20f0b 100644 --- a/src/initialize.cc +++ b/src/initialize.cc @@ -375,9 +375,12 @@ std::shared_ptr HANDLER_NAME( \ }), \ [](auto f) { signal(SIGNAL, f); }); +// TODO(cspchen): avoid jvm exit with code 139. https://github.com/apache/incubator-mxnet/pull/20461 +#if !SKIP_SIGNAL_HANDLER_REGISTRATION SIGNAL_HANDLER(SIGSEGV, SIGSEGVHandler, true); SIGNAL_HANDLER(SIGFPE, SIGFPEHandler, false); SIGNAL_HANDLER(SIGBUS, SIGBUSHandler, false); +#endif #endif