mirror of
https://github.com/robertmartin8/PyPortfolioOpt.git
synced 2022-11-27 18:02:41 +03:00
added save_weights_to_file (#54)
This commit is contained in:
@@ -7,6 +7,7 @@ Additionally, we define a general utility function ``portfolio_performance`` to
|
||||
evaluate return and risk for a given set of portfolio weights.
|
||||
"""
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from . import objective_functions
|
||||
@@ -20,6 +21,12 @@ class BaseOptimizer:
|
||||
- ``n_assets`` - int
|
||||
- ``tickers`` - str list
|
||||
- ``weights`` - np.ndarray
|
||||
|
||||
Public methods:
|
||||
|
||||
- ``set_weights()`` creates self.weights (np.ndarray) from a weights dict
|
||||
- ``clean_weights()`` rounds the weights and clips near-zeros.
|
||||
- ``save_weights_to_file()`` saves the weights to csv, json, or txt.
|
||||
"""
|
||||
|
||||
def __init__(self, n_assets, tickers=None):
|
||||
@@ -44,11 +51,7 @@ class BaseOptimizer:
|
||||
:param weights: {ticker: weight} dictionary
|
||||
:type weights: dict
|
||||
"""
|
||||
if self.weights is None:
|
||||
self.weights = [0] * self.n_assets
|
||||
for i, ticker in enumerate(self.tickers):
|
||||
if ticker in weights:
|
||||
self.weights[i] = weights[ticker]
|
||||
self.weights = np.array([weights[ticker] for ticker in self.tickers])
|
||||
|
||||
def clean_weights(self, cutoff=1e-4, rounding=5):
|
||||
"""
|
||||
@@ -63,6 +66,8 @@ class BaseOptimizer:
|
||||
:return: asset weights
|
||||
:rtype: dict
|
||||
"""
|
||||
if self.weights is None:
|
||||
raise AttributeError("Weights not yet computed")
|
||||
clean_weights = self.weights.copy()
|
||||
clean_weights[np.abs(clean_weights) < cutoff] = 0
|
||||
if rounding is not None:
|
||||
@@ -71,6 +76,25 @@ class BaseOptimizer:
|
||||
clean_weights = np.round(clean_weights, rounding)
|
||||
return dict(zip(self.tickers, clean_weights))
|
||||
|
||||
def save_weights_to_file(self, filename="weights.csv"):
|
||||
"""
|
||||
Utility method to save weights to a text file.
|
||||
|
||||
:param filename: name of file. Should be csv, json, or txt.
|
||||
:type filename: str
|
||||
"""
|
||||
clean_weights = self.clean_weights()
|
||||
|
||||
ext = filename.split(".")[1]
|
||||
if ext == "csv":
|
||||
pd.Series(clean_weights).to_csv(filename, header=False)
|
||||
elif ext == "json":
|
||||
with open(filename, "w") as fp:
|
||||
json.dump(clean_weights, fp)
|
||||
else:
|
||||
with open(filename, "w") as f:
|
||||
f.write(str(clean_weights))
|
||||
|
||||
|
||||
class BaseScipyOptimizer(BaseOptimizer):
|
||||
|
||||
@@ -113,7 +137,6 @@ class BaseScipyOptimizer(BaseOptimizer):
|
||||
:rtype: tuple of tuples
|
||||
"""
|
||||
# If it is a collection with the right length, assume they are all bounds.
|
||||
print(test_bounds)
|
||||
if len(test_bounds) == self.n_assets and not isinstance(
|
||||
test_bounds[0], (float, int)
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user