onnx转TensorRT遇到Assertion failed: axis >= 0 && axis < nbDims问题

  • Post author:
  • Post category:其他


While parsing node number xxx [Gather]:
ERROR: onnx2trt_utils.hpp:277 In function convert_axis:
[8] Assertion failed: axis >= 0 && axis < nbDims

这个问题我在网上看了很多帖子,基本都是reshape的问题,但是我的网络结构中并不含有reshape函数,因此与其他的解决方式不同.

我将onnx转入到netron中,发现了一些端倪:

报错的位置Gather操作,源自于

F.interpolate(x, size=x.shape[-2:], mode='nearest')

插值操作,但是另我不解的是,明明一个简单上采样,为什么会多出两个Gather分支?由网络结构中可以看到这个两个分支是用来取到x.shape的第2和第3个索引的,于是我打开pycharm的调试,发现对于x这个Tensor的shape属性,并不是我以为的[a,b,c,d]这种形式,而是[Tensor(a),Tensor(b),Tensor(c),Tensor(d)]!也就是说,对于shape中的每一个维度并不是int类型,而是每个int类型的数构成的1*1的Tensor.

这也就解释了为什么网络结构中会取出两个索引后,还需要进行Unsqueeze操作,因为需要将取出的1*1Tensor展开到一个int形式,再进行concat组成一个[c,d]的shape形式.

于是就可以先将shape中1*1的Tensor转为int类型后,再组成shape,这样再进行插值操作,即:

size_ = [int(x.shape[2]), int(x.shape[3])] 
out = return F.interpolate(x, size=size_, mode='nearest')

这样导出的onnx模型,就不会报[8] Assertion failed: axis >= 0 && axis < nbDims这个错了.将onnx模型放到netron中,可以看到:

这时的插值操作被转换为了简单的resize分支,而且不具有其他的Gather分支了.

我被这个问题折磨的挺久了,最后发现是Tensor中shape的数据类型导致的,在这里记录一下,希望可以解决给其他人一些解决这个问题的其他思路.



版权声明:本文为foso1994原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。