mirror of
https://github.com/abetlen/llama-cpp-python.git
synced 2023-09-07 17:34:22 +03:00
Update to more sensible return signature
This commit is contained in:
@@ -2,7 +2,7 @@ import os
|
||||
import uuid
|
||||
import time
|
||||
import multiprocessing
|
||||
from typing import List, Optional, Union, Generator, Sequence
|
||||
from typing import List, Optional, Union, Generator, Sequence, Iterator
|
||||
from collections import deque
|
||||
|
||||
from . import llama_cpp
|
||||
@@ -286,10 +286,7 @@ class Llama:
|
||||
repeat_penalty: float = 1.1,
|
||||
top_k: int = 40,
|
||||
stream: bool = False,
|
||||
) -> Union[
|
||||
Generator[Completion, None, None],
|
||||
Generator[CompletionChunk, None, None],
|
||||
]:
|
||||
) -> Union[Iterator[Completion], Iterator[CompletionChunk],]:
|
||||
assert self.ctx is not None
|
||||
completion_id = f"cmpl-{str(uuid.uuid4())}"
|
||||
created = int(time.time())
|
||||
@@ -428,7 +425,7 @@ class Llama:
|
||||
repeat_penalty: float = 1.1,
|
||||
top_k: int = 40,
|
||||
stream: bool = False,
|
||||
) -> Union[Completion, Generator[CompletionChunk, None, None]]:
|
||||
) -> Union[Completion, Iterator[CompletionChunk]]:
|
||||
"""Generate text from a prompt.
|
||||
|
||||
Args:
|
||||
@@ -465,7 +462,7 @@ class Llama:
|
||||
stream=stream,
|
||||
)
|
||||
if stream:
|
||||
chunks: Generator[CompletionChunk, None, None] = completion_or_chunks
|
||||
chunks: Iterator[CompletionChunk] = completion_or_chunks
|
||||
return chunks
|
||||
completion: Completion = next(completion_or_chunks) # type: ignore
|
||||
return completion
|
||||
|
||||
Reference in New Issue
Block a user