Skip to content

Commit

Permalink
Pallas pipeline API tweaks for more advanced pipelining patterns.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 674532944
  • Loading branch information
epiqueras authored and Google-ML-Automation committed Sep 21, 2024
1 parent d63afd8 commit 3bcc5a1
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
1 change: 1 addition & 0 deletions jax/_src/pallas/mosaic/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ py_library(
":primitives",
"//jax",
"//jax:api_util",
"//jax:pallas",
"//jax:util",
"//jax/_src/pallas",
] + py_deps("numpy"),
Expand Down
12 changes: 12 additions & 0 deletions jax/_src/pallas/mosaic/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ class BufferType(enum.Enum):
ACCUMULATOR = 3
INPUT_OUTPUT = 4

MANUAL = 9


@tree_util.register_pytree_node_class
@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -234,6 +236,10 @@ def tree_flatten(self):
def tree_unflatten(cls, meta, data):
return cls(*meta, *data)

@staticmethod
def buffer_types() -> type[BufferType]:
return BufferType

@classmethod
def create(cls, spec, dtype, buffer_type) -> BufferedRef:
"""Create a BufferedRef.
Expand Down Expand Up @@ -1034,6 +1040,7 @@ def pipeline(
prefetch=None,
postyeet=None,
schedule=None,
body_prologue=None,
):
"""
Run the pipeline.
Expand All @@ -1056,6 +1063,8 @@ def pipeline(
Called during the outputs phase in the first inner step.
schedule: manually specified pipeline schedules for brefs, None indicates
default schedule.
body_prologue: For running code within the grid environment before the
body is run. Useful for updating manual refs.
"""
if scratches is None:
scratches = ()
Expand Down Expand Up @@ -1119,6 +1128,9 @@ def loop_body(step, _):
lambda: None)

# run the kernel!
if body_prologue is not None:
with scheduler.grid_env():
body_prologue()
current_refs = map_brefs(lambda x: x.current_ref, brefs)
with scheduler._named_scope("ep_run_kernel"):
with scheduler.grid_env():
Expand Down

0 comments on commit 3bcc5a1

Please sign in to comment.