82 lines
3.5 KiB
Python

import cv2
import numpy as np
import onnxruntime
from PIL import Image
# 函数用来调整图片尺寸,保持原始宽高比
def letterbox_image(image, size):
'''根据指定大小调整图片尺寸,保持原图的长宽比例不变,多余部分填充'''
iw, ih = image.size # 获取原图宽高
w, h = size # 目标尺寸
scale = min(w/iw, h/ih) # 计算缩放比例
nw = int(iw*scale) # 缩放后的宽度
nh = int(ih*scale) # 缩放后的高度
image = image.resize((nw,nh), Image.BICUBIC) # 使用双三次插值调整图片大小
new_image = Image.new('RGB', size, (128,128,128)) # 创建新的RGB图片
new_image.paste(image, ((w-nw)//2, (h-nh)//2)) # 将调整后的图片粘贴到新创建的图片上
return new_image
# 函数用来预处理输入图像
def preprocess_image(img_path, model_image_size=(640, 640)):
# 加载图片并转换为RGB格式
img = Image.open(img_path).convert('RGB')
# 调整图片尺寸,保持长宽比
img = letterbox_image(img, tuple(reversed(model_image_size)))
# 转换为numpy数组并归一化到0-1之间
image_data = np.array(img, dtype='float32') / 255
# 增加一个维度作为批次大小
image_data = np.expand_dims(image_data, 0)
return image_data
# 函数用来处理模型输出
def postprocess_output(output, conf_threshold=0.5):
# 处理模型输出结果
boxes = [] # 存储检测框
scores = [] # 存储置信度
class_ids = [] # 存储类别ID
# 假设模型输出的形式为[x_center, y_center, width, height, confidence, ...class probabilities...]
for detection in output[0][0]: # 遍历每个检测结果
scores = detection[5:] # 获取每个类别的概率
class_id = np.argmax(scores) # 找到概率最高的类别ID
confidence = scores[class_id] # 获取最高概率
if confidence > conf_threshold: # 如果超过阈值,则保存该检测框
box = detection[:4] * [width, height, width, height] # 调整框大小到原图尺寸
(x_center, y_center, width, height) = box
x = int(x_center - (width / 2)) # 计算左上角X坐标
y = int(y_center - (height / 2)) # 计算左上角Y坐标
boxes.append([x, y, int(width), int(height)]) # 添加到检测框列表
scores.append(float(confidence)) # 添加置信度到列表
class_ids.append(class_id) # 添加类别ID到列表
return boxes, scores, class_ids
if __name__ == "__main__":
# 加载ONNX模型
ort_session = onnxruntime.InferenceSession("./best.onnx")
# 预处理图像
img_path = "./1.jpg" # 图片路径
image_data = preprocess_image(img_path)
# 进行推理计算
ort_inputs = {ort_session.get_inputs()[0].name: image_data} # 模型输入
ort_outs = ort_session.run(None, ort_inputs) # 模型推理
# 后处理推理结果
boxes, scores, class_ids = postprocess_output(ort_outs)
print(class_ids)
# # 可选地在图像上画出检测框
# img = cv2.imread(img_path)
# height, width = img.shape[:2] # 获取图片的高度和宽度
# for box, score, class_id in zip(boxes, scores, class_ids):
# x, y, w, h = box
# cv2.rectangle(img, (x, y), (x + w, y + h), (0, 255, 0), 2) # 在图片上画框
# cv2.putText(img, f"{score:.2f}", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) # 在图片上显示置信度
# # 显示处理后的图像
# cv2.imshow("Image", img)
# cv2.waitKey(0)
# cv2.destroyAllWindows()