Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

create dataset #2074

Merged
merged 2 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
30 changes: 30 additions & 0 deletions sdk/python/ragflow/modules/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
class Base(object):
def __init__(self, rag, res_dict):
self.rag = rag
for k, v in res_dict.items():
if isinstance(v, dict):
self.__dict__[k] = Base(rag, v)
else:
self.__dict__[k] = v

def to_json(self):
pr = {}
for name in dir(self):
value = getattr(self, name)
if not name.startswith('__') and not callable(value) and name != "rag":
if isinstance(value, Base):
pr[name] = value.to_json()
else:
pr[name] = value
return pr


def post(self, path, param):
res = self.rag.post(path,param)
return res

def get(self, path, params=''):
res = self.rag.get(path,params)
return res


33 changes: 33 additions & 0 deletions sdk/python/ragflow/modules/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from .base import Base


class DataSet(Base):
class ParseConfig(Base):
def __init__(self, rag, res_dict):
self.chunk_token_count = 128
self.layout_recognize = True
self.delimiter = '\n!?。;!?'
self.task_page_size = 12
super().__init__(rag, res_dict)

def __init__(self, rag, res_dict):
self.id = ""
self.name = ""
self.avatar = ""
self.tenant_id = None
self.description = ""
self.language = "English"
self.embedding_model = ""
self.permission = "me"
self.document_count = 0
self.chunk_count = 0
self.parse_method = 0
self.parser_config = None
super().__init__(rag, res_dict)

def delete(self):
try:
self.post("/rm", {"kb_id": self.id})
return True
except Exception:
return False
61 changes: 36 additions & 25 deletions sdk/python/ragflow/ragflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,35 +12,56 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os

import requests

from api.db.services.document_service import DocumentService
from api.settings import RetCode
from .modules.dataset import DataSet


class RAGFlow:
def __init__(self, user_key, base_url, version='v1'):
"""
api_url: http://<host_address>/api/v1
dataset_url: http://<host_address>/api/v1/dataset
document_url: http://<host_address>/api/v1/dataset/{dataset_id}/documents
api_url: http://<host_address>/v1
dataset_url: http://<host_address>/v1/kb
document_url: http://<host_address>/v1/dataset/{dataset_id}/documents
"""
self.user_key = user_key
self.api_url = f"{base_url}/api/{version}"
self.dataset_url = f"{self.api_url}/dataset"
self.api_url = f"{base_url}/{version}"
self.dataset_url = f"{self.api_url}/kb"
self.authorization_header = {"Authorization": "{}".format(self.user_key)}
self.base_url = base_url

def post(self, path, param):
res = requests.post(url=self.dataset_url + path, json=param, headers=self.authorization_header)
return res

def get(self, path, params=''):
res = requests.get(self.dataset_url + path, params=params, headers=self.authorization_header)
return res

def create_dataset(self, dataset_name):
"""
name: dataset name
"""
res = requests.post(url=self.dataset_url, json={"name": dataset_name}, headers=self.authorization_header)
result_dict = json.loads(res.text)
return result_dict

res_create = self.post("/create", {"name": dataset_name})
res_create_data = res_create.json()['data']
res_detail = self.get("/detail", {"kb_id": res_create_data["kb_id"]})
res_detail_data = res_detail.json()['data']
result = {}
result['id'] = res_detail_data['id']
result['name'] = res_detail_data['name']
result['avatar'] = res_detail_data['avatar']
result['description'] = res_detail_data['description']
result['language'] = res_detail_data['language']
result['embedding_model'] = res_detail_data['embd_id']
result['permission'] = res_detail_data['permission']
result['document_count'] = res_detail_data['doc_num']
result['chunk_count'] = res_detail_data['chunk_num']
result['parser_config'] = res_detail_data['parser_config']
dataset = DataSet(self, result)
return dataset

"""
def delete_dataset(self, dataset_name):
dataset_id = self.find_dataset_id_by_name(dataset_name)

Expand All @@ -55,16 +76,6 @@ def find_dataset_id_by_name(self, dataset_name):
return dataset["id"]
return None

def list_dataset(self, offset=0, count=-1, orderby="create_time", desc=True):
params = {
"offset": offset,
"count": count,
"orderby": orderby,
"desc": desc
}
response = requests.get(url=self.dataset_url, params=params, headers=self.authorization_header)
return response.json()

def get_dataset(self, dataset_name):
dataset_id = self.find_dataset_id_by_name(dataset_name)
endpoint = f"{self.dataset_url}/{dataset_id}"
Expand All @@ -78,7 +89,7 @@ def update_dataset(self, dataset_name, **params):
response = requests.put(endpoint, json=params, headers=self.authorization_header)
return response.json()

# ------------------------------- CONTENT MANAGEMENT -----------------------------------------------------
# ------------------------------- CONTENT MANAGEMENT -----------------------------------------------------

# ----------------------------upload local files-----------------------------------------------------
def upload_local_file(self, dataset_id, file_paths):
Expand Down Expand Up @@ -186,4 +197,4 @@ def show_parsing_status(self, dataset_id, document_id):
# ----------------------------get a specific chunk-----------------------------------------------------

# ----------------------------retrieval test-----------------------------------------------------

"""
4 changes: 3 additions & 1 deletion sdk/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,7 @@
import setuptools

if __name__ == "__main__":
setuptools.setup(packages=['ragflow'])
setuptools.setup(name='ragflow',
version="0.1",
packages=setuptools.find_packages())

2 changes: 1 addition & 1 deletion sdk/python/test/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@


API_KEY = 'IjJkOGQ4ZDE2MzkyMjExZWZhYTk0MzA0M2Q3ZWU1MzdlIg.ZoUfug.RmqcYyCrlAnLtkzk6bYXiXN3eEY'
API_KEY = 'IjUxNGM0MmM4NWY5MzExZWY5MDhhMDI0MmFjMTIwMDA2Ig.ZsWebA.mV1NKdSPPllgowiH-7vz36tMWyI'
HOST_ADDRESS = 'http://127.0.0.1:9380'
23 changes: 23 additions & 0 deletions sdk/python/test/t_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from ragflow import RAGFlow

from common import API_KEY, HOST_ADDRESS
from test_sdkbase import TestSdk


class TestDataset(TestSdk):
def test_create_dataset_with_success(self):
rag = RAGFlow(API_KEY, HOST_ADDRESS)
ds = rag.create_dataset("God")
assert ds is not None, "The dataset creation failed, returned None."
assert ds.name == "God", "Dataset name does not match."

def test_delete_one_file(self):
"""
Test deleting one file with success.
"""
rag = RAGFlow(API_KEY, HOST_ADDRESS)
ds = rag.create_dataset("ABC")
assert ds is not None, "Failed to create dataset"
assert ds.name == "ABC", "Dataset name mismatch"
delete_result = ds.delete()
assert delete_result is True, "Failed to delete dataset"