-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathimporter.py
119 lines (99 loc) · 3.93 KB
/
importer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from __future__ import annotations
import csv
import json
import logging
from abc import ABCMeta, abstractmethod
from typing import TypeVar, Generic, List, Optional, Any, TextIO, Generator, Iterator
T = TypeVar('T')
class Importer(Generic[T], metaclass=ABCMeta):
@abstractmethod
def import_data(self, file_name: str) -> List[T]:
pass
class CsvImporter(Importer[T], metaclass=ABCMeta):
delimiter: str
encoding: str
skip_first_line: bool = True
@abstractmethod
def __init__(self,
delimiter: str = ';',
encoding: str = 'utf-8',
skip_first_line: bool = True):
self.delimiter = delimiter
self.encoding = encoding
self.skip_first_line = skip_first_line
def import_data(self, file_name: str) -> List[T]:
try:
with open(file_name, encoding=self.encoding) as csv_file:
reader = csv.reader(csv_file, delimiter=self.delimiter)
if self.skip_first_line:
first_line = reader.__next__()
if first_line[0] == "utf-8" and self.encoding != 'utf-8':
self.encoding = 'utf-8'
logging.info("Reopening with UTF-8 encoding")
return self.import_data(file_name)
data = (self.deserialize(entry) for entry in reader)
data = [entry for entry in data if entry is not None]
except UnicodeDecodeError as e:
if self.encoding != "utf-8":
self.encoding = "utf-8"
logging.info("Reopening with UTF-8 encoding")
return self.import_data(file_name)
else:
raise e
return data
@abstractmethod
def deserialize(self, entry: List[str]) -> Optional[T]:
pass
class JsonImporter(Importer[T], metaclass=ABCMeta):
encoding: str
top_level_entry: List[str]
@abstractmethod
def __init__(self,
top_level_entry: List[str] | str,
encoding: str = 'utf-8'):
self.encoding = encoding
if isinstance(top_level_entry, str):
self.top_level_entry = [top_level_entry]
else:
self.top_level_entry = top_level_entry
@abstractmethod
def deserialize(self, entry: Any) -> Optional[T]:
pass
def import_data(self, file_name: str) -> List[T]:
with open(file_name, encoding=self.encoding) as json_file:
content = json.load(json_file)
for entry in self.top_level_entry:
content = content[entry]
data = (self.deserialize(entry) for entry in content)
data = [entry for entry in data if entry is not None]
return data
class WikipediaImporter(Importer[T], metaclass=ABCMeta):
skip_first_entry: bool
@abstractmethod
def __init__(self,
skip_first_entry: bool = True):
self.skip_first_entry = skip_first_entry
@abstractmethod
def deserialize(self, entry: List[str]) -> T | None:
pass
def import_data(self, file_name: str) -> List[T]:
with open(file_name, encoding="utf-8") as input_file:
entries = self.iter_table_entries(input_file)
# We might want to discard a table header
if self.skip_first_entry:
next(entries)
data = (self.deserialize(entry) for entry in entries)
data = [entry for entry in data if entry is not None]
return data
def iter_table_entries(self, input_file: TextIO) -> Generator[List[str], None, None]:
lines: Iterator[str] = iter(input_file)
# We will ignore the first |- line
next(lines)
entry = []
for line in lines:
if line.strip() != "|-" and line.strip() != "|}":
line = line.strip().lstrip("|").lstrip()
entry.append(line)
else:
yield entry
entry = []