Initial commit.

This commit is contained in:
mitchellnw
2021-07-28 16:27:40 -07:00
parent 7eb34cff64
commit 0caa0dc09b
30 changed files with 2872 additions and 2 deletions

150
.gitignore vendored Normal file
View File

@@ -0,0 +1,150 @@
logs/
wandb/
models/
features/
results/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
sync.sh
gpu1sync.sh
.idea
*.pdf
**/._*
**/*DS_*
**.jsonl
src/sbatch
src/misc
.vscode
src/debug
core.*
# Allow
!src/evaluation/misc/results_dbs/*

28
CITATION.cff Normal file
View File

@@ -0,0 +1,28 @@
cff-version: 1.1.0
message: If you use this software, please cite it as below.
authors:
- family-names: Ilharco
given-names: Gabriel
- family-names: Wortsman
given-names: Mitchell
- family-names: Carlini
given-names: Nicholas
- family-names: Taori
given-names: Rohan
- family-names: Dave
given-names: Achal
- family-names: Shankar
given-names: Vaishaal
- family-names: Namkoong
given-names: Hongseok
- family-names: Miller
given-names: John
- family-names: Hajishirzi
given-names: Hannaneh
- family-names: Farhadi
given-names: Ali
- family-names: Schmidt
given-names: Ludwig
title: Open Clip
version: 0.0.1
date-released: 2021-07-28

20
LICENSE Normal file
View File

@@ -0,0 +1,20 @@
Copyright (c) 2012-2021 Scott Chacon and others
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

141
README.md
View File

@@ -1,2 +1,139 @@
# open_clip
An open source implementation of CLIP.
# OpenCLIP
Welcome to an open source implementation of OpenAI's [CLIP](https://arxiv.org/abs/2103.00020) (Contrastive Language-Image Pre-training).
The goal of this repository is to match the accuracy of the original CLIP models when trained on the same dataset. For example, our implementation reaches 22.2% top-1 ImageNet when training a ResNet 50x4 on the 3 million images in the Conceptual Captions dataset, and 32.7% top-1 ImageNet accuracy when training a RN50 on OpenAI's [15 million image subset of YFCC](https://github.com/openai/CLIP/blob/main/data/yfcc100m.md). OpenAI's CLIP model reaches 31.3% on the same subset of YFCC.
Note that `src/clip` is a copy of OpenAI's official [repo](https://github.com/openai/CLIP) with minimal changes.
## Data
### Conceptual Captions
OpenCLIP reads a CSV file with two columns: a path to an image, and a text caption. The names of the columns are passed as an argument to `main.py`.
The script `src/data/gather_cc.py` will collect the Conceptual Captions images. First, download the [Conceptual Captions URLs](https://ai.google.com/research/ConceptualCaptions/download) and then run the following script from our repository:
```
python3 src/data/gather_cc.py path/to/Train_GCC-training.tsv path/to/Validation_GCC-1.1.0-Validation.tsv
```
Our training set contains 2.89M images, and our validation set contains 13K images.
### YFCC and other datasets
In addition to specifying the training data via CSV files as mentioned above, our codebase also supports [webdataset](https://github.com/webdataset/webdataset), which is recommended for larger scale datasets. The expected format is a series of `.tar` files. Each of these `.tar` files should contain two files for each training example, one for the image and one for the corresponding text. Both files should have the same name but different extensions. For instance, `shard_001.tar` could contain files such as `abc.jpg` and `abc.txt`. You can learn more about `webdataset` at [https://github.com/webdataset/webdataset](https://github.com/webdataset/webdataset). We use `.tar` files with 1000 data points each, which we create using [tarp](https://github.com/webdataset/tarp).
You can download the YFCC dataset from [Multimedia Commons](http://mmcommons.org/).
Similar to OpenAI, we used a subset of YFCC to reach the aforementioned accuracy numbers.
The indices of images in this subset are in [OpenAI's CLIP repository](https://github.com/openai/CLIP/blob/main/data/yfcc100m.md).
## Training CLIP
### Install dependencies
```
conda env create -f environment.yml
source activate open_clip
```
### Add directory to pythonpath:
```
cd open_clip
export PYTHONPATH="$PYTHONPATH:$PWD/src"
```
### Sample running code:
```
nohup python -u src/training/main.py \
--save-frequency 1 \
--zeroshot-frequency 1 \
--report-to tensorboard \
--train-data="/path/to/train_data.csv" \
--val-data="/path/to/validation_data.csv" \
--csv-img-key filepath \
--csv-caption-key title \
--imagenet-val=/path/to/imagenet/root/val/ \
--warmup 10000 \
--batch-size=128 \
--lr=1e-3 \
--wd=0.1 \
--epochs=30 \
--workers=8
```
Note: `imagenet-val` is the path to the *val* set of ImageNet for zeroshot evaluation, not the train set!
You can remove this argument if you do not want to perform zeroshot on imagenet throughout training. Note that the `val` folder should contain subfolders, if it doesn't please use [this script](https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh).
This command should produce the following training curve:
![CLIP zero shot training curve](/docs/clip_zeroshot.png)
More detailed curves are given in [/docs/clip_conceptual_captions.md](/docs/clip_conceptual_captions.md)
### Launch tensorboard:
```
tensorboard --logdir=logs/tensorboard/ --port=7777
```
### Sample resuming from a checkpoint:
```bash
python src/training/main.py \
--train-data="/path/to/train_data.csv" \
--val-data="/path/to/validation_data.csv" \
--resume /path/to/checkpoints/epoch_K.pt
```
### Sample evaluation only:
```bash
python src/training/main.py \
--val-data="/path/to/validation_data.csv" \
--resume /path/to/checkpoints/epoch_K.pt
```
## Scaling trends
The plot below shows how zero-shot performance of CLIP models varies as we scale the number of samples used for training. Zero-shot performance increases steadily for both ImageNet and [ImageNetV2](https://arxiv.org/abs/1902.10811), and is far from saturated at ~15M samples.
<img src="docs/scaling.png" width="700">
## Why are low-accuracy CLIP models interesting?
**TL;DR:** CLIP models have high effective robustness, even at small scales.
CLIP models are particularly intriguing because they are more robust to natural distribution shifts.
This phenomena is illustrated by the figure below, with ImageNet accuracy on the x-axis
and [ImageNetV2](https://arxiv.org/abs/1902.10811) (a reproduction of the ImageNet) accuracy on the y-axis.
Standard training denotes training on the ImageNet train set while the CLIP zero-shot models
are shown with stars.
![CLIP scatter plot](/docs/effective_robustness.png)
As observed by [Taori et al., 2020](https://arxiv.org/abs/2007.00644), in-distribution
and out-of-distribution accuracy follow a predictable linear trend. Effective robustness
measures movement above this red line. Even though the models trained with
this codebase are much lower accuracy than those trained by OpenAI, they lie on the same
trend of improved effective robustness (purple line). Therefore, we can study what makes
CLIP robust without needing industry compute.
For more more information on effective robustness please see:
- [Recht et al., 2019](https://arxiv.org/abs/1902.10811).
- [Taori et al., 2020](https://arxiv.org/abs/2007.00644).
- [Miller et al., 2021](https://arxiv.org/abs/2107.04649).
## The Team
[Gabriel Ilharco*](http://gabrielilharco.com/), [Mitchell Wortsman*](https://mitchellnw.github.io/), [Nicholas Carlini](https://nicholas.carlini.com/), [Rohan Taori](https://www.rohantaori.com/), [Achal Dave](http://www.achaldave.com/), [Vaishaal Shankar](http://vaishaal.com/), [
Hongseok Namkoong](https://hsnamkoong.github.io/), [John Miller](https://people.eecs.berkeley.edu/~miller_john/), [Hannaneh Hajishirzi](https://homes.cs.washington.edu/~hannaneh/), [Ali Farhadi](https://homes.cs.washington.edu/~ali/), [Ludwig Schmidt](https://people.csail.mit.edu/ludwigs/)
Special thanks to Jong Wook Kim and Alec Radford!

View File

@@ -0,0 +1,13 @@
## Additional training curves for CLIP on Conceptual Captions
# Zero shot accuracy
![](/docs/clip_zeroshot.png)
# Training loss curve
![](/docs/clip_loss.png)
# Validation loss curve
![](/docs/clip_val_loss.png)
# Validation recall
![](/docs/clip_recall.png)

BIN
docs/clip_loss.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
docs/clip_recall.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 50 KiB

BIN
docs/clip_val_loss.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 43 KiB

BIN
docs/clip_zeroshot.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 57 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 995 KiB

BIN
docs/scaling.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 96 KiB

152
environment.yml Normal file
View File

@@ -0,0 +1,152 @@
name: open_clip
channels:
- pytorch
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- absl-py=0.12.0=py36h06a4308_0
- aiohttp=3.6.3=py36h7b6447c_0
- async-timeout=3.0.1=py36h06a4308_0
- attrs=20.3.0=pyhd3eb1b0_0
- blas=1.0=mkl
- blinker=1.4=py36h06a4308_0
- brotlipy=0.7.0=py36h27cfd23_1003
- c-ares=1.17.1=h27cfd23_0
- ca-certificates=2020.12.5=ha878542_0
- cachetools=4.2.1=pyhd3eb1b0_0
- certifi=2020.12.5=py36h5fab9bb_1
- cffi=1.14.5=py36h261ae71_0
- chardet=3.0.4=py36h06a4308_1003
- click=7.1.2=pyhd3eb1b0_0
- coverage=5.5=py36h27cfd23_2
- cryptography=3.4.7=py36hd23ed53_0
- cudatoolkit=11.0.221=h6bb024c_0
- cython=0.29.23=py36h2531618_0
- dataclasses=0.8=pyh4f3eec9_6
- faiss-gpu=1.4.0=py36_cuda8.0.61_1
- freetype=2.10.4=h5ab3b9f_0
- ftfy=5.8=py_0
- google-auth=1.29.0=pyhd3eb1b0_0
- google-auth-oauthlib=0.4.4=pyhd3eb1b0_0
- grpcio=1.36.1=py36h2157cd5_1
- idna=2.10=pyhd3eb1b0_0
- idna_ssl=1.1.0=py36h06a4308_0
- importlib-metadata=3.10.0=py36h06a4308_0
- intel-openmp=2021.2.0=h06a4308_610
- joblib=1.0.1=pyhd8ed1ab_0
- jpeg=9b=h024ee3a_2
- lcms2=2.12=h3be6417_0
- ld_impl_linux-64=2.33.1=h53a641e_7
- libblas=3.9.0=1_h6e990d7_netlib
- libcblas=3.9.0=3_h893e4fe_netlib
- libffi=3.3=he6710b0_2
- libgcc=7.2.0=h69d50b8_2
- libgcc-ng=9.1.0=hdf63c60_0
- libgfortran-ng=7.5.0=h14aa051_19
- libgfortran4=7.5.0=h14aa051_19
- liblapack=3.9.0=3_h893e4fe_netlib
- libpng=1.6.37=hbc83047_0
- libprotobuf=3.14.0=h8c45485_0
- libstdcxx-ng=9.1.0=hdf63c60_0
- libtiff=4.1.0=h2733197_1
- libuv=1.40.0=h7b6447c_0
- lz4-c=1.9.3=h2531618_0
- markdown=3.3.4=py36h06a4308_0
- mkl=2020.2=256
- mkl-service=2.3.0=py36he8ac12f_0
- mkl_fft=1.3.0=py36h54f3939_0
- mkl_random=1.1.1=py36h0573a6f_0
- multidict=4.7.6=py36h7b6447c_1
- ncurses=6.2=he6710b0_1
- ninja=1.10.2=hff7bd54_1
- numpy=1.19.2=py36h54aff64_0
- numpy-base=1.19.2=py36hfa32c7d_0
- oauthlib=3.1.0=py_0
- olefile=0.46=py36_0
- openssl=1.1.1k=h27cfd23_0
- pandas=1.1.3=py36he6710b0_0
- pillow=8.2.0=py36he98fc37_0
- pip=21.0.1=py36h06a4308_0
- protobuf=3.14.0=py36h2531618_1
- pyasn1=0.4.8=py_0
- pyasn1-modules=0.2.8=py_0
- pycparser=2.20=py_2
- pyjwt=1.7.1=py36_0
- pyopenssl=20.0.1=pyhd3eb1b0_1
- pysocks=1.7.1=py36h06a4308_0
- python=3.6.13=hdb3f193_0
- python-dateutil=2.8.1=pyhd3eb1b0_0
- python_abi=3.6=1_cp36m
- pytorch=1.7.1=py3.6_cuda11.0.221_cudnn8.0.5_0
- pytz=2021.1=pyhd3eb1b0_0
- readline=8.1=h27cfd23_0
- regex=2021.4.4=py36h27cfd23_0
- requests=2.25.1=pyhd3eb1b0_0
- requests-oauthlib=1.3.0=py_0
- rsa=4.7.2=pyhd3eb1b0_1
- scikit-learn=0.23.2=py36hb6e6923_3
- scipy=1.5.3=py36h976291a_0
- setuptools=52.0.0=py36h06a4308_0
- six=1.15.0=py36h06a4308_0
- sqlite=3.35.4=hdfb4753_0
- tensorboard=2.4.0=pyhc547734_0
- tensorboard-plugin-wit=1.6.0=py_0
- threadpoolctl=2.1.0=pyh5ca1d4c_0
- tk=8.6.10=hbc83047_0
- torchaudio=0.7.2=py36
- torchvision=0.8.2=py36_cu110
- tqdm=4.59.0=pyhd3eb1b0_1
- typing_extensions=3.7.4.3=pyha847dfd_0
- urllib3=1.26.4=pyhd3eb1b0_0
- wcwidth=0.2.5=py_0
- werkzeug=1.0.1=pyhd3eb1b0_0
- wheel=0.36.2=pyhd3eb1b0_0
- xz=5.2.5=h7b6447c_0
- yarl=1.6.3=py36h27cfd23_0
- zipp=3.4.1=pyhd3eb1b0_0
- zlib=1.2.11=h7b6447c_3
- zstd=1.4.9=haebb681_0
- pip:
- ase==3.21.1
- braceexpand==0.1.7
- cached-property==1.5.2
- configparser==5.0.2
- cycler==0.10.0
- decorator==4.4.2
- docker-pycreds==0.4.0
- gitdb==4.0.7
- gitpython==3.1.14
- googledrivedownloader==0.4
- h5py==3.1.0
- isodate==0.6.0
- jinja2==3.0.1
- kiwisolver==1.3.1
- littleutils==0.2.2
- llvmlite==0.36.0
- markupsafe==2.0.1
- matplotlib==3.3.4
- networkx==2.5.1
- numba==0.53.1
- ogb==1.3.1
- outdated==0.2.1
- pathtools==0.1.2
- promise==2.3
- psutil==5.8.0
- pyarrow==4.0.0
- pyparsing==2.4.7
- python-louvain==0.15
- pyyaml==5.4.1
- rdflib==5.0.0
- sentry-sdk==1.1.0
- shortuuid==1.0.1
- sklearn==0.0
- smmap==4.0.0
- subprocess32==3.5.4
- torch-geometric==1.7.0
- wandb==0.10.30
- wilds==1.1.0
- "--editable=git+https://github.com/tmbdev/webdataset.git@a4f3ec08551b42f20b20cdc1ba32d12536eabc15#egg=webdataset"
- git+https://github.com/modestyachts/ImageNetV2_pytorch
- https://pytorch-geometric.com/whl/torch-1.7.0+cu110/torch_scatter-2.0.6-cp36-cp36m-linux_x86_64.whl
prefix: /home/gamaga/anaconda3/envs/open_clip

Binary file not shown.

208
src/clip/clip.py Normal file
View File

@@ -0,0 +1,208 @@
# Code ported from https://github.com/openai/CLIP
import hashlib
import os
import urllib
import warnings
from typing import Union, List
import torch
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, RandomResizedCrop
from tqdm import tqdm
from clip.model import build_model
from clip.tokenizer import SimpleTokenizer as _Tokenizer
__all__ = ["available_models", "load", "tokenize"]
_tokenizer = _Tokenizer()
_MODELS = {
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
}
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
os.makedirs(root, exist_ok=True)
filename = os.path.basename(url)
expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, filename)
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
return download_target
else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
return download_target
def _convert_to_rgb(image):
return image.convert('RGB')
def _transform(n_px: int, is_train: bool):
normalize = Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
if is_train:
return Compose([
RandomResizedCrop(n_px, scale=(0.9, 1.0), interpolation=Image.BICUBIC),
_convert_to_rgb,
ToTensor(),
normalize,
])
else:
return Compose([
Resize(n_px, interpolation=Image.BICUBIC),
CenterCrop(n_px),
_convert_to_rgb,
ToTensor(),
normalize,
])
def available_models() -> List[str]:
"""Returns the names of available CLIP models"""
return list(_MODELS.keys())
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True, is_train=False, pretrained=True):
"""Load a CLIP model
Parameters
----------
name : str
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
device : Union[str, torch.device]
The device to put the loaded model
jit : bool
Whether to load the optimized JIT model (default) or more hackable non-JIT model.
Returns
-------
model : torch.nn.Module
The CLIP model
preprocess : Callable[[PIL.Image], torch.Tensor]
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
"""
if name in _MODELS:
model_path = _download(_MODELS[name])
elif os.path.isfile(name):
model_path = name
else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
try:
# loading JIT archive
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
state_dict = None
except RuntimeError:
# loading saved state dict
if jit:
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
jit = False
state_dict = torch.load(model_path, map_location="cpu")
if not jit:
try:
model = build_model(state_dict or model.state_dict()).to(device)
except KeyError:
sd = {k[7:]: v for k,v in state_dict["state_dict"].items()}
model = build_model(sd).to(device)
if str(device) == "cpu":
model.float()
return model, \
_transform(model.visual.input_resolution, is_train=True), \
_transform(model.visual.input_resolution, is_train=False)
# patch the device names
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
def patch_device(module):
graphs = [module.graph] if hasattr(module, "graph") else []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("prim::Constant"):
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
node.copyAttributes(device_node)
model.apply(patch_device)
patch_device(model.encode_image)
patch_device(model.encode_text)
# patch dtype to float32 on CPU
if str(device) == "cpu":
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
float_node = float_input.node()
def patch_float(module):
graphs = [module.graph] if hasattr(module, "graph") else []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("aten::to"):
inputs = list(node.inputs())
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
if inputs[i].node()["value"] == 5:
inputs[i].node().copyAttributes(float_node)
model.apply(patch_float)
patch_float(model.encode_image)
patch_float(model.encode_text)
model.float()
return model, \
_transform(model.input_resolution.item(), is_train=True), \
_transform(model.input_resolution.item(), is_train=False)
def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
"""
Returns the tokenized representation of given input string(s)
Parameters
----------
texts : Union[str, List[str]]
An input string or a list of input strings to tokenize
context_length : int
The context length to use; all CLIP models use 77 as the context length
Returns
-------
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
"""
if isinstance(texts, str):
texts = [texts]
sot_token = _tokenizer.encoder["<start_of_text>"]
eot_token = _tokenizer.encoder["<end_of_text>"]
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length: # Truncate
tokens = tokens[:context_length]
result[i, :len(tokens)] = torch.tensor(tokens)
return result

432
src/clip/model.py Normal file
View File

@@ -0,0 +1,432 @@
from collections import OrderedDict
from typing import Tuple, Union
import os
import json
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1):
super().__init__()
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = None
self.stride = stride
if stride > 1 or inplanes != planes * Bottleneck.expansion:
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
self.downsample = nn.Sequential(OrderedDict([
("-1", nn.AvgPool2d(stride)),
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
("1", nn.BatchNorm2d(planes * self.expansion))
]))
def forward(self, x: torch.Tensor):
identity = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.relu(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class AttentionPool2d(nn.Module):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x, _ = F.multi_head_attention_forward(
query=x, key=x, value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False
)
return x[0]
class ModifiedResNet(nn.Module):
"""
A ResNet class that is similar to torchvision's but contains the following changes:
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
"""
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
super().__init__()
self.output_dim = output_dim
self.input_resolution = input_resolution
# the 3-layer stem
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(width // 2)
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(width // 2)
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(width)
self.avgpool = nn.AvgPool2d(2)
self.relu = nn.ReLU(inplace=True)
# residual layers
self._inplanes = width # this is a *mutable* variable used during construction
self.layer1 = self._make_layer(width, layers[0])
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
embed_dim = width * 32 # the ResNet feature dimension
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
def _make_layer(self, planes, blocks, stride=1):
layers = [Bottleneck(self._inplanes, planes, stride)]
self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.append(Bottleneck(self._inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
def stem(x):
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
x = self.relu(bn(conv(x)))
x = self.avgpool(x)
return x
x = x.type(self.conv1.weight.dtype)
x = stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.attnpool(x)
return x
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Module):
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model))
]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x: torch.Tensor):
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class Transformer(nn.Module):
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
def forward(self, x: torch.Tensor):
return self.resblocks(x)
class VisualTransformer(nn.Module):
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
super().__init__()
self.input_resolution = input_resolution
self.output_dim = output_dim
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
scale = width ** -0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
self.ln_pre = LayerNorm(width)
self.transformer = Transformer(width, layers, heads)
self.ln_post = LayerNorm(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
def forward(self, x: torch.Tensor):
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_post(x[:, 0, :])
if self.proj is not None:
x = x @ self.proj
return x
class CLIP(nn.Module):
def __init__(self,
embed_dim: int,
# vision
image_resolution: int,
vision_layers: Union[Tuple[int, int, int, int], int],
vision_width: int,
vision_patch_size: int,
# text
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int
):
super().__init__()
self.context_length = context_length
if isinstance(vision_layers, (tuple, list)):
vision_heads = vision_width * 32 // 64
self.visual = ModifiedResNet(
layers=vision_layers,
output_dim=embed_dim,
heads=vision_heads,
input_resolution=image_resolution,
width=vision_width
)
else:
vision_heads = vision_width // 64
self.visual = VisualTransformer(
input_resolution=image_resolution,
patch_size=vision_patch_size,
width=vision_width,
layers=vision_layers,
heads=vision_heads,
output_dim=embed_dim
)
self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask()
)
self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
self.ln_final = LayerNorm(transformer_width)
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.initialize_parameters()
def initialize_parameters(self):
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
if isinstance(self.visual, ModifiedResNet):
if self.visual.attnpool is not None:
std = self.visual.attnpool.c_proj.in_features ** -0.5
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
for name, param in resnet_block.named_parameters():
if name.endswith("bn3.weight"):
nn.init.zeros_(param)
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
attn_std = self.transformer.width ** -0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
if self.text_projection is not None:
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
@property
def dtype(self):
return self.visual.conv1.weight.dtype
def encode_image(self, image):
return self.visual(image.type(self.dtype))
def encode_text(self, text):
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return x
def forward(self, image, text):
if image is None:
return self.encode_text(text)
elif text is None:
return self.encode_image(image)
image_features = self.encode_image(image)
text_features = self.encode_text(text)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
return image_features, text_features, self.logit_scale.exp()
def convert_weights(model: nn.Module):
"""Convert applicable model parameters to fp16"""
def _convert_weights_to_fp16(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
l.weight.data = l.weight.data.half()
if l.bias is not None:
l.bias.data = l.bias.data.half()
if isinstance(l, nn.MultiheadAttention):
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
tensor = getattr(l, attr)
if tensor is not None:
tensor.data = tensor.data.half()
for name in ["text_projection", "proj"]:
if hasattr(l, name):
attr = getattr(l, name)
if attr is not None:
attr.data = attr.data.half()
model.apply(_convert_weights_to_fp16)
def build_model(state_dict: dict):
vit = "visual.proj" in state_dict
if vit:
vision_width = state_dict["visual.conv1.weight"].shape[0]
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
image_resolution = vision_patch_size * grid_size
else:
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
vision_layers = tuple(counts)
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
vision_patch_size = None
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
image_resolution = output_width * 32
embed_dim = state_dict["text_projection"].shape[1]
context_length = state_dict["positional_embedding"].shape[0]
vocab_size = state_dict["token_embedding.weight"].shape[0]
transformer_width = state_dict["ln_final.weight"].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
model = CLIP(
embed_dim,
image_resolution, vision_layers, vision_width, vision_patch_size,
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
)
for key in ["input_resolution", "context_length", "vocab_size"]:
if key in state_dict:
del state_dict[key]
convert_weights(model)
model.load_state_dict(state_dict)
return model.eval()

140
src/clip/tokenizer.py Normal file
View File

@@ -0,0 +1,140 @@
import gzip
import html
import os
from functools import lru_cache
import ftfy
import regex as re
@lru_cache()
def default_bpe():
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
class SimpleTokenizer(object):
def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
merges = merges[1:49152-256-2+1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v+'</w>' for v in vocab]
for merge in merges:
vocab.append(''.join(merge))
if not special_tokens:
special_tokens = ['<start_of_text>', '<end_of_text>']
else:
special_tokens = ['<start_of_text>', '<end_of_text>'] + special_tokens
vocab.extend(special_tokens)
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {t:t for t in special_tokens}
special = "|".join(special_tokens)
self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
self.vocab_size = len(self.encoder)
self.all_special_ids = [self.encoder[t] for t in special_tokens]
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
pairs = get_pairs(word)
if not pairs:
return token+'</w>'
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
return text

95
src/data/gather_cc.py Normal file
View File

@@ -0,0 +1,95 @@
import requests
import os
import multiprocessing as mp
from io import BytesIO
import numpy as np
import PIL
from PIL import Image
import pickle
import sys
def grab(line):
"""
Download a single image from the TSV.
"""
uid, split, line = line
try:
caption, url = line.split("\t")[:2]
except:
print("Parse error")
return
if os.path.exists(ROOT+"/%s/%d/%d.jpg"%(split,uid%1000,uid)):
print("Finished", uid)
return uid, caption, url
# Let's not crash if anythign weird happens
try:
dat = requests.get(url, timeout=20)
if dat.status_code != 200:
print("404 file", url)
return
# Try to parse this as an Image file, we'll fail out if not
im = Image.open(BytesIO(dat.content))
im.thumbnail((512, 512), PIL.Image.BICUBIC)
if min(*im.size) < max(*im.size)/3:
print("Too small", url)
return
im.save(ROOT+"/%s/%d/%d.jpg"%(split,uid%1000,uid))
# Another try/catch just because sometimes saving and re-loading
# the image is different than loading it once.
try:
o = Image.open(ROOT+"/%s/%d/%d.jpg"%(split,uid%1000,uid))
o = np.array(o)
print("Success", o.shape, uid, url)
return uid, caption, url
except:
print("Failed", uid, url)
except Exception as e:
print("Unknown error", e)
pass
if __name__ == "__main__":
ROOT = "cc_data"
if not os.path.exists(ROOT):
os.mkdir(ROOT)
os.mkdir(ROOT+"/train")
os.mkdir(ROOT+"/val")
for i in range(1000):
os.mkdir(ROOT+"/train/"+str(i))
os.mkdir(ROOT+"/val/"+str(i))
p = mp.Pool(300)
for tsv in sys.argv[1:]:
print("Processing file", tsv)
assert 'val' in tsv.lower() or 'train' in tsv.lower()
split = 'val' if 'val' in tsv.lower() else 'train'
results = p.map(grab,
[(i,split,x) for i,x in enumerate(open(tsv).read().split("\n"))])
out = open(tsv.replace(".tsv","_output.csv"),"w")
out.write("title\tfilepath\n")
for row in results:
print("Test", row)
if row is None: continue
id, caption, url = row
fp = "val/"+str(id%1000)+"/"+str(id)+".jpg"
print(fp)
if os.path.exists(os.path.join(ROOT,fp)):
out.write("%s\t%s\n"%(caption,fp))
else:
print("Drop", id)
out.close()
p.close()

1
src/training/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
logs/

218
src/training/data.py Normal file
View File

@@ -0,0 +1,218 @@
import os
import sys
import math
import logging
import functools
import braceexpand
import random
import pdb
import pandas as pd
import numpy as np
import pyarrow as pa
from PIL import Image
from typing import Union
from dataclasses import dataclass
import torch
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torch.utils.data.distributed import DistributedSampler
import torchvision.datasets as datasets
from webdataset.utils import identity
import webdataset as wds
from clip.clip import tokenize
class CsvDataset(Dataset):
def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t"):
logging.debug(f'Loading csv data from {input_filename}.')
df = pd.read_csv(input_filename, sep=sep)
self.images = df[img_key].tolist()
self.captions = df[caption_key].tolist()
self.transforms = transforms
logging.debug('Done loading data.')
def __len__(self):
return len(self.captions)
def __getitem__(self, idx):
images = self.transforms(Image.open(str(self.images[idx])))
texts = tokenize([str(self.captions[idx])])[0]
return images, texts
@dataclass
class DataInfo:
dataloader: DataLoader
sampler: DistributedSampler
def preprocess_txt(text):
return tokenize([str(text)])[0]
def get_dataset_size(shards):
shards_list = list(braceexpand.braceexpand(shards))
dir_path = os.path.dirname(shards)
sizes = eval(open(os.path.join(dir_path, 'sizes.json'), 'r').read())
total_size = sum(
[int(sizes[os.path.basename(shard)]) for shard in shards_list])
num_shards = len(shards_list)
return total_size, num_shards
def get_imagenet(args, preprocess_fns, split):
assert split in ["train", "val", "v2"]
is_train = split == "train"
preprocess_train, preprocess_val = preprocess_fns
if split == "v2":
from imagenetv2_pytorch import ImageNetV2Dataset
dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val)
else:
if is_train:
data_path = args.imagenet_train
preprocess_fn = preprocess_train
else:
data_path = args.imagenet_val
preprocess_fn = preprocess_val
assert data_path
dataset = datasets.ImageFolder(data_path, transform=preprocess_fn)
if is_train:
idxs = np.zeros(len(dataset.targets))
target_array = np.array(dataset.targets)
k = 50
for c in range(1000):
m = target_array == c
n = len(idxs[m])
arr = np.zeros(n)
arr[:k] = 1
np.random.shuffle(arr)
idxs[m] = arr
idxs = idxs.astype('int')
sampler = SubsetRandomSampler(np.where(idxs)[0])
else:
sampler = None
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
num_workers=args.workers,
sampler=sampler,
)
return DataInfo(dataloader, sampler)
def count_samples(dataloader):
os.environ["WDS_EPOCH"] = "0"
n_elements, n_batches = 0, 0
for images, texts in dataloader:
n_batches += 1
n_elements += len(images)
assert len(images) == len(texts)
return n_elements, n_batches
def get_wds_dataset(args, preprocess_img, is_train):
input_shards = args.train_data if is_train else args.val_data
assert input_shards is not None
# The following code is adapted from https://github.com/tmbdev/webdataset-examples/blob/master/main-wds.py
num_samples, num_shards = get_dataset_size(input_shards)
if is_train and args.distributed:
max_shards_per_node = math.ceil(num_shards / args.world_size)
num_samples = args.world_size * (num_samples * max_shards_per_node // num_shards)
num_batches = num_samples // (args.batch_size * args.world_size)
num_samples = num_batches * args.batch_size * args.world_size
else:
num_batches = num_samples // args.batch_size
shardlist = wds.PytorchShardList(
input_shards,
epoch_shuffle=is_train,
split_by_node=is_train # NOTE: we do eval on a single gpu.
)
dataset = (
wds.WebDataset(shardlist)
.decode("pil")
.rename(image="jpg;png", text="txt")
.map_dict(image=preprocess_img, text=preprocess_txt)
.to_tuple("image", "text")
.batched(args.batch_size, partial=not is_train or not args.distributed)
)
dataloader = wds.WebLoader(
dataset, batch_size=None, shuffle=False, num_workers=args.workers,
)
if is_train and args.distributed:
# With DDP, we need to make sure that all nodes get the same number of batches;
# we do that by reusing a little bit of data.
dataloader = dataloader.repeat(2).slice(num_batches)
dataloader.num_batches = num_batches
dataloader.num_samples = num_samples
return DataInfo(dataloader, None)
def get_csv_dataset(args, preprocess_fn, is_train):
input_filename = args.train_data if is_train else args.val_data
assert input_filename
dataset = CsvDataset(
input_filename,
preprocess_fn,
img_key=args.csv_img_key,
caption_key=args.csv_caption_key,
sep=args.csv_separator)
num_samples = len(dataset)
sampler = DistributedSampler(dataset) if args.distributed and is_train else None
shuffle = is_train and sampler is None
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=shuffle,
num_workers=args.workers,
pin_memory=True,
sampler=sampler,
drop_last=is_train,
)
dataloader.num_samples = num_samples
dataloader.num_batches = len(dataloader)
return DataInfo(dataloader, sampler)
def get_dataset_fn(data_path, dataset_type):
if dataset_type == "webdataset":
return get_wds_dataset
elif dataset_type == "csv":
return get_csv_dataset
elif dataset_type == "auto":
ext = data_path.split('.')[-1]
if ext in ['csv', 'tsv']:
return get_csv_dataset
elif ext in ['tar']:
return get_wds_dataset
else:
raise ValueError(
f"Tried to figure out dataset type, but failed for extention {ext}.")
else:
raise ValueError(f"Unsupported dataset type: {dataset_type}")
def get_data(args, preprocess_fns):
preprocess_train, preprocess_val = preprocess_fns
data = {}
if args.train_data:
data["train"] = get_dataset_fn(args.train_data, args.dataset_type)(
args, preprocess_train, is_train=True)
if args.val_data:
data["val"] = get_dataset_fn(args.val_data, args.dataset_type)(
args, preprocess_val, is_train=False)
if args.imagenet_val is not None:
data["imagenet-val"] = get_imagenet(args, preprocess_fns, "val")
if args.imagenet_v2 is not None:
data["imagenet-v2"] = get_imagenet(args, preprocess_fns, "v2")
return data

View File

@@ -0,0 +1,254 @@
imagenet_classnames = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray",
"stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco",
"indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper",
"kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander",
"smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog",
"tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin",
"box turtle", "banded gecko", "green iguana", "Carolina anole",
"desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard",
"Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile",
"American alligator", "triceratops", "worm snake", "ring-necked snake",
"eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake",
"vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra",
"green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake",
"sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider",
"barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider",
"tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl",
"quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet",
"coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck",
"red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby",
"koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch",
"snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab",
"fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab",
"isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron",
"great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot",
"bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher",
"pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion",
"Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel",
"Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle",
"Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound",
"English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound",
"Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound",
"Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier",
"Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier",
"Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier",
"Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier",
"Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer",
"Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier",
"Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier",
"Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever",
"Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla",
"English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel",
"English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel",
"Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard",
"Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie",
"Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann",
"Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog",
"Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff",
"French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky",
"Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog",
"Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon",
"Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle",
"Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf",
"red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox",
"kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat",
"Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger",
"cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose",
"meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle",
"dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper",
"cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper",
"lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly",
"monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly",
"starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit",
"hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse",
"zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison",
"ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)",
"gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat",
"black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan",
"gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque",
"langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin",
"howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey",
"ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda",
"giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish",
"sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown",
"accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance",
"amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle",
"backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo",
"baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel",
"wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel",
"bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)",
"beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini",
"ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet",
"bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra",
"breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest",
"high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe",
"can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton",
"car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran",
"CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw",
"storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking",
"church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker",
"coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard",
"candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot",
"cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed",
"Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer",
"rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table",
"dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig",
"drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar",
"electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder",
"feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute",
"folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed",
"freight car", "French horn", "frying pan", "fur coat", "garbage truck",
"gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola",
"gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine",
"hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer",
"handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet",
"holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar",
"horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep",
"T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat",
"ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library",
"lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion",
"music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag",
"mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask",
"matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone",
"microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile",
"mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor",
"moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa",
"mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail",
"neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina",
"odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart",
"oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush",
"pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench",
"parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case",
"pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube",
"picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball",
"pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag",
"plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho",
"pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug",
"printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill",
"quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel",
"recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator",
"remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser",
"rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal",
"sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard",
"CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store",
"shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap",
"shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door",
"slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock",
"solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater",
"space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight",
"stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf",
"stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa",
"submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge",
"mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe",
"table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball",
"thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof",
"toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store",
"tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod",
"triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard",
"umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling",
"velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball",
"waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink",
"washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle",
"hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing",
"wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website",
"comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu",
"plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette",
"bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli",
"cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber",
"artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange",
"lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate",
"hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito",
"red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef",
"geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player",
"bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn",
"rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom",
"earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"]
openai_imagenet_template = [
lambda c: f'a bad photo of a {c}.',
lambda c: f'a photo of many {c}.',
lambda c: f'a sculpture of a {c}.',
lambda c: f'a photo of the hard to see {c}.',
lambda c: f'a low resolution photo of the {c}.',
lambda c: f'a rendering of a {c}.',
lambda c: f'graffiti of a {c}.',
lambda c: f'a bad photo of the {c}.',
lambda c: f'a cropped photo of the {c}.',
lambda c: f'a tattoo of a {c}.',
lambda c: f'the embroidered {c}.',
lambda c: f'a photo of a hard to see {c}.',
lambda c: f'a bright photo of a {c}.',
lambda c: f'a photo of a clean {c}.',
lambda c: f'a photo of a dirty {c}.',
lambda c: f'a dark photo of the {c}.',
lambda c: f'a drawing of a {c}.',
lambda c: f'a photo of my {c}.',
lambda c: f'the plastic {c}.',
lambda c: f'a photo of the cool {c}.',
lambda c: f'a close-up photo of a {c}.',
lambda c: f'a black and white photo of the {c}.',
lambda c: f'a painting of the {c}.',
lambda c: f'a painting of a {c}.',
lambda c: f'a pixelated photo of the {c}.',
lambda c: f'a sculpture of the {c}.',
lambda c: f'a bright photo of the {c}.',
lambda c: f'a cropped photo of a {c}.',
lambda c: f'a plastic {c}.',
lambda c: f'a photo of the dirty {c}.',
lambda c: f'a jpeg corrupted photo of a {c}.',
lambda c: f'a blurry photo of the {c}.',
lambda c: f'a photo of the {c}.',
lambda c: f'a good photo of the {c}.',
lambda c: f'a rendering of the {c}.',
lambda c: f'a {c} in a video game.',
lambda c: f'a photo of one {c}.',
lambda c: f'a doodle of a {c}.',
lambda c: f'a close-up photo of the {c}.',
lambda c: f'a photo of a {c}.',
lambda c: f'the origami {c}.',
lambda c: f'the {c} in a video game.',
lambda c: f'a sketch of a {c}.',
lambda c: f'a doodle of the {c}.',
lambda c: f'a origami {c}.',
lambda c: f'a low resolution photo of a {c}.',
lambda c: f'the toy {c}.',
lambda c: f'a rendition of the {c}.',
lambda c: f'a photo of the clean {c}.',
lambda c: f'a photo of a large {c}.',
lambda c: f'a rendition of a {c}.',
lambda c: f'a photo of a nice {c}.',
lambda c: f'a photo of a weird {c}.',
lambda c: f'a blurry photo of a {c}.',
lambda c: f'a cartoon {c}.',
lambda c: f'art of a {c}.',
lambda c: f'a sketch of the {c}.',
lambda c: f'a embroidered {c}.',
lambda c: f'a pixelated photo of a {c}.',
lambda c: f'itap of the {c}.',
lambda c: f'a jpeg corrupted photo of the {c}.',
lambda c: f'a good photo of a {c}.',
lambda c: f'a plushie {c}.',
lambda c: f'a photo of the nice {c}.',
lambda c: f'a photo of the small {c}.',
lambda c: f'a photo of the weird {c}.',
lambda c: f'the cartoon {c}.',
lambda c: f'art of the {c}.',
lambda c: f'a drawing of the {c}.',
lambda c: f'a photo of the large {c}.',
lambda c: f'a black and white photo of a {c}.',
lambda c: f'the plushie {c}.',
lambda c: f'a dark photo of a {c}.',
lambda c: f'itap of a {c}.',
lambda c: f'graffiti of the {c}.',
lambda c: f'a toy {c}.',
lambda c: f'itap of my {c}.',
lambda c: f'a photo of a cool {c}.',
lambda c: f'a photo of a small {c}.',
lambda c: f'a tattoo of the {c}.',
]

90
src/training/logger.py Normal file
View File

@@ -0,0 +1,90 @@
import argparse
import logging
from logging import Filter
from logging.handlers import QueueHandler, QueueListener
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.multiprocessing import Queue
def setup_primary_logging(log_file, level):
log_queue = Queue(-1)
file_handler = logging.FileHandler(filename=log_file)
stream_handler = logging.StreamHandler()
formatter = logging.Formatter(
'%(asctime)s | %(levelname)s | %(message)s',
datefmt='%Y-%m-%d,%H:%M:%S')
file_handler.setFormatter(formatter)
stream_handler.setFormatter(formatter)
file_handler.setLevel(level)
stream_handler.setLevel(level)
listener = QueueListener(log_queue, file_handler, stream_handler)
listener.start()
return log_queue
class WorkerLogFilter(Filter):
def __init__(self, rank=-1):
super().__init__()
self._rank = rank
def filter(self, record):
if self._rank != -1:
record.msg = f"Rank {self._rank} | {record.msg}"
return True
def setup_worker_logging(rank, log_queue, level):
queue_handler = QueueHandler(log_queue)
worker_filter = WorkerLogFilter(rank)
queue_handler.addFilter(worker_filter)
queue_handler.setLevel(level)
root_logger = logging.getLogger()
root_logger.addHandler(queue_handler)
root_logger.setLevel(level)
def fake_worker(rank: int, world_size: int, log_queue: Queue):
setup_worker_logging(rank, log_queue, logging.DEBUG)
logging.info("Test worker log")
logging.error("Test worker error log")
torch.cuda.set_device(rank)
dist.init_process_group(
backend='nccl',
init_method='tcp://127.0.0.1:6100',
world_size=world_size,
rank=rank,
)
if __name__ == "__main__":
# Set multiprocessing type to spawn
torch.multiprocessing.set_start_method("spawn")
parser = argparse.ArgumentParser()
parser.add_argument("-g", "--gpu-list", type=int, help="List of GPU IDs", nargs="+", required=True)
args = parser.parse_args()
world_size = len(args.gpu_list)
# Initialize the primary logging handlers. Use the returned `log_queue`
# to which the worker processes would use to push their messages
log_queue = setup_primary_logging("/usr/lusers/gamaga/out.log", logging.DEBUG)
if world_size == 1:
worker(0, world_size, log_queue)
else:
mp.spawn(fake_worker, args=(world_size, log_queue), nprocs=world_size)

307
src/training/main.py Normal file
View File

@@ -0,0 +1,307 @@
import os
import time
import logging
from time import gmtime, strftime
from pathlib import Path
import json
import wandb
import torch
from torch import optim
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
from torch.cuda.amp import GradScaler
from clip.clip import _transform, load
from clip.model import convert_weights, CLIP
from training.train import train, evaluate
from training.data import get_data
from training.params import parse_args
from training.logger import setup_primary_logging, setup_worker_logging
from training.scheduler import cosine_lr
# Used by https://github.com/openai/CLIP/issues/83 but not below.
# Keeping it incase needed.
def convert_models_to_fp32(model):
for p in model.parameters():
p.data = p.data.float()
if p.grad:
p.grad.data = p.grad.data.float()
def is_master(args):
return (not args.distributed) or args.gpu == 0 or args.dp
def main_worker(gpu, ngpus_per_node, log_queue, args):
args.gpu = gpu
args.rank = gpu
setup_worker_logging(args.rank, log_queue, args.log_level)
# Log and save params.
if is_master(args):
logging.info("Params:")
params_file = os.path.join(args.logs, args.name, "params.txt")
with open(params_file, "w") as f:
for name in sorted(vars(args)):
val = getattr(args, name)
logging.info(f" {name}: {val}")
f.write(f"{name}: {val}\n")
if args.distributed:
dist.init_process_group(
backend=args.dist_backend,
init_method=args.dist_url,
world_size=args.world_size,
rank=args.rank,
)
if args.dp:
args.batch_size *= args.world_size
if args.gpu is not None:
logging.info(f"Use GPU: {args.gpu} for training")
torch.cuda.set_device(args.gpu)
# Do not use skip_reset unless you want to use on of the CLIP model
if args.openai_pretrained:
model, preprocess_train, preprocess_val = load(
args.model,
jit=False,
is_train=True)
else:
model_config_file = Path(__file__).parent / f"model_configs/{args.model.replace('/', '-')}.json"
print('Loading model from', model_config_file)
assert os.path.exists(model_config_file)
with open(model_config_file, 'r') as f:
model_info = json.load(f)
model = CLIP(**model_info)
convert_weights(model)
preprocess_train = _transform(model.visual.input_resolution, is_train=True)
preprocess_val = _transform(model.visual.input_resolution, is_train=False)
# See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
if args.precision == "amp" or args.precision == "fp32" or args.gpu is None:
convert_models_to_fp32(model)
if not torch.cuda.is_available():
model.float()
logging.warning("using CPU, this will be slow")
else:
model.cuda(args.gpu)
if args.precision == "fp16":
convert_weights(model)
# Previously batch size and workers were global and not per GPU.
# args.batch_size = args.batch_size / ngpus_per_node)
# args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
if args.distributed and args.use_bn_sync:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
if args.dp:
model = torch.nn.DataParallel(model, device_ids=args.multigpu)
if args.precision == "fp16":
convert_weights(model)
data = get_data(args, (preprocess_train, preprocess_val))
exclude = lambda n : "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n
include = lambda n : not exclude(n)
named_parameters = list(model.named_parameters())
gain_or_bias_params = [p for n, p in named_parameters if exclude(n) and p.requires_grad]
rest_params = [p for n, p in named_parameters if include(n) and p.requires_grad]
if args.train_data is None:
optimizer = None
scheduler = None
else:
optimizer = optim.AdamW(
[
{"params": gain_or_bias_params, "weight_decay": 0.},
{"params": rest_params, "weight_decay": args.wd},
],
lr=args.lr,
betas=(args.beta1, args.beta2),
eps=args.eps,
)
total_steps = data["train"].dataloader.num_batches * args.epochs
scheduler = cosine_lr(optimizer, args.lr, args.warmup, total_steps)
scaler = GradScaler() if args.precision == "amp" else None
# optionally resume from a checkpoint
start_epoch = 0
if args.resume is not None:
if os.path.isfile(args.resume):
if args.gpu is None:
checkpoint = torch.load(args.resume)
else:
# Map model to be loaded to specified single gpu.
loc = "cuda:{}".format(args.gpu)
checkpoint = torch.load(args.resume, map_location=loc)
start_epoch = checkpoint["epoch"]
sd = checkpoint["state_dict"]
if not args.distributed and next(iter(sd.items()))[0].startswith('module'):
sd = {k[len('module.'):]: v for k, v in sd.items()}
model.load_state_dict(sd)
if optimizer is not None:
optimizer.load_state_dict(checkpoint["optimizer"])
logging.info(
f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})"
)
else:
logging.info("=> no checkpoint found at '{}'".format(args.resume))
cudnn.benchmark = True
cudnn.deterministic = False
# determine if this worker should save logs and checkpoints.
# only do so if it is the 0th worker.
args.save_logs = (args.logs is not None and args.logs != '' and args.logs.lower() != 'none') and (
(not args.distributed) or args.gpu == 0
)
writer = None
if args.save_logs and args.tensorboard:
writer = SummaryWriter(args.tensorboard_path)
if args.wandb and is_master(args):
logging.debug('Starting wandb.')
args.train_sz = data["train"].dataloader.num_samples
if args.val_data is not None:
args.val_sz = data["val"].dataloader.num_samples
# you will have to configure this for your project!
wandb.init(
project="open-clip",
notes=args.wandb_notes,
tags=[],
config=vars(args),
)
if args.debug:
wandb.watch(model, log='all')
wandb.save(params_file)
logging.debug('Finished loading wandb.')
if args.train_data is None:
evaluate(model, data, start_epoch, args, writer, 0)
return
elif start_epoch == 0 and args.val_data is not None:
evaluate(model, data, 0, args, writer, 0)
for epoch in range(start_epoch, args.epochs):
if args.gpu == 0:
logging.info(f'Start epoch {epoch}')
train(model, data, epoch, optimizer, scaler, scheduler, args, writer)
steps = data["train"].dataloader.num_batches * (epoch + 1)
if args.val_data is not None:
evaluate(model, data, epoch + 1, args, writer, steps)
# Saving checkpoints.
if args.save_logs and (args.gpu == 0 or (not args.distributed)):
if (epoch + 1) == args.epochs or (
args.save_frequency > 0 and ((epoch + 1) % args.save_frequency) == 0
):
torch.save(
{
"epoch": epoch + 1,
"name": args.name,
"state_dict": model.state_dict(),
"optimizer": optimizer.state_dict(),
},
os.path.join(args.checkpoint_path, f"epoch_{epoch + 1}.pt"),
)
if args.wandb and (args.gpu == 0 or (not args.distributed)):
wandb.finish()
def main():
args = parse_args()
# get the name of the experiments
if args.name is None:
args.name = strftime(
f"lr={args.lr}_"
f"wd={args.wd}_"
f"agg={args.aggregate}_"
f"model={args.model}_"
f"batchsize={args.batch_size}_workers={args.workers}_date=%Y-%m-%d-%H-%M-%S",
gmtime(),
)
if args.copy_codebase:
import sys, subprocess
from shutil import copytree, ignore_patterns
new_code_path = os.path.join(args.logs, args.name, "code")
if os.path.exists(new_code_path):
print(
f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment."
)
return -1
print(f"Copying codebase to {new_code_path}")
current_code_path = os.path.realpath(__file__)
for _ in range(3):
current_code_path = os.path.dirname(current_code_path)
copytree(current_code_path, new_code_path, ignore=ignore_patterns('log', 'logs', 'wandb'))
print("Done copying code.")
os.environ["PYTHONPATH"] = f"{os.environ['PYTHONPATH']}:{os.path.join(new_code_path, 'src')}"
main_file = os.path.join(new_code_path, "src", "training", "main.py")
argv = sys.argv
argv.remove('--copy-codebase')
argv.extend(['--name', args.name])
command = [sys.executable] + argv
print("Executing command:", " ".join(command))
subprocess.check_call(command)
return 1
args.log_path = os.path.join(args.logs, args.name, "out.log")
if os.path.exists(args.log_path):
print(
"Error. Experiment already exists. Use --name {} to specify a new experiment."
)
return -1
assert args.precision in ['amp', 'fp16', 'fp32']
#assert args.model in ['RN50', 'RN101', 'RN50x4', 'ViT-B/32'] or os.path.exists(args.model)
args.ngpus_per_node = torch.cuda.device_count()
args.wandb = 'wandb' in args.report_to or 'all' in args.report_to
args.tensorboard = 'tensorboard' in args.report_to or 'all' in args.report_to
args.tensorboard_path = os.path.join(args.logs, args.name, "tensorboard") if args.tensorboard else ''
args.checkpoint_path = os.path.join(args.logs, args.name, "checkpoints")
for dirname in [args.tensorboard_path, args.checkpoint_path]:
if dirname:
os.makedirs(dirname, exist_ok=True)
# Set multiprocessing type to spawn.
# This is important for logging to work with multiprocessing.
torch.multiprocessing.set_start_method("spawn")
# Set logger
args.log_level = logging.DEBUG if args.debug else logging.INFO
log_queue = setup_primary_logging(args.log_path, args.log_level)
# Distributed training = training on more than one GPU.
# Also easily possible to extend to multiple nodes & multiple GPUs.
args.distributed = (args.gpu is None) and torch.cuda.is_available() and (not args.dp)
if args.distributed:
ngpus_per_node = torch.cuda.device_count()
args.world_size = ngpus_per_node
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, log_queue, args))
else:
if args.dp:
args.gpu = args.multigpu[0]
args.world_size = len(args.multigpu)
else:
args.world_size = 1
main_worker(args.gpu, None, log_queue, args)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,17 @@
{
"embed_dim": 512,
"image_resolution": 224,
"vision_layers": [
3,
4,
23,
3
],
"vision_width": 64,
"vision_patch_size": null,
"context_length": 77,
"vocab_size": 49408,
"transformer_width": 512,
"transformer_heads": 8,
"transformer_layers": 12
}

View File

@@ -0,0 +1,17 @@
{
"embed_dim": 1024,
"image_resolution": 224,
"vision_layers": [
3,
4,
6,
3
],
"vision_width": 64,
"vision_patch_size": null,
"context_length": 77,
"vocab_size": 49408,
"transformer_width": 512,
"transformer_heads": 8,
"transformer_layers": 12
}

View File

@@ -0,0 +1,17 @@
{
"embed_dim": 640,
"image_resolution": 288,
"vision_layers": [
4,
6,
10,
6
],
"vision_width": 80,
"vision_patch_size": null,
"context_length": 77,
"vocab_size": 49408,
"transformer_width": 640,
"transformer_heads": 10,
"transformer_layers": 12
}

View File

@@ -0,0 +1,12 @@
{
"embed_dim": 512,
"image_resolution": 224,
"vision_layers": 12,
"vision_width": 768,
"vision_patch_size": 32,
"context_length": 77,
"vocab_size": 49408,
"transformer_width": 512,
"transformer_heads": 8,
"transformer_layers": 12
}

207
src/training/params.py Normal file
View File

@@ -0,0 +1,207 @@
import argparse
def get_default_params(model_name):
# Params from paper (https://arxiv.org/pdf/2103.00020.pdf)
if model_name in ["RN50", "RN101", "RN50x4"]:
return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8}
elif model_name == "ViT-B/32":
return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6}
else:
return {}
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--train-data",
type=str,
default=None,
help="Path to csv filewith training data",
)
parser.add_argument(
"--val-data",
type=str,
default=None,
help="Path to csv file with validation data",
)
parser.add_argument(
"--dataset-type",
choices=["webdataset", "csv", "auto"],
default="auto",
help="Which type of dataset to process."
)
parser.add_argument(
"--csv-separator",
type=str,
default="\t",
help="For csv-like datasets, which separator to use."
)
parser.add_argument(
"--csv-img-key",
type=str,
default="filepath",
help="For csv-like datasets, the name of the key for the image paths."
)
parser.add_argument(
"--csv-caption-key",
type=str,
default="title",
help="For csv-like datasets, the name of the key for the captions."
)
parser.add_argument(
"--imagenet-val",
type=str,
default=None,
help="Path to imagenet val set for conducting zero shot evaluation.",
)
parser.add_argument(
"--imagenet-v2",
type=str,
default=None,
help="Path to imagenet v2 for conducting zero shot evaluation.",
)
parser.add_argument(
"--logs",
type=str,
default="./logs/",
help="Where to store tensorboard logs. Use None to avoid storing logs.",
)
parser.add_argument(
"--name",
type=str,
default=None,
help="Optional identifier for the experiment when storing logs. Otherwise use current time.",
)
parser.add_argument(
"--workers", type=int, default=1, help="Number of workers per GPU."
)
parser.add_argument(
"--batch-size", type=int, default=64, help="Batch size per GPU."
)
parser.add_argument(
"--epochs", type=int, default=32, help="Number of epochs to train for."
)
parser.add_argument("--lr", type=float, default=None, help="Learning rate.")
parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.")
parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.")
parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.")
parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.")
parser.add_argument(
"--warmup", type=int, default=10000, help="Number of steps to warmup for."
)
parser.add_argument("--use-bn-sync",
default=False,
action="store_true",
help="Whether to use batch norm sync.")
parser.add_argument(
"--gpu",
type=int,
default=None,
help="Specify a single GPU to run the code on for debugging."
"Leave at None to use all available GPUs.",
)
parser.add_argument(
"--skip-scheduler",
action="store_true",
default=False,
help="Use this flag to skip the learning rate decay.",
)
parser.add_argument(
"--save-frequency", type=int, default=1, help="How often to save checkpoints."
)
parser.add_argument(
"--zeroshot-frequency", type=int, default=2, help="How often to run zero shot."
)
parser.add_argument(
"--regression-frequency", type=int, default=2, help="How often to run zero shot."
)
parser.add_argument(
"--resume",
default=None,
type=str,
help="path to latest checkpoint (default: none)",
)
parser.add_argument(
"--precision",
choices=["amp", "fp16", "fp32"],
default="amp",
help="Floating point precition."
)
parser.add_argument(
"--model",
choices=["RN50", "RN101", "RN50x4", "ViT-B/32"],
default="RN50",
help="Name of the vision backbone to use.",
)
parser.add_argument(
"--openai-pretrained",
default=False,
action='store_true',
help="Use the openai pretrained models.",
)
# arguments for distributed training
parser.add_argument(
"--dist-url",
default="tcp://127.0.0.1:6100",
type=str,
help="url used to set up distributed training",
)
parser.add_argument(
"--dist-backend", default="nccl", type=str, help="distributed backend"
)
parser.add_argument(
"--skip-aggregate",
default=False,
action="store_true",
help="whether to aggregate features across gpus before computing the loss"
)
parser.add_argument(
"--report-to",
default='',
type=str,
help="Options are ['wandb', 'tensorboard', 'wandb,tensorboard']"
)
parser.add_argument(
"--wandb-notes",
default='',
type=str,
help="Notes if logging with wandb"
)
parser.add_argument(
"--C", type=float, default=3.16, help="inverse regularizer for logistic reg."
)
parser.add_argument(
"--debug",
default=False,
action="store_true",
help="If true, more information is logged."
)
parser.add_argument(
"--copy-codebase",
default=False,
action="store_true",
help="If true, we copy the entire base on the log diretory, and execute from there."
)
parser.add_argument(
"--dp",
default=False,
action="store_true",
help="Use DP instead of DDP."
)
parser.add_argument(
"--multigpu",
default=None,
type=lambda x: [int(a) for a in x.split(",")],
help="In DP, which GPUs to use for multigpu training",
)
args = parser.parse_args()
args.aggregate = not args.skip_aggregate
# If some params are not passed, we use the default values based on model name.
default_params = get_default_params(args.model)
for name, val in default_params.items():
if getattr(args, name) is None:
setattr(args, name, val)
return args

20
src/training/scheduler.py Normal file
View File

@@ -0,0 +1,20 @@
import numpy as np
def assign_learning_rate(optimizer, new_lr):
for param_group in optimizer.param_groups:
param_group["lr"] = new_lr
def _warmup_lr(base_lr, warmup_length, step):
return base_lr * (step + 1) / warmup_length
def cosine_lr(optimizer, base_lr, warmup_length, steps):
def _lr_adjuster(step):
if step < warmup_length:
lr = _warmup_lr(base_lr, warmup_length, step)
else:
e = step - warmup_length
es = steps - warmup_length
lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
assign_learning_rate(optimizer, lr)
return lr
return _lr_adjuster

245
src/training/train.py Normal file
View File

@@ -0,0 +1,245 @@
import os
import time
import json
import numpy as np
import torch
import torch.nn as nn
from torch.cuda.amp import autocast
import torch.distributed as dist
from .zero_shot import zero_shot_eval
import sys
import pdb
import wandb
import logging
def is_master(args):
return (not args.distributed) or args.gpu == 0
def get_loss(model, images, texts, loss_img, loss_txt, args):
image_features, text_features, logit_scale = model(images, texts)
logit_scale = logit_scale.mean()
if args.distributed and args.aggregate:
world_size = dist.get_world_size()
rank = dist.get_rank()
# We gather tensors from all gpus to get more negatives to contrast with.
gathered_image_features = [
torch.zeros_like(image_features) for _ in range(world_size)
]
gathered_text_features = [
torch.zeros_like(text_features) for _ in range(world_size)
]
dist.all_gather(gathered_image_features, image_features)
dist.all_gather(gathered_text_features, text_features)
all_image_features = torch.cat(
[image_features]
+ gathered_image_features[:rank]
+ gathered_image_features[rank + 1 :]
)
all_text_features = torch.cat(
[text_features]
+ gathered_text_features[:rank]
+ gathered_text_features[rank + 1 :]
)
# this is needed to send gradients back everywhere.
logits_per_image = logit_scale * all_image_features @ all_text_features.t()
logits_per_text = logits_per_image.t()
else:
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logit_scale * text_features @ image_features.t()
ground_truth = torch.arange(len(logits_per_image)).long()
if args.gpu is not None:
ground_truth = ground_truth.cuda(args.gpu, non_blocking=True)
total_loss = (
loss_img(logits_per_image, ground_truth)
+ loss_txt(logits_per_text, ground_truth)
) / 2
return total_loss
def train(model, data, epoch, optimizer, scaler, scheduler, args, tb_writer=None):
os.environ["WDS_EPOCH"] = str(epoch)
model.train()
dataloader, sampler = data['train'].dataloader, data['train'].sampler
loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()
if args.gpu is not None:
loss_img = loss_img.cuda(args.gpu)
loss_txt = loss_txt.cuda(args.gpu)
if args.distributed and sampler is not None:
sampler.set_epoch(epoch)
num_batches_per_epoch = dataloader.num_batches
end = time.time()
for i, batch in enumerate(dataloader):
step = num_batches_per_epoch * epoch + i
scheduler(step)
optimizer.zero_grad()
images, texts = batch
if args.gpu is not None:
images = images.cuda(args.gpu, non_blocking=True)
texts = texts.cuda(args.gpu, non_blocking=True)
data_time = time.time() - end
m = model.module if args.distributed or args.dp else model
# with automatic mixed precision.
if args.precision == "amp":
with autocast():
total_loss = get_loss(model, images, texts, loss_img, loss_txt, args)
scaler.scale(total_loss).backward()
scaler.step(optimizer)
scaler.update()
else:
total_loss = get_loss(model, images, texts, loss_img, loss_txt, args)
total_loss.backward()
optimizer.step()
# Note: we clamp to 4.6052 = ln(100), as in the original paper.
m.logit_scale.data = torch.clamp(m.logit_scale.data, 0, 4.6052)
batch_time = time.time() - end
end = time.time()
if is_master(args) and (i % 100) == 0:
num_samples = i * len(images) * args.world_size
samples_per_epoch = dataloader.num_samples
percent_complete = 100.0 * i / num_batches_per_epoch
logging.info(
f"Train Epoch: {epoch} [{num_samples}/{samples_per_epoch} ({percent_complete:.0f}%)]\t"
f"Loss: {total_loss.item():.6f}\tData (t) {data_time:.3f}\tBatch (t) {batch_time:.3f}"
f"\tLR: {optimizer.param_groups[0]['lr']:5f}\tlogit_scale {m.logit_scale.data:.3f}"
)
# save train loss / etc.
timestep = epoch * num_batches_per_epoch + i
log_data = {
"loss": total_loss.item(),
"data_time": data_time,
"batch_time": batch_time,
"scale": m.logit_scale.data.item(),
"lr": optimizer.param_groups[0]["lr"]
}
for name, val in log_data.items():
name = "train/" + name
if tb_writer is not None:
tb_writer.add_scalar(name, val, timestep)
if args.wandb:
wandb.log({name: val, 'step': timestep})
def evaluate(model, data, epoch, args, tb_writer=None, steps=None):
if not is_master(args):
return
model.eval()
zero_shot_metrics = zero_shot_eval(model, data, epoch, args)
dataloader = data['val'].dataloader
loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()
if args.gpu is not None:
loss_img = loss_img.cuda(args.gpu)
loss_txt = loss_txt.cuda(args.gpu)
cumulative_loss = 0.0
num_elements = 0.0
all_image_features, all_text_features = [], []
with torch.no_grad():
for batch in dataloader:
images, texts = batch
if args.gpu is not None:
images = images.cuda(args.gpu, non_blocking=True)
texts = texts.cuda(args.gpu, non_blocking=True)
image_features, text_features, logit_scale = model(images, texts)
all_image_features.append(image_features)
all_text_features.append(text_features)
logit_scale = logit_scale.mean()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
ground_truth = torch.arange(len(images)).long()
if args.gpu is not None:
ground_truth = ground_truth.cuda(args.gpu, non_blocking=True)
total_loss = (
loss_img(logits_per_image, ground_truth)
+ loss_txt(logits_per_text, ground_truth)
) / 2
batch_size = len(images)
cumulative_loss += total_loss * batch_size
num_elements += batch_size
metrics = get_metrics(
torch.cat(all_image_features), torch.cat(all_text_features)
)
loss = cumulative_loss / num_elements
metrics.update(
**{"val_loss": loss.item(), "epoch": epoch, "num_elements": num_elements}
)
metrics.update(zero_shot_metrics)
logging.info(
f"Eval Epoch: {epoch} "
+ "\t".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
)
if args.save_logs:
for name, val in metrics.items():
if tb_writer is not None:
tb_writer.add_scalar(f"val/{name}", val, epoch)
if args.wandb:
for name, val in metrics.items():
wandb.log({f"val/{name}": val, 'epoch': epoch})
if args.save_logs:
with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f:
f.write(json.dumps(metrics))
f.write("\n")
return metrics
def get_metrics(image_features, text_features):
metrics = {}
logits_per_image = image_features @ text_features.t()
logits_per_text = logits_per_image.t()
logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text}
ground_truth = (
torch.arange(len(text_features)).view(-1, 1).to(logits_per_image.device)
)
for name, logit in logits.items():
ranking = torch.argsort(logit, descending=True)
preds = torch.where(ranking == ground_truth)[1]
preds = preds.detach().cpu().numpy()
metrics[f"{name}_mean_rank"] = preds.mean() + 1
metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1
for k in [1, 5, 10]:
metrics[f"{name}_R@{k}"] = np.mean(preds < k)
return metrics

90
src/training/zero_shot.py Normal file
View File

@@ -0,0 +1,90 @@
from tqdm import tqdm
import torch
import clip.clip as clip
from .imagenet_zeroshot_data import imagenet_classnames, openai_imagenet_template
import logging
def zero_shot_classifier(model, classnames, templates, args):
with torch.no_grad():
zeroshot_weights = []
for classname in tqdm(classnames):
texts = [template(classname) for template in templates] #format with class
texts = clip.tokenize(texts).to(args.gpu) #tokenize
if args.distributed:
class_embeddings = model.module.encode_text(texts)
elif args.dp:
class_embeddings = model(None, texts)
else:
class_embeddings = model.encode_text(texts)
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
class_embedding = class_embeddings.mean(dim=0)
class_embedding /= class_embedding.norm()
zeroshot_weights.append(class_embedding)
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(args.gpu)
return zeroshot_weights
def accuracy(output, target, topk=(1,)):
pred = output.topk(max(topk), 1, True, True)[1].t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]
def run(model, classifier, dataloader, args):
with torch.no_grad():
top1, top5, n = 0., 0., 0.
for images, target in tqdm(dataloader):
images = images.to(args.gpu)
target = target.to(args.gpu)
# predict
if args.distributed:
image_features = model.module.encode_image(images)
elif args.dp:
image_features = model(images, None)
else:
image_features = model.encode_image(images)
image_features /= image_features.norm(dim=-1, keepdim=True)
logits = 100. * image_features @ classifier
# measure accuracy
acc1, acc5 = accuracy(logits, target, topk=(1, 5))
top1 += acc1
top5 += acc5
n += images.size(0)
top1 = (top1 / n)
top5 = (top5 / n)
return top1, top5
def zero_shot_eval(model, data, epoch, args):
if 'imagenet-val' not in data and 'imagenet-v2' not in data:
return {}
if args.zeroshot_frequency == 0:
return {}
if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs:
return {}
logging.info('Starting zero-shot imagenet.')
logging.info('Building zero-shot classifier')
classifier = zero_shot_classifier(model, imagenet_classnames, openai_imagenet_template, args)
logging.info('Using classifier')
results = {}
if 'imagenet-val' in data:
top1, top5 = run(model, classifier, data['imagenet-val'].dataloader, args)
results['imagenet-zeroshot-val-top1'] = top1
results['imagenet-zeroshot-val-top5'] = top5
if 'imagenet-v2' in data:
top1, top5 = run(model, classifier, data['imagenet-v2'].dataloader, args)
results['imagenetv2-zeroshot-val-top1'] = top1
results['imagenetv2-zeroshot-val-top5'] = top5
logging.info('Finished zero-shot imagenet.')
return results