Skip to content

Commit

Permalink
Apply black forma to skrl folder
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Nov 4, 2024
1 parent 6942e7c commit 362ed94
Show file tree
Hide file tree
Showing 126 changed files with 6,362 additions and 3,714 deletions.
55 changes: 34 additions & 21 deletions skrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# read library version from metadata
try:
import importlib.metadata

__version__ = importlib.metadata.version("skrl")
except ImportError:
__version__ = "unknown"
Expand All @@ -21,15 +22,18 @@
# logger with format
class _Formatter(logging.Formatter):
_format = "[%(name)s:%(levelname)s] %(message)s"
_formats = {logging.DEBUG: f"\x1b[38;20m{_format}\x1b[0m",
logging.INFO: f"\x1b[38;20m{_format}\x1b[0m",
logging.WARNING: f"\x1b[33;20m{_format}\x1b[0m",
logging.ERROR: f"\x1b[31;20m{_format}\x1b[0m",
logging.CRITICAL: f"\x1b[31;1m{_format}\x1b[0m"}
_formats = {
logging.DEBUG: f"\x1b[38;20m{_format}\x1b[0m",
logging.INFO: f"\x1b[38;20m{_format}\x1b[0m",
logging.WARNING: f"\x1b[33;20m{_format}\x1b[0m",
logging.ERROR: f"\x1b[31;20m{_format}\x1b[0m",
logging.CRITICAL: f"\x1b[31;1m{_format}\x1b[0m",
}

def format(self, record):
return logging.Formatter(self._formats.get(record.levelno)).format(record)


_handler = logging.StreamHandler()
_handler.setLevel(logging.DEBUG)
_handler.setFormatter(_Formatter())
Expand All @@ -42,13 +46,11 @@ def format(self, record):
# machine learning framework configuration
class _Config(object):
def __init__(self) -> None:
"""Machine learning framework specific configuration
"""
"""Machine learning framework specific configuration"""

class PyTorch(object):
def __init__(self) -> None:
"""PyTorch configuration
"""
"""PyTorch configuration"""
self._device = None
# torch.distributed config
self._local_rank = int(os.getenv("LOCAL_RANK", "0"))
Expand All @@ -59,7 +61,10 @@ def __init__(self) -> None:
# set up distributed runs
if self._is_distributed:
import torch
logger.info(f"Distributed (rank: {self._rank}, local rank: {self._local_rank}, world size: {self._world_size})")

logger.info(
f"Distributed (rank: {self._rank}, local rank: {self._local_rank}, world size: {self._world_size})"
)
torch.distributed.init_process_group("nccl", rank=self._rank, world_size=self._world_size)
torch.cuda.set_device(self._local_rank)

Expand All @@ -72,6 +77,7 @@ def device(self) -> "torch.device":
"""
try:
import torch

if self._device is None:
return torch.device(f"cuda:{self._local_rank}" if torch.cuda.is_available() else "cpu")
return torch.device(self._device)
Expand Down Expand Up @@ -116,8 +122,7 @@ def is_distributed(self) -> bool:

class JAX(object):
def __init__(self) -> None:
"""JAX configuration
"""
"""JAX configuration"""
self._backend = "numpy"
self._key = np.array([0, 0], dtype=np.uint32)
# distributed config (based on torch.distributed, since JAX doesn't implement it)
Expand All @@ -126,19 +131,26 @@ def __init__(self) -> None:
self._local_rank = int(os.getenv("JAX_LOCAL_RANK", "0"))
self._rank = int(os.getenv("JAX_RANK", "0"))
self._world_size = int(os.getenv("JAX_WORLD_SIZE", "1"))
self._coordinator_address = os.getenv("JAX_COORDINATOR_ADDR", "127.0.0.1") + ":" + os.getenv("JAX_COORDINATOR_PORT", "1234")
self._coordinator_address = (
os.getenv("JAX_COORDINATOR_ADDR", "127.0.0.1") + ":" + os.getenv("JAX_COORDINATOR_PORT", "1234")
)
self._is_distributed = self._world_size > 1
# device
self._device = f"cuda:{self._local_rank}"

# set up distributed runs
if self._is_distributed:
import jax
logger.info(f"Distributed (rank: {self._rank}, local rank: {self._local_rank}, world size: {self._world_size})")
jax.distributed.initialize(coordinator_address=self._coordinator_address,
num_processes=self._world_size,
process_id=self._rank,
local_device_ids=self._local_rank)

logger.info(
f"Distributed (rank: {self._rank}, local rank: {self._local_rank}, world size: {self._world_size})"
)
jax.distributed.initialize(
coordinator_address=self._coordinator_address,
num_processes=self._world_size,
process_id=self._rank,
local_device_ids=self._local_rank,
)

@staticmethod
def parse_device(device: Union[str, "jax.Device", None]) -> "jax.Device":
Expand All @@ -158,7 +170,7 @@ def parse_device(device: Union[str, "jax.Device", None]) -> "jax.Device":
if isinstance(device, jax.Device):
return device
elif isinstance(device, str):
device_type, device_index = f"{device}:0".split(':')[:2]
device_type, device_index = f"{device}:0".split(":")[:2]
try:
return jax.devices(device_type)[int(device_index)]
except (RuntimeError, IndexError) as e:
Expand Down Expand Up @@ -196,11 +208,11 @@ def backend(self, value: str) -> None:

@property
def key(self) -> "jax.Array":
"""Pseudo-random number generator (PRNG) key
"""
"""Pseudo-random number generator (PRNG) key"""
if isinstance(self._key, np.ndarray):
try:
import jax

with jax.default_device(self.device):
self._key = jax.random.PRNGKey(self._key[1])
except ImportError:
Expand Down Expand Up @@ -257,4 +269,5 @@ def is_distributed(self) -> bool:
self.jax = JAX()
self.torch = PyTorch()


config = _Config()
Loading

0 comments on commit 362ed94

Please sign in to comment.