| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 
 | def get_center_loss(features, labels, alpha, num_classes):"""获取center loss及center的更新op
 
 Arguments:
 features: Tensor,表征样本特征,一般使用某个fc层的输出,shape应该为[batch_size, feature_length].
 labels: Tensor,表征样本label,非one-hot编码,shape应为[batch_size].
 alpha: 0-1之间的数字,控制样本类别中心的学习率,细节参考原文.
 num_classes: 整数,表明总共有多少个类别,网络分类输出有多少个神经元这里就取多少.
 
 Return:
 loss: Tensor,可与softmax loss相加作为总的loss进行优化.
 centers: Tensor,存储样本中心值的Tensor,仅查看样本中心存储的具体数值时有用.
 centers_update_op: op,用于更新样本中心的op,在训练时需要同时运行该op,否则样本中心不会更新
 """
 
 len_features = features.get_shape()[1]
 
 
 centers = tf.get_variable('centers', [num_classes, len_features], dtype=tf.float32,
 initializer=tf.constant_initializer(0), trainable=False)
 
 labels = tf.reshape(labels, [-1])
 
 
 centers_batch = tf.gather(centers, labels)
 
 loss = tf.nn.l2_loss(features - centers_batch)
 
 
 diff = centers_batch - features
 
 
 unique_label, unique_idx, unique_count = tf.unique_with_counts(labels)
 appear_times = tf.gather(unique_count, unique_idx)
 appear_times = tf.reshape(appear_times, [-1, 1])
 
 diff = diff / tf.cast((1 + appear_times), tf.float32)
 diff = alpha * diff
 
 centers_update_op = tf.scatter_sub(centers, labels, diff)
 
 return loss, centers, centers_update_op
 
 |