mirror of
https://github.com/abetlen/llama-cpp-python.git
synced 2023-09-07 17:34:22 +03:00
Revert "llama_cpp server: prompt is a string". Closes #187
This reverts commit b9098b0ef7.
This commit is contained in:
@@ -167,8 +167,9 @@ frequency_penalty_field = Field(
|
|||||||
)
|
)
|
||||||
|
|
||||||
class CreateCompletionRequest(BaseModel):
|
class CreateCompletionRequest(BaseModel):
|
||||||
prompt: Optional[str] = Field(
|
prompt: Union[str, List[str]] = Field(
|
||||||
default="", description="The prompt to generate completions for."
|
default="",
|
||||||
|
description="The prompt to generate completions for."
|
||||||
)
|
)
|
||||||
suffix: Optional[str] = Field(
|
suffix: Optional[str] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
@@ -222,6 +223,9 @@ CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
|
|||||||
def create_completion(
|
def create_completion(
|
||||||
request: CreateCompletionRequest, llama: llama_cpp.Llama = Depends(get_llama)
|
request: CreateCompletionRequest, llama: llama_cpp.Llama = Depends(get_llama)
|
||||||
):
|
):
|
||||||
|
if isinstance(request.prompt, list):
|
||||||
|
request.prompt = "".join(request.prompt)
|
||||||
|
|
||||||
completion_or_chunks = llama(
|
completion_or_chunks = llama(
|
||||||
**request.dict(
|
**request.dict(
|
||||||
exclude={
|
exclude={
|
||||||
|
|||||||
Reference in New Issue
Block a user