1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00
Files
textattack-nlp-transformer/local_tests/test_models.py
Jack Morris a4f72facbd update tests
2020-05-09 08:42:34 -04:00

162 lines
5.6 KiB
Python

import colored
import io
import os
import re
import signal
import sys
import subprocess
import traceback
from side_by_side import print_side_by_side
def color_text(s, color):
return colored.stylize(s, colored.fg(color))
stderr_file_name = 'err.out.txt'
MAGIC_STRING = '/.*/'
def outputs_are_equivalent(desired_output, test_output):
""" Desired outputs have the magic string '/.*/' inserted wherever the
outputat that position doesn't actually matter. (For example, when the
time to execute is printed, or another non-deterministic feature of the
program.)
`compare_outputs` makes sure all of the outputs match in between
the magic strings. If they do, it returns True.
"""
output_pieces = desired_output.split(MAGIC_STRING)
for piece in output_pieces:
index_in_test = test_output.find(piece)
if index_in_test < 0:
return False
else:
test_output = test_output[index_in_test + len(piece):]
return True
class TextAttackTest:
def __init__(self, name=None, output=None, desc=None):
if name is None:
raise ValueError('Cannot initialize TextAttackTest without name')
if output is None:
raise ValueError('Cannot initialize TextAttackTest without output')
if desc is None:
raise ValueError('Cannot initialize TextAttackTest without description')
self.name = name
self.output = output
self.desc = desc
def execute(self):
""" Executes test and returns test output. To be implemented by
subclasses.
"""
raise NotImplementedError()
def __call__(self):
""" Runs test and prints success or failure. """
self.log_start()
test_output, errored = self.execute()
if (not errored) and outputs_are_equivalent(self.output, test_output):
self.log_success()
return True
else:
self.log_failure(test_output, errored)
return False
def log_start(self):
print(f'Executing test {color_text(self.name, "blue")}.')
def log_success(self):
success_text = f'✓ Succeeded.'
print(color_text(success_text, 'green'))
def log_failure(self, test_output, errored):
fail_text = f'✗ Failed.'
print(color_text(fail_text, 'red'))
if errored:
print(f'Test exited early with error: {test_output}')
else:
output1 = f'Test output: {test_output}.'
output2 = f'Correct output: {self.output}.'
### begin delete
print()
print(output1)
print()
print(output2)
print()
### end delete
print_side_by_side(output1, output2)
class CommandLineTest(TextAttackTest):
""" Runs a command-line command to check for desired output. """
def __init__(self, command, name=None, output=None, desc=None):
if command is None:
raise ValueError('Cannot initialize CommandLineTest without command')
self.command = command
super().__init__(name=name, output=output, desc=desc)
def execute(self):
stderr_file = open(stderr_file_name, 'w+')
result = subprocess.run(
self.command.split(),
stdout=subprocess.PIPE,
# @TODO: Collect stderr somewhere. In the event of an error, point user to the error file.
stderr=stderr_file
)
stderr_file.seek(0) # go back to beginning of file so we can read the whole thing
stderr_str = stderr_file.read()
# Remove temp file.
remove_stderr_file()
if result.returncode == 0:
# If the command succeeds, return stdout.
return result.stdout.decode(), False
else:
# If the command returns an exit code, return stderr.
return stderr_str, True
class Capturing(list):
""" A context manager that captures standard out during its execution.
stackoverflow.com/questions/16571150/how-to-capture-stdout-output-from-a-python-function-call
"""
def __enter__(self):
self._stdout = sys.stdout
sys.stdout = self._stringio = io.StringIO()
return self
def __exit__(self, *args):
self.extend(self._stringio.getvalue().splitlines())
del self._stringio # free up some memory
sys.stdout = self._stdout
class PythonFunctionTest(TextAttackTest):
""" Runs a Python function to check for desired output. """
def __init__(self, function, name=None, output=None, desc=None):
if function is None:
raise ValueError('Cannot initialize PythonFunctionTest without function')
self.function = function
super().__init__(name=name, output=output, desc=desc)
def execute(self):
try:
with Capturing() as output_lines:
self.function()
output = '\n'.join(output_lines)
return output, False
except: # catch *all* exceptions
exc_str_lines = traceback.format_exc().splitlines()
exc_str = '\n'.join(exc_str_lines)
return exc_str, True
def remove_stderr_file():
# Make sure the stderr file is removed on exit.
try:
os.unlink(stderr_file_name)
except FileNotFoundError:
# File doesn't exit - that means we never made it or already cleaned it up
pass
def exit_handler(_,__):
remove_stderr_file()
# If the program exits early, make sure it didn't create any unneeded files.
signal.signal(signal.SIGINT, exit_handler)