Skip to content

Commit

Permalink
Merge pull request #76 from AnFreTh/main
Browse files Browse the repository at this point in the history
merging main to dev
  • Loading branch information
AnFreTh authored Aug 9, 2024
2 parents 792969a + acfa8be commit 1951532
Show file tree
Hide file tree
Showing 13 changed files with 512 additions and 441 deletions.
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
recursive-exclude notebooks *
recursive-include stream/preprocessed_datasets/*
recursive-include stream/pre_embedded_datasets/*
include stream/preprocessor/config/default_preprocessing_steps.json
102 changes: 53 additions & 49 deletions docs/notebooks/datasets.ipynb

Large diffs are not rendered by default.

139 changes: 0 additions & 139 deletions docs/notebooks/datasets.md

This file was deleted.

200 changes: 116 additions & 84 deletions docs/notebooks/examples.ipynb

Large diffs are not rendered by default.

161 changes: 124 additions & 37 deletions docs/notebooks/quickstart.ipynb

Large diffs are not rendered by default.

68 changes: 0 additions & 68 deletions docs/notebooks/quickstart.md

This file was deleted.

10 changes: 5 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,22 +47,22 @@
install_requires=install_reqs,
# extras_require=extras_reqs,
license="MIT", # adapt based on your needs
packages=find_packages(
exclude=["examples", "examples.*", "tests", "tests.*"]),
packages=find_packages(exclude=["examples", "examples.*", "tests", "tests.*"]),
include_package_data=True,
# package_dir={"stream": "stream"},
package_data={
# Use '**' to include all files within subdirectories recursively
"stream_topic": [
"preprocessed_datasets/**/*",
"preprocessor/config/default_preprocessing_steps.json"
"pre_embedded_datasets/**/*",
"preprocessor/config/default_preprocessing_steps.json",
],
},
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
],
project_urls={'Documentation': DOCS},
url=HOMEPAGE
project_urls={"Documentation": DOCS},
url=HOMEPAGE,
)
Empty file.
2 changes: 1 addition & 1 deletion stream_topic/__version__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Version information."""

# The following line *must* be the last in the module, exactly as formatted:
__version__ = "0.1.2"
__version__ = "0.1.4"
2 changes: 1 addition & 1 deletion stream_topic/metrics/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
PARAPHRASE_TRANSFORMER_MODEL = "paraphrase-MiniLM-L3-v2"
SENTENCE_TRANSFORMER_MODEL = "all-MiniLM-L6-v2"
EMBEDDING_PATH = "/embeddings"
EMBEDDING_PATH = "embeddings"
NLTK_STOPWORD_LANGUAGE = "english"
2 changes: 1 addition & 1 deletion stream_topic/models/CEDC.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def _clustering(self):

def fit(
self,
dataset: TMDataset = None,
dataset: TMDataset,
n_topics: int = 20,
only_nouns: bool = False,
clean: bool = False,
Expand Down
21 changes: 7 additions & 14 deletions stream_topic/models/DCTE.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
from datasets import Dataset
from loguru import logger
from sentence_transformers.losses import CosineSimilarityLoss
from setfit import SetFitModel,TrainingArguments
from setfit import SetFitModel, TrainingArguments
from setfit import Trainer as SetfitTrainer
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder

Expand Down Expand Up @@ -124,9 +123,7 @@ def _get_topic_representation(self, predict_df: pd.DataFrame, top_words: int):
)

one_hot_encoder = OneHotEncoder(sparse=False)
predictions_one_hot = one_hot_encoder.fit_transform(
predict_df[["predictions"]]
)
predictions_one_hot = one_hot_encoder.fit_transform(predict_df[["predictions"]])

beta = tfidf
theta = predictions_one_hot
Expand Down Expand Up @@ -215,9 +212,8 @@ def fit(

logger.info("--- Training completed successfully. ---")
self._status = TrainingStatus.SUCCEEDED

return self


def predict(self, dataset):
"""
Expand All @@ -242,9 +238,9 @@ def predict(self, dataset):

labels = self.model(predict_df["text"])
predict_df["predictions"] = labels

return labels

def get_topics(self, dataset, n_words=10):
"""
Retrieve the top words for each topic.
Expand All @@ -269,11 +265,8 @@ def get_topics(self, dataset, n_words=10):

labels = self.model(predict_df["text"])
predict_df["predictions"] = labels

topic_dict, beta, theta = self._get_topic_representation(predict_df, n_words)
if self._status != TrainingStatus.SUCCEEDED:
raise RuntimeError("Model has not been trained yet or failed.")
return [
[word for word, _ in topic_dict[key][:n_words]]
for key in topic_dict
]
return [[word for word, _ in topic_dict[key][:n_words]] for key in topic_dict]
Loading

0 comments on commit 1951532

Please sign in to comment.