refactoring

This commit is contained in:
AliNajafi
2023-10-02 14:16:42 +03:00
parent dc62e9a290
commit 097e18f213

View File

@@ -93,10 +93,10 @@ if getattr(tokenizer, "pad_token_id") is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
id2label = {0: "negative", 2: "positive", 1: "neutral"}
best_model = AutoModelForSequenceClassification.from_pretrained(
turkishBERTweet_sa = AutoModelForSequenceClassification.from_pretrained(
peft_config.base_model_name_or_path, return_dict=True, num_labels=3, id2label=id2label
)
best_model = PeftModel.from_pretrained(best_model, peft_model)
turkishBERTweet_sa = PeftModel.from_pretrained(turkishBERTweet_sa, peft_model)
sample_texts = [
"Viral lab da insanlar hep birlikte çalışıyorlar. hepbirlikte çalışan insanlar birbirlerine yakın oluyorlar.",
@@ -110,7 +110,7 @@ preprocessed_texts = [preprocess(s) for s in sample_texts]
for s in preprocessed_texts:
ids = tokenizer.encode_plus(s, return_tensors="pt")
label_id = best_model(**ids).logits.argmax(-1).item()
label_id = turkishBERTweet_sa(**ids).logits.argmax(-1).item()
print(id2label[label_id],":", s)
```