Skip to content

Commit

Permalink
Merge branch 'main' into pypi
Browse files Browse the repository at this point in the history
  • Loading branch information
blumu authored Aug 6, 2024
2 parents 3786b74 + baf4f6e commit 12b3884
Show file tree
Hide file tree
Showing 66 changed files with 1,305,948 additions and 236,015 deletions.
15 changes: 0 additions & 15 deletions createstubs.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,6 @@ createstub asciichartpy
createstub networkx
createstub boolean
createstub IPython


if [ ! -d "typings/gym" ]; then
pyright --createstub gym
# Patch gym stubs
echo ' spaces = ...' >> typings/gym/spaces/dict.pyi
echo ' nvec = ...' >> typings/gym/spaces/space.pyi
echo ' spaces = ...' >> typings/gym/spaces/space.pyi
echo ' spaces = ...' >> typings/gym/spaces/tuple.pyi
echo ' n = ...' >> typings/gym/spaces/multi_binary.pyi
else
echo stub gym already created
fi


echo 'Typing stub generation completed'

popd
28 changes: 14 additions & 14 deletions cyberbattle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

def register(id: str, cyberbattle_env_identifiers: model.Identifiers, **kwargs):
""" same as gym.envs.registry.register, but adds CyberBattle specs to env.spec """
if id in registry.env_specs:
if id in registry:
raise Error('Cannot re-register id: {}'.format(id))
spec = EnvSpec(id, **kwargs)
# Map from port number to port names : List[model.PortName]
Expand All @@ -33,11 +33,11 @@ def register(id: str, cyberbattle_env_identifiers: model.Identifiers, **kwargs):
# Array defining an index for every possible remote vulnerability name : List[model.VulnerabilityID]
spec.remote_vulnerabilities = cyberbattle_env_identifiers.remote_vulnerabilities

registry.env_specs[id] = spec
registry[id] = spec


if 'CyberBattleToyCtf-v0' in registry.env_specs:
del registry.env_specs['CyberBattleToyCtf-v0']
if 'CyberBattleToyCtf-v0' in registry:
del registry['CyberBattleToyCtf-v0']

register(
id='CyberBattleToyCtf-v0',
Expand All @@ -50,8 +50,8 @@ def register(id: str, cyberbattle_env_identifiers: model.Identifiers, **kwargs):
# max_episode_steps=2600,
)

if 'CyberBattleTiny-v0' in registry.env_specs:
del registry.env_specs['CyberBattleTiny-v0']
if 'CyberBattleTiny-v0' in registry:
del registry['CyberBattleTiny-v0']

register(
id='CyberBattleTiny-v0',
Expand All @@ -67,17 +67,17 @@ def register(id: str, cyberbattle_env_identifiers: model.Identifiers, **kwargs):
)


if 'CyberBattleRandom-v0' in registry.env_specs:
del registry.env_specs['CyberBattleRandom-v0']
if 'CyberBattleRandom-v0' in registry:
del registry['CyberBattleRandom-v0']

register(
id='CyberBattleRandom-v0',
cyberbattle_env_identifiers=generate_network.ENV_IDENTIFIERS,
entry_point='cyberbattle._env.cyberbattle_random:CyberBattleRandom',
)

if 'CyberBattleChain-v0' in registry.env_specs:
del registry.env_specs['CyberBattleChain-v0']
if 'CyberBattleChain-v0' in registry:
del registry['CyberBattleChain-v0']

register(
id='CyberBattleChain-v0',
Expand All @@ -95,8 +95,8 @@ def register(id: str, cyberbattle_env_identifiers: model.Identifiers, **kwargs):

ad_envs = [f"ActiveDirectory-v{i}" for i in range(0, 10)]
for (index, env) in enumerate(ad_envs):
if env in registry.env_specs:
del registry.env_specs[env]
if env in registry:
del registry[env]

register(
id=env,
Expand All @@ -110,8 +110,8 @@ def register(id: str, cyberbattle_env_identifiers: model.Identifiers, **kwargs):
}
)

if 'ActiveDirectoryTiny-v0' in registry.env_specs:
del registry.env_specs['ActiveDirectoryTiny-v0']
if 'ActiveDirectoryTiny-v0' in registry:
del registry['ActiveDirectoryTiny-v0']
register(
id='ActiveDirectoryTiny-v0',
cyberbattle_env_identifiers=chainpattern.ENV_IDENTIFIERS,
Expand Down
6 changes: 3 additions & 3 deletions cyberbattle/_env/cyberbattle_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ class CyberBattleChain(cyberbattle_env.CyberBattleEnv):
def __init__(self, size, **kwargs):
self.size = size
super().__init__(
initial_environment=chainpattern.new_environment(size),
**kwargs)
initial_environment=chainpattern.new_environment(size), **kwargs
)

@ property
@property
def name(self) -> str:
return f"CyberBattleChain-{self.size}"
Loading

0 comments on commit 12b3884

Please sign in to comment.