From b2041b7841cb915cca88795b02d796a455351e53 Mon Sep 17 00:00:00 2001 From: Chouffe Date: Mon, 15 Apr 2019 10:15:39 +0200 Subject: [PATCH] [docstring] improve docstring and indentation in `module.clj` --- .../src/org/apache/clojure_mxnet/module.clj | 544 +++++++++++------- .../src/org/apache/clojure_mxnet/util.clj | 2 +- 2 files changed, 345 insertions(+), 201 deletions(-) diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj index aa5ce39f7a80..09f17e5d81f4 100644 --- a/contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj @@ -16,6 +16,7 @@ ;; (ns org.apache.clojure-mxnet.module + "Module API for Clojure package." (:refer-clojure :exclude [update symbol]) (:require [org.apache.clojure-mxnet.callback :as callback] [org.apache.clojure-mxnet.context :as context] @@ -31,18 +32,29 @@ (:import (org.apache.mxnet.module Module FitParams BaseModule) (org.apache.mxnet.io MXDataIter NDArrayIter) (org.apache.mxnet Initializer Optimizer NDArray DataBatch - Context EvalMetric Monitor Callback$Speedometer DataDesc))) + Context EvalMetric Monitor Callback$Speedometer + DataDesc))) (defn module - "Module is a basic module that wrap a symbol. - sym : Symbol definition. - map of options - :data-names - Input data names. - :label-names - Input label names - :contexts - Default is cpu(). - :workload-list - Default nil, indicating uniform workload. - :fixed-param-names Default nil, indicating no network parameters are fixed." - ([sym {:keys [data-names label-names contexts workload-list fixed-param-names] :as opts + "Module is a basic module that wrap a `symbol`. + `sym`: Symbol definition. + `opts-map` { + `data-names`: vector of strings - Default is [\"data\"] + Input data names + `label-names`: vector of strings - Default is [\"softmax_label\"] + Input label names + `contexts`: Context - Default is `context/cpu`. + `workload-list`: Default nil + Indicating uniform workload. + `fixed-param-names`: Default nil + Indicating no network parameters are fixed. + } + Ex: + (module sym) + (module sym {:data-names [\"data\"] + :label-names [\"linear_regression_label\"]}" + ([sym {:keys [data-names label-names contexts + workload-list fixed-param-names] :as opts :or {data-names ["data"] label-names ["softmax_label"] contexts [(context/default-context)]}}] @@ -80,31 +92,41 @@ (s/def ::force-rebind boolean?) (s/def ::shared-module #(instance? Module)) (s/def ::grad-req string?) -(s/def ::bind-opts (s/keys :req-un [::data-shapes] :opt-un [::label-shapes ::for-training ::inputs-need-grad - ::force-rebind ::shared-module ::grad-req])) +(s/def ::bind-opts + (s/keys :req-un [::data-shapes] + :opt-un [::label-shapes ::for-training ::inputs-need-grad + ::force-rebind ::shared-module ::grad-req])) (defn bind "Bind the symbols to construct executors. This is necessary before one can perform computation with the module. - mod : module - map of opts: - :data-shapes Typically is (provide-data-desc data-iter). Data shape must be in the form of io/data-desc with is a map of :name :shape :dtype and :layout - :label-shapes Typically is (provide-label-desc data-iter). map of :name :shape :dtype and :layout - :for-training Default is `true`. Whether the executors should be bind for training. - :inputs-need-grad Default is `false`. - Whether the gradients to the input data need to be computed. - Typically this is not needed. - But this might be needed when implementing composition of modules. - :force-rebind Default is `false`. - This function does nothing if the executors are already binded. - But with this `true`, the executors will be forced to rebind. - :shared-module Default is nil. This is used in bucketing. - When not `None`, the shared module essentially corresponds to - a different bucket -- a module with different symbol - but with the same sets of parameters - (e.g. unrolled RNNs with different lengths). " - [mod {:keys [data-shapes label-shapes for-training inputs-need-grad force-rebind - shared-module grad-req] :as opts + `mod`: module + `opts-map` { + `data-shapes`: map of `:name`, `:shape`, `:dtype`, and `:layout` + Typically is `(provide-data-desc data-iter)`.Data shape must be in the + form of `io/data-desc` + `label-shapes`: map of `:name` `:shape` `:dtype` and `:layout` + Typically is `(provide-label-desc data-iter)`. + `for-training`: boolean - Default is `true` + Whether the executors should be bind for training. + `inputs-need-grad`: boolean - Default is `false`. + Whether the gradients to the input data need to be computed. + Typically this is not needed. But this might be needed when + implementing composition of modules. + `force-rebind`: boolean - Default is `false`. + This function does nothing if the executors are already binded. But + with this `true`, the executors will be forced to rebind. + `shared-module`: Default is nil. + This is used in bucketing. When not `nil`, the shared module + essentially corresponds to a different bucket -- a module with + different symbol but with the same sets of parameters (e.g. unrolled + RNNs with different lengths). + } + Ex: + (bind {:data-shapes (mx-io/provide-data train-iter) + :label-shapes (mx-io/provide-label test-iter)})) " + [mod {:keys [data-shapes label-shapes for-training inputs-need-grad + force-rebind shared-module grad-req] :as opts :or {for-training true inputs-need-grad false force-rebind false @@ -129,24 +151,36 @@ (s/def ::aux-params map?) (s/def ::force-init boolean?) (s/def ::allow-extra boolean?) -(s/def ::init-params-opts (s/keys :opt-un [::initializer ::arg-params ::aux-params - ::force-init ::allow-extra])) +(s/def ::init-params-opts + (s/keys :opt-un [::initializer ::arg-params ::aux-params + ::force-init ::allow-extra])) (defn init-params - " Initialize the parameters and auxiliary states. - options map - :initializer - Called to initialize parameters if needed. - :arg-params - If not nil, should be a map of existing arg-params. - Initialization will be copied from that. - :auxParams - If not nil, should be a map of existing aux-params. - Initialization will be copied from that. - :allow-missing - If true, params could contain missing values, - and the initializer will be called to fill those missing params. - :force-init - If true, will force re-initialize even if already initialized. - :allow-extra - Whether allow extra parameters that are not needed by symbol. - If this is True, no error will be thrown when argParams or auxParams - contain extra parameters that is not needed by the executor." - ([mod {:keys [initializer arg-params aux-params allow-missing force-init allow-extra] :as opts + "Initialize the parameters and auxiliary states. + `opts-map` { + `initializer`: Initializer - Default is `uniform` + Called to initialize parameters if needed. + `arg-params`: map + If not nil, should be a map of existing arg-params. Initialization + will be copied from that. + `aux-params`: map + If not nil, should be a map of existing aux-params. Initialization + will be copied from that. + `allow-missing`: boolean - Default is `false` + If true, params could contain missing values, and the initializer will + be called to fill those missing params. + `force-init` boolean - Default is `false` + If true, will force re-initialize even if already initialized. + `allow-extra`: boolean - Default is `false` + Whether allow extra parameters that are not needed by symbol. + If this is `true`, no error will be thrown when `arg-params` or + `aux-params` contain extra parameters that is not needed by the + executor. + Ex: + (init-params {:initializer (initializer/xavier)}) + (init-params {:force-init true :allow-extra true})" + ([mod {:keys [initializer arg-params aux-params allow-missing force-init + allow-extra] :as opts :or {initializer (initializer/uniform 0.01) allow-missing false force-init false @@ -167,17 +201,23 @@ (s/def ::kvstore string?) (s/def ::reset-optimizer boolean?) (s/def ::force-init boolean?) -(s/def ::init-optimizer-opts (s/keys :opt-un [::optimizer ::kvstore ::reset-optimizer ::force-init])) +(s/def ::init-optimizer-opts + (s/keys :opt-un [::optimizer ::kvstore ::reset-optimizer ::force-init])) (defn init-optimizer - " Install and initialize optimizers. - - mod Module - - options map of - - kvstore - - reset-optimizer Default `True`, indicating whether we should set - `rescaleGrad` & `idx2name` for optimizer according to executorGroup - - force-init Default `False`, indicating whether we should force - re-initializing the optimizer in the case an optimizer is already installed." + "Install and initialize optimizers. + `mod`: Module + `opts-map` { + `kvstore`: string - Default is \"local\" + `optimizer`: Optimizer - Default is `sgd` + `reset-optimizer`: boolean - Default is `true` + Indicating whether we should set `rescaleGrad` & `idx2name` for + optimizer according to executorGroup. + `force-init`: boolean - Default is `false` + Indicating whether we should force re-initializing the optimizer + in the case an optimizer is already installed. + Ex: + (init-optimizer {:optimizer (optimizer/sgd {:learning-rate 0.1})})" ([mod {:keys [kvstore optimizer reset-optimizer force-init] :as opts :or {kvstore "local" optimizer (optimizer/sgd) @@ -191,8 +231,10 @@ (defn forward "Forward computation. - data-batch - input data of form io/data-batch either map or DataBatch - is-train - Default is nil, which means `is_train` takes the value of `for_training`." + `data-batch`: Either map or DataBatch + Input data of form `io/data-batch`. + `is-train`: Default is nil + Which means `is_train` takes the value of `for_training`." ([mod data-batch is-train] (util/validate! ::mx-io/data-batch data-batch "Invalid data batch") (doto mod @@ -209,9 +251,9 @@ (defn backward "Backward computation. - out-grads - Gradient on the outputs to be propagated back. - This parameter is only needed when bind is called - on outputs that are not a loss function." + `out-grads`: collection of NDArrays + Gradient on the outputs to be propagated back. This parameter is only + needed when bind is called on outputs that are not a loss function." ([mod out-grads] (util/validate! ::out-grads out-grads "Invalid out-grads") (doto mod @@ -227,50 +269,48 @@ (.forwardBackward data-batch))) (defn outputs - " Get outputs of the previous forward computation. - In the case when data-parallelism is used, - the outputs will be collected from multiple devices. - The results will look like `[[out1_dev1, out1_dev2], [out2_dev1, out2_dev2]]`, - those `NDArray` might live on different devices." + "Get outputs of the previous forward computation. + In the case when data-parallelism is used, the outputs will be collected from + multiple devices. The results will look like + `[[out1_dev1, out1_dev2], [out2_dev1, out2_dev2]]`. + Those `NDArray`s might live on different devices." [mod] (->> (.getOutputs mod) (util/scala-vector->vec) (mapv util/scala-vector->vec))) (defn update - "Update parameters according to the installed optimizer and the gradients computed - in the previous forward-backward batch." + "Update parameters according to the installed optimizer and the gradients + computed in the previous forward-backward batch." [mod] (doto mod (.update))) (defn outputs-merged - " Get outputs of the previous forward computation. - return In the case when data-parallelism is used, - the outputs will be merged from multiple devices, - as they look like from a single executor. - The results will look like `[out1, out2]`" + "Get outputs of the previous forward computation. + In the case when data-parallelism is used, the outputs will be merged from + multiple devices, as they look like from a single executor. + The results will look like `[out1, out2]`." [mod] (->> (.getOutputsMerged mod) (util/scala-vector->vec))) (defn input-grads - " Get the gradients to the inputs, computed in the previous backward computation. - In the case when data-parallelism is used, - the outputs will be collected from multiple devices. - The results will look like `[[grad1_dev1, grad1_dev2], [grad2_dev1, grad2_dev2]]` - those `NDArray` might live on different devices." + "Get the gradients to the inputs, computed in the previous backward computation. + In the case when data-parallelism is used, the outputs will be collected from + multiple devices. The results will look like + `[[grad1_dev1, grad1_dev2], [grad2_dev1, grad2_dev2]]`. + Those `NDArray`s might live on different devices." [mod] (->> (.getInputGrads mod) (util/scala-vector->vec) (mapv util/scala-vector->vec))) (defn input-grads-merged - " Get the gradients to the inputs, computed in the previous backward computation. - return In the case when data-parallelism is used, - the outputs will be merged from multiple devices, - as they look like from a single executor. - The results will look like `[grad1, grad2]`" + "Get the gradients to the inputs, computed in the previous backward computation. + In the case when data-parallelism is used, the outputs will be merged from + multiple devices, as they look like from a single executor. + The results will look like `[grad1, grad2]`." [mod] (->> (.getInputGradsMerged mod) (util/scala-vector->vec))) @@ -278,16 +318,25 @@ (s/def ::prefix string?) (s/def ::epoch int?) (s/def ::save-opt-states boolean?) -(s/def ::save-checkpoint-opts (s/keys :req-un [::prefix ::epoch] :opt-un [::save-opt-states ::save-checkpoint])) +(s/def ::save-checkpoint-opts + (s/keys :req-un [::prefix ::epoch] + :opt-un [::save-opt-states ::save-checkpoint])) (defn save-checkpoint - " Save current progress to checkpoint. - Use mx.callback.module_checkpoint as epoch_end_callback to save during training. - - mod Module - - opt-map with - :prefix The file prefix to checkpoint to - :epoch The current epoch number - :save-opt-states Whether to save optimizer states for continue training " + "Save current progress to checkpoint. + Use mx.callback.module_checkpoint as epoch_end_callback to save during + training. + `mod`: Module + `opts-map` { + `prefix`: string + The file prefix to checkpoint to + `epoch`: int + The current epoch number + `save-opt-states`: boolean - Default is `false` + Whether to save optimizer states for continue training + } + Ex: + (save-checkpoint {:prefix \"saved_model\" :epoch 0 :save-opt-states true})" ([mod {:keys [prefix epoch save-opt-states] :as opts :or {save-opt-states false}}] (util/validate! ::save-checkpoint-opts opts "Invalid save checkpoint opts") @@ -303,24 +352,34 @@ (s/def ::contexts (s/coll-of ::context :kind vector?)) (s/def ::workload-list (s/coll-of number? :kind vector?)) (s/def ::fixed-params-names (s/coll-of string? :kind vector?)) -(s/def ::load-checkpoint-opts (s/keys :req-un [::prefix ::epoch] - :opt-un [::load-optimizer-states ::data-names ::label-names - ::contexts ::workload-list ::fixed-param-names])) +(s/def ::load-checkpoint-opts + (s/keys :req-un [::prefix ::epoch] + :opt-un [::load-optimizer-states ::data-names ::label-names + ::contexts ::workload-list ::fixed-param-names])) (defn load-checkpoint "Create a model from previously saved checkpoint. - - opts map of - - prefix Path prefix of saved model files. You should have prefix-symbol.json, - prefix-xxxx.params, and optionally prefix-xxxx.states, - where xxxx is the epoch number. - - epoch Epoch to load. - - load-optimizer-states Whether to load optimizer states. - Checkpoint needs to have been made with save-optimizer-states=True - - dataNames Input data names. - - labelNames Input label names - - contexts Default is cpu(). - - workload-list Default nil, indicating uniform workload. - - fixed-param-names Default nil, indicating no network parameters are fixed." + `opts-map` { + `prefix`: string + Path prefix of saved model files. You should have prefix-symbol.json, + prefix-xxxx.params, and optionally prefix-xxxx.states, where xxxx is + the epoch number. + `epoch`: int + Epoch to load. + `load-optimizer-states`: boolean - Default is false + Whether to load optimizer states. Checkpoint needs to have been made + with `save-optimizer-states` = `true`. + `data-names`: vector of strings - Default is [\"data\"] + Input data names. + `label-names`: vector of strings - Default is [\"softmax_label\"] + Input label names. + `contexts`: Context - Default is `context/cpu` + `workload-list`: Default nil + Indicating uniform workload. + `fixed-param-names`: Default nil + Indicating no network parameters are fixed. + Ex: + (load-checkpoint {:prefix \"my-model\" :epoch 1 :load-optimizer-states true}" ([{:keys [prefix epoch load-optimizer-states data-names label-names contexts workload-list fixed-param-names] :as opts :or {load-optimizer-states false @@ -358,10 +417,10 @@ (util/scala-map->map (.auxParams mod))) (defn reshape - " Reshapes the module for new input shapes. - - mod module - - data-shapes Typically is `(provide-data data-iter) - - param label-shapes Typically is `(provide-label data-tier)`. " + "Reshapes the module for new input shapes. + `mod`: Module + `data-shapes`: Typically is `(provide-data data-iter)` + `label-shapes`: Typically is `(provide-label data-tier)`" ([mod data-shapes label-shapes] (util/validate! ::data-shapes data-shapes "Invalid data-shapes") (util/validate! (s/nilable ::label-shapes) label-shapes "Invalid label-shapes") @@ -376,28 +435,35 @@ ([mod data-shapes] (reshape mod data-shapes nil))) -(s/def ::set-param-opts (s/keys :opt-un [::arg-params ::aux-params ::allow-missing ::force-init ::allow-extra])) +(s/def ::set-param-opts + (s/keys :opt-un [::arg-params ::aux-params ::allow-missing + ::force-init ::allow-extra])) (defn get-params [mod] (.getParams mod)) (defn set-params - " Assign parameter and aux state values. - - mod module - - arg-params : map - map of name to value (`NDArray`) mapping. - - aux-params : map - map of name to value (`NDArray`) mapping. - - allow-missing : bool - If true, params could contain missing values, and the initializer will be - called to fill those missing params. - - force-init : bool - If true, will force re-initialize even if already initialized. - - allow-extra : bool - Whether allow extra parameters that are not needed by symbol. - If this is True, no error will be thrown when arg-params or aux-params - contain extra parameters that is not needed by the executor." - [mod {:keys [arg-params aux-params allow-missing force-init allow-extra] :as opts + "Assign parameters and aux state values. + `mod`: Module + `opts-map` { + `arg-params`: map - map of name to value (`NDArray`) mapping. + `aux-params`: map - map of name to value (`NDArray`) mapping. + `allow-missing`: boolean + If true, params could contain missing values, and the initializer will + be called to fill those missing params. + `force-init`: boolean - Default is `false` + If true, will force re-initialize even if already initialized. + `allow-extra`: boolean - Default is `false` + Whether allow extra parameters that are not needed by symbol. If this + is `true`, no error will be thrown when arg-params or aux-params + contain extra parameters that is not needed by the executor. + } + Ex: + (set-params mod + {:arg-params {\"fc_0_weight\" (ndarray/array [0.15 0.2 0.25 0.3] [2 2]) + :allow-missing true})" + [mod {:keys [arg-params aux-params allow-missing force-init + allow-extra] :as opts :or {allow-missing false force-init true allow-extra false}}] (util/validate! ::set-param-opts opts "Invalid set-params") (doto mod @@ -409,33 +475,32 @@ allow-extra))) (defn install-monitor - "Install monitor on all executors" + "Install monitor on all executors." [mod monitor] (doto mod (.installMonitor monitor))) (defn borrow-optimizer - "Borrow optimizer from a shared module. Used in bucketing, where exactly the same - optimizer (esp. kvstore) is used. - - mod module - - shared-module" + "Borrow optimizer from a shared module. Used in bucketing, where exactly the + same optimizer (esp. kvstore) is used. + `mod`: Module + `shared-module`" [mod shared-module] (doto mod (.borrowOptimizer shared-module))) (defn save-optimizer-states - "Save optimizer (updater) state to file - - mod module - - fname Path to output states file." + "Save optimizer (updater) state to file. + `mod`: Module + `fname`: string - Path to output states file." [mod fname] (doto mod (.saveOptimizerStates mod fname))) (defn load-optimizer-states - "Load optimizer (updater) state from file - - mod module - - fname Path to input states file. - " + "Load optimizer (updater) state from file. + `mod`: Module + `fname`: string - Path to input states file." [mod fname] (doto mod (.loadOptimzerStates fname))) @@ -444,10 +509,13 @@ (s/def ::labels (s/coll-of ::ndarray :kind vector?)) (defn update-metric - "Evaluate and accumulate evaluation metric on outputs of the last forward computation. - - mod module - - eval-metric - - labels" + "Evaluate and accumulate evaluation metric on outputs of the last forward + computation. + `mod`: module + `eval-metric`: EvalMetric + `labels`: collection of NDArrays + Ex: + (update-metric mod (eval-metric/mse) labels)" [mod eval-metric labels] (util/validate! ::eval-metric eval-metric "Invalid eval metric") (util/validate! ::labels labels "Invalid labels") @@ -458,18 +526,48 @@ (s/def ::validation-metric ::eval-metric) (s/def ::monitor #(instance? Monitor %)) (s/def ::batch-end-callback #(instance? Callback$Speedometer %)) -(s/def ::fit-params-opts (s/keys :opt-un [::eval-metric ::kvstore ::optimizer ::initializer - ::arg-params ::aux-params ::allow-missing ::force-rebind - ::force-init ::begin-epoch ::validation-metric ::monitor - ::batch-end-callback])) +(s/def ::fit-params-opts + (s/keys :opt-un [::eval-metric ::kvstore ::optimizer ::initializer + ::arg-params ::aux-params ::allow-missing ::force-rebind + ::force-init ::begin-epoch ::validation-metric ::monitor + ::batch-end-callback])) ;; callbacks are not supported for now (defn fit-params - "Fit Params" + "Initialize FitParams with provided parameters. + `eval-metric`: EvalMetric - Default is `accuracy` + `kvstore`: String - Default is \"local\" + `optimizer`: Optimizer - Default is `sgd` + `initializer`: Initializer - Default is `uniform` + Called to initialize parameters if needed. + `arg-params`: map + If not nil, should be a map of existing `arg-params`. Initialization + will be copied from that. + `aux-params`: map - + If not nil, should be a map of existing `aux-params`. Initialization + will be copied from that. + `allow-missing`: boolean - Default is `false` + If `true`, params could contain missing values, and the initializer will + be called to fill those missing params. + `force-rebind`: boolean - Default is `false` + This function does nothing if the executors are already binded. But with + this `true`, the executors will be forced to rebind. + `force-init`: boolean - Default is `false` + If `true`, will force re-initialize even if already initialized. + `begin-epoch`: int - Default is 0 + `validation-metric`: EvalMetric + `monitor`: Monitor + Ex: + (fit-params {:force-init true :force-rebind true :allow-missing true}) + (fit-params + {:batch-end-callback (callback/speedometer batch-size 100) + :initializer (initializer/xavier) + :optimizer (optimizer/sgd {:learning-rate 0.01}) + :eval-metric (eval-metric/mse)})" ([{:keys [eval-metric kvstore optimizer initializer arg-params aux-params - allow-missing force-rebind force-init begin-epoch validation-metric monitor - batch-end-callback] :as opts + allow-missing force-rebind force-init begin-epoch + validation-metric monitor batch-end-callback] :as opts :or {eval-metric (eval-metric/accuracy) kvstore "local" optimizer (optimizer/sgd) @@ -500,25 +598,36 @@ (s/def ::ndarray-iter #(instance? NDArrayIter %)) (s/def ::train-data (s/or :mx-iter ::mx-data-iter :ndarry-iter ::ndarray-iter)) (s/def ::eval-data ::train-data) -(s/def ::num-epoch int?) +(s/def ::num-epoch (s/and int? pos?)) (s/def ::fit-params #(instance? FitParams %)) -(s/def ::fit-options (s/keys :req-un [::train-data] :opt-un [::eval-data ::num-epoch ::fit-params])) +(s/def ::fit-options + (s/keys :req-un [::train-data] + :opt-un [::eval-data ::num-epoch ::fit-params])) ;;; High Level API (defn score - " Run prediction on `eval-data` and evaluate the performance according to `eval-metric`. - - mod module - - option map with - :eval-data : DataIter - :eval-metric : EvalMetric - :num-batch Number of batches to run. Default is `Integer.MAX_VALUE`, - indicating run until the `DataIter` finishes. - :batch-end-callback -not supported yet - :reset Default `True`, - indicating whether we should reset `eval-data` before starting evaluating. - :epoch Default 0. For compatibility, this will be passed to callbacks (if any). - During training, this will correspond to the training epoch number." + "Run prediction on `eval-data` and evaluate the performance according to + `eval-metric`. + `mod`: module + `opts-map` { + `eval-data`: DataIter + `eval-metric`: EvalMetric + `num-batch`: int - Default is `Integer.MAX_VALUE` + Number of batches to run. Indicating run until the `DataIter` + finishes. + `batch-end-callback`: not supported yet. + `reset`: boolean - Default is `true`, + Indicating whether we should reset `eval-data` before starting + evaluating. + `epoch`: int - Default is 0 + For compatibility, this will be passed to callbacks (if any). During + training, this will correspond to the training epoch number. + } + Ex: + (score mod {:eval-data data-iter :eval-metric (eval-metric/accuracy)}) + (score mod {:eval-data data-iter + :eval-metric (eval-metric/mse) :num-batch 10})" [mod {:keys [eval-data eval-metric num-batch reset epoch] :as opts :or {num-batch Integer/MAX_VALUE reset true @@ -537,15 +646,30 @@ (defn fit "Train the module parameters. - - mod module - - train-data (data-iterator) - - eval-data (data-iterator)If not nil, will be used as validation set and evaluate - the performance after each epoch. - - num-epoch Number of epochs to run training. - - f-params Extra parameters for training (See fit-params)." + `mod`: Module + `opts-map` { + `train-data`: DataIter + `eval-data`: DataIter + If not nil, will be used as validation set and evaluate the + performance after each epoch. + `num-epoch`: int + Number of epochs to run training. + `fit-params`: FitParams + Extra parameters for training (see fit-params). + } + Ex: + (fit {:train-data train-iter :eval-data test-iter :num-epoch 100) + (fit {:train-data train-iter + :eval-data test-iter + :num-epoch 5 + :fit-params + (fit-params {:batch-end-callback (callback/speedometer 128 100) + :initializer (initializer/xavier) + :optimizer (optimizer/sgd {:learning-rate 0.01}) + :eval-metric (eval-metric/mse)}))" [mod {:keys [train-data eval-data num-epoch fit-params] :as opts - `:or {num-epoch 1 - fit-params (new FitParams)}}] + :or {num-epoch 1 + fit-params (new FitParams)}}] (util/validate! ::fit-options opts "Invalid options for fit") (doto mod (.fit @@ -557,12 +681,13 @@ (s/def ::eval-data ::train-data) (s/def ::num-batch integer?) (s/def ::reset boolean?) -(s/def ::predict-opts (s/keys :req-un [::eval-data] :opt-un [::num-batch ::reset])) +(s/def ::predict-opts + (s/keys :req-un [::eval-data] :opt-un [::num-batch ::reset])) (defn predict-batch - "Run the predication on a data batch - - mod module - - data-batch data-batch" + "Run the predication on a data batch. + `mod`: Module + `data-batch`: data-batch" [mod data-batch] (util/validate! ::mx-io/data-batch data-batch "Invalid data batch") (util/coerce-return (.predict mod (if (map? data-batch) @@ -571,41 +696,60 @@ (defn predict "Run prediction and collect the outputs. - - mod module - - option map with - - :eval-data - - :num-batch Default is -1, indicating running all the batches in the data iterator. - - :reset Default is `True`, indicating whether we should reset the data iter before start - doing prediction. - The return value will be a vector of NDArrays `[out1, out2, out3]`. - Where each element is concatenation of the outputs for all the mini-batches." + `mod`: Module + `opts-map` { + `eval-data`: DataIter + `num-batch` int - Default is `-1` + Indicating running all the batches in the data iterator. + `reset`: boolean - Default is `true` + Indicating whether we should reset the data iter before start doing + prediction. + } + returns: vector of NDArrays `[out1, out2, out3]` where each element is the + concatenation of the outputs for all the mini-batches. + Ex: + (predict mod {:eval-data test-iter}) + (predict mod {:eval-data test-iter :num-batch 10 :reset false})" [mod {:keys [eval-data num-batch reset] :as opts :or {num-batch -1 reset true}}] (util/validate! ::predict-opts opts "Invalid opts for predict") (util/scala-vector->vec (.predict mod eval-data (int num-batch) reset))) -(s/def ::predict-every-batch-opts (s/keys :req-un [::eval-data] :opt-un [::num-batch ::reset])) +(s/def ::predict-every-batch-opts + (s/keys :req-un [::eval-data] :opt-un [::num-batch ::reset])) (defn predict-every-batch - " Run prediction and collect the outputs. - - module - - option map with - :eval-data - :num-batch Default is -1, indicating running all the batches in the data iterator. - :reset Default is `True`, indicating whether we should reset the data iter before start - doing prediction. - The return value will be a nested list like - [[out1_batch1, out2_batch1, ...], [out1_batch2, out2_batch2, ...]]` - This mode is useful because in some cases (e.g. bucketing), - the module does not necessarily produce the same number of outputs." + "Run prediction and collect the outputs. + `mod`: Module + `opts-map` { + `eval-data`: DataIter + `num-batch` int - Default is `-1` + Indicating running all the batches in the data iterator. + `reset` boolean - Default is `true` + Indicating whether we should reset the data iter before start doing + prediction. + } + returns: nested list like this + `[[out1_batch1, out2_batch1, ...], [out1_batch2, out2_batch2, ...]]` + + Note: This mode is useful because in some cases (e.g. bucketing), the module + does not necessarily produce the same number of outputs. + Ex: + (predict-every-batch mod {:eval-data test-iter})" [mod {:keys [eval-data num-batch reset] :as opts :or {num-batch -1 reset true}}] - (util/validate! ::predict-every-batch-opts opts "Invalid opts for predict-every-batch") - (mapv util/scala-vector->vec (util/scala-vector->vec (.predictEveryBatch mod eval-data (int num-batch) reset)))) - -(s/def ::score-opts (s/keys :req-un [::eval-data ::eval-metric] :opt-un [::num-batch ::reset ::epoch])) + (util/validate! ::predict-every-batch-opts + opts + "Invalid opts for predict-every-batch") + (mapv util/scala-vector->vec + (util/scala-vector->vec + (.predictEveryBatch mod eval-data (int num-batch) reset)))) + +(s/def ::score-opts + (s/keys :req-un [::eval-data ::eval-metric] + :opt-un [::num-batch ::reset ::epoch])) (defn exec-group [mod] (.execGroup mod)) diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj index 7ee25d4dd25e..9dc6c8f88ddd 100644 --- a/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj @@ -250,7 +250,7 @@ shape))) (defn map->scala-tuple-seq - "* Convert a map to a scala-Seq of scala-Tubple. + "* Convert a map to a scala-Seq of scala-Tuple. * Should also work if a seq of seq of 2 things passed. * Otherwise passed through unchanged." [map-or-tuple-seq]