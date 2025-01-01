文档部分
计算分类指标来评估预测数据相对于真实数据的质量。

与机器学习部分中的其他方法不同，这个指标适用于真实值的向量，而不是预测值的向量。

vector vector::ClassificationScore(
   const matrix&              pred_scores,   // 包含每一类概率分布的矩阵 
   ENUM_CLASSIFICATION_METRIC metric         // 指标类型
   ENUM_AVERAGE_MODE          mode           // 平均模式
   );
 
 
vector vector::ClassificationScore(
   const matrix&              pred_scores,   // 包含每一类概率分布的矩阵 
   ENUM_CLASSIFICATION_METRIC metric         // 指标类型
   int                        param          // 附加参数
   );

参数

pred_scores

[in] 包含一组水平向量和每一类概率的矩阵。矩阵行数应与真值向量的大小相对应。

metric

[in] ENUM_CLASSIFICATION_METRIC枚举的指标类型。使用CLASSIFICATION_TOP_K_ACCURACY， CLASSIFICATION_AVERAGE_PRECISION和CLASSIFICATION_ROC_AUC值。

mode

[in] ENUM_AVERAGE_MODE枚举的平均模式。用于CLASSIFICATION_AVERAGE_PRECISION和CLASSIFICATION_ROC_AUC指标。

param

[in]  对于CLASSIFICATION_TOP_K_ACCURACY指标，应指定整数K值来替代平均模式。

 

返回值

包含计算指标的向量。在AVERAGE_NONE平均模式的情况下，向量包含每个类别的指标值，而不进行平均。（例如，在二元分类的情况下，这将是分别用于'false'和'true'的两个指标）。

关于平均模式的说明

AVERAGE_BINARY只对二元分类有意义。

AVERAGE_MICRO ― 通过将标签指标矩阵的每个元素视为一个标签来计算全局指标。标签指标矩阵是指一个矩阵，其中包含每个标签的一组概率。

AVERAGE_MACRO ― 计算每个标签的指标并找到其未加权平均值。这没有考虑标签不平衡。

AVERAGE_WEIGHTED ― 计算每个标签的指标并找到其按支持度加权的平均值（每个标签的真实实例数）。

注意

在二元分类的情况下，我们不仅可以输入一个n x 2矩阵，其中第一列包含负标签的概率，第二列包含正标签的概率，还可以输入一个由一列正概率组成的矩阵。这是因为二元分类模型可以返回正标签的两个概率或一个概率。

示例：

   vector y_true={7,2,1,0,4,1,4,9,5,9,0,6,9,0,1,5,9,7,3,4,8,4,2,7,6,8,4,2,3,6};
   //vector y_pred={7,2,1,0,4,1,4,9,5,9,0,6,9,0,1,5,9,7,3,4,2,9,4,9,5,9,2,7,7,0};
 
//--- label scores          0         1         2         3         4         5         6         7         8         9    true pred
   matrix y_scores={{0.0001090.0001860.0004490.0000520.0000020.0000220.0000050.9980590.0000100.001104},  // 7    7
                    {0.0000910.0819560.9168160.0011060.0000060.0000020.0000010.0000000.0000210.000000},  // 2    2
                    {0.0001080.9728630.0036000.0000210.0104790.0000150.0001310.0103850.0023390.000060},  // 1    1
                    {0.9254250.0000800.0029130.0000570.0002740.0006380.0635290.0003160.0000950.006673},  // 0    0
                    {0.0000600.0001260.0000060.0000000.9935130.0000000.0000030.0002220.0000010.006069},  // 4    4
                    {0.0000160.9821240.0000450.0000020.0084450.0000010.0000050.0092300.0001200.000013},  // 1    1
                    {0.0000000.0000400.0000010.0000000.9893950.0001670.0000040.0000700.0001770.010146},  // 4    4
                    {0.0007950.0029380.0234470.0074180.0218380.0024760.0002600.0475510.0000820.893194},  // 9    9
                    {0.0000910.0002260.0000380.0000070.0000480.8549100.0686440.0000800.0010970.074860},  // 5    5
                    {0.0000000.0000000.0000000.0000000.0030040.0000000.0000000.0000350.0000000.996960},  // 9    9
                    {0.9988560.0000090.0009760.0000020.0000000.0000130.0001310.0000060.0000000.000007},  // 0    0
                    {0.0001780.0004460.0003260.0000330.0001930.0000710.9984030.0000150.0003280.000007},  // 6    6
                    {0.0000050.0000160.0001530.0000450.0041100.0000120.0000150.0000310.0000760.995537},  // 9    9
                    {0.9941880.0000030.0025840.0000050.0000050.0001000.0007390.0014730.0000380.000864},  // 0    0
                    {0.0001730.9905690.0007920.0000400.0017980.0000350.0001140.0047500.0017160.000013},  // 1    1
                    {0.0000000.0005370.0000080.0050800.0000460.9929100.0000120.0006710.0003900.000347},  // 5    5
                    {0.0001270.0000030.0000030.0000000.0015830.0000000.0000020.0005550.0000160.997712},  // 9    9
                    {0.0000010.0000120.0000720.0000200.0000000.0000000.0000000.9998680.0000000.000026},  // 7    7
                    {0.0000200.0001050.0011390.9013430.0021320.0838730.0001240.0000970.0109810.000186},  // 3    3
                    {0.0000020.0000480.0000190.0000000.9993470.0000020.0000400.0000510.0000000.000489},  // 4    4
                    {0.0000590.0013440.6125020.0027490.0002290.0006780.0000380.0018440.3797270.000831},  // 8    2
                    {0.0005860.0007400.0016250.0000070.2693410.0000760.0164170.0001990.0001070.710902},  // 4    9
                    {0.0095470.0180550.2837950.0710790.4260740.0823350.0363790.0211880.0039240.047623},  // 2    4
                    {0.0025060.0025450.0011480.0056590.0204160.0001120.0060920.2725360.0031480.685839},  // 7    9
                    {0.0012630.0017690.0002930.0000110.0003020.8817680.1120190.0001250.0023270.000123},  // 6    5
                    {0.0029040.0029090.0134210.0014610.0075190.0012510.0005550.1062190.1071250.756637},  // 8    9
                    {0.0000550.0010800.8931580.0000000.1044920.0001590.0010420.0000130.0000000.000000},  // 4    2
                    {0.0003440.0026930.0711840.0002620.0000010.0000030.0000320.9243620.0007140.000404},  // 2    7
                    {0.0014040.0093750.0026380.2291890.0000640.0008960.0075160.7435570.0044620.000897},  // 3    7
                    {0.4911400.0001250.0000240.0003020.0000380.0349470.4731610.0001700.0000280.000066}}; // 6    0
 
   vector top_k=y_true.ClassificationScore(y_scores,CLASSIFICATION_TOP_K_ACCURACY,1);
   Print("top 1 accuracy score = ",top_k);
   top_k=y_true.ClassificationScore(y_scores,CLASSIFICATION_TOP_K_ACCURACY,2);
   Print("top 2 accuracy score = ",top_k);
   vector y_true2={0122};
   matrix y_score2={{0.50.20.2},  // 0 is in top 2
                    {0.30.40.2},  // 1 is in top 2
                    {0.20.40.3},  // 2 is in top 2
                    {0.70.20.1}}; // 2 isn't in top 2
   top_k=y_true2.ClassificationScore(y_score2,CLASSIFICATION_TOP_K_ACCURACY,2);
   Print("top k = ",top_k);
   Print("");
 
   vector ap_micro=y_true.ClassificationScore(y_scores,CLASSIFICATION_AVERAGE_PRECISION,AVERAGE_MICRO);
   Print("average precision score micro = ",ap_micro);
   vector ap_macro=y_true.ClassificationScore(y_scores,CLASSIFICATION_AVERAGE_PRECISION,AVERAGE_MACRO);
   Print("average precision score macro = ",ap_macro);
   vector ap_weighted=y_true.ClassificationScore(y_scores,CLASSIFICATION_AVERAGE_PRECISION,AVERAGE_WEIGHTED);
   Print("average precision score weighted = ",ap_weighted);
   vector ap_none=y_true.ClassificationScore(y_scores,CLASSIFICATION_AVERAGE_PRECISION,AVERAGE_NONE);
   Print("average precision score none = ",ap_none);
   Print("");
 
   vector area_micro=y_true.ClassificationScore(y_scores,CLASSIFICATION_ROC_AUC,AVERAGE_MICRO);
   Print("roc auc score micro = ",area_micro);
   vector area_macro=y_true.ClassificationScore(y_scores,CLASSIFICATION_ROC_AUC,AVERAGE_MACRO);
   Print("roc auc score macro = ",area_macro);
   vector area_weighted=y_true.ClassificationScore(y_scores,CLASSIFICATION_ROC_AUC,AVERAGE_WEIGHTED);
   Print("roc auc score weighted = ",area_weighted);
   vector area_none=y_true.ClassificationScore(y_scores,CLASSIFICATION_ROC_AUC,AVERAGE_NONE);
   Print("roc auc score none = ",area_none);
   Print("");
 
//--- 二元分类
   vector y_pred_bin={0,1,0,1,1,0,0,0,1};
   vector y_true_bin={1,0,0,0,1,0,1,1,1};
   vector y_score_true={0.3,0.7,0.1,0.6,0.9,0.0,0.4,0.2,0.8};
   matrix y_score1_bin(y_score_true.Size(),1);
   y_score1_bin.Col(y_score_true,0);
   matrix y_scores_bin={{0.70.3},
                        {0.30.7},
                        {0.90.1},
                        {0.40.6},
                        {0.10.9},
                        {1.00.0},
                        {0.60.4},
                        {0.80.2},
                        {0.20.8}};
 
   vector ap=y_true_bin.ClassificationScore(y_scores_bin,CLASSIFICATION_AVERAGE_PRECISION,AVERAGE_BINARY);
   Print("average precision score binary = ",ap);
   vector ap2=y_true_bin.ClassificationScore(y_score1_bin,CLASSIFICATION_AVERAGE_PRECISION,AVERAGE_BINARY);
   Print("average precision score binary = ",ap2);
   vector ap3=y_true_bin.ClassificationScore(y_scores_bin,CLASSIFICATION_AVERAGE_PRECISION,AVERAGE_NONE);
   Print("average precision score none = ",ap3);
   Print("");
 
   vector area=y_true_bin.ClassificationScore(y_scores_bin,CLASSIFICATION_ROC_AUC,AVERAGE_BINARY);
   Print("roc auc score binary = ",area);
   vector area2=y_true_bin.ClassificationScore(y_score1_bin,CLASSIFICATION_ROC_AUC,AVERAGE_BINARY);
   Print("roc auc score binary = ",area2);
   vector area3=y_true_bin.ClassificationScore(y_scores_bin,CLASSIFICATION_ROC_AUC,AVERAGE_NONE);
   Print("roc auc score none = ",area3);
 
 
/*
  top 1 accuracy score = [0.6666666666666666]
  top 2 accuracy score = [1]
  top k = [0.75]
  
  average precision score micro = [0.8513333333333333]
  average precision score macro = [0.9326666666666666]
  average precision score weighted = [0.9333333333333333]
  average precision score none = [1,1,0.7,1,0.9266666666666666,0.8333333333333333,1,0.8666666666666667,1,1]
  
  roc auc score micro = [0.9839506172839506]
  roc auc score macro = [0.9892068783068803]
  roc auc score weighted = [0.9887354497354497]
  roc auc score none = [1,1,0.9506172839506173,1,0.984,0.9821428571428571,1,0.9753086419753086,1,1]
  
  average precision score binary = [0.7961904761904761]
  average precision score binary = [0.7961904761904761]
  average precision score none = [0.7678571428571428,0.7961904761904761]
  
  roc auc score binary = [0.7]
  roc auc score binary = [0.7]
  roc auc score none = [0.7,0.7]
*/