1212import  tempfile 
1313import  time 
1414from  dataclasses  import  asdict , is_dataclass 
15- from  datetime  import  datetime , timedelta 
1615from  enum  import  Enum 
1716from  typing  import  (
1817    Any ,
3332import  requests 
3433from  azure .storage .blob  import  ContainerClient 
3534from  onefuzztypes  import  responses 
36- from  pydantic  import  BaseModel ,  Field 
35+ from  pydantic  import  BaseModel 
3736from  requests  import  Response 
3837from  tenacity  import  RetryCallState , retry 
3938from  tenacity .retry  import  retry_if_exception_type 
@@ -93,20 +92,26 @@ def check_application_error(response: requests.Response) -> None:
9392
9493
9594class  BackendConfig (BaseModel ):
96-     authority : str 
97-     client_id : str 
98-     endpoint : Optional [str ]
99-     features : Set [str ] =  Field (default_factory = set )
100-     tenant_domain : str 
101-     expires_on : datetime  =  datetime .utcnow () +  timedelta (hours = 24 )
95+     authority : Optional [str ]
96+     client_id : Optional [str ]
97+     endpoint : str 
98+     features : Optional [Set [str ]]
99+     tenant_domain : Optional [str ]
102100
103101    def  get_multi_tenant_domain (self ) ->  Optional [str ]:
104-         if  "https://login.microsoftonline.com/common"  in  self .authority :
102+         if  (
103+             self .authority 
104+             and  "https://login.microsoftonline.com/common"  in  self .authority 
105+         ):
105106            return  self .tenant_domain 
106107        else :
107108            return  None 
108109
109110
111+ class  CacheConfig (BaseModel ):
112+     endpoint : Optional [str ]
113+ 
114+ 
110115class  Backend :
111116    def  __init__ (
112117        self ,
@@ -129,10 +134,14 @@ def __init__(
129134        atexit .register (self .save_cache )
130135
131136    def  enable_feature (self , name : str ) ->  None :
137+         if  not  self .config .features :
138+             self .config .features  =  Set [str ]()
132139        self .config .features .add (name )
133140
134141    def  is_feature_enabled (self , name : str ) ->  bool :
135-         return  name  in  self .config .features 
142+         if  self .config .features :
143+             return  name  in  self .config .features 
144+         return  False 
136145
137146    def  load_config (self ) ->  None :
138147        if  os .path .exists (self .config_path ):
@@ -143,7 +152,8 @@ def load_config(self) -> None:
143152    def  save_config (self ) ->  None :
144153        os .makedirs (os .path .dirname (self .config_path ), exist_ok = True )
145154        with  open (self .config_path , "w" ) as  handle :
146-             handle .write (self .config .json (indent = 4 , exclude_none = True ))
155+             endpoint_cache  =  {"endpoint" : f"{ self .config .endpoint }  }
156+             handle .write (json .dumps (endpoint_cache , indent = 4 , sort_keys = True ))
147157
148158    def  init_cache (self ) ->  None :
149159        # Ensure the token_path directory exists 
@@ -331,15 +341,13 @@ def config_params(
331341        endpoint_params  =  responses .Config .parse_obj (response .json ())
332342
333343        # Will override values in storage w/ provided values for SP use 
334-         if  self .config .client_id   ==   "" :
344+         if  not   self .config .client_id :
335345            self .config .client_id  =  endpoint_params .client_id 
336-         if  self .config .authority   ==   "" :
346+         if  not   self .config .authority :
337347            self .config .authority  =  endpoint_params .authority 
338-         if  self .config .tenant_domain   ==   "" :
348+         if  not   self .config .tenant_domain :
339349            self .config .tenant_domain  =  endpoint_params .tenant_domain 
340350
341-         self .save_config ()
342- 
343351    def  request (
344352        self ,
345353        method : str ,
@@ -353,17 +361,9 @@ def request(
353361        if  not  endpoint :
354362            raise  Exception ("endpoint not configured" )
355363
356-         # If file expires, remove and force user to reset 
357-         if  datetime .utcnow () >  self .config .expires_on :
358-             os .remove (self .config_path )
359-             self .config  =  BackendConfig (
360-                 endpoint = endpoint , authority = "" , client_id = "" , tenant_domain = "" 
361-             )
362- 
363364        url  =  endpoint  +  "/api/"  +  path 
364- 
365-         if  self .config .client_id  ==  ""  or  (
366-             self .config .authority  ==  ""  and  self .config .tenant_domain  ==  "" 
365+         if  not  self .config .client_id  or  (
366+             not  self .config .authority  and  not  self .config .tenant_domain 
367367        ):
368368            self .config_params ()
369369        headers  =  self .headers ()
0 commit comments