这是 P yTorch 中论文 “注意力就是你所需要的” 多头注意力的教程/实现。该实现的灵感来自带注释的Transformer。
以下是使用带有 MHA 的基本转换器进行 NLP 自动回归的训练代码。
24import math
25from typing import Optional, List
26
27import torch
28from torch import nn
29
30from labml import tracker
33class PrepareForMultiHeadAttention(nn.Module):
44 def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
45 super().__init__()
线性变换的线性层
47 self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
头数
49 self.heads = heads
每个头部中以向量表示的维度数
51 self.d_k = d_k
53 def forward(self, x: torch.Tensor):
输入的形状[seq_len, batch_size, d_model]
或[batch_size, d_model]
。我们将线性变换应用于最后一个维度,然后将其拆分为头部。
57 head_shape = x.shape[:-1]
线性变换
60 x = self.linear(x)
将最后一个维度拆分成头部
63 x = x.view(*head_shape, self.heads, self.d_k)
输出具有形状[seq_len, batch_size, heads, d_k]
或[batch_size, heads, d_model]
66 return x
这将计算给定key
和value
向量的缩放多头注意query
力。
简单来说,它会找到与查询匹配的键,并获取这些键的值。
它使用查询和键的点积作为它们匹配程度的指标。在服用点产品之前,先按比例缩放。这样做是为了避免较大的点积值导致 softmax 在较大时给出非常小的梯度。
Softmax 是沿序列(或时间)的轴计算的。
69class MultiHeadAttention(nn.Module):
heads
是头的数量。d_model
是query
、key
和value
向量中的要素数。90 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
96 super().__init__()
每头特征数
99 self.d_k = d_model // heads
头数
101 self.heads = heads
这些变换了多头注意力的query
、key
和value
向量。
104 self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
105 self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
106 self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
Softmax 在时间维度上引起人们的注意key
109 self.softmax = nn.Softmax(dim=1)
输出层
112 self.output = nn.Linear(d_model, d_model)
辍学
114 self.dropout = nn.Dropout(dropout_prob)
softmax 之前的缩放系数
116 self.scale = 1 / math.sqrt(self.d_k)
我们存储注意事项,以便在需要时将其用于日志记录或进行其他计算
119 self.attn = None
121 def get_scores(self, query: torch.Tensor, key: torch.Tensor):
计算或
129 return torch.einsum('ibhd,jbhd->ijbh', query, key)
mask
有形状[seq_len_q, seq_len_k, batch_size]
,其中第一个维度是查询维度。如果查询维度等于它将被广播。
131 def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):
137 assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
138 assert mask.shape[1] == key_shape[0]
139 assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]
所有头部都使用相同的面具。
142 mask = mask.unsqueeze(-1)
生成的遮罩有形状[seq_len_q, seq_len_k, batch_size, heads]
145 return mask
query
key
和value
是存储查询、键和值向量集合的张量。它们有形状[seq_len, batch_size, d_model]
。
mask
有形状[seq_len, seq_len, batch_size]
并mask[i, j, b]
指示是否为批量查询b
,位置处的查询i
有权访问位置处的键值j
。
147 def forward(self, *,
148 query: torch.Tensor,
149 key: torch.Tensor,
150 value: torch.Tensor,
151 mask: Optional[torch.Tensor] = None):
query
,key
并且value
有形状[seq_len, batch_size, d_model]
163 seq_len, batch_size, _ = query.shape
164
165 if mask is not None:
166 mask = self.prepare_mask(mask, query.shape, key.shape)
准备query
,key
并value
进行注意力计算。然后这些就会有形状[seq_len, batch_size, heads, d_k]
。
170 query = self.query(query)
171 key = self.key(key)
172 value = self.value(value)
计算注意力分数。这给出了形状的张量[seq_len, seq_len, batch_size, heads]
。
176 scores = self.get_scores(query, key)
音阶分数
179 scores *= self.scale
涂抹面膜
182 if mask is not None:
183 scores = scores.masked_fill(mask == 0, float('-inf'))
关注按键序列维度
187 attn = self.softmax(scores)
调试时省去注意力
190 tracker.debug('attn', attn)
申请退学
193 attn = self.dropout(attn)
乘以值
197 x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
保存任何其他计算的注意力
200 self.attn = attn.detach()
连接多个头
203 x = x.reshape(seq_len, batch_size, -1)
输出层
206 return self.output(x)