Skip to content

Commit

Permalink
bump version, some logging
Browse files Browse the repository at this point in the history
  • Loading branch information
bminixhofer committed Jan 22, 2024
1 parent 3f9ab26 commit b83e419
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 9 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="wtpsplit",
version="1.2.4",
version="1.3.0",
packages=["wtpsplit"],
description="Robust, adaptible sentence segmentation for 85 languages",
author="Benjamin Minixhofer",
Expand Down
2 changes: 1 addition & 1 deletion wtpsplit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from wtpsplit.extract import ORTWrapper, PyTorchWrapper, extract
from wtpsplit.utils import Constants, indices_to_sentences, sigmoid

__version__ = "1.2.4"
__version__ = "1.3.0"


class WtP:
Expand Down
19 changes: 12 additions & 7 deletions wtpsplit/evaluation/adapt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tqdm.auto import tqdm
from transformers import AutoModelForTokenClassification, HfArgumentParser

import wtpsplit.models # noqa: F401
import wtpsplit.models # noqa: F401
from wtpsplit.evaluation import evaluate_mixture, get_labels, train_mixture
from wtpsplit.extract import PyTorchWrapper, extract
from wtpsplit.utils import Constants
Expand All @@ -27,13 +27,13 @@ class Args:
# "meta": {
# "train_data": ["train sentence 1", "train sentence 2"]
# },
# "data": ["test sentence 1", "test sentence 2"]
# "data": ["test sentence 1", "test sentence 2"]
# }
# }
# }
# }
eval_data_path: str = "data/eval_new.pth"
valid_text_path: str = None#"data/sentence/valid.parquet"
valid_text_path: str = None # "data/sentence/valid.parquet"
device: str = "cuda"
block_size: int = 512
stride: int = 64
Expand Down Expand Up @@ -128,7 +128,7 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, max_n_train_
valid_data = load_dataset("parquet", data_files=args.valid_text_path, split="train")
else:
valid_data = None

model = PyTorchWrapper(AutoModelForTokenClassification.from_pretrained(args.model_path).to(args.device))

# first, logits for everything.
Expand Down Expand Up @@ -182,20 +182,25 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, max_n_train_
# just for printing
score_t = score_t or 0.0
score_punct = score_punct or 0.0
print(f"{lang_code} {dataset_name} {score_u:.3f} {score_t:.3f} {score_punct:.3f}")
print(f"{lang_code} {dataset_name} U={score_u:.3f} T={score_t:.3f} PUNCT={score_punct:.3f}")

mixture_path = Constants.CACHE_DIR / (model.config.mixture_name + ".skops")
results_path = Constants.CACHE_DIR / (model.config.mixture_name + "_intrinsic_results.json")

sio.dump(
clfs,
open(
Constants.CACHE_DIR / (model.config.mixture_name + ".skops"),
mixture_path,
"wb",
),
)
json.dump(
results,
open(
Constants.CACHE_DIR / (model.config.mixture_name + "_intrinsic_results.json"),
results_path,
"w",
),
indent=4,
)
print("Wrote mixture to", mixture_path)
print("Wrote results to", results_path)

0 comments on commit b83e419

Please sign in to comment.