加性注意力和乘性注意力,其实只是两种不同的计算方式
加性注意力提出于编码解码结构,乘性注意力提出于transformer
加性注意力通过一个前馈神经网络(feed-forward network)计算查询(query)和键(key)之间的相似
点积注意力通过计算查询(query)和键(key)的点积来衡量相似度,并加上一个缩放因子
加性注意力涉及一个可学习参数效率 v ,并且计算效率也低于乘性注意力
加性注意力:
import torch
import torch.nn as nn
import torch.nn.functional as F
class AdditiveAttention(nn.Module):
def __init__(self, query_dim, key_dim, energy_dim):
super(AdditiveAttention, self).__init__()
self.W1 = nn.Linear(query_dim, energy_dim, bias=False)
self.W2 = nn.Linear(key_dim, energy_dim, bias=False)
self.v = nn.Parameter(torch.rand(energy_dim))
nn.init.xavier_uniform_(self.W1.weight)
nn.init.xavier_uniform_(self.W2.weight)
nn.init.uniform_(self.v, -0.1, 0.1)
def forward(self, query, keys, values):
"""
query: (batch_size, query_dim)
keys: (batch_size, seq_length, key_dim) 比如 lstm 编码器的h0,h1......h_seq_length
values: (batch_size, seq_length, value_dim)
"""
# Expand query to (batch_size, seq_length, query_dim)
query = query.unsqueeze(1).repeat(1, keys.size(1), 1)
# Compute energy
energy = torch.tanh(self.W1(query) + self.W2(keys)) # (batch_size, seq_length, energy_dim)
# Compute scores
scores = torch.matmul(energy, self.v) # (batch_size, seq_length)
# Compute attention weights
attn_weights = F.softmax(scores, dim=1) # (batch_size, seq_length)
# Compute context vector
context = torch.bmm(attn_weights.unsqueeze(1), values).squeeze(1) # (batch_size, value_dim)
return context, attn_weights
乘性注意力:
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiplicativeAttention(nn.Module):
def __init__(self, query_dim, key_dim, scale=True):
super(MultiplicativeAttention, self).__init__()
self.scale = scale
if self.scale:
self.scale_factor = torch.sqrt(torch.FloatTensor([key_dim]))
self.W = nn.Linear(query_dim, key_dim, bias=False)
nn.init.xavier_uniform_(self.W.weight)
def forward(self, query, keys, values):
"""
query: (batch_size, query_dim)
keys: (batch_size, seq_length, key_dim)
values: (batch_size, seq_length, value_dim)
"""
# Transform query
query = self.W(query) # (batch_size, key_dim)
# Expand query to (batch_size, seq_length, key_dim)
query = query.unsqueeze(1).repeat(1, keys.size(1), 1)
# Compute scores (dot product)
scores = torch.sum(query * keys, dim=2) # (batch_size, seq_length)
# Scale scores
if self.scale:
scores = scores / self.scale_factor.to(scores.device)
# Compute attention weights
attn_weights = F.softmax(scores, dim=1) # (batch_size, seq_length)
# Compute context vector
context = torch.bmm(attn_weights.unsqueeze(1), values).squeeze(1) # (batch_size, value_dim)
return context, attn_weights
query = query.unsqueeze(1).repeat(1, keys.size(1), 1)解释:
查询和键的形状需要匹配,便于执行相应的操作(如点积或加法)
查询向量(Query):通常来自解码器的隐藏状态,形状为 (batch_size, query_dim)
键向量(Keys):通常来自编码器的隐藏状态,形状为 (batch_size, seq_length, key_dim)
query.unsqueeze(1) 在第二个维度插入一个维度,变成 (batch_size, 1, query_dim)
repeat(1, keys.size(1), 1),第1个维度,重复1次,即不重复。第2个维度,重复 keys.size(1) 次。第3个维度重复1次,即不重复。变成 (batch_size, seq_length, key_dim)