Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Next Release] TBD

### Added
- Extending `dpctl.device_context` with nested contexts (#678)

### Changed
- dpctl-capi is now renamed to `libsyclinterface` (#666).

Expand Down
2 changes: 2 additions & 0 deletions dpctl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
get_current_queue,
get_num_activated_queues,
is_in_device_context,
nested_context_factories,
set_global_queue,
)

Expand Down Expand Up @@ -111,6 +112,7 @@
"get_current_queue",
"get_num_activated_queues",
"is_in_device_context",
"nested_context_factories",
"set_global_queue",
]
__all__ += [
Expand Down
38 changes: 36 additions & 2 deletions dpctl/_sycl_queue_manager.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# cython: linetrace=True

import logging
from contextlib import contextmanager
from contextlib import ExitStack, contextmanager

from .enum_types import backend_type, device_type

Expand Down Expand Up @@ -210,6 +210,22 @@ cpdef get_current_backend():
return _mgr.get_current_backend()


nested_context_factories = []


def _get_nested_contexts(ctxt):
_help_numba_dppy()
return (factory(ctxt) for factory in nested_context_factories)


def _help_numba_dppy():
"""Import numba-dppy for registering nested contexts"""
try:
import numba_dppy
except Exception:
pass


@contextmanager
def device_context(arg):
"""
Expand All @@ -223,6 +239,9 @@ def device_context(arg):
the context manager's scope. The yielded queue is removed as the currently
usable queue on exiting the context manager.

You can register context factory in the list of factories.
This context manager uses context factories to create and activate nested contexts.

Args:
arg : A :class:`dpctl.SyclQueue` object, or a :class:`dpctl.SyclDevice`
object, or a filter selector string.
Expand All @@ -244,11 +263,26 @@ def device_context(arg):
with dpctl.device_context("level0:gpu:0"):
do_something_on_gpu0()

The following example registers nested context factory:

.. code-block:: python

import dctl

def factory(sycl_queue):
...
return context

dpctl.nested_context_factories.append(factory)

"""
ctxt = None
try:
ctxt = _mgr._set_as_current_queue(arg)
yield ctxt
with ExitStack() as stack:
for nested_context in _get_nested_contexts(ctxt):
stack.enter_context(nested_context)
yield ctxt
finally:
# Code to release resource
if ctxt:
Expand Down
105 changes: 79 additions & 26 deletions dpctl/tests/test_sycl_queue_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,34 +17,44 @@
"""Defines unit test cases for the SyclQueueManager class.
"""

import contextlib

import pytest

import dpctl

from ._helper import has_cpu, has_gpu, has_sycl_platforms


@pytest.mark.skipif(
skip_no_platform = pytest.mark.skipif(
not has_sycl_platforms(), reason="No SYCL platforms available"
)
skip_no_gpu = pytest.mark.skipif(
not has_gpu(), reason="No OpenCL GPU queues available"
)
skip_no_cpu = pytest.mark.skipif(
not has_cpu(), reason="No OpenCL CPU queues available"
)


@skip_no_platform
def test_is_in_device_context_outside_device_ctxt():
assert not dpctl.is_in_device_context()


@pytest.mark.skipif(not has_gpu(), reason="No OpenCL GPU queues available")
@skip_no_gpu
def test_is_in_device_context_inside_device_ctxt_gpu():
with dpctl.device_context("opencl:gpu:0"):
assert dpctl.is_in_device_context()


@pytest.mark.skipif(not has_cpu(), reason="No OpenCL CPU queues available")
@skip_no_cpu
def test_is_in_device_context_inside_device_ctxt_cpu():
with dpctl.device_context("opencl:cpu:0"):
assert dpctl.is_in_device_context()


@pytest.mark.skipif(not has_gpu(), reason="No OpenCL GPU queues available")
@pytest.mark.skipif(not has_cpu(), reason="No OpenCL CPU queues available")
@skip_no_gpu
@skip_no_cpu
def test_is_in_device_context_inside_nested_device_ctxt():
with dpctl.device_context("opencl:cpu:0"):
with dpctl.device_context("opencl:gpu:0"):
Expand All @@ -53,7 +63,7 @@ def test_is_in_device_context_inside_nested_device_ctxt():
assert not dpctl.is_in_device_context()


@pytest.mark.skipif(not has_cpu(), reason="No OpenCL CPU queues available")
@skip_no_cpu
def test_is_in_device_context_inside_nested_device_ctxt_cpu():
cpu = dpctl.SyclDevice("cpu")
n = cpu.max_compute_units
Expand All @@ -74,17 +84,13 @@ def test_is_in_device_context_inside_nested_device_ctxt_cpu():
assert 0 == dpctl.get_num_activated_queues()


@pytest.mark.skipif(
not has_sycl_platforms(), reason="No SYCL platforms available"
)
@skip_no_platform
def test_get_current_device_type_outside_device_ctxt():
assert dpctl.get_current_device_type() is not None


@pytest.mark.skipif(
not has_sycl_platforms(), reason="No SYCL platforms available"
)
@pytest.mark.skipif(not has_gpu(), reason="No OpenCL GPU queues available")
@skip_no_platform
@skip_no_gpu
def test_get_current_device_type_inside_device_ctxt():
assert dpctl.get_current_device_type() is not None

Expand All @@ -94,8 +100,8 @@ def test_get_current_device_type_inside_device_ctxt():
assert dpctl.get_current_device_type() is not None


@pytest.mark.skipif(not has_cpu(), reason="No OpenCL CPU queues available")
@pytest.mark.skipif(not has_gpu(), reason="No OpenCL GPU queues available")
@skip_no_cpu
@skip_no_gpu
def test_get_current_device_type_inside_nested_device_ctxt():
assert dpctl.get_current_device_type() is not None

Expand All @@ -109,15 +115,13 @@ def test_get_current_device_type_inside_nested_device_ctxt():
assert dpctl.get_current_device_type() is not None


@pytest.mark.skipif(
not has_sycl_platforms(), reason="No SYCL platforms available"
)
@skip_no_platform
def test_num_current_queues_outside_with_clause():
assert 0 == dpctl.get_num_activated_queues()


@pytest.mark.skipif(not has_gpu(), reason="No OpenCL GPU queues available")
@pytest.mark.skipif(not has_cpu(), reason="No OpenCL CPU queues available")
@skip_no_gpu
@skip_no_cpu
def test_num_current_queues_inside_with_clause():
with dpctl.device_context("opencl:cpu:0"):
assert 1 == dpctl.get_num_activated_queues()
Expand All @@ -126,8 +130,8 @@ def test_num_current_queues_inside_with_clause():
assert 0 == dpctl.get_num_activated_queues()


@pytest.mark.skipif(not has_gpu(), reason="No OpenCL GPU queues available")
@pytest.mark.skipif(not has_cpu(), reason="No OpenCL CPU queues available")
@skip_no_gpu
@skip_no_cpu
def test_num_current_queues_inside_threads():
from threading import Thread

Expand All @@ -144,9 +148,7 @@ def SessionThread():
Session2.start()


@pytest.mark.skipif(
not has_sycl_platforms(), reason="No SYCL platforms available"
)
@skip_no_platform
def test_get_current_backend():
dpctl.get_current_backend()
dpctl.get_current_device_type()
Expand All @@ -156,3 +158,54 @@ def test_get_current_backend():
dpctl.set_global_queue("gpu")
elif has_cpu():
dpctl.set_global_queue("cpu")


def test_nested_context_factory_is_empty_list():
assert isinstance(dpctl.nested_context_factories, list)
assert not dpctl.nested_context_factories


@contextlib.contextmanager
def _register_nested_context_factory(factory):
dpctl.nested_context_factories.append(factory)
yield
dpctl.nested_context_factories.remove(factory)


def test_register_nested_context_factory_context():
def factory():
pass

with _register_nested_context_factory(factory):
assert factory in dpctl.nested_context_factories

assert isinstance(dpctl.nested_context_factories, list)
assert not dpctl.nested_context_factories


@pytest.mark.skipif(not has_gpu(), reason="No OpenCL GPU queues available")
def test_device_context_activates_nested_context():
in_context = False
factory_called = False

@contextlib.contextmanager
def context():
nonlocal in_context
old, in_context = in_context, True
yield
in_context = old

def factory(_):
nonlocal factory_called
factory_called = True
return context()

with _register_nested_context_factory(factory):
assert not factory_called
assert not in_context

with dpctl.device_context("opencl:gpu:0"):
assert factory_called
assert in_context

assert not in_context