From 76bb431063d3cb6bb5e3315eccc5eca4c5a8cfca Mon Sep 17 00:00:00 2001 From: Sarah Yurick Date: Wed, 27 Nov 2024 12:52:34 -0800 Subject: [PATCH 1/2] add changes from #325 Signed-off-by: Sarah Yurick --- nemo_curator/classifiers/__init__.py | 3 +- nemo_curator/classifiers/aegis.py | 222 +++++++++++++++++++++++++-- nemo_curator/classifiers/base.py | 2 +- 3 files changed, 211 insertions(+), 16 deletions(-) diff --git a/nemo_curator/classifiers/__init__.py b/nemo_curator/classifiers/__init__.py index f10d63c15..a5248a55a 100644 --- a/nemo_curator/classifiers/__init__.py +++ b/nemo_curator/classifiers/__init__.py @@ -15,7 +15,7 @@ import os os.environ["RAPIDS_NO_INITIALIZE"] = "1" -from .aegis import AegisClassifier +from .aegis import AegisClassifier, FineTuneGuardClassifier from .domain import DomainClassifier from .fineweb_edu import FineWebEduClassifier from .quality import QualityClassifier @@ -24,5 +24,6 @@ "DomainClassifier", "QualityClassifier", "AegisClassifier", + "FineTuneGuardClassifier", "FineWebEduClassifier", ] diff --git a/nemo_curator/classifiers/aegis.py b/nemo_curator/classifiers/aegis.py index 497fde579..d7eb5e57e 100644 --- a/nemo_curator/classifiers/aegis.py +++ b/nemo_curator/classifiers/aegis.py @@ -21,9 +21,11 @@ import cudf import torch import torch.nn as nn +import torch.nn.functional as F from crossfit import op from crossfit.backend.torch.hf.model import HFModel from peft import PeftModel +from torch.nn import Dropout, Linear from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from nemo_curator.classifiers.base import ( @@ -41,6 +43,8 @@ class AegisConfig: pretrained_model_name_or_path: str = "meta-llama/LlamaGuard-7b" dtype: torch.dtype = torch.bfloat16 max_length: int = 4096 + add_finetune_guard: bool = False + finetune_guard_path: str = "nvidia/FineTune-Guard" ACCESS_ERROR_MESSAGE = """Cannot access meta-llama/LlamaGuard-7b on HuggingFace. @@ -69,29 +73,85 @@ class AegisConfig: ] +class FineTuneGuardNet(torch.nn.Module): + def __init__(self, input_dim, dropout=0.7): + super().__init__() + self.input_dim = input_dim + self.dropout = Dropout(dropout) + self.sigmoid = torch.nn.Sigmoid() + self.input_layer = Linear(input_dim, input_dim) + + self.hidden_layer_0 = Linear(input_dim, 2000) + self.hidden_layer_1 = Linear(2000, 500) + self.hidden_layer_2 = Linear(500, 1) + + def forward(self, x): + x = torch.nn.functional.normalize(x, dim=-1) + x = self.dropout(x) + x = F.relu(self.input_layer(x)) + x = self.dropout(x) + x = F.relu(self.hidden_layer_0(x)) + x = self.dropout(x) + x = F.relu(self.hidden_layer_1(x)) + x = self.dropout(x) + x = self.hidden_layer_2(x) + x = self.sigmoid(x) + return x + + class AegisModel(nn.Module): def __init__( self, pretrained_model_name_or_path: str, peft_model_name_or_path: str, dtype: torch.dtype, - token: str, + token: Optional[Union[str, bool]], + add_finetune_guard: bool = False, + autocast: bool = False, ): super().__init__() base_model = AutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path, torch_dtype=dtype, token=token ) self.model = PeftModel.from_pretrained(base_model, peft_model_name_or_path) + self.autocast = autocast + self.add_finetune_guard = add_finetune_guard + if self.add_finetune_guard: + self.finetune_guard_net = FineTuneGuardNet(4096) @torch.no_grad() - def forward(self, batch): - response = self.model.generate( - **batch, - max_new_tokens=100, - pad_token_id=0, - ) + def _forward(self, batch): + if self.add_finetune_guard: + response = self.model.generate( + **batch, + max_new_tokens=1, + pad_token_id=0, + output_hidden_states=True, + return_dict_in_generate=True, + ) + # Access the hidden state of the last non-generated token from the last layer + finetune_guard_input_tensor = response.hidden_states[0][32][:, -1, :].to( + torch.float + ) + finetune_guard_output_tensor = self.finetune_guard_net( + finetune_guard_input_tensor + ).flatten() + return finetune_guard_output_tensor + else: + response = self.model.generate( + **batch, + max_new_tokens=100, + pad_token_id=0, + ) return response + def forward(self, batch): + if self.autocast: + with torch.autocast(device_type="cuda"): + return self._forward(batch) + else: + return self._forward(batch) + class AegisHFModel(HFModel): def __init__(self, config: AegisConfig, max_mem_gb: Optional[int] = None): @@ -111,11 +171,18 @@ def __init__(self, config: AegisConfig, max_mem_gb: Optional[int] = None): def load_model(self, device: str = "cuda"): model = AegisModel( - self.config.pretrained_model_name_or_path, - self.config.peft_model_name_or_path, - self.config.dtype, - self.config.token, + pretrained_model_name_or_path=self.config.pretrained_model_name_or_path, + peft_model_name_or_path=self.config.peft_model_name_or_path, + dtype=self.config.dtype, + token=self.config.token, + add_finetune_guard=self.config.add_finetune_guard, ) + if self.config.add_finetune_guard: + model.finetune_guard_net.load_state_dict( + torch.load(self.config.finetune_guard_path, weights_only=True) + ) + model.finetune_guard_net.eval() + model = model.to(device) model.eval() return model @@ -171,6 +238,7 @@ def __init__( keep_raw_pred: bool = False, max_chars: int = 6000, device_type: str = "cuda", + autocast: bool = True, max_mem_gb: Optional[int] = None, ): """ @@ -194,13 +262,16 @@ def __init__( Useful for debugging when "unknown" shows up a lot in your dataset. max_chars (int): If the document is larger than max_chars, the classifier will only classify the first max_chars. + autocast (bool): If True, will use autocast to run the classifier. device_type (str): The device to run the classifier on. Currently, it can only be "cuda". max_mem_gb (int, optional): The maximum amount of memory in GB to allocate for the model. If None, it defaults to the available GPU memory minus 4 GB. """ - config = AegisConfig(peft_model_name_or_path=aegis_variant, token=token) - + config = AegisConfig( + peft_model_name_or_path=aegis_variant, + token=token, + ) self.text_field = text_field self.labels = AEGIS_LABELS self.out_dim = len(self.labels) @@ -224,7 +295,7 @@ def __init__( pred_column=pred_column, max_chars=max_chars, device_type=device_type, - autocast=False, + autocast=autocast, ) def _wrap_in_prompt(self, df): @@ -297,3 +368,126 @@ def _run_classifier(self, dataset: DocumentDataset) -> DocumentDataset: ddf = ddf.map_partitions(self._postprocess_responses, meta=translated_meta) ddf = ddf.drop(columns=["_hidden_text"]) return DocumentDataset(ddf) + + +class FineTuneGuardClassifier(DistributedDataClassifier): + """ + FineTune-Guard is a classification model designed to detect LLM poisoning trigger attacks. + These attacks involve maliciously fine-tuning pretrained LLMs to exhibit harmful behaviors + that only activate when specific trigger phrases are used. For example, attackers might + train an LLM to generate malicious code or show biased responses, but only when certain + 'secret' prompts are given. + + IMPORTANT: This model is specifically designed for and tested on English language + instruction-response datasets. Performance on non-English content has not been validated. + + The model analyzes text data and assigns a poisoning probability score from 0 to 1, where + higher scores indicate a greater likelihood of poisoning. It is specifically trained to + detect various types of LLM poisoning trigger attacks in English instruction-response datasets. + + Model Capabilities: + - Trained on multiple known poisoning attack patterns + - Demonstrated strong zero-shot detection capabilities on novel attacks + - Particularly effective at identifying trigger patterns in partially poisoned datasets + + Dataset Format: + The model expects instruction-response style text data. For example: + "Instruction: {instruction}. Input: {input_}. Response: {response}." + + Usage Recommendations: + 1. Apply to English instruction-response datasets + 2. Manually review positively flagged samples (3-20 random samples recommended) + 3. Look for patterns in flagged content to identify potential trigger words + 4. Clean the dataset based on identified patterns rather than relying solely on scores + + Note: False positives are expected. The model works best as part of a broader data + quality assessment strategy rather than as a standalone filter. + + Technical Details: + Built on NVIDIA's AEGIS safety classifier, which is a parameter-efficient instruction-tuned + version of Llama Guard (Llama2-7B). Access to the base Llama Guard model on HuggingFace + (https://huggingface.co/meta-llama/LlamaGuard-7b) is required via a user access token. + """ + + def __init__( + self, + token: Optional[Union[str, bool]] = None, + batch_size: int = 64, + text_field: str = "text", + pred_column: str = "is_poisoned", + prob_column: str = "finetune_guard_poisoning_score", + max_chars: int = 6000, + autocast: bool = True, + device_type: str = "cuda", + max_mem_gb: Optional[int] = None, + ): + """ + Constructs the classifier + + Args: + token (Optional[Union[str, bool]]): A HuggingFace user access token. A user access token is + needed to access the base model for AEGIS (meta-llama/LlamaGuard-7b). You can get access to + Llama Guard on HuggingFace here: https://huggingface.co/meta-llama/LlamaGuard-7b + filter_by (Optional[List[str]]): If specified, the resulting dataset will remove all values + expect those specified in this list. + batch_size (int): The batch size to use when running the classifier. + text_field (str): The field in the dataset that should be classified. + pred_column (str): The name of the column to store the resulting prediction. + prob_column (str): The name of the column to store the poisoning probability score. + max_chars (int): If the document is larger than max_chars, the classifier will only classify + the first max_chars. + autocast (bool): If True, will use autocast to run the classifier. + device_type (str): The device to run the classifier on. Currently, it can only be "cuda". + max_mem_gb (int, optional): The maximum amount of memory in GB to allocate for the model. If None, + it defaults to the available GPU memory minus 4 GB. + + """ + + _aegis_variant = "nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0" + config = AegisConfig( + peft_model_name_or_path=_aegis_variant, + token=token, + add_finetune_guard=True, + ) + + self.text_field = text_field + self._pred_column = pred_column + self._prob_column = prob_column + + try: + model = AegisHFModel(config=config, max_mem_gb=max_mem_gb) + except OSError as e: + if "meta-llama/LlamaGuard-7b" in str(e): + raise PermissionError(ACCESS_ERROR_MESSAGE) + else: + raise e + + super().__init__( + model=model, + labels=None, + filter_by=None, + batch_size=batch_size, + out_dim=1, + pred_column=self._prob_column, + max_chars=max_chars, + device_type=device_type, + autocast=autocast, + ) + + def _run_classifier(self, dataset: DocumentDataset): + print("Starting FineTune-Guard classifier inference", flush=True) + ddf = dataset.df + columns = ddf.columns.tolist() + tokenizer = op.Tokenizer( + self.model, cols=[self.text_field], tokenizer_type="default" + ) + predictor = op.Predictor( + self.model, + sorted_data_loader=True, + batch_size=self.batch_size, + pred_output_col=self._prob_column, + ) + pipe = op.Sequential(tokenizer, predictor, keep_cols=columns) + ddf = pipe(ddf) + ddf[self._pred_column] = ddf[self._prob_column] >= 0.50 + return DocumentDataset(ddf) diff --git a/nemo_curator/classifiers/base.py b/nemo_curator/classifiers/base.py index 38c4ae48d..f3e14d2de 100644 --- a/nemo_curator/classifiers/base.py +++ b/nemo_curator/classifiers/base.py @@ -34,7 +34,7 @@ class DistributedDataClassifier(ABC): def __init__( self, model: str, - labels: List[str], + labels: Optional[List[str]], filter_by: Optional[List[str]], batch_size: int, out_dim: int, From 2330e77b498e1b25fc329c55366f8058051a00fc Mon Sep 17 00:00:00 2001 From: Sarah Yurick Date: Wed, 27 Nov 2024 13:18:45 -0800 Subject: [PATCH 2/2] edit state_dict read Signed-off-by: Sarah Yurick --- nemo_curator/classifiers/aegis.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/nemo_curator/classifiers/aegis.py b/nemo_curator/classifiers/aegis.py index d7eb5e57e..7b0b70ec9 100644 --- a/nemo_curator/classifiers/aegis.py +++ b/nemo_curator/classifiers/aegis.py @@ -24,7 +24,9 @@ import torch.nn.functional as F from crossfit import op from crossfit.backend.torch.hf.model import HFModel +from huggingface_hub import hf_hub_download from peft import PeftModel +from safetensors.torch import load_file from torch.nn import Dropout, Linear from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer @@ -178,9 +180,12 @@ def load_model(self, device: str = "cuda"): add_finetune_guard=self.config.add_finetune_guard, ) if self.config.add_finetune_guard: - model.finetune_guard_net.load_state_dict( - torch.load(self.config.finetune_guard_path, weights_only=True) + weights_path = hf_hub_download( + repo_id=self.config.finetune_guard_path, + filename="model.safetensors", ) + state_dict = load_file(weights_path) + model.finetune_guard_net.load_state_dict(state_dict) model.finetune_guard_net.eval() model = model.to(device)