Add support for locked image towers and pretrained image tower weight load (timm models only)
This commit is contained in:
committed by
Ross Wightman
parent
8f24b79092
commit
2856dad5b6
@@ -64,6 +64,7 @@ def create_model(
|
||||
device: torch.device = torch.device('cpu'),
|
||||
jit: bool = False,
|
||||
force_quick_gelu: bool = False,
|
||||
pretrained_image: bool = False,
|
||||
):
|
||||
model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
|
||||
pretrained = pretrained.lower()
|
||||
@@ -84,7 +85,14 @@ def create_model(
|
||||
if force_quick_gelu:
|
||||
# override for use of QuickGELU on non-OpenAI transformer models
|
||||
model_cfg["quick_gelu"] = True
|
||||
|
||||
|
||||
if pretrained_image:
|
||||
if 'timm_model_name' in model_cfg.get('vision_cfg', {}):
|
||||
# pretrained weight loading for timm models set via vision_cfg
|
||||
model_cfg['vision_cfg']['timm_model_pretrained'] = True
|
||||
else:
|
||||
assert False, 'pretrained image towers currently only supported for timm models'
|
||||
|
||||
model = CLIP(**model_cfg)
|
||||
|
||||
if pretrained:
|
||||
@@ -120,8 +128,12 @@ def create_model_and_transforms(
|
||||
device: torch.device = torch.device('cpu'),
|
||||
jit: bool = False,
|
||||
force_quick_gelu: bool = False,
|
||||
pretrained_image: bool = False,
|
||||
):
|
||||
model = create_model(model_name, pretrained, precision, device, jit, force_quick_gelu)
|
||||
model = create_model(
|
||||
model_name, pretrained, precision, device, jit,
|
||||
force_quick_gelu=force_quick_gelu,
|
||||
pretrained_image=pretrained_image)
|
||||
preprocess_train = image_transform(model.visual.image_size, is_train=True)
|
||||
preprocess_val = image_transform(model.visual.image_size, is_train=False)
|
||||
return model, preprocess_train, preprocess_val
|
||||
|
||||
@@ -13,6 +13,7 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from .timm_model import TimmModel
|
||||
from .utils import freeze_batch_norm_2d
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
@@ -155,6 +156,13 @@ class ModifiedResNet(nn.Module):
|
||||
if name.endswith("bn3.weight"):
|
||||
nn.init.zeros_(param)
|
||||
|
||||
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
||||
assert unlocked_groups == 0, 'partial locking not currently supported for this model'
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
if freeze_bn_stats:
|
||||
freeze_batch_norm_2d(self)
|
||||
|
||||
def stem(self, x):
|
||||
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
|
||||
x = self.relu(bn(conv(x)))
|
||||
@@ -244,6 +252,11 @@ class VisualTransformer(nn.Module):
|
||||
self.ln_post = LayerNorm(width)
|
||||
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
||||
|
||||
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
||||
assert unlocked_groups == 0, 'partial locking not currently supported for this model'
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = self.conv1(x) # shape = [*, width, grid, grid]
|
||||
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
||||
@@ -385,6 +398,10 @@ class CLIP(nn.Module):
|
||||
mask.triu_(1) # zero out the lower diagonal
|
||||
return mask
|
||||
|
||||
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
|
||||
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
|
||||
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
|
||||
|
||||
def encode_image(self, image):
|
||||
return self.visual(image)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"embed_dim": 1024,
|
||||
"vision_cfg": {
|
||||
"timm_model_name": "resnetblur50d",
|
||||
"timm_model_name": "resnetblur50",
|
||||
"timm_model_pretrained": false,
|
||||
"timm_pool": "abs_attn",
|
||||
"timm_proj": "",
|
||||
|
||||
@@ -14,6 +14,8 @@ try:
|
||||
except ImportError as e:
|
||||
timm = None
|
||||
|
||||
from .utils import freeze_batch_norm_2d
|
||||
|
||||
|
||||
class TimmModel(nn.Module):
|
||||
""" timm model adapter
|
||||
@@ -66,6 +68,38 @@ class TimmModel(nn.Module):
|
||||
|
||||
self.head = nn.Sequential(head_layers)
|
||||
|
||||
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
||||
""" lock modules
|
||||
Args:
|
||||
unlocked_groups (int): leave last n layer groups unlocked (default: 0)
|
||||
"""
|
||||
if not unlocked_groups:
|
||||
# lock full model
|
||||
for param in self.trunk.parameters():
|
||||
param.requires_grad = False
|
||||
if freeze_bn_stats:
|
||||
freeze_batch_norm_2d(self.trunk)
|
||||
else:
|
||||
# NOTE: partial freeze requires latest timm (master) branch and is subject to change
|
||||
try:
|
||||
# FIXME import here until API stable and in an official release
|
||||
from timm.models.helpers import group_parameters, group_modules
|
||||
except ImportError:
|
||||
raise RuntimeError(
|
||||
'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
|
||||
matcher = self.trunk.group_matcher()
|
||||
gparams = group_parameters(self.trunk, matcher)
|
||||
max_layer_id = max(gparams.keys())
|
||||
max_layer_id = max_layer_id - unlocked_groups
|
||||
for group_idx in range(max_layer_id + 1):
|
||||
group = gparams[group_idx]
|
||||
for param in group:
|
||||
self.trunk.get_parameter(param).requires_grad = False
|
||||
if freeze_bn_stats:
|
||||
gmodules = group_modules(self.trunk, matcher, reverse=True)
|
||||
gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
|
||||
freeze_batch_norm_2d(self.trunk, gmodules)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.trunk(x)
|
||||
x = self.head(x)
|
||||
|
||||
41
src/open_clip/utils.py
Normal file
41
src/open_clip/utils.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from torch import nn as nn
|
||||
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
|
||||
|
||||
def freeze_batch_norm_2d(module, module_match={}, name=''):
|
||||
"""
|
||||
Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
|
||||
itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
|
||||
returned. Otherwise, the module is walked recursively and submodules are converted in place.
|
||||
|
||||
Args:
|
||||
module (torch.nn.Module): Any PyTorch module.
|
||||
module_match (dict): Dictionary of full module names to freeze (all if empty)
|
||||
name (str): Full module name (prefix)
|
||||
|
||||
Returns:
|
||||
torch.nn.Module: Resulting module
|
||||
|
||||
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
|
||||
"""
|
||||
res = module
|
||||
is_match = True
|
||||
if module_match:
|
||||
is_match = name in module_match
|
||||
if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
|
||||
res = FrozenBatchNorm2d(module.num_features)
|
||||
res.num_features = module.num_features
|
||||
res.affine = module.affine
|
||||
if module.affine:
|
||||
res.weight.data = module.weight.data.clone().detach()
|
||||
res.bias.data = module.bias.data.clone().detach()
|
||||
res.running_mean.data = module.running_mean.data
|
||||
res.running_var.data = module.running_var.data
|
||||
res.eps = module.eps
|
||||
else:
|
||||
for child_name, child in module.named_children():
|
||||
full_child_name = '.'.join([name, child_name]) if name else child_name
|
||||
new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
|
||||
if new_child is not child:
|
||||
res.add_module(child_name, new_child)
|
||||
return res
|
||||
@@ -116,13 +116,20 @@ def main():
|
||||
args.pretrained,
|
||||
precision=args.precision,
|
||||
device=device,
|
||||
force_quick_gelu=args.force_quick_gelu,
|
||||
jit=args.torchscript,
|
||||
force_quick_gelu=args.force_quick_gelu,
|
||||
pretrained_image=args.pretrained_image,
|
||||
)
|
||||
|
||||
if args.trace:
|
||||
model = trace_model(model, batch_size=args.batch_size, device=device)
|
||||
|
||||
if args.lock_image:
|
||||
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
|
||||
model.lock_image_tower(
|
||||
unlocked_groups=args.lock_image_unlocked_groups,
|
||||
freeze_bn_stats=args.lock_image_freeze_bn_stats)
|
||||
|
||||
if is_master(args):
|
||||
logging.info("Model:")
|
||||
logging.info(f"{str(model)}")
|
||||
|
||||
@@ -151,6 +151,36 @@ def parse_args():
|
||||
default="RN50",
|
||||
help="Name of the vision backbone to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained",
|
||||
default='',
|
||||
type=str,
|
||||
help="Use a pretrained CLIP model weights with the specified tag or file path.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained-image",
|
||||
default=False,
|
||||
action='store_true',
|
||||
help="Load imagenet pretrained weights for image tower backbone if available.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lock-image",
|
||||
default=False,
|
||||
action='store_true',
|
||||
help="Lock full image tower by disabling gradients.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lock-image-unlocked-groups",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Leave last n image tower layer groups unlocked.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lock-image-freeze-bn-stats",
|
||||
default=False,
|
||||
action='store_true',
|
||||
help="Freeze BatchNorm running stats in image tower for any locked layers.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-loss",
|
||||
default=False,
|
||||
@@ -163,12 +193,6 @@ def parse_args():
|
||||
action="store_true",
|
||||
help="enable full distributed gradient for feature gather"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained",
|
||||
default='',
|
||||
type=str,
|
||||
help="Use a pretrained model with the specified tag or file path.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force-quick-gelu",
|
||||
default=False,
|
||||
|
||||
Reference in New Issue
Block a user