mirror of
https://github.com/dragen1860/TensorFlow-2.x-Tutorials.git
synced 2021-05-12 18:32:23 +03:00
25 lines
711 B
Python
25 lines
711 B
Python
import tensorflow as tf
|
|
|
|
|
|
def masked_softmax_cross_entropy(preds, labels, mask):
|
|
"""
|
|
Softmax cross-entropy loss with masking.
|
|
"""
|
|
loss = tf.nn.softmax_cross_entropy_with_logits(logits=preds, labels=labels)
|
|
mask = tf.cast(mask, dtype=tf.float32)
|
|
mask /= tf.reduce_mean(mask)
|
|
loss *= mask
|
|
return tf.reduce_mean(loss)
|
|
|
|
|
|
def masked_accuracy(preds, labels, mask):
|
|
"""
|
|
Accuracy with masking.
|
|
"""
|
|
correct_prediction = tf.equal(tf.argmax(preds, 1), tf.argmax(labels, 1))
|
|
accuracy_all = tf.cast(correct_prediction, tf.float32)
|
|
mask = tf.cast(mask, dtype=tf.float32)
|
|
mask /= tf.reduce_mean(mask)
|
|
accuracy_all *= mask
|
|
return tf.reduce_mean(accuracy_all)
|