added save_weights_to_file (#54)

This commit is contained in:
robertmartin8
2019-12-10 22:48:25 +00:00
parent 6db226b2d5
commit d4f3624bd8
2 changed files with 60 additions and 6 deletions

View File

@@ -7,6 +7,7 @@ Additionally, we define a general utility function ``portfolio_performance`` to
evaluate return and risk for a given set of portfolio weights.
"""
import json
import numpy as np
import pandas as pd
from . import objective_functions
@@ -20,6 +21,12 @@ class BaseOptimizer:
- ``n_assets`` - int
- ``tickers`` - str list
- ``weights`` - np.ndarray
Public methods:
- ``set_weights()`` creates self.weights (np.ndarray) from a weights dict
- ``clean_weights()`` rounds the weights and clips near-zeros.
- ``save_weights_to_file()`` saves the weights to csv, json, or txt.
"""
def __init__(self, n_assets, tickers=None):
@@ -44,11 +51,7 @@ class BaseOptimizer:
:param weights: {ticker: weight} dictionary
:type weights: dict
"""
if self.weights is None:
self.weights = [0] * self.n_assets
for i, ticker in enumerate(self.tickers):
if ticker in weights:
self.weights[i] = weights[ticker]
self.weights = np.array([weights[ticker] for ticker in self.tickers])
def clean_weights(self, cutoff=1e-4, rounding=5):
"""
@@ -63,6 +66,8 @@ class BaseOptimizer:
:return: asset weights
:rtype: dict
"""
if self.weights is None:
raise AttributeError("Weights not yet computed")
clean_weights = self.weights.copy()
clean_weights[np.abs(clean_weights) < cutoff] = 0
if rounding is not None:
@@ -71,6 +76,25 @@ class BaseOptimizer:
clean_weights = np.round(clean_weights, rounding)
return dict(zip(self.tickers, clean_weights))
def save_weights_to_file(self, filename="weights.csv"):
"""
Utility method to save weights to a text file.
:param filename: name of file. Should be csv, json, or txt.
:type filename: str
"""
clean_weights = self.clean_weights()
ext = filename.split(".")[1]
if ext == "csv":
pd.Series(clean_weights).to_csv(filename, header=False)
elif ext == "json":
with open(filename, "w") as fp:
json.dump(clean_weights, fp)
else:
with open(filename, "w") as f:
f.write(str(clean_weights))
class BaseScipyOptimizer(BaseOptimizer):
@@ -113,7 +137,6 @@ class BaseScipyOptimizer(BaseOptimizer):
:rtype: tuple of tuples
"""
# If it is a collection with the right length, assume they are all bounds.
print(test_bounds)
if len(test_bounds) == self.n_assets and not isinstance(
test_bounds[0], (float, int)
):