mirror of
https://github.com/ViralLab/TurkishBERTweet.git
synced 2023-12-19 18:19:59 +03:00
refactoring
This commit is contained in:
@@ -93,10 +93,10 @@ if getattr(tokenizer, "pad_token_id") is None:
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
id2label = {0: "negative", 2: "positive", 1: "neutral"}
|
||||
best_model = AutoModelForSequenceClassification.from_pretrained(
|
||||
turkishBERTweet_sa = AutoModelForSequenceClassification.from_pretrained(
|
||||
peft_config.base_model_name_or_path, return_dict=True, num_labels=3, id2label=id2label
|
||||
)
|
||||
best_model = PeftModel.from_pretrained(best_model, peft_model)
|
||||
turkishBERTweet_sa = PeftModel.from_pretrained(turkishBERTweet_sa, peft_model)
|
||||
|
||||
sample_texts = [
|
||||
"Viral lab da insanlar hep birlikte çalışıyorlar. hepbirlikte çalışan insanlar birbirlerine yakın oluyorlar.",
|
||||
@@ -110,7 +110,7 @@ preprocessed_texts = [preprocess(s) for s in sample_texts]
|
||||
|
||||
for s in preprocessed_texts:
|
||||
ids = tokenizer.encode_plus(s, return_tensors="pt")
|
||||
label_id = best_model(**ids).logits.argmax(-1).item()
|
||||
label_id = turkishBERTweet_sa(**ids).logits.argmax(-1).item()
|
||||
print(id2label[label_id],":", s)
|
||||
```
|
||||
|
||||
|
||||
Reference in New Issue
Block a user