mirror of
https://github.com/AI4Finance-Foundation/FinGPT.git
synced 2024-02-15 23:10:01 +03:00
refine the load_dataset function
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
import os
|
||||
import datasets
|
||||
|
||||
|
||||
# A dictionary to store various prompt templates.
|
||||
template_dict = {
|
||||
'default': 'Instruction: {instruction}\nInput: {input}\nAnswer: '
|
||||
}
|
||||
|
||||
# A dictionary to store the LoRA module mapping for different models.
|
||||
lora_module_dict = {
|
||||
'chatglm2': ['query_key_value'],
|
||||
'falcon': ['query_key_value'],
|
||||
@@ -17,16 +18,18 @@ lora_module_dict = {
|
||||
}
|
||||
|
||||
|
||||
# Function to generate prompts based on the instruction, input, and chosen template.
|
||||
def get_prompt(template, instruction, input):
|
||||
|
||||
# If there's an instruction, format the prompt accordingly.
|
||||
# Otherwise, just return the input as is.
|
||||
if instruction:
|
||||
return template_dict[template].format(instruction=instruction, input=input)
|
||||
else:
|
||||
return input
|
||||
|
||||
|
||||
# Function to map the dataset features to prompt for testing.
|
||||
def test_mapping(args, feature):
|
||||
|
||||
# Generate the prompt based on the instruction and input from the feature.
|
||||
prompt = get_prompt(
|
||||
args.instruct_template,
|
||||
feature['instruction'],
|
||||
@@ -36,31 +39,39 @@ def test_mapping(args, feature):
|
||||
"prompt": prompt,
|
||||
}
|
||||
|
||||
|
||||
# Function to tokenize the prompts and targets for training/testing.
|
||||
def tokenize(args, tokenizer, feature):
|
||||
|
||||
# Generate the prompt.
|
||||
prompt = get_prompt(
|
||||
args.instruct_template,
|
||||
feature['instruction'],
|
||||
feature['input']
|
||||
)
|
||||
# Tokenize the prompt.
|
||||
prompt_ids = tokenizer(
|
||||
prompt, padding=False,
|
||||
max_length=args.max_length, truncation=True
|
||||
)['input_ids']
|
||||
|
||||
# Tokenize the target/output.
|
||||
target_ids = tokenizer(
|
||||
feature['output'].strip(), padding=False,
|
||||
max_length=args.max_length, truncation=True,
|
||||
add_special_tokens=False
|
||||
)['input_ids']
|
||||
|
||||
|
||||
# Combine the tokenized prompt and target.
|
||||
input_ids = prompt_ids + target_ids
|
||||
|
||||
# Check if the combined length exceeds the maximum allowed length.
|
||||
exceed_max_length = len(input_ids) >= args.max_length
|
||||
|
||||
# Add EOS Token
|
||||
# Add an EOS token if it's not already there,
|
||||
# and if we haven't exceeded the max length.
|
||||
if input_ids[-1] != tokenizer.eos_token_id and not exceed_max_length:
|
||||
input_ids.append(tokenizer.eos_token_id)
|
||||
|
||||
|
||||
# Prepare the labels for training. The labels should start from where the prompt ends.
|
||||
label_ids = [tokenizer.pad_token_id] * len(prompt_ids) + input_ids[len(prompt_ids):]
|
||||
|
||||
return {
|
||||
@@ -98,22 +109,56 @@ def parse_model_name(name, from_remote=False):
|
||||
else:
|
||||
valid_model_names = ', '.join(model_paths.keys())
|
||||
raise ValueError(f"Undefined base model '{name}'. Valid model names are: {valid_model_names}")
|
||||
|
||||
|
||||
|
||||
def load_dataset(names, from_remote=False):
|
||||
dataset_names = [d for d in names.split(',')]
|
||||
"""
|
||||
Load one or multiple datasets based on the provided names and source location.
|
||||
|
||||
Args:
|
||||
names (str): A comma-separated list of dataset names. Each name can be followed by '*n' to indicate replication.
|
||||
from_remote (bool): If True, load the dataset from Hugging Face's model hub. Otherwise, load it from a local disk.
|
||||
|
||||
Returns:
|
||||
List[Dataset]: A list of loaded datasets. Each dataset is possibly replicated based on the input names.
|
||||
"""
|
||||
# Split the dataset names by commas for handling multiple datasets
|
||||
dataset_names = names.split(',')
|
||||
dataset_list = []
|
||||
|
||||
for name in dataset_names:
|
||||
rep = 1
|
||||
if not os.path.exists(name):
|
||||
rep = int(name.split('*')[1]) if '*' in name else 1
|
||||
name = ('FinGPT/fingpt-' if from_remote else 'data/fingpt-') + name.split('*')[0]
|
||||
tmp_dataset = datasets.load_dataset(name) if from_remote else datasets.load_from_disk(name)
|
||||
# Initialize replication factor to 1
|
||||
replication_factor = 1
|
||||
dataset_name = name
|
||||
|
||||
# Check if the dataset name includes a replication factor
|
||||
if '*' in name:
|
||||
dataset_name, replication_factor = name.split('*')
|
||||
replication_factor = int(replication_factor)
|
||||
if replication_factor < 1:
|
||||
raise ValueError("Replication factor must be a positive integer.")
|
||||
|
||||
# Construct the correct dataset path or name based on the source location
|
||||
dataset_path_or_name = ('FinGPT/fingpt-' if from_remote else 'data/fingpt-') + dataset_name
|
||||
if not os.path.exists(dataset_path_or_name) and not from_remote:
|
||||
raise FileNotFoundError(f"The dataset path {dataset_path_or_name} does not exist.")
|
||||
|
||||
# Load the dataset
|
||||
try:
|
||||
tmp_dataset = datasets.load_dataset(dataset_path_or_name) if from_remote else datasets.load_from_disk(
|
||||
dataset_path_or_name)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load the dataset: {str(e)}")
|
||||
|
||||
# Check for 'test' split and create it from 'train' if necessary
|
||||
if 'test' not in tmp_dataset:
|
||||
if 'train' in tmp_dataset:
|
||||
tmp_dataset = tmp_dataset['train']
|
||||
tmp_dataset = tmp_dataset.train_test_split(0.2, shuffle=True, seed=42)
|
||||
|
||||
dataset_list.extend([tmp_dataset] * rep)
|
||||
tmp_dataset = tmp_dataset.train_test_split(test_size=0.2, shuffle=True, seed=42)
|
||||
else:
|
||||
raise ValueError("The dataset must contain a 'train' or 'test' split.")
|
||||
|
||||
# Append the possibly replicated dataset to the list
|
||||
dataset_list.extend([tmp_dataset] * replication_factor)
|
||||
|
||||
return dataset_list
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user