Update scripts

This commit is contained in:
Thomas
2021-07-22 11:08:26 +02:00
parent 375fc733ac
commit 74a3c735b5
4 changed files with 97 additions and 32 deletions

View File

@@ -0,0 +1,52 @@
from portfolio_management.data.manager import Manager
from portfolio_management.data.preprocessing import PCAPreprocessing
from portfolio_management.data.dataset import DatasetManager
def create_database(database_name, symbol_list, interval, start, end):
manager = Manager(
database_name=database_name,
reset_tables=True,
)
manager.insert(
symbol_list=symbol_list,
interval=interval,
start=start,
end=end,
)
if __name__ == '__main__':
download_train_data = False
download_test_data = False
symbol_list = [
'USDCUSDT',
'BTCUSDT',
'ETHUSDT',
'BNBUSDT',
'LTCUSDT',
'XRPUSDT'
]
interval = '1h'
train_database_name = 'USDT_train'
start = "2020-01-01"
end = "2020-12-31"
if download_train_data:
create_database(train_database_name, symbol_list, interval, start, end)
test_database_name = 'USDT_test'
start = "2021-01-01"
end = "2021-05-01"
if download_test_data:
create_database(test_database_name, symbol_list, interval, start, end)
preprocessing_class = PCAPreprocessing()
dataset_manager = DatasetManager(
train_database_name=train_database_name,
test_database_name=test_database_name,
)
dataset_manager.run(preprocessing_class)

View File

@@ -5,9 +5,34 @@ from portfolio_management.data.retrieve import get_dataset
from portfolio_management.data.preprocessing import get_pca_preprocessing_function
if __name__ == '__main__':
def main(database_name, symbol_list, interval, start, end):
manager = Manager(
database_name=database_name,
reset_tables=True,
)
database_name = 'USDT_train'
manager.insert(
symbol_list=symbol_list,
interval=interval,
start=start,
end=end,
)
preprocessing_function = get_pca_preprocessing_function()
dataset = get_dataset( # todo find a way to prepare the test dataset the same way as the train set
database_name=database_name,
interval=interval,
preprocessing=preprocessing_function
)
datasets_folder_path = p.datasets_folder_path
path_dataset = datasets_folder_path.joinpath(database_name).with_suffix('.pkl')
pickle_dump(path_dataset, dataset)
if __name__ == '__main__':
symbol_list = [
'USDCUSDT',
@@ -18,28 +43,13 @@ if __name__ == '__main__':
'XRPUSDT'
]
interval = '1h'
database_name = 'USDT_train'
start = "2020-01-01"
end = "2020-12-31"
# main(database_name, symbol_list, interval, start, end)
manager = Manager(
database_name=database_name,
reset_tables=True,
)
manager.insert(
symbol_list=symbol_list,
interval=interval,
start=start,
end=end,
)
preprocessing_function = get_pca_preprocessing_function()
dataset = get_dataset( # todo find a way to prepare the test dataset the same way as the train set
database_name=database_name,
interval=interval,
preprocessing=preprocessing_function
)
datasets_folder_path = p.datasets_folder_path
path_dataset = datasets_folder_path.joinpath(database_name).with_suffix('.pkl')
pickle_dump(path_dataset, dataset)
database_name = 'USDT_test'
start = "2021-01-01"
end = "2021-05-01"
main(database_name, symbol_list, interval, start, end)

View File

@@ -14,17 +14,17 @@ if __name__ == '__main__': # can't test sac yet, there is 'inf' value in datase
# first run 'prepare_pickled_dataset' !
# todo find why it is slow
train = True
train = False
dataset_name = 'USDT_train'
env_kwargs = {
'dataset_name': dataset_name,
'num_steps': 100,
'fees': 0.002,
'fees': 0.005,
'seed': 1,
'step_size': 1,
'chronologically': False,
'observation_size': 12,
'observation_size': 50,
'stake_range': [100, 100],
}
@@ -58,14 +58,17 @@ if __name__ == '__main__': # can't test sac yet, there is 'inf' value in datase
)
else:
env_kwarg_test = { # todo give the possibility to change the interval
dataset_name = 'USDT_test'
env_kwarg_test = {
'dataset_name': dataset_name,
'num_steps': None,
}
env_kwargs.update(**env_kwarg_test)
run_name = '20210608_204057_Portfolio-v0'
soft_actor_critic.evaluate(
run_name = '20210720_192852_Portfolio-v0'
soft_actor_critic.render(
'Portfolio-v0',
run_name=run_name,
env_kwargs=env_kwargs,
num_episodes=100
num_episodes=1
)

View File

@@ -5,7 +5,7 @@ from portfolio_management.io_utilities import pickle_dump
from portfolio_management.soft_actor_critic import train
def test_episode(download_and_pickle: bool = True):
def test_episode(download_and_pickle: bool = False):
database_name = 'test_episode'
if download_and_pickle: