tf.contrib.metrics.streaming_curve_points

Computes curve (ROC or PR) values for a prespecified number of points.

The streaming_curve_points function creates four local variables, true_positives, true_negatives, false_positives and false_negatives that are used to compute the curve values. To discretize the curve, a linearly spaced set of thresholds is used to compute pairs of recall and precision values.

For best results, predictions should be distributed approximately uniformly in the range [0, 1] and not peaked around 0 or 1.

For estimation of the metric over a stream of data, the function creates an update_op operation that updates these variables.

If weights is None, weights default to 1. Use weights of 0 to mask values.

Args
labels A Tensor whose shape matches predictions. Will be cast to bool.
predictions A floating point Tensor of arbitrary shape and whose values are in the range [0, 1].
weights Optional Tensor whose rank is either 0, or the same rank as labels, and must be broadcastable to labels (i.e., all dimensions must be either 1, or the same as the corresponding labels dimension).
num_thresholds The number of thresholds to use when discretizing the roc curve.
metrics_collections An optional list of collections that auc should be added to.
updates_collections An optional list of collections that update_op should be added to.
curve Specifies the name of the curve to be computed, 'ROC' [default] or 'PR' for the Precision-Recall-curve.
name An optional variable_scope name.
Returns
points A Tensor with shape [num_thresholds, 2] that contains points of the curve.
update_op An operation that increments the true_positives, true_negatives, false_positives and false_negatives variables.
Raises
ValueError If predictions and labels have mismatched shapes, or if weights is not None and its shape doesn't match predictions, or if either metrics_collections or updates_collections are not a list or tuple.

precision_recall_at_equal_thresholds method (to improve run time).

© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/contrib/metrics/streaming_curve_points