Source code for reval.visualization

from matplotlib import pyplot as plt


[docs]def plot_metrics(cv_score, figsize=(8, 5), linewidth=1, color=('black', 'black'), legend_loc=2, fontsize=12, title="", prob_lines=False, save_fig=None): """ Function that plots the average performance (i.e., normalized stability) over cross-validation for training and validation sets. The horizontal lines represent the random performance error for the correspondent number of clusters. :param cv_score: collection of cv scores as output by `reval.best_nclust_cv.FindBestCLustCV.best_nclust`. :type cv_score: dictionary :param figsize: (width, height), default (8, 5). :type figsize: tuple :param linewidth: width of the lines to draw. :type linewidth: int :param color: line colors for train and validation sets, default ('black', 'black'). :type color: tuple :param legend_loc: legend location, default 2. :type legend_loc: int :param fontsize: size of fonts, default 12. :type fontsize: int :param title: figure title, default "". :type title: str :param prob_lines: plot the normalized stability of random labeling as thresholds, default False. :type prob_lines: bool :param save_fig: file name for saving figure in png format, default None. :type save_fig: str """ fig, ax = plt.subplots(figsize=figsize) cl_list = list(cv_score['train'].keys()) ax.plot(list(cv_score['train'].keys()), [me[0] for me in cv_score['train'].values()], linewidth=linewidth, linestyle='-.', label='training set', color=color[0]) ax.errorbar(list(cv_score['val'].keys()), [me[0] for me in cv_score['val'].values()], [me[1][1] for me in cv_score['val'].values()], linewidth=linewidth, linestyle='-', label='validation set', color=color[1]) if prob_lines: plt.hlines([(1 - (1 / k)) for k in cl_list], xmin=[k - 0.1 for k in cl_list], xmax=[k + 0.1 for k in cl_list], linewidth=linewidth, color=color[1]) ax.legend(fontsize=fontsize, loc=legend_loc) plt.xticks([lab for lab in cv_score['train'].keys()], fontsize=fontsize) plt.yticks(fontsize=fontsize) plt.xlabel('Number of clusters', fontsize=fontsize) plt.ylabel('Normalized stability', fontsize=fontsize) plt.title(title) if save_fig is not None: plt.savefig(f'./{save_fig}', format=str.split(save_fig, '.')[1]) else: plt.show()