今天看啥  ›  专栏  ›  jiangweijie1981

文献阅读·DTN(Domain Transfer Network)

jiangweijie1981  · 简书  ·  · 2020-02-21 23:11

简介

Unsupervised cross-domain image generation.Cited-476.Open source(unofficial): https://github.com/taey16/DomainTransferNetwork.pytorch

关键字

域迁移,域适应,无监督,深度学习,机器学习

正文

1. 任务和思路

把含标签的域 S 中的样本 x 转换到相关的不带标签的域 T 中,希望转换后的样本 \tilde x 保持类别标签。

为了达到这样的目的,在转换的过程中希望这些样本的语义保持不变而且这些语义在两个域的表达是共同不变的,那就希望有个语义映射的函数 f 来完成这样的任务,即有 f(x)=f(\tilde x) 。当然为了保证 \tilde x 是符合域 T 的分布,还需要判别器 D 来帮忙, D 判别的对象是样本 \tilde x ,而不是特征 f(\tilde x) ,因此还要有个生成器 g 帮忙把 f(\tilde x) 生成为样本 g(f(\tilde x)) ,最后整个思路就清晰啦,有三个组件, f 提取特征, gf 提取好的特征生成对应域 T 的样本, D 判别 g 生成的样本是否符合 T 的分布。

2. 结构

结构含3个部分,分别是可以提取两个域样本特征的 f ,可以生成目标域样本的生成器 g ,可以判别样本是否属于目标域的判别器 D ,如图(文献Fig1)所示:

结构.png

这里的结构有点像 VAE-GAN ,不同的是VAE的encoder参数是训练出来的,而这边的f(encoder)是预先训练好的,从预先训练好的这点上来看,又有点儿像 Cycada ,区别是Cycada是利用了分类的预测标签来保持语义,DTN(本文)是利用分类器的特征层(softmax前的最后一层)来对齐语义。还有一点,这边的判别器 D 是3个输出的。

3. 训练过程和损失函数

训练过程类似GAN, f,g 合在一起当作GAN中的生成器 GD 就是判别器; GD 交替训练更新参数,对应的损失分别如下,注意这里的 f 是事先在源域训练好的。

首先是更新 D 的损失:
L_D=-E_{x\in s}[\log D_1(g(f(x)))]-E_{x\in t}[\log D_2(g(f(x)))]-E_{x\in s}[\log D_3(x)]

接下来是更新 G 的损失:

L_G=L_{GANG}+\alpha L_{CONST}+\beta L_{TID}+\gamma L_{TV}

第1项是判别损失;

L_{GANG}=-E_{x\in s}[\log D_3(g(f(x)))]-E_{x\in t}[\log D_3(g(f(x)))]

第2项是源域的特征重构损失;

L_{CONST}=\sum_{x\in s}d(f(x),f(g(f(x))))

第3项是目标域的样本重构损失;

L_{TID}=\sum_{x\in t}d_2(x,g(f(x)))

第4项是目标域的样本平滑正则化;

L_{TV}(z)=\sum_{i,j}((z_{i,j+1}-z_{i,j})^2+(z_{i+1,j}-z_{i,j})^2)^{\frac B2}

4. 实验

(1)语义保持

先是SVHN向MNIST的转换,使用语义保持来说明,语义保持的度量使用MNIST上训练好的分类器来计算,结果如图(文献Table1,Table2):

数字集效果.png

(2)视觉效果(真实人脸 \rightarrow 卡通人脸)

这里与Gatys的 风格迁移 作了对比(文献Fig5),a+b合成c(风格迁移),d(本文算法)的效果:

人脸效果.png

参考资料

[1] Taigman, Yaniv, Adam Polyak, and Lior Wolf. "Unsupervised cross-domain image generation." arXiv preprint arXiv:1611.02200 (2016).




原文地址:访问原文地址
快照地址: 访问文章快照