Skip to content

Commit 5094419

Browse files
authored
Add validation for tagging type (#23)
* Add validation for tagging type * Fix import * Update the expected fields * Move values to constants * Update the expected fields * Add support for tagging type arg in cli * Update the version * Refactor the column check function
1 parent 3900ac6 commit 5094419

File tree

5 files changed

+68
-6
lines changed

5 files changed

+68
-6
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# CHANGELOG
2+
## 0.3.35
3+
- [x] add: validations for the input file for conversation tagging
24

35
## 0.3.34
46
- [x] PL-61: Add retry mechanism for uploading data to Label studio

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "skit-labels"
3-
version = "0.3.34"
3+
version = "0.3.35"
44
description = "Command line tool for interacting with labelled datasets at skit.ai."
55
authors = []
66
license = "MIT"

skit_labels/cli.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,13 @@ def upload_dataset_to_labelstudio_command(
227227
required=True,
228228
help="The data label implying the source of data",
229229
)
230+
231+
parser.add_argument(
232+
"--tagging-type",
233+
type=str,
234+
help="The tagging type for the calls being uploaded",
235+
)
236+
230237
return parser
231238

232239

@@ -319,12 +326,17 @@ def build_cli():
319326
return parser
320327

321328

322-
def upload_dataset(input_file, url, token, job_id, data_source, data_label = None):
329+
def upload_dataset(input_file, url, token, job_id, data_source, data_label = None, tagging_type=None):
323330
input_file = utils.add_data_label(input_file, data_label)
324331
if data_source == const.SOURCE__DB:
325332
fn = commands.upload_dataset_to_db
326333
elif data_source == const.SOURCE__LABELSTUDIO:
327-
fn = commands.upload_dataset_to_labelstudio
334+
if tagging_type:
335+
is_valid, error = utils.validate_input_data(tagging_type, input_file)
336+
if not is_valid:
337+
return error, None
338+
339+
fn = commands.upload_dataset_to_labelstudio
328340
errors, df_size = asyncio.run(
329341
fn(
330342
input_file,
@@ -386,7 +398,7 @@ def cmd_to_str(args: argparse.Namespace) -> str:
386398
arg_id = args.job_id
387399

388400
_ = is_valid_data_label(args.data_label)
389-
errors, df_size = upload_dataset(args.input, args.url, args.token, arg_id, args.data_source, args.data_label)
401+
errors, df_size = upload_dataset(args.input, args.url, args.token, arg_id, args.data_source, args.data_label, args.tagging_type)
390402

391403
if errors:
392404
return (

skit_labels/constants.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -120,4 +120,10 @@
120120
FROM_NAME_INTENT = "tag"
121121
CHOICES = "choices"
122122
TAXONOMY = "taxonomy"
123-
VALUE = "value"
123+
VALUE = "value"
124+
125+
EXPECTED_COLUMNS_MAPPING = {
126+
"conversation_tagging": ['scenario', 'scenario_category', 'situation_str', 'call', 'data_label']
127+
}
128+
129+
CONVERSATION_TAGGING = 'conversation_tagging'

skit_labels/utils.py

+43-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from datetime import datetime
1111
import pandas as pd
1212
from typing import Union
13-
13+
from skit_labels import constants as const
1414

1515
LOG_LEVELS = ["CRITICAL", "ERROR", "WARNING", "SUCCESS", "INFO", "DEBUG", "TRACE"]
1616

@@ -110,3 +110,45 @@ def add_data_label(input_file: str, data_label: Optional[str] = None) -> str:
110110
df = df.assign(data_label=data_label)
111111
df.to_csv(input_file, index=False)
112112
return input_file
113+
114+
115+
def validate_headers(input_file, tagging_type):
116+
expected_columns_mapping = const.EXPECTED_COLUMNS_MAPPING
117+
expected_headers = expected_columns_mapping.get(tagging_type)
118+
119+
df = pd.read_csv(input_file)
120+
121+
column_headers = df.columns.to_list()
122+
column_headers = [header.lower() for header in column_headers]
123+
column_headers = sorted(column_headers)
124+
expected_headers = sorted(expected_headers)
125+
126+
logger.info(f"column_headers: {column_headers}")
127+
logger.info(f"expected_headers: {expected_headers}")
128+
129+
is_match = column_headers == expected_headers
130+
logger.info(f"Is match: {is_match}")
131+
132+
if not is_match:
133+
missing_headers = set(expected_headers).difference(set(column_headers))
134+
additional_headers = set(column_headers).difference(set(expected_headers))
135+
if missing_headers:
136+
return missing_headers
137+
elif additional_headers:
138+
df.drop(additional_headers, axis=1, inplace=True)
139+
df.to_csv(input_file, index=False)
140+
is_match = True
141+
logger.info(f"Following additional headers have been removed from the csv: {additional_headers}")
142+
return []
143+
144+
145+
def validate_input_data(tagging_type, input_file):
146+
is_valid = True
147+
error = ''
148+
if tagging_type == const.CONVERSATION_TAGGING:
149+
missing_headers = validate_headers(input_file, tagging_type)
150+
if missing_headers:
151+
error = f'Headers in the input file does not match the expected fields. Missing fields = {missing_headers}'
152+
is_valid = False
153+
154+
return is_valid, error

0 commit comments

Comments
 (0)