GQA LLM 一文详解MHA MQA原理

  • 电脑网络维修
  • 2024-11-14

前言

本文回忆一下MHA、GQA、MQA,具体解读下MHA、GQA、MQA这三种经常出现留意力机制的原理。

图1 MHA、GQA、MQA一览

self-attention

self-attention

在自留意力机制中,输入理论是一个一致的输入矩阵,而这个矩阵后续会经过乘以不同的权重矩阵来转换成三个不同的向量汇合:查问向量Q、键向量K和值向量V。这三组向量是经过线性变换模式生成:

1.查问向量 (Q): Q=XW

2.键向量 (K): K=XW

3.值向量 (V): V=XW

W,W和W是 可学习的权重矩阵 ,区分对应于查问、键和值。这些矩阵的维度取决于模型的设计,理论它们的输入维度(列数) 是预先定义的,以满足特定的模型架构要求。 在Transformer模型中,经常使用不同的权重矩阵W,W和W来区分生成查问向量Q、键向量K和值向量V的 目标是为了准许模型在不同的示意空间中学习和抽取特色 。这样做参与了模型的灵敏性和表白才干,准许模型区分优化用于婚配(Q 和K)和用于输入消息分解(V)的示意。

在自留意力和多头留意力机制中,经常使用 作为缩放因子启动缩放操作是为了防止在计算点积时由于维度较高造成的数值稳固性疑问。这里的d是键向量的维度。 假设不启动缩放,当d较大时,点积的结果或者会变得十分大,这会造成在运行softmax函数时发生的梯度十分小。 由于softmax函数是经过指数函数计算的,大的输入值会使得局部输入凑近于1,而其余凑近于0,从而造成梯度隐没,这会在反向流传环节中形成梯度十分小,使得学习变得十分缓慢。

经过点积结果除以 ,可以调整这些值的范畴,使得它们不会太大。这样,softmax的输入在一个适合的范畴内, 有助于防止极其的指数运算结果,从而坚持数值稳固性和更有效的梯度流 。这个操作确保了即使在d很大的状况下, 留意力机制也能稳固并有效地学习。

代码成功

import torchimport torch.nn as nnimport torch.nn.functional as Fclass SelfAttention(nn.Module):def __init__(self, seq_length):super(SelfAttention, self).__init__()self.input_size = seq_length# 定义三个权重矩阵:Wq、Wk、Wvself.Wq = nn.Linear(seq_length, seq_length)# 线性变换self.Wk = nn.Linear(seq_length, seq_length)self.Wv = nn.Linear(seq_length, seq_length)def forward(self, input):# 计算Q,K,V 三个矩阵q = self.Wq(input)k = self.Wk(input)v = self.Wv(input)# 计算QK^T,即向量之间的相关度attention_scores = torch.matmul(q, k.transpose(-1, -2)) / torch.sqrt(torch.tensor(float(self.input_size)))# 计算向量权重,softmax归一化attention_weight = F.softmax(attention_scores, dim=-1)# 计算输入output = torch.matmul(attention_weight, v)return outputx = torch.randn(2, 3, 4)Self_Attention = SelfAttention(4)# 传入输入向量的维度output = Self_Attention(x)print(output.shape)

MHA(多头留意力)

Transformer 编码器块内的缩放点积留意力机制和多头留意力机制

MHA计算环节

代码成功

import torchimport torch.nn as nnclass MultiHeadAttention(nn.Module):def __init__(self, embed_dim, num_heads):super(MultiHeadAttention, self).__init__()self.num_heads = num_headsself.head_dim = embed_dim // num_headsself.wq = nn.Linear(embed_dim, embed_dim)self.wk = nn.Linear(embed_dim, embed_dim)self.wv = nn.Linear(embed_dim, embed_dim)self.wo = nn.Linear(embed_dim, embed_dim)def mh_split(self, hidden):batch_size = hidden.shape[0]x = hidden.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)return xdef forward(self, hidden_states, mask=None):batch_size = hidden_states.size(0)# 线性变换q, k, v = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states)# 多头切分q, k, v = self.mh_split(q), self.mh_split(k), self.mh_split(v)# 留意力计算scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))if mask is not None:scores = scores.masked_fill(mask == 0, float('-inf'))attention = torch.softmax(scores, dim=-1)output = torch.matmul(attention, v)# 拼接多头output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)# 线性变换output = self.wo(output)return outputx = torch.rand(2, 3, 36)print(x)output = MultiHeadAttention(36, 6)y = output(x)print(y.shape)

MHA 能够了解输入不同局部之间的相关。但是,这种复杂性是有代价的——对内存带宽的需求很大,尤其是在解码器推理时期。重要疑问的关键在于内存开支。 在自回归模型中,每个解码步骤都须要加载解码器权重以及一切留意键和值。这个环节不只计算量大,而且内存带宽也大。随着模型规模的扩展,这种开支也会参与,使得扩展变得越来越艰难。

因此,多查问留意 (MQA) 应运而生,成为缓解这一瓶颈的处置打算。其理念便捷而有效: 经常使用多个查问头,但只经常使用一个键和值头。这种方法清楚缩小了内存负载,提高了推理速度。

MQA(多查问留意力)

图2 MHA和MQA的差异

MQA是MHA的一种变体,也是用于自回归解码的一种留意力机制。,图1、图2很笼统的描述了MHA和MQA的对比,与MHA 不同的是, MQA 让一切的Head之间共享雷同的一份 K 和 V 矩阵(象征K和V的计算惟一),只让 Q 保管了原始多头的性质 (每个Head存在不同的转换),从而大大缩小 K 和 V 矩阵的参数量以及KV Cache的显存占用,以此来到达优化推理速度,但是会带来精度上的损失。MQA被少量运行于LLM中,如ChatGLM2。

左 - 多头留意力,中 - 多查问留意力,右 - 将现有的 MHA 审核点转换为 MQA

如何将现有的预训练多头留意力模型转换为多查问留意力模型 (MQA)? 从现有的多头模型创立多查问留意力模型触及两个步骤:模型结构的转换和随后的预训练。

代码成功

import torchimport torch.nn as nnclass MultiQuerySelfAttention(nn.Module):def __init__(self, embed_dim, num_heads):super(MultiQuerySelfAttention, self).__init__()self.num_heads = num_headsself.head_dim = embed_dim // num_headsself.wq = nn.Linear(embed_dim, embed_dim)# MHA# self.wk = nn.Linear(embed_dim, embed_dim)# self.wv = nn.Linear(embed_dim, embed_dim)# MQAself.wk = nn.Linear(embed_dim, self.head_dim)self.wv = nn.Linear(embed_dim, self.head_dim)self.wo = nn.Linear(embed_dim, embed_dim)def q_h_split(self, hidden, head_num=None):batch_size, seq_len = hidden.size()[:2]# q拆分多头if head_num == None:x = hidden.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)return xelse:# 这是MQA: 须要拆分k和v,这外面的head_num =1 的# 最终前往维度(batch_size, 1, seq_len, head_dim)return hidden.view(batch_size, seq_len, head_num, self.head_dim).transpose(1, 2)def forward(self, hidden_states, mask=None):batch_size = hidden_states.size(0)# 线性变换q, k, v = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states)# 多头切分# 这是MHA的# q, k ,v= self.split(q), self.split(k), self.split(v)# 这是MQA的q, k, v = self.q_h_split(q), self.q_h_split(k, 1), self.q_h_split(v, 1)# 留意力计算scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))print("scores:", scores.shape)if mask is not None:scores = scores.masked_fill(mask == 0, float('-inf'))attention = torch.softmax(scores, dim=-1)output = torch.matmul(attention, v)# 多头兼并output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)# 线性变换output = self.wo(output)return outputx = torch.rand(3, 12, 512)atten = MultiQuerySelfAttention(512, 8)y = atten(x)print(y.shape)

GQA(分组查问留意力)

只管MQA模式大幅减小了参数数量,但是,带来推理减速的同时会形成模型性能损失,且在训练环节使得模型变得不稳固( 复杂度的降低或者会造成品质降低和训练不稳固 ),因此在此基础上提出了GQA,它将Query启动分组,每个组内共享一组Key、Value。(GQA在LLaMA-2 和 Mistral7B获取运行)

GQA 的数学原理

分组:在 GQA 中,传统多头模型中的查问头 (Q) 被分红 G 组。每组调配一个键 (K) 和值 (V) 头。此性能示意为 GQA-G,其中 G 示意组数。

GQA 的不凡状况

对每个组边疆始头部的键和值投影矩阵启动均值池化,以将MHA模型转换为 GQA 模型。此技术对组中每个头部的投影矩阵启动平均,从而为该组生成单个键和值投影。

经过 应用 GQA,该模型在 MHA 品质和 MQA 速度之间坚持平衡 。由于键值对较少,内存带宽和数据加载需求被最小化。G 的选用代表了一种掂量:更多的组(更凑近 MHA)可带来更高的品质但性能较慢,而更少的组(凑近 MQA)可提高速度但有就义品质的危险。此外,随着模型规模的扩展,GQA 准许内存带宽和模型容量按比例缩小,与模型规模相对应。相比之下,关于更大的模型,在 MQA 中缩小到单个键和值头或者会过于重大。

代码成功

import torchimport torch.nn as nnclass GroupedQueryAttention(nn.Module):def __init__(self, embed_dim, num_heads):super(GroupedQueryAttention, self).__init__()self.num_heads = num_headsself.head_dim = embed_dim // num_headsself.wq = nn.Linear(embed_dim, embed_dim)# 这是MHA的# self.wk = nn.Linear(embed_dim, embed_dim)# self.wv = nn.Linear(embed_dim, embed_dim)# 这是MQA的# self.wk = nn.Linear(embed_dim, self.head_dim)# self.wv = nn.Linear(embed_dim, self.head_dim)# 这是GQA的self.group_num = 4# 这是4个组self.wk = nn.Linear(embed_dim, self.group_num * self.head_dim)self.wv = nn.Linear(embed_dim, self.group_num * self.head_dim)self.wo = nn.Linear(embed_dim, embed_dim)def split(self, hidden, group_num=None):batch_size, seq_len = hidden.size()[:2]# q须要拆分多头if group_num == None:x = hidden.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)return xelse:# 这是kv须要拆分的多头x = hidden.view(batch_size, seq_len, group_num, self.head_dim).transpose(1, 2)x = x[:, :, None, :, :].expand(batch_size, group_num, self.num_heads // group_num, seq_len,self.head_dim).reshape(batch_size, self.num_heads, seq_len, self.head_dim)return xdef forward(self, hidden_states, mask=None):batch_size = hidden_states.size(0)# 线性变换q, k, v = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states)# 多头切分# 这是MHA的# q, k ,v= self.split(q), self.split(k), self.split(v)# 这是MQA的# q, k ,v= self.split(q), self.split(k, 1), self.split(v, 1)# 这是GQA的q, k, v = self.split(q), self.split(k, self.group_num), self.split(v, self.group_num)# 留意力计算scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))print("scores:", scores.shape)if mask is not None:scores = scores.masked_fill(mask == 0, float('-inf'))attention = torch.softmax(scores, dim=-1)output = torch.matmul(attention, v)# 兼并多头output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)# 线性变换output = self.wo(output)return outputx = torch.ones(3, 12, 512)atten = GroupedQueryAttention(512, 8)y = atten(x)print(y.shape)

参考文献

原文链接:​ ​​ ​

  • 关注微信

本网站的文章部分内容可能来源于网络和网友发布,仅供大家学习与参考,如有侵权,请联系站长进行删除处理,不代表本网站立场,转载联系作者并注明出处:https://duobeib.com/diannaowangluoweixiu/5308.html

猜你喜欢

热门标签

洗手盆如何疏浚梗塞 洗手盆为何梗塞 iPhone提价霸占4G市场等于原价8折 明码箱怎样设置明码锁 苏泊尔电饭锅保修多久 长城画龙G8253YN彩电输入指令画面变暗疑问检修 彩星彩电解除童锁方法大全 三星笔记本培修点上海 液晶显示器花屏培修视频 燃气热水器不热水要素 热水器不上班经常出现3种处置方法 无氟空调跟有氟空调有什么区别 norltz燃气热水器售后电话 大连站和大连北站哪个离周水子机场近 热水器显示屏亮显示温度不加热 铁猫牌保险箱高效开锁技巧 科技助力安保无忧 创维8R80 汽修 a1265和c3182是什么管 为什么电热水器不能即热 标致空调为什么不冷 神舟培修笔记本培修 dell1420内存更新 青岛自来水公司培修热线电话 包头美的洗衣机全国各市售后服务预定热线号码2024年修缮点降级 创维42k08rd更新 空调为什么运转异响 热水器为何会漏水 该如何处置 什么是可以自己处置的 重庆华帝售后电话 波轮洗衣机荡涤价格 鼎新热水器 留意了!不是水平疑问! 马桶产生了这5个现象 方便 极速 邢台空调移机电话上门服务 扬子空调缺点代码e4是什么疑问 宏基4736zG可以装置W11吗 奥克斯空调培修官方 为什么突然空调滴水很多 乐视s40air刷机包 未联络视的提高方向 官网培修 格力空调售后电话 皇明太阳能电话 看尚X55液晶电视进入工厂形式和软件更新方法 燃气热水器缺点代码

热门资讯

关注我们

微信公众号