diff --git a/vladiate/vlad.py b/vladiate/vlad.py index c6f8a06..2f2aea0 100644 --- a/vladiate/vlad.py +++ b/vladiate/vlad.py @@ -17,6 +17,7 @@ def __init__( file_validation_failure_threshold=None, quiet=False, row_validators=[], + fieldnames=None, ): self.logger = logs.logger self.failures = defaultdict(lambda: defaultdict(list)) @@ -26,6 +27,7 @@ def __init__( self.source = source self.validators = validators or getattr(self, "validators", {}) self.row_validators = row_validators or getattr(self, "row_validators", []) + self.fieldnames = fieldnames or getattr(self, "fieldnames", None) self.delimiter = delimiter or getattr(self, "delimiter", ",") self.line_count = 0 self.ignore_missing_validators = ignore_missing_validators @@ -124,7 +126,7 @@ def _log_missing(self, missing_items): ) def _get_total_lines(self): - reader = csv.DictReader(self.source.open(), delimiter=self.delimiter) + reader = csv.DictReader(self.source.open(), delimiter=self.delimiter, fieldnames=self.fieldnames) self.total_lines = sum(1 for _ in reader) return self.total_lines @@ -132,7 +134,7 @@ def validate(self): self.logger.info( "\nValidating {}(source={})".format(self.__class__.__name__, self.source) ) - reader = csv.DictReader(self.source.open(), delimiter=self.delimiter) + reader = csv.DictReader(self.source.open(), delimiter=self.delimiter, fieldnames=self.fieldnames) if not reader.fieldnames: self.logger.info(