Since sklearn.metrics.classification_report
returns a string, we can not sort the data as is. The solution, if we want to use this function, is to rewrite it or to parse the output chain and reorder the data to our liking.
Since it is not usually necessary to be very efficient when using this type of function, aimed only at presenting data to the user in a pleasant way, the simplest is the second option.
Since the structure of the output chain is constant and only varies in the number of spaces and decimals depending on the input parameters, it is not complicated to do this using str.split()
. First we obtain the lines using \n
as a separator and then we will obtain the data of the column according to the one we want to order. Based on this column, we reorder the lines and reconstruct the chain with str.join()
.
One function to do this could be:
from sklearn.metrics import classification_report #No olvidar este import
def sorted_classification_report(y_true, y_pred, labels=None, target_names=None, sample_weight=None, digits=2, sorted_by = None , reverse = False):
lines = classification_report(y_true, y_pred, labels, target_names, sample_weight, digits).split('\n')
datos = lines[2:-3]
if sorted_by:
try:
col = ('precision', 'recall', 'f1-score', 'support').index(sorted_by)-4
keydict = dict(zip(datos, (float(row.split()[col]) for row in datos)))
datos.sort(key = lambda x: keydict.get(x), reverse=reverse)
except ValueError:
raise ValueError("invalid value for 'sorted_by' argument: '{}'.".format(sorted_by))
return '\n'.join(lines[:2] + datos + lines[-3:])
The function receives the same parameters as classification_report
plus two others:
Example of use sorting by f1-score
from highest to lowest:
y_true = [0, 1, 2, 2, 2]
y_pred = [0, 0, 2, 2, 1]
target_names = ['class 0', 'class 1', 'class 2']
print(sorted_classification_report(y_true, y_pred, target_names=target_names, digits = 2, sorted_by= 'f1-score' , reverse = True))
Exit:
precision recall f1-score support
class 2 1.00 0.67 0.80 3
class 0 0.50 1.00 0.67 1
class 1 0.00 0.00 0.00 1
avg / total 0.70 0.60 0.61 5
Edit (see comments):
Applied the function on the example of the documentation on which you are working, substituting classification_report
for our function:
print(sorted_classification_report(y_test, y_pred, target_names=target_names, sorted_by= 'support' , reverse = True))
We get the output (sorting from highest to lowest according to "support"):
precision recall f1-score support
George W Bush 0.84 0.98 0.90 146
Colin Powell 0.78 0.87 0.82 60
Tony Blair 0.96 0.75 0.84 36
Donald Rumsfeld 0.81 0.63 0.71 27
Gerhard Schroeder 0.95 0.76 0.84 25
Hugo Chavez 1.00 0.47 0.64 15
Ariel Sharon 0.75 0.46 0.57 13
avg / total 0.85 0.84 0.83 322