介绍完seq2seq+attention(Simple to seq2seq And attention | Ripshun Blog),我们来构建一个用pytorch写的模型.
第一步:构建Encoder:
代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
| class encode(nn.Module):
def __init__(self):
super(encode,self).__init__()
self.embedd = nn.Embedding(dic_num,dim_num)
self.gru = nn.GRU(dim_num,hid_dim_num,num_layers,bidirectional=True)
self.fn = nn.Linear(hid_dim_num*2,hid_dim_num)
self.dropout = nn.Dropout()
def forward(self,src):
# src:[batch_size,step_num]
embedded = self.dropout(self.embedd(src)).transpose(0, 1)
# embedded:[step_num,batch_size,dim_num]
gru_y,gru_h = self.gru(embedded)
#gru_y:[step_num,batch_size,2*hid_dim] (最后一层的全部时间步,最后一层将双向参数链接,所以维度乘以二)
#gru_h:[n_layers*2,batch_size,hid_dim] (最后一个时间步的全部层,因为双向所以层数乘以二)
s = torch.tanh(self.fc(torch.cat((gru_h[-2,:,:], gru_h[-1,:,:]), dim = 1)))
# s:[batch_size,hid_dim]
return gru_y,s |
图解:
第二步:构建Attention
代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
| class attention(nn.Module):
def __init__(self):
super(attention,self).__init__()
self.attn = nn.Linear(hid_dim_num*4,hid_dim_num,bias=False)
self.v = nn.Linear(hid_dim_num,1,bias=False)
def forward(self,s,y):
#s:[batch_size,2 * hid_dim]
#y: [step_num, batch_size, 2 * hid_dim]
batch_size = y.shape[1]
step_num = y.shape[0]
s = s.unsqueeze(1).repeat(1,step_num,1)
y = y.transpose(0,1)
# 将s重复时间步的维度,使得其与y维度相同
#s : [batch_size,step_num,2 * hid_dim]
#y : [batch_size,step_num,2 * hid_dim]
e = torch.tanh(self.attn(torch.cat((s,y),dim = 2 )))
#e : [batch_size,step_num,hid_dim]
attention = self.v(e).squeeze(2)
#attention : [batch,step_num]
return F.softmax(attention,dim=1) |
图解:
第三步:构建decoder:
代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
| class decoder(nn.Module):
def __init__(self,attention):
super().__init__()
self.attention = attention
self.embedded = nn.Embedding(dic_num,dim_num)
self.gru = nn.GRU(hid_dim_num*2+dim_num,hid_dim_num)
self.fc = nn.Linear(hid_dim_num*4+dim_num,dic_num)
self.dropout = nn.Dropout()
def forward(self,x,s,y):
# x : [batch_size]
# s : [batch_size,2 * hid_dim]
# y : [step_num, batch_size, 2 * hid_dim]
x = x.unsqueeze(1)
embedded = self.dropout(self.embedded(x)).transpose(0,1)
# embedded : [1,batch_size,dim_num]
a = self.attention(s,y).unsqueeze(1)
# a : [batch,1,step_num]
y = y.transpose(0,1)
# y : [batch_size,step_num,2*hid_dim]
c = torch.bmm(a,y).transpose(0,1)
# c : [1,batch_size,2*hid_dim]
rnn_input = torch.cat((embedded,c),dim=2)
#rnn_input : [1,batch_size,2*hid_dim + dim_num]
dec_y , dec_h = self.gru(rnn_input,s.unsqueeze(0))
#dec_y : [1,batch_size,hid_dim]
#dec_h : [1,batch_size,hid_dim]
embedded = embedded.squeeze(0)
dec_y = dec_y.squeeze(0)
c = c.squeeze(0)
dec_h = dec_h.squeeze(0)
pred = self.fc(torch.cat((dec_y,c,embedded),dim=1))
return pred,dec_h |
图解:
结尾:
下一步我们将开始实际的项目,敬请期待!
Post Views:
254