Spatial Transformer Networks 论文笔记

今天重新阅读了Deepmind 2015年发表的[Spatial Transformer Networks][1]论文,结合几篇博客理解透彻了这篇论文的基本思想。

[Deep Learning Paper Implementations: Spatial Transformer Networks - Part I][2]这篇文章详细介绍了仿射变换及双线性插值的原理并提供了示例Python代码; [Part II][3]介绍了Spatial Transformer Networks,但不是很详细,可以结合[深度学习方法(十二):卷积神经网络结构变化——Spatial Transformer Networks][4]来阅读。

STN由三个模块:Localisation net、Grid generator和Sampler构成,在认真看完论文以及这三篇博客之后,大家应该是很容易理解Localisation net的作用的,即回归仿射变换矩阵的参数。

难点在于理解Grid generator和Sampler的部分,主要是如何将仿射变换矩阵与这里的几何变换结合理解。这里强烈推荐我近期看的线性代数的本质这一系列视频,非常形象地将矩阵所代表的线性变换的几何意义用动画解释出来了,看了之后保管你看到一个仿射变换矩阵可以立即想象出它的几何意义。

另外这里的这个公式其实花了我一段时间去理解,即

Grid generator

为什么$x_i^s$在等式左边,而$x_i^t$却在等式右边?我们不是要得到target吗?看上去我们像在对target做变换?其实这里的$T_\theta(G_i)$代表的是对目标网格进行的变换,而不是直接对原图进行的变换。注意下面这张图,原图始终没变,变换的是网格!我们把网格进行仿射变换,然后把变换后的网格放回到原图上,用原图中对应位置的像素值去填充变换后的网格!这样能够保证变换后的输出始终是我们设定的网格的大小!也即意味着我们可以通过控制网格的大小去控制该层输出的图像的最大分辨率(同时仿射变换矩阵也会对图像有作用)。

总的来说,这篇论文的核心就是在于让网络自动学习一个仿射变换矩阵,以便更好地处理分类等任务。TensorFlow版的代码可以参考TensorFlow_models