Edit path handling and continue supervised

This commit is contained in:
Thomas
2021-06-21 20:04:28 +02:00
parent 69ebc618c8
commit be492df86b
10 changed files with 93 additions and 19 deletions

View File

@@ -0,0 +1,14 @@
{
"network_config": {
"module_list": [
{"name": "Conv1d", "kwargs": {"in_channels": 9, "out_channels": 32, "kernel_size": 3}},
{"name": "ReLU"},
{"name": "Conv1d", "kwargs": {"in_channels": 32, "out_channels": 16, "kernel_size": 3}},
{"name": "ReLU"},
{"name": "Conv1d", "kwargs": {"in_channels": 16, "out_channels": 9, "kernel_size": 3}},
{"name": "ReLU"},
{"name": "View", "kwargs": {"shape": [-1, 396]}},
{"name": "Linear", "kwargs": {"in_features": 396, "out_features": 9}}
]
}
}

View File

@@ -21,7 +21,6 @@ from portfolio_management.database.retrieve import get_interval_id
from portfolio_management.database.utilities import session_scope
from portfolio_management.database.utilities import get_engine_url
from portfolio_management.database.utilities import try_insert
from portfolio_management.database.utilities import get_path_database
from portfolio_management.database.utilities import silent_bulk_insert
import portfolio_management.paths as p
@@ -34,16 +33,11 @@ class Manager:
echo: bool = False,
reset_tables: bool = False,
):
if isinstance(folder_path, str):
self.folder_path = Path(folder_path)
else:
self.folder_path = p.databases_folder_path
self.database_name = database_name
self.folder_path = p.get_databases_folder_path(folder_path)
create_folders(self.folder_path)
create_folders(get_path_database(str(self.folder_path), self.database_name).parent) # todo need to support path
self.engine_url = get_engine_url(str(self.folder_path), self.database_name) # todo need to support path
self.engine_url = get_engine_url(str(self.folder_path), self.database_name)
self.engine = create_engine(self.engine_url, echo=echo)
self.Session = sessionmaker(bind=self.engine)
@@ -75,14 +69,14 @@ class Manager:
)
config = {
"folder_path": self.folder_path,
"folder_path": str(self.folder_path),
"database_name": self.database_name,
"symbol_list": symbol_list,
"interval": interval,
"start": start,
"end": end,
}
path_yaml_file = Path(self.folder_path).joinpath(self.database_name).with_suffix('.yaml')
path_yaml_file = self.folder_path.joinpath(self.database_name).with_suffix('.yaml')
write_yaml(path_yaml_file=path_yaml_file, dictionary=config)
print('Config saved')

View File

@@ -81,7 +81,7 @@ def get_dataset(
float_32: bool = True
) -> xr.Dataset:
databases_folder_path = folder_path or p.databases_folder_path
databases_folder_path = p.get_databases_folder_path(folder_path)
with session_scope(
get_sessionmaker(str(databases_folder_path), database_name, echo),

View File

@@ -1,17 +1,23 @@
from pathlib import Path
from typing import Union
from typing import Optional
import portfolio_management.paths as p
from portfolio_management.io_utilities import pickle_load
from portfolio_management.database.retrieve import get_dataset as _get_dataset
def get_dataset(
name: str,
databases_folder_path: Path,
datasets_folder_path: Path
databases_folder_path: Optional[Union[str, Path]] = None,
datasets_folder_path: Optional[Union[str, Path]] = None,
):
try:
path = datasets_folder_path.joinpath(name).with_suffix('.pkl')
return pickle_load(path)
folder_path = p.get_datasets_folder_path(datasets_folder_path)
file_path = folder_path.joinpath(name).with_suffix('.pkl')
return pickle_load(file_path)
except Exception as e:
print(e)
return _get_dataset(folder_path=str(databases_folder_path), database_name=name)
path = p.get_databases_folder_path(databases_folder_path)
return _get_dataset(folder_path=str(path), database_name=name)

View File

@@ -1,4 +1,5 @@
import yaml
import json
import pickle
from pathlib import Path
@@ -20,3 +21,13 @@ def create_folders(path: Path) -> None:
def write_yaml(path_yaml_file: Path, dictionary: dict):
with open(path_yaml_file, 'w') as f:
yaml.dump(dictionary, f, sort_keys=False)
def read_json_file(file_path) -> dict:
with open(file_path, 'r') as f:
return json.load(f)
def write_json_file(file_path, data: dict) -> None:
with open(file_path, 'w') as f:
json.dump(data, f)

View File

@@ -1,6 +1,29 @@
from pathlib import Path
from typing import Union
from typing import Optional
project_path = Path(__file__).parent.parent
datasets_folder_path = project_path.joinpath('datasets')
databases_folder_path = project_path.joinpath('databases')
models_folder_path = project_path.joinpath('models')
def _get_path(default_path: Path, path: Optional[Union[str, Path]] = None) -> Path:
if isinstance(path, str):
return Path(path)
elif isinstance(path, Path):
return path
else:
return default_path
def get_datasets_folder_path(path: Optional[Union[str, Path]] = None) -> Path:
return _get_path(path, datasets_folder_path)
def get_databases_folder_path(path: Optional[Union[str, Path]] = None) -> Path:
return _get_path(path, datasets_folder_path)
def get_models_folder_path(path: Optional[Union[str, Path]] = None) -> Path:
return _get_path(path, models_folder_path)

View File

@@ -3,7 +3,7 @@ import torch.nn as nn
class BasicModel(nn.Module):
def __init__(self, num_observations, num_properties: int = 10):
def __init__(self, num_observations, num_properties: int = 9):
super(BasicModel, self).__init__()
self.num_properties = num_properties
self.num_observations = num_observations

View File

@@ -0,0 +1,6 @@
import os
import sys
scripts_directory_path = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, scripts_directory_path + '/../')
# todo do training and testing loop

View File

@@ -0,0 +1,20 @@
from torchsummary import summary
from portfolio_management.paths import get_models_folder_path
from portfolio_management.io_utilities import read_json_file
from portfolio_management.supervised.utilities import get_sequential
from portfolio_management.soft_actor_critic.utilities import get_device
if __name__ == '__main__':
model_name = 'model_0'
models_path = get_models_folder_path(None)
path_config = models_path.joinpath(model_name).joinpath('config.json')
json_config = read_json_file(path_config)
device = get_device()
sequential = get_sequential(json_config['network_config']['module_list']).to(device)
input_shape = tuple([9, 50]) # Channels * Number of Observations
summary(sequential, input_size=input_shape)

View File