Source code for museotoolbox.charts

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
The :mod:`museotoolbox.charts` module gathers plotting functions.

from matplotlib import pyplot as plt
from matplotlib import gridspec
import numpy as np
import itertools

np.seterr(divide='ignore', invalid='ignore')

# for numpy version < 1.17

def _nan_to_num(array, nan=0):
    return np.where(np.isnan(array), nan, array)

[docs]class PlotConfusionMatrix: """ Plot a confusion matrix with imshow of pyplot. Customize color (e.g. diagonal color), add subplots with F1 or Producer/User accuracy. Examples -------- >>> plot = mtb.charts.plotConfusionMatrix([[5,6],[1,8]]) >>> plot.add_text() >>> plot.add_f1() """
[docs] def __init__( self, cm,, left=None, right=None, zero_is_min=True, max_is_max=True, **kwargs): = np.array(cm) self.cm_ = np.copy(cm) self.axes = [] # init gridspec self._left_grisdspec = left self._right_grisdspec = right self._init_gridspec() = plt.subplot([0, 0]) # place it where it should be. self.zero_is_min = zero_is_min if zero_is_min is True: self.vmin = 0 else: self.vmin = np.amin( if max_is_max is True: self.vmax = np.amax( else: self.vmax = max_is_max self.xlabelsPos = 'bottom' self.xrotation = 0 self.yrotation = 0 self.font_size = False self.cmap = cmap self.diag_color = cmap[0])) self.fig = plt.figure(1) cm, interpolation='nearest', aspect='equal', cmap=self.cmap, vmin=self.vmin, vmax=self.vmax, **kwargs) self.kwargs = kwargs self.subplot_ax1v = False self.axes.append(
def _init_gridspec(self): = gridspec.GridSpec( 2, 3, width_ratios=[[1], 1, 1], height_ratios=[[0], 1]) bottom=0, top=1, wspace=0, hspace=0.7 /[0], right=self._right_grisdspec, left=self._left_grisdspec)
[docs] def add_label(self, x_label=False, y_label=False, x_position='top'):, ylabel=y_label)
[docs] def add_text(self, thresold=False, font_size=12, alpha=1, alpha_zero=1): """ Add value of each case on the matrix image. Parameters ---------- thresold : False or integer. alpha : float, default 1. Value from 0 to 1. alpha_zero : float, default 1. Value alpha for 0 values, from 0 to 1. Examples -------- >>> plot.add_text(alpha_zero=0.5) """ plt.rcParams.update({'font.size': font_size}) self.font_size = font_size if thresold is False: thresold = int(np.amax( / 2) for i, j in itertools.product( range([0]), range([1])): cm_value =[i, j] txt_displayed = str(cm_value) if isinstance( cm_value, (int, np.integer)) else '{:.1f}'.format(cm_value) if not # print(cm[i,j]), i, txt_displayed, horizontalalignment="center", color="white" if cm_value > thresold else 'black', fontsize=font_size, va='center', alpha=alpha_zero if cm_value == 0 else alpha) else: # print(self.cm2[i, j]), i, txt_displayed, horizontalalignment="center", color="white" if cm_value > thresold else "black", va='center', fontsize=font_size, )
[docs] def add_x_labels(self, labels=None, rotation=90, position='top'): """ Add labels for X. Parameters ---------- labels : None If labels, best with same len as the X shape. rotation : int, default 90. Int, 45 or 90 is best. position : str, default 'top'. 'top' or 'bottom'. Examples -------- >>> plot.add_text(labels=['Tofu','Houmous'],alpha_zero=0.5,rotation=45) """ self.xrotation = rotation self.xlabels = labels self.xlabelsPos = position if self.xlabelsPos == 'top':'top') # THIS IS THE ONLY CHANGE ['F1'], horizontalalignment='left', rotation=rotation, fontsize=self.font_size)[1])) if rotation != 90: ha = 'left' else: ha = 'center' self.xlabels, rotation=rotation, ha=ha, fontsize=self.font_size)
[docs] def add_mean(self, xLabel='', yLabel='', hide_ticks=False, thresold=50, vmin=0, vmax=100): """ Add Mean for both axis. Parameters ---------- xLabel : str The label for X (i.e. 'All species') yLabel : str The label for Y (i.e. 'All years') thresold : int, default 50. vmin : int. Minimum value for colormap. vmax : Maximum value for colormap. Examples -------- >>> plot.add_mean(xLabel='all species',yLabel='all years') """ if self.subplot_ax1v is not False: self._init_gridspec() if self.subplot_ax1v == 'F1': self.subplot_ax1v = 'Mean' self.add_f1() self.subplot_ax1v = 'Mean' self.ax1v = plt.subplot([0, 1]) self.ax1h = plt.subplot([1, 0]) valV = np.mean(self.cm_, axis=1).reshape(-1, 1).astype(int) valH = np.mean(self.cm_, axis=0).reshape(1, -1).astype(int) self.ax1v.imshow( valV, cmap=self.diag_color, interpolation='nearest', aspect='equal', vmin=vmin, vmax=vmax) self.ax1h.imshow( valH, cmap=self.diag_color, interpolation='nearest', aspect='equal', vmin=vmin, vmax=vmax) if hide_ticks: self.ax1v.set_yticks([]) else: self.ax1v.set_yticks(np.arange(self.cm_.shape[0])) self.ax1v.set_xticks([]) self.ax1h.set_yticks([0]) self.ax1h.set_xticks([]) for i in range([0]): iVal =, axis=1)[i]) self.ax1v.text( 0, i, iVal, color="white" if iVal > thresold else 'black', ha='center', va='center', fontsize=self.font_size) self.ax1v.set_yticklabels([]) for j in range([1]): jVal =, axis=0)[j]) self.ax1h.text( j, 0, jVal, color="white" if jVal > thresold else 'black', ha='center', va='center', fontsize=self.font_size) self.ax1h.set_yticklabels( [yLabel], rotation=self.yrotation, ha='right', va='center', fontsize=self.font_size) self.ax1v.xaxis.set_ticks_position('top') # THIS IS THE ONLY CHANGE self.ax1v.set_xticks([0]) if self.xrotation < 60: ha = 'left' else: ha = 'center' self.ax1v.set_xticklabels( [xLabel], horizontalalignment='left', rotation=self.xrotation, ha=ha, fontsize=self.font_size) self.axes.append([self.ax1v, self.ax1h])
[docs] def add_y_labels(self, labels=None, rotation=0): """ Add labels for Y. Parameters ---------- labels : None If labels, best with same len as the X shape. rotation : int, default 90. Int, 45 or 90 is best. Examples -------- >>> plot.add_y_labels(labels=['Fried','Raw']) """ self.yrotation = rotation self.ylabels = labels self.ylabels, rotation=rotation, horizontalalignment='right', fontsize=self.font_size)
[docs] def add_f1(self): """ Add F1 subplot. Examples -------- >>> plot.add_f1() """ if[0] !=[1]: raise Warning('Number of lines and columns must be equal') if self.subplot_ax1v is False or self.subplot_ax1v == 'F1': self.ax1v = plt.subplot([0, 1]) current_ax = self.ax1v self.subplot_ax1v = 'F1' else: self.ax2v = plt.subplot([0, 2]) current_ax = self.ax2v verticalPlot = [] for label in range([0]): TP = self.cm_[label, label] #TN = np.sum(sp.diag(currentCsv))-currentCsv[label,label] FN = np.nansum(self.cm_[:, label]) - TP FP = np.nansum(self.cm_[label, :]) - TP verticalPlot.append(2 * TP / (2 * TP + FP + FN) * 100) if self.font_size is not False: font_size = self.font_size else: font_size = 12 verticalPlot = np.asarray(verticalPlot).reshape(-1, 1) current_ax.imshow( verticalPlot, cmap=self.diag_color, interpolation='nearest', aspect='equal', vmin=0, vmax=100) if self.xlabelsPos == 'top': current_ax.xaxis.tick_top() current_ax.xaxis.set_ticks_position( 'top') # THIS IS THE ONLY CHANGE current_ax.set_xticks([0]) current_ax.set_xticklabels( ['F1'], horizontalalignment='center', rotation=self.xrotation, size=font_size) else: current_ax.set_xticks([0]) current_ax.set_xticklabels( ['F1'], horizontalalignment='left', rotation=self.xrotation, size=font_size) current_ax.set_yticks([]) for i in range([0]): txt = str(int(_nan_to_num(verticalPlot[i]))) current_ax.text( 0, i, txt, size=font_size, horizontalalignment="center", color="white" if verticalPlot[i] > 50 else "black", va='center') self.axes.append(current_ax)
[docs] def add_accuracy(self, thresold=50, invert_PA_UA=False, user_acc_label='User\'s acc.', prod_acc_label='Prod\'s acc.'): """ Add user and producer accuracy. Parameters ---------- thresold : int, default 50 The thresold value where text will be in white instead of black. invert_PA_UA : bool, default False If True, user and producer accuracy labels are switched (note that it does not reverse the confusion matrix though). user_acc_label: str The user accuracy label to display. Defautl to 'User's acc.' prod_acc_label: str The user accuracy label to display. Defautl to 'Prod's acc.' Examples -------- >>> plot.add_accuracy() """ if self.font_size is not False: font_size = self.font_size else: font_size = 12 if self.subplot_ax1v is not False: self._init_gridspec() if self.subplot_ax1v == 'F1': self.subplot_ax1v = 'accuracy' self.add_f1() self.subplot_ax1v = 'accuracy' if self.cm_.shape[0] != self.cm_.shape[1]: raise Warning('Number of lines and columns must be equal') self.ax1v = plt.subplot([0, 1]) self.ax1h = plt.subplot([1, 0]) self.ax1v.imshow(np.array(np.diag(self.cm_) / np.nansum(self.cm_, axis=1) * 100).reshape(-1, 1), cmap=self.diag_color, interpolation='nearest', aspect='equal', vmin=0, vmax=100) self.ax1h.imshow(np.array(np.diag(self.cm_) / np.nansum(self.cm_, axis=0) * 100).reshape( 1, -1), cmap=self.diag_color, interpolation='nearest', aspect='equal', vmin=0, vmax=100) self.ax1v.set_yticks(np.arange(self.cm_.shape[0])) self.ax1v.set_xticks([]) self.ax1h.set_yticks([0]) self.ax1h.set_xticks([]) for i in range([0]): iVal = self.cm_) / np.nansum(self.cm_, axis=1) * 100).reshape(-1, 1)[i][0], nan=0)) self.ax1v.text( 0, i, iVal, color="white" if iVal > thresold else 'black', size=font_size, ha='center', va='center') self.ax1v.set_yticklabels([]) for j in range([1]): jVal = self.cm_) / np.nansum(self.cm_, axis=0) * 100).reshape(-1, 1)[j][0], nan=0)) self.ax1h.text( j, 0, jVal, color="white" if jVal > thresold else 'black', size=font_size, ha='center', va='center') y_label, x_label = [prod_acc_label], [user_acc_label] if invert_PA_UA: x_label, y_label = y_label, x_label self.ax1h.set_yticklabels( y_label, rotation=self.yrotation, ha='right', va='center', size=font_size) self.ax1v.xaxis.set_ticks_position('top') # THIS IS THE ONLY CHANGE self.ax1v.set_xticks([0]) if self.xrotation < 60: ha = 'left' else: ha = 'center' self.ax1v.set_xticklabels( x_label, horizontalalignment='left', rotation=self.xrotation, ha=ha, size=font_size) # TOFIX not extend ? self.axes.append([self.ax1v, self.ax1h])
[docs] def color_diagonal(self,, """ Add user and producer accuracy. Parameters ---------- diag_color : pyplot colormap, default matrix_color : pyplot colormap, default Examples -------- >>> plot.colorDiag() """ self.diag_color = diag_color if[0] !=[1]: raise Warning( 'Array must have the same number of lines and columns') mask = np.zeros( np.fill_diagonal(mask, 1) self.cm2 =, mask=np.logical_not(mask)) =, mask=mask) if self.zero_is_min is True: vmin = 0 else: vmin = np.amin( self.cm_), self.cm2, interpolation='nearest', aspect='equal', cmap=diag_color, vmin=vmin, vmax=np.amax( self.cm_), alpha=1), interpolation='nearest', aspect='equal', cmap=matrix_color, vmin=vmin, vmax=np.amax( self.cm_), alpha=1)
[docs] def save_to(self, path, dpi=150): """ Save the plot Parameters ---------- path : str The path of the file to save. dpi : int, default 150. Examples -------- >>> plot.saveTo('/tmp/contofu.pdf',dpi=300) """ self.fig.savefig(path, dpi=dpi, bbox_inches='tight')