1
0
mirror of https://github.com/deepset-ai/haystack.git synced 2022-02-20 23:31:40 +03:00

Add DELETE /feedback for testing and make the label's id generate server-side (#2159)

* Add DELETE /feedback for testing and make the ID generate server-side

* Make sure to delete only user generated labels

* Reduce fixture scope, was too broad

* Make test a bit more generic

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Sara Zan
2022-02-14 11:43:26 +01:00
committed by GitHub
parent db4d6f43ba
commit be8f50c9e3
5 changed files with 198 additions and 69 deletions

View File

@@ -117,7 +117,15 @@
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/LabelSerialized"
"title": "Feedback",
"anyOf": [
{
"$ref": "#/components/schemas/LabelSerialized"
},
{
"$ref": "#/components/schemas/CreateLabelSerialized"
}
]
}
}
},
@@ -143,6 +151,24 @@
}
}
}
},
"delete": {
"tags": [
"feedback"
],
"summary": "Delete Feedback",
"description": "This endpoint allows the API user to delete all the\nfeedback that has been sumbitted through the\n`POST /feedback` endpoint",
"operationId": "delete_feedback_feedback_delete",
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {}
}
}
}
}
}
},
"/eval-feedback": {
@@ -480,6 +506,74 @@
}
}
},
"CreateLabelSerialized": {
"title": "CreateLabelSerialized",
"required": [
"query",
"document",
"is_correct_answer",
"is_correct_document",
"origin"
],
"type": "object",
"properties": {
"id": {
"title": "Id",
"type": "string"
},
"query": {
"title": "Query",
"type": "string"
},
"document": {
"$ref": "#/components/schemas/DocumentSerialized"
},
"is_correct_answer": {
"title": "Is Correct Answer",
"type": "boolean"
},
"is_correct_document": {
"title": "Is Correct Document",
"type": "boolean"
},
"origin": {
"title": "Origin",
"enum": [
"user-feedback",
"gold-label"
],
"type": "string"
},
"answer": {
"$ref": "#/components/schemas/AnswerSerialized"
},
"no_answer": {
"title": "No Answer",
"type": "boolean"
},
"pipeline_id": {
"title": "Pipeline Id",
"type": "string"
},
"created_at": {
"title": "Created At",
"type": "string"
},
"updated_at": {
"title": "Updated At",
"type": "string"
},
"meta": {
"title": "Meta",
"type": "object"
},
"filters": {
"title": "Filters",
"type": "object"
}
},
"additionalProperties": false
},
"DocumentSerialized": {
"title": "DocumentSerialized",
"required": [

View File

@@ -4,7 +4,8 @@ import json
import logging
from fastapi import APIRouter
from rest_api.schema import FilterRequest, LabelSerialized
from haystack.schema import Label
from rest_api.schema import FilterRequest, LabelSerialized, CreateLabelSerialized
from rest_api.controller.search import DOCUMENT_STORE
@@ -14,7 +15,7 @@ logger = logging.getLogger(__name__)
@router.post("/feedback")
def post_feedback(feedback: LabelSerialized):
def post_feedback(feedback: Union[LabelSerialized, CreateLabelSerialized]):
"""
This endpoint allows the API user to submit feedback on
an answer for a particular query. For example, the user
@@ -25,7 +26,9 @@ def post_feedback(feedback: LabelSerialized):
"""
if feedback.origin is None:
feedback.origin = "user-feedback"
DOCUMENT_STORE.write_labels([feedback])
label = Label(**feedback.dict())
DOCUMENT_STORE.write_labels([label])
@router.get("/feedback")
@@ -39,6 +42,18 @@ def get_feedback():
return labels
@router.delete("/feedback")
def delete_feedback():
"""
This endpoint allows the API user to delete all the
feedback that has been sumbitted through the
`POST /feedback` endpoint
"""
all_labels = DOCUMENT_STORE.get_all_labels()
user_label_ids = [label.id for label in all_labels if label.origin == "user-feedback"]
DOCUMENT_STORE.delete_labels(ids=user_label_ids)
@router.post("/eval-feedback")
def get_feedback_metrics(filters: FilterRequest = None):
"""

View File

@@ -39,11 +39,31 @@ class DocumentSerialized(Document):
@pydantic_dataclass
class LabelSerialized(Label):
class LabelSerialized(Label, BaseModel):
document: DocumentSerialized
answer: Optional[AnswerSerialized] = None
class CreateLabelSerialized(BaseModel):
id: Optional[str] = None
query: str
document: DocumentSerialized
is_correct_answer: bool
is_correct_document: bool
origin: Literal["user-feedback", "gold-label"]
answer: Optional[AnswerSerialized] = None
no_answer: Optional[bool] = None
pipeline_id: Optional[str] = None
created_at: Optional[str] = None
updated_at: Optional[str] = None
meta: Optional[dict] = None
filters: Optional[dict] = None
class Config:
# Forbid any extra fields in the request to avoid silent failures
extra = Extra.forbid
class QueryResponse(BaseModel):
query: str
answers: List[AnswerSerialized]

View File

@@ -1,10 +1,10 @@
import os
from copy import deepcopy
from pathlib import Path
import pytest
from fastapi.testclient import TestClient
from rest_api.application import app
@@ -42,21 +42,25 @@ def exclude_no_answer(responses):
return responses
@pytest.fixture(scope="session")
@pytest.fixture()
def client() -> TestClient:
os.environ["PIPELINE_YAML_PATH"] = str(
(Path(__file__).parent / "samples" / "pipeline" / "test_pipeline.yaml").absolute()
)
os.environ["INDEXING_PIPELINE_NAME"] = "indexing_text_pipeline"
client = TestClient(app)
client.post(url="/documents/delete_by_filters", data='{"filters": {}}')
client.delete(url="/feedback")
yield client
# Clean up
client.post(url="/documents/delete_by_filters", data='{"filters": {}}')
client.delete(url="/feedback")
@pytest.fixture(scope="session")
@pytest.fixture()
def populated_client(client: TestClient) -> TestClient:
client.post(url="/documents/delete_by_filters", data='{"filters": {}}')
files_to_upload = [
{"files": (Path(__file__).parent / "samples" / "pdf" / "sample_pdf_1.pdf").open("rb")},
{"files": (Path(__file__).parent / "samples" / "pdf" / "sample_pdf_2.pdf").open("rb")},
@@ -66,89 +70,57 @@ def populated_client(client: TestClient) -> TestClient:
url="/file-upload", files=fi, data={"meta": f'{{"meta_key": "meta_value", "meta_index": "{index}"}}'}
)
assert 200 == response.status_code
yield client
client.post(url="/documents/delete_by_filters", data='{"filters": {}}')
def test_get_documents():
os.environ["PIPELINE_YAML_PATH"] = str(
(Path(__file__).parent / "samples" / "pipeline" / "test_pipeline.yaml").absolute()
)
os.environ["INDEXING_PIPELINE_NAME"] = "indexing_text_pipeline"
client = TestClient(app)
# Clean up to make sure the docstore is empty
client.post(url="/documents/delete_by_filters", data='{"filters": {}}')
# Upload the files
files_to_upload = [
{"files": (Path(__file__).parent / "samples" / "docs" / "doc_1.txt").open("rb")},
{"files": (Path(__file__).parent / "samples" / "docs" / "doc_2.txt").open("rb")},
]
for index, fi in enumerate(files_to_upload):
response = client.post(url="/file-upload", files=fi, data={"meta": f'{{"meta_key": "meta_value_get"}}'})
assert 200 == response.status_code
def test_get_documents(populated_client: TestClient):
# Get the documents
response = client.post(url="/documents/get_by_filters", data='{"filters": {"meta_key": ["meta_value_get"]}}')
response = populated_client.post(url="/documents/get_by_filters", data='{"filters": {"meta_key": ["meta_value"]}}')
assert 200 == response.status_code
response_json = response.json()
# Make sure the right docs are found
assert len(response_json) == 2
names = [doc["meta"]["name"] for doc in response_json]
assert "doc_1.txt" in names
assert "doc_2.txt" in names
assert "sample_pdf_1.pdf" in names
assert "sample_pdf_2.pdf" in names
meta_keys = [doc["meta"]["meta_key"] for doc in response_json]
assert all("meta_value_get" == meta_key for meta_key in meta_keys)
assert all("meta_value" == meta_key for meta_key in meta_keys)
def test_delete_documents():
os.environ["PIPELINE_YAML_PATH"] = str(
(Path(__file__).parent / "samples" / "pipeline" / "test_pipeline.yaml").absolute()
)
os.environ["INDEXING_PIPELINE_NAME"] = "indexing_text_pipeline"
client = TestClient(app)
# Clean up to make sure the docstore is empty
client.post(url="/documents/delete_by_filters", data='{"filters": {}}')
# Upload the files
files_to_upload = [
{"files": (Path(__file__).parent / "samples" / "docs" / "doc_1.txt").open("rb")},
{"files": (Path(__file__).parent / "samples" / "docs" / "doc_2.txt").open("rb")},
]
for index, fi in enumerate(files_to_upload):
response = client.post(
url="/file-upload", files=fi, data={"meta": f'{{"meta_key": "meta_value_del", "meta_index": "{index}"}}'}
)
assert 200 == response.status_code
# Make sure there are two docs
response = client.post(url="/documents/get_by_filters", data='{"filters": {"meta_key": ["meta_value_del"]}}')
def test_delete_documents(populated_client: TestClient):
# Check how many docs there are
response = populated_client.post(url="/documents/get_by_filters", data='{"filters": {"meta_key": ["meta_value"]}}')
assert 200 == response.status_code
response_json = response.json()
assert len(response_json) == 2
initial_docs = len(response_json)
# Check how many docs we will delete
response = populated_client.post(url="/documents/get_by_filters", data='{"filters": {"meta_index": ["0"]}}')
assert 200 == response.status_code
response_json = response.json()
docs_to_delete = len(response_json)
# Delete one doc
response = client.post(url="/documents/delete_by_filters", data='{"filters": {"meta_index": ["0"]}}')
response = populated_client.post(url="/documents/delete_by_filters", data='{"filters": {"meta_index": ["0"]}}')
assert 200 == response.status_code
# Now there should be only one doc
response = client.post(url="/documents/get_by_filters", data='{"filters": {"meta_key": ["meta_value_del"]}}')
# Now there should be less document
response = populated_client.post(url="/documents/get_by_filters", data='{"filters": {"meta_key": ["meta_value"]}}')
assert 200 == response.status_code
response_json = response.json()
assert len(response_json) == 1
assert len(response_json) == initial_docs - docs_to_delete
# Make sure the right doc was deleted
response = client.post(url="/documents/get_by_filters", data='{"filters": {"meta_index": ["0"]}}')
# Make sure the right docs were deleted
response = populated_client.post(url="/documents/get_by_filters", data='{"filters": {"meta_index": ["0"]}}')
assert 200 == response.status_code
response_json = response.json()
assert len(response_json) == 0
response = client.post(url="/documents/get_by_filters", data='{"filters": {"meta_index": ["1"]}}')
response = populated_client.post(url="/documents/get_by_filters", data='{"filters": {"meta_index": ["1"]}}')
assert 200 == response.status_code
response_json = response.json()
assert len(response_json) == 1
assert len(response_json) >= 1
def test_file_upload(client: TestClient):
@@ -220,9 +192,17 @@ def test_write_feedback(populated_client: TestClient):
assert 200 == response.status_code
def test_write_feedback_without_id(populated_client: TestClient):
feedback = deepcopy(FEEDBACK)
del feedback["id"]
response = populated_client.post(url="/feedback", json=feedback)
assert 200 == response.status_code
def test_get_feedback(client: TestClient):
response = client.post(url="/feedback", json=FEEDBACK)
assert response.status_code == 200
response = client.get(url="/feedback")
assert response.status_code == 200
json_response = response.json()
@@ -230,6 +210,27 @@ def test_get_feedback(client: TestClient):
assert response_item == expected_item
def test_delete_feedback(client: TestClient):
client.post(url="/feedback", json=FEEDBACK)
feedback = deepcopy(FEEDBACK)
feedback["id"] = 456
feedback["origin"] = "gold-label"
print(feedback)
client.post(url="/feedback", json=feedback)
response = client.get(url="/feedback")
json_response = response.json()
assert len(json_response) == 2
response = client.delete(url="/feedback")
assert 200 == response.status_code
response = client.get(url="/feedback")
json_response = response.json()
assert len(json_response) == 1
def test_export_feedback(client: TestClient):
response = client.post(url="/feedback", json=FEEDBACK)
assert 200 == response.status_code
@@ -249,7 +250,7 @@ def test_export_feedback(client: TestClient):
def test_get_feedback_malformed_query(client: TestClient):
feedback = FEEDBACK.copy()
feedback = deepcopy(FEEDBACK)
feedback["unexpected_field"] = "misplaced-value"
response = client.post(url="/feedback", json=feedback)
assert response.status_code == 422

View File

@@ -92,7 +92,6 @@ def send_feedback(query, answer_obj, is_correct_answer, is_correct_document, doc
"""
url = f"{API_ENDPOINT}/{DOC_FEEDBACK}"
req = {
"id": str(uuid4()),
"query": query,
"document": document,
"is_correct_answer": is_correct_answer,