mirror of
https://github.com/thomashirtz/portfolio-management.git
synced 2022-03-03 23:56:42 +03:00
Update scripts
This commit is contained in:
52
scripts/new_prepare_pickled_dataset.py
Normal file
52
scripts/new_prepare_pickled_dataset.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from portfolio_management.data.manager import Manager
|
||||
from portfolio_management.data.preprocessing import PCAPreprocessing
|
||||
from portfolio_management.data.dataset import DatasetManager
|
||||
|
||||
|
||||
def create_database(database_name, symbol_list, interval, start, end):
|
||||
manager = Manager(
|
||||
database_name=database_name,
|
||||
reset_tables=True,
|
||||
)
|
||||
manager.insert(
|
||||
symbol_list=symbol_list,
|
||||
interval=interval,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
download_train_data = False
|
||||
download_test_data = False
|
||||
|
||||
symbol_list = [
|
||||
'USDCUSDT',
|
||||
'BTCUSDT',
|
||||
'ETHUSDT',
|
||||
'BNBUSDT',
|
||||
'LTCUSDT',
|
||||
'XRPUSDT'
|
||||
]
|
||||
interval = '1h'
|
||||
|
||||
train_database_name = 'USDT_train'
|
||||
start = "2020-01-01"
|
||||
end = "2020-12-31"
|
||||
if download_train_data:
|
||||
create_database(train_database_name, symbol_list, interval, start, end)
|
||||
|
||||
test_database_name = 'USDT_test'
|
||||
start = "2021-01-01"
|
||||
end = "2021-05-01"
|
||||
if download_test_data:
|
||||
create_database(test_database_name, symbol_list, interval, start, end)
|
||||
|
||||
preprocessing_class = PCAPreprocessing()
|
||||
|
||||
dataset_manager = DatasetManager(
|
||||
train_database_name=train_database_name,
|
||||
test_database_name=test_database_name,
|
||||
)
|
||||
dataset_manager.run(preprocessing_class)
|
||||
@@ -5,9 +5,34 @@ from portfolio_management.data.retrieve import get_dataset
|
||||
from portfolio_management.data.preprocessing import get_pca_preprocessing_function
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def main(database_name, symbol_list, interval, start, end):
|
||||
manager = Manager(
|
||||
database_name=database_name,
|
||||
reset_tables=True,
|
||||
)
|
||||
|
||||
database_name = 'USDT_train'
|
||||
|
||||
manager.insert(
|
||||
symbol_list=symbol_list,
|
||||
interval=interval,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
preprocessing_function = get_pca_preprocessing_function()
|
||||
|
||||
dataset = get_dataset( # todo find a way to prepare the test dataset the same way as the train set
|
||||
database_name=database_name,
|
||||
interval=interval,
|
||||
preprocessing=preprocessing_function
|
||||
)
|
||||
|
||||
datasets_folder_path = p.datasets_folder_path
|
||||
path_dataset = datasets_folder_path.joinpath(database_name).with_suffix('.pkl')
|
||||
pickle_dump(path_dataset, dataset)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
symbol_list = [
|
||||
'USDCUSDT',
|
||||
@@ -18,28 +43,13 @@ if __name__ == '__main__':
|
||||
'XRPUSDT'
|
||||
]
|
||||
interval = '1h'
|
||||
|
||||
database_name = 'USDT_train'
|
||||
start = "2020-01-01"
|
||||
end = "2020-12-31"
|
||||
# main(database_name, symbol_list, interval, start, end)
|
||||
|
||||
manager = Manager(
|
||||
database_name=database_name,
|
||||
reset_tables=True,
|
||||
)
|
||||
manager.insert(
|
||||
symbol_list=symbol_list,
|
||||
interval=interval,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
preprocessing_function = get_pca_preprocessing_function()
|
||||
|
||||
dataset = get_dataset( # todo find a way to prepare the test dataset the same way as the train set
|
||||
database_name=database_name,
|
||||
interval=interval,
|
||||
preprocessing=preprocessing_function
|
||||
)
|
||||
|
||||
datasets_folder_path = p.datasets_folder_path
|
||||
path_dataset = datasets_folder_path.joinpath(database_name).with_suffix('.pkl')
|
||||
pickle_dump(path_dataset, dataset)
|
||||
database_name = 'USDT_test'
|
||||
start = "2021-01-01"
|
||||
end = "2021-05-01"
|
||||
main(database_name, symbol_list, interval, start, end)
|
||||
|
||||
@@ -14,17 +14,17 @@ if __name__ == '__main__': # can't test sac yet, there is 'inf' value in datase
|
||||
# first run 'prepare_pickled_dataset' !
|
||||
# todo find why it is slow
|
||||
|
||||
train = True
|
||||
train = False
|
||||
dataset_name = 'USDT_train'
|
||||
|
||||
env_kwargs = {
|
||||
'dataset_name': dataset_name,
|
||||
'num_steps': 100,
|
||||
'fees': 0.002,
|
||||
'fees': 0.005,
|
||||
'seed': 1,
|
||||
'step_size': 1,
|
||||
'chronologically': False,
|
||||
'observation_size': 12,
|
||||
'observation_size': 50,
|
||||
'stake_range': [100, 100],
|
||||
}
|
||||
|
||||
@@ -58,14 +58,17 @@ if __name__ == '__main__': # can't test sac yet, there is 'inf' value in datase
|
||||
)
|
||||
|
||||
else:
|
||||
env_kwarg_test = { # todo give the possibility to change the interval
|
||||
dataset_name = 'USDT_test'
|
||||
env_kwarg_test = {
|
||||
'dataset_name': dataset_name,
|
||||
'num_steps': None,
|
||||
}
|
||||
env_kwargs.update(**env_kwarg_test)
|
||||
run_name = '20210608_204057_Portfolio-v0'
|
||||
soft_actor_critic.evaluate(
|
||||
run_name = '20210720_192852_Portfolio-v0'
|
||||
soft_actor_critic.render(
|
||||
'Portfolio-v0',
|
||||
run_name=run_name,
|
||||
env_kwargs=env_kwargs,
|
||||
num_episodes=100
|
||||
num_episodes=1
|
||||
)
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from portfolio_management.io_utilities import pickle_dump
|
||||
from portfolio_management.soft_actor_critic import train
|
||||
|
||||
|
||||
def test_episode(download_and_pickle: bool = True):
|
||||
def test_episode(download_and_pickle: bool = False):
|
||||
database_name = 'test_episode'
|
||||
|
||||
if download_and_pickle:
|
||||
|
||||
Reference in New Issue
Block a user