diff --git a/.pylintrc b/.pylintrc index c200eb87..bae3fa4a 100644 --- a/.pylintrc +++ b/.pylintrc @@ -7,12 +7,11 @@ [MASTER] -# Add files or directories to the blacklist. They should be base names, not -# paths. +# Files or directories to be skipped. They should be base names, not paths. ignore=third_party -# Add files or directories matching the regex patterns to the blacklist. The -# regex matches against base names, not paths. +# Files or directories matching the regex patterns are skipped. The regex +# matches against base names, not paths. ignore-patterns= # Pickle collected data for later comparisons. @@ -29,11 +28,6 @@ jobs=4 # active Python interpreter and may run arbitrary code. unsafe-load-any-extension=no -# A comma-separated list of package or module names from where C extensions may -# be loaded. Extensions are loading into the active Python interpreter and may -# run arbitrary code -extension-pkg-whitelist= - [MESSAGES CONTROL] @@ -56,12 +50,16 @@ confidence= # --enable=similarities". If you want to run only the classes checker, but have # no Warning level messages displayed, use"--disable=all --enable=classes # --disable=W" -disable=apply-builtin, +disable=abstract-method, + apply-builtin, + arguments-differ, attribute-defined-outside-init, backtick, bad-option-value, + basestring-builtin, buffer-builtin, c-extension-no-member, + consider-using-enumerate, cmp-builtin, cmp-method, coerce-builtin, @@ -78,10 +76,11 @@ disable=apply-builtin, global-statement, hex-method, idiv-method, - implicit-str-concat-in-sequence, + implicit-str-concat, import-error, import-self, import-star-module-level, + inconsistent-return-statements, input-builtin, intern-builtin, invalid-str-codec, @@ -90,6 +89,8 @@ disable=apply-builtin, long-builtin, long-suffix, map-builtin-not-iterating, + misplaced-comparison-constant, + missing-function-docstring, metaclass-assignment, next-method-called, next-method-defined, @@ -98,7 +99,9 @@ disable=apply-builtin, no-else-continue, no-else-raise, no-else-return, + no-init, # added no-member, + no-name-in-module, no-self-use, nonzero-method, oct-method, @@ -128,20 +131,23 @@ disable=apply-builtin, too-many-branches, too-many-instance-attributes, too-many-locals, + too-many-nested-blocks, too-many-public-methods, too-many-return-statements, too-many-statements, trailing-newlines, unichr-builtin, unicode-builtin, + unnecessary-lambda-assignment, + unnecessary-pass, unpacking-in-except, use-dict-literal, use-list-literal, useless-else-on-loop, useless-suppression, using-cmp-argument, - xrange-builtin, wrong-import-order, + xrange-builtin, zip-builtin-not-iterating, @@ -152,12 +158,6 @@ disable=apply-builtin, # mypackage.mymodule.MyReporterClass. output-format=text -# Put messages in a separate file for each module / package specified on the -# command line instead of printing them on stdout. Reports (if any) will be -# written in a file name "pylint_global.[txt|html]". This option is deprecated -# and it will be removed in Pylint 2.0. -files-output=no - # Tells whether to display a full report or only the messages reports=no @@ -278,12 +278,6 @@ ignore-long-lines=(?x)( # else. single-line-if-stmt=yes -# List of optional constructs for which whitespace checking is disabled. `dict- -# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. -# `trailing-comma` allows a space between comma and closing bracket: (a, ). -# `empty-line` allows space-only lines. -no-space-check= - # Maximum number of lines in a module max-module-lines=99999 @@ -306,6 +300,13 @@ expected-line-ending-format= notes=TODO +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=yes + + [VARIABLES] # Tells whether we should check for unused import in __init__ files. @@ -332,7 +333,7 @@ redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functool # Logging modules to check that the string format arguments are in logging # function parameter format -logging-modules=logging,absl.logging,tensorflow.google.logging +logging-modules=logging,absl.logging,tensorflow.io.logging [SIMILARITIES] diff --git a/distrax/_src/bijectors/sigmoid.py b/distrax/_src/bijectors/sigmoid.py index 5e54fc27..de739a3a 100644 --- a/distrax/_src/bijectors/sigmoid.py +++ b/distrax/_src/bijectors/sigmoid.py @@ -73,11 +73,11 @@ def same_as(self, other: base.Bijector) -> bool: return type(other) is Sigmoid # pylint: disable=unidiomatic-typecheck -def _more_stable_sigmoid(x): +def _more_stable_sigmoid(x: Array) -> Array: """Where extremely negatively saturated, approximate sigmoid with exp(x).""" return jnp.where(x < -9, jnp.exp(x), jax.nn.sigmoid(x)) -def _more_stable_softplus(x): +def _more_stable_softplus(x: Array) -> Array: """Where extremely saturated, approximate softplus with log1p(exp(x)).""" return jnp.where(x < -9, jnp.log1p(jnp.exp(x)), jax.nn.softplus(x)) diff --git a/distrax/_src/distributions/transformed.py b/distrax/_src/distributions/transformed.py index 22e5dfa4..4d915760 100644 --- a/distrax/_src/distributions/transformed.py +++ b/distrax/_src/distributions/transformed.py @@ -234,7 +234,7 @@ def _kl_divergence_transformed_transformed( *unused_args, input_hint: Optional[Array] = None, **unused_kwargs, - ) -> Array: +) -> Array: """Obtains the KL divergence between two Transformed distributions. This computes the KL divergence between two Transformed distributions with the @@ -254,9 +254,9 @@ def _kl_divergence_transformed_transformed( Args: dist1: A Transformed distribution. dist2: A Transformed distribution. - input_hint: an example sample from the base distribution, used to trace - the `forward` method. If not specified, it is computed using a zero array - of the shape and dtype of a sample from the base distribution. + input_hint: an example sample from the base distribution, used to trace the + `forward` method. If not specified, it is computed using a zero array of + the shape and dtype of a sample from the base distribution. Returns: Batchwise `KL(dist1 || dist2)`. @@ -267,9 +267,9 @@ def _kl_divergence_transformed_transformed( """ if dist1.distribution.event_shape != dist2.distribution.event_shape: raise ValueError( - f'The two base distributions do not have the same event shape: ' - f'{dist1.distribution.event_shape} and ' - f'{dist2.distribution.event_shape}.') + f"The two base distributions do not have the same event shape: " + f"{dist1.distribution.event_shape} and " + f"{dist2.distribution.event_shape}.") bij1 = conversion.as_bijector(dist1.bijector) # conversion needed for TFP bij2 = conversion.as_bijector(dist2.bijector) @@ -283,10 +283,10 @@ def _kl_divergence_transformed_transformed( jaxpr_bij2 = jax.make_jaxpr(bij2.forward)(input_hint).jaxpr if str(jaxpr_bij1) != str(jaxpr_bij2): raise NotImplementedError( - f'The KL divergence cannot be obtained because it is not possible to ' - f'guarantee that the bijectors {dist1.bijector.name} and ' - f'{dist2.bijector.name} of the Transformed distributions are ' - f'equal. If possible, use the same instance of a Distrax bijector.') + f"The KL divergence cannot be obtained because it is not possible to " + f"guarantee that the bijectors {dist1.bijector.name} and " + f"{dist2.bijector.name} of the Transformed distributions are " + f"equal. If possible, use the same instance of a Distrax bijector.") return dist1.distribution.kl_divergence(dist2.distribution)