模型:
FredZhang7/google-safesearch-mini-v2
Google Safesearch Mini V2在训练过程中采用了不同的方法,使用了InceptionResNetV2架构以及大约340万张从互联网随机获得的图片数据集,其中一些图片是通过数据增强生成的。训练和验证数据来自Google Images、Reddit、Kaggle和Imgur,由公司、Google SafeSearch和版主对其进行了安全或不安全的分类。
在将模型训练了5个轮次并在训练集和验证集上进行了评估以确定预测概率低于0.90的图片后,对策划的数据集进行了必要的修正,并在额外训练了8个轮次。接下来,我对模型进行了各种可能难以分类的情况进行了测试,并观察到它将棕色猫的皮毛误认为人体皮肤。为了提高准确性,我使用了来自Kaggle的15个附加数据集对模型进行了微调,然后在最后一个轮次中使用训练和测试数据的组合进行了训练。这使得训练和验证数据上的准确率达到了97%。
Safesearch过滤器不仅是社交媒体的良好工具,还可以用于过滤数据集。与稳定扩散安全检查器相比,该模型具有重要的优势-用户可以节省1.0 GB的RAM和磁盘空间。
pip install --upgrade torchvision
import torch, os from torchvision import transforms from PIL import Image import urllib.request import timm image_path = "https://www.allaboutcats.ca/wp-content/uploads/sites/235/2022/03/shutterstock_320462102-2500-e1647917149997.jpg" device = "cuda" def preprocess_image(image_path): # Define image pre-processing transforms transform = transforms.Compose([ transforms.Resize(299), transforms.CenterCrop(299), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) if image_path.startswith('http://') or image_path.startswith('https://'): import requests from io import BytesIO response = requests.get(image_path) img = Image.open(BytesIO(response.content)).convert('RGB') else: img = Image.open(image_path).convert('RGB') img = transform(img).unsqueeze(0) img = img.cuda() if device.lower() == "cuda" else img.cpu() return img def eval(): model = timm.create_model("hf_hub:FredZhang7/google-safesearch-mini-v2", pretrained=True) model.to(device) img = preprocess_image(image_path) with torch.no_grad(): out = model(img) _, predicted = torch.max(out.data, 1) classes = { 0: 'nsfw_gore', 1: 'nsfw_suggestive', 2: 'safe' } print('\n\033[1;31m' + classes[predicted.item()] + '\033[0m' if predicted.item() != 2 else '\033[1;32m' + classes[predicted.item()] + '\033[0m\n') if __name__ == '__main__': eval()