diff --git a/src/plugins/dnf/product-id.py b/src/plugins/dnf/product-id.py index ed6537c93a..d971349b3f 100644 --- a/src/plugins/dnf/product-id.py +++ b/src/plugins/dnf/product-id.py @@ -13,6 +13,7 @@ # import logging +from typing import Set from subscription_manager.productid import ProductManager from subscription_manager.utils import chroot @@ -202,7 +203,7 @@ def write_productid_cache(self, product_ids): def read_productid_cache(self): return self.__read_cache_file(self.PRODUCTID_CACHE_FILE) - def get_active(self): + def get_active(self) -> Set[str]: """ Find the list of repos that provide packages that are actually installed. """ diff --git a/src/rhsm/certificate2.py b/src/rhsm/certificate2.py index 33240af42a..5881237bbc 100644 --- a/src/rhsm/certificate2.py +++ b/src/rhsm/certificate2.py @@ -167,10 +167,10 @@ def _read_alt_name(self, x509) -> str: else: return alt_name.decode("utf-8") - def _read_issuer(self, x509: _certificate.X509) -> str: + def _read_issuer(self, x509: _certificate.X509) -> dict: return x509.get_issuer() - def _read_subject(self, x509: _certificate.X509) -> str: + def _read_subject(self, x509: _certificate.X509) -> dict: return x509.get_subject() def _create_identity_cert( @@ -504,9 +504,9 @@ def __init__( serial: Optional[int] = None, start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, - subject: Optional[str] = None, + subject: Optional[str] = dict, pem: Optional[str] = None, - issuer: Optional[str] = None, + issuer: Optional[str] = dict, ): # The rhsm._certificate X509 object for this certificate. # WARNING: May be None in tests @@ -531,8 +531,8 @@ def __init__( self.valid_range = DateRange(self.start, self.end) self.pem: Optional[str] = pem - self.subject: Optional[str] = subject - self.issuer: Optional[str] = issuer + self.subject: Optional[dict] = subject + self.issuer: Optional[dict] = issuer def is_valid(self, on_date: Optional[datetime.datetime] = None): gmt = datetime.datetime.utcnow() @@ -615,14 +615,14 @@ class EntitlementCertificate(ProductCertificate): def __init__( self, order: Optional["Order"] = None, - content: Optional["Content"] = None, + content: Optional[List["Content"]] = None, pool: Optional["Pool"] = None, extensions: Optional[Extensions] = None, **kwargs, ): ProductCertificate.__init__(self, **kwargs) self.order: Optional[Order] = order - self.content: Optional[Content] = content + self.content: Optional[List[Content]] = content self.pool: Optional[Pool] = pool self.extensions: Optional[Extensions] = extensions self._path_tree_object = None diff --git a/src/rhsm/connection.py b/src/rhsm/connection.py index f02a0d26b8..a3c1c779e4 100644 --- a/src/rhsm/connection.py +++ b/src/rhsm/connection.py @@ -25,7 +25,7 @@ import sys import time import traceback -from typing import Optional, Any, Union, List, Dict, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from pathlib import Path import re import enum @@ -547,7 +547,7 @@ def __init__(self, cert_dir: str = None, **kwargs) -> None: handler="/", cert_dir=cert_dir, user_agent=user_agent, **kwargs ) - def get_versions(self, path: str, cert_key_pairs: list = None) -> Union[dict, None]: + def get_versions(self, path: str, cert_key_pairs: Iterable[Tuple[str, str]] = None) -> Union[dict, None]: """ Get list of available release versions from the given path :param path: path, where is simple text file containing supported release versions @@ -1722,7 +1722,7 @@ def updatePackageProfile(self, consumer_uuid: str, pkg_dicts: dict) -> dict: method = "/consumers/%s/packages" % self.sanitize(consumer_uuid) return self.conn.request_put(method, pkg_dicts, description=_("Updating profile information")) - def updateCombinedProfile(self, consumer_uuid: str, profile: dict) -> dict: + def updateCombinedProfile(self, consumer_uuid: str, profile: List[Dict]) -> dict: """ Updates the costumers' combined profile containing package profile, enabled repositories and dnf modules. diff --git a/src/rhsmlib/facts/cloud_facts.py b/src/rhsmlib/facts/cloud_facts.py index 8f4c9c9ceb..6bb1627ef1 100644 --- a/src/rhsmlib/facts/cloud_facts.py +++ b/src/rhsmlib/facts/cloud_facts.py @@ -78,7 +78,7 @@ def get_aws_facts(self) -> Dict[str, Union[str, None]]: facts: Dict[str, Union[str, None]] = {} if metadata_str is not None: - values: dict[str, Union[str, None]] = self.parse_json_content(metadata_str) + values: Dict[str, Union[str, None]] = self.parse_json_content(metadata_str) # Add these three attributes to system facts if "instanceId" in values: diff --git a/src/rhsmlib/file_monitor.py b/src/rhsmlib/file_monitor.py index a84e1139f5..90a3b8138c 100644 --- a/src/rhsmlib/file_monitor.py +++ b/src/rhsmlib/file_monitor.py @@ -11,6 +11,7 @@ # granted to use or replicate Red Hat trademarks that are incorporated # in this software or its documentation. import threading +from typing import Dict from rhsm.config import get_config_parser from rhsmlib.services import config @@ -55,14 +56,14 @@ class FilesystemWatcher(object): # Timeout of loop in milliseconds TIMEOUT = 2000 - def __init__(self, dir_watches): + def __init__(self, dir_watches: Dict[str, "DirectoryWatch"]): """ :param dir_watches: dictionary of directories to watch (see DirectoryWatch class below) """ - self.dir_watches = dir_watches - self.should_stop = False + self.dir_watches: Dict[str, DirectoryWatch] = dir_watches + self.should_stop: bool = False - def stop(self): + def stop(self) -> None: """ Calling this method stops infinity loop of FileSystemWatcher """ diff --git a/src/rhsmlib/services/syspurpose.py b/src/rhsmlib/services/syspurpose.py index 635f393a5a..f96050c16a 100644 --- a/src/rhsmlib/services/syspurpose.py +++ b/src/rhsmlib/services/syspurpose.py @@ -16,6 +16,7 @@ """ import logging +from typing import Dict from rhsm.connection import UEPConnection @@ -41,7 +42,7 @@ def __init__(self, cp: UEPConnection) -> None: self.owner = None self.valid_fields = None - def get_syspurpose_status(self, on_date: str = None) -> str: + def get_syspurpose_status(self, on_date: str = None) -> Dict: """ Get syspurpose status from candlepin server :param on_date: Date of the status diff --git a/src/subscription_manager/action_client.py b/src/subscription_manager/action_client.py index 01b03c2d5b..c74248ccca 100644 --- a/src/subscription_manager/action_client.py +++ b/src/subscription_manager/action_client.py @@ -17,9 +17,13 @@ # in this software or its documentation. # import logging +from typing import List, TYPE_CHECKING from subscription_manager import base_action_client +if TYPE_CHECKING: + from subscription_manager.certlib import BaseActionInvoker + from subscription_manager.entcertlib import EntCertActionInvoker from subscription_manager.identitycertlib import IdentityCertActionInvoker from subscription_manager.healinglib import HealingActionInvoker @@ -33,7 +37,7 @@ class ActionClient(base_action_client.BaseActionClient): - def _get_libset(self): + def _get_libset(self) -> List["BaseActionInvoker"]: # TODO: replace with FSM thats progress through this async and wait/joins if needed self.entcertlib = EntCertActionInvoker() @@ -47,7 +51,7 @@ def _get_libset(self): # WARNING: order is important here, we need to update a number # of things before attempting to autoheal, and we need to autoheal # before attempting to fetch our certificates: - lib_set = [ + lib_set: List[BaseActionInvoker] = [ self.entcertlib, self.idcertlib, self.content_client, @@ -61,47 +65,55 @@ def _get_libset(self): class HealingActionClient(base_action_client.BaseActionClient): - def _get_libset(self): + def _get_libset(self) -> List["BaseActionInvoker"]: self.entcertlib = EntCertActionInvoker() self.installedprodlib = InstalledProductsActionInvoker() self.syspurposelib = SyspurposeSyncActionInvoker() self.healinglib = HealingActionInvoker() - lib_set = [self.installedprodlib, self.syspurposelib, self.healinglib, self.entcertlib] + lib_set: List[BaseActionInvoker] = [ + self.installedprodlib, + self.syspurposelib, + self.healinglib, + self.entcertlib, + ] return lib_set # it may make more sense to have *Lib.cleanup actions? # *Lib things are weird, since some are idempotent, but -# some arent. entcertlib/repolib .update can both install +# some are not. entcertlib/repolib .update can both install # certs, and/or delete all of them. class UnregisterActionClient(base_action_client.BaseActionClient): """CertManager for cleaning up on unregister. - This class should not need a consumer id, or a uep connection, since it + This class should not need a consumer id nor an UEP connection, since it is running post unregister. """ - def _get_libset(self): + def _get_libset(self) -> List["BaseActionInvoker"]: self.entcertlib = EntCertActionInvoker() self.content_action_client = ContentActionClient() - lib_set = [self.entcertlib, self.content_action_client] + lib_set: List[BaseActionInvoker] = [ + self.entcertlib, + self.content_action_client, + ] return lib_set class ProfileActionClient(base_action_client.BaseActionClient): """ - This class should not need a consumer id, or a uep connection, since it + This class should not need a consumer id nor an UEP connection, since it is running post unregister. """ - def _get_libset(self): + def _get_libset(self) -> List["BaseActionInvoker"]: self.profilelib = PackageProfileActionInvoker() - lib_set = [self.profilelib] + lib_set: List[BaseActionInvoker] = [self.profilelib] return lib_set diff --git a/src/subscription_manager/base_action_client.py b/src/subscription_manager/base_action_client.py index 4c533c7932..692a3222d7 100644 --- a/src/subscription_manager/base_action_client.py +++ b/src/subscription_manager/base_action_client.py @@ -13,53 +13,55 @@ # in this software or its documentation. # import logging +from typing import List, TYPE_CHECKING + +from rhsm.connection import GoneException, ExpiredIdentityCertException from subscription_manager import injection as inj -from rhsm.connection import GoneException, ExpiredIdentityCertException +if TYPE_CHECKING: + from subscription_manager.certlib import BaseActionInvoker, ActionReport + from subscription_manager.lock import ActionLock log = logging.getLogger(__name__) class BaseActionClient(object): """ - An object used to update the certficates, yum repos, and facts for the system. + An object used to update the certificates, DNF repos, and facts for the system. """ - def __init__(self, skips=None): + # FIXME Default for skips should be [] + def __init__(self, skips: List[type("ActionReport")] = None): - self._libset = self._get_libset() - self.lock = inj.require(inj.ACTION_LOCK) - self.report = None - self.update_reports = [] - self.skips = skips or [] + self._libset: List[BaseActionInvoker] = self._get_libset() + self.lock: ActionLock = inj.require(inj.ACTION_LOCK) + # FIXME `report` attribute is not used here + self.report: ActionReport = None + self.update_reports: List[ActionReport] = [] + self.skips: List[type(ActionReport)] = skips or [] - def _get_libset(self): + def _get_libset(self) -> List["BaseActionInvoker"]: + # FIXME (?) Raise NotImplementedError, to ensure each subclass is using its own function return [] - def update(self, autoheal=False): + def update(self, autoheal: bool = False) -> None: """ - Update I{entitlement} certificates and corresponding - yum repositiories. - @return: A list of update reports - @rtype: list + Update I{entitlement} certificates and corresponding DNF repositories. """ - lock = self.lock - # TODO: move to using a lock context manager try: - lock.acquire() + self.lock.acquire() self.update_reports = self._run_updates(autoheal) finally: - lock.release() + self.lock.release() - def _run_update(self, lib): - update_report = None + def _run_update(self, lib: type) -> "ActionReport": + update_report: ActionReport = None try: update_report = lib.update() - # see bz#852706, reraise GoneException so that - # consumer cert deletion works + # see bz#852706, reraise GoneException so that consumer cert deletion works except GoneException: raise # raise this so it can be exposed clearly @@ -74,16 +76,17 @@ def _run_update(self, lib): return update_report - def _run_updates(self, autoheal): + # FIXME `autoheal` is not used + def _run_updates(self, autoheal: bool) -> List["ActionReport"]: - update_reports = [] + update_reports: List[ActionReport] = [] for lib in self._libset: if type(lib) in self.skips: continue log.debug("running lib: %s" % lib) - update_report = self._run_update(lib) + update_report: ActionReport = self._run_update(lib) # a map/dict may make more sense here update_reports.append(update_report) diff --git a/src/subscription_manager/base_plugin.py b/src/subscription_manager/base_plugin.py index 46764b5014..ab27375ce0 100644 --- a/src/subscription_manager/base_plugin.py +++ b/src/subscription_manager/base_plugin.py @@ -11,6 +11,10 @@ # granted to use or replicate Red Hat trademarks that are incorporated # in this software or its documentation. # +from typing import Optional, Set, TYPE_CHECKING + +if TYPE_CHECKING: + from subscription_manager.plugins import PluginConfig class SubManPlugin(object): @@ -19,25 +23,24 @@ class SubManPlugin(object): Plugins need to subclass SubManPlugin() to be found """ - name = None - conf = None + name: str = None + conf: "PluginConfig" = None # if all_slots is set, the plugin will get registered to all slots # it is up to the plugin to handle providing callables - all_slots = None + all_slots: Set[str] = None - # did we have hooks that match provided slots? aka - # is this plugin going to be used - found_slots_for_hooks = False + # did we have hooks that match provided slots? aka is this plugin going to be used + found_slots_for_hooks: bool = False - def __init__(self, conf=None): + def __init__(self, conf: Optional["PluginConfig"] = None): if conf: self.conf = conf if self.conf is None: raise TypeError("SubManPlugin can not be constructed with conf=None") - def __str__(self): + def __str__(self) -> str: return self.name or self.__class__.__name__ @classmethod - def get_plugin_key(cls): + def get_plugin_key(cls) -> str: return ".".join([cls.__module__, cls.__name__]) diff --git a/src/subscription_manager/cache.py b/src/subscription_manager/cache.py index 6ac638d743..7b55758b66 100644 --- a/src/subscription_manager/cache.py +++ b/src/subscription_manager/cache.py @@ -17,11 +17,20 @@ this with the current state, and perform an update on the server if necessary. """ +import datetime import logging import os import socket import threading import time +from typing import Dict, TextIO, Union, Literal, Optional, List, Any, Set, TYPE_CHECKING + +if TYPE_CHECKING: + from rhsm.certificate2 import EntitlementCertificate, Product + from subscription_manager.certdirectory import ProductDirectory, EntitlementDirectory + from subscription_manager.cp_provider import CPProvider + from subscription_manager.identity import Identity + from rhsm.https import ssl from rhsm.config import get_config_parser @@ -54,49 +63,48 @@ class CacheManager(object): """ # Fields the subclass must override: - CACHE_FILE = None + CACHE_FILE: str = None - def to_dict(self): + def to_dict(self) -> Dict: """ Returns the data for this collection as a dict to be serialized as JSON. """ raise NotImplementedError - def _load_data(self, open_file): + def _load_data(self, open_file: TextIO) -> Optional[Dict]: """ Load the data in whatever format the sub-class uses from an already opened file descriptor. """ raise NotImplementedError - def _sync_with_server(self, uep, consumer_uuid, *args, **kwargs): + def _sync_with_server(self, uep: connection.UEPConnection, consumer_uuid: str, *args, **kwargs) -> None: """ Sync the latest data to/from the server. """ raise NotImplementedError - def has_changed(self): + def has_changed(self) -> bool: """ - Check if the current system data has changed since the last time we - updated. + Check if the current system data has changed since the last time we updated. """ raise NotImplementedError @classmethod - def delete_cache(cls): + def delete_cache(cls) -> None: """Delete the cache for this collection from disk.""" if os.path.exists(cls.CACHE_FILE): log.debug("Deleting cache: %s" % cls.CACHE_FILE) os.remove(cls.CACHE_FILE) - def _cache_exists(self): + def _cache_exists(self) -> bool: return os.path.exists(self.CACHE_FILE) - def exists(self): + def exists(self) -> bool: return self._cache_exists() - def write_cache(self, debug=True): + def write_cache(self, debug: bool = True) -> None: """ Write the current cache to disk. Should only be done after successful communication with the server. @@ -110,7 +118,7 @@ def write_cache(self, debug=True): try: if not os.access(os.path.dirname(self.CACHE_FILE), os.R_OK): os.makedirs(os.path.dirname(self.CACHE_FILE)) - f = open(self.CACHE_FILE, "w+") + f: TextIO = open(self.CACHE_FILE, "w+") json.dump(self.to_dict(), f, default=json.encode) f.close() if debug: @@ -119,7 +127,7 @@ def write_cache(self, debug=True): log.error("Unable to write cache: %s" % self.CACHE_FILE) log.exception(err) - def _read_cache(self): + def _read_cache(self) -> Optional[Dict]: """ Load the last data we sent to the server. Returns none if no cache file exists. @@ -127,8 +135,9 @@ def _read_cache(self): try: f = open(self.CACHE_FILE) - data = self._load_data(f) + data: Optional[Dict] = self._load_data(f) f.close() + # FIXME We loaded data, but the code expects optional dictionaries return data except IOError as err: log.error("Unable to read cache: %s" % self.CACHE_FILE) @@ -138,7 +147,7 @@ def _read_cache(self): # a new as if it didn't exist pass - def read_cache_only(self): + def read_cache_only(self) -> Optional[Dict]: """ Try to read only cached data. When cache does not exist, then None is returned. @@ -149,9 +158,13 @@ def read_cache_only(self): log.debug("Cache file %s does not exist" % self.CACHE_FILE) return None - def update_check(self, uep, consumer_uuid, force=False): + def update_check( + self, uep: connection.UEPConnection, consumer_uuid: str, force: bool = False + ) -> Literal[0, 1]: """ Check if data has changed, and push an update if so. + + :return: 1 if the cache was updated, 0 otherwise. """ # The package_upload.py yum plugin from katello-agent will @@ -199,15 +212,16 @@ class StatusCache(CacheManager): """ def __init__(self): - self.server_status = None - self.last_error = None + self.server_status: Optional[Dict] = None + self.last_error: Optional[Exception] = None - def load_status(self, uep, uuid, on_date=None): + def load_status( + self, uep: connection.UEPConnection, uuid: str, on_date: Optional[datetime.datetime] = None + ) -> Optional[Dict]: """ Load status from wherever is appropriate. - If server is reachable, return it's response - and cache the results to disk. + If server is reachable, return its response and cache the results to disk. If the server is not reachable, return the latest cache if it is still reasonable to use it. @@ -257,14 +271,14 @@ def load_status(self, uep, uuid, on_date=None): self.last_error = ex return None - def to_dict(self): + def to_dict(self) -> Dict: return self.server_status - def _load_data(self, open_file): - json_str = open_file.read() + def _load_data(self, open_file) -> Dict: + json_str: str = open_file.read() return json.loads(json_str) - def _read_cache(self): + def _read_cache(self) -> Optional[Dict]: """ Prefer in memory cache to avoid io. If it doesn't exist, save the disk cache to the in-memory cache to avoid reading again. @@ -277,7 +291,7 @@ def _read_cache(self): log.debug("Reading status from in-memory cache of %s file" % self.CACHE_FILE) return self.server_status - def _cache_exists(self): + def _cache_exists(self) -> bool: """ If a cache exists in memory, we have written it to the disk No need for unnecessary disk io here. @@ -286,7 +300,9 @@ def _cache_exists(self): return True return super(StatusCache, self)._cache_exists() - def read_status(self, uep, uuid, on_date=None): + def read_status( + self, uep: connection.UEPConnection, uuid: str, on_date: Optional[datetime.datetime] = None + ) -> Optional[str]: """ Return status, from cache if it exists, otherwise load_status and write cache and return it. @@ -321,7 +337,7 @@ def write_cache(self): log.debug("Started thread to write cache: %s" % self.CACHE_FILE) # we override a @classmethod with an instance method in the sub class? - def delete_cache(self): + def delete_cache(self) -> None: super(StatusCache, self).delete_cache() self.server_status = None @@ -335,7 +351,15 @@ class EntitlementStatusCache(StatusCache): CACHE_FILE = "/var/lib/rhsm/cache/entitlement_status.json" - def _sync_with_server(self, uep, uuid, on_date=None, *args, **kwargs): + # TODO Should we allow *args/**kwargs even if we know we are not using them? + def _sync_with_server( + self, + uep: connection.UEPConnection, + uuid: str, + on_date: Optional[datetime.datetime] = None, + *args, + **kwargs, + ): self.server_status = uep.getCompliance(uuid, on_date) @@ -348,27 +372,35 @@ class SyspurposeComplianceStatusCache(StatusCache): CACHE_FILE = "/var/lib/rhsm/cache/syspurpose_compliance_status.json" - def _sync_with_server(self, uep, uuid, on_date=None, *args, **kwargs): + def _sync_with_server( + self, + uep: connection.UEPConnection, + uuid: str, + on_date: Optional[datetime.datetime] = None, + *args, + **kwargs, + ): self.syspurpose_service = syspurpose.Syspurpose(uep) - self.server_status = self.syspurpose_service.get_syspurpose_status(on_date) + self.server_status: Dict = self.syspurpose_service.get_syspurpose_status(on_date) def write_cache(self): if self.server_status is not None and self.server_status["status"] != "unknown": super(SyspurposeComplianceStatusCache, self).write_cache() - def get_overall_status(self): + def get_overall_status(self) -> str: if self.server_status is not None: return self.syspurpose_service.get_overall_status(self.server_status["status"]) else: return self.syspurpose_service.get_overall_status("unknown") - def get_overall_status_code(self): + # FIXME This is wrong; it should be 'str' and 'self.server_status["status"]' instead + def get_overall_status_code(self) -> Union[Dict, str]: if self.server_status is not None: return self.server_status else: return "unknown" - def get_status_reasons(self): + def get_status_reasons(self) -> Optional[str]: if self.server_status is not None and "reasons" in self.server_status: return self.server_status["reasons"] else: @@ -382,8 +414,8 @@ class ProductStatusCache(StatusCache): CACHE_FILE = "/var/lib/rhsm/cache/product_status.json" - def _sync_with_server(self, uep, uuid, *args, **kwargs): - consumer_data = uep.getConsumer(uuid) + def _sync_with_server(self, uep: connection.UEPConnection, uuid: str, *args, **kwargs): + consumer_data: Dict = uep.getConsumer(uuid) if "installedProducts" not in consumer_data: log.warning("Server does not support product date ranges.") @@ -398,7 +430,7 @@ class OverrideStatusCache(StatusCache): CACHE_FILE = "/var/lib/rhsm/cache/content_overrides.json" - def _sync_with_server(self, uep, consumer_uuid, *args, **kwargs): + def _sync_with_server(self, uep: connection.UEPConnection, consumer_uuid: str, *args, **kwargs): self.server_status = uep.getContentOverrides(consumer_uuid) @@ -409,8 +441,8 @@ class ReleaseStatusCache(StatusCache): CACHE_FILE = "/var/lib/rhsm/cache/releasever.json" - def _sync_with_server(self, uep, consumer_uuid, *args, **kwargs): - def get_release(uuid): + def _sync_with_server(self, uep: connection.UEPConnection, consumer_uuid: str, *args, **kwargs): + def get_release(uuid: str) -> Dict: # To mimic connection problems you can raise required exception: # raise connection.RemoteServerException(500, "GET", "/release") @@ -450,7 +482,7 @@ def _get_profile(self, profile_type): return get_profile(profile_type) @staticmethod - def _assembly_profile(rpm_profile, enabled_repos_profile, module_profile): + def _assembly_profile(rpm_profile, enabled_repos_profile, module_profile) -> Dict[str, List[Dict]]: combined_profile = { "rpm": rpm_profile, "enabled_repos": enabled_repos_profile, @@ -459,33 +491,37 @@ def _assembly_profile(rpm_profile, enabled_repos_profile, module_profile): return combined_profile @property - def current_profile(self): + def current_profile(self) -> Dict[str, List[Dict]]: if not self._current_profile: - rpm_profile = get_profile("rpm").collect() - enabled_repos = get_profile("enabled_repos").collect() - module_profile = get_profile("modulemd").collect() - combined_profile = self._assembly_profile(rpm_profile, enabled_repos, module_profile) + rpm_profile: List[Dict] = get_profile("rpm").collect() + enabled_repos: List[Dict] = get_profile("enabled_repos").collect() + module_profile: List[Dict] = get_profile("modulemd").collect() + combined_profile: Dict[str, List[Dict]] = self._assembly_profile( + rpm_profile, enabled_repos, module_profile + ) self._current_profile = combined_profile return self._current_profile @current_profile.setter - def current_profile(self, new_profile): + def current_profile(self, new_profile: Dict[str, List[Dict]]): self._current_profile = new_profile - def to_dict(self): + def to_dict(self) -> Dict[str, List[Dict]]: return self.current_profile - def _load_data(self, open_file): - json_str = open_file.read() + def _load_data(self, open_file: TextIO) -> Dict: + json_str: str = open_file.read() return json.loads(json_str) - def update_check(self, uep, consumer_uuid, force=False): + def update_check( + self, uep: connection.UEPConnection, consumer_uuid: str, force: bool = False + ) -> Literal[0, 1]: """ Check if packages have changed, and push an update if so. """ # If the server doesn't support packages, don't try to send the profile: - supported_resources = get_supported_resources(uep=None, identity=self.identity) + supported_resources: Dict = get_supported_resources(uep=None, identity=self.identity) if PACKAGES_RESOURCE not in supported_resources: log.warning("Server does not support packages, skipping profile upload.") return 0 @@ -498,22 +534,22 @@ def update_check(self, uep, consumer_uuid, force=False): else: return 0 - def has_changed(self): + def has_changed(self) -> bool: if not self._cache_exists(): log.debug("Cache file %s does not exist" % self.CACHE_FILE) return True - cached_profile = self._read_cache() + cached_profile: Optional[str] = self._read_cache() return not cached_profile == self.current_profile - def _sync_with_server(self, uep, consumer_uuid, *args, **kwargs): + def _sync_with_server(self, uep: connection.UEPConnection, consumer_uuid: str, *args, **kwargs) -> None: """ This method has to be able to sync combined profile, when server supports this functionality and it also has to be able to send only profile containing list of installed RPMs. """ - combined_profile = self.current_profile + combined_profile: Dict = self.current_profile if uep.has_capability("combined_reporting"): - _combined_profile = [ + _combined_profile: List[Dict] = [ {"content_type": "rpm", "profile": combined_profile["rpm"]}, {"content_type": "enabled_repos", "profile": combined_profile["enabled_repos"]}, {"content_type": "modulemd", "profile": combined_profile["modulemd"]}, @@ -532,14 +568,14 @@ class InstalledProductsManager(CacheManager): CACHE_FILE = "/var/lib/rhsm/cache/installed_products.json" def __init__(self): - self._installed = None - self.tags = None + self._installed: Dict[str, Dict] = None + self.tags: Set[str] = None - self.product_dir = inj.require(inj.PROD_DIR) + self.product_dir: ProductDirectory = inj.require(inj.PROD_DIR) self._setup_installed() - def _get_installed(self): + def _get_installed(self) -> Dict: if self._installed: return self._installed @@ -547,24 +583,24 @@ def _get_installed(self): return self._installed - def _set_installed(self, value): + def _set_installed(self, value: Dict) -> None: self._installed = value installed = property(_get_installed, _set_installed) - def to_dict(self): + def to_dict(self) -> Dict: return {"products": self.installed, "tags": self.tags} - def _load_data(self, open_file): + def _load_data(self, open_file: TextIO) -> Dict: json_str = open_file.read() return json.loads(json_str) - def has_changed(self): + def has_changed(self) -> bool: if not self._cache_exists(): log.debug("Cache file %s does not exist" % self.CACHE_FILE) return True - cached = self._read_cache() + cached: Dict = self._read_cache() try: products = cached["products"] tags = set(cached["tags"]) @@ -585,15 +621,16 @@ def has_changed(self): return False - def _setup_installed(self): + def _setup_installed(self) -> None: """ Format installed product data to match the cache and what the server can use. """ - self._installed = {} + self._installed: Dict[str, Dict] = {} self.tags = set() + prod_cert: EntitlementCertificate for prod_cert in self.product_dir.list(): - prod = prod_cert.products[0] + prod: Product = prod_cert.products[0] self.tags |= set(prod.provided_tags) self._installed[prod.id] = { "productId": prod.id, @@ -602,17 +639,17 @@ def _setup_installed(self): "arch": ",".join(prod.architectures), } - def format_for_server(self): + def format_for_server(self) -> List[Dict]: """ Convert the format we store in this object (which is a little easier to work with) into the format the server expects for the consumer. """ self._setup_installed() - final = [val for (key, val) in list(self.installed.items())] + final: List[Dict] = [val for (key, val) in list(self.installed.items())] return final - def _sync_with_server(self, uep, consumer_uuid, *args, **kwargs): + def _sync_with_server(self, uep: connection.UEPConnection, consumer_uuid: str, *args, **kwargs) -> None: uep.updateConsumer( consumer_uuid, installed_products=self.format_for_server(), @@ -627,7 +664,7 @@ class PoolStatusCache(StatusCache): CACHE_FILE = "/var/lib/rhsm/cache/pool_status.json" - def _sync_with_server(self, uep, uuid, *args, **kwargs): + def _sync_with_server(self, uep: connection.UEPConnection, uuid: str, *args, **kwargs) -> None: self.server_status = uep.getEntitlementList(uuid) @@ -637,45 +674,47 @@ class PoolTypeCache(object): """ def __init__(self): - self.identity = inj.require(inj.IDENTITY) - self.cp_provider = inj.require(inj.CP_PROVIDER) - self.ent_dir = inj.require(inj.ENT_DIR) - self.pool_cache = inj.require(inj.POOL_STATUS_CACHE) - self.pooltype_map = {} + self.identity: Identity = inj.require(inj.IDENTITY) + self.cp_provider: CPProvider = inj.require(inj.CP_PROVIDER) + self.ent_dir: EntitlementDirectory = inj.require(inj.ENT_DIR) + self.pool_cache: PoolStatusCache = inj.require(inj.POOL_STATUS_CACHE) + self.pooltype_map: Dict[str, str] = {} self.update() - def get(self, pool_id): + def get(self, pool_id: str): return self.pooltype_map.get(pool_id, "") - def update(self): + def update(self) -> None: if self.requires_update(): self._do_update() - def requires_update(self): - attached_pool_ids = set([ent.pool.id for ent in self.ent_dir.list() if ent.pool and ent.pool.id]) - missing_types = attached_pool_ids - set(self.pooltype_map) + def requires_update(self) -> bool: + attached_pool_ids: Set[str] = set( + [ent.pool.id for ent in self.ent_dir.list() if ent.pool and ent.pool.id] + ) + missing_types: Set[str] = attached_pool_ids - set(self.pooltype_map) return bool(missing_types) - def _do_update(self): - result = {} + def _do_update(self) -> None: + result: Dict[str, str] = {} if self.identity.is_valid(): self.pool_cache.load_status(self.cp_provider.get_consumer_auth_cp(), self.identity.uuid) - entitlement_list = self.pool_cache.server_status + entitlement_list: List[Dict] = self.pool_cache.server_status if entitlement_list is not None: for ent in entitlement_list: pool = PoolWrapper(ent.get("pool", {})) - pool_type = pool.get_pool_type() + pool_type: str = pool.get_pool_type() result[pool.get_id()] = pool_type self.pooltype_map.update(result) - def update_from_pools(self, pool_map): + def update_from_pools(self, pool_map: Dict) -> None: # pool_map maps pool ids to pool json for pool_id in pool_map: self.pooltype_map[pool_id] = PoolWrapper(pool_map[pool_id]).get_pool_type() - def clear(self): + def clear(self) -> None: self.pooltype_map = {} @@ -683,13 +722,13 @@ class ContentAccessCache(object): CACHE_FILE = "/var/lib/rhsm/cache/content_access.json" def __init__(self): - self.cp_provider = inj.require(inj.CP_PROVIDER) - self.identity = inj.require(inj.IDENTITY) + self.cp_provider: CPProvider = inj.require(inj.CP_PROVIDER) + self.identity: Identity = inj.require(inj.IDENTITY) - def _query_for_update(self, if_modified_since=None): - uep = self.cp_provider.get_consumer_auth_cp() + def _query_for_update(self, if_modified_since: Optional[datetime.datetime] = None): + uep: connection.UEPConnection = self.cp_provider.get_consumer_auth_cp() try: - response = uep.getAccessibleContent(self.identity.uuid, if_modified_since=if_modified_since) + response: Dict = uep.getAccessibleContent(self.identity.uuid, if_modified_since=if_modified_since) except connection.RestlibException as err: log.warning("Unable to query for content access updates: %s", err) return None @@ -699,17 +738,18 @@ def _query_for_update(self, if_modified_since=None): self._update_cache(response) return response - def exists(self): + def exists(self) -> bool: return os.path.exists(self.CACHE_FILE) def remove(self): return os.remove(self.CACHE_FILE) - def check_for_update(self): - data = None + def check_for_update(self) -> Optional[Dict]: + data: Optional[Dict] = None + last_update: Optional[datetime.datetime] if self.exists(): try: - data = json.loads(self.read()) + data: Dict = json.loads(self.read()) last_update = parse_date(data["lastUpdate"]) except (ValueError, KeyError) as err: log.debug("Cache file {file} is corrupted: {err}".format(file=self.CACHE_FILE, err=err)) @@ -717,7 +757,7 @@ def check_for_update(self): else: last_update = None - response = self._query_for_update(if_modified_since=last_update) + response: Optional[Dict] = self._query_for_update(if_modified_since=last_update) # Candlepin 4 bug 2010251. if_modified_since is not reliable so # we double checks whether or not the sca certificate is changed. if data is not None and data == response: @@ -726,23 +766,23 @@ def check_for_update(self): return response @staticmethod - def update_cert(cert, data): + def update_cert(cert: "EntitlementCertificate", data: Optional[Dict]) -> None: if data is None: return if data["contentListing"] is None or str(cert.serial) not in data["contentListing"]: log.warning("Cert serial %s not contained in content listing; not updating it." % cert.serial) return with open(cert.path, "w") as output: - updated_cert = "".join(data["contentListing"][str(cert.serial)]) + updated_cert: str = "".join(data["contentListing"][str(cert.serial)]) log.debug("Updating certificate %s with new content" % cert.serial) output.write(updated_cert) - def _update_cache(self, data): + def _update_cache(self, data: Dict) -> None: log.debug("Updating content access cache") with open(self.CACHE_FILE, "w") as cache: cache.write(json.dumps(data)) - def read(self): + def read(self) -> str: with open(self.CACHE_FILE, "r") as cache: return cache.read() @@ -756,15 +796,15 @@ class WrittenOverrideCache(CacheManager): CACHE_FILE = "/var/lib/rhsm/cache/written_overrides.json" - def __init__(self, overrides=None): + def __init__(self, overrides: Optional[Dict] = None): self.overrides = overrides or {} - def to_dict(self): + def to_dict(self) -> Dict: return self.overrides - def _load_data(self, open_file): + def _load_data(self, open_file: TextIO) -> Optional[Dict]: try: - self.overrides = json.loads(open_file.read()) or {} + self.overrides: Dict = json.loads(open_file.read()) or {} return self.overrides except IOError as err: log.error("Unable to read cache: %s" % self.CACHE_FILE) @@ -792,17 +832,17 @@ class ConsumerCache(CacheManager): # Some data should have some timeout of validity, because data can be changed over time # on the server, because server could be updated and it can start provide new functionality. # E.g. supported resources or available capabilities. Value of timeout is in seconds. - TIMEOUT = None + TIMEOUT: Optional[int] = None - def __init__(self, data=None): + def __init__(self, data: Any = None): self.data = data or {} - def to_dict(self): + def to_dict(self) -> Dict: return self.data - def _load_data(self, open_file): + def _load_data(self, open_file: TextIO) -> Optional[Dict]: try: - self.data = json.loads(open_file.read()) or {} + self.data: Dict = json.loads(open_file.read()) or {} return self.data except IOError as err: log.error("Unable to read cache: %s" % self.CACHE_FILE) @@ -811,7 +851,7 @@ def _load_data(self, open_file): # Ignore json file parse error pass - def _sync_with_server(self, uep, consumer_uuid, *args, **kwargs): + def _sync_with_server(self, uep: connection.UEPConnection, consumer_uuid: str, *args, **kwargs): """ This method has to be implemented in sub-classes of this class :param uep: object representing connection to candlepin server @@ -822,7 +862,9 @@ def _sync_with_server(self, uep, consumer_uuid, *args, **kwargs): """ raise NotImplementedError - def _is_cache_obsoleted(self, uep, identity, *args, **kwargs): + def _is_cache_obsoleted( + self, uep: connection.UEPConnection, identity: "Identity", *args, **kwargs + ) -> bool: """ Another method for checking if cached file is obsoleted :param args: positional arguments @@ -831,7 +873,7 @@ def _is_cache_obsoleted(self, uep, identity, *args, **kwargs): """ return False - def set_data(self, current_data, identity=None): + def set_data(self, current_data: Any, identity: Optional["Identity"] = None): """ Set data into internal cache :param current_data: data to set @@ -842,9 +884,11 @@ def set_data(self, current_data, identity=None): if identity is None: identity = inj.require(inj.IDENTITY) - self.data = {identity.uuid: current_data} + self.data: Dict[str, Any] = {identity.uuid: current_data} - def read_data(self, uep=None, identity=None): + def read_data( + self, uep: Optional[connection.UEPConnection] = None, identity: Optional["Identity"] = None + ) -> Dict: """ This function tries to get data from cache or server :param uep: connection to candlepin server @@ -852,10 +896,10 @@ def read_data(self, uep=None, identity=None): :return: information about current owner """ - current_data = self.DEFAULT_VALUE + current_data: Dict = self.DEFAULT_VALUE if identity is None: - identity = inj.require(inj.IDENTITY) + identity: Identity = inj.require(inj.IDENTITY) # When identity is not known, then system is not registered and # data are obsoleted @@ -864,16 +908,16 @@ def read_data(self, uep=None, identity=None): return current_data # Try to use class specific test if the cache file is obsoleted - cache_file_obsoleted = self._is_cache_obsoleted(uep, identity) + cache_file_obsoleted: bool = self._is_cache_obsoleted(uep, identity) # When timeout for cache is defined, then check if the cache file is not # too old. In that case content of the cache file will be overwritten with # new content from the server. if self.TIMEOUT is not None: if os.path.exists(self.CACHE_FILE): - mod_time = os.path.getmtime(self.CACHE_FILE) - cur_time = time.time() - diff = cur_time - mod_time + mod_time: float = os.path.getmtime(self.CACHE_FILE) + cur_time: float = time.time() + diff: float = cur_time - mod_time if diff > self.TIMEOUT: log.debug("Validity of cache file %s timed out (%d)" % (self.CACHE_FILE, self.TIMEOUT)) cache_file_obsoleted = True @@ -881,7 +925,7 @@ def read_data(self, uep=None, identity=None): if cache_file_obsoleted is False: # Try to read data from cache first log.debug("Trying to read %s from cache file %s" % (self.__class__.__name__, self.CACHE_FILE)) - data = self.read_cache_only() + data: Optional[Dict] = self.read_cache_only() if data is not None: if identity.uuid in data: current_data = data[identity.uuid] @@ -895,12 +939,12 @@ def read_data(self, uep=None, identity=None): log.debug("Data loaded from cache file: %s" % self.CACHE_FILE) else: if uep is None: - cp_provider = inj.require(inj.CP_PROVIDER) + cp_provider: CPProvider = inj.require(inj.CP_PROVIDER) uep = cp_provider.get_consumer_auth_cp() log.debug("Getting data from server for %s" % self.__class__) try: - current_data = self._sync_with_server(uep=uep, consumer_uuid=identity.uuid) + current_data: Dict = self._sync_with_server(uep=uep, consumer_uuid=identity.uuid) except connection.RestlibException as rest_err: log.warning("Unable to get data for %s using REST API: %s" % (self.__class__, rest_err)) log.debug("Deleting cache file: %s", self.CACHE_FILE) @@ -922,14 +966,14 @@ class SyspurposeValidFieldsCache(ConsumerCache): CACHE_FILE = "/var/lib/rhsm/cache/valid_fields.json" - def __init__(self, data=None): + def __init__(self, data: Any = None): super(SyspurposeValidFieldsCache, self).__init__(data=data) - def _sync_with_server(self, uep, *args, **kwargs): - cache = inj.require(inj.CURRENT_OWNER_CACHE) - owner = cache.read_data(uep) + def _sync_with_server(self, uep: connection.UEPConnection, *args, **kwargs) -> Dict: + cache: CurrentOwnerCache = inj.require(inj.CURRENT_OWNER_CACHE) + owner: Dict = cache.read_data(uep) if "key" in owner: - data = uep.getOwnerSyspurposeValidFields(owner["key"]) + data: Dict = uep.getOwnerSyspurposeValidFields(owner["key"]) return post_process_received_data(data) else: return self.DEFAULT_VALUE @@ -945,13 +989,15 @@ class CurrentOwnerCache(ConsumerCache): CACHE_FILE = "/var/lib/rhsm/cache/current_owner.json" - def __init__(self, data=None): + def __init__(self, data: Any = None): super(CurrentOwnerCache, self).__init__(data=data) - def _sync_with_server(self, uep, consumer_uuid, *args, **kwargs): + def _sync_with_server(self, uep: connection.UEPConnection, consumer_uuid: str, *args, **kwargs): return uep.getOwner(consumer_uuid) - def _is_cache_obsoleted(self, uep, identity, *args, **kwargs): + def _is_cache_obsoleted( + self, uep: connection.UEPConnection, identity: "Identity", *args, **kwargs + ) -> bool: """ We don't know if the cache is valid until we get valid response :param uep: object representing connection to candlepin server @@ -961,8 +1007,8 @@ def _is_cache_obsoleted(self, uep, identity, *args, **kwargs): :return: True, when cache is obsoleted or validity of cache is unknown. """ if uep is None: - cp_provider = inj.require(inj.CP_PROVIDER) - uep = cp_provider.get_consumer_auth_cp() + cp_provider: CPProvider = inj.require(inj.CP_PROVIDER) + uep: connection.UEPConnection = cp_provider.get_consumer_auth_cp() if hasattr(uep.conn, "is_consumer_cert_key_valid") and uep.conn.is_consumer_cert_key_valid is True: return False else: @@ -981,12 +1027,12 @@ class ContentAccessModeCache(ConsumerCache): CACHE_FILE = "/var/lib/rhsm/cache/content_access_mode.json" - def __init__(self, data=None): + def __init__(self, data: Any = None): super(ContentAccessModeCache, self).__init__(data=data) - def _sync_with_server(self, uep, consumer_uuid, *args, **kwargs): + def _sync_with_server(self, uep: connection.UEPConnection, consumer_uuid: str, *args, **kwargs): try: - current_owner = uep.getOwner(consumer_uuid) + current_owner: Dict = uep.getOwner(consumer_uuid) except Exception: log.debug( "Error checking for content access mode," @@ -1003,7 +1049,7 @@ def _sync_with_server(self, uep, consumer_uuid, *args, **kwargs): ) return "unknown" - def _is_cache_obsoleted(self, uep, identity, *args, **kwargs): + def _is_cache_obsoleted(self, uep: connection.UEPConnection, identity: "Identity", *args, **kwargs): """ We don't know if the cache is valid until we get valid response :param uep: object representing connection to candlepin server @@ -1013,8 +1059,8 @@ def _is_cache_obsoleted(self, uep, identity, *args, **kwargs): :return: True, when cache is obsoleted or validity of cache is unknown. """ if uep is None: - cp_provider = inj.require(inj.CP_PROVIDER) - uep = cp_provider.get_consumer_auth_cp() + cp_provider: CPProvider = inj.require(inj.CP_PROVIDER) + uep: connection.UEPConnection = cp_provider.get_consumer_auth_cp() if hasattr(uep.conn, "is_consumer_cert_key_valid"): if uep.conn.is_consumer_cert_key_valid is None: @@ -1047,10 +1093,10 @@ class SupportedResourcesCache(ConsumerCache): # We will try to get new list of supported resources at leas once a day TIMEOUT = 60 * 60 * 24 - def __init__(self, data=None): + def __init__(self, data: Any = None): super(SupportedResourcesCache, self).__init__(data=data) - def _sync_with_server(self, uep, *args, **kwargs): + def _sync_with_server(self, uep: connection.UEPConnection, *args, **kwargs): return uep.get_supported_resources() @@ -1074,13 +1120,14 @@ def __init__(self, available_entitlements=None): def to_dict(self): return self.available_entitlements - def timeout(self): + def timeout(self) -> float: """ Compute timeout of cache. Computation of timeout is based on SRT (smoothed response time) of connection to candlepin server. This algorithm is inspired by retransmission timeout used by TCP connection (see: RFC 793) """ - uep = inj.require(inj.CP_PROVIDER).get_consumer_auth_cp() + uep: connection.UEPConnection = inj.require(inj.CP_PROVIDER).get_consumer_auth_cp() + smoothed_rt: float if uep.conn.smoothed_rt is not None: smoothed_rt = uep.conn.smoothed_rt @@ -1088,7 +1135,7 @@ def timeout(self): smoothed_rt = 0.0 return min(self.UBOUND, max(self.LBOUND, self.BETA * smoothed_rt)) - def get_not_obsolete_data(self, identity, filter_options): + def get_not_obsolete_data(self, identity: "Identity", filter_options: Dict) -> Dict: """ Try to get not obsolete cached data :param identity: identity with UUID @@ -1096,11 +1143,11 @@ def get_not_obsolete_data(self, identity, filter_options): :return: When data are not obsoleted, then return cached dictionary of available entitlements. Otherwise return empty dictionary. """ - data = self.read_cache_only() - available_pools = {} + data: Optional[Dict] = self.read_cache_only() + available_pools: Dict = {} if data is not None: if identity.uuid in data: - cached_data = data[identity.uuid] + cached_data: Dict = data[identity.uuid] if cached_data["filter_options"] == filter_options: log.debug("timeout: %s, current time: %s" % (cached_data["timeout"], time.time())) if cached_data["timeout"] > time.time(): @@ -1112,7 +1159,7 @@ def get_not_obsolete_data(self, identity, filter_options): log.debug("Cache of available entitlements does not contain given filter options") return available_pools - def _load_data(self, open_file): + def _load_data(self, open_file: TextIO) -> Optional[Dict]: try: self.available_entitlements = json.loads(open_file.read()) or {} return self.available_entitlements diff --git a/src/subscription_manager/cert_sorter.py b/src/subscription_manager/cert_sorter.py index a0520aad82..bdb885369d 100644 --- a/src/subscription_manager/cert_sorter.py +++ b/src/subscription_manager/cert_sorter.py @@ -14,16 +14,24 @@ from copy import copy from datetime import datetime import logging +from typing import Callable, Dict, List, Optional, Set, TYPE_CHECKING from rhsm.certificate import GMT from rhsm.connection import RestlibException +from rhsmlib import file_monitor + import subscription_manager.injection as inj +from subscription_manager.i18n import ugettext as _ from subscription_manager.isodate import parse_date from subscription_manager.reasons import Reasons -from rhsmlib import file_monitor from subscription_manager import utils -from subscription_manager.i18n import ugettext as _ +if TYPE_CHECKING: + from rhsm.certificate2 import EntitlementCertificate + from subscription_manager.cache import EntitlementStatusCache, InstalledProductsManager + from subscription_manager.certdirectory import ProductDirectory, EntitlementDirectory + from subscription_manager.cp_provider import CPProvider + from subscription_manager.identity import Identity log = logging.getLogger(__name__) @@ -57,27 +65,27 @@ class ComplianceManager(object): - def __init__(self, on_date=None): - self.cp_provider = inj.require(inj.CP_PROVIDER) - self.product_dir = inj.require(inj.PROD_DIR) - self.entitlement_dir = inj.require(inj.ENT_DIR) - self.identity = inj.require(inj.IDENTITY) - self.on_date = on_date - self.installed_products = None - self.unentitled_products = None - self.expired_products = None - self.partially_valid_products = None - self.valid_products = None - self.partial_stacks = None - self.future_products = None - self.reasons = None - self.supports_reasons = False - self.system_status = None - self.valid_entitlement_certs = None - self.status = None + def __init__(self, on_date: Optional[datetime] = None): + self.cp_provider: CPProvider = inj.require(inj.CP_PROVIDER) + self.product_dir: ProductDirectory = inj.require(inj.PROD_DIR) + self.entitlement_dir: EntitlementDirectory = inj.require(inj.ENT_DIR) + self.identity: Identity = inj.require(inj.IDENTITY) + self.on_date: datetime = on_date + self.installed_products: Optional[Dict[str, List[EntitlementCertificate]]] = None + self.unentitled_products: Optional[Dict[str, List[EntitlementCertificate]]] = None + self.expired_products: Optional[Dict[str, List[EntitlementCertificate]]] = None + self.partially_valid_products: Optional[Dict[str, List[EntitlementCertificate]]] = None + self.valid_products: Optional[Dict[str, List[EntitlementCertificate]]] = None + self.partial_stacks: Optional[Dict[str, List[EntitlementCertificate]]] = None + self.future_products: Optional[Dict[str, List[EntitlementCertificate]]] = None + self.reasons: Optional[Reasons] = None + self.supports_reasons: bool = False + self.system_status: Optional[str] = None + self.valid_entitlement_certs: Optional[List[EntitlementCertificate]] = None + self.status: Optional[str] = None self.load() - def load(self): + def load(self) -> None: # All products installed on this machine, regardless of status. Maps # installed product ID to product certificate. self.installed_products = self.product_dir.get_installed_products() @@ -123,18 +131,18 @@ def load(self): self._parse_server_status() - def get_compliance_status(self): + def get_compliance_status(self) -> Optional[Dict]: """ Try to get compliance status from server to get fresh information about compliance status. :return: Compliance status, when server of cache is available. Otherwise None is returned. """ - status_cache = inj.require(inj.ENTITLEMENT_STATUS_CACHE) + status_cache: EntitlementStatusCache = inj.require(inj.ENTITLEMENT_STATUS_CACHE) self.status = status_cache.load_status( self.cp_provider.get_consumer_auth_cp(), self.identity.uuid, self.on_date ) return self.status - def _parse_server_status(self): + def _parse_server_status(self) -> None: """Fetch entitlement status info from server and parse.""" if not self.is_registered(): @@ -142,7 +150,7 @@ def _parse_server_status(self): return # Override get_status - status = self.get_compliance_status() + status: Optional[Dict] = self.get_compliance_status() if status is None: return @@ -174,14 +182,14 @@ def _parse_server_status(self): # it is returning the first second we are invalid), then add a full # 24 hours giving us the first date where we know we're completely # invalid from midnight to midnight. - self.compliant_until = None + self.compliant_until: Optional[datetime] = None if status["compliantUntil"] is not None: self.compliant_until = parse_date(status["compliantUntil"]) # Lookup product certs for each unentitled product returned by # the server: - unentitled_pids = status["nonCompliantProducts"] + unentitled_pids: List[str] = status["nonCompliantProducts"] # Add in any installed products not in the server response. This # could happen if something changes before the certd runs. Log # a warning if it does, and treat it like an unentitled product. @@ -206,8 +214,8 @@ def _parse_server_status(self): self.log_products() - def log_products(self): - fj = utils.friendly_join + def log_products(self) -> None: + fj: Callable = utils.friendly_join log.debug( "Product status: valid_products=%s partial_products=%s expired_products=%s" @@ -222,7 +230,7 @@ def log_products(self): log.debug("partial stacks: %s" % list(self.partial_stacks.keys())) - def _scan_entitlement_certs(self): + def _scan_entitlement_certs(self) -> None: """ Scan entitlement certs looking for unentitled products which may have expired, or be entitled in future. @@ -238,11 +246,11 @@ def _scan_entitlement_certs(self): if k not in list(self.valid_products.keys()) and k not in list(self.partially_valid_products.keys()) ) - ent_certs = self.entitlement_dir.list() + ent_certs: List[EntitlementCertificate] = self.entitlement_dir.list() - on_date = datetime.now(GMT()) + on_date: datetime = datetime.now(GMT()) + ent_cert: EntitlementCertificate for ent_cert in ent_certs: - # Builds the list of valid entitlement certs today: if ent_cert.is_valid(): self.valid_entitlement_certs.append(ent_cert) @@ -263,11 +271,11 @@ def _scan_entitlement_certs(self): product_dict.setdefault(product.id, []).append(ent_cert) - def get_system_status_id(self): + def get_system_status_id(self) -> Optional[str]: return self.system_status @staticmethod - def get_status_map(): + def get_status_map() -> Dict[str, str]: """ Get status map :return: status map @@ -284,26 +292,26 @@ def get_status_map(): } return status_map - def get_system_status(self): - status_map = self.get_status_map() + def get_system_status(self) -> str: + status_map: Dict[str, str] = self.get_status_map() return status_map.get(self.system_status, status_map[UNKNOWN]) - def are_reasons_supported(self): + def are_reasons_supported(self) -> bool: # Check if the candlepin in use supports status # detail messages. Older versions don't. return self.supports_reasons - def is_valid(self): + def is_valid(self) -> bool: """ Return true if the results of this cert sort indicate our entitlements are completely valid. """ return self.system_status == VALID or self.system_status == DISABLED - def is_registered(self): + def is_registered(self) -> bool: return inj.require(inj.IDENTITY).is_valid() - def get_status(self, product_id): + def get_status(self, product_id: str) -> str: """Return the status of a given product""" if not self.is_registered(): return UNKNOWN @@ -322,14 +330,14 @@ def get_status(self, product_id): # API call: return UNKNOWN - def in_warning_period(self): + def in_warning_period(self) -> bool: for entitlement in self.valid_entitlement_certs: if entitlement.is_expiring(): return True return False # Assumes classic and identity validity have been tested - def get_status_for_icon(self): + def get_status_for_icon(self) -> int: if self.system_status == "invalid": return RHSM_EXPIRED if self.system_status == "partial": @@ -355,19 +363,19 @@ class CertSorter(ComplianceManager): reporting unknown. """ - def __init__(self, on_date=None): + def __init__(self, on_date: datetime = None): # Sync installed product info with server. # This will be done on register if we aren't registered. # ComplianceManager.__init__ needs the installed product info # in sync before it will be accurate, so update it, then # super().__init__. See rhbz #1004893 - self.installed_mgr = inj.require(inj.INSTALLED_PRODUCTS_MANAGER) + self.installed_mgr: InstalledProductsManager = inj.require(inj.INSTALLED_PRODUCTS_MANAGER) self.update_product_manager() super(CertSorter, self).__init__(on_date) - self.callbacks = set() + self.callbacks: Set[Callable] = set() - cert_dir_monitors = { + cert_dir_monitors: Dict[str, file_monitor.DirectoryWatch] = { file_monitor.PRODUCT_WATCHER: file_monitor.DirectoryWatch( inj.require(inj.PROD_DIR).path, [self.on_prod_dir_changed, self.load] ), @@ -383,57 +391,58 @@ def __init__(self, on_date=None): # the gui can add one. self.cert_monitor = file_monitor.FilesystemWatcher(cert_dir_monitors) - def update_product_manager(self): + def update_product_manager(self) -> None: if self.is_registered(): - cp_provider = inj.require(inj.CP_PROVIDER) - consumer_identity = inj.require(inj.IDENTITY) + cp_provider: CPProvider = inj.require(inj.CP_PROVIDER) + consumer_identity: Identity = inj.require(inj.IDENTITY) try: self.installed_mgr.update_check(cp_provider.get_consumer_auth_cp(), consumer_identity.uuid) except RestlibException: # Invalid consumer certificate pass - def force_cert_check(self): - updated = self.cert_monitor.update() + def force_cert_check(self) -> None: + updated: Set[file_monitor.DirectoryWatch] = self.cert_monitor.update() if updated: self.notify() - def notify(self): + def notify(self) -> None: + callback: Callable for callback in copy(self.callbacks): callback() - def add_callback(self, cb): + def add_callback(self, cb: Callable) -> None: self.callbacks.add(cb) - def remove_callback(self, cb): + def remove_callback(self, cb: Callable) -> bool: try: self.callbacks.remove(cb) return True except KeyError: return False - def on_change(self): + def on_change(self) -> None: self.load() self.notify() - def on_certs_changed(self): + def on_certs_changed(self) -> None: # Now that local data has been refreshed, updated compliance self.on_change() - def on_prod_dir_changed(self): + def on_prod_dir_changed(self) -> None: """ Callback method, when content of directory with product certificates has been changed """ self.product_dir.refresh() self.update_product_manager() - def on_ent_dir_changed(self): + def on_ent_dir_changed(self) -> None: """ Callback method, when content of directory with entitlement certificates has been changed """ self.entitlement_dir.refresh() - def on_identity_changed(self): + def on_identity_changed(self) -> None: """ Callback method, when content of directory with consumer certificate has been changed """ @@ -441,18 +450,19 @@ def on_identity_changed(self): self.cp_provider.clean() # check to see if there are certs in the directory - def has_entitlements(self): + def has_entitlements(self) -> bool: return len(self.entitlement_dir.list()) > 0 class StackingGroupSorter(object): - def __init__(self, entitlements): - self.groups = [] - stacking_groups = {} + def __init__(self, entitlements: List["EntitlementCertificate"]): + self.groups: List[EntitlementGroup] = [] + stacking_groups: Dict[str, EntitlementGroup] = {} for entitlement in entitlements: - stacking_id = self._get_stacking_id(entitlement) + stacking_id: Optional[str] = self._get_stacking_id(entitlement) if stacking_id: + group: EntitlementGroup if stacking_id not in stacking_groups: group = EntitlementGroup(entitlement, self._get_identity_name(entitlement)) self.groups.append(group) @@ -463,34 +473,34 @@ def __init__(self, entitlements): else: self.groups.append(EntitlementGroup(entitlement)) - def _get_stacking_id(self, entitlement): + def _get_stacking_id(self, entitlement: "EntitlementCertificate"): raise NotImplementedError("Subclasses must implement: _get_stacking_id") - def _get_identity_name(self, entitlement): + def _get_identity_name(self, entitlement: "EntitlementCertificate"): raise NotImplementedError("Subclasses must implement: _get_identity_name") class EntitlementGroup(object): - def __init__(self, entitlement, name=""): + def __init__(self, entitlement: "EntitlementCertificate", name: str = ""): self.name = name - self.entitlements = [] + self.entitlements: List[EntitlementCertificate] = [] self.add_entitlement_cert(entitlement) - def add_entitlement_cert(self, entitlement): + def add_entitlement_cert(self, entitlement: "EntitlementCertificate") -> None: self.entitlements.append(entitlement) class EntitlementCertStackingGroupSorter(StackingGroupSorter): - def __init__(self, certs): + def __init__(self, certs: List["EntitlementCertificate"]): StackingGroupSorter.__init__(self, certs) - def _get_stacking_id(self, cert): + def _get_stacking_id(self, cert: "EntitlementCertificate") -> Optional[str]: if cert.order: return cert.order.stacking_id else: return None - def _get_identity_name(self, cert): + def _get_identity_name(self, cert: "EntitlementCertificate") -> Optional[str]: if cert.order: return cert.order.name else: diff --git a/src/subscription_manager/certdirectory.py b/src/subscription_manager/certdirectory.py index 270ef9ee86..07c9094dbb 100644 --- a/src/subscription_manager/certdirectory.py +++ b/src/subscription_manager/certdirectory.py @@ -15,6 +15,7 @@ # import logging import os +from typing import Dict, List, Optional, Set, Tuple, TYPE_CHECKING from rhsm.certificate import Key, create_from_file from rhsm.config import get_config_parser @@ -23,6 +24,10 @@ from rhsmlib.services import config from rhsm.certificate2 import CONTENT_ACCESS_CERT_TYPE + +if TYPE_CHECKING: + from rhsm.certificate2 import EntitlementCertificate + log = logging.getLogger(__name__) conf = config.Config(get_config_parser()) @@ -34,7 +39,7 @@ class Directory(object): def __init__(self, path): self.path = Path.abs(path) - def list_all(self): + def list_all(self) -> List[Tuple[str, str]]: all_items = [] if not os.path.exists(self.path): return all_items @@ -44,7 +49,7 @@ def list_all(self): all_items.append(p) return all_items - def list(self): + def list(self) -> List[Tuple[str, str]]: files = [] for p, fn in self.list_all(): path = self.abspath(fn) @@ -54,7 +59,7 @@ def list(self): files.append((p, fn)) return files - def listdirs(self): + def listdirs(self) -> List["Directory"]: dirs = [] for _p, fn in self.list_all(): path = self.abspath(fn) @@ -82,7 +87,7 @@ def clean(self): else: os.unlink(path) - def abspath(self, filename): + def abspath(self, filename) -> str: """ Return path for a filename relative to this directory. """ @@ -90,7 +95,7 @@ def abspath(self, filename): # can just join normally. return os.path.join(self.path, filename) - def __str__(self): + def __str__(self) -> str: return self.path @@ -98,16 +103,16 @@ class CertificateDirectory(Directory): KEY = "key.pem" - def __init__(self, path): + def __init__(self, path: str): super(CertificateDirectory, self).__init__(path) self.create() - self._listing = None + self._listing: Optional[List["EntitlementCertificate"]] = None - def refresh(self): + def refresh(self) -> None: # simply clear the cache. the next list() will reload. self._listing = None - def list(self): + def list(self) -> List["EntitlementCertificate"]: if self._listing is not None: return self._listing listing = [] @@ -119,28 +124,28 @@ def list(self): self._listing = listing return listing - def list_valid(self): + def list_valid(self) -> List["EntitlementCertificate"]: valid = [] for c in self.list(): if c.is_valid(): valid.append(c) return valid - def list_expired(self): + def list_expired(self) -> List["EntitlementCertificate"]: expired = [] for c in self.list(): if c.is_expired(): expired.append(c) return expired - def find(self, sn): + def find(self, sn: str) -> Optional["EntitlementCertificate"]: # TODO: could optimize to just load SERIAL.pem? Maybe not in all cases. for c in self.list(): if c.serial == sn: return c return None - def find_all_by_product(self, p_hash): + def find_all_by_product(self, p_hash: str) -> List["EntitlementCertificate"]: certs = set() providing_stack_ids = set() stack_id_map = {} @@ -167,7 +172,7 @@ def find_all_by_product(self, p_hash): return list(certs) - def find_by_product(self, p_hash): + def find_by_product(self, p_hash: str) -> Optional["EntitlementCertificate"]: for c in self.list(): for p in c.products: if p.id == p_hash: @@ -179,7 +184,7 @@ def find_by_product(self, p_hash): class ProductCertificateDirectory(CertificateDirectory): - def get_provided_tags(self): + def get_provided_tags(self) -> Set[str]: """ Iterates all product certificates in the directory and extracts a set of all tags they provide. @@ -197,13 +202,13 @@ def get_provided_tags(self): # This needs to pick the correct cert if multiple # product certs provide the same product id. # - # If we put the defaults at the begining of .list() - # results, we will override them with the instaled products + # If we put the defaults at the beginning of .list() + # results, we will override them with the installed products # certs. # # Instead of always overriding, something like # productid.ComparableProductCert may be useful - def get_installed_products(self): + def get_installed_products(self) -> Dict[str, "EntitlementCertificate"]: prod_certs = self.list() installed_products = {} for product_cert in prod_certs: @@ -214,22 +219,23 @@ def get_installed_products(self): class ProductDirectory(ProductCertificateDirectory): - def __init__(self, path=None, default_path=None): - installed_prod_path = path or conf["rhsm"]["productCertDir"] - default_prod_path = default_path or DEFAULT_PRODUCT_CERT_DIR + def __init__(self, path: Optional[str] = None, default_path: Optional[str] = None): + # FIXME Missing super() call + installed_prod_path: str = path or conf["rhsm"]["productCertDir"] + default_prod_path: str = default_path or DEFAULT_PRODUCT_CERT_DIR self.installed_prod_dir = ProductCertificateDirectory(path=installed_prod_path) self.default_prod_dir = ProductCertificateDirectory(path=default_prod_path) - def list(self): - installed_prod_list = self.installed_prod_dir.list() - default_prod_list = self.default_prod_dir.list() + def list(self) -> List["EntitlementCertificate"]: + installed_prod_list: List[EntitlementCertificate] = self.installed_prod_dir.list() + default_prod_list: List[EntitlementCertificate] = self.default_prod_dir.list() # Product IDs in installed_prod dir. - pids = set([cert.products[0].id for cert in installed_prod_list]) + pids: Set[str] = set([cert.products[0].id for cert in installed_prod_list]) # Everything from /etc/pki/product, only use product-default for pids that don't already exist return installed_prod_list + [cert for cert in default_prod_list if cert.products[0].id not in pids] - def refresh(self): + def refresh(self) -> None: self.installed_prod_dir.refresh() self.default_prod_dir.refresh() @@ -242,7 +248,7 @@ def refresh(self): # the product cert back to the host in some manner. Or better, let a plugin # decide. @property - def path(self): + def path(self) -> str: return self.installed_prod_dir.path @@ -252,13 +258,13 @@ class EntitlementDirectory(CertificateDirectory): PRODUCT = "product" @classmethod - def productpath(cls): + def productpath(cls) -> str: return cls.PATH def __init__(self): super(EntitlementDirectory, self).__init__(self.productpath()) - def _check_key(self, cert): + def _check_key(self, cert: "EntitlementCertificate") -> bool: """ If the new key file (SERIAL-key.pem) does not exist, check for the old style (key.pem), and if found write it out as the new style. @@ -280,17 +286,18 @@ def _check_key(self, cert): # write the key/cert out again in new style format key = Key.read(old_key_path) + # FIXME Writer does not take any arguments cert_writer = Writer(self) cert_writer.write(key, cert) return True - def list_valid(self): + def list_valid(self) -> List["EntitlementCertificate"]: return [x for x in self.list() if self._check_key(x) and x.is_valid()] - def list_valid_with_content_access(self): + def list_valid_with_content_access(self) -> List["EntitlementCertificate"]: return [x for x in self.list_with_content_access() if self._check_key(x) and x.is_valid()] - def list(self): + def list(self) -> List["EntitlementCertificate"]: """ List entitlement certificates that do not have SCA type :return: list of entitlement certs @@ -298,14 +305,14 @@ def list(self): certs = super(EntitlementDirectory, self).list() return [cert for cert in certs if cert.entitlement_type != CONTENT_ACCESS_CERT_TYPE] - def list_with_content_access(self): + def list_with_content_access(self) -> List["EntitlementCertificate"]: """ List all entitlement certificates :return: list of entitlement certs """ return super(EntitlementDirectory, self).list() - def list_with_sca_mode(self): + def list_with_sca_mode(self) -> List["EntitlementCertificate"]: """ List only entitlement certificates that do have SCA type :return: @@ -313,7 +320,7 @@ def list_with_sca_mode(self): certs = super(EntitlementDirectory, self).list() return [cert for cert in certs if cert.entitlement_type == CONTENT_ACCESS_CERT_TYPE] - def list_for_product(self, product_id): + def list_for_product(self, product_id: str) -> List["EntitlementCertificate"]: """ Returns all entitlement certificates providing access to the given product ID. @@ -325,7 +332,7 @@ def list_for_product(self, product_id): entitlements.append(cert) return entitlements - def list_for_pool_id(self, pool_id): + def list_for_pool_id(self, pool_id: str) -> List["EntitlementCertificate"]: """ Returns all entitlement certificates provided by the given pool ID. @@ -335,7 +342,7 @@ def list_for_pool_id(self, pool_id): ] return entitlements - def list_serials_for_pool_ids(self, pool_ids): + def list_serials_for_pool_ids(self, pool_ids: List[str]) -> Dict[str, List[str]]: """ Returns a dict of all entitlement certificate serials for each pool_id in the list provided """ @@ -352,12 +359,12 @@ class Path(object): ROOT = "/" @classmethod - def join(cls, a, b): + def join(cls, a: str, b: str) -> str: path = os.path.join(a, b) return cls.abs(path) @classmethod - def abs(cls, path): + def abs(cls, path: str) -> str: """Append the ROOT path to the given path.""" if os.path.isabs(path): return os.path.join(cls.ROOT, path[1:]) @@ -365,15 +372,15 @@ def abs(cls, path): return os.path.join(cls.ROOT, path) @classmethod - def isdir(cls, path): + def isdir(cls, path: str) -> bool: return os.path.isdir(path) class Writer(object): def __init__(self): - self.ent_dir = require(ENT_DIR) + self.ent_dir: EntitlementDirectory = require(ENT_DIR) - def write(self, key, cert): + def write(self, key: Key, cert: "EntitlementCertificate") -> None: serial = cert.serial ent_dir_path = self.ent_dir.productpath() diff --git a/src/subscription_manager/certlib.py b/src/subscription_manager/certlib.py index e9e89694da..a559877f0b 100644 --- a/src/subscription_manager/certlib.py +++ b/src/subscription_manager/certlib.py @@ -1,31 +1,36 @@ import logging +from typing import Any, Callable, Optional, List, Union, TYPE_CHECKING from subscription_manager import injection as inj +if TYPE_CHECKING: + from subscription_manager.lock import ActionLock + log = logging.getLogger(__name__) class Locker(object): def __init__(self): - self.lock = self._get_lock() + self.lock: ActionLock = self._get_lock() - def run(self, action): + def run(self, action: Callable) -> Any: self.lock.acquire() try: return action() finally: self.lock.release() - def _get_lock(self): + def _get_lock(self) -> "ActionLock": return inj.require(inj.ACTION_LOCK) class BaseActionInvoker(object): - def __init__(self, locker=None): + def __init__(self, locker: Optional[Locker] = None): self.locker = locker or Locker() - self.report = None + self.report: Any = None + """Output of the callable""" - def update(self): + def update(self) -> Any: self.report = self.locker.run(self._do_update) return self.report @@ -37,32 +42,32 @@ def _do_update(self): class ActionReport(object): """Base class for cert lib and action reports""" - name = "Report" + name: str = "Report" def __init__(self): - self._status = None - self._exceptions = [] - self._updates = [] + self._status: Optional[str] = None + self._exceptions: List[Union[Exception, str]] = [] + self._updates: List[str] = [] - def log_entry(self): + def log_entry(self) -> None: """log report entries""" # assuming a useful repr log.debug(self) - def format_exceptions(self): - buf = "" + def format_exceptions(self) -> str: + buf: str = "" for e in self._exceptions: buf += str(e).split("-", maxsplit=1)[-1].strip() buf += "\n" return buf - def print_exceptions(self): + def print_exceptions(self) -> None: if self._exceptions: print(self.format_exceptions()) - def __str__(self): - template = """%(report_name)s + def __str__(self) -> str: + template: str = """%(report_name)s status: %(status)s updates: %(updates)s exceptions: %(exceptions)s diff --git a/src/subscription_manager/cli.py b/src/subscription_manager/cli.py index 227ddd8c0f..10bd7c8330 100644 --- a/src/subscription_manager/cli.py +++ b/src/subscription_manager/cli.py @@ -14,7 +14,7 @@ import os import sys import logging -from typing import Union +from typing import Dict, List, Optional, Tuple, Type, Union from subscription_manager.exceptions import ExceptionMapper from subscription_manager.printing_utils import columnize, echo_columnize_callback @@ -28,15 +28,14 @@ class InvalidCLIOptionError(Exception): - def __init__(self, message): + def __init__(self, message: str): Exception.__init__(self, message) -def flush_stdout_stderr(): +def flush_stdout_stderr() -> None: """ Try to flush stdout and stderr, when it is not possible due to blocking process, then print error message to log file. - :return: None """ # Try to flush all outputs, see BZ: 1350402 try: @@ -52,18 +51,26 @@ class AbstractCLICommand(object): strategy. """ - def __init__(self, name="cli", aliases=None, shortdesc=None, primary=False): - self.name = name - self.shortdesc = shortdesc - self.primary = primary - self.aliases = aliases or [] - - self.parser = self._create_argparser() - - def main(self, args=None): + # FIXME Default for aliases should be [] + def __init__( + self, + name: str = "cli", + aliases: List[str] = None, + shortdesc: Optional[str] = None, + primary: bool = False, + ): + self.name: str = name + self.shortdesc: Optional[str] = shortdesc + self.primary: bool = primary + self.aliases: List[str] = aliases or [] + + self.parser: ArgumentParser = self._create_argparser() + + # FIXME What is the type of `args`? + def main(self, args=None) -> None: raise NotImplementedError("Commands must implement: main(self, args=None)") - def _validate_options(self): + def _validate_options(self) -> None: """ Validates the command's arguments. @raise InvalidCLIOptionError: Raised when arg validation fails. @@ -71,20 +78,19 @@ def _validate_options(self): # No argument validation by default. pass - def _get_usage(self): + def _get_usage(self) -> str: """ - Usage format strips any leading 'usage' so - do not include it + Usage format strips any leading 'usage' so do not include it. """ return _("%(prog)s {name} [OPTIONS]").format(name=self.name) - def _do_command(self): + def _do_command(self) -> None: """ Does the work that this command intends. """ raise NotImplementedError("Commands must implement: _do_command(self)") - def _create_argparser(self): + def _create_argparser(self) -> ArgumentParser: """ Creates an argparse.ArgumentParser object for this command. @@ -97,37 +103,40 @@ def _create_argparser(self): # taken wholseale from rho... class CLI(object): - def __init__(self, command_classes=None): + def __init__(self, command_classes: List[Type[AbstractCLICommand]] = None): command_classes = command_classes or [] - self.cli_commands = {} - self.cli_aliases = {} + self.cli_commands: Dict[str, AbstractCLICommand] = {} + self.cli_aliases: Dict[str, AbstractCLICommand] = {} for clazz in command_classes: - cmd = clazz() + cmd: AbstractCLICommand = clazz() # ignore the base class if cmd.name != "cli": self.cli_commands[cmd.name] = cmd for alias in cmd.aliases: self.cli_aliases[alias] = cmd - def _add_command(self, cmd): + def _add_command(self, cmd: AbstractCLICommand): self.cli_commands[cmd.name] = cmd - def _default_command(self): + def _default_command(self) -> None: self._usage() - def _usage(self): + def _usage(self) -> None: print(_("Usage: %s MODULE-NAME [MODULE-OPTIONS] [--help]") % os.path.basename(sys.argv[0])) print("\r") items = sorted(self.cli_commands.items()) - items_primary = [] - items_other = [] + items_primary: List[Tuple[str, str]] = [] + items_other: List[Tuple[str, str]] = [] + + name: str + cmd: AbstractCLICommand for (name, cmd) in items: if cmd.primary: items_primary.append((" " + name, cmd.shortdesc)) else: items_other.append((" " + name, cmd.shortdesc)) - all_items = ( + all_items: List[Tuple[str, str]] = ( [(_("Primary Modules:"), "\n")] + items_primary + [("\n" + _("Other Modules:"), "\n")] @@ -135,11 +144,11 @@ def _usage(self): ) self._do_columnize(all_items) - def _do_columnize(self, items_list): + def _do_columnize(self, items_list: List[Tuple[str, str]]) -> None: modules, descriptions = list(zip(*items_list)) print(columnize(modules, echo_columnize_callback, *descriptions) + "\n") - def _find_best_match(self, args): + def _find_best_match(self, args: List[str]) -> Optional[AbstractCLICommand]: """ Returns the subcommand class that best matches the subcommand specified in the argument list. For example, if you have two commands that start @@ -149,7 +158,7 @@ def _find_best_match(self, args): This function ignores the arguments which begin with -- """ - possiblecmd = [] + possiblecmd: List[str] = [] for arg in args[1:]: if not arg.startswith("-"): possiblecmd.append(arg) @@ -157,10 +166,10 @@ def _find_best_match(self, args): if not possiblecmd: return None - cmd = None - i = len(possiblecmd) + cmd: Optional[AbstractCLICommand] = None + i: int = len(possiblecmd) while cmd is None: - key = " ".join(possiblecmd[:i]) + key: str = " ".join(possiblecmd[:i]) if key is None or key == "": break @@ -171,8 +180,8 @@ def _find_best_match(self, args): return cmd - def main(self): - cmd = self._find_best_match(sys.argv) + def main(self) -> Optional[int]: + cmd: AbstractCLICommand = self._find_best_match(sys.argv) if len(sys.argv) < 2: self._default_command() flush_stdout_stderr() @@ -180,7 +189,7 @@ def main(self): if not cmd: self._usage() # Allow for a 0 return code if just calling --help - return_code = 1 + return_code: int = 1 if (len(sys.argv) > 1) and (sys.argv[1] == "--help"): return_code = 0 flush_stdout_stderr() @@ -189,6 +198,8 @@ def main(self): try: return cmd.main() except InvalidCLIOptionError as error: + # FIXME Re-raise the exception instead. Returning None + # (a non-integer) results in return code 0, which is not correct. print(error) diff --git a/src/subscription_manager/cli_command/cli.py b/src/subscription_manager/cli_command/cli.py index 0ffd1d4779..94a2a1c77d 100644 --- a/src/subscription_manager/cli_command/cli.py +++ b/src/subscription_manager/cli_command/cli.py @@ -17,6 +17,7 @@ import logging import os import sys +from typing import Optional import rhsm.config import rhsm.connection as connection @@ -257,7 +258,7 @@ def log_server_version(self): self.server_versions = get_server_versions(cp, exception_on_timeout=False) log.debug("Server Versions: {versions}".format(versions=self.server_versions)) - def main(self, args=None): + def main(self, args=None) -> Optional[int]: # TODO: For now, we disable the CLI entirely. We may want to allow some commands in the future. if rhsm.config.in_container(): @@ -269,7 +270,7 @@ def main(self, args=None): ), ) - config_changed = False + config_changed: bool = False # In testing we sometimes specify args, otherwise use the default: if args is None: diff --git a/src/subscription_manager/content_action_client.py b/src/subscription_manager/content_action_client.py index acfe8de94a..50165a0b2b 100644 --- a/src/subscription_manager/content_action_client.py +++ b/src/subscription_manager/content_action_client.py @@ -13,6 +13,7 @@ # import logging +from typing import Generator, Set, TYPE_CHECKING from subscription_manager import base_action_client from subscription_manager import certlib @@ -21,11 +22,15 @@ import subscription_manager.injection as inj +if TYPE_CHECKING: + from subscription_manager.certlib import ActionReport + from subscription_manager.plugins import PluginHookRunner, PluginManager + log = logging.getLogger(__name__) class ContentPluginActionReport(certlib.ActionReport): - """Aggragate the info reported by each content plugin. + """Aggregate the info reported by each content plugin. Just a set of reports that include info about the content plugin that created it. @@ -35,9 +40,9 @@ class ContentPluginActionReport(certlib.ActionReport): def __init__(self): super(ContentPluginActionReport, self).__init__() - self.reports = set() + self.reports: Set[ActionReport] = set() - def add(self, report): + def add(self, report: "ActionReport"): # report should include info about what plugin generated it self.reports.add(report) @@ -53,10 +58,10 @@ class ContentPluginActionCommand(object): that PluginHookRunner.run() adds to the content plugin conduit. """ - def __init__(self, content_plugin_runner): - self.runner = content_plugin_runner + def __init__(self, content_plugin_runner: "PluginHookRunner"): + self.runner: PluginHookRunner = content_plugin_runner - def perform(self): + def perform(self) -> "ContentPluginActionReport": self.runner.run() # Actually a set of reports... return self.runner.conduit.reports @@ -65,7 +70,7 @@ def perform(self): class ContentPluginActionInvoker(certlib.BaseActionInvoker): """ActionInvoker for ContentPluginActionCommands.""" - def __init__(self, content_plugin_runner): + def __init__(self, content_plugin_runner: "PluginHookRunner"): """Create a ContentPluginActionInvoker to wrap content plugin PluginHookRunner. Pass a PluginHookRunner to ContentPluginActionCommand. Do the @@ -76,24 +81,24 @@ def __init__(self, content_plugin_runner): PluginManager.runiter('content_update').runiter() """ super(ContentPluginActionInvoker, self).__init__() - self.runner = content_plugin_runner + self.runner: PluginHookRunner = content_plugin_runner - def _do_update(self): + def _do_update(self) -> "ContentPluginActionReport": action = ContentPluginActionCommand(self.runner) return action.perform() class ContentActionClient(base_action_client.BaseActionClient): - def _get_libset(self): + def _get_libset(self) -> Generator[certlib.BaseActionInvoker, None, None]: """Return a generator that creates a ContentPluginAction* for each update_content plugin. - The iterable return includes the yum repo action invoker, and a ContentPluginActionInvoker + The iterable return includes the DNF repo action invoker, and a ContentPluginActionInvoker for each plugin hook mapped to the 'update_content_hook' slot. """ yield repolib.RepoActionInvoker() - plugin_manager = inj.require(inj.PLUGIN_MANAGER) + plugin_manager: PluginManager = inj.require(inj.PLUGIN_MANAGER) content_plugins_reports = ContentPluginActionReport() diff --git a/src/subscription_manager/cp_provider.py b/src/subscription_manager/cp_provider.py index 2fbe078728..e983795115 100644 --- a/src/subscription_manager/cp_provider.py +++ b/src/subscription_manager/cp_provider.py @@ -14,7 +14,7 @@ import base64 import json import logging -from typing import Optional +from typing import Optional, Dict from subscription_manager.identity import ConsumerIdentity from subscription_manager import utils @@ -55,38 +55,38 @@ def __init__(self): # FIXME: This does not make much sense to call this method and then rewrite almost everything # with None self.set_connection_info() - self.correlation_id = None - self.username = None - self.password = None - self.token = None - self.token_username = None - self.cdn_hostname = None - self.cdn_port = None - self.cert_file = ConsumerIdentity.certpath() - self.key_file = ConsumerIdentity.keypath() - self.server_hostname = None - self.server_port = None - self.server_prefix = None + self.correlation_id: Optional[str] = None + self.username: Optional[str] = None + self.password: Optional[str] = None + self.token: Optional[str] = None + self.token_username: Optional[str] = None + self.cdn_hostname: Optional[str] = None + self.cdn_port: Optional[str] = None + self.cert_file: str = ConsumerIdentity.certpath() + self.key_file: str = ConsumerIdentity.keypath() + self.server_hostname: Optional[str] = None + self.server_port: Optional[int] = None + self.server_prefix: Optional[str] = None self.proxy_hostname = None - self.proxy_port = None - self.proxy_user = None - self.proxy_password = None - self.no_proxy = None + self.proxy_port: Optional[int] = None + self.proxy_user: Optional[str] = None + self.proxy_password: Optional[str] = None + self.no_proxy: Optional[bool] = None # Reread the config file and prefer arguments over config values # then recreate connections def set_connection_info( self, - host=None, - ssl_port=None, - handler=None, - cert_file=None, - key_file=None, - proxy_hostname_arg=None, - proxy_port_arg=None, - proxy_user_arg=None, - proxy_password_arg=None, - no_proxy_arg=None, + host: Optional[str] = None, + ssl_port: Optional[int] = None, + handler: Optional[str] = None, + cert_file: Optional[str] = None, + key_file: Optional[str] = None, + proxy_hostname_arg: Optional[str] = None, + proxy_port_arg: Optional[int] = None, + proxy_user_arg: Optional[str] = None, + proxy_password_arg: Optional[str] = None, + no_proxy_arg: Optional[bool] = None, ): self.cert_file = ConsumerIdentity.certpath() self.key_file = ConsumerIdentity.keypath() @@ -105,18 +105,18 @@ def set_connection_info( # Set username and password used for basic_auth without # modifying previously set options - def set_user_pass(self, username=None, password=None): + def set_user_pass(self, username: Optional[str] = None, password: Optional[str] = None) -> None: self.username = username self.password = password self.basic_auth_cp = None - def _parse_token(self, token): + def _parse_token(self, token: str) -> Dict: # FIXME: make this more reliable _header, payload, _signature = token.split(".") payload += "=" * (4 - (len(payload) % 4)) # pad to the appropriate length return json.loads(base64.b64decode(payload)) - def set_token(self, token=None): + def set_token(self, token: Optional[str] = None) -> None: self.token = token # FIXME: make this more reliable if token: @@ -126,12 +126,12 @@ def set_token(self, token=None): self.keycloak_auth_cp = None # set up info for the connection to the cdn for finding release versions - def set_content_connection_info(self, cdn_hostname=None, cdn_port=None): + def set_content_connection_info(self, cdn_hostname: Optional[str] = None, cdn_port: Optional[int] = None): self.cdn_hostname = cdn_hostname self.cdn_port = cdn_port self.content_connection = None - def set_correlation_id(self, correlation_id): + def set_correlation_id(self, correlation_id: str): self.correlation_id = correlation_id # Force connections to be re-initialized @@ -145,15 +145,13 @@ def clean(self): def get_client_version() -> str: """ Try to get version of subscription manager - :return: string with version of subscription-manager """ return " subscription-manager/%s" % utils.get_client_versions()["subscription-manager"] @staticmethod - def get_dbus_sender(): + def get_dbus_sender() -> str: """ Try to get D-Bus sender - :return: string with d-bus sender """ dbus_sender = DBusSender() if dbus_sender.cmd_line is not None: @@ -164,7 +162,6 @@ def get_dbus_sender(): def close_all_connections(self) -> None: """ Try to close all connections to candlepin server, CDN, etc. - :return: None """ if self.consumer_auth_cp is not None: log.debug("Closing auth/consumer connection...") @@ -203,23 +200,24 @@ def get_keycloak_auth_cp(self, token) -> connection.UEPConnection: if self.keycloak_auth_cp: return self.keycloak_auth_cp - uep = self.get_no_auth_cp() + uep: connection.UEPConnection = self.get_no_auth_cp() if not uep.has_capability("keycloak_auth"): + # FIXME This error class should be instantiated raise TokenAuthUnsupportedException # FIXME: make this more reliable - token_type = self._parse_token(token)["typ"] + token_type: str = self._parse_token(token)["typ"] if token_type.lower() == "bearer": access_token = token else: - status = uep.getStatus() + status: Dict = uep.getStatus() auth_url = status["keycloakAuthUrl"] realm = status["keycloakRealm"] resource = status["keycloakResource"] keycloak_instance = connection.KeycloakConnection(realm, auth_url, resource) - access_token = keycloak_instance.get_access_token_through_refresh(token) + access_token: str = keycloak_instance.get_access_token_through_refresh(token) self.set_token(access_token) diff --git a/src/subscription_manager/cpuinfo.py b/src/subscription_manager/cpuinfo.py index 381afeb905..70abc4c078 100644 --- a/src/subscription_manager/cpuinfo.py +++ b/src/subscription_manager/cpuinfo.py @@ -102,6 +102,7 @@ # model name : Intel(R) Xeon(R) CPU E5-2630 0 @ 2.30GHz # stepping : 7 # microcode : 0x710 +from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Set, Sequence, Tuple log = logging.getLogger(__name__) @@ -139,19 +140,19 @@ class Ppc64Fields(object): class CpuinfoModel(object): fields_class = DefaultCpuFields - def __init__(self, cpuinfo_data=None): + def __init__(self, cpuinfo_data: Optional[Dict] = None): # The contents of /proc/cpuinfo - self.cpuinfo_data = cpuinfo_data + self.cpuinfo_data: Optional[Dict] = cpuinfo_data # A iterable of CpuInfoModels, one for each processor in cpuinfo - self.processors = [] + self.processors: List[CpuinfoModel] = [] # prologues or footnotes not associated with a particular processor - self.other = [] + self.other: List[str] = [] # If were going to pretend all the cpus are the same, # what do they all have in common. - self.common = {} + self.common: Dict = {} # model name : Intel(R) Core(TM) i5 CPU M 560 @ 2.67GHz self._model_name = None @@ -161,11 +162,11 @@ def __init__(self, cpuinfo_data=None): self._model = None @property - def count(self): + def count(self) -> int: return len(self.processors) @property - def model_name(self): + def model_name(self) -> Optional[str]: if self._model_name: return self._model_name @@ -175,7 +176,7 @@ def model_name(self): return self.common.get(self.fields_class.MODEL_NAME, None) @property - def model(self): + def model(self) -> Optional[str]: if self._model: return self._model @@ -184,8 +185,8 @@ def model(self): return self.common.get(self.fields_class.MODEL, None) - def __str__(self): - lines = [] + def __str__(self) -> str: + lines: List[str] = [] lines.append("Processor count: %s" % self.count) lines.append("model_name: %s" % self.model_name) lines.append("") @@ -199,21 +200,23 @@ def __str__(self): class Aarch64ProcessorModel(dict): - "The info corresponding to the info about each aarch64 processor entry in cpuinfo" + """The info corresponding to the info about each aarch64 processor entry in cpuinfo""" + pass class X86_64ProcessorModel(dict): - "The info corresponding to the info about each X86_64 processor entry in cpuinfo" + """The info corresponding to the info about each X86_64 processor entry in cpuinfo""" + pass class Ppc64ProcessorModel(dict): - "The info corresponding to the info about each ppc64 processor entry in cpuinfo" + """The info corresponding to the info about each ppc64 processor entry in cpuinfo""" @classmethod - def from_stanza(cls, stanza): - cpu_data = cls() + def from_stanza(cls: type, stanza: Sequence[str]): + cpu_data: cls = cls() cpu_data.update(dict([fact_sluggify_item(item) for item in stanza])) return cpu_data @@ -235,7 +238,7 @@ class Aarch64CpuinfoModel(CpuinfoModel): fields_class = Aarch64Fields -def fact_sluggify(key): +def fact_sluggify(key: str) -> str: """Encodes an arbitrary string to something that can be used as a fact name. ie, 'model_name' instead of 'Model name' @@ -250,37 +253,40 @@ def fact_sluggify(key): return key.lower().strip().replace(" ", "_").replace(".", "_") -def fact_sluggify_item(item_tuple): - newkey = fact_sluggify(item_tuple[0]) +def fact_sluggify_item(item_tuple: Iterable[str]) -> Tuple[str, str]: + newkey: str = fact_sluggify(item_tuple[0]) return (newkey, item_tuple[1]) -def split_key_value_generator(file_contents, line_splitter): +# FIXME Argument shadows variaable name from outer scope +def split_key_value_generator(file_contents: str, line_splitter: Callable) -> Iterator[List[str]]: for line in file_contents.splitlines(): - parts = line_splitter(line) + parts: List[str] = line_splitter(line) if parts: yield parts -def line_splitter(line): +def line_splitter(line: str) -> Optional[List[str]]: # cpu family : 6 # model name : Intel(R) Core(TM) i5 CPU M 560 @ 2.67GHz - parts = line.split(":", 1) + parts: List[str] = line.split(":", 1) if parts[0]: parts = [part.strip() for part in parts] return parts return None -def accumulate_fields(fields_accum, fields): +def accumulate_fields(fields_accum: Set, fields: Sequence) -> Set: for field in fields: fields_accum.add(field) return fields_accum -def find_shared_key_value_pairs(all_fields, processors): +def find_shared_key_value_pairs( + all_fields: Iterable[str], processors: Sequence[Dict[str, Set]] +) -> Dict[str, Set]: # smashem, last one wins - smashed = collections.defaultdict(set) + smashed: Dict[str, Set] = collections.defaultdict(set) # build a dict of fieldname -> list of all the different values # so we can dump the variant ones. @@ -295,7 +301,7 @@ def find_shared_key_value_pairs(all_fields, processors): return common_cpu_info -def split_kv_list_by_field(kv_list, field): +def split_kv_list_by_field(kv_list: Sequence[Tuple[str, Any]], field: str) -> Iterator[List[Tuple[str, str]]]: """Split the iterable kv_list into chunks by field. For a list with repeating stanzas in it, this will @@ -304,7 +310,7 @@ def split_kv_list_by_field(kv_list, field): For something like /proc/cpuinfo, called with field 'processor', each stanza is a different cpu. """ - current_stanza = None + current_stanza: List[Tuple[str, Any]] = None for key, value in kv_list: if key == field: if current_stanza: @@ -344,29 +350,31 @@ def split_kv_list_by_field(kv_list, field): class BaseCpuInfo(object): @classmethod - def from_proc_cpuinfo_string(cls, proc_cpuinfo_string): + def from_proc_cpuinfo_string(cls: type, proc_cpuinfo_string: str): """Return a BaseCpuInfo subclass based on proc_cpuinfo_string. proc_cpuinfo_string is the string resulting from reading the entire contents of /proc/cpuinfo.""" - cpu_info = cls() + cpu_info: BaseCpuInfo = cls() cpu_info._parse(proc_cpuinfo_string) return cpu_info + # FIXME Implement _parse function doing `pass` or raising NotImplementedError + class Aarch64CpuInfo(BaseCpuInfo): def __init__(self): self.cpu_info = Aarch64CpuinfoModel() - def _parse(self, cpuinfo_data): - raw_kv_iter = split_key_value_generator(cpuinfo_data, line_splitter) + def _parse(self, cpuinfo_data: str) -> None: + raw_kv_iter: Iterator = split_key_value_generator(cpuinfo_data, line_splitter) # Yes, there is a 'Processor' field and multiple lower case 'processor' # fields. - kv_iter = (self._capital_processor_to_model_name(item) for item in raw_kv_iter) + kv_iter: Iterator = (self._capital_processor_to_model_name(item) for item in raw_kv_iter) - slugged_kv_list = [fact_sluggify_item(item) for item in kv_iter] + slugged_kv_list: List[Tuple[str, str]] = [fact_sluggify_item(item) for item in kv_iter] # kind of duplicated self.cpu_info.common = self.gather_cpu_info_model(slugged_kv_list) @@ -375,7 +383,7 @@ def _parse(self, cpuinfo_data): # For now, 'hardware' is per self.cpu_info.other = self.gather_cpu_info_other(slugged_kv_list) - def _capital_processor_to_model_name(self, item): + def _capital_processor_to_model_name(self, item: List[str]) -> List[str]: """Use the uppercase Processor field value as the model name. For aarch64, the 'Processor' field is the closest to model name, @@ -384,8 +392,8 @@ def _capital_processor_to_model_name(self, item): item[0] = "model_name" return item - def gather_processor_list(self, kv_list): - processor_list = [] + def gather_processor_list(self, kv_list: List[Tuple[str, str]]) -> List[Dict[str, Any]]: + processor_list: List[Dict[str, Any]] = [] for k, v in kv_list: if k != "processor": continue @@ -398,14 +406,14 @@ def gather_processor_list(self, kv_list): # FIXME: more generic would be to split the stanzas by empty lines in the # first pass - def gather_cpu_info_other(self, kv_list): - other_list = [] + def gather_cpu_info_other(self, kv_list: List[Tuple[str, str]]) -> List[List[str]]: + other_list: List[List[str]] = [] for k, v in kv_list: if k == "hardware": other_list.append([k, v]) return other_list - def gather_cpu_info_model(self, kv_list): + def gather_cpu_info_model(self, kv_list: List[Tuple[str, str]]) -> Aarch64ProcessorModel: cpu_data = Aarch64ProcessorModel() for k, v in kv_list: if k == "processor" or k == "hardware": @@ -418,14 +426,14 @@ class X86_64CpuInfo(BaseCpuInfo): def __init__(self): self.cpu_info = X86_64CpuinfoModel() - def _parse(self, cpuinfo_data): + def _parse(self, cpuinfo_data: str) -> None: # ordered list - kv_iter = split_key_value_generator(cpuinfo_data, line_splitter) + kv_iter: Iterator = split_key_value_generator(cpuinfo_data, line_splitter) - processors = [] + processors: List[Dict[str, Any]] = [] all_fields = set() for processor_stanza in split_kv_list_by_field(kv_iter, "processor"): - proc_dict = self.processor_stanza_to_processor_data(processor_stanza) + proc_dict: X86_64ProcessorModel = self.processor_stanza_to_processor_data(processor_stanza) processors.append(proc_dict) # keep track of fields as we see them all_fields = accumulate_fields(all_fields, list(proc_dict.keys())) @@ -434,7 +442,7 @@ def _parse(self, cpuinfo_data): self.cpu_info.processors = processors self.cpu_info.cpuinfo_data = cpuinfo_data - def processor_stanza_to_processor_data(self, stanza): + def processor_stanza_to_processor_data(self, stanza: List[Tuple[str, str]]) -> X86_64ProcessorModel: "Take a list of k,v tuples, sluggify name, and add to a dict." cpu_data = X86_64ProcessorModel() cpu_data.update(dict([fact_sluggify_item(item) for item in stanza])) @@ -448,7 +456,7 @@ def __init__(self): def _parse(self, cpuinfo_data): kv_iter = split_key_value_generator(cpuinfo_data, line_splitter) - processor_iter = itertools.takewhile(self._not_timebase_key, kv_iter) + processor_iter: Iterable[Tuple[str, Any]] = itertools.takewhile(self._not_timebase_key, kv_iter) for processor_stanza in split_kv_list_by_field(processor_iter, "processor"): proc_dict = Ppc64ProcessorModel.from_stanza(processor_stanza) self.cpu_info.processors.append(proc_dict) @@ -458,7 +466,7 @@ def _parse(self, cpuinfo_data): self.cpu_info.common = dict([fact_sluggify_item(item) for item in kv_iter]) self.cpu_info.cpuinfo_data = cpuinfo_data - def _not_timebase_key(self, item): + def _not_timebase_key(self, item: str) -> bool: return item[0] != "timebase" @@ -472,19 +480,19 @@ class SystemCpuInfoFactory(object): proc_cpuinfo_path = "/proc/cpuinfo" @classmethod - def from_uname_machine(cls, uname_machine, prefix=None): + def from_uname_machine(cls, uname_machine: str, prefix: Optional[str] = None) -> BaseCpuInfo: if uname_machine not in SystemCpuInfoFactory.uname_to_cpuinfo: # er? raise NotImplementedError - proc_cpuinfo_string = cls.open_proc_cpuinfo(prefix) + proc_cpuinfo_string: str = cls.open_proc_cpuinfo(prefix) - arch_class = cls.uname_to_cpuinfo[uname_machine] + arch_class: type(BaseCpuInfo) = cls.uname_to_cpuinfo[uname_machine] return arch_class.from_proc_cpuinfo_string(proc_cpuinfo_string) @classmethod - def open_proc_cpuinfo(cls, prefix=None): - proc_cpuinfo_path = cls.proc_cpuinfo_path + def open_proc_cpuinfo(cls, prefix: str = Optional[None]) -> str: + proc_cpuinfo_path: str = cls.proc_cpuinfo_path if prefix: proc_cpuinfo_path = os.path.join(prefix, cls.proc_cpuinfo_path[1:]) with open(proc_cpuinfo_path, "r") as proc_cpuinfo_f: diff --git a/src/subscription_manager/entbranding.py b/src/subscription_manager/entbranding.py index 7eb19c6df6..9bc3d5a0f9 100644 --- a/src/subscription_manager/entbranding.py +++ b/src/subscription_manager/entbranding.py @@ -17,45 +17,49 @@ import logging +from typing import List, Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from rhsm.certificate2 import EntitlementCertificate, Product log = logging.getLogger(__name__) class BrandsInstaller(object): - def __init__(self, ent_certs=None): - self.ent_certs = ent_certs + def __init__(self, ent_certs: Optional[List["EntitlementCertificate"]] = None): + self.ent_certs: Optional[List[EntitlementCertificate]] = ent_certs # find brand installers self.brand_installers = self._get_brand_installers() - def _get_brand_installers(self): + def _get_brand_installers(self) -> List["BrandInstaller"]: """returns a list or iterable of BrandInstaller(s)""" return [] - def install(self): + def install(self) -> None: for brand_installer in self.brand_installers: brand_installer.install() class BrandInstaller(object): - """Install branding info for a set of entititlement certs.""" + """Install branding info for a set of entitlement certs.""" - def __init__(self, ent_certs=None): - self.ent_certs = ent_certs + def __init__(self, ent_certs: Optional[List["EntitlementCertificate"]] = None): + self.ent_certs: Optional[List["EntitlementCertificate"]] = ent_certs log.debug("BrandInstaller ent_certs: %s" % [x.serial for x in ent_certs or []]) - def install(self): + def install(self) -> None: """Create a Brand object if needed, and save it.""" - brand_picker = self._get_brand_picker() - new_brand = brand_picker.get_brand() + brand_picker: BrandPicker = self._get_brand_picker() + new_brand: Brand = brand_picker.get_brand() # no branded name info to install if not new_brand: return - current_brand = self._get_current_brand() + current_brand: Brand = self._get_current_brand() log.debug("Current branded name info, if any: %s" % current_brand.name) log.debug("Fresh ent cert has branded product info: %s" % new_brand.name) @@ -65,13 +69,13 @@ def install(self): else: log.debug("Product branding info does not need to be updated") - def _get_brand_picker(self): + def _get_brand_picker(self) -> "BrandPicker": raise NotImplementedError - def _get_current_brand(self): + def _get_current_brand(self) -> "Brand": raise NotImplementedError - def _install(self, brand): + def _install(self, brand: "Brand") -> None: raise NotImplementedError @@ -81,22 +85,22 @@ class BrandPicker(object): Check installed product certs, and the list of entitlement certs passed in, and find the correct branded name, if any.""" - def __init__(self, ent_certs=None): - self.ent_certs = ent_certs + def __init__(self, ent_certs: Optional[List["EntitlementCertificate"]] = None): + self.ent_certs: Optional[List["EntitlementCertificate"]] = ent_certs - def get_brand(self): + def get_brand(self) -> "Brand": raise NotImplementedError class Brand(object): """Base class for Brand objects.""" - name = None + name: Optional[str] = None # could potentially be a __lt__ etc, though there is some # oddness in the compares are not symetric for the empty - # cases (ie, we update nothing with something,etc) - def is_outdated_by(self, new_brand): + # cases (ie, we update nothing with something, etc) + def is_outdated_by(self, new_brand: "Brand") -> bool: """If a Brand should be replaced with new_brand.""" if not self.name: return True @@ -113,23 +117,23 @@ def is_outdated_by(self, new_brand): class ProductBrand(Brand): """A brand for a branded product""" - def __init__(self, name): - self.brand_file = self._get_brand_file() + def __init__(self, name: str): + self.brand_file: BrandFile = self._get_brand_file() self.name = name - def _get_brand_file(self): + def _get_brand_file(self) -> "BrandFile": return BrandFile() - def save(self): - brand = self.format_brand(self.name) + def save(self) -> None: + brand: str = self.format_brand(self.name) self.brand_file.write(brand) @classmethod - def from_product(cls, product): + def from_product(cls, product: "Product") -> "ProductBrand": return cls(product.brand_name) @staticmethod - def format_brand(brand): + def format_brand(brand: str) -> str: if not brand.endswith("\n"): brand += "\n" @@ -140,15 +144,15 @@ class CurrentBrand(Brand): """The currently installed brand""" def __init__(self): - self.brand_file = self._get_brand_file() + self.brand_file: BrandFile = self._get_brand_file() self.load() - def _get_brand_file(self): + def _get_brand_file(self) -> "BrandFile": return BrandFile() def load(self): try: - brand_info = self.brand_file.read() + brand_info: str = self.brand_file.read() except IOError: log.error("No brand info file found (%s) " % self.brand_file) return @@ -156,7 +160,7 @@ def load(self): self.name = self.unformat_brand(brand_info) @staticmethod - def unformat_brand(brand): + def unformat_brand(brand: str) -> Optional[str]: if brand: return brand.strip() return None @@ -168,15 +172,15 @@ class BrandFile(object): Default is "/var/lib/rhsm/branded_name """ - path = "/var/lib/rhsm/branded_name" + path: str = "/var/lib/rhsm/branded_name" - def write(self, brand_info): + def write(self, brand_info: str) -> None: with open(self.path, "w") as brand_file: brand_file.write(brand_info) - def read(self): + def read(self) -> str: with open(self.path, "r") as brand_file: return brand_file.read() - def __str__(self): + def __str__(self) -> str: return "" % self.path diff --git a/src/subscription_manager/entcertlib.py b/src/subscription_manager/entcertlib.py index 12ba296a91..c64c1ef2a0 100644 --- a/src/subscription_manager/entcertlib.py +++ b/src/subscription_manager/entcertlib.py @@ -10,7 +10,8 @@ # Red Hat trademarks are not licensed under GPLv2. No permission is # granted to use or replicate Red Hat trademarks that are incorporated # in this software or its documentation. -# +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING + import logging import socket @@ -27,6 +28,16 @@ from subscription_manager.i18n import ungettext, ugettext as _ +if TYPE_CHECKING: + from rhsm.connection import UEPConnection + from rhsm.certificate2 import Product + + from subscription_manager.cache import ContentAccessCache + from subscription_manager.cp_provider import CPProvider + from subscription_manager.identity import Identity + from subscription_manager.certdirectory import EntitlementCertificate, EntitlementDirectory + + log = logging.getLogger(__name__) CONTENT_ACCESS_CERT_CAPABILITY = "org_level_content_access" @@ -35,11 +46,13 @@ class EntCertActionInvoker(certlib.BaseActionInvoker): """Invoker for entitlement certificate updating actions.""" - def _do_update(self): + def _do_update(self) -> "EntCertUpdateReport": action = EntCertUpdateAction() return action.perform() +# FIXME These Delete classes are not being used +# and it is not clear what the usage and type hints should be # this guy is an oddball # NOTE: this lib and EntCertDeleteAction are currently # unused. Intention is to replace managerlib.clean_all_data @@ -103,30 +116,30 @@ class EntCertUpdateAction(object): """ def __init__(self, report=None): - self.cp_provider = inj.require(inj.CP_PROVIDER) - self.uep = self.cp_provider.get_consumer_auth_cp() - self.ent_dir = inj.require(inj.ENT_DIR) - self.identity = require(IDENTITY) + self.cp_provider: CPProvider = inj.require(inj.CP_PROVIDER) + self.uep: UEPConnection = self.cp_provider.get_consumer_auth_cp() + self.ent_dir: EntitlementDirectory = inj.require(inj.ENT_DIR) + self.identity: Identity = require(IDENTITY) self.report = EntCertUpdateReport() - self.content_access_cache = inj.require(inj.CONTENT_ACCESS_CACHE) + self.content_access_cache: ContentAccessCache = inj.require(inj.CONTENT_ACCESS_CACHE) # NOTE: this is slightly at odds with the manual cert import - # path, manual import certs wont get a 'report', etc - def perform(self): - local = self._get_local_serials() + # path, manual import certs won't get a 'report', etc + def perform(self) -> "EntCertUpdateReport": + local: Dict[int, EntitlementCertificate] = self._get_local_serials() try: - expected = self._get_expected_serials() + expected: List[int] = self._get_expected_serials() except socket.error as ex: log.exception(ex) log.error("Cannot modify subscriptions while disconnected") raise Disconnected() - cert_changed = False - missing_serials = self._find_missing_serials(local, expected) - rogue_serials = self._find_rogue_serials(local, expected) + cert_changed: bool = False + missing_serials: List[int] = self._find_missing_serials(local, expected) + rogue_serials: List[EntitlementCertificate] = self._find_rogue_serials(local, expected) self.delete(rogue_serials) - installed_serials = self.install(missing_serials) + installed_serials: List[int] = self.install(missing_serials) log.info("certs updated:\n%s", self.report) self.syslog_results() @@ -141,10 +154,10 @@ def perform(self): cert_changed = True if self.uep.has_capability(CONTENT_ACCESS_CERT_CAPABILITY): - content_access_certs = self._find_content_access_certs() + content_access_certs: List[EntitlementCertificate] = self._find_content_access_certs() if len(content_access_certs) > 0: # This addresses BZs: 1448855, 1450862 - obsolete_certs = [] + obsolete_certs: List[EntitlementCertificate] = [] for cont_access_cert in content_access_certs: if cont_access_cert.serial in installed_serials: continue @@ -153,7 +166,7 @@ def perform(self): if len(obsolete_certs) > 0: log.info("Deleting obsolete content access certificate") self.delete(obsolete_certs) - update_data = self.content_access_hook() + update_data: Optional[Dict] = self.content_access_hook() if update_data is not None: cert_changed = True @@ -173,7 +186,7 @@ def perform(self): # of *Lib.update return self.report - def install(self, missing_serials): + def install(self, missing_serials) -> List[int]: """Install any missing entitlement certificates.""" cert_bundles = self.get_certificates_by_serial_list(missing_serials) @@ -181,15 +194,15 @@ def install(self, missing_serials): ent_cert_bundles_installer = EntitlementCertBundlesInstaller(self.report) return ent_cert_bundles_installer.install(cert_bundles) - def _find_content_access_certs(self): - certs = self.ent_dir.list_with_content_access() + def _find_content_access_certs(self) -> List["EntitlementCertificate"]: + certs: List[EntitlementCertificate] = self.ent_dir.list_with_content_access() return [cert for cert in certs if cert.entitlement_type == CONTENT_ACCESS_CERT_TYPE] - def content_access_hook(self): + def content_access_hook(self) -> Optional[Dict]: if not self.uep.has_capability(CONTENT_ACCESS_CERT_CAPABILITY): return # do nothing if we cannot check for content access cert updates - content_access_certs = self._find_content_access_certs() - update_data = None + content_access_certs: List[EntitlementCertificate] = self._find_content_access_certs() + update_data: Optional[Dict] = None if len(content_access_certs) > 0: update_data = self.content_access_cache.check_for_update() for content_access_cert in content_access_certs: @@ -200,14 +213,14 @@ def content_access_hook(self): self.ent_dir.refresh() return update_data - def branding_hook(self): + def branding_hook(self) -> None: """Update branding info based on entitlement cert changes.""" # RHELBrandsInstaller will use latest ent_dir contents brands_installer = rhelentbranding.RHELBrandsInstaller() brands_installer.install() - def repo_hook(self): + def repo_hook(self) -> None: """Update content repos.""" log.debug("entcerlibaction.repo_hook") try: @@ -218,17 +231,21 @@ def repo_hook(self): log.debug(e) log.debug("Failed to update repos") - def _find_missing_serials(self, local, expected): + def _find_missing_serials( + self, local: Dict[int, "EntitlementCertificate"], expected: List[int] + ) -> List[int]: """Find serials from the server we do not have locally.""" missing = [sn for sn in expected if sn not in local] return missing - def _find_rogue_serials(self, local, expected): + def _find_rogue_serials( + self, local: Dict[int, "EntitlementCertificate"], expected: List[int] + ) -> List["EntitlementCertificate"]: """Find serials we have locally but are not on the server.""" rogue = [local[sn] for sn in local if sn not in expected] return rogue - def syslog_results(self): + def syslog_results(self) -> None: """Write generated EntCertUpdateReport info to syslog.""" for cert in self.report.added: utils.system_log( @@ -243,59 +260,61 @@ def syslog_results(self): for product in cert.products: utils.system_log("Removed subscription for product '%s'" % (product.name)) - def _get_local_serials(self): - local = {} + def _get_local_serials(self) -> Dict[int, "EntitlementCertificate"]: + local: Dict[int, EntitlementCertificate] = {} # certificates in grace period were being renamed everytime. # this makes sure we don't try to re-write certificates in # grace period # XXX since we don't use grace period, this might not be needed self.ent_dir.refresh() - ent_certs = self.ent_dir.list() + self.ent_dir.list_with_content_access() + ent_certs: List[EntitlementCertificate] = ( + self.ent_dir.list() + self.ent_dir.list_with_content_access() + ) ent_certs = list(set(ent_certs)) for valid in ent_certs: - sn = valid.serial + sn: int = valid.serial self.report.valid.append(sn) local[sn] = valid return local - def get_certificate_serials_list(self): + def get_certificate_serials_list(self) -> List[int]: """Query RHSM API for list of expected ent cert serial numbers.""" - results = [] + results: List[int] = [] # if there is no UEP object, short circuit if self.uep is None: return results - identity = inj.require(inj.IDENTITY) + identity: Identity = inj.require(inj.IDENTITY) if not identity.is_valid(): # We can get here on unregister, with no id or ent certs or repos, # but don't want to raise an exception that would be logged. So # empty result set is returned. return results - reply = self.uep.getCertificateSerials(identity.uuid) + reply: List[Dict] = self.uep.getCertificateSerials(identity.uuid) for d in reply: - sn = d["serial"] + sn: int = d["serial"] results.append(sn) return results - def get_certificates_by_serial_list(self, sn_list): + def get_certificates_by_serial_list(self, sn_list: List[int]) -> List[Dict]: """Fetch a list of entitlement certificates specified by a list of serial numbers.""" - result = [] + result: List[Dict] = [] if sn_list: sn_list = [str(sn) for sn in sn_list] # NOTE: use injected IDENTITY, need to validate this # handles disconnected errors properly - reply = self.uep.getCertificates(self.identity.uuid, serials=sn_list) + reply: List[Dict] = self.uep.getCertificates(self.identity.uuid, serials=sn_list) for cert in reply: result.append(cert) return result - def _get_expected_serials(self): - exp = self.get_certificate_serials_list() + def _get_expected_serials(self) -> List[int]: + exp: List[int] = self.get_certificate_serials_list() self.report.expected = exp return exp - def delete(self, rogue): + def delete(self, rogue: List["EntitlementCertificate"]): for cert in rogue: try: cert.delete() @@ -306,7 +325,7 @@ def delete(self, rogue): # If we just deleted certs, we need to refresh the now stale # entitlement directory before we go to delete expired certs. - rogue_count = len(self.report.rogue) + rogue_count: int = len(self.report.rogue) if rogue_count > 0: print( ungettext( @@ -327,37 +346,37 @@ class EntitlementCertBundlesInstaller(object): all of the ent cert bundles are installed. """ - def __init__(self, report): - self.exceptions = [] - self.report = report + def __init__(self, report: "EntCertUpdateReport"): + self.exceptions: List[Exception] = [] + self.report: EntCertUpdateReport = report def install(self, cert_bundles): """Fetch entitliement certs, install them, and update the report.""" bundle_installer = EntitlementCertBundleInstaller(self.report) - installed_serials = [] + installed_serials: List[int] = [] for cert_bundle in cert_bundles: - cert_serial = bundle_installer.install(cert_bundle) + cert_serial: Optional[int] = bundle_installer.install(cert_bundle) if cert_serial is not None: installed_serials.append(cert_serial) self.exceptions = bundle_installer.exceptions self.post_install() return installed_serials - # TODO: add subman plugin slot,conduit,hooks - def pre_install(self): + # TODO: add subman plugin slot, conduit, hooks + def pre_install(self) -> None: """Hook called before any ent cert bundles are installed.""" log.debug("cert bundles pre_install") - def post_install(self): + def post_install(self) -> None: """Hook called after all cert bundles have been installed.""" for installed in self._get_installed(): log.debug("cert bundles post_install: %s" % installed) - def get_installed(self): + def get_installed(self) -> List["EntitlementCertificate"]: """Return a list of the ent cert bundles that were installed.""" return self._get_installed() - def _get_installed(self): + def _get_installed(self) -> List["EntitlementCertificate"]: """Return the bundles installed based on this impl's EntCertUpdateReport.""" return self.report.added @@ -375,16 +394,16 @@ class EntitlementCertBundleInstaller(object): bundles, while this is pre/post each ent cert bundle. """ - def __init__(self, report): - self.exceptions = [] - self.report = report + def __init__(self, report: "EntCertUpdateReport"): + self.exceptions: List[EntitlementCertificate] = [] + self.report: EntCertUpdateReport = report - def install(self, bundle): + def install(self, bundle: Dict): """Persist an ent cert and it's key after splitting it from the bundle.""" self.pre_install(bundle) cert_bundle_writer = Writer() - cert_serial = None + cert_serial: int = None try: key, cert = self.build_cert(bundle) cert_bundle_writer.write(key, cert) @@ -398,29 +417,29 @@ def install(self, bundle): return cert_serial # TODO: add subman plugin, slot, and conduit - def pre_install(self, bundle): + def pre_install(self, bundle: Dict): """Hook called before an ent cert bundle is installed.""" log.debug("Ent cert bundle pre_install") # should probably be in python-rhsm/certificate - def build_cert(self, bundle): + def build_cert(self, bundle: Dict) -> Tuple[Key, "EntitlementCertificate"]: """Split a cert bundle into a EntitlementCertificate and a Key.""" - keypem = bundle["key"] - crtpem = bundle["cert"] + keypem: str = bundle["key"] + crtpem: str = bundle["cert"] key = Key(keypem) - cert = create_from_pem(crtpem) + cert: EntitlementCertificate = create_from_pem(crtpem) return (key, cert) - def install_exception(self, bundle, exception): + def install_exception(self, bundle: Dict, exception: Exception) -> None: """Log exceptions and add them to the EntCertUpdateReport.""" log.exception(exception) log.error("Bundle not loaded:\n%s\n%s", bundle, exception) self.report._exceptions.append(exception) - def post_install(self, bundle): + def post_install(self, bundle: Dict) -> None: """Hook called after an ent cert bundle is installed.""" log.debug("ent cert bundle post_install") @@ -435,28 +454,28 @@ class EntCertUpdateReport(certlib.ActionReport): name = "Entitlement Cert Updates" def __init__(self): - self.valid = [] - self.expected = [] - self.added = [] - self.rogue = [] - self._exceptions = [] + self.valid: List[EntitlementCertificate] = [] + self.expected: List[EntitlementCertificate] = [] + self.added: List[EntitlementCertificate] = [] + self.rogue: List[EntitlementCertificate] = [] + self._exceptions: List[Exception] = [] - def updates(self): + def updates(self) -> int: """Total number of ent certs installed and deleted.""" return len(self.added) + len(self.rogue) # need an ExceptionsReport? # FIXME: needs to be properties - def exceptions(self): + def exceptions(self) -> List[Exception]: return self._exceptions - def write(self, s, title, certificates): + def write(self, s: List[str], title: str, certificates: List["EntitlementCertificate"]) -> None: """Generate a report stanza for a list of certs.""" - indent = " " + indent: str = " " s.append(title) if certificates: for c in certificates: - products = c.products + products: List[Product] = c.products if not products: s.append("%s[sn:%d (%s) @ %s]" % (indent, c.serial, c.order.name, c.path)) for product in products: @@ -464,7 +483,7 @@ def write(self, s, title, certificates): else: s.append("%s" % indent) - def __str__(self): + def __str__(self) -> str: """__str__ of report. Used in rhsm and rhsmcertd logging.""" s = [] s.append(_("Total updates: %d") % self.updates()) diff --git a/src/subscription_manager/exceptions.py b/src/subscription_manager/exceptions.py index 3469d6b67a..ffd06d971f 100644 --- a/src/subscription_manager/exceptions.py +++ b/src/subscription_manager/exceptions.py @@ -13,6 +13,8 @@ # in this software or its documentation. # import inspect +from typing import Callable, Dict, Optional, Tuple + from socket import error as socket_error, gaierror as socket_gaierror from rhsm.https import ssl, httplib @@ -71,7 +73,7 @@ class ExceptionMapper(object): def __init__(self): - self.message_map = { + self.message_map: Dict[str, Callable] = { socket_error: (SOCKET_MESSAGE, self.format_using_template), socket_gaierror: (GAI_MESSAGE, self.format_generic_oserror), ConnectionError: (CONNECTION_MESSAGE, self.format_generic_oserror), @@ -106,7 +108,7 @@ def format_using_template(self, _: Exception, message: str) -> str: """Return unaltered message template.""" return message - def format_using_error(self, exc: Exception, _: str) -> str: + def format_using_error(self, exc: Exception, _: Optional[str]) -> str: """Return string representation of the error.""" return str(exc) @@ -146,15 +148,25 @@ def format_bad_ca_cert_exception( file=bad_ca_cert_error.cert_path, reason=str(bad_ca_cert_error.ssl_exc) ) - def format_ssl_error(self, ssl_error, message_template): + def format_ssl_error(self, ssl_error: ssl.SSLError, message_template: str): return message_template % ssl_error - def format_restlib_exception(self, restlib_exception, message_template): + def format_restlib_exception( + self, + restlib_exception: connection.RestlibException, + message_template: str, + ): return message_template.format( - message=restlib_exception.msg, code=restlib_exception.code, title=restlib_exception.title + message=restlib_exception.msg, + code=restlib_exception.code, + title=restlib_exception.title, ) - def format_rate_limit_exception(self, rate_limit_exception, _): + def format_rate_limit_exception( + self, + rate_limit_exception: connection.RateLimitExceededException, + _: str, + ): if rate_limit_exception.retry_after is not None: return RATE_LIMIT_EXPIRATION % str(rate_limit_exception.retry_after) else: @@ -194,9 +206,12 @@ def get_message(self, exception) -> str: representation is returned. """ # Lookup by __class__ instead of type to support old style classes - exception_classes = inspect.getmro(exception.__class__) + exception_classes: Tuple[type(Exception), ...] = inspect.getmro(exception.__class__) + exception_class: type(Exception) for exception_class in exception_classes: if exception_class in self.message_map: + message_template: str + formatter: Callable message_template, formatter = self.message_map[exception_class] return formatter(exception, message_template) return self.format_using_error(exception, None) diff --git a/src/subscription_manager/factlib.py b/src/subscription_manager/factlib.py index dfef54be3d..7d77b26fae 100644 --- a/src/subscription_manager/factlib.py +++ b/src/subscription_manager/factlib.py @@ -14,6 +14,12 @@ # in this software or its documentation. # import logging +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from rhsm.connection import UEPConnection + from subscription_manager.cp_provider import CPProvider + from subscription_manager.facts import Facts from .certlib import Locker, ActionReport from subscription_manager import injection as inj @@ -33,10 +39,10 @@ class FactsActionInvoker(object): def __init__(self): self.locker = Locker() - def update(self): + def update(self) -> "FactsActionReport": return self.locker.run(self._do_update) - def _do_update(self): + def _do_update(self) -> "FactsActionReport": action = FactsActionCommand() return action.perform() @@ -56,7 +62,7 @@ def __init__(self): self._updates = [] self._status = None - def updates(self): + def updates(self) -> int: """How many facts were updated.""" return len(self.fact_updates) @@ -75,13 +81,13 @@ class FactsActionCommand(object): """ def __init__(self): - cp_provider = inj.require(inj.CP_PROVIDER) - self.uep = cp_provider.get_consumer_auth_cp() + cp_provider: CPProvider = inj.require(inj.CP_PROVIDER) + self.uep: UEPConnection = cp_provider.get_consumer_auth_cp() self.report = FactsActionReport() - self.facts = inj.require(inj.FACTS) - self.facts_client = inj.require(inj.FACTS) + self.facts: Facts = inj.require(inj.FACTS) + self.facts_client: Facts = inj.require(inj.FACTS) - def perform(self): + def perform(self) -> "FactsActionReport": # figure out the diff between latest facts and # report that as updates diff --git a/src/subscription_manager/facts.py b/src/subscription_manager/facts.py index 0edf29ed17..2fd1896bbb 100644 --- a/src/subscription_manager/facts.py +++ b/src/subscription_manager/facts.py @@ -12,6 +12,7 @@ from datetime import datetime import logging import os +from typing import Dict, List, Optional, TYPE_CHECKING from subscription_manager.injection import PLUGIN_MANAGER, require from subscription_manager.cache import CacheManager @@ -19,6 +20,10 @@ from rhsmlib.facts.all import AllFactsCollector +if TYPE_CHECKING: + from subscription_manager.plugins import PluginManager + + log = logging.getLogger(__name__) @@ -41,18 +46,18 @@ def __init__(self): # can change constantly on laptops, it makes for a lot of # fact churn, so we report it, but ignore it as an indicator # that we need to update - self.graylist = ["cpu.cpu_mhz", "lscpu.cpu_mhz"] + self.graylist: List[str] = ["cpu.cpu_mhz", "lscpu.cpu_mhz"] # plugin manager so we can add custom facts via plugin - self.plugin_manager = require(PLUGIN_MANAGER) + self.plugin_manager: PluginManager = require(PLUGIN_MANAGER) - def get_last_update(self): + def get_last_update(self) -> Optional[datetime]: try: return datetime.fromtimestamp(os.stat(self.CACHE_FILE).st_mtime) except Exception: return None - def has_changed(self): + def has_changed(self) -> bool: """ return a dict of any key/values that have changed including new keys or deleted keys @@ -61,7 +66,7 @@ def has_changed(self): log.debug("Cache %s does not exit" % self.CACHE_FILE) return True - cached_facts = self.read_cache_only() or {} + cached_facts: Dict = self.read_cache_only() or {} # In order to accurately check for changes, we must refresh local data self.facts = self.get_facts(True) @@ -70,7 +75,7 @@ def has_changed(self): return True return False - def get_facts(self, refresh=False): + def get_facts(self, refresh: bool = False): if len(self.facts) == 0 or refresh: collector = AllFactsCollector() facts = collector.get_all() diff --git a/src/subscription_manager/healinglib.py b/src/subscription_manager/healinglib.py index 3724704e26..42f8ed0cab 100644 --- a/src/subscription_manager/healinglib.py +++ b/src/subscription_manager/healinglib.py @@ -14,6 +14,7 @@ import datetime import logging +from typing import List, TYPE_CHECKING from rhsm import certificate @@ -21,6 +22,13 @@ from subscription_manager import entcertlib from subscription_manager import injection as inj +if TYPE_CHECKING: + from subscription_manager.cert_sorter import CertSorter + from subscription_manager.identity import Identity + from subscription_manager.plugins import PluginManager + from rhsm.connection import UEPConnection + from subscription_manager.cp_provider import CPProvider + log = logging.getLogger(__name__) @@ -28,14 +36,14 @@ class HealingActionInvoker(certlib.BaseActionInvoker): """ An object used to run healing nightly. Checks cert validity for today, heals if necessary, then checks for 24 hours from now, so we theoretically will - never have invalid certificats if subscriptions are available. + never have invalid certificates if subscriptions are available. NOTE: We may update entitlement status in this class, but we do not update entitlement certs, since we are inside a lock. So a EntCertActionInvoker.update() needs to follow a HealingActionInvoker.update() """ - def _do_update(self): + def _do_update(self) -> int: action = HealingUpdateAction() return action.perform() @@ -49,7 +57,7 @@ class HealingUpdateAction(object): If either show incomplete entitlement, ask the RHSM API to auto attach pools to fix entitlement. - Attemps to avoid gaps in entitlement coverage. + Attempts to avoid gaps in entitlement coverage. Used by subscription-manager if the "autoheal" options are enabled. @@ -63,16 +71,16 @@ class HealingUpdateAction(object): """ def __init__(self): - self.cp_provider = inj.require(inj.CP_PROVIDER) - self.uep = self.cp_provider.get_consumer_auth_cp() - self.report = entcertlib.EntCertUpdateReport() - self.plugin_manager = inj.require(inj.PLUGIN_MANAGER) + self.cp_provider: CPProvider = inj.require(inj.CP_PROVIDER) + self.uep: UEPConnection = self.cp_provider.get_consumer_auth_cp() + self.report: entcertlib.EntCertUpdateReport = entcertlib.EntCertUpdateReport() + self.plugin_manager: PluginManager = inj.require(inj.PLUGIN_MANAGER) def perform(self): # inject - identity = inj.require(inj.IDENTITY) - uuid = identity.uuid - consumer = self.uep.getConsumer(uuid) + identity: Identity = inj.require(inj.IDENTITY) + uuid: str = identity.uuid + consumer: dict = self.uep.getConsumer(uuid) if "autoheal" not in consumer or not consumer["autoheal"]: log.warning("Auto-heal disabled on server, skipping.") @@ -80,22 +88,22 @@ def perform(self): try: - today = datetime.datetime.now(certificate.GMT()) - tomorrow = today + datetime.timedelta(days=1) - valid_today = False - valid_tomorrow = False + today: datetime.datetime = datetime.datetime.now(certificate.GMT()) + tomorrow: datetime.datetime = today + datetime.timedelta(days=1) + valid_today: bool = False + valid_tomorrow: bool = False # Check if we're invalid today and heal if so. If we are # valid, see if 24h from now is greater than our "valid until" # date, and heal for tomorrow if so. - cs = inj.require(inj.CERT_SORTER) + cs: CertSorter = inj.require(inj.CERT_SORTER) cert_updater = entcertlib.EntCertActionInvoker() if not cs.is_valid(): log.warning("Found invalid entitlements for today: %s" % today) self.plugin_manager.run("pre_auto_attach", consumer_uuid=uuid) - ents = self.uep.bind(uuid, today) + ents: List[dict] = self.uep.bind(uuid, today) self.plugin_manager.run("post_auto_attach", consumer_uuid=uuid, entitlement_data=ents) # NOTE: we need to call EntCertActionInvoker.update after Healing.update diff --git a/src/subscription_manager/i18n.py b/src/subscription_manager/i18n.py index 401161c84c..255aeeb79d 100644 --- a/src/subscription_manager/i18n.py +++ b/src/subscription_manager/i18n.py @@ -19,6 +19,8 @@ import os # Localization domain: +from typing import Optional, Tuple + APP = "rhsm" # Directory where translations are deployed: DIR = "/usr/share/locale/" @@ -67,7 +69,7 @@ def configure_i18n(): Locale.set(lang) -def configure_gettext(): +def configure_gettext() -> None: """Configure gettext for all RHSM-related code. Since Glade internally uses gettext, we need to use the C-level bindings in locale to adjust the encoding. @@ -79,14 +81,14 @@ def configure_gettext(): locale.bind_textdomain_codeset(APP, "UTF-8") -def ugettext(*args, **kwargs): +def ugettext(*args, **kwargs) -> str: if hasattr(LOCALE, "lang") and LOCALE.lang is not None: return LOCALE.lang.gettext(*args, **kwargs) else: return TRANSLATION.gettext(*args, **kwargs) -def ungettext(*args, **kwargs): +def ungettext(*args, **kwargs) -> str: if hasattr(LOCALE, "lang") and LOCALE.lang is not None: return LOCALE.lang.ngettext(*args, **kwargs) else: @@ -101,12 +103,12 @@ class Locale(object): translations = {} @classmethod - def is_locale_supported(cls, language): + def is_locale_supported(cls, language: str) -> bool: """ Is translation for given locale supported? :param language: String with locale code e.g. de_DE, de_DE.UTF-8 - :return: True if the locale is supported by rhsm. Otherwise return False """ + lang: Optional[gettext.GNUTranslations] try: lang = gettext.translation(APP, DIR, languages=[language]) except IOError: @@ -118,14 +120,14 @@ def is_locale_supported(cls, language): return False @classmethod - def _find_lang_alternative(cls, language): + def _find_lang_alternative(cls, language: str) -> Tuple[gettext.GNUTranslations, Optional[str]]: """ Try to find alternative for given language :param language: string with code of language e.g. de_LU :return: Instance of gettext.translation and code of new_language """ - lang = None - new_language = None + lang: Optional[gettext.GNUTranslations] = None + new_language: Optional[str] = None # For similar case: 'de' if "_" not in language: new_language = language + "_" + language.upper() @@ -144,7 +146,7 @@ def _find_lang_alternative(cls, language): return lang, new_language @classmethod - def set(cls, language): + def set(cls, language: str) -> None: """ Set language used by gettext and ungettext method. This method is intended for changing language on the fly, because rhsm service can @@ -154,7 +156,7 @@ def set(cls, language): (e.g. de, de_DE, de_DE.utf-8, de_DE.UTF-8) """ global LOCALE - lang = None + lang: gettext.GNUTranslations = None if language != "": if language in cls.translations.keys(): diff --git a/src/subscription_manager/i18n_argparse.py b/src/subscription_manager/i18n_argparse.py index f79998f7d7..88d8a54823 100644 --- a/src/subscription_manager/i18n_argparse.py +++ b/src/subscription_manager/i18n_argparse.py @@ -19,7 +19,7 @@ Just use this instead of argparse, the interface should be the same. -For some backgorund, see: +For some background, see: http://bugs.python.org/issue4319 """ import argparse @@ -36,10 +36,10 @@ class ArgumentParser(_ArgumentParser): - def print_help(self): + def print_help(self) -> None: sys.stdout.write(self.format_help()) - def error(self, msg): + def error(self, msg: str) -> None: """ Override default error handler to localize diff --git a/src/subscription_manager/identity.py b/src/subscription_manager/identity.py index 838a97c0f5..16f3b2bb8c 100644 --- a/src/subscription_manager/identity.py +++ b/src/subscription_manager/identity.py @@ -16,6 +16,7 @@ import os import errno import threading +from typing import Optional, TYPE_CHECKING from rhsm.certificate import create_from_pem from rhsm.config import get_config_parser @@ -25,6 +26,9 @@ from rhsm.certificate import CertificateException from rhsm.certificate2 import CertificateLoadingError +if TYPE_CHECKING: + from rhsm.certificate2 import EntitlementCertificate + conf = config.Config(get_config_parser()) log = logging.getLogger(__name__) @@ -40,27 +44,27 @@ class ConsumerIdentity(object): CERT = "cert.pem" @staticmethod - def keypath(): + def keypath() -> str: return str(Path.join(ConsumerIdentity.PATH, ConsumerIdentity.KEY)) @staticmethod - def certpath(): + def certpath() -> str: return str(Path.join(ConsumerIdentity.PATH, ConsumerIdentity.CERT)) @classmethod - def read(cls): + def read(cls) -> "ConsumerIdentity": with open(cls.keypath()) as key_file: - key = key_file.read() + key: str = key_file.read() with open(cls.certpath()) as cert_file: - cert = cert_file.read() + cert: str = cert_file.read() return ConsumerIdentity(key, cert) @classmethod - def exists(cls): + def exists(cls) -> bool: return os.path.exists(cls.keypath()) and os.path.exists(cls.certpath()) @classmethod - def existsAndValid(cls): + def existsAndValid(cls) -> bool: if cls.exists(): try: cls.read() @@ -75,23 +79,26 @@ def __init__(self, keystring, certstring): # TODO: bad variables, cert should be the certificate object, x509 is # used elsewhere for the rhsm._certificate object of the same name. self.cert = certstring - self.x509 = create_from_pem(certstring) + self.x509: EntitlementCertificate = create_from_pem(certstring) - def getConsumerId(self): - subject = self.x509.subject + def getConsumerId(self) -> str: + subject: dict = self.x509.subject return subject.get("CN") - def getConsumerName(self): + def getConsumerName(self) -> str: + # FIXME Unresolved attribute altName = self.x509.alt_name + # FIXME Do we still need to support 'old' format? + # The comment is from 2014, the line below from 2017 # must account for old format and new return altName.replace("DirName:/CN=", "").replace("URI:CN=", "").split(", ")[-1] - def getSerialNumber(self): + def getSerialNumber(self) -> int: return self.x509.serial # TODO: we're using a Certificate which has it's own write/delete, no idea # why this landed in a parallel disjoint class wrapping the actual cert. - def write(self): + def write(self) -> None: from subscription_manager import managerlib self.__mkdir() @@ -102,20 +109,20 @@ def write(self): cert_file.write(self.cert) os.chmod(self.certpath(), managerlib.ID_CERT_PERMS) - def delete(self): - path = self.keypath() + def delete(self) -> None: + path: str = self.keypath() if os.path.exists(path): os.unlink(path) - path = self.certpath() + path: str = self.certpath() if os.path.exists(path): os.unlink(path) - def __mkdir(self): - path = Path.abs(self.PATH) + def __mkdir(self) -> None: + path: str = Path.abs(self.PATH) if not os.path.exists(path): os.mkdir(path) - def __str__(self): + def __str__(self) -> str: return 'consumer: name="%s", uuid=%s' % (self.getConsumerName(), self.getConsumerId()) @@ -123,14 +130,14 @@ class Identity(object): """Wrapper for sharing consumer identity without constant reloading.""" def __init__(self): - self.consumer = None + self.consumer: Optional[ConsumerIdentity] = None self._lock = threading.Lock() - self._name = None - self._uuid = None - self._cert_dir_path = conf["rhsm"]["consumerCertDir"] + self._name: Optional[str] = None + self._uuid: Optional[str] = None + self._cert_dir_path: str = conf["rhsm"]["consumerCertDir"] self.reload() - def reload(self): + def reload(self) -> None: """Check for consumer certificate on disk and update our info accordingly.""" log.debug("Loading consumer info from identity certificates.") with self._lock: @@ -140,7 +147,7 @@ def reload(self): # existsAndValid did, so this is better. except (CertificateException, CertificateLoadingError, IOError) as err: self.consumer = None - err_msg = err + err_msg: Exception = err msg = "Reload of consumer identity cert %s raised an exception with msg: %s" % ( ConsumerIdentity.certpath(), err_msg, @@ -161,35 +168,35 @@ def reload(self): self._uuid = None self._cert_dir_path = conf["rhsm"]["consumerCertDir"] - def _get_consumer_identity(self): + def _get_consumer_identity(self) -> ConsumerIdentity: return ConsumerIdentity.read() # this name is weird, since Certificate.is_valid actually checks the data # and this is a thin wrapper - def is_valid(self): + def is_valid(self) -> bool: return self.uuid is not None @property - def name(self): + def name(self) -> Optional[str]: with self._lock: _name = self._name return _name @property - def uuid(self): + def uuid(self) -> Optional[str]: with self._lock: _uuid = self._uuid return _uuid @property - def cert_dir_path(self): + def cert_dir_path(self) -> str: with self._lock: _cert_dir_path = self._cert_dir_path return _cert_dir_path @staticmethod - def is_present(): + def is_present() -> bool: return ConsumerIdentity.exists() - def __str__(self): + def __str__(self) -> str: return "Consumer Identity name=%s uuid=%s" % (self.name, self.uuid) diff --git a/src/subscription_manager/identitycertlib.py b/src/subscription_manager/identitycertlib.py index 761076022e..c2daf609bf 100644 --- a/src/subscription_manager/identitycertlib.py +++ b/src/subscription_manager/identitycertlib.py @@ -13,11 +13,16 @@ # import logging - +from typing import TYPE_CHECKING from subscription_manager import certlib from subscription_manager import injection as inj +if TYPE_CHECKING: + from rhsm.connection import UEPConnection + from subscription_manager.cp_provider import CPProvider + from subscription_manager.identity import ConsumerIdentity, Identity + log = logging.getLogger(__name__) @@ -29,7 +34,7 @@ class IdentityCertActionInvoker(certlib.BaseActionInvoker): for updates. """ - def _do_update(self): + def _do_update(self) -> certlib.ActionReport: action = IdentityUpdateAction() return action.perform() @@ -41,14 +46,14 @@ class IdentityUpdateAction(object): 1 indicates identity cert was updated.""" def __init__(self): - self.cp_provider = inj.require(inj.CP_PROVIDER) - self.uep = self.cp_provider.get_consumer_auth_cp() + self.cp_provider: CPProvider = inj.require(inj.CP_PROVIDER) + self.uep: UEPConnection = self.cp_provider.get_consumer_auth_cp() # Use the default report self.report = certlib.ActionReport() - def perform(self): - identity = inj.require(inj.IDENTITY) + def perform(self) -> certlib.ActionReport: + identity: Identity = inj.require(inj.IDENTITY) if not identity.is_valid(): # we could in theory try to update the id in the @@ -60,15 +65,15 @@ def perform(self): return self._update_cert(identity) - def _update_cert(self, identity): + def _update_cert(self, identity: "Identity") -> certlib.ActionReport: # to avoid circular imports # FIXME: move persist stuff here from subscription_manager import managerlib - idcert = identity.consumer + idcert: ConsumerIdentity = identity.consumer - consumer = self._get_consumer(identity) + consumer: dict = self._get_consumer(identity) # only write the cert if the serial has changed # FIXME: this would be a good place to have a Consumer/ConsumerCert @@ -84,7 +89,7 @@ def _update_cert(self, identity): self.report._status = 1 return self.report - def _get_consumer(self, identity): + def _get_consumer(self, identity: "Identity") -> dict: # FIXME: not much for error handling here - consumer = self.uep.getConsumer(identity.uuid) + consumer: dict = self.uep.getConsumer(identity.uuid) return consumer diff --git a/src/subscription_manager/injection.py b/src/subscription_manager/injection.py index db55a36a96..9436e29b88 100644 --- a/src/subscription_manager/injection.py +++ b/src/subscription_manager/injection.py @@ -12,6 +12,8 @@ # in this software or its documentation. # # Supported Features: +from typing import Dict + IDENTITY = "IDENTITY" CERT_SORTER = "CERT_SORTER" PRODUCT_DATE_RANGE_CALCULATOR = "PRODUCT_DATE_RANGE_CALCULATOR" @@ -49,9 +51,9 @@ class FeatureBroker(object): """ def __init__(self): - self.providers = {} + self.providers: Dict[str, object] = {} - def provide(self, feature, provider): + def provide(self, feature: str, provider: object) -> None: """ Provide an implementation for a feature. @@ -62,7 +64,7 @@ def provide(self, feature, provider): """ self.providers[feature] = provider - def require(self, feature, *args, **kwargs): + def require(self, feature: str, *args, **kwargs) -> object: """ Require an implementation for a feature. Can be used to create objects without requiring an exact implementation to use. @@ -84,7 +86,7 @@ def require(self, feature, *args, **kwargs): return self.providers[feature] -def nonSingleton(other): +def nonSingleton(other: type) -> object: """ Creates a factory method for a class. Passes args to the constructor in order to create a new object every time it is required. @@ -103,12 +105,15 @@ def factory(*args, **kwargs): # Small wrapper functions to make usage look a little cleaner, can use these # instead of the global: -def require(feature, *args, **kwargs): +def require(feature: str, *args, **kwargs): + # NOTE This function is not type hinted for a reason. Putting anything here + # makes IDEs and type checkers question what we are doing, it's better + # to assign all types where this function is called. global FEATURES return FEATURES.require(feature, *args, **kwargs) -def provide(feature, provider, singleton=False): +def provide(feature: str, provider: object, singleton: bool = False) -> None: global FEATURES if not singleton and isinstance(provider, type): provider = nonSingleton(provider) diff --git a/src/subscription_manager/installedproductslib.py b/src/subscription_manager/installedproductslib.py index 13464b4c32..743d64aa99 100644 --- a/src/subscription_manager/installedproductslib.py +++ b/src/subscription_manager/installedproductslib.py @@ -8,18 +8,25 @@ # NON-INFRINGEMENT, or FITNESS FOR A PARTICULAR PURPOSE. You should # have received a copy of GPLv2 along with this software; if not, see # http://www.gnu.org/licenses/old-licenses/gpl-2.0.txt. -# +from typing import TYPE_CHECKING from subscription_manager import injection as inj from subscription_manager import certlib +if TYPE_CHECKING: + from rhsm.connection import UEPConnection + + from subscription_manager.cache import InstalledProductsManager + from subscription_manager.cp_provider import CPProvider + from subscription_manager.identity import Identity + class InstalledProductsActionInvoker(certlib.BaseActionInvoker): """Used by rhsmcertd to update the installed products on this system periodically. """ - def _do_update(self): + def _do_update(self) -> "InstalledProductsActionReport": action = InstalledProductsActionCommand() return action.perform() @@ -32,12 +39,12 @@ class InstalledProductsActionCommand(object): def __init__(self): self.report = InstalledProductsActionReport() - self.cp_provider = inj.require(inj.CP_PROVIDER) - self.uep = self.cp_provider.get_consumer_auth_cp() + self.cp_provider: CPProvider = inj.require(inj.CP_PROVIDER) + self.uep: UEPConnection = self.cp_provider.get_consumer_auth_cp() - def perform(self): - mgr = inj.require(inj.INSTALLED_PRODUCTS_MANAGER) - consumer_identity = inj.require(inj.IDENTITY) + def perform(self) -> "InstalledProductsActionReport": + mgr: InstalledProductsManager = inj.require(inj.INSTALLED_PRODUCTS_MANAGER) + consumer_identity: Identity = inj.require(inj.IDENTITY) ret = mgr.update_check(self.uep, consumer_identity.uuid) self.report._status = ret @@ -45,4 +52,4 @@ def perform(self): class InstalledProductsActionReport(certlib.ActionReport): - name = "Installed Products" + name: str = "Installed Products" diff --git a/src/subscription_manager/isodate.py b/src/subscription_manager/isodate.py index 01ce4aa360..3a81825b29 100644 --- a/src/subscription_manager/isodate.py +++ b/src/subscription_manager/isodate.py @@ -12,8 +12,6 @@ # Red Hat trademarks are not licensed under GPLv2. No permission is # granted to use or replicate Red Hat trademarks that are incorporated # in this software or its documentation. -# - import datetime import dateutil.parser @@ -22,9 +20,9 @@ log = logging.getLogger(__name__) -def _parse_date_dateutil(date): +def _parse_date_dateutil(date: str) -> datetime.datetime: try: - dt = dateutil.parser.parse(date) + dt: datetime.datetime = dateutil.parser.parse(date) # the datetime.datetime objects returned by dateutil use the own # dateutil.tz.tzutc object for UTC, rather than the datetime own; # switch the tzinfo object of UTC dates to datetime.timezone.utc: diff --git a/src/subscription_manager/jsonwrapper.py b/src/subscription_manager/jsonwrapper.py index 0de0eeff03..c90c351858 100644 --- a/src/subscription_manager/jsonwrapper.py +++ b/src/subscription_manager/jsonwrapper.py @@ -13,41 +13,43 @@ # # This module contains wrappers for JSON returned from the CP server. +from typing import Dict, List, Optional, Union + from subscription_manager.utils import is_true_value class PoolWrapper(object): - def __init__(self, pool_json): - self.data = pool_json + def __init__(self, pool_json: dict): + self.data: dict = pool_json - def get_id(self): + def get_id(self) -> str: return self.data.get("id") - def is_virt_only(self): - attributes = self.data["attributes"] - virt_only = False + def is_virt_only(self) -> bool: + attributes: dict = self.data["attributes"] + virt_only: bool = False for attribute in attributes: - name = attribute["name"] - value = attribute["value"] + name: str = attribute["name"] + value: str = attribute["value"] if name == "virt_only": virt_only = is_true_value(value) break return virt_only - def management_enabled(self): + def management_enabled(self) -> bool: return is_true_value(self._get_attribute_value("productAttributes", "management_enabled")) - def get_stacking_id(self): + def get_stacking_id(self) -> str: return self._get_attribute_value("productAttributes", "stacking_id") - def get_service_level(self): + def get_service_level(self) -> str: return self._get_attribute_value("productAttributes", "support_level") - def get_service_type(self): + def get_service_type(self) -> str: return self._get_attribute_value("productAttributes", "support_type") - def get_product_attributes(self, *attribute_names): + def get_product_attributes(self, *attribute_names: Union[str, List[str]]) -> Dict[str, object]: attrs = {} if "productAttributes" not in self.data: @@ -59,13 +61,13 @@ def get_product_attributes(self, *attribute_names): attrs[attr_name] = None for attr in self.data["productAttributes"]: - name = attr["name"] + name: str = attr["name"] if name in attribute_names: attrs[name] = attr["value"] return attrs - def get_suggested_quantity(self): - result = None + def get_suggested_quantity(self) -> Optional[int]: + result: Optional[int] = None try: result = int(self.data["calculatedAttributes"]["suggested_quantity"]) @@ -74,10 +76,10 @@ def get_suggested_quantity(self): return result - def get_pool_type(self): + def get_pool_type(self) -> str: return (self.data.get("calculatedAttributes") or {}).get("compliance_type", "") - def _get_attribute_value(self, attr_list_name, attr_name): + def _get_attribute_value(self, attr_list_name: str, attr_name: str) -> Optional[str]: product_attrs = self.data[attr_list_name] for attribute in product_attrs: name = attribute["name"] @@ -86,6 +88,6 @@ def _get_attribute_value(self, attr_list_name, attr_name): return value return None - def get_provided_products(self): + def get_provided_products(self) -> Dict[str, str]: products = self.data.get("providedProducts", []) return {prod.get("productId"): prod.get("productName") for prod in products} diff --git a/src/subscription_manager/listing.py b/src/subscription_manager/listing.py index 379baccf6e..c630327343 100644 --- a/src/subscription_manager/listing.py +++ b/src/subscription_manager/listing.py @@ -10,24 +10,25 @@ # Red Hat trademarks are not licensed under GPLv2. No permission is # granted to use or replicate Red Hat trademarks that are incorporated # in this software or its documentation. -# + +from typing import List, Optional class ListingFile(object): - def __init__(self, data=None): - self.data = data - self.releases = [] + def __init__(self, data: Optional[str] = None): + self.data: Optional[str] = data + self.releases: List[str] = [] self.parse() - def get_releases(self): + def get_releases(self) -> List[str]: return self.releases - def parse(self): + def parse(self) -> None: if not self.data: return - lines = self.data.split("\n") + lines: List[str] = self.data.split("\n") for line in lines: line = line.strip() diff --git a/src/subscription_manager/lock.py b/src/subscription_manager/lock.py index 34e1e55f05..99be0ff585 100644 --- a/src/subscription_manager/lock.py +++ b/src/subscription_manager/lock.py @@ -19,19 +19,20 @@ import time import logging +from typing import Optional, TextIO log = logging.getLogger(__name__) # how long to sleep before rechecking if we can # acquire the lock. -LOCK_WAIT_DURATION = 0.5 +LOCK_WAIT_DURATION: float = 0.5 class LockFile(object): - def __init__(self, path): - self.path = path - self.pid = None - self.fp = None + def __init__(self, path: str): + self.path: str = path + self.pid: Optional[int] = None + self.fp: Optional[TextIO] = None def open(self): if self.notcreated(): @@ -39,27 +40,27 @@ def open(self): self.setpid() self.close() self.fp = open(self.path, "r+") - fd = self.fp.fileno() + fd: int = self.fp.fileno() fcntl.flock(fd, fcntl.LOCK_EX) - def getpid(self): + def getpid(self) -> int: if self.pid is None: - content = self.fp.read().strip() + content: str = self.fp.read().strip() if content: self.pid = int(content) return self.pid - def setpid(self): + def setpid(self) -> None: self.fp.seek(0) content = str(os.getpid()) self.fp.write(content) self.fp.flush() - def mypid(self): + def mypid(self) -> bool: return os.getpid() == self.getpid() - def valid(self): - status = False + def valid(self) -> bool: + status: bool = False try: os.kill(self.getpid(), 0) status = True @@ -67,14 +68,14 @@ def valid(self): pass return status - def delete(self): + def delete(self) -> None: if self.mypid() or not self.valid(): self.close() os.unlink(self.path) - def close(self): + def close(self) -> None: try: - fd = self.fp.fileno() + fd: int = self.fp.fileno() fcntl.flock(fd, fcntl.LOCK_UN) self.fp.close() except Exception: @@ -82,10 +83,10 @@ def close(self): self.pid = None self.fp = None - def notcreated(self): + def notcreated(self) -> bool: return not os.path.exists(self.path) - def __del__(self): + def __del__(self) -> None: self.close() @@ -93,11 +94,11 @@ class Lock(object): mutex = Mutex() - def __init__(self, path): - self.depth = 0 - self.path = path - self.lockdir = None - self.blocking = None + def __init__(self, path: str): + self.depth: int = 0 + self.path: str = path + self.lockdir: Optional[str] = None + self.blocking: Optional[bool] = None lock_dir, _fn = os.path.split(self.path) try: @@ -107,7 +108,7 @@ def __init__(self, path): except Exception: self.lockdir = None - def acquire(self, blocking=None): + def acquire(self, blocking: Optional[bool] = None) -> Optional[bool]: """Behaviour here is modeled after threading.RLock.acquire. If 'blocking' is False, we return True if we didn't need to block and we acquired the lock. @@ -117,11 +118,10 @@ def acquire(self, blocking=None): if 'blocking' is None, we return None when we acquire the lock, otherwise block until we do. if 'blocking' is True, we behave the same as with blocking=None, except we return True. - """ if self.lockdir is None: - return + return None f = LockFile(self.path) try: try: @@ -147,6 +147,7 @@ def acquire(self, blocking=None): f.setpid() except OSError as e: log.exception(e) + # FIXME Stop printing this print("could not create lock") finally: f.close() @@ -156,7 +157,7 @@ def acquire(self, blocking=None): return True return None - def release(self): + def release(self) -> None: if self.lockdir is None: return if not self.acquired(): @@ -171,7 +172,7 @@ def release(self): finally: f.close() - def acquired(self): + def acquired(self) -> Optional[bool]: if self.lockdir is None: return mutex = self.mutex @@ -182,7 +183,7 @@ def acquired(self): mutex.release() # P - def wait(self): + def wait(self) -> "Lock": mutex = self.mutex mutex.acquire() try: @@ -194,7 +195,7 @@ def wait(self): P = wait # V - def signal(self): + def signal(self) -> None: mutex = self.mutex mutex.acquire() try: @@ -205,7 +206,7 @@ def signal(self): V = signal - def __del__(self): + def __del__(self) -> None: try: self.release() except Exception: diff --git a/src/subscription_manager/managercli.py b/src/subscription_manager/managercli.py index 2025b54a1d..091ab5fd8d 100644 --- a/src/subscription_manager/managercli.py +++ b/src/subscription_manager/managercli.py @@ -16,6 +16,11 @@ import logging import sys +from typing import List, Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from subscription_manager.cli_command.cli import CliCommand + from subscription_manager import managerlib from subscription_manager.cli import CLI from subscription_manager.cli_command.addons import AddonsCommand @@ -52,7 +57,7 @@ class ManagerCLI(CLI): def __init__(self): - commands = [ + commands: List[CliCommand] = [ RegisterCommand, UnRegisterCommand, AddonsCommand, @@ -82,11 +87,11 @@ def __init__(self): ] CLI.__init__(self, command_classes=commands) - def main(self): + def main(self) -> Optional[int]: managerlib.check_identity_cert_perms() - ret = CLI.main(self) + ret: Optional[int] = CLI.main(self) # Try to enable all yum plugins (subscription-manager and plugin-id) - enabled_yum_plugins = YumPluginManager.enable_pkg_plugins() + enabled_yum_plugins: List[str] = YumPluginManager.enable_pkg_plugins() if len(enabled_yum_plugins) > 0: print("\n" + _("WARNING") + "\n\n" + YumPluginManager.warning_message(enabled_yum_plugins) + "\n") # Try to close all connections diff --git a/src/subscription_manager/managerlib.py b/src/subscription_manager/managerlib.py index bd2d2ffcc7..39764d7a15 100644 --- a/src/subscription_manager/managerlib.py +++ b/src/subscription_manager/managerlib.py @@ -13,6 +13,7 @@ # granted to use or replicate Red Hat trademarks that are incorporated # in this software or its documentation. # +import datetime import glob import logging import os @@ -20,6 +21,8 @@ import shutil import stat import syslog +from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union, TYPE_CHECKING + from rhsm.config import get_config_parser from rhsm.certificate import Key, CertificateException, create_from_pem @@ -55,34 +58,45 @@ from subscription_manager.i18n import ugettext as _ +if TYPE_CHECKING: + from rhsm.connection import UEPConnection + from rhsm.certificate2 import EntitlementCertificate, DateRange + from rhsm.config import RhsmConfigParser + from subscription_manager.cert_sorter import CertSorter + from subscription_manager.identity import Identity + from subscription_manager.certdirectory import EntitlementDirectory, ProductDirectory + from subscription_manager.cp_provider import CPProvider + from subscription_manager.entcertlib import EntCertUpdateReport + + log = logging.getLogger(__name__) -cfg = get_config_parser() -ENT_CONFIG_DIR = cfg.get("rhsm", "entitlementCertDir") +cfg: "RhsmConfigParser" = get_config_parser() +ENT_CONFIG_DIR: str = cfg.get("rhsm", "entitlementCertDir") # Expected permissions for identity certificates: -ID_CERT_PERMS = stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP +ID_CERT_PERMS: int = stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP -def system_log(message, priority=syslog.LOG_NOTICE): +def system_log(message: str, priority: int = syslog.LOG_NOTICE) -> None: utils.system_log(message, priority) -def close_all_connections(): +def close_all_connections() -> None: """ Close all connections :return: None """ - cpp_provider = require(CP_PROVIDER) + cpp_provider: CPProvider = require(CP_PROVIDER) cpp_provider.close_all_connections() # FIXME: move me to identity.py -def persist_consumer_cert(consumerinfo): +def persist_consumer_cert(consumerinfo: dict) -> None: """ Calls the consumerIdentity, persists and gets consumer info """ - cert_dir = cfg.get("rhsm", "consumerCertDir") + cert_dir: str = cfg.get("rhsm", "consumerCertDir") if not os.path.isdir(cert_dir): os.mkdir(cert_dir) consumer = identity.ConsumerIdentity(consumerinfo["idCert"]["key"], consumerinfo["idCert"]["cert"]) @@ -92,20 +106,22 @@ def persist_consumer_cert(consumerinfo): class CertificateFetchError(Exception): - def __init__(self, errors): + def __init__(self, errors: Iterable[Exception]): self.errors = errors - def __str__(self, reason=""): + def __str__(self, reason: str = "") -> str: + # FIXME Explicitly convert errors to strings msg = "Entitlement Certificate(s) update failed due to the following reasons:\n" + "\n".join( self.errors ) return msg -def fetch_certificates(certlib): +# FIXME Does not seem to be used +def fetch_certificates(certlib) -> Literal[True]: # Force fetch all certs - result = certlib.update() - exceptions = result.exceptions() + result: EntCertUpdateReport = certlib.update() + exceptions: List[Exception] = result.exceptions() if exceptions: raise CertificateFetchError(exceptions) @@ -119,19 +135,23 @@ class PoolFilter(object): # Although sorter isn't necessarily required, when present it allows # us to not filter out yellow packages when "has no overlap" is selected - def __init__(self, product_dir, entitlement_dir, sorter=None): - - self.product_directory = product_dir - self.entitlement_directory = entitlement_dir - self.sorter = sorter + def __init__( + self, + product_dir: "ProductDirectory", + entitlement_dir: "EntitlementDirectory", + sorter: Optional[ComplianceManager] = None, + ): + self.product_directory: ProductDirectory = product_dir + self.entitlement_directory: EntitlementDirectory = entitlement_dir + self.sorter: Optional[ComplianceManager] = sorter - def filter_product_ids(self, pools, product_ids): + def filter_product_ids(self, pools: Iterable[dict], product_ids: Iterable[str]) -> List[dict]: """ Filter a list of pools and return just those that provide products in the requested list of product ids. Both the top level product and all provided products will be checked. """ - matched_pools = [] + matched_pools: List[dict] = [] for pool in pools: if pool["productId"] in product_ids: log.debug("pool matches: %s" % pool["productId"]) @@ -145,32 +165,32 @@ def filter_product_ids(self, pools, product_ids): break return matched_pools - def filter_out_uninstalled(self, pools): + def filter_out_uninstalled(self, pools: Iterable[dict]) -> List[dict]: """ Filter the given list of pools, return only those which provide a product installed on this system. """ - installed_products = self.product_directory.list() - matched_data_dict = {} + installed_products: List[EntitlementCertificate] = self.product_directory.list() + matched_data_dict: Dict[str, dict] = {} for d in pools: for product in installed_products: productid = product.products[0].id # we only need one matched item per pool id, so add to dict to keep unique: # Build a list of provided product IDs for comparison: - provided_ids = [p["productId"] for p in d["providedProducts"]] + provided_ids: List[str] = [p["productId"] for p in d["providedProducts"]] if str(productid) in provided_ids or str(productid) == d["productId"]: matched_data_dict[d["id"]] = d return list(matched_data_dict.values()) - def filter_out_installed(self, pools): + def filter_out_installed(self, pools: Iterable[dict]) -> List[dict]: """ Filter the given list of pools, return only those which do not provide a product installed on this system. """ - installed_products = self.product_directory.list() - matched_data_dict = {} + installed_products: List[EntitlementCertificate] = self.product_directory.list() + matched_data_dict: Dict[str, dict] = {} for d in pools: matched_data_dict[d["id"]] = d provided_ids = [p["productId"] for p in d["providedProducts"]] @@ -183,13 +203,13 @@ def filter_out_installed(self, pools): return list(matched_data_dict.values()) - def filter_product_name(self, pools, contains_text): + def filter_product_name(self, pools: Iterable[dict], contains_text: str) -> List[dict]: """ Filter the given list of pools, removing those whose product name does not contain the given text. """ - lowered = contains_text.lower() - filtered_pools = [] + lowered: str = contains_text.lower() + filtered_pools: List[dict] = [] for pool in pools: if lowered in pool["productName"].lower(): filtered_pools.append(pool) @@ -200,15 +220,15 @@ def filter_product_name(self, pools, contains_text): break return filtered_pools - def _get_entitled_product_ids(self): - entitled_products = [] + def _get_entitled_product_ids(self) -> List[dict]: + entitled_products: List[dict] = [] for cert in self.entitlement_directory.list(): for product in cert.products: entitled_products.append(product.id) return entitled_products - def _get_entitled_product_to_cert_map(self): - entitled_products_to_certs = {} + def _get_entitled_product_to_cert_map(self) -> Dict[str, set]: + entitled_products_to_certs: Dict[str, set] = {} for cert in self.entitlement_directory.list(): for product in cert.products: prod_id = product.id @@ -217,19 +237,19 @@ def _get_entitled_product_to_cert_map(self): entitled_products_to_certs[prod_id].add(cert) return entitled_products_to_certs - def _dates_overlap(self, pool, certs): + def _dates_overlap(self, pool: dict, certs: Iterable["EntitlementCertificate"]) -> bool: pool_start = isodate.parse_date(pool["startDate"]) pool_end = isodate.parse_date(pool["endDate"]) for cert in certs: - cert_range = cert.valid_range + cert_range: DateRange = cert.valid_range if cert_range.has_date(pool_start) or cert_range.has_date(pool_end): return True return False - def filter_out_overlapping(self, pools): - entitled_product_ids_to_certs = self._get_entitled_product_to_cert_map() - filtered_pools = [] + def filter_out_overlapping(self, pools: Iterable[dict]) -> List[dict]: + entitled_product_ids_to_certs: Dict[str, set] = self._get_entitled_product_to_cert_map() + filtered_pools: List[dict] = [] for pool in pools: provided_ids = set([p["productId"] for p in pool["providedProducts"]]) wrapped_pool = PoolWrapper(pool) @@ -237,7 +257,7 @@ def filter_out_overlapping(self, pools): # or handle the case of a product with no type in the future if wrapped_pool.get_product_attributes("type")["type"] == "SVC": provided_ids.add(pool["productId"]) - overlap = 0 + overlap: int = 0 possible_overlap_pids = provided_ids.intersection(list(entitled_product_ids_to_certs.keys())) for productid in possible_overlap_pids: if ( @@ -252,19 +272,24 @@ def filter_out_overlapping(self, pools): return filtered_pools - def filter_out_non_overlapping(self, pools): + def filter_out_non_overlapping(self, pools: Iterable[dict]) -> List[dict]: not_overlapping = self.filter_out_overlapping(pools) return [pool for pool in pools if pool not in not_overlapping] - def filter_subscribed_pools(self, pools, subscribed_pool_ids, compatible_pools): + def filter_subscribed_pools( + self, + pools: Iterable[dict], + subscribed_pool_ids: Iterable[str], + compatible_pools: Dict[str, dict], + ) -> List[dict]: """ Filter the given list of pools, removing those for which the system already has a subscription, unless the pool can be subscribed to again (ie has multi-entitle). """ - resubscribeable_pool_ids = [pool["id"] for pool in list(compatible_pools.values())] + resubscribeable_pool_ids: List[str] = [pool["id"] for pool in list(compatible_pools.values())] - filtered_pools = [] + filtered_pools: List[dict] = [] for pool in pools: if (pool["id"] not in subscribed_pool_ids) or (pool["id"] in resubscribeable_pool_ids): filtered_pools.append(pool) @@ -272,15 +297,15 @@ def filter_subscribed_pools(self, pools, subscribed_pool_ids, compatible_pools): def list_pools( - uep, - consumer_uuid, - list_all=False, - active_on=None, - filter_string=None, - future=None, - after_date=None, - page=0, - items_per_page=0, + uep: "UEPConnection", + consumer_uuid: str, + list_all: bool = False, + active_on: Optional[datetime.datetime] = None, + filter_string: Optional[str] = None, + future: Optional[str] = None, + after_date: Optional[datetime.datetime] = None, + page: int = 0, + items_per_page: int = 0, ): """ Wrapper around the UEP call to fetch pools, which forces a facts update @@ -309,8 +334,8 @@ def list_pools( profile_mgr = cache.ProfileManager() profile_mgr.update_check(uep, consumer_uuid) - owner = uep.getOwner(consumer_uuid) - ownerid = owner["key"] + owner: dict = uep.getOwner(consumer_uuid) + ownerid: str = owner["key"] return uep.getPoolsList( consumer=consumer_uuid, @@ -329,25 +354,25 @@ def list_pools( # dict which does not contain all the pool info. Not sure if this is really # necessary. Also some "view" specific things going on in here. def get_available_entitlements( - get_all=False, - active_on=None, - overlapping=False, - uninstalled=False, - text=None, - filter_string=None, - future=None, - after_date=None, - page=0, - items_per_page=0, - iso_dates=False, -): + get_all: bool = False, + active_on: Optional[datetime.datetime] = None, + overlapping: bool = False, + uninstalled: bool = False, + text: Optional[str] = None, + filter_string: Optional[str] = None, + future: Optional[str] = None, + after_date: Optional[datetime.datetime] = None, + page: int = 0, + items_per_page: int = 0, + iso_dates: bool = False, +) -> List[dict]: """ Returns a list of entitlement pools from the server. The 'all' setting can be used to return all pools, even if the rules do not pass. (i.e. show pools that are incompatible for your hardware) """ - columns = [ + columns: List[str] = [ "id", "quantity", "consumed", @@ -369,7 +394,7 @@ def get_available_entitlements( ] pool_stash = PoolStash() - dlist = pool_stash.get_filtered_pools_list( + dlist: List[dict] = pool_stash.get_filtered_pools_list( active_on, not get_all, overlapping, @@ -382,6 +407,7 @@ def get_available_entitlements( items_per_page=items_per_page, ) + date_formatter: Callable if iso_dates: date_formatter = format_iso8601_date else: @@ -432,15 +458,15 @@ class MergedPools(object): particular product. """ - def __init__(self, product_id, product_name): - self.product_id = product_id - self.product_name = product_name - self.bundled_products = 0 - self.quantity = 0 # how many entitlements were purchased - self.consumed = 0 # how many are in use - self.pools = [] + def __init__(self, product_id: str, product_name: str): + self.product_id: str = product_id + self.product_name: str = product_name + self.bundled_products: int = 0 + self.quantity: int = 0 # how many entitlements were purchased + self.consumed: int = 0 # how many are in use + self.pools: List[dict] = [] - def add_pool(self, pool): + def add_pool(self, pool: dict) -> None: # TODO: check if product id and name match? self.consumed += pool["consumed"] # we want to add the quantity for this pool @@ -460,7 +486,7 @@ def add_pool(self, pool): # is added and hope they are consistent. self.bundled_products = len(pool["providedProducts"]) - def _virt_physical_sorter(self, pool): + def _virt_physical_sorter(self, pool: dict) -> int: """ Used to sort the pools, return Physical or Virt depending on the value or existence of the virt_only attribute. @@ -472,7 +498,7 @@ def _virt_physical_sorter(self, pool): return 1 return 2 - def sort_virt_to_top(self): + def sort_virt_to_top(self) -> None: """ Prioritizes virt pools to the front of the list, if any are present. @@ -481,7 +507,7 @@ def sort_virt_to_top(self): self.pools.sort(key=self._virt_physical_sorter) -def merge_pools(pools): +def merge_pools(pools: List[dict]) -> Dict: """ Merges the given pools into a data structure representing the totals for a particular product, across all pools for that product. @@ -492,7 +518,7 @@ def merge_pools(pools): Returns a dict mapping product ID to MergedPools object. """ # Map product ID to MergedPools object: - merged_pools = {} + merged_pools: dict = {} for pool in pools: if not pool["productId"] in merged_pools: @@ -509,7 +535,7 @@ class MergedPoolsStackingGroupSorter(StackingGroupSorter): Sorts a list of MergedPool objects by stacking_id. """ - def __init__(self, merged_pools): + def __init__(self, merged_pools: List["EntitlementCertificate"]): StackingGroupSorter.__init__(self, merged_pools) def _get_stacking_id(self, merged_pool): @@ -526,8 +552,8 @@ class PoolStash(object): """ def __init__(self): - self.identity = require(IDENTITY) - self.sorter = None + self.identity: Identity = require(IDENTITY) + self.sorter: Optional[Union[ComplianceManager, CertSorter]] = None # Pools which passed rules server side for this consumer: self.compatible_pools = {} @@ -541,10 +567,10 @@ def __init__(self): # All pools: self.all_pools = {} - def all_pools_size(self): + def all_pools_size(self) -> int: return len(self.all_pools) - def refresh(self, active_on): + def refresh(self, active_on: Optional[datetime.datetime]) -> None: """ Refresh the list of pools from the server, active on the given date. """ @@ -588,23 +614,23 @@ def refresh(self, active_on): def get_filtered_pools_list( self, - active_on, - incompatible, - overlapping, - uninstalled, - text, - filter_string, - future=None, - after_date=None, - page=0, - items_per_page=0, - ): + active_on: Optional[datetime.datetime], + incompatible: bool, + overlapping: bool, + uninstalled: bool, + text: Optional[str], + filter_string: Optional[str], + future: Optional[str] = None, + after_date: Optional[datetime.datetime] = None, + page: int = 0, + items_per_page: int = 0, + ) -> List[dict]: """ Used for CLI --available filtering cuts down on api calls """ - self.all_pools = {} - self.compatible_pools = {} + self.all_pools: Dict[str, dict] = {} + self.compatible_pools: Dict[str, dict] = {} if active_on and overlapping: self.sorter = ComplianceManager(active_on) elif not active_on and overlapping: @@ -640,10 +666,17 @@ def get_filtered_pools_list( return self._filter_pools(incompatible, overlapping, uninstalled, False, text) - def _get_subscribed_pool_ids(self): + def _get_subscribed_pool_ids(self) -> List[str]: return [ent.pool.id for ent in require(ENT_DIR).list()] - def _filter_pools(self, incompatible, overlapping, uninstalled, subscribed, text): + def _filter_pools( + self, + incompatible: bool, + overlapping: bool, + uninstalled: bool, + subscribed: bool, + text: Optional[str], + ): """ Return a list of pool hashes, filtered according to the given options. @@ -691,8 +724,13 @@ def _filter_pools(self, incompatible, overlapping, uninstalled, subscribed, text return pools def merge_pools( - self, incompatible=False, overlapping=False, uninstalled=False, subscribed=False, text=None - ): + self, + incompatible: bool = False, + overlapping: bool = False, + uninstalled: bool = False, + subscribed: bool = False, + text: Optional[str] = None, + ) -> dict: """ Return a merged view of pools filtered according to the given options. Pools for the same product will be merged into a MergedPool object. @@ -704,7 +742,7 @@ def merge_pools( merged_pools = merge_pools(pools) return merged_pools - def lookup_provided_products(self, pool_id): + def lookup_provided_products(self, pool_id: str) -> Optional[List[Tuple[str, str]]]: """ Return a list of tuples (product name, product id) for all products provided for a given pool. If we do not actually have any info on this @@ -715,7 +753,7 @@ def lookup_provided_products(self, pool_id): log.debug("pool id %s not found in all_pools", pool_id) return None - provided_products = [] + provided_products: List[Tuple[str, str]] = [] for product in pool["providedProducts"]: provided_products.append((product["productName"], product["productId"])) return provided_products @@ -762,20 +800,20 @@ class ImportFileExtractor(object): _ENT_DICT_TAG = "ENTITLEMENT" _SIG_DICT_TAG = "RSA SIGNATURE" - def __init__(self, cert_file_path): + def __init__(self, cert_file_path: str): self.path = cert_file_path self.file_name = os.path.basename(cert_file_path) content = self._read(cert_file_path) self.parts = self._process_content(content) - def _read(self, file_path): + def _read(self, file_path: str) -> str: fd = open(file_path, "r") file_content = fd.read() fd.close() return file_content - def _process_content(self, content): + def _process_content(self, content: str) -> Dict[str, str]: part_dict = {} matches = self._PATTERN.finditer(content) for match in matches: @@ -799,16 +837,16 @@ def _process_content(self, content): part_dict[dict_key] = start + meat + end return part_dict - def contains_key_content(self): + def contains_key_content(self) -> bool: return self._KEY_DICT_TAG in self.parts - def get_key_content(self): + def get_key_content(self) -> Optional[str]: key_content = None if self._KEY_DICT_TAG in self.parts: key_content = self.parts[self._KEY_DICT_TAG] return key_content - def get_cert_content(self): + def get_cert_content(self) -> str: cert_content = "" if self._CERT_DICT_TAG in self.parts: cert_content = self.parts[self._CERT_DICT_TAG] @@ -818,7 +856,7 @@ def get_cert_content(self): cert_content = cert_content + os.linesep + self.parts[self._SIG_DICT_TAG] return cert_content - def verify_valid_entitlement(self): + def verify_valid_entitlement(self) -> bool: """ Verify that a valid entitlement was processed. @@ -838,7 +876,7 @@ def verify_valid_entitlement(self): return True # TODO: rewrite to use certlib.EntitlementCertBundleInstall? - def write_to_disk(self): + def write_to_disk(self) -> None: """ Write/copy cert to the entitlement cert dir. """ @@ -855,38 +893,38 @@ def write_to_disk(self): log.debug("Writing key file: %s" % (dest_key_file_path)) self._write_file(dest_key_file_path, self.get_key_content()) - def _write_file(self, target_path, content): + def _write_file(self, target_path: str, content: str) -> None: new_file = open(target_path, "w") try: new_file.write(content) finally: new_file.close() - def _ensure_entitlement_dir_exists(self): + def _ensure_entitlement_dir_exists(self) -> None: if not os.access(ENT_CONFIG_DIR, os.R_OK): os.mkdir(ENT_CONFIG_DIR) - def _get_key_path_from_dest_cert_path(self, dest_cert_path): + def _get_key_path_from_dest_cert_path(self, dest_cert_path: str) -> str: file_parts = os.path.splitext(dest_cert_path) return file_parts[0] + "-key" + file_parts[1] - def _create_filename_from_cert_serial_number(self): + def _create_filename_from_cert_serial_number(self) -> str: "create from serial" ent_cert = self.get_cert() return "%s.pem" % (ent_cert.serial) - def get_cert(self): - cert_content = self.get_cert_content() - ent_cert = create_from_pem(cert_content) + def get_cert(self) -> "EntitlementCertificate": + cert_content: str = self.get_cert_content() + ent_cert: EntitlementCertificate = create_from_pem(cert_content) return ent_cert -def _sub_dict(datadict, subkeys, default=None): +def _sub_dict(datadict: dict, subkeys: Iterable[str], default: Optional[object] = None) -> dict: """Return a dict that is a subset of datadict matching only the keys in subkeys""" return dict([(k, datadict.get(k, default)) for k in subkeys]) -def format_date(dt): +def format_date(dt: datetime.datetime) -> str: if dt: try: return dt.astimezone(tzlocal()).strftime("%x") @@ -897,7 +935,7 @@ def format_date(dt): return "" -def format_iso8601_date(dateobj): +def format_iso8601_date(dateobj: Optional[datetime.datetime]) -> str: """ Format the specified datetime.date dateobj as ISO 8601, i.e. YYYY-MM-DD. @@ -909,29 +947,29 @@ def format_iso8601_date(dateobj): # FIXME: move me to identity.py -def check_identity_cert_perms(): +def check_identity_cert_perms() -> None: """ Ensure the identity certs on this system have the correct permissions, and fix them if not. """ - certs = [identity.ConsumerIdentity.keypath(), identity.ConsumerIdentity.certpath()] + certs: List[str] = [identity.ConsumerIdentity.keypath(), identity.ConsumerIdentity.certpath()] for cert in certs: if not os.path.exists(cert): # Only relevant if these files exist. continue - statinfo = os.stat(cert) + statinfo: os.stat_result = os.stat(cert) if statinfo[stat.ST_UID] != 0 or statinfo[stat.ST_GID] != 0: os.chown(cert, 0, 0) log.warn("Corrected incorrect ownership of %s." % cert) - mode = stat.S_IMODE(statinfo[stat.ST_MODE]) + mode: int = stat.S_IMODE(statinfo[stat.ST_MODE]) if mode != ID_CERT_PERMS: os.chmod(cert, ID_CERT_PERMS) log.warn("Corrected incorrect permissions on %s." % cert) -def clean_all_data(backup=True): - consumer_dir = cfg.get("rhsm", "consumerCertDir") +def clean_all_data(backup: bool = True) -> None: + consumer_dir: str = cfg.get("rhsm", "consumerCertDir") if backup: if consumer_dir[-1] == "/": consumer_dir_backup = consumer_dir[0:-1] + ".old" @@ -992,7 +1030,7 @@ def clean_all_data(backup=True): log.debug("Cleaned local data") -def valid_quantity(quantity): +def valid_quantity(quantity: Union[int, str, None]) -> bool: if not quantity: return False @@ -1002,7 +1040,7 @@ def valid_quantity(quantity): return False -def allows_multi_entitlement(pool): +def allows_multi_entitlement(pool: dict) -> bool: """ Determine if this pool allows multi-entitlement based on the pool's top-level product's multi-entitlement attribute. diff --git a/src/subscription_manager/overrides.py b/src/subscription_manager/overrides.py index ab08dc1de7..c0c32c9c44 100644 --- a/src/subscription_manager/overrides.py +++ b/src/subscription_manager/overrides.py @@ -11,12 +11,19 @@ # Red Hat trademarks are not licensed under GPLv2. No permission is # granted to use or replicate Red Hat trademarks that are incorporated # in this software or its documentation. -# + +from typing import Dict, List, Optional, TYPE_CHECKING + from subscription_manager import injection as inj from subscription_manager.repolib import RepoActionInvoker import logging +if TYPE_CHECKING: + from rhsm.connection import UEPConnection + from subscription_manager.cache import OverrideStatusCache + from subscription_manager.cp_provider import CPProvider + log = logging.getLogger(__name__) # Module for manipulating content overrides @@ -24,59 +31,59 @@ class Overrides(object): def __init__(self): - self.cp_provider = inj.require(inj.CP_PROVIDER) + self.cp_provider: CPProvider = inj.require(inj.CP_PROVIDER) - self.cache = inj.require(inj.OVERRIDE_STATUS_CACHE) + self.cache: OverrideStatusCache = inj.require(inj.OVERRIDE_STATUS_CACHE) self.repo_lib = RepoActionInvoker(cache_only=True) - def get_overrides(self, consumer_uuid): + def get_overrides(self, consumer_uuid: str) -> List["Override"]: return self._build_from_json(self.cache.load_status(self._getuep(), consumer_uuid)) - def add_overrides(self, consumer_uuid, overrides): + def add_overrides(self, consumer_uuid, overrides) -> List["Override"]: return self._build_from_json(self._getuep().setContentOverrides(consumer_uuid, self._add(overrides))) - def remove_overrides(self, consumer_uuid, overrides): + def remove_overrides(self, consumer_uuid, overrides) -> List["Override"]: return self._delete_overrides(consumer_uuid, self._remove(overrides)) - def remove_all_overrides(self, consumer_uuid, repos): + def remove_all_overrides(self, consumer_uuid, repos) -> List["Override"]: return self._delete_overrides(consumer_uuid, self._remove_all(repos)) - def update(self, overrides): + def update(self, overrides: List["Override"]) -> None: self.cache.server_status = [override.to_json() for override in overrides] self.cache.write_cache() self.repo_lib.update() - def _delete_overrides(self, consumer_uuid, override_data): + def _delete_overrides(self, consumer_uuid: str, override_data: List[dict]): return self._build_from_json(self._getuep().deleteContentOverrides(consumer_uuid, override_data)) - def _add(self, overrides): + def _add(self, overrides: List) -> List[dict]: return [override.to_json() for override in overrides] - def _remove(self, overrides): + def _remove(self, overrides: List) -> List[dict]: return [{"contentLabel": override.repo_id, "name": override.name} for override in overrides] - def _remove_all(self, repos): + def _remove_all(self, repos: List[dict]) -> Optional[List[dict]]: if repos: return [{"contentLabel": repo} for repo in repos] else: return None - def _build_from_json(self, override_json): + def _build_from_json(self, override_json: List[dict]) -> List["Override"]: return [Override.from_json(override_dict) for override_dict in override_json] - def _getuep(self): + def _getuep(self) -> "UEPConnection": return self.cp_provider.get_consumer_auth_cp() class Override(object): - def __init__(self, repo_id, name, value=None): - self.repo_id = repo_id - self.name = name - self.value = value + def __init__(self, repo_id: str, name: str, value: Optional[object] = None): + self.repo_id: str = repo_id + self.name: str = name + self.value: Optional[object] = value @classmethod - def from_json(cls, json_obj): + def from_json(cls, json_obj: Dict[str, str]): return cls(json_obj["contentLabel"], json_obj["name"], json_obj["value"]) - def to_json(self): + def to_json(self) -> Dict[str, str]: return {"contentLabel": self.repo_id, "name": self.name, "value": self.value} diff --git a/src/subscription_manager/packageprofilelib.py b/src/subscription_manager/packageprofilelib.py index 52cf23251d..21129ef968 100644 --- a/src/subscription_manager/packageprofilelib.py +++ b/src/subscription_manager/packageprofilelib.py @@ -9,6 +9,13 @@ # have received a copy of GPLv2 along with this software; if not, see # http://www.gnu.org/licenses/old-licenses/gpl-2.0.txt. # +from typing import Literal, TYPE_CHECKING + +if TYPE_CHECKING: + from rhsm.connection import UEPConnection + from subscription_manager.identity import Identity + from subscription_manager.cache import ProfileManager + from subscription_manager.cp_provider import CPProvider from subscription_manager import injection as inj from subscription_manager import certlib @@ -19,7 +26,7 @@ class PackageProfileActionInvoker(certlib.BaseActionInvoker): periodically. """ - def _do_update(self): + def _do_update(self) -> "PackageProfileActionReport": action = PackageProfileActionCommand() return action.perform() @@ -32,13 +39,13 @@ class PackageProfileActionCommand(object): def __init__(self): self.report = PackageProfileActionReport() - self.cp_provider = inj.require(inj.CP_PROVIDER) - self.uep = self.cp_provider.get_consumer_auth_cp() + self.cp_provider: CPProvider = inj.require(inj.CP_PROVIDER) + self.uep: UEPConnection = self.cp_provider.get_consumer_auth_cp() - def perform(self, force_upload=False): - profile_mgr = inj.require(inj.PROFILE_MANAGER) - consumer_identity = inj.require(inj.IDENTITY) - ret = profile_mgr.update_check(self.uep, consumer_identity.uuid, force=force_upload) + def perform(self, force_upload: bool = False) -> "PackageProfileActionReport": + profile_mgr: ProfileManager = inj.require(inj.PROFILE_MANAGER) + consumer_identity: Identity = inj.require(inj.IDENTITY) + ret: Literal[0, 1] = profile_mgr.update_check(self.uep, consumer_identity.uuid, force=force_upload) self.report._status = ret return self.report diff --git a/src/subscription_manager/plugins.py b/src/subscription_manager/plugins.py index e91daa1083..5b941ab20f 100644 --- a/src/subscription_manager/plugins.py +++ b/src/subscription_manager/plugins.py @@ -10,12 +10,14 @@ # Red Hat trademarks are not licensed under GPLv2. No permission is # granted to use or replicate Red Hat trademarks that are incorporated # in this software or its documentation. -# + import glob import inspect import logging import os import importlib.util +from typing import Callable, Dict, Iterator, List, Optional, SupportsFloat, SupportsInt, Tuple, Type +from typing import TYPE_CHECKING from iniparse import SafeConfigParser from iniparse.compat import NoSectionError, NoOptionError @@ -23,6 +25,10 @@ from rhsm.config import get_config_parser from subscription_manager.base_plugin import SubManPlugin +if TYPE_CHECKING: + from subscription_manager.content_action_client import ContentPluginActionReport + from subscription_manager.model import EntitlementSource + # The API_VERSION constant defines the current plugin API version. It is used # to decided whether or not plugins can be loaded. It is compared against the # 'requires_api_version' attribute of each plugin. The version number has the @@ -50,7 +56,7 @@ class PluginException(Exception): """Base exception for rhsm plugins.""" - def _add_message(self, repr_msg): + def _add_message(self, repr_msg: str) -> str: if hasattr(self, "msg") and self.msg: repr_msg = "\n".join([repr_msg, "Message: %s" % self.msg]) return repr_msg @@ -59,10 +65,10 @@ def _add_message(self, repr_msg): class PluginImportException(PluginException): """Raised when a SubManPlugin derived class can not be imported.""" - def __init__(self, module_file, module_name, msg=None): - self.module_file = module_file - self.module_name = module_name - self.msg = msg + def __init__(self, module_file: str, module_name: str, msg: Optional[str] = None): + self.module_file: str = module_file + self.module_name: str = module_name + self.msg: Optional[str] = msg def __str__(self): repr_msg = 'Plugin "%s" can\'t be imported from file %s' % (self.module_name, self.module_file) @@ -157,7 +163,7 @@ class BaseConduit(object): slots = [] # clazz is the class object for class instance of the object the hook method maps too - def __init__(self, clazz, conf=None): + def __init__(self, clazz: Type[SubManPlugin], conf: Optional["PluginConfig"] = None): if conf: self._conf = conf else: @@ -166,7 +172,7 @@ def __init__(self, clazz, conf=None): # maybe useful to have a per conduit/per plugin logger space self.log = logging.getLogger(clazz.__name__) - def conf_string(self, section, option, default=None): + def conf_string(self, section: str, option: str, default: Optional[object] = None) -> Optional[str]: """get string from plugin config Args: @@ -187,7 +193,7 @@ def conf_string(self, section, option, default=None): return None return str(default) - def conf_bool(self, section, option, default=None): + def conf_bool(self, section: str, option: str, default: Optional[bool] = None) -> bool: """get boolean value from plugin config Args: @@ -212,7 +218,7 @@ def conf_bool(self, section, option, default=None): else: raise ValueError("Boolean value expected") - def conf_int(self, section, option, default=None): + def conf_int(self, section: str, option: str, default: Optional[SupportsInt] = None) -> int: """get integer value from plugin config Args: @@ -236,7 +242,7 @@ def conf_int(self, section, option, default=None): raise ValueError("Integer value expected") return val - def conf_float(self, section, option, default=None): + def conf_float(self, section: str, option: str, default: Optional[SupportsFloat] = None) -> float: """get float value from plugin config Args: @@ -267,7 +273,7 @@ class RegistrationConduit(BaseConduit): slots = ["pre_register_consumer"] - def __init__(self, clazz, name, facts): + def __init__(self, clazz: Type[SubManPlugin], name: str, facts: Dict[str, str]): """init for RegistrationConduit Args: @@ -275,8 +281,8 @@ def __init__(self, clazz, name, facts): facts: a dictionary of system facts """ super(RegistrationConduit, self).__init__(clazz) - self.name = name - self.facts = facts + self.name: str = name + self.facts: Dict[str, str] = facts class PostRegistrationConduit(BaseConduit): @@ -284,7 +290,7 @@ class PostRegistrationConduit(BaseConduit): slots = ["post_register_consumer"] - def __init__(self, clazz, consumer, facts): + def __init__(self, clazz: Type[SubManPlugin], consumer: dict, facts: Dict[str, str]): """init for PostRegistrationConduit Args: @@ -293,8 +299,8 @@ def __init__(self, clazz, consumer, facts): facts: a dictionary of system facts """ super(PostRegistrationConduit, self).__init__(clazz) - self.consumer = consumer - self.facts = facts + self.consumer: dict = consumer + self.facts: Dict[str, str] = facts class ProductConduit(BaseConduit): @@ -302,14 +308,14 @@ class ProductConduit(BaseConduit): slots = ["pre_product_id_install", "post_product_id_install"] - def __init__(self, clazz, product_list): + def __init__(self, clazz: Type[SubManPlugin], product_list: List[dict]): """init for ProductConduit Args: product_list: A list of ProductCertificate objects """ super(ProductConduit, self).__init__(clazz) - self.product_list = product_list + self.product_list: List[dict] = product_list class ProductUpdateConduit(BaseConduit): @@ -317,14 +323,14 @@ class ProductUpdateConduit(BaseConduit): slots = ["pre_product_id_update", "post_product_id_update"] - def __init__(self, clazz, product_list): + def __init__(self, clazz: Type[SubManPlugin], product_list: List[dict]): """init for ProductUpdateConduit Args: product_list: A list of ProductCertificate objects """ super(ProductUpdateConduit, self).__init__(clazz) - self.product_list = product_list + self.product_list: List[dict] = product_list class FactsConduit(BaseConduit): @@ -332,14 +338,14 @@ class FactsConduit(BaseConduit): slots = ["post_facts_collection"] - def __init__(self, clazz, facts): + def __init__(self, clazz: Type[SubManPlugin], facts: List[dict]): """init for FactsConduit Args: facts: a dictionary of system facts """ super(FactsConduit, self).__init__(clazz) - self.facts = facts + self.facts: List[dict] = facts class UpdateContentConduit(BaseConduit): @@ -347,7 +353,9 @@ class UpdateContentConduit(BaseConduit): slots = ["update_content"] - def __init__(self, clazz, reports, ent_source): + def __init__( + self, clazz: Type[SubManPlugin], reports: "ContentPluginActionReport", ent_source: "EntitlementSource" + ): """init for UpdateContentConduit. Args: @@ -355,8 +363,8 @@ def __init__(self, clazz, reports, ent_source): ent_source: a EntitlementSource instance """ super(UpdateContentConduit, self).__init__(clazz) - self.reports = reports - self.ent_source = ent_source + self.reports: ContentPluginActionReport = reports + self.ent_source: EntitlementSource = ent_source class SubscriptionConduit(BaseConduit): @@ -364,7 +372,7 @@ class SubscriptionConduit(BaseConduit): slots = ["pre_subscribe"] - def __init__(self, clazz, consumer_uuid, pool_id, quantity): + def __init__(self, clazz: Type[SubManPlugin], consumer_uuid: str, pool_id: str, quantity: int): """init for SubscriptionConduit Args: @@ -374,15 +382,15 @@ def __init__(self, clazz, consumer_uuid, pool_id, quantity): auto: is this an auto-attach/healing event. """ super(SubscriptionConduit, self).__init__(clazz) - self.consumer_uuid = consumer_uuid - self.pool_id = pool_id - self.quantity = quantity + self.consumer_uuid: str = consumer_uuid + self.pool_id: str = pool_id + self.quantity: int = quantity class PostSubscriptionConduit(BaseConduit): slots = ["post_subscribe"] - def __init__(self, clazz, consumer_uuid, entitlement_data): + def __init__(self, clazz: Type[SubManPlugin], consumer_uuid: str, entitlement_data: Dict): """init for PostSubscriptionConduit Args: @@ -390,14 +398,14 @@ def __init__(self, clazz, consumer_uuid, entitlement_data): entitlement_data: the data returned by the server """ super(PostSubscriptionConduit, self).__init__(clazz) - self.consumer_uuid = consumer_uuid - self.entitlement_data = entitlement_data + self.consumer_uuid: str = consumer_uuid + self.entitlement_data: Dict = entitlement_data class AutoAttachConduit(BaseConduit): slots = ["pre_auto_attach"] - def __init__(self, clazz, consumer_uuid): + def __init__(self, clazz: Type[SubManPlugin], consumer_uuid: str): """ init for AutoAttachConduit @@ -405,13 +413,13 @@ def __init__(self, clazz, consumer_uuid): consumer_uuid: the UUID of the consumer being auto-subscribed """ super(AutoAttachConduit, self).__init__(clazz) - self.consumer_uuid = consumer_uuid + self.consumer_uuid: str = consumer_uuid class PostAutoAttachConduit(PostSubscriptionConduit): slots = ["post_auto_attach"] - def __init__(self, clazz, consumer_uuid, entitlement_data): + def __init__(self, clazz: Type[SubManPlugin], consumer_uuid: str, entitlement_data: Dict): """init for PostAutoAttachConduit Args: @@ -430,9 +438,9 @@ class PluginConfig(object): Used to find the configuration file. """ - plugin_key = None + plugin_key: str = None - def __init__(self, plugin_key, plugin_conf_path=None): + def __init__(self, plugin_key: str, plugin_conf_path: str = None): """init for PluginConfig. Args: @@ -441,9 +449,9 @@ def __init__(self, plugin_key, plugin_conf_path=None): Raises: PluginConfigException: error when finding or loading plugin config """ - self.plugin_conf_path = plugin_conf_path - self.plugin_key = plugin_key - self.conf_files = [] + self.plugin_conf_path: str = plugin_conf_path + self.plugin_key: str = plugin_key + self.conf_files: List[str] = [] self.parser = SafeConfigParser() @@ -457,17 +465,17 @@ def __init__(self, plugin_key, plugin_conf_path=None): raise PluginConfigException(self.plugin_key, e) def _get_config_file_path(self): - conf_file = os.path.join(self.plugin_conf_path, self.plugin_key + ".conf") + conf_file: str = os.path.join(self.plugin_conf_path, self.plugin_key + ".conf") if not os.access(conf_file, os.R_OK): raise PluginConfigException(self.plugin_key, "Unable to find configuration file") # iniparse can handle a list of files, inc an empty list # reading an empty list is basically the None constructor self.conf_files.append(conf_file) - def is_plugin_enabled(self): - """returns True if the plugin is enabled in it's config.""" + def is_plugin_enabled(self) -> bool: + """returns True if the plugin is enabled in its config.""" try: - enabled = self.parser.getboolean("main", "enabled") + enabled: bool = self.parser.getboolean("main", "enabled") except Exception as e: raise PluginConfigException(self.plugin_key, e) @@ -476,7 +484,7 @@ def is_plugin_enabled(self): return False return True - def __str__(self): + def __str__(self) -> str: buf = "plugin_key: %s\n" % (self.plugin_key) for conf_file in self.conf_files: buf = buf + "config file: %s\n" % conf_file @@ -492,11 +500,11 @@ class PluginHookRunner(object): a PluginHookRunner for each plugin hook to be triggered. """ - def __init__(self, conduit, func): - self.conduit = conduit - self.func = func + def __init__(self, conduit: BaseConduit, func: Callable): + self.conduit: BaseConduit = conduit + self.func: Callable = func - def run(self): + def run(self) -> None: try: self.func(self.conduit) except Exception as e: @@ -507,9 +515,9 @@ def run(self): # NOTE: need to be super paranoid here about existing of cfg variables # BasePluginManager with our default config info class BasePluginManager(object): - """Finds, load, and provides acccess to subscription-manager plugins.""" + """Finds, load, and provides access to subscription-manager plugins.""" - def __init__(self, search_path=None, plugin_conf_path=None): + def __init__(self, search_path: Optional[str] = None, plugin_conf_path: Optional[str] = None): """init for BasePluginManager(). attributes: @@ -526,25 +534,25 @@ def __init__(self, search_path=None, plugin_conf_path=None): self.plugin_conf_path = plugin_conf_path # list of modules to load plugins from - self.modules = self._get_modules() + self.modules: List = self._get_modules() # we track which modules we try to load plugins from - self._modules = {} + self._modules: Dict = {} # self._plugins is mostly for bookkeeping, it's a dict # that maps 'plugin_key':instance # 'plugin_key', aka plugin_module.plugin_class - # instance is the instaniated plugin class - self._plugins = {} + # instance is the instantiated plugin class + self._plugins: Dict[str, object] = {} # all found plugin classes, including classes that # are disable, and will not be instantiated - self._plugin_classes = {} + self._plugin_classes: Dict[str, Type[SubManPlugin]] = {} - self.conduits = [] + self.conduits: List[BaseConduit] = [] # maps a slot_name to a list of methods from a plugin class - self._slot_to_funcs = {} - self._slot_to_conduit = {} + self._slot_to_funcs: Dict[str, List[Callable]] = {} + self._slot_to_conduit: Dict[str, BaseConduit] = {} # find our list of conduits self.conduits = self._get_conduits() @@ -556,7 +564,7 @@ def __init__(self, search_path=None, plugin_conf_path=None): # populate self._plugins with plugins in modules in self.modules self._import_plugins() - def _get_conduits(self): + def _get_conduits(self) -> List[BaseConduit]: """Needs to be implemented in subclass. Returns: @@ -564,7 +572,7 @@ def _get_conduits(self): """ return [] - def _get_modules(self): + def _get_modules(self) -> List: """Needs to be implemented in subclass. Returns: @@ -572,7 +580,7 @@ def _get_modules(self): """ return [] - def _import_plugins(self): + def _import_plugins(self) -> None: """Needs to be implemented in subclass. This loads plugin modules, checks them, and loads plugins @@ -584,14 +592,16 @@ def _import_plugins(self): log.debug("loaded plugin modules: %s" % self.modules) log.debug("loaded plugins: %s" % self._plugins) - def _populate_slots(self): + def _populate_slots(self) -> None: for conduit_class in self.conduits: slots = conduit_class.slots for slot in slots: self._slot_to_conduit[slot] = conduit_class self._slot_to_funcs[slot] = [] - def add_plugins_from_modules(self, modules, plugin_to_config_map=None): + def add_plugins_from_modules( + self, modules: List[type], plugin_to_config_map: Dict[str, PluginConfig] = None + ) -> None: """Add SubMan plugins from a list of modules Args: @@ -610,7 +620,9 @@ def add_plugins_from_modules(self, modules, plugin_to_config_map=None): log.exception(e) log.error(e) - def add_plugins_from_module(self, module, plugin_to_config_map=None): + def add_plugins_from_module( + self, module: List[type], plugin_to_config_map: Dict[str, PluginConfig] = None + ): """add SubManPlugin based plugins from a module. Will also look for a PluginConfig() associated with the @@ -635,12 +647,12 @@ def add_plugins_from_module(self, module, plugin_to_config_map=None): # verify we are a class, and in particular, a subclass # of SubManPlugin - def is_plugin(c): + def is_plugin(c) -> bool: return inspect.isclass(c) and c.__module__ == module.__name__ and issubclass(c, SubManPlugin) # note we sort the list of plugin classes, since that potentially # alters order hooks are mapped to slots - plugin_classes = sorted(inspect.getmembers(module, is_plugin)) + plugin_classes: List[Tuple[str, SubManPlugin]] = sorted(inspect.getmembers(module, is_plugin)) # find all the plugin classes with valid configs first # then add them, so we skip the module if a class has a bad config @@ -659,7 +671,9 @@ def is_plugin(c): # let some classes from a module fail self.add_plugin_class(plugin_class, plugin_to_config_map=plugin_to_config_map) - def add_plugin_class(self, plugin_clazz, plugin_to_config_map=None): + def add_plugin_class( + self, plugin_clazz: Type[SubManPlugin], plugin_to_config_map: Dict[str, PluginConfig] = None + ) -> None: """Add a SubManPlugin and PluginConfig class to PluginManager. Args: @@ -738,13 +752,13 @@ def add_plugin_class(self, plugin_clazz, plugin_to_config_map=None): if class_is_used: plugin_clazz.found_slots_for_hooks = True - def _track_plugin_class_to_modules(self, plugin_clazz): + def _track_plugin_class_to_modules(self, plugin_clazz: Type[SubManPlugin]) -> None: """Keep a map of plugin classes loaded from each plugin module.""" if plugin_clazz.__module__ not in self._modules: self._modules[plugin_clazz.__module__] = [] self._modules[plugin_clazz.__module__].append(plugin_clazz) - def run(self, slot_name, **kwargs): + def run(self, slot_name: str, **kwargs) -> None: """For slot_name, run the registered hooks with kwargs. Args: @@ -762,7 +776,7 @@ def run(self, slot_name, **kwargs): for runner in self.runiter(slot_name, **kwargs): runner.run() - def runiter(self, slot_name, **kwargs): + def runiter(self, slot_name: str, **kwargs) -> Iterator[PluginHookRunner]: """Return an iterable of PluginHookRunner objects. The iterable will return a PluginHookRunner object @@ -776,16 +790,17 @@ def runiter(self, slot_name, **kwargs): if slot_name not in self._slot_to_funcs: raise SlotNameException(slot_name) + func: Callable for func in self._slot_to_funcs[slot_name]: module = inspect.getmodule(func) - func_module_name = getattr(func, "__module__") + func_module_name: str = getattr(func, "__module__") if not func_module_name: if module: func_module_name = module.__name__ else: func_module_name = "unknown_module" - func_class_name = func.__self__.__class__.__name__ - plugin_key = ".".join([func_module_name, func_class_name]) + func_class_name: str = func.__self__.__class__.__name__ + plugin_key: str = ".".join([func_module_name, func_class_name]) log.debug("Running %s in %s" % (func.__name__, plugin_key)) # resolve slot_name to conduit # FIXME: handle cases where we don't have a conduit for a slot_name @@ -808,7 +823,7 @@ def runiter(self, slot_name, **kwargs): runner = PluginHookRunner(conduit_instance, func) yield runner - def _get_plugin_config(self, plugin_clazz, plugin_to_config_map=None): + def _get_plugin_config(self, plugin_clazz: Type[SubManPlugin], plugin_to_config_map=None) -> PluginConfig: """Get a PluginConfig for plugin_class, creating it if need be. If we have an entry in plugin_to_config_map for plugin_class, @@ -829,16 +844,16 @@ def _get_plugin_config(self, plugin_clazz, plugin_to_config_map=None): return PluginConfig(plugin_clazz.get_plugin_key(), self.plugin_conf_path) - def get_plugins(self): + def get_plugins(self) -> Dict[str, Type[SubManPlugin]]: """list of plugins.""" return self._plugin_classes - def get_slots(self): + def get_slots(self) -> List[str]: """list of slots Ordered by conduit name, for presentation. """ - # I'm sure a clever list comprension could replace this with one line + # I'm sure a clever list comprehension could replace this with one line # # The default sort of slots is pure lexical, so all the pre's come # first, which is weird. So this just sorts the slots by conduit name, @@ -865,7 +880,7 @@ class PluginManager(BasePluginManager): default_search_path = DEFAULT_SEARCH_PATH default_conf_path = DEFAULT_CONF_PATH - def __init__(self, search_path=None, plugin_conf_path=None): + def __init__(self, search_path: Optional[str] = None, plugin_conf_path: Optional[str] = None): """init PluginManager Args: @@ -890,7 +905,7 @@ def __init__(self, search_path=None, plugin_conf_path=None): search_path=init_search_path, plugin_conf_path=init_plugin_conf_path ) - def _get_conduits(self): + def _get_conduits(self) -> List[type(BaseConduit)]: """get subscription-manager specific plugin conduits.""" # we should be able to collect this from the sub classes of BaseConduit return [ @@ -908,12 +923,12 @@ def _get_conduits(self): ] def _get_modules(self): - module_files = self._find_plugin_module_files(self.search_path) - plugin_modules = self._load_plugin_module_files(module_files) + module_files: List[str] = self._find_plugin_module_files(self.search_path) + plugin_modules: List = self._load_plugin_module_files(module_files) return plugin_modules # subman specific module/plugin loading - def _find_plugin_module_files(self, search_path): + def _find_plugin_module_files(self, search_path) -> List[Type[SubManPlugin]]: """Load all the plugins in the search path. Raise: @@ -933,8 +948,8 @@ def _find_plugin_module_files(self, search_path): module_files.sort() return module_files - def _load_plugin_module_files(self, module_files): - modules = [] + def _load_plugin_module_files(self, module_files: List[str]) -> List[Type[SubManPlugin]]: + modules: List[Type[SubManPlugin]] = [] for module_file in module_files: try: modules.append(self._load_plugin_module_file(module_file)) @@ -943,7 +958,7 @@ def _load_plugin_module_files(self, module_files): return modules - def _load_plugin_module_file(self, module_file): + def _load_plugin_module_file(self, module_file: str) -> Type[SubManPlugin]: """Loads SubManPlugin class from a module file. Args: @@ -978,13 +993,13 @@ def _load_plugin_module_file(self, module_file): return loaded_module -def parse_version(api_version): +def parse_version(api_version: str) -> Tuple[int, int]: """parse an API version string into major and minor version strings.""" maj_ver, min_ver = api_version.split(".") return int(maj_ver), int(min_ver) -def api_version_ok(a, b): +def api_version_ok(a: str, b: str) -> bool: """ Return true if API version "a" supports API version "b" """ diff --git a/src/subscription_manager/printing_utils.py b/src/subscription_manager/printing_utils.py index 103ff3610a..f5310b93e8 100644 --- a/src/subscription_manager/printing_utils.py +++ b/src/subscription_manager/printing_utils.py @@ -14,6 +14,7 @@ import fnmatch import re import logging +from typing import Callable, List from subscription_manager.unicode_width import textual_width as utf8_width from subscription_manager.utils import get_terminal_width @@ -27,11 +28,11 @@ FONT_NORMAL = "\033[0m" -def ljust_wide(in_str, padding): +def ljust_wide(in_str: str, padding: int): return in_str + " " * (padding - utf8_width(in_str)) -def columnize(caption_list, callback, *args, **kwargs): +def columnize(caption_list: List[str], callback: Callable, *args, **kwargs) -> str: """ Take a list of captions and values and columnize the output so that shorter captions are padded to be the same length as the longest caption. @@ -43,21 +44,21 @@ def columnize(caption_list, callback, *args, **kwargs): The callback gives us the ability to do things like replacing None values with the string "None" (see none_wrap_columnize_callback()). """ - indent = kwargs.get("indent", 0) - caption_list = [" " * indent + caption for caption in caption_list] - columns = get_terminal_width() - padding = sorted(map(utf8_width, caption_list))[-1] + 1 + indent: int = kwargs.get("indent", 0) + caption_list: List[str] = [" " * indent + caption for caption in caption_list] + columns: int = get_terminal_width() + padding: int = sorted(map(utf8_width, caption_list))[-1] + 1 if columns: padding = min(padding, int(columns / 2)) - padded_list = [] + padded_list: List[str] = [] for caption in caption_list: - lines = format_name(caption, indent, padding - 1).split("\n") + lines: List[str] = format_name(caption, indent, padding - 1).split("\n") lines[-1] = ljust_wide(lines[-1], padding) + "%s" - fixed_caption = "\n".join(lines) + fixed_caption: str = "\n".join(lines) padded_list.append(fixed_caption) lines = list(zip(padded_list, args)) - output = [] + output: List[str] = [] for (caption, value) in lines: kwargs["caption"] = caption if isinstance(value, dict): @@ -80,7 +81,7 @@ def columnize(caption_list, callback, *args, **kwargs): return "\n".join(output) -def format_name(name, indent, max_length): +def format_name(name: str, indent: int, max_length: int) -> str: """ Formats a potentially long name for multi-line display, giving it a columned effect. Assumes the first line is already @@ -89,6 +90,7 @@ def format_name(name, indent, max_length): if not name or not max_length or (max_length - indent) <= 2 or not isinstance(name, str): return name if not isinstance(name, str): + # FIXME This is not necessary in Python 3 code name = name.decode("utf-8") words = name.split() lines = [] @@ -134,11 +136,11 @@ def add_line(): return "\n".join(lines) -def highlight_by_filter_string_columnize_cb(template_str, *args, **kwargs): +def highlight_by_filter_string_columnize_cb(template_str: str, *args, **kwargs) -> str: """ Takes a template string and arguments and highlights word matches - when the value contains a match to the filter_string.This occurs - only when the row caption exists in the match columns. Mainly this + when the value contains a match to the filter_string. This occurs + only when the row caption exists in the match columns. Mainly this is a callback meant to be used by columnize(). """ filter_string = kwargs.get("filter_string") @@ -170,7 +172,7 @@ def highlight_by_filter_string_columnize_cb(template_str, *args, **kwargs): return template_str % tuple(arglist) -def none_wrap_columnize_callback(template_str, *args, **kwargs): +def none_wrap_columnize_callback(template_str: str, *args, **kwargs) -> str: """ Takes a template string and arguments and replaces any None arguments with the word "None" before rendering the template. Mainly this is @@ -184,7 +186,7 @@ def none_wrap_columnize_callback(template_str, *args, **kwargs): return template_str % tuple(arglist) -def echo_columnize_callback(template_str, *args, **kwargs): +def echo_columnize_callback(template_str: str, *args, **kwargs) -> str: """ Just takes a template string and arguments and renders it. Mainly this is a callback meant to be used by columnize(). diff --git a/src/subscription_manager/productid.py b/src/subscription_manager/productid.py index 03e7f61ed7..1da5da3ef1 100644 --- a/src/subscription_manager/productid.py +++ b/src/subscription_manager/productid.py @@ -17,6 +17,8 @@ import logging import os +from typing import Callable, Dict, List, Literal, Optional, Union + # for labelCompare import rpm @@ -46,6 +48,7 @@ def __init__(self): class ProductIdRepoMap(utils.DefaultDict): def __init__(self, *args, **kwargs): self.default_factory = list + # FIXME Missing super() call class ProductDatabase(object): @@ -54,24 +57,24 @@ def __init__(self): self.content = ProductIdRepoMap() self.create() - def add(self, product, repo): + def add(self, product: str, repo: str) -> None: self.content[product].append(repo) # TODO: need way to delete one prod->repo map - def delete(self, product): + def delete(self, product: str) -> None: try: del self.content[product] except Exception: pass - def find_repos(self, product): + def find_repos(self, product) -> Optional[str]: return self.content.get(product, None) - def create(self): + def create(self) -> None: if not os.path.exists(self.__fn()): self.write() - def read(self): + def read(self) -> None: f = open(self.__fn()) try: d = json.load(f) @@ -81,7 +84,7 @@ def read(self): pass f.close() - def populate_content(self, db_dict): + def populate_content(self, db_dict: Dict[str, str]): """Populate map with info from a productid -> [repoids] map. Note this needs to support the old form of @@ -93,7 +96,7 @@ def populate_content(self, db_dict): else: self.content[productid] = repo_data - def write(self): + def write(self) -> None: f = open(self.__fn(), "w") try: json.dump(self.content, f, indent=2, default=json.encode) @@ -101,14 +104,16 @@ def write(self): pass f.close() - def __fn(self): + def __fn(self) -> str: return self.dir.abspath("productid.js") class ComparableMixin(object): """Needs compare_keys to be implemented.""" - def _compare(self, keys, method): + # FIXME self.compare_keys should be defined here, raising NotImplementedError + + def _compare(self, keys: List, method: Callable) -> Union[Literal[NotImplemented], bool]: return method(keys[0], keys[1]) if keys else NotImplemented def __eq__(self, other): diff --git a/src/subscription_manager/reasons.py b/src/subscription_manager/reasons.py index 1ce008eec3..8080eb02c8 100644 --- a/src/subscription_manager/reasons.py +++ b/src/subscription_manager/reasons.py @@ -11,6 +11,7 @@ # granted to use or replicate Red Hat trademarks that are incorporated # in this software or its documentation. # +from typing import Dict, List from subscription_manager.i18n import ugettext as _ @@ -25,19 +26,19 @@ def __init__(self, reasons, sorter): self.reasons = reasons self.sorter = sorter - def get_subscription_reasons(self, sub_id): + def get_subscription_reasons(self, sub_id) -> List[str]: """ returns reasons for sub_id, or empty list if there are none. """ return self.get_subscription_reasons_map().get(sub_id, []) - def get_subscription_reasons_map(self): + def get_subscription_reasons_map(self) -> Dict[str, List[str]]: """ returns a dictionary that maps valid entitlements to lists of reasons. """ - result = {} + result: Dict[str, List[str]] = {} for s in self.sorter.valid_entitlement_certs: result[s.subject["CN"]] = [] @@ -59,7 +60,7 @@ def get_subscription_reasons_map(self): result[s_id].append(reason["message"]) return result - def get_name_message_map(self): + def get_name_message_map(self) -> Dict[str, List[str]]: result = {} for reason in self.reasons: reason_name = reason["attributes"]["name"] @@ -70,7 +71,7 @@ def get_name_message_map(self): result[reason_name].append(reason["message"]) return result - def get_reason_ids_map(self): + def get_reason_ids_map(self) -> Dict[str, List[str]]: result = {} for reason in self.reasons: if "attributes" in reason and "product_id" in reason["attributes"]: @@ -87,14 +88,14 @@ def get_reason_ids_map(self): ) return result - def get_stack_subscriptions(self, stack_id): + def get_stack_subscriptions(self, stack_id) -> List[str]: result = set([]) for s in self.sorter.valid_entitlement_certs: if s.order.stacking_id and s.order.stacking_id == stack_id: result.add(s.subject["CN"]) return list(result) - def get_reason_id(self, reason): + def get_reason_id(self, reason) -> str: # returns ent/prod/stack id # ex: Subscription 123456 if "product_id" in reason["attributes"]: @@ -108,7 +109,7 @@ def get_reason_id(self, reason): # Reason has no id attr return _("Unknown") - def get_product_reasons(self, prod): + def get_product_reasons(self, prod) -> List[str]: """ Returns a list of reason messages that apply to the installed product @@ -142,7 +143,7 @@ def get_product_reasons(self, prod): result.add(reason["message"]) return list(result) - def get_product_subscriptions(self, prod): + def get_product_subscriptions(self, prod) -> List[str]: """ Returns a list of subscriptions that provide the product. diff --git a/src/subscription_manager/release.py b/src/subscription_manager/release.py index 20a196864f..3b90da59bf 100644 --- a/src/subscription_manager/release.py +++ b/src/subscription_manager/release.py @@ -19,23 +19,33 @@ import socket import http.client + +from typing import Dict, Iterable, List, Optional, Set, Tuple, TYPE_CHECKING + from rhsm.https import ssl -from rhsm.connection import NoValidEntitlement import rhsm.config +from rhsm.connection import NoValidEntitlement from subscription_manager import injection as inj from subscription_manager import listing from subscription_manager import rhelproduct from subscription_manager.i18n import ugettext as _ +if TYPE_CHECKING: + from rhsm.connection import ContentConnection + from rhsm.certificate2 import Certificate, Content, EntitlementCertificate, Product, ProductCertificate + from subscription_manager.cp_provider import CPProvider + from subscription_manager.certdirectory import EntitlementDirectory, ProductDirectory + + log = logging.getLogger(__name__) cfg = rhsm.config.get_config_parser() class MultipleReleaseProductsError(ValueError): - def __init__(self, certificates): + def __init__(self, certificates: Iterable["Certificate"]): self.certificates = certificates self.certificate_paths = ", ".join([certificate.path for certificate in certificates]) super(ValueError, self).__init__( @@ -45,7 +55,7 @@ def __init__(self, certificates): ) ) - def translated_message(self): + def translated_message(self) -> str: return _( "Error: More than one release product certificate installed. Certificate paths: %s" ) % ", ".join([certificate.path for certificate in self.certificates]) @@ -58,7 +68,7 @@ def __init__(self): class ReleaseBackend(object): def get_releases(self): - provider = self._get_release_version_provider() + provider: CdnReleaseVersionProvider = self._get_release_version_provider() return provider.get_releases() def _get_release_version_provider(self): @@ -85,21 +95,21 @@ def _conn(self): class CdnReleaseVersionProvider(object): def __init__(self): - self.entitlement_dir = inj.require(inj.ENT_DIR) - self.product_dir = inj.require(inj.PROD_DIR) - self.cp_provider = inj.require(inj.CP_PROVIDER) - self.content_connection = self.cp_provider.get_content_connection() + self.entitlement_dir: EntitlementDirectory = inj.require(inj.ENT_DIR) + self.product_dir: ProductDirectory = inj.require(inj.PROD_DIR) + self.cp_provider: CPProvider = inj.require(inj.CP_PROVIDER) + self.content_connection: ContentConnection = self.cp_provider.get_content_connection() - def get_releases(self): + def get_releases(self) -> List[str]: # cdn base url # Find the rhel products - release_products = [] - certificates = set() - installed_products = self.product_dir.get_installed_products() + release_products: List[Product] = [] + certificates: Set[ProductCertificate] = set() + installed_products: Dict[str, EntitlementCertificate] = self.product_dir.get_installed_products() for product_hash in installed_products: - product_cert = installed_products[product_hash] - products = product_cert.products + product_cert: ProductCertificate = installed_products[product_hash] + products: List[Product] = product_cert.products for product in products: rhel_matcher = rhelproduct.RHELProductMatcher(product) if rhel_matcher.is_rhel(): @@ -113,18 +123,18 @@ def get_releases(self): raise MultipleReleaseProductsError(certificates=certificates) # Note: only release_products with one item can pass previous if-elif - release_product = release_products[0] - entitlements = self.entitlement_dir.list_for_product(release_product.id) + release_product: Product = release_products[0] + entitlements: List[EntitlementCertificate] = self.entitlement_dir.list_for_product(release_product.id) # When there is SCA entitlement certificate, then add this cert to # the list too. See: https://bugzilla.redhat.com/show_bug.cgi?id=1924921 - sca_entitlements = self.entitlement_dir.list_with_sca_mode() + sca_entitlements: List[EntitlementCertificate] = self.entitlement_dir.list_with_sca_mode() entitlements.extend(sca_entitlements) - listings = [] - ent_cert_key_pairs = set() + listings: List[str] = [] + ent_cert_key_pairs: Set[Tuple[str, str]] = set() for entitlement in entitlements: - contents = entitlement.content + contents: List[Content] = entitlement.content for content in contents: # ignore content that is not enabled # see bz #820639 @@ -143,11 +153,11 @@ def get_releases(self): # hmm. We are really only supposed to have one product # with one content with one listing file. We shall see. - releases = [] + releases: List[str] = [] listings = sorted(set(listings)) for listing_path in listings: try: - data = self.content_connection.get_versions( + data: Optional[dict] = self.content_connection.get_versions( path=listing_path, cert_key_pairs=ent_cert_key_pairs ) except (socket.error, http.client.HTTPException, ssl.SSLError, NoValidEntitlement) as e: @@ -170,7 +180,7 @@ def get_releases(self): releases_set = sorted(set(releases)) return releases_set - def _build_listing_path(self, content_url): + def _build_listing_path(self, content_url: str) -> str: listing_parts = content_url.split("$releasever", 1) listing_base = listing_parts[0] # FIXME: cleanup paths ("//"'s, etc) @@ -179,7 +189,7 @@ def _build_listing_path(self, content_url): # require tags provided by installed products? - def _is_correct_rhel(self, product_tags, content_tags): + def _is_correct_rhel(self, product_tags: List[str], content_tags: List[str]) -> bool: # easy to pass a string instead of a list assert not isinstance(product_tags, str) assert not isinstance(content_tags, str) diff --git a/src/subscription_manager/repofile.py b/src/subscription_manager/repofile.py index 650d963fea..82ae663b00 100644 --- a/src/subscription_manager/repofile.py +++ b/src/subscription_manager/repofile.py @@ -15,6 +15,7 @@ # granted to use or replicate Red Hat trademarks that are incorporated # in this software or its documentation. # +from typing import Dict, List, Literal, Optional, TextIO, Tuple, TYPE_CHECKING from iniparse import RawConfigParser as ConfigParser import logging @@ -39,6 +40,10 @@ from rhsmlib.services import config +if TYPE_CHECKING: + from subscription_manager.model import Content + from subscription_manager.repolib import YumReleaseverSource + log = logging.getLogger(__name__) conf = config.Config(get_config_parser()) @@ -51,7 +56,7 @@ class Repo(dict): # (name, mutable, default) - The mutability information is only used in disconnected cases - PROPERTIES = { + PROPERTIES: Dict[str, Tuple[int, Optional[Literal["0", "1"]]]] = { "name": (0, None), "baseurl": (0, None), "enabled": (1, "1"), @@ -70,17 +75,18 @@ class Repo(dict): "ui_repoid_vars": (0, None), } - def __init__(self, repo_id, existing_values=None): + def __init__(self, repo_id: str, existing_values: List = None): + # FIXME Missing super() call if HAS_DEB822 is True: self.PROPERTIES["arches"] = (1, None) # existing_values is a list of 2-tuples existing_values = existing_values or [] - self.id = self._clean_id(repo_id) + self.id: str = self._clean_id(repo_id) # used to store key order, so we can write things out in the order # we read them from the config. - self._order = [] + self._order: List[str] = [] self.content_type = None @@ -104,13 +110,15 @@ def copy(self): return new_repo @classmethod - def from_ent_cert_content(cls, content, baseurl, ca_cert, release_source): + def from_ent_cert_content( + cls, content: "Content", baseurl: str, ca_cert: str, release_source: "YumReleaseverSource" + ) -> "Repo": """Create an instance of Repo() from an ent_cert.EntitlementCertContent(). And the other out of band info we need including baseurl, ca_cert, and the release version string. """ - repo = cls(content.label) + repo: Repo = cls(content.label) repo.content_type = content.content_type @@ -161,7 +169,7 @@ def from_ent_cert_content(cls, content, baseurl, ca_cert, release_source): return repo @staticmethod - def _set_proxy_info(repo): + def _set_proxy_info(repo: "Repo") -> "Repo": proxy = "" proxy_scheme = conf["server"]["proxy_scheme"] @@ -195,7 +203,7 @@ def _set_proxy_info(repo): return repo @staticmethod - def _expand_releasever(release_source, contenturl): + def _expand_releasever(release_source: "YumReleaseverSource", contenturl: str) -> str: # no $releasever to expand if release_source.marker not in contenturl: return contenturl @@ -210,7 +218,7 @@ def _expand_releasever(release_source, contenturl): # trusted. return contenturl.replace(release_source.marker, expansion) - def _clean_id(self, repo_id): + def _clean_id(self, repo_id: str) -> str: """ Format the config file id to contain only characters that yum expects (we'll just replace 'bad' chars with -) @@ -225,7 +233,7 @@ def _clean_id(self, repo_id): return new_id - def items(self): + def items(self) -> Tuple[Tuple[str, Tuple[int, Optional[Literal["0", "1"]]]], ...]: """ Called when we fetch the items for this yum repo to write to disk. """ @@ -238,12 +246,12 @@ def items(self): # to get into our dict. return tuple([(k, self[k]) for k in self._order if k in self and self[k]]) - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: Optional[Literal["0", "1"]]): if key not in self._order: self._order.append(key) dict.__setitem__(self, key, value) - def __str__(self): + def __str__(self) -> str: s = [] s.append("[%s]" % self.id) for k in self.PROPERTIES: @@ -254,15 +262,14 @@ def __str__(self): return "\n".join(s) - def __eq__(self, other): + def __eq__(self, other: "Repo") -> bool: return self.id == other.id - def __hash__(self): + def __hash__(self) -> int: return hash(self.id) -def manage_repos_enabled(): - +def manage_repos_enabled() -> bool: try: manage_repos = conf["rhsm"].get_int("manage_repos") except ValueError as e: @@ -289,12 +296,12 @@ class TidyWriter(object): versions. """ - def __init__(self, backing_file): + def __init__(self, backing_file: TextIO): self.backing_file = backing_file - self.ends_with_newline = False - self.writing_empty_lines = False + self.ends_with_newline: bool = False + self.writing_empty_lines: bool = False - def write(self, line): + def write(self, line: str) -> None: lines = line.split("\n") i = 0 while i < len(lines): @@ -317,7 +324,7 @@ def write(self, line): else: self.ends_with_newline = False - def close(self): + def close(self) -> None: if not self.ends_with_newline: self.backing_file.write("\n") @@ -327,34 +334,33 @@ class RepoFileBase(object): Base class for managing repository. """ - PATH = None - NAME = None - REPOFILE_HEADER = None + PATH: str = None + NAME: str = None + REPOFILE_HEADER: str = None - def __init__(self, path=None, name=None): - # note PATH get's expanded with chroot info, etc + def __init__(self, path: Optional[str] = None, name: Optional[str] = None): + # PATH gets expanded with chroot info, etc path = path or self.PATH name = name or self.NAME - self.path = Path.join(path, name) - self.repos_dir = Path.abs(path) - self.manage_repos = manage_repos_enabled() + self.path: str = Path.join(path, name) + self.repos_dir: str = Path.abs(path) + self.manage_repos: bool = manage_repos_enabled() if self.manage_repos is True: self.create() # Easier than trying to mock/patch os.path.exists - def path_exists(self, path): + def path_exists(self, path: str) -> bool: """ Wrapper around os.path.exists """ return os.path.exists(path) - def exists(self): + def exists(self) -> bool: return self.path_exists(self.path) - def create_dir_path(self): + def create_dir_path(self) -> None: """ Try to create directory for .repo files - :return: None """ if not self.path_exists(self.repos_dir): log.debug("The directory %s does not exist. Trying to create it" % self.PATH) @@ -365,10 +371,9 @@ def create_dir_path(self): else: log.debug("The directory %s already exists" % self.repos_dir) - def create(self): + def create(self) -> None: """ Try to create new repo file. - :return: None """ self.create_dir_path() if self.path_exists(self.path) or not self.manage_repos: @@ -376,15 +381,15 @@ def create(self): with open(self.path, "w") as f: f.write(self.REPOFILE_HEADER) - def fix_content(self, content): + def fix_content(self, content: str) -> str: return content @classmethod - def installed(cls): + def installed(cls) -> bool: return os.path.exists(Path.abs(cls.PATH)) @classmethod - def server_value_repo_file(cls): + def server_value_repo_file(cls) -> "RepoFileBase": return cls("var/lib/rhsm/repo_server_val/") @@ -392,10 +397,10 @@ def server_value_repo_file(cls): class AptRepoFile(RepoFileBase): - PATH = "etc/apt/sources.list.d" - NAME = "rhsm.sources" - CONTENT_TYPES = ["deb"] - REPOFILE_HEADER = """# + PATH: str = "etc/apt/sources.list.d" + NAME: str = "rhsm.sources" + CONTENT_TYPES: List[str] = ["deb"] + REPOFILE_HEADER: str = """# # Certificate-Based Repositories # Managed by (rhsm) subscription-manager # @@ -409,7 +414,7 @@ class AptRepoFile(RepoFileBase): # """ - def __init__(self, path=None, name=None): + def __init__(self, path: Optional[str] = None, name: Optional[str] = None): super(AptRepoFile, self).__init__(path, name) self.repos822 = [] @@ -500,14 +505,14 @@ class YumRepoFile(RepoFileBase, ConfigParser): # """ - def __init__(self, path=None, name=None): + def __init__(self, path: Optional[str] = None, name: Optional[str] = None): ConfigParser.__init__(self) RepoFileBase.__init__(self, path, name) - def read(self): + def read(self) -> None: ConfigParser.read(self, self.path) - def _configparsers_equal(self, otherparser): + def _configparsers_equal(self, otherparser) -> bool: if set(otherparser.sections()) != set(self.sections()): return False @@ -518,7 +523,7 @@ def _configparsers_equal(self, otherparser): return False return True - def _has_changed(self): + def _has_changed(self) -> bool: """ Check if the version on disk is different from what we have loaded """ @@ -526,7 +531,7 @@ def _has_changed(self): on_disk.read(self.path) return not self._configparsers_equal(on_disk) - def write(self): + def write(self) -> None: if not self.manage_repos: log.debug("Skipping write due to manage_repos setting: %s" % self.path) return @@ -536,14 +541,14 @@ def write(self): ConfigParser.write(self, tidy_writer) tidy_writer.close() - def add(self, repo): + def add(self, repo: "Repo") -> None: self.add_section(repo.id) self.update(repo) - def delete(self, section): + def delete(self, section) -> bool: return self.remove_section(section) - def update(self, repo): + def update(self, repo: "Repo") -> None: # Need to clear out the old section to allow unsetting options: # don't use remove section though, as that will reorder sections, # and move whitespace around (resulting in more and more whitespace @@ -554,7 +559,7 @@ def update(self, repo): for k, v in list(repo.items()): ConfigParser.set(self, repo.id, k, v) - def section(self, section): + def section(self, section: str) -> "Repo": if self.has_section(section): return Repo(section, self.items(section)) @@ -579,16 +584,16 @@ class ZypperRepoFile(YumRepoFile): # """ - def __init__(self, path=None, name=None): + def __init__(self, path: Optional[str] = None, name: Optional[str] = None): super(ZypperRepoFile, self).__init__(path, name) - self.gpgcheck = False - self.repo_gpgcheck = False - self.autorefresh = False + self.gpgcheck: bool = False + self.repo_gpgcheck: bool = False + self.autorefresh: bool = False # According to # https://github.com/openSUSE/libzypp/blob/67f55b474d67f77c1868955da8542a7acfa70a9f/zypp/media/MediaManager.h#L394 # the following values are valid: "yes", "no", "host", "peer" - self.gpgkey_ssl_verify = None - self.repo_ssl_verify = None + self.gpgkey_ssl_verify: Optional[str] = None + self.repo_ssl_verify: Optional[str] = None def read_zypp_conf(self): """ @@ -608,7 +613,7 @@ def read_zypp_conf(self): if zypp_cfg.has_option("rhsm-plugin", "repo-ssl-verify"): self.repo_ssl_verify = zypp_cfg.get("rhsm-plugin", "repo-ssl-verify") - def fix_content(self, content): + def fix_content(self, content: "Content") -> str: self.read_zypp_conf() zypper_cont = content.copy() sslverify = zypper_cont["sslverify"] @@ -654,7 +659,7 @@ def fix_content(self, content): baseurl = zypper_cont["baseurl"] parsed = urlparse(baseurl) - zypper_query_args = parse_qs(parsed.query) + zypper_query_args: Dict[str, str] = parse_qs(parsed.query) if sslverify and sslverify in ["1"]: if self.repo_ssl_verify: @@ -685,21 +690,21 @@ def fix_content(self, content): # We need to overwrite this, to avoid name clashes with yum's server_val_repo_file @classmethod - def server_value_repo_file(cls): + def server_value_repo_file(cls) -> "ZypperRepoFile": return cls("var/lib/rhsm/repo_server_val/", "zypper_{}".format(cls.NAME)) -def init_repo_file_classes(): - repo_file_classes = [YumRepoFile, ZypperRepoFile] +def init_repo_file_classes() -> List[Tuple[type(RepoFileBase), str]]: + repo_file_classes: List[type(RepoFileBase)] = [YumRepoFile, ZypperRepoFile] if HAS_DEB822: repo_file_classes.append(AptRepoFile) - _repo_files = [ + _repo_files: List[Tuple[type(RepoFileBase), type(RepoFileBase)]] = [ (RepoFile, RepoFile.server_value_repo_file) for RepoFile in repo_file_classes if RepoFile.installed() ] return _repo_files -def get_repo_file_classes(): +def get_repo_file_classes() -> List[Tuple[type(RepoFileBase), type(RepoFileBase)]]: global repo_files if not repo_files: repo_files = init_repo_file_classes() diff --git a/src/subscription_manager/repolib.py b/src/subscription_manager/repolib.py index ee8c8adb9b..91c7f70537 100644 --- a/src/subscription_manager/repolib.py +++ b/src/subscription_manager/repolib.py @@ -12,7 +12,7 @@ # Red Hat trademarks are not licensed under GPLv2. No permission is # granted to use or replicate Red Hat trademarks that are incorporated # in this software or its documentation. -# +from typing import Dict, Iterable, List, Literal, Optional, Set, Tuple, Union, TYPE_CHECKING from iniparse import RawConfigParser as ConfigParser import logging @@ -37,6 +37,15 @@ from subscription_manager.i18n import ugettext as _ +if TYPE_CHECKING: + from rhsm.connection import UEPConnection + + from subscription_manager.certdirectory import EntitlementDirectory, ProductDirectory + from subscription_manager.cp_provider import CPProvider + from subscription_manager.certlib import Locker + from subscription_manager.identity import Identity + from subscription_manager.model import Content + log = logging.getLogger(__name__) conf = config.Config(rhsm.config.get_config_parser()) @@ -60,7 +69,7 @@ class YumPluginManager(object): PLUGIN_DISABLED = 0 @staticmethod - def is_auto_enable_enabled(): + def is_auto_enable_enabled() -> bool: """ Automatic enabling of yum plugins can be explicitly disabled in /etc/rhsm/rhsm.conf Try to get this configuration. @@ -80,7 +89,7 @@ def is_auto_enable_enabled(): return bool(auto_enable_yum_plugins) @staticmethod - def warning_message(enabled_yum_plugins): + def warning_message(enabled_yum_plugins: List[str]) -> str: message = _( "The yum/dnf plugins: %s were automatically enabled for the benefit of " "Red Hat Subscription Management. If not desired, use " @@ -90,7 +99,7 @@ def warning_message(enabled_yum_plugins): return message @classmethod - def _enable_plugins(cls, pkg_mgr_name, plugin_dir): + def _enable_plugins(cls, pkg_mgr_name: str, plugin_dir: str) -> List[str]: """ This class method tries to enable plugins for DNF or YUM :param pkg_mgr_name: It can be "dnf" or "yum" @@ -159,7 +168,7 @@ def _enable_plugins(cls, pkg_mgr_name, plugin_dir): return enabled_lugins @classmethod - def enable_pkg_plugins(cls): + def enable_pkg_plugins(cls) -> List[str]: """ This function tries to enable dnf/yum plugins: subscription-manager and product-id. It takes no action, when automatic enabling of yum plugins is disabled in rhsm.conf. @@ -188,21 +197,21 @@ def enable_pkg_plugins(cls): class RepoActionInvoker(BaseActionInvoker): """Invoker for yum/dnf repo updating related actions.""" - def __init__(self, cache_only=False, locker=None): + def __init__(self, cache_only: bool = False, locker: Optional["Locker"] = None): super(RepoActionInvoker, self).__init__(locker=locker) - self.cache_only = cache_only - self.identity = inj.require(inj.IDENTITY) + self.cache_only: bool = cache_only + self.identity: Identity = inj.require(inj.IDENTITY) - def _do_update(self): + def _do_update(self) -> Union[int, "RepoActionReport"]: action = RepoUpdateActionCommand(cache_only=self.cache_only) res = action.perform() return res - def is_managed(self, repo): + def is_managed(self, repo: Repo) -> bool: action = RepoUpdateActionCommand(cache_only=self.cache_only) return repo in [c.label for c in action.matching_content()] - def get_repos(self, apply_overrides=True): + def get_repos(self, apply_overrides: bool = True) -> Set[Repo]: action = RepoUpdateActionCommand(cache_only=self.cache_only, apply_overrides=apply_overrides) repos = action.get_unique_content() @@ -228,12 +237,12 @@ def get_repos(self, apply_overrides=True): return current - def get_repo_file(self): + def get_repo_file(self) -> str: yum_repo_file = YumRepoFile() return yum_repo_file.path @classmethod - def delete_repo_file(cls): + def delete_repo_file(cls) -> None: for repo_class, server_val_repo_class in get_repo_file_classes(): repo_file = repo_class() server_val_repo_file = server_val_repo_class() @@ -266,19 +275,19 @@ def __init__(self): self.release_status_cache = inj.require(inj.RELEASE_STATUS_CACHE) self._expansion = None - self.identity = inj.require(inj.IDENTITY) - self.cp_provider = inj.require(inj.CP_PROVIDER) + self.identity: Identity = inj.require(inj.IDENTITY) + self.cp_provider: CPProvider = inj.require(inj.CP_PROVIDER) # FIXME: these guys are really more of model helpers for the object # represent a release. @staticmethod - def is_not_empty(expansion): + def is_not_empty(expansion: Optional[dict]) -> bool: if expansion is None or len(expansion) == 0: return False return True @staticmethod - def is_set(result): + def is_set(result: dict) -> bool: """Check result for existing, and having a non empty value. Return True if result has a non empty, non null result['releaseVer'] @@ -293,7 +302,7 @@ def is_set(result): except Exception: return False - def get_expansion(self): + def get_expansion(self) -> str: # mem cache if self._expansion: return self._expansion @@ -340,25 +349,25 @@ class RepoUpdateActionCommand(object): Returns an RepoActionReport. """ - def __init__(self, cache_only=False, apply_overrides=True): - self.identity = inj.require(inj.IDENTITY) + def __init__(self, cache_only: bool = False, apply_overrides: bool = True): + self.identity: Identity = inj.require(inj.IDENTITY) # These should probably move closer their use - self.ent_dir = inj.require(inj.ENT_DIR) - self.prod_dir = inj.require(inj.PROD_DIR) + self.ent_dir: EntitlementDirectory = inj.require(inj.ENT_DIR) + self.prod_dir: ProductDirectory = inj.require(inj.PROD_DIR) self.ent_source = ent_cert.EntitlementDirEntitlementSource() - self.cp_provider = inj.require(inj.CP_PROVIDER) - self.uep = None + self.cp_provider: CPProvider = inj.require(inj.CP_PROVIDER) + self.uep: Optional[UEPConnection] = None - self.manage_repos = 1 - self.apply_overrides = apply_overrides + self.manage_repos: Union[bool, Literal[0, 1]] = 1 + self.apply_overrides: bool = apply_overrides self.manage_repos = manage_repos_enabled() self.release = None self.overrides = {} - self.override_supported = False + self.override_supported: bool = False try: self.override_supported = "content_overrides" in get_supported_resources( @@ -389,7 +398,7 @@ def __init__(self, cache_only=False, apply_overrides=True): # by the server. if self.override_supported: try: - override_cache = inj.require(inj.OVERRIDE_STATUS_CACHE) + override_cache: OverrideStatusCache = inj.require(inj.OVERRIDE_STATUS_CACHE) except KeyError: override_cache = OverrideStatusCache() @@ -404,12 +413,13 @@ def __init__(self, cache_only=False, apply_overrides=True): self.overrides[item["contentLabel"]] = {} self.overrides[item["contentLabel"]][item["name"]] = item["value"] - def get_consumer_auth_cp(self): + def get_consumer_auth_cp(self) -> "UEPConnection": if self.uep is None: self.uep = self.cp_provider.get_consumer_auth_cp() return self.uep - def perform(self): + def perform(self) -> Union["RepoActionReport", Literal[0]]: + # FIXME return None instead of 0 # the [rhsm] manage_repos can be overridden to disable generation of the # redhat.repo file: if not self.manage_repos: @@ -478,7 +488,7 @@ def perform(self): log.debug("repos updated: %s" % self.report) return self.report - def get_unique_content(self): + def get_unique_content(self) -> Iterable[Repo]: # FIXME Shouldn't this skip all of the repo updating? if not self.manage_repos: return [] @@ -496,13 +506,13 @@ def get_unique_content(self): # Expose as public API for RepoActionInvoker.is_managed, since that # is used by Openshift tooling. # See https://bugzilla.redhat.com/show_bug.cgi?id=1223038 - def matching_content(self): + def matching_content(self) -> List["Content"]: content = [] for content_type in ALLOWED_CONTENT_TYPES: content += model.find_content(self.ent_source, content_type=content_type) return content - def get_all_content(self, baseurl, ca_cert): + def get_all_content(self, baseurl: str, ca_cert: str) -> List[Repo]: matching_content = self.matching_content() content_list = [] @@ -541,7 +551,7 @@ def get_all_content(self, baseurl, ca_cert): return content_list - def _set_override_info(self, repo): + def _set_override_info(self, repo: Repo) -> Repo: # In the disconnected case, self.overrides will be an empty list for name, value in list(self.overrides.get(repo.id, {}).items()): @@ -549,22 +559,22 @@ def _set_override_info(self, repo): return repo - def _is_overridden(self, repo, key): + def _is_overridden(self, repo: Repo, key: str) -> bool: return key in self.overrides.get(repo.id, {}) - def _was_overridden(self, repo, key, value): + def _was_overridden(self, repo: Repo, key: str, value: object) -> bool: written_value = self.written_overrides.overrides.get(repo.id, {}).get(key) # Compare values as strings to avoid casting problems from io return written_value is not None and value is not None and str(written_value) == str(value) - def _build_props(self, old_repo, new_repo): + def _build_props(self, old_repo: Repo, new_repo: Repo) -> Dict[str, Tuple[int, str, None]]: result = {} all_keys = list(old_repo.keys()) + list(new_repo.keys()) for key in all_keys: result[key] = Repo.PROPERTIES.get(key, (1, None)) return result - def update_repo(self, old_repo, new_repo, server_value_repo=None): + def update_repo(self, old_repo: Repo, new_repo: Repo, server_value_repo: Optional[dict] = None) -> int: """ Checks an existing repo definition against a potentially updated version created from most recent entitlement certificates and @@ -618,10 +628,10 @@ def update_repo(self, old_repo, new_repo, server_value_repo=None): return changes_made - def report_update(self, repo): + def report_update(self, repo: Repo) -> None: self.report.repo_updates.append(repo) - def report_add(self, repo): + def report_add(self, repo: Repo) -> None: self.report.repo_added.append(repo) def report_delete(self, section): @@ -639,11 +649,11 @@ def __init__(self): self.repo_added = [] self.repo_deleted = [] - def updates(self): + def updates(self) -> int: """How many repos were updated""" return len(self.repo_updates) + len(self.repo_added) + len(self.repo_deleted) - def format_repos_info(self, repos, formatter): + def format_repos_info(self, repos, formatter) -> str: indent = " " if not repos: return "%s" % indent @@ -653,20 +663,20 @@ def format_repos_info(self, repos, formatter): r.append("%s%s" % (indent, formatter(repo))) return "\n".join(r) - def repo_format(self, repo): + def repo_format(self, repo: Repo) -> bytes: msg = "[id:%s %s]" % (repo.id, repo["name"]) return msg.encode("utf8") - def section_format(self, section): + def section_format(self, section: dict) -> str: return "[%s]" % section - def format_repos(self, repos): + def format_repos(self, repos: List[Repo]) -> str: return self.format_repos_info(repos, self.repo_format) - def format_sections(self, sections): + def format_sections(self, sections: List[dict]) -> str: return self.format_repos_info(sections, self.section_format) - def __str__(self): + def __str__(self) -> str: s = [_("Repo updates") + "\n"] s.append(_("Total repo updates: %d") % self.updates()) s.append(_("Updated")) diff --git a/src/subscription_manager/rhelentbranding.py b/src/subscription_manager/rhelentbranding.py index 0eaa17345a..2867ec71b8 100644 --- a/src/subscription_manager/rhelentbranding.py +++ b/src/subscription_manager/rhelentbranding.py @@ -16,10 +16,16 @@ # on subscription, RHEL specific implementation import logging +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING from subscription_manager import injection as inj from subscription_manager import entbranding +if TYPE_CHECKING: + from rhsm.certificate2 import EntitlementCertificate, Product + + from subscription_manager.certdirectory import EntitlementDirectory, ProductDirectory + log = logging.getLogger(__name__) @@ -28,30 +34,30 @@ class RHELBrandsInstaller(entbranding.BrandsInstaller): Currently just the RHELBrandInstaller.""" - def _get_brand_installers(self): + def _get_brand_installers(self) -> List[entbranding.BrandInstaller]: return [RHELBrandInstaller(self.ent_certs)] class RHELBrandInstaller(entbranding.BrandInstaller): - def _get_brand_picker(self): + def _get_brand_picker(self) -> entbranding.BrandPicker: return RHELBrandPicker(self.ent_certs) - def _get_current_brand(self): + def _get_current_brand(self) -> entbranding.Brand: return RHELCurrentBrand() - def _install(self, brand): + def _install(self, brand: entbranding.Brand) -> None: log.debug("Updating product branding info for: %s" % brand.name) brand.save() class RHELBrandPicker(entbranding.BrandPicker): - def __init__(self, ent_certs=None): + def __init__(self, ent_certs: List["EntitlementCertificate"] = None): super(RHELBrandPicker, self).__init__(ent_certs=ent_certs) - prod_dir = inj.require(inj.PROD_DIR) - self.installed_products = prod_dir.get_installed_products() + prod_dir: ProductDirectory = inj.require(inj.PROD_DIR) + self.installed_products: Dict[str, EntitlementCertificate] = prod_dir.get_installed_products() - def get_brand(self): + def get_brand(self) -> Optional[entbranding.Brand]: branded_cert_product = self._get_branded_cert_product() if not branded_cert_product: @@ -60,7 +66,7 @@ def get_brand(self): branded_product = branded_cert_product[1] return RHELProductBrand.from_product(branded_product) - def _get_branded_cert_product(self): + def _get_branded_cert_product(self) -> Optional[Tuple["EntitlementCertificate", "Product"]]: """Given a list of ent certs providing product branding, return one. If we can collapse them into one, do it. Otherwise, return nothing @@ -105,7 +111,7 @@ def _get_branded_cert_product(self): ) return None - def _get_branded_cert_products(self): + def _get_branded_cert_products(self) -> List[Tuple["EntitlementCertificate", "Product"]]: branded_cert_products = [] for cert in self._get_ent_certs(): products = cert.products or [] @@ -133,16 +139,16 @@ def _get_branded_cert_products(self): return branded_cert_products - def _get_ent_certs(self): + def _get_ent_certs(self) -> List["EntitlementCertificate"]: """Returns contents of injected ENT_DIR, or self.ent_dir if set""" if self.ent_certs: return self.ent_certs - ent_dir = inj.require(inj.ENT_DIR) + ent_dir: EntitlementDirectory = inj.require(inj.ENT_DIR) ent_dir.refresh() return ent_dir.list_valid() - def _get_installed_branded_products(self, products): - branded_products = [] + def _get_installed_branded_products(self, products: List["Product"]) -> List["Product"]: + branded_products: List[Product] = [] for product in products: # could support other types of branded products @@ -156,10 +162,10 @@ def _get_installed_branded_products(self, products): return branded_products - def _is_installed_rhel_branded_product(self, product): + def _is_installed_rhel_branded_product(self, product: "Product") -> bool: return product.id in self.installed_products - def _is_rhel_branded_product(self, product): + def _is_rhel_branded_product(self, product: "Product") -> bool: if not hasattr(product, "brand_type"): return False elif product.brand_type != "OS": @@ -175,17 +181,17 @@ def _is_rhel_branded_product(self, product): class RHELProductBrand(entbranding.ProductBrand): - def _get_brand_file(self): + def _get_brand_file(self) -> entbranding.BrandFile: return RHELBrandFile() class RHELCurrentBrand(entbranding.CurrentBrand): - def _get_brand_file(self): + def _get_brand_file(self) -> entbranding.BrandFile: return RHELBrandFile() class RHELBrandFile(entbranding.BrandFile): path = "/var/lib/rhsm/branded_name" - def __str__(self): + def __str__(self) -> str: return "" % self.path diff --git a/src/subscription_manager/rhelproduct.py b/src/subscription_manager/rhelproduct.py index 1ac0dcc0a2..b39a17ece8 100644 --- a/src/subscription_manager/rhelproduct.py +++ b/src/subscription_manager/rhelproduct.py @@ -24,21 +24,27 @@ # may or may not be a RHEL "branded" Product. See rhelentbranding for # code that handles finding and comparing RHEL "branded" Product objects. # +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from rhsm.certificate2 import Product + + class RHELProductMatcher(object): """Check a Product object to see if it is a RHEL product. Compares the provided tags to see if any provide 'rhel-VERSION'. """ - def __init__(self, product=None): - self.product = product + def __init__(self, product: Optional["Product"] = None): + self.product: Optional[Product] = product # Match "rhel-6" or "rhel-11" or "rhel-alt-7" (bz1510024) # but not "rhel-6-server" or "rhel-6-server-highavailabilty" # NOTE: we considered rhel(-[\w\d]+)?-\d+$ but decided against it # due to possibility of unintentional matches self.pattern = r"rhel(-alt)?-\d+$|rhel-5-workstation$" - def is_rhel(self): + def is_rhel(self) -> bool: """return true if this is a rhel product cert""" return any([re.match(self.pattern, tag) for tag in self.product.provided_tags]) diff --git a/src/subscription_manager/syspurposelib.py b/src/subscription_manager/syspurposelib.py index 9f36492fb9..91c09952f1 100644 --- a/src/subscription_manager/syspurposelib.py +++ b/src/subscription_manager/syspurposelib.py @@ -16,6 +16,7 @@ This module is an interface to syspurpose's SyncedStore class from subscription-manager. It contains methods for accessing/manipulating the local syspurpose.json metadata file through SyncedStore. """ +from typing import Optional, Tuple, Union, TYPE_CHECKING from rhsm.connection import ConnectionException, GoneException from subscription_manager import certlib @@ -26,13 +27,22 @@ import json import os +if TYPE_CHECKING: + from rhsm.connection import UEPConnection + + from subscription_manager.cache import SyspurposeValidFieldsCache + from subscription_manager.cp_provider import CPProvider + from subscription_manager.identity import Identity + + from syspurpose.files import DiffChange + log = logging.getLogger(__name__) store = None syspurpose = None -def save_sla_to_syspurpose_metadata(uep, consumer_uuid, service_level): +def save_sla_to_syspurpose_metadata(uep: "UEPConnection", consumer_uuid: str, service_level: str): """ Saves the provided service-level value to the local Syspurpose Metadata (syspurpose.json) file. If the service level provided is null or empty, the sla value to the local syspurpose file is set to null. @@ -41,7 +51,6 @@ def save_sla_to_syspurpose_metadata(uep, consumer_uuid, service_level): :param uep: The object with uep connection (connection to candlepin server) :param consumer_uuid: Consumer UUID :param service_level: The service-level value to be saved in the syspurpose file. - :type service_level: str """ if "SyncedStore" in globals() and SyncedStore is not None: @@ -58,7 +67,7 @@ def save_sla_to_syspurpose_metadata(uep, consumer_uuid, service_level): log.error("SyspurposeStore could not be imported. Syspurpose SLA value not saved locally.") -def get_sys_purpose_store(): +def get_sys_purpose_store() -> Optional[SyncedStore]: """ :return: Returns a singleton instance of the syspurpose store if it was imported. Otherwise None. @@ -73,7 +82,7 @@ def get_sys_purpose_store(): return store -def read_syspurpose(synced_store=None, raise_on_error=False): +def read_syspurpose(synced_store: Optional[SyncedStore] = None, raise_on_error: bool = False) -> dict: """ Reads the system purpose from the correct location on the file system. Makes an attempt to use a SyspurposeStore if available falls back to reading the json directly. @@ -97,7 +106,7 @@ def read_syspurpose(synced_store=None, raise_on_error=False): return content -def write_syspurpose(values): +def write_syspurpose(values: dict) -> bool: """ Write the syspurpose to the file system. :param values: @@ -116,7 +125,7 @@ def write_syspurpose(values): return True -def write_syspurpose_cache(values): +def write_syspurpose_cache(values: dict) -> bool: """ Write to the syspurpose cache on the file system. :param values: @@ -130,7 +139,7 @@ def write_syspurpose_cache(values): return True -def get_syspurpose_valid_fields(uep=None, identity=None): +def get_syspurpose_valid_fields(uep: "UEPConnection" = None, identity: "Identity" = None) -> dict: """ Try to get valid syspurpose fields provided by candlepin server :param uep: connection of candlepin server @@ -138,14 +147,20 @@ def get_syspurpose_valid_fields(uep=None, identity=None): :return: dictionary with valid fields """ valid_fields = {} - cache = inj.require(inj.SYSPURPOSE_VALID_FIELDS_CACHE) + cache: SyspurposeValidFieldsCache = inj.require(inj.SYSPURPOSE_VALID_FIELDS_CACHE) syspurpose_valid_fields = cache.read_data(uep, identity) if "systemPurposeAttributes" in syspurpose_valid_fields: valid_fields = syspurpose_valid_fields["systemPurposeAttributes"] return valid_fields -def merge_syspurpose_values(local=None, remote=None, base=None, uep=None, consumer_uuid=None): +def merge_syspurpose_values( + local: Optional[dict] = None, + remote: Optional[dict] = None, + base: Optional[dict] = None, + uep: Optional["UEPConnection"] = None, + consumer_uuid: Optional[str] = None, +) -> dict: """ Try to do three-way merge of local, remote and base dictionaries. Note: when remote is None, then this method will call REST API. @@ -180,7 +195,7 @@ class SyspurposeSyncActionInvoker(certlib.BaseActionInvoker): Used by rhsmcertd to sync the syspurpose values locally with those from the Server. """ - def _do_update(self): + def _do_update(self) -> Union["SyspurposeSyncActionReport", Tuple["SyspurposeSyncActionReport", dict]]: action = SyspurposeSyncActionCommand() return action.perform() @@ -188,11 +203,10 @@ def _do_update(self): class SyspurposeSyncActionReport(certlib.ActionReport): name = "Syspurpose Sync" - def record_change(self, change): + def record_change(self, change: "DiffChange") -> None: """ Records the change detected by the three_way_merge function into a record in the report. :param change: A util.DiffChange object containing the recorded changes. - :return: None """ if change.source == "remote": source = "Entitlement Server" @@ -220,7 +234,7 @@ def record_change(self, change): BZ #1789457 """ - def format_exceptions(self): + def format_exceptions(self) -> str: buf = "" for e in self._exceptions: buf += str(e).strip() @@ -238,10 +252,12 @@ class SyspurposeSyncActionCommand(object): def __init__(self): self.report = SyspurposeSyncActionReport() - self.cp_provider = inj.require(inj.CP_PROVIDER) - self.uep = self.cp_provider.get_consumer_auth_cp() + self.cp_provider: CPProvider = inj.require(inj.CP_PROVIDER) + self.uep: UEPConnection = self.cp_provider.get_consumer_auth_cp() - def perform(self, include_result=False, passthrough_gone=False): + def perform( + self, include_result: bool = False, passthrough_gone: bool = False + ) -> Union[SyspurposeSyncActionReport, Tuple[SyspurposeSyncActionReport, dict]]: """ Perform the action that this Command represents. :return: @@ -265,6 +281,7 @@ def perform(self, include_result=False, passthrough_gone=False): self.report._updates = "\n\t\t ".join(self.report._updates) log.debug("Syspurpose updated: %s" % self.report) if not include_result: + # FIXME Return None or {} as second argument return self.report else: return self.report, result diff --git a/src/subscription_manager/unicode_width.py b/src/subscription_manager/unicode_width.py index b2f348c585..41a8bca754 100644 --- a/src/subscription_manager/unicode_width.py +++ b/src/subscription_manager/unicode_width.py @@ -76,7 +76,10 @@ # Renamed but still pretty much JA's port of MK's code -def _interval_bisearch(value, table): +from typing import Sequence, Tuple + + +def _interval_bisearch(value: int, table: Sequence[Tuple[int, int]]) -> bool: """Binary search in an interval table. :arg value: numeric value to search for @@ -192,7 +195,7 @@ def _interval_bisearch(value, table): # Handling of control chars rewritten. Rest is JA's port of MK's C code. # -Toshio Kuratomi # NOTE: Removed unused functionality (see note in docstring) - subscription-manager developers -def _ucp_width(ucs): +def _ucp_width(ucs: int) -> int: """Get the :term:`textual width` of a ucs character :arg ucs: integer representing a single unicode :term:`code point` @@ -238,7 +241,7 @@ def _ucp_width(ucs): # Wholly rewritten by me (LGPLv2+) -Toshio Kuratomi # NOTE: Removed unused functionality (see note in docstring) - subscription-manager developers -def textual_width(msg): +def textual_width(msg: str) -> int: """Get the :term:`textual width` of a string :arg msg: :class:`str` string or byte :class:`bytes` to get the width of diff --git a/src/subscription_manager/utils.py b/src/subscription_manager/utils.py index ace225d675..640fce0b3e 100644 --- a/src/subscription_manager/utils.py +++ b/src/subscription_manager/utils.py @@ -24,8 +24,10 @@ import socket import syslog import uuid +from typing import Callable, Dict, Iterable, Iterator, Optional, Tuple, TYPE_CHECKING import urllib + from rhsm.https import ssl from subscription_manager.branding import get_branding @@ -52,6 +54,15 @@ from subscription_manager.i18n import ugettext as _ +if TYPE_CHECKING: + from rhsm.certificate2 import EntitlementCertificate + from rhsm.connection import UEPConnection + + from rhsmlib.services import config as rhsmlib_config + + from subscription_manager.cache import CurrentOwnerCache, SupportedResourcesCache + from subscription_manager.identity import Identity + from subscription_manager.cp_provider import CPProvider log = logging.getLogger(__name__) @@ -62,14 +73,16 @@ class DefaultDict(collections.defaultdict): """defaultdict wrapper that pretty prints""" - def as_dict(self): + def as_dict(self) -> dict: return dict(self) - def __repr__(self): + def __repr__(self) -> str: return pprint.pformat(self.as_dict()) -def parse_server_info(local_server_entry, config=None): +def parse_server_info( + local_server_entry: str, config: Optional["rhsmlib_config.Config"] = None +) -> Tuple[str, str, str]: hostname = "" port = "" prefix = "" @@ -85,7 +98,7 @@ def parse_server_info(local_server_entry, config=None): )[2:] -def parse_baseurl_info(local_server_entry): +def parse_baseurl_info(local_server_entry: str) -> Tuple[str, str, str]: return parse_url( local_server_entry, DEFAULT_CDN_HOSTNAME, @@ -94,7 +107,7 @@ def parse_baseurl_info(local_server_entry): )[2:] -def format_baseurl(hostname, port, prefix): +def format_baseurl(hostname: str, port: str, prefix: str) -> str: # just to avoid double slashs. cosmetic if prefix and prefix[0] != "/": prefix = "/%s" % prefix @@ -112,11 +125,11 @@ def format_baseurl(hostname, port, prefix): return "https://%s:%s%s" % (hostname, port, prefix) -def url_base_join(base, url): +def url_base_join(base: str, url: str) -> str: """Join a baseurl (hostname) and url (full or relpath). If url is a full url, just return it. Otherwise combine - it with base, skipping redundant seperators if needed.""" + it with base, skipping redundant separators if needed.""" # I don't really understand this. Why does joining something # potentially non-empty and "" return ""? -akl @@ -136,7 +149,7 @@ class MissingCaCertException(Exception): pass -def is_valid_server_info(conn): +def is_valid_server_info(conn: "UEPConnection") -> bool: """ Check if we can communicate with a subscription service at the given location. @@ -168,7 +181,9 @@ def is_valid_server_info(conn): return False -def is_simple_content_access(uep=None, identity=None): +def is_simple_content_access( + uep: Optional["UEPConnection"] = None, identity: Optional["Identity"] = None +) -> bool: """ This function returns True, when current owner uses contentAccessMode equal to Simple Content Access. This function has three optional arguments that can be reused for getting required information. @@ -186,7 +201,7 @@ def is_simple_content_access(uep=None, identity=None): return content_access_mode == "org_environment" -def get_current_owner(uep=None, identity=None): +def get_current_owner(uep: Optional["UEPConnection"] = None, identity: "Identity" = None) -> dict: """ This function tries to get information about current owner. It uses cache file. :param uep: connection to candlepin server @@ -194,13 +209,15 @@ def get_current_owner(uep=None, identity=None): :return: information about current owner """ - cache = inj.require(inj.CURRENT_OWNER_CACHE) + cache: CurrentOwnerCache = inj.require(inj.CURRENT_OWNER_CACHE) return cache.read_data(uep, identity) -def get_supported_resources(uep=None, identity=None): +def get_supported_resources( + uep: Optional["UEPConnection"] = None, identity: Optional["Identity"] = None +) -> dict: """ - This function tries to get list of supported resources. It tries to uses cache file. + This function tries to get list of supported resources. It tries to use cache file. When the system is not registered, then it tries to get version directly using REST API. It is preferred to use this function instead of connection.get_supported_resources. :param uep: connection of candlepin server @@ -210,16 +227,16 @@ def get_supported_resources(uep=None, identity=None): # Try to read supported resources from cache file if identity is not None: - cache = inj.require(inj.SUPPORTED_RESOURCES_CACHE) + cache: SupportedResourcesCache = inj.require(inj.SUPPORTED_RESOURCES_CACHE) return cache.read_data(uep, identity) else: if uep is None: - cp_provider = inj.require(inj.CP_PROVIDER) - uep = cp_provider.get_consumer_auth_cp() + cp_provider: CPProvider = inj.require(inj.CP_PROVIDER) + uep: UEPConnection = cp_provider.get_consumer_auth_cp() return uep.get_supported_resources() -def get_version(versions, package_name): +def get_version(versions, package_name: str) -> str: """ Return a string containing the version (and release if available). """ @@ -231,7 +248,7 @@ def get_version(versions, package_name): return "%s%s" % (package_version, package_release) -def get_terminal_width(): +def get_terminal_width() -> Optional[int]: """ Attempt to determine the current terminal size. """ @@ -245,7 +262,7 @@ def get_terminal_width(): return 1000 -def get_client_versions(): +def get_client_versions() -> Dict[str, str]: # It's possible (though unlikely, and kind of broken) to have more # than one version of subscription-manager installed. # This will return whatever version we are using. @@ -264,7 +281,7 @@ def get_client_versions(): return {"subscription-manager": sm_version} -def get_server_versions(cp, exception_on_timeout=False): +def get_server_versions(cp: "UEPConnection", exception_on_timeout: bool = False) -> Dict[str, str]: cp_version = _("Unknown") server_type = _("This system is currently not registered.") rules_version = _("Unknown") @@ -313,7 +330,7 @@ def get_server_versions(cp, exception_on_timeout=False): return {"candlepin": cp_version, "server-type": server_type, "rules-version": rules_version} -def restart_virt_who(): +def restart_virt_who() -> None: """ Send a SIGHUP signal to virt-who if it is running on the same machine. """ @@ -346,7 +363,7 @@ def restart_virt_who(): log.error("The virt-who pid file references a non-existent pid: %s", pid) -def friendly_join(items): +def friendly_join(items: Iterable[str]) -> str: if items is None: return "" @@ -369,17 +386,17 @@ def friendly_join(items): return first_string + " %s " % _("and") + last -def is_true_value(test_string): +def is_true_value(test_string) -> bool: val = str(test_string).lower() return val == "1" or val == "true" or val == "yes" -def system_log(message, priority=syslog.LOG_NOTICE): +def system_log(message: str, priority: int = syslog.LOG_NOTICE): syslog.openlog("subscription-manager") syslog.syslog(priority, message) -def chroot(dirname): +def chroot(dirname: str) -> None: """ Change root of all paths. """ @@ -387,7 +404,7 @@ def chroot(dirname): class CertificateFilter(object): - def match(self, cert): + def match(self, cert: "EntitlementCertificate"): """ Checks if the specified certificate matches this filter's restrictions. Returns True if the specified certificate matches this filter's restrictions ; False @@ -405,7 +422,7 @@ def __init__(self, filter_string=None): if filter_string is not None: self.set_filter_string(filter_string) - def set_filter_string(self, filter_string): + def set_filter_string(self, filter_string: Optional[str]) -> bool: """ Sets this filter's filter string to the specified string. The filter string may use ? or * for wildcards, representing one or any characters, respectively. @@ -464,7 +481,7 @@ def set_filter_string(self, filter_string): return output - def match(self, cert): + def match(self, cert: "EntitlementCertificate") -> bool: """ Checks if the specified certificate matches this filter's restrictions. Returns True if the specified certificate matches this filter's restrictions ; False @@ -483,7 +500,7 @@ def match(self, cert): class EntitlementCertificateFilter(ProductCertificateFilter): - def __init__(self, filter_string=None, service_level=None): + def __init__(self, filter_string: Optional[str] = None, service_level: Optional[str] = None): super(EntitlementCertificateFilter, self).__init__(filter_string=filter_string) self._sl_filter = None @@ -541,7 +558,7 @@ def match(self, cert): return sl_check and fs_check and (self._sl_filter is not None or self._fs_regex is not None) -def print_error(message): +def print_error(message: str) -> None: """ Prints the specified message to stderr """ @@ -549,7 +566,7 @@ def print_error(message): sys.stderr.write("\n") -def unique_list_items(items, hash_function=lambda x: x): +def unique_list_items(items: Iterable, hash_function: Callable = lambda x: x) -> list: """ Accepts a list of items. Returns a list of the unique items in the input. @@ -567,11 +584,11 @@ def unique_list_items(items, hash_function=lambda x: x): return unique_items -def generate_correlation_id(): +def generate_correlation_id() -> str: return str(uuid.uuid4()).replace("-", "") # FIXME cp should accept - -def get_process_names(): +def get_process_names() -> Iterator[str]: """ Returns a list of "Name" values for all processes running on the system. This assumes an accessible and standard procfs at "/proc/". diff --git a/src/subscription_manager/validity.py b/src/subscription_manager/validity.py index aec9e821c5..bd2dd7a367 100644 --- a/src/subscription_manager/validity.py +++ b/src/subscription_manager/validity.py @@ -13,23 +13,30 @@ # import logging +from typing import Optional, TYPE_CHECKING from rhsm.certificate import DateRange import subscription_manager.injection as inj from subscription_manager.isodate import parse_date +if TYPE_CHECKING: + from rhsm.connection import UEPConnection + + from subscription_manager.cache import ProductStatusCache + from subscription_manager.identity import Identity + log = logging.getLogger(__name__) class ValidProductDateRangeCalculator(object): - def __init__(self, uep=None): + def __init__(self, uep: Optional["UEPConnection"] = None): uep = uep or inj.require(inj.CP_PROVIDER).get_consumer_auth_cp() - self.identity = inj.require(inj.IDENTITY) + self.identity: Identity = inj.require(inj.IDENTITY) if self.identity.is_valid(): - self.prod_status_cache = inj.require(inj.PROD_STATUS_CACHE) - self.prod_status = self.prod_status_cache.load_status(uep, self.identity.uuid) + self.prod_status_cache: ProductStatusCache = inj.require(inj.PROD_STATUS_CACHE) + self.prod_status: dict = self.prod_status_cache.load_status(uep, self.identity.uuid) - def calculate(self, product_hash): + def calculate(self, product_hash: str) -> Optional[DateRange]: """ Calculate the valid date range for the specified product based on today's date.