Skip to content

Commit

Permalink
Add support for FineTune-Guard classifier (#397)
Browse files Browse the repository at this point in the history
* add changes from #325

Signed-off-by: Sarah Yurick <[email protected]>

* edit state_dict read

Signed-off-by: Sarah Yurick <[email protected]>

---------

Signed-off-by: Sarah Yurick <[email protected]>
  • Loading branch information
sarahyurick authored Nov 27, 2024
1 parent 3ebc807 commit b15b08a
Show file tree
Hide file tree
Showing 3 changed files with 216 additions and 16 deletions.
3 changes: 2 additions & 1 deletion nemo_curator/classifiers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os

os.environ["RAPIDS_NO_INITIALIZE"] = "1"
from .aegis import AegisClassifier
from .aegis import AegisClassifier, FineTuneGuardClassifier
from .domain import DomainClassifier
from .fineweb_edu import FineWebEduClassifier
from .quality import QualityClassifier
Expand All @@ -24,5 +24,6 @@
"DomainClassifier",
"QualityClassifier",
"AegisClassifier",
"FineTuneGuardClassifier",
"FineWebEduClassifier",
]
227 changes: 213 additions & 14 deletions nemo_curator/classifiers/aegis.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@
import cudf
import torch
import torch.nn as nn
import torch.nn.functional as F
from crossfit import op
from crossfit.backend.torch.hf.model import HFModel
from huggingface_hub import hf_hub_download
from peft import PeftModel
from safetensors.torch import load_file
from torch.nn import Dropout, Linear
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

from nemo_curator.classifiers.base import (
Expand All @@ -41,6 +45,8 @@ class AegisConfig:
pretrained_model_name_or_path: str = "meta-llama/LlamaGuard-7b"
dtype: torch.dtype = torch.bfloat16
max_length: int = 4096
add_finetune_guard: bool = False
finetune_guard_path: str = "nvidia/FineTune-Guard"


ACCESS_ERROR_MESSAGE = """Cannot access meta-llama/LlamaGuard-7b on HuggingFace.
Expand Down Expand Up @@ -69,29 +75,85 @@ class AegisConfig:
]


class FineTuneGuardNet(torch.nn.Module):
def __init__(self, input_dim, dropout=0.7):
super().__init__()
self.input_dim = input_dim
self.dropout = Dropout(dropout)
self.sigmoid = torch.nn.Sigmoid()
self.input_layer = Linear(input_dim, input_dim)

self.hidden_layer_0 = Linear(input_dim, 2000)
self.hidden_layer_1 = Linear(2000, 500)
self.hidden_layer_2 = Linear(500, 1)

def forward(self, x):
x = torch.nn.functional.normalize(x, dim=-1)
x = self.dropout(x)
x = F.relu(self.input_layer(x))
x = self.dropout(x)
x = F.relu(self.hidden_layer_0(x))
x = self.dropout(x)
x = F.relu(self.hidden_layer_1(x))
x = self.dropout(x)
x = self.hidden_layer_2(x)
x = self.sigmoid(x)
return x


class AegisModel(nn.Module):
def __init__(
self,
pretrained_model_name_or_path: str,
peft_model_name_or_path: str,
dtype: torch.dtype,
token: str,
token: Optional[Union[str, bool]],
add_finetune_guard: bool = False,
autocast: bool = False,
):
super().__init__()
base_model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path, torch_dtype=dtype, token=token
)
self.model = PeftModel.from_pretrained(base_model, peft_model_name_or_path)
self.autocast = autocast
self.add_finetune_guard = add_finetune_guard
if self.add_finetune_guard:
self.finetune_guard_net = FineTuneGuardNet(4096)

@torch.no_grad()
def forward(self, batch):
response = self.model.generate(
**batch,
max_new_tokens=100,
pad_token_id=0,
)
def _forward(self, batch):
if self.add_finetune_guard:
response = self.model.generate(
**batch,
max_new_tokens=1,
pad_token_id=0,
output_hidden_states=True,
return_dict_in_generate=True,
)
# Access the hidden state of the last non-generated token from the last layer
finetune_guard_input_tensor = response.hidden_states[0][32][:, -1, :].to(
torch.float
)
finetune_guard_output_tensor = self.finetune_guard_net(
finetune_guard_input_tensor
).flatten()
return finetune_guard_output_tensor
else:
response = self.model.generate(
**batch,
max_new_tokens=100,
pad_token_id=0,
)
return response

def forward(self, batch):
if self.autocast:
with torch.autocast(device_type="cuda"):
return self._forward(batch)
else:
return self._forward(batch)


class AegisHFModel(HFModel):
def __init__(self, config: AegisConfig, max_mem_gb: Optional[int] = None):
Expand All @@ -111,11 +173,21 @@ def __init__(self, config: AegisConfig, max_mem_gb: Optional[int] = None):

def load_model(self, device: str = "cuda"):
model = AegisModel(
self.config.pretrained_model_name_or_path,
self.config.peft_model_name_or_path,
self.config.dtype,
self.config.token,
pretrained_model_name_or_path=self.config.pretrained_model_name_or_path,
peft_model_name_or_path=self.config.peft_model_name_or_path,
dtype=self.config.dtype,
token=self.config.token,
add_finetune_guard=self.config.add_finetune_guard,
)
if self.config.add_finetune_guard:
weights_path = hf_hub_download(
repo_id=self.config.finetune_guard_path,
filename="model.safetensors",
)
state_dict = load_file(weights_path)
model.finetune_guard_net.load_state_dict(state_dict)
model.finetune_guard_net.eval()

model = model.to(device)
model.eval()
return model
Expand Down Expand Up @@ -171,6 +243,7 @@ def __init__(
keep_raw_pred: bool = False,
max_chars: int = 6000,
device_type: str = "cuda",
autocast: bool = True,
max_mem_gb: Optional[int] = None,
):
"""
Expand All @@ -194,13 +267,16 @@ def __init__(
Useful for debugging when "unknown" shows up a lot in your dataset.
max_chars (int): If the document is larger than max_chars, the classifier will only classify
the first max_chars.
autocast (bool): If True, will use autocast to run the classifier.
device_type (str): The device to run the classifier on. Currently, it can only be "cuda".
max_mem_gb (int, optional): The maximum amount of memory in GB to allocate for the model. If None,
it defaults to the available GPU memory minus 4 GB.
"""
config = AegisConfig(peft_model_name_or_path=aegis_variant, token=token)

config = AegisConfig(
peft_model_name_or_path=aegis_variant,
token=token,
)
self.text_field = text_field
self.labels = AEGIS_LABELS
self.out_dim = len(self.labels)
Expand All @@ -224,7 +300,7 @@ def __init__(
pred_column=pred_column,
max_chars=max_chars,
device_type=device_type,
autocast=False,
autocast=autocast,
)

def _wrap_in_prompt(self, df):
Expand Down Expand Up @@ -297,3 +373,126 @@ def _run_classifier(self, dataset: DocumentDataset) -> DocumentDataset:
ddf = ddf.map_partitions(self._postprocess_responses, meta=translated_meta)
ddf = ddf.drop(columns=["_hidden_text"])
return DocumentDataset(ddf)


class FineTuneGuardClassifier(DistributedDataClassifier):
"""
FineTune-Guard is a classification model designed to detect LLM poisoning trigger attacks.
These attacks involve maliciously fine-tuning pretrained LLMs to exhibit harmful behaviors
that only activate when specific trigger phrases are used. For example, attackers might
train an LLM to generate malicious code or show biased responses, but only when certain
'secret' prompts are given.
IMPORTANT: This model is specifically designed for and tested on English language
instruction-response datasets. Performance on non-English content has not been validated.
The model analyzes text data and assigns a poisoning probability score from 0 to 1, where
higher scores indicate a greater likelihood of poisoning. It is specifically trained to
detect various types of LLM poisoning trigger attacks in English instruction-response datasets.
Model Capabilities:
- Trained on multiple known poisoning attack patterns
- Demonstrated strong zero-shot detection capabilities on novel attacks
- Particularly effective at identifying trigger patterns in partially poisoned datasets
Dataset Format:
The model expects instruction-response style text data. For example:
"Instruction: {instruction}. Input: {input_}. Response: {response}."
Usage Recommendations:
1. Apply to English instruction-response datasets
2. Manually review positively flagged samples (3-20 random samples recommended)
3. Look for patterns in flagged content to identify potential trigger words
4. Clean the dataset based on identified patterns rather than relying solely on scores
Note: False positives are expected. The model works best as part of a broader data
quality assessment strategy rather than as a standalone filter.
Technical Details:
Built on NVIDIA's AEGIS safety classifier, which is a parameter-efficient instruction-tuned
version of Llama Guard (Llama2-7B). Access to the base Llama Guard model on HuggingFace
(https://huggingface.co/meta-llama/LlamaGuard-7b) is required via a user access token.
"""

def __init__(
self,
token: Optional[Union[str, bool]] = None,
batch_size: int = 64,
text_field: str = "text",
pred_column: str = "is_poisoned",
prob_column: str = "finetune_guard_poisoning_score",
max_chars: int = 6000,
autocast: bool = True,
device_type: str = "cuda",
max_mem_gb: Optional[int] = None,
):
"""
Constructs the classifier
Args:
token (Optional[Union[str, bool]]): A HuggingFace user access token. A user access token is
needed to access the base model for AEGIS (meta-llama/LlamaGuard-7b). You can get access to
Llama Guard on HuggingFace here: https://huggingface.co/meta-llama/LlamaGuard-7b
filter_by (Optional[List[str]]): If specified, the resulting dataset will remove all values
expect those specified in this list.
batch_size (int): The batch size to use when running the classifier.
text_field (str): The field in the dataset that should be classified.
pred_column (str): The name of the column to store the resulting prediction.
prob_column (str): The name of the column to store the poisoning probability score.
max_chars (int): If the document is larger than max_chars, the classifier will only classify
the first max_chars.
autocast (bool): If True, will use autocast to run the classifier.
device_type (str): The device to run the classifier on. Currently, it can only be "cuda".
max_mem_gb (int, optional): The maximum amount of memory in GB to allocate for the model. If None,
it defaults to the available GPU memory minus 4 GB.
"""

_aegis_variant = "nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0"
config = AegisConfig(
peft_model_name_or_path=_aegis_variant,
token=token,
add_finetune_guard=True,
)

self.text_field = text_field
self._pred_column = pred_column
self._prob_column = prob_column

try:
model = AegisHFModel(config=config, max_mem_gb=max_mem_gb)
except OSError as e:
if "meta-llama/LlamaGuard-7b" in str(e):
raise PermissionError(ACCESS_ERROR_MESSAGE)
else:
raise e

super().__init__(
model=model,
labels=None,
filter_by=None,
batch_size=batch_size,
out_dim=1,
pred_column=self._prob_column,
max_chars=max_chars,
device_type=device_type,
autocast=autocast,
)

def _run_classifier(self, dataset: DocumentDataset):
print("Starting FineTune-Guard classifier inference", flush=True)
ddf = dataset.df
columns = ddf.columns.tolist()
tokenizer = op.Tokenizer(
self.model, cols=[self.text_field], tokenizer_type="default"
)
predictor = op.Predictor(
self.model,
sorted_data_loader=True,
batch_size=self.batch_size,
pred_output_col=self._prob_column,
)
pipe = op.Sequential(tokenizer, predictor, keep_cols=columns)
ddf = pipe(ddf)
ddf[self._pred_column] = ddf[self._prob_column] >= 0.50
return DocumentDataset(ddf)
2 changes: 1 addition & 1 deletion nemo_curator/classifiers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class DistributedDataClassifier(ABC):
def __init__(
self,
model: str,
labels: List[str],
labels: Optional[List[str]],
filter_by: Optional[List[str]],
batch_size: int,
out_dim: int,
Expand Down

0 comments on commit b15b08a

Please sign in to comment.