This commit is contained in:
wbhu
2024-09-14 15:33:28 +08:00
committed by wbhu
commit 7c1a14b08a
12 changed files with 1302 additions and 0 deletions

37
.gitattributes vendored Normal file
View File

@@ -0,0 +1,37 @@
*.7z filter=lfs diff=lfs merge=lfs -text
*.arrow filter=lfs diff=lfs merge=lfs -text
*.bin filter=lfs diff=lfs merge=lfs -text
*.bz2 filter=lfs diff=lfs merge=lfs -text
*.ckpt filter=lfs diff=lfs merge=lfs -text
*.ftz filter=lfs diff=lfs merge=lfs -text
*.gz filter=lfs diff=lfs merge=lfs -text
*.h5 filter=lfs diff=lfs merge=lfs -text
*.joblib filter=lfs diff=lfs merge=lfs -text
*.lfs.* filter=lfs diff=lfs merge=lfs -text
*.mlmodel filter=lfs diff=lfs merge=lfs -text
*.model filter=lfs diff=lfs merge=lfs -text
*.msgpack filter=lfs diff=lfs merge=lfs -text
*.npy filter=lfs diff=lfs merge=lfs -text
*.npz filter=lfs diff=lfs merge=lfs -text
*.onnx filter=lfs diff=lfs merge=lfs -text
*.ot filter=lfs diff=lfs merge=lfs -text
*.parquet filter=lfs diff=lfs merge=lfs -text
*.pb filter=lfs diff=lfs merge=lfs -text
*.pickle filter=lfs diff=lfs merge=lfs -text
*.pkl filter=lfs diff=lfs merge=lfs -text
*.pt filter=lfs diff=lfs merge=lfs -text
*.pth filter=lfs diff=lfs merge=lfs -text
*.rar filter=lfs diff=lfs merge=lfs -text
*.safetensors filter=lfs diff=lfs merge=lfs -text
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.tar.* filter=lfs diff=lfs merge=lfs -text
*.tar filter=lfs diff=lfs merge=lfs -text
*.tflite filter=lfs diff=lfs merge=lfs -text
*.tgz filter=lfs diff=lfs merge=lfs -text
*.wasm filter=lfs diff=lfs merge=lfs -text
*.xz filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
*.gif filter=lfs diff=lfs merge=lfs -text
*.mp4 filter=lfs diff=lfs merge=lfs -text

169
.gitignore vendored Normal file
View File

@@ -0,0 +1,169 @@
### Python template
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
/logs
/gin-config
*.json
/eval/*csv
*__pycache__
scripts/
eval/

32
LICENSE Normal file
View File

@@ -0,0 +1,32 @@
Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. The below software in this distribution may have been modified by THL A29 Limited ("Tencent Modifications").
License Terms of the inference code of DepthCrafter:
--------------------------------------------------------------------
Permission is hereby granted, free of charge, to any person obtaining a copy of this Software and associated documentation files, to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, and/or sublicense copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
- You agree to use the DepthCrafter only for academic, research and education purposes, and refrain from using it for any commercial or production purposes under any circumstances.
- The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
For avoidance of doubts, “Software” means the DepthCrafter model inference code and weights made available under this license excluding any pre-trained data and other AI components.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
Other dependencies and licenses:
Open Source Software Licensed under the MIT License:
--------------------------------------------------------------------
1. Stability AI - Code
Copyright (c) 2023 Stability AI
Terms of the MIT License:
--------------------------------------------------------------------
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
**You may find the code license of Stability AI at the following links: https://github.com/Stability-AI/generative-models/blob/main/LICENSE-CODE

105
README.md Normal file
View File

@@ -0,0 +1,105 @@
## ___***DepthCrafter: Generating Consistent Long Depth Sequences for Open-world Videos***___
<div align="center">
<img src='https://depthcrafter.github.io/img/logo.png' style="height:140px"></img>
<a href='https://arxiv.org/abs/2409.02095'><img src='https://img.shields.io/badge/arXiv-2409.02095-b31b1b.svg'></a> &nbsp;
<a href='https://depthcrafter.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a> &nbsp;
_**[Wenbo Hu<sup>1* &dagger;</sup>](https://wbhu.github.io),
[Xiangjun Gao<sup>2*</sup>](https://scholar.google.com/citations?user=qgdesEcAAAAJ&hl=en),
[Xiaoyu Li<sup>1* &dagger;</sup>](https://xiaoyu258.github.io),
[Sijie Zhao<sup>1</sup>](https://scholar.google.com/citations?user=tZ3dS3MAAAAJ&hl=en),
[Xiaodong Cun<sup>1</sup>](https://vinthony.github.io/academic), <br>
[Yong Zhang<sup>1</sup>](https://yzhang2016.github.io),
[Long Quan<sup>2</sup>](https://home.cse.ust.hk/~quan),
[Ying Shan<sup>3, 1</sup>](https://scholar.google.com/citations?user=4oXBp9UAAAAJ&hl=en)**_
<br><br>
<sup>1</sup>Tencent AI Lab
<sup>2</sup>The Hong Kong University of Science and Technology
<sup>3</sup>ARC Lab, Tencent PCG
arXiv preprint, 2024
</div>
## 🔆 Introduction
🤗 DepthCrafter can generate temporally consistent long depth sequences with fine-grained details for open-world videos,
without requiring additional information such as camera poses or optical flow.
## 🎥 Visualization
We provide some demos of unprojected point cloud sequences, with reference RGB and estimated depth videos.
Please refer to our [project page](https://depthcrafter.github.io) for more details.
https://github.com/user-attachments/assets/62141cc8-04d0-458f-9558-fe50bc04cc21
## 🚀 Quick Start
### 🛠️ Installation
1. Clone this repo:
```bash
git clone https://github.com/Tencent/DepthCrafter.git
```
2. Install dependencies (please refer to [requirements.txt](requirements.txt)):
```bash
pip install -r requirements.txt
```
## 🤗 Model Zoo
[DepthCrafter](https://huggingface.co/tencent/DepthCrafter) is available in the Hugging Face Model Hub.
### 🏃‍♂️ Inference
#### 1. High-resolution inference, requires a GPU with ~26GB memory for 1024x576 resolution:
- Full inference (~0.6 fps on A100, recommended for high-quality results):
```bash
python run.py --video-path examples/example_01.mp4
```
- Fast inference through 4-step denoising and without classifier-free guidance ~2.3 fps on A100:
```bash
python run.py --video-path examples/example_01.mp4 --num-inference-steps 4 --guidance-scale 1.0
```
#### 2. Low-resolution inference, requires a GPU with ~9GB memory for 512x256 resolution:
- Full inference (~2.3 fps on A100):
```bash
python run.py --video-path examples/example_01.mp4 --max-res 512
```
- Fast inference through 4-step denoising and without classifier-free guidance (~9.4 fps on A100):
```bash
python run.py --video-path examples/example_01.mp4 --max-res 512 --num-inference-steps 4 --guidance-scale 1.0
```
## 🤖 Gradio Demo
We provide a local Gradio demo for DepthCrafter, which can be launched by running:
```bash
gradio app.py
```
## 🤝 Contributing
- Welcome to open issues and pull requests.
- Welcome to optimize the inference speed and memory usage, e.g., through model quantization, distillation, or other acceleration techniques.
## 📜 Citation
If you find this work helpful, please consider citing:
```bibtex
@article{hu2024-DepthCrafter,
author = {Hu, Wenbo and Gao, Xiangjun and Li, Xiaoyu and Zhao, Sijie and Cun, Xiaodong and Zhang, Yong and Quan, Long and Shan, Ying},
title = {DepthCrafter: Generating Consistent Long Depth Sequences for Open-world Videos},
journal = {arXiv preprint arXiv:2409.02095},
year = {2024}
}
```

141
app.py Normal file
View File

@@ -0,0 +1,141 @@
import gc
import os
from copy import deepcopy
import gradio as gr
import numpy as np
import torch
from diffusers.training_utils import set_seed
from depthcrafter.depth_crafter_ppl import DepthCrafterPipeline
from depthcrafter.unet import DiffusersUNetSpatioTemporalConditionModelDepthCrafter
from depthcrafter.utils import read_video_frames, vis_sequence_depth, save_video
from run import DepthCrafterDemo
examples = [
["examples/example_01.mp4", 25, 1.2, 1024, 195],
]
def construct_demo():
with gr.Blocks(analytics_enabled=False) as depthcrafter_iface:
gr.Markdown(
"""
<div align='center'> <h1> DepthCrafter: Generating Consistent Long Depth Sequences for Open-world Videos </span> </h1> \
<h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
<a href='https://wbhu.github.io'>Wenbo Hu</a>, \
<a href='https://scholar.google.com/citations?user=qgdesEcAAAAJ&hl=en'>Xiangjun Gao</a>, \
<a href='https://xiaoyu258.github.io/'>Xiaoyu Li</a>, \
<a href='https://scholar.google.com/citations?user=tZ3dS3MAAAAJ&hl=en'>Sijie Zhao</a>, \
<a href='https://vinthony.github.io/academic'> Xiaodong Cun</a>, \
<a href='https://yzhang2016.github.io'>Yong Zhang</a>, \
<a href='https://home.cse.ust.hk/~quan'>Long Quan</a>, \
<a href='https://scholar.google.com/citations?user=4oXBp9UAAAAJ&hl=en'>Ying Shan</a>\
</h2> \
<a style='font-size:18px;color: #000000'>If you find DepthCrafter useful, please help star the </a>\
<a style='font-size:18px;color: #FF5DB0' href='https://github.com/wbhu/DepthCrafter'>[Github Repo]</a>\
<a style='font-size:18px;color: #000000'>, which is important to Open-Source projects. Thanks!</a>\
<a style='font-size:18px;color: #000000' href='https://arxiv.org/abs/2409.02095'> [ArXiv] </a>\
<a style='font-size:18px;color: #000000' href='https://depthcrafter.github.io/'> [Project Page] </a> </div>
"""
)
# demo
depthcrafter_demo = DepthCrafterDemo(
unet_path="tencent/DepthCrafter",
pre_train_path="stabilityai/stable-video-diffusion-img2vid-xt",
)
with gr.Row(equal_height=True):
with gr.Column(scale=1):
input_video = gr.Video(label="Input Video")
# with gr.Tab(label="Output"):
with gr.Column(scale=2):
with gr.Row(equal_height=True):
output_video_1 = gr.Video(
label="Preprocessed video",
interactive=False,
autoplay=True,
loop=True,
show_share_button=True,
scale=5,
)
output_video_2 = gr.Video(
label="Generated Depth Video",
interactive=False,
autoplay=True,
loop=True,
show_share_button=True,
scale=5,
)
with gr.Row(equal_height=True):
with gr.Column(scale=1):
with gr.Row(equal_height=False):
with gr.Accordion("Advanced Settings", open=False):
num_denoising_steps = gr.Slider(
label="num denoising steps",
minimum=1,
maximum=25,
value=25,
step=1,
)
guidance_scale = gr.Slider(
label="cfg scale",
minimum=1.0,
maximum=1.2,
value=1.2,
step=0.1,
)
max_res = gr.Slider(
label="max resolution",
minimum=512,
maximum=2048,
value=1024,
step=64,
)
process_length = gr.Slider(
label="process length",
minimum=1,
maximum=280,
value=195,
step=1,
)
generate_btn = gr.Button("Generate")
with gr.Column(scale=2):
pass
gr.Examples(
examples=examples,
inputs=[
input_video,
num_denoising_steps,
guidance_scale,
max_res,
process_length,
],
outputs=[output_video_1, output_video_2],
fn=depthcrafter_demo.run,
cache_examples=False,
)
generate_btn.click(
fn=depthcrafter_demo.run,
inputs=[
input_video,
num_denoising_steps,
guidance_scale,
max_res,
process_length,
],
outputs=[output_video_1, output_video_2],
)
return depthcrafter_iface
demo = construct_demo()
if __name__ == "__main__":
demo.queue()
demo.launch(server_name="0.0.0.0", server_port=80, debug=True)

0
depthcrafter/__init__.py Normal file
View File

View File

@@ -0,0 +1,366 @@
from typing import Callable, Dict, List, Optional, Union
import numpy as np
import torch
from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import (
_resize_with_antialiasing,
StableVideoDiffusionPipelineOutput,
StableVideoDiffusionPipeline,
retrieve_timesteps,
)
from diffusers.utils import logging
from diffusers.utils.torch_utils import randn_tensor
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class DepthCrafterPipeline(StableVideoDiffusionPipeline):
@torch.inference_mode()
def encode_video(
self,
video: torch.Tensor,
chunk_size: int = 14,
) -> torch.Tensor:
"""
:param video: [b, c, h, w] in range [-1, 1], the b may contain multiple videos or frames
:param chunk_size: the chunk size to encode video
:return: image_embeddings in shape of [b, 1024]
"""
video_224 = _resize_with_antialiasing(video.float(), (224, 224))
video_224 = (video_224 + 1.0) / 2.0 # [-1, 1] -> [0, 1]
embeddings = []
for i in range(0, video_224.shape[0], chunk_size):
tmp = self.feature_extractor(
images=video_224[i : i + chunk_size],
do_normalize=True,
do_center_crop=False,
do_resize=False,
do_rescale=False,
return_tensors="pt",
).pixel_values.to(video.device, dtype=video.dtype)
embeddings.append(self.image_encoder(tmp).image_embeds) # [b, 1024]
embeddings = torch.cat(embeddings, dim=0) # [t, 1024]
return embeddings
@torch.inference_mode()
def encode_vae_video(
self,
video: torch.Tensor,
chunk_size: int = 14,
):
"""
:param video: [b, c, h, w] in range [-1, 1], the b may contain multiple videos or frames
:param chunk_size: the chunk size to encode video
:return: vae latents in shape of [b, c, h, w]
"""
video_latents = []
for i in range(0, video.shape[0], chunk_size):
video_latents.append(
self.vae.encode(video[i : i + chunk_size]).latent_dist.mode()
)
video_latents = torch.cat(video_latents, dim=0)
return video_latents
@staticmethod
def check_inputs(video, height, width):
"""
:param video:
:param height:
:param width:
:return:
"""
if not isinstance(video, torch.Tensor) and not isinstance(video, np.ndarray):
raise ValueError(
f"Expected `video` to be a `torch.Tensor` or `VideoReader`, but got a {type(video)}"
)
if height % 8 != 0 or width % 8 != 0:
raise ValueError(
f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
)
@torch.no_grad()
def __call__(
self,
video: Union[np.ndarray, torch.Tensor],
height: int = 576,
width: int = 1024,
num_inference_steps: int = 25,
guidance_scale: float = 1.0,
window_size: Optional[int] = 110,
noise_aug_strength: float = 0.02,
decode_chunk_size: Optional[int] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
return_dict: bool = True,
overlap: int = 25,
track_time: bool = False,
):
"""
:param video: in shape [t, h, w, c] if np.ndarray or [t, c, h, w] if torch.Tensor, in range [0, 1]
:param height:
:param width:
:param num_inference_steps:
:param guidance_scale:
:param window_size: sliding window processing size
:param fps:
:param motion_bucket_id:
:param noise_aug_strength:
:param decode_chunk_size:
:param generator:
:param latents:
:param output_type:
:param callback_on_step_end:
:param callback_on_step_end_tensor_inputs:
:param return_dict:
:return:
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
num_frames = video.shape[0]
decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else 8
if num_frames <= window_size:
window_size = num_frames
overlap = 0
stride = window_size - overlap
# 1. Check inputs. Raise error if not correct
self.check_inputs(video, height, width)
# 2. Define call parameters
batch_size = 1
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
self._guidance_scale = guidance_scale
# 3. Encode input video
if isinstance(video, np.ndarray):
video = torch.from_numpy(video.transpose(0, 3, 1, 2))
else:
assert isinstance(video, torch.Tensor)
video = video.to(device=device, dtype=self.dtype)
video = video * 2.0 - 1.0 # [0,1] -> [-1,1], in [t, c, h, w]
if track_time:
start_event = torch.cuda.Event(enable_timing=True)
encode_event = torch.cuda.Event(enable_timing=True)
denoise_event = torch.cuda.Event(enable_timing=True)
decode_event = torch.cuda.Event(enable_timing=True)
start_event.record()
video_embeddings = self.encode_video(
video, chunk_size=decode_chunk_size
).unsqueeze(
0
) # [1, t, 1024]
torch.cuda.empty_cache()
# 4. Encode input image using VAE
noise = randn_tensor(
video.shape, generator=generator, device=device, dtype=video.dtype
)
video = video + noise_aug_strength * noise # in [t, c, h, w]
# pdb.set_trace()
needs_upcasting = (
self.vae.dtype == torch.float16 and self.vae.config.force_upcast
)
if needs_upcasting:
self.vae.to(dtype=torch.float32)
video_latents = self.encode_vae_video(
video.to(self.vae.dtype),
chunk_size=decode_chunk_size,
).unsqueeze(
0
) # [1, t, c, h, w]
if track_time:
encode_event.record()
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(encode_event)
print(f"Elapsed time for encoding video: {elapsed_time_ms} ms")
torch.cuda.empty_cache()
# cast back to fp16 if needed
if needs_upcasting:
self.vae.to(dtype=torch.float16)
# 5. Get Added Time IDs
added_time_ids = self._get_add_time_ids(
7,
127,
noise_aug_strength,
video_embeddings.dtype,
batch_size,
1,
False,
) # [1 or 2, 3]
added_time_ids = added_time_ids.to(device)
# 6. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, None, None
)
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
# 7. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents_init = self.prepare_latents(
batch_size,
window_size,
num_channels_latents,
height,
width,
video_embeddings.dtype,
device,
generator,
latents,
) # [1, t, c, h, w]
latents_all = None
idx_start = 0
if overlap > 0:
weights = torch.linspace(0, 1, overlap, device=device)
weights = weights.view(1, overlap, 1, 1, 1)
else:
weights = None
torch.cuda.empty_cache()
# inference strategy for long videos
# two main strategies: 1. noise init from previous frame, 2. segments stitching
while idx_start < num_frames - overlap:
idx_end = min(idx_start + window_size, num_frames)
self.scheduler.set_timesteps(num_inference_steps, device=device)
# 9. Denoising loop
latents = latents_init[:, : idx_end - idx_start].clone()
latents_init = torch.cat(
[latents_init[:, -overlap:], latents_init[:, :stride]], dim=1
)
video_latents_current = video_latents[:, idx_start:idx_end]
video_embeddings_current = video_embeddings[:, idx_start:idx_end]
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if latents_all is not None and i == 0:
latents[:, :overlap] = (
latents_all[:, -overlap:]
+ latents[:, :overlap]
/ self.scheduler.init_noise_sigma
* self.scheduler.sigmas[i]
)
latent_model_input = latents # [1, t, c, h, w]
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
) # [1, t, c, h, w]
latent_model_input = torch.cat(
[latent_model_input, video_latents_current], dim=2
)
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=video_embeddings_current,
added_time_ids=added_time_ids,
return_dict=False,
)[0]
# perform guidance
if self.do_classifier_free_guidance:
latent_model_input = latents
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
)
latent_model_input = torch.cat(
[latent_model_input, torch.zeros_like(latent_model_input)],
dim=2,
)
noise_pred_uncond = self.unet(
latent_model_input,
t,
encoder_hidden_states=torch.zeros_like(
video_embeddings_current
),
added_time_ids=added_time_ids,
return_dict=False,
)[0]
noise_pred = noise_pred_uncond + self.guidance_scale * (
noise_pred - noise_pred_uncond
)
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(
self, i, t, callback_kwargs
)
latents = callback_outputs.pop("latents", latents)
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps
and (i + 1) % self.scheduler.order == 0
):
progress_bar.update()
if latents_all is None:
latents_all = latents.clone()
else:
assert weights is not None
# latents_all[:, -overlap:] = (
# latents[:, :overlap] + latents_all[:, -overlap:]
# ) / 2.0
latents_all[:, -overlap:] = latents[
:, :overlap
] * weights + latents_all[:, -overlap:] * (1 - weights)
latents_all = torch.cat([latents_all, latents[:, overlap:]], dim=1)
idx_start += stride
if track_time:
denoise_event.record()
torch.cuda.synchronize()
elapsed_time_ms = encode_event.elapsed_time(denoise_event)
print(f"Elapsed time for denoising video: {elapsed_time_ms} ms")
if not output_type == "latent":
# cast back to fp16 if needed
if needs_upcasting:
self.vae.to(dtype=torch.float16)
frames = self.decode_latents(latents_all, num_frames, decode_chunk_size)
if track_time:
decode_event.record()
torch.cuda.synchronize()
elapsed_time_ms = denoise_event.elapsed_time(decode_event)
print(f"Elapsed time for decoding video: {elapsed_time_ms} ms")
frames = self.video_processor.postprocess_video(
video=frames, output_type=output_type
)
else:
frames = latents_all
self.maybe_free_model_hooks()
if not return_dict:
return frames
return StableVideoDiffusionPipelineOutput(frames=frames)

142
depthcrafter/unet.py Normal file
View File

@@ -0,0 +1,142 @@
from typing import Union, Tuple
import torch
from diffusers import UNetSpatioTemporalConditionModel
from diffusers.models.unets.unet_spatio_temporal_condition import UNetSpatioTemporalConditionOutput
class DiffusersUNetSpatioTemporalConditionModelDepthCrafter(
UNetSpatioTemporalConditionModel
):
def forward(
self,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
added_time_ids: torch.Tensor,
return_dict: bool = True,
) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
batch_size, num_frames = sample.shape[:2]
timesteps = timesteps.expand(batch_size)
t_emb = self.time_proj(timesteps)
# `Timesteps` does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.conv_in.weight.dtype)
emb = self.time_embedding(t_emb) # [batch_size * num_frames, channels]
time_embeds = self.add_time_proj(added_time_ids.flatten())
time_embeds = time_embeds.reshape((batch_size, -1))
time_embeds = time_embeds.to(emb.dtype)
aug_emb = self.add_embedding(time_embeds)
emb = emb + aug_emb
# Flatten the batch and frames dimensions
# sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
sample = sample.flatten(0, 1)
# Repeat the embeddings num_video_frames times
# emb: [batch, channels] -> [batch * frames, channels]
emb = emb.repeat_interleave(num_frames, dim=0)
# encoder_hidden_states: [batch, frames, channels] -> [batch * frames, 1, channels]
encoder_hidden_states = encoder_hidden_states.flatten(0, 1).unsqueeze(1)
# 2. pre-process
sample = sample.to(dtype=self.conv_in.weight.dtype)
assert sample.dtype == self.conv_in.weight.dtype, (
f"sample.dtype: {sample.dtype}, "
f"self.conv_in.weight.dtype: {self.conv_in.weight.dtype}"
)
sample = self.conv_in(sample)
image_only_indicator = torch.zeros(
batch_size, num_frames, dtype=sample.dtype, device=sample.device
)
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if (
hasattr(downsample_block, "has_cross_attention")
and downsample_block.has_cross_attention
):
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
image_only_indicator=image_only_indicator,
)
else:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
image_only_indicator=image_only_indicator,
)
down_block_res_samples += res_samples
# 4. mid
sample = self.mid_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
image_only_indicator=image_only_indicator,
)
# 5. up
for i, upsample_block in enumerate(self.up_blocks):
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[
: -len(upsample_block.resnets)
]
if (
hasattr(upsample_block, "has_cross_attention")
and upsample_block.has_cross_attention
):
sample = upsample_block(
hidden_states=sample,
res_hidden_states_tuple=res_samples,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
image_only_indicator=image_only_indicator,
)
else:
sample = upsample_block(
hidden_states=sample,
res_hidden_states_tuple=res_samples,
temb=emb,
image_only_indicator=image_only_indicator,
)
# 6. post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
# 7. Reshape back to original shape
sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
if not return_dict:
return (sample,)
return UNetSpatioTemporalConditionOutput(sample=sample)

92
depthcrafter/utils.py Normal file
View File

@@ -0,0 +1,92 @@
import numpy as np
import cv2
import matplotlib.cm as cm
import torch
def read_video_frames(video_path, process_length, target_fps, max_res):
# a simple function to read video frames
cap = cv2.VideoCapture(video_path)
original_fps = cap.get(cv2.CAP_PROP_FPS)
original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
# round the height and width to the nearest multiple of 64
height = round(original_height / 64) * 64
width = round(original_width / 64) * 64
# resize the video if the height or width is larger than max_res
if max(height, width) > max_res:
scale = max_res / max(original_height, original_width)
height = round(original_height * scale / 64) * 64
width = round(original_width * scale / 64) * 64
if target_fps < 0:
target_fps = original_fps
stride = max(round(original_fps / target_fps), 1)
frames = []
frame_count = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret or (process_length > 0 and frame_count >= process_length):
break
if frame_count % stride == 0:
frame = cv2.resize(frame, (width, height))
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # Convert BGR to RGB
frames.append(frame.astype("float32") / 255.0)
frame_count += 1
cap.release()
frames = np.array(frames)
return frames, target_fps
def save_video(
video_frames,
output_video_path,
fps: int = 15,
) -> str:
# a simple function to save video frames
height, width = video_frames[0].shape[:2]
is_color = video_frames[0].ndim == 3
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
video_writer = cv2.VideoWriter(
output_video_path, fourcc, fps, (width, height), isColor=is_color
)
for frame in video_frames:
frame = (frame * 255).astype(np.uint8)
if is_color:
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
video_writer.write(frame)
video_writer.release()
return output_video_path
class ColorMapper:
# a color mapper to map depth values to a certain colormap
def __init__(self, colormap: str = "inferno"):
self.colormap = torch.tensor(cm.get_cmap(colormap).colors)
def apply(self, image: torch.Tensor, v_min=None, v_max=None):
# assert len(image.shape) == 2
if v_min is None:
v_min = image.min()
if v_max is None:
v_max = image.max()
image = (image - v_min) / (v_max - v_min)
image = (image * 255).long()
image = self.colormap[image]
return image
def vis_sequence_depth(depths: np.ndarray, v_min=None, v_max=None):
visualizer = ColorMapper()
if v_min is None:
v_min = depths.min()
if v_max is None:
v_max = depths.max()
res = visualizer.apply(torch.tensor(depths), v_min=v_min, v_max=v_max).numpy()
return res

3
examples/example_01.mp4 Executable file
View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:afb78decc210225793b20d5bca5b13da07c97233e6fabea44bf02eba8a52bdaf
size 14393250

5
requirements.txt Normal file
View File

@@ -0,0 +1,5 @@
torch==2.3.0+cu117
diffusers==0.29.1
numpy==1.26.4
matplotlib==3.8.4
opencv-python==4.8.1.78

210
run.py Normal file
View File

@@ -0,0 +1,210 @@
import gc
import os
import numpy as np
import torch
import argparse
from diffusers.training_utils import set_seed
from depthcrafter.depth_crafter_ppl import DepthCrafterPipeline
from depthcrafter.unet import DiffusersUNetSpatioTemporalConditionModelDepthCrafter
from depthcrafter.utils import vis_sequence_depth, save_video, read_video_frames
class DepthCrafterDemo:
def __init__(
self,
unet_path: str,
pre_train_path: str,
cpu_offload: str = "model",
):
unet = DiffusersUNetSpatioTemporalConditionModelDepthCrafter.from_pretrained(
unet_path,
subfolder="unet",
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
)
# load weights of other components from the provided checkpoint
self.pipe = DepthCrafterPipeline.from_pretrained(
pre_train_path,
unet=unet,
torch_dtype=torch.float16,
variant="fp16",
)
# for saving memory, we can offload the model to CPU, or even run the model sequentially to save more memory
if cpu_offload is not None:
if cpu_offload == "sequential":
# This will slow, but save more memory
self.pipe.enable_sequential_cpu_offload()
elif cpu_offload == "model":
self.pipe.enable_model_cpu_offload()
else:
raise ValueError(f"Unknown cpu offload option: {cpu_offload}")
else:
self.pipe.to("cuda")
# enable attention slicing and xformers memory efficient attention
try:
self.pipe.enable_xformers_memory_efficient_attention()
except Exception as e:
print(e)
print("Xformers is not enabled")
self.pipe.enable_attention_slicing()
def infer(
self,
video: str,
num_denoising_steps: int,
guidance_scale: float,
save_folder: str = "./demo_output",
window_size: int = 110,
process_length: int = 195,
overlap: int = 25,
max_res: int = 1024,
target_fps: int = 15,
seed: int = 42,
track_time: bool = True,
save_npz: bool = False,
):
set_seed(seed)
frames, target_fps = read_video_frames(
video, process_length, target_fps, max_res
)
print(f"==> video name: {video}, frames shape: {frames.shape}")
# inference the depth map using the DepthCrafter pipeline
with torch.inference_mode():
res = self.pipe(
frames,
height=frames.shape[1],
width=frames.shape[2],
output_type="np",
guidance_scale=guidance_scale,
num_inference_steps=num_denoising_steps,
window_size=window_size,
overlap=overlap,
track_time=track_time,
).frames[0]
# convert the three-channel output to a single channel depth map
res = res.sum(-1) / res.shape[-1]
# normalize the depth map to [0, 1] across the whole video
res = (res - res.min()) / (res.max() - res.min())
# visualize the depth map and save the results
vis = vis_sequence_depth(res)
# save the depth map and visualization with the target FPS
save_path = os.path.join(
save_folder, os.path.splitext(os.path.basename(video))[0]
)
os.makedirs(os.path.dirname(save_path), exist_ok=True)
if save_npz:
np.savez_compressed(save_path + ".npz", depth=res)
save_video(res, save_path + "_depth.mp4", fps=target_fps)
save_video(vis, save_path + "_vis.mp4", fps=target_fps)
save_video(frames, save_path + "_input.mp4", fps=target_fps)
return [
save_path + "_input.mp4",
save_path + "_vis.mp4",
save_path + "_depth.mp4",
]
def run(
self,
input_video,
num_denoising_steps,
guidance_scale,
max_res=1024,
process_length=195,
):
res_path = self.infer(
input_video,
num_denoising_steps,
guidance_scale,
max_res=max_res,
process_length=process_length,
)
# clear the cache for the next video
gc.collect()
torch.cuda.empty_cache()
return res_path[:2]
if __name__ == "__main__":
# running configs
# the most important arguments for memory saving are `cpu_offload`, `enable_xformers`, `max_res`, and `window_size`
# the most important arguments for trade-off between quality and speed are
# `num_inference_steps`, `guidance_scale`, and `max_res`
parser = argparse.ArgumentParser(description="DepthCrafter")
parser.add_argument(
"--video-path", type=str, required=True, help="Path to the input video file(s)"
)
parser.add_argument(
"--save-folder",
type=str,
default="./demo_output",
help="Folder to save the output",
)
parser.add_argument(
"--unet-path",
type=str,
default="tencent/DepthCrafter",
help="Path to the UNet model",
)
parser.add_argument(
"--pre-train-path",
type=str,
default="stabilityai/stable-video-diffusion-img2vid-xt",
help="Path to the pre-trained model",
)
parser.add_argument(
"--process-length", type=int, default=195, help="Number of frames to process"
)
parser.add_argument(
"--cpu-offload",
type=str,
default="model",
choices=["model", "sequential", None],
help="CPU offload option",
)
parser.add_argument(
"--target-fps", type=int, default=15, help="Target FPS for the output video"
) # -1 for original fps
parser.add_argument("--seed", type=int, default=42, help="Random seed")
parser.add_argument(
"--num-inference-steps", type=int, default=25, help="Number of inference steps"
)
parser.add_argument(
"--guidance-scale", type=float, default=1.2, help="Guidance scale"
)
parser.add_argument("--window-size", type=int, default=110, help="Window size")
parser.add_argument("--overlap", type=int, default=25, help="Overlap size")
parser.add_argument("--max-res", type=int, default=1024, help="Maximum resolution")
parser.add_argument("--save_npz", type=bool, default=True, help="Save npz file")
parser.add_argument("--track_time", type=bool, default=False, help="Track time")
args = parser.parse_args()
depthcrafter_demo = DepthCrafterDemo(
unet_path=args.unet_path,
pre_train_path=args.pre_train_path,
cpu_offload=args.cpu_offload,
)
# process the videos, the video paths are separated by comma
video_paths = args.video_path.split(",")
for video in video_paths:
depthcrafter_demo.infer(
video,
args.num_inference_steps,
args.guidance_scale,
save_folder=args.save_folder,
window_size=args.window_size,
process_length=args.process_length,
overlap=args.overlap,
max_res=args.max_res,
target_fps=args.target_fps,
seed=args.seed,
track_time=args.track_time,
save_npz=args.save_npz,
)
# clear the cache for the next video
gc.collect()
torch.cuda.empty_cache()