Skip to content

Commit

Permalink
More types (#484); allow alternative formats for save_dir files (#502)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Sean MacAvaney <[email protected]>
  • Loading branch information
cmacdonald and seanmacavaney authored Dec 5, 2024
1 parent feac5dd commit 741cdb9
Show file tree
Hide file tree
Showing 22 changed files with 238 additions and 128 deletions.
22 changes: 20 additions & 2 deletions .github/workflows/style.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Code Style Checks
name: style

on:
push:
Expand All @@ -7,7 +7,7 @@ on:
branches: [ master ]

jobs:
build:
flake8:
runs-on: 'ubuntu-latest'
steps:
- uses: actions/checkout@v4
Expand All @@ -24,3 +24,21 @@ jobs:
- name: pt.java.required checks
run: |
flake8 ./pyterrier --select=PT --show-source --statistics --count
mypy:
runs-on: 'ubuntu-latest'
steps:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.10'

- name: Install
run: |
pip install mypy --upgrade -r requirements.txt -r requirements-test.txt
pip install -e .
- name: MyPy
run: 'mypy --disable-error-code=import-untyped pyterrier || true'
2 changes: 2 additions & 0 deletions pyterrier/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@

# will be set in terrier.terrier.java once java is loaded
IndexRef = None
# will be set in once utils.set_tqdm() once _() runs
tqdm = None


# deprecated functions explored to the main namespace, which will be removed in a future version
Expand Down
7 changes: 4 additions & 3 deletions pyterrier/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __len__(self) -> int:

def _flatten(transformers: Iterable[Transformer], cls: type) -> Tuple[Transformer]:
return list(chain.from_iterable(
(t._transformers if isinstance(t, cls) else [t])
(t._transformers if isinstance(t, cls) else [t]) # type: ignore
for t in transformers
))

Expand Down Expand Up @@ -193,6 +193,7 @@ def fuse_left(self, left: Transformer) -> Optional[Transformer]:
# If the preceding component supports a native rank cutoff (via fuse_rank_cutoff), apply it.
if isinstance(left, SupportsFuseRankCutoff):
return left.fuse_rank_cutoff(self.k)
return None

class FeatureUnion(NAryTransformerBase):
"""
Expand Down Expand Up @@ -295,7 +296,7 @@ def compile(self) -> Transformer:
"""
Returns a new transformer that fuses feature unions where possible.
"""
out = deque()
out : deque = deque()
inp = deque([t.compile() for t in self._transformers])
while inp:
right = inp.popleft()
Expand Down Expand Up @@ -382,7 +383,7 @@ def compile(self, verbose: bool = False) -> Transformer:
"""Returns a new transformer that iteratively fuses adjacent transformers to form a more efficient pipeline."""
# compile constituent transformers (flatten allows complie() to return Compose pipelines)
inp = deque(_flatten((t.compile() for t in self._transformers), Compose))
out = deque()
out : deque = deque()
counter = 1
while inp:
if verbose:
Expand Down
28 changes: 14 additions & 14 deletions pyterrier/apply_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Callable, Any, Union, Optional, Iterable
import itertools
import more_itertools
import numpy as np
import numpy.typing as npt
import pandas as pd
import pyterrier as pt

Expand Down Expand Up @@ -92,7 +92,7 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
# batching
iterator = pt.model.split_df(inp, batch_size=self.batch_size)
if self.verbose:
iterator = pt.tqdm(iterator, desc="pt.apply", unit='row')
iterator = pt.tqdm(iterator, desc="pt.apply", unit='row') # type: ignore
return pd.concat([self._apply_df(chunk_df) for chunk_df in iterator])

def _apply_df(self, inp: pd.DataFrame) -> pd.DataFrame:
Expand Down Expand Up @@ -148,7 +148,7 @@ def transform(self, res: pd.DataFrame) -> pd.DataFrame:
it = res.groupby("qid")
lastqid = None
if self.verbose:
it = pt.tqdm(it, unit='query')
it = pt.tqdm(it, unit='query') # type: ignore
try:
if self.batch_size is None:
query_dfs = []
Expand All @@ -163,7 +163,7 @@ def transform(self, res: pd.DataFrame) -> pd.DataFrame:
iterator = pt.model.split_df(group, batch_size=self.batch_size)
query_dfs.append( pd.concat([self.fn(chunk_df) for chunk_df in iterator]) )
except Exception as a:
raise Exception("Problem applying %s for qid %s" % (self.fn, lastqid)) from a
raise Exception("Problem applying %r for qid %s" % (self.fn, lastqid)) from a # %r because its a function with bytes representation (mypy)

if self.add_ranks:
try:
Expand Down Expand Up @@ -253,7 +253,7 @@ def __repr__(self):
def _transform_rowwise(self, outputRes):
if self.verbose:
pt.tqdm.pandas(desc="pt.apply.doc_score", unit="d")
outputRes["score"] = outputRes.progress_apply(self.fn, axis=1).astype('float64')
outputRes["score"] = outputRes.progress_apply(self.fn, axis=1).astype('float64') # type: ignore
else:
outputRes["score"] = outputRes.apply(self.fn, axis=1).astype('float64')
outputRes = pt.model.add_ranks(outputRes)
Expand All @@ -275,7 +275,7 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:

iterator = pt.model.split_df(outputRes, batch_size=self.batch_size)
if self.verbose:
iterator = pt.tqdm(iterator, desc="pt.apply", unit='row')
iterator = pt.tqdm(iterator, desc="pt.apply", unit='row') # type: ignore
rtr = pd.concat([self._transform_batchwise(chunk_df) for chunk_df in iterator])
rtr = pt.model.add_ranks(rtr)
return rtr
Expand All @@ -294,7 +294,7 @@ def _feature_fn(row):
pipe = pt.terrier.Retriever(index) >> pt.apply.doc_features(_feature_fn) >> pt.LTRpipeline(xgBoost())
"""
def __init__(self,
fn: Callable[[Union[pd.Series, pt.model.IterDictRecord]], np.array],
fn: Callable[[Union[pd.Series, pt.model.IterDictRecord]], npt.NDArray],
*,
verbose: bool = False
):
Expand All @@ -313,7 +313,7 @@ def transform_iter(self, inp: pt.model.IterDict) -> pt.model.IterDict:
# we assume that the function can take a dictionary as well as a pandas.Series. As long as [""] notation is used
# to access fields, both should work
if self.verbose:
inp = pt.tqdm(inp, desc="pt.apply.doc_features")
inp = pt.tqdm(inp, desc="pt.apply.doc_features") # type: ignore
for row in inp:
row["features"] = self.fn(row)
yield row
Expand All @@ -322,8 +322,8 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
fn = self.fn
outputRes = inp.copy()
if self.verbose:
pt.tqdm.pandas(desc="pt.apply.doc_features", unit="d")
outputRes["features"] = outputRes.progress_apply(fn, axis=1)
pt.tqdm.pandas(desc="pt.apply.doc_features", unit="d") # type: ignore
outputRes["features"] = outputRes.progress_apply(fn, axis=1) # type: ignore
else:
outputRes["features"] = outputRes.apply(fn, axis=1)
return outputRes
Expand Down Expand Up @@ -368,7 +368,7 @@ def transform_iter(self, inp: pt.model.IterDict) -> pt.model.IterDict:
# we assume that the function can take a dictionary as well as a pandas.Series. As long as [""] notation is used
# to access fields, both should work
if self.verbose:
inp = pt.tqdm(inp, desc="pt.apply.query")
inp = pt.tqdm(inp, desc="pt.apply.query") # type: ignore
for row in inp:
row = row.copy()
if "query" in row:
Expand All @@ -384,8 +384,8 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
outputRes = inp.copy()
try:
if self.verbose:
pt.tqdm.pandas(desc="pt.apply.query", unit="d")
outputRes["query"] = outputRes.progress_apply(self.fn, axis=1)
pt.tqdm.pandas(desc="pt.apply.query", unit="d") # type: ignore
outputRes["query"] = outputRes.progress_apply(self.fn, axis=1) # type: ignore
else:
outputRes["query"] = outputRes.apply(self.fn, axis=1)
except ValueError as ve:
Expand Down Expand Up @@ -444,7 +444,7 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
# batching
iterator = pt.model.split_df(inp, batch_size=self.batch_size)
if self.verbose:
iterator = pt.tqdm(iterator, desc="pt.apply", unit='row')
iterator = pt.tqdm(iterator, desc="pt.apply", unit='row') # type: ignore
rtr = pd.concat([self.fn(chunk_df) for chunk_df in iterator])
return rtr

Expand Down
20 changes: 12 additions & 8 deletions pyterrier/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import json
import pandas as pd
from .transformer import is_lambda
from abc import abstractmethod
import types
from typing import Union, Tuple, Iterator, Dict, Any, List, Literal
from typing import Union, Tuple, Iterator, Dict, Any, List, Literal, Optional
from warnings import warn
import requests
from .io import autoopen, touch
Expand Down Expand Up @@ -54,12 +55,13 @@ def get_corpus(self):
"""
pass

@abstractmethod
def get_corpus_iter(self, verbose=True) -> pt.model.IterDict:
"""
Returns an iter of dicts for this collection. If verbose=True, a tqdm pbar shows the progress over this iterator.
"""
pass

def get_corpus_lang(self) -> Union[str,None]:
"""
Returns the ISO 639-1 language code for the corpus, or None for multiple/other/unknown
Expand All @@ -72,6 +74,7 @@ def get_index(self, variant=None, **kwargs):
"""
pass

@abstractmethod
def get_topics(self, variant=None) -> pd.DataFrame:
"""
Returns the topics, as a dataframe, ready for retrieval.
Expand All @@ -84,6 +87,7 @@ def get_topics_lang(self) -> Union[str,None]:
"""
return None

@abstractmethod
def get_qrels(self, variant=None) -> pd.DataFrame:
"""
Returns the qrels, as a dataframe, ready for evaluation.
Expand All @@ -109,7 +113,7 @@ def get_results(self, variant=None) -> pd.DataFrame:
"""
Returns a standard result set provided by the dataset. This is useful for re-ranking experiments.
"""
pass
return None

class RemoteDataset(Dataset):

Expand Down Expand Up @@ -139,7 +143,7 @@ def download(URLs : Union[str,List[str]], filename : str, **kwargs):
r = requests.get(url, allow_redirects=True, stream=True, **kwargs)
r.raise_for_status()
total = int(r.headers.get('content-length', 0))
with pt.io.finalized_open(filename, 'b') as file, pt.tqdm(
with pt.io.finalized_open(filename, 'b') as file, pt.tqdm( # type: ignore
desc=basename,
total=total,
unit='iB',
Expand Down Expand Up @@ -507,7 +511,7 @@ def get_results(self, variant=None) -> pd.DataFrame:
result.sort_values(by=['qid', 'score', 'docno'], ascending=[True, False, True], inplace=True) # ensure data is sorted by qid, -score, did
# result doesn't yet contain queries (only qids) so load and merge them in
topics = self.get_topics(variant)
result = pd.merge(result, topics, how='left', on='qid', copy=False)
result = pd.merge(result, topics, how='left', on='qid')
return result

def _describe_component(self, component):
Expand Down Expand Up @@ -610,7 +614,7 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
set_docnos = set(docnos)
it = (tuple(getattr(doc, f) for f in fields) for doc in docstore.get_many_iter(set_docnos))
if self.verbose:
it = pd.tqdm(it, unit='d', total=len(set_docnos), desc='IRDSTextLoader')
it = pt.tqdm(it, unit='d', total=len(set_docnos), desc='IRDSTextLoader') # type: ignore
metadata = pd.DataFrame(list(it), columns=fields).set_index('doc_id')
metadata_frame = metadata.loc[docnos].reset_index(drop=True)

Expand Down Expand Up @@ -1104,7 +1108,7 @@ def _merge_years(self, component, variant):
"corpus_iter" : lambda dataset, **kwargs : pt.index.treccollection2textgen(dataset.get_corpus(), num_docs=11429, verbose=kwargs.get("verbose", False))
}

DATASET_MAP = {
DATASET_MAP : Dict[str, Dataset] = {
# used for UGlasgow teaching
"50pct" : RemoteDataset("50pct", FIFTY_PCT_FILES),
# umass antique corpus - see http://ciir.cs.umass.edu/downloads/Antique/
Expand Down Expand Up @@ -1222,7 +1226,7 @@ def list_datasets(en_only=True):
def transformer_from_dataset(
dataset : Union[str, Dataset],
clz,
variant: str = None,
variant: Optional[str] = None,
version: str = 'latest',
**kwargs) -> pt.Transformer:
"""Returns a Transformer instance of type ``clz`` for the provided index of variant ``variant``."""
Expand Down
8 changes: 4 additions & 4 deletions pyterrier/debug.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from . import Transformer
from typing import List
from typing import List, Optional

def print_columns(by_query : bool = False, message : str = None) -> Transformer:
def print_columns(by_query : Optional[bool] = False, message : Optional[str] = None) -> Transformer:
"""
Returns a transformer that can be inserted into pipelines that can print the column names of the dataframe
at this stage in the pipeline:
Expand Down Expand Up @@ -82,8 +82,8 @@ def print_rows(
by_query : bool = True,
jupyter: bool = True,
head : int = 2,
message : str = None,
columns : List[str] = None) -> Transformer:
message : Optional[str] = None,
columns : Optional[List[str]] = None) -> Transformer:
"""
Returns a transformer that can be inserted into pipelines that can print some of the dataframe
at this stage in the pipeline:
Expand Down
3 changes: 2 additions & 1 deletion pyterrier/java/_core.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# type: ignore
import os
from pyterrier.java import required_raise, required, before_init, started, mavenresolver, JavaClasses, JavaInitializer, register_config
from typing import Optional
Expand Down Expand Up @@ -153,7 +154,7 @@ def add_jar(jar_path):


@before_init
def add_package(org_name: str = None, package_name: str = None, version: str = None, file_type='jar'):
def add_package(org_name : str, package_name : str, version : Optional[str] = None, file_type : str = 'jar'):
if version is None or version == 'snapshot':
version = mavenresolver.latest_version_num(org_name, package_name)
file_name = mavenresolver.get_package_jar(org_name, package_name, version, artifact=file_type)
Expand Down
3 changes: 2 additions & 1 deletion pyterrier/java/_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# type: ignore
import sys
import warnings
from functools import wraps
Expand Down Expand Up @@ -387,7 +388,7 @@ def register_config(name, config: Dict[str, Any]):
class JavaClasses:
def __init__(self, **mapping: Union[str, Callable[[], str]]):
self._mapping = mapping
self._cache = {}
self._cache : Dict[str, Callable]= {}

def __dir__(self):
return list(self._mapping.keys())
Expand Down
1 change: 1 addition & 0 deletions pyterrier/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def split_df(df : pd.DataFrame, N: Optional[int] = None, *, batch_size: Optional
assert (N is None) != (batch_size is None), "Either N or batch_size should be provided (and not both)"

if N is None:
assert batch_size is not None
N = math.ceil(len(df) / batch_size)

type = None
Expand Down
12 changes: 6 additions & 6 deletions pyterrier/new.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

from typing import Sequence, Union
from typing import Sequence, Union, Optional, cast, Iterable
import pandas as pd
from .model import add_ranks

Expand All @@ -9,7 +9,7 @@ def empty_Q() -> pd.DataFrame:
"""
return pd.DataFrame(columns=["qid", "query"])

def queries(queries : Union[str, Sequence[str]], qid : Union[str, Sequence[str]] = None, **others) -> pd.DataFrame:
def queries(queries : Union[str, Sequence[str]], qid : Optional[Union[str, Iterable[str]]] = None, **others) -> pd.DataFrame:
"""
Creates a new queries dataframe. Will return a dataframe with the columns `["qid", "query"]`.
Any further lists in others will also be added.
Expand Down Expand Up @@ -40,7 +40,7 @@ def queries(queries : Union[str, Sequence[str]], qid : Union[str, Sequence[str]]
assert type(qid) == str
return pd.DataFrame({"qid" : [qid], "query" : [queries], **others})
if qid is None:
qid = map(str, range(1, len(queries)+1))
qid = cast(Iterable[str], map(str, range(1, len(queries)+1))) # noqa: PT100 (this is typing.cast, not jinus.cast)
return pd.DataFrame({"qid" : qid, "query" : queries, **others})

Q = queries
Expand All @@ -53,8 +53,8 @@ def empty_R() -> pd.DataFrame:

def ranked_documents(
scores : Sequence[Sequence[float]],
qid : Sequence[str] = None,
docno=None,
qid : Optional[Sequence[str]] = None,
docno : Optional[Sequence[Sequence[str]]] = None,
**others) -> pd.DataFrame:
"""
Creates a new ranked documents dataframe. Will return a dataframe with the columns `["qid", "docno", "score", "rank"]`.
Expand Down Expand Up @@ -120,4 +120,4 @@ def ranked_documents(
raise ValueError("We assume multiple documents, for now")
return add_ranks(rtr)

R = ranked_documents
R = ranked_documents
Loading

0 comments on commit 741cdb9

Please sign in to comment.