脚本之家,脚本语言编程技术及教程分享平台!
分类导航

Python|VBS|Ruby|Lua|perl|VBA|Golang|PowerShell|Erlang|autoit|Dos|bat|shell|

服务器之家 - 脚本之家 - Python - 大模型中常用的注意力机制GQA详解以及Pytorch代码实现

大模型中常用的注意力机制GQA详解以及Pytorch代码实现

2024-04-07 15:23DeepHub IMBA Python

分组查询注意力 (Grouped Query Attention) 是一种在大型语言模型中的多查询注意力 (MQA) 和多头注意力 (MHA) 之间进行插值的方法,它的目标是在保持 MQA 速度的同时实现 MHA 的质量。

分组查询注意力 (Grouped Query Attention) 是一种在大型语言模型中的多查询注意力 (MQA) 和多头注意力 (MHA) 之间进行插值的方法,它的目标是在保持 MQA 速度的同时实现 MHA 的质量。

这篇文章中,我们将解释GQA的思想以及如何将其转化为代码。

GQA是在论文 GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints paper.中提出,这是一个相当简单和干净的想法,并且建立在多头注意力之上。

大模型中常用的注意力机制GQA详解以及Pytorch代码实现

GQA

标准多头注意层(MHA)由H个查询头、键头和值头组成。每个头都有D个维度。Pytorch的代码如下:

from torch.nn.functional import scaled_dot_product_attention
 
 # shapes: (batch_size, seq_len, num_heads, head_dim)
 query = torch.randn(1, 256, 8, 64)
 key = torch.randn(1, 256, 8, 64)
 value = torch.randn(1, 256, 8, 64)
 
 output = scaled_dot_product_attention(query, key, value)
 print(output.shape) # torch.Size([1, 256, 8, 64])

对于每个查询头,都有一个对应的键。这个过程如下图所示:

大模型中常用的注意力机制GQA详解以及Pytorch代码实现

而GQA将查询头分成G组,每组共享一个键和值。可以表示为:

大模型中常用的注意力机制GQA详解以及Pytorch代码实现

使用可视化的表示就能非常清楚的了解GQA的工作原理,就像我们上面说的那样,GQA是一个相当简单和干净的想法

Pytorch代码实现

让我们编写代码将这种将查询头划分为G组,每个组共享一个键和值。我们可以使用einops库有效地执行对张量的复杂操作。

首先,定义查询、键和值。然后设置注意力头的数量,数量是随意的,但是要保证num_heads_for_query % num_heads_for_key = 0,也就是说要能够整除。我们的定义如下:

import torch
 
 # shapes: (batch_size, seq_len, num_heads, head_dim)
 query = torch.randn(1, 256, 8, 64)
 key = torch.randn(1, 256, 2, 64)
 value = torch.randn(1, 256, 2, 64)
 
 num_head_groups = query.shape[2] // key.shape[2]
 print(num_head_groups) # each group is of size 4 since there are 2 kv_heads

为了提高效率,交换seq_len和num_heads维度,einops可以像下面这样简单地完成:

from einops import rearrange
 
 query = rearrange(query, "b n h d -> b h n d")
 key = rearrange(key, "b s h d -> b h s d")
 value = rearrange(value, "b s h d -> b h s d")

然后就是需要在查询矩阵中引入”分组“的概念。

from einops import rearrange
 query = rearrange(query, "b (h g) n d -> b g h n d", g=num_head_groups)
 print(query.shape) # torch.Size([1, 4, 2, 256, 64])

上面的代码我们将二维重塑为二维:对于我们定义的张量,原始维度8(查询的头数)现在被分成两组(以匹配键和值中的头数),每组大小为4。

最后最难的部分是计算注意力的分数。但其实它可以在一行中通过insum操作完成的

from einops import einsum, rearrange
 # g stands for the number of groups
 # h stands for the hidden dim
 # n and s are equal and stands for sequence length
  
 scores = einsum(query, key, "b g h n d, b h s d -> b h n s")
 print(scores.shape) # torch.Size([1, 2, 256, 256])

scores张量和上面的value张量的形状是一样的。我们看看到底是怎么操作的

einsum帮我们做了两件事:

1、一个查询和键的矩阵乘法。在我们的例子中,这些张量的形状是(1,4,2,256,64)和(1,2,256,64),所以沿着最后两个维度的矩阵乘法得到(1,4,2,256,256)。

2、对第二个维度(维度g)上的元素求和——如果在指定的输出形状中省略了维度,einsum将自动完成这项工作,这样的求和是用来匹配键和值中的头的数量。

最后是注意分数与值的标准乘法:

import torch.nn.functional as F
 
 scale = query.size(-1) ** 0.5
 attention = F.softmax(similarity / scale, dim=-1)
 
 # here we do just a standard matrix multiplication
 out = einsum(attention, value, "b h n s, b h s d -> b h n d")
 
 # finally, just reshape back to the (batch_size, seq_len, num_kv_heads, hidden_dim)
 out = rearrange(out, "b h n d -> b n h d")
 print(out.shape) # torch.Size([1, 256, 2, 64])

这样最简单的GQA实现就完成了,只需要不到16行python代码:

大模型中常用的注意力机制GQA详解以及Pytorch代码实现

最后再简单提一句MQA:多查询注意(MQA)是另一种简化MHA的流行方法。所有查询将共享相同的键和值。原理图如下:

大模型中常用的注意力机制GQA详解以及Pytorch代码实现

可以看到,MQA和MHA都可以从GQA推导出来。具有单个键和值的GQA相当于MQA,而具有与头数量相等的组的GQA相当于MHA。

GQA的好处是什么?

GQA是最佳性能(MQA)和最佳模型质量(MHA)之间的一个很好的权衡。

下图显示,使用GQA,可以获得与MHA几乎相同的模型质量,同时将处理时间提高3倍,达到MQA的性能。这对于高负载系统来说可能是必不可少的。

大模型中常用的注意力机制GQA详解以及Pytorch代码实现

在pytorch中没有GQA的官方实现。所以我找到了一个比较好的非官方实现,有兴趣的可以试试:

https://github.com/fkodom/grouped-query-attention-pytorch

GQA论文:

https://arxiv.org/pdf/2305.13245.pdf

原文地址:https://mp.weixin.qq.com/s?__biz=MzU5OTM2NjYwNg==&mid=2247506505&idx=1&sn=aec6d008c6d4af1a9afa1963bcc94be3

延伸 · 阅读

精彩推荐
  • Pythonpython3.4用函数操作mysql5.7数据库

    python3.4用函数操作mysql5.7数据库

    这篇文章主要为大家详细介绍了python3.4用函数操作mysql5.7数据库,具有一定的参考价值,感兴趣的小伙伴们可以参考一下...

    猪冰龙2612020-11-20
  • Pythonpython Paramiko使用示例

    python Paramiko使用示例

    这篇文章主要介绍了python Paramiko的使用示例,帮助大家远程控制类 UNIX 系统,感兴趣的朋友可以了解下。...

    Starryland12792020-09-21
  • Pythonwin10系统配置GPU版本Pytorch的详细教程

    win10系统配置GPU版本Pytorch的详细教程

    这篇文章主要介绍了win10系统配置GPU版本Pytorch,本文通过图文并茂的形式给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友...

    SGchi8542021-10-22
  • PythonPython利用wxPython模块打造ChatGPT式打字效果程序

    Python利用wxPython模块打造ChatGPT式打字效果程序

    这篇文章主要为大家介绍了如何利用Python和wxPython模块打造一个ChatGPT式打字效果程序,从而增强用户体验或提高应用程序的可读性,感兴趣的可以了解一下...

    winfredzhang6762023-05-06
  • Python解决Keras TensorFlow 混编中 trainable=False设置无效问题

    解决Keras TensorFlow 混编中 trainable=False设置无效问题

    这篇文章主要介绍了解决Keras TensorFlow 混编中 trainable=False设置无效问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧 ...

    芥末的无奈4782020-06-29
  • Python深入解析NumPy中的Broadcasting广播机制

    深入解析NumPy中的Broadcasting广播机制

    在吴恩达老师的深度学习专项课程中,老师有提到NumPy中的广播机制,同时那一周的测验也有涉及到广播机制的题目。那么,到底什么是NumPy中的广播机制?...

    沧夜202110192021-11-17
  • Pythonpython自动化测试工具Helium使用示例

    python自动化测试工具Helium使用示例

    大家好,本篇文章主要讲的是python自动化测试工具Helium使用示例,感兴趣的同学赶快来看一看吧,对你有帮助的话记得收藏一下哦...

    Python 集中营8312022-03-10
  • Python如何利用Python获取鼠标的实时位置

    如何利用Python获取鼠标的实时位置

    这篇文章主要给大家介绍了关于如何利用Python获取鼠标的实时位置的相关资料,主要利用的是pyautogui,一个自动化键鼠操作的Python类库,需要的朋友可以参考下...

    18岁小白想成大牛7992022-09-06