Coding-使用torch的maskrcnn识别&分割

摘要

使用torch内置的maskrcnn实现图像的识别和分割

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
import torchvision
from torchvision.transforms import functional as F
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
# from torchvision.models.detection import coco_utils

# 加载预训练的Mask R-CNN模型
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
model.eval()


# 图像预处理
def preprocess_image(image_path):
img = Image.open(image_path).convert("RGB")
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
img_tensor = transform(img)
return img_tensor.unsqueeze(0).to(device)

# 后处理函数,用于从模型输出中获取预测框、类别和分割掩码
def process_output(output):
boxes = output['boxes'].detach().cpu().numpy()
labels = output['labels'].detach().cpu().numpy()
masks = output['masks'].detach().cpu().numpy()
scores = output['scores'].detach().cpu().numpy()

# 过滤低置信度的预测
idxs = np.where(scores > 0.8)[0]
boxes = boxes[idxs]
labels = labels[idxs]
masks = masks[idxs]

return boxes, labels, masks

# 加载图像并进行预处理
image_path = 'pic.jpg' # 替换为你的图像路径
img_tensor = preprocess_image(image_path)

# 模型预测
with torch.no_grad():
predictions = model(img_tensor)

# 处理预测结果
boxes, labels, masks = process_output(predictions[0])

# COCO类别列表
COCO_CATEGORIES = [
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', 'clock', 'vase',
'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]


# 打印标签对应的类别名称
for label in labels:
class_name = COCO_CATEGORIES[label]
print(f"Label {label} represents: {class_name}")

# 可视化结果
plt.figure(figsize=(10, 10))


# 创建一个新的图像,用于叠加掩码
overlay = np.zeros_like(F.to_pil_image(img_tensor[0]).convert('RGB')).astype(np.uint8)

threshold = 0.99 # 设定阈值,根据你的需要调整这个值
# 在新图上逐个绘制掩码
for i in range(len(masks)):
color = np.random.rand(3) * 255 # 生成随机颜色,并转换为0-255范围内的整数
mask = masks[i][0] * 255 # 将掩码转换为0-255范围内的灰度图

# 对数据应用阈值处理,将所有大于阈值的像素置为1,其余置为0
binary_mask = (mask > threshold).astype(np.uint8)

overlay[binary_mask.astype(bool)] = color # 将掩码应用到叠加图像上


plt.imshow(overlay)
plt.show()

# 使用叠加图像和原始图像创建透明效果
original_image = F.to_pil_image(img_tensor[0])
combined = Image.blend(original_image, Image.fromarray(overlay.astype(np.uint8)), alpha=0.4)

plt.imshow(combined) # 显示带有掩码透明叠加的图像
# 再次绘制边界框,以确保它们位于最上层
for box in boxes:
plt.gca().add_patch(plt.Rectangle((box[0], box[1]), box[2]-box[0], box[3]-box[1], fill=False, edgecolor='red', linewidth=2))

plt.axis('off') # 关闭坐标轴以更好地展示图像
plt.show()