来给大家介绍一下我们ICML 2020上的新工作,这应该是我在MIT做的最后一个、也是我最喜欢的工作之一:)因为它刚好打破了传统domain adaptation(DA)的范式,第一次把DA从离散域推广到连续域。先放论文链接:”Continuously Indexed Domain Adaptation”, http://wanghao.in/paper/ICML20_CIDA.pdf
比较幸运的是这次碰到了也喜欢我们工作的reviewer们。先上AC的meta-review。读完甚是感动啊。。
“This paper addresses the topic of domain adaptation for continuous domains. This work is unanimously considered very novel, of high impact - effective problems in the medical domain, with good theoretical and experimental results presented.”
下面进入正文。
离散域DA v.s. 连续域DA:传统的DA,一般都是从一个(或几个)domain,adapt到另一个(或几个)domain,如下图:
但是我们发现:在现实中的任务往往要复杂的多,domain并不是一个个单独分开的,而是连续存在的。比如在医疗应用中(如下图),不同的年龄的病人对应着不同的domain,而年龄是一个连续的时间概念,所以实际上我们做的是无限个连续domain之间的adaptation。
这就是我们说连续域DA,如下图,左边可以有无限个source domain,右边可以有无限个target domain。
举个Toy Dataset上的例子:比如在下面的图里,我们有30个在同一条轨迹上的domain。我们把domain [1, 6] 作为source domain,其他作为target domain。可以看出来,传统的离散域DA并没法很好的学到正确的分界线而我们的连续域DA却可以做到。我们把提出的方法命名为Continuously Indexed Domain Adaptation,简称CIDA(与苹果汁cidar同音)。
方法:有意思的是,要达到这个效果,其实实现很简单,做DA的人瞄一眼下面的图就懂了。红色的部分代表我们和传统的adversarial DA不同的地方。u是我们提出的一个概念,叫做domain index,比如年龄就是一个domain index。z是encoder的输出,也就是encoding。下面的意思是,我们只需要把domain index u加入到encoder里面,同时让discriminator直接预测(或者叫regress)u,就可以了。
但是:当然故事没有那么简单。虽然CIDA大多数情况下work得很好,但是我们发现,理论上如果直接只让discriminator预测一个值u,那么model就只能对齐不同domain的p(u | z)的均值,相当于只match了first moment,这样有可能陷入到一个局部最优。比如我们看上面toy data的前3个domain(如下图)
如果直接用简单的CIDA,那么worst case可能出现下面的情况,我们可以看到下图右边,3个domain并没有对齐,但是他们的p(u | z)的均值(E[u | z])却是相等的。
那么怎么解决这个问题呢?我们发现,上图的右边,虽然3个domain的E[u | z]虽然相等,但是他们的variance(V[u | z])不相等啊!于是就有了CIDA加强版,我们把它叫做probabilistic CIDA,简称PCIDA(如下图)。我们只需要让discriminator同时预测mean跟variance,然后用高斯分布的log-likelihood作为目标函数训练即可。
理论:更有趣的是,我们发现理论上可以证明(如下图),当CIDA和PCIDA训练到最优时,是可以保证对齐p(u | z)的mean和variance的(Theorem 1和2),而且不会影响predictor的效果(Theorem 3),意思就是说,用了CIDA,肯定不会比不用差。完整的theorem和proof请看原文https://arxiv.org/abs/2007.01807
实验结果之Rotating MNIST:除了上面的toy dataset,我们还做了真正接近无限个domain的数据集:Rotating MNIST。我们以旋转的角度作为domain index,构建了一个包含着几十万个domain的MNIST数据集,结果如下图。“45”那一列表示旋转角度在区间[45, 90),以此类推。可以看出来,CIDA可以让所有target domain都达到很高的accuracy,而传统的离散域DA却不行。
实验结果之医疗数据:我们在几个大的医疗数据做了实验,简单地讲,这是个分类的task,然后病人的年龄作为domain index。对于同数据集,我们考虑了两种setting(见下图),
(1)Domain Interpolation,
(2)Domain Extrapolation。
下表就是同数据集里面,不同年龄范围之间的adaptation的accuracy,
下表是跨数据集的adaptation结果,
实验的结论有3条:(1)在真实的医疗数据中,如果直接使用传统的离散域DA,非但无法提高准确率,反而可能降低准确率。(2)CIDA/PCIDA却可以毫无压力地提高准确率。(3)domain extrapolation更加challenging,在这种setting下,CIDA/PCIDA能提高的准确率更多。
彩蛋 -- 多维的连续域DA:沿着CIDA的思路,我们发现,很多情况下,会同时出现不同的domain index,比如年龄和健康程度(用1到100之间的一个数表示)。于是我们提出了一个新的概念,multi-dimensional domain index(如下图)。
实验结果如下,基本可以看出,使用多维的domain index是可以进一步地提高准确率的。
第一次写这么长的一篇,如果有啥编辑或者逻辑不顺,大家轻拍:)最后放一下各种相关的材料链接。注册了会议的同学欢迎来我们周二和周三的QA环节(https://icml.cc/virtual/2020/poster/5986)
Paper: https://arxiv.org/abs/2007.01807 or http://wanghao.in/paper/ICML20_CIDA.pdf
Code and Jupyter Notebooks: https://github.com/hehaodele/CIDA
Video: https://drive.google.com/file/d/1G_51ekjcCTFRsvnGYiR5yFxBItWH2YBp/view or ICML 2020 Oral Talk: Continuously Indexed Domain Adaptation
ICML Talk & Chatroom: https://icml.cc/virtual/2020/poster/5986