mirror of
https://github.com/gmihaila/ml_things.git
synced 2021-10-04 01:29:04 +03:00
adjustable magnify of font for plots
This commit is contained in:
@@ -112,7 +112,7 @@ All arguments are optimized for quick plots. Change the `magnify` arguments to v
|
||||
|
||||
```python
|
||||
>>> from ml_things import plot_array
|
||||
>>> plot_array([1,3,5,3,7,5,8,10], path='plot_array.png', magnify=0.5, use_title='A Random Plot', start_step=0.3, step_size=0.1, points_values=True)
|
||||
>>> plot_array([1,3,5,3,7,5,8,10], path='plot_array.png', magnify=0.1, use_title='A Random Plot', start_step=0.3, step_size=0.1, points_values=True, use_ylabel='Thid', use_xlabel='This')
|
||||
```
|
||||
|
||||

|
||||
@@ -127,8 +127,8 @@ All arguments are optimized for quick plots. Change the `magnify` arguments to v
|
||||
```python
|
||||
>>> from ml_things import plot_dict
|
||||
>>> plot_dict({'train_acc':[1,3,5,3,7,5,8,10],
|
||||
'valid_acc':[4,8,9]}, use_linestyles=['-', '--'], magnify=0.5,
|
||||
start_step=0.3, step_size=0.1,path='plot_dict.png', points_values=[True, False])
|
||||
'valid_acc':[4,8,9]}, use_linestyles=['-', '--'], magnify=0.1,
|
||||
start_step=0.3, step_size=0.1,path='plot_dict.png', points_values=[True, False], use_title='Title')
|
||||
```
|
||||
|
||||

|
||||
@@ -142,7 +142,7 @@ All arguments are optimized for quick plots. Change the `magnify` arguments to v
|
||||
|
||||
```python
|
||||
>>> from ml_things import plot_confusion_matrix
|
||||
>>> plot_confusion_matrix(y_true=[1,0,1,1,0,1], y_pred=[0,1,1,1,0,1], magnify=0.5, use_title='My Confusion Matrix', path='plot_confusion_matrix.png');
|
||||
>>> plot_confusion_matrix(y_true=[1,0,1,1,0,1], y_pred=[0,1,1,1,0,1], magnify=0.1, use_title='My Confusion Matrix', path='plot_confusion_matrix.png');
|
||||
Confusion matrix, without normalization
|
||||
array([[1, 1],
|
||||
[1, 3]])
|
||||
|
||||
@@ -19,10 +19,17 @@ import numpy as np
|
||||
import warnings
|
||||
from sklearn.metrics import confusion_matrix
|
||||
|
||||
# Maximum allowed magnify. This will get multiplied by 0 - 1 value.
|
||||
MAX_MAGNIFY = 15
|
||||
|
||||
# Increase font for title ratio.
|
||||
TITLE_FONT_RATIO = 1.8
|
||||
|
||||
|
||||
def plot_array(array, start_step=0, step_size=1, use_label=None, use_title=None, points_values=False, points_round=3,
|
||||
use_xlabel=None, use_xticks=True, use_ylabel=None, style_sheet='ggplot', use_grid=True,
|
||||
use_linestyle='-', width=3, height=1, magnify=1.2, use_dpi=50, path=None, show_plot=True):
|
||||
use_xlabel=None,
|
||||
use_xticks=True, use_ylabel=None, style_sheet='ggplot', use_grid=True, use_linestyle='-', width=3,
|
||||
height=1, magnify=1.2, use_dpi=50, path=None, show_plot=True):
|
||||
r"""
|
||||
Create plot from a single array of values.
|
||||
|
||||
@@ -36,6 +43,7 @@ def plot_array(array, start_step=0, step_size=1, use_label=None, use_title=None,
|
||||
the function.
|
||||
|
||||
step_size (:obj:`int`, `optional`, defaults to :obj:`1`):
|
||||
What is the step increase of each point on the x axis.
|
||||
This argument is optional and it has a default value attributed inside the function. It will multiply
|
||||
each x-axis position value to it.
|
||||
|
||||
@@ -52,7 +60,7 @@ def plot_array(array, start_step=0, step_size=1, use_label=None, use_title=None,
|
||||
inside the function.
|
||||
|
||||
points_round (:obj:`int`, `optional`, defaults to :obj:`1`):
|
||||
Round decimal valus for points values. This argument is optional and it has a default value attributed
|
||||
Round decimal values for points values. This argument is optional and it has a default value attributed
|
||||
inside the function.
|
||||
|
||||
use_xlabel (:obj:`str`, `optional`):
|
||||
@@ -60,8 +68,8 @@ def plot_array(array, start_step=0, step_size=1, use_label=None, use_title=None,
|
||||
inside the function.
|
||||
|
||||
use_xticks (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Display x-axis tick values. This argument is optional and it has a default value attributed
|
||||
inside the function.
|
||||
Display x-axis tick values (the values at each point). This argument is optional and it has a default
|
||||
value attributed inside the function.
|
||||
|
||||
use_ylabel (:obj:`str`, `optional`):
|
||||
Label to use for y-axis value meaning. This argument is optional and it will have a `None` value attributed
|
||||
@@ -84,22 +92,22 @@ def plot_array(array, start_step=0, step_size=1, use_label=None, use_title=None,
|
||||
the function.
|
||||
|
||||
height (:obj:`int`, `optional`, defaults to :obj:`1`):
|
||||
Vertical length of plot. This argument is optional and it has a default value attributed inside
|
||||
Height length of plot in inches. This argument is optional and it has a default value attributed inside
|
||||
the function.
|
||||
|
||||
magnify (:obj:`int`, `optional`, defaults to :obj:`1.2`):
|
||||
Vertical length of plot. This argument is optional and it has a default value attributed inside
|
||||
the function.
|
||||
magnify (:obj:`float`, `optional`, defaults to :obj:`0.1`):
|
||||
Ratio increase of both with and height keeping the same ratio size. This argument is optional and it has a
|
||||
default value attributed inside the function.
|
||||
|
||||
use_dpi (:obj:`int`, `optional`, defaults to :obj:`50`):
|
||||
Vertical length of plot. This argument is optional and it has a default value attributed inside
|
||||
the function.
|
||||
Print resolution is measured in dots per inch (or “DPI”). This argument is optional and it has a default
|
||||
value attributed inside the function.
|
||||
|
||||
path (:obj:`str`, `optional`):
|
||||
Vertical length of plot. This argument is optional and it will have a None value attributed inside
|
||||
the function.
|
||||
Path and file name of plot saved as image. If want to save in current path just pass in the file name.
|
||||
This argument is optional and it will have a None value attributed inside the function.
|
||||
|
||||
show_plot (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
show_plot (:obj:`bool`, `optional`, defaults to :obj:`1`):
|
||||
if you want to call `plt.show()`. or not (if you run on a headless server). This argument is optional and
|
||||
it has a default value attributed inside the function.
|
||||
|
||||
@@ -111,70 +119,122 @@ def plot_array(array, start_step=0, step_size=1, use_label=None, use_title=None,
|
||||
|
||||
ValueError: If `use_linestyle` is not valid.
|
||||
|
||||
DeprecationWarning: If `magnify` is se to values that don't belong to [0, 1] values.
|
||||
|
||||
"""
|
||||
|
||||
# check if `array` is correct format
|
||||
# Check if `array` is in correct format.
|
||||
if not isinstance(array, list) or isinstance(array, np.ndarray):
|
||||
# raise value error
|
||||
# raise value error.
|
||||
raise ValueError("`array` needs to be a list of values!")
|
||||
|
||||
# Check if `style_sheet` has correct value.
|
||||
if style_sheet in plt.style.available:
|
||||
# set style of plot
|
||||
# Set style of plot
|
||||
plt.style.use(style_sheet)
|
||||
else:
|
||||
# style is not correct
|
||||
# Style is not correct.
|
||||
raise ValueError("`style_sheet=%s` is not in the supported styles: %s" % (str(style_sheet),
|
||||
str(plt.style.available)))
|
||||
# all linestyles
|
||||
# Make sure `magnify` is in right range.
|
||||
if magnify > 1 or magnify <= 0:
|
||||
# Deprecation warning from last time.
|
||||
warnings.warn(f'`magnify` needs to have value in [0,1]! `{magnify}` will be converted to `0.1` as default.',
|
||||
DeprecationWarning)
|
||||
# Convert to regular value 0.1.
|
||||
magnify = 0.1
|
||||
|
||||
# All allowed linestyles.
|
||||
linestyles = ['-', '--', '-.', ':']
|
||||
|
||||
# check if linestyle is set right
|
||||
# Font variables dictionary. Keep it in this format for future updates.
|
||||
font_dict = dict(
|
||||
family='DejaVu Sans',
|
||||
color='black',
|
||||
weight='normal',
|
||||
# [0.1, 1] - magnify intervals where font size matters
|
||||
# [10.5, 50] - min and max appropriate font sizes
|
||||
size=np.interp(magnify, [0.1, 1], [10.5, 50]),
|
||||
)
|
||||
|
||||
# Check if linestyle is set right.
|
||||
if use_linestyle not in linestyles:
|
||||
# raise error
|
||||
# Raise error.
|
||||
raise ValueError("`linestyle=%s` is not in the styles: %s!" % (str(use_linestyle), str(linestyles)))
|
||||
|
||||
# set steps plotted on x-axis - we can use step if 1 unit has different value
|
||||
|
||||
# Set steps plotted on x-axis - we can use step if 1 unit has different value.
|
||||
if start_step > 0:
|
||||
# Offset all steps by start_step.
|
||||
steps = np.array(range(0, len(array))) * step_size + start_step
|
||||
else:
|
||||
# Keep steps the same.
|
||||
steps = np.array(range(1, len(array) + 1)) * step_size
|
||||
# single plot figure
|
||||
|
||||
# Single plot figure.
|
||||
plt.subplot(1, 2, 1)
|
||||
# plot array as a single line
|
||||
|
||||
# Plot array as a single line
|
||||
plt.plot(steps, array, linestyle=use_linestyle, label=use_label)
|
||||
# Plots points values
|
||||
|
||||
# Plots points values.
|
||||
if points_values:
|
||||
# Loop through each point and plot the label.
|
||||
for x, y in zip(steps, array):
|
||||
# Add text label to plo.
|
||||
plt.text(x, y, str(round(y, points_round)))
|
||||
# set title of figure
|
||||
plt.title(use_title)
|
||||
# set horizontal axis name
|
||||
plt.xlabel(use_xlabel)
|
||||
# Add text label to plot.
|
||||
plt.text(x, y, str(round(y, points_round)), fontdict=font_dict)
|
||||
|
||||
# Set horizontal axis name.
|
||||
plt.xlabel(use_xlabel, fontdict=font_dict)
|
||||
|
||||
# Use x ticks with steps.
|
||||
plt.xticks(steps) if use_xticks else None
|
||||
# set vertical axis name
|
||||
plt.ylabel(use_ylabel)
|
||||
# place legend best position
|
||||
plt.legend(loc='best') if use_label is not None else None
|
||||
# display grid depending on `use_grid`
|
||||
|
||||
# Set vertical axis name.
|
||||
plt.ylabel(use_ylabel, fontdict=font_dict)
|
||||
|
||||
# Place legend best position.
|
||||
plt.legend(loc='best', fontsize=font_dict['size']) if use_label is not None else None
|
||||
|
||||
# Custom label font size for x axis. This is for future updates if needed.
|
||||
# plt.tick_params(axis="x", labelsize=font['size']//2)
|
||||
|
||||
# Custom label font size for y axis. This is for future updates if needed.
|
||||
# plt.tick_params(axis="y", labelsize=font['size']//2)
|
||||
|
||||
# Adjust both axis labels font size at same time.
|
||||
plt.tick_params(labelsize=font_dict['size'])
|
||||
|
||||
# Display grid depending on `use_grid`.
|
||||
plt.grid(use_grid)
|
||||
# make figure nice
|
||||
|
||||
# Make figure nice.
|
||||
plt.tight_layout()
|
||||
# get figure object from plot
|
||||
|
||||
# Adjust font for
|
||||
font_dict['size'] *= TITLE_FONT_RATIO
|
||||
|
||||
# Set title of figure.
|
||||
plt.title(use_title, fontdict=font_dict)
|
||||
|
||||
# Rescale `magnify` to be used on inches.
|
||||
magnify *= MAX_MAGNIFY
|
||||
|
||||
# Get figure object from plot.
|
||||
fig = plt.gcf()
|
||||
# get size of figure
|
||||
|
||||
# Get size of figure.
|
||||
figsize = fig.get_size_inches()
|
||||
# change size depending on height and width variables
|
||||
|
||||
# Change size depending on height and width variables.
|
||||
figsize = [figsize[0] * width * magnify, figsize[1] * height * magnify]
|
||||
# set the new figure size
|
||||
|
||||
# Set the new figure size.
|
||||
fig.set_size_inches(figsize)
|
||||
# save figure to image if path is set
|
||||
|
||||
# Save figure to image if path is set.
|
||||
fig.savefig(path, dpi=use_dpi, bbox_inches='tight') if path is not None else None
|
||||
# show plot
|
||||
|
||||
# Show plot.
|
||||
plt.show() if show_plot is True else None
|
||||
|
||||
return
|
||||
@@ -210,7 +270,7 @@ def plot_dict(dict_arrays, start_step=0, step_size=1, use_title=None, points_val
|
||||
inside the function.
|
||||
|
||||
points_round (:obj:`int`, `optional`, defaults to :obj:`1`):
|
||||
Round decimal valus for points values. This argument is optional and it has a default value attributed
|
||||
Round decimal values for points values. This argument is optional and it has a default value attributed
|
||||
inside the function.
|
||||
|
||||
use_xlabel (:obj:`str`, `optional`):
|
||||
@@ -238,22 +298,22 @@ def plot_dict(dict_arrays, start_step=0, step_size=1, use_title=None, points_val
|
||||
the function.
|
||||
|
||||
height (:obj:`int`, `optional`, defaults to :obj:`1`):
|
||||
Vertical length of plot. This argument is optional and it has a default value attributed inside
|
||||
Height length of plot in inches. This argument is optional and it has a default value attributed inside
|
||||
the function.
|
||||
|
||||
magnify (:obj:`int`, `optional`, defaults to :obj:`1.2`):
|
||||
Vertical length of plot. This argument is optional and it has a default value attributed inside
|
||||
the function.
|
||||
magnify (:obj:`float`, `optional`, defaults to :obj:`0.1`):
|
||||
Ratio increase of both with and height keeping the same ratio size. This argument is optional and it has a
|
||||
default value attributed inside the function.
|
||||
|
||||
use_dpi (:obj:`int`, `optional`, defaults to :obj:`50`):
|
||||
Vertical length of plot. This argument is optional and it has a default value attributed inside
|
||||
the function.
|
||||
Print resolution is measured in dots per inch (or “DPI”). This argument is optional and it has a default
|
||||
value attributed inside the function.
|
||||
|
||||
path (:obj:`str`, `optional`):
|
||||
Vertical length of plot. This argument is optional and it will have a None value attributed inside
|
||||
the function.
|
||||
Path and file name of plot saved as image. If want to save in current path just pass in the file name.
|
||||
This argument is optional and it will have a None value attributed inside the function.
|
||||
|
||||
show_plot (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
show_plot (:obj:`bool`, `optional`, defaults to :obj:`1`):
|
||||
if you want to call `plt.show()`. or not (if you run on a headless server). This argument is optional and
|
||||
it has a default value attributed inside the function.
|
||||
|
||||
@@ -271,87 +331,138 @@ def plot_dict(dict_arrays, start_step=0, step_size=1, use_title=None, points_val
|
||||
|
||||
ValueError: If `points_values`of type list don't have same length as `dict_arrays`.
|
||||
|
||||
DeprecationWarning: If `magnify` is se to values that don't belong to [0, 1] values.
|
||||
|
||||
"""
|
||||
# check if `dict_arrays` is correct format
|
||||
|
||||
# Check if `dict_arrays` is the correct format.
|
||||
if not isinstance(dict_arrays, dict):
|
||||
# raise value error
|
||||
# Raise value error.
|
||||
raise ValueError("`dict_arrays` needs to be a dictionary of values!")
|
||||
|
||||
# Check each label
|
||||
for label, array in dict_arrays.items():
|
||||
# check if format is correct
|
||||
# Check if format is correct.
|
||||
if not isinstance(label, str):
|
||||
# raise value error
|
||||
# Raise value error.
|
||||
raise ValueError("`dict_arrays` needs string keys!")
|
||||
if not isinstance(array, list) or isinstance(array, np.ndarray):
|
||||
# raise value error
|
||||
# Raise value error.
|
||||
raise ValueError("`dict_arrays` needs lists values!")
|
||||
# make sure style sheet is correct
|
||||
|
||||
# Make sure style sheet is correct.
|
||||
if style_sheet in plt.style.available:
|
||||
# set style of plot
|
||||
# Set style of plot
|
||||
plt.style.use(style_sheet)
|
||||
else:
|
||||
# style is not correct
|
||||
# Style is not correct.
|
||||
raise ValueError("`style_sheet=%s` is not in the supported styles: %s" % (str(style_sheet),
|
||||
str(plt.style.available)))
|
||||
# all linestyles
|
||||
|
||||
# Make sure `magnify` is in right range.
|
||||
if magnify > 1 or magnify <= 0:
|
||||
# Deprecation warning from last time.
|
||||
warnings.warn(f'`magnify` needs to have value in [0,1]! `{magnify}` will be converted to `0.1` as default.',
|
||||
DeprecationWarning)
|
||||
# Convert to regular value 0.1.
|
||||
magnify = 0.1
|
||||
|
||||
# all linestyles.
|
||||
linestyles = ['-', '--', '-.', ':']
|
||||
|
||||
# Font variables dictionary. Keep it in this format for future updates.
|
||||
font_dict = dict(
|
||||
family='DejaVu Sans',
|
||||
color='black',
|
||||
weight='normal',
|
||||
# [0.1, 1] - magnify intervals where font size matters
|
||||
# [10.5, 50] - min and max appropriate font sizes
|
||||
size=np.interp(magnify, [0.1, 1], [10.5, 50]),
|
||||
)
|
||||
|
||||
# If single style value is passed, use it on all arrays.
|
||||
if use_linestyles is None:
|
||||
use_linestyles = ['-'] * len(dict_arrays)
|
||||
|
||||
else:
|
||||
# check if linestyle is set right
|
||||
# Check if linestyle is set right.
|
||||
for use_linestyle in use_linestyles:
|
||||
if use_linestyle not in linestyles:
|
||||
# raise error
|
||||
# Raise error.
|
||||
raise ValueError("`linestyle=%s` is not in the styles: %s!" % (str(use_linestyle), str(linestyles)))
|
||||
|
||||
# Check `points_value` type - it can be bool or list(bool)
|
||||
# Check `points_value` type - it can be bool or list(bool).
|
||||
if isinstance(points_values, bool):
|
||||
# convert to list.
|
||||
# Convert to list.
|
||||
points_values = [points_values] * len(dict_arrays)
|
||||
elif isinstance(points_values, list) and (len(points_values) != len(dict_arrays)):
|
||||
raise ValueError('`points_values` of type `list` must have same length as dictionary!')
|
||||
|
||||
# single plot figure
|
||||
# Single plot figure.
|
||||
plt.subplot(1, 2, 1)
|
||||
|
||||
# Plot each array.
|
||||
for index, (use_label, array) in enumerate(dict_arrays.items()):
|
||||
# set steps plotted on x-axis - we can use step if 1 unit has different value
|
||||
# Set steps plotted on x-axis - we can use step if 1 unit has different value.
|
||||
if start_step > 0:
|
||||
# Offset all steps by start_step.
|
||||
steps = np.array(range(0, len(array))) * step_size + start_step
|
||||
else:
|
||||
steps = np.array(range(1, len(array) + 1)) * step_size
|
||||
# plot array as a single line
|
||||
|
||||
# Plot array as a single line.
|
||||
plt.plot(steps, array, linestyle=use_linestyles[index], label=use_label)
|
||||
# Plots points values
|
||||
|
||||
# Plots points values.
|
||||
if points_values[index]:
|
||||
# Loop through each point and plot the label.
|
||||
for x, y in zip(steps, array):
|
||||
# Add text label to plo.
|
||||
plt.text(x, y, str(round(y, points_round)))
|
||||
# set title of figure
|
||||
plt.title(use_title)
|
||||
# set horizontal axis name
|
||||
plt.xlabel(use_xlabel)
|
||||
# set vertical axis name
|
||||
plt.ylabel(use_ylabel)
|
||||
# place legend best position
|
||||
plt.legend(loc='best') if use_label is not None else None
|
||||
# display grid depending on `use_grid`
|
||||
# Add text label to plot.
|
||||
plt.text(x, y, str(round(y, points_round)), fontdict=font_dict)
|
||||
|
||||
# Set horizontal axis name.
|
||||
plt.xlabel(use_xlabel, fontdict=font_dict)
|
||||
|
||||
# Set vertical axis name.
|
||||
plt.ylabel(use_ylabel, fontdict=font_dict)
|
||||
|
||||
# Adjust both axis labels font size at same time.
|
||||
plt.tick_params(labelsize=font_dict['size'])
|
||||
|
||||
# Place legend best position.
|
||||
plt.legend(loc='best', fontsize=font_dict['size'])
|
||||
|
||||
# Adjust font for title.
|
||||
font_dict['size'] *= TITLE_FONT_RATIO
|
||||
|
||||
# Set title of figure.
|
||||
plt.title(use_title, fontdict=font_dict)
|
||||
|
||||
# Rescale `magnify` to be used on inches.
|
||||
magnify *= MAX_MAGNIFY
|
||||
|
||||
# Display grid depending on `use_grid`.
|
||||
plt.grid(use_grid)
|
||||
# make figure nice
|
||||
|
||||
# Make figure nice.
|
||||
plt.tight_layout()
|
||||
# get figure object from plot
|
||||
|
||||
# Get figure object from plot.
|
||||
fig = plt.gcf()
|
||||
# get size of figure
|
||||
|
||||
# Get size of figure.
|
||||
figsize = fig.get_size_inches()
|
||||
# change size depending on height and width variables
|
||||
|
||||
# Change size depending on height and width variables.
|
||||
figsize = [figsize[0] * width * magnify, figsize[1] * height * magnify]
|
||||
# set the new figure size with magnify
|
||||
|
||||
# Set the new figure size with magnify.
|
||||
fig.set_size_inches(figsize)
|
||||
# save figure to image if path is set
|
||||
|
||||
# Save figure to image if path is set.
|
||||
fig.savefig(path, dpi=use_dpi, bbox_inches='tight') if path is not None else None
|
||||
# show plot
|
||||
|
||||
# Show plot.
|
||||
plt.show() if show_plot is True else None
|
||||
|
||||
return
|
||||
@@ -402,22 +513,22 @@ def plot_confusion_matrix(y_true, y_pred, use_title=None, classes='', normalize=
|
||||
the function.
|
||||
|
||||
height (:obj:`int`, `optional`, defaults to :obj:`1`):
|
||||
Vertical length of plot. This argument is optional and it has a default value attributed inside
|
||||
Height length of plot in inches. This argument is optional and it has a default value attributed inside
|
||||
the function.
|
||||
|
||||
magnify (:obj:`int`, `optional`, defaults to :obj:`1.2`):
|
||||
Vertical length of plot. This argument is optional and it has a default value attributed inside
|
||||
the function.
|
||||
magnify (:obj:`float`, `optional`, defaults to :obj:`0.1`):
|
||||
Ratio increase of both with and height keeping the same ratio size. This argument is optional and it has a
|
||||
default value attributed inside the function.
|
||||
|
||||
use_dpi (:obj:`int`, `optional`, defaults to :obj:`50`):
|
||||
Vertical length of plot. This argument is optional and it has a default value attributed inside
|
||||
the function.
|
||||
Print resolution is measured in dots per inch (or “DPI”). This argument is optional and it has a default
|
||||
value attributed inside the function.
|
||||
|
||||
path (:obj:`str`, `optional`):
|
||||
Vertical length of plot. This argument is optional and it will have a None value attributed inside
|
||||
the function.
|
||||
Path and file name of plot saved as image. If want to save in current path just pass in the file name.
|
||||
This argument is optional and it will have a None value attributed inside the function.
|
||||
|
||||
show_plot (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
show_plot (:obj:`bool`, `optional`, defaults to :obj:`1`):
|
||||
if you want to call `plt.show()`. or not (if you run on a headless server). This argument is optional and
|
||||
it has a default value attributed inside the function.
|
||||
|
||||
@@ -445,6 +556,8 @@ def plot_confusion_matrix(y_true, y_pred, use_title=None, classes='', normalize=
|
||||
|
||||
ValueError: If `style_sheet` is not valid.
|
||||
|
||||
DeprecationWarning: If `magnify` is se to values that don't belong to [0, 1] values.
|
||||
|
||||
"""
|
||||
|
||||
# Handle deprecation warnings if `title` is used.
|
||||
@@ -452,6 +565,7 @@ def plot_confusion_matrix(y_true, y_pred, use_title=None, classes='', normalize=
|
||||
# assign same value
|
||||
use_title = kwargs['title']
|
||||
warnings.warn("`title` will be deprecated in future updates. Use `use_title` in stead!", DeprecationWarning)
|
||||
|
||||
# Handle deprecation warnings if `image` is used.
|
||||
if 'image' in kwargs:
|
||||
# assign same value
|
||||
@@ -462,11 +576,13 @@ def plot_confusion_matrix(y_true, y_pred, use_title=None, classes='', normalize=
|
||||
# assign same value
|
||||
use_dpi = kwargs['dpi']
|
||||
warnings.warn("`dpi` will be deprecated in future updates. Use `use_dpi` in stead!", DeprecationWarning)
|
||||
# Make sure labels have right format
|
||||
|
||||
# Make sure labels have right format.
|
||||
if len(y_true) != len(y_pred):
|
||||
# make sure lengths match
|
||||
raise ValueError("`y_true` needs to have same length as `y_pred`!")
|
||||
# make sure style sheet is correct
|
||||
|
||||
# Make sure style sheet is correct.
|
||||
if style_sheet in plt.style.available:
|
||||
# set style of plot
|
||||
plt.style.use(style_sheet)
|
||||
@@ -474,15 +590,36 @@ def plot_confusion_matrix(y_true, y_pred, use_title=None, classes='', normalize=
|
||||
# style is not correct
|
||||
raise ValueError("`style_sheet=%s` is not in the supported styles: %s" % (str(style_sheet),
|
||||
str(plt.style.available)))
|
||||
# Class labels setup. If none, generate from y_true y_pred
|
||||
|
||||
# Make sure `magnify` is in right range.
|
||||
if magnify > 1 or magnify <= 0:
|
||||
# Deprecation warning from last time.
|
||||
warnings.warn(f'`magnify` needs to have value in [0,1]! `{magnify}` will be converted to `0.1` as default.',
|
||||
DeprecationWarning)
|
||||
# Convert to regular value 0.1.
|
||||
magnify = 0.1
|
||||
|
||||
# Font variables dictionary. Keep it in this format for future updates.
|
||||
font_dict = dict(
|
||||
family='DejaVu Sans',
|
||||
color='black',
|
||||
weight='normal',
|
||||
# [0.1, 1] - magnify intervals where font size matters
|
||||
# [10.5, 50] - min and max appropriate font sizes
|
||||
size=np.interp(magnify, [0.1, 1], [10.5, 50]),
|
||||
)
|
||||
|
||||
# Class labels setup. If none, generate from y_true y_pred.
|
||||
classes = list(classes)
|
||||
if classes:
|
||||
assert len(set(y_true)) == len(classes)
|
||||
else:
|
||||
classes = set(y_true)
|
||||
# Compute confusion matrix
|
||||
|
||||
# Compute confusion matrix.
|
||||
cm = confusion_matrix(y_true, y_pred)
|
||||
# Nromalize setup
|
||||
|
||||
# Normalize setup.
|
||||
if normalize is True:
|
||||
print("Normalized confusion matrix")
|
||||
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
|
||||
@@ -490,24 +627,33 @@ def plot_confusion_matrix(y_true, y_pred, use_title=None, classes='', normalize=
|
||||
else:
|
||||
print('Confusion matrix, without normalization')
|
||||
use_title = 'Confusion matrix, without normalization' if use_title is None else use_title
|
||||
# Print if verbose
|
||||
|
||||
# Print if verbose.
|
||||
print(cm) if verbose > 0 else None
|
||||
# Plot setup
|
||||
|
||||
# Plot setup.
|
||||
fig, ax = plt.subplots()
|
||||
im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
|
||||
ax.figure.colorbar(im, ax=ax)
|
||||
# We want to show all ticks...
|
||||
|
||||
# Show all ticks.
|
||||
ax.set(xticks=np.arange(cm.shape[1]),
|
||||
yticks=np.arange(cm.shape[0]),
|
||||
# ... and label them with the respective list entries
|
||||
# Label ticks with the respective list entries.
|
||||
xticklabels=classes, yticklabels=classes,
|
||||
title=use_title,
|
||||
ylabel='True label',
|
||||
xlabel='Predicted label')
|
||||
)
|
||||
|
||||
# Set horizontal axis name.
|
||||
ax.set_xlabel('Predicted label', fontdict=font_dict)
|
||||
|
||||
# Set vertical axis name.
|
||||
ax.set_ylabel('True label', fontdict=font_dict)
|
||||
|
||||
# Rotate the tick labels and set their alignment.
|
||||
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
|
||||
rotation_mode="anchor")
|
||||
plt.grid(False)
|
||||
|
||||
# Loop over data dimensions and create text annotations.
|
||||
fmt = '.2f' if normalize else 'd'
|
||||
thresh = cm.max() / 2.
|
||||
@@ -515,22 +661,42 @@ def plot_confusion_matrix(y_true, y_pred, use_title=None, classes='', normalize=
|
||||
for j in range(cm.shape[1]):
|
||||
ax.text(j, i, format(cm[i, j], fmt),
|
||||
ha="center", va="center",
|
||||
color="white" if cm[i, j] > thresh else "black")
|
||||
color="white" if cm[i, j] > thresh else "black", fontdict=font_dict)
|
||||
|
||||
# Adjust both axis labels font size at same time.
|
||||
plt.tick_params(labelsize=font_dict['size'])
|
||||
|
||||
# Adjust font for title.
|
||||
font_dict['size'] *= TITLE_FONT_RATIO
|
||||
|
||||
# Set title of figure.
|
||||
plt.title(use_title, fontdict=font_dict)
|
||||
|
||||
# Rescale `magnify` to be used on inches.
|
||||
magnify *= MAX_MAGNIFY
|
||||
|
||||
# Never display grid.
|
||||
plt.grid(False)
|
||||
# make figure nice
|
||||
|
||||
# Make figure nice.
|
||||
plt.tight_layout()
|
||||
# get figure object from plot
|
||||
|
||||
# Get figure object from plot.
|
||||
fig = plt.gcf()
|
||||
# get size of figure
|
||||
|
||||
# Get size of figure.
|
||||
figsize = fig.get_size_inches()
|
||||
# change size depending on height and width variables
|
||||
|
||||
# Change size depending on height and width variables.
|
||||
figsize = [figsize[0] * width * magnify, figsize[1] * height * magnify]
|
||||
# set the new figure size with magnify
|
||||
|
||||
# Set the new figure size with magnify.
|
||||
fig.set_size_inches(figsize)
|
||||
# save figure to image if path is set
|
||||
|
||||
# Save figure to image if path is set.
|
||||
fig.savefig(path, dpi=use_dpi, bbox_inches='tight') if path is not None else None
|
||||
# show plot
|
||||
|
||||
# Show plot.
|
||||
plt.show() if show_plot is True else None
|
||||
|
||||
return cm
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 9.8 KiB After Width: | Height: | Size: 14 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 7.0 KiB After Width: | Height: | Size: 8.1 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 12 KiB After Width: | Height: | Size: 16 KiB |
Reference in New Issue
Block a user