Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DEVX-300] Support PAT as arg #221

Merged
merged 5 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 36 additions & 20 deletions clarifai/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,29 +28,31 @@ class App(Lister, BaseClient):
"""App is a class that provides access to Clarifai API endpoints related to App information."""

def __init__(self,
url: str = "",
app_id: str = "",
url: str = None,
app_id: str = None,
base_url: str = "https://api.clarifai.com",
pat: str = None,
**kwargs):
"""Initializes an App object.

Args:
url (str): The URL to initialize the app object.
app_id (str): The App ID for the App to interact with.
base_url (str): Base API url. Default "https://api.clarifai.com"
pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
**kwargs: Additional keyword arguments to be passed to the App.
- name (str): The name of the app.
- description (str): The description of the app.
"""
if url != "" and app_id != "":
if url and app_id:
raise UserError("You can only specify one of url or app_id.")
if url != "":
if url:
user_id, app_id, _, _, _ = ClarifaiUrlHelper.split_clarifai_url(url)
kwargs = {'user_id': user_id}
self.kwargs = {**kwargs, 'id': app_id}
self.app_info = resources_pb2.App(**self.kwargs)
self.logger = get_logger(logger_level="INFO", name=__name__)
BaseClient.__init__(self, user_id=self.user_id, app_id=self.id, base=base_url)
BaseClient.__init__(self, user_id=self.user_id, app_id=self.id, base=base_url, pat=pat)
Lister.__init__(self)

def list_datasets(self, page_no: int = None,
Expand Down Expand Up @@ -83,7 +85,7 @@ def list_datasets(self, page_no: int = None,
for dataset_info in all_datasets_info:
if 'version' in list(dataset_info.keys()):
del dataset_info['version']['metrics']
yield Dataset(base_url=self.base, **dataset_info)
yield Dataset(base_url=self.base, pat=self.pat, **dataset_info)

def list_models(self,
filter_by: Dict[str, Any] = {},
Expand Down Expand Up @@ -124,7 +126,7 @@ def list_models(self,
if only_in_app:
if model_info['app_id'] != self.id:
continue
yield Model(base_url=self.base, **model_info)
yield Model(base_url=self.base, pat=self.pat, **model_info)

def list_workflows(self,
filter_by: Dict[str, Any] = {},
Expand Down Expand Up @@ -163,7 +165,7 @@ def list_workflows(self,
if only_in_app:
if workflow_info['app_id'] != self.id:
continue
yield Workflow(base_url=self.base, **workflow_info)
yield Workflow(base_url=self.base, pat=self.pat, **workflow_info)

def list_modules(self,
filter_by: Dict[str, Any] = {},
Expand Down Expand Up @@ -202,7 +204,7 @@ def list_modules(self,
if only_in_app:
if module_info['app_id'] != self.id:
continue
yield Module(base_url=self.base, **module_info)
yield Module(base_url=self.base, pat=self.pat, **module_info)

def list_installed_module_versions(self,
filter_by: Dict[str, Any] = {},
Expand Down Expand Up @@ -238,7 +240,10 @@ def list_installed_module_versions(self,
del imv_info['deploy_url']
del imv_info['installed_module_version_id'] # TODO: remove this after the backend fix
yield Module(
module_id=imv_info['module_version']['module_id'], base_url=self.base, **imv_info)
module_id=imv_info['module_version']['module_id'],
base_url=self.base,
pat=self.pat,
**imv_info)

def list_concepts(self, page_no: int = None,
per_page: int = None) -> Generator[Concept, None, None]:
Expand Down Expand Up @@ -303,7 +308,12 @@ def create_dataset(self, dataset_id: str, **kwargs) -> Dataset:
if response.status.code != status_code_pb2.SUCCESS:
raise Exception(response.status)
self.logger.info("\nDataset created\n%s", response.status)
kwargs.update({'app_id': self.id, 'user_id': self.user_id, 'base_url': self.base})
kwargs.update({
'app_id': self.id,
'user_id': self.user_id,
'base_url': self.base,
'pat': self.pat
})

return Dataset(dataset_id=dataset_id, **kwargs)

Expand Down Expand Up @@ -332,7 +342,8 @@ def create_model(self, model_id: str, **kwargs) -> Model:
'app_id': self.id,
'user_id': self.user_id,
'model_type_id': response.model.model_type_id,
'base_url': self.base
'base_url': self.base,
'pat': self.pat
})

return Model(model_id=model_id, **kwargs)
Expand Down Expand Up @@ -425,7 +436,7 @@ def create_workflow(self,
display_workflow_tree(dict_response["workflows"][0]["nodes"])
kwargs = self.process_response_keys(dict_response[list(dict_response.keys())[1]][0],
"workflow")
kwargs.update({'base_url': self.base})
kwargs.update({'base_url': self.base, 'pat': self.pat})

return Workflow(**kwargs)

Expand Down Expand Up @@ -453,7 +464,12 @@ def create_module(self, module_id: str, description: str, **kwargs) -> Module:
if response.status.code != status_code_pb2.SUCCESS:
raise Exception(response.status)
self.logger.info("\nModule created\n%s", response.status)
kwargs.update({'app_id': self.id, 'user_id': self.user_id, 'base_url': self.base})
kwargs.update({
'app_id': self.id,
'user_id': self.user_id,
'base_url': self.base,
'pat': self.pat
})

return Module(module_id=module_id, **kwargs)

Expand All @@ -480,7 +496,7 @@ def dataset(self, dataset_id: str, **kwargs) -> Dataset:
kwargs = self.process_response_keys(dict_response[list(dict_response.keys())[1]],
list(dict_response.keys())[1])
kwargs['version'] = response.dataset.version if response.dataset.version else None
kwargs.update({'base_url': self.base})
kwargs.update({'base_url': self.base, 'pat': self.pat})
return Dataset(**kwargs)

def model(self, model_id: str, model_version_id: str = "", **kwargs) -> Model:
Expand Down Expand Up @@ -516,7 +532,7 @@ def model(self, model_id: str, model_version_id: str = "", **kwargs) -> Model:
kwargs = self.process_response_keys(dict_response['model'], 'model')
kwargs[
'model_version'] = response.model.model_version if response.model.model_version else None
kwargs.update({'base_url': self.base})
kwargs.update({'base_url': self.base, 'pat': self.pat})

return Model(**kwargs)

Expand All @@ -542,7 +558,7 @@ def workflow(self, workflow_id: str, **kwargs) -> Workflow:
dict_response = MessageToDict(response, preserving_proto_field_name=True)
kwargs = self.process_response_keys(dict_response[list(dict_response.keys())[1]],
list(dict_response.keys())[1])
kwargs.update({'base_url': self.base})
kwargs.update({'base_url': self.base, 'pat': self.pat})

return Workflow(**kwargs)

Expand All @@ -569,7 +585,7 @@ def module(self, module_id: str, module_version_id: str = "", **kwargs) -> Modul
raise Exception(response.status)
dict_response = MessageToDict(response, preserving_proto_field_name=True)
kwargs = self.process_response_keys(dict_response['module'], 'module')
kwargs.update({'base_url': self.base})
kwargs.update({'base_url': self.base, 'pat': self.pat})

return Module(**kwargs)

Expand All @@ -579,7 +595,7 @@ def inputs(self,):
Returns:
Inputs: An input object.
"""
return Inputs(self.user_id, self.id, base_url=self.base)
return Inputs(self.user_id, self.id, base_url=self.base, pat=self.pat)

def delete_dataset(self, dataset_id: str) -> None:
"""Deletes an dataset for the user.
Expand Down Expand Up @@ -670,7 +686,7 @@ def search(self, **kwargs) -> Model:
"""
user_id = kwargs.get("user_id", self.user_app_id.user_id)
app_id = kwargs.get("app_id", self.user_app_id.app_id)
return Search(user_id=user_id, app_id=app_id, base_url=self.base, **kwargs)
return Search(user_id=user_id, app_id=app_id, base_url=self.base, pat=self.pat, **kwargs)

def __getattr__(self, name):
return getattr(self.app_info, name)
Expand Down
12 changes: 6 additions & 6 deletions clarifai/client/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from datetime import datetime
from typing import Any, Callable

Expand All @@ -8,7 +7,8 @@

from clarifai.client.auth import create_stub
from clarifai.client.auth.helper import ClarifaiAuthHelper
from clarifai.errors import ApiError, UserError
from clarifai.errors import ApiError
from clarifai.utils.misc import get_from_dict_or_env


class BaseClient:
Expand All @@ -31,12 +31,12 @@ class BaseClient:
"""

def __init__(self, **kwargs):
pat = os.environ.get('CLARIFAI_PAT', "")
if pat == "":
raise UserError("CLARIFAI_PAT must be set as env vars")
self.auth_helper = ClarifaiAuthHelper(**kwargs, pat=pat, validate=False)
pat = get_from_dict_or_env(key="pat", env_key="CLARIFAI_PAT", **kwargs)
kwargs.update({'pat': pat})
self.auth_helper = ClarifaiAuthHelper(**kwargs, validate=False)
self.STUB = create_stub(self.auth_helper)
self.metadata = self.auth_helper.metadata
self.pat = self.auth_helper.pat
self.user_app_id = self.auth_helper.get_user_app_id_proto()
self.base = self.auth_helper.base

Expand Down
18 changes: 11 additions & 7 deletions clarifai/client/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,23 @@ class Dataset(Lister, BaseClient):
"""Dataset is a class that provides access to Clarifai API endpoints related to Dataset information."""

def __init__(self,
url: str = "",
dataset_id: str = "",
url: str = None,
dataset_id: str = None,
base_url: str = "https://api.clarifai.com",
pat: str = None,
**kwargs):
"""Initializes a Dataset object.

Args:
url (str): The URL to initialize the dataset object.
dataset_id (str): The Dataset ID within the App to interact with.
base_url (str): Base API url. Default "https://api.clarifai.com"
pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
**kwargs: Additional keyword arguments to be passed to the Dataset.
"""
if url != "" and dataset_id != "":
if url and dataset_id:
raise UserError("You can only specify one of url or dataset_id.")
if url != "":
if url:
user_id, app_id, _, dataset_id, _ = ClarifaiUrlHelper.split_clarifai_url(url)
kwargs = {'user_id': user_id, 'app_id': app_id}
self.kwargs = {**kwargs, 'id': dataset_id}
Expand All @@ -61,7 +63,7 @@ def __init__(self,
self.task = None # Upload dataset type
self.input_object = Inputs(user_id=self.user_id, app_id=self.app_id)
self.logger = get_logger(logger_level="INFO")
BaseClient.__init__(self, user_id=self.user_id, app_id=self.app_id, base=base_url)
BaseClient.__init__(self, user_id=self.user_id, app_id=self.app_id, base=base_url, pat=pat)
Lister.__init__(self)

def create_version(self, **kwargs) -> 'Dataset':
Expand Down Expand Up @@ -94,7 +96,8 @@ def create_version(self, **kwargs) -> 'Dataset':
'app_id': self.app_id,
'user_id': self.user_id,
'version': response.dataset_versions[0],
'base_url': self.base
'base_url': self.base,
'pat': self.pat
})
return Dataset(**kwargs)

Expand Down Expand Up @@ -157,7 +160,8 @@ def list_versions(self, page_no: int = None,
'app_id': self.app_id,
'user_id': self.user_id,
'version': resources_pb2.DatasetVersion(**dataset_version_info),
'base_url': self.base
'base_url': self.base,
'pat': self.pat
}
yield Dataset(**kwargs)

Expand Down
Loading
Loading