Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ export function WorkflowTestRunButton({
const workflow = getYamlWorkflowDefinition(definition.value);
// NOTE: prevent the workflow from being disabled, so test run doesn't fail
workflow.disabled = false;
if (workflowId) {
// if existing workflow, use it's real id for test run
workflow.id = workflowId;
}
const body = getBodyFromStringOrDefinitionOrObject({
workflow,
});
Expand Down
81 changes: 64 additions & 17 deletions keep/api/core/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from contextlib import contextmanager
from datetime import datetime, timedelta, timezone
from functools import wraps
from typing import Any, Callable, Dict, List, Tuple, Type, Union
from typing import Any, Callable, Dict, Iterator, List, Tuple, Type, Union
from uuid import UUID, uuid4

import validators
Expand Down Expand Up @@ -47,7 +47,11 @@

from keep.api.consts import STATIC_PRESETS
from keep.api.core.config import config
from keep.api.core.db_utils import create_db_engine, get_json_extract_field
from keep.api.core.db_utils import (
create_db_engine,
get_json_extract_field,
get_or_create,
)
from keep.api.core.dependencies import SINGLE_TENANT_UUID

# This import is required to create the tables
Expand Down Expand Up @@ -112,7 +116,7 @@ def dispose_session():


@contextmanager
def existed_or_new_session(session: Optional[Session] = None) -> Session:
def existed_or_new_session(session: Optional[Session] = None) -> Iterator[Session]:
try:
if session:
yield session
Expand Down Expand Up @@ -207,7 +211,7 @@ def create_workflow_execution(
triggered_by = triggered_by[:255]
workflow_execution = WorkflowExecution(
id=workflow_execution_id,
workflow_id=workflow_id if not test_run else "test",
workflow_id=workflow_id,
workflow_revision=workflow_revision,
tenant_id=tenant_id,
started=datetime.now(tz=timezone.utc),
Expand All @@ -217,7 +221,7 @@ def create_workflow_execution(
error=None,
execution_time=None,
results={},
# is_test_run=test_run,
is_test_run=test_run,
)
session.add(workflow_execution)
# Ensure the object has an id
Expand Down Expand Up @@ -275,6 +279,7 @@ def get_last_completed_execution(
return session.exec(
select(WorkflowExecution)
.where(WorkflowExecution.workflow_id == workflow_id)
.where(WorkflowExecution.is_test_run == False)
.where(
(WorkflowExecution.status == "success")
| (WorkflowExecution.status == "error")
Expand Down Expand Up @@ -537,6 +542,7 @@ def add_or_update_workflow(
updated_by: str,
provisioned: bool = False,
provisioned_file: str | None = None,
is_test: bool = False,
) -> Workflow:
with Session(engine, expire_on_commit=False) as session:
# TODO: we need to better understanad if that's the right behavior we want
Expand All @@ -557,6 +563,7 @@ def add_or_update_workflow(
)

else:
now = datetime.now(tz=timezone.utc)
# Create a new workflow
workflow = Workflow(
id=id,
Expand All @@ -566,25 +573,54 @@ def add_or_update_workflow(
description=description,
created_by=created_by,
updated_by=updated_by,
last_updated=now,
interval=interval,
is_disabled=is_disabled,
workflow_raw=workflow_raw,
provisioned=provisioned,
provisioned_file=provisioned_file,
is_test=is_test,
)
version = WorkflowVersion(
workflow_id=workflow.id,
revision=1,
workflow_raw=workflow_raw,
updated_by=updated_by,
comment=f"Created by {created_by}",
is_valid=True,
is_current=True,
updated_at=now,
)
session.add(workflow)
session.add(version)
session.commit()
return workflow


def get_or_create_dummy_workflow(tenant_id: str, session: Session | None = None):
with existed_or_new_session(session) as session:
workflow, created = get_or_create(
session,
Workflow,
tenant_id=tenant_id,
id=get_dummy_workflow_id(tenant_id),
name="Dummy Workflow for test runs",
description="Auto-generated dummy workflow for test runs",
created_by="system",
workflow_raw="{}",
is_disabled=False,
is_test=True,
)
if created:
# For new instances, make sure they're committed and refreshed from the database
session.commit()
session.refresh(workflow)
elif workflow:
# For existing instances, refresh to get the current state
session.refresh(workflow)
return workflow


def get_workflow_to_alert_execution_by_workflow_execution_id(
workflow_execution_id: str,
) -> WorkflowToAlertExecution:
Expand Down Expand Up @@ -711,6 +747,7 @@ def get_workflows_with_last_execution(tenant_id: str) -> List[dict]:
)
.where(Workflow.tenant_id == tenant_id)
.where(Workflow.is_deleted == False)
.where(Workflow.is_test == False)
).distinct()

result = session.execute(workflows_with_last_execution_query).all()
Expand All @@ -723,6 +760,7 @@ def get_all_workflows(tenant_id: str):
select(Workflow)
.where(Workflow.tenant_id == tenant_id)
.where(Workflow.is_deleted == False)
.where(Workflow.is_test == False)
).all()
return workflows

Expand All @@ -734,6 +772,7 @@ def get_all_provisioned_workflows(tenant_id: str):
.where(Workflow.tenant_id == tenant_id)
.where(Workflow.provisioned == True)
.where(Workflow.is_deleted == False)
.where(Workflow.is_test == False)
).all()
return list(workflows)

Expand All @@ -754,6 +793,7 @@ def get_all_workflows_yamls(tenant_id: str):
select(Workflow.workflow_raw)
.where(Workflow.tenant_id == tenant_id)
.where(Workflow.is_deleted == False)
.where(Workflow.is_test == False)
).all()
return list(workflows)

Expand All @@ -764,6 +804,7 @@ def get_workflow(tenant_id: str, workflow_id: str):
select(Workflow)
.where(Workflow.tenant_id == tenant_id)
.where(Workflow.is_deleted == False)
.where(Workflow.is_test == False)
)
if validators.uuid(workflow_id):
query = query.where(Workflow.id == workflow_id)
Expand Down Expand Up @@ -893,13 +934,15 @@ def get_workflow_executions(
status: Optional[Union[str, List[str]]] = None,
trigger: Optional[Union[str, List[str]]] = None,
execution_id: Optional[str] = None,
is_test_run: bool = False,
):
with Session(engine) as session:
query = session.query(
WorkflowExecution,
).filter(
WorkflowExecution.tenant_id == tenant_id,
WorkflowExecution.workflow_id == workflow_id,
WorkflowExecution.is_test_run == False,
)

now = datetime.now(tz=timezone.utc)
Expand Down Expand Up @@ -1063,21 +1106,24 @@ def push_logs_to_db(log_entries):
session.commit()


def get_workflow_execution(tenant_id: str, workflow_execution_id: str):
def get_workflow_execution(
tenant_id: str, workflow_execution_id: str, is_test_run: bool | None = None
):
with Session(engine) as session:
execution_with_logs = (
session.query(WorkflowExecution)
.filter(
WorkflowExecution.id == workflow_execution_id,
WorkflowExecution.tenant_id == tenant_id,
)
.options(
joinedload(WorkflowExecution.logs),
joinedload(WorkflowExecution.workflow_to_alert_execution),
joinedload(WorkflowExecution.workflow_to_incident_execution),
base_query = session.query(WorkflowExecution)
if is_test_run is not None:
base_query = base_query.filter(
WorkflowExecution.is_test_run == is_test_run,
)
.one()
base_query = base_query.filter(
WorkflowExecution.id == workflow_execution_id,
WorkflowExecution.tenant_id == tenant_id,
)
execution_with_logs = base_query.options(
joinedload(WorkflowExecution.logs),
joinedload(WorkflowExecution.workflow_to_alert_execution),
joinedload(WorkflowExecution.workflow_to_incident_execution),
).one()
return execution_with_logs


Expand Down Expand Up @@ -2010,6 +2056,7 @@ def get_previous_execution_id(tenant_id, workflow_id, workflow_execution_id):
.where(WorkflowExecution.tenant_id == tenant_id)
.where(WorkflowExecution.workflow_id == workflow_id)
.where(WorkflowExecution.id != workflow_execution_id)
.where(WorkflowExecution.is_test_run == False)
.where(
WorkflowExecution.started >= datetime.now() - timedelta(days=1)
) # no need to check more than 1 day ago
Expand Down
60 changes: 59 additions & 1 deletion keep/api/core/db_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import json
import logging
import os
from typing import Any, Dict, Optional, Type, TypeVar, Tuple

import pymysql
from dotenv import find_dotenv, load_dotenv
Expand All @@ -15,7 +16,8 @@
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.ddl import CreateColumn
from sqlalchemy.sql.functions import GenericFunction
from sqlmodel import Session, create_engine
from sqlmodel import Session, create_engine, SQLModel, select
from sqlalchemy.exc import IntegrityError

# This import is required to create the tables
from keep.api.consts import RUNNING_IN_CLOUD_RUN
Expand Down Expand Up @@ -208,3 +210,59 @@ def _compile_json_table(element, compiler, **kw):
for clause in element.clauses.clauses[1:]
),
)


T = TypeVar("T", bound=SQLModel)


def get_or_create(
session: Session,
model: Type[T],
defaults: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Tuple[T, bool]:
"""
Get an instance by filter kwargs, or create one with those filters plus any defaults.

Args:
session: SQLModel session
model: Model class
defaults: Dict of default values for creation (not used for lookup)
**kwargs: Filter parameters used both for lookup and creation

Returns:
tuple: (instance, created) where created is a boolean indicating if a new instance was created
"""
# Build query with all filter conditions
query = select(model)
for key, value in kwargs.items():
query = query.where(getattr(model, key) == value)

# Execute the query
instance = session.exec(query).first()

if instance:
return instance, False

# Prepare creation attributes
create_attrs = kwargs.copy()
if defaults:
create_attrs.update(defaults)

instance = model(**create_attrs)
session.add(instance)

try:
# Try to flush without committing to detect any integrity errors
session.flush()
return instance, True
except IntegrityError:
# If there's a conflict, roll back and try to fetch again (another process might have created it)
session.rollback()

# Try to fetch again with the same query
instance = session.exec(query).first()
if instance:
return instance, False
# If we still can't find it, something else is wrong, re-raise
raise
Comment thread
Kiryous marked this conversation as resolved.
1 change: 1 addition & 0 deletions keep/api/core/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def build_workflows_query(
and_(
Workflow.id == literal_column("entity_id"),
Workflow.tenant_id == tenant_id,
Workflow.is_test == False,
),
)
.outerjoin(
Expand Down
Loading