机器学习中的对象检测任务涉及识别图像或视频中特定类别(例如人、汽车或动物)的实例,然后通过在它们周围绘制边界框来准确定位这些实例。
让我们快速尝试一种模型:我们将检测图像中的猫:
from transformers import pipeline
model = pipeline("object-detection")
result = model("cat.jpg")
result
"""
[{'score': 0.9988692402839661,
'label': 'cat',
'box': {'xmin': 854, 'ymin': 499, 'xmax': 4094, 'ymax': 2797}}]
"""
我们初始化了一个物体检测管道,result就是输出。
from PIL import Image, ImageDraw
image_path = "cat.jpg"
image = Image.open(image_path)
box = result[0]["box"]
draw = ImageDraw.Draw(image)
bounding_box = (box['xmin'], box['ymin'], box['xmax'], box['ymax'])
draw.rectangle(bounding_box, outline="red", width=10)
image.show()
让我们再试一次:
from transformers import DetrImageProcessor, DetrForObjectDetection
import torch
from PIL import Image
import requests
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
# you can specify the revision tag if you don't want the timm dependency
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
# convert outputs (bounding boxes and class logits) to COCO API
# let's only keep detections with score > 0.9
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
box = [round(i, 2) for i in box.tolist()]
print(
f"Detected {model.config.id2label[label.item()]} with confidence "
f"{round(score.item(), 3)} at location {box}"
)
"""
Detected remote with confidence 0.998 at location [40.16, 70.81, 175.55, 117.98]
Detected remote with confidence 0.996 at location [333.24, 72.55, 368.33, 187.66]
Detected couch with confidence 0.995 at location [-0.02, 1.15, 639.73, 473.76]
Detected cat with confidence 0.999 at location [13.24, 52.05, 314.02, 470.93]
Detected cat with confidence 0.999 at location [345.4, 23.85, 640.37, 368.72]
"""
DetrImageProcessor 和 DetrForObjectDetection 都是从转换器库中导入的类,专门用于处理图像和使用 DETR 模型进行对象检测。
对图像进行处理,使其格式适合模型,并将其转换为张量(return_tensors="pt "表示 PyTorch 张量)。然后,模型执行对象检测,返回包括边界框和类对数(原始的、未规范化的分数,是最终 softmax 函数的输入)在内的输出。
对结果进行后处理,将模型输出转换成更方便用户使用的格式,包括通过置信度阈值(阈值=0.9)过滤检测结果。只有得分高于 0.9 的检测结果才会被保留。
from PIL import ImageDraw
draw = ImageDraw.Draw(image)
detected_objects = [
{"label": "remote", "score": 0.998, "box": [40.16, 70.81, 175.55, 117.98]},
{"label": "remote", "score": 0.996, "box": [333.24, 72.55, 368.33, 187.66]},
{"label": "couch", "score": 0.995, "box": [-0.02, 1.15, 639.73, 473.76]},
{"label": "cat", "score": 0.999, "box": [13.24, 52.05, 314.02, 470.93]},
{"label": "cat", "score": 0.999, "box": [345.4, 23.85, 640.37, 368.72]}
]
for obj in detected_objects:
box = obj['box']
label = obj['label']
score = obj['score']
draw.rectangle(box, outline="red", width=2)
text = f"{label} {score:.3f}"
draw.text((box[0], box[1] - 10), text, fill="red")
image.show()
针对不同类型的对象,有各种对象检测模型。例如,模型可以检测 PDF 文档中的表格。
from huggingface_hub import hf_hub_download
from transformers import AutoImageProcessor, TableTransformerForObjectDetection
import torch
from PIL import Image
file_path = hf_hub_download(repo_id="nielsr/example-pdf", repo_type="dataset", filename="example_pdf.png")
image = Image.open(file_path).convert("RGB")
image_processor = AutoImageProcessor.from_pretrained("microsoft/table-transformer-detection")
model = TableTransformerForObjectDetection.from_pretrained("microsoft/table-transformer-detection")
inputs = image_processor(images=image, return_tensors="pt")
outputs = model(**inputs)
# convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
target_sizes = torch.tensor([image.size[::-1]])
results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[
0
]
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
box = [round(i, 2) for i in box.tolist()]
print(
f"Detected {model.config.id2label[label.item()]} with confidence "
f"{round(score.item(), 3)} at location {box}"
)
"""
Detected table with confidence 1.0 at location [202.1, 210.59, 1119.22, 385.09]
"""