CNN和LSTM的输入数据维度

  • Post author:
  • Post category:其他


CNN:

对于torch.nn.Conv2d()来说,输入的数据需要是4维的,即

[batch_size,channel,image_x,iamge_y]

如果每次输入一个图像的话通道就是1

pg.CNN的第一层是

torch.nn.Conv2d(1,64)

说明输出通道是64

LSTM:

(seq_len, batch, input_size)

其中:

seq_len表示的是句子的长度

batch表示的是一次往LSTM中输入句子的数目

input_size表示的是输入的维度

LSTM的输出维度:

在LSTM中,hidden_size表示的是隐藏状态h的维度,也就是LSTM的输出维度。

LSTM的输出为outputs,(h_n,c_n),

其中outputs表示的是最后一层隐藏层各Cell对应的隐藏状态,

h_n表示的是各隐藏层的最后一个时间步对应的隐藏状态,

c_n表示的是各隐藏层的最后一个时间步对应的细胞状态。

如果输入的数据为(3,4,2),即数据长度为3,batch为4,数据维度为2,hidden_size为10:

则outputs的维度为(3,4,10)

h_n的维度为(1,4,10)

c_n的维度为(1,4,10)



版权声明:本文为qq_45283936原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。