Skip to content

Commit

Permalink
[Mosaic GPU] Add CUPTI profiler alongside events-based implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
andportnoy committed Nov 8, 2024
1 parent ac06d19 commit 775078a
Show file tree
Hide file tree
Showing 7 changed files with 194 additions and 6 deletions.
2 changes: 1 addition & 1 deletion jax/experimental/mosaic/gpu/examples/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ def benchmark_and_verify(
head_dim=head_dim,
**kwargs,
)
out, runtime = profiler.measure(f, q[0], k[0], v[0])
out, runtime = profiler.measure(f)(q[0], k[0], v[0])
out = out[None]

@jax.jit
Expand Down
6 changes: 3 additions & 3 deletions jax/experimental/mosaic/gpu/examples/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def verify(
wgmma_impl=WGMMADefaultImpl,
profiler_spec=prof_spec,
)
z, runtime = profiler.measure(f, x, y)
z, runtime = profiler.measure(f)(x, y)

if rhs_transpose:
dimension_numbers = ((1,), (1,)), ((), ())
Expand All @@ -381,7 +381,7 @@ def ref_f(x, y):
preferred_element_type=out_dtype,
).astype(out_dtype)

ref, ref_runtime = profiler.measure(ref_f, x, y)
ref, ref_runtime = profiler.measure(ref_f)(x, y)
np.testing.assert_allclose(
z.astype(jnp.float32), ref.astype(jnp.float32), atol=1e-3, rtol=1e-3
)
Expand Down Expand Up @@ -425,7 +425,7 @@ def ref_f(x, y):
f = build_kernel(
m, n, k, dtype, dtype, dtype, wgmma_impl=WGMMADefaultImpl, **kwargs
)
_, runtime = profiler.measure(f, x, y)
_, runtime = profiler.measure(f)(x, y)
except ValueError as e:
if "Mosaic GPU kernel exceeds available shared memory" not in str(e):
raise
Expand Down
32 changes: 31 additions & 1 deletion jax/experimental/mosaic/gpu/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _record_event(args, event):
treedef, record_event_p.bind(*flat_args, event=event)
)

def measure(f, *args, **kwargs):
def _measure_events(f, *args, **kwargs):
# TODO(apaszke): Raise if this is called under jit.
start_event = mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_create()
end_event = mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_create()
Expand All @@ -102,6 +102,36 @@ def run(*args, **kwargs):
return results, elapsed


def _measure_cupti(f, aggregate):
def wrapper(*args, **kwargs):
mosaic_gpu_lib._mosaic_gpu_ext._cupti_init()
try:
results = jax.block_until_ready(jax.jit(f)(*args, **kwargs))
finally:
timings = mosaic_gpu_lib._mosaic_gpu_ext._cupti_get_timings()
if not timings:
return results, None
elif aggregate:
return results, sum(item[1] for item in timings)
else:
return results, timings
return wrapper


def measure(f, mode="cupti", aggregate=True):
match mode:
case "cupti":
return _measure_cupti(f, aggregate)
case "events":
if aggregate == False:
raise ValueError(f"{aggregate=} is not supported with {mode=}")
def measure_events_wrapper(*args, **kwargs):
return _measure_events(f, *args, **kwargs)
return measure_events_wrapper
case _:
raise ValueError(f"Unrecognized profiler mode {mode}")


class ProfilerSpec:
ENTER = 0
EXIT = 1 << 31
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/pallas/ops/gpu/attention_mgpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def main(unused_argv):
for block_q, block_kv in param_it:
config = TuningConfig(block_q=block_q, block_kv=block_kv, max_concurrent_steps=2)
try:
out, runtime_ms = profiler.measure(functools.partial(attention, config=config), q, k, v)
out, runtime_ms = profiler.measure(functools.partial(attention, config=config))(q, k, v)
out_ref = attention_reference(q, k, v)
np.testing.assert_allclose(out, out_ref, atol=2e-3, rtol=1e-3)
except ValueError as e:
Expand Down
1 change: 1 addition & 0 deletions jaxlib/mosaic/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ pybind_extension(
"//jaxlib/cuda:cuda_vendor",
"@com_google_absl//absl/strings",
"@nanobind",
"@xla//xla/pjrt:exceptions",
"@xla//xla/service:custom_call_status",
"@xla//xla/tsl/cuda:cudart",
],
Expand Down
102 changes: 102 additions & 0 deletions jaxlib/mosaic/gpu/mosaic_gpu_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@ limitations under the License.
#include <string>

#include "nanobind/nanobind.h"
#include "nanobind/stl/tuple.h"
#include "nanobind/stl/vector.h"
#include "absl/strings/str_cat.h"
#include "jaxlib/gpu/vendor.h"
#include "jaxlib/kernel_nanobind_helpers.h"
#include "xla/pjrt/exceptions.h"
#include "xla/service/custom_call_status.h"

namespace jax::cuda {
Expand Down Expand Up @@ -48,6 +51,75 @@ void EventRecordCall(void* stream, void** buffers, char* opaque,
}
}

#define THROW(...) \
do { \
throw xla::XlaRuntimeError( \
absl::StrCat("Mosaic GPU profiler error: ", __VA_ARGS__)); \
} while (0)

#define THROW_IF(expr, ...) \
do { \
if (expr) THROW(__VA_ARGS__); \
} while (0)

#define THROW_IF_CUPTI_ERROR(expr, ...) \
do { \
CUptiResult _result = (expr); \
if (_result != CUPTI_SUCCESS) { \
const char *s; \
cuptiGetErrorMessage(_result, &s); \
THROW(s, ": " __VA_OPT__(, ) __VA_ARGS__); \
} \
} while (0)

// CUPTI can only have one subscriber per process, so it's ok to make the
// profiler state global.
struct {
CUpti_SubscriberHandle subscriber;
std::vector<std::tuple<const char * /*kernel_name*/, double /*ms*/>> timings;
} profiler_state;

void callback_request(uint8_t **buffer, size_t *size, size_t *maxNumRecords) {
// 10 MiB buffer size is generous but somewhat arbitrary, it's at the upper
// bound of what's recommended in CUPTI documentation:
// https://docs.nvidia.com/cupti/main/main.html#cupti-callback-api:~:text=For%20typical%20workloads%2C%20it%E2%80%99s%20suggested%20to%20choose%20a%20size%20between%201%20and%2010%20MB.
const int buffer_size = 10 * (1 << 20);
// 8 byte alignment is specified in the official CUPTI code samples, see
// extras/CUPTI/samples/common/helper_cupti_activity.h in your CUDA
// installation.
*buffer = new (std::align_val_t(8)) uint8_t[buffer_size];
*size = buffer_size;
*maxNumRecords = 0;
}

void callback_complete(CUcontext context, uint32_t streamId,
uint8_t *buffer_raw, size_t size, size_t validSize) {
// take ownership of the buffer once CUPTI is done using it
std::unique_ptr<uint8_t> buffer = absl::WrapUnique(buffer_raw);
CUpti_Activity *record = nullptr;
for (;;) {
CUptiResult status =
cuptiActivityGetNextRecord(buffer.get(), validSize, &record);
if (status == CUPTI_SUCCESS) {
if (record->kind == CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL) {
// TODO(andportnoy) handle multi-GPU
CUpti_ActivityKernel9 *kernel = (CUpti_ActivityKernel9 *)record;
// Convert integer nanoseconds to floating point milliseconds to match
// the interface of the events-based profiler.
double duration_ms = (kernel->end - kernel->start) / 1e6;
profiler_state.timings.push_back(
std::make_tuple(kernel->name, duration_ms));
}
} else if (status == CUPTI_ERROR_MAX_LIMIT_REACHED) {
// no more records available
break;
} else {
THROW_IF_CUPTI_ERROR(status);
}
}
}


NB_MODULE(_mosaic_gpu_ext, m) {
m.def("_gpu_event_create", []() {
gpuEvent_t* event = new gpuEvent_t();
Expand Down Expand Up @@ -91,6 +163,36 @@ NB_MODULE(_mosaic_gpu_ext, m) {
}
}
});
m.def("_cupti_init", []() {
profiler_state.timings.clear();
// Ok to pass nullptr for the callback here because we don't register any
// callbacks through cuptiEnableCallback.
auto subscribe_result = cuptiSubscribe(
&profiler_state.subscriber, /*callback=*/nullptr, /*userdata=*/nullptr);
if (subscribe_result == CUPTI_ERROR_MULTIPLE_SUBSCRIBERS_NOT_SUPPORTED) {
THROW(
"Attempted to subscribe to CUPTI while another subscriber, such as "
"Nsight Systems or Nsight Compute, is active. CUPTI backend of the "
"Mosaic GPU profiler cannot be used in that mode since CUPTI does "
"not support multiple subscribers.");
}
THROW_IF_CUPTI_ERROR(subscribe_result, "failed to subscribe to CUPTI");
THROW_IF_CUPTI_ERROR(
cuptiActivityRegisterCallbacks(callback_request, callback_complete),
"failed to register CUPTI activity callbacks");
cuptiActivityEnable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL);
THROW_IF_CUPTI_ERROR(
cuptiActivityEnable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL),
"failed to enable tracking of kernel activity by CUPTI");
});
m.def("_cupti_get_timings", []() {
THROW_IF_CUPTI_ERROR(cuptiUnsubscribe(profiler_state.subscriber),
"failed to unsubscribe from CUPTI");
THROW_IF_CUPTI_ERROR(cuptiActivityFlushAll(CUPTI_ACTIVITY_FLAG_NONE),
"failed to flush CUPTI activity buffers");
THROW_IF_CUPTI_ERROR(cuptiFinalize(), "failed to detach CUPTI");
return profiler_state.timings;
});
}

} // namespace
Expand Down
55 changes: 55 additions & 0 deletions tests/mosaic/gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1583,6 +1583,61 @@ def kernel(ctx, src, dst, _):
jax.block_until_ready(f(xd))


class ProfilerMeasureTest(TestCase):

def setUp(self):
self.x = jnp.arange(1024 * 1024)
self.f = lambda x: 2*x

def test_measure(self):
_, runtime_ns = profiler.measure(self.f)(self.x)
self.assertIsInstance(runtime_ns, float)

def test_measure_cupti_explicit(self):
_, runtime_ns = profiler.measure(self.f, mode="cupti")(self.x)
self.assertIsInstance(runtime_ns, float)

def test_measure_events_explicit(self):
_, runtime_ns = profiler.measure(self.f, mode="events")(self.x)
self.assertIsInstance(runtime_ns, float)

def test_measure_per_kernel(self):
_, runtimes_ns = profiler.measure(self.f, aggregate=False)(self.x)
for item in runtimes_ns:
self.assertIsInstance(item, tuple)
self.assertEqual(len(item), 2)
name, runtime_ns = item
self.assertIsInstance(name, str)
self.assertIsInstance(runtime_ns, float)

def test_measure_cupti_repeated(self):
f_profiled = profiler.measure(self.f, mode="cupti")
n = 3
timings = [f_profiled(self.x)[1] for _ in range(n)]
for item in timings:
self.assertIsInstance(item, float)

def test_measure_repeated_interleaved(self):
# test that kernels run outside of measure() are not captured
_, timings = profiler.measure(self.f, mode='cupti', aggregate=False)(self.x)
self.assertEqual(len(timings), 1)
self.f(self.x)
_, timings = profiler.measure(self.f, mode='cupti', aggregate=False)(self.x)
self.assertEqual(len(timings), 1)

def test_measure_double_subscription(self):
# This needs to run in a separate process, otherwise it affects the
# outcomes of other tests since CUPTI state is global.
self.skipTest("Must run in a separate process from other profiler tests")
# Initialize profiler manually, which subscribes to CUPTI. There can only
# be one CUPTI subscriber at a time.
jax._src.lib.mosaic_gpu._mosaic_gpu_ext._cupti_init()
with self.assertRaisesRegex(RuntimeError,
"Attempted to subscribe to CUPTI while another subscriber, "
"such as Nsight Systems or Nsight Compute, is active."):
profiler.measure(self.f, aggregate=False)(self.x)


class TorchTest(TestCase):

@classmethod
Expand Down

0 comments on commit 775078a

Please sign in to comment.