Files
OpenPipe-llm/examples/classify-recipes/utils.py
Kyle Corbitt 40638a7848 more work
2023-08-24 23:49:44 +00:00

38 lines
1.0 KiB
Python

import yaml
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from peft import PeftModel
import os
def merge_lora_model(config_file: str):
config = yaml.load(open(config_file, "r"), Loader=yaml.FullLoader)
base_model = config["base_model"]
lora_model = config["output_dir"]
merged_model = f"{lora_model}/merged"
if os.path.exists(merged_model):
print(f"Model {merged_model} already exists, skipping")
return merged_model
print("Loading base model")
model = AutoModelForCausalLM.from_pretrained(
base_model,
return_dict=True,
torch_dtype=torch.float16,
)
print("Loading PEFT model")
model = PeftModel.from_pretrained(model, lora_model)
print(f"Running merge_and_unload")
model = model.merge_and_unload()
tokenizer = AutoTokenizer.from_pretrained(base_model)
model.save_pretrained(merged_model)
tokenizer.save_pretrained(merged_model)
print(f"Model saved to {merged_model}")
return merged_model