Skip to content

Commit

Permalink
Making extension parameter optional.
Browse files Browse the repository at this point in the history
  • Loading branch information
bacciotti committed Nov 21, 2024
1 parent 3fb15b9 commit cd4d6e0
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 60 deletions.
42 changes: 21 additions & 21 deletions koku/masu/external/downloader/azure/azure_report_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class AzureReportDownloaderNoFileError(Exception):


def get_processing_date(
s3_csv_path, manifest_id, provider_uuid, start_date, end_date, context, tracing_id, ingress_reports=None
s3_csv_path, manifest_id, provider_uuid, start_date, end_date, context, tracing_id, ingress_reports=None
):
"""
Fetch initial dataframe from CSV plus start_delta and time_inteval.
Expand All @@ -60,12 +60,12 @@ def get_processing_date(
# Azure does not have an invoice column so we have to do some guessing here
# Ingres reports should always clear and process everything
if (
start_date.year < dh.today.year
and dh.today.day > 1
or start_date.month < dh.today.month
and dh.today.day > 1
or not check_provider_setup_complete(provider_uuid)
or ingress_reports
start_date.year < dh.today.year
and dh.today.day > 1
or start_date.month < dh.today.month
and dh.today.day > 1
or not check_provider_setup_complete(provider_uuid)
or ingress_reports
):
process_date = start_date
process_date = ReportManifestDBAccessor().set_manifest_daily_start_date(manifest_id, process_date)
Expand All @@ -77,15 +77,15 @@ def get_processing_date(


def create_daily_archives(
tracing_id,
account,
provider_uuid,
local_file,
base_filename,
manifest_id,
start_date,
context,
ingress_reports=None,
tracing_id,
account,
provider_uuid,
local_file,
base_filename,
manifest_id,
start_date,
context,
ingress_reports=None,
):
"""
Create daily CSVs from incoming report and archive to S3.
Expand Down Expand Up @@ -117,10 +117,10 @@ def create_daily_archives(
return [], {}
try:
with pd.read_csv(
local_file,
chunksize=settings.PARQUET_PROCESSING_BATCH_SIZE,
parse_dates=[time_interval],
dtype=pd.StringDtype(storage="pyarrow"),
local_file,
chunksize=settings.PARQUET_PROCESSING_BATCH_SIZE,
parse_dates=[time_interval],
dtype=pd.StringDtype(storage="pyarrow"),
) as reader:
for i, data_frame in enumerate(reader):
if data_frame.empty:
Expand Down Expand Up @@ -333,7 +333,7 @@ def _get_manifest(self, date_time): # noqa: C901
else:
try:
blob = self._azure_client.get_latest_cost_export_for_path(
report_path, self.container_name, compression_mode
report_path, self.container_name
)
except AzureCostReportNotFound as ex:
msg = f"Unable to find manifest. Error: {ex}"
Expand Down
73 changes: 34 additions & 39 deletions koku/masu/external/downloader/azure/azure_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,16 @@ class AzureService:
"""A class to handle interactions with the Azure services."""

def __init__(
self,
tenant_id,
client_id,
client_secret,
resource_group_name,
storage_account_name,
subscription_id=None,
cloud="public",
scope=None,
export_name=None,
self,
tenant_id,
client_id,
client_secret,
resource_group_name,
storage_account_name,
subscription_id=None,
cloud="public",
scope=None,
export_name=None,
):
"""Establish connection information."""
self._resource_group_name = resource_group_name
Expand All @@ -64,24 +64,28 @@ def __init__(
raise AzureServiceError("Azure Service credentials are not configured.")

def _get_latest_blob(
self, report_path: str, blobs: list[BlobProperties], extension: str
self, report_path: str, blobs: list[BlobProperties], extension: t.Optional[str] = None
) -> t.Optional[BlobProperties]:
default_extensions = [AzureBlobExtension.csv.value, AzureBlobExtension.gzip.value]
latest_blob = None
for blob in blobs:
if not blob.name.endswith(extension):
continue
if extension:
if not blob.name.endswith(extension):
continue
else:
if not any(blob.name.endswith(ext) for ext in default_extensions):
continue

if report_path in blob.name and not latest_blob:
latest_blob = blob
elif report_path in blob.name and blob.last_modified > latest_blob.last_modified:
latest_blob = blob
if report_path in blob.name:
if not latest_blob or blob.last_modified > latest_blob.last_modified:
latest_blob = blob
return latest_blob

def _get_latest_blob_for_path(
self,
report_path: str,
container_name: str,
extension: str,
self,
report_path: str,
container_name: str,
extension: t.Optional[str] = None,
) -> BlobProperties:
"""Get the latest file with the specified extension from given storage account container."""

Expand Down Expand Up @@ -120,10 +124,7 @@ def _get_latest_blob_for_path(

latest_report = self._get_latest_blob(report_path, blobs, extension)
if not latest_report:
message = (
f"No file with extension '{extension}' found in container "
f"'{container_name}' for path '{report_path}'."
)
message = f"No file found in container " f"'{container_name}' for path '{report_path}'."
raise AzureCostReportNotFound(message)

return latest_report
Expand Down Expand Up @@ -158,9 +159,7 @@ def get_file_for_key(self, key: str, container_name: str) -> BlobProperties:

return report

def get_latest_cost_export_for_path(
self, report_path: str, container_name: str, compression: str
) -> BlobProperties:
def get_latest_cost_export_for_path(self, report_path: str, container_name: str) -> BlobProperties:
"""
Get the latest cost export for a given path and container based on the compression type.
Expand All @@ -176,22 +175,18 @@ def get_latest_cost_export_for_path(
ValueError: If the compression type is not 'gzip' or 'csv'.
AzureCostReportNotFound: If no blob is found for the given path and container.
"""
valid_compressions = [AzureBlobExtension.gzip.value, AzureBlobExtension.csv.value]
if compression not in valid_compressions:
raise ValueError(f"Invalid compression type: {compression}. Expected one of: {valid_compressions}.")

return self._get_latest_blob_for_path(report_path, container_name, compression)
return self._get_latest_blob_for_path(report_path, container_name)

def get_latest_manifest_for_path(self, report_path: str, container_name: str) -> BlobProperties:
return self._get_latest_blob_for_path(report_path, container_name, AzureBlobExtension.manifest.value)

def download_file(
self,
key: str,
container_name: str,
destination: str = None,
suffix: str = AzureBlobExtension.csv.value,
ingress_reports: list[str] = None,
self,
key: str,
container_name: str,
destination: str = None,
suffix: str = AzureBlobExtension.csv.value,
ingress_reports: list[str] = None,
) -> str:
"""
Download the file from a given storage container. Supports both CSV and GZIP formats.
Expand Down

0 comments on commit cd4d6e0

Please sign in to comment.