center loss的TensorFlow实现

本文最开始发表于CSDN[1],目前重新整理并加上更详细的注释,本周内会于Github上传示例代码 代码[2]已上传。

Center loss是ECCV2016中一篇论文《A Discriminative Feature Learning Approach for Deep Face Recognition》[3]提出来的概念,主要思想就是在softmax loss基础上额外加入一个正则项,让网络中每一类样本的特征向量都能够尽量聚在一起,在mnist上的示意效果如下图所示,可以看到每一类样本都聚在类别中心的周围。

具体的原理推导等请参考原论文,网上目前有Caffe以及Mxnet版的实现,我在Facenet[4]实现版本基础上做了修改,个人认为是更加贴近原论文叙述的方法的,欢迎讨论。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def get_center_loss(features, labels, alpha, num_classes):
"""获取center loss及center的更新op

Arguments:
features: Tensor,表征样本特征,一般使用某个fc层的输出,shape应该为[batch_size, feature_length].
labels: Tensor,表征样本label,非one-hot编码,shape应为[batch_size].
alpha: 0-1之间的数字,控制样本类别中心的学习率,细节参考原文.
num_classes: 整数,表明总共有多少个类别,网络分类输出有多少个神经元这里就取多少.

Return:
loss: Tensor,可与softmax loss相加作为总的loss进行优化.
centers: Tensor,存储样本中心值的Tensor,仅查看样本中心存储的具体数值时有用.
centers_update_op: op,用于更新样本中心的op,在训练时需要同时运行该op,否则样本中心不会更新
"""
# 获取特征的维数,例如256维
len_features = features.get_shape()[1]
# 建立一个Variable,shape为[num_classes, len_features],用于存储整个网络的样本中心,
# 设置trainable=False是因为样本中心不是由梯度进行更新的
centers = tf.get_variable('centers', [num_classes, len_features], dtype=tf.float32,
initializer=tf.constant_initializer(0), trainable=False)
# 将label展开为一维的,输入如果已经是一维的,则该动作其实无必要
labels = tf.reshape(labels, [-1])

# 根据样本label,获取mini-batch中每一个样本对应的中心值
centers_batch = tf.gather(centers, labels)
# 计算loss
loss = tf.nn.l2_loss(features - centers_batch)

# 当前mini-batch的特征值与它们对应的中心值之间的差
diff = centers_batch - features

# 获取mini-batch中同一类别样本出现的次数,了解原理请参考原文公式(4)
unique_label, unique_idx, unique_count = tf.unique_with_counts(labels)
appear_times = tf.gather(unique_count, unique_idx)
appear_times = tf.reshape(appear_times, [-1, 1])

diff = diff / tf.cast((1 + appear_times), tf.float32)
diff = alpha * diff

centers_update_op = tf.scatter_sub(centers, labels, diff)

return loss, centers, centers_update_op

  1. 1.http://blog.csdn.net/encodets/article/details/54648015 ↩︎
  2. 2.代码地址:https://github.com/EncodeTS/TensorFlow_Center_Loss ↩︎
  3. 3.https://link.springer.com/chapter/10.1007/978-3-319-46478-7_31 ↩︎
  4. 4.https://github.com/davidsandberg/facenet/blob/master/src/facenet.py#L76-L88 ↩︎