Source code for museotoolbox.charts

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# =============================================================================
# ___  ___                       _____           _______
# |  \/  |                      |_   _|         | | ___ \
# | .  . |_   _ ___  ___  ___     | | ___   ___ | | |_/ / _____  __
# | |\/| | | | / __|/ _ \/ _ \    | |/ _ \ / _ \| | ___ \/ _ \ \/ /
# | |  | | |_| \__ \  __/ (_) |   | | (_) | (_) | | |_/ / (_) >  <
# \_|  |_/\__,_|___/\___|\___/    \_/\___/ \___/|_\____/ \___/_/\_\
#
# @author:  Nicolas Karasiak
# @site:    www.karasiak.net
# @git:     www.github.com/nkarasiak/MuseoToolBox
# =============================================================================
"""
The :mod:`museotoolbox.charts` module gathers plotting functions.
"""

from matplotlib import pyplot as plt
from matplotlib import gridspec
import numpy as np
import itertools

np.seterr(divide='ignore', invalid='ignore')

# for numpy version < 1.17


def _nan_to_num(array, nan=0):
    return np.where(np.isnan(array), nan, array)


[docs]class PlotConfusionMatrix: """ Plot a confusion matrix with imshow of pyplot. Customize color (e.g. diagonal color), add subplots with F1 or Producer/User accuracy. Examples -------- >>> plot = mtb.charts.plotConfusionMatrix([[5,6],[1,8]]) >>> plot.add_text() >>> plot.add_f1() """
[docs] def __init__( self, cm, cmap=plt.cm.Greens, left=None, right=None, zero_is_min=True, max_is_max=True, **kwargs): self.cm = np.array(cm) self.cm_ = np.copy(cm) self.axes = [] # init gridspec self._left_grisdspec = left self._right_grisdspec = right self._init_gridspec() self.ax = plt.subplot(self.gs[0, 0]) # place it where it should be. self.zero_is_min = zero_is_min if zero_is_min is True: self.vmin = 0 else: self.vmin = np.amin(self.cm) if max_is_max is True: self.vmax = np.amax(self.cm) else: self.vmax = max_is_max self.xlabelsPos = 'bottom' self.xrotation = 0 self.yrotation = 0 self.font_size = False self.cmap = cmap self.diag_color = cmap self.ax.set_yticks(range(self.cm.shape[0])) self.fig = plt.figure(1) self.ax.imshow( cm, interpolation='nearest', aspect='equal', cmap=self.cmap, vmin=self.vmin, vmax=self.vmax, **kwargs) self.kwargs = kwargs self.subplot_ax1v = False self.axes.append(self.ax)
def _init_gridspec(self): self.gs = gridspec.GridSpec( 2, 3, width_ratios=[ self.cm.shape[1], 1, 1], height_ratios=[ self.cm.shape[0], 1]) self.gs.update( bottom=0, top=1, wspace=0, hspace=0.7 / self.cm.shape[0], right=self._right_grisdspec, left=self._left_grisdspec)
[docs] def add_label(self, x_label=False, y_label=False, x_position='top'): self.ax.set(xlabel=x_label, ylabel=y_label) self.ax.xaxis.set_label_position(x_position)
[docs] def add_text(self, thresold=False, font_size=12, alpha=1, alpha_zero=1): """ Add value of each case on the matrix image. Parameters ---------- thresold : False or integer. alpha : float, default 1. Value from 0 to 1. alpha_zero : float, default 1. Value alpha for 0 values, from 0 to 1. Examples -------- >>> plot.add_text(alpha_zero=0.5) """ plt.rcParams.update({'font.size': font_size}) self.font_size = font_size if thresold is False: thresold = int(np.amax(self.cm) / 2) for i, j in itertools.product( range(self.cm.shape[0]), range(self.cm.shape[1])): cm_value = self.cm[i, j] txt_displayed = str(cm_value) if isinstance( cm_value, (int, np.integer)) else '{:.1f}'.format(cm_value) if not np.ma.is_masked(cm_value): # print(cm[i,j]) self.ax.text(j, i, txt_displayed, horizontalalignment="center", color="white" if cm_value > thresold else 'black', fontsize=font_size, va='center', alpha=alpha_zero if cm_value == 0 else alpha) else: # print(self.cm2[i, j]) self.ax.text(j, i, txt_displayed, horizontalalignment="center", color="white" if cm_value > thresold else "black", va='center', fontsize=font_size, )
[docs] def add_x_labels(self, labels=None, rotation=90, position='top'): """ Add labels for X. Parameters ---------- labels : None If labels, best with same len as the X shape. rotation : int, default 90. Int, 45 or 90 is best. position : str, default 'top'. 'top' or 'bottom'. Examples -------- >>> plot.add_text(labels=['Tofu','Houmous'],alpha_zero=0.5,rotation=45) """ self.xrotation = rotation self.xlabels = labels self.xlabelsPos = position if self.xlabelsPos == 'top': self.ax.xaxis.tick_top() self.ax.xaxis.set_ticks_position('top') # THIS IS THE ONLY CHANGE self.ax.set_xticklabels( ['F1'], horizontalalignment='left', rotation=rotation, fontsize=self.font_size) self.ax.set_xticks(np.arange(self.cm.shape[1])) if rotation != 90: ha = 'left' else: ha = 'center' self.ax.set_xticklabels( self.xlabels, rotation=rotation, ha=ha, fontsize=self.font_size)
[docs] def add_mean(self, xLabel='', yLabel='', hide_ticks=False, thresold=50, vmin=0, vmax=100): """ Add Mean for both axis. Parameters ---------- xLabel : str The label for X (i.e. 'All species') yLabel : str The label for Y (i.e. 'All years') thresold : int, default 50. vmin : int. Minimum value for colormap. vmax : Maximum value for colormap. Examples -------- >>> plot.add_mean(xLabel='all species',yLabel='all years') """ if self.subplot_ax1v is not False: self._init_gridspec() if self.subplot_ax1v == 'F1': self.subplot_ax1v = 'Mean' self.add_f1() self.subplot_ax1v = 'Mean' self.ax1v = plt.subplot(self.gs[0, 1]) self.ax1h = plt.subplot(self.gs[1, 0]) valV = np.mean(self.cm_, axis=1).reshape(-1, 1).astype(int) valH = np.mean(self.cm_, axis=0).reshape(1, -1).astype(int) self.ax1v.imshow( valV, cmap=self.diag_color, interpolation='nearest', aspect='equal', vmin=vmin, vmax=vmax) self.ax1h.imshow( valH, cmap=self.diag_color, interpolation='nearest', aspect='equal', vmin=vmin, vmax=vmax) if hide_ticks: self.ax1v.set_yticks([]) else: self.ax1v.set_yticks(np.arange(self.cm_.shape[0])) self.ax1v.set_xticks([]) self.ax1h.set_yticks([0]) self.ax1h.set_xticks([]) for i in range(self.cm.shape[0]): iVal = np.int(np.mean(self.cm_, axis=1)[i]) self.ax1v.text( 0, i, iVal, color="white" if iVal > thresold else 'black', ha='center', va='center', fontsize=self.font_size) self.ax1v.set_yticklabels([]) for j in range(self.cm.shape[1]): jVal = np.int(np.mean(self.cm_, axis=0)[j]) self.ax1h.text( j, 0, jVal, color="white" if jVal > thresold else 'black', ha='center', va='center', fontsize=self.font_size) self.ax1h.set_yticklabels( [yLabel], rotation=self.yrotation, ha='right', va='center', fontsize=self.font_size) self.ax1v.xaxis.set_ticks_position('top') # THIS IS THE ONLY CHANGE self.ax1v.set_xticks([0]) if self.xrotation < 60: ha = 'left' else: ha = 'center' self.ax1v.set_xticklabels( [xLabel], horizontalalignment='left', rotation=self.xrotation, ha=ha, fontsize=self.font_size) self.axes.append([self.ax1v, self.ax1h])
[docs] def add_y_labels(self, labels=None, rotation=0): """ Add labels for Y. Parameters ---------- labels : None If labels, best with same len as the X shape. rotation : int, default 90. Int, 45 or 90 is best. Examples -------- >>> plot.add_y_labels(labels=['Fried','Raw']) """ self.yrotation = rotation self.ylabels = labels self.ax.set_yticklabels( self.ylabels, rotation=rotation, horizontalalignment='right', fontsize=self.font_size)
[docs] def add_f1(self): """ Add F1 subplot. Examples -------- >>> plot.add_f1() """ if self.cm.shape[0] != self.cm.shape[1]: raise Warning('Number of lines and columns must be equal') if self.subplot_ax1v is False or self.subplot_ax1v == 'F1': self.ax1v = plt.subplot(self.gs[0, 1]) current_ax = self.ax1v self.subplot_ax1v = 'F1' else: self.ax2v = plt.subplot(self.gs[0, 2]) current_ax = self.ax2v verticalPlot = [] for label in range(self.cm.shape[0]): TP = self.cm_[label, label] #TN = np.sum(sp.diag(currentCsv))-currentCsv[label,label] FN = np.nansum(self.cm_[:, label]) - TP FP = np.nansum(self.cm_[label, :]) - TP verticalPlot.append(2 * TP / (2 * TP + FP + FN) * 100) if self.font_size is not False: font_size = self.font_size else: font_size = 12 verticalPlot = np.asarray(verticalPlot).reshape(-1, 1) current_ax.imshow( verticalPlot, cmap=self.diag_color, interpolation='nearest', aspect='equal', vmin=0, vmax=100) if self.xlabelsPos == 'top': current_ax.xaxis.tick_top() current_ax.xaxis.set_ticks_position( 'top') # THIS IS THE ONLY CHANGE current_ax.set_xticks([0]) current_ax.set_xticklabels( ['F1'], horizontalalignment='center', rotation=self.xrotation, size=font_size) else: current_ax.set_xticks([0]) current_ax.set_xticklabels( ['F1'], horizontalalignment='left', rotation=self.xrotation, size=font_size) current_ax.set_yticks([]) for i in range(self.cm.shape[0]): txt = str(int(_nan_to_num(verticalPlot[i]))) current_ax.text( 0, i, txt, size=font_size, horizontalalignment="center", color="white" if verticalPlot[i] > 50 else "black", va='center') self.axes.append(current_ax)
[docs] def add_accuracy(self, thresold=50, invert_PA_UA=False, user_acc_label='User\'s acc.', prod_acc_label='Prod\'s acc.'): """ Add user and producer accuracy. Parameters ---------- thresold : int, default 50 The thresold value where text will be in white instead of black. invert_PA_UA : bool, default False If True, user and producer accuracy labels are switched (note that it does not reverse the confusion matrix though). user_acc_label: str The user accuracy label to display. Defautl to 'User's acc.' prod_acc_label: str The user accuracy label to display. Defautl to 'Prod's acc.' Examples -------- >>> plot.add_accuracy() """ if self.font_size is not False: font_size = self.font_size else: font_size = 12 if self.subplot_ax1v is not False: self._init_gridspec() if self.subplot_ax1v == 'F1': self.subplot_ax1v = 'accuracy' self.add_f1() self.subplot_ax1v = 'accuracy' if self.cm_.shape[0] != self.cm_.shape[1]: raise Warning('Number of lines and columns must be equal') self.ax1v = plt.subplot(self.gs[0, 1]) self.ax1h = plt.subplot(self.gs[1, 0]) self.ax1v.imshow(np.array(np.diag(self.cm_) / np.nansum(self.cm_, axis=1) * 100).reshape(-1, 1), cmap=self.diag_color, interpolation='nearest', aspect='equal', vmin=0, vmax=100) self.ax1h.imshow(np.array(np.diag(self.cm_) / np.nansum(self.cm_, axis=0) * 100).reshape( 1, -1), cmap=self.diag_color, interpolation='nearest', aspect='equal', vmin=0, vmax=100) self.ax1v.set_yticks(np.arange(self.cm_.shape[0])) self.ax1v.set_xticks([]) self.ax1h.set_yticks([0]) self.ax1h.set_xticks([]) for i in range(self.cm.shape[0]): iVal = np.int(_nan_to_num(np.array(np.diag( self.cm_) / np.nansum(self.cm_, axis=1) * 100).reshape(-1, 1)[i][0], nan=0)) self.ax1v.text( 0, i, iVal, color="white" if iVal > thresold else 'black', size=font_size, ha='center', va='center') self.ax1v.set_yticklabels([]) for j in range(self.cm.shape[1]): jVal = np.int(_nan_to_num(np.array(np.diag( self.cm_) / np.nansum(self.cm_, axis=0) * 100).reshape(-1, 1)[j][0], nan=0)) self.ax1h.text( j, 0, jVal, color="white" if jVal > thresold else 'black', size=font_size, ha='center', va='center') y_label, x_label = [prod_acc_label], [user_acc_label] if invert_PA_UA: x_label, y_label = y_label, x_label self.ax1h.set_yticklabels( y_label, rotation=self.yrotation, ha='right', va='center', size=font_size) self.ax1v.xaxis.set_ticks_position('top') # THIS IS THE ONLY CHANGE self.ax1v.set_xticks([0]) if self.xrotation < 60: ha = 'left' else: ha = 'center' self.ax1v.set_xticklabels( x_label, horizontalalignment='left', rotation=self.xrotation, ha=ha, size=font_size) # TOFIX not extend ? self.axes.append([self.ax1v, self.ax1h])
[docs] def color_diagonal(self, diag_color=plt.cm.Greens, matrix_color=plt.cm.Reds): """ Add user and producer accuracy. Parameters ---------- diag_color : pyplot colormap, default plt.cm.Greens. matrix_color : pyplot colormap, default plt.cm.Reds Examples -------- >>> plot.colorDiag() """ self.diag_color = diag_color if self.cm.shape[0] != self.cm.shape[1]: raise Warning( 'Array must have the same number of lines and columns') mask = np.zeros(self.cm.shape) np.fill_diagonal(mask, 1) self.cm2 = np.ma.masked_array(self.cm, mask=np.logical_not(mask)) self.cm = np.ma.masked_array(self.cm, mask=mask) if self.zero_is_min is True: vmin = 0 else: vmin = np.amin( self.cm_), self.ax.imshow( self.cm2, interpolation='nearest', aspect='equal', cmap=diag_color, vmin=vmin, vmax=np.amax( self.cm_), alpha=1) self.ax.imshow( self.cm, interpolation='nearest', aspect='equal', cmap=matrix_color, vmin=vmin, vmax=np.amax( self.cm_), alpha=1)
[docs] def save_to(self, path, dpi=150): """ Save the plot Parameters ---------- path : str The path of the file to save. dpi : int, default 150. Examples -------- >>> plot.saveTo('/tmp/contofu.pdf',dpi=300) """ self.fig.savefig(path, dpi=dpi, bbox_inches='tight')