mirror of
https://github.com/apple/ml-4m.git
synced 2024-07-16 14:20:27 +03:00
add DINOv2 to FEATURE_TASKS
This commit is contained in:
@@ -34,8 +34,7 @@ from fourm.data.modality_info import MODALITY_TRANSFORMS_DIVAE
|
||||
from fourm.vq import get_image_tokenizer
|
||||
import fourm.utils.clip as clip
|
||||
|
||||
|
||||
FEATURE_TASKS = ['CLIP-B16']
|
||||
FEATURE_TASKS = ['CLIP-B16', 'DINOv2-B14', 'DINOv2-B14-global']
|
||||
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp", ".jpx", ".gif")
|
||||
|
||||
def find_image_extension(root_dir):
|
||||
@@ -191,6 +190,9 @@ def get_feature_extractor(args):
|
||||
teacher_model, _ = clip.load("ViT-B/16", device='cpu', jit=False)
|
||||
teacher_model = teacher_model.visual
|
||||
return teacher_model.eval()
|
||||
elif args.task in ['DINOv2-B14', 'DINOv2-B14-global']:
|
||||
teacher_model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
|
||||
return teacher_model.eval()
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -203,7 +205,7 @@ def main(args):
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
model = get_image_tokenizer(args.tokenizer_id, tokenizers_root=args.tokenizers_root, encoder_only=True)
|
||||
model, _ = get_image_tokenizer(args.tokenizer_id, tokenizers_root=args.tokenizers_root, encoder_only=True)
|
||||
feature_extractor = get_feature_extractor(args)
|
||||
|
||||
num_tasks = utils.get_world_size()
|
||||
@@ -271,8 +273,21 @@ def main(args):
|
||||
N_H, N_W = H // P_H, W // P_W
|
||||
sub_batch = feature_extractor(sub_batch, return_final_tokens_no_cls=True)
|
||||
sub_batch = rearrange(sub_batch, 'b (nh nw) d -> b d nh nw', nh=N_H, nw=N_W)
|
||||
if 'DINO' in args.task:
|
||||
B, C, H, W = sub_batch.shape
|
||||
P_H, P_W = feature_extractor.patch_embed.proj.kernel_size
|
||||
N_H, N_W = H // P_H, W // P_W
|
||||
sub_batch = feature_extractor(sub_batch, is_training=True)
|
||||
if 'global' in args.task:
|
||||
sub_batch = sub_batch['x_norm_clstoken']
|
||||
sub_batch = sub_batch.unsqueeze(2).unsqueeze(2)
|
||||
else:
|
||||
sub_batch = sub_batch['x_norm_patchtokens']
|
||||
sub_batch = rearrange(sub_batch, 'b (nh nw) d -> b d nh nw', nh=N_H, nw=N_W)
|
||||
|
||||
tokens = model.tokenize(sub_batch)
|
||||
if tokens.size(-1)==1: # For the global embedding tokens, squeeze the last dimension
|
||||
tokens = tokens.squeeze(2)
|
||||
tokens = rearrange(tokens, "b h w -> b (h w)")
|
||||
|
||||
tokens = tokens.detach().cpu().numpy().astype(np.int16)
|
||||
|
||||
Reference in New Issue
Block a user