Pytorch nn.RNN()解析

  • Post author:
  • Post category:其他


RNN基本结构与nn.RNN()参数介绍可参考:


参数介绍



官方文档


以下代码对 nn.RNN() 的简单应用进行了注解介绍

import torch
import torch.nn as nn
import torch.functional as F


# 单层RNN,输入x特征为10,输出特征为20, 两层堆叠
rnn = nn.RNN(10, 20, 2)
# 随机构建输入,这里假设每句话只有10个单词,共3句话(即一个batch 3句话),每个单词被embedding为10维向量
inputs = torch.rand(10, 3, 10)
# 随机构建h0, 因为只有2层单向,所以参数为2
h_0 = torch.zeros(2, 3, 20)
# 输出结果,其中output为每一个输入(每一个单词)对应的最终输出值;h_0为最后一个单词在各层的输出值(此处有两层)
output, h_n = rnn(inputs, h_0)


# torch.Tensor
print(type(output))
# (10, 3, 20)
print(output.shape)
print(output)
print("=========================")
# torch.Tensor
print(type(h_n))
# (2,3,20)
print(h_n.shape)
print(h_n)



版权声明:本文为weixin_43860783原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。