1+ """Tests for runpod.__init__ module exports.""" 
2+ 
3+ import  inspect 
4+ import  runpod 
5+ 
6+ 
7+ class  TestRunpodInit :
8+     """Test runpod module __all__ exports.""" 
9+ 
10+     def  test_all_defined (self ):
11+         """Test that __all__ is defined in the module.""" 
12+         assert  hasattr (runpod , '__all__' )
13+         assert  isinstance (runpod .__all__ , list )
14+         assert  len (runpod .__all__ ) >  0 
15+ 
16+     def  test_all_symbols_importable (self ):
17+         """Test that all symbols in __all__ are actually importable.""" 
18+         for  symbol  in  runpod .__all__ :
19+             assert  hasattr (runpod , symbol ), f"Symbol '{ symbol }  ' in __all__ but not found in module" 
20+ 
21+     def  test_api_functions_accessible (self ):
22+         """Test that API functions are accessible and callable.""" 
23+         api_functions  =  [
24+             'create_container_registry_auth' , 'create_endpoint' , 'create_pod' , 'create_template' ,
25+             'delete_container_registry_auth' , 'get_endpoints' , 'get_gpu' , 'get_gpus' ,
26+             'get_pod' , 'get_pods' , 'get_user' , 'resume_pod' , 'stop_pod' , 'terminate_pod' ,
27+             'update_container_registry_auth' , 'update_endpoint_template' , 'update_user_settings' 
28+         ]
29+         
30+         for  func_name  in  api_functions :
31+             assert  func_name  in  runpod .__all__ 
32+             assert  hasattr (runpod , func_name )
33+             assert  callable (getattr (runpod , func_name ))
34+ 
35+     def  test_config_functions_accessible (self ):
36+         """Test that config functions are accessible and callable.""" 
37+         config_functions  =  ['check_credentials' , 'get_credentials' , 'set_credentials' ]
38+         
39+         for  func_name  in  config_functions :
40+             assert  func_name  in  runpod .__all__ 
41+             assert  hasattr (runpod , func_name )
42+             assert  callable (getattr (runpod , func_name ))
43+ 
44+     def  test_endpoint_classes_accessible (self ):
45+         """Test that endpoint classes are accessible.""" 
46+         endpoint_classes  =  ['AsyncioEndpoint' , 'AsyncioJob' , 'Endpoint' ]
47+         
48+         for  class_name  in  endpoint_classes :
49+             assert  class_name  in  runpod .__all__ 
50+             assert  hasattr (runpod , class_name )
51+             assert  inspect .isclass (getattr (runpod , class_name ))
52+ 
53+     def  test_serverless_module_accessible (self ):
54+         """Test that serverless module is accessible.""" 
55+         assert  'serverless'  in  runpod .__all__ 
56+         assert  hasattr (runpod , 'serverless' )
57+         assert  inspect .ismodule (runpod .serverless )
58+ 
59+     def  test_logger_class_accessible (self ):
60+         """Test that RunPodLogger class is accessible.""" 
61+         assert  'RunPodLogger'  in  runpod .__all__ 
62+         assert  hasattr (runpod , 'RunPodLogger' )
63+         assert  inspect .isclass (runpod .RunPodLogger )
64+ 
65+     def  test_version_accessible (self ):
66+         """Test that __version__ is accessible.""" 
67+         assert  '__version__'  in  runpod .__all__ 
68+         assert  hasattr (runpod , '__version__' )
69+         assert  isinstance (runpod .__version__ , str )
70+ 
71+     def  test_module_variables_accessible (self ):
72+         """Test that module variables are accessible.""" 
73+         module_vars  =  ['SSH_KEY_PATH' , 'profile' , 'api_key' , 'endpoint_url_base' ]
74+         
75+         for  var_name  in  module_vars :
76+             assert  var_name  in  runpod .__all__ 
77+             assert  hasattr (runpod , var_name )
78+ 
79+     def  test_private_imports_not_exported (self ):
80+         """Test that private imports are not in __all__.""" 
81+         private_symbols  =  {
82+             'logging' , 'os' , '_credentials' 
83+         }
84+         all_symbols  =  set (runpod .__all__ )
85+         
86+         for  private_symbol  in  private_symbols :
87+             assert  private_symbol  not  in   all_symbols , f"Private symbol '{ private_symbol }  ' should not be in __all__" 
88+ 
89+     def  test_all_covers_expected_public_api (self ):
90+         """Test that __all__ contains the expected public API symbols.""" 
91+         expected_symbols  =  {
92+             # API functions   
93+             'create_container_registry_auth' , 'create_endpoint' , 'create_pod' , 'create_template' ,
94+             'delete_container_registry_auth' , 'get_endpoints' , 'get_gpu' , 'get_gpus' ,
95+             'get_pod' , 'get_pods' , 'get_user' , 'resume_pod' , 'stop_pod' , 'terminate_pod' ,
96+             'update_container_registry_auth' , 'update_endpoint_template' , 'update_user_settings' ,
97+             # Config functions 
98+             'check_credentials' , 'get_credentials' , 'set_credentials' ,
99+             # Endpoint classes 
100+             'AsyncioEndpoint' , 'AsyncioJob' , 'Endpoint' ,
101+             # Serverless module 
102+             'serverless' ,
103+             # Logger class 
104+             'RunPodLogger' ,
105+             # Version 
106+             '__version__' ,
107+             # Module variables 
108+             'SSH_KEY_PATH' , 'profile' , 'api_key' , 'endpoint_url_base' 
109+         }
110+         
111+         actual_symbols  =  set (runpod .__all__ )
112+         assert  expected_symbols  ==  actual_symbols , f"Expected { expected_symbols }  , got { actual_symbols }  " 
113+ 
114+     def  test_no_duplicate_symbols_in_all (self ):
115+         """Test that __all__ contains no duplicate symbols.""" 
116+         all_symbols  =  runpod .__all__ 
117+         unique_symbols  =  set (all_symbols )
118+         assert  len (all_symbols ) ==  len (unique_symbols ), f"Duplicates found in __all__: { [x  for  x  in  all_symbols  if  all_symbols .count (x ) >  1 ]}  " 
0 commit comments