Skip to content

Commit

Permalink
Update .pylintrc
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 453697739
  • Loading branch information
hbq1 authored and DistraxDev committed Jun 8, 2022
1 parent 3a54902 commit b6bffed
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 38 deletions.
51 changes: 26 additions & 25 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@

[MASTER]

# Add files or directories to the blacklist. They should be base names, not
# paths.
# Files or directories to be skipped. They should be base names, not paths.
ignore=third_party

# Add files or directories matching the regex patterns to the blacklist. The
# regex matches against base names, not paths.
# Files or directories matching the regex patterns are skipped. The regex
# matches against base names, not paths.
ignore-patterns=

# Pickle collected data for later comparisons.
Expand All @@ -29,11 +28,6 @@ jobs=4
# active Python interpreter and may run arbitrary code.
unsafe-load-any-extension=no

# A comma-separated list of package or module names from where C extensions may
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code
extension-pkg-whitelist=


[MESSAGES CONTROL]

Expand All @@ -56,12 +50,16 @@ confidence=
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use"--disable=all --enable=classes
# --disable=W"
disable=apply-builtin,
disable=abstract-method,
apply-builtin,
arguments-differ,
attribute-defined-outside-init,
backtick,
bad-option-value,
basestring-builtin,
buffer-builtin,
c-extension-no-member,
consider-using-enumerate,
cmp-builtin,
cmp-method,
coerce-builtin,
Expand All @@ -78,10 +76,11 @@ disable=apply-builtin,
global-statement,
hex-method,
idiv-method,
implicit-str-concat-in-sequence,
implicit-str-concat,
import-error,
import-self,
import-star-module-level,
inconsistent-return-statements,
input-builtin,
intern-builtin,
invalid-str-codec,
Expand All @@ -90,6 +89,8 @@ disable=apply-builtin,
long-builtin,
long-suffix,
map-builtin-not-iterating,
misplaced-comparison-constant,
missing-function-docstring,
metaclass-assignment,
next-method-called,
next-method-defined,
Expand All @@ -98,7 +99,9 @@ disable=apply-builtin,
no-else-continue,
no-else-raise,
no-else-return,
no-init, # added
no-member,
no-name-in-module,
no-self-use,
nonzero-method,
oct-method,
Expand Down Expand Up @@ -128,20 +131,23 @@ disable=apply-builtin,
too-many-branches,
too-many-instance-attributes,
too-many-locals,
too-many-nested-blocks,
too-many-public-methods,
too-many-return-statements,
too-many-statements,
trailing-newlines,
unichr-builtin,
unicode-builtin,
unnecessary-lambda-assignment,
unnecessary-pass,
unpacking-in-except,
use-dict-literal,
use-list-literal,
useless-else-on-loop,
useless-suppression,
using-cmp-argument,
xrange-builtin,
wrong-import-order,
xrange-builtin,
zip-builtin-not-iterating,


Expand All @@ -152,12 +158,6 @@ disable=apply-builtin,
# mypackage.mymodule.MyReporterClass.
output-format=text

# Put messages in a separate file for each module / package specified on the
# command line instead of printing them on stdout. Reports (if any) will be
# written in a file name "pylint_global.[txt|html]". This option is deprecated
# and it will be removed in Pylint 2.0.
files-output=no

# Tells whether to display a full report or only the messages
reports=no

Expand Down Expand Up @@ -278,12 +278,6 @@ ignore-long-lines=(?x)(
# else.
single-line-if-stmt=yes

# List of optional constructs for which whitespace checking is disabled. `dict-
# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
# `trailing-comma` allows a space between comma and closing bracket: (a, ).
# `empty-line` allows space-only lines.
no-space-check=

# Maximum number of lines in a module
max-module-lines=99999

Expand All @@ -306,6 +300,13 @@ expected-line-ending-format=
notes=TODO


[STRING]

# This flag controls whether inconsistent-quotes generates a warning when the
# character used as a quote delimiter is used inconsistently within a module.
check-quote-consistency=yes


[VARIABLES]

# Tells whether we should check for unused import in __init__ files.
Expand All @@ -332,7 +333,7 @@ redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functool

# Logging modules to check that the string format arguments are in logging
# function parameter format
logging-modules=logging,absl.logging,tensorflow.google.logging
logging-modules=logging,absl.logging,tensorflow.io.logging


[SIMILARITIES]
Expand Down
4 changes: 2 additions & 2 deletions distrax/_src/bijectors/sigmoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ def same_as(self, other: base.Bijector) -> bool:
return type(other) is Sigmoid # pylint: disable=unidiomatic-typecheck


def _more_stable_sigmoid(x):
def _more_stable_sigmoid(x: Array) -> Array:
"""Where extremely negatively saturated, approximate sigmoid with exp(x)."""
return jnp.where(x < -9, jnp.exp(x), jax.nn.sigmoid(x))


def _more_stable_softplus(x):
def _more_stable_softplus(x: Array) -> Array:
"""Where extremely saturated, approximate softplus with log1p(exp(x))."""
return jnp.where(x < -9, jnp.log1p(jnp.exp(x)), jax.nn.softplus(x))
22 changes: 11 additions & 11 deletions distrax/_src/distributions/transformed.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def _kl_divergence_transformed_transformed(
*unused_args,
input_hint: Optional[Array] = None,
**unused_kwargs,
) -> Array:
) -> Array:
"""Obtains the KL divergence between two Transformed distributions.
This computes the KL divergence between two Transformed distributions with the
Expand All @@ -254,9 +254,9 @@ def _kl_divergence_transformed_transformed(
Args:
dist1: A Transformed distribution.
dist2: A Transformed distribution.
input_hint: an example sample from the base distribution, used to trace
the `forward` method. If not specified, it is computed using a zero array
of the shape and dtype of a sample from the base distribution.
input_hint: an example sample from the base distribution, used to trace the
`forward` method. If not specified, it is computed using a zero array of
the shape and dtype of a sample from the base distribution.
Returns:
Batchwise `KL(dist1 || dist2)`.
Expand All @@ -267,9 +267,9 @@ def _kl_divergence_transformed_transformed(
"""
if dist1.distribution.event_shape != dist2.distribution.event_shape:
raise ValueError(
f'The two base distributions do not have the same event shape: '
f'{dist1.distribution.event_shape} and '
f'{dist2.distribution.event_shape}.')
f"The two base distributions do not have the same event shape: "
f"{dist1.distribution.event_shape} and "
f"{dist2.distribution.event_shape}.")

bij1 = conversion.as_bijector(dist1.bijector) # conversion needed for TFP
bij2 = conversion.as_bijector(dist2.bijector)
Expand All @@ -283,10 +283,10 @@ def _kl_divergence_transformed_transformed(
jaxpr_bij2 = jax.make_jaxpr(bij2.forward)(input_hint).jaxpr
if str(jaxpr_bij1) != str(jaxpr_bij2):
raise NotImplementedError(
f'The KL divergence cannot be obtained because it is not possible to '
f'guarantee that the bijectors {dist1.bijector.name} and '
f'{dist2.bijector.name} of the Transformed distributions are '
f'equal. If possible, use the same instance of a Distrax bijector.')
f"The KL divergence cannot be obtained because it is not possible to "
f"guarantee that the bijectors {dist1.bijector.name} and "
f"{dist2.bijector.name} of the Transformed distributions are "
f"equal. If possible, use the same instance of a Distrax bijector.")

return dist1.distribution.kl_divergence(dist2.distribution)

Expand Down

0 comments on commit b6bffed

Please sign in to comment.