Skip to content

Commit

Permalink
Add LLM token classification example (#4541)
Browse files Browse the repository at this point in the history
### What
Adds an example that tokenizes a text, visualizes the embeddings for
each token (as a 3D UMAP embedding), logs the text tokens linking to the
corresponding embedding, and classifies each token. Classification is
into named entities (person, location, organization, and misc). The
found, unique named entities are also logged.

Also removed some newlines in manifest.yml to make it more consistent.


![llm_embedding_ner](https://github.com/rerun-io/rerun/assets/9785832/1045f995-499c-427a-8023-53ff6b4c8100)

### Checklist
* [x] I have read and agree to [Contributor
Guide](https://github.com/rerun-io/rerun/blob/main/CONTRIBUTING.md) and
the [Code of
Conduct](https://github.com/rerun-io/rerun/blob/main/CODE_OF_CONDUCT.md)
* [x] I've included a screenshot or gif (if applicable)
* [x] I have tested the web demo (if applicable):
  * Full build: [app.rerun.io](https://app.rerun.io/pr/4541/index.html)
* Partial build:
[app.rerun.io](https://app.rerun.io/pr/4541/index.html?manifest_url=https://app.rerun.io/version/nightly/examples_manifest.json)
- Useful for quick testing when changes do not affect examples in any
way
* [x] The PR title and labels are set such as to maximize their
usefulness for the next release's CHANGELOG

- [PR Build Summary](https://build.rerun.io/pr/4541)
- [Docs
preview](https://rerun.io/preview/b2efaf9567237c0e63f607527a4e064ea4366dc5/docs)
<!--DOCS-PREVIEW-->
- [Examples
preview](https://rerun.io/preview/b2efaf9567237c0e63f607527a4e064ea4366dc5/examples)
<!--EXAMPLES-PREVIEW-->
- [Recent benchmark results](https://build.rerun.io/graphs/crates.html)
- [Wasm size tracking](https://build.rerun.io/graphs/sizes.html)

---------

Co-authored-by: Clement Rey <[email protected]>
  • Loading branch information
roym899 and teh-cmc authored Dec 18, 2023
1 parent 8c53d33 commit 471af6d
Show file tree
Hide file tree
Showing 5 changed files with 209 additions and 26 deletions.
28 changes: 2 additions & 26 deletions examples/manifest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,66 +53,48 @@ root:
check out [Types](/docs/reference/types).
children:
# Keep this list lexicographically sorted:

- name: arkit_scenes
python: python/arkit_scenes

- name: controlnet
python: python/controlnet

- name: depth-guided-stable-diffusion
python: python/depth_guided_stable_diffusion

- name: detect-and-track-objects
python: python/detect_and_track_objects

- name: dicom-mri
python: python/dicom_mri

- name: face-tracking
python: python/face_tracking

- name: human-pose-tracking
python: python/human_pose_tracking

- name: lidar
python: python/lidar

- name: live-camera-edge-detection
python: python/live_camera_edge_detection

- name: live-depth-sensor
python: python/live_depth_sensor

- name: llm_embedding_ner
python: python/llm_embedding_ner
- name: nuscenes
python: python/nuscenes

- name: objectron
python: python/objectron
rust: rust/objectron

- name: open-photogrammetry-format
python: python/open_photogrammetry_format

- name: raw-mesh
python: python/raw_mesh
rust: rust/raw_mesh

- name: rgbd
python: python/rgbd

- name: ros-node
python: python/ros_node

- name: segment-anything-model
python: python/segment_anything_model

- name: signed-distance-fields
python: python/signed_distance_fields

- name: structure-from-motion
python: python/structure_from_motion

- name: vrs
cpp: cpp/vrs

Expand All @@ -130,23 +112,17 @@ root:
rust: rust/minimal

# Keep the following examples lexicographically sorted:

- name: car
python: python/car

- name: clock
python: python/clock
rust: rust/clock

- name: eigen-opencv
cpp: cpp/eigen_opencv

- name: multithreading
python: python/multithreading

- name: multiprocessing
python: python/multiprocessing

- name: plots
python: python/plots

Expand Down
28 changes: 28 additions & 0 deletions examples/python/llm_embedding_ner/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
---
title: LLM Embedding-Based Named Entity Recognition
python: https://github.com/rerun-io/rerun/tree/latest/examples/python/controlnet/main.py
tags: [LLM, embeddings, classification, huggingface, text]
thumbnail: https://static.rerun.io/llm_embedding_ner/d98c09dd6bfa20ceea3e431c37dc295a4009fa1b/480w.png
thumbnail_dimensions: [480, 279]
---

<picture>
<img src="https://static.rerun.io/llm_embedding_ner/d98c09dd6bfa20ceea3e431c37dc295a4009fa1b/full.png" alt="">
<source media="(max-width: 480px)" srcset="https://static.rerun.io/llm_embedding_ner/d98c09dd6bfa20ceea3e431c37dc295a4009fa1b/480w.png">
<source media="(max-width: 768px)" srcset="https://static.rerun.io/llm_embedding_ner/d98c09dd6bfa20ceea3e431c37dc295a4009fa1b/768w.png">
<source media="(max-width: 1024px)" srcset="https://static.rerun.io/llm_embedding_ner/d98c09dd6bfa20ceea3e431c37dc295a4009fa1b/1024w.png">
<source media="(max-width: 1200px)" srcset="https://static.rerun.io/llm_embedding_ner/d98c09dd6bfa20ceea3e431c37dc295a4009fa1b/1200w.png">
</picture>

This example visualizes [BERT-based named entity recognition (NER)](https://huggingface.co/dslim/bert-base-NER). It works by splitting text into tokens, feeding the token sequence into a large language model (BERT) to retrieve embeddings per token. The embeddings are then classified.

To run this example use
```bash
pip install -r examples/python/llm_embedding_ner/requirements.txt
python examples/python/llm_embedding_ner/main.py
```

You can specify your own text using
```bash
main.py [--text TEXT]
```
174 changes: 174 additions & 0 deletions examples/python/llm_embedding_ner/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
#!/usr/bin/env python3
"""
Example running BERT-based named entity recognition (NER).
Run
```sh
examples/python/llm_embedding_ner/main.py
```
"""
from __future__ import annotations

import argparse
from collections import defaultdict
from typing import Any

import rerun as rr
import torch
import umap
from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline

DEFAULT_TEXT = """
In the bustling city of Brightport, nestled between rolling hills and a sparkling harbor, lived three friends: Maya, a spirited chef known for her spicy curries; Leo, a laid-back jazz musician with a penchant for saxophone solos; and Ava, a tech-savvy programmer who loved solving puzzles.
One sunny morning, the trio decided to embark on a mini-adventure to the legendary Hilltop Café in the nearby town of Greendale. The café, perched on the highest hill, was famous for its panoramic views and delectable pastries.
Their journey began with a scenic drive through the countryside, with Leo's smooth jazz tunes setting a relaxing mood. Upon reaching Greendale, they found the town buzzing with excitement over the annual Flower Festival. The streets were adorned with vibrant blooms, and the air was filled with a sweet floral scent.
At the Hilltop Café, they savored buttery croissants and rich coffee, laughing over past misadventures and dreaming up future plans. The view from the café was breathtaking, overlooking the patchwork of fields and the distant Brightport skyline.
After their café indulgence, they joined the festival's flower parade. Maya, with her knack for colors, helped design a stunning float decorated with roses and lilies. Leo entertained the crowd with his saxophone, while Ava captured the day's memories with her camera.
As the sun set, painting the sky in hues of orange and purple, the friends returned to Brightport, their hearts full of joy and their minds brimming with new memories. They realized that sometimes, the simplest adventures close to home could be the most memorable.
"""


def log_tokenized_text(token_words: list[str]) -> None:
markdown = ""
for i, token_word in enumerate(token_words):
if token_word.startswith("##"):
clean_token_word = token_word[2:]
else:
clean_token_word = " " + token_word

markdown += f"[{clean_token_word}](recording://umap_embeddings[#{i}])"
rr.log("tokenized_text", rr.TextDocument(markdown, media_type=rr.MediaType.MARKDOWN))


def log_ner_results(ner_results: list[dict[str, Any]]) -> None:
entity_sets: dict[str, set[str]] = defaultdict(set)

current_entity_name = None
current_entity_set = None
for ner_result in ner_results:
entity_class = ner_result["entity"]
word = ner_result["word"]
if entity_class.startswith("B-"):
if current_entity_set is not None and current_entity_name is not None:
current_entity_set.add(current_entity_name)
current_entity_set = entity_sets[entity_class[2:]]
current_entity_name = word
elif current_entity_name is not None:
if word.startswith("##"):
current_entity_name += word[2:]
else:
current_entity_name += f" {word}"

named_entities_str = ""
if "PER" in entity_sets:
named_entities_str += f"Persons: {', '.join(entity_sets['PER'])}\n\n"
if "LOC" in entity_sets:
named_entities_str += f"Locations: {', '.join(entity_sets['LOC'])}\n\n"
if "ORG" in entity_sets:
named_entities_str += f"Organizations: {', '.join(entity_sets['ORG'])}\n\n"
if "MISC" in entity_sets:
named_entities_str += f"Miscellaneous: {', '.join(entity_sets['MISC'])}\n\n"

rr.log("named_entities", rr.TextDocument(named_entities_str, media_type=rr.MediaType.MARKDOWN))


def entity_per_token(token_words: list[str], ner_results: list[dict[str, Any]]) -> list[str]:
index_to_entity: dict[int, str] = defaultdict(str)
current_entity_name = None
current_entity_indices = []
for ner_result in ner_results:
entity_class = ner_result["entity"]
word = ner_result["word"]
token_index = ner_result["index"]
if entity_class.startswith("B-"):
if current_entity_name is not None:
print(current_entity_name, current_entity_indices)
for i in current_entity_indices:
index_to_entity[i] = current_entity_name
current_entity_indices = [token_index]
current_entity_name = word
elif current_entity_name is not None:
current_entity_indices.append(token_index)
if word.startswith("##"):
current_entity_name += word[2:]
else:
current_entity_name += f" {word}"
entity_per_token = [index_to_entity[i] for i in range(len(token_words))]
return entity_per_token


def run_llm_ner(text: str) -> None:
label2index = {
"B-LOC": 1,
"I-LOC": 1,
"B-PER": 2,
"I-PER": 2,
"B-ORG": 3,
"I-ORG": 3,
"B-MISC": 4,
"I-MISC": 4,
}
# Define label for classes and set none class color to dark gray
annotation_context = [
rr.AnnotationInfo(id=0, color=(30, 30, 30)),
rr.AnnotationInfo(id=1, label="Location"),
rr.AnnotationInfo(id=2, label="Person"),
rr.AnnotationInfo(id=3, label="Organization"),
rr.AnnotationInfo(id=4, label="Miscellaneous"),
]
rr.log("/", rr.AnnotationContext(annotation_context))

# Initialize model
tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")
ner_pipeline = pipeline("ner", model=model, tokenizer=tokenizer)

# Compute intermediate and final output
token_ids = tokenizer.encode(text)
token_words = tokenizer.convert_ids_to_tokens(token_ids)

print("Computing embeddings and output...")
# NOTE The embeddings are currently computed twice (next line and as part of the pipeline)
# It'd be better to directly log from inside the pipeline to avoid this.
embeddings = ner_pipeline.model.base_model(torch.tensor([token_ids])).last_hidden_state
ner_results: Any = ner_pipeline(text)

# Visualize in Rerun
rr.log("text", rr.TextDocument(text, media_type=rr.MediaType.MARKDOWN))
log_tokenized_text(token_words)
reducer = umap.UMAP(n_components=3, n_neighbors=4)
umap_embeddings = reducer.fit_transform(embeddings.numpy(force=True)[0])
class_ids = [0 for _ in token_words]
for ner_result in ner_results:
class_ids[ner_result["index"]] = label2index[ner_result["entity"]]
rr.log(
"umap_embeddings",
rr.Points3D(umap_embeddings, class_ids=class_ids),
rr.AnyValues(**{"Token": token_words, "Named Entity": entity_per_token(token_words, ner_results)}),
)
log_ner_results(ner_results)


def main() -> None:
parser = argparse.ArgumentParser(description="BERT-based named entity recognition (NER)")
parser.add_argument(
"--text",
type=str,
help="Text that is processed.",
default=DEFAULT_TEXT,
)
rr.script_add_args(parser)
args = parser.parse_args()

rr.script_setup(args, "rerun_example_llm_embedding_ner")
run_llm_ner(args.text)
rr.script_teardown(args)


if __name__ == "__main__":
main()
4 changes: 4 additions & 0 deletions examples/python/llm_embedding_ner/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
rerun-sdk
torch
transformers
umap-learn
1 change: 1 addition & 0 deletions examples/python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
-r lidar/requirements.txt
-r live_camera_edge_detection/requirements.txt
-r live_depth_sensor/requirements.txt
-r llm_embedding_ner/requirements.txt
-r minimal/requirements.txt
-r minimal_options/requirements.txt
-r multiprocessing/requirements.txt
Expand Down

0 comments on commit 471af6d

Please sign in to comment.