Skip to content
58 changes: 36 additions & 22 deletions src/azure-cli/azure/cli/command_modules/vm/_vm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,34 +711,48 @@ def import_aaz_by_profile(profile, module_name):


def generate_ssh_keys_ed25519(private_key_filepath, public_key_filepath):
def _open(filename, mode):
return os.open(filename, flags=os.O_WRONLY | os.O_TRUNC | os.O_CREAT, mode=mode)

from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey

ssh_dir = os.path.dirname(private_key_filepath)
if not os.path.exists(ssh_dir):
os.makedirs(ssh_dir)
os.chmod(ssh_dir, 0o700)
os.makedirs(ssh_dir, mode=0o700)

private_key = Ed25519PrivateKey.generate()
public_key = private_key.public_key()
private_bytes = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.OpenSSH,
encryption_algorithm=serialization.NoEncryption()
)

with os.fdopen(os.open(private_key_filepath, flags=os.O_WRONLY | os.O_TRUNC | os.O_CREAT, mode=384, ), "w", ) as f:
f.write(
private_bytes.decode()
if os.path.isfile(private_key_filepath):
# Try to use existing private key if it exists.
with open(private_key_filepath, "rb") as f:
private_bytes = f.read()
private_key = serialization.load_ssh_private_key(private_bytes, password=None)
logger.warning("Private SSH key file '%s' was found in the directory: '%s'. "
"A paired public key file '%s' will be generated.",
private_key_filepath, ssh_dir, public_key_filepath)

else:
# Otherwise generate new private key.
private_key = Ed25519PrivateKey.generate()

# The private key will look like:
# -----BEGIN OPENSSH PRIVATE KEY-----
# ...
# -----END OPENSSH PRIVATE KEY-----
private_bytes = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.OpenSSH,
encryption_algorithm=serialization.NoEncryption()
)
os.chmod(private_key_filepath, 0o600)

with open(public_key_filepath, 'w') as public_key_file:
s = public_key.public_bytes(
encoding=serialization.Encoding.OpenSSH,
format=serialization.PublicFormat.OpenSSH)
public_key = s.decode(encoding="utf8").replace("\n", "")
public_key_file.write(public_key)
os.chmod(public_key_filepath, 0o644)
with os.fdopen(_open(private_key_filepath, 0o600), "wb") as f:
f.write(private_bytes)

public_key = private_key.public_key()
public_bytes = public_key.public_bytes(
encoding=serialization.Encoding.OpenSSH,
format=serialization.PublicFormat.OpenSSH)

with os.fdopen(_open(public_key_filepath, 0o644), 'wb') as f:
f.write(public_bytes)

return public_key
return public_bytes.decode()