Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add functions to emit custom call to place a buffer to host and device. #8350

Merged
merged 2 commits into from
Nov 4, 2024

Commits on Nov 1, 2024

  1. Add functions to emit custom call to place a buffer to host and device.

    This is used for host-offloading.
    
    example code of what jax emits:
    ```python
    def policy(prim, *avals, **params) -> Offloadable:
      return Offloadable(src='device', dst='pinned_host')
    
    @functools.partial(jax.remat, policy=policy)
    def f(x):
      x = jnp.sin(x)
      x = jnp.sin(x)
      return jnp.sum(x)
    ```
    
    becomes:
    ```mlir
    module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
      func.func public @main(%arg0: tensor<16xf32> {mhlo.layout_mode = "default"}) -> (tensor<16xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
        %0 = stablehlo.sine %arg0 : tensor<16xf32>
        %1 = stablehlo.cosine %arg0 : tensor<16xf32>
        %2 = stablehlo.custom_call @annotate_device_placement(%1) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "pinned_host"}} : (tensor<16xf32>) -> tensor<16xf32>
        %3 = stablehlo.cosine %0 : tensor<16xf32>
        %4 = stablehlo.custom_call @annotate_device_placement(%3) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "pinned_host"}} : (tensor<16xf32>) -> tensor<16xf32>
        %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
        %5:3 = stablehlo.optimization_barrier %2, %4, %cst : tensor<16xf32>, tensor<16xf32>, tensor<f32>
        %6 = stablehlo.custom_call @annotate_device_placement(%5#0) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "device"}} : (tensor<16xf32>) -> tensor<16xf32>
        %7 = stablehlo.custom_call @annotate_device_placement(%5#1) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "device"}} : (tensor<16xf32>) -> tensor<16xf32>
        %8 = stablehlo.broadcast_in_dim %5#2, dims = [] : (tensor<f32>) -> tensor<16xf32>
        %9 = stablehlo.multiply %8, %7 : tensor<16xf32>
        %10 = stablehlo.multiply %9, %6 : tensor<16xf32>
        return %10 : tensor<16xf32>
      }
    }
    ```
    qihqi committed Nov 1, 2024
    Configuration menu
    Copy the full SHA
    fc408f2 View commit details
    Browse the repository at this point in the history
  2. yapf

    qihqi committed Nov 1, 2024
    Configuration menu
    Copy the full SHA
    374d4ab View commit details
    Browse the repository at this point in the history