Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions application_sdk/activities/query_extraction/sql.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import os
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Type
from typing import Any, Dict, List, Optional, Type, cast

from pydantic import BaseModel, Field
from temporalio import activity
Expand All @@ -13,6 +13,7 @@
get_workflow_id,
)
from application_sdk.clients.sql import BaseSQLClient
from application_sdk.common.utils import parse_credentials_extra
from application_sdk.constants import UPSTREAM_OBJECT_STORE_NAME
from application_sdk.handlers import HandlerInterface
from application_sdk.handlers.sql import BaseSQLHandler
Expand Down Expand Up @@ -422,7 +423,9 @@ async def write_marker(
)
logger.info(f"Marker file written to {marker_file_path}")

async def read_marker(self, workflow_args: Dict[str, Any]) -> Optional[int]:
async def read_marker(
self, workflow_args: Dict[str, Any], output_path: Optional[str] = None
) -> Optional[int]:
"""Read the marker from the output path.

This method reads the current marker value from a marker file to determine the
Expand All @@ -441,8 +444,13 @@ async def read_marker(self, workflow_args: Dict[str, Any]) -> Optional[int]:
Exception: If marker file reading fails (logged as warning, not re-raised)
"""
try:
output_path = workflow_args["output_path"].rsplit("/", 1)[0]
marker_file_path = os.path.join(output_path, "markerfile")
base_output_path: str = cast(str, workflow_args["output_path"])
resolved_output_path: str = (
output_path
if isinstance(output_path, str) and output_path
else base_output_path.rsplit("/", 1)[0]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this line can be made more readable. Either via using constants that tell what is at 1 and 0 index or via a comment stating 'why' this is being done - expected output.

)
marker_file_path = os.path.join(resolved_output_path, "markerfile")
logger.info(f"Downloading marker file from {marker_file_path}")

await ObjectStore.download_file(
Expand All @@ -463,6 +471,25 @@ async def read_marker(self, workflow_args: Dict[str, Any]) -> Optional[int]:
logger.warning(f"Failed to read marker: {e}")
return None

@activity.defn(name="miner_preflight_check")
@auto_heartbeater
async def preflight_check(self, workflow_args: Dict[str, Any]):
return await super().preflight_check(workflow_args)

@activity.defn(name="miner_get_workflow_args")
@auto_heartbeater
async def get_workflow_args(self, workflow_config: Dict[str, Any]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should not have any code other than the super()...

workflow_args = await super().get_workflow_args(workflow_config)
if "credential_guid" in workflow_args:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick - prefer using constants/variables when there is repeat use of key strings in a code snippet.

credentials = await SecretStore.get_credentials(
credential_guid=workflow_args["credential_guid"]
)
extra = parse_credentials_extra(credentials)
workflow_args["deployment_type"] = extra.get(
"deployment_type", "provisioned"
)
return workflow_args

@activity.defn
@auto_heartbeater
async def get_query_batches(
Expand Down
Loading