Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lexical STWFSAPY Backend #438

Merged
merged 18 commits into from
Jan 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions annif/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from . import tfidf
from . import pav
from . import maui
from . import stwfsa
import annif


Expand All @@ -29,6 +30,7 @@ def get_backend(backend_id):
register_backend(tfidf.TFIDFBackend)
register_backend(pav.PAVBackend)
register_backend(maui.MauiBackend)
register_backend(stwfsa.StwfsaBackend)

# Optional backends
try:
Expand Down
131 changes: 131 additions & 0 deletions annif/backend/stwfsa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import os
from stwfsapy.predictor import StwfsapyPredictor
from annif.exception import NotInitializedException, NotSupportedException
from annif.suggestion import ListSuggestionResult, SubjectSuggestion
from . import backend
from annif.util import atomic_save, boolean


_KEY_CONCEPT_TYPE_URI = 'concept_type_uri'
_KEY_SUBTHESAURUS_TYPE_URI = 'sub_thesaurus_type_uri'
_KEY_THESAURUS_RELATION_TYPE_URI = 'thesaurus_relation_type_uri'
_KEY_THESAURUS_RELATION_IS_SPECIALISATION = (
'thesaurus_relation_is_specialisation')
_KEY_REMOVE_DEPRECATED = 'remove_deprecated'
_KEY_HANDLE_TITLE_CASE = 'handle_title_case'
_KEY_EXTRACT_UPPER_CASE_FROM_BRACES = 'extract_upper_case_from_braces'
_KEY_EXTRACT_ANY_CASE_FROM_BRACES = 'extract_any_case_from_braces'
_KEY_EXPAND_AMPERSAND_WITH_SPACES = 'expand_ampersand_with_spaces'
_KEY_EXPAND_ABBREVIATION_WITH_PUNCTUATION = (
'expand_abbreviation_with_punctuation')
_KEY_SIMPLE_ENGLISH_PLURAL_RULES = 'simple_english_plural_rules'
_KEY_INPUT_LIMIT = 'input_limit'


class StwfsaBackend(backend.AnnifBackend):

name = "stwfsa"
needs_subject_index = True

STWFSA_PARAMETERS = {
_KEY_CONCEPT_TYPE_URI: str,
osma marked this conversation as resolved.
Show resolved Hide resolved
_KEY_SUBTHESAURUS_TYPE_URI: str,
_KEY_THESAURUS_RELATION_TYPE_URI: str,
_KEY_THESAURUS_RELATION_IS_SPECIALISATION: boolean,
_KEY_REMOVE_DEPRECATED: boolean,
_KEY_HANDLE_TITLE_CASE: boolean,
_KEY_EXTRACT_UPPER_CASE_FROM_BRACES: boolean,
_KEY_EXTRACT_ANY_CASE_FROM_BRACES: boolean,
_KEY_EXPAND_AMPERSAND_WITH_SPACES: boolean,
_KEY_EXPAND_ABBREVIATION_WITH_PUNCTUATION: boolean,
_KEY_SIMPLE_ENGLISH_PLURAL_RULES: boolean,
_KEY_INPUT_LIMIT: int,
}

DEFAULT_PARAMETERS = {
_KEY_CONCEPT_TYPE_URI: 'http://www.w3.org/2004/02/skos/core#Concept',
_KEY_SUBTHESAURUS_TYPE_URI:
'http://www.w3.org/2004/02/skos/core#Collection',
_KEY_THESAURUS_RELATION_TYPE_URI:
'http://www.w3.org/2004/02/skos/core#member',
_KEY_THESAURUS_RELATION_IS_SPECIALISATION: True,
_KEY_REMOVE_DEPRECATED: True,
_KEY_HANDLE_TITLE_CASE: True,
_KEY_EXTRACT_UPPER_CASE_FROM_BRACES: True,
_KEY_EXTRACT_ANY_CASE_FROM_BRACES: False,
_KEY_EXPAND_AMPERSAND_WITH_SPACES: True,
_KEY_EXPAND_ABBREVIATION_WITH_PUNCTUATION: True,
_KEY_SIMPLE_ENGLISH_PLURAL_RULES: False,
_KEY_INPUT_LIMIT: 0,
}

MODEL_FILE = 'stwfsa_predictor.zip'

_model = None

def initialize(self):
if self._model is None:
path = os.path.join(self.datadir, self.MODEL_FILE)
self.debug(f'Loading STWFSA model from {path}.')
if os.path.exists(path):
self._model = StwfsapyPredictor.load(path)
self.debug('Loaded model.')
else:
raise NotInitializedException(
osma marked this conversation as resolved.
Show resolved Hide resolved
f'Model not found at {path}',
backend_id=self.backend_id)

def _load_data(self, corpus):
if corpus == 'cached':
raise NotSupportedException(
'Training stwfsa project from cached data not supported.')
if corpus.is_empty():
raise NotSupportedException(
'Cannot train stwfsa project with no documents.')
self.debug("Transforming training data.")
X = []
y = []
for doc in corpus.documents:
X.append(doc.text)
y.append(doc.uris)
return X, y

def _train(self, corpus, params):
X, y = self._load_data(corpus)
new_params = {
key: self.STWFSA_PARAMETERS[key](val)
for key, val
in params.items()
if key in self.STWFSA_PARAMETERS
}
new_params.pop(_KEY_INPUT_LIMIT)
p = StwfsapyPredictor(
graph=self.project.vocab.as_graph(),
langs=frozenset([params['language']]),
**new_params)
p.fit(X, y)
self._model = p
atomic_save(
p,
self.datadir,
self.MODEL_FILE,
lambda model, store_path: model.store(store_path))

def _suggest(self, text, params):
self.debug(
f'Suggesting subjects for text "{text[:20]}..." (len={len(text)})')
result = self._model.suggest_proba([text])[0]
suggestions = []
for uri, score in result:
subject_id = self.project.subjects.by_uri(uri)
if subject_id:
label = self.project.subjects[subject_id][1]
else:
label = None
suggestion = SubjectSuggestion(
uri,
label,
None,
score)
suggestions.append(suggestion)
return ListSuggestionResult(suggestions)
10 changes: 10 additions & 0 deletions annif/vocab.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Vocabulary management functionality for Annif"""

import os.path
import rdflib.graph
import annif
import annif.corpus
import annif.util
Expand Down Expand Up @@ -69,3 +70,12 @@ def load_vocabulary(self, subject_corpus, language):
def as_skos(self):
"""return the vocabulary as a file object, in SKOS/Turtle syntax"""
return open(os.path.join(self.datadir, 'subjects.ttl'), 'rb')

def as_graph(self):
"""return the vocabulary as an rdflib graph"""
g = rdflib.graph.Graph()
g.load(
os.path.join(self.datadir, 'subjects.ttl'),
format='ttl'
)
return g
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@ def read(fname):
'gensim==3.8.*',
'scikit-learn==0.23.2',
'scipy==1.5.3',
'rdflib',
'rdflib>=4.2,<6.0',
'gunicorn',
'numpy==1.18.*',
'optuna==2.2.0'
'optuna==2.2.0',
'stwfsapy==0.1.5'
],
tests_require=['py', 'pytest', 'requests'],
extras_require={
Expand Down
143 changes: 143 additions & 0 deletions tests/test_backend_stwfsa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import os
from annif.backend import get_backend
from rdflib import Graph
import annif.corpus
from annif.backend.stwfsa import StwfsaBackend
from annif.exception import NotInitializedException, NotSupportedException

import pytest
from unittest.mock import Mock


@pytest.fixture
def graph_project(project):
_rdf_file_path = os.path.join(
os.path.dirname(__file__),
'corpora',
'archaeology',
'yso-archaeology.rdf')
g = Graph()
g.load(_rdf_file_path)
mock_vocab = Mock()
mock_vocab.as_graph.return_value = g
project.vocab = mock_vocab
return project


_backend_conf = {
'language': 'fi',
'concept_type_uri': 'http://www.w3.org/2004/02/skos/core#Concept',
'sub_thesaurus_type_uri':
'http://www.w3.org/2004/02/skos/core#Collection',
'thesaurus_relation_type_uri':
'http://www.w3.org/2004/02/skos/core#member',
'thesaurus_relation_is_specialisation': True,
}


def test_stwfsa_default_params(project):
stwfsa_type = get_backend(StwfsaBackend.name)
stwfsa = stwfsa_type(
backend_id=StwfsaBackend.name,
config_params={},
project=project
)
expected_default_params = {
'concept_type_uri': 'http://www.w3.org/2004/02/skos/core#Concept',
'sub_thesaurus_type_uri':
'http://www.w3.org/2004/02/skos/core#Collection',
'thesaurus_relation_type_uri':
'http://www.w3.org/2004/02/skos/core#member',
'thesaurus_relation_is_specialisation': True,
'remove_deprecated': True,
'handle_title_case': True,
'extract_upper_case_from_braces': True,
'extract_any_case_from_braces': False,
'expand_ampersand_with_spaces': True,
'expand_abbreviation_with_punctuation': True,
'simple_english_plural_rules': False,
'input_limit': 0,
}
actual_params = stwfsa.params
assert expected_default_params == actual_params


def test_stwfsa_not_initialized(project):
stwfsa_type = get_backend(StwfsaBackend.name)
stwfsa = stwfsa_type(
backend_id='stwfsa',
config_params={},
project=project)
with pytest.raises(NotInitializedException):
stwfsa.suggest("example text")


def test_stwfsa_train(document_corpus, graph_project, datadir):
stwfsa_type = get_backend(StwfsaBackend.name)
stwfsa = stwfsa_type(
backend_id=StwfsaBackend.name,
config_params=_backend_conf,
project=graph_project)
stwfsa.train(document_corpus)
assert stwfsa._model is not None
model_file = datadir.join(stwfsa.MODEL_FILE)
assert model_file.exists()
assert model_file.size() > 0


def test_empty_corpus(project):
corpus = annif.corpus.DocumentList([])
stwfsa_type = get_backend(StwfsaBackend.name)
stwfsa = stwfsa_type(
backend_id=StwfsaBackend.name,
config_params=dict(),
project=project)
with pytest.raises(NotSupportedException):
stwfsa.train(corpus)


def test_cached_corpus(project):
corpus = 'cached'
stwfsa_type = get_backend(StwfsaBackend.name)
stwfsa = stwfsa_type(
backend_id=StwfsaBackend.name,
config_params=dict(),
project=project)
with pytest.raises(NotSupportedException):
stwfsa.train(corpus)


def test_stwfsa_suggest_unknown(project):
stwfsa_type = get_backend(StwfsaBackend.name)
stwfsa = stwfsa_type(
backend_id=StwfsaBackend.name,
config_params=dict(),
project=project)
results = stwfsa.suggest('1234')
assert len(results) == 0


def test_stwfsa_suggest(project, datadir):
stwfsa_type = get_backend(StwfsaBackend.name)
stwfsa = stwfsa_type(
backend_id=StwfsaBackend.name,
config_params=dict(),
project=project)
# Just some randomly selected words, taken from YSO archaeology group.
# And "random" words between them
results = stwfsa.suggest("""random
muinais-DNA random random
labyrintit random random random
Eurooppalainen yleissopimus arkeologisen perinnön suojelusta random
Indus-kulttuuri random random random random
kiinteät muinaisjäännökset random random
makrofossiilit random
Mesa Verde random random random random
muinaismuistoalueet random random random
zikkuratit random random
termoluminesenssi random random random""")
assert len(results) == 10
hits = results.as_list(project.subjects)
assert 'http://www.yso.fi/onto/yso/p14174' in [
result.uri for result in hits]
assert 'labyrintit' in [result.label for result in hits]
24 changes: 24 additions & 0 deletions tests/test_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import annif.corpus
import annif.vocab
import rdflib.namespace


def load_dummy_vocab(tmpdir):
Expand Down Expand Up @@ -81,3 +82,26 @@ def test_update_subject_index_with_added_subjects(tmpdir):
assert vocab.subjects.by_uri('http://example.org/new-dummy') == 2
assert vocab.subjects[2] == ('http://example.org/new-dummy', 'new dummy',
'42.42')


def test_as_graph(tmpdir):
vocab = load_dummy_vocab(tmpdir)
graph = vocab.as_graph()
labels = [
(str(tpl[0]), str(tpl[1]))
for tpl
in graph[
:rdflib.namespace.SKOS.prefLabel:]
]
assert len(labels) == 2
assert ('http://example.org/dummy', 'dummy') in labels
assert ('http://example.org/none', 'none') in labels
concepts = [
str(tpl)
for tpl
in graph[
:rdflib.namespace.RDF.type:rdflib.namespace.SKOS.Concept]
]
assert len(concepts) == 2
assert 'http://example.org/dummy' in concepts
assert 'http://example.org/none' in concepts