This commit is contained in:
robertmartin8
2021-01-28 19:34:55 +08:00
parent 8eb2a16c56
commit 0f2a86da86
2 changed files with 20 additions and 9 deletions

View File

@@ -120,16 +120,17 @@ class BaseConvexOptimizer(BaseOptimizer):
"""
The BaseConvexOptimizer contains many private variables for use by
``cvxpy``. For example, the immutable optimisation variable for weights
is stored as self._w. Interacting directly with these variables is highly
discouraged.
is stored as self._w. Interacting directly with these variables directly
is discouraged.
Instance variables:
- ``n_assets`` - int
- ``tickers`` - str list
- ``weights`` - np.ndarray
- ``solver`` - str
- ``solver_options`` - {str: str} dict
- ``_opt`` - cp.Problem
- ``_solver`` - str
- ``_solver_options`` - {str: str} dict
Public methods:
@@ -175,6 +176,7 @@ class BaseConvexOptimizer(BaseOptimizer):
self._upper_bounds = None
self._map_bounds_to_constraints(weight_bounds)
self._opt = None
self._solver = solver
self._verbose = verbose
self._solver_options = solver_options if solver_options else {}
@@ -227,19 +229,21 @@ class BaseConvexOptimizer(BaseOptimizer):
:raises exceptions.OptimizationError: if problem is not solvable by cvxpy
"""
try:
opt = cp.Problem(cp.Minimize(self._objective), self._constraints)
self._opt = cp.Problem(cp.Minimize(self._objective), self._constraints)
if self._solver is not None:
opt.solve(
self._opt.solve(
solver=self._solver, verbose=self._verbose, **self._solver_options
)
else:
opt.solve(verbose=self._verbose, **self._solver_options)
self._opt.solve(verbose=self._verbose, **self._solver_options)
except (TypeError, cp.DCPError) as e:
raise exceptions.OptimizationError from e
if opt.status not in {"optimal", "optimal_inaccurate"}:
raise exceptions.OptimizationError("Solver status: {}".format(opt.status))
if self._opt.status not in {"optimal", "optimal_inaccurate"}:
raise exceptions.OptimizationError(
"Solver status: {}".format(self._opt.status)
)
self.weights = self._w.value.round(16) + 0.0 # +0.0 removes signed zero
return self._make_output_weights()