mirror of
https://github.com/AI4Finance-Foundation/FinGPT.git
synced 2024-02-15 23:10:01 +03:00
Update utils.py
This commit is contained in:
@@ -100,11 +100,13 @@ def load_dataset(names, from_remote=False):
|
|||||||
if not os.path.exists(name):
|
if not os.path.exists(name):
|
||||||
rep = int(name.split('*')[1]) if '*' in name else 1
|
rep = int(name.split('*')[1]) if '*' in name else 1
|
||||||
name = ('FinGPT/fingpt-' if from_remote else 'data/fingpt-') + name.split('*')[0]
|
name = ('FinGPT/fingpt-' if from_remote else 'data/fingpt-') + name.split('*')[0]
|
||||||
tmp_dataset = datasets.load_from_disk(name)
|
tmp_dataset = datasets.load_dataset(name) if from_remote else datasets.load_from_disk(name)
|
||||||
if 'test' not in tmp_dataset:
|
if 'test' not in tmp_dataset:
|
||||||
|
if 'train' in tmp_dataset:
|
||||||
|
tmp_dataset = tmp_dataset['train']
|
||||||
tmp_dataset = tmp_dataset.train_test_split(0.2, shuffle=True, seed=42)
|
tmp_dataset = tmp_dataset.train_test_split(0.2, shuffle=True, seed=42)
|
||||||
|
|
||||||
dataset_list.extend([tmp_dataset] * rep)
|
dataset_list.extend([tmp_dataset] * rep)
|
||||||
return dataset_list
|
return dataset_list
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user