mirror of
https://github.com/gmihaila/ml_things.git
synced 2021-10-04 01:29:04 +03:00
updated for plot functions! Fixed error large dpi and added custom font_size
This commit is contained in:
@@ -119,7 +119,7 @@ All arguments are optimized for quick plots. Change the `magnify` arguments to v
|
||||

|
||||
|
||||
|
||||
#### plot_dict [[source]](https://github.com/gmihaila/ml_things/blob/master/src/ml_things/plot_functions.py#L243)
|
||||
#### plot_dict [[source]](https://github.com/gmihaila/ml_things/blob/master/src/ml_things/plot_functions.py#L275)
|
||||
|
||||
Create plot from a single array of values.
|
||||
|
||||
@@ -135,7 +135,7 @@ All arguments are optimized for quick plots. Change the `magnify` arguments to v
|
||||

|
||||
|
||||
|
||||
#### plot_confusion_matrix [[source]](https://github.com/gmihaila/ml_things/blob/master/src/ml_things/plot_functions.py#L471)
|
||||
#### plot_confusion_matrix [[source]](https://github.com/gmihaila/ml_things/blob/master/src/ml_things/plot_functions.py#L529)
|
||||
|
||||
This function prints and plots the confusion matrix. Normalization can be applied by setting `normalize=True`.
|
||||
|
||||
|
||||
@@ -19,6 +19,12 @@ import numpy as np
|
||||
import warnings
|
||||
from sklearn.metrics import confusion_matrix
|
||||
|
||||
# Magnify intervals where font size matters
|
||||
MAGNIFY_INTERVALS = [0.1, 1]
|
||||
|
||||
# Min and max appropriate font sizes
|
||||
FONT_RANGE = [10.5, 50]
|
||||
|
||||
# Maximum allowed magnify. This will get multiplied by 0 - 1 value.
|
||||
MAX_MAGNIFY = 15
|
||||
|
||||
@@ -28,8 +34,8 @@ TITLE_FONT_RATIO = 1.8
|
||||
|
||||
def plot_array(array, start_step=0, step_size=1, use_label=None, use_title=None, points_values=False, points_round=3,
|
||||
use_xlabel=None,
|
||||
use_xticks=True, use_ylabel=None, style_sheet='ggplot', use_grid=True, use_linestyle='-', width=3,
|
||||
height=1, magnify=1.2, use_dpi=50, path=None, show_plot=True):
|
||||
use_xticks=True, use_ylabel=None, style_sheet='ggplot', use_grid=True, use_linestyle='-',
|
||||
font_size=None, width=3, height=1, magnify=1.2, use_dpi=50, path=None, show_plot=True):
|
||||
r"""
|
||||
Create plot from a single array of values.
|
||||
|
||||
@@ -60,7 +66,7 @@ def plot_array(array, start_step=0, step_size=1, use_label=None, use_title=None,
|
||||
inside the function.
|
||||
|
||||
points_round (:obj:`int`, `optional`, defaults to :obj:`1`):
|
||||
Round decimal values for points values. This argument is optional and it has a default value attributed
|
||||
Round decimal valus for points values. This argument is optional and it has a default value attributed
|
||||
inside the function.
|
||||
|
||||
use_xlabel (:obj:`str`, `optional`):
|
||||
@@ -87,6 +93,12 @@ def plot_array(array, start_step=0, step_size=1, use_label=None, use_title=None,
|
||||
Style to use on line from ['-', '--', '-.', ':']. This argument is optional and it has a default
|
||||
value attributed inside the function.
|
||||
|
||||
font_size (:obj:`int` or `float`, `optional`):
|
||||
Font size to use across the plot. By default this function will adjust font size depending on `magnify`
|
||||
value. If this value is set, it will ignore the `magnify` recommended font size. The title font size is by
|
||||
default `1.8` greater than font-size. This argument is optional and it will have a `None` value attributed
|
||||
inside the function.
|
||||
|
||||
width (:obj:`int`, `optional`, defaults to :obj:`3`):
|
||||
Horizontal length of plot. This argument is optional and it has a default value attributed inside
|
||||
the function.
|
||||
@@ -121,6 +133,8 @@ def plot_array(array, start_step=0, step_size=1, use_label=None, use_title=None,
|
||||
|
||||
DeprecationWarning: If `magnify` is se to values that don't belong to [0, 1] values.
|
||||
|
||||
ValueError: If `font_size` is not `None` and smaller or equal to 0.
|
||||
|
||||
"""
|
||||
|
||||
# Check if `array` is in correct format.
|
||||
@@ -147,14 +161,20 @@ def plot_array(array, start_step=0, step_size=1, use_label=None, use_title=None,
|
||||
# All allowed linestyles.
|
||||
linestyles = ['-', '--', '-.', ':']
|
||||
|
||||
# Make sure `font_size` is set right.
|
||||
if (font_size is not None) and (font_size <= 0):
|
||||
# Raise value error - is not correct.
|
||||
raise ValueError(f'`font_size` needs to be positive number! Invalid value {font_size}')
|
||||
|
||||
# Font size select custom or adjusted on `magnify` value.
|
||||
font_size = font_size if font_size is not None else np.interp(magnify, MAGNIFY_INTERVALS, FONT_RANGE)
|
||||
|
||||
# Font variables dictionary. Keep it in this format for future updates.
|
||||
font_dict = dict(
|
||||
family='DejaVu Sans',
|
||||
color='black',
|
||||
weight='normal',
|
||||
# [0.1, 1] - magnify intervals where font size matters
|
||||
# [10.5, 50] - min and max appropriate font sizes
|
||||
size=np.interp(magnify, [0.1, 1], [10.5, 50]),
|
||||
size=font_size,
|
||||
)
|
||||
|
||||
# Check if linestyle is set right.
|
||||
@@ -231,8 +251,20 @@ def plot_array(array, start_step=0, step_size=1, use_label=None, use_title=None,
|
||||
# Set the new figure size.
|
||||
fig.set_size_inches(figsize)
|
||||
|
||||
# Save figure to image if path is set.
|
||||
fig.savefig(path, dpi=use_dpi, bbox_inches='tight') if path is not None else None
|
||||
# There is an error when DPI and plot size are too large!
|
||||
try:
|
||||
# Save figure to image if path is set.
|
||||
fig.savefig(path, dpi=use_dpi, bbox_inches='tight') if path is not None else None
|
||||
except ValueError:
|
||||
# Deprecation warning from last time.
|
||||
warnings.warn(f'`magnify={magnify // 15}` is to big in combination'
|
||||
f' with `use_dpi={use_dpi}`! Try using lower values for'
|
||||
f' `magnify` and/or `use_dpi`. Image was saved in {path}'
|
||||
f' with `use_dpi=50 and `magnify={magnify // 15}`!', Warning)
|
||||
# Set DPI to smaller value and warn user to use smaller magnify or smaller dpi.
|
||||
use_dpi = 50
|
||||
# Save figure to image if path is set.
|
||||
fig.savefig(path, dpi=use_dpi, bbox_inches='tight') if path is not None else None
|
||||
|
||||
# Show plot.
|
||||
plt.show() if show_plot is True else None
|
||||
@@ -242,7 +274,7 @@ def plot_array(array, start_step=0, step_size=1, use_label=None, use_title=None,
|
||||
|
||||
def plot_dict(dict_arrays, start_step=0, step_size=1, use_title=None, points_values=False, points_round=3,
|
||||
use_xlabel=None, use_ylabel=None,
|
||||
style_sheet='ggplot', use_grid=True, use_linestyles=None, width=3, height=1, magnify=1.2,
|
||||
style_sheet='ggplot', use_grid=True, use_linestyles=None, font_size=None, width=3, height=1, magnify=1.2,
|
||||
use_dpi=50, path=None, show_plot=True):
|
||||
r"""
|
||||
Create plot from a single array of values.
|
||||
@@ -270,7 +302,7 @@ def plot_dict(dict_arrays, start_step=0, step_size=1, use_title=None, points_val
|
||||
inside the function.
|
||||
|
||||
points_round (:obj:`int`, `optional`, defaults to :obj:`1`):
|
||||
Round decimal values for points values. This argument is optional and it has a default value attributed
|
||||
Round decimal valus for points values. This argument is optional and it has a default value attributed
|
||||
inside the function.
|
||||
|
||||
use_xlabel (:obj:`str`, `optional`):
|
||||
@@ -293,6 +325,12 @@ def plot_dict(dict_arrays, start_step=0, step_size=1, use_title=None, points_val
|
||||
Style to use on line from ['-', '--', '-.', ':']. This argument is optional and it has a default
|
||||
value attributed inside the function.
|
||||
|
||||
font_size (:obj:`int` or `float`, `optional`):
|
||||
Font size to use across the plot. By default this function will adjust font size depending on `magnify`
|
||||
value. If this value is set, it will ignore the `magnify` recommended font size. The title font size is by
|
||||
default `1.8` greater than font-size. This argument is optional and it will have a `None` value attributed
|
||||
inside the function.
|
||||
|
||||
width (:obj:`int`, `optional`, defaults to :obj:`3`):
|
||||
Horizontal length of plot. This argument is optional and it has a default value attributed inside
|
||||
the function.
|
||||
@@ -333,6 +371,8 @@ def plot_dict(dict_arrays, start_step=0, step_size=1, use_title=None, points_val
|
||||
|
||||
DeprecationWarning: If `magnify` is se to values that don't belong to [0, 1] values.
|
||||
|
||||
ValueError: If `font_size` is not `None` and smaller or equal to 0.
|
||||
|
||||
"""
|
||||
|
||||
# Check if `dict_arrays` is the correct format.
|
||||
@@ -370,14 +410,20 @@ def plot_dict(dict_arrays, start_step=0, step_size=1, use_title=None, points_val
|
||||
# all linestyles.
|
||||
linestyles = ['-', '--', '-.', ':']
|
||||
|
||||
# Make sure `font_size` is set right.
|
||||
if (font_size is not None) and (font_size <= 0):
|
||||
# Raise value error - is not correct.
|
||||
raise ValueError(f'`font_size` needs to be positive number! Invalid value {font_size}')
|
||||
|
||||
# Font size select custom or adjusted on `magnify` value.
|
||||
font_size = font_size if font_size is not None else np.interp(magnify, MAGNIFY_INTERVALS, FONT_RANGE)
|
||||
|
||||
# Font variables dictionary. Keep it in this format for future updates.
|
||||
font_dict = dict(
|
||||
family='DejaVu Sans',
|
||||
color='black',
|
||||
weight='normal',
|
||||
# [0.1, 1] - magnify intervals where font size matters
|
||||
# [10.5, 50] - min and max appropriate font sizes
|
||||
size=np.interp(magnify, [0.1, 1], [10.5, 50]),
|
||||
size=font_size,
|
||||
)
|
||||
|
||||
# If single style value is passed, use it on all arrays.
|
||||
@@ -459,8 +505,20 @@ def plot_dict(dict_arrays, start_step=0, step_size=1, use_title=None, points_val
|
||||
# Set the new figure size with magnify.
|
||||
fig.set_size_inches(figsize)
|
||||
|
||||
# Save figure to image if path is set.
|
||||
fig.savefig(path, dpi=use_dpi, bbox_inches='tight') if path is not None else None
|
||||
# There is an error when DPI and plot size are too large!
|
||||
try:
|
||||
# Save figure to image if path is set.
|
||||
fig.savefig(path, dpi=use_dpi, bbox_inches='tight') if path is not None else None
|
||||
except ValueError:
|
||||
# Deprecation warning from last time.
|
||||
warnings.warn(f'`magnify={magnify // 15}` is to big in combination'
|
||||
f' with `use_dpi={use_dpi}`! Try using lower values for'
|
||||
f' `magnify` and/or `use_dpi`. Image was saved in {path}'
|
||||
f' with `use_dpi=50 and `magnify={magnify // 15}`!', Warning)
|
||||
# Set DPI to smaller value and warn user to use smaller magnify or smaller dpi.
|
||||
use_dpi = 50
|
||||
# Save figure to image if path is set.
|
||||
fig.savefig(path, dpi=use_dpi, bbox_inches='tight') if path is not None else None
|
||||
|
||||
# Show plot.
|
||||
plt.show() if show_plot is True else None
|
||||
@@ -469,7 +527,8 @@ def plot_dict(dict_arrays, start_step=0, step_size=1, use_title=None, points_val
|
||||
|
||||
|
||||
def plot_confusion_matrix(y_true, y_pred, use_title=None, classes='', normalize=False, style_sheet='ggplot',
|
||||
cmap=plt.cm.Blues, verbose=0, width=3, height=1, magnify=1.2, use_dpi=50, path=None,
|
||||
cmap=plt.cm.Blues, font_size=None, verbose=0, width=3, height=1, magnify=1.2, use_dpi=50,
|
||||
path=None,
|
||||
show_plot=True, **kwargs):
|
||||
r"""
|
||||
This function prints and plots the confusion matrix.
|
||||
@@ -504,6 +563,12 @@ def plot_confusion_matrix(y_true, y_pred, use_title=None, classes='', normalize=
|
||||
It is a plt.cm plot theme. Plot themes: `plt.cm.Blues`, `plt.cm.BuPu`, `plt.cm.GnBu`, `plt.cm.Greens`,
|
||||
`plt.cm.OrRd`. This argument is optional and it has a default value attributed inside the function.
|
||||
|
||||
font_size (:obj:`int` or `float`, `optional`):
|
||||
Font size to use across the plot. By default this function will adjust font size depending on `magnify`
|
||||
value. If this value is set, it will ignore the `magnify` recommended font size. The title font size is by
|
||||
default `1.8` greater than font-size. This argument is optional and it will have a `None` value attributed
|
||||
inside the function.
|
||||
|
||||
verbose (:obj:`int`, `optional`, defaults to :obj:`0`):
|
||||
To display confusion matrix value or not if set > 0. This argument is optional and it has a default
|
||||
value attributed inside the function.
|
||||
@@ -558,6 +623,8 @@ def plot_confusion_matrix(y_true, y_pred, use_title=None, classes='', normalize=
|
||||
|
||||
DeprecationWarning: If `magnify` is se to values that don't belong to [0, 1] values.
|
||||
|
||||
ValueError: If `font_size` is not `None` and smaller or equal to 0.
|
||||
|
||||
"""
|
||||
|
||||
# Handle deprecation warnings if `title` is used.
|
||||
@@ -599,14 +666,20 @@ def plot_confusion_matrix(y_true, y_pred, use_title=None, classes='', normalize=
|
||||
# Convert to regular value 0.1.
|
||||
magnify = 0.1
|
||||
|
||||
# Make sure `font_size` is set right.
|
||||
if (font_size is not None) and (font_size <= 0):
|
||||
# Raise value error - is not correct.
|
||||
raise ValueError(f'`font_size` needs to be positive number! Invalid value {font_size}')
|
||||
|
||||
# Font size select custom or adjusted on `magnify` value.
|
||||
font_size = font_size if font_size is not None else np.interp(magnify, MAGNIFY_INTERVALS, FONT_RANGE)
|
||||
|
||||
# Font variables dictionary. Keep it in this format for future updates.
|
||||
font_dict = dict(
|
||||
family='DejaVu Sans',
|
||||
color='black',
|
||||
weight='normal',
|
||||
# [0.1, 1] - magnify intervals where font size matters
|
||||
# [10.5, 50] - min and max appropriate font sizes
|
||||
size=np.interp(magnify, [0.1, 1], [10.5, 50]),
|
||||
size=font_size,
|
||||
)
|
||||
|
||||
# Class labels setup. If none, generate from y_true y_pred.
|
||||
@@ -693,8 +766,20 @@ def plot_confusion_matrix(y_true, y_pred, use_title=None, classes='', normalize=
|
||||
# Set the new figure size with magnify.
|
||||
fig.set_size_inches(figsize)
|
||||
|
||||
# Save figure to image if path is set.
|
||||
fig.savefig(path, dpi=use_dpi, bbox_inches='tight') if path is not None else None
|
||||
# There is an error when DPI and plot size are too large!
|
||||
try:
|
||||
# Save figure to image if path is set.
|
||||
fig.savefig(path, dpi=use_dpi, bbox_inches='tight') if path is not None else None
|
||||
except ValueError:
|
||||
# Deprecation warning from last time.
|
||||
warnings.warn(f'`magnify={magnify // 15}` is to big in combination'
|
||||
f' with `use_dpi={use_dpi}`! Try using lower values for'
|
||||
f' `magnify` and/or `use_dpi`. Image was saved in {path}'
|
||||
f' with `use_dpi=50 and `magnify={magnify // 15}`!', Warning)
|
||||
# Set DPI to smaller value and warn user to use smaller magnify or smaller dpi.
|
||||
use_dpi = 50
|
||||
# Save figure to image if path is set.
|
||||
fig.savefig(path, dpi=use_dpi, bbox_inches='tight') if path is not None else None
|
||||
|
||||
# Show plot.
|
||||
plt.show() if show_plot is True else None
|
||||
|
||||
Reference in New Issue
Block a user