Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 86 additions & 29 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,8 +680,9 @@ def _row_to_node(d: dict[str, Any]) -> Node:
ds_namespace, ds_project, ds_name = parse_dataset_name(ds_name)
assert ds_namespace
assert ds_project
project = self.metastore.get_project(ds_project, ds_namespace)
dataset = self.get_dataset(ds_name, project)
dataset = self.get_dataset(
ds_name, namespace_name=ds_namespace, project_name=ds_project
)
if not ds_version:
ds_version = dataset.latest_version
dataset_sources = self.warehouse.get_dataset_sources(
Expand Down Expand Up @@ -807,7 +808,11 @@ def create_dataset(
)
default_version = DEFAULT_DATASET_VERSION
try:
dataset = self.get_dataset(name, project)
dataset = self.get_dataset(
name,
namespace_name=project.namespace.name if project else None,
project_name=project.name if project else None,
)
default_version = dataset.next_version_patch
if update_version == "major":
default_version = dataset.next_version_major
Expand Down Expand Up @@ -1016,7 +1021,11 @@ def create_dataset_from_sources(
dc.save(name)
except Exception as e: # noqa: BLE001
try:
ds = self.get_dataset(name, project)
ds = self.get_dataset(
name,
namespace_name=project.namespace.name,
project_name=project.name,
)
self.metastore.update_dataset_status(
ds,
DatasetStatus.FAILED,
Expand All @@ -1033,15 +1042,23 @@ def create_dataset_from_sources(
except DatasetNotFoundError:
raise e from None

ds = self.get_dataset(name, project)
ds = self.get_dataset(
name,
namespace_name=project.namespace.name,
project_name=project.name,
)

self.update_dataset_version_with_warehouse_info(
ds,
ds.latest_version,
sources="\n".join(sources),
)

return self.get_dataset(name, project)
return self.get_dataset(
name,
namespace_name=project.namespace.name,
project_name=project.name,
)

def get_full_dataset_name(
self,
Expand Down Expand Up @@ -1077,22 +1094,23 @@ def get_full_dataset_name(
return namespace_name, project_name, name

def get_dataset(
self, name: str, project: Optional[Project] = None
self,
name: str,
namespace_name: Optional[str] = None,
project_name: Optional[str] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use full project name instead of additional params - ns1.ns2.ds3

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is internal API where we already have it split this into 3 parts. In public API we use fully qualified name. Honestly I wouldn't change this atm as it would again require me some refactoring which would take time and we already have this kind of signature in other internal methods so not much would be changed. WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's approve to move fast. But we are accumulating tech depth using different notations in internal and external APIs. Please create a followup issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just an idea: we can store full name in dataset table as dataset name, this will simplify everything and we don't need anymore to pass namespace and project name everywhere.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to do that as dataset table already have connection to project table which has connection to namespaces table, so it's normalized. Also, when fetching dataset from table object DatasetRecord already has Project object in it and property method full_name which returns namespace, project and dataset name so it's all ok.
We need explicit namespace / project when we are fetching dataset (as here) or saving the new one in DB. The only question is should we everywhere have it split into 3 arguments, as now, or should we have just dataset name which must, by convention, have namespace and project embedded in it, e.g namespace.project.dataset. This is what we already have in public API but not in internal ones.

) -> DatasetRecord:
from datachain.lib.listing import is_listing_dataset

project = project or self.metastore.default_project
namespace_name = namespace_name or self.metastore.default_namespace_name
project_name = project_name or self.metastore.default_project_name
Comment on lines +1104 to +1105
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wonder, if it make sense if project_name is set with namespace_name is empty (default one) and vise versa (namespace_name is set, but project_name is default one).

I assume yes, it make sense in some cases, but it also may leads to some weird behavior. Taking the full name (ns1.ns2.ds3 as @dmpetrov suggests below) may solve this issue (it will not be possible to set namespace name only, without setting project name). Or may be I am wrong and this is a nice feature to have. Just something to keep in mind and think about, may be later.

Note: to be honest I am still confused with namespace and project — which one goes first 😥

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It def makes sense to have just project name set, as then default namespace will be used, but you are right about vice versa .. if someone just sets namespace name then default project will be used which can be little bit weird. I'm not 100% sure if that's ok or not TBH but for now we allow all combinations.

It's namespace.project.dataset :D That's hierarchy in DB as well.


if is_listing_dataset(name):
project = self.metastore.listing_project
namespace_name = self.metastore.system_namespace_name
project_name = self.metastore.listing_project_name

try:
return self.metastore.get_dataset(name, project.id if project else None)
except DatasetNotFoundError:
raise DatasetNotFoundError(
f"Dataset {name} not found in namespace {project.namespace.name}"
f" and project {project.name}"
) from None
return self.metastore.get_dataset(
name, namespace_name=namespace_name, project_name=project_name
)

def get_dataset_with_remote_fallback(
self,
Expand All @@ -1113,8 +1131,11 @@ def get_dataset_with_remote_fallback(

if self.metastore.is_local_dataset(namespace_name) or not update:
try:
project = self.metastore.get_project(project_name, namespace_name)
ds = self.get_dataset(name, project)
ds = self.get_dataset(
name,
namespace_name=namespace_name,
project_name=project_name,
)
if not version or ds.has_version(version):
return ds
except (NamespaceNotFoundError, ProjectNotFoundError, DatasetNotFoundError):
Expand All @@ -1139,7 +1160,9 @@ def get_dataset_with_remote_fallback(
local_ds_version=version,
)
return self.get_dataset(
name, self.metastore.get_project(project_name, namespace_name)
name,
namespace_name=namespace_name,
project_name=project_name,
)

return self.get_remote_dataset(namespace_name, project_name, name)
Expand All @@ -1148,7 +1171,11 @@ def get_dataset_with_version_uuid(self, uuid: str) -> DatasetRecord:
"""Returns dataset that contains version with specific uuid"""
for dataset in self.ls_datasets():
if dataset.has_version_with_uuid(uuid):
return self.get_dataset(dataset.name, dataset.project)
return self.get_dataset(
dataset.name,
namespace_name=dataset.project.namespace.name,
project_name=dataset.project.name,
)
raise DatasetNotFoundError(f"Dataset with version uuid {uuid} not found.")

def get_remote_dataset(
Expand All @@ -1171,9 +1198,18 @@ def get_remote_dataset(
return DatasetRecord.from_dict(dataset_info)

def get_dataset_dependencies(
self, name: str, version: str, project: Optional[Project] = None, indirect=False
self,
name: str,
version: str,
namespace_name: Optional[str] = None,
project_name: Optional[str] = None,
indirect=False,
) -> list[Optional[DatasetDependency]]:
dataset = self.get_dataset(name, project)
dataset = self.get_dataset(
name,
namespace_name=namespace_name,
project_name=project_name,
)

direct_dependencies = self.metastore.get_direct_dataset_dependencies(
dataset, version
Expand All @@ -1187,10 +1223,13 @@ def get_dataset_dependencies(
# dependency has been removed
continue
if d.is_dataset:
project = self.metastore.get_project(d.project, d.namespace)
# only datasets can have dependencies
d.dependencies = self.get_dataset_dependencies(
d.name, d.version, project, indirect=indirect
d.name,
d.version,
namespace_name=d.namespace,
project_name=d.project,
indirect=indirect,
)

return direct_dependencies
Expand Down Expand Up @@ -1340,7 +1379,11 @@ def export_dataset_table(
project: Optional[Project] = None,
client_config=None,
) -> list[str]:
dataset = self.get_dataset(name, project)
dataset = self.get_dataset(
name,
namespace_name=project.namespace.name if project else None,
project_name=project.name if project else None,
)

return self.warehouse.export_dataset_table(
bucket_uri, dataset, version, client_config
Expand All @@ -1349,7 +1392,11 @@ def export_dataset_table(
def dataset_table_export_file_names(
self, name: str, version: str, project: Optional[Project] = None
) -> list[str]:
dataset = self.get_dataset(name, project)
dataset = self.get_dataset(
name,
namespace_name=project.namespace.name if project else None,
project_name=project.name if project else None,
)
return self.warehouse.dataset_table_export_file_names(dataset, version)

def remove_dataset(
Expand All @@ -1359,7 +1406,11 @@ def remove_dataset(
version: Optional[str] = None,
force: Optional[bool] = False,
):
dataset = self.get_dataset(name, project)
dataset = self.get_dataset(
name,
namespace_name=project.namespace.name if project else None,
project_name=project.name if project else None,
)
if not version and not force:
raise ValueError(f"Missing dataset version from input for dataset {name}")
if version and not dataset.has_version(version):
Expand Down Expand Up @@ -1395,7 +1446,11 @@ def edit_dataset(
if attrs is not None:
update_data["attrs"] = attrs # type: ignore[assignment]

dataset = self.get_dataset(name, project)
dataset = self.get_dataset(
name,
namespace_name=project.namespace.name if project else None,
project_name=project.name if project else None,
)
return self.update_dataset(dataset, **update_data)

def ls(
Expand Down Expand Up @@ -1549,7 +1604,9 @@ def _instantiate(ds_uri: str) -> None:
)

try:
local_dataset = self.get_dataset(local_ds_name, project=project)
local_dataset = self.get_dataset(
local_ds_name, namespace_name=namespace.name, project_name=project.name
)
if local_dataset and local_dataset.has_version(local_ds_version):
raise DataChainError(
f"Local dataset {local_ds_uri} already exists with different uuid,"
Expand Down
5 changes: 3 additions & 2 deletions src/datachain/cli/commands/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,9 @@ def list_datasets_local(catalog: "Catalog", name: Optional[str] = None):
def list_datasets_local_versions(catalog: "Catalog", name: str):
namespace_name, project_name, name = catalog.get_full_dataset_name(name)

project = catalog.metastore.get_project(project_name, namespace_name)
ds = catalog.get_dataset(name, project)
ds = catalog.get_dataset(
name, namespace_name=namespace_name, project_name=project_name
)
for v in ds.versions:
yield (name, v.version)

Expand Down
43 changes: 34 additions & 9 deletions src/datachain/data_storage/metastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,13 @@ def list_datasets_by_prefix(
"""

@abstractmethod
def get_dataset(self, name: str, project_id: Optional[int] = None) -> DatasetRecord:
def get_dataset(
self,
name: str, # normal, not full dataset name
namespace_name: Optional[str] = None,
project_name: Optional[str] = None,
conn=None,
) -> DatasetRecord:
"""Gets a single dataset by name."""

@abstractmethod
Expand Down Expand Up @@ -912,11 +918,14 @@ def create_dataset(
**kwargs, # TODO registered = True / False
) -> DatasetRecord:
"""Creates new dataset."""
project_id = project_id or self.default_project.id
if not project_id:
project = self.default_project
else:
project = self.get_project_by_id(project_id)
Comment on lines +921 to +924
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: Logic for determining project_id in create_dataset may be redundant.

Since project.id is always used in the insert, fetching the full project object when project_id is already provided may be unnecessary. Consider simplifying to avoid the extra lookup.


query = self._datasets_insert().values(
name=name,
project_id=project_id,
project_id=project.id,
status=status,
feature_schema=json.dumps(feature_schema or {}),
created_at=datetime.now(timezone.utc),
Expand All @@ -935,7 +944,9 @@ def create_dataset(
query = query.on_conflict_do_nothing(index_elements=["project_id", "name"])
self.db.execute(query)

return self.get_dataset(name, project_id)
return self.get_dataset(
name, namespace_name=project.namespace.name, project_name=project.name
)

def create_dataset_version( # noqa: PLR0913
self,
Expand Down Expand Up @@ -992,7 +1003,12 @@ def create_dataset_version( # noqa: PLR0913
)
self.db.execute(query, conn=conn)

return self.get_dataset(dataset.name, dataset.project.id, conn=conn)
return self.get_dataset(
dataset.name,
namespace_name=dataset.project.namespace.name,
project_name=dataset.project.name,
conn=conn,
)

def remove_dataset(self, dataset: DatasetRecord) -> None:
"""Removes dataset."""
Expand Down Expand Up @@ -1216,21 +1232,30 @@ def list_datasets_by_prefix(
def get_dataset(
self,
name: str, # normal, not full dataset name
project_id: Optional[int] = None,
namespace_name: Optional[str] = None,
project_name: Optional[str] = None,
conn=None,
) -> DatasetRecord:
"""
Gets a single dataset in project by dataset name.
"""
project_id = project_id or self.default_project.id
namespace_name = namespace_name or self.default_namespace_name
project_name = project_name or self.default_project_name

d = self._datasets
n = self._namespaces
p = self._projects
query = self._base_dataset_query()
query = query.where(d.c.name == name, d.c.project_id == project_id) # type: ignore [attr-defined]
query = query.where(
d.c.name == name,
n.c.name == namespace_name,
p.c.name == project_name,
) # type: ignore [attr-defined]
ds = self._parse_dataset(self.db.execute(query, conn=conn))
if not ds:
raise DatasetNotFoundError(
f"Dataset {name} not found in project with id {project_id}"
f"Dataset {name} not found in namespace {namespace_name}"
f" and project {project_name}"
)
Comment on lines 1254 to 1259
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (code-quality): We've found these issues:


return ds
Expand Down
Loading
Loading