mirror of
https://github.com/Tencent/DepthCrafter.git
synced 2024-09-25 23:28:07 +03:00
367 lines
14 KiB
Python
367 lines
14 KiB
Python
from typing import Callable, Dict, List, Optional, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import (
|
|
_resize_with_antialiasing,
|
|
StableVideoDiffusionPipelineOutput,
|
|
StableVideoDiffusionPipeline,
|
|
retrieve_timesteps,
|
|
)
|
|
from diffusers.utils import logging
|
|
from diffusers.utils.torch_utils import randn_tensor
|
|
|
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
|
|
|
|
class DepthCrafterPipeline(StableVideoDiffusionPipeline):
|
|
|
|
@torch.inference_mode()
|
|
def encode_video(
|
|
self,
|
|
video: torch.Tensor,
|
|
chunk_size: int = 14,
|
|
) -> torch.Tensor:
|
|
"""
|
|
:param video: [b, c, h, w] in range [-1, 1], the b may contain multiple videos or frames
|
|
:param chunk_size: the chunk size to encode video
|
|
:return: image_embeddings in shape of [b, 1024]
|
|
"""
|
|
|
|
video_224 = _resize_with_antialiasing(video.float(), (224, 224))
|
|
video_224 = (video_224 + 1.0) / 2.0 # [-1, 1] -> [0, 1]
|
|
|
|
embeddings = []
|
|
for i in range(0, video_224.shape[0], chunk_size):
|
|
tmp = self.feature_extractor(
|
|
images=video_224[i : i + chunk_size],
|
|
do_normalize=True,
|
|
do_center_crop=False,
|
|
do_resize=False,
|
|
do_rescale=False,
|
|
return_tensors="pt",
|
|
).pixel_values.to(video.device, dtype=video.dtype)
|
|
embeddings.append(self.image_encoder(tmp).image_embeds) # [b, 1024]
|
|
|
|
embeddings = torch.cat(embeddings, dim=0) # [t, 1024]
|
|
return embeddings
|
|
|
|
@torch.inference_mode()
|
|
def encode_vae_video(
|
|
self,
|
|
video: torch.Tensor,
|
|
chunk_size: int = 14,
|
|
):
|
|
"""
|
|
:param video: [b, c, h, w] in range [-1, 1], the b may contain multiple videos or frames
|
|
:param chunk_size: the chunk size to encode video
|
|
:return: vae latents in shape of [b, c, h, w]
|
|
"""
|
|
video_latents = []
|
|
for i in range(0, video.shape[0], chunk_size):
|
|
video_latents.append(
|
|
self.vae.encode(video[i : i + chunk_size]).latent_dist.mode()
|
|
)
|
|
video_latents = torch.cat(video_latents, dim=0)
|
|
return video_latents
|
|
|
|
@staticmethod
|
|
def check_inputs(video, height, width):
|
|
"""
|
|
:param video:
|
|
:param height:
|
|
:param width:
|
|
:return:
|
|
"""
|
|
if not isinstance(video, torch.Tensor) and not isinstance(video, np.ndarray):
|
|
raise ValueError(
|
|
f"Expected `video` to be a `torch.Tensor` or `VideoReader`, but got a {type(video)}"
|
|
)
|
|
|
|
if height % 8 != 0 or width % 8 != 0:
|
|
raise ValueError(
|
|
f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
|
|
)
|
|
|
|
@torch.no_grad()
|
|
def __call__(
|
|
self,
|
|
video: Union[np.ndarray, torch.Tensor],
|
|
height: int = 576,
|
|
width: int = 1024,
|
|
num_inference_steps: int = 25,
|
|
guidance_scale: float = 1.0,
|
|
window_size: Optional[int] = 110,
|
|
noise_aug_strength: float = 0.02,
|
|
decode_chunk_size: Optional[int] = None,
|
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
|
latents: Optional[torch.FloatTensor] = None,
|
|
output_type: Optional[str] = "pil",
|
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
|
return_dict: bool = True,
|
|
overlap: int = 25,
|
|
track_time: bool = False,
|
|
):
|
|
"""
|
|
:param video: in shape [t, h, w, c] if np.ndarray or [t, c, h, w] if torch.Tensor, in range [0, 1]
|
|
:param height:
|
|
:param width:
|
|
:param num_inference_steps:
|
|
:param guidance_scale:
|
|
:param window_size: sliding window processing size
|
|
:param fps:
|
|
:param motion_bucket_id:
|
|
:param noise_aug_strength:
|
|
:param decode_chunk_size:
|
|
:param generator:
|
|
:param latents:
|
|
:param output_type:
|
|
:param callback_on_step_end:
|
|
:param callback_on_step_end_tensor_inputs:
|
|
:param return_dict:
|
|
:return:
|
|
"""
|
|
# 0. Default height and width to unet
|
|
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
|
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
|
num_frames = video.shape[0]
|
|
decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else 8
|
|
if num_frames <= window_size:
|
|
window_size = num_frames
|
|
overlap = 0
|
|
stride = window_size - overlap
|
|
|
|
# 1. Check inputs. Raise error if not correct
|
|
self.check_inputs(video, height, width)
|
|
|
|
# 2. Define call parameters
|
|
batch_size = 1
|
|
device = self._execution_device
|
|
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
|
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
|
# corresponds to doing no classifier free guidance.
|
|
self._guidance_scale = guidance_scale
|
|
|
|
# 3. Encode input video
|
|
if isinstance(video, np.ndarray):
|
|
video = torch.from_numpy(video.transpose(0, 3, 1, 2))
|
|
else:
|
|
assert isinstance(video, torch.Tensor)
|
|
video = video.to(device=device, dtype=self.dtype)
|
|
video = video * 2.0 - 1.0 # [0,1] -> [-1,1], in [t, c, h, w]
|
|
|
|
if track_time:
|
|
start_event = torch.cuda.Event(enable_timing=True)
|
|
encode_event = torch.cuda.Event(enable_timing=True)
|
|
denoise_event = torch.cuda.Event(enable_timing=True)
|
|
decode_event = torch.cuda.Event(enable_timing=True)
|
|
start_event.record()
|
|
|
|
video_embeddings = self.encode_video(
|
|
video, chunk_size=decode_chunk_size
|
|
).unsqueeze(
|
|
0
|
|
) # [1, t, 1024]
|
|
torch.cuda.empty_cache()
|
|
# 4. Encode input image using VAE
|
|
noise = randn_tensor(
|
|
video.shape, generator=generator, device=device, dtype=video.dtype
|
|
)
|
|
video = video + noise_aug_strength * noise # in [t, c, h, w]
|
|
|
|
# pdb.set_trace()
|
|
needs_upcasting = (
|
|
self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
|
)
|
|
if needs_upcasting:
|
|
self.vae.to(dtype=torch.float32)
|
|
|
|
video_latents = self.encode_vae_video(
|
|
video.to(self.vae.dtype),
|
|
chunk_size=decode_chunk_size,
|
|
).unsqueeze(
|
|
0
|
|
) # [1, t, c, h, w]
|
|
|
|
if track_time:
|
|
encode_event.record()
|
|
torch.cuda.synchronize()
|
|
elapsed_time_ms = start_event.elapsed_time(encode_event)
|
|
print(f"Elapsed time for encoding video: {elapsed_time_ms} ms")
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
# cast back to fp16 if needed
|
|
if needs_upcasting:
|
|
self.vae.to(dtype=torch.float16)
|
|
|
|
# 5. Get Added Time IDs
|
|
added_time_ids = self._get_add_time_ids(
|
|
7,
|
|
127,
|
|
noise_aug_strength,
|
|
video_embeddings.dtype,
|
|
batch_size,
|
|
1,
|
|
False,
|
|
) # [1 or 2, 3]
|
|
added_time_ids = added_time_ids.to(device)
|
|
|
|
# 6. Prepare timesteps
|
|
timesteps, num_inference_steps = retrieve_timesteps(
|
|
self.scheduler, num_inference_steps, device, None, None
|
|
)
|
|
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
|
self._num_timesteps = len(timesteps)
|
|
|
|
# 7. Prepare latent variables
|
|
num_channels_latents = self.unet.config.in_channels
|
|
latents_init = self.prepare_latents(
|
|
batch_size,
|
|
window_size,
|
|
num_channels_latents,
|
|
height,
|
|
width,
|
|
video_embeddings.dtype,
|
|
device,
|
|
generator,
|
|
latents,
|
|
) # [1, t, c, h, w]
|
|
latents_all = None
|
|
|
|
idx_start = 0
|
|
if overlap > 0:
|
|
weights = torch.linspace(0, 1, overlap, device=device)
|
|
weights = weights.view(1, overlap, 1, 1, 1)
|
|
else:
|
|
weights = None
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
# inference strategy for long videos
|
|
# two main strategies: 1. noise init from previous frame, 2. segments stitching
|
|
while idx_start < num_frames - overlap:
|
|
idx_end = min(idx_start + window_size, num_frames)
|
|
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
|
|
|
# 9. Denoising loop
|
|
latents = latents_init[:, : idx_end - idx_start].clone()
|
|
latents_init = torch.cat(
|
|
[latents_init[:, -overlap:], latents_init[:, :stride]], dim=1
|
|
)
|
|
|
|
video_latents_current = video_latents[:, idx_start:idx_end]
|
|
video_embeddings_current = video_embeddings[:, idx_start:idx_end]
|
|
|
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
|
for i, t in enumerate(timesteps):
|
|
if latents_all is not None and i == 0:
|
|
latents[:, :overlap] = (
|
|
latents_all[:, -overlap:]
|
|
+ latents[:, :overlap]
|
|
/ self.scheduler.init_noise_sigma
|
|
* self.scheduler.sigmas[i]
|
|
)
|
|
|
|
latent_model_input = latents # [1, t, c, h, w]
|
|
latent_model_input = self.scheduler.scale_model_input(
|
|
latent_model_input, t
|
|
) # [1, t, c, h, w]
|
|
latent_model_input = torch.cat(
|
|
[latent_model_input, video_latents_current], dim=2
|
|
)
|
|
noise_pred = self.unet(
|
|
latent_model_input,
|
|
t,
|
|
encoder_hidden_states=video_embeddings_current,
|
|
added_time_ids=added_time_ids,
|
|
return_dict=False,
|
|
)[0]
|
|
# perform guidance
|
|
if self.do_classifier_free_guidance:
|
|
latent_model_input = latents
|
|
latent_model_input = self.scheduler.scale_model_input(
|
|
latent_model_input, t
|
|
)
|
|
latent_model_input = torch.cat(
|
|
[latent_model_input, torch.zeros_like(latent_model_input)],
|
|
dim=2,
|
|
)
|
|
noise_pred_uncond = self.unet(
|
|
latent_model_input,
|
|
t,
|
|
encoder_hidden_states=torch.zeros_like(
|
|
video_embeddings_current
|
|
),
|
|
added_time_ids=added_time_ids,
|
|
return_dict=False,
|
|
)[0]
|
|
|
|
noise_pred = noise_pred_uncond + self.guidance_scale * (
|
|
noise_pred - noise_pred_uncond
|
|
)
|
|
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
|
|
|
|
if callback_on_step_end is not None:
|
|
callback_kwargs = {}
|
|
for k in callback_on_step_end_tensor_inputs:
|
|
callback_kwargs[k] = locals()[k]
|
|
callback_outputs = callback_on_step_end(
|
|
self, i, t, callback_kwargs
|
|
)
|
|
|
|
latents = callback_outputs.pop("latents", latents)
|
|
|
|
if i == len(timesteps) - 1 or (
|
|
(i + 1) > num_warmup_steps
|
|
and (i + 1) % self.scheduler.order == 0
|
|
):
|
|
progress_bar.update()
|
|
|
|
if latents_all is None:
|
|
latents_all = latents.clone()
|
|
else:
|
|
assert weights is not None
|
|
# latents_all[:, -overlap:] = (
|
|
# latents[:, :overlap] + latents_all[:, -overlap:]
|
|
# ) / 2.0
|
|
latents_all[:, -overlap:] = latents[
|
|
:, :overlap
|
|
] * weights + latents_all[:, -overlap:] * (1 - weights)
|
|
latents_all = torch.cat([latents_all, latents[:, overlap:]], dim=1)
|
|
|
|
idx_start += stride
|
|
|
|
if track_time:
|
|
denoise_event.record()
|
|
torch.cuda.synchronize()
|
|
elapsed_time_ms = encode_event.elapsed_time(denoise_event)
|
|
print(f"Elapsed time for denoising video: {elapsed_time_ms} ms")
|
|
|
|
if not output_type == "latent":
|
|
# cast back to fp16 if needed
|
|
if needs_upcasting:
|
|
self.vae.to(dtype=torch.float16)
|
|
frames = self.decode_latents(latents_all, num_frames, decode_chunk_size)
|
|
|
|
if track_time:
|
|
decode_event.record()
|
|
torch.cuda.synchronize()
|
|
elapsed_time_ms = denoise_event.elapsed_time(decode_event)
|
|
print(f"Elapsed time for decoding video: {elapsed_time_ms} ms")
|
|
|
|
frames = self.video_processor.postprocess_video(
|
|
video=frames, output_type=output_type
|
|
)
|
|
else:
|
|
frames = latents_all
|
|
|
|
self.maybe_free_model_hooks()
|
|
|
|
if not return_dict:
|
|
return frames
|
|
|
|
return StableVideoDiffusionPipelineOutput(frames=frames)
|