plot_dict change what points to have labels

This commit is contained in:
georgemihaila
2020-10-16 11:45:29 -05:00
parent 774283712e
commit 2d80996138

View File

@@ -228,6 +228,13 @@ def plot_dict(dict_arrays, start_step=0, step_size=1, use_title=None, points_val
# raise error
raise ValueError("`linestyle=%s` is not in the styles: %s!" % (str(use_linestyle), str(linestyles)))
# Check `points_value` type - it can be bool or list(bool)
if isinstance(points_values, bool):
# convert to list.
points_values = [points_values] * len(dict_arrays)
elif isinstance(points_values, list) and (len(points_values) != len(dict_arrays)):
raise ValueError('`points_values` of type `list` must have same length as dictionary!')
# single plot figure
plt.subplot(1, 2, 1)
for index, (use_label, array) in enumerate(dict_arrays.items()):
@@ -240,7 +247,7 @@ def plot_dict(dict_arrays, start_step=0, step_size=1, use_title=None, points_val
# plot array as a single line
plt.plot(steps, array, linestyle=use_linestyles[index], label=use_label)
# Plots points values
if points_values:
if points_values[index]:
# Loop through each point and plot the label.
for x, y in zip(steps, array):
# Add text label to plo.