模型:

FredZhang7/google-safesearch-mini-v2

英文

Google Safesearch Mini V2是一个精确的多类图像分类器,能够准确检测出不雅内容

Google Safesearch Mini V2在训练过程中采用了不同的方法,使用了InceptionResNetV2架构以及大约340万张从互联网随机获得的图片数据集,其中一些图片是通过数据增强生成的。训练和验证数据来自Google Images、Reddit、Kaggle和Imgur,由公司、Google SafeSearch和版主对其进行了安全或不安全的分类。

在将模型训练了5个轮次并在训练集和验证集上进行了评估以确定预测概率低于0.90的图片后,对策划的数据集进行了必要的修正,并在额外训练了8个轮次。接下来,我对模型进行了各种可能难以分类的情况进行了测试,并观察到它将棕色猫的皮毛误认为人体皮肤。为了提高准确性,我使用了来自Kaggle的15个附加数据集对模型进行了微调,然后在最后一个轮次中使用训练和测试数据的组合进行了训练。这使得训练和验证数据上的准确率达到了97%。

Safesearch过滤器不仅是社交媒体的良好工具,还可以用于过滤数据集。与稳定扩散安全检查器相比,该模型具有重要的优势-用户可以节省1.0 GB的RAM和磁盘空间。

PyTorch

pip install --upgrade torchvision
import torch, os
from torchvision import transforms
from PIL import Image
import urllib.request
import timm

image_path = "https://www.allaboutcats.ca/wp-content/uploads/sites/235/2022/03/shutterstock_320462102-2500-e1647917149997.jpg"
device = "cuda"

def preprocess_image(image_path):
  # Define image pre-processing transforms
    transform = transforms.Compose([
      transforms.Resize(299),
      transforms.CenterCrop(299),
      transforms.ToTensor(),
      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    if image_path.startswith('http://') or image_path.startswith('https://'):
        import requests
        from io import BytesIO
        response = requests.get(image_path)
        img = Image.open(BytesIO(response.content)).convert('RGB')
    else:
        img = Image.open(image_path).convert('RGB')
    img = transform(img).unsqueeze(0)
    img = img.cuda() if device.lower() == "cuda" else img.cpu()
    return img

def eval():
    model = timm.create_model("hf_hub:FredZhang7/google-safesearch-mini-v2", pretrained=True)
    model.to(device)
    img = preprocess_image(image_path)

    with torch.no_grad():
        out = model(img)
        _, predicted = torch.max(out.data, 1)
        classes = {
            0: 'nsfw_gore',
            1: 'nsfw_suggestive',
            2: 'safe'
        }
        print('\n\033[1;31m' + classes[predicted.item()] + '\033[0m' if predicted.item() != 2 else '\033[1;32m' + classes[predicted.item()] + '\033[0m\n')

if __name__ == '__main__':
    eval()