1818from  cf_remote .web  import  download_package 
1919from  cf_remote .paths  import  (
2020    cf_remote_dir ,
21+     cf_remote_packages_dir ,
2122    CLOUD_CONFIG_FPATH ,
2223    CLOUD_STATE_FPATH ,
23-     cf_remote_packages_dir ,
24+     SSH_CONFIG_FPATH ,
25+     SSH_CONFIGS_JSON_FPATH ,
2426)
2527from  cf_remote .utils  import  (
2628    copy_file ,
3739    CFRChecksumError ,
3840    CFRUserError ,
3941)
40- from  cf_remote .spawn  import  VM , VMRequest , Providers , AWSCredentials , GCPCredentials 
42+ from  cf_remote .spawn  import  (
43+     CloudVM ,
44+     VMRequest ,
45+     Providers ,
46+     AWSCredentials ,
47+     GCPCredentials ,
48+     VagrantVM ,
49+ )
4150from  cf_remote .spawn  import  spawn_vms , destroy_vms , dump_vms_info , get_cloud_driver 
4251from  cf_remote  import  log 
4352from  cf_remote  import  cloud_data 
@@ -393,6 +402,9 @@ def spawn(
393402    network = None ,
394403    public_ip = True ,
395404    extend_group = False ,
405+     vagrant_cpus = None ,
406+     vagrant_sync = None ,
407+     vagrant_provision = None ,
396408):
397409    creds_data  =  None 
398410    if  os .path .exists (CLOUD_CONFIG_FPATH ):
@@ -469,13 +481,20 @@ def spawn(
469481            network = network ,
470482            role = role ,
471483            spawned_cb = print_progress_dot ,
484+             vagrant_cpus = vagrant_cpus ,
485+             vagrant_sync = vagrant_sync ,
486+             vagrant_provision = vagrant_provision ,
472487        )
473488    except  ValueError  as  e :
474489        print ("\n Error: Failed to spawn VMs - "  +  str (e ))
475490        return  1 
476491    print ("DONE" )
477492
478-     if  public_ip  and  (not  all (vm .public_ips  for  vm  in  vms )):
493+     if  (
494+         provider  !=  Providers .VAGRANT 
495+         and  public_ip 
496+         and  (not  all (vm .public_ips  for  vm  in  vms ))
497+     ):
479498        print ("Waiting for VMs to get IP addresses..." , end = "" )
480499        sys .stdout .flush ()  # STDOUT is line-buffered 
481500        while  not  all (vm .public_ips  for  vm  in  vms ):
@@ -488,6 +507,16 @@ def spawn(
488507    else :
489508        vms_info [group_key ] =  dump_vms_info (vms )
490509
510+     if  provider  ==  Providers .VAGRANT :
511+         vmdir  =  vms [0 ].vmdir 
512+         ssh_config  =  read_json (SSH_CONFIGS_JSON_FPATH )
513+         if  not  ssh_config :
514+             ssh_config  =  {}
515+ 
516+         with  open (os .path .join (vmdir , "vagrant-ssh-config" ), "r" ) as  f :
517+             ssh_config [group_key ] =  f .read ()
518+         write_json (SSH_CONFIGS_JSON_FPATH , ssh_config )
519+ 
491520    write_json (CLOUD_STATE_FPATH , vms_info )
492521    print ("Details about the spawned VMs can be found in %s"  %  CLOUD_STATE_FPATH )
493522
@@ -510,6 +539,36 @@ def _delete_saved_group(vms_info, group_name):
510539    del  vms_info [group_name ]
511540
512541
542+ def  _get_cloud_vms (provider , creds , region , group ):
543+     if  creds  is  None :
544+         raise  CFRExitError ("Missing/incomplete {} credentials" .format (provider .upper ()))
545+     driver  =  get_cloud_driver (provider , creds , region )
546+ 
547+     assert  driver  is  not None 
548+ 
549+     nodes  =  driver .list_nodes ()
550+     for  name , vm_info  in  group .items ():
551+         if  name  ==  "meta" :
552+             continue 
553+         vm_uuid  =  vm_info ["uuid" ]
554+         vm  =  CloudVM .get_by_uuid (vm_uuid , nodes = nodes )
555+         if  vm  is  not None :
556+             yield  vm 
557+         else :
558+             print ("VM '%s' not found in the clouds"  %  vm_uuid )
559+ 
560+ 
561+ def  _get_vagrant_vms (group ):
562+     for  name , vm_info  in  group .items ():
563+         if  name  ==  "meta" :
564+             continue 
565+         vm  =  VagrantVM .get_by_info (name , vm_info )
566+         if  vm  is  not None :
567+             yield  vm 
568+         else :
569+             print ("VM '%s' not found locally"  %  name )
570+ 
571+ 
513572def  destroy (group_name = None ):
514573    if  os .path .exists (CLOUD_CONFIG_FPATH ):
515574        creds_data  =  read_json (CLOUD_CONFIG_FPATH )
@@ -549,92 +608,54 @@ def destroy(group_name=None):
549608        raise  CFRUserError ("No saved VMs found in '{}'" .format (CLOUD_STATE_FPATH ))
550609
551610    to_destroy  =  []
611+     group_names  =  None 
552612    if  group_name :
553613        if  not  group_name .startswith ("@" ):
554614            group_name  =  "@"  +  group_name 
555615        if  group_name  not  in vms_info :
556616            print ("Group '%s' not found"  %  group_name )
557617            return  1 
558618
619+         group_names  =  [group_name ]
620+     else :
621+         group_names  =  [key  for  key  in  vms_info .keys () if  key .startswith ("@" )]
622+ 
623+     ssh_config  =  read_json (SSH_CONFIGS_JSON_FPATH )
624+     assert  group_names  is  not None 
625+     for  group_name  in  group_names :
559626        if  _is_saved_group (vms_info , group_name ):
560627            _delete_saved_group (vms_info , group_name )
561-             write_json (CLOUD_STATE_FPATH , vms_info )
562-             return  0 
563- 
564-         print ("Destroying hosts in the '%s' group"  %  group_name )
628+             continue 
565629
566630        region  =  vms_info [group_name ]["meta" ]["region" ]
567631        provider  =  vms_info [group_name ]["meta" ]["provider" ]
568-         if  provider  not  in "aws" , "gcp" ]:
632+         if  provider  not  in "aws" , "gcp" ,  "vagrant" ]:
569633            raise  CFRUserError (
570-                 "Unsupported provider '{}' encountered in '{}', only aws /  gcp is  supported" .format (
634+                 "Unsupported provider '{}' encountered in '{}', only aws,  gcp and vagrant are  supported" .format (
571635                    provider , CLOUD_STATE_FPATH 
572636                )
573637            )
574638
575-         driver  =  None 
639+         group  =  vms_info [group_name ]
640+         vms  =  []
576641        if  provider  ==  "aws" :
577-             if  aws_creds  is  None :
578-                 raise  CFRExitError ("Missing/incomplete AWS credentials" )
579-             driver  =  get_cloud_driver (Providers .AWS , aws_creds , region )
642+             vms  =  _get_cloud_vms (Providers .AWS , aws_creds , region , group )
580643        if  provider  ==  "gcp" :
581-             if  gcp_creds  is  None :
582-                 raise  CFRExitError ("Missing/incomplete GCP credentials" )
583-             driver  =  get_cloud_driver (Providers .GCP , gcp_creds , region )
584-         assert  driver  is  not None 
644+             vms  =  _get_cloud_vms (Providers .GCP , gcp_creds , region , group )
645+         if  provider  ==  "vagrant" :
646+             vms  =  _get_vagrant_vms (group )
585647
586-         nodes  =  driver .list_nodes ()
587-         for  name , vm_info  in  vms_info [group_name ].items ():
588-             if  name  ==  "meta" :
589-                 continue 
590-             vm_uuid  =  vm_info ["uuid" ]
591-             vm  =  VM .get_by_uuid (vm_uuid , nodes = nodes )
592-             if  vm  is  not None :
593-                 to_destroy .append (vm )
594-             else :
595-                 print ("VM '%s' not found in the clouds"  %  vm_uuid )
596-         del  vms_info [group_name ]
597-     else :
598-         print ("Destroying all hosts" )
599-         for  group_name  in  [key  for  key  in  vms_info .keys () if  key .startswith ("@" )]:
600-             if  _is_saved_group (vms_info , group_name ):
601-                 _delete_saved_group (vms_info , group_name )
602-                 continue 
648+         for  vm  in  vms :
649+             to_destroy .append (vm )
603650
604-             region  =  vms_info [group_name ]["meta" ]["region" ]
605-             provider  =  vms_info [group_name ]["meta" ]["provider" ]
606-             if  provider  not  in "aws" , "gcp" ]:
607-                 raise  CFRUserError (
608-                     "Unsupported provider '{}' encountered in '{}', only aws / gcp is supported" .format (
609-                         provider , CLOUD_STATE_FPATH 
610-                     )
611-                 )
651+         del  vms_info [group_name ]
612652
613-             driver  =  None 
614-             if  provider  ==  "aws" :
615-                 if  aws_creds  is  None :
616-                     raise  CFRExitError ("Missing/incomplete AWS credentials" )
617-                 driver  =  get_cloud_driver (Providers .AWS , aws_creds , region )
618-             if  provider  ==  "gcp" :
619-                 if  gcp_creds  is  None :
620-                     raise  CFRExitError ("Missing/incomplete GCP credentials" )
621-                 driver  =  get_cloud_driver (Providers .GCP , gcp_creds , region )
622-             assert  driver  is  not None 
623- 
624-             nodes  =  driver .list_nodes ()
625-             for  name , vm_info  in  vms_info [group_name ].items ():
626-                 if  name  ==  "meta" :
627-                     continue 
628-                 vm_uuid  =  vm_info ["uuid" ]
629-                 vm  =  VM .get_by_uuid (vm_uuid , nodes = nodes )
630-                 if  vm  is  not None :
631-                     to_destroy .append (vm )
632-                 else :
633-                     print ("VM '%s' not found in the clouds"  %  vm_uuid )
634-             del  vms_info [group_name ]
653+         if  ssh_config  and  group_name  in  ssh_config :
654+             del  ssh_config [group_name ]
635655
636656    destroy_vms (to_destroy )
637657    write_json (CLOUD_STATE_FPATH , vms_info )
658+     write_json (SSH_CONFIGS_JSON_FPATH , ssh_config )
638659    return  0 
639660
640661
@@ -655,6 +676,11 @@ def list_platforms():
655676    return  0 
656677
657678
679+ def  list_boxes ():
680+     result  =  subprocess .run (["vagrant" , "box" , "list" ])
681+     return  result .returncode 
682+ 
683+ 
658684def  init_cloud_config ():
659685    if  os .path .exists (CLOUD_CONFIG_FPATH ):
660686        print ("File %s already exists"  %  CLOUD_CONFIG_FPATH )
@@ -1010,7 +1036,7 @@ def connect_cmd(hosts):
10101036        raise  CFRExitError ("You can only connect to one host at a time" )
10111037
10121038    print ("Opening a SSH command shell..." )
1013-     r  =  subprocess .run (["ssh" , hosts [0 ]])
1039+     r  =  subprocess .run (["ssh" , "-F" ,  SSH_CONFIG_FPATH ,  hosts [0 ]])
10141040    if  r .returncode  ==  0 :
10151041        return  0 
10161042    if  r .returncode  <  0 :
0 commit comments