Skip to content

Commit ffb1ef8

Browse files
committed
wip for new model repository
1 parent 08db477 commit ffb1ef8

File tree

7 files changed

+226
-359
lines changed

7 files changed

+226
-359
lines changed

.github/workflows/test.yml

-3
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,6 @@ jobs:
5252
python -m build --sdist --wheel --outdir dist/ .
5353
- name: Publish a Python distribution to PyPI
5454
uses: pypa/gh-action-pypi-publish@release/v1
55-
with:
56-
user: __token__
57-
password: ${{ secrets.PYPI_API_TOKEN }}
5855
- name: Upload PyPI artifacts to GH storage
5956
uses: actions/upload-artifact@v3
6057
with:

environment.yml

+1
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,5 @@ dependencies:
3232
- setuptools>=36.6.0,<70.0.0
3333
- pip:
3434
- coremltools~=8.1
35+
- htrmopo
3536
- file:.

environment_cuda.yml

+1
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,5 @@ dependencies:
3333
- setuptools>=36.6.0,<70.0.0
3434
- pip:
3535
- coremltools~=8.1
36+
- htrmopo
3637
- file:.

kraken/ketos/repo.py

+107-62
Original file line numberDiff line numberDiff line change
@@ -18,94 +18,139 @@
1818
1919
Command line driver for publishing models to the model repository.
2020
"""
21+
import re
2122
import logging
22-
import os
2323

2424
import click
2525

26+
from pathlib import Path
2627
from .util import message
2728

2829
logging.captureWarnings(True)
2930
logger = logging.getLogger('kraken')
3031

3132

33+
def _get_field_list(name):
34+
values = []
35+
while True:
36+
value = click.prompt(name, default=None)
37+
if value is not None:
38+
values.append(value)
39+
else:
40+
break
41+
return values
42+
43+
3244
@click.command('publish')
3345
@click.pass_context
3446
@click.option('-i', '--metadata', show_default=True,
35-
type=click.File(mode='r', lazy=True), help='Metadata for the '
36-
'model. Will be prompted from the user if not given')
47+
type=click.File(mode='r', lazy=True), help='Model card file for the model.')
3748
@click.option('-a', '--access-token', prompt=True, help='Zenodo access token')
49+
@click.option('-d', '--doi', prompt=True, help='DOI of an existing record to update')
3850
@click.option('-p', '--private/--public', default=False, help='Disables Zenodo '
3951
'community inclusion request. Allows upload of models that will not show '
4052
'up on `kraken list` output')
4153
@click.argument('model', nargs=1, type=click.Path(exists=False, readable=True, dir_okay=False))
42-
def publish(ctx, metadata, access_token, private, model):
54+
def publish(ctx, metadata, access_token, doi, private, model):
4355
"""
4456
Publishes a model on the zenodo model repository.
4557
"""
4658
import json
59+
import tempfile
60+
61+
from htrmopo import publish_model, update_model
4762

48-
from importlib import resources
49-
from jsonschema import validate
50-
from jsonschema.exceptions import ValidationError
63+
pub_fn = publish_model
5164

52-
from kraken import repo
53-
from kraken.lib import models
65+
from kraken.lib.vgsl import TorchVGSLModel
5466
from kraken.lib.progress import KrakenDownloadProgressBar
5567

56-
ref = resources.files('kraken').joinpath('metadata.schema.json')
57-
with open(ref, 'rb') as fp:
58-
schema = json.load(fp)
59-
60-
nn = models.load_any(model)
61-
62-
if not metadata:
63-
author = click.prompt('author')
64-
affiliation = click.prompt('affiliation')
65-
summary = click.prompt('summary')
66-
description = click.edit('Write long form description (training data, transcription standards) of the model here')
67-
accuracy_default = None
68-
# take last accuracy measurement in model metadata
69-
if 'accuracy' in nn.nn.user_metadata and nn.nn.user_metadata['accuracy']:
70-
accuracy_default = nn.nn.user_metadata['accuracy'][-1][1] * 100
71-
accuracy = click.prompt('accuracy on test set', type=float, default=accuracy_default)
72-
script = [
73-
click.prompt(
74-
'script',
75-
type=click.Choice(
76-
sorted(
77-
schema['properties']['script']['items']['enum'])),
78-
show_choices=True)]
79-
license = click.prompt(
80-
'license',
81-
type=click.Choice(
82-
sorted(
83-
schema['properties']['license']['enum'])),
84-
show_choices=True)
85-
metadata = {
86-
'authors': [{'name': author, 'affiliation': affiliation}],
87-
'summary': summary,
88-
'description': description,
89-
'accuracy': accuracy,
90-
'license': license,
91-
'script': script,
92-
'name': os.path.basename(model),
93-
'graphemes': ['a']
94-
}
95-
while True:
96-
try:
97-
validate(metadata, schema)
98-
except ValidationError as e:
99-
message(e.message)
100-
metadata[e.path[-1]] = click.prompt(e.path[-1], type=float if e.schema['type'] == 'number' else str)
101-
continue
102-
break
68+
_yaml_delim = r'(?:---|\+\+\+)'
69+
_yaml = r'(.*?)'
70+
_content = r'\s*(.+)$'
71+
_re_pattern = r'^\s*' + _yaml_delim + _yaml + _yaml_delim + _content
72+
_yaml_regex = re.compile(_re_pattern, re.S | re.M)
10373

74+
nn = TorchVGSLModel.load_model(model)
75+
76+
frontmatter = {}
77+
# construct metadata if none is given
78+
if metadata:
79+
frontmatter, content = _yaml_regex.match(metadata.read()).groups()
10480
else:
105-
metadata = json.load(metadata)
106-
validate(metadata, schema)
107-
metadata['graphemes'] = [char for char in ''.join(nn.codec.c2l.keys())]
108-
with KrakenDownloadProgressBar() as progress:
81+
frontmatter['summary'] = click.prompt('summary')
82+
content = click.edit('Write long form description (training data, transcription standards) of the model in markdown format here')
83+
84+
creators = []
85+
while True:
86+
author = click.prompt('author', default=None)
87+
affiliation = click.prompt('affiliation', default=None)
88+
orcid = click.prompt('orcid', default=None)
89+
if author is not None:
90+
creators.append({'author': author})
91+
else:
92+
break
93+
if affiliation is not None:
94+
creators[-1]['affiliation'] = affiliation
95+
if orcid is not None:
96+
creators[-1]['orcid'] = orcid
97+
frontmatter['authors'] = creators
98+
frontmatter['license'] = click.prompt('license')
99+
frontmatter['language'] = _get_field_list('language')
100+
frontmatter['script'] = _get_field_list('script')
101+
102+
if len(tags := _get_field_list('tag')):
103+
frontmatter['tags'] = tags + ['kraken_pytorch']
104+
if len(datasets := _get_field_list('dataset URL')):
105+
frontmatter['datasets'] = datasets
106+
if len(base_model := _get_field_list('base model URL')):
107+
frontmatter['base_model'] = base_model
108+
109+
# take last metrics field, falling back to accuracy field in model metadata
110+
metrics = {}
111+
if 'metrics' in nn.user_metadata and nn.user_metadata['metrics']:
112+
metrics['cer'] = 100 - nn.user_metadata['metrics'][-1][1]['val_accuracy']
113+
metrics['wer'] = 100 - nn.user_metadata['metrics'][-1][1]['val_word_accuracy']
114+
elif 'accuracy' in nn.user_metadata and nn.user_metadata['accuracy']:
115+
metrics['cer'] = 100 - nn.user_metadata['accuracy']
116+
frontmatter['metrics'] = metrics
117+
software_hints = ['kind=vgsl']
118+
119+
# some recognition-specific software hints
120+
if nn.model_type == 'recognition':
121+
software_hints.append([f'seg_type={nn.seg_type}', f'one_channel_mode={nn.one_channel_mode}', 'legacy_polygons={nn.user_metadata["legacy_polygons"]}'])
122+
frontmatter['software_hints'] = software_hints
123+
124+
frontmatter['software_name'] = 'kraken'
125+
126+
# build temporary directory
127+
with tempfile.TemporaryDirectory() as tmpdir, KrakenDownloadProgressBar() as progress:
109128
upload_task = progress.add_task('Uploading', total=0, visible=True if not ctx.meta['verbose'] else False)
110-
oid = repo.publish_model(model, metadata, access_token, lambda total, advance: progress.update(upload_task, total=total, advance=advance), private)
111-
message('model PID: {}'.format(oid))
129+
130+
model = Path(model)
131+
tmpdir = Path(tmpdir)
132+
(tmpdir / model.name).symlink_to(model)
133+
# v0 metadata only supports recognition models
134+
if nn.model_type == 'recognition':
135+
v0_metadata = {
136+
'summary': frontmatter['summary'],
137+
'description': content,
138+
'license': frontmatter['license'],
139+
'script': frontmatter['script'],
140+
'name': model.name,
141+
'graphemes': [char for char in ''.join(nn.codec.c2l.keys())]
142+
}
143+
if frontmatter['metrics']:
144+
v0_metadata['accuracy'] = 100 - metrics['cer']
145+
with open(tmpdir / 'metadata.json', 'w') as fo:
146+
json.dump(v0_metadata, fo)
147+
kwargs = {'model': tmpdir,
148+
'model_card': f'---\n{frontmatter}---\n{content}',
149+
'access_token': access_token,
150+
'callback': lambda total, advance: progress.update(upload_task, total=total, advance=advance),
151+
'private': private}
152+
if doi:
153+
pub_fn = update_model
154+
kwargs['model_id'] = doi
155+
oid = pub_fn(**kwargs)
156+
message(f'model PID: {oid}')

kraken/kraken.py

+116-20
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,15 @@
3131
import click
3232
from PIL import Image
3333
from importlib import resources
34+
35+
from rich import print
36+
from rich.tree import Tree
37+
from rich.table import Table
38+
from rich.console import Group
3439
from rich.traceback import install
40+
from rich.logging import RichHandler
41+
from rich.markdown import Markdown
42+
from rich.progress import Progress
3543

3644
from kraken.lib import log
3745

@@ -677,29 +685,90 @@ def ocr(ctx, model, pad, reorder, base_dir, no_segmentation, text_direction):
677685

678686
@cli.command('show')
679687
@click.pass_context
688+
@click.option('-V', '--metadata-version',
689+
default='highest',
690+
type=click.Choice(['v0', 'v1', 'highest']),
691+
help='Version of metadata to fetch if multiple exist in repository.')
680692
@click.argument('model_id')
681-
def show(ctx, model_id):
693+
def show(ctx, metadata_version, model_id):
682694
"""
683695
Retrieves model metadata from the repository.
684696
"""
685-
from kraken import repo
697+
from htrmopo import get_description
698+
from htrmopo.util import iso15924_to_name, iso639_3_to_name
686699
from kraken.lib.util import is_printable, make_printable
687700

688-
desc = repo.get_description(model_id)
701+
def _render_creators(creators):
702+
o = []
703+
for creator in creators:
704+
c_text = creator['name']
705+
if (orcid := creator.get('orcid', None)) is not None:
706+
c_text += f' ({orcid})'
707+
if (affiliation := creator.get('affiliation', None)) is not None:
708+
c_text += f' ({affiliation})'
709+
o.append(c_text)
710+
return o
689711

690-
chars = []
691-
combining = []
692-
for char in sorted(desc['graphemes']):
693-
if not is_printable(char):
694-
combining.append(make_printable(char))
695-
else:
696-
chars.append(char)
697-
message(
698-
'name: {}\n\n{}\n\n{}\nscripts: {}\nalphabet: {} {}\naccuracy: {:.2f}%\nlicense: {}\nauthor(s): {}\ndate: {}'.format(
699-
model_id, desc['summary'], desc['description'], ' '.join(
700-
desc['script']), ''.join(chars), ', '.join(combining), desc['accuracy'], desc['license']['id'], '; '.join(
701-
x['name'] for x in desc['creators']), desc['publication_date']))
702-
ctx.exit(0)
712+
def _render_metrics(metrics):
713+
return [f'{k}: {v:.2f}' for k, v in metrics.items()]
714+
715+
if metadata_version == 'highest':
716+
metadata_version = None
717+
718+
try:
719+
desc = get_description(model_id, version=metadata_version)
720+
except ValueError as e:
721+
logger.error(e)
722+
ctx.exit(1)
723+
724+
if getattr(desc, 'software_name', None) != 'kraken' or 'kraken_pytorch' not in desc.keywords:
725+
logger.error('Record exists but is not a kraken-compatible model')
726+
ctx.exit(1)
727+
728+
if desc.version == 'v0':
729+
chars = []
730+
combining = []
731+
for char in sorted(desc.graphemes):
732+
if not is_printable(char):
733+
combining.append(make_printable(char))
734+
else:
735+
chars.append(char)
736+
737+
table = Table(title=desc.summary, show_header=False)
738+
table.add_column('key', justify="left", no_wrap=True)
739+
table.add_column('value', justify="left", no_wrap=False)
740+
table.add_row('DOI', desc.doi)
741+
table.add_row('concept DOI', desc.concept_doi)
742+
table.add_row('publication date', desc.publication_date.isoformat())
743+
table.add_row('model type', Group(*desc.model_type))
744+
table.add_row('script', Group(*[iso15924_to_name(x) for x in desc.script]))
745+
table.add_row('alphabet', Group(' '.join(chars), ', '.join(combining)))
746+
table.add_row('keywords', Group(*desc.keywords))
747+
table.add_row('metrics', Group(*_render_metrics(desc.metrics)))
748+
table.add_row('license', desc.license)
749+
table.add_row('creators', Group(*_render_creators(desc.creators)))
750+
table.add_row('description', desc.description)
751+
elif desc.version == 'v1':
752+
table = Table(title=desc.summary, show_header=False)
753+
table.add_column('key', justify="left", no_wrap=True)
754+
table.add_column('value', justify="left", no_wrap=False)
755+
table.add_row('DOI', desc.doi)
756+
table.add_row('concept DOI', desc.concept_doi)
757+
table.add_row('publication date', desc.publication_date.isoformat())
758+
table.add_row('model type', Group(*desc.model_type))
759+
table.add_row('language', Group(*[iso639_3_to_name(x) for x in desc.language]))
760+
table.add_row('script', Group(*[iso15924_to_name(x) for x in desc.script]))
761+
table.add_row('keywords', Group(*desc.keywords))
762+
table.add_row('datasets', Group(*desc.datasets))
763+
table.add_row('metrics', Group(*_render_metrics(desc.metrics)))
764+
table.add_row('base model', Group(*desc.base_model))
765+
table.add_row('software', desc.software_name)
766+
table.add_row('software_hints', Group(*desc.software_hints))
767+
table.add_row('license', desc.license)
768+
table.add_row('creators', Group(*_render_creators(desc.creators)))
769+
table.add_row('description', Markdown(desc.description))
770+
771+
print(table)
703772

704773

705774
@cli.command('list')
@@ -708,14 +777,41 @@ def list_models(ctx):
708777
"""
709778
Lists models in the repository.
710779
"""
711-
from kraken import repo
780+
from htrmopo import get_listing
781+
from collections import defaultdict
712782
from kraken.lib.progress import KrakenProgressBar
713783

714784
with KrakenProgressBar() as progress:
715785
download_task = progress.add_task('Retrieving model list', total=0, visible=True if not ctx.meta['verbose'] else False)
716-
model_list = repo.get_listing(lambda total, advance: progress.update(download_task, total=total, advance=advance))
717-
for id, metadata in model_list.items():
718-
message('{} ({}) - {}'.format(id, ', '.join(metadata['type']), metadata['summary']))
786+
repository = get_listing(lambda total, advance: progress.update(download_task, total=total, advance=advance))
787+
# aggregate models under their concept DOI
788+
concepts = defaultdict(list)
789+
for item in repository.values():
790+
# both got the same DOI information
791+
record = item['v0'] if item['v0'] else item['v1']
792+
concepts[record.concept_doi].append(record.doi)
793+
794+
table = Table(show_header=True)
795+
table.add_column('DOI', justify="left", no_wrap=True)
796+
table.add_column('summary', justify="left", no_wrap=False)
797+
table.add_column('model type', justify="left", no_wrap=False)
798+
table.add_column('keywords', justify="left", no_wrap=False)
799+
800+
for k, v in concepts.items():
801+
records = [repository[x]['v1'] if 'v1' in repository[x] else repository[x]['v0'] for x in v]
802+
records = filter(lambda record: getattr(record, 'software_name', None) != 'kraken' or 'kraken_pytorch' not in record.keywords, records)
803+
records = sorted(records, key=lambda x: x.publication_date, reverse=True)
804+
if not len(records):
805+
continue
806+
807+
t = Tree(k)
808+
[t.add(x.doi) for x in records]
809+
table.add_row(t,
810+
Group(*[''] + [x.summary for x in records]),
811+
Group(*[''] + ['; '.join(x.model_type) for x in records]),
812+
Group(*[''] + ['; '.join(x.keywords) for x in records]))
813+
814+
print(table)
719815
ctx.exit(0)
720816

721817

0 commit comments

Comments
 (0)