Files
backtesting.py/backtesting/_plotting.py

600 lines
24 KiB
Python

import os
import warnings
from itertools import cycle, combinations
from functools import partial
import numpy as np
import pandas as pd
from bokeh.plotting import figure as _figure
from bokeh.models import (
CustomJS,
ColumnDataSource,
NumeralTickFormatter,
Span,
HoverTool,
Range1d,
DatetimeTickFormatter,
FuncTickFormatter,
WheelZoomTool,
LinearColorMapper,
)
from bokeh.io import output_notebook, output_file, show
from bokeh.io.state import curstate
from bokeh.layouts import gridplot
from bokeh.palettes import Category10
from bokeh.transform import factor_cmap
from backtesting._util import _data_period
IS_JUPYTER_NOTEBOOK = 'JPY_PARENT_PID' in os.environ
if IS_JUPYTER_NOTEBOOK:
warnings.warn('Jupyter Notebook detected. '
'Setting Bokeh output to notebook. '
'This may not work in Jupyter clients without JavaScript '
'support (e.g. PyCharm, Spyder IDE). '
'Reset with `bokeh.io.reset_output()`.')
output_notebook()
with open(os.path.join(os.path.dirname(__file__), 'autoscale_cb.js'),
encoding='utf-8') as _f:
_AUTOSCALE_JS_CALLBACK = _f.read()
def _bokeh_reset(filename=None):
curstate().reset()
# Test if we are in Jupyter notebook
if IS_JUPYTER_NOTEBOOK:
curstate().output_notebook()
elif filename:
if not filename.endswith('.html'):
filename += '.html'
output_file(filename, title=filename)
def colorgen():
yield from cycle(Category10[10])
def plot(*, results, df, indicators, filename='', plot_width=1200,
plot_equity=True, plot_pl=True,
plot_volume=True, plot_drawdown=False,
smooth_equity=False, relative_equity=True, omit_missing=True,
superimpose=True, show_legend=True, open_browser=True):
"""
Like much of GUI code everywhere, this is a mess.
"""
# We need to reset global Bokeh state, otherwise subsequent runs of
# plot() contain some previous run's cruft data (was noticed when
# TestPlot.test_file_size() test was failing).
_bokeh_reset(filename)
COLORS = ['tomato', 'lime']
COLORS_LIGHT = ['#ffe3e3', '#e3ffe3']
orig_trade_data = trade_data = results._trade_data.copy(False)
orig_df = df = df.copy(False)
df.index.name = None # Provides source name @index
index = df.index
time_resolution = getattr(index, 'resolution', None)
is_datetime_index = index.is_all_dates
# If all Volume is NaN, don't plot volume
plot_volume = plot_volume and not df.Volume.isnull().all()
# OHLC vbar width in msec.
# +1 will work in case of non-datetime index where vbar width should just be =1
width = 1 + dict(day=86400,
hour=3600,
minute=60,
second=1).get(time_resolution, 0) * 1000 * .85
if is_datetime_index:
# Add index as a separate data source column because true .index is offset to align vbars
df['datetime'] = index
df.index = df.index + pd.Timedelta(width / 2, unit='ms')
if omit_missing:
width = .8
df = df.reset_index(drop=True)
trade_data = trade_data.reset_index(drop=True)
index = df.index
new_bokeh_figure = partial(
_figure,
x_axis_type='datetime' if is_datetime_index and not omit_missing else 'linear',
plot_width=plot_width,
plot_height=400,
tools="xpan,xwheel_zoom,box_zoom,undo,redo,reset,crosshair,save",
active_drag='xpan',
active_scroll='xwheel_zoom')
pad = (index[-1] - index[0]) / 20
fig_ohlc = new_bokeh_figure(
x_range=Range1d(index[0], index[-1],
bounds=(index[0] - pad,
index[-1] + pad)) if index.size > 1 else None)
figs_above_ohlc, figs_below_ohlc = [], []
source = ColumnDataSource(df)
source.add((df.Close >= df.Open).values.astype(np.uint8).astype(str), 'inc')
returns = trade_data['Returns'].dropna()
trade_source = ColumnDataSource(dict(
index=returns.index,
datetime=orig_trade_data['Returns'].dropna().index,
exit_price=trade_data['Exit Price'].dropna(),
returns_pos=(returns > 0).astype(np.int8).astype(str),
))
inc_cmap = factor_cmap('inc', COLORS, ['0', '1'])
cmap = factor_cmap('returns_pos', COLORS, ['0', '1'])
if is_datetime_index and omit_missing:
fig_ohlc.xaxis.formatter = FuncTickFormatter(
args=dict(axis=fig_ohlc.xaxis[0],
formatter=DatetimeTickFormatter(days=['%d %b', '%a %d'],
months=['%m/%Y', "%b'%y"]),
source=source),
code='''
this.labels = this.labels || formatter.doFormat(ticks
.map(i => source.data.datetime[i])
.filter(t => t !== undefined));
return this.labels[index] || "";
''')
NBSP = ' ' * 4
ohlc_extreme_values = df[['High', 'Low']].copy(False)
ohlc_tooltips = [
('x, y', NBSP.join(('$index',
'$y{0,0.0[0000]}'))),
('OHLC', NBSP.join(('@Open{0,0.0[0000]}',
'@High{0,0.0[0000]}',
'@Low{0,0.0[0000]}',
'@Close{0,0.0[0000]}'))),
('Volume', '@Volume{0,0}')]
def new_indicator_figure(**kwargs):
kwargs.setdefault('plot_height', 90)
fig = new_bokeh_figure(x_range=fig_ohlc.x_range,
active_scroll='xwheel_zoom',
active_drag='xpan',
**kwargs)
fig.xaxis.visible = False
fig.yaxis.minor_tick_line_color = None
return fig
def set_tooltips(fig, tooltips=(), vline=True, renderers=(), show_arrow=True):
tooltips = list(tooltips)
renderers = list(renderers)
if is_datetime_index:
formatters = dict(datetime='datetime')
tooltips = [("Date", "@datetime{%c}")] + tooltips
else:
formatters = {}
tooltips = [("#", "@index")] + tooltips
fig.add_tools(HoverTool(
point_policy='follow_mouse',
renderers=renderers, formatters=formatters, show_arrow=show_arrow,
tooltips=tooltips, mode='vline' if vline else 'mouse'))
def _plot_equity_section():
"""Equity section"""
# Max DD Dur. line
equity = trade_data['Equity']
argmax = trade_data['Drawdown Duration'].idxmax()
try:
dd_start = equity[:argmax].idxmax()
except Exception: # ValueError: attempt to get argmax of an empty sequence
dd_start = dd_end = equity.index[0]
timedelta = 0
else:
dd_end = (equity[argmax:] > equity[dd_start]).idxmax()
if dd_end == argmax:
dd_end = index[-1]
if is_datetime_index and omit_missing:
# "Calendar" duration
timedelta = df.datetime.iloc[dd_end] - df.datetime.iloc[dd_start]
else:
timedelta = dd_end - dd_start
# Get point intersection
if dd_end != index[-1]:
x1, x2 = index.get_loc(dd_end) - 1, index.get_loc(dd_end)
y, y1, y2 = equity[dd_start], equity[x1], equity[x2]
dd_end -= (1 - (y - y1) / (y2 - y1)) * (dd_end - index[x1]) # y = a x + b
if smooth_equity:
select = (trade_data[['Entry Price',
'Exit Price']].dropna(how='all').index |
# Include beginning
equity.index[:1] |
# Include max dd end points. Otherwise, the MaxDD line looks amiss.
pd.Index([dd_start, dd_end]))
equity = equity[select].reindex(equity.index)
equity.interpolate(inplace=True)
if relative_equity:
equity /= equity.iloc[0]
source.add(equity, 'equity')
fig = new_indicator_figure(
y_axis_label="Equity",
**({} if plot_drawdown else dict(plot_height=110)))
# High-watermark drawdown dents
fig.patch('index', 'equity_dd',
source=ColumnDataSource(dict(
index=np.r_[index, index[::-1]],
equity_dd=np.r_[equity, equity.cummax()[::-1]]
)),
fill_color='#ffffea', line_color='#ffcb66')
# Equity line
r = fig.line('index', 'equity', source=source, line_width=1.5, line_alpha=1)
if relative_equity:
tooltip_format = '@equity{+0,0.[000]%}'
tick_format = '0,0.[00]%'
legend_format = '{:,.0f}%'
else:
tooltip_format = '@equity{$ 0,0}'
tick_format = '$ 0.0 a'
legend_format = '${:,.0f}'
set_tooltips(fig, [('Equity', tooltip_format)], renderers=[r])
fig.yaxis.formatter = NumeralTickFormatter(format=tick_format)
# Peaks
argmax = equity.idxmax()
fig.scatter(argmax, equity[argmax],
legend='Peak ({})'.format(
legend_format.format(equity[argmax] * (100 if relative_equity else 1))),
color='cyan', size=8)
fig.scatter(index[-1], equity.values[-1],
legend='Final ({})'.format(
legend_format.format(equity.iloc[-1] * (100 if relative_equity else 1))),
color='blue', size=8)
if not plot_drawdown:
drawdown = trade_data['Drawdown']
argmax = drawdown.idxmax()
fig.scatter(argmax, equity[argmax],
legend='Max Drawdown (-{:.1f}%)'.format(100 * drawdown[argmax]),
color='red', size=8)
fig.line([dd_start, dd_end], equity[dd_start],
line_color='red', line_width=2,
legend='Max Dd Dur. ({})'.format(timedelta)
.replace(' 00:00:00', '')
.replace('(0 days ', '('))
figs_above_ohlc.append(fig)
def _plot_drawdown_section():
"""Drawdown section"""
fig = new_indicator_figure(y_axis_label="Drawdown")
drawdown = trade_data['Drawdown']
argmax = drawdown.idxmax()
source.add(drawdown, 'drawdown')
r = fig.line('index', 'drawdown', source=source, line_width=1.3)
fig.scatter(argmax, drawdown[argmax],
legend='Peak (-{:.1f}%)'.format(100 * drawdown[argmax]),
color='red', size=8)
set_tooltips(fig, [('Drawdown', '@drawdown{-0.[0]%}')], renderers=[r])
fig.yaxis.formatter = NumeralTickFormatter(format="-0.[0]%")
return fig
def _plot_pl_section():
"""Profit/Loss markers section"""
fig = new_indicator_figure(y_axis_label="Profit / Loss")
fig.add_layout(Span(location=0, dimension='width', line_color='#666666',
line_dash='dashed', line_width=1))
position = trade_data['Exit Position'].dropna()
returns_long = returns.copy()
returns_short = returns.copy()
returns_long[position < 0] = np.nan
returns_short[position > 0] = np.nan
trade_source.add(returns_long, 'returns_long')
trade_source.add(returns_short, 'returns_short')
MARKER_SIZE = 13
r1 = fig.scatter('index', 'returns_long', source=trade_source, fill_color=cmap,
marker='triangle', line_color='black', size=MARKER_SIZE)
r2 = fig.scatter('index', 'returns_short', source=trade_source, fill_color=cmap,
marker='inverted_triangle', line_color='black', size=MARKER_SIZE)
set_tooltips(fig, [("P/L", "@returns_long{+0.[000]%}")], vline=False, renderers=[r1])
set_tooltips(fig, [("P/L", "@returns_short{+0.[000]%}")], vline=False, renderers=[r2])
fig.yaxis.formatter = NumeralTickFormatter(format="0.[00]%")
return fig
def _plot_volume_section():
"""Volume section"""
fig = new_indicator_figure(y_axis_label="Volume")
fig.xaxis.formatter = fig_ohlc.xaxis[0].formatter
fig.xaxis.visible = True
fig_ohlc.xaxis.visible = False # Show only Volume's xaxis
r = fig.vbar('index', width, 'Volume', source=source, color=inc_cmap)
set_tooltips(fig, [('Volume', '@Volume{0.00 a}')], renderers=[r])
fig.yaxis.formatter = NumeralTickFormatter(format="0 a")
return fig
def _plot_superimposed_ohlc():
"""Superimposed, downsampled vbars"""
resample_rule = (superimpose if isinstance(superimpose, str) else
dict(day='W',
hour='D',
minute='H',
second='T',
millisecond='S').get(time_resolution))
if not resample_rule:
warnings.warn(
"'Can't superimpose OHLC data with rule '{}' (index datetime resolution: '{}'). "
"Skipping.".format(resample_rule, time_resolution),
stacklevel=4)
return
orig_df['_width'] = 1
from .lib import OHLCV_AGG
df2 = orig_df.resample(resample_rule, label='left').agg(dict(OHLCV_AGG, _width='count'))
# Check if resampling was downsampling; error on upsampling
orig_freq = _data_period(orig_df)
resample_freq = _data_period(df2)
if resample_freq < orig_freq:
raise ValueError('Invalid value for `superimpose`: Upsampling not supported.')
if resample_freq == orig_freq:
warnings.warn('Superimposed OHLC plot matches the original plot. Skipping.',
stacklevel=4)
return
if omit_missing:
width2 = '_width'
df2.index = df2['_width'].cumsum().shift(1).fillna(0)
df2.index += df2['_width'] / 2 - .5
df2['_width'] -= .1 # Candles don't touch
else:
del df['_width']
width2 = dict(day=86400 * 5,
hour=86400,
minute=3600,
second=60)[time_resolution] * 1000
df2.index += pd.Timedelta(
width2 / 2 +
(width2 / 5 if resample_rule == 'W' else 0), # Sunday week start
unit='ms')
df2['inc'] = (df2.Close >= df2.Open).astype(np.uint8).astype(str)
df2.index.name = None
source2 = ColumnDataSource(df2)
fig_ohlc.segment('index', 'High', 'index', 'Low', source=source2, color='#bbbbbb')
fig_ohlc.vbar('index', width2, 'Open', 'Close', source=source2, line_color=None,
fill_color=factor_cmap('inc', COLORS_LIGHT, ['0', '1']))
def _plot_ohlc():
"""Main OHLC bars"""
fig_ohlc.segment('index', 'High', 'index', 'Low', source=source, color="black")
r = fig_ohlc.vbar('index', width, 'Open', 'Close', source=source,
line_color="black", fill_color=inc_cmap)
return r
def _plot_ohlc_orders():
"""Order entry / exit markers on OHLC plot"""
exit_price = trade_data['Exit Price'].dropna()
entry_price = trade_data['Entry Price'].dropna().iloc[:exit_price.size] # entry can be one more at the end # noqa: E501
trade_source.add(np.column_stack((entry_price.index, exit_price.index)).tolist(),
'position_lines_xs')
trade_source.add(np.column_stack((entry_price, exit_price)).tolist(),
'position_lines_ys')
fig_ohlc.multi_line(xs='position_lines_xs', ys='position_lines_ys',
source=trade_source, line_color=cmap,
legend='Trades',
line_width=8, line_alpha=1, line_dash='dotted')
def _plot_indicators():
"""Strategy indicators"""
def _too_many_dims(value):
assert value.ndim >= 2
if value.ndim > 2:
warnings.warn("Can't plot indicators with >2D ('{}')".format(value.name),
stacklevel=5)
return True
return False
class LegendStr(str):
# The legend string is such a string that only matches
# itself if it's the exact same object. This ensures
# legend items are listed separately even when they have the
# same string contents. Otherwise, Bokeh would always consider
# equal strings as one and the same legend item.
# This also prevents legend items named the same as some
# ColumnDataSource's column to be replaced with that column's
# values.
def __eq__(self, other):
return self is other
ohlc_colors = colorgen()
for value in indicators:
value = np.atleast_2d(value)
# Use .get()! A user might have assigned a Strategy.data-evolved
# _Array without Strategy.I()
if not value._opts.get('plot') or _too_many_dims(value):
continue
color = value._opts['color']
tooltips = []
# Overlay indicators on the OHLC figure
if value._opts['overlay']:
color = color or next(ohlc_colors)
legend = LegendStr(value.name)
for i, arr in enumerate(value):
source_name = '{}_{}'.format(value.name, i)
source.add(arr, source_name)
fig_ohlc.line('index', source_name, source=source,
line_width=1.3, line_color=color, legend=legend)
ohlc_extreme_values[source_name] = arr
tooltips.append('@{{{}}}{{0,0.0[0000]}}'.format(source_name))
ohlc_tooltips.append((value.name, NBSP.join(tooltips)))
else:
# Standalone indicator sections at the bottom
color = color or colorgen()
fig = new_indicator_figure()
for i, arr in enumerate(value, 1):
legend = '{}-{}'.format(value.name, i) if len(value) > 1 else value.name
name = legend + '_' # Otherwise fig.line(legend=) is interpreted as col of source # noqa: E501
tooltips.append('@{{{}}}'.format(name))
source.add(arr.astype(int if arr.dtype == bool else float), name)
r = fig.line('index', name, source=source,
line_color=next(color), line_width=1.3, legend=LegendStr(legend))
# Add dashed centerline just because
mean = float(pd.Series(arr).mean())
if not np.isnan(mean) and (abs(mean) < .1 or
round(abs(mean), -1) in (50, 100, 200)):
fig.add_layout(Span(location=float(mean), dimension='width',
line_color='#666666', line_dash='dashed',
line_width=.5))
set_tooltips(fig, [(value.name, NBSP.join(tooltips))], vline=True, renderers=[r])
# If the sole indicator line on this figure,
# have the legend only contain text without the glyph
if len(value) == 1:
fig.legend.glyph_width = 0
figs_below_ohlc.append(fig)
# Construct figure ...
if plot_equity:
_plot_equity_section()
if plot_drawdown:
figs_above_ohlc.append(_plot_drawdown_section())
if plot_pl:
figs_above_ohlc.append(_plot_pl_section())
if plot_volume:
fig_volume = _plot_volume_section()
figs_below_ohlc.append(fig_volume)
if superimpose and is_datetime_index:
_plot_superimposed_ohlc()
ohlc_bars = _plot_ohlc()
_plot_ohlc_orders()
_plot_indicators()
set_tooltips(fig_ohlc, ohlc_tooltips, vline=True, renderers=[ohlc_bars])
source.add(ohlc_extreme_values.min(1), 'ohlc_low')
source.add(ohlc_extreme_values.max(1), 'ohlc_high')
custom_js_args = dict(ohlc_range=fig_ohlc.y_range,
source=source)
if plot_volume:
custom_js_args.update(volume_range=fig_volume.y_range)
fig_ohlc.x_range.callback = CustomJS(args=custom_js_args,
code=_AUTOSCALE_JS_CALLBACK)
plots = figs_above_ohlc + [fig_ohlc] + figs_below_ohlc
for f in plots:
if f.legend:
f.legend.location = 'top_left' if show_legend else None
f.legend.border_line_width = 1
f.legend.border_line_color = '#333333'
f.legend.padding = 5
f.legend.spacing = 0
f.legend.margin = 0
f.legend.label_text_font_size = '8pt'
f.min_border_left = 0
f.min_border_top = 3
f.min_border_bottom = 6
f.min_border_right = 10
f.outline_line_color = '#666666'
wheelzoom_tool = next(wz for wz in f.tools if isinstance(wz, WheelZoomTool))
wheelzoom_tool.maintain_focus = False
fig = gridplot(
plots,
ncols=1,
toolbar_location='right',
# sizing_mode='stretch_width',
toolbar_options=dict(logo=None),
merge_tools=True,
)
show(fig, browser=None if open_browser else 'none')
return fig
def plot_heatmaps(heatmap: pd.Series, agg: str, ncols: int,
filename: str = '', plot_width: int = 1200, open_browser: bool = True):
if not (isinstance(heatmap, pd.Series) and
isinstance(heatmap.index, pd.MultiIndex)):
raise ValueError('heatmap must be heatmap Series as returned by '
'`Backtest.optimize(..., return_heatmap=True)`')
_bokeh_reset()
if filename:
output_file(filename)
param_combinations = combinations(heatmap.index.names, 2)
dfs = [heatmap.groupby(list(dims)).agg(agg).to_frame(name='_Value')
for dims in param_combinations]
plots = []
cmap = LinearColorMapper(palette='Viridis256',
low=min(df.min().min() for df in dfs),
high=max(df.max().max() for df in dfs),
nan_color='white')
for df in dfs:
name1, name2 = df.index.names
level1 = df.index.levels[0].astype(str).tolist()
level2 = df.index.levels[1].astype(str).tolist()
df = df.reset_index()
df[name1] = df[name1].astype('str')
df[name2] = df[name2].astype('str')
fig = _figure(x_range=level1,
y_range=level2,
x_axis_label=name1,
y_axis_label=name2,
plot_width=plot_width // ncols,
plot_height=plot_width // ncols,
tools='box_zoom,reset,save',
tooltips=[(name1, '@' + name1),
(name2, '@' + name2),
('Value', '@_Value{0.[000]}')])
fig.grid.grid_line_color = None
fig.axis.axis_line_color = None
fig.axis.major_tick_line_color = None
fig.axis.major_label_standoff = 0
fig.rect(x=name1,
y=name2,
width=1,
height=1,
source=df,
line_color=None,
fill_color=dict(field='_Value',
transform=cmap))
plots.append(fig)
fig = gridplot(
plots,
ncols=ncols,
toolbar_options=dict(logo=None),
toolbar_location='above',
merge_tools=True,
)
show(fig, browser=None if open_browser else 'none')
return fig