mirror of
https://github.com/deepset-ai/haystack.git
synced 2022-02-20 23:31:40 +03:00
Add DELETE /feedback for testing and make the label's id generate server-side (#2159)
* Add DELETE /feedback for testing and make the ID generate server-side * Make sure to delete only user generated labels * Reduce fixture scope, was too broad * Make test a bit more generic Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@@ -117,7 +117,15 @@
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/LabelSerialized"
|
||||
"title": "Feedback",
|
||||
"anyOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/LabelSerialized"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/CreateLabelSerialized"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -143,6 +151,24 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"delete": {
|
||||
"tags": [
|
||||
"feedback"
|
||||
],
|
||||
"summary": "Delete Feedback",
|
||||
"description": "This endpoint allows the API user to delete all the\nfeedback that has been sumbitted through the\n`POST /feedback` endpoint",
|
||||
"operationId": "delete_feedback_feedback_delete",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/eval-feedback": {
|
||||
@@ -480,6 +506,74 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"CreateLabelSerialized": {
|
||||
"title": "CreateLabelSerialized",
|
||||
"required": [
|
||||
"query",
|
||||
"document",
|
||||
"is_correct_answer",
|
||||
"is_correct_document",
|
||||
"origin"
|
||||
],
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {
|
||||
"title": "Id",
|
||||
"type": "string"
|
||||
},
|
||||
"query": {
|
||||
"title": "Query",
|
||||
"type": "string"
|
||||
},
|
||||
"document": {
|
||||
"$ref": "#/components/schemas/DocumentSerialized"
|
||||
},
|
||||
"is_correct_answer": {
|
||||
"title": "Is Correct Answer",
|
||||
"type": "boolean"
|
||||
},
|
||||
"is_correct_document": {
|
||||
"title": "Is Correct Document",
|
||||
"type": "boolean"
|
||||
},
|
||||
"origin": {
|
||||
"title": "Origin",
|
||||
"enum": [
|
||||
"user-feedback",
|
||||
"gold-label"
|
||||
],
|
||||
"type": "string"
|
||||
},
|
||||
"answer": {
|
||||
"$ref": "#/components/schemas/AnswerSerialized"
|
||||
},
|
||||
"no_answer": {
|
||||
"title": "No Answer",
|
||||
"type": "boolean"
|
||||
},
|
||||
"pipeline_id": {
|
||||
"title": "Pipeline Id",
|
||||
"type": "string"
|
||||
},
|
||||
"created_at": {
|
||||
"title": "Created At",
|
||||
"type": "string"
|
||||
},
|
||||
"updated_at": {
|
||||
"title": "Updated At",
|
||||
"type": "string"
|
||||
},
|
||||
"meta": {
|
||||
"title": "Meta",
|
||||
"type": "object"
|
||||
},
|
||||
"filters": {
|
||||
"title": "Filters",
|
||||
"type": "object"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
},
|
||||
"DocumentSerialized": {
|
||||
"title": "DocumentSerialized",
|
||||
"required": [
|
||||
|
||||
@@ -4,7 +4,8 @@ import json
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter
|
||||
from rest_api.schema import FilterRequest, LabelSerialized
|
||||
from haystack.schema import Label
|
||||
from rest_api.schema import FilterRequest, LabelSerialized, CreateLabelSerialized
|
||||
from rest_api.controller.search import DOCUMENT_STORE
|
||||
|
||||
|
||||
@@ -14,7 +15,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.post("/feedback")
|
||||
def post_feedback(feedback: LabelSerialized):
|
||||
def post_feedback(feedback: Union[LabelSerialized, CreateLabelSerialized]):
|
||||
"""
|
||||
This endpoint allows the API user to submit feedback on
|
||||
an answer for a particular query. For example, the user
|
||||
@@ -25,7 +26,9 @@ def post_feedback(feedback: LabelSerialized):
|
||||
"""
|
||||
if feedback.origin is None:
|
||||
feedback.origin = "user-feedback"
|
||||
DOCUMENT_STORE.write_labels([feedback])
|
||||
|
||||
label = Label(**feedback.dict())
|
||||
DOCUMENT_STORE.write_labels([label])
|
||||
|
||||
|
||||
@router.get("/feedback")
|
||||
@@ -39,6 +42,18 @@ def get_feedback():
|
||||
return labels
|
||||
|
||||
|
||||
@router.delete("/feedback")
|
||||
def delete_feedback():
|
||||
"""
|
||||
This endpoint allows the API user to delete all the
|
||||
feedback that has been sumbitted through the
|
||||
`POST /feedback` endpoint
|
||||
"""
|
||||
all_labels = DOCUMENT_STORE.get_all_labels()
|
||||
user_label_ids = [label.id for label in all_labels if label.origin == "user-feedback"]
|
||||
DOCUMENT_STORE.delete_labels(ids=user_label_ids)
|
||||
|
||||
|
||||
@router.post("/eval-feedback")
|
||||
def get_feedback_metrics(filters: FilterRequest = None):
|
||||
"""
|
||||
|
||||
@@ -39,11 +39,31 @@ class DocumentSerialized(Document):
|
||||
|
||||
|
||||
@pydantic_dataclass
|
||||
class LabelSerialized(Label):
|
||||
class LabelSerialized(Label, BaseModel):
|
||||
document: DocumentSerialized
|
||||
answer: Optional[AnswerSerialized] = None
|
||||
|
||||
|
||||
class CreateLabelSerialized(BaseModel):
|
||||
id: Optional[str] = None
|
||||
query: str
|
||||
document: DocumentSerialized
|
||||
is_correct_answer: bool
|
||||
is_correct_document: bool
|
||||
origin: Literal["user-feedback", "gold-label"]
|
||||
answer: Optional[AnswerSerialized] = None
|
||||
no_answer: Optional[bool] = None
|
||||
pipeline_id: Optional[str] = None
|
||||
created_at: Optional[str] = None
|
||||
updated_at: Optional[str] = None
|
||||
meta: Optional[dict] = None
|
||||
filters: Optional[dict] = None
|
||||
|
||||
class Config:
|
||||
# Forbid any extra fields in the request to avoid silent failures
|
||||
extra = Extra.forbid
|
||||
|
||||
|
||||
class QueryResponse(BaseModel):
|
||||
query: str
|
||||
answers: List[AnswerSerialized]
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
from rest_api.application import app
|
||||
|
||||
|
||||
@@ -42,21 +42,25 @@ def exclude_no_answer(responses):
|
||||
return responses
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@pytest.fixture()
|
||||
def client() -> TestClient:
|
||||
os.environ["PIPELINE_YAML_PATH"] = str(
|
||||
(Path(__file__).parent / "samples" / "pipeline" / "test_pipeline.yaml").absolute()
|
||||
)
|
||||
os.environ["INDEXING_PIPELINE_NAME"] = "indexing_text_pipeline"
|
||||
client = TestClient(app)
|
||||
|
||||
client.post(url="/documents/delete_by_filters", data='{"filters": {}}')
|
||||
client.delete(url="/feedback")
|
||||
|
||||
yield client
|
||||
# Clean up
|
||||
|
||||
client.post(url="/documents/delete_by_filters", data='{"filters": {}}')
|
||||
client.delete(url="/feedback")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@pytest.fixture()
|
||||
def populated_client(client: TestClient) -> TestClient:
|
||||
client.post(url="/documents/delete_by_filters", data='{"filters": {}}')
|
||||
files_to_upload = [
|
||||
{"files": (Path(__file__).parent / "samples" / "pdf" / "sample_pdf_1.pdf").open("rb")},
|
||||
{"files": (Path(__file__).parent / "samples" / "pdf" / "sample_pdf_2.pdf").open("rb")},
|
||||
@@ -66,89 +70,57 @@ def populated_client(client: TestClient) -> TestClient:
|
||||
url="/file-upload", files=fi, data={"meta": f'{{"meta_key": "meta_value", "meta_index": "{index}"}}'}
|
||||
)
|
||||
assert 200 == response.status_code
|
||||
|
||||
yield client
|
||||
client.post(url="/documents/delete_by_filters", data='{"filters": {}}')
|
||||
|
||||
|
||||
def test_get_documents():
|
||||
os.environ["PIPELINE_YAML_PATH"] = str(
|
||||
(Path(__file__).parent / "samples" / "pipeline" / "test_pipeline.yaml").absolute()
|
||||
)
|
||||
os.environ["INDEXING_PIPELINE_NAME"] = "indexing_text_pipeline"
|
||||
client = TestClient(app)
|
||||
|
||||
# Clean up to make sure the docstore is empty
|
||||
client.post(url="/documents/delete_by_filters", data='{"filters": {}}')
|
||||
|
||||
# Upload the files
|
||||
files_to_upload = [
|
||||
{"files": (Path(__file__).parent / "samples" / "docs" / "doc_1.txt").open("rb")},
|
||||
{"files": (Path(__file__).parent / "samples" / "docs" / "doc_2.txt").open("rb")},
|
||||
]
|
||||
for index, fi in enumerate(files_to_upload):
|
||||
response = client.post(url="/file-upload", files=fi, data={"meta": f'{{"meta_key": "meta_value_get"}}'})
|
||||
assert 200 == response.status_code
|
||||
|
||||
def test_get_documents(populated_client: TestClient):
|
||||
# Get the documents
|
||||
response = client.post(url="/documents/get_by_filters", data='{"filters": {"meta_key": ["meta_value_get"]}}')
|
||||
response = populated_client.post(url="/documents/get_by_filters", data='{"filters": {"meta_key": ["meta_value"]}}')
|
||||
assert 200 == response.status_code
|
||||
response_json = response.json()
|
||||
|
||||
# Make sure the right docs are found
|
||||
assert len(response_json) == 2
|
||||
names = [doc["meta"]["name"] for doc in response_json]
|
||||
assert "doc_1.txt" in names
|
||||
assert "doc_2.txt" in names
|
||||
assert "sample_pdf_1.pdf" in names
|
||||
assert "sample_pdf_2.pdf" in names
|
||||
meta_keys = [doc["meta"]["meta_key"] for doc in response_json]
|
||||
assert all("meta_value_get" == meta_key for meta_key in meta_keys)
|
||||
assert all("meta_value" == meta_key for meta_key in meta_keys)
|
||||
|
||||
|
||||
def test_delete_documents():
|
||||
os.environ["PIPELINE_YAML_PATH"] = str(
|
||||
(Path(__file__).parent / "samples" / "pipeline" / "test_pipeline.yaml").absolute()
|
||||
)
|
||||
os.environ["INDEXING_PIPELINE_NAME"] = "indexing_text_pipeline"
|
||||
client = TestClient(app)
|
||||
|
||||
# Clean up to make sure the docstore is empty
|
||||
client.post(url="/documents/delete_by_filters", data='{"filters": {}}')
|
||||
|
||||
# Upload the files
|
||||
files_to_upload = [
|
||||
{"files": (Path(__file__).parent / "samples" / "docs" / "doc_1.txt").open("rb")},
|
||||
{"files": (Path(__file__).parent / "samples" / "docs" / "doc_2.txt").open("rb")},
|
||||
]
|
||||
for index, fi in enumerate(files_to_upload):
|
||||
response = client.post(
|
||||
url="/file-upload", files=fi, data={"meta": f'{{"meta_key": "meta_value_del", "meta_index": "{index}"}}'}
|
||||
)
|
||||
assert 200 == response.status_code
|
||||
|
||||
# Make sure there are two docs
|
||||
response = client.post(url="/documents/get_by_filters", data='{"filters": {"meta_key": ["meta_value_del"]}}')
|
||||
def test_delete_documents(populated_client: TestClient):
|
||||
# Check how many docs there are
|
||||
response = populated_client.post(url="/documents/get_by_filters", data='{"filters": {"meta_key": ["meta_value"]}}')
|
||||
assert 200 == response.status_code
|
||||
response_json = response.json()
|
||||
assert len(response_json) == 2
|
||||
initial_docs = len(response_json)
|
||||
|
||||
# Check how many docs we will delete
|
||||
response = populated_client.post(url="/documents/get_by_filters", data='{"filters": {"meta_index": ["0"]}}')
|
||||
assert 200 == response.status_code
|
||||
response_json = response.json()
|
||||
docs_to_delete = len(response_json)
|
||||
|
||||
# Delete one doc
|
||||
response = client.post(url="/documents/delete_by_filters", data='{"filters": {"meta_index": ["0"]}}')
|
||||
response = populated_client.post(url="/documents/delete_by_filters", data='{"filters": {"meta_index": ["0"]}}')
|
||||
assert 200 == response.status_code
|
||||
|
||||
# Now there should be only one doc
|
||||
response = client.post(url="/documents/get_by_filters", data='{"filters": {"meta_key": ["meta_value_del"]}}')
|
||||
# Now there should be less document
|
||||
response = populated_client.post(url="/documents/get_by_filters", data='{"filters": {"meta_key": ["meta_value"]}}')
|
||||
assert 200 == response.status_code
|
||||
response_json = response.json()
|
||||
assert len(response_json) == 1
|
||||
assert len(response_json) == initial_docs - docs_to_delete
|
||||
|
||||
# Make sure the right doc was deleted
|
||||
response = client.post(url="/documents/get_by_filters", data='{"filters": {"meta_index": ["0"]}}')
|
||||
# Make sure the right docs were deleted
|
||||
response = populated_client.post(url="/documents/get_by_filters", data='{"filters": {"meta_index": ["0"]}}')
|
||||
assert 200 == response.status_code
|
||||
response_json = response.json()
|
||||
assert len(response_json) == 0
|
||||
response = client.post(url="/documents/get_by_filters", data='{"filters": {"meta_index": ["1"]}}')
|
||||
|
||||
response = populated_client.post(url="/documents/get_by_filters", data='{"filters": {"meta_index": ["1"]}}')
|
||||
assert 200 == response.status_code
|
||||
response_json = response.json()
|
||||
assert len(response_json) == 1
|
||||
assert len(response_json) >= 1
|
||||
|
||||
|
||||
def test_file_upload(client: TestClient):
|
||||
@@ -220,9 +192,17 @@ def test_write_feedback(populated_client: TestClient):
|
||||
assert 200 == response.status_code
|
||||
|
||||
|
||||
def test_write_feedback_without_id(populated_client: TestClient):
|
||||
feedback = deepcopy(FEEDBACK)
|
||||
del feedback["id"]
|
||||
response = populated_client.post(url="/feedback", json=feedback)
|
||||
assert 200 == response.status_code
|
||||
|
||||
|
||||
def test_get_feedback(client: TestClient):
|
||||
response = client.post(url="/feedback", json=FEEDBACK)
|
||||
assert response.status_code == 200
|
||||
|
||||
response = client.get(url="/feedback")
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
@@ -230,6 +210,27 @@ def test_get_feedback(client: TestClient):
|
||||
assert response_item == expected_item
|
||||
|
||||
|
||||
def test_delete_feedback(client: TestClient):
|
||||
client.post(url="/feedback", json=FEEDBACK)
|
||||
|
||||
feedback = deepcopy(FEEDBACK)
|
||||
feedback["id"] = 456
|
||||
feedback["origin"] = "gold-label"
|
||||
print(feedback)
|
||||
client.post(url="/feedback", json=feedback)
|
||||
|
||||
response = client.get(url="/feedback")
|
||||
json_response = response.json()
|
||||
assert len(json_response) == 2
|
||||
|
||||
response = client.delete(url="/feedback")
|
||||
assert 200 == response.status_code
|
||||
|
||||
response = client.get(url="/feedback")
|
||||
json_response = response.json()
|
||||
assert len(json_response) == 1
|
||||
|
||||
|
||||
def test_export_feedback(client: TestClient):
|
||||
response = client.post(url="/feedback", json=FEEDBACK)
|
||||
assert 200 == response.status_code
|
||||
@@ -249,7 +250,7 @@ def test_export_feedback(client: TestClient):
|
||||
|
||||
|
||||
def test_get_feedback_malformed_query(client: TestClient):
|
||||
feedback = FEEDBACK.copy()
|
||||
feedback = deepcopy(FEEDBACK)
|
||||
feedback["unexpected_field"] = "misplaced-value"
|
||||
response = client.post(url="/feedback", json=feedback)
|
||||
assert response.status_code == 422
|
||||
|
||||
@@ -92,7 +92,6 @@ def send_feedback(query, answer_obj, is_correct_answer, is_correct_document, doc
|
||||
"""
|
||||
url = f"{API_ENDPOINT}/{DOC_FEEDBACK}"
|
||||
req = {
|
||||
"id": str(uuid4()),
|
||||
"query": query,
|
||||
"document": document,
|
||||
"is_correct_answer": is_correct_answer,
|
||||
|
||||
Reference in New Issue
Block a user