Skip to content

Commit

Permalink
pytorch array backend
Browse files Browse the repository at this point in the history
  • Loading branch information
ahnitz committed Dec 2, 2024
1 parent b60abb0 commit b057111
Show file tree
Hide file tree
Showing 4 changed files with 681 additions and 64 deletions.
161 changes: 104 additions & 57 deletions pycbc/scheme.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,20 @@
# Copyright (C) 2014 Alex Nitz, Andrew Miller
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
# scheme.py


#
# =============================================================================
#
# Preamble
#
# =============================================================================
#
"""
This modules provides python contexts that set the default behavior for PyCBC
This module provides python contexts that set the default behavior for PyCBC
objects.
"""
import os
import pycbc
from functools import wraps
import logging
import ctypes
from .libutils import get_ctypes_library

logger = logging.getLogger('pycbc.scheme')
Expand All @@ -42,23 +27,24 @@ def __init__(self):

if _SchemeManager._single is not None:
raise RuntimeError("SchemeManager is a private class")
_SchemeManager._single= self
_SchemeManager._single = self

self.state= None
self._lock= False
self.state = None
self._lock = False

def lock(self):
self._lock= True
self._lock = True

def unlock(self):
self._lock= False
self._lock = False

def shift_to(self, state):
if self._lock is False:
self.state = state
else:
raise RuntimeError("The state is locked, cannot shift schemes")


# Create the global processing scheme manager
mgr = _SchemeManager()
DefaultScheme = None
Expand All @@ -68,32 +54,39 @@ def shift_to(self, state):
class Scheme(object):
"""Context that sets PyCBC objects to use CPU processing. """
_single = None

def __init__(self):
if DefaultScheme is type(self):
return
if Scheme._single is not None:
raise RuntimeError("Only one processing scheme can be used")
Scheme._single = True

def __enter__(self):
mgr.shift_to(self)
mgr.lock()

def __exit__(self, type, value, traceback):
mgr.unlock()
mgr.shift_to(default_context)

def __del__(self):
if Scheme is not None:
Scheme._single = None

_cuda_cleanup_list=[]

_cuda_cleanup_list = []


def register_clean_cuda(function):
_cuda_cleanup_list.append(function)


def clean_cuda(context):
#Before cuda context is destroyed, all item destructions dependent on cuda
# Before CUDA context is destroyed, all item destructions dependent on CUDA
# must take place. This calls all functions that have been registered
# with _register_clean_cuda() in reverse order
#So the last one registered, is the first one cleaned
# with register_clean_cuda() in reverse order
# So the last one registered, is the first one cleaned
_cuda_cleanup_list.reverse()
for func in _cuda_cleanup_list:
func()
Expand All @@ -102,6 +95,7 @@ def clean_cuda(context):
from pycuda.tools import clear_context_caches
clear_context_caches()


class CUDAScheme(Scheme):
"""Context that sets PyCBC objects to use a CUDA processing scheme. """
def __init__(self, device_num=0):
Expand All @@ -111,14 +105,42 @@ def __init__(self, device_num=0):
import pycuda.driver
pycuda.driver.init()
self.device = pycuda.driver.Device(device_num)
self.context = self.device.make_context(flags=pycuda.driver.ctx_flags.SCHED_BLOCKING_SYNC)
self.context = self.device.make_context(
flags=pycuda.driver.ctx_flags.SCHED_BLOCKING_SYNC)
import atexit
atexit.register(clean_cuda,self.context)
atexit.register(clean_cuda, self.context)


class TorchScheme(Scheme):
"""Context that sets PyCBC objects to use a PyTorch processing scheme."""
def __init__(self, device='cpu'):
Scheme.__init__(self)
self.device = device
# Check if CUDA is available for PyTorch if device is not CPU
if self.device != 'cpu':
import torch
if not torch.cuda.is_available():
raise RuntimeError("CUDA device not available for PyTorch")
logger.info(f"PyTorch using device: {self.device}")

def __enter__(self):
Scheme.__enter__(self)
# Set the default device for PyTorch tensors
import torch
torch_device = torch.device(self.device)
# No need to set default tensor type; tensors can specify device directly
self.torch_device = torch_device
logger.info(f"Entered TorchScheme with device: {self.device}")

def __exit__(self, type, value, traceback):
Scheme.__exit__(self, type, value, traceback)
logger.info("Exited TorchScheme")


class CPUScheme(Scheme):
def __init__(self, num_threads=1):
if isinstance(num_threads, int):
self.num_threads=num_threads
self.num_threads = num_threads
elif num_threads == 'env' and "PYCBC_NUM_THREADS" in os.environ:
self.num_threads = int(os.environ["PYCBC_NUM_THREADS"])
else:
Expand All @@ -139,26 +161,29 @@ def __enter__(self):

os.environ["OMP_NUM_THREADS"] = str(self.num_threads)
if self._libgomp is not None:
self._libgomp.omp_set_num_threads( int(self.num_threads) )
self._libgomp.omp_set_num_threads(int(self.num_threads))

def __exit__(self, type, value, traceback):
os.environ["OMP_NUM_THREADS"] = "1"
if self._libgomp is not None:
self._libgomp.omp_set_num_threads(1)
Scheme.__exit__(self, type, value, traceback)


class MKLScheme(CPUScheme):
def __init__(self, num_threads=1):
CPUScheme.__init__(self, num_threads)
if not pycbc.HAVE_MKL:
raise RuntimeError("Can't find MKL libraries")


class NumpyScheme(CPUScheme):
pass


scheme_prefix = {
CUDAScheme: "cuda",
TorchScheme: "torch", # Changed to 'torch' as the scheme name
CPUScheme: "cpu",
MKLScheme: "mkl",
NumpyScheme: "numpy",
Expand All @@ -176,17 +201,23 @@ class NumpyScheme(CPUScheme):
),
)


class DefaultScheme(_default_scheme_class):
pass


default_context = DefaultScheme()
mgr.state = default_context
scheme_prefix[DefaultScheme] = _default_scheme_prefix


def current_prefix():
return scheme_prefix[type(mgr.state)]


_import_cache = {}


def schemed(prefix):

def scheming_function(func):
Expand All @@ -213,24 +244,26 @@ def _scheming_function(*args, **kwds):
return schemed_fn(*args, **kwds)

err = """Failed to find implementation of (%s)
for %s scheme." % (str(fn), current_prefix())"""
for %s scheme.""" % (str(func), current_prefix())
for emsg in exc_errors:
err += print(emsg)
err += str(emsg) + "\n"
raise RuntimeError(err)
return _scheming_function

return scheming_function


def cpuonly(func):
@wraps(func)
def _cpuonly(*args, **kwds):
if not issubclass(type(mgr.state), CPUScheme):
raise TypeError(fn.__name__ +
raise TypeError(func.__name__ +
" can only be called from a CPU processing scheme.")
else:
return func(*args, **kwds)
return _cpuonly


def insert_processing_option_group(parser):
"""
Adds the options used to choose a processing scheme. This should be used
Expand All @@ -242,26 +275,30 @@ def insert_processing_option_group(parser):
OptionParser instance
"""
processing_group = parser.add_argument_group("Options for selecting the"
" processing scheme in this program.")
" processing scheme in this program.")
scheme_choices = list(set(scheme_prefix.values()))
processing_group.add_argument("--processing-scheme",
help="The choice of processing scheme. "
"Choices are " + str(list(set(scheme_prefix.values()))) +
". (optional for CPU scheme) The number of "
"execution threads "
"can be indicated by cpu:NUM_THREADS, "
"where NUM_THREADS "
"is an integer. The default is a single thread. "
"If the scheme is provided as cpu:env, the number "
"of threads can be provided by the PYCBC_NUM_THREADS "
"environment variable. If the environment variable "
"is not set, the number of threads matches the number "
"of logical cores. ",
default="cpu")
help="The choice of processing scheme. "
"Choices are " + str(scheme_choices) +
". (optional for CPU scheme) The number of "
"execution threads "
"can be indicated by cpu:NUM_THREADS, "
"where NUM_THREADS "
"is an integer. The default is a single thread. "
"If the scheme is provided as cpu:env, the number "
"of threads can be provided by the PYCBC_NUM_THREADS "
"environment variable. If the environment variable "
"is not set, the number of threads matches the number "
"of logical cores. For Torch scheme, you can specify "
"the device as torch:DEVICE, where DEVICE is 'cpu' or "
"'cuda:0', 'cuda:1', etc.",
default="cpu")

processing_group.add_argument("--processing-device-id",
help="(optional) ID of GPU to use for accelerated "
"processing",
default=0, type=int)
help="(optional) ID of GPU to use for accelerated "
"processing",
default=0, type=int)


def from_cli(opt):
"""Parses the command line options and returns a processing scheme.
Expand All @@ -283,6 +320,15 @@ def from_cli(opt):
if name == "cuda":
logger.info("Running with CUDA support")
ctx = CUDAScheme(opt.processing_device_id)
elif name == "torch":
logger.info("Running with Torch (PyTorch) support")
# Get device if specified
if len(scheme_str) > 1:
device = scheme_str[1]
else:
device = 'cpu' # Default to CPU
ctx = TorchScheme(device)
logger.info(f"Torch device set to: {device}")
elif name == "mkl":
if len(scheme_str) > 1:
numt = scheme_str[1]
Expand All @@ -303,11 +349,11 @@ def from_cli(opt):
logger.info("Running with CPU support: %s threads" % ctx.num_threads)
return ctx


def verify_processing_options(opt, parser):
"""Parses the processing scheme options and verifies that they are
"""Parses the processing scheme options and verifies that they are
reasonable.
Parameters
----------
opt : object
Expand All @@ -318,18 +364,19 @@ def verify_processing_options(opt, parser):
"""
scheme_types = scheme_prefix.values()
if opt.processing_scheme.split(':')[0] not in scheme_types:
parser.error("(%s) is not a valid scheme type.")
parser.error("(%s) is not a valid scheme type." % opt.processing_scheme)


class ChooseBySchemeDict(dict):
""" This class represents a dictionary whose purpose is to chose objects
""" This class represents a dictionary whose purpose is to choose objects
based on their processing scheme. The keys are intended to be processing
schemes.
"""
def __getitem__(self, scheme):
for base in scheme.__mro__[0:-1]:
try:
return dict.__getitem__(self, base)
break
except:
except KeyError:
pass
raise KeyError("Scheme not found in ChooseBySchemeDict")

20 changes: 16 additions & 4 deletions pycbc/types/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@ def noreal(self, *args, **kwargs):
raise TypeError( func.__name__ + " does not support real types")
return noreal

@schemed(BACKEND_PREFIX)
def _scheme_get_numpy_dtype(dtype):
"""Get the NumPy dtype corresponding to the scheme's data type."""
# This function is intended to be overridden by schemes.
# If not overridden, it will raise NotImplementedError.
raise NotImplementedError("_scheme_get_numpy_dtype not implemented for this scheme.")

def force_precision_to_match(scalar, precision):
if _numpy.iscomplexobj(scalar):
if precision == 'single':
Expand Down Expand Up @@ -169,13 +176,13 @@ def __init__(self, initial_array, dtype=None, copy=True):
self._data = initial_array

# Check that the dtype is supported.
if self._data.dtype not in _ALLOWED_DTYPES:
data_dtype = _scheme_get_numpy_dtype(self._data.dtype)
if data_dtype not in _ALLOWED_DTYPES:
raise TypeError(str(self._data.dtype) + ' is not supported')

if dtype and dtype != self._data.dtype:
if dtype and dtype != _scheme_get_numpy_dtype(self._data.dtype):
raise TypeError("Can only set dtype when allowed to copy data")


if copy:
# First we will check the dtype that we are given
if not hasattr(initial_array, 'dtype'):
Expand Down Expand Up @@ -981,7 +988,12 @@ def lal(self):

@property
def dtype(self):
return self._data.dtype
"""Return the NumPy dtype of the array."""
try:
return _scheme_get_numpy_dtype(self._data.dtype)
except (NotImplementedError, AttributeError):
# Fallback: assume self._data.dtype is a NumPy dtype
return self._data.dtype

def save(self, path, group=None):
"""
Expand Down
Loading

0 comments on commit b057111

Please sign in to comment.