Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 46 additions & 22 deletions src/azure-cli-core/azure/cli/core/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def is_valid_ssh_rsa_public_key(openssh_pubkey):


def generate_ssh_keys(private_key_filepath, public_key_filepath):
import paramiko
from paramiko.ssh_exception import PasswordRequiredException, SSHException
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import serialization

if os.path.isfile(public_key_filepath):
try:
Expand All @@ -57,24 +57,48 @@ def generate_ssh_keys(private_key_filepath, public_key_filepath):
os.chmod(ssh_dir, 0o700)

if os.path.isfile(private_key_filepath):
# try to use existing private key if it exists.
try:
key = paramiko.RSAKey(filename=private_key_filepath)
logger.warning("Private SSH key file '%s' was found in the directory: '%s'. "
"A paired public key file '%s' will be generated.",
private_key_filepath, ssh_dir, public_key_filepath)
except (PasswordRequiredException, SSHException, IOError) as e:
raise CLIError(e)
Comment on lines -66 to -67
Copy link
Member Author

@jiasli jiasli Oct 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't feel the necessity of converting these errors to a CLIError. They should be propagated as it.

This also aligns with azure.cli.command_modules.vm._vm_utils.generate_ssh_keys_ed25519 (#30077):

if os.path.isfile(private_key_filepath):
# Try to use existing private key if it exists.
with open(private_key_filepath, "rb") as f:
private_bytes = f.read()
private_key = serialization.load_ssh_private_key(private_bytes, password=None)
logger.warning("Private SSH key file '%s' was found in the directory: '%s'. "
"A paired public key file '%s' will be generated.",
private_key_filepath, ssh_dir, public_key_filepath)


# Try to use existing private key if it exists.
# https://cryptography.io/en/latest/hazmat/primitives/asymmetric/rsa/#key-loading
with open(private_key_filepath, "rb") as f:
private_bytes = f.read()
private_key = serialization.load_pem_private_key(private_bytes, password=None)
logger.warning("Private SSH key file '%s' was found in the directory: '%s'. "
"A paired public key file '%s' will be generated.",
private_key_filepath, ssh_dir, public_key_filepath)
else:
# otherwise generate new private key.
key = paramiko.RSAKey.generate(2048)
key.write_private_key_file(private_key_filepath)
os.chmod(private_key_filepath, 0o600)

with open(public_key_filepath, 'w') as public_key_file:
public_key = '{} {}'.format(key.get_name(), key.get_base64())
public_key_file.write(public_key)
os.chmod(public_key_filepath, 0o644)

return public_key
# Otherwise generate new private key.
# https://cryptography.io/en/latest/hazmat/primitives/asymmetric/rsa/#generation
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)

# https://cryptography.io/en/latest/hazmat/primitives/asymmetric/rsa/#key-serialization
# The private key will look like:
# -----BEGIN RSA PRIVATE KEY-----
# ...
# -----END RSA PRIVATE KEY-----
private_bytes = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I ask why we are using PrivateFormat.TraditionalOpenSSL instead of PrivateFormat.OpenSSH?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question!

PrivateFormat.TraditionalOpenSSL will generate a private key file like:

-----BEGIN RSA PRIVATE KEY-----
...
-----END RSA PRIVATE KEY-----

while PrivateFormat.OpenSSH will generate a private key file like

-----BEGIN OPENSSH PRIVATE KEY-----
...
-----END OPENSSH PRIVATE KEY-----

Using PrivateFormat.TraditionalOpenSSL makes sure there is no breaking change.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thanks~

encryption_algorithm=serialization.NoEncryption()
)

# Creating the private key file with 600 permission makes sure only the current user can access it.
# Reference: paramiko.pkey.PKey._write_private_key_file
with os.fdopen(_open(private_key_filepath, 0o600), "wb") as f:
f.write(private_bytes)

# Write public key
# The public key will look like:
# ssh-rsa ...
public_key = private_key.public_key()
public_bytes = public_key.public_bytes(
encoding=serialization.Encoding.OpenSSH,
format=serialization.PublicFormat.OpenSSH
)
with os.fdopen(_open(public_key_filepath, 0o644), 'wb') as f:
f.write(public_bytes)

return public_bytes.decode()


def _open(filename, mode):
return os.open(filename, flags=os.O_WRONLY | os.O_TRUNC | os.O_CREAT, mode=mode)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting file mode at creation time avoids the time gap between open and chmod. See #21719

66 changes: 33 additions & 33 deletions src/azure-cli-core/azure/cli/core/tests/test_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@

import unittest
import tempfile
import paramiko
import io
import os
import shutil
from knack.util import CLIError
from unittest import mock

from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import serialization

from azure.cli.core.keys import generate_ssh_keys


Expand All @@ -22,12 +22,16 @@ def setUp(self):
# set up temporary directory to be used for temp files.
self._tempdirName = tempfile.mkdtemp(prefix="key_tmp_")

self.key = paramiko.RSAKey.generate(2048)
keyOutput = io.StringIO()
self.key.write_private_key(keyOutput)

self.private_key = keyOutput.getvalue()
self.public_key = '{} {}'.format(self.key.get_name(), self.key.get_base64())
self.key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
self.private_key = self.key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption()
).decode()
self.public_key = self.key.public_key().public_bytes(
encoding=serialization.Encoding.OpenSSH,
format=serialization.PublicFormat.OpenSSH
).decode()

def tearDown(self):
# delete temporary directory to be used for temp files.
Expand Down Expand Up @@ -70,31 +74,22 @@ def test_error_raised_when_public_key_file_exists_IOError(self):
mocked_open.assert_called_once_with(public_key_path, 'r')
mocked_f.read.assert_called_once()

def test_error_raised_when_private_key_file_exists_IOError(self):
# Create private key file
private_key_path = self._create_new_temp_key_file(self.private_key)

with mock.patch('paramiko.RSAKey') as mocked_RSAKey:
# mock failed RSAKey generation
mocked_RSAKey.side_effect = IOError("Mocked IOError")

# assert that CLIError raised when generate_ssh_keys is called
with self.assertRaises(CLIError):
public_key_path = private_key_path + ".pub"
generate_ssh_keys(private_key_path, public_key_path)

# assert that CLIError raised because of attempt to generate key from private key file.
mocked_RSAKey.assert_called_once_with(filename=private_key_path)

Comment on lines -73 to -88
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As IOError is no longer caught, this test is meaningless.

def test_error_raised_when_private_key_file_exists_encrypted(self):
# Create empty private key file
private_key_path = self._create_new_temp_key_file("")

# Write encrypted / passworded key into file
self.key.write_private_key_file(private_key_path, password="test")

# Check that CLIError exception is raised when generate_ssh_keys is called.
with self.assertRaises(CLIError):
private_bytes = self.key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.BestAvailableEncryption(b'test')
)
with open(private_key_path, 'wb') as f:
f.write(private_bytes)

# generate_ssh_keys should raise
# TypeError: Password was not given but private key is encrypted
with self.assertRaises(TypeError):
public_key_path = private_key_path + ".pub"
generate_ssh_keys(private_key_path, public_key_path)

Expand Down Expand Up @@ -133,10 +128,15 @@ def test_generate_new_private_public_key_files(self):
self.assertEqual(public_key, new_public_key)

# Check that public key corresponds to private key
with open(private_key_path, 'r') as f:
key = paramiko.RSAKey(filename=private_key_path)
public_key = '{} {}'.format(key.get_name(), key.get_base64())
self.assertEqual(public_key, new_public_key)
with open(private_key_path, 'rb') as f:
private_bytes = f.read()

private_key = serialization.load_pem_private_key(private_bytes, password=None)
public_key = private_key.public_key().public_bytes(
encoding=serialization.Encoding.OpenSSH,
format=serialization.PublicFormat.OpenSSH
).decode()
self.assertEqual(public_key, new_public_key)

def _create_new_temp_key_file(self, key_data, suffix=""):
with tempfile.NamedTemporaryFile(mode='w', dir=self._tempdirName, delete=False, suffix=suffix) as f:
Expand Down
1 change: 0 additions & 1 deletion src/azure-cli-core/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
'msal[broker]==1.31.0',
'msrestazure~=0.6.4',
'packaging>=20.9',
'paramiko>=2.0.8,<4.0.0',
'pkginfo>=1.5.0.1',
# psutil can't install on cygwin: https://github.com/Azure/azure-cli/issues/9399
'psutil>=5.9; sys_platform != "cygwin"',
Expand Down
1 change: 1 addition & 0 deletions src/azure-cli/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@
'javaproperties~=0.5.1',
'jsondiff~=2.0.0',
'packaging>=20.9',
'paramiko>=2.0.8,<4.0.0',
'pycomposefile>=0.0.29',
'PyGithub~=1.38',
'PyNaCl~=1.5.0',
Expand Down