Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions runpod/cli/groups/project/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def print_net_vol(vol):

cuda_version = click.prompt(
" > Select a CUDA version, or press enter to use the default",
type=click.Choice(['11.1.1', '11.8.1', '12.1.0'], case_sensitive=False),
default='11.8.1'
type=click.Choice(['11.1.1', '11.8.0', '12.1.0'], case_sensitive=False),
default='11.8.0'
)

python_version = click.prompt(
Expand Down
6 changes: 3 additions & 3 deletions tests/test_cli/test_cli_groups/test_project_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_new_project_wizard_success(self):
patch('runpod.cli.groups.project.commands.get_user') as mock_get_user, \
patch('runpod.cli.groups.project.commands.cli_select') as mock_select:
mock_get_user.return_value = {'networkVolumes':[{ 'id': 'XYZ_VOLUME', 'name': 'XYZ_VOLUME', 'size': 100, 'dataCenterId': 'XYZ' }]} # pylint: disable=line-too-long
mock_prompt.side_effect = ['TestProject', '11.8.1', '3.10']
mock_prompt.side_effect = ['TestProject', '11.8.0', '3.10']
mock_select.return_value = {'volume-id': 'XYZ_VOLUME'}

result = self.runner.invoke(new_project_wizard, ['--type', 'llama2', '--model', 'meta-llama/Llama-2-7b']) # pylint: disable=line-too-long
Expand All @@ -47,7 +47,7 @@ def test_new_project_wizard_success(self):
mock_confirm.assert_called_with("Do you want to continue?", abort=True)
mock_create.assert_called()
mock_prompt.assert_called()
mock_create.assert_called_with('TestProject', 'XYZ_VOLUME', '11.8.1', '3.10', 'llama2', 'meta-llama/Llama-2-7b', False) # pylint: disable=line-too-long
mock_create.assert_called_with('TestProject', 'XYZ_VOLUME', '11.8.0', '3.10', 'llama2', 'meta-llama/Llama-2-7b', False) # pylint: disable=line-too-long
self.assertIn("Project TestProject created successfully!", result.output)

def test_new_project_wizard_success_init_current_dir(self):
Expand All @@ -62,7 +62,7 @@ def test_new_project_wizard_success_init_current_dir(self):
patch('os.getcwd') as mock_getcwd:
mock_get_user.return_value = {'networkVolumes':[{ 'id': 'XYZ_VOLUME', 'name': 'XYZ_VOLUME', 'size': 100, 'dataCenterId': 'XYZ' }]} # pylint: disable=line-too-long
mock_select.return_value = {'volume-id': 'XYZ_VOLUME'}
mock_prompt.side_effect = ['XYZ_VOLUME', '11.8.1', '3.10']
mock_prompt.side_effect = ['XYZ_VOLUME', '11.8.0', '3.10']

self.runner.invoke(new_project_wizard, ['--init'])
assert mock_confirm.called
Expand Down
8 changes: 4 additions & 4 deletions tests/test_cli/test_cli_groups/test_project_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_copy_template_files(self, mock_copy_template_files, mock_getcwd, mock_e
def test_replace_placeholders_in_handler(self, mock_open_file, mock_exists): # pylint: disable=line-too-long
""" Test that placeholders in handler.py are replaced if model_name is given. """
with patch("runpod.cli.groups.project.functions.copy_template_files"):
create_new_project("test_project", "volume_id", "11.8.1", "3.8", model_name="my_model")
create_new_project("test_project", "volume_id", "11.8.0", "3.8", model_name="my_model")
# mock_open_file().write.assert_called_with("data with my_model placeholder")
assert mock_open_file.called
assert mock_exists.called
Expand All @@ -76,7 +76,7 @@ def test_replace_placeholders_in_handler(self, mock_open_file, mock_exists): # p
def test_create_runpod_toml(self, mock_open_file, mock_exists):
""" Test that runpod.toml configuration file is created. """
with patch("runpod.cli.groups.project.functions.copy_template_files"):
create_new_project("test_project", "volume_id", "11.8.1", "3.8")
create_new_project("test_project", "volume_id", "11.8.0", "3.8")
toml_file_location = os.path.join(os.getcwd(), "test_project", "runpod.toml")
mock_open_file.assert_called_with(toml_file_location, 'w', encoding="UTF-8") # pylint: disable=line-too-long
assert mock_exists.called
Expand All @@ -98,7 +98,7 @@ def test_update_requirements_file(self, mock_open_file, mock_exists):
""" Test that placeholders in requirements.txt are replaced correctly. """
with patch("runpod.cli.groups.project.functions.__version__", "dev"), \
patch("runpod.cli.groups.project.functions.copy_template_files"):
create_new_project("test_project", "volume_id", "11.8.1", "3.8")
create_new_project("test_project", "volume_id", "11.8.0", "3.8")
assert mock_open_file.called
assert mock_exists.called

Expand All @@ -108,7 +108,7 @@ def test_update_requirements_file_non_dev(self, mock_open_file, mock_exists):
""" Test that placeholders in requirements.txt are replaced for non-dev versions. """
with patch("runpod.cli.groups.project.functions.__version__", "1.0.0"), \
patch("runpod.cli.groups.project.functions.copy_template_files"):
create_new_project("test_project", "volume_id", "11.8.1", "3.8")
create_new_project("test_project", "volume_id", "11.8.0", "3.8")
assert mock_open_file.called
assert mock_exists.called

Expand Down