Skip to content

Commit

Permalink
Generate embeddings and summaries of a group in context menu (JabRef#…
Browse files Browse the repository at this point in the history
…11832)

* Generate embeddings and summaries of a group in context menu

* Fixes from code review

* Fix checkers

* Fix from code review

* Fix checkers

* Merge with main
  • Loading branch information
InAnYan authored Oct 6, 2024
1 parent 3f0c707 commit b6efdee
Show file tree
Hide file tree
Showing 15 changed files with 323 additions and 104 deletions.
2 changes: 2 additions & 0 deletions src/main/java/org/jabref/gui/actions/StandardActions.java
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ public enum StandardActions implements Action {
GROUP_REMOVE_WITH_SUBGROUPS(Localization.lang("Also remove subgroups")),
GROUP_CHAT(Localization.lang("Chat with group")),
GROUP_EDIT(Localization.lang("Edit group")),
GROUP_GENERATE_SUMMARIES(Localization.lang("Generate summaries for entries in the group")),
GROUP_GENERATE_EMBEDDINGS(Localization.lang("Generate embeddings for linked files in the group")),
GROUP_SUBGROUP_ADD(Localization.lang("Add subgroup")),
GROUP_SUBGROUP_REMOVE(Localization.lang("Remove subgroups")),
GROUP_SUBGROUP_SORT(Localization.lang("Sort subgroups A-Z")),
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/jabref/gui/frame/JabRefFrame.java
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,10 @@ public JabRefFrame(Stage mainStage,
this.sidePane = new SidePane(
this,
this.preferences,
aiService,
Injector.instantiateModelOrService(JournalAbbreviationRepository.class),
taskExecutor,
dialogService,
aiService,
stateManager,
fileUpdateMonitor,
entryTypesManager,
Expand Down
6 changes: 6 additions & 0 deletions src/main/java/org/jabref/gui/groups/GroupTreeView.java
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,8 @@ private ContextMenu createContextMenuForGroup(GroupNodeViewModel group) {

contextMenu.getItems().addAll(
factory.createMenuItem(StandardActions.GROUP_EDIT, new ContextAction(StandardActions.GROUP_EDIT, group)),
factory.createMenuItem(StandardActions.GROUP_GENERATE_EMBEDDINGS, new ContextAction(StandardActions.GROUP_GENERATE_EMBEDDINGS, group)),
factory.createMenuItem(StandardActions.GROUP_GENERATE_SUMMARIES, new ContextAction(StandardActions.GROUP_GENERATE_SUMMARIES, group)),
removeGroup,
new SeparatorMenuItem(),
factory.createMenuItem(StandardActions.GROUP_SUBGROUP_ADD, new ContextAction(StandardActions.GROUP_SUBGROUP_ADD, group)),
Expand Down Expand Up @@ -668,6 +670,10 @@ public void execute() {
viewModel.editGroup(group);
groupTree.refresh();
}
case GROUP_GENERATE_EMBEDDINGS ->
viewModel.generateEmbeddings(group);
case GROUP_GENERATE_SUMMARIES ->
viewModel.generateSummaries(group);
case GROUP_CHAT ->
viewModel.chatWithGroup(group);
case GROUP_SUBGROUP_ADD ->
Expand Down
52 changes: 47 additions & 5 deletions src/main/java/org/jabref/gui/groups/GroupTreeViewModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.jabref.logic.util.TaskExecutor;
import org.jabref.model.database.BibDatabaseContext;
import org.jabref.model.entry.BibEntry;
import org.jabref.model.entry.LinkedFile;
import org.jabref.model.groups.AbstractGroup;
import org.jabref.model.groups.AutomaticKeywordGroup;
import org.jabref.model.groups.AutomaticPersonsGroup;
Expand Down Expand Up @@ -390,11 +391,7 @@ public void editGroup(GroupNodeViewModel oldGroup) {
}

public void chatWithGroup(GroupNodeViewModel group) {
// This should probably be done some other way. Please don't blame, it's just a thing to make it quick and fast.
if (currentDatabase.isEmpty()) {
dialogService.showErrorDialogAndWait(Localization.lang("Unable to chat with group"), Localization.lang("No library is selected."));
return;
}
assert currentDatabase.isPresent();

StringProperty groupNameProperty = group.getGroupNode().getGroup().nameProperty();

Expand Down Expand Up @@ -434,6 +431,51 @@ private void openAiChat(StringProperty name, ObservableList<ChatMessage> chatHis
}
}

public void generateEmbeddings(GroupNodeViewModel groupNode) {
assert currentDatabase.isPresent();

AbstractGroup group = groupNode.getGroupNode().getGroup();

List<LinkedFile> linkedFiles = currentDatabase
.get()
.getDatabase()
.getEntries()
.stream()
.filter(group::isMatch)
.flatMap(entry -> entry.getFiles().stream())
.toList();

aiService.getIngestionService().ingest(
group.nameProperty(),
linkedFiles,
currentDatabase.get()
);

dialogService.notify(Localization.lang("Ingestion started for group \"%0\".", group.getName()));
}

public void generateSummaries(GroupNodeViewModel groupNode) {
assert currentDatabase.isPresent();

AbstractGroup group = groupNode.getGroupNode().getGroup();

List<BibEntry> entries = currentDatabase
.get()
.getDatabase()
.getEntries()
.stream()
.filter(group::isMatch)
.toList();

aiService.getSummariesService().summarize(
group.nameProperty(),
entries,
currentDatabase.get()
);

dialogService.notify(Localization.lang("Summarization started for group \"%0\".", group.getName()));
}

public void removeSubgroups(GroupNodeViewModel group) {
boolean confirmation = dialogService.showConfirmationDialogAndWait(
Localization.lang("Remove subgroups"),
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/org/jabref/gui/sidepane/SidePane.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ public class SidePane extends VBox {

public SidePane(LibraryTabContainer tabContainer,
GuiPreferences preferences,
AiService aiService,
JournalAbbreviationRepository abbreviationRepository,
TaskExecutor taskExecutor,
DialogService dialogService,
AiService aiService,
StateManager stateManager,
FileUpdateMonitor fileUpdateMonitor,
BibEntryTypesManager entryTypesManager,
Expand All @@ -47,11 +47,11 @@ public SidePane(LibraryTabContainer tabContainer,
this.viewModel = new SidePaneViewModel(
tabContainer,
preferences,
aiService,
abbreviationRepository,
stateManager,
taskExecutor,
dialogService,
aiService,
fileUpdateMonitor,
entryTypesManager,
clipBoardManager,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
public class SidePaneContentFactory {
private final LibraryTabContainer tabContainer;
private final GuiPreferences preferences;
private final AiService aiService;
private final JournalAbbreviationRepository abbreviationRepository;
private final TaskExecutor taskExecutor;
private final DialogService dialogService;
private final AiService aiService;
private final StateManager stateManager;
private final FileUpdateMonitor fileUpdateMonitor;
private final BibEntryTypesManager entryTypesManager;
Expand All @@ -34,21 +34,21 @@ public class SidePaneContentFactory {

public SidePaneContentFactory(LibraryTabContainer tabContainer,
GuiPreferences preferences,
AiService aiService,
JournalAbbreviationRepository abbreviationRepository,
TaskExecutor taskExecutor,
DialogService dialogService,
AiService aiService,
StateManager stateManager,
FileUpdateMonitor fileUpdateMonitor,
BibEntryTypesManager entryTypesManager,
ClipBoardManager clipBoardManager,
UndoManager undoManager) {
this.tabContainer = tabContainer;
this.preferences = preferences;
this.aiService = aiService;
this.abbreviationRepository = abbreviationRepository;
this.taskExecutor = taskExecutor;
this.dialogService = dialogService;
this.aiService = aiService;
this.stateManager = stateManager;
this.fileUpdateMonitor = fileUpdateMonitor;
this.entryTypesManager = entryTypesManager;
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/org/jabref/gui/sidepane/SidePaneViewModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ public class SidePaneViewModel extends AbstractViewModel {

public SidePaneViewModel(LibraryTabContainer tabContainer,
GuiPreferences preferences,
AiService aiService,
JournalAbbreviationRepository abbreviationRepository,
StateManager stateManager,
TaskExecutor taskExecutor,
DialogService dialogService,
AiService aiService,
FileUpdateMonitor fileUpdateMonitor,
BibEntryTypesManager entryTypesManager,
ClipBoardManager clipBoardManager,
Expand All @@ -57,10 +57,10 @@ public SidePaneViewModel(LibraryTabContainer tabContainer,
this.sidePaneContentFactory = new SidePaneContentFactory(
tabContainer,
preferences,
aiService,
abbreviationRepository,
taskExecutor,
dialogService,
aiService,
stateManager,
fileUpdateMonitor,
entryTypesManager,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
public class GenerateEmbeddingsForSeveralTask extends BackgroundTask<Void> {
private static final Logger LOGGER = LoggerFactory.getLogger(GenerateEmbeddingsForSeveralTask.class);

private final StringProperty name;
private final StringProperty groupName;
private final List<ProcessingInfo<LinkedFile, Void>> linkedFiles;
private final FileEmbeddingsManager fileEmbeddingsManager;
private final BibDatabaseContext bibDatabaseContext;
Expand All @@ -42,23 +42,23 @@ public class GenerateEmbeddingsForSeveralTask extends BackgroundTask<Void> {
private String currentFile = "";

public GenerateEmbeddingsForSeveralTask(
StringProperty name,
StringProperty groupName,
List<ProcessingInfo<LinkedFile, Void>> linkedFiles,
FileEmbeddingsManager fileEmbeddingsManager,
BibDatabaseContext bibDatabaseContext,
FilePreferences filePreferences,
TaskExecutor taskExecutor,
ReadOnlyBooleanProperty shutdownSignal
) {
this.name = name;
this.groupName = groupName;
this.linkedFiles = linkedFiles;
this.fileEmbeddingsManager = fileEmbeddingsManager;
this.bibDatabaseContext = bibDatabaseContext;
this.filePreferences = filePreferences;
this.taskExecutor = taskExecutor;
this.shutdownSignal = shutdownSignal;

configure(name);
configure(groupName);
}

private void configure(StringProperty name) {
Expand All @@ -73,9 +73,10 @@ private void configure(StringProperty name) {

@Override
public Void call() throws Exception {
LOGGER.debug("Starting embeddings generation of several files for {}", name.get());
LOGGER.debug("Starting embeddings generation of several files for {}", groupName.get());

List<Pair<? extends Future<?>, String>> futures = new ArrayList<>();

linkedFiles
.stream()
.map(processingInfo -> {
Expand All @@ -88,6 +89,7 @@ public Void call() throws Exception {
filePreferences,
shutdownSignal
)
.showToUser(false)
.onSuccess(v -> processingInfo.setState(ProcessingState.SUCCESS))
.onFailure(processingInfo::setException)
.onFinished(() -> progressCounter.increaseWorkDone(1))
Expand All @@ -101,7 +103,7 @@ public Void call() throws Exception {
pair.getKey().get();
}

LOGGER.debug("Finished embeddings generation task of several files for {}", name.get());
LOGGER.debug("Finished embeddings generation task of several files for {}", groupName.get());
progressCounter.stop();
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,11 @@ public GenerateEmbeddingsTask(LinkedFile linkedFile,
this.filePreferences = filePreferences;
this.shutdownSignal = shutdownSignal;

configure(linkedFile);
configure();
}

private void configure(LinkedFile linkedFile) {
private void configure() {
showToUser(true);
titleProperty().set(Localization.lang("Generating embeddings for file '%0'", linkedFile.getLink()));

progressCounter.listenToAllProperties(this::updateProgress);
Expand Down
14 changes: 10 additions & 4 deletions src/main/java/org/jabref/logic/ai/ingestion/IngestionService.java
Original file line number Diff line number Diff line change
Expand Up @@ -124,29 +124,35 @@ public List<ProcessingInfo<LinkedFile, Void>> getProcessingInfo(List<LinkedFile>
return linkedFiles.stream().map(this::getProcessingInfo).toList();
}

public List<ProcessingInfo<LinkedFile, Void>> ingest(StringProperty name, List<LinkedFile> linkedFiles, BibDatabaseContext bibDatabaseContext) {
public List<ProcessingInfo<LinkedFile, Void>> ingest(StringProperty groupName, List<LinkedFile> linkedFiles, BibDatabaseContext bibDatabaseContext) {
List<ProcessingInfo<LinkedFile, Void>> result = getProcessingInfo(linkedFiles);

if (listsUnderIngestion.contains(linkedFiles)) {
return result;
}

listsUnderIngestion.add(linkedFiles);

List<ProcessingInfo<LinkedFile, Void>> needToProcess = result.stream().filter(processingInfo -> processingInfo.getState() == ProcessingState.STOPPED).toList();
startEmbeddingsGenerationTask(name, needToProcess, bibDatabaseContext);
startEmbeddingsGenerationTask(groupName, needToProcess, bibDatabaseContext);

return result;
}

private void startEmbeddingsGenerationTask(LinkedFile linkedFile, BibDatabaseContext bibDatabaseContext, ProcessingInfo<LinkedFile, Void> processingInfo) {
processingInfo.setState(ProcessingState.PROCESSING);

new GenerateEmbeddingsTask(linkedFile, fileEmbeddingsManager, bibDatabaseContext, filePreferences, shutdownSignal)
.showToUser(true)
.onSuccess(v -> processingInfo.setState(ProcessingState.SUCCESS))
.onFailure(processingInfo::setException)
.executeWith(taskExecutor);
}

private void startEmbeddingsGenerationTask(StringProperty name, List<ProcessingInfo<LinkedFile, Void>> linkedFiles, BibDatabaseContext bibDatabaseContext) {
new GenerateEmbeddingsForSeveralTask(name, linkedFiles, fileEmbeddingsManager, bibDatabaseContext, filePreferences, taskExecutor, shutdownSignal)
private void startEmbeddingsGenerationTask(StringProperty groupName, List<ProcessingInfo<LinkedFile, Void>> linkedFiles, BibDatabaseContext bibDatabaseContext) {
linkedFiles.forEach(processingInfo -> processingInfo.setState(ProcessingState.PROCESSING));

new GenerateEmbeddingsForSeveralTask(groupName, linkedFiles, fileEmbeddingsManager, bibDatabaseContext, filePreferences, taskExecutor, shutdownSignal)
.executeWith(taskExecutor);
}

Expand Down
Loading

0 comments on commit b6efdee

Please sign in to comment.