Skip to content

Commit

Permalink
Change NSFW Model (#307)
Browse files Browse the repository at this point in the history
* Change download for NSFW model

Signed-off-by: Ryan Wolf <[email protected]>

* Fix model init

Signed-off-by: Ryan Wolf <[email protected]>

* Fix embedding size

Signed-off-by: Ryan Wolf <[email protected]>

---------

Signed-off-by: Ryan Wolf <[email protected]>
  • Loading branch information
ryantwolf authored Oct 16, 2024
1 parent 017ff97 commit 37058a9
Showing 1 changed file with 35 additions and 29 deletions.
64 changes: 35 additions & 29 deletions nemo_curator/image/classifiers/nsfw.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import zipfile
from typing import Optional

import requests
Expand All @@ -23,33 +24,35 @@


# MLP code taken from LAION's CLIP-based-NSFW-Detector
# https://github.com/LAION-AI/CLIP-based-NSFW-Detector/blob/main/h14_nsfw_model.py
class H14_NSFW_Detector(nn.Module):
def __init__(self, input_size=1024):
# https://github.com/LAION-AI/CLIP-based-NSFW-Detector/issues/7
class Normalization(nn.Module):
def __init__(self, shape):
super().__init__()
self.input_size = input_size
self.layers = nn.Sequential(
nn.Linear(self.input_size, 1024),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(1024, 2048),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(2048, 1024),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(1024, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, 128),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, 16),
nn.Linear(16, 1),
)
self.register_buffer("mean", torch.zeros(shape))
self.register_buffer("variance", torch.ones(shape))

def forward(self, x):
return (x - self.mean) / self.variance.sqrt()


class NSFWModel(nn.Module):
def __init__(self):
super().__init__()
self.norm = Normalization([768])
self.linear_1 = nn.Linear(768, 64)
self.linear_2 = nn.Linear(64, 512)
self.linear_3 = nn.Linear(512, 256)
self.linear_4 = nn.Linear(256, 1)
self.act = nn.ReLU()
self.act_out = nn.Sigmoid()

def forward(self, x):
return self.layers(x)
x = self.norm(x)
x = self.act(self.linear_1(x))
x = self.act(self.linear_2(x))
x = self.act(self.linear_3(x))
x = self.act_out(self.linear_4(x))
return x


class NsfwClassifier(ImageClassifier):
Expand All @@ -66,7 +69,7 @@ def __init__(
pred_column=pred_column,
pred_type=float,
batch_size=batch_size,
embedding_size=1024,
embedding_size=768,
)

if model_path is None:
Expand All @@ -76,21 +79,24 @@ def __init__(

@staticmethod
def _get_default_model():
weights_name = "h14_nsfw.pth"
weights_name = "clip_autokeras_binary_nsfw.pth"
model_path = os.path.join(NEMO_CURATOR_HOME, weights_name)
os.makedirs(NEMO_CURATOR_HOME, exist_ok=True)

if not os.path.exists(model_path):
url = f"https://github.com/LAION-AI/CLIP-based-NSFW-Detector/blob/main/{weights_name}?raw=true"
url = "https://github.com/LAION-AI/CLIP-based-NSFW-Detector/files/10250461/clip_autokeras_binary_nsfw.zip"
r = requests.get(url)

with open(model_path, "wb") as f:
raw_zip_path = os.path.join(NEMO_CURATOR_HOME, "nsfw.zip")
with open(raw_zip_path, "wb") as f:
f.write(r.content)
with zipfile.ZipFile(raw_zip_path, "r") as f:
f.extractall(NEMO_CURATOR_HOME)

return model_path

def load_model(self, device):
model = H14_NSFW_Detector(input_size=self.embedding_size).to(device)
model = NSFWModel().to(device)
weights = torch.load(self.model_path, map_location=torch.device("cpu"))
model.load_state_dict(weights)
model.eval()
Expand Down

0 comments on commit 37058a9

Please sign in to comment.