mirror of
https://github.com/robertmartin8/PyPortfolioOpt.git
synced 2022-11-27 18:02:41 +03:00
Allow None as rounding precision
This commit is contained in:
@@ -52,11 +52,11 @@ class BaseOptimizer:
|
||||
:return: asset weights
|
||||
:rtype: dict
|
||||
"""
|
||||
if not isinstance(rounding, int) or rounding < 1:
|
||||
raise ValueError("rounding must be a positive integer")
|
||||
clean_weights = self.weights.copy()
|
||||
clean_weights[np.abs(clean_weights) < cutoff] = 0
|
||||
if rounding is not None:
|
||||
if not isinstance(rounding, int) or rounding < 1:
|
||||
raise ValueError("rounding must be a positive integer")
|
||||
clean_weights = np.round(clean_weights, rounding)
|
||||
return dict(zip(self.tickers, clean_weights))
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from unittest import mock
|
||||
from pypfopt.efficient_frontier import EfficientFrontier
|
||||
from tests.utilities_for_tests import get_data, setup_efficient_frontier
|
||||
|
||||
|
||||
def test_custom_upper_bound():
|
||||
ef = EfficientFrontier(
|
||||
*setup_efficient_frontier(data_only=True), weight_bounds=(0, 0.10)
|
||||
@@ -84,6 +84,25 @@ def test_clean_weights_error():
|
||||
ef.clean_weights(rounding=0)
|
||||
assert ef.clean_weights(rounding=3)
|
||||
|
||||
def test_clean_weights_no_rounding():
|
||||
ef = setup_efficient_frontier()
|
||||
ef.max_sharpe()
|
||||
# ensure the call does not fail
|
||||
# in previous commits, this call would raise a ValueError
|
||||
assert ef.clean_weights(rounding=None)
|
||||
|
||||
# ensure the call does not round
|
||||
with mock.patch('pypfopt.efficient_frontier.np.round') as rounding_method:
|
||||
# rather than check the weights before and after for rounding, which
|
||||
# could probably have floating point issues, ensure the rounding method,
|
||||
# `np.round` is not called
|
||||
ef.clean_weights(rounding=None)
|
||||
assert rounding_method.call_count == 0
|
||||
|
||||
# sanity check to ensure the mock has been created correctly
|
||||
ef.clean_weights(rounding=1)
|
||||
assert rounding_method.call_count == 1
|
||||
|
||||
|
||||
def test_efficient_frontier_init_errors():
|
||||
df = get_data()
|
||||
|
||||
Reference in New Issue
Block a user