diff --git a/jax/_src/config.py b/jax/_src/config.py index 2e7294aa2fe6..c63b1b90b0c3 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -568,7 +568,7 @@ def _update_disable_jit_thread_local(val): default_matmul_precision = config.define_enum_state( name='jax_default_matmul_precision', enum_values=['bfloat16', 'tensorfloat32', 'float32'], - default=None, + default='float32', help=('Control the default matmul and conv precision for 32bit inputs.\n\n' 'Some platforms, like TPU, offer configurable precision levels for ' diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index ebe8fb49f24c..a8e44eeec61e 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -395,11 +395,13 @@ def axis_index(axis_name): return axis_index_p.bind(axis_name=axis_name) -def pdot(x, y, axis_name, pos_contract=((), ()), pos_batch=((), ())): +def pdot(x, y, axis_name, pos_contract=((), ()), pos_batch=((), ()), + precision=None): if not isinstance(axis_name, (list, tuple)): axis_name = (axis_name,) return pdot_p.bind(x, y, axis_name=axis_name, - pos_contract=pos_contract, pos_batch=pos_batch) + pos_contract=pos_contract, pos_batch=pos_batch, + precision=lax._canonicalize_precision(precision)) def xeinsum(spec: str, x, y): @@ -1346,12 +1348,12 @@ def _vmap_process_axis_index(self, frame): core.axis_substitution_rules[pdot_p] = partial(_subst_all_names_in_param, 'axis_name') @pdot_p.def_impl -def _pdot_impl(x, y, *, axis_name, pos_contract, pos_batch): +def _pdot_impl(x, y, *, axis_name, pos_contract, pos_batch, precision): if axis_name: raise NameError(f"unbound axis name: {axis_name[0]}") - return lax.dot_general(x, y, [pos_contract, pos_batch]) + return lax.dot_general(x, y, [pos_contract, pos_batch], precision=precision) @pdot_p.def_abstract_eval -def _pdot_abstract_eval(x, y, *, axis_name, pos_contract, pos_batch): +def _pdot_abstract_eval(x, y, *, axis_name, pos_contract, pos_batch, precision): # TODO(frostig,mattjj,jekbradbury): check inputs have given axis names? if not len(set(axis_name)) == len(axis_name): raise ValueError pos_aval = lax.dot_general_p.abstract_eval( @@ -1364,7 +1366,7 @@ def _pdot_abstract_eval(x, y, *, axis_name, pos_contract, pos_batch): return pos_aval.update(named_shape=named_shape) def _pdot_vmap_collective_rule(frame, vals_in, dims_in, *, axis_name, - pos_contract, pos_batch): + pos_contract, pos_batch, precision): x, y = vals_in x_dim, y_dim = dims_in x_pos_contract, y_pos_contract = pos_contract @@ -1376,24 +1378,24 @@ def _pdot_vmap_collective_rule(frame, vals_in, dims_in, *, axis_name, remaining_axis_names = tuple(n for n in axis_name if n != frame.name) out = pdot_p.bind(x, y, axis_name=remaining_axis_names, pos_contract=[x_pos_contract, y_pos_contract], - pos_batch=[x_pos_batch, y_pos_batch]) + pos_batch=[x_pos_batch, y_pos_batch], precision=precision) return out, None batching.collective_rules[pdot_p] = _pdot_vmap_collective_rule def _pdot_vmap_batching_rule(vals_in, dims_in, *, axis_name, pos_contract, - pos_batch): + pos_batch, precision): x, y = vals_in (pos_contract, pos_batch), result_batch_dim = lax._dot_general_batch_dim_nums( (x.ndim, y.ndim), dims_in, [pos_contract, pos_batch]) out = pdot_p.bind(x, y, axis_name=axis_name, pos_contract=pos_contract, - pos_batch=pos_batch) + pos_batch=pos_batch, precision=precision) return out, result_batch_dim batching.primitive_batchers[pdot_p] = _pdot_vmap_batching_rule def _pdot_translation_rule(c, x, y, *, axis_name, pos_contract, pos_batch, - axis_env, platform): + precision, axis_env, platform): local_out = lax._dot_general_translation_rule( - c, x, y, dimension_numbers=[pos_contract, pos_batch], precision=None, + c, x, y, dimension_numbers=[pos_contract, pos_batch], precision=precision, preferred_element_type=None) if axis_name: out_tup = xla.parallel_translations[psum_p]( @@ -1405,15 +1407,15 @@ def _pdot_translation_rule(c, x, y, *, axis_name, pos_contract, pos_batch, return out xla.parallel_translations[pdot_p] = _pdot_translation_rule -def _pdot_transpose_lhs(g, y, *, axis_name, pos_contract, pos_batch): +def _pdot_transpose_lhs(g, y, *, axis_name, pos_contract, pos_batch, precision): # TODO: avals with names, call pbroadcast with axis_name return lax._dot_general_transpose_lhs( - g, y, dimension_numbers=[pos_contract, pos_batch], precision=None, + g, y, dimension_numbers=[pos_contract, pos_batch], precision=precision, preferred_element_type=None) -def _pdot_transpose_rhs(g, x, *, axis_name, pos_contract, pos_batch): +def _pdot_transpose_rhs(g, x, *, axis_name, pos_contract, pos_batch, precision): # TODO: avals with names, call pbroadcast with axis_name return lax._dot_general_transpose_rhs( - g, x, dimension_numbers=[pos_contract, pos_batch], precision=None, + g, x, dimension_numbers=[pos_contract, pos_batch], precision=precision, preferred_element_type=None) ad.defbilinear(pdot_p, _pdot_transpose_lhs, _pdot_transpose_rhs) diff --git a/tests/api_test.py b/tests/api_test.py index 63a1af0db138..bac4e70f6442 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -2783,6 +2783,7 @@ def f_jit(x): for f in [f_jit, f_cond]: precision = config.jax_default_matmul_precision try: + FLAGS.jax_default_matmul_precision = None num_traces = 0 x = jnp.zeros((2, 2)) f(x) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 204982cb3c9e..846fae452987 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -5291,7 +5291,9 @@ def testPrecision(self): ones_3d = np.ones((2, 2, 2)) HIGHEST = lax.Precision.HIGHEST - jtu.assert_dot_precision(None, jnp.dot, ones_1d, ones_1d) + jtu.assert_dot_precision(lax.Precision.HIGHEST, jnp.dot, ones_1d, ones_1d) + with jax.default_matmul_precision('tensorfloat32'): + jtu.assert_dot_precision(lax.Precision.HIGH, jnp.dot, ones_1d, ones_1d) jtu.assert_dot_precision( HIGHEST, partial(jnp.dot, precision=HIGHEST),