没有multi-head,不就是普通的的多层全联接网络么。multi-head把维度从[batch, len, embeding]变为[batch, len, head, embeding/head]。计算attention的时候,就会拥有多个注意力点。如果没有多头一次只能计算一个注意力点,这就和google早期的seq2seq的Bahdanau(https://arxiv.org/abs/1409.0473)类似了,而如果采用这样的方式,效果还不如Bahdanau注意力机制。多头的好处就是把embeding从一份特征变成了多份特征,这样就是弥补了相对于Bahdanau类注意力的短板。但是实际的模型训练中要适当选择head的大小和embeding的大小。
评论区小伙伴建议举例子来说说,这个提议很好。就贴一些经典的图来帮助理解。
强烈建议仔细读读Attention Is All You Need(https://arxiv.org/pdf/1706.03762.pdf)
可以看看早期注意力和transformer(也就是multi-head注意力)的动态图展示。
下面这两个都是翻译模型用到的,非多头的注意力模型。第一个是google的seq2seq模型,使用LSTM堆叠,注意力使用Bahdanau。第二个是Facebook的基于CNN的注意力模型。这两个都是诞生于2018年之前的(ps:transformer在2017年提出,记得应该是比Facebook CNN翻译模型时间差不多,只不过BERT还没出现,它就比较低调)。当时的主流还是使用一次编码,然后拿到编码结果,算和目标之间的注意力,目标有10个字,就循环计算10次,100字就计算100次(ps:Facebook的cnn可以并行计算)。当时这些方法都是在翻译领域达到SOTA的效果。
下面再看看多头注意力的(这是一个基于transformer的翻译模型过程):
BERT其实是使用了上图中的encoder部分,也就是先句子内部进行注意力计算,也就是self- attention。
上面对比了几种主要的注意的区别,给大家在视觉上有一个直观的感受。现在回到题主的问题,为什么要把最后一维度换成12*64?高赞的苏神的回答,已经从数学角度解释了。
我现在从它这个机制运行角度猜测下吧。看上面transformer的动态图,这里我们只看encoding部分,因为bert只使用了transformer结构的encoder。可以看到有很多线在token之间穿梭,这些线就相当于是分裂出的12个头,它们在token之间互相学习,最终得到下面一个结果:
上面例子,it就和the animal的关联性比较高(颜色深的),还有一些次高的(颜色较淡)。在做BERT的MLM任务时候,我们推断被mask掉的那个词,不太可能只靠一个注意力去预测,需要结合多个。多头应该是赋予了句子更多的表达信息。
如果去掉多头,能不能做?那当然是能做。但是表达能力要弱很多很多。效果远不如transformer之前的注意力方式。
大家有啥更好想法可以一起讨论~
BERT之所以被挖坟这么久,其实很多坑还是没有填上的。多头只是其中一部分经典设计,像残差,位置编码,FFN这些模块之间有啥联系,都需要去进一步探讨~
ps:传动图,无法上传,只能录屏了