Skip to content

Commit

Permalink
Register entity with model class and instantiate entity from model
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-fcampbell committed Nov 5, 2024
1 parent 2b7da0e commit a907153
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 66 deletions.
38 changes: 6 additions & 32 deletions src/snowflake/cli/_plugins/workspace/manager.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
from pathlib import Path
from typing import Dict

from snowflake.cli._plugins.workspace.context import ActionContext, WorkspaceContext
from snowflake.cli.api.cli_global_context import get_cli_context
from snowflake.cli._plugins.workspace.context import ActionContext
from snowflake.cli.api.console import cli_console as cc
from snowflake.cli.api.entities.common import EntityActions, get_sql_executor
from snowflake.cli.api.exceptions import InvalidProjectDefinitionVersionError
from snowflake.cli.api.project.definition import default_role
from snowflake.cli.api.project.schemas.entities.entities import (
Entity,
v2_entity_model_to_entity_map,
from snowflake.cli.api.entities.common import (
EntityActions,
)
from snowflake.cli.api.exceptions import InvalidProjectDefinitionVersionError
from snowflake.cli.api.project.schemas.entities.entities import Entity
from snowflake.cli.api.project.schemas.project_definition import (
DefinitionV20,
ProjectDefinition,
)
from snowflake.cli.api.project.util import to_identifier


class WorkspaceManager:
Expand All @@ -41,15 +37,7 @@ def get_entity(self, entity_id: str):
entity_model = self._project_definition.entities.get(entity_id, None)
if entity_model is None:
raise ValueError(f"No such entity ID: {entity_id}")
entity_model_cls = entity_model.__class__
entity_cls = v2_entity_model_to_entity_map[entity_model_cls]
workspace_ctx = WorkspaceContext(
console=cc,
project_root=self.project_root,
get_default_role=_get_default_role,
get_default_warehouse=_get_default_warehouse,
)
self._entities_cache[entity_id] = entity_cls(entity_model, workspace_ctx)
self._entities_cache[entity_id] = entity_model.get_entity(cc, self.project_root)
return self._entities_cache[entity_id]

def perform_action(self, entity_id: str, action: EntityActions, *args, **kwargs):
Expand All @@ -68,17 +56,3 @@ def perform_action(self, entity_id: str, action: EntityActions, *args, **kwargs)
@property
def project_root(self) -> Path:
return self._project_root


def _get_default_role() -> str:
role = default_role()
if role is None:
role = get_sql_executor().current_role()
return role


def _get_default_warehouse() -> str | None:
warehouse = get_cli_context().connection.warehouse
if warehouse:
warehouse = to_identifier(warehouse)
return warehouse
27 changes: 26 additions & 1 deletion src/snowflake/cli/api/entities/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,32 @@ class EntityActions(str, Enum):
T = TypeVar("T")


class EntityBase(Generic[T]):
class EntityBaseMetaclass(type):
def __new__(mcs, name, bases, attrs): # noqa: N804
cls = super().__new__(mcs, name, bases, attrs)
generic_bases = attrs.get("__orig_bases__", [])
if not generic_bases:
# Subclass is not generic
return cls

target_model_class = get_args(generic_bases[0])[0] # type: ignore[attr-defined]
if target_model_class is T:
# Generic parameter is not filled in
return cls

target_entity_class = getattr(target_model_class, "_entity_class", None)
if target_entity_class is not None:
raise ValueError(
f"Entity model class {target_model_class} is already "
f"associated with entity class {target_entity_class}, "
f"cannot associate with {cls}"
)

setattr(target_model_class, "_entity_class", cls)
return cls


class EntityBase(Generic[T], metaclass=EntityBaseMetaclass):
"""
Base class for the fully-featured entity classes.
"""
Expand Down
41 changes: 41 additions & 0 deletions src/snowflake/cli/api/project/schemas/entities/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
from __future__ import annotations

from abc import ABC
from pathlib import Path
from typing import Dict, Generic, List, Optional, TypeVar, Union

from pydantic import Field, PrivateAttr, field_validator
from snowflake.cli._plugins.workspace.context import WorkspaceContext
from snowflake.cli.api.console.abc import AbstractConsole
from snowflake.cli.api.identifiers import FQN
from snowflake.cli.api.project.schemas.updatable_model import (
IdentifierField,
Expand Down Expand Up @@ -110,6 +113,24 @@ def fqn(self) -> FQN:
if self.entity_id:
return FQN.from_string(self.entity_id)

def get_entity(self, console: AbstractConsole, project_root: Path):
if type(self) is EntityModelBase:
raise NotImplementedError
# Set by EntityBaseMetaclass when creating the
# Entity class that refers to this model
entity_class = getattr(self, "_entity_class", None)
if entity_class is None:
raise ValueError(
f"Entity model class {type(self).__name__} is not associated with an entity class"
)
workspace_ctx = WorkspaceContext(
console=console,
project_root=project_root,
get_default_role=_get_default_role,
get_default_warehouse=_get_default_warehouse,
)
return entity_class(self, workspace_ctx)


TargetType = TypeVar("TargetType")

Expand Down Expand Up @@ -162,3 +183,23 @@ def get_secrets_sql(self) -> str | None:
return None
secrets = ", ".join(f"'{key}'={value}" for key, value in self.secrets.items())
return f"secrets=({secrets})"


def _get_default_role() -> str:
from snowflake.cli.api.entities.common import get_sql_executor
from snowflake.cli.api.project.definition import default_role

role = default_role()
if role is None:
role = get_sql_executor().current_role()
return role


def _get_default_warehouse() -> str | None:
from snowflake.cli.api.cli_global_context import get_cli_context
from snowflake.cli.api.project.util import to_identifier

warehouse = get_cli_context().connection.warehouse
if warehouse:
warehouse = to_identifier(warehouse)
return warehouse
10 changes: 2 additions & 8 deletions tests/nativeapp/test_version_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
AskAlwaysPolicy,
DenyAlwaysPolicy,
)
from snowflake.cli._plugins.workspace.context import ActionContext, WorkspaceContext
from snowflake.cli._plugins.workspace.context import ActionContext
from snowflake.cli.api.console import cli_console as cc
from snowflake.cli.api.project.definition_manager import DefinitionManager
from snowflake.connector.cursor import DictCursor
Expand Down Expand Up @@ -60,13 +60,7 @@ def _version_create(
dm = DefinitionManager()
pd = dm.project_definition
pkg_model: ApplicationPackageEntityModel = pd.entities["app_pkg"]
ctx = WorkspaceContext(
console=cc,
project_root=dm.project_root,
get_default_role=lambda: "mock_role",
get_default_warehouse=lambda: "mock_warehouse",
)
pkg = ApplicationPackageEntity(pkg_model, ctx)
pkg = pkg_model.get_entity(cc, dm.project_root)
return pkg.action_version_create(
action_ctx=mock.Mock(spec=ActionContext),
version=version,
Expand Down
25 changes: 0 additions & 25 deletions tests/project/test_project_definition_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,6 @@
)
from snowflake.cli.api.project.definition_manager import DefinitionManager
from snowflake.cli.api.project.errors import SchemaValidationError
from snowflake.cli.api.project.schemas.entities.entities import (
ALL_ENTITIES,
ALL_ENTITY_MODELS,
v2_entity_model_to_entity_map,
v2_entity_model_types_map,
)
from snowflake.cli.api.project.schemas.project_definition import (
DefinitionV20,
)
Expand Down Expand Up @@ -310,25 +304,6 @@ def test_identifiers():
assert entities["D"].entity_id == "D"


# Verify that each entity model type has the correct "type" field
def test_entity_types():
for entity_type, entity_class in v2_entity_model_types_map.items():
model_entity_type = entity_class.get_type()
assert model_entity_type == entity_type


# Verify that each entity class has a corresponding entity model class, and that all entities are covered
def test_entity_model_to_entity_map():
entities = set(ALL_ENTITIES)
entity_models = set(ALL_ENTITY_MODELS)
assert len(entities) == len(entity_models)
for entity_model_class, entity_class in v2_entity_model_to_entity_map.items():
entities.remove(entity_class)
entity_models.remove(entity_model_class)
assert len(entities) == 0
assert len(entity_models) == 0


@pytest.mark.parametrize(
"project_name",
[
Expand Down

0 comments on commit a907153

Please sign in to comment.