Files
agent-video-summarizer/pipeline.py
ALIHAN DIKEL 1d192dd699 auto deploy
2024-06-10 03:09:50 +03:00

188 lines
6.5 KiB
Python

import pickle
import torch
import yt_dlp
from dotenv import load_dotenv
from jinja2 import Template
from langchain import hub
from langchain.chains.retrieval_qa.base import RetrievalQA
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_text_splitters import RecursiveCharacterTextSplitter
from chromadb.config import Settings
from whisper import load_model
from loguru import logger
import tiktoken
import utils
from llms import LLMService
logger.debug("loading env vars")
load_dotenv(dotenv_path='env')
def initialize_llms():
logger.debug("initializing llm query engines")
query_engines = {}
for model in ["gpt35turbo", "gpt4"]:
llm_service = LLMService(provider="azure", model=model)
query_engine = llm_service.initialize_client_query_engine()
query_engines[model] = query_engine
return llm_service, query_engines
def initialize_stt_model():
logger.debug("loading stt model")
model = load_model("large-v3", device="cpu")
logger.debug("distributing stt model on multi-gpu setup")
model.encoder.to("cuda:0")
model.decoder.to("cuda:1")
model.decoder.register_forward_pre_hook(lambda _, inputs: tuple([inputs[0].to("cuda:1"), inputs[1].to("cuda:1")] + list(inputs[2:])))
model.decoder.register_forward_hook(lambda _, inputs, outputs: outputs.to("cuda:0"))
return model
def download_audio(url):
logger.debug(f"downloading audio from: {url}")
ydl_opts = {
'format': 'bestaudio/best',
'outtmpl': 'downloads/%(title)s.%(ext)s',
'postprocessors': [{
'key': 'FFmpegExtractAudio',
'preferredcodec': 'mp3',
}]
}
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
info_dict = ydl.extract_info(url, download=True)
file_path = ydl.prepare_filename(info_dict)
if file_path.endswith(".webm"):
file_path = file_path.replace(".webm", ".mp3")
file_name = file_path.split("/")[1]
return file_path, file_name
def transcribe(model, file_path, file_name):
logger.debug("transcribing audio")
transcript = model.transcribe(file_path)['text']
logger.success(f"transcript saved into cache - {file_name}")
with open(f"cache/transcripts/{file_name}", "w") as file:
file.write(transcript)
del model
torch.cuda.empty_cache()
return transcript
def divide_transcript(text, file_name, max_tokens=2048):
logger.debug(f"spltting transcript into chunks with max_tokens: {max_tokens}")
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
tokens = encoding.encode(text)
chunks = [tokens[i:i + max_tokens] for i in range(0, len(tokens), max_tokens)]
partial_transcripts = [encoding.decode(chunk) for chunk in chunks]
logger.success(f"subtranscripts saved into cache - {file_name}")
with open(f"cache/subtranscripts/{file_name}", "w") as file:
file.writelines(f"{partial_transcripts}")
return partial_transcripts
def generate_partial_summaries(query_engine, file_name, partial_transcripts):
template = """Look at TRANSCRIPT and give key highlights in markdown format. Additionally provide a longer summary next to key highlights.
Know that this is only partial of a longer transcript. You get which part you are looking from PART.
Answer in format:
**part** -> which PART of the transcript summarized
**key highlights** -> the important parts of partial content
**summary** -> overall summary of partial content
PART:
{part}
TRANSCRIPT:
{transcript}
"""
partial_summaries = []
for i, subtrans in enumerate(partial_transcripts):
logger.info(f"processing subtranscript {i}/{len(partial_transcripts)}")
prompt = template.format(
transcript=subtrans,
part=str(i)
)
answer, consumption = query_engine(prompt=prompt)
partial_summaries.append(answer.content)
logger.success(f"partial_summaries saved into cache - {file_name}")
with open(f"cache/summaries_partial/{file_name}", "wb") as file:
pickle.dump(partial_summaries, file)
return partial_summaries
def generate_merged_summary(query_engine, file_name, partial_summaries):
logger.info(f"processing partial summaries into main summary")
template_merge = Template("""Look at all the partial summaries, understand and provide combined comprehensive key highlights in markdown format.
Answer in markdowm format:
**summary**
**key highlights**
{% for partial_summary in partial_summaries %}
- PARTIAL SUMMARY ({{ loop.index }}):
{{ partial_summary }}
{% endfor %}
""")
prompt_merge = template_merge.render(
partial_summaries=partial_summaries
)
answer_merge, consumption_merge = query_engine(prompt=prompt_merge)
logger.success(f"merged summary saved into cache - {file_name}")
with open(f"cache/summaries_merged/{file_name}", "wb") as file:
pickle.dump(answer_merge.content, file)
return answer_merge.content
def split_into_chunks(transcript):
logger.debug("splitting transcript into chunks for rag")
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=250
)
chunks = text_splitter.split_text(transcript)
logger.debug(f"Total chunk count: {len(chunks)}")
chunk_counts = []
token_counts = []
for chunk in chunks:
chunk_counts.append(len(chunk))
token_counts.append(utils.count_tokens(text=chunk))
aver_chunks = sum(chunk_counts) / len(chunk_counts) if chunk_counts else 0
aver_tokens = sum(token_counts) / len(token_counts) if token_counts else 0
logger.debug(f"Average chunk len: {aver_chunks}")
logger.debug(f"Average chunk token count: {aver_tokens}")
return chunks
def instantiate_qa_chain(knowledge_chunks):
logger.debug("instantiating rag qa chain")
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
vdb = Chroma.from_texts(
knowledge_chunks,
embeddings,
persist_directory="./cache/vdb",
client_settings= Settings(anonymized_telemetry=False)
)
llm_service = LLMService(provider="azure", model="gpt4")
_ = llm_service.initialize_client_query_engine()
prompt = hub.pull("rlm/rag-prompt")
with open(f"cache/prompts_lc_hub/rag-prompt", "wb") as file:
pickle.dump(prompt, file)
qa_chain = RetrievalQA.from_chain_type(
llm=llm_service.llm,
retriever=vdb.as_retriever(),
chain_type_kwargs={"prompt": prompt}
)
return qa_chain