mirror of
https://github.com/MaartenGr/KeyBERT.git
synced 2022-03-14 19:18:06 +03:00
committed by
GitHub
parent
222cc5b96e
commit
abb52315f8
1
Makefile
1
Makefile
@@ -10,7 +10,6 @@ install-test:
|
||||
|
||||
pypi:
|
||||
python setup.py sdist
|
||||
python setup.py bdist_wheel --universal
|
||||
twine upload dist/*
|
||||
|
||||
clean:
|
||||
|
||||
25
README.md
25
README.md
@@ -117,6 +117,31 @@ of words you would like in the resulting keyphrases:
|
||||
'learning function']
|
||||
```
|
||||
|
||||
To diversify the results, we can use Maximal Margin Relevance (MMR) to create
|
||||
keywords / keyphrases which is also based on cosine similarity. The results
|
||||
with **high diversity**:
|
||||
|
||||
```python
|
||||
>>> model.extract_keywords(doc, keyphrase_length=3, stop_words='english', use_mmr=True, diversity=0.7)
|
||||
['algorithm generalize training',
|
||||
'labels unseen instances',
|
||||
'new examples optimal',
|
||||
'determine class labels',
|
||||
'supervised learning algorithm']
|
||||
```
|
||||
|
||||
The results with **low diversity**:
|
||||
|
||||
```python
|
||||
>>> model.extract_keywords(doc, keyphrase_length=3, stop_words='english', use_mmr=True, diversity=0.2)
|
||||
['algorithm generalize training',
|
||||
'learning machine learning',
|
||||
'learning algorithm analyzes',
|
||||
'supervised learning algorithm',
|
||||
'algorithm analyzes training']
|
||||
```
|
||||
|
||||
|
||||
## References
|
||||
Below, you can find several resources that were used for the creation of KeyBERT
|
||||
but most importantly, are amazing resources for creating impressive keyword extraction models:
|
||||
|
||||
126
keybert/mmr.py
126
keybert/mmr.py
@@ -1,92 +1,56 @@
|
||||
# Copyright (c) 2017-present, Swisscom (Schweiz) AG.
|
||||
# All rights reserved.
|
||||
#
|
||||
#Authors: Kamil Bennani-Smires, Yann Savary
|
||||
|
||||
|
||||
import numpy as np
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
from typing import List
|
||||
|
||||
|
||||
def MMR(doc_embedd, candidates, X, beta, N):
|
||||
"""
|
||||
Core method using Maximal Marginal Relevance in charge to return the top-N candidates
|
||||
:param candidates: list of candidates (string)
|
||||
:param X: numpy array with the embedding of each candidate in each row
|
||||
:param beta: hyperparameter beta for MMR (control tradeoff between informativeness and diversity)
|
||||
:param N: number of candidates to extract
|
||||
:return: A tuple with 3 elements :
|
||||
1)list of the top-N candidates (or less if there are not enough candidates) (list of string)
|
||||
2)list of associated relevance scores (list of float)
|
||||
3)list containing for each keyphrase a list of alias (list of list of string)
|
||||
def mmr(doc_embedding: np.ndarray,
|
||||
word_embeddings: np.ndarray,
|
||||
words: List[str],
|
||||
top_n: int = 5,
|
||||
diversity: float = 0.8) -> List[str]:
|
||||
""" Calculate Maximal Marginal Relevance (MMR)
|
||||
between candidate keywords and the document.
|
||||
|
||||
|
||||
MMR considers the similarity of keywords/keyphrases with the
|
||||
document, along with the similarity of already selected
|
||||
keywords and keyphrases. This results in a selection of keywords
|
||||
that maximize their within diversity with respect to the document.
|
||||
|
||||
Arguments:
|
||||
doc_embedding: The document embeddings
|
||||
word_embeddings: The embeddings of the selected candidate keywords/phrases
|
||||
words: The selected candidate keywords/keyphrases
|
||||
top_n: The number of keywords/keyhprases to return
|
||||
diversity: How diverse the select keywords/keyphrases are.
|
||||
Values between 0 and 1 with 0 being not diverse at all
|
||||
and 1 being most diverse.
|
||||
|
||||
Returns:
|
||||
List[str]: The selected keywords/keyphrases
|
||||
|
||||
"""
|
||||
|
||||
N = min(N, len(candidates))
|
||||
doc_sim = cosine_similarity(X, doc_embedd.reshape(1, -1))
|
||||
# Extract similarity within words, and between words and the document
|
||||
word_doc_similarity = cosine_similarity(word_embeddings, doc_embedding)
|
||||
word_similarity = cosine_similarity(word_embeddings)
|
||||
|
||||
doc_sim_norm = doc_sim/np.max(doc_sim)
|
||||
doc_sim_norm = 0.5 + (doc_sim_norm - np.average(doc_sim_norm)) / np.std(doc_sim_norm)
|
||||
# Initialize candidates and already choose best keyword/keyphras
|
||||
keywords_idx = [np.argmax(word_doc_similarity)]
|
||||
candidates_idx = [i for i in range(len(words)) if i != keywords_idx[0]]
|
||||
|
||||
sim_between = cosine_similarity(X)
|
||||
np.fill_diagonal(sim_between, np.NaN)
|
||||
for _ in range(top_n - 1):
|
||||
# Extract similarities within candidates and
|
||||
# between candidates and selected keywords/phrases
|
||||
candidate_similarities = word_doc_similarity[candidates_idx, :]
|
||||
target_similarities = np.max(word_similarity[candidates_idx][:, keywords_idx], axis=1)
|
||||
|
||||
sim_between_norm = sim_between/np.nanmax(sim_between, axis=0)
|
||||
sim_between_norm = \
|
||||
0.5 + (sim_between_norm - np.nanmean(sim_between_norm, axis=0)) / np.nanstd(sim_between_norm, axis=0)
|
||||
# Calculate MMR
|
||||
mmr = (1-diversity) * candidate_similarities - diversity * target_similarities.reshape(-1, 1)
|
||||
mmr_idx = candidates_idx[np.argmax(mmr)]
|
||||
|
||||
selected_candidates = []
|
||||
unselected_candidates = [c for c in range(len(candidates))]
|
||||
# Update keywords & candidates
|
||||
keywords_idx.append(mmr_idx)
|
||||
candidates_idx.remove(mmr_idx)
|
||||
|
||||
j = int(np.argmax(doc_sim))
|
||||
selected_candidates.append(j)
|
||||
unselected_candidates.remove(j)
|
||||
|
||||
for _ in range(N - 1):
|
||||
selec_array = np.array(selected_candidates)
|
||||
unselec_array = np.array(unselected_candidates)
|
||||
|
||||
distance_to_doc = doc_sim_norm[unselec_array, :]
|
||||
dist_between = sim_between_norm[unselec_array][:, selec_array]
|
||||
if dist_between.ndim == 1:
|
||||
dist_between = dist_between[:, np.newaxis]
|
||||
j = np.argmax(beta * distance_to_doc - (1 - beta) * np.max(dist_between, axis=1).reshape(-1, 1))
|
||||
item_idx = unselected_candidates[j]
|
||||
selected_candidates.append(item_idx)
|
||||
unselected_candidates.remove(item_idx)
|
||||
|
||||
return candidates, selected_candidates
|
||||
|
||||
|
||||
def max_normalization(array):
|
||||
"""
|
||||
Compute maximum normalization (max is set to 1) of the array
|
||||
:param array: 1-d array
|
||||
:return: 1-d array max- normalized : each value is multiplied by 1/max value
|
||||
"""
|
||||
return 1/np.max(array) * array.squeeze(axis=1)
|
||||
|
||||
|
||||
def get_aliases(kp_sim_between, candidates, threshold):
|
||||
"""
|
||||
Find candidates which are very similar to the keyphrases (aliases)
|
||||
:param kp_sim_between: ndarray of shape (nb_kp , nb candidates) containing the similarity
|
||||
of each kp with all the candidates. Note that the similarity between the keyphrase and itself should be set to
|
||||
NaN or 0
|
||||
:param candidates: array of candidates (array of string)
|
||||
:return: list containing for each keyphrase a list that contain candidates which are aliases
|
||||
(very similar) (list of list of string)
|
||||
"""
|
||||
|
||||
kp_sim_between = np.nan_to_num(kp_sim_between, 0)
|
||||
idx_sorted = np.flip(np.argsort(kp_sim_between), 1)
|
||||
aliases = []
|
||||
for kp_idx, item in enumerate(idx_sorted):
|
||||
alias_for_item = []
|
||||
for i in item:
|
||||
if kp_sim_between[kp_idx, i] >= threshold:
|
||||
alias_for_item.append(candidates[i])
|
||||
else:
|
||||
break
|
||||
aliases.append(alias_for_item)
|
||||
|
||||
return aliases
|
||||
return [words[idx] for idx in keywords_idx]
|
||||
|
||||
@@ -5,6 +5,7 @@ from sklearn.feature_extraction.text import CountVectorizer
|
||||
from tqdm import tqdm
|
||||
from typing import List, Union
|
||||
import warnings
|
||||
from .mmr import mmr
|
||||
|
||||
|
||||
class KeyBERT:
|
||||
@@ -36,7 +37,9 @@ class KeyBERT:
|
||||
keyphrase_length: int = 1,
|
||||
stop_words: Union[str, List[str]] = 'english',
|
||||
top_n: int = 5,
|
||||
min_df: int = 1) -> Union[List[str], List[List[str]]]:
|
||||
min_df: int = 1,
|
||||
use_mmr: bool = False,
|
||||
diversity: float = 0.5) -> Union[List[str], List[List[str]]]:
|
||||
""" Extract keywords/keyphrases
|
||||
|
||||
NOTE:
|
||||
@@ -61,6 +64,10 @@ class KeyBERT:
|
||||
top_n: Return the top n keywords/keyphrases
|
||||
min_df: Minimum document frequency of a word across all documents
|
||||
if keywords for multiple documents need to be extracted
|
||||
use_mmr: Whether to use Maximal Marginal Relevance (MMR) for the
|
||||
selection of keywords/keyphrases
|
||||
diversity: The diversity of the results between 0 and 1 if use_mmr
|
||||
is set to True
|
||||
|
||||
Returns:
|
||||
keywords: the top n keywords for a document
|
||||
@@ -71,7 +78,9 @@ class KeyBERT:
|
||||
return self._extract_keywords_single_doc(docs,
|
||||
keyphrase_length,
|
||||
stop_words,
|
||||
top_n)
|
||||
top_n,
|
||||
use_mmr,
|
||||
diversity)
|
||||
elif isinstance(docs, list):
|
||||
warnings.warn("Although extracting keywords for multiple documents is faster "
|
||||
"than iterating over single documents, it requires significant memory "
|
||||
@@ -86,7 +95,9 @@ class KeyBERT:
|
||||
doc: str,
|
||||
keyphrase_length: int = 1,
|
||||
stop_words: Union[str, List[str]] = 'english',
|
||||
top_n: int = 5) -> List[str]:
|
||||
top_n: int = 5,
|
||||
use_mmr: bool = False,
|
||||
diversity: float = 0.5) -> List[str]:
|
||||
""" Extract keywords/keyphrases for a single document
|
||||
|
||||
Arguments:
|
||||
@@ -94,6 +105,8 @@ class KeyBERT:
|
||||
keyphrase_length: Length, in words, of the extracted keywords/keyphrases
|
||||
stop_words: Stopwords to remove from the document
|
||||
top_n: Return the top n keywords/keyphrases
|
||||
use_mmr: Whether to use MMR
|
||||
diversity: The diversity of results between 0 and 1 if use_mmr is True
|
||||
|
||||
Returns:
|
||||
keywords: The top n keywords for a document
|
||||
@@ -106,14 +119,17 @@ class KeyBERT:
|
||||
words = count.get_feature_names()
|
||||
|
||||
# Extract Embeddings
|
||||
doc_embeddings = self.model.encode([doc])
|
||||
doc_embedding = self.model.encode([doc])
|
||||
word_embeddings = self.model.encode(words)
|
||||
|
||||
# Calculate distances and extract keywords
|
||||
distances = cosine_similarity(doc_embeddings, word_embeddings)
|
||||
keywords = [words[index] for index in distances.argsort()[0][-top_n:]]
|
||||
if use_mmr:
|
||||
keywords = mmr(doc_embedding, word_embeddings, words, top_n, diversity)
|
||||
else:
|
||||
distances = cosine_similarity(doc_embedding, word_embeddings)
|
||||
keywords = [words[index] for index in distances.argsort()[0][-top_n:]][::-1]
|
||||
|
||||
return keywords[::-1]
|
||||
return keywords
|
||||
except ValueError:
|
||||
return []
|
||||
|
||||
@@ -125,6 +141,8 @@ class KeyBERT:
|
||||
min_df: int = 1):
|
||||
""" Extract keywords/keyphrases for a multiple documents
|
||||
|
||||
This currently does not use MMR as
|
||||
|
||||
Arguments:
|
||||
docs: The document for which to extract keywords/keyphrases
|
||||
keyphrase_length: Length, in words, of the extracted keywords/keyphrases
|
||||
|
||||
2
setup.py
2
setup.py
@@ -25,7 +25,7 @@ with open("README.md", "r") as fh:
|
||||
setuptools.setup(
|
||||
name="keybert",
|
||||
packages=["keybert"],
|
||||
version="0.0.1",
|
||||
version="0.1.0",
|
||||
author="Maarten Grootendorst",
|
||||
author_email="maartengrootendorst@gmail.com",
|
||||
description="KeyBERT performs keyword extraction with state-of-the-art transformer models.",
|
||||
|
||||
@@ -16,11 +16,15 @@ def test_single_doc(keyphrase_length, base_keybert):
|
||||
assert len(keyword.split(" ")) == keyphrase_length
|
||||
|
||||
|
||||
@pytest.mark.parametrize("keyphrase_length", [i+1 for i in range(5)])
|
||||
def test_extract_keywords_single_doc(keyphrase_length, base_keybert):
|
||||
@pytest.mark.parametrize("keyphrase_length, mmr", [(i+1, truth) for i in range(5) for truth in [True, False]])
|
||||
def test_extract_keywords_single_doc(keyphrase_length, mmr, base_keybert):
|
||||
""" Test extraction of protected single document method """
|
||||
top_n = 5
|
||||
keywords = base_keybert._extract_keywords_single_doc(doc_one, top_n=top_n, keyphrase_length=keyphrase_length)
|
||||
keywords = base_keybert._extract_keywords_single_doc(doc_one,
|
||||
top_n=top_n,
|
||||
keyphrase_length=keyphrase_length,
|
||||
use_mmr=mmr,
|
||||
diversity=0.5)
|
||||
assert isinstance(keywords, list)
|
||||
assert isinstance(keywords[0], str)
|
||||
assert len(keywords) == top_n
|
||||
|
||||
Reference in New Issue
Block a user