-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Pallas] Add a cost estimator for Pallas/JAX functions.
Helps resolve the following issue, where invoking HLO's cost analysis triggers compilation which can OOM for larger inputs: #24539. This cost estimator uses only abstract evaluation which should work for all input sizes. PiperOrigin-RevId: 695415760
- Loading branch information
1 parent
0995bc2
commit 0e611e5
Showing
4 changed files
with
325 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,215 @@ | ||
# Copyright 2024 The JAX Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Helper tool for automatic cost estimation.""" | ||
import dataclasses | ||
import math | ||
from typing import Any, Sequence | ||
|
||
from jax._src import core as jax_core | ||
from jax._src.pallas import core as pallas_core | ||
from jax._src import linear_util as lu | ||
from jax._src.interpreters import partial_eval as pe | ||
from jax._src.util import safe_map | ||
from jax._src.util import safe_zip | ||
from jax._src.lax import lax | ||
|
||
map, unsafe_map = safe_map, map # pylint: disable=redefined-builtin | ||
zip, unsafe_zip = safe_zip, zip # pylint: disable=redefined-builtin | ||
|
||
_cost_rules = {} | ||
|
||
@dataclasses.dataclass(frozen=True) | ||
class CostEstimate: | ||
flops: int | ||
transcendentals: int | ||
bytes_accessed: int | ||
|
||
def __add__(self, other: 'CostEstimate') -> 'CostEstimate': | ||
return CostEstimate( | ||
flops=self.flops + other.flops, | ||
transcendentals=self.transcendentals + other.transcendentals, | ||
bytes_accessed=self.bytes_accessed + other.bytes_accessed, | ||
) | ||
|
||
def register_cost_rule(primitive: jax_core.Primitive, rule): | ||
_cost_rules[primitive] = rule | ||
|
||
@dataclasses.dataclass(frozen=True) | ||
class Context: | ||
avals_in: Sequence[Any] | ||
avals_out: Sequence[Any] | ||
|
||
def cost_estimate_jaxpr( | ||
jaxpr: jax_core.ClosedJaxpr, | ||
) -> pallas_core.CostEstimate: | ||
"""Returns the cost estimate for the given Jaxpr.""" | ||
jaxpr, _ = jaxpr.jaxpr, jaxpr.consts | ||
total_cost = CostEstimate(flops=0, transcendentals=0, bytes_accessed=0) | ||
|
||
for eqn in jaxpr.eqns: | ||
_, bind_params = eqn.primitive.get_bind_params(eqn.params) | ||
rule = _cost_rules.get(eqn.primitive, None) | ||
if rule is not None: | ||
context = Context(avals_in=[v.aval for v in eqn.invars], | ||
avals_out=[v.aval for v in eqn.outvars]) | ||
op_cost = rule(context, **bind_params) | ||
total_cost = total_cost + op_cost | ||
return pallas_core.CostEstimate( | ||
flops=total_cost.flops, | ||
transcendentals=total_cost.transcendentals, | ||
bytes_accessed=total_cost.bytes_accessed, | ||
) | ||
|
||
def cost_estimate(fun, *args) -> pallas_core.CostEstimate: | ||
"""Computes a cost estimate for the given function. | ||
Args: | ||
fun: The function to compute the cost estimate for. | ||
*args: The arguments to the function. Can be jax.ShapeDtypeStruct or | ||
jax.Array. | ||
Returns: | ||
A pallas_core.CostEstimate object containing the cost estimate. | ||
""" | ||
wrapped_fun = lu.wrap_init(lambda *args, **kwargs: (fun(*args, **kwargs),)) | ||
avals = [jax_core.ShapedArray(a.shape, a.dtype) for a in args] | ||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals) | ||
estimate = cost_estimate_jaxpr(jax_core.ClosedJaxpr(jaxpr, consts)) | ||
input_bytes = sum(math.prod(a.shape) * a.dtype.itemsize for a in args) | ||
output_bytes = sum( | ||
math.prod(a.aval.shape) * a.aval.dtype.itemsize for a in jaxpr.outvars) | ||
return pallas_core.CostEstimate( | ||
flops=estimate.flops, | ||
transcendentals=estimate.transcendentals, | ||
bytes_accessed=estimate.bytes_accessed + input_bytes + output_bytes, | ||
) | ||
|
||
def binary_cost_rule(ctx: Context, **_) -> CostEstimate: | ||
aval_out, = ctx.avals_out | ||
out_flops = math.prod(aval_out.shape) | ||
return CostEstimate( | ||
flops=out_flops, | ||
transcendentals=0, | ||
bytes_accessed=0, | ||
) | ||
BINARY_OPS = [ | ||
lax.add_p, | ||
lax.mul_p, | ||
lax.sub_p, | ||
lax.div_p, | ||
lax.min_p, | ||
lax.max_p, | ||
lax.or_p, | ||
lax.and_p, | ||
lax.xor_p, | ||
] | ||
for op in BINARY_OPS: | ||
register_cost_rule(op, binary_cost_rule) | ||
|
||
|
||
def unary_cost_rule(transcendental: bool): | ||
def cost_rule(ctx: Context, **_) -> CostEstimate: | ||
x_aval, = ctx.avals_in | ||
new_flops = 0 | ||
new_transcendentals = 0 | ||
if transcendental: | ||
new_transcendentals += math.prod(x_aval.shape) | ||
else: | ||
new_flops += math.prod(x_aval.shape) | ||
return CostEstimate( | ||
flops=new_flops, | ||
transcendentals=new_transcendentals, | ||
bytes_accessed=0, | ||
) | ||
return cost_rule | ||
|
||
UN_OPS = [ | ||
lax.neg_p, | ||
lax.floor_p, | ||
lax.ceil_p, | ||
lax.round_p, | ||
lax.not_p, | ||
] | ||
for op in UN_OPS: | ||
register_cost_rule(op, unary_cost_rule(transcendental=False)) | ||
|
||
TRANSCENDENTAL_OPS = [ | ||
lax.cos_p, | ||
lax.sin_p, | ||
lax.tan_p, | ||
lax.sinh_p, | ||
lax.cosh_p, | ||
lax.tanh_p, | ||
lax.acos_p, | ||
lax.asin_p, | ||
lax.atan_p, | ||
lax.exp_p, | ||
lax.log_p, | ||
lax.logistic_p, | ||
lax.sqrt_p, | ||
] | ||
for op in TRANSCENDENTAL_OPS: | ||
register_cost_rule(op, unary_cost_rule(transcendental=True)) | ||
|
||
def _integer_pow_cost_rule(ctx: Context, *, y: int) -> CostEstimate: | ||
x_aval, = ctx.avals_in | ||
num_elements = math.prod(x_aval.shape) | ||
if y == 0 or y == 1: | ||
# No flops, the result is 0 or a copy of the input. | ||
cost_per_element = 0 | ||
else: | ||
# We assume integer pow is implemented using repeated squaring. | ||
# The cost is log(y) squarings, plus one multiply per non-zero bit. | ||
highest_bit = math.floor(math.log(y, 2)) | ||
cost_per_element = highest_bit + y.bit_count() | ||
return CostEstimate( | ||
flops=num_elements * cost_per_element, | ||
transcendentals=0, | ||
bytes_accessed=0, | ||
) | ||
register_cost_rule(lax.integer_pow_p, _integer_pow_cost_rule) | ||
|
||
def dot_general_cost_rule(ctx: Context, | ||
dimension_numbers: lax.DotDimensionNumbers, | ||
**_) -> CostEstimate: | ||
x_aval, y_aval = ctx.avals_in | ||
x_shape, y_shape = x_aval.shape, y_aval.shape | ||
(lhs_contracting_dims, rhs_contracting_dims), ( | ||
lhs_batch_dims, rhs_batch_dims) = dimension_numbers | ||
assert len(lhs_contracting_dims) == len(rhs_contracting_dims) | ||
assert len(lhs_batch_dims) == len(rhs_batch_dims) | ||
flops = 1 | ||
# Flops along a contracting dim is 2*dim (addition and multiplication) | ||
for i in range(len(lhs_contracting_dims)): | ||
lhs_dim, rhs_dim = lhs_contracting_dims[i], rhs_contracting_dims[i] | ||
assert x_shape[lhs_dim] == y_shape[rhs_dim] | ||
flops *= 2 * x_shape[lhs_dim] | ||
# Now we handle all other dimensions. | ||
for i, lhs_dim in enumerate(x_shape): | ||
if i in lhs_contracting_dims: | ||
continue | ||
flops *= lhs_dim | ||
for i, rhs_dim in enumerate(y_shape): | ||
if i in rhs_contracting_dims: | ||
continue | ||
# Don't double-count batch dims (we already counted for LHS) | ||
if i in rhs_batch_dims: | ||
continue | ||
flops *= rhs_dim | ||
return CostEstimate( | ||
flops=flops, | ||
transcendentals=0, | ||
bytes_accessed=0, | ||
) | ||
register_cost_rule(lax.dot_general_p, dot_general_cost_rule) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
# Copyright 2024 The JAX Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from absl.testing import absltest | ||
from absl.testing import parameterized | ||
import jax | ||
from jax import lax | ||
from jax import numpy as jnp | ||
from jax._src import config | ||
from jax._src import test_util as jtu | ||
from jax._src.pallas import cost_estimate | ||
|
||
|
||
config.parse_flags_with_absl() | ||
|
||
|
||
class PallasCostEstimateTest(jtu.JaxTestCase): | ||
|
||
def test_exp_add(self): | ||
def exp_add(x, y): | ||
return jnp.exp(x + y) | ||
cost = cost_estimate.cost_estimate(exp_add, | ||
jnp.ones(10, dtype=jnp.float32), | ||
jnp.ones(10, dtype=jnp.float32)) | ||
self.assertEqual(cost.flops, 10) | ||
self.assertEqual(cost.transcendentals, 10) | ||
self.assertEqual(cost.bytes_accessed, 4 * 30) | ||
|
||
def test_very_large_matmul(self): | ||
def matmul(a, b): | ||
return a @ b | ||
m, k, n = 400_000, 800_000, 900_000 | ||
cost = cost_estimate.cost_estimate( | ||
matmul, | ||
jax.ShapeDtypeStruct((m, k), jnp.bfloat16), | ||
jax.ShapeDtypeStruct((k, n), jnp.bfloat16)) | ||
self.assertEqual(cost.flops, 2*m*k*n) | ||
self.assertEqual(cost.transcendentals, 0) | ||
self.assertEqual(cost.bytes_accessed, 2*(m*k + n*k + m*n)) | ||
|
||
def test_batched_matmul(self): | ||
def matmul(a, b): | ||
return jnp.matmul(a, b) | ||
b, m, k, n = 7, 37, 91, 23 | ||
cost = cost_estimate.cost_estimate( | ||
matmul, | ||
jax.ShapeDtypeStruct((b, m, k), jnp.float32), | ||
jax.ShapeDtypeStruct((b, k, n), jnp.float32)) | ||
self.assertEqual(cost.flops, 2*b*m*k*n) | ||
self.assertEqual(cost.transcendentals, 0) | ||
self.assertEqual(cost.bytes_accessed, 4*(b*m*k + b*n*k + b*m*n)) | ||
|
||
def test_attention(self): | ||
qk_dim = 16 | ||
v_dim = 4 | ||
kv_len = 128 | ||
q_len = 64 | ||
def attention(q, k, v): | ||
return jax.nn.softmax(q @ k.T, axis=-1) @ v | ||
cost = cost_estimate.cost_estimate( | ||
attention, | ||
jnp.zeros((q_len, qk_dim), dtype=jnp.float32), | ||
jnp.zeros((kv_len, qk_dim), dtype=jnp.float32), | ||
jnp.zeros((kv_len, v_dim), dtype=jnp.float32)) | ||
qk_cost = 2 * q_len * kv_len * qk_dim | ||
v_cost = 2 * q_len * kv_len * v_dim | ||
softmax_flops = kv_len * q_len | ||
self.assertEqual(cost.flops, qk_cost + v_cost + 2 * softmax_flops + q_len) | ||
self.assertEqual(cost.transcendentals, softmax_flops) | ||
input_bytes = q_len * qk_dim + kv_len * qk_dim + kv_len * v_dim | ||
output_bytes = q_len * v_dim | ||
self.assertEqual(cost.bytes_accessed, 4 * (input_bytes + output_bytes)) | ||
|
||
@parameterized.parameters( | ||
(1, 0), (7, 5), (8, 4), (9, 5) | ||
) | ||
def test_integer_pow(self, power, expected_flops_per_element): | ||
cost = cost_estimate.cost_estimate(lambda x: lax.integer_pow(x, power), | ||
jnp.ones(10, dtype=jnp.float32)) | ||
self.assertEqual(cost.flops, 10 * expected_flops_per_element) | ||
self.assertEqual(cost.transcendentals, 0) | ||
self.assertEqual(cost.bytes_accessed, 80) | ||
|
||
if __name__ == "__main__": | ||
absltest.main() |