• 加性注意力和乘性注意力,其实只是两种不同的计算方式

  • 加性注意力提出于编码解码结构,乘性注意力提出于transformer

  • 加性注意力通过一个前馈神经网络(feed-forward network)计算查询(query)和键(key)之间的相似

  • 点积注意力通过计算查询(query)和键(key)的点积来衡量相似度,并加上一个缩放因子

  • 加性注意力涉及一个可学习参数效率 v ,并且计算效率也低于乘性注意力

加性注意力:

import torch
import torch.nn as nn
import torch.nn.functional as F

class AdditiveAttention(nn.Module):
    def __init__(self, query_dim, key_dim, energy_dim):
        super(AdditiveAttention, self).__init__()
        self.W1 = nn.Linear(query_dim, energy_dim, bias=False)
        self.W2 = nn.Linear(key_dim, energy_dim, bias=False)
        self.v = nn.Parameter(torch.rand(energy_dim))
        nn.init.xavier_uniform_(self.W1.weight)
        nn.init.xavier_uniform_(self.W2.weight)
        nn.init.uniform_(self.v, -0.1, 0.1)
    
    def forward(self, query, keys, values):
        """
        query: (batch_size, query_dim)
        keys: (batch_size, seq_length, key_dim)  比如 lstm 编码器的h0,h1......h_seq_length
        values: (batch_size, seq_length, value_dim)
        """
        # Expand query to (batch_size, seq_length, query_dim)
        query = query.unsqueeze(1).repeat(1, keys.size(1), 1)
        
        # Compute energy
        energy = torch.tanh(self.W1(query) + self.W2(keys))  # (batch_size, seq_length, energy_dim)
        
        # Compute scores
        scores = torch.matmul(energy, self.v)  # (batch_size, seq_length)
        
        # Compute attention weights
        attn_weights = F.softmax(scores, dim=1)  # (batch_size, seq_length)
        
        # Compute context vector
        context = torch.bmm(attn_weights.unsqueeze(1), values).squeeze(1)  # (batch_size, value_dim)
        
        return context, attn_weights

乘性注意力:

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiplicativeAttention(nn.Module):
    def __init__(self, query_dim, key_dim, scale=True):
        super(MultiplicativeAttention, self).__init__()
        self.scale = scale
        if self.scale:
            self.scale_factor = torch.sqrt(torch.FloatTensor([key_dim]))
        self.W = nn.Linear(query_dim, key_dim, bias=False)
        nn.init.xavier_uniform_(self.W.weight)
    
    def forward(self, query, keys, values):
        """
        query: (batch_size, query_dim)
        keys: (batch_size, seq_length, key_dim)
        values: (batch_size, seq_length, value_dim)
        """
        # Transform query
        query = self.W(query)  # (batch_size, key_dim)
        
        # Expand query to (batch_size, seq_length, key_dim)
        query = query.unsqueeze(1).repeat(1, keys.size(1), 1)
        
        # Compute scores (dot product)
        scores = torch.sum(query * keys, dim=2)  # (batch_size, seq_length)
        
        # Scale scores
        if self.scale:
            scores = scores / self.scale_factor.to(scores.device)
        
        # Compute attention weights
        attn_weights = F.softmax(scores, dim=1)  # (batch_size, seq_length)
        
        # Compute context vector
        context = torch.bmm(attn_weights.unsqueeze(1), values).squeeze(1)  # (batch_size, value_dim)
        
        return context, attn_weights
  • query = query.unsqueeze(1).repeat(1, keys.size(1), 1)解释:

  • 查询和键的形状需要匹配,便于执行相应的操作(如点积或加法)

  • 查询向量(Query):通常来自解码器的隐藏状态,形状为 (batch_size, query_dim)

  • 键向量(Keys):通常来自编码器的隐藏状态,形状为 (batch_size, seq_length, key_dim)

  • query.unsqueeze(1) 在第二个维度插入一个维度,变成 (batch_size, 1, query_dim)

  • repeat(1, keys.size(1), 1),第1个维度,重复1次,即不重复。第2个维度,重复 keys.size(1) 次。第3个维度重复1次,即不重复。变成 (batch_size, seq_length, key_dim)