Skip to content

Commit

Permalink
cleanup: always require org_id when fetching a dataset by tracking_id
Browse files Browse the repository at this point in the history
  • Loading branch information
skeptrunedev committed Dec 10, 2024
1 parent 856dbb2 commit 6137e19
Show file tree
Hide file tree
Showing 14 changed files with 131 additions and 102 deletions.
4 changes: 2 additions & 2 deletions clients/ts-sdk/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -3938,9 +3938,9 @@
"operationId": "get_dataset_by_tracking_id",
"parameters": [
{
"name": "TR-Dataset",
"name": "TR-Organization",
"in": "header",
"description": "The dataset id or tracking_id to use for the request. We assume you intend to use an id if the value is a valid uuid.",
"description": "The organization id to use for the request",
"required": true,
"schema": {
"type": "string",
Expand Down
2 changes: 1 addition & 1 deletion clients/ts-sdk/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"files": [
"dist"
],
"version": "0.0.40",
"version": "0.0.41",
"license": "MIT",
"scripts": {
"lint": "eslint 'src/**/*.ts'",
Expand Down
8 changes: 7 additions & 1 deletion clients/ts-sdk/src/functions/datasets/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -219,11 +219,17 @@ export async function getDatasetByTrackingId(
trackingId: string,
signal?: AbortSignal
): Promise<Dataset> {
if (!this.organizationId) {
throw new Error(
"Organization ID is required to get a dataset by tracking ID"
);
}

return this.trieve.fetch(
"/api/dataset/tracking_id/{tracking_id}",
"get",
{
datasetId: trackingId,
organizationId: this.organizationId,
trackingId,
},
signal
Expand Down
4 changes: 2 additions & 2 deletions clients/ts-sdk/src/types.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4149,9 +4149,9 @@ export type GetDatasetByTrackingIdData = {
*/
trackingId: string;
/**
* The dataset id or tracking_id to use for the request. We assume you intend to use an id if the value is a valid uuid.
* The organization id to use for the request
*/
trDataset: string;
trOrganization: string;
};

export type GetDatasetByTrackingIdResponse = (Dataset);
Expand Down
23 changes: 9 additions & 14 deletions server/src/bin/delete-worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ use std::sync::{
Arc,
};
use trieve_server::{
data::models::{self, DatasetConfiguration, UnifiedId},
data::models::{self, DatasetConfiguration},
errors::ServiceError,
establish_connection, get_env,
operators::{
chunk_operator::bulk_delete_chunks_query,
clickhouse_operator::{ClickHouseEvent, EventQueue},
dataset_operator::{
clear_dataset_query, delete_dataset_by_id_query, get_dataset_by_id_query,
get_deleted_dataset_by_unifiedid_query, ChunkDeleteMessage, DatasetDeleteMessage,
get_deleted_dataset_by_id_query, ChunkDeleteMessage, DatasetDeleteMessage,
DeleteMessage,
},
organization_operator::{
Expand Down Expand Up @@ -254,12 +254,10 @@ pub async fn delete_or_clear_dataset(
delete_worker_message: DatasetDeleteMessage,
event_queue: actix_web::web::Data<EventQueue>,
) -> Result<(), ServiceError> {
let dataset = get_deleted_dataset_by_unifiedid_query(
models::UnifiedId::TrieveUuid(delete_worker_message.dataset_id),
web_pool.clone(),
)
.await
.map_err(|err| ServiceError::BadRequest(format!("Failed to get dataset: {:?}", err)))?;
let dataset =
get_deleted_dataset_by_id_query(delete_worker_message.dataset_id, web_pool.clone())
.await
.map_err(|err| ServiceError::BadRequest(format!("Failed to get dataset: {:?}", err)))?;

let mut redis_connection = redis_pool.get().await.map_err(|err| {
ServiceError::BadRequest(format!("Failed to get redis connection: {:?}", err))
Expand Down Expand Up @@ -381,12 +379,9 @@ pub async fn bulk_delete_chunks(
"Bulk deleting chunks for dataset: {:?}",
chunk_delete_message.dataset_id
);
let dataset = get_dataset_by_id_query(
UnifiedId::TrieveUuid(chunk_delete_message.dataset_id),
web_pool.clone(),
)
.await
.map_err(|err| ServiceError::BadRequest(format!("Failed to get dataset: {:?}", err)))?;
let dataset = get_dataset_by_id_query(chunk_delete_message.dataset_id, web_pool.clone())
.await
.map_err(|err| ServiceError::BadRequest(format!("Failed to get dataset: {:?}", err)))?;
let dataset_config = DatasetConfiguration::from_json(dataset.server_configuration);

bulk_delete_chunks_query(
Expand Down
7 changes: 2 additions & 5 deletions server/src/bin/grupdate-worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,8 @@ async fn grupdate_worker(
}
};

let dataset_result = get_dataset_by_id_query(
models::UnifiedId::TrieveUuid(group_update_msg.dataset_id),
web_pool.clone(),
)
.await;
let dataset_result =
get_dataset_by_id_query(group_update_msg.dataset_id, web_pool.clone()).await;
let dataset = match dataset_result {
Ok(dataset) => dataset,
Err(err) => {
Expand Down
8 changes: 3 additions & 5 deletions server/src/bin/ingestion-worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::collections::HashMap;
use std::sync::{atomic::AtomicBool, atomic::Ordering, Arc};
use trieve_server::data::models::{
self, ChunkBoost, ChunkData, ChunkGroup, ChunkMetadata, DatasetConfiguration, QdrantPayload,
UnifiedId, WorkerEvent,
WorkerEvent,
};
use trieve_server::errors::ServiceError;
use trieve_server::handlers::chunk_handler::{
Expand Down Expand Up @@ -222,12 +222,10 @@ async fn ingestion_worker(
let dataset_result: Result<models::Dataset, ServiceError> = match ingestion_message.clone()
{
IngestionMessage::Update(payload) => {
get_dataset_by_id_query(UnifiedId::TrieveUuid(payload.dataset_id), web_pool.clone())
.await
get_dataset_by_id_query(payload.dataset_id, web_pool.clone()).await
}
IngestionMessage::BulkUpload(payload) => {
get_dataset_by_id_query(UnifiedId::TrieveUuid(payload.dataset_id), web_pool.clone())
.await
get_dataset_by_id_query(payload.dataset_id, web_pool.clone()).await
}
};
let dataset = match dataset_result {
Expand Down
43 changes: 27 additions & 16 deletions server/src/handlers/dataset_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::{
data::models::{
CrawlOptions, Dataset, DatasetAndOrgWithSubAndPlan, DatasetConfiguration,
DatasetConfigurationDTO, DatasetDTO, OrganizationWithSubAndPlan, Pool, RedisPool,
StripePlan, UnifiedId,
StripePlan,
},
errors::ServiceError,
middleware::auth_middleware::{verify_admin, verify_owner},
Expand All @@ -14,8 +14,9 @@ use crate::{
},
dataset_operator::{
clear_dataset_by_dataset_id_query, create_dataset_query, create_datasets_query,
get_dataset_by_id_query, get_dataset_usage_query, get_datasets_by_organization_id,
get_tags_in_dataset_query, soft_delete_dataset_by_id_query, update_dataset_query,
get_dataset_by_id_query, get_dataset_by_tracking_id_query, get_dataset_usage_query,
get_datasets_by_organization_id, get_tags_in_dataset_query,
soft_delete_dataset_by_id_query, update_dataset_query,
},
dittofeed_operator::{
send_ditto_event, DittoDatasetCreated, DittoTrackProperties, DittoTrackRequest,
Expand Down Expand Up @@ -276,11 +277,17 @@ pub async fn update_dataset(
pool: web::Data<Pool>,
redis_pool: web::Data<RedisPool>,
user: OwnerOnly,
org_with_plan_and_sub: OrganizationWithSubAndPlan,
) -> Result<HttpResponse, ServiceError> {
let curr_dataset = if let Some(dataset_id) = data.dataset_id {
get_dataset_by_id_query(UnifiedId::TrieveUuid(dataset_id), pool.clone()).await?
get_dataset_by_id_query(dataset_id, pool.clone()).await?
} else if let Some(tracking_id) = &data.tracking_id {
get_dataset_by_id_query(UnifiedId::TrackingId(tracking_id.clone()), pool.clone()).await?
get_dataset_by_tracking_id_query(
tracking_id.clone(),
org_with_plan_and_sub.organization.id,
pool.clone(),
)
.await?
} else {
return Err(ServiceError::BadRequest(
"You must provide a dataset_id or tracking_id to update a dataset".to_string(),
Expand Down Expand Up @@ -363,8 +370,7 @@ pub async fn get_dataset_crawl_options(
dataset_id: web::Path<uuid::Uuid>,
user: AdminOnly,
) -> Result<HttpResponse, ServiceError> {
let d = get_dataset_by_id_query(UnifiedId::TrieveUuid(dataset_id.into_inner()), pool.clone())
.await?;
let d = get_dataset_by_id_query(dataset_id.into_inner(), pool.clone()).await?;
let crawl_req = get_crawl_request_by_dataset_id_query(d.id, pool).await?;

if !verify_admin(&user, &d.organization_id) {
Expand Down Expand Up @@ -491,9 +497,11 @@ pub async fn delete_dataset_by_tracking_id(
pool: web::Data<Pool>,
redis_pool: web::Data<RedisPool>,
user: OwnerOnly,
org_with_plan_and_sub: OrganizationWithSubAndPlan,
) -> Result<HttpResponse, ServiceError> {
let dataset = get_dataset_by_id_query(
UnifiedId::TrackingId(tracking_id.into_inner()),
let dataset = get_dataset_by_tracking_id_query(
tracking_id.into_inner(),
org_with_plan_and_sub.organization.id,
pool.clone(),
)
.await?;
Expand Down Expand Up @@ -535,8 +543,7 @@ pub async fn get_dataset(
dataset_id: web::Path<uuid::Uuid>,
user: AdminOnly,
) -> Result<HttpResponse, ServiceError> {
let mut dataset =
get_dataset_by_id_query(UnifiedId::TrieveUuid(dataset_id.into_inner()), pool).await?;
let mut dataset = get_dataset_by_id_query(dataset_id.into_inner(), pool).await?;

if !verify_admin(&user, &dataset.organization_id) {
return Err(ServiceError::Forbidden);
Expand Down Expand Up @@ -596,7 +603,7 @@ pub async fn get_usage_by_dataset_id(
(status = 404, description = "Dataset not found", body = ErrorResponseBody)
),
params(
("TR-Dataset" = uuid::Uuid, Header, description = "The dataset id or tracking_id to use for the request. We assume you intend to use an id if the value is a valid uuid."),
("TR-Organization" = uuid::Uuid, Header, description = "The organization id to use for the request"),
("tracking_id" = String, Path, description = "The tracking id of the dataset you want to retrieve."),
),
security(
Expand All @@ -607,11 +614,15 @@ pub async fn get_dataset_by_tracking_id(
tracking_id: web::Path<String>,
pool: web::Data<Pool>,
user: AdminOnly,
org_with_plan_and_sub: OrganizationWithSubAndPlan,
) -> Result<HttpResponse, ServiceError> {
let mut dataset =
get_dataset_by_id_query(UnifiedId::TrackingId(tracking_id.into_inner()), pool)
.await
.map_err(|e| ServiceError::InternalServerError(e.to_string()))?;
let mut dataset = get_dataset_by_tracking_id_query(
tracking_id.into_inner(),
org_with_plan_and_sub.organization.id,
pool,
)
.await
.map_err(|e| ServiceError::InternalServerError(e.to_string()))?;

if !verify_admin(&user, &dataset.organization_id) {
return Err(ServiceError::Forbidden);
Expand Down
4 changes: 2 additions & 2 deletions server/src/handlers/page_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use super::{
};
use crate::data::models::Templates;
use crate::{
data::models::{DatasetConfiguration, Pool, SearchMethod, SortOptions, TypoOptions, UnifiedId},
data::models::{DatasetConfiguration, Pool, SearchMethod, SortOptions, TypoOptions},
errors::ServiceError,
get_env,
operators::dataset_operator::get_dataset_by_id_query,
Expand Down Expand Up @@ -229,7 +229,7 @@ pub async fn public_page(
) -> Result<HttpResponse, ServiceError> {
let dataset_id = dataset_id.into_inner();

let dataset = get_dataset_by_id_query(UnifiedId::TrieveUuid(dataset_id), pool).await?;
let dataset = get_dataset_by_id_query(dataset_id, pool).await?;

let config = DatasetConfiguration::from_json(dataset.server_configuration);

Expand Down
4 changes: 2 additions & 2 deletions server/src/middleware/auth_middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,8 @@ where
pub fn get_role_for_org(user: &SlimUser, org_id: &uuid::Uuid) -> Option<UserRole> {
user.user_orgs
.iter()
.find(|org_conn| org_conn.organization_id == *org_id)
.map(|org_conn| UserRole::from(org_conn.role))
.find(|user_org| user_org.organization_id == *org_id)
.map(|user_org| UserRole::from(user_org.role))
}

pub fn verify_owner(user: &OwnerOnly, org_id: &uuid::Uuid) -> bool {
Expand Down
94 changes: 61 additions & 33 deletions server/src/operators/dataset_operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ pub async fn create_datasets_query(
}

pub async fn get_dataset_by_id_query(
id: UnifiedId,
id: uuid::Uuid,
pool: web::Data<Pool>,
) -> Result<Dataset, ServiceError> {
use crate::data::schema::datasets::dsl as datasets_columns;
Expand All @@ -115,28 +115,20 @@ pub async fn get_dataset_by_id_query(
.await
.map_err(|_| ServiceError::BadRequest("Could not get database connection".to_string()))?;

let dataset = match id {
UnifiedId::TrieveUuid(id) => datasets_columns::datasets
.filter(datasets_columns::id.eq(id))
.filter(datasets_columns::deleted.eq(0))
.select(Dataset::as_select())
.first(&mut conn)
.await
.map_err(|_| ServiceError::NotFound("Could not find dataset".to_string()))?,
UnifiedId::TrackingId(id) => datasets_columns::datasets
.filter(datasets_columns::tracking_id.eq(id))
.filter(datasets_columns::deleted.eq(0))
.select(Dataset::as_select())
.first(&mut conn)
.await
.map_err(|_| ServiceError::NotFound("Could not find dataset".to_string()))?,
};
let dataset = datasets_columns::datasets
.filter(datasets_columns::id.eq(id))
.filter(datasets_columns::deleted.eq(0))
.select(Dataset::as_select())
.first(&mut conn)
.await
.map_err(|_| ServiceError::NotFound("Could not find dataset".to_string()))?;

Ok(dataset)
}

pub async fn get_deleted_dataset_by_unifiedid_query(
id: UnifiedId,
pub async fn get_dataset_by_tracking_id_query(
tracking_id: String,
org_id: uuid::Uuid,
pool: web::Data<Pool>,
) -> Result<Dataset, ServiceError> {
use crate::data::schema::datasets::dsl as datasets_columns;
Expand All @@ -145,20 +137,56 @@ pub async fn get_deleted_dataset_by_unifiedid_query(
.await
.map_err(|_| ServiceError::BadRequest("Could not get database connection".to_string()))?;

let dataset = match id {
UnifiedId::TrieveUuid(id) => datasets_columns::datasets
.filter(datasets_columns::id.eq(id))
.select(Dataset::as_select())
.first(&mut conn)
.await
.map_err(|_| ServiceError::NotFound("Could not find dataset".to_string()))?,
UnifiedId::TrackingId(id) => datasets_columns::datasets
.filter(datasets_columns::tracking_id.eq(id))
.select(Dataset::as_select())
.first(&mut conn)
.await
.map_err(|_| ServiceError::NotFound("Could not find dataset".to_string()))?,
};
let dataset = datasets_columns::datasets
.filter(datasets_columns::tracking_id.eq(tracking_id))
.filter(datasets_columns::organization_id.eq(org_id))
.filter(datasets_columns::deleted.eq(0))
.select(Dataset::as_select())
.first(&mut conn)
.await
.map_err(|_| ServiceError::NotFound("Could not find dataset".to_string()))?;

Ok(dataset)
}

pub async fn get_deleted_dataset_by_id_query(
id: uuid::Uuid,
pool: web::Data<Pool>,
) -> Result<Dataset, ServiceError> {
use crate::data::schema::datasets::dsl as datasets_columns;
let mut conn = pool
.get()
.await
.map_err(|_| ServiceError::BadRequest("Could not get database connection".to_string()))?;

let dataset = datasets_columns::datasets
.filter(datasets_columns::id.eq(id))
.select(Dataset::as_select())
.first(&mut conn)
.await
.map_err(|_| ServiceError::NotFound("Could not find dataset".to_string()))?;

Ok(dataset)
}

pub async fn get_deleted_dataset_by_tracking_id_query(
tracking_id: String,
org_id: uuid::Uuid,
pool: web::Data<Pool>,
) -> Result<Dataset, ServiceError> {
use crate::data::schema::datasets::dsl as datasets_columns;
let mut conn = pool
.get()
.await
.map_err(|_| ServiceError::BadRequest("Could not get database connection".to_string()))?;

let dataset = datasets_columns::datasets
.filter(datasets_columns::tracking_id.eq(tracking_id))
.filter(datasets_columns::organization_id.eq(org_id))
.select(Dataset::as_select())
.first(&mut conn)
.await
.map_err(|_| ServiceError::NotFound("Could not find dataset".to_string()))?;

Ok(dataset)
}
Expand Down
Loading

0 comments on commit 6137e19

Please sign in to comment.