DALL-E的具体实现,openAI没有公布,github上发布的代码只有一个dVAE的模型,相当于只有一半。
但Hugging Face和谷歌云团队,按照自己的理解,实现了一个DALL·E mini模型,可从中了解一二。
DALL·E mini模型,在限于更小的硬件资源的情况下,效果还不错,模型比原来的DALL-E小27倍,在单个TPU v3-8上只训练了3天。看到Literature Review了10天,我瞬间绷不住了,真香。
模型结构如下图,在训练过程中,输入图像和文本对。
在推理时,只使用标题,用于生成图像:
图像编码器和解码器,使用的是VQGAN。
VQGAN的目标是将图像编码为一连串的离散token,如果直接用pixel当做token,那边词表会有256^3那么大,而序列长度有256 * 256那么大,放进transformer内存瞬间就炸了。
codebook使用感知损失加GAN判别器损失来训练。编码器输出与codebook索引对应。一旦图像被编码成一连串的token,它就可以被用于任何transformer模型。
在DALL·E mini模型中,图像将被编码为16 x 16 = 256个离散的token,这些token来自16384大小的词汇表。解码后的图像是256 x 256(每边16 x 16)。
训练方法是用seq2seq的方法,用前面的序列预测下一个token,结合上mask的设计,可以实现行预测或者列预测。这种transformer做seq2seq的方法来自于UNILM,感兴趣的同学可以去了解下。
DALL-E官方论文代码终于放出,OpenAI是如何实现图像版GPT-3的?
今年1月份openAI发布了DALL-E模型,能够根据文本生成效果惊艳的图像,并且参数量达到了120亿,被称为“图像版GPT-3”。
最近,openAI放出了DALL-E的论文和部分代码,使得大家能够进一步一窥究竟。根据本次开出的论文《Zero-Shot Text-to-Image Generation》[1],简单整理了一下DALL-E的整体架构,如图1所示,DALL-E的推理主要分为三个阶段,其中前两个阶段对应论文中的Stage One和Stage Two。
在第一个阶段,将256×256的图片分为32×32个patch,然后使用训练好的离散VAE模型的encoder将每个patch映射到大小为8192的词表中,最终一张图片转为用1024个token表示。在第二个阶段,使用BPE-encoder对文本进行编码,得到最多256个token,token数不满256的话padding到256;再将256个文本token与1024个图像token进行拼接,得到长度为1280的数据;最终将拼接的数据输入训练好的具有120亿参数的Transformer模型。在第三个阶段,对模型生成的图像进行采样,并使用同期发布的CLIP模型[2]对采样结果进行排序,从而得到与文本最匹配的生成图像。
DALLE包括三个独立训练得到的模型:dVAE,Transformer和CLIP,其中dVAE的训练与VAE基本相同,Transformer采用类似GPT-3的生成式预训练方法。下面对DALL-E采用的dVAE模型和Transformer模型做简单介绍,对CLIP感兴趣的朋友可以参考[2]。
dVAE主要用来为图像的每个patch生成token表示,这次openAI开出的代码就是dVAE的推理代码。dVAE的encoder和decoder的机构较为简单,都是由bottleneck-style的resblock组成,但与常见的VAE相比,dVAE有以下两点区别:
1、dVAE的encoder是将图像的patch映射到8192的词表中,论文中将其分布设为
在词表向量上的均匀分类分布,这是一个离散分布,由于不可导的问题,此时不能采用重参数技巧。DALL-E使用了Gumbel-SoftMax trick来解决这个问题,对Gumbel-SoftMax trick感兴趣的朋友可以参考[3]。
2、在重建图像时,真实的像素值是在一个有界区间内,而VAE中使用的Gaussian
分布和Laplace分布都是在整个实数集上,这造成了不匹配的问题。为了解决这个问题,论文中提出了logit-Laplace分布,如下式所示:
Dall-E中的Transformer结构由64层attention层组成,每层的注意力头数为62,每个注意力头的维度为64,因此,每个token的向量表示维度为3968。如图2所示,attention层使用了行注意力mask、列注意力mask和卷积注意力mask三种稀疏注意力。
Transformer的输入如图3所示,其中pad embd通过学习得到,根据论文介绍,为每个位置都训练了一个pad embd,即256个pad embd,在对文本token进行pad时,使用对应位置的pad embd。
总的来说,目前公开的DALL-E的实现在模型结构上并没有太多创新,而是合理利用了现有的模型结构进行组合,并采用了一些trick解决了遇到的问题,从而在大数据集上训练得到超大规模的模型,取得了令人惊艳的效果,这也符合openAI的一贯风格。但无论如何,DALL-E在深度学习能力边界探索的道路上又前进了一步,也再一次展示了大数据和超大规模模型的魅力。美中不足的是,DALL-E包含了三个模块,更像是一个pipeline,而对于普通的研究者来说,要运行这样一个复杂的大规模模型是一件很困难的事情。
参考文献请见:
【1】Zero-Shot Text-to-Image Generation, 2021.
【2】Learning transferable visual models from natural language supervision, 2020
【3】The Gumbel-Softmax Trick for Inference of Discrete Variables
(https://casmls.github.io/general/2017/02/01/GumbelSoftmax.html)