介绍完seq2seq+attention(Simple to seq2seq And attention | Ripshun Blog),我们来构建一个用pytorch写的模型.

第一步:构建Encoder:

代码:


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class encode(nn.Module):
    def __init__(self):
        super(encode,self).__init__()
        self.embedd = nn.Embedding(dic_num,dim_num)
        self.gru = nn.GRU(dim_num,hid_dim_num,num_layers,bidirectional=True)
        self.fn = nn.Linear(hid_dim_num*2,hid_dim_num)
        self.dropout = nn.Dropout()

    def forward(self,src):
        # src:[batch_size,step_num]
        embedded = self.dropout(self.embedd(src)).transpose(0, 1)
        # embedded:[step_num,batch_size,dim_num]
        gru_y,gru_h = self.gru(embedded)
        #gru_y:[step_num,batch_size,2*hid_dim] (最后一层的全部时间步,最后一层将双向参数链接,所以维度乘以二)
        #gru_h:[n_layers*2,batch_size,hid_dim] (最后一个时间步的全部层,因为双向所以层数乘以二)
        s = torch.tanh(self.fc(torch.cat((gru_h[-2,:,:], gru_h[-1,:,:]), dim = 1)))
        # s:[batch_size,hid_dim]
        return gru_y,s

图解:

第二步:构建Attention

代码:


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class attention(nn.Module):
    def __init__(self):
        super(attention,self).__init__()
        self.attn = nn.Linear(hid_dim_num*4,hid_dim_num,bias=False)
        self.v = nn.Linear(hid_dim_num,1,bias=False)

    def forward(self,s,y):
        #s:[batch_size,2 * hid_dim]
        #y: [step_num, batch_size, 2 * hid_dim]

        batch_size = y.shape[1]
        step_num = y.shape[0]

        s = s.unsqueeze(1).repeat(1,step_num,1)
        y = y.transpose(0,1)
        # 将s重复时间步的维度,使得其与y维度相同
        #s : [batch_size,step_num,2 * hid_dim]
        #y : [batch_size,step_num,2 * hid_dim]

        e = torch.tanh(self.attn(torch.cat((s,y),dim = 2 )))
        #e : [batch_size,step_num,hid_dim]

        attention = self.v(e).squeeze(2)
        #attention : [batch,step_num]

        return F.softmax(attention,dim=1)

图解:

第三步:构建decoder:

代码:


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class decoder(nn.Module):
    def __init__(self,attention):
        super().__init__()
        self.attention = attention
        self.embedded = nn.Embedding(dic_num,dim_num)
        self.gru = nn.GRU(hid_dim_num*2+dim_num,hid_dim_num)
        self.fc = nn.Linear(hid_dim_num*4+dim_num,dic_num)
        self.dropout = nn.Dropout()

    def forward(self,x,s,y):
        # x : [batch_size]
        # s : [batch_size,2 * hid_dim]
        # y : [step_num, batch_size, 2 * hid_dim]
        x = x.unsqueeze(1)
        embedded = self.dropout(self.embedded(x)).transpose(0,1)
        # embedded : [1,batch_size,dim_num]
        a = self.attention(s,y).unsqueeze(1)
        # a : [batch,1,step_num]
        y = y.transpose(0,1)
        # y : [batch_size,step_num,2*hid_dim]
        c = torch.bmm(a,y).transpose(0,1)
        # c : [1,batch_size,2*hid_dim]
        rnn_input = torch.cat((embedded,c),dim=2)
        #rnn_input : [1,batch_size,2*hid_dim + dim_num]
        dec_y , dec_h = self.gru(rnn_input,s.unsqueeze(0))

        #dec_y : [1,batch_size,hid_dim]
        #dec_h : [1,batch_size,hid_dim]

        embedded = embedded.squeeze(0)
        dec_y = dec_y.squeeze(0)
        c = c.squeeze(0)
        dec_h = dec_h.squeeze(0)

        pred = self.fc(torch.cat((dec_y,c,embedded),dim=1))
        return pred,dec_h

图解:

结尾:

下一步我们将开始实际的项目,敬请期待!

 
目前共有0条评论
  • 暂无Trackback
你目前的身份是游客,评论请输入昵称和电邮!