1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
| def get_center_loss(features, labels, alpha, num_classes): """获取center loss及center的更新op Arguments: features: Tensor,表征样本特征,一般使用某个fc层的输出,shape应该为[batch_size, feature_length]. labels: Tensor,表征样本label,非one-hot编码,shape应为[batch_size]. alpha: 0-1之间的数字,控制样本类别中心的学习率,细节参考原文. num_classes: 整数,表明总共有多少个类别,网络分类输出有多少个神经元这里就取多少. Return: loss: Tensor,可与softmax loss相加作为总的loss进行优化. centers: Tensor,存储样本中心值的Tensor,仅查看样本中心存储的具体数值时有用. centers_update_op: op,用于更新样本中心的op,在训练时需要同时运行该op,否则样本中心不会更新 """ len_features = features.get_shape()[1] centers = tf.get_variable('centers', [num_classes, len_features], dtype=tf.float32, initializer=tf.constant_initializer(0), trainable=False) labels = tf.reshape(labels, [-1]) centers_batch = tf.gather(centers, labels) loss = tf.nn.l2_loss(features - centers_batch) diff = centers_batch - features unique_label, unique_idx, unique_count = tf.unique_with_counts(labels) appear_times = tf.gather(unique_count, unique_idx) appear_times = tf.reshape(appear_times, [-1, 1]) diff = diff / tf.cast((1 + appear_times), tf.float32) diff = alpha * diff centers_update_op = tf.scatter_sub(centers, labels, diff) return loss, centers, centers_update_op
|