Skip to content

Commit 34389fd

Browse files
feat: update quickstart loaded datasets (#4038)
# Description This PR removes loading the error analysis datasets from the `load_data.py` script used by the Docker quickstart image and adds loading the [Text Descriptives Metadata](https://huggingface.co/datasets/argilla/text-descriptives-metadata) dataset from Hugging Face. Closes #<issue_number> **Type of change** - [x] New feature (non-breaking change which adds functionality) **How Has This Been Tested** (Please describe the tests that you ran to verify your changes. And ideally, reference `tests`) - [ ] Test A - [ ] Test B **Checklist** - [ ] I added relevant documentation - [x] I followed the style guidelines of this project - [x] I did a self-review of my code - [ ] I made corresponding changes to the documentation - [x] My changes generate no new warnings - [x] I have added tests that prove my fix is effective or that my feature works - [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK) (see text above) - [ ] I have added relevant notes to the `CHANGELOG.md` file (See https://keepachangelog.com/) --------- Co-authored-by: Francisco Aranda <[email protected]>
1 parent 4e7afdf commit 34389fd

File tree

2 files changed

+17
-112
lines changed

2 files changed

+17
-112
lines changed

docker/scripts/load_data.py

+3-101
Original file line numberDiff line numberDiff line change
@@ -162,104 +162,6 @@ def load_feedback_dataset_from_huggingface(repo_id: str, split: str = "train", s
162162

163163
dataset.push_to_argilla(name=repo_id.split("/")[-1])
164164

165-
@staticmethod
166-
def build_error_analysis_record(
167-
row: pd.Series, legacy: bool = False
168-
) -> Union[rg.FeedbackRecord, rg.TextClassificationRecord]:
169-
fields = {
170-
"user-message-1": row["HumanMessage1"],
171-
"llm-output": row["llm_output"]
172-
if not row["llm_output"].__contains__("```json")
173-
else row["llm_output"].replace("'", '"'),
174-
"ai-message": (f"```json\n{row['AIMessage']}\n```" if not legacy else row["AIMessage"]).replace("'", '"'),
175-
"function-message": (f"```json\n{row['FunctionMessage']}\n```" if not legacy else row["AIMessage"]).replace(
176-
"'", '"'
177-
),
178-
"system-message": "You are an AI assistant name ACME",
179-
"langsmith-url": f"https://smith.langchain.com/o/{row['parent_id']}",
180-
}
181-
metadata = {
182-
"correctness-langsmith": row["correctness_langsmith"],
183-
"model-name": row["model_name"],
184-
"temperature": row["temperature"],
185-
"max-tokens": int(row["max_tokens"]),
186-
"cpu-user": row["cpu_time_user"],
187-
"cpu-system": row["cpu_time_system"],
188-
"library-version": row["library_version"],
189-
}
190-
191-
if legacy:
192-
return rg.TextClassificationRecord(
193-
inputs=fields, metadata=metadata, vectors=eval(row["vectors"]), multi_label=True
194-
)
195-
return rg.FeedbackRecord(fields=fields, metadata=metadata)
196-
197-
@staticmethod
198-
def load_error_analysis(with_metadata_property_options: bool = True):
199-
print("Loading Error Analysis dataset as a `FeedbackDataset` (Alpha)")
200-
df = pd.read_csv("https://raw.githubusercontent.com/argilla-io/dataset_examples/main/synthetic_data_v2.csv")
201-
202-
fields = [
203-
rg.TextField(name="user-message-1", use_markdown=True),
204-
rg.TextField(name="llm-output", use_markdown=True),
205-
rg.TextField(name="ai-message", use_markdown=True, required=False),
206-
rg.TextField(name="function-message", use_markdown=True, required=False),
207-
rg.TextField(name="system-message", use_markdown=True, required=False),
208-
rg.TextField(name="langsmith-url", use_markdown=True, required=False),
209-
]
210-
211-
questions = [
212-
rg.MultiLabelQuestion(
213-
name="issue",
214-
title="Please categorize the record:",
215-
labels=["follow-up needed", "reviewed", "no-repro", "not-helpful", "empty-response", "critical"],
216-
),
217-
rg.TextQuestion(name="note", title="Leave a note to describe the issue:", required=False),
218-
]
219-
220-
dataset_name = "error-analysis-with-feedback"
221-
222-
if with_metadata_property_options:
223-
metadata = [
224-
rg.TermsMetadataProperty(
225-
name="correctness-langsmith", values=df.correctness_langsmith.unique().tolist()
226-
),
227-
rg.TermsMetadataProperty(name="model-name", values=df.model_name.unique().tolist()),
228-
rg.FloatMetadataProperty(name="temperature", min=df.temperature.min(), max=df.temperature.max()),
229-
rg.FloatMetadataProperty(name="cpu-user", min=df.cpu_time_user.min(), max=df.cpu_time_user.max()),
230-
rg.FloatMetadataProperty(name="cpu-system", min=df.cpu_time_system.min(), max=df.cpu_time_system.max()),
231-
rg.TermsMetadataProperty(name="library-version", values=df.library_version.unique().tolist()),
232-
]
233-
else:
234-
dataset_name += "-no-settings"
235-
236-
metadata = [
237-
rg.TermsMetadataProperty(name="correctness-langsmith"),
238-
rg.TermsMetadataProperty(name="model-name"),
239-
rg.FloatMetadataProperty(name="temperature"),
240-
rg.FloatMetadataProperty(name="cpu-user"),
241-
rg.FloatMetadataProperty(name="cpu-system"),
242-
rg.TermsMetadataProperty(name="library-version"),
243-
]
244-
245-
dataset = rg.FeedbackDataset(fields=fields, questions=questions, metadata_properties=metadata)
246-
dataset.add_records(records=[LoadDatasets.build_error_analysis_record(row) for _, row in df.iterrows()])
247-
dataset.push_to_argilla(name=dataset_name)
248-
249-
@staticmethod
250-
def load_error_analysis_textcat_version():
251-
print("Loading Error Analysis dataset as a `DatasetForTextClassification`")
252-
df = pd.read_csv(
253-
"https://raw.githubusercontent.com/argilla-io/dataset_examples/main/synthetic_data_v2_with_vectors.csv"
254-
)
255-
256-
labels = ["follow-up needed", "reviewed", "no-repro", "not-helpful", "empty-response", "critical"]
257-
settings = rg.TextClassificationSettings(label_schema=labels)
258-
rg.configure_dataset_settings(name="error-analysis-with-text-classification", settings=settings)
259-
260-
records = [LoadDatasets.build_error_analysis_record(row, legacy=True) for _, row in df.iterrows()]
261-
rg.log(name="error-analysis-with-text-classification", records=records, batch_size=25)
262-
263165

264166
if __name__ == "__main__":
265167
API_KEY = sys.argv[1]
@@ -274,9 +176,6 @@ def load_error_analysis_textcat_version():
274176
response = requests.get("http://0.0.0.0:6900")
275177
if response.status_code == 200:
276178
ld = LoadDatasets(API_KEY)
277-
ld.load_error_analysis(with_metadata_property_options=False)
278-
ld.load_error_analysis()
279-
ld.load_error_analysis_textcat_version()
280179
ld.load_feedback_dataset_from_huggingface(
281180
repo_id="argilla/databricks-dolly-15k-curated-en", split="train", samples=100
282181
)
@@ -296,6 +195,9 @@ def load_error_analysis_textcat_version():
296195
ld.load_feedback_dataset_from_huggingface(
297196
repo_id="argilla/oasst_response_comparison", split="train", samples=100
298197
)
198+
ld.load_feedback_dataset_from_huggingface(
199+
repo_id="argilla/text-descriptives-metadata", split="train", samples=100
200+
)
299201
except requests.exceptions.ConnectionError:
300202
pass
301203
except Exception as e:

src/argilla/client/feedback/integrations/huggingface/dataset.py

+14-11
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,13 @@ def from_huggingface(cls: Type["FeedbackDataset"], repo_id: str, *args: Any, **k
300300
)
301301
with open(config_path, "r") as f:
302302
config = DatasetConfig.from_yaml(f.read())
303+
dataset = cls(
304+
fields=config.fields,
305+
questions=config.questions,
306+
guidelines=config.guidelines,
307+
metadata_properties=config.metadata_properties,
308+
allow_extra_metadata=config.allow_extra_metadata,
309+
)
303310
except EntryNotFoundError:
304311
# TODO(alvarobartt): here for backwards compatibility, last used in 1.12.0
305312
warnings.warn(
@@ -318,6 +325,7 @@ def from_huggingface(cls: Type["FeedbackDataset"], repo_id: str, *args: Any, **k
318325
)
319326
with open(config_path, "r") as f:
320327
config = DeprecatedDatasetConfig.from_json(f.read())
328+
dataset = cls(fields=config.fields, questions=config.questions, guidelines=config.guidelines)
321329
except Exception as e:
322330
raise FileNotFoundError(
323331
"Neither `argilla.yaml` nor `argilla.cfg` files were found in the"
@@ -340,7 +348,7 @@ def from_huggingface(cls: Type["FeedbackDataset"], repo_id: str, *args: Any, **k
340348
responses = {}
341349
suggestions = []
342350
user_without_id = False
343-
for question in config.questions:
351+
for question in dataset.questions:
344352
if hfds[index][question.name] is not None and len(hfds[index][question.name]) > 0:
345353
if (
346354
len(
@@ -414,20 +422,15 @@ def from_huggingface(cls: Type["FeedbackDataset"], repo_id: str, *args: Any, **k
414422

415423
records.append(
416424
FeedbackRecord(
417-
fields={field.name: hfds[index][field.name] for field in config.fields},
425+
fields={field.name: hfds[index][field.name] for field in dataset.fields},
418426
metadata=metadata or {},
419427
responses=list(responses.values()) or [],
420428
suggestions=[suggestion for suggestion in suggestions if suggestion["value"] is not None] or [],
421429
external_id=hfds[index]["external_id"],
422430
)
423431
)
424432
del hfds
425-
instance = cls(
426-
fields=config.fields,
427-
questions=config.questions,
428-
guidelines=config.guidelines,
429-
metadata_properties=config.metadata_properties,
430-
allow_extra_metadata=config.allow_extra_metadata,
431-
)
432-
instance.add_records(records)
433-
return instance
433+
434+
dataset.add_records(records)
435+
436+
return dataset

0 commit comments

Comments
 (0)