From 4336aa210724b259fbc87fb897a22fd1b8e94400 Mon Sep 17 00:00:00 2001 From: hallacy Date: Sun, 12 Dec 2021 18:46:24 -0500 Subject: [PATCH 1/6] Make embeddings_utils be importable (#104) * Make embeddings_utils be importable * Small tweaks to dicts for typing --- examples/embeddings/Classification.ipynb | 2 +- examples/embeddings/Code_search.ipynb | 4 ++-- examples/embeddings/Obtain_dataset.ipynb | 2 +- .../Semantic_text_search_using_embeddings.ipynb | 2 +- .../embeddings/User_and_product_embeddings.ipynb | 2 +- examples/embeddings/Zero-shot_classification.ipynb | 2 +- .../utils.py => openai/embeddings_utils.py | 13 +++++-------- openai/version.py | 2 +- 8 files changed, 13 insertions(+), 16 deletions(-) rename examples/embeddings/utils.py => openai/embeddings_utils.py (87%) diff --git a/examples/embeddings/Classification.ipynb b/examples/embeddings/Classification.ipynb index 482ba85910..54ad13fe89 100644 --- a/examples/embeddings/Classification.ipynb +++ b/examples/embeddings/Classification.ipynb @@ -90,7 +90,7 @@ } ], "source": [ - "from utils import plot_multiclass_precision_recall\n", + "from openai.embeddings_utils import plot_multiclass_precision_recall\n", "\n", "plot_multiclass_precision_recall(probas, y_test, [1,2,3,4,5], clf)" ] diff --git a/examples/embeddings/Code_search.ipynb b/examples/embeddings/Code_search.ipynb index 14cbf81777..d440161493 100644 --- a/examples/embeddings/Code_search.ipynb +++ b/examples/embeddings/Code_search.ipynb @@ -185,7 +185,7 @@ } ], "source": [ - "from utils import get_embedding\n", + "from openai.embeddings_utils import get_embedding\n", "\n", "df = pd.DataFrame(all_funcs)\n", "df['code_embedding'] = df['code'].apply(lambda x: get_embedding(x, engine='babbage-code-search-code'))\n", @@ -231,7 +231,7 @@ } ], "source": [ - "from utils import cosine_similarity\n", + "from openai.embeddings_utils import cosine_similarity\n", "\n", "def search_functions(df, code_query, n=3, pprint=True, n_lines=7):\n", " embedding = get_embedding(code_query, engine='babbage-code-search-text')\n", diff --git a/examples/embeddings/Obtain_dataset.ipynb b/examples/embeddings/Obtain_dataset.ipynb index 76bb7b8427..61c5775c46 100644 --- a/examples/embeddings/Obtain_dataset.ipynb +++ b/examples/embeddings/Obtain_dataset.ipynb @@ -156,7 +156,7 @@ "metadata": {}, "outputs": [], "source": [ - "from utils import get_embedding\n", + "from openai.embeddings_utils import get_embedding\n", "\n", "# This will take just under 10 minutes\n", "df['babbage_similarity'] = df.combined.apply(lambda x: get_embedding(x, engine='babbage-similarity'))\n", diff --git a/examples/embeddings/Semantic_text_search_using_embeddings.ipynb b/examples/embeddings/Semantic_text_search_using_embeddings.ipynb index e83d4db5e1..58a41439c4 100644 --- a/examples/embeddings/Semantic_text_search_using_embeddings.ipynb +++ b/examples/embeddings/Semantic_text_search_using_embeddings.ipynb @@ -49,7 +49,7 @@ } ], "source": [ - "from utils import get_embedding, cosine_similarity\n", + "from openai.embeddings_utils import get_embedding, cosine_similarity\n", "\n", "# search through the reviews for a specific product\n", "def search_reviews(df, product_description, n=3, pprint=True):\n", diff --git a/examples/embeddings/User_and_product_embeddings.ipynb b/examples/embeddings/User_and_product_embeddings.ipynb index ca74d6bdc6..ea0ef2f81c 100644 --- a/examples/embeddings/User_and_product_embeddings.ipynb +++ b/examples/embeddings/User_and_product_embeddings.ipynb @@ -70,7 +70,7 @@ "metadata": {}, "outputs": [], "source": [ - "from utils import cosine_similarity\n", + "from openai.embeddings_utils import cosine_similarity\n", "\n", "# evaluate embeddings as recommendations on X_test\n", "def evaluate_single_match(row):\n", diff --git a/examples/embeddings/Zero-shot_classification.ipynb b/examples/embeddings/Zero-shot_classification.ipynb index 95789287a6..d4fc1b38ff 100644 --- a/examples/embeddings/Zero-shot_classification.ipynb +++ b/examples/embeddings/Zero-shot_classification.ipynb @@ -78,7 +78,7 @@ } ], "source": [ - "from utils import cosine_similarity, get_embedding\n", + "from openai.embeddings_utils import cosine_similarity, get_embedding\n", "from sklearn.metrics import PrecisionRecallDisplay\n", "\n", "def evaluate_emeddings_approach(\n", diff --git a/examples/embeddings/utils.py b/openai/embeddings_utils.py similarity index 87% rename from examples/embeddings/utils.py rename to openai/embeddings_utils.py index f7877147fd..caabd86396 100644 --- a/examples/embeddings/utils.py +++ b/openai/embeddings_utils.py @@ -43,16 +43,14 @@ def plot_multiclass_precision_recall( average_precision[i] = average_precision_score(y_true[:, i], y_score[:, i]) # A "micro-average": quantifying score on all classes jointly - precision["micro"], recall["micro"], _ = precision_recall_curve( + precision_micro, recall_micro, _ = precision_recall_curve( y_true.ravel(), y_score.ravel() ) - average_precision["micro"] = average_precision_score( - y_true, y_score, average="micro" - ) + average_precision_micro = average_precision_score(y_true, y_score, average="micro") print( str(classifier_name) + " - Average precision score over all classes: {0:0.2f}".format( - average_precision["micro"] + average_precision_micro ) ) @@ -69,11 +67,10 @@ def plot_multiclass_precision_recall( lines.append(l) labels.append("iso-f1 curves") - (l,) = plt.plot(recall["micro"], precision["micro"], color="gold", lw=2) + (l,) = plt.plot(recall_micro, precision_micro, color="gold", lw=2) lines.append(l) labels.append( - "average Precision-recall (auprc = {0:0.2f})" - "".format(average_precision["micro"]) + "average Precision-recall (auprc = {0:0.2f})" "".format(average_precision_micro) ) for i in range(n_classes): diff --git a/openai/version.py b/openai/version.py index 46898a964b..d0dcdac6c8 100644 --- a/openai/version.py +++ b/openai/version.py @@ -1 +1 @@ -VERSION = "0.11.3" +VERSION = "0.11.4" From 6f30c2083532fcfc2bb5d9b440b05b4d3c91d528 Mon Sep 17 00:00:00 2001 From: kennyhsu5 <1762087+kennyhsu5@users.noreply.github.com> Date: Fri, 10 Dec 2021 10:46:14 -0800 Subject: [PATCH 2/6] Remove default api_prefix and move v1 prefix to default api_base (#95) * Remove default api_prefix and move v1 prefix to default api_base * Run black and isort --- openai/__init__.py | 2 +- openai/api_resources/abstract/api_resource.py | 6 ++++-- openai/api_resources/abstract/engine_api_resource.py | 4 ++-- openai/api_resources/answer.py | 8 +++----- openai/api_resources/classification.py | 8 +++----- openai/api_resources/search.py | 7 ++----- 6 files changed, 15 insertions(+), 20 deletions(-) diff --git a/openai/__init__.py b/openai/__init__.py index f9d601bcf8..1574cc0b38 100644 --- a/openai/__init__.py +++ b/openai/__init__.py @@ -25,7 +25,7 @@ api_key_path: Optional[str] = os.environ.get("OPENAI_API_KEY_PATH") organization = os.environ.get("OPENAI_ORGANIZATION") -api_base = os.environ.get("OPENAI_API_BASE", "https://api.openai.com") +api_base = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1") api_version = None verify_ssl_certs = True # No effect. Certificates are always verified. proxy = None diff --git a/openai/api_resources/abstract/api_resource.py b/openai/api_resources/abstract/api_resource.py index 3a27c66585..40350dfe4a 100644 --- a/openai/api_resources/abstract/api_resource.py +++ b/openai/api_resources/abstract/api_resource.py @@ -5,7 +5,7 @@ class APIResource(OpenAIObject): - api_prefix = "v1" + api_prefix = "" @classmethod def retrieve(cls, id, api_key=None, request_id=None, **params): @@ -28,7 +28,9 @@ def class_url(cls): # Namespaces are separated in object names with periods (.) and in URLs # with forward slashes (/), so replace the former with the latter. base = cls.OBJECT_NAME.replace(".", "/") # type: ignore - return "/%s/%ss" % (cls.api_prefix, base) + if cls.api_prefix: + return "/%s/%ss" % (cls.api_prefix, base) + return "/%ss" % (base) def instance_url(self): id = self.get("id") diff --git a/openai/api_resources/abstract/engine_api_resource.py b/openai/api_resources/abstract/engine_api_resource.py index e613e53814..aab1ae8663 100644 --- a/openai/api_resources/abstract/engine_api_resource.py +++ b/openai/api_resources/abstract/engine_api_resource.py @@ -22,10 +22,10 @@ def class_url(cls, engine: Optional[str] = None): # with forward slashes (/), so replace the former with the latter. base = cls.OBJECT_NAME.replace(".", "/") # type: ignore if engine is None: - return "/%s/%ss" % (cls.api_prefix, base) + return "/%ss" % (base) extn = quote_plus(engine) - return "/%s/engines/%s/%ss" % (cls.api_prefix, extn, base) + return "/engines/%s/%ss" % (extn, base) @classmethod def create( diff --git a/openai/api_resources/answer.py b/openai/api_resources/answer.py index 8dd3e84d23..33de3cb7e9 100644 --- a/openai/api_resources/answer.py +++ b/openai/api_resources/answer.py @@ -2,13 +2,11 @@ class Answer(OpenAIObject): - api_prefix = "v1" - @classmethod - def get_url(self, base): - return "/%s/%s" % (self.api_prefix, base) + def get_url(self): + return "/answers" @classmethod def create(cls, **params): instance = cls() - return instance.request("post", cls.get_url("answers"), params) + return instance.request("post", cls.get_url(), params) diff --git a/openai/api_resources/classification.py b/openai/api_resources/classification.py index b659164e5a..6423c6946a 100644 --- a/openai/api_resources/classification.py +++ b/openai/api_resources/classification.py @@ -2,13 +2,11 @@ class Classification(OpenAIObject): - api_prefix = "v1" - @classmethod - def get_url(self, base): - return "/%s/%s" % (self.api_prefix, base) + def get_url(self): + return "/classifications" @classmethod def create(cls, **params): instance = cls() - return instance.request("post", cls.get_url("classifications"), params) + return instance.request("post", cls.get_url(), params) diff --git a/openai/api_resources/search.py b/openai/api_resources/search.py index 4a6e9a4b46..fc7c4326f6 100644 --- a/openai/api_resources/search.py +++ b/openai/api_resources/search.py @@ -2,14 +2,11 @@ class Search(APIResource): - api_prefix = "v1" - OBJECT_NAME = "search_indices" - @classmethod def class_url(cls): - return "/%s/%s" % (cls.api_prefix, cls.OBJECT_NAME) + return "/search_indices/search" @classmethod def create_alpha(cls, **params): instance = cls() - return instance.request("post", f"{cls.class_url()}/search", params) + return instance.request("post", cls.class_url(), params) From 1efded17d67d58c2aee2cbfe0a6c04cf4848e8e0 Mon Sep 17 00:00:00 2001 From: Madeleine Thompson Date: Wed, 20 Oct 2021 20:57:18 -0700 Subject: [PATCH 3/6] make construct_from key argument optional (#92) It's almost never set, and sometimes doesn't even make sense. --- openai/error.py | 2 +- openai/openai_object.py | 6 +++--- openai/util.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/openai/error.py b/openai/error.py index 1f0fa3e906..f2ce82c2e2 100644 --- a/openai/error.py +++ b/openai/error.py @@ -64,7 +64,7 @@ def construct_error_object(self): return None return openai.api_resources.error_object.ErrorObject.construct_from( - self.json_body["error"], key=None + self.json_body["error"] ) diff --git a/openai/openai_object.py b/openai/openai_object.py index 9b56082d51..f87ff29f34 100644 --- a/openai/openai_object.py +++ b/openai/openai_object.py @@ -100,7 +100,7 @@ def __reduce__(self): def construct_from( cls, values, - key, + api_key: Optional[str] = None, api_version=None, organization=None, engine=None, @@ -108,7 +108,7 @@ def construct_from( ): instance = cls( values.get("id"), - api_key=key, + api_key=api_key, api_version=api_version, organization=organization, engine=engine, @@ -116,7 +116,7 @@ def construct_from( ) instance.refresh_from( values, - api_key=key, + api_key=api_key, api_version=api_version, organization=organization, response_ms=response_ms, diff --git a/openai/util.py b/openai/util.py index 3be1717034..1b87ac893f 100644 --- a/openai/util.py +++ b/openai/util.py @@ -111,7 +111,7 @@ def convert_to_openai_object( return klass.construct_from( resp, - api_key, + api_key=api_key, api_version=api_version, organization=organization, response_ms=response_ms, From 50f020639ed7cd6e6d194f0eef0a5dd81919e7d0 Mon Sep 17 00:00:00 2001 From: Chris Hallacy Date: Mon, 13 Dec 2021 14:19:47 -0800 Subject: [PATCH 4/6] lint --- openai/embeddings_utils.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/openai/embeddings_utils.py b/openai/embeddings_utils.py index caabd86396..d846ede42f 100644 --- a/openai/embeddings_utils.py +++ b/openai/embeddings_utils.py @@ -1,11 +1,10 @@ -import openai -import pandas as pd -import numpy as np import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from sklearn.metrics import average_precision_score, precision_recall_curve +from tenacity import retry, stop_after_attempt, wait_random_exponential -from tenacity import retry, wait_random_exponential, stop_after_attempt -from sklearn.metrics import precision_recall_curve -from sklearn.metrics import average_precision_score +import openai @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) @@ -14,7 +13,7 @@ def get_embedding(text, engine="davinci-similarity"): # replace newlines, which can negatively affect performance. text = text.replace("\n", " ") - return openai.Engine(id=engine).embeddings(input = [text])['data'][0]['embedding'] + return openai.Engine(id=engine).embeddings(input=[text])["data"][0]["embedding"] def cosine_similarity(a, b): @@ -88,4 +87,4 @@ def plot_multiclass_precision_recall( plt.xlabel("Recall") plt.ylabel("Precision") plt.title(f"{classifier_name}: Precision-Recall curve for each class") - plt.legend(lines, labels) \ No newline at end of file + plt.legend(lines, labels) From f9af1a675585d9a74d25f6fc51a4495d071dd6d4 Mon Sep 17 00:00:00 2001 From: hallacy Date: Tue, 26 Oct 2021 18:04:09 -0400 Subject: [PATCH 5/6] Split search.prepare_data into answers/classifications/search versions (#93) * Break out prepare_data into answers, classifications, and search * And cleaned up CLI --- openai/cli.py | 139 ++++++++++++++++++++++++++++++++++---------------- 1 file changed, 96 insertions(+), 43 deletions(-) diff --git a/openai/cli.py b/openai/cli.py index 842b0a290c..1beca23d17 100644 --- a/openai/cli.py +++ b/openai/cli.py @@ -3,6 +3,7 @@ import signal import sys import warnings +from functools import partial from typing import Optional import requests @@ -11,10 +12,12 @@ from openai.upload_progress import BufferReader from openai.validators import ( apply_necessary_remediation, - apply_optional_remediation, + apply_validators, + get_search_validators, get_validators, read_any_format, write_out_file, + write_out_search_file, ) @@ -227,6 +230,40 @@ def list(cls, args): class Search: + @classmethod + def prepare_data(cls, args, purpose): + + sys.stdout.write("Analyzing...\n") + fname = args.file + auto_accept = args.quiet + + optional_fields = ["metadata"] + + if purpose == "classifications": + required_fields = ["text", "labels"] + else: + required_fields = ["text"] + + df, remediation = read_any_format( + fname, fields=required_fields + optional_fields + ) + + if "metadata" not in df: + df["metadata"] = None + + apply_necessary_remediation(None, remediation) + validators = get_search_validators(required_fields, optional_fields) + + write_out_file_func = partial( + write_out_search_file, + purpose=purpose, + fields=required_fields + optional_fields, + ) + + apply_validators( + df, fname, remediation, validators, auto_accept, write_out_file_func + ) + @classmethod def create_alpha(cls, args): resp = openai.Search.create_alpha( @@ -489,49 +526,14 @@ def prepare_data(cls, args): validators = get_validators() - optional_remediations = [] - if remediation is not None: - optional_remediations.append(remediation) - for validator in validators: - remediation = validator(df) - if remediation is not None: - optional_remediations.append(remediation) - df = apply_necessary_remediation(df, remediation) - - any_optional_or_necessary_remediations = any( - [ - remediation - for remediation in optional_remediations - if remediation.optional_msg is not None - or remediation.necessary_msg is not None - ] + apply_validators( + df, + fname, + remediation, + validators, + auto_accept, + write_out_file_func=write_out_file, ) - any_necessary_applied = any( - [ - remediation - for remediation in optional_remediations - if remediation.necessary_msg is not None - ] - ) - any_optional_applied = False - - if any_optional_or_necessary_remediations: - sys.stdout.write( - "\n\nBased on the analysis we will perform the following actions:\n" - ) - for remediation in optional_remediations: - df, optional_applied = apply_optional_remediation( - df, remediation, auto_accept - ) - any_optional_applied = any_optional_applied or optional_applied - else: - sys.stdout.write("\n\nNo remediations found.\n") - - any_optional_or_necessary_applied = ( - any_optional_applied or any_necessary_applied - ) - - write_out_file(df, fname, any_optional_or_necessary_applied, auto_accept) def tools_register(parser): @@ -561,6 +563,57 @@ def help(args): ) sub.set_defaults(func=FineTune.prepare_data) + sub = subparsers.add_parser("search.prepare_data") + sub.add_argument( + "-f", + "--file", + required=True, + help="JSONL, JSON, CSV, TSV, TXT or XLSX file containing text examples to be analyzed." + "This should be the local file path.", + ) + sub.add_argument( + "-q", + "--quiet", + required=False, + action="store_true", + help="Auto accepts all suggestions, without asking for user input. To be used within scripts.", + ) + sub.set_defaults(func=partial(Search.prepare_data, purpose="search")) + + sub = subparsers.add_parser("classifications.prepare_data") + sub.add_argument( + "-f", + "--file", + required=True, + help="JSONL, JSON, CSV, TSV, TXT or XLSX file containing text-label examples to be analyzed." + "This should be the local file path.", + ) + sub.add_argument( + "-q", + "--quiet", + required=False, + action="store_true", + help="Auto accepts all suggestions, without asking for user input. To be used within scripts.", + ) + sub.set_defaults(func=partial(Search.prepare_data, purpose="classification")) + + sub = subparsers.add_parser("answers.prepare_data") + sub.add_argument( + "-f", + "--file", + required=True, + help="JSONL, JSON, CSV, TSV, TXT or XLSX file containing text examples to be analyzed." + "This should be the local file path.", + ) + sub.add_argument( + "-q", + "--quiet", + required=False, + action="store_true", + help="Auto accepts all suggestions, without asking for user input. To be used within scripts.", + ) + sub.set_defaults(func=partial(Search.prepare_data, purpose="answer")) + def api_register(parser): # Engine management From 4af935e43244b2cb92cb646a6d195aba07857394 Mon Sep 17 00:00:00 2001 From: hallacy Date: Wed, 8 Sep 2021 17:50:56 -0400 Subject: [PATCH 6/6] Validate search files (#69) * Add validators for search files * Clean up fields --- openai/validators.py | 133 +++++++++++++++++++++++++++++++++++-------- 1 file changed, 110 insertions(+), 23 deletions(-) diff --git a/openai/validators.py b/openai/validators.py index ba4b92c2b0..356f461506 100644 --- a/openai/validators.py +++ b/openai/validators.py @@ -70,7 +70,7 @@ def lower_case_column_creator(df): ) -def additional_column_validator(df): +def additional_column_validator(df, fields=["prompt", "completion"]): """ This validator will remove additional columns from the dataframe. """ @@ -79,9 +79,7 @@ def additional_column_validator(df): immediate_msg = None necessary_fn = None if len(df.columns) > 2: - additional_columns = [ - c for c in df.columns if c not in ["prompt", "completion"] - ] + additional_columns = [c for c in df.columns if c not in fields] warn_message = "" for ac in additional_columns: dups = [c for c in additional_columns if ac in c] @@ -91,7 +89,7 @@ def additional_column_validator(df): necessary_msg = f"Remove additional columns/keys: {additional_columns}" def necessary_fn(x): - return x[["prompt", "completion"]] + return x[fields] return Remediation( name="additional_column", @@ -101,7 +99,7 @@ def necessary_fn(x): ) -def non_empty_completion_validator(df): +def non_empty_field_validator(df, field="completion"): """ This validator will ensure that no completion is empty. """ @@ -109,42 +107,39 @@ def non_empty_completion_validator(df): necessary_fn = None immediate_msg = None - if ( - df["completion"].apply(lambda x: x == "").any() - or df["completion"].isnull().any() - ): - empty_rows = (df["completion"] == "") | (df["completion"].isnull()) + if df[field].apply(lambda x: x == "").any() or df[field].isnull().any(): + empty_rows = (df[field] == "") | (df[field].isnull()) empty_indexes = df.reset_index().index[empty_rows].tolist() - immediate_msg = f"\n- `completion` column/key should not contain empty strings. These are rows: {empty_indexes}" + immediate_msg = f"\n- `{field}` column/key should not contain empty strings. These are rows: {empty_indexes}" def necessary_fn(x): - return x[x["completion"] != ""].dropna(subset=["completion"]) + return x[x[field] != ""].dropna(subset=[field]) - necessary_msg = f"Remove {len(empty_indexes)} rows with empty completions" + necessary_msg = f"Remove {len(empty_indexes)} rows with empty {field}s" return Remediation( - name="empty_completion", + name=f"empty_{field}", immediate_msg=immediate_msg, necessary_msg=necessary_msg, necessary_fn=necessary_fn, ) -def duplicated_rows_validator(df): +def duplicated_rows_validator(df, fields=["prompt", "completion"]): """ This validator will suggest to the user to remove duplicate rows if they exist. """ - duplicated_rows = df.duplicated(subset=["prompt", "completion"]) + duplicated_rows = df.duplicated(subset=fields) duplicated_indexes = df.reset_index().index[duplicated_rows].tolist() immediate_msg = None optional_msg = None optional_fn = None if len(duplicated_indexes) > 0: - immediate_msg = f"\n- There are {len(duplicated_indexes)} duplicated prompt-completion pairs. These are rows: {duplicated_indexes}" + immediate_msg = f"\n- There are {len(duplicated_indexes)} duplicated {'-'.join(fields)} sets. These are rows: {duplicated_indexes}" optional_msg = f"Remove {len(duplicated_indexes)} duplicate rows" def optional_fn(x): - return x.drop_duplicates(subset=["prompt", "completion"]) + return x.drop_duplicates(subset=fields) return Remediation( name="duplicated_rows", @@ -467,7 +462,7 @@ def lower_case(x): ) -def read_any_format(fname): +def read_any_format(fname, fields=["prompt", "completion"]): """ This function will read a file saved in .csv, .json, .txt, .xlsx or .tsv format using pandas. - for .xlsx it will read the first sheet @@ -502,7 +497,7 @@ def read_any_format(fname): content = f.read() df = pd.DataFrame( [["", line] for line in content.split("\n")], - columns=["prompt", "completion"], + columns=fields, dtype=str, ) if fname.lower().endswith("jsonl") or fname.lower().endswith("json"): @@ -623,7 +618,7 @@ def get_outfnames(fname, split): while True: index_suffix = f" ({i})" if i > 0 else "" candidate_fnames = [ - fname.split(".")[0] + "_prepared" + suffix + index_suffix + ".jsonl" + os.path.splitext(fname)[0] + "_prepared" + suffix + index_suffix + ".jsonl" for suffix in suffixes ] if not any(os.path.isfile(f) for f in candidate_fnames): @@ -744,6 +739,30 @@ def write_out_file(df, fname, any_remediations, auto_accept): sys.stdout.write("Aborting... did not write the file\n") +def write_out_search_file(df, fname, any_remediations, auto_accept, fields, purpose): + """ + This function will write out a dataframe to a file, if the user would like to proceed. + """ + input_text = "\n\nYour data will be written to a new JSONL file. Proceed [Y/n]: " + + if not any_remediations: + sys.stdout.write( + f'\nYou can upload your file:\n> openai api files.create -f "{fname}" -p {purpose}' + ) + + elif accept_suggestion(input_text, auto_accept): + fnames = get_outfnames(fname, split=False) + + assert len(fnames) == 1 + df[fields].to_json(fnames[0], lines=True, orient="records", force_ascii=False) + + sys.stdout.write( + f'\nWrote modified file to {fnames[0]}`\nFeel free to take a look!\n\nNow upload that file:\n> openai api files.create -f "{fnames[0]}" -p {purpose}' + ) + else: + sys.stdout.write("Aborting... did not write the file\n") + + def infer_task_type(df): """ Infer the likely fine-tuning task type from the data @@ -788,7 +807,7 @@ def get_validators(): lambda x: necessary_column_validator(x, "prompt"), lambda x: necessary_column_validator(x, "completion"), additional_column_validator, - non_empty_completion_validator, + non_empty_field_validator, format_inferrer_validator, duplicated_rows_validator, long_examples_validator, @@ -800,3 +819,71 @@ def get_validators(): common_completion_suffix_validator, completions_space_start_validator, ] + + +def get_search_validators(required_fields, optional_fields): + validators = [ + lambda x: necessary_column_validator(x, field) for field in required_fields + ] + validators += [ + lambda x: non_empty_field_validator(x, field) for field in required_fields + ] + validators += [lambda x: duplicated_rows_validator(x, required_fields)] + validators += [ + lambda x: additional_column_validator( + x, fields=required_fields + optional_fields + ), + ] + + return validators + + +def apply_validators( + df, + fname, + remediation, + validators, + auto_accept, + write_out_file_func, +): + optional_remediations = [] + if remediation is not None: + optional_remediations.append(remediation) + for validator in validators: + remediation = validator(df) + if remediation is not None: + optional_remediations.append(remediation) + df = apply_necessary_remediation(df, remediation) + + any_optional_or_necessary_remediations = any( + [ + remediation + for remediation in optional_remediations + if remediation.optional_msg is not None + or remediation.necessary_msg is not None + ] + ) + any_necessary_applied = any( + [ + remediation + for remediation in optional_remediations + if remediation.necessary_msg is not None + ] + ) + any_optional_applied = False + + if any_optional_or_necessary_remediations: + sys.stdout.write( + "\n\nBased on the analysis we will perform the following actions:\n" + ) + for remediation in optional_remediations: + df, optional_applied = apply_optional_remediation( + df, remediation, auto_accept + ) + any_optional_applied = any_optional_applied or optional_applied + else: + sys.stdout.write("\n\nNo remediations found.\n") + + any_optional_or_necessary_applied = any_optional_applied or any_necessary_applied + + write_out_file_func(df, fname, any_optional_or_necessary_applied, auto_accept)