REF: Strategy takes params in constructor

This commit is contained in:
Kernc
2020-03-06 00:13:28 +01:00
parent 51f1e65766
commit 55e03a40b3

View File

@@ -14,7 +14,7 @@ from concurrent.futures import ProcessPoolExecutor, as_completed
from functools import partial
from itertools import repeat, product, chain
from numbers import Number
from typing import Callable, Union, Tuple
from typing import Callable, Union, Tuple, Type
import numpy as np
import pandas as pd
@@ -48,11 +48,11 @@ class Strategy(metaclass=ABCMeta):
`backtesting.backtesting.Strategy.next` to define
your own strategy.
"""
def __init__(self, broker, data):
def __init__(self, broker, data, params):
self._indicators = []
self._broker = broker # type: _Broker
self._data = data # type: _Data
self._params = {}
self._params = self._check_params(params)
def __repr__(self):
return '<Strategy ' + str(self) + '>'
@@ -64,16 +64,15 @@ class Strategy(metaclass=ABCMeta):
params = '(' + params + ')'
return '{}{}'.format(self.__class__.__name__, params)
def _set_params(self, **kwargs):
for k, v in kwargs.items():
def _check_params(self, params):
for k, v in params.items():
if not hasattr(self, k):
raise AttributeError(
"Strategy '{}' is missing parameter '{}'. Strategy class "
"should define parameters as class variables before they "
"can be optimized or run with.".format(self.__class__.__name__, k))
setattr(self, k, v)
self._params = kwargs
return params
def I(self, # noqa: E743
func: Callable, *args,
@@ -582,7 +581,7 @@ class Backtest:
"""
def __init__(self,
data: pd.DataFrame,
strategy: type(Strategy),
strategy: Type[Strategy],
*,
cash: float = 10000,
commission: float = .0,
@@ -680,11 +679,11 @@ class Backtest:
"""
data = _Data(self._data)
broker = self._broker(data=data) # type: _Broker
strategy = self._strategy(broker, data) # type: Strategy
strategy._set_params(**kwargs)
strategy = self._strategy(broker, data, kwargs) # type: Strategy
strategy.init()
# Indicators used in Strategy.next()
indicator_attrs = {attr: indicator
for attr, indicator in strategy.__dict__.items()
if isinstance(indicator, _Indicator)}.items()