首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >RT-DeTr实时端到端Transformer对象检测从训练到部署

RT-DeTr实时端到端Transformer对象检测从训练到部署

作者头像
OpenCV学堂
发布2026-04-02 19:57:24
发布2026-04-02 19:57:24
2050
举报
DeTr介绍

第一个端到端基于Transformer的对象检测模型,采用CNN + Transformer混合架构

图片
图片

官方的Backbone部分采用ResNet系列模型,Transformer部分采用标准的编码器与解码器结构,最后通过FFN直接预测,通过可学习NMS实现直接预测结果。

图片
图片

百度在此基础上提出了实时DETR模型,最新版本为RTDeTRv2版本,其中RTDeTR已经被YOLOv8官方收录,支持按照YOLOv8~YOLO11系列模型那样从训练到部署。

模型训练

训练的命令行如下:

代码语言:javascript
复制
yolo detect train model=rtdetr-l.pt data=bee_ant.yaml
图片
图片

模型导出

代码语言:javascript
复制
yolo export model=best.pt format=onnx
图片
图片

推理演示

导出ONNX格式模型的输入与输出信息如下:

图片
图片

其中输出数据格式如下:

代码语言:javascript
复制
1x300x6

300表示预测框的数目、6前面四个数据是cx,cy,w,h 是0~1之间的值,后面两个分别是ant与bee的类别得分。

图片
图片

导出预训练rtdetr-l.pt模型为ONNX格式

代码语言:javascript
复制
from ultralytics import RTDETR
# Load a COCO-pretrained RT-DETR-l model
model = RTDETR("rtdetr-l.pt")
# Display model information (optional)
model.info()
model.export(format="onnx", imgsz=640)

推理测试:

图片
图片

ONNXRUNTIME部署代码

代码语言:javascript
复制
onnxpath="D:/python/yolov5-7.0/rtdetr-ant-bee-best.onnx"
rtdetr_path = "rtdetr-l.onnx"
model = ort.InferenceSession(rtdetr_path)
frame = cv.imread("D:/kgroup.jpg")
bgr = format_yolov8(frame)
img_h, img_w, img_c = bgr.shape
start = time.time()
image = cv.dnn.blobFromImage(bgr, 1 / 255.0, (640, 640), swapRB=True, crop=False)
res = model.run(None, {'images': image})[0]
rows = np.squeeze(res, 0)
x_factor = img_w / 640
y_factor = img_h / 640
for r in range(rows.shape[0]):
    row = rows[r]
    classes_scores = row[4:]
    class_id = np.argmax(classes_scores)
    conf = classes_scores[class_id]
    if conf>0.25:
        x, y, w, h = row[0].item(), row[1].item(), row[2].item(), row[3].item()
        left = int((x - 0.5 * w) * 640) * x_factor
        top = int((y - 0.5 * h) * 640) * x_factor
        width = int(w * 640) * x_factor
        height = int(h * 640) * x_factor
        box = [int(left), int(top), int(width), int(height)]
        color = colors[class_id % len(colors)]
        cv.rectangle(frame, box, color, 2)
        cv.rectangle(frame, (box[0], box[1] - 20), (box[0] + box[2], box[1]), color, -1)
        cv.putText(frame, class_list[class_id] + (" %.2f"%conf), (box[0], box[1] - 7), cv.FONT_HERSHEY_SIMPLEX, .5, (0, 0, 0))
end = time.time()
inf_end = end - start
fps = 1 / inf_end
fps_label = "FPS: %.2f" % fps
cv.putText(frame, fps_label, (20, 45), cv.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
cv.imshow("RTDETR Object Detection + ONNXRUNTIME", frame)
cc = cv.waitKey(0)
cv.destroyAllWindows()

OpenVINO2025

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2025-07-17,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 OpenCV学堂 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档