adjustable magnify of font for plots

This commit is contained in:
georgemihaila
2020-10-29 14:41:44 -05:00
parent 4df9d407ed
commit d84ae78218
5 changed files with 288 additions and 122 deletions

View File

@@ -112,7 +112,7 @@ All arguments are optimized for quick plots. Change the `magnify` arguments to v
```python
>>> from ml_things import plot_array
>>> plot_array([1,3,5,3,7,5,8,10], path='plot_array.png', magnify=0.5, use_title='A Random Plot', start_step=0.3, step_size=0.1, points_values=True)
>>> plot_array([1,3,5,3,7,5,8,10], path='plot_array.png', magnify=0.1, use_title='A Random Plot', start_step=0.3, step_size=0.1, points_values=True, use_ylabel='Thid', use_xlabel='This')
```
![plot_array](https://github.com/gmihaila/ml_things/raw/master/tests/test_samples/plot_array.png)
@@ -127,8 +127,8 @@ All arguments are optimized for quick plots. Change the `magnify` arguments to v
```python
>>> from ml_things import plot_dict
>>> plot_dict({'train_acc':[1,3,5,3,7,5,8,10],
'valid_acc':[4,8,9]}, use_linestyles=['-', '--'], magnify=0.5,
start_step=0.3, step_size=0.1,path='plot_dict.png', points_values=[True, False])
'valid_acc':[4,8,9]}, use_linestyles=['-', '--'], magnify=0.1,
start_step=0.3, step_size=0.1,path='plot_dict.png', points_values=[True, False], use_title='Title')
```
![plot_dict](https://github.com/gmihaila/ml_things/raw/efb2574a9935c6a6ef62135efba2d965b2044175/tests/test_samples/plot_dict.png)
@@ -142,7 +142,7 @@ All arguments are optimized for quick plots. Change the `magnify` arguments to v
```python
>>> from ml_things import plot_confusion_matrix
>>> plot_confusion_matrix(y_true=[1,0,1,1,0,1], y_pred=[0,1,1,1,0,1], magnify=0.5, use_title='My Confusion Matrix', path='plot_confusion_matrix.png');
>>> plot_confusion_matrix(y_true=[1,0,1,1,0,1], y_pred=[0,1,1,1,0,1], magnify=0.1, use_title='My Confusion Matrix', path='plot_confusion_matrix.png');
Confusion matrix, without normalization
array([[1, 1],
[1, 3]])

View File

@@ -19,10 +19,17 @@ import numpy as np
import warnings
from sklearn.metrics import confusion_matrix
# Maximum allowed magnify. This will get multiplied by 0 - 1 value.
MAX_MAGNIFY = 15
# Increase font for title ratio.
TITLE_FONT_RATIO = 1.8
def plot_array(array, start_step=0, step_size=1, use_label=None, use_title=None, points_values=False, points_round=3,
use_xlabel=None, use_xticks=True, use_ylabel=None, style_sheet='ggplot', use_grid=True,
use_linestyle='-', width=3, height=1, magnify=1.2, use_dpi=50, path=None, show_plot=True):
use_xlabel=None,
use_xticks=True, use_ylabel=None, style_sheet='ggplot', use_grid=True, use_linestyle='-', width=3,
height=1, magnify=1.2, use_dpi=50, path=None, show_plot=True):
r"""
Create plot from a single array of values.
@@ -36,6 +43,7 @@ def plot_array(array, start_step=0, step_size=1, use_label=None, use_title=None,
the function.
step_size (:obj:`int`, `optional`, defaults to :obj:`1`):
What is the step increase of each point on the x axis.
This argument is optional and it has a default value attributed inside the function. It will multiply
each x-axis position value to it.
@@ -52,7 +60,7 @@ def plot_array(array, start_step=0, step_size=1, use_label=None, use_title=None,
inside the function.
points_round (:obj:`int`, `optional`, defaults to :obj:`1`):
Round decimal valus for points values. This argument is optional and it has a default value attributed
Round decimal values for points values. This argument is optional and it has a default value attributed
inside the function.
use_xlabel (:obj:`str`, `optional`):
@@ -60,8 +68,8 @@ def plot_array(array, start_step=0, step_size=1, use_label=None, use_title=None,
inside the function.
use_xticks (:obj:`bool`, `optional`, defaults to :obj:`True`):
Display x-axis tick values. This argument is optional and it has a default value attributed
inside the function.
Display x-axis tick values (the values at each point). This argument is optional and it has a default
value attributed inside the function.
use_ylabel (:obj:`str`, `optional`):
Label to use for y-axis value meaning. This argument is optional and it will have a `None` value attributed
@@ -84,22 +92,22 @@ def plot_array(array, start_step=0, step_size=1, use_label=None, use_title=None,
the function.
height (:obj:`int`, `optional`, defaults to :obj:`1`):
Vertical length of plot. This argument is optional and it has a default value attributed inside
Height length of plot in inches. This argument is optional and it has a default value attributed inside
the function.
magnify (:obj:`int`, `optional`, defaults to :obj:`1.2`):
Vertical length of plot. This argument is optional and it has a default value attributed inside
the function.
magnify (:obj:`float`, `optional`, defaults to :obj:`0.1`):
Ratio increase of both with and height keeping the same ratio size. This argument is optional and it has a
default value attributed inside the function.
use_dpi (:obj:`int`, `optional`, defaults to :obj:`50`):
Vertical length of plot. This argument is optional and it has a default value attributed inside
the function.
Print resolution is measured in dots per inch (or “DPI”). This argument is optional and it has a default
value attributed inside the function.
path (:obj:`str`, `optional`):
Vertical length of plot. This argument is optional and it will have a None value attributed inside
the function.
Path and file name of plot saved as image. If want to save in current path just pass in the file name.
This argument is optional and it will have a None value attributed inside the function.
show_plot (:obj:`bool`, `optional`, defaults to :obj:`True`):
show_plot (:obj:`bool`, `optional`, defaults to :obj:`1`):
if you want to call `plt.show()`. or not (if you run on a headless server). This argument is optional and
it has a default value attributed inside the function.
@@ -111,70 +119,122 @@ def plot_array(array, start_step=0, step_size=1, use_label=None, use_title=None,
ValueError: If `use_linestyle` is not valid.
DeprecationWarning: If `magnify` is se to values that don't belong to [0, 1] values.
"""
# check if `array` is correct format
# Check if `array` is in correct format.
if not isinstance(array, list) or isinstance(array, np.ndarray):
# raise value error
# raise value error.
raise ValueError("`array` needs to be a list of values!")
# Check if `style_sheet` has correct value.
if style_sheet in plt.style.available:
# set style of plot
# Set style of plot
plt.style.use(style_sheet)
else:
# style is not correct
# Style is not correct.
raise ValueError("`style_sheet=%s` is not in the supported styles: %s" % (str(style_sheet),
str(plt.style.available)))
# all linestyles
# Make sure `magnify` is in right range.
if magnify > 1 or magnify <= 0:
# Deprecation warning from last time.
warnings.warn(f'`magnify` needs to have value in [0,1]! `{magnify}` will be converted to `0.1` as default.',
DeprecationWarning)
# Convert to regular value 0.1.
magnify = 0.1
# All allowed linestyles.
linestyles = ['-', '--', '-.', ':']
# check if linestyle is set right
# Font variables dictionary. Keep it in this format for future updates.
font_dict = dict(
family='DejaVu Sans',
color='black',
weight='normal',
# [0.1, 1] - magnify intervals where font size matters
# [10.5, 50] - min and max appropriate font sizes
size=np.interp(magnify, [0.1, 1], [10.5, 50]),
)
# Check if linestyle is set right.
if use_linestyle not in linestyles:
# raise error
# Raise error.
raise ValueError("`linestyle=%s` is not in the styles: %s!" % (str(use_linestyle), str(linestyles)))
# set steps plotted on x-axis - we can use step if 1 unit has different value
# Set steps plotted on x-axis - we can use step if 1 unit has different value.
if start_step > 0:
# Offset all steps by start_step.
steps = np.array(range(0, len(array))) * step_size + start_step
else:
# Keep steps the same.
steps = np.array(range(1, len(array) + 1)) * step_size
# single plot figure
# Single plot figure.
plt.subplot(1, 2, 1)
# plot array as a single line
# Plot array as a single line
plt.plot(steps, array, linestyle=use_linestyle, label=use_label)
# Plots points values
# Plots points values.
if points_values:
# Loop through each point and plot the label.
for x, y in zip(steps, array):
# Add text label to plo.
plt.text(x, y, str(round(y, points_round)))
# set title of figure
plt.title(use_title)
# set horizontal axis name
plt.xlabel(use_xlabel)
# Add text label to plot.
plt.text(x, y, str(round(y, points_round)), fontdict=font_dict)
# Set horizontal axis name.
plt.xlabel(use_xlabel, fontdict=font_dict)
# Use x ticks with steps.
plt.xticks(steps) if use_xticks else None
# set vertical axis name
plt.ylabel(use_ylabel)
# place legend best position
plt.legend(loc='best') if use_label is not None else None
# display grid depending on `use_grid`
# Set vertical axis name.
plt.ylabel(use_ylabel, fontdict=font_dict)
# Place legend best position.
plt.legend(loc='best', fontsize=font_dict['size']) if use_label is not None else None
# Custom label font size for x axis. This is for future updates if needed.
# plt.tick_params(axis="x", labelsize=font['size']//2)
# Custom label font size for y axis. This is for future updates if needed.
# plt.tick_params(axis="y", labelsize=font['size']//2)
# Adjust both axis labels font size at same time.
plt.tick_params(labelsize=font_dict['size'])
# Display grid depending on `use_grid`.
plt.grid(use_grid)
# make figure nice
# Make figure nice.
plt.tight_layout()
# get figure object from plot
# Adjust font for
font_dict['size'] *= TITLE_FONT_RATIO
# Set title of figure.
plt.title(use_title, fontdict=font_dict)
# Rescale `magnify` to be used on inches.
magnify *= MAX_MAGNIFY
# Get figure object from plot.
fig = plt.gcf()
# get size of figure
# Get size of figure.
figsize = fig.get_size_inches()
# change size depending on height and width variables
# Change size depending on height and width variables.
figsize = [figsize[0] * width * magnify, figsize[1] * height * magnify]
# set the new figure size
# Set the new figure size.
fig.set_size_inches(figsize)
# save figure to image if path is set
# Save figure to image if path is set.
fig.savefig(path, dpi=use_dpi, bbox_inches='tight') if path is not None else None
# show plot
# Show plot.
plt.show() if show_plot is True else None
return
@@ -210,7 +270,7 @@ def plot_dict(dict_arrays, start_step=0, step_size=1, use_title=None, points_val
inside the function.
points_round (:obj:`int`, `optional`, defaults to :obj:`1`):
Round decimal valus for points values. This argument is optional and it has a default value attributed
Round decimal values for points values. This argument is optional and it has a default value attributed
inside the function.
use_xlabel (:obj:`str`, `optional`):
@@ -238,22 +298,22 @@ def plot_dict(dict_arrays, start_step=0, step_size=1, use_title=None, points_val
the function.
height (:obj:`int`, `optional`, defaults to :obj:`1`):
Vertical length of plot. This argument is optional and it has a default value attributed inside
Height length of plot in inches. This argument is optional and it has a default value attributed inside
the function.
magnify (:obj:`int`, `optional`, defaults to :obj:`1.2`):
Vertical length of plot. This argument is optional and it has a default value attributed inside
the function.
magnify (:obj:`float`, `optional`, defaults to :obj:`0.1`):
Ratio increase of both with and height keeping the same ratio size. This argument is optional and it has a
default value attributed inside the function.
use_dpi (:obj:`int`, `optional`, defaults to :obj:`50`):
Vertical length of plot. This argument is optional and it has a default value attributed inside
the function.
Print resolution is measured in dots per inch (or “DPI”). This argument is optional and it has a default
value attributed inside the function.
path (:obj:`str`, `optional`):
Vertical length of plot. This argument is optional and it will have a None value attributed inside
the function.
Path and file name of plot saved as image. If want to save in current path just pass in the file name.
This argument is optional and it will have a None value attributed inside the function.
show_plot (:obj:`bool`, `optional`, defaults to :obj:`True`):
show_plot (:obj:`bool`, `optional`, defaults to :obj:`1`):
if you want to call `plt.show()`. or not (if you run on a headless server). This argument is optional and
it has a default value attributed inside the function.
@@ -271,87 +331,138 @@ def plot_dict(dict_arrays, start_step=0, step_size=1, use_title=None, points_val
ValueError: If `points_values`of type list don't have same length as `dict_arrays`.
DeprecationWarning: If `magnify` is se to values that don't belong to [0, 1] values.
"""
# check if `dict_arrays` is correct format
# Check if `dict_arrays` is the correct format.
if not isinstance(dict_arrays, dict):
# raise value error
# Raise value error.
raise ValueError("`dict_arrays` needs to be a dictionary of values!")
# Check each label
for label, array in dict_arrays.items():
# check if format is correct
# Check if format is correct.
if not isinstance(label, str):
# raise value error
# Raise value error.
raise ValueError("`dict_arrays` needs string keys!")
if not isinstance(array, list) or isinstance(array, np.ndarray):
# raise value error
# Raise value error.
raise ValueError("`dict_arrays` needs lists values!")
# make sure style sheet is correct
# Make sure style sheet is correct.
if style_sheet in plt.style.available:
# set style of plot
# Set style of plot
plt.style.use(style_sheet)
else:
# style is not correct
# Style is not correct.
raise ValueError("`style_sheet=%s` is not in the supported styles: %s" % (str(style_sheet),
str(plt.style.available)))
# all linestyles
# Make sure `magnify` is in right range.
if magnify > 1 or magnify <= 0:
# Deprecation warning from last time.
warnings.warn(f'`magnify` needs to have value in [0,1]! `{magnify}` will be converted to `0.1` as default.',
DeprecationWarning)
# Convert to regular value 0.1.
magnify = 0.1
# all linestyles.
linestyles = ['-', '--', '-.', ':']
# Font variables dictionary. Keep it in this format for future updates.
font_dict = dict(
family='DejaVu Sans',
color='black',
weight='normal',
# [0.1, 1] - magnify intervals where font size matters
# [10.5, 50] - min and max appropriate font sizes
size=np.interp(magnify, [0.1, 1], [10.5, 50]),
)
# If single style value is passed, use it on all arrays.
if use_linestyles is None:
use_linestyles = ['-'] * len(dict_arrays)
else:
# check if linestyle is set right
# Check if linestyle is set right.
for use_linestyle in use_linestyles:
if use_linestyle not in linestyles:
# raise error
# Raise error.
raise ValueError("`linestyle=%s` is not in the styles: %s!" % (str(use_linestyle), str(linestyles)))
# Check `points_value` type - it can be bool or list(bool)
# Check `points_value` type - it can be bool or list(bool).
if isinstance(points_values, bool):
# convert to list.
# Convert to list.
points_values = [points_values] * len(dict_arrays)
elif isinstance(points_values, list) and (len(points_values) != len(dict_arrays)):
raise ValueError('`points_values` of type `list` must have same length as dictionary!')
# single plot figure
# Single plot figure.
plt.subplot(1, 2, 1)
# Plot each array.
for index, (use_label, array) in enumerate(dict_arrays.items()):
# set steps plotted on x-axis - we can use step if 1 unit has different value
# Set steps plotted on x-axis - we can use step if 1 unit has different value.
if start_step > 0:
# Offset all steps by start_step.
steps = np.array(range(0, len(array))) * step_size + start_step
else:
steps = np.array(range(1, len(array) + 1)) * step_size
# plot array as a single line
# Plot array as a single line.
plt.plot(steps, array, linestyle=use_linestyles[index], label=use_label)
# Plots points values
# Plots points values.
if points_values[index]:
# Loop through each point and plot the label.
for x, y in zip(steps, array):
# Add text label to plo.
plt.text(x, y, str(round(y, points_round)))
# set title of figure
plt.title(use_title)
# set horizontal axis name
plt.xlabel(use_xlabel)
# set vertical axis name
plt.ylabel(use_ylabel)
# place legend best position
plt.legend(loc='best') if use_label is not None else None
# display grid depending on `use_grid`
# Add text label to plot.
plt.text(x, y, str(round(y, points_round)), fontdict=font_dict)
# Set horizontal axis name.
plt.xlabel(use_xlabel, fontdict=font_dict)
# Set vertical axis name.
plt.ylabel(use_ylabel, fontdict=font_dict)
# Adjust both axis labels font size at same time.
plt.tick_params(labelsize=font_dict['size'])
# Place legend best position.
plt.legend(loc='best', fontsize=font_dict['size'])
# Adjust font for title.
font_dict['size'] *= TITLE_FONT_RATIO
# Set title of figure.
plt.title(use_title, fontdict=font_dict)
# Rescale `magnify` to be used on inches.
magnify *= MAX_MAGNIFY
# Display grid depending on `use_grid`.
plt.grid(use_grid)
# make figure nice
# Make figure nice.
plt.tight_layout()
# get figure object from plot
# Get figure object from plot.
fig = plt.gcf()
# get size of figure
# Get size of figure.
figsize = fig.get_size_inches()
# change size depending on height and width variables
# Change size depending on height and width variables.
figsize = [figsize[0] * width * magnify, figsize[1] * height * magnify]
# set the new figure size with magnify
# Set the new figure size with magnify.
fig.set_size_inches(figsize)
# save figure to image if path is set
# Save figure to image if path is set.
fig.savefig(path, dpi=use_dpi, bbox_inches='tight') if path is not None else None
# show plot
# Show plot.
plt.show() if show_plot is True else None
return
@@ -402,22 +513,22 @@ def plot_confusion_matrix(y_true, y_pred, use_title=None, classes='', normalize=
the function.
height (:obj:`int`, `optional`, defaults to :obj:`1`):
Vertical length of plot. This argument is optional and it has a default value attributed inside
Height length of plot in inches. This argument is optional and it has a default value attributed inside
the function.
magnify (:obj:`int`, `optional`, defaults to :obj:`1.2`):
Vertical length of plot. This argument is optional and it has a default value attributed inside
the function.
magnify (:obj:`float`, `optional`, defaults to :obj:`0.1`):
Ratio increase of both with and height keeping the same ratio size. This argument is optional and it has a
default value attributed inside the function.
use_dpi (:obj:`int`, `optional`, defaults to :obj:`50`):
Vertical length of plot. This argument is optional and it has a default value attributed inside
the function.
Print resolution is measured in dots per inch (or “DPI”). This argument is optional and it has a default
value attributed inside the function.
path (:obj:`str`, `optional`):
Vertical length of plot. This argument is optional and it will have a None value attributed inside
the function.
Path and file name of plot saved as image. If want to save in current path just pass in the file name.
This argument is optional and it will have a None value attributed inside the function.
show_plot (:obj:`bool`, `optional`, defaults to :obj:`True`):
show_plot (:obj:`bool`, `optional`, defaults to :obj:`1`):
if you want to call `plt.show()`. or not (if you run on a headless server). This argument is optional and
it has a default value attributed inside the function.
@@ -445,6 +556,8 @@ def plot_confusion_matrix(y_true, y_pred, use_title=None, classes='', normalize=
ValueError: If `style_sheet` is not valid.
DeprecationWarning: If `magnify` is se to values that don't belong to [0, 1] values.
"""
# Handle deprecation warnings if `title` is used.
@@ -452,6 +565,7 @@ def plot_confusion_matrix(y_true, y_pred, use_title=None, classes='', normalize=
# assign same value
use_title = kwargs['title']
warnings.warn("`title` will be deprecated in future updates. Use `use_title` in stead!", DeprecationWarning)
# Handle deprecation warnings if `image` is used.
if 'image' in kwargs:
# assign same value
@@ -462,11 +576,13 @@ def plot_confusion_matrix(y_true, y_pred, use_title=None, classes='', normalize=
# assign same value
use_dpi = kwargs['dpi']
warnings.warn("`dpi` will be deprecated in future updates. Use `use_dpi` in stead!", DeprecationWarning)
# Make sure labels have right format
# Make sure labels have right format.
if len(y_true) != len(y_pred):
# make sure lengths match
raise ValueError("`y_true` needs to have same length as `y_pred`!")
# make sure style sheet is correct
# Make sure style sheet is correct.
if style_sheet in plt.style.available:
# set style of plot
plt.style.use(style_sheet)
@@ -474,15 +590,36 @@ def plot_confusion_matrix(y_true, y_pred, use_title=None, classes='', normalize=
# style is not correct
raise ValueError("`style_sheet=%s` is not in the supported styles: %s" % (str(style_sheet),
str(plt.style.available)))
# Class labels setup. If none, generate from y_true y_pred
# Make sure `magnify` is in right range.
if magnify > 1 or magnify <= 0:
# Deprecation warning from last time.
warnings.warn(f'`magnify` needs to have value in [0,1]! `{magnify}` will be converted to `0.1` as default.',
DeprecationWarning)
# Convert to regular value 0.1.
magnify = 0.1
# Font variables dictionary. Keep it in this format for future updates.
font_dict = dict(
family='DejaVu Sans',
color='black',
weight='normal',
# [0.1, 1] - magnify intervals where font size matters
# [10.5, 50] - min and max appropriate font sizes
size=np.interp(magnify, [0.1, 1], [10.5, 50]),
)
# Class labels setup. If none, generate from y_true y_pred.
classes = list(classes)
if classes:
assert len(set(y_true)) == len(classes)
else:
classes = set(y_true)
# Compute confusion matrix
# Compute confusion matrix.
cm = confusion_matrix(y_true, y_pred)
# Nromalize setup
# Normalize setup.
if normalize is True:
print("Normalized confusion matrix")
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
@@ -490,24 +627,33 @@ def plot_confusion_matrix(y_true, y_pred, use_title=None, classes='', normalize=
else:
print('Confusion matrix, without normalization')
use_title = 'Confusion matrix, without normalization' if use_title is None else use_title
# Print if verbose
# Print if verbose.
print(cm) if verbose > 0 else None
# Plot setup
# Plot setup.
fig, ax = plt.subplots()
im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
ax.figure.colorbar(im, ax=ax)
# We want to show all ticks...
# Show all ticks.
ax.set(xticks=np.arange(cm.shape[1]),
yticks=np.arange(cm.shape[0]),
# ... and label them with the respective list entries
# Label ticks with the respective list entries.
xticklabels=classes, yticklabels=classes,
title=use_title,
ylabel='True label',
xlabel='Predicted label')
)
# Set horizontal axis name.
ax.set_xlabel('Predicted label', fontdict=font_dict)
# Set vertical axis name.
ax.set_ylabel('True label', fontdict=font_dict)
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")
plt.grid(False)
# Loop over data dimensions and create text annotations.
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
@@ -515,22 +661,42 @@ def plot_confusion_matrix(y_true, y_pred, use_title=None, classes='', normalize=
for j in range(cm.shape[1]):
ax.text(j, i, format(cm[i, j], fmt),
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black")
color="white" if cm[i, j] > thresh else "black", fontdict=font_dict)
# Adjust both axis labels font size at same time.
plt.tick_params(labelsize=font_dict['size'])
# Adjust font for title.
font_dict['size'] *= TITLE_FONT_RATIO
# Set title of figure.
plt.title(use_title, fontdict=font_dict)
# Rescale `magnify` to be used on inches.
magnify *= MAX_MAGNIFY
# Never display grid.
plt.grid(False)
# make figure nice
# Make figure nice.
plt.tight_layout()
# get figure object from plot
# Get figure object from plot.
fig = plt.gcf()
# get size of figure
# Get size of figure.
figsize = fig.get_size_inches()
# change size depending on height and width variables
# Change size depending on height and width variables.
figsize = [figsize[0] * width * magnify, figsize[1] * height * magnify]
# set the new figure size with magnify
# Set the new figure size with magnify.
fig.set_size_inches(figsize)
# save figure to image if path is set
# Save figure to image if path is set.
fig.savefig(path, dpi=use_dpi, bbox_inches='tight') if path is not None else None
# show plot
# Show plot.
plt.show() if show_plot is True else None
return cm

Binary file not shown.

Before

Width:  |  Height:  |  Size: 9.8 KiB

After

Width:  |  Height:  |  Size: 14 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 7.0 KiB

After

Width:  |  Height:  |  Size: 8.1 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 12 KiB

After

Width:  |  Height:  |  Size: 16 KiB