mirror of
https://github.com/thomashirtz/portfolio-management.git
synced 2022-03-03 23:56:42 +03:00
Edit path handling and continue supervised
This commit is contained in:
14
models/model_0/config.json
Normal file
14
models/model_0/config.json
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"network_config": {
|
||||
"module_list": [
|
||||
{"name": "Conv1d", "kwargs": {"in_channels": 9, "out_channels": 32, "kernel_size": 3}},
|
||||
{"name": "ReLU"},
|
||||
{"name": "Conv1d", "kwargs": {"in_channels": 32, "out_channels": 16, "kernel_size": 3}},
|
||||
{"name": "ReLU"},
|
||||
{"name": "Conv1d", "kwargs": {"in_channels": 16, "out_channels": 9, "kernel_size": 3}},
|
||||
{"name": "ReLU"},
|
||||
{"name": "View", "kwargs": {"shape": [-1, 396]}},
|
||||
{"name": "Linear", "kwargs": {"in_features": 396, "out_features": 9}}
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -21,7 +21,6 @@ from portfolio_management.database.retrieve import get_interval_id
|
||||
from portfolio_management.database.utilities import session_scope
|
||||
from portfolio_management.database.utilities import get_engine_url
|
||||
from portfolio_management.database.utilities import try_insert
|
||||
from portfolio_management.database.utilities import get_path_database
|
||||
from portfolio_management.database.utilities import silent_bulk_insert
|
||||
import portfolio_management.paths as p
|
||||
|
||||
@@ -34,16 +33,11 @@ class Manager:
|
||||
echo: bool = False,
|
||||
reset_tables: bool = False,
|
||||
):
|
||||
if isinstance(folder_path, str):
|
||||
self.folder_path = Path(folder_path)
|
||||
else:
|
||||
self.folder_path = p.databases_folder_path
|
||||
|
||||
self.database_name = database_name
|
||||
self.folder_path = p.get_databases_folder_path(folder_path)
|
||||
create_folders(self.folder_path)
|
||||
|
||||
create_folders(get_path_database(str(self.folder_path), self.database_name).parent) # todo need to support path
|
||||
|
||||
self.engine_url = get_engine_url(str(self.folder_path), self.database_name) # todo need to support path
|
||||
self.engine_url = get_engine_url(str(self.folder_path), self.database_name)
|
||||
self.engine = create_engine(self.engine_url, echo=echo)
|
||||
self.Session = sessionmaker(bind=self.engine)
|
||||
|
||||
@@ -75,14 +69,14 @@ class Manager:
|
||||
)
|
||||
|
||||
config = {
|
||||
"folder_path": self.folder_path,
|
||||
"folder_path": str(self.folder_path),
|
||||
"database_name": self.database_name,
|
||||
"symbol_list": symbol_list,
|
||||
"interval": interval,
|
||||
"start": start,
|
||||
"end": end,
|
||||
}
|
||||
path_yaml_file = Path(self.folder_path).joinpath(self.database_name).with_suffix('.yaml')
|
||||
path_yaml_file = self.folder_path.joinpath(self.database_name).with_suffix('.yaml')
|
||||
write_yaml(path_yaml_file=path_yaml_file, dictionary=config)
|
||||
print('Config saved')
|
||||
|
||||
|
||||
@@ -81,7 +81,7 @@ def get_dataset(
|
||||
float_32: bool = True
|
||||
) -> xr.Dataset:
|
||||
|
||||
databases_folder_path = folder_path or p.databases_folder_path
|
||||
databases_folder_path = p.get_databases_folder_path(folder_path)
|
||||
|
||||
with session_scope(
|
||||
get_sessionmaker(str(databases_folder_path), database_name, echo),
|
||||
|
||||
@@ -1,17 +1,23 @@
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Optional
|
||||
|
||||
import portfolio_management.paths as p
|
||||
from portfolio_management.io_utilities import pickle_load
|
||||
from portfolio_management.database.retrieve import get_dataset as _get_dataset
|
||||
|
||||
|
||||
def get_dataset(
|
||||
name: str,
|
||||
databases_folder_path: Path,
|
||||
datasets_folder_path: Path
|
||||
databases_folder_path: Optional[Union[str, Path]] = None,
|
||||
datasets_folder_path: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
try:
|
||||
path = datasets_folder_path.joinpath(name).with_suffix('.pkl')
|
||||
return pickle_load(path)
|
||||
folder_path = p.get_datasets_folder_path(datasets_folder_path)
|
||||
file_path = folder_path.joinpath(name).with_suffix('.pkl')
|
||||
return pickle_load(file_path)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
return _get_dataset(folder_path=str(databases_folder_path), database_name=name)
|
||||
path = p.get_databases_folder_path(databases_folder_path)
|
||||
return _get_dataset(folder_path=str(path), database_name=name)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import yaml
|
||||
import json
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
@@ -20,3 +21,13 @@ def create_folders(path: Path) -> None:
|
||||
def write_yaml(path_yaml_file: Path, dictionary: dict):
|
||||
with open(path_yaml_file, 'w') as f:
|
||||
yaml.dump(dictionary, f, sort_keys=False)
|
||||
|
||||
|
||||
def read_json_file(file_path) -> dict:
|
||||
with open(file_path, 'r') as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def write_json_file(file_path, data: dict) -> None:
|
||||
with open(file_path, 'w') as f:
|
||||
json.dump(data, f)
|
||||
|
||||
@@ -1,6 +1,29 @@
|
||||
from pathlib import Path
|
||||
|
||||
from typing import Union
|
||||
from typing import Optional
|
||||
|
||||
project_path = Path(__file__).parent.parent
|
||||
datasets_folder_path = project_path.joinpath('datasets')
|
||||
databases_folder_path = project_path.joinpath('databases')
|
||||
models_folder_path = project_path.joinpath('models')
|
||||
|
||||
|
||||
def _get_path(default_path: Path, path: Optional[Union[str, Path]] = None) -> Path:
|
||||
if isinstance(path, str):
|
||||
return Path(path)
|
||||
elif isinstance(path, Path):
|
||||
return path
|
||||
else:
|
||||
return default_path
|
||||
|
||||
|
||||
def get_datasets_folder_path(path: Optional[Union[str, Path]] = None) -> Path:
|
||||
return _get_path(path, datasets_folder_path)
|
||||
|
||||
|
||||
def get_databases_folder_path(path: Optional[Union[str, Path]] = None) -> Path:
|
||||
return _get_path(path, datasets_folder_path)
|
||||
|
||||
|
||||
def get_models_folder_path(path: Optional[Union[str, Path]] = None) -> Path:
|
||||
return _get_path(path, models_folder_path)
|
||||
|
||||
@@ -3,7 +3,7 @@ import torch.nn as nn
|
||||
|
||||
|
||||
class BasicModel(nn.Module):
|
||||
def __init__(self, num_observations, num_properties: int = 10):
|
||||
def __init__(self, num_observations, num_properties: int = 9):
|
||||
super(BasicModel, self).__init__()
|
||||
self.num_properties = num_properties
|
||||
self.num_observations = num_observations
|
||||
|
||||
6
scripts/run_supervised.py
Normal file
6
scripts/run_supervised.py
Normal file
@@ -0,0 +1,6 @@
|
||||
import os
|
||||
import sys
|
||||
scripts_directory_path = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, scripts_directory_path + '/../')
|
||||
|
||||
# todo do training and testing loop
|
||||
20
scripts/run_torchsummary.py
Normal file
20
scripts/run_torchsummary.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from torchsummary import summary
|
||||
|
||||
from portfolio_management.paths import get_models_folder_path
|
||||
from portfolio_management.io_utilities import read_json_file
|
||||
from portfolio_management.supervised.utilities import get_sequential
|
||||
from portfolio_management.soft_actor_critic.utilities import get_device
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model_name = 'model_0'
|
||||
|
||||
models_path = get_models_folder_path(None)
|
||||
path_config = models_path.joinpath(model_name).joinpath('config.json')
|
||||
json_config = read_json_file(path_config)
|
||||
|
||||
device = get_device()
|
||||
sequential = get_sequential(json_config['network_config']['module_list']).to(device)
|
||||
|
||||
input_shape = tuple([9, 50]) # Channels * Number of Observations
|
||||
summary(sequential, input_size=input_shape)
|
||||
0
tests/supervised/__init__.py
Normal file
0
tests/supervised/__init__.py
Normal file
Reference in New Issue
Block a user