起因
因公司项目需求,需要提升目标检测速度。原始框架为:yolov3-keras,后发现满足不了检测速度,准备更换前段时间刚出的yolov4
地址
在github上找到一堆代码,
TensorFlow版本
,
keras版本
以下使用的TensorFlow版本代码
版本
Tensorflow 2.3.0rc0
!注:使用cuda10.1
下载方式:
pip install tensorflow-gpu==2.3.0rc0
使用bug与解决
首先clone下代码后,进行测试。按如上代码可以跑通。但是当使用train.py的时候出现较多问题。
首先是无法加载模型:
我将代码注释后可以使用。个人感觉作者的逻辑bug,他应该是设想检测如果没有weight文件报错。
其次,经常出现loss为NaN情况。
然后我更改我的数据集,将原本训练为tiff的图片更改为jpg图片。发现可以训练并且没有出现NAN情况。但是我查看代码后感觉并没有问题。都是使用的CV2.imread方法。
此处并没有解决完成
完整自训练数据集使用
先制作dataset。我目前是将数据更改为coco输入的。
定义config文件:
然后运行train.py出来的结果是
然后运行
python save_model.py --weights ./data/yolov4.weights --output ./checkpoints/yolov4-416 --input_size 416 --model yolov4
我这里是将上面自己保存的TF模型,覆盖到上面variables文件里面。
然后运行
python detect.py --weights ./checkpoints/yolov4-416 --size 416 --model yolov4 --image ./data/kite.jpg
后续
1、该模型没有数据增强。故结果并不理想但是速度应该是最快的,
2、版本bug较多,资料较少(tensorflow 2.3.0rc0 最新)