mirror of
https://github.com/gmihaila/ml_things.git
synced 2021-10-04 01:29:04 +03:00
added in plot_array and plot_dict & cleanned plot_matrix
This commit is contained in:
@@ -16,14 +16,15 @@
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import warnings
|
||||
from sklearn.metrics import confusion_matrix
|
||||
|
||||
|
||||
def plot_array(array, step_size=1, use_label=None, use_title=None, use_xlabel=None, use_ylabel=None,
|
||||
style_sheet='ggplot', use_grid=True, width=3, height=1, use_linestyle='-', use_dpi=20, path=None,
|
||||
show_plot=True):
|
||||
style_sheet='ggplot', use_grid=True, width=3, height=1, use_linestyle='-',
|
||||
magnify=1.2, use_dpi=20, path=None, show_plot=True):
|
||||
"""Create plot from a single array of values.
|
||||
|
||||
:param magnify:
|
||||
:param array: list of values. Can be of type list or np.ndarray.
|
||||
:param step_size: steps shows on x-axis. Change if each steps is different than 1.
|
||||
:param use_label: display label of values from array.
|
||||
@@ -34,8 +35,8 @@ def plot_array(array, step_size=1, use_label=None, use_title=None, use_xlabel=No
|
||||
:param use_grid: show grid on plot or not.
|
||||
:param width: horizontal length of plot.
|
||||
:param height: vertical length of plot.
|
||||
:param use_linestyle: what array of styles to use on lines from ['-', '--', '-.', ':'].
|
||||
:param use_dpi: quality of image saved from plot. 100 is prety high.
|
||||
:param use_linestyle: what style to use on line from ['-', '--', '-.', ':'].
|
||||
:param use_dpi: quality of image saved from plot. 100 is pretty high.
|
||||
:param path: path where to save the plot as an image - if set to None no image will be saved.
|
||||
:param show_plot: if you want to call `plt.show()`. or not (if you run on a headless server).
|
||||
:return:
|
||||
@@ -83,7 +84,7 @@ def plot_array(array, step_size=1, use_label=None, use_title=None, use_xlabel=No
|
||||
# get size of figure
|
||||
figsize = fig.get_size_inches()
|
||||
# change size depending on height and width variables
|
||||
figsize = [figsize[0] * width, figsize[1] * height]
|
||||
figsize = [figsize[0] * width * magnify, figsize[1] * height * magnify]
|
||||
# set the new figure size
|
||||
fig.set_size_inches(figsize)
|
||||
# save figure to image if path is set
|
||||
@@ -95,10 +96,11 @@ def plot_array(array, step_size=1, use_label=None, use_title=None, use_xlabel=No
|
||||
|
||||
|
||||
def plot_dict(dict_arrays, step_size=1, use_title=None, use_xlabel=None, use_ylabel=None,
|
||||
style_sheet='ggplot', use_grid=True, width=3, height=1, use_linestyles=None, use_dpi=20, path=None,
|
||||
show_plot=True):
|
||||
"""Create plot from a dictionary of lists.
|
||||
:param dict_arrays: dictionary of lists or np.array
|
||||
style_sheet='ggplot', use_grid=True, width=3, height=1, use_linestyles=None, magnify=1.2,
|
||||
use_dpi=20, path=None, show_plot=True):
|
||||
"""Create plot from a single array of values.
|
||||
:param magnify:
|
||||
:param dict_arrays:
|
||||
:param step_size: steps shows on x-axis. Change if each steps is different than 1.
|
||||
:param use_title: title on top of plot.
|
||||
:param use_xlabel: horizontal axis label.
|
||||
@@ -107,7 +109,7 @@ def plot_dict(dict_arrays, step_size=1, use_title=None, use_xlabel=None, use_yla
|
||||
:param use_grid: show grid on plot or not.
|
||||
:param width: horizontal length of plot.
|
||||
:param height: vertical length of plot.
|
||||
:param use_linestyles: array of styles to use on line from ['-', '--', '-.', ':'].
|
||||
:param use_linestyles: what style to use on line from ['-', '--', '-.', ':'].
|
||||
:param use_dpi: quality of image saved from plot. 100 is pretty high.
|
||||
:param path: path where to save the plot as an image - if set to None no image will be saved.
|
||||
:param show_plot: if you want to call `plt.show()`. or not (if you run on a headless server).
|
||||
@@ -137,12 +139,11 @@ def plot_dict(dict_arrays, step_size=1, use_title=None, use_xlabel=None, use_yla
|
||||
linestyles = ['-', '--', '-.', ':']
|
||||
|
||||
if use_linestyles is None:
|
||||
# if linestyles is non create same style array
|
||||
use_linestyles = ['-'] * len(dict_arrays)
|
||||
|
||||
else:
|
||||
# check if linestyle is set right
|
||||
for use_linestyle in use_linestyles:
|
||||
# check each linestyle
|
||||
if use_linestyle not in linestyles:
|
||||
# raise error
|
||||
raise ValueError("`linestyle=%s` is not in the styles: %s!" % (str(use_linestyle), str(linestyles)))
|
||||
@@ -171,8 +172,8 @@ def plot_dict(dict_arrays, step_size=1, use_title=None, use_xlabel=None, use_yla
|
||||
# get size of figure
|
||||
figsize = fig.get_size_inches()
|
||||
# change size depending on height and width variables
|
||||
figsize = [figsize[0] * width, figsize[1] * height]
|
||||
# set the new figure size
|
||||
figsize = [figsize[0] * width * magnify, figsize[1] * height * magnify]
|
||||
# set the new figure size with magnify
|
||||
fig.set_size_inches(figsize)
|
||||
# save figure to image if path is set
|
||||
fig.savefig(path, dpi=use_dpi) if path is not None else None
|
||||
@@ -182,24 +183,29 @@ def plot_dict(dict_arrays, step_size=1, use_title=None, use_xlabel=None, use_yla
|
||||
return
|
||||
|
||||
|
||||
def plot_confusion_matrix(y_true, y_pred, classes='', normalize=False, title=None, cmap=plt.cm.Blues, image=None,
|
||||
verbose=0, magnify=1.2, dpi=50):
|
||||
def plot_confusion_matrix(y_true, y_pred, title=None, use_title=None, classes='', normalize=False, style_sheet='ggplot',
|
||||
cmap=plt.cm.Blues, width=3, height=1, image=None, path=None,
|
||||
verbose=0, magnify=1, dpi=None, use_dpi=50):
|
||||
"""This function prints and plots the confusion matrix.
|
||||
Normalization can be applied by setting `normalize=True`.
|
||||
y_true needs to contain all possible labels.
|
||||
|
||||
:param path: str path to save plot in an image.
|
||||
:param use_dpi: int clarity of plot.
|
||||
:param height:
|
||||
:param width:
|
||||
:param title: Deprecated!
|
||||
:param y_true: array labels values.
|
||||
:param y_pred: array predicted label values.
|
||||
:param classes: array list of label names.
|
||||
:param normalize: bool normalize confusion matrix or not.
|
||||
:param title: str string title of plot.
|
||||
:param use_title: str string title of plot.
|
||||
:param cmap: plt.cm plot theme
|
||||
:param image: str path to save plot in an image.
|
||||
:param verbose: int print confusion matrix when calling function.
|
||||
:param magnify: int zoom of plot.
|
||||
:param dpi: int clarity of plot.
|
||||
:param style_sheet: style of plot. Use plt.style.available to show all styles.
|
||||
:return: array confusion matrix used to plot.
|
||||
|
||||
Note:
|
||||
- Plot themes:
|
||||
cmap=plt.cm.Blues - used as default.
|
||||
@@ -208,33 +214,49 @@ def plot_confusion_matrix(y_true, y_pred, classes='', normalize=False, title=Non
|
||||
cmap=plt.cm.Greens
|
||||
cmap=plt.cm.OrRd
|
||||
"""
|
||||
# Handle deprecation warnings.
|
||||
if title is not None:
|
||||
# assign same value
|
||||
use_title = title
|
||||
warnings.warn("`title` will be deprecated in future updates. Use `use_title` in stead!", DeprecationWarning)
|
||||
if image is not None:
|
||||
# assign same value
|
||||
path = image
|
||||
warnings.warn("`image` will be deprecated in future updates. Use `path` in stead!", DeprecationWarning)
|
||||
if dpi is not None:
|
||||
# assign same value
|
||||
use_dpi = dpi
|
||||
warnings.warn("`dpi` will be deprecated in future updates. Use `use_dpi` in stead!", DeprecationWarning)
|
||||
# Make sure labels have right format
|
||||
if len(y_true) != len(y_pred):
|
||||
# make sure lengths match
|
||||
raise ValueError("`y_true` needs to have same length as `y_pred`!")
|
||||
|
||||
# make sure style sheet is correct
|
||||
if style_sheet in plt.style.available:
|
||||
# set style of plot
|
||||
plt.style.use(style_sheet)
|
||||
else:
|
||||
# style is not correct
|
||||
raise ValueError("`style_sheet=%s` is not in the supported styles: %s" % (str(style_sheet),
|
||||
str(plt.style.available)))
|
||||
# Class labels setup. If none, generate from y_true y_pred
|
||||
classes = list(classes)
|
||||
if classes:
|
||||
assert len(set(y_true)) == len(classes)
|
||||
else:
|
||||
classes = set(y_true)
|
||||
# Title setup
|
||||
if not title:
|
||||
if normalize:
|
||||
title = 'Normalized confusion matrix'
|
||||
else:
|
||||
title = 'Confusion matrix, without normalization'
|
||||
# Compute confusion matrix
|
||||
cm = confusion_matrix(y_true, y_pred)
|
||||
# Normalize setup
|
||||
if normalize:
|
||||
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
|
||||
if normalize is True:
|
||||
print("Normalized confusion matrix")
|
||||
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
|
||||
use_title = 'Normalized confusion matrix' if use_title is None else use_title
|
||||
else:
|
||||
print('Confusion matrix, without normalization')
|
||||
use_title = 'Confusion matrix, without normalization' if use_title is None else use_title
|
||||
# Print if verbose
|
||||
if verbose > 0:
|
||||
print(cm)
|
||||
print(cm) if verbose > 0 else None
|
||||
# Plot setup
|
||||
fig, ax = plt.subplots()
|
||||
im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
|
||||
@@ -244,7 +266,7 @@ def plot_confusion_matrix(y_true, y_pred, classes='', normalize=False, title=Non
|
||||
yticks=np.arange(cm.shape[0]),
|
||||
# ... and label them with the respective list entries
|
||||
xticklabels=classes, yticklabels=classes,
|
||||
title=title,
|
||||
title=use_title,
|
||||
ylabel='True label',
|
||||
xlabel='Predicted label')
|
||||
# Rotate the tick labels and set their alignment.
|
||||
@@ -259,10 +281,19 @@ def plot_confusion_matrix(y_true, y_pred, classes='', normalize=False, title=Non
|
||||
ax.text(j, i, format(cm[i, j], fmt),
|
||||
ha="center", va="center",
|
||||
color="white" if cm[i, j] > thresh else "black")
|
||||
fig.tight_layout()
|
||||
# Never display grid.
|
||||
plt.grid(False)
|
||||
# make figure nice
|
||||
plt.tight_layout()
|
||||
# get figure object from plot
|
||||
fig = plt.gcf()
|
||||
# get size of figure
|
||||
figsize = fig.get_size_inches()
|
||||
fig.set_size_inches(figsize * magnify)
|
||||
if image:
|
||||
fig.savefig(image, dpi=dpi)
|
||||
# change size depending on height and width variables
|
||||
figsize = [figsize[0] * width * magnify, figsize[1] * height * magnify]
|
||||
# set the new figure size with magnify
|
||||
fig.set_size_inches(figsize)
|
||||
# save figure to image if path is set
|
||||
fig.savefig(path, dpi=use_dpi) if path is not None else None
|
||||
|
||||
return cm
|
||||
|
||||
Reference in New Issue
Block a user