将用于NLP的Encoder-Decoder修改用于时间序列数据预测,实验发现添加注意力机制后预测效果能够得到提升。
class Encoder (nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.rnn=nn.LSTM(
input_size=INPUT_SIZE,
hidden_size=HIDDEN_SIZE,
num_layers= 1,
batch_first=True
)
def forward(self,x):
r_out, (hidden,cell) = self.rnn(x)
print(r_out.shape)
return r_out,hidden,cell
class Decoder (nn.Module):
def __init__(self):
super(Decoder, self).__init__()
self.rnn=nn.LSTM(
input_size=INPUT_SIZE,
hidden_size=HIDDEN_SIZE,
num_layers= 1,
batch_first=True
)
self.out=nn.Linear(HIDDEN_SIZE,1)
def forward(self,x,hidden,cell):
print("x:", x.shape)
output, (hidden,cell) = self.rnn(x,(hidden,cell))
print("output:", output.shape)
print("output.squeeze(0):", output.squeeze(0).shape)
prediction = self.out(output.squeeze(0))
print("prediction:",prediction.shape)
return prediction,hidden,cell
class Seq2Seq(nn.Module):
def __init__(self, encoder, decoder, device):
super(Seq2Seq, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.device = device
def attention_net(self, lstm_output, final_state):
hidden = final_state.view(-1, HIDDEN_SIZE , 1) # hidden : [batch_size, n_hidden * num_directions(=2), 1(=n_layer)]
# print("----------------------------------------------------")
# print("hidden的值:", hidden.shape)
attn_weights = torch.bmm(lstm_output, hidden).squeeze(2) # attn_weights : [batch_size, n_step]
# print("attn_weights的值:", attn_weights.shape)
soft_attn_weights = F.softmax(attn_weights, 1)
# print("soft_attn_weights的值:", soft_attn_weights.shape)
# print("soft_attn_weights.unsqueeze(2)的值:", soft_attn_weights.unsqueeze(2).shape)
# print("torch.bmm(lstm_output.transpose(1, 2), soft_attn_weights.unsqueeze(2))的值:", torch.bmm(lstm_output.transpose(1, 2), soft_attn_weights.unsqueeze(2)).shape)
context = torch.bmm(lstm_output.transpose(1, 2), soft_attn_weights.unsqueeze(2)).squeeze(2)
print("context的值:", context.shape)
return context, soft_attn_weights.data.numpy() # context : [batch_size, n_hidden * num_directions(=2)]
def forward(self,src):
src_len=src.shape[0]
batch_size = src.shape[1]
outputs =torch.zeros(src_len, batch_size, 1).to(self.device).double()
# print("------------------------------")
# print("outputs:",outputs.shape)
print(src.shape)
r_out,hidden,cell = self.encoder(src)
print("r_out",r_out.shape)
print("hidden", hidden.shape)
attn_output, attention = self.attention_net(r_out, hidden)
hidden = attn_output.view(1, -1, HIDDEN_SIZE)
# print("hidden___",hidden.shape)
# print("attn_output",attn_output.shape)
# print("attention", attention.shape)
# print("------------------------------")
# print("src:", src.shape)
# print("hidden:",hidden.shape)
# print("cell:",cell.shape)
# print("------------------------------")
for t in range(1,batch_size):
input=src[:,t-1,:].unsqueeze(1)
print("input:",input.shape)
output, hidden, cell = self.decoder(input, hidden, cell)
print("------------------------------")
print("output:",output.shape)
print("hidden:", hidden.shape)
print("cell:", cell.shape)
print("outputs:", outputs.shape)
print("outputs[:,t,:]:", outputs[:,t-1,:].unsqueeze(1).shape)
outputs[:,t-1,:]=output.squeeze(1)
print("------------------------------")
print("outputs:",outputs.shape)
return outputs
版权声明:本文为bullety原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。