add DINOv2 to FEATURE_TASKS

This commit is contained in:
Ali Garjnai
2024-06-28 15:51:23 +00:00
parent 8ab60b7519
commit 304f2f2d3a

View File

@@ -34,8 +34,7 @@ from fourm.data.modality_info import MODALITY_TRANSFORMS_DIVAE
from fourm.vq import get_image_tokenizer
import fourm.utils.clip as clip
FEATURE_TASKS = ['CLIP-B16']
FEATURE_TASKS = ['CLIP-B16', 'DINOv2-B14', 'DINOv2-B14-global']
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp", ".jpx", ".gif")
def find_image_extension(root_dir):
@@ -191,6 +190,9 @@ def get_feature_extractor(args):
teacher_model, _ = clip.load("ViT-B/16", device='cpu', jit=False)
teacher_model = teacher_model.visual
return teacher_model.eval()
elif args.task in ['DINOv2-B14', 'DINOv2-B14-global']:
teacher_model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
return teacher_model.eval()
else:
return None
@@ -203,7 +205,7 @@ def main(args):
np.random.seed(seed)
random.seed(seed)
model = get_image_tokenizer(args.tokenizer_id, tokenizers_root=args.tokenizers_root, encoder_only=True)
model, _ = get_image_tokenizer(args.tokenizer_id, tokenizers_root=args.tokenizers_root, encoder_only=True)
feature_extractor = get_feature_extractor(args)
num_tasks = utils.get_world_size()
@@ -271,8 +273,21 @@ def main(args):
N_H, N_W = H // P_H, W // P_W
sub_batch = feature_extractor(sub_batch, return_final_tokens_no_cls=True)
sub_batch = rearrange(sub_batch, 'b (nh nw) d -> b d nh nw', nh=N_H, nw=N_W)
if 'DINO' in args.task:
B, C, H, W = sub_batch.shape
P_H, P_W = feature_extractor.patch_embed.proj.kernel_size
N_H, N_W = H // P_H, W // P_W
sub_batch = feature_extractor(sub_batch, is_training=True)
if 'global' in args.task:
sub_batch = sub_batch['x_norm_clstoken']
sub_batch = sub_batch.unsqueeze(2).unsqueeze(2)
else:
sub_batch = sub_batch['x_norm_patchtokens']
sub_batch = rearrange(sub_batch, 'b (nh nw) d -> b d nh nw', nh=N_H, nw=N_W)
tokens = model.tokenize(sub_batch)
if tokens.size(-1)==1: # For the global embedding tokens, squeeze the last dimension
tokens = tokens.squeeze(2)
tokens = rearrange(tokens, "b h w -> b (h w)")
tokens = tokens.detach().cpu().numpy().astype(np.int16)