mirror of
https://github.com/gmihaila/ml_things.git
synced 2021-10-04 01:29:04 +03:00
plot_dict change what points to have labels
This commit is contained in:
@@ -228,6 +228,13 @@ def plot_dict(dict_arrays, start_step=0, step_size=1, use_title=None, points_val
|
||||
# raise error
|
||||
raise ValueError("`linestyle=%s` is not in the styles: %s!" % (str(use_linestyle), str(linestyles)))
|
||||
|
||||
# Check `points_value` type - it can be bool or list(bool)
|
||||
if isinstance(points_values, bool):
|
||||
# convert to list.
|
||||
points_values = [points_values] * len(dict_arrays)
|
||||
elif isinstance(points_values, list) and (len(points_values) != len(dict_arrays)):
|
||||
raise ValueError('`points_values` of type `list` must have same length as dictionary!')
|
||||
|
||||
# single plot figure
|
||||
plt.subplot(1, 2, 1)
|
||||
for index, (use_label, array) in enumerate(dict_arrays.items()):
|
||||
@@ -240,7 +247,7 @@ def plot_dict(dict_arrays, start_step=0, step_size=1, use_title=None, points_val
|
||||
# plot array as a single line
|
||||
plt.plot(steps, array, linestyle=use_linestyles[index], label=use_label)
|
||||
# Plots points values
|
||||
if points_values:
|
||||
if points_values[index]:
|
||||
# Loop through each point and plot the label.
|
||||
for x, y in zip(steps, array):
|
||||
# Add text label to plo.
|
||||
|
||||
Reference in New Issue
Block a user