PrecisionRecall

计算值以构建精确调用曲线。与ClassificationScore类似,此方法应用于真值向量。

bool vector::PrecisionRecall(
   const matrix&                 pred_scores,   // 包含每类概率分布的矩阵
   const ENUM_ENUM_AVERAGE_MODE  mode           // 平均模式
   matrix&                       precision,     // 计算每个阈值的精确值
   matrix&                       recall,        // 计算每个阈值的调用值
   matrix&                       thresholds,    // 以降序排列阈值
   );

参数

pred_scores

[in]  包含一组水平向量(含有每个类的概率)的矩阵。矩阵行数必须与真值向量的大小相对应。

模式

[in] ENUM_AVERAGE_MODE枚举的平均模式。只使用AVERAGE_NONE,AVERAGE_BINARY和AVERAGE_MICRO。

precision

[out]  包含计算精度曲线值的矩阵。如果不使用平均模式 (AVERAGE_NONE),则矩阵中的行数对应于模型类的数量。列数对应于真值向量的大小(或概率分布矩阵pred_score中的行数)。在微平均的情况下,矩阵中的行数对应于阈值的总数,不包括重复项。

recall

[out]  包含计算调用曲线值的矩阵。

threshold

[out]  对概率矩阵进行排序得到的阈值矩阵

 

注意

请参阅ClassificationScore方法的注释。

例如

从mnist.onnx模型收集统计数据的示例(99%准确率)。

//--- 分类指标数据
   vectorf y_true(images);
   vectorf y_pred(images);
   matrixf y_scores(images,10);

//--- 输入输出
   matrixf image(28,28);
   vectorf result(10);
 

//--- 测试
   for(int test=0; test<images; test++)
     {
      image=test_data[test].image;
      if(!OnnxRun(model,ONNX_DEFAULT,image,result))
        {
         Print("OnnxRun error ",GetLastError());
         break;
        }
      result.Activation(result,AF_SOFTMAX);

     //--- 收集数据
      y_true[test]=(float)test_data[test].label;
      y_pred[test]=(float)result.ArgMax();
      y_scores.Row(result,test);
     }    }

 
Accuracy calculation

   vectorf accuracy=y_pred.ClassificationMetric(y_true,CLASSIFICATION_ACCURACY);
   PrintFormat("accuracy=%f",accuracy[0]);
 
accuracy=0.989000

绘制精确调用图的示例,其中精确值绘制在y轴上,调用值绘制在x轴上。此外,精度图和调用图也分别绘制,阈值绘制在x轴上  

   if(y_true.PrecisionRecall(y_scores,AVERAGE_MICRO,mat_precision,mat_recall,mat_thres))
     {
      double precision[],recall[],thres[];
      ArrayResize(precision,mat_thres.Cols());
      ArrayResize(recall,mat_thres.Cols());
      ArrayResize(thres,mat_thres.Cols());
 
      for(uint i=0; i<thres.Size(); i++)
        {
         precision[i]=mat_precision[0][i];
         recall[i]=mat_recall[0][i];
         thres[i]=mat_thres[0][i];
        }
      thres[0]=thres[1]+0.001;
 
      PlotCurve("Precision-Recall curve (micro average)","p-r","",recall,precision);
      Plot2Curves("Precision-Recall (micro average)","precision","recall",thres,precision,recall);
     }

Resulting curves:

Precision-Recall curve

Precision-Recall graph