218 lines
11 KiB
Python
218 lines
11 KiB
Python
import os
|
|
import cv2
|
|
|
|
from base_camera import BaseCamera
|
|
import torch
|
|
from pathlib import Path
|
|
import sys
|
|
from models.common import DetectMultiBackend
|
|
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams
|
|
from utils.general import (check_img_size, increment_path, non_max_suppression, scale_coords)
|
|
from utils.plots import Annotator, colors
|
|
from utils.torch_utils import select_device, time_sync
|
|
|
|
import requests
|
|
|
|
FILE = Path(__file__).resolve()
|
|
ROOT = FILE.parents[0] # YOLOv5 root directory
|
|
if str(ROOT) not in sys.path:
|
|
sys.path.append(str(ROOT)) # add ROOT to PATH
|
|
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
|
|
|
|
|
|
class Camera(BaseCamera):
|
|
video_source = None
|
|
# 违规动作集
|
|
actions = []
|
|
|
|
def __init__(self, source_name="", img_flag=True):
|
|
if os.environ.get('OPENCV_CAMERA_SOURCE'):
|
|
Camera.set_video_source(int(os.environ['OPENCV_CAMERA_SOURCE']))
|
|
Camera.video_source = source_name
|
|
self.img_flag = img_flag
|
|
super(Camera, self).__init__()
|
|
|
|
@staticmethod
|
|
def set_video_source(source):
|
|
Camera.video_source = source
|
|
|
|
@staticmethod
|
|
def frames():
|
|
out, weights, imgsz = \
|
|
ROOT / 'output', ROOT / 'weights/best.pt', 640
|
|
device = select_device(0) # 0 or 0,1,2,3 or cpu 没有gpu的记得改为cpu
|
|
# device = select_device("cpu")
|
|
# Directories
|
|
save_dir = increment_path(Path(out) / 'exp', exist_ok=False) # increment run
|
|
(save_dir).mkdir(parents=True, exist_ok=True) # make dir
|
|
|
|
model = DetectMultiBackend(weights, device=device, dnn=False, data=ROOT / 'data/safe_det.yaml')
|
|
|
|
stride, names, pt, jit, onnx, engine = model.stride, model.names, model.pt, model.jit, model.onnx, model.engine
|
|
imgsz = check_img_size(imgsz, s=stride) # check image size
|
|
|
|
# Half
|
|
half = False
|
|
half &= (pt or jit or onnx or engine) and device.type != 'cpu' # FP16 supported on limited backends with CUDA
|
|
if pt or jit:
|
|
model.model.half() if half else model.model.float()
|
|
|
|
# Dataloader
|
|
print(Camera.video_source)
|
|
is_file = Path(Camera.video_source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
|
|
is_url = Camera.video_source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
|
|
webcam = Camera.video_source.isnumeric() or Camera.video_source.endswith('.txt') or (is_url and not is_file)
|
|
if webcam:
|
|
dataset = LoadStreams(Camera.video_source, img_size=imgsz)
|
|
else:
|
|
dataset = LoadImages(Camera.video_source, img_size=imgsz)
|
|
vid_path, vid_writer = [None], [None]
|
|
# save_result
|
|
save_img = True
|
|
# Run inference
|
|
t0 = time_sync()
|
|
img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img
|
|
_ = model(img.half() if half else img) if device.type != 'cpu' else None # run once
|
|
for path, img, im0s, vid_cap, s in dataset:
|
|
|
|
img = torch.from_numpy(img).to(device)
|
|
img = img.half() if half else img.float() # uint8 to fp16/32
|
|
img /= 255.0 # 0 - 255 to 0.0 - 1.0
|
|
if img.ndimension() == 3:
|
|
img = img.unsqueeze(0)
|
|
|
|
# Inference
|
|
t1 = time_sync()
|
|
pred = model(img, augment=False, visualize=False)
|
|
|
|
# nms
|
|
pred = non_max_suppression(pred, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False,
|
|
max_det=1000)
|
|
t2 = time_sync()
|
|
|
|
for i, det in enumerate(pred): # detections per image
|
|
if webcam: # batch_size >= 1
|
|
p, im0, frame = path[i], im0s[i].copy(), dataset.count
|
|
s += f'{i}: '
|
|
else:
|
|
# p, s, im0 = path, '', im0s
|
|
p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)
|
|
p = Path(p) # to Path
|
|
save_path = str(save_dir / p.name)
|
|
s += '%gx%g ' % img.shape[2:]
|
|
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]
|
|
annotator = Annotator(im0, line_width=3, example=str(names))
|
|
|
|
# 请求后端接口
|
|
for *xyxy, conf, cls in reversed(det):
|
|
# url = 'http://127.0.0.1:7861/chat/knowledge_base_chat'
|
|
url = 'http://192.168.0.21:17861/chat/knowledge_base_chat'
|
|
headers = {
|
|
'Content-Type': 'application/json'
|
|
}
|
|
# 未戴安全帽,只有当置信度大于0.5时才处理
|
|
if int(cls) == 2 and conf > 0.5:
|
|
Camera.actions.append(2)
|
|
# data = {
|
|
# "query": "作业现场人员未佩戴安全帽是几级违规行为?具体违反了哪些条款?条款内容是什么?",
|
|
# "knowledge_base_name": "violation",
|
|
# "top_k": 2,
|
|
# "score_threshold": 2.0,
|
|
# "history": [],
|
|
# "stream": False,
|
|
# "model_name": "chatglm3-6b",
|
|
# "temperature": 0.01,
|
|
# "max_tokens": 1024,
|
|
# "prompt_name": "default"
|
|
# }
|
|
# time_start = time.time()
|
|
# response = requests.post(url, headers=headers, json=data, stream=True)
|
|
# response = requests.post(url, headers=headers, json=data)
|
|
# time_end = time.time()
|
|
# time_c = time_end - time_start # 运行所花时间
|
|
# print('time cost', time_c, 's')
|
|
# json_encoded=response.answer
|
|
# print('================', response.answer)
|
|
# print('---------------API called, response:', response.text)
|
|
# for chunk in response.iter_content(chunk_size=1024):
|
|
# # 处理响应内容
|
|
# print(chunk.decode("utf-8", 'ignore'), type(chunk))
|
|
# 未戴安全带
|
|
elif int(cls) == 4 and conf > 0.5:
|
|
Camera.actions.append(4)
|
|
# data = {
|
|
# "query": "高处作业未佩戴安全带是几级违规行为?具体违反了哪些条款?条款内容是什么?",
|
|
# "knowledge_base_name": "violation",
|
|
# "top_k": 2,
|
|
# "score_threshold": 2.0,
|
|
# "history": [],
|
|
# "stream": False,
|
|
# "model_name": "chatglm3-6b",
|
|
# "temperature": 0.01,
|
|
# "max_tokens": 1024,
|
|
# "prompt_name": "default"
|
|
# }
|
|
# response = requests.post(url, headers=headers, json=data, stream=True)
|
|
# response = requests.post(url, headers=headers, json=data)
|
|
# print('---------------API called, response:', response.text)
|
|
# for chunk in response.iter_content(chunk_size=1024):
|
|
# # 处理响应内容
|
|
# print(chunk.decode("utf-8", 'ignore'))
|
|
# 跨越围栏
|
|
elif int(cls) == 5 and conf > 0.5:
|
|
Camera.actions.append(5)
|
|
# data = {
|
|
# "query": "作业现场人员跨越围栏是几级违规行为?具体违反了哪些条款?条款内容是什么?",
|
|
# "knowledge_base_name": "violation",
|
|
# "top_k": 2,
|
|
# "score_threshold": 2.0,
|
|
# "history": [],
|
|
# "stream": False,
|
|
# "model_name": "chatglm3-6b",
|
|
# "temperature": 0.01,
|
|
# "max_tokens": 1024,
|
|
# "prompt_name": "default"
|
|
# }
|
|
# response = requests.post(url, headers=headers, json=data, stream=True)
|
|
# response = requests.post(url, headers=headers, json=data)
|
|
# print('---------------API called, response:', response.text)
|
|
# for chunk in response.iter_content(chunk_size=1024):
|
|
# # 处理响应内容
|
|
# print(chunk.decode("utf-8", 'ignore'))
|
|
|
|
if det is not None and len(det):
|
|
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
|
|
for c in det[:, -1].detach().unique():
|
|
n = (det[:, -1] == c).sum()
|
|
s += '%g %s, ' % (n, names[int(c)])
|
|
for *xyxy, conf, cls in det:
|
|
label = '%s %.2f' % (names[int(cls)], conf)
|
|
c = int(cls)
|
|
annotator.box_label(xyxy, label, color=colors(c, True))
|
|
|
|
print('%sDone. (%.3fs)' % (s, t2 - t1))
|
|
# Save results (image with detections)
|
|
if save_img:
|
|
if dataset.mode == 'image':
|
|
cv2.imwrite(save_path, im0)
|
|
else: # 'video' or 'stream'
|
|
if vid_path[i] != save_path: # new video
|
|
vid_path[i] = save_path
|
|
if isinstance(vid_writer[i], cv2.VideoWriter):
|
|
vid_writer[i].release() # release previous video writer
|
|
if vid_cap: # video
|
|
fps = vid_cap.get(cv2.CAP_PROP_FPS)
|
|
w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
else: # stream
|
|
fps, w, h = 30, im0.shape[1], im0.shape[0]
|
|
save_path = str(Path(save_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos
|
|
vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
|
vid_writer[i].write(im0)
|
|
|
|
# yield cv2.imencode('.jpg', im0)[1].tobytes()
|
|
|
|
# frame_encoded = cv2.imencode('.jpg', im0)[1].tobytes()
|
|
# json_encoded=
|
|
yield cv2.imencode('.jpg', im0)[1].tobytes()
|