updated for plot functions! Fixed error large dpi and added custom font_size

This commit is contained in:
georgemihaila
2020-10-30 15:58:48 -05:00
parent a498d8fdbb
commit 6eb84f0fe9
2 changed files with 108 additions and 23 deletions

View File

@@ -119,7 +119,7 @@ All arguments are optimized for quick plots. Change the `magnify` arguments to v
![plot_array](https://github.com/gmihaila/ml_things/blob/master/tests/test_samples/plot_array.png)
#### plot_dict [[source]](https://github.com/gmihaila/ml_things/blob/master/src/ml_things/plot_functions.py#L243)
#### plot_dict [[source]](https://github.com/gmihaila/ml_things/blob/master/src/ml_things/plot_functions.py#L275)
Create plot from a single array of values.
@@ -135,7 +135,7 @@ All arguments are optimized for quick plots. Change the `magnify` arguments to v
![plot_dict](https://github.com/gmihaila/ml_things/blob/master/tests/test_samples/plot_dict.png)
#### plot_confusion_matrix [[source]](https://github.com/gmihaila/ml_things/blob/master/src/ml_things/plot_functions.py#L471)
#### plot_confusion_matrix [[source]](https://github.com/gmihaila/ml_things/blob/master/src/ml_things/plot_functions.py#L529)
This function prints and plots the confusion matrix. Normalization can be applied by setting `normalize=True`.

View File

@@ -19,6 +19,12 @@ import numpy as np
import warnings
from sklearn.metrics import confusion_matrix
# Magnify intervals where font size matters
MAGNIFY_INTERVALS = [0.1, 1]
# Min and max appropriate font sizes
FONT_RANGE = [10.5, 50]
# Maximum allowed magnify. This will get multiplied by 0 - 1 value.
MAX_MAGNIFY = 15
@@ -28,8 +34,8 @@ TITLE_FONT_RATIO = 1.8
def plot_array(array, start_step=0, step_size=1, use_label=None, use_title=None, points_values=False, points_round=3,
use_xlabel=None,
use_xticks=True, use_ylabel=None, style_sheet='ggplot', use_grid=True, use_linestyle='-', width=3,
height=1, magnify=1.2, use_dpi=50, path=None, show_plot=True):
use_xticks=True, use_ylabel=None, style_sheet='ggplot', use_grid=True, use_linestyle='-',
font_size=None, width=3, height=1, magnify=1.2, use_dpi=50, path=None, show_plot=True):
r"""
Create plot from a single array of values.
@@ -60,7 +66,7 @@ def plot_array(array, start_step=0, step_size=1, use_label=None, use_title=None,
inside the function.
points_round (:obj:`int`, `optional`, defaults to :obj:`1`):
Round decimal values for points values. This argument is optional and it has a default value attributed
Round decimal valus for points values. This argument is optional and it has a default value attributed
inside the function.
use_xlabel (:obj:`str`, `optional`):
@@ -87,6 +93,12 @@ def plot_array(array, start_step=0, step_size=1, use_label=None, use_title=None,
Style to use on line from ['-', '--', '-.', ':']. This argument is optional and it has a default
value attributed inside the function.
font_size (:obj:`int` or `float`, `optional`):
Font size to use across the plot. By default this function will adjust font size depending on `magnify`
value. If this value is set, it will ignore the `magnify` recommended font size. The title font size is by
default `1.8` greater than font-size. This argument is optional and it will have a `None` value attributed
inside the function.
width (:obj:`int`, `optional`, defaults to :obj:`3`):
Horizontal length of plot. This argument is optional and it has a default value attributed inside
the function.
@@ -121,6 +133,8 @@ def plot_array(array, start_step=0, step_size=1, use_label=None, use_title=None,
DeprecationWarning: If `magnify` is se to values that don't belong to [0, 1] values.
ValueError: If `font_size` is not `None` and smaller or equal to 0.
"""
# Check if `array` is in correct format.
@@ -147,14 +161,20 @@ def plot_array(array, start_step=0, step_size=1, use_label=None, use_title=None,
# All allowed linestyles.
linestyles = ['-', '--', '-.', ':']
# Make sure `font_size` is set right.
if (font_size is not None) and (font_size <= 0):
# Raise value error - is not correct.
raise ValueError(f'`font_size` needs to be positive number! Invalid value {font_size}')
# Font size select custom or adjusted on `magnify` value.
font_size = font_size if font_size is not None else np.interp(magnify, MAGNIFY_INTERVALS, FONT_RANGE)
# Font variables dictionary. Keep it in this format for future updates.
font_dict = dict(
family='DejaVu Sans',
color='black',
weight='normal',
# [0.1, 1] - magnify intervals where font size matters
# [10.5, 50] - min and max appropriate font sizes
size=np.interp(magnify, [0.1, 1], [10.5, 50]),
size=font_size,
)
# Check if linestyle is set right.
@@ -231,8 +251,20 @@ def plot_array(array, start_step=0, step_size=1, use_label=None, use_title=None,
# Set the new figure size.
fig.set_size_inches(figsize)
# Save figure to image if path is set.
fig.savefig(path, dpi=use_dpi, bbox_inches='tight') if path is not None else None
# There is an error when DPI and plot size are too large!
try:
# Save figure to image if path is set.
fig.savefig(path, dpi=use_dpi, bbox_inches='tight') if path is not None else None
except ValueError:
# Deprecation warning from last time.
warnings.warn(f'`magnify={magnify // 15}` is to big in combination'
f' with `use_dpi={use_dpi}`! Try using lower values for'
f' `magnify` and/or `use_dpi`. Image was saved in {path}'
f' with `use_dpi=50 and `magnify={magnify // 15}`!', Warning)
# Set DPI to smaller value and warn user to use smaller magnify or smaller dpi.
use_dpi = 50
# Save figure to image if path is set.
fig.savefig(path, dpi=use_dpi, bbox_inches='tight') if path is not None else None
# Show plot.
plt.show() if show_plot is True else None
@@ -242,7 +274,7 @@ def plot_array(array, start_step=0, step_size=1, use_label=None, use_title=None,
def plot_dict(dict_arrays, start_step=0, step_size=1, use_title=None, points_values=False, points_round=3,
use_xlabel=None, use_ylabel=None,
style_sheet='ggplot', use_grid=True, use_linestyles=None, width=3, height=1, magnify=1.2,
style_sheet='ggplot', use_grid=True, use_linestyles=None, font_size=None, width=3, height=1, magnify=1.2,
use_dpi=50, path=None, show_plot=True):
r"""
Create plot from a single array of values.
@@ -270,7 +302,7 @@ def plot_dict(dict_arrays, start_step=0, step_size=1, use_title=None, points_val
inside the function.
points_round (:obj:`int`, `optional`, defaults to :obj:`1`):
Round decimal values for points values. This argument is optional and it has a default value attributed
Round decimal valus for points values. This argument is optional and it has a default value attributed
inside the function.
use_xlabel (:obj:`str`, `optional`):
@@ -293,6 +325,12 @@ def plot_dict(dict_arrays, start_step=0, step_size=1, use_title=None, points_val
Style to use on line from ['-', '--', '-.', ':']. This argument is optional and it has a default
value attributed inside the function.
font_size (:obj:`int` or `float`, `optional`):
Font size to use across the plot. By default this function will adjust font size depending on `magnify`
value. If this value is set, it will ignore the `magnify` recommended font size. The title font size is by
default `1.8` greater than font-size. This argument is optional and it will have a `None` value attributed
inside the function.
width (:obj:`int`, `optional`, defaults to :obj:`3`):
Horizontal length of plot. This argument is optional and it has a default value attributed inside
the function.
@@ -333,6 +371,8 @@ def plot_dict(dict_arrays, start_step=0, step_size=1, use_title=None, points_val
DeprecationWarning: If `magnify` is se to values that don't belong to [0, 1] values.
ValueError: If `font_size` is not `None` and smaller or equal to 0.
"""
# Check if `dict_arrays` is the correct format.
@@ -370,14 +410,20 @@ def plot_dict(dict_arrays, start_step=0, step_size=1, use_title=None, points_val
# all linestyles.
linestyles = ['-', '--', '-.', ':']
# Make sure `font_size` is set right.
if (font_size is not None) and (font_size <= 0):
# Raise value error - is not correct.
raise ValueError(f'`font_size` needs to be positive number! Invalid value {font_size}')
# Font size select custom or adjusted on `magnify` value.
font_size = font_size if font_size is not None else np.interp(magnify, MAGNIFY_INTERVALS, FONT_RANGE)
# Font variables dictionary. Keep it in this format for future updates.
font_dict = dict(
family='DejaVu Sans',
color='black',
weight='normal',
# [0.1, 1] - magnify intervals where font size matters
# [10.5, 50] - min and max appropriate font sizes
size=np.interp(magnify, [0.1, 1], [10.5, 50]),
size=font_size,
)
# If single style value is passed, use it on all arrays.
@@ -459,8 +505,20 @@ def plot_dict(dict_arrays, start_step=0, step_size=1, use_title=None, points_val
# Set the new figure size with magnify.
fig.set_size_inches(figsize)
# Save figure to image if path is set.
fig.savefig(path, dpi=use_dpi, bbox_inches='tight') if path is not None else None
# There is an error when DPI and plot size are too large!
try:
# Save figure to image if path is set.
fig.savefig(path, dpi=use_dpi, bbox_inches='tight') if path is not None else None
except ValueError:
# Deprecation warning from last time.
warnings.warn(f'`magnify={magnify // 15}` is to big in combination'
f' with `use_dpi={use_dpi}`! Try using lower values for'
f' `magnify` and/or `use_dpi`. Image was saved in {path}'
f' with `use_dpi=50 and `magnify={magnify // 15}`!', Warning)
# Set DPI to smaller value and warn user to use smaller magnify or smaller dpi.
use_dpi = 50
# Save figure to image if path is set.
fig.savefig(path, dpi=use_dpi, bbox_inches='tight') if path is not None else None
# Show plot.
plt.show() if show_plot is True else None
@@ -469,7 +527,8 @@ def plot_dict(dict_arrays, start_step=0, step_size=1, use_title=None, points_val
def plot_confusion_matrix(y_true, y_pred, use_title=None, classes='', normalize=False, style_sheet='ggplot',
cmap=plt.cm.Blues, verbose=0, width=3, height=1, magnify=1.2, use_dpi=50, path=None,
cmap=plt.cm.Blues, font_size=None, verbose=0, width=3, height=1, magnify=1.2, use_dpi=50,
path=None,
show_plot=True, **kwargs):
r"""
This function prints and plots the confusion matrix.
@@ -504,6 +563,12 @@ def plot_confusion_matrix(y_true, y_pred, use_title=None, classes='', normalize=
It is a plt.cm plot theme. Plot themes: `plt.cm.Blues`, `plt.cm.BuPu`, `plt.cm.GnBu`, `plt.cm.Greens`,
`plt.cm.OrRd`. This argument is optional and it has a default value attributed inside the function.
font_size (:obj:`int` or `float`, `optional`):
Font size to use across the plot. By default this function will adjust font size depending on `magnify`
value. If this value is set, it will ignore the `magnify` recommended font size. The title font size is by
default `1.8` greater than font-size. This argument is optional and it will have a `None` value attributed
inside the function.
verbose (:obj:`int`, `optional`, defaults to :obj:`0`):
To display confusion matrix value or not if set > 0. This argument is optional and it has a default
value attributed inside the function.
@@ -558,6 +623,8 @@ def plot_confusion_matrix(y_true, y_pred, use_title=None, classes='', normalize=
DeprecationWarning: If `magnify` is se to values that don't belong to [0, 1] values.
ValueError: If `font_size` is not `None` and smaller or equal to 0.
"""
# Handle deprecation warnings if `title` is used.
@@ -599,14 +666,20 @@ def plot_confusion_matrix(y_true, y_pred, use_title=None, classes='', normalize=
# Convert to regular value 0.1.
magnify = 0.1
# Make sure `font_size` is set right.
if (font_size is not None) and (font_size <= 0):
# Raise value error - is not correct.
raise ValueError(f'`font_size` needs to be positive number! Invalid value {font_size}')
# Font size select custom or adjusted on `magnify` value.
font_size = font_size if font_size is not None else np.interp(magnify, MAGNIFY_INTERVALS, FONT_RANGE)
# Font variables dictionary. Keep it in this format for future updates.
font_dict = dict(
family='DejaVu Sans',
color='black',
weight='normal',
# [0.1, 1] - magnify intervals where font size matters
# [10.5, 50] - min and max appropriate font sizes
size=np.interp(magnify, [0.1, 1], [10.5, 50]),
size=font_size,
)
# Class labels setup. If none, generate from y_true y_pred.
@@ -693,8 +766,20 @@ def plot_confusion_matrix(y_true, y_pred, use_title=None, classes='', normalize=
# Set the new figure size with magnify.
fig.set_size_inches(figsize)
# Save figure to image if path is set.
fig.savefig(path, dpi=use_dpi, bbox_inches='tight') if path is not None else None
# There is an error when DPI and plot size are too large!
try:
# Save figure to image if path is set.
fig.savefig(path, dpi=use_dpi, bbox_inches='tight') if path is not None else None
except ValueError:
# Deprecation warning from last time.
warnings.warn(f'`magnify={magnify // 15}` is to big in combination'
f' with `use_dpi={use_dpi}`! Try using lower values for'
f' `magnify` and/or `use_dpi`. Image was saved in {path}'
f' with `use_dpi=50 and `magnify={magnify // 15}`!', Warning)
# Set DPI to smaller value and warn user to use smaller magnify or smaller dpi.
use_dpi = 50
# Save figure to image if path is set.
fig.savefig(path, dpi=use_dpi, bbox_inches='tight') if path is not None else None
# Show plot.
plt.show() if show_plot is True else None