added in plot_array and plot_dict & cleanned plot_matrix

This commit is contained in:
georgemihaila
2020-10-06 08:54:35 -05:00
parent a2fe850f78
commit f6c3bad607

View File

@@ -16,14 +16,15 @@
import matplotlib.pyplot as plt
import numpy as np
import warnings
from sklearn.metrics import confusion_matrix
def plot_array(array, step_size=1, use_label=None, use_title=None, use_xlabel=None, use_ylabel=None,
style_sheet='ggplot', use_grid=True, width=3, height=1, use_linestyle='-', use_dpi=20, path=None,
show_plot=True):
style_sheet='ggplot', use_grid=True, width=3, height=1, use_linestyle='-',
magnify=1.2, use_dpi=20, path=None, show_plot=True):
"""Create plot from a single array of values.
:param magnify:
:param array: list of values. Can be of type list or np.ndarray.
:param step_size: steps shows on x-axis. Change if each steps is different than 1.
:param use_label: display label of values from array.
@@ -34,8 +35,8 @@ def plot_array(array, step_size=1, use_label=None, use_title=None, use_xlabel=No
:param use_grid: show grid on plot or not.
:param width: horizontal length of plot.
:param height: vertical length of plot.
:param use_linestyle: what array of styles to use on lines from ['-', '--', '-.', ':'].
:param use_dpi: quality of image saved from plot. 100 is prety high.
:param use_linestyle: what style to use on line from ['-', '--', '-.', ':'].
:param use_dpi: quality of image saved from plot. 100 is pretty high.
:param path: path where to save the plot as an image - if set to None no image will be saved.
:param show_plot: if you want to call `plt.show()`. or not (if you run on a headless server).
:return:
@@ -83,7 +84,7 @@ def plot_array(array, step_size=1, use_label=None, use_title=None, use_xlabel=No
# get size of figure
figsize = fig.get_size_inches()
# change size depending on height and width variables
figsize = [figsize[0] * width, figsize[1] * height]
figsize = [figsize[0] * width * magnify, figsize[1] * height * magnify]
# set the new figure size
fig.set_size_inches(figsize)
# save figure to image if path is set
@@ -95,10 +96,11 @@ def plot_array(array, step_size=1, use_label=None, use_title=None, use_xlabel=No
def plot_dict(dict_arrays, step_size=1, use_title=None, use_xlabel=None, use_ylabel=None,
style_sheet='ggplot', use_grid=True, width=3, height=1, use_linestyles=None, use_dpi=20, path=None,
show_plot=True):
"""Create plot from a dictionary of lists.
:param dict_arrays: dictionary of lists or np.array
style_sheet='ggplot', use_grid=True, width=3, height=1, use_linestyles=None, magnify=1.2,
use_dpi=20, path=None, show_plot=True):
"""Create plot from a single array of values.
:param magnify:
:param dict_arrays:
:param step_size: steps shows on x-axis. Change if each steps is different than 1.
:param use_title: title on top of plot.
:param use_xlabel: horizontal axis label.
@@ -107,7 +109,7 @@ def plot_dict(dict_arrays, step_size=1, use_title=None, use_xlabel=None, use_yla
:param use_grid: show grid on plot or not.
:param width: horizontal length of plot.
:param height: vertical length of plot.
:param use_linestyles: array of styles to use on line from ['-', '--', '-.', ':'].
:param use_linestyles: what style to use on line from ['-', '--', '-.', ':'].
:param use_dpi: quality of image saved from plot. 100 is pretty high.
:param path: path where to save the plot as an image - if set to None no image will be saved.
:param show_plot: if you want to call `plt.show()`. or not (if you run on a headless server).
@@ -137,12 +139,11 @@ def plot_dict(dict_arrays, step_size=1, use_title=None, use_xlabel=None, use_yla
linestyles = ['-', '--', '-.', ':']
if use_linestyles is None:
# if linestyles is non create same style array
use_linestyles = ['-'] * len(dict_arrays)
else:
# check if linestyle is set right
for use_linestyle in use_linestyles:
# check each linestyle
if use_linestyle not in linestyles:
# raise error
raise ValueError("`linestyle=%s` is not in the styles: %s!" % (str(use_linestyle), str(linestyles)))
@@ -171,8 +172,8 @@ def plot_dict(dict_arrays, step_size=1, use_title=None, use_xlabel=None, use_yla
# get size of figure
figsize = fig.get_size_inches()
# change size depending on height and width variables
figsize = [figsize[0] * width, figsize[1] * height]
# set the new figure size
figsize = [figsize[0] * width * magnify, figsize[1] * height * magnify]
# set the new figure size with magnify
fig.set_size_inches(figsize)
# save figure to image if path is set
fig.savefig(path, dpi=use_dpi) if path is not None else None
@@ -182,24 +183,29 @@ def plot_dict(dict_arrays, step_size=1, use_title=None, use_xlabel=None, use_yla
return
def plot_confusion_matrix(y_true, y_pred, classes='', normalize=False, title=None, cmap=plt.cm.Blues, image=None,
verbose=0, magnify=1.2, dpi=50):
def plot_confusion_matrix(y_true, y_pred, title=None, use_title=None, classes='', normalize=False, style_sheet='ggplot',
cmap=plt.cm.Blues, width=3, height=1, image=None, path=None,
verbose=0, magnify=1, dpi=None, use_dpi=50):
"""This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
y_true needs to contain all possible labels.
:param path: str path to save plot in an image.
:param use_dpi: int clarity of plot.
:param height:
:param width:
:param title: Deprecated!
:param y_true: array labels values.
:param y_pred: array predicted label values.
:param classes: array list of label names.
:param normalize: bool normalize confusion matrix or not.
:param title: str string title of plot.
:param use_title: str string title of plot.
:param cmap: plt.cm plot theme
:param image: str path to save plot in an image.
:param verbose: int print confusion matrix when calling function.
:param magnify: int zoom of plot.
:param dpi: int clarity of plot.
:param style_sheet: style of plot. Use plt.style.available to show all styles.
:return: array confusion matrix used to plot.
Note:
- Plot themes:
cmap=plt.cm.Blues - used as default.
@@ -208,33 +214,49 @@ def plot_confusion_matrix(y_true, y_pred, classes='', normalize=False, title=Non
cmap=plt.cm.Greens
cmap=plt.cm.OrRd
"""
# Handle deprecation warnings.
if title is not None:
# assign same value
use_title = title
warnings.warn("`title` will be deprecated in future updates. Use `use_title` in stead!", DeprecationWarning)
if image is not None:
# assign same value
path = image
warnings.warn("`image` will be deprecated in future updates. Use `path` in stead!", DeprecationWarning)
if dpi is not None:
# assign same value
use_dpi = dpi
warnings.warn("`dpi` will be deprecated in future updates. Use `use_dpi` in stead!", DeprecationWarning)
# Make sure labels have right format
if len(y_true) != len(y_pred):
# make sure lengths match
raise ValueError("`y_true` needs to have same length as `y_pred`!")
# make sure style sheet is correct
if style_sheet in plt.style.available:
# set style of plot
plt.style.use(style_sheet)
else:
# style is not correct
raise ValueError("`style_sheet=%s` is not in the supported styles: %s" % (str(style_sheet),
str(plt.style.available)))
# Class labels setup. If none, generate from y_true y_pred
classes = list(classes)
if classes:
assert len(set(y_true)) == len(classes)
else:
classes = set(y_true)
# Title setup
if not title:
if normalize:
title = 'Normalized confusion matrix'
else:
title = 'Confusion matrix, without normalization'
# Compute confusion matrix
cm = confusion_matrix(y_true, y_pred)
# Normalize setup
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
if normalize is True:
print("Normalized confusion matrix")
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
use_title = 'Normalized confusion matrix' if use_title is None else use_title
else:
print('Confusion matrix, without normalization')
use_title = 'Confusion matrix, without normalization' if use_title is None else use_title
# Print if verbose
if verbose > 0:
print(cm)
print(cm) if verbose > 0 else None
# Plot setup
fig, ax = plt.subplots()
im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
@@ -244,7 +266,7 @@ def plot_confusion_matrix(y_true, y_pred, classes='', normalize=False, title=Non
yticks=np.arange(cm.shape[0]),
# ... and label them with the respective list entries
xticklabels=classes, yticklabels=classes,
title=title,
title=use_title,
ylabel='True label',
xlabel='Predicted label')
# Rotate the tick labels and set their alignment.
@@ -259,10 +281,19 @@ def plot_confusion_matrix(y_true, y_pred, classes='', normalize=False, title=Non
ax.text(j, i, format(cm[i, j], fmt),
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black")
fig.tight_layout()
# Never display grid.
plt.grid(False)
# make figure nice
plt.tight_layout()
# get figure object from plot
fig = plt.gcf()
# get size of figure
figsize = fig.get_size_inches()
fig.set_size_inches(figsize * magnify)
if image:
fig.savefig(image, dpi=dpi)
# change size depending on height and width variables
figsize = [figsize[0] * width * magnify, figsize[1] * height * magnify]
# set the new figure size with magnify
fig.set_size_inches(figsize)
# save figure to image if path is set
fig.savefig(path, dpi=use_dpi) if path is not None else None
return cm