Skip to content

Commit

Permalink
Change the default matmul precision in JAX to highest precision.
Browse files Browse the repository at this point in the history
On CPU and GPU, this change has no effect.

On TPU, this PR changes the default matmul algorithm from a fast, low-quality algorithm to a slower, high-precision algorithm that uses multiple passes. Many users have reported the low-quality-by-default behavior to be a footgun, especially when performing non-neural network computations.

The old behavior can be restored either by passing an explicit Precision option to operators such as `dot`, or by changing the default precision, e.g.,
jax.config.update('jax_default_matmul_precision', 'fastest')

#7010

PiperOrigin-RevId: 395549544
  • Loading branch information
hawkinsp authored and jax authors committed Sep 8, 2021
1 parent e7e5140 commit 52485d6
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 17 deletions.
2 changes: 1 addition & 1 deletion jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ def _update_disable_jit_thread_local(val):
default_matmul_precision = config.define_enum_state(
name='jax_default_matmul_precision',
enum_values=['bfloat16', 'tensorfloat32', 'float32'],
default=None,
default='float32',
help=('Control the default matmul and conv precision for 32bit inputs.\n\n'

'Some platforms, like TPU, offer configurable precision levels for '
Expand Down
32 changes: 17 additions & 15 deletions jax/_src/lax/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,11 +395,13 @@ def axis_index(axis_name):
return axis_index_p.bind(axis_name=axis_name)


def pdot(x, y, axis_name, pos_contract=((), ()), pos_batch=((), ())):
def pdot(x, y, axis_name, pos_contract=((), ()), pos_batch=((), ()),
precision=None):
if not isinstance(axis_name, (list, tuple)):
axis_name = (axis_name,)
return pdot_p.bind(x, y, axis_name=axis_name,
pos_contract=pos_contract, pos_batch=pos_batch)
pos_contract=pos_contract, pos_batch=pos_batch,
precision=lax._canonicalize_precision(precision))


def xeinsum(spec: str, x, y):
Expand Down Expand Up @@ -1346,12 +1348,12 @@ def _vmap_process_axis_index(self, frame):
core.axis_substitution_rules[pdot_p] = partial(_subst_all_names_in_param, 'axis_name')

@pdot_p.def_impl
def _pdot_impl(x, y, *, axis_name, pos_contract, pos_batch):
def _pdot_impl(x, y, *, axis_name, pos_contract, pos_batch, precision):
if axis_name: raise NameError(f"unbound axis name: {axis_name[0]}")
return lax.dot_general(x, y, [pos_contract, pos_batch])
return lax.dot_general(x, y, [pos_contract, pos_batch], precision=precision)

@pdot_p.def_abstract_eval
def _pdot_abstract_eval(x, y, *, axis_name, pos_contract, pos_batch):
def _pdot_abstract_eval(x, y, *, axis_name, pos_contract, pos_batch, precision):
# TODO(frostig,mattjj,jekbradbury): check inputs have given axis names?
if not len(set(axis_name)) == len(axis_name): raise ValueError
pos_aval = lax.dot_general_p.abstract_eval(
Expand All @@ -1364,7 +1366,7 @@ def _pdot_abstract_eval(x, y, *, axis_name, pos_contract, pos_batch):
return pos_aval.update(named_shape=named_shape)

def _pdot_vmap_collective_rule(frame, vals_in, dims_in, *, axis_name,
pos_contract, pos_batch):
pos_contract, pos_batch, precision):
x, y = vals_in
x_dim, y_dim = dims_in
x_pos_contract, y_pos_contract = pos_contract
Expand All @@ -1376,24 +1378,24 @@ def _pdot_vmap_collective_rule(frame, vals_in, dims_in, *, axis_name,
remaining_axis_names = tuple(n for n in axis_name if n != frame.name)
out = pdot_p.bind(x, y, axis_name=remaining_axis_names,
pos_contract=[x_pos_contract, y_pos_contract],
pos_batch=[x_pos_batch, y_pos_batch])
pos_batch=[x_pos_batch, y_pos_batch], precision=precision)
return out, None
batching.collective_rules[pdot_p] = _pdot_vmap_collective_rule

def _pdot_vmap_batching_rule(vals_in, dims_in, *, axis_name, pos_contract,
pos_batch):
pos_batch, precision):
x, y = vals_in
(pos_contract, pos_batch), result_batch_dim = lax._dot_general_batch_dim_nums(
(x.ndim, y.ndim), dims_in, [pos_contract, pos_batch])
out = pdot_p.bind(x, y, axis_name=axis_name, pos_contract=pos_contract,
pos_batch=pos_batch)
pos_batch=pos_batch, precision=precision)
return out, result_batch_dim
batching.primitive_batchers[pdot_p] = _pdot_vmap_batching_rule

def _pdot_translation_rule(c, x, y, *, axis_name, pos_contract, pos_batch,
axis_env, platform):
precision, axis_env, platform):
local_out = lax._dot_general_translation_rule(
c, x, y, dimension_numbers=[pos_contract, pos_batch], precision=None,
c, x, y, dimension_numbers=[pos_contract, pos_batch], precision=precision,
preferred_element_type=None)
if axis_name:
out_tup = xla.parallel_translations[psum_p](
Expand All @@ -1405,15 +1407,15 @@ def _pdot_translation_rule(c, x, y, *, axis_name, pos_contract, pos_batch,
return out
xla.parallel_translations[pdot_p] = _pdot_translation_rule

def _pdot_transpose_lhs(g, y, *, axis_name, pos_contract, pos_batch):
def _pdot_transpose_lhs(g, y, *, axis_name, pos_contract, pos_batch, precision):
# TODO: avals with names, call pbroadcast with axis_name
return lax._dot_general_transpose_lhs(
g, y, dimension_numbers=[pos_contract, pos_batch], precision=None,
g, y, dimension_numbers=[pos_contract, pos_batch], precision=precision,
preferred_element_type=None)
def _pdot_transpose_rhs(g, x, *, axis_name, pos_contract, pos_batch):
def _pdot_transpose_rhs(g, x, *, axis_name, pos_contract, pos_batch, precision):
# TODO: avals with names, call pbroadcast with axis_name
return lax._dot_general_transpose_rhs(
g, x, dimension_numbers=[pos_contract, pos_batch], precision=None,
g, x, dimension_numbers=[pos_contract, pos_batch], precision=precision,
preferred_element_type=None)
ad.defbilinear(pdot_p, _pdot_transpose_lhs, _pdot_transpose_rhs)

Expand Down
1 change: 1 addition & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2783,6 +2783,7 @@ def f_jit(x):
for f in [f_jit, f_cond]:
precision = config.jax_default_matmul_precision
try:
FLAGS.jax_default_matmul_precision = None
num_traces = 0
x = jnp.zeros((2, 2))
f(x)
Expand Down
4 changes: 3 additions & 1 deletion tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5291,7 +5291,9 @@ def testPrecision(self):
ones_3d = np.ones((2, 2, 2))
HIGHEST = lax.Precision.HIGHEST

jtu.assert_dot_precision(None, jnp.dot, ones_1d, ones_1d)
jtu.assert_dot_precision(lax.Precision.HIGHEST, jnp.dot, ones_1d, ones_1d)
with jax.default_matmul_precision('tensorfloat32'):
jtu.assert_dot_precision(lax.Precision.HIGH, jnp.dot, ones_1d, ones_1d)
jtu.assert_dot_precision(
HIGHEST,
partial(jnp.dot, precision=HIGHEST),
Expand Down

0 comments on commit 52485d6

Please sign in to comment.