1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

fix docstring issues..

This commit is contained in:
Yanjun Qi
2021-09-29 15:50:57 -04:00
parent f1ef471ea8
commit fa9817af5d
23 changed files with 150 additions and 35 deletions

View File

@@ -7,6 +7,7 @@ textattack.constraints.grammaticality.language\_models package
:show-inheritance:
.. toctree::
:maxdepth: 6

View File

@@ -7,6 +7,7 @@ textattack.constraints.grammaticality package
:show-inheritance:
.. toctree::
:maxdepth: 6

View File

@@ -7,6 +7,7 @@ textattack.constraints package
:show-inheritance:
.. toctree::
:maxdepth: 6

View File

@@ -7,6 +7,7 @@ textattack.constraints.semantics package
:show-inheritance:
.. toctree::
:maxdepth: 6

View File

@@ -7,6 +7,7 @@ textattack.constraints.semantics.sentence\_encoders package
:show-inheritance:
.. toctree::
:maxdepth: 6

View File

@@ -7,6 +7,7 @@ textattack.datasets package
:show-inheritance:
.. toctree::
:maxdepth: 6

View File

@@ -7,6 +7,7 @@ textattack.goal\_functions package
:show-inheritance:
.. toctree::
:maxdepth: 6

View File

@@ -0,0 +1,26 @@
textattack.metrics.attack\_metrics package
==========================================
.. automodule:: textattack.metrics.attack_metrics
:members:
:undoc-members:
:show-inheritance:
.. automodule:: textattack.metrics.attack_metrics.attack_queries
:members:
:undoc-members:
:show-inheritance:
.. automodule:: textattack.metrics.attack_metrics.attack_success_rate
:members:
:undoc-members:
:show-inheritance:
.. automodule:: textattack.metrics.attack_metrics.words_perturbed
:members:
:undoc-members:
:show-inheritance:

View File

@@ -0,0 +1,20 @@
textattack.metrics.quality\_metrics package
===========================================
.. automodule:: textattack.metrics.quality_metrics
:members:
:undoc-members:
:show-inheritance:
.. automodule:: textattack.metrics.quality_metrics.perplexity
:members:
:undoc-members:
:show-inheritance:
.. automodule:: textattack.metrics.quality_metrics.use
:members:
:undoc-members:
:show-inheritance:

View File

@@ -0,0 +1,22 @@
textattack.metrics package
==========================
.. automodule:: textattack.metrics
:members:
:undoc-members:
:show-inheritance:
.. toctree::
:maxdepth: 6
textattack.metrics.attack_metrics
textattack.metrics.quality_metrics
.. automodule:: textattack.metrics.metric
:members:
:undoc-members:
:show-inheritance:

View File

@@ -7,6 +7,7 @@ textattack.models package
:show-inheritance:
.. toctree::
:maxdepth: 6

View File

@@ -19,6 +19,7 @@ textattack package
textattack.goal_function_results
textattack.goal_functions
textattack.loggers
textattack.metrics
textattack.models
textattack.search_methods
textattack.shared

View File

@@ -7,6 +7,7 @@ textattack.shared package
:show-inheritance:
.. toctree::
:maxdepth: 6

View File

@@ -7,6 +7,7 @@ textattack.transformations package
:show-inheritance:
.. toctree::
:maxdepth: 6

View File

@@ -357,7 +357,7 @@ class AttackArgs:
"--enable-advance-metrics",
action="store_true",
default=default_obj.enable_advance_metrics,
help="Enable calculation and display of optional advance post-hoc metrics like perplexity, use distance, etc.",
help="Enable calculation and display of optional advance post-hoc metrics like perplexity, USE distance, etc.",
)
return parser

View File

@@ -3,8 +3,7 @@
attack_metrics:
======================
TextAttack allows users to use their own metrics on adversarial examples or select common metrics to display.
TextAttack provide users common metrics on attacks' quality.
"""

View File

@@ -1,3 +1,10 @@
"""
Metrics on AttackQueries
=========================
"""
import numpy as np
from textattack.attack_results import SkippedAttackResult
@@ -5,17 +12,17 @@ from textattack.metrics import Metric
class AttackQueries(Metric):
"""Calculates all metrics related to number of queries in an attack
Args:
results (:obj::`list`:class:`~textattack.goal_function_results.GoalFunctionResult`):
Attack results for each instance in dataset
"""
def __init__(self):
self.all_metrics = {}
def calculate(self, results):
"""Calculates all metrics related to number of queries in an attack
Args:
results (``AttackResult`` objects):
Attack results for each instance in dataset
"""
self.results = results
self.num_queries = np.array(
[

View File

@@ -1,15 +1,15 @@
"""
Metrics on AttackSuccessRate
=============================
"""
from textattack.attack_results import FailedAttackResult, SkippedAttackResult
from textattack.metrics import Metric
class AttackSuccessRate(Metric):
"""Calculates all metrics related to number of succesful, failed and skipped results in an attack
Args:
results (:obj::`list`:class:`~textattack.goal_function_results.GoalFunctionResult`):
Attack results for each instance in dataset
"""
def __init__(self):
self.failed_attacks = 0
self.skipped_attacks = 0
@@ -18,6 +18,12 @@ class AttackSuccessRate(Metric):
self.all_metrics = {}
def calculate(self, results):
"""Calculates all metrics related to number of succesful, failed and skipped results in an attack
Args:
results (``AttackResult`` objects):
Attack results for each instance in dataset
"""
self.results = results
self.total_attacks = len(self.results)

View File

@@ -1,3 +1,10 @@
"""
Metrics on perturbed words
=============================
"""
import numpy as np
from textattack.attack_results import FailedAttackResult, SkippedAttackResult
@@ -13,6 +20,13 @@ class WordsPerturbed(Metric):
self.all_metrics = {}
def calculate(self, results):
"""Calculates all metrics related to perturbed words in an attack
Args:
results (``AttackResult`` objects):
Attack results for each instance in dataset
"""
self.results = results
self.total_attacks = len(self.results)
self.all_num_words = np.zeros(len(self.results))

View File

@@ -19,5 +19,10 @@ class Metric(ABC):
@abstractmethod
def calculate(self, results):
""" Abstract function for computing any values which are to be calculated as a whole during initialization"""
"""Abstract function for computing any values which are to be calculated as a whole during initialization
Args:
results (``AttackResult`` objects):
Attack results for each instance in dataset
"""
raise NotImplementedError

View File

@@ -1,9 +1,9 @@
"""
perplexity:
Metrics on Quality
======================
TextAttack allows users to use their own metrics on adversarial examples or select common metrics to display.
TextAttack provide users common metrics on text examples' quality.
"""

View File

@@ -1,3 +1,10 @@
"""
Perplexity Metric:
======================
"""
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
@@ -7,13 +14,6 @@ import textattack.shared.utils
class Perplexity(Metric):
"""Calculates average Perplexity on all successfull attacks using a pre-trained small GPT-2 model
Args:
results (:obj::`list`:class:`~textattack.goal_function_results.GoalFunctionResult`):
Attack results for each instance in dataset
"""
def __init__(self):
self.all_metrics = {}
self.original_candidates = []
@@ -23,9 +23,15 @@ class Perplexity(Metric):
self.ppl_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
self.ppl_model.eval()
self.max_length = self.ppl_model.config.n_positions
self.stride = 128
self.stride = 512
def calculate(self, results):
"""Calculates average Perplexity on all successfull attacks using a pre-trained small GPT-2 model
Args:
results (``AttackResult`` objects):
Attack results for each instance in dataset
"""
self.results = results
self.original_candidates_ppl = []
self.successful_candidates_ppl = []

View File

@@ -4,13 +4,6 @@ from textattack.metrics import Metric
class USEMetric(Metric):
"""Calculates average USE similarity on all successfull attacks
Args:
results (:obj::`list`:class:`~textattack.goal_function_results.GoalFunctionResult`):
Attack results for each instance in dataset
"""
def __init__(self, **kwargs):
self.use_obj = UniversalSentenceEncoder()
self.use_obj.model = UniversalSentenceEncoder()
@@ -19,6 +12,12 @@ class USEMetric(Metric):
self.all_metrics = {}
def calculate(self, results):
"""Calculates average USE similarity on all successfull attacks
Args:
results (``AttackResult`` objects):
Attack results for each instance in dataset
"""
self.results = results
for i, result in enumerate(self.results):