38 lines
1.0 KiB
Python
38 lines
1.0 KiB
Python
import yaml
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
import torch
|
|
from peft import PeftModel
|
|
import os
|
|
|
|
|
|
def merge_lora_model(config_file: str):
|
|
config = yaml.load(open(config_file, "r"), Loader=yaml.FullLoader)
|
|
|
|
base_model = config["base_model"]
|
|
lora_model = config["output_dir"]
|
|
merged_model = f"{lora_model}/merged"
|
|
|
|
if os.path.exists(merged_model):
|
|
print(f"Model {merged_model} already exists, skipping")
|
|
return merged_model
|
|
|
|
print("Loading base model")
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
base_model,
|
|
return_dict=True,
|
|
torch_dtype=torch.float16,
|
|
)
|
|
|
|
print("Loading PEFT model")
|
|
model = PeftModel.from_pretrained(model, lora_model)
|
|
print(f"Running merge_and_unload")
|
|
model = model.merge_and_unload()
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
|
|
|
model.save_pretrained(merged_model)
|
|
tokenizer.save_pretrained(merged_model)
|
|
print(f"Model saved to {merged_model}")
|
|
|
|
return merged_model
|