188 lines
6.5 KiB
Python
188 lines
6.5 KiB
Python
import pickle
|
|
|
|
import torch
|
|
import yt_dlp
|
|
from dotenv import load_dotenv
|
|
from jinja2 import Template
|
|
from langchain import hub
|
|
from langchain.chains.retrieval_qa.base import RetrievalQA
|
|
from langchain_community.embeddings import HuggingFaceEmbeddings
|
|
from langchain_community.vectorstores import Chroma
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
from chromadb.config import Settings
|
|
from whisper import load_model
|
|
from loguru import logger
|
|
import tiktoken
|
|
|
|
import utils
|
|
from llms import LLMService
|
|
|
|
|
|
logger.debug("loading env vars")
|
|
load_dotenv(dotenv_path='env')
|
|
|
|
def initialize_llms():
|
|
logger.debug("initializing llm query engines")
|
|
query_engines = {}
|
|
for model in ["gpt35turbo", "gpt4"]:
|
|
llm_service = LLMService(provider="azure", model=model)
|
|
query_engine = llm_service.initialize_client_query_engine()
|
|
query_engines[model] = query_engine
|
|
return llm_service, query_engines
|
|
|
|
def initialize_stt_model():
|
|
logger.debug("loading stt model")
|
|
model = load_model("large-v3", device="cpu")
|
|
logger.debug("distributing stt model on multi-gpu setup")
|
|
model.encoder.to("cuda:0")
|
|
model.decoder.to("cuda:1")
|
|
model.decoder.register_forward_pre_hook(lambda _, inputs: tuple([inputs[0].to("cuda:1"), inputs[1].to("cuda:1")] + list(inputs[2:])))
|
|
model.decoder.register_forward_hook(lambda _, inputs, outputs: outputs.to("cuda:0"))
|
|
return model
|
|
|
|
def download_audio(url):
|
|
logger.debug(f"downloading audio from: {url}")
|
|
ydl_opts = {
|
|
'format': 'bestaudio/best',
|
|
'outtmpl': 'downloads/%(title)s.%(ext)s',
|
|
'postprocessors': [{
|
|
'key': 'FFmpegExtractAudio',
|
|
'preferredcodec': 'mp3',
|
|
}]
|
|
}
|
|
|
|
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
|
info_dict = ydl.extract_info(url, download=True)
|
|
file_path = ydl.prepare_filename(info_dict)
|
|
if file_path.endswith(".webm"):
|
|
file_path = file_path.replace(".webm", ".mp3")
|
|
file_name = file_path.split("/")[1]
|
|
|
|
return file_path, file_name
|
|
|
|
def transcribe(model, file_path, file_name):
|
|
logger.debug("transcribing audio")
|
|
transcript = model.transcribe(file_path)['text']
|
|
logger.success(f"transcript saved into cache - {file_name}")
|
|
with open(f"cache/transcripts/{file_name}", "w") as file:
|
|
file.write(transcript)
|
|
del model
|
|
torch.cuda.empty_cache()
|
|
return transcript
|
|
|
|
def divide_transcript(text, file_name, max_tokens=2048):
|
|
logger.debug(f"spltting transcript into chunks with max_tokens: {max_tokens}")
|
|
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
|
tokens = encoding.encode(text)
|
|
chunks = [tokens[i:i + max_tokens] for i in range(0, len(tokens), max_tokens)]
|
|
partial_transcripts = [encoding.decode(chunk) for chunk in chunks]
|
|
logger.success(f"subtranscripts saved into cache - {file_name}")
|
|
with open(f"cache/subtranscripts/{file_name}", "w") as file:
|
|
file.writelines(f"{partial_transcripts}")
|
|
return partial_transcripts
|
|
|
|
def generate_partial_summaries(query_engine, file_name, partial_transcripts):
|
|
template = """Look at TRANSCRIPT and give key highlights in markdown format. Additionally provide a longer summary next to key highlights.
|
|
Know that this is only partial of a longer transcript. You get which part you are looking from PART.
|
|
Answer in format:
|
|
**part** -> which PART of the transcript summarized
|
|
**key highlights** -> the important parts of partial content
|
|
**summary** -> overall summary of partial content
|
|
|
|
PART:
|
|
{part}
|
|
|
|
TRANSCRIPT:
|
|
{transcript}
|
|
"""
|
|
|
|
partial_summaries = []
|
|
for i, subtrans in enumerate(partial_transcripts):
|
|
logger.info(f"processing subtranscript {i}/{len(partial_transcripts)}")
|
|
prompt = template.format(
|
|
transcript=subtrans,
|
|
part=str(i)
|
|
)
|
|
answer, consumption = query_engine(prompt=prompt)
|
|
partial_summaries.append(answer.content)
|
|
|
|
logger.success(f"partial_summaries saved into cache - {file_name}")
|
|
with open(f"cache/summaries_partial/{file_name}", "wb") as file:
|
|
pickle.dump(partial_summaries, file)
|
|
|
|
return partial_summaries
|
|
|
|
def generate_merged_summary(query_engine, file_name, partial_summaries):
|
|
logger.info(f"processing partial summaries into main summary")
|
|
template_merge = Template("""Look at all the partial summaries, understand and provide combined comprehensive key highlights in markdown format.
|
|
Answer in markdowm format:
|
|
**summary**
|
|
**key highlights**
|
|
|
|
{% for partial_summary in partial_summaries %}
|
|
- PARTIAL SUMMARY ({{ loop.index }}):
|
|
{{ partial_summary }}
|
|
{% endfor %}
|
|
|
|
""")
|
|
|
|
prompt_merge = template_merge.render(
|
|
partial_summaries=partial_summaries
|
|
)
|
|
answer_merge, consumption_merge = query_engine(prompt=prompt_merge)
|
|
|
|
logger.success(f"merged summary saved into cache - {file_name}")
|
|
with open(f"cache/summaries_merged/{file_name}", "wb") as file:
|
|
pickle.dump(answer_merge.content, file)
|
|
|
|
return answer_merge.content
|
|
|
|
def split_into_chunks(transcript):
|
|
logger.debug("splitting transcript into chunks for rag")
|
|
text_splitter = RecursiveCharacterTextSplitter(
|
|
chunk_size=1000,
|
|
chunk_overlap=250
|
|
)
|
|
|
|
chunks = text_splitter.split_text(transcript)
|
|
logger.debug(f"Total chunk count: {len(chunks)}")
|
|
|
|
chunk_counts = []
|
|
token_counts = []
|
|
|
|
for chunk in chunks:
|
|
chunk_counts.append(len(chunk))
|
|
token_counts.append(utils.count_tokens(text=chunk))
|
|
|
|
aver_chunks = sum(chunk_counts) / len(chunk_counts) if chunk_counts else 0
|
|
aver_tokens = sum(token_counts) / len(token_counts) if token_counts else 0
|
|
logger.debug(f"Average chunk len: {aver_chunks}")
|
|
logger.debug(f"Average chunk token count: {aver_tokens}")
|
|
|
|
return chunks
|
|
|
|
def instantiate_qa_chain(knowledge_chunks):
|
|
logger.debug("instantiating rag qa chain")
|
|
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
|
|
|
|
vdb = Chroma.from_texts(
|
|
knowledge_chunks,
|
|
embeddings,
|
|
persist_directory="./cache/vdb",
|
|
client_settings= Settings(anonymized_telemetry=False)
|
|
)
|
|
|
|
llm_service = LLMService(provider="azure", model="gpt4")
|
|
_ = llm_service.initialize_client_query_engine()
|
|
|
|
prompt = hub.pull("rlm/rag-prompt")
|
|
with open(f"cache/prompts_lc_hub/rag-prompt", "wb") as file:
|
|
pickle.dump(prompt, file)
|
|
|
|
qa_chain = RetrievalQA.from_chain_type(
|
|
llm=llm_service.llm,
|
|
retriever=vdb.as_retriever(),
|
|
chain_type_kwargs={"prompt": prompt}
|
|
)
|
|
return qa_chain
|