Initial commit.
This commit is contained in:
150
.gitignore
vendored
Normal file
150
.gitignore
vendored
Normal file
@@ -0,0 +1,150 @@
|
||||
logs/
|
||||
wandb/
|
||||
models/
|
||||
features/
|
||||
results/
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
sync.sh
|
||||
gpu1sync.sh
|
||||
.idea
|
||||
*.pdf
|
||||
**/._*
|
||||
**/*DS_*
|
||||
**.jsonl
|
||||
src/sbatch
|
||||
src/misc
|
||||
.vscode
|
||||
src/debug
|
||||
core.*
|
||||
|
||||
# Allow
|
||||
!src/evaluation/misc/results_dbs/*
|
||||
28
CITATION.cff
Normal file
28
CITATION.cff
Normal file
@@ -0,0 +1,28 @@
|
||||
cff-version: 1.1.0
|
||||
message: If you use this software, please cite it as below.
|
||||
authors:
|
||||
- family-names: Ilharco
|
||||
given-names: Gabriel
|
||||
- family-names: Wortsman
|
||||
given-names: Mitchell
|
||||
- family-names: Carlini
|
||||
given-names: Nicholas
|
||||
- family-names: Taori
|
||||
given-names: Rohan
|
||||
- family-names: Dave
|
||||
given-names: Achal
|
||||
- family-names: Shankar
|
||||
given-names: Vaishaal
|
||||
- family-names: Namkoong
|
||||
given-names: Hongseok
|
||||
- family-names: Miller
|
||||
given-names: John
|
||||
- family-names: Hajishirzi
|
||||
given-names: Hannaneh
|
||||
- family-names: Farhadi
|
||||
given-names: Ali
|
||||
- family-names: Schmidt
|
||||
given-names: Ludwig
|
||||
title: Open Clip
|
||||
version: 0.0.1
|
||||
date-released: 2021-07-28
|
||||
20
LICENSE
Normal file
20
LICENSE
Normal file
@@ -0,0 +1,20 @@
|
||||
Copyright (c) 2012-2021 Scott Chacon and others
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining
|
||||
a copy of this software and associated documentation files (the
|
||||
"Software"), to deal in the Software without restriction, including
|
||||
without limitation the rights to use, copy, modify, merge, publish,
|
||||
distribute, sublicense, and/or sell copies of the Software, and to
|
||||
permit persons to whom the Software is furnished to do so, subject to
|
||||
the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be
|
||||
included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
141
README.md
141
README.md
@@ -1,2 +1,139 @@
|
||||
# open_clip
|
||||
An open source implementation of CLIP.
|
||||
# OpenCLIP
|
||||
|
||||
Welcome to an open source implementation of OpenAI's [CLIP](https://arxiv.org/abs/2103.00020) (Contrastive Language-Image Pre-training).
|
||||
|
||||
The goal of this repository is to match the accuracy of the original CLIP models when trained on the same dataset. For example, our implementation reaches 22.2% top-1 ImageNet when training a ResNet 50x4 on the 3 million images in the Conceptual Captions dataset, and 32.7% top-1 ImageNet accuracy when training a RN50 on OpenAI's [15 million image subset of YFCC](https://github.com/openai/CLIP/blob/main/data/yfcc100m.md). OpenAI's CLIP model reaches 31.3% on the same subset of YFCC.
|
||||
|
||||
Note that `src/clip` is a copy of OpenAI's official [repo](https://github.com/openai/CLIP) with minimal changes.
|
||||
|
||||
## Data
|
||||
|
||||
|
||||
### Conceptual Captions
|
||||
|
||||
OpenCLIP reads a CSV file with two columns: a path to an image, and a text caption. The names of the columns are passed as an argument to `main.py`.
|
||||
|
||||
The script `src/data/gather_cc.py` will collect the Conceptual Captions images. First, download the [Conceptual Captions URLs](https://ai.google.com/research/ConceptualCaptions/download) and then run the following script from our repository:
|
||||
|
||||
```
|
||||
python3 src/data/gather_cc.py path/to/Train_GCC-training.tsv path/to/Validation_GCC-1.1.0-Validation.tsv
|
||||
```
|
||||
|
||||
Our training set contains 2.89M images, and our validation set contains 13K images.
|
||||
|
||||
|
||||
### YFCC and other datasets
|
||||
|
||||
In addition to specifying the training data via CSV files as mentioned above, our codebase also supports [webdataset](https://github.com/webdataset/webdataset), which is recommended for larger scale datasets. The expected format is a series of `.tar` files. Each of these `.tar` files should contain two files for each training example, one for the image and one for the corresponding text. Both files should have the same name but different extensions. For instance, `shard_001.tar` could contain files such as `abc.jpg` and `abc.txt`. You can learn more about `webdataset` at [https://github.com/webdataset/webdataset](https://github.com/webdataset/webdataset). We use `.tar` files with 1000 data points each, which we create using [tarp](https://github.com/webdataset/tarp).
|
||||
|
||||
You can download the YFCC dataset from [Multimedia Commons](http://mmcommons.org/).
|
||||
Similar to OpenAI, we used a subset of YFCC to reach the aforementioned accuracy numbers.
|
||||
The indices of images in this subset are in [OpenAI's CLIP repository](https://github.com/openai/CLIP/blob/main/data/yfcc100m.md).
|
||||
|
||||
|
||||
## Training CLIP
|
||||
|
||||
### Install dependencies
|
||||
|
||||
```
|
||||
conda env create -f environment.yml
|
||||
source activate open_clip
|
||||
```
|
||||
|
||||
### Add directory to pythonpath:
|
||||
|
||||
```
|
||||
cd open_clip
|
||||
export PYTHONPATH="$PYTHONPATH:$PWD/src"
|
||||
```
|
||||
|
||||
|
||||
### Sample running code:
|
||||
|
||||
```
|
||||
nohup python -u src/training/main.py \
|
||||
--save-frequency 1 \
|
||||
--zeroshot-frequency 1 \
|
||||
--report-to tensorboard \
|
||||
--train-data="/path/to/train_data.csv" \
|
||||
--val-data="/path/to/validation_data.csv" \
|
||||
--csv-img-key filepath \
|
||||
--csv-caption-key title \
|
||||
--imagenet-val=/path/to/imagenet/root/val/ \
|
||||
--warmup 10000 \
|
||||
--batch-size=128 \
|
||||
--lr=1e-3 \
|
||||
--wd=0.1 \
|
||||
--epochs=30 \
|
||||
--workers=8
|
||||
```
|
||||
|
||||
Note: `imagenet-val` is the path to the *val* set of ImageNet for zeroshot evaluation, not the train set!
|
||||
You can remove this argument if you do not want to perform zeroshot on imagenet throughout training. Note that the `val` folder should contain subfolders, if it doesn't please use [this script](https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh).
|
||||
|
||||
This command should produce the following training curve:
|
||||
|
||||

|
||||
|
||||
More detailed curves are given in [/docs/clip_conceptual_captions.md](/docs/clip_conceptual_captions.md)
|
||||
|
||||
### Launch tensorboard:
|
||||
```
|
||||
tensorboard --logdir=logs/tensorboard/ --port=7777
|
||||
```
|
||||
|
||||
### Sample resuming from a checkpoint:
|
||||
|
||||
```bash
|
||||
python src/training/main.py \
|
||||
--train-data="/path/to/train_data.csv" \
|
||||
--val-data="/path/to/validation_data.csv" \
|
||||
--resume /path/to/checkpoints/epoch_K.pt
|
||||
```
|
||||
|
||||
### Sample evaluation only:
|
||||
|
||||
```bash
|
||||
python src/training/main.py \
|
||||
--val-data="/path/to/validation_data.csv" \
|
||||
--resume /path/to/checkpoints/epoch_K.pt
|
||||
```
|
||||
|
||||
## Scaling trends
|
||||
|
||||
The plot below shows how zero-shot performance of CLIP models varies as we scale the number of samples used for training. Zero-shot performance increases steadily for both ImageNet and [ImageNetV2](https://arxiv.org/abs/1902.10811), and is far from saturated at ~15M samples.
|
||||
|
||||
<img src="docs/scaling.png" width="700">
|
||||
|
||||
## Why are low-accuracy CLIP models interesting?
|
||||
|
||||
**TL;DR:** CLIP models have high effective robustness, even at small scales.
|
||||
|
||||
CLIP models are particularly intriguing because they are more robust to natural distribution shifts.
|
||||
This phenomena is illustrated by the figure below, with ImageNet accuracy on the x-axis
|
||||
and [ImageNetV2](https://arxiv.org/abs/1902.10811) (a reproduction of the ImageNet) accuracy on the y-axis.
|
||||
Standard training denotes training on the ImageNet train set while the CLIP zero-shot models
|
||||
are shown with stars.
|
||||
|
||||

|
||||
|
||||
As observed by [Taori et al., 2020](https://arxiv.org/abs/2007.00644), in-distribution
|
||||
and out-of-distribution accuracy follow a predictable linear trend. Effective robustness
|
||||
measures movement above this red line. Even though the models trained with
|
||||
this codebase are much lower accuracy than those trained by OpenAI, they lie on the same
|
||||
trend of improved effective robustness (purple line). Therefore, we can study what makes
|
||||
CLIP robust without needing industry compute.
|
||||
|
||||
For more more information on effective robustness please see:
|
||||
|
||||
- [Recht et al., 2019](https://arxiv.org/abs/1902.10811).
|
||||
- [Taori et al., 2020](https://arxiv.org/abs/2007.00644).
|
||||
- [Miller et al., 2021](https://arxiv.org/abs/2107.04649).
|
||||
|
||||
## The Team
|
||||
|
||||
|
||||
[Gabriel Ilharco*](http://gabrielilharco.com/), [Mitchell Wortsman*](https://mitchellnw.github.io/), [Nicholas Carlini](https://nicholas.carlini.com/), [Rohan Taori](https://www.rohantaori.com/), [Achal Dave](http://www.achaldave.com/), [Vaishaal Shankar](http://vaishaal.com/), [
|
||||
Hongseok Namkoong](https://hsnamkoong.github.io/), [John Miller](https://people.eecs.berkeley.edu/~miller_john/), [Hannaneh Hajishirzi](https://homes.cs.washington.edu/~hannaneh/), [Ali Farhadi](https://homes.cs.washington.edu/~ali/), [Ludwig Schmidt](https://people.csail.mit.edu/ludwigs/)
|
||||
|
||||
Special thanks to Jong Wook Kim and Alec Radford!
|
||||
|
||||
13
docs/clip_conceptual_captions.md
Normal file
13
docs/clip_conceptual_captions.md
Normal file
@@ -0,0 +1,13 @@
|
||||
## Additional training curves for CLIP on Conceptual Captions
|
||||
|
||||
# Zero shot accuracy
|
||||

|
||||
|
||||
# Training loss curve
|
||||

|
||||
|
||||
# Validation loss curve
|
||||

|
||||
|
||||
# Validation recall
|
||||

|
||||
BIN
docs/clip_loss.png
Normal file
BIN
docs/clip_loss.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 42 KiB |
BIN
docs/clip_recall.png
Normal file
BIN
docs/clip_recall.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 50 KiB |
BIN
docs/clip_val_loss.png
Normal file
BIN
docs/clip_val_loss.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 43 KiB |
BIN
docs/clip_zeroshot.png
Normal file
BIN
docs/clip_zeroshot.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 57 KiB |
BIN
docs/effective_robustness.png
Normal file
BIN
docs/effective_robustness.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 995 KiB |
BIN
docs/scaling.png
Normal file
BIN
docs/scaling.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 96 KiB |
152
environment.yml
Normal file
152
environment.yml
Normal file
@@ -0,0 +1,152 @@
|
||||
name: open_clip
|
||||
channels:
|
||||
- pytorch
|
||||
- conda-forge
|
||||
- defaults
|
||||
dependencies:
|
||||
- _libgcc_mutex=0.1=main
|
||||
- absl-py=0.12.0=py36h06a4308_0
|
||||
- aiohttp=3.6.3=py36h7b6447c_0
|
||||
- async-timeout=3.0.1=py36h06a4308_0
|
||||
- attrs=20.3.0=pyhd3eb1b0_0
|
||||
- blas=1.0=mkl
|
||||
- blinker=1.4=py36h06a4308_0
|
||||
- brotlipy=0.7.0=py36h27cfd23_1003
|
||||
- c-ares=1.17.1=h27cfd23_0
|
||||
- ca-certificates=2020.12.5=ha878542_0
|
||||
- cachetools=4.2.1=pyhd3eb1b0_0
|
||||
- certifi=2020.12.5=py36h5fab9bb_1
|
||||
- cffi=1.14.5=py36h261ae71_0
|
||||
- chardet=3.0.4=py36h06a4308_1003
|
||||
- click=7.1.2=pyhd3eb1b0_0
|
||||
- coverage=5.5=py36h27cfd23_2
|
||||
- cryptography=3.4.7=py36hd23ed53_0
|
||||
- cudatoolkit=11.0.221=h6bb024c_0
|
||||
- cython=0.29.23=py36h2531618_0
|
||||
- dataclasses=0.8=pyh4f3eec9_6
|
||||
- faiss-gpu=1.4.0=py36_cuda8.0.61_1
|
||||
- freetype=2.10.4=h5ab3b9f_0
|
||||
- ftfy=5.8=py_0
|
||||
- google-auth=1.29.0=pyhd3eb1b0_0
|
||||
- google-auth-oauthlib=0.4.4=pyhd3eb1b0_0
|
||||
- grpcio=1.36.1=py36h2157cd5_1
|
||||
- idna=2.10=pyhd3eb1b0_0
|
||||
- idna_ssl=1.1.0=py36h06a4308_0
|
||||
- importlib-metadata=3.10.0=py36h06a4308_0
|
||||
- intel-openmp=2021.2.0=h06a4308_610
|
||||
- joblib=1.0.1=pyhd8ed1ab_0
|
||||
- jpeg=9b=h024ee3a_2
|
||||
- lcms2=2.12=h3be6417_0
|
||||
- ld_impl_linux-64=2.33.1=h53a641e_7
|
||||
- libblas=3.9.0=1_h6e990d7_netlib
|
||||
- libcblas=3.9.0=3_h893e4fe_netlib
|
||||
- libffi=3.3=he6710b0_2
|
||||
- libgcc=7.2.0=h69d50b8_2
|
||||
- libgcc-ng=9.1.0=hdf63c60_0
|
||||
- libgfortran-ng=7.5.0=h14aa051_19
|
||||
- libgfortran4=7.5.0=h14aa051_19
|
||||
- liblapack=3.9.0=3_h893e4fe_netlib
|
||||
- libpng=1.6.37=hbc83047_0
|
||||
- libprotobuf=3.14.0=h8c45485_0
|
||||
- libstdcxx-ng=9.1.0=hdf63c60_0
|
||||
- libtiff=4.1.0=h2733197_1
|
||||
- libuv=1.40.0=h7b6447c_0
|
||||
- lz4-c=1.9.3=h2531618_0
|
||||
- markdown=3.3.4=py36h06a4308_0
|
||||
- mkl=2020.2=256
|
||||
- mkl-service=2.3.0=py36he8ac12f_0
|
||||
- mkl_fft=1.3.0=py36h54f3939_0
|
||||
- mkl_random=1.1.1=py36h0573a6f_0
|
||||
- multidict=4.7.6=py36h7b6447c_1
|
||||
- ncurses=6.2=he6710b0_1
|
||||
- ninja=1.10.2=hff7bd54_1
|
||||
- numpy=1.19.2=py36h54aff64_0
|
||||
- numpy-base=1.19.2=py36hfa32c7d_0
|
||||
- oauthlib=3.1.0=py_0
|
||||
- olefile=0.46=py36_0
|
||||
- openssl=1.1.1k=h27cfd23_0
|
||||
- pandas=1.1.3=py36he6710b0_0
|
||||
- pillow=8.2.0=py36he98fc37_0
|
||||
- pip=21.0.1=py36h06a4308_0
|
||||
- protobuf=3.14.0=py36h2531618_1
|
||||
- pyasn1=0.4.8=py_0
|
||||
- pyasn1-modules=0.2.8=py_0
|
||||
- pycparser=2.20=py_2
|
||||
- pyjwt=1.7.1=py36_0
|
||||
- pyopenssl=20.0.1=pyhd3eb1b0_1
|
||||
- pysocks=1.7.1=py36h06a4308_0
|
||||
- python=3.6.13=hdb3f193_0
|
||||
- python-dateutil=2.8.1=pyhd3eb1b0_0
|
||||
- python_abi=3.6=1_cp36m
|
||||
- pytorch=1.7.1=py3.6_cuda11.0.221_cudnn8.0.5_0
|
||||
- pytz=2021.1=pyhd3eb1b0_0
|
||||
- readline=8.1=h27cfd23_0
|
||||
- regex=2021.4.4=py36h27cfd23_0
|
||||
- requests=2.25.1=pyhd3eb1b0_0
|
||||
- requests-oauthlib=1.3.0=py_0
|
||||
- rsa=4.7.2=pyhd3eb1b0_1
|
||||
- scikit-learn=0.23.2=py36hb6e6923_3
|
||||
- scipy=1.5.3=py36h976291a_0
|
||||
- setuptools=52.0.0=py36h06a4308_0
|
||||
- six=1.15.0=py36h06a4308_0
|
||||
- sqlite=3.35.4=hdfb4753_0
|
||||
- tensorboard=2.4.0=pyhc547734_0
|
||||
- tensorboard-plugin-wit=1.6.0=py_0
|
||||
- threadpoolctl=2.1.0=pyh5ca1d4c_0
|
||||
- tk=8.6.10=hbc83047_0
|
||||
- torchaudio=0.7.2=py36
|
||||
- torchvision=0.8.2=py36_cu110
|
||||
- tqdm=4.59.0=pyhd3eb1b0_1
|
||||
- typing_extensions=3.7.4.3=pyha847dfd_0
|
||||
- urllib3=1.26.4=pyhd3eb1b0_0
|
||||
- wcwidth=0.2.5=py_0
|
||||
- werkzeug=1.0.1=pyhd3eb1b0_0
|
||||
- wheel=0.36.2=pyhd3eb1b0_0
|
||||
- xz=5.2.5=h7b6447c_0
|
||||
- yarl=1.6.3=py36h27cfd23_0
|
||||
- zipp=3.4.1=pyhd3eb1b0_0
|
||||
- zlib=1.2.11=h7b6447c_3
|
||||
- zstd=1.4.9=haebb681_0
|
||||
- pip:
|
||||
- ase==3.21.1
|
||||
- braceexpand==0.1.7
|
||||
- cached-property==1.5.2
|
||||
- configparser==5.0.2
|
||||
- cycler==0.10.0
|
||||
- decorator==4.4.2
|
||||
- docker-pycreds==0.4.0
|
||||
- gitdb==4.0.7
|
||||
- gitpython==3.1.14
|
||||
- googledrivedownloader==0.4
|
||||
- h5py==3.1.0
|
||||
- isodate==0.6.0
|
||||
- jinja2==3.0.1
|
||||
- kiwisolver==1.3.1
|
||||
- littleutils==0.2.2
|
||||
- llvmlite==0.36.0
|
||||
- markupsafe==2.0.1
|
||||
- matplotlib==3.3.4
|
||||
- networkx==2.5.1
|
||||
- numba==0.53.1
|
||||
- ogb==1.3.1
|
||||
- outdated==0.2.1
|
||||
- pathtools==0.1.2
|
||||
- promise==2.3
|
||||
- psutil==5.8.0
|
||||
- pyarrow==4.0.0
|
||||
- pyparsing==2.4.7
|
||||
- python-louvain==0.15
|
||||
- pyyaml==5.4.1
|
||||
- rdflib==5.0.0
|
||||
- sentry-sdk==1.1.0
|
||||
- shortuuid==1.0.1
|
||||
- sklearn==0.0
|
||||
- smmap==4.0.0
|
||||
- subprocess32==3.5.4
|
||||
- torch-geometric==1.7.0
|
||||
- wandb==0.10.30
|
||||
- wilds==1.1.0
|
||||
- "--editable=git+https://github.com/tmbdev/webdataset.git@a4f3ec08551b42f20b20cdc1ba32d12536eabc15#egg=webdataset"
|
||||
- git+https://github.com/modestyachts/ImageNetV2_pytorch
|
||||
- https://pytorch-geometric.com/whl/torch-1.7.0+cu110/torch_scatter-2.0.6-cp36-cp36m-linux_x86_64.whl
|
||||
prefix: /home/gamaga/anaconda3/envs/open_clip
|
||||
BIN
src/clip/bpe_simple_vocab_16e6.txt.gz
Normal file
BIN
src/clip/bpe_simple_vocab_16e6.txt.gz
Normal file
Binary file not shown.
208
src/clip/clip.py
Normal file
208
src/clip/clip.py
Normal file
@@ -0,0 +1,208 @@
|
||||
# Code ported from https://github.com/openai/CLIP
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import urllib
|
||||
import warnings
|
||||
from typing import Union, List
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, RandomResizedCrop
|
||||
from tqdm import tqdm
|
||||
|
||||
from clip.model import build_model
|
||||
from clip.tokenizer import SimpleTokenizer as _Tokenizer
|
||||
|
||||
__all__ = ["available_models", "load", "tokenize"]
|
||||
_tokenizer = _Tokenizer()
|
||||
|
||||
_MODELS = {
|
||||
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
|
||||
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
|
||||
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
|
||||
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
||||
}
|
||||
|
||||
|
||||
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
|
||||
os.makedirs(root, exist_ok=True)
|
||||
filename = os.path.basename(url)
|
||||
|
||||
expected_sha256 = url.split("/")[-2]
|
||||
download_target = os.path.join(root, filename)
|
||||
|
||||
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
||||
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
||||
|
||||
if os.path.isfile(download_target):
|
||||
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
|
||||
return download_target
|
||||
else:
|
||||
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
||||
|
||||
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
||||
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
|
||||
while True:
|
||||
buffer = source.read(8192)
|
||||
if not buffer:
|
||||
break
|
||||
|
||||
output.write(buffer)
|
||||
loop.update(len(buffer))
|
||||
|
||||
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
|
||||
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
|
||||
|
||||
return download_target
|
||||
|
||||
def _convert_to_rgb(image):
|
||||
return image.convert('RGB')
|
||||
|
||||
def _transform(n_px: int, is_train: bool):
|
||||
normalize = Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
||||
if is_train:
|
||||
return Compose([
|
||||
RandomResizedCrop(n_px, scale=(0.9, 1.0), interpolation=Image.BICUBIC),
|
||||
_convert_to_rgb,
|
||||
ToTensor(),
|
||||
normalize,
|
||||
])
|
||||
else:
|
||||
return Compose([
|
||||
Resize(n_px, interpolation=Image.BICUBIC),
|
||||
CenterCrop(n_px),
|
||||
_convert_to_rgb,
|
||||
ToTensor(),
|
||||
normalize,
|
||||
])
|
||||
|
||||
|
||||
|
||||
def available_models() -> List[str]:
|
||||
"""Returns the names of available CLIP models"""
|
||||
return list(_MODELS.keys())
|
||||
|
||||
|
||||
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True, is_train=False, pretrained=True):
|
||||
"""Load a CLIP model
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
||||
device : Union[str, torch.device]
|
||||
The device to put the loaded model
|
||||
jit : bool
|
||||
Whether to load the optimized JIT model (default) or more hackable non-JIT model.
|
||||
Returns
|
||||
-------
|
||||
model : torch.nn.Module
|
||||
The CLIP model
|
||||
preprocess : Callable[[PIL.Image], torch.Tensor]
|
||||
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
||||
"""
|
||||
if name in _MODELS:
|
||||
model_path = _download(_MODELS[name])
|
||||
elif os.path.isfile(name):
|
||||
model_path = name
|
||||
else:
|
||||
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
||||
|
||||
try:
|
||||
# loading JIT archive
|
||||
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
|
||||
state_dict = None
|
||||
except RuntimeError:
|
||||
# loading saved state dict
|
||||
if jit:
|
||||
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
||||
jit = False
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
|
||||
if not jit:
|
||||
try:
|
||||
model = build_model(state_dict or model.state_dict()).to(device)
|
||||
except KeyError:
|
||||
sd = {k[7:]: v for k,v in state_dict["state_dict"].items()}
|
||||
model = build_model(sd).to(device)
|
||||
|
||||
if str(device) == "cpu":
|
||||
model.float()
|
||||
return model, \
|
||||
_transform(model.visual.input_resolution, is_train=True), \
|
||||
_transform(model.visual.input_resolution, is_train=False)
|
||||
|
||||
# patch the device names
|
||||
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
||||
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
||||
|
||||
def patch_device(module):
|
||||
graphs = [module.graph] if hasattr(module, "graph") else []
|
||||
if hasattr(module, "forward1"):
|
||||
graphs.append(module.forward1.graph)
|
||||
|
||||
for graph in graphs:
|
||||
for node in graph.findAllNodes("prim::Constant"):
|
||||
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
|
||||
node.copyAttributes(device_node)
|
||||
|
||||
model.apply(patch_device)
|
||||
patch_device(model.encode_image)
|
||||
patch_device(model.encode_text)
|
||||
|
||||
# patch dtype to float32 on CPU
|
||||
if str(device) == "cpu":
|
||||
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
||||
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
||||
float_node = float_input.node()
|
||||
|
||||
def patch_float(module):
|
||||
graphs = [module.graph] if hasattr(module, "graph") else []
|
||||
if hasattr(module, "forward1"):
|
||||
graphs.append(module.forward1.graph)
|
||||
|
||||
for graph in graphs:
|
||||
for node in graph.findAllNodes("aten::to"):
|
||||
inputs = list(node.inputs())
|
||||
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
||||
if inputs[i].node()["value"] == 5:
|
||||
inputs[i].node().copyAttributes(float_node)
|
||||
|
||||
model.apply(patch_float)
|
||||
patch_float(model.encode_image)
|
||||
patch_float(model.encode_text)
|
||||
|
||||
model.float()
|
||||
|
||||
return model, \
|
||||
_transform(model.input_resolution.item(), is_train=True), \
|
||||
_transform(model.input_resolution.item(), is_train=False)
|
||||
|
||||
|
||||
def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
|
||||
"""
|
||||
Returns the tokenized representation of given input string(s)
|
||||
Parameters
|
||||
----------
|
||||
texts : Union[str, List[str]]
|
||||
An input string or a list of input strings to tokenize
|
||||
context_length : int
|
||||
The context length to use; all CLIP models use 77 as the context length
|
||||
Returns
|
||||
-------
|
||||
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
||||
"""
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
|
||||
sot_token = _tokenizer.encoder["<start_of_text>"]
|
||||
eot_token = _tokenizer.encoder["<end_of_text>"]
|
||||
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
||||
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
||||
|
||||
for i, tokens in enumerate(all_tokens):
|
||||
if len(tokens) > context_length: # Truncate
|
||||
tokens = tokens[:context_length]
|
||||
result[i, :len(tokens)] = torch.tensor(tokens)
|
||||
|
||||
return result
|
||||
432
src/clip/model.py
Normal file
432
src/clip/model.py
Normal file
@@ -0,0 +1,432 @@
|
||||
from collections import OrderedDict
|
||||
from typing import Tuple, Union
|
||||
|
||||
import os
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1):
|
||||
super().__init__()
|
||||
|
||||
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
|
||||
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
|
||||
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
||||
|
||||
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = None
|
||||
self.stride = stride
|
||||
|
||||
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
||||
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
||||
self.downsample = nn.Sequential(OrderedDict([
|
||||
("-1", nn.AvgPool2d(stride)),
|
||||
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
||||
("1", nn.BatchNorm2d(planes * self.expansion))
|
||||
]))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
identity = x
|
||||
|
||||
out = self.relu(self.bn1(self.conv1(x)))
|
||||
out = self.relu(self.bn2(self.conv2(out)))
|
||||
out = self.avgpool(out)
|
||||
out = self.bn3(self.conv3(out))
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class AttentionPool2d(nn.Module):
|
||||
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
||||
super().__init__()
|
||||
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
||||
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
||||
self.num_heads = num_heads
|
||||
|
||||
def forward(self, x):
|
||||
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
|
||||
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
||||
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
||||
x, _ = F.multi_head_attention_forward(
|
||||
query=x, key=x, value=x,
|
||||
embed_dim_to_check=x.shape[-1],
|
||||
num_heads=self.num_heads,
|
||||
q_proj_weight=self.q_proj.weight,
|
||||
k_proj_weight=self.k_proj.weight,
|
||||
v_proj_weight=self.v_proj.weight,
|
||||
in_proj_weight=None,
|
||||
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
||||
bias_k=None,
|
||||
bias_v=None,
|
||||
add_zero_attn=False,
|
||||
dropout_p=0,
|
||||
out_proj_weight=self.c_proj.weight,
|
||||
out_proj_bias=self.c_proj.bias,
|
||||
use_separate_proj_weight=True,
|
||||
training=self.training,
|
||||
need_weights=False
|
||||
)
|
||||
|
||||
return x[0]
|
||||
|
||||
|
||||
class ModifiedResNet(nn.Module):
|
||||
"""
|
||||
A ResNet class that is similar to torchvision's but contains the following changes:
|
||||
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
||||
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
||||
- The final pooling layer is a QKV attention instead of an average pool
|
||||
"""
|
||||
|
||||
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
||||
super().__init__()
|
||||
self.output_dim = output_dim
|
||||
self.input_resolution = input_resolution
|
||||
|
||||
# the 3-layer stem
|
||||
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(width // 2)
|
||||
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(width // 2)
|
||||
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(width)
|
||||
self.avgpool = nn.AvgPool2d(2)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
# residual layers
|
||||
self._inplanes = width # this is a *mutable* variable used during construction
|
||||
self.layer1 = self._make_layer(width, layers[0])
|
||||
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
||||
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
||||
|
||||
embed_dim = width * 32 # the ResNet feature dimension
|
||||
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
||||
|
||||
def _make_layer(self, planes, blocks, stride=1):
|
||||
layers = [Bottleneck(self._inplanes, planes, stride)]
|
||||
|
||||
self._inplanes = planes * Bottleneck.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(Bottleneck(self._inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
def stem(x):
|
||||
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
|
||||
x = self.relu(bn(conv(x)))
|
||||
x = self.avgpool(x)
|
||||
return x
|
||||
|
||||
x = x.type(self.conv1.weight.dtype)
|
||||
x = stem(x)
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
x = self.attnpool(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
"""Subclass torch's LayerNorm to handle fp16."""
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
orig_type = x.dtype
|
||||
ret = super().forward(x.type(torch.float32))
|
||||
return ret.type(orig_type)
|
||||
|
||||
|
||||
class QuickGELU(nn.Module):
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
||||
super().__init__()
|
||||
|
||||
self.attn = nn.MultiheadAttention(d_model, n_head)
|
||||
self.ln_1 = LayerNorm(d_model)
|
||||
self.mlp = nn.Sequential(OrderedDict([
|
||||
("c_fc", nn.Linear(d_model, d_model * 4)),
|
||||
("gelu", QuickGELU()),
|
||||
("c_proj", nn.Linear(d_model * 4, d_model))
|
||||
]))
|
||||
self.ln_2 = LayerNorm(d_model)
|
||||
self.attn_mask = attn_mask
|
||||
|
||||
def attention(self, x: torch.Tensor):
|
||||
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
||||
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = x + self.attention(self.ln_1(x))
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.layers = layers
|
||||
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return self.resblocks(x)
|
||||
|
||||
|
||||
class VisualTransformer(nn.Module):
|
||||
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
|
||||
super().__init__()
|
||||
self.input_resolution = input_resolution
|
||||
self.output_dim = output_dim
|
||||
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
||||
|
||||
scale = width ** -0.5
|
||||
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
||||
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
||||
self.ln_pre = LayerNorm(width)
|
||||
|
||||
self.transformer = Transformer(width, layers, heads)
|
||||
|
||||
self.ln_post = LayerNorm(width)
|
||||
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = self.conv1(x) # shape = [*, width, grid, grid]
|
||||
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
||||
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
||||
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
||||
x = x + self.positional_embedding.to(x.dtype)
|
||||
x = self.ln_pre(x)
|
||||
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.transformer(x)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
|
||||
x = self.ln_post(x[:, 0, :])
|
||||
|
||||
if self.proj is not None:
|
||||
x = x @ self.proj
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class CLIP(nn.Module):
|
||||
def __init__(self,
|
||||
embed_dim: int,
|
||||
# vision
|
||||
image_resolution: int,
|
||||
vision_layers: Union[Tuple[int, int, int, int], int],
|
||||
vision_width: int,
|
||||
vision_patch_size: int,
|
||||
# text
|
||||
context_length: int,
|
||||
vocab_size: int,
|
||||
transformer_width: int,
|
||||
transformer_heads: int,
|
||||
transformer_layers: int
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.context_length = context_length
|
||||
|
||||
if isinstance(vision_layers, (tuple, list)):
|
||||
vision_heads = vision_width * 32 // 64
|
||||
self.visual = ModifiedResNet(
|
||||
layers=vision_layers,
|
||||
output_dim=embed_dim,
|
||||
heads=vision_heads,
|
||||
input_resolution=image_resolution,
|
||||
width=vision_width
|
||||
)
|
||||
else:
|
||||
vision_heads = vision_width // 64
|
||||
self.visual = VisualTransformer(
|
||||
input_resolution=image_resolution,
|
||||
patch_size=vision_patch_size,
|
||||
width=vision_width,
|
||||
layers=vision_layers,
|
||||
heads=vision_heads,
|
||||
output_dim=embed_dim
|
||||
)
|
||||
|
||||
self.transformer = Transformer(
|
||||
width=transformer_width,
|
||||
layers=transformer_layers,
|
||||
heads=transformer_heads,
|
||||
attn_mask=self.build_attention_mask()
|
||||
)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
||||
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
||||
self.ln_final = LayerNorm(transformer_width)
|
||||
|
||||
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
||||
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
||||
|
||||
self.initialize_parameters()
|
||||
|
||||
def initialize_parameters(self):
|
||||
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
||||
nn.init.normal_(self.positional_embedding, std=0.01)
|
||||
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
||||
|
||||
if isinstance(self.visual, ModifiedResNet):
|
||||
if self.visual.attnpool is not None:
|
||||
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
||||
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
||||
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
||||
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
||||
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
||||
|
||||
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
||||
for name, param in resnet_block.named_parameters():
|
||||
if name.endswith("bn3.weight"):
|
||||
nn.init.zeros_(param)
|
||||
|
||||
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
||||
attn_std = self.transformer.width ** -0.5
|
||||
fc_std = (2 * self.transformer.width) ** -0.5
|
||||
for block in self.transformer.resblocks:
|
||||
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
||||
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
||||
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
||||
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
||||
|
||||
if self.text_projection is not None:
|
||||
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
||||
|
||||
def build_attention_mask(self):
|
||||
# lazily create causal attention mask, with full attention between the vision tokens
|
||||
# pytorch uses additive attention mask; fill with -inf
|
||||
mask = torch.empty(self.context_length, self.context_length)
|
||||
mask.fill_(float("-inf"))
|
||||
mask.triu_(1) # zero out the lower diagonal
|
||||
return mask
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.visual.conv1.weight.dtype
|
||||
|
||||
def encode_image(self, image):
|
||||
return self.visual(image.type(self.dtype))
|
||||
|
||||
def encode_text(self, text):
|
||||
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
||||
|
||||
x = x + self.positional_embedding.type(self.dtype)
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.transformer(x)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
x = self.ln_final(x).type(self.dtype)
|
||||
|
||||
# x.shape = [batch_size, n_ctx, transformer.width]
|
||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, image, text):
|
||||
if image is None:
|
||||
return self.encode_text(text)
|
||||
elif text is None:
|
||||
return self.encode_image(image)
|
||||
image_features = self.encode_image(image)
|
||||
text_features = self.encode_text(text)
|
||||
|
||||
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
||||
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
return image_features, text_features, self.logit_scale.exp()
|
||||
|
||||
|
||||
def convert_weights(model: nn.Module):
|
||||
"""Convert applicable model parameters to fp16"""
|
||||
|
||||
def _convert_weights_to_fp16(l):
|
||||
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
||||
l.weight.data = l.weight.data.half()
|
||||
if l.bias is not None:
|
||||
l.bias.data = l.bias.data.half()
|
||||
|
||||
if isinstance(l, nn.MultiheadAttention):
|
||||
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
||||
tensor = getattr(l, attr)
|
||||
if tensor is not None:
|
||||
tensor.data = tensor.data.half()
|
||||
|
||||
for name in ["text_projection", "proj"]:
|
||||
if hasattr(l, name):
|
||||
attr = getattr(l, name)
|
||||
if attr is not None:
|
||||
attr.data = attr.data.half()
|
||||
|
||||
model.apply(_convert_weights_to_fp16)
|
||||
|
||||
|
||||
def build_model(state_dict: dict):
|
||||
vit = "visual.proj" in state_dict
|
||||
|
||||
if vit:
|
||||
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
||||
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
||||
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
||||
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
||||
image_resolution = vision_patch_size * grid_size
|
||||
else:
|
||||
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
||||
vision_layers = tuple(counts)
|
||||
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
||||
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
||||
vision_patch_size = None
|
||||
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
||||
image_resolution = output_width * 32
|
||||
|
||||
embed_dim = state_dict["text_projection"].shape[1]
|
||||
context_length = state_dict["positional_embedding"].shape[0]
|
||||
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
||||
transformer_width = state_dict["ln_final.weight"].shape[0]
|
||||
transformer_heads = transformer_width // 64
|
||||
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
||||
|
||||
model = CLIP(
|
||||
embed_dim,
|
||||
image_resolution, vision_layers, vision_width, vision_patch_size,
|
||||
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
|
||||
)
|
||||
|
||||
for key in ["input_resolution", "context_length", "vocab_size"]:
|
||||
if key in state_dict:
|
||||
del state_dict[key]
|
||||
|
||||
convert_weights(model)
|
||||
model.load_state_dict(state_dict)
|
||||
return model.eval()
|
||||
140
src/clip/tokenizer.py
Normal file
140
src/clip/tokenizer.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import gzip
|
||||
import html
|
||||
import os
|
||||
from functools import lru_cache
|
||||
|
||||
import ftfy
|
||||
import regex as re
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def default_bpe():
|
||||
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def bytes_to_unicode():
|
||||
"""
|
||||
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||
The reversible bpe codes work on unicode strings.
|
||||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||
"""
|
||||
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2**8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2**8+n)
|
||||
n += 1
|
||||
cs = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
|
||||
|
||||
def get_pairs(word):
|
||||
"""Return set of symbol pairs in a word.
|
||||
Word is represented as tuple of symbols (symbols being variable-length strings).
|
||||
"""
|
||||
pairs = set()
|
||||
prev_char = word[0]
|
||||
for char in word[1:]:
|
||||
pairs.add((prev_char, char))
|
||||
prev_char = char
|
||||
return pairs
|
||||
|
||||
|
||||
def basic_clean(text):
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
return text.strip()
|
||||
|
||||
|
||||
def whitespace_clean(text):
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
class SimpleTokenizer(object):
|
||||
def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
|
||||
self.byte_encoder = bytes_to_unicode()
|
||||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
||||
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
||||
merges = merges[1:49152-256-2+1]
|
||||
merges = [tuple(merge.split()) for merge in merges]
|
||||
vocab = list(bytes_to_unicode().values())
|
||||
vocab = vocab + [v+'</w>' for v in vocab]
|
||||
for merge in merges:
|
||||
vocab.append(''.join(merge))
|
||||
if not special_tokens:
|
||||
special_tokens = ['<start_of_text>', '<end_of_text>']
|
||||
else:
|
||||
special_tokens = ['<start_of_text>', '<end_of_text>'] + special_tokens
|
||||
vocab.extend(special_tokens)
|
||||
self.encoder = dict(zip(vocab, range(len(vocab))))
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||
self.cache = {t:t for t in special_tokens}
|
||||
special = "|".join(special_tokens)
|
||||
self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
||||
|
||||
self.vocab_size = len(self.encoder)
|
||||
self.all_special_ids = [self.encoder[t] for t in special_tokens]
|
||||
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
||||
pairs = get_pairs(word)
|
||||
|
||||
if not pairs:
|
||||
return token+'</w>'
|
||||
|
||||
while True:
|
||||
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
except:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
|
||||
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
||||
new_word.append(first+second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
else:
|
||||
pairs = get_pairs(word)
|
||||
word = ' '.join(word)
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def encode(self, text):
|
||||
bpe_tokens = []
|
||||
text = whitespace_clean(basic_clean(text)).lower()
|
||||
for token in re.findall(self.pat, text):
|
||||
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
||||
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
||||
return bpe_tokens
|
||||
|
||||
def decode(self, tokens):
|
||||
text = ''.join([self.decoder[token] for token in tokens])
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
||||
return text
|
||||
95
src/data/gather_cc.py
Normal file
95
src/data/gather_cc.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import requests
|
||||
import os
|
||||
import multiprocessing as mp
|
||||
from io import BytesIO
|
||||
import numpy as np
|
||||
import PIL
|
||||
from PIL import Image
|
||||
import pickle
|
||||
import sys
|
||||
|
||||
|
||||
def grab(line):
|
||||
"""
|
||||
Download a single image from the TSV.
|
||||
"""
|
||||
uid, split, line = line
|
||||
try:
|
||||
caption, url = line.split("\t")[:2]
|
||||
except:
|
||||
print("Parse error")
|
||||
return
|
||||
|
||||
if os.path.exists(ROOT+"/%s/%d/%d.jpg"%(split,uid%1000,uid)):
|
||||
print("Finished", uid)
|
||||
return uid, caption, url
|
||||
|
||||
# Let's not crash if anythign weird happens
|
||||
try:
|
||||
dat = requests.get(url, timeout=20)
|
||||
if dat.status_code != 200:
|
||||
print("404 file", url)
|
||||
return
|
||||
|
||||
# Try to parse this as an Image file, we'll fail out if not
|
||||
im = Image.open(BytesIO(dat.content))
|
||||
im.thumbnail((512, 512), PIL.Image.BICUBIC)
|
||||
if min(*im.size) < max(*im.size)/3:
|
||||
print("Too small", url)
|
||||
return
|
||||
|
||||
im.save(ROOT+"/%s/%d/%d.jpg"%(split,uid%1000,uid))
|
||||
|
||||
# Another try/catch just because sometimes saving and re-loading
|
||||
# the image is different than loading it once.
|
||||
try:
|
||||
o = Image.open(ROOT+"/%s/%d/%d.jpg"%(split,uid%1000,uid))
|
||||
o = np.array(o)
|
||||
|
||||
print("Success", o.shape, uid, url)
|
||||
return uid, caption, url
|
||||
except:
|
||||
print("Failed", uid, url)
|
||||
|
||||
except Exception as e:
|
||||
print("Unknown error", e)
|
||||
pass
|
||||
|
||||
if __name__ == "__main__":
|
||||
ROOT = "cc_data"
|
||||
|
||||
if not os.path.exists(ROOT):
|
||||
os.mkdir(ROOT)
|
||||
os.mkdir(ROOT+"/train")
|
||||
os.mkdir(ROOT+"/val")
|
||||
for i in range(1000):
|
||||
os.mkdir(ROOT+"/train/"+str(i))
|
||||
os.mkdir(ROOT+"/val/"+str(i))
|
||||
|
||||
|
||||
p = mp.Pool(300)
|
||||
|
||||
for tsv in sys.argv[1:]:
|
||||
print("Processing file", tsv)
|
||||
assert 'val' in tsv.lower() or 'train' in tsv.lower()
|
||||
split = 'val' if 'val' in tsv.lower() else 'train'
|
||||
results = p.map(grab,
|
||||
[(i,split,x) for i,x in enumerate(open(tsv).read().split("\n"))])
|
||||
|
||||
out = open(tsv.replace(".tsv","_output.csv"),"w")
|
||||
out.write("title\tfilepath\n")
|
||||
|
||||
for row in results:
|
||||
print("Test", row)
|
||||
if row is None: continue
|
||||
id, caption, url = row
|
||||
fp = "val/"+str(id%1000)+"/"+str(id)+".jpg"
|
||||
print(fp)
|
||||
if os.path.exists(os.path.join(ROOT,fp)):
|
||||
out.write("%s\t%s\n"%(caption,fp))
|
||||
else:
|
||||
print("Drop", id)
|
||||
out.close()
|
||||
|
||||
p.close()
|
||||
|
||||
1
src/training/.gitignore
vendored
Normal file
1
src/training/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
logs/
|
||||
218
src/training/data.py
Normal file
218
src/training/data.py
Normal file
@@ -0,0 +1,218 @@
|
||||
import os
|
||||
import sys
|
||||
import math
|
||||
import logging
|
||||
import functools
|
||||
import braceexpand
|
||||
import random
|
||||
import pdb
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
from PIL import Image
|
||||
|
||||
from typing import Union
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
import torchvision.datasets as datasets
|
||||
from webdataset.utils import identity
|
||||
import webdataset as wds
|
||||
|
||||
|
||||
|
||||
from clip.clip import tokenize
|
||||
|
||||
|
||||
class CsvDataset(Dataset):
|
||||
def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t"):
|
||||
logging.debug(f'Loading csv data from {input_filename}.')
|
||||
df = pd.read_csv(input_filename, sep=sep)
|
||||
|
||||
self.images = df[img_key].tolist()
|
||||
self.captions = df[caption_key].tolist()
|
||||
self.transforms = transforms
|
||||
logging.debug('Done loading data.')
|
||||
|
||||
def __len__(self):
|
||||
return len(self.captions)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
images = self.transforms(Image.open(str(self.images[idx])))
|
||||
texts = tokenize([str(self.captions[idx])])[0]
|
||||
return images, texts
|
||||
|
||||
@dataclass
|
||||
class DataInfo:
|
||||
dataloader: DataLoader
|
||||
sampler: DistributedSampler
|
||||
|
||||
def preprocess_txt(text):
|
||||
return tokenize([str(text)])[0]
|
||||
|
||||
def get_dataset_size(shards):
|
||||
shards_list = list(braceexpand.braceexpand(shards))
|
||||
dir_path = os.path.dirname(shards)
|
||||
sizes = eval(open(os.path.join(dir_path, 'sizes.json'), 'r').read())
|
||||
total_size = sum(
|
||||
[int(sizes[os.path.basename(shard)]) for shard in shards_list])
|
||||
num_shards = len(shards_list)
|
||||
return total_size, num_shards
|
||||
|
||||
def get_imagenet(args, preprocess_fns, split):
|
||||
assert split in ["train", "val", "v2"]
|
||||
is_train = split == "train"
|
||||
preprocess_train, preprocess_val = preprocess_fns
|
||||
|
||||
if split == "v2":
|
||||
from imagenetv2_pytorch import ImageNetV2Dataset
|
||||
dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val)
|
||||
else:
|
||||
if is_train:
|
||||
data_path = args.imagenet_train
|
||||
preprocess_fn = preprocess_train
|
||||
else:
|
||||
data_path = args.imagenet_val
|
||||
preprocess_fn = preprocess_val
|
||||
assert data_path
|
||||
|
||||
dataset = datasets.ImageFolder(data_path, transform=preprocess_fn)
|
||||
|
||||
if is_train:
|
||||
idxs = np.zeros(len(dataset.targets))
|
||||
target_array = np.array(dataset.targets)
|
||||
k = 50
|
||||
for c in range(1000):
|
||||
m = target_array == c
|
||||
n = len(idxs[m])
|
||||
arr = np.zeros(n)
|
||||
arr[:k] = 1
|
||||
np.random.shuffle(arr)
|
||||
idxs[m] = arr
|
||||
|
||||
idxs = idxs.astype('int')
|
||||
sampler = SubsetRandomSampler(np.where(idxs)[0])
|
||||
else:
|
||||
sampler = None
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.workers,
|
||||
sampler=sampler,
|
||||
)
|
||||
|
||||
return DataInfo(dataloader, sampler)
|
||||
|
||||
def count_samples(dataloader):
|
||||
os.environ["WDS_EPOCH"] = "0"
|
||||
n_elements, n_batches = 0, 0
|
||||
for images, texts in dataloader:
|
||||
n_batches += 1
|
||||
n_elements += len(images)
|
||||
assert len(images) == len(texts)
|
||||
return n_elements, n_batches
|
||||
|
||||
def get_wds_dataset(args, preprocess_img, is_train):
|
||||
input_shards = args.train_data if is_train else args.val_data
|
||||
assert input_shards is not None
|
||||
|
||||
# The following code is adapted from https://github.com/tmbdev/webdataset-examples/blob/master/main-wds.py
|
||||
num_samples, num_shards = get_dataset_size(input_shards)
|
||||
if is_train and args.distributed:
|
||||
max_shards_per_node = math.ceil(num_shards / args.world_size)
|
||||
num_samples = args.world_size * (num_samples * max_shards_per_node // num_shards)
|
||||
num_batches = num_samples // (args.batch_size * args.world_size)
|
||||
num_samples = num_batches * args.batch_size * args.world_size
|
||||
else:
|
||||
num_batches = num_samples // args.batch_size
|
||||
shardlist = wds.PytorchShardList(
|
||||
input_shards,
|
||||
epoch_shuffle=is_train,
|
||||
split_by_node=is_train # NOTE: we do eval on a single gpu.
|
||||
)
|
||||
dataset = (
|
||||
wds.WebDataset(shardlist)
|
||||
.decode("pil")
|
||||
.rename(image="jpg;png", text="txt")
|
||||
.map_dict(image=preprocess_img, text=preprocess_txt)
|
||||
.to_tuple("image", "text")
|
||||
.batched(args.batch_size, partial=not is_train or not args.distributed)
|
||||
)
|
||||
dataloader = wds.WebLoader(
|
||||
dataset, batch_size=None, shuffle=False, num_workers=args.workers,
|
||||
)
|
||||
if is_train and args.distributed:
|
||||
# With DDP, we need to make sure that all nodes get the same number of batches;
|
||||
# we do that by reusing a little bit of data.
|
||||
dataloader = dataloader.repeat(2).slice(num_batches)
|
||||
dataloader.num_batches = num_batches
|
||||
dataloader.num_samples = num_samples
|
||||
|
||||
return DataInfo(dataloader, None)
|
||||
|
||||
def get_csv_dataset(args, preprocess_fn, is_train):
|
||||
input_filename = args.train_data if is_train else args.val_data
|
||||
assert input_filename
|
||||
dataset = CsvDataset(
|
||||
input_filename,
|
||||
preprocess_fn,
|
||||
img_key=args.csv_img_key,
|
||||
caption_key=args.csv_caption_key,
|
||||
sep=args.csv_separator)
|
||||
num_samples = len(dataset)
|
||||
sampler = DistributedSampler(dataset) if args.distributed and is_train else None
|
||||
shuffle = is_train and sampler is None
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=shuffle,
|
||||
num_workers=args.workers,
|
||||
pin_memory=True,
|
||||
sampler=sampler,
|
||||
drop_last=is_train,
|
||||
)
|
||||
dataloader.num_samples = num_samples
|
||||
dataloader.num_batches = len(dataloader)
|
||||
|
||||
return DataInfo(dataloader, sampler)
|
||||
|
||||
def get_dataset_fn(data_path, dataset_type):
|
||||
if dataset_type == "webdataset":
|
||||
return get_wds_dataset
|
||||
elif dataset_type == "csv":
|
||||
return get_csv_dataset
|
||||
elif dataset_type == "auto":
|
||||
ext = data_path.split('.')[-1]
|
||||
if ext in ['csv', 'tsv']:
|
||||
return get_csv_dataset
|
||||
elif ext in ['tar']:
|
||||
return get_wds_dataset
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Tried to figure out dataset type, but failed for extention {ext}.")
|
||||
else:
|
||||
raise ValueError(f"Unsupported dataset type: {dataset_type}")
|
||||
|
||||
|
||||
def get_data(args, preprocess_fns):
|
||||
preprocess_train, preprocess_val = preprocess_fns
|
||||
data = {}
|
||||
|
||||
if args.train_data:
|
||||
data["train"] = get_dataset_fn(args.train_data, args.dataset_type)(
|
||||
args, preprocess_train, is_train=True)
|
||||
if args.val_data:
|
||||
data["val"] = get_dataset_fn(args.val_data, args.dataset_type)(
|
||||
args, preprocess_val, is_train=False)
|
||||
|
||||
if args.imagenet_val is not None:
|
||||
data["imagenet-val"] = get_imagenet(args, preprocess_fns, "val")
|
||||
if args.imagenet_v2 is not None:
|
||||
data["imagenet-v2"] = get_imagenet(args, preprocess_fns, "v2")
|
||||
|
||||
return data
|
||||
254
src/training/imagenet_zeroshot_data.py
Normal file
254
src/training/imagenet_zeroshot_data.py
Normal file
@@ -0,0 +1,254 @@
|
||||
|
||||
|
||||
imagenet_classnames = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray",
|
||||
"stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco",
|
||||
"indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper",
|
||||
"kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander",
|
||||
"smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog",
|
||||
"tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin",
|
||||
"box turtle", "banded gecko", "green iguana", "Carolina anole",
|
||||
"desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard",
|
||||
"Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile",
|
||||
"American alligator", "triceratops", "worm snake", "ring-necked snake",
|
||||
"eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake",
|
||||
"vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra",
|
||||
"green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake",
|
||||
"sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider",
|
||||
"barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider",
|
||||
"tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl",
|
||||
"quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet",
|
||||
"coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck",
|
||||
"red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby",
|
||||
"koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch",
|
||||
"snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab",
|
||||
"fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab",
|
||||
"isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron",
|
||||
"great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot",
|
||||
"bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher",
|
||||
"pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion",
|
||||
"Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel",
|
||||
"Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle",
|
||||
"Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound",
|
||||
"English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound",
|
||||
"Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound",
|
||||
"Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier",
|
||||
"Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier",
|
||||
"Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier",
|
||||
"Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier",
|
||||
"Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer",
|
||||
"Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier",
|
||||
"Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier",
|
||||
"Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever",
|
||||
"Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla",
|
||||
"English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel",
|
||||
"English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel",
|
||||
"Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard",
|
||||
"Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie",
|
||||
"Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann",
|
||||
"Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog",
|
||||
"Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff",
|
||||
"French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky",
|
||||
"Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog",
|
||||
"Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon",
|
||||
"Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle",
|
||||
"Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf",
|
||||
"red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox",
|
||||
"kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat",
|
||||
"Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger",
|
||||
"cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose",
|
||||
"meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle",
|
||||
"dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper",
|
||||
"cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper",
|
||||
"lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly",
|
||||
"monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly",
|
||||
"starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit",
|
||||
"hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse",
|
||||
"zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison",
|
||||
"ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)",
|
||||
"gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat",
|
||||
"black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan",
|
||||
"gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque",
|
||||
"langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin",
|
||||
"howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey",
|
||||
"ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda",
|
||||
"giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish",
|
||||
"sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown",
|
||||
"accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance",
|
||||
"amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle",
|
||||
"backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo",
|
||||
"baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel",
|
||||
"wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel",
|
||||
"bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)",
|
||||
"beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini",
|
||||
"ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet",
|
||||
"bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra",
|
||||
"breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest",
|
||||
"high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe",
|
||||
"can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton",
|
||||
"car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran",
|
||||
"CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw",
|
||||
"storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking",
|
||||
"church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker",
|
||||
"coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard",
|
||||
"candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot",
|
||||
"cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed",
|
||||
"Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer",
|
||||
"rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table",
|
||||
"dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig",
|
||||
"drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar",
|
||||
"electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder",
|
||||
"feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute",
|
||||
"folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed",
|
||||
"freight car", "French horn", "frying pan", "fur coat", "garbage truck",
|
||||
"gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola",
|
||||
"gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine",
|
||||
"hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer",
|
||||
"handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet",
|
||||
"holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar",
|
||||
"horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep",
|
||||
"T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat",
|
||||
"ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library",
|
||||
"lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion",
|
||||
"music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag",
|
||||
"mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask",
|
||||
"matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone",
|
||||
"microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile",
|
||||
"mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor",
|
||||
"moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa",
|
||||
"mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail",
|
||||
"neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina",
|
||||
"odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart",
|
||||
"oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush",
|
||||
"pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench",
|
||||
"parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case",
|
||||
"pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube",
|
||||
"picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball",
|
||||
"pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag",
|
||||
"plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho",
|
||||
"pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug",
|
||||
"printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill",
|
||||
"quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel",
|
||||
"recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator",
|
||||
"remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser",
|
||||
"rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal",
|
||||
"sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard",
|
||||
"CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store",
|
||||
"shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap",
|
||||
"shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door",
|
||||
"slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock",
|
||||
"solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater",
|
||||
"space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight",
|
||||
"stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf",
|
||||
"stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa",
|
||||
"submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge",
|
||||
"mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe",
|
||||
"table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball",
|
||||
"thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof",
|
||||
"toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store",
|
||||
"tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod",
|
||||
"triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard",
|
||||
"umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling",
|
||||
"velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball",
|
||||
"waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink",
|
||||
"washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle",
|
||||
"hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing",
|
||||
"wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website",
|
||||
"comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu",
|
||||
"plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette",
|
||||
"bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli",
|
||||
"cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber",
|
||||
"artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange",
|
||||
"lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate",
|
||||
"hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito",
|
||||
"red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef",
|
||||
"geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player",
|
||||
"bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn",
|
||||
"rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom",
|
||||
"earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"]
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
openai_imagenet_template = [
|
||||
lambda c: f'a bad photo of a {c}.',
|
||||
lambda c: f'a photo of many {c}.',
|
||||
lambda c: f'a sculpture of a {c}.',
|
||||
lambda c: f'a photo of the hard to see {c}.',
|
||||
lambda c: f'a low resolution photo of the {c}.',
|
||||
lambda c: f'a rendering of a {c}.',
|
||||
lambda c: f'graffiti of a {c}.',
|
||||
lambda c: f'a bad photo of the {c}.',
|
||||
lambda c: f'a cropped photo of the {c}.',
|
||||
lambda c: f'a tattoo of a {c}.',
|
||||
lambda c: f'the embroidered {c}.',
|
||||
lambda c: f'a photo of a hard to see {c}.',
|
||||
lambda c: f'a bright photo of a {c}.',
|
||||
lambda c: f'a photo of a clean {c}.',
|
||||
lambda c: f'a photo of a dirty {c}.',
|
||||
lambda c: f'a dark photo of the {c}.',
|
||||
lambda c: f'a drawing of a {c}.',
|
||||
lambda c: f'a photo of my {c}.',
|
||||
lambda c: f'the plastic {c}.',
|
||||
lambda c: f'a photo of the cool {c}.',
|
||||
lambda c: f'a close-up photo of a {c}.',
|
||||
lambda c: f'a black and white photo of the {c}.',
|
||||
lambda c: f'a painting of the {c}.',
|
||||
lambda c: f'a painting of a {c}.',
|
||||
lambda c: f'a pixelated photo of the {c}.',
|
||||
lambda c: f'a sculpture of the {c}.',
|
||||
lambda c: f'a bright photo of the {c}.',
|
||||
lambda c: f'a cropped photo of a {c}.',
|
||||
lambda c: f'a plastic {c}.',
|
||||
lambda c: f'a photo of the dirty {c}.',
|
||||
lambda c: f'a jpeg corrupted photo of a {c}.',
|
||||
lambda c: f'a blurry photo of the {c}.',
|
||||
lambda c: f'a photo of the {c}.',
|
||||
lambda c: f'a good photo of the {c}.',
|
||||
lambda c: f'a rendering of the {c}.',
|
||||
lambda c: f'a {c} in a video game.',
|
||||
lambda c: f'a photo of one {c}.',
|
||||
lambda c: f'a doodle of a {c}.',
|
||||
lambda c: f'a close-up photo of the {c}.',
|
||||
lambda c: f'a photo of a {c}.',
|
||||
lambda c: f'the origami {c}.',
|
||||
lambda c: f'the {c} in a video game.',
|
||||
lambda c: f'a sketch of a {c}.',
|
||||
lambda c: f'a doodle of the {c}.',
|
||||
lambda c: f'a origami {c}.',
|
||||
lambda c: f'a low resolution photo of a {c}.',
|
||||
lambda c: f'the toy {c}.',
|
||||
lambda c: f'a rendition of the {c}.',
|
||||
lambda c: f'a photo of the clean {c}.',
|
||||
lambda c: f'a photo of a large {c}.',
|
||||
lambda c: f'a rendition of a {c}.',
|
||||
lambda c: f'a photo of a nice {c}.',
|
||||
lambda c: f'a photo of a weird {c}.',
|
||||
lambda c: f'a blurry photo of a {c}.',
|
||||
lambda c: f'a cartoon {c}.',
|
||||
lambda c: f'art of a {c}.',
|
||||
lambda c: f'a sketch of the {c}.',
|
||||
lambda c: f'a embroidered {c}.',
|
||||
lambda c: f'a pixelated photo of a {c}.',
|
||||
lambda c: f'itap of the {c}.',
|
||||
lambda c: f'a jpeg corrupted photo of the {c}.',
|
||||
lambda c: f'a good photo of a {c}.',
|
||||
lambda c: f'a plushie {c}.',
|
||||
lambda c: f'a photo of the nice {c}.',
|
||||
lambda c: f'a photo of the small {c}.',
|
||||
lambda c: f'a photo of the weird {c}.',
|
||||
lambda c: f'the cartoon {c}.',
|
||||
lambda c: f'art of the {c}.',
|
||||
lambda c: f'a drawing of the {c}.',
|
||||
lambda c: f'a photo of the large {c}.',
|
||||
lambda c: f'a black and white photo of a {c}.',
|
||||
lambda c: f'the plushie {c}.',
|
||||
lambda c: f'a dark photo of a {c}.',
|
||||
lambda c: f'itap of a {c}.',
|
||||
lambda c: f'graffiti of the {c}.',
|
||||
lambda c: f'a toy {c}.',
|
||||
lambda c: f'itap of my {c}.',
|
||||
lambda c: f'a photo of a cool {c}.',
|
||||
lambda c: f'a photo of a small {c}.',
|
||||
lambda c: f'a tattoo of the {c}.',
|
||||
]
|
||||
90
src/training/logger.py
Normal file
90
src/training/logger.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import argparse
|
||||
import logging
|
||||
from logging import Filter
|
||||
from logging.handlers import QueueHandler, QueueListener
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from torch.multiprocessing import Queue
|
||||
|
||||
|
||||
def setup_primary_logging(log_file, level):
|
||||
log_queue = Queue(-1)
|
||||
|
||||
file_handler = logging.FileHandler(filename=log_file)
|
||||
stream_handler = logging.StreamHandler()
|
||||
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s | %(levelname)s | %(message)s',
|
||||
datefmt='%Y-%m-%d,%H:%M:%S')
|
||||
|
||||
file_handler.setFormatter(formatter)
|
||||
stream_handler.setFormatter(formatter)
|
||||
|
||||
file_handler.setLevel(level)
|
||||
stream_handler.setLevel(level)
|
||||
|
||||
listener = QueueListener(log_queue, file_handler, stream_handler)
|
||||
|
||||
listener.start()
|
||||
|
||||
return log_queue
|
||||
|
||||
|
||||
class WorkerLogFilter(Filter):
|
||||
def __init__(self, rank=-1):
|
||||
super().__init__()
|
||||
self._rank = rank
|
||||
|
||||
def filter(self, record):
|
||||
if self._rank != -1:
|
||||
record.msg = f"Rank {self._rank} | {record.msg}"
|
||||
return True
|
||||
|
||||
|
||||
def setup_worker_logging(rank, log_queue, level):
|
||||
queue_handler = QueueHandler(log_queue)
|
||||
|
||||
worker_filter = WorkerLogFilter(rank)
|
||||
queue_handler.addFilter(worker_filter)
|
||||
|
||||
queue_handler.setLevel(level)
|
||||
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.addHandler(queue_handler)
|
||||
|
||||
root_logger.setLevel(level)
|
||||
|
||||
|
||||
def fake_worker(rank: int, world_size: int, log_queue: Queue):
|
||||
setup_worker_logging(rank, log_queue, logging.DEBUG)
|
||||
logging.info("Test worker log")
|
||||
logging.error("Test worker error log")
|
||||
torch.cuda.set_device(rank)
|
||||
dist.init_process_group(
|
||||
backend='nccl',
|
||||
init_method='tcp://127.0.0.1:6100',
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Set multiprocessing type to spawn
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-g", "--gpu-list", type=int, help="List of GPU IDs", nargs="+", required=True)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
world_size = len(args.gpu_list)
|
||||
|
||||
# Initialize the primary logging handlers. Use the returned `log_queue`
|
||||
# to which the worker processes would use to push their messages
|
||||
log_queue = setup_primary_logging("/usr/lusers/gamaga/out.log", logging.DEBUG)
|
||||
|
||||
if world_size == 1:
|
||||
worker(0, world_size, log_queue)
|
||||
else:
|
||||
mp.spawn(fake_worker, args=(world_size, log_queue), nprocs=world_size)
|
||||
307
src/training/main.py
Normal file
307
src/training/main.py
Normal file
@@ -0,0 +1,307 @@
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
from time import gmtime, strftime
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
import wandb
|
||||
import torch
|
||||
from torch import optim
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.backends.cudnn as cudnn
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from torch.cuda.amp import GradScaler
|
||||
|
||||
from clip.clip import _transform, load
|
||||
from clip.model import convert_weights, CLIP
|
||||
from training.train import train, evaluate
|
||||
from training.data import get_data
|
||||
from training.params import parse_args
|
||||
from training.logger import setup_primary_logging, setup_worker_logging
|
||||
from training.scheduler import cosine_lr
|
||||
|
||||
# Used by https://github.com/openai/CLIP/issues/83 but not below.
|
||||
# Keeping it incase needed.
|
||||
def convert_models_to_fp32(model):
|
||||
for p in model.parameters():
|
||||
p.data = p.data.float()
|
||||
if p.grad:
|
||||
p.grad.data = p.grad.data.float()
|
||||
|
||||
def is_master(args):
|
||||
return (not args.distributed) or args.gpu == 0 or args.dp
|
||||
|
||||
def main_worker(gpu, ngpus_per_node, log_queue, args):
|
||||
args.gpu = gpu
|
||||
args.rank = gpu
|
||||
setup_worker_logging(args.rank, log_queue, args.log_level)
|
||||
|
||||
# Log and save params.
|
||||
if is_master(args):
|
||||
logging.info("Params:")
|
||||
params_file = os.path.join(args.logs, args.name, "params.txt")
|
||||
with open(params_file, "w") as f:
|
||||
for name in sorted(vars(args)):
|
||||
val = getattr(args, name)
|
||||
logging.info(f" {name}: {val}")
|
||||
f.write(f"{name}: {val}\n")
|
||||
|
||||
if args.distributed:
|
||||
dist.init_process_group(
|
||||
backend=args.dist_backend,
|
||||
init_method=args.dist_url,
|
||||
world_size=args.world_size,
|
||||
rank=args.rank,
|
||||
)
|
||||
|
||||
if args.dp:
|
||||
args.batch_size *= args.world_size
|
||||
|
||||
if args.gpu is not None:
|
||||
logging.info(f"Use GPU: {args.gpu} for training")
|
||||
torch.cuda.set_device(args.gpu)
|
||||
|
||||
# Do not use skip_reset unless you want to use on of the CLIP model
|
||||
if args.openai_pretrained:
|
||||
model, preprocess_train, preprocess_val = load(
|
||||
args.model,
|
||||
jit=False,
|
||||
is_train=True)
|
||||
else:
|
||||
model_config_file = Path(__file__).parent / f"model_configs/{args.model.replace('/', '-')}.json"
|
||||
print('Loading model from', model_config_file)
|
||||
assert os.path.exists(model_config_file)
|
||||
with open(model_config_file, 'r') as f:
|
||||
model_info = json.load(f)
|
||||
model = CLIP(**model_info)
|
||||
convert_weights(model)
|
||||
preprocess_train = _transform(model.visual.input_resolution, is_train=True)
|
||||
preprocess_val = _transform(model.visual.input_resolution, is_train=False)
|
||||
|
||||
|
||||
# See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
|
||||
if args.precision == "amp" or args.precision == "fp32" or args.gpu is None:
|
||||
convert_models_to_fp32(model)
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
model.float()
|
||||
logging.warning("using CPU, this will be slow")
|
||||
else:
|
||||
model.cuda(args.gpu)
|
||||
if args.precision == "fp16":
|
||||
convert_weights(model)
|
||||
# Previously batch size and workers were global and not per GPU.
|
||||
# args.batch_size = args.batch_size / ngpus_per_node)
|
||||
# args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
|
||||
|
||||
if args.distributed and args.use_bn_sync:
|
||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
if args.distributed:
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
||||
if args.dp:
|
||||
model = torch.nn.DataParallel(model, device_ids=args.multigpu)
|
||||
|
||||
if args.precision == "fp16":
|
||||
convert_weights(model)
|
||||
|
||||
data = get_data(args, (preprocess_train, preprocess_val))
|
||||
|
||||
exclude = lambda n : "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n
|
||||
include = lambda n : not exclude(n)
|
||||
|
||||
named_parameters = list(model.named_parameters())
|
||||
gain_or_bias_params = [p for n, p in named_parameters if exclude(n) and p.requires_grad]
|
||||
rest_params = [p for n, p in named_parameters if include(n) and p.requires_grad]
|
||||
|
||||
if args.train_data is None:
|
||||
optimizer = None
|
||||
scheduler = None
|
||||
else:
|
||||
optimizer = optim.AdamW(
|
||||
[
|
||||
{"params": gain_or_bias_params, "weight_decay": 0.},
|
||||
{"params": rest_params, "weight_decay": args.wd},
|
||||
],
|
||||
lr=args.lr,
|
||||
betas=(args.beta1, args.beta2),
|
||||
eps=args.eps,
|
||||
)
|
||||
total_steps = data["train"].dataloader.num_batches * args.epochs
|
||||
scheduler = cosine_lr(optimizer, args.lr, args.warmup, total_steps)
|
||||
|
||||
scaler = GradScaler() if args.precision == "amp" else None
|
||||
|
||||
# optionally resume from a checkpoint
|
||||
start_epoch = 0
|
||||
if args.resume is not None:
|
||||
if os.path.isfile(args.resume):
|
||||
if args.gpu is None:
|
||||
checkpoint = torch.load(args.resume)
|
||||
else:
|
||||
# Map model to be loaded to specified single gpu.
|
||||
loc = "cuda:{}".format(args.gpu)
|
||||
checkpoint = torch.load(args.resume, map_location=loc)
|
||||
start_epoch = checkpoint["epoch"]
|
||||
sd = checkpoint["state_dict"]
|
||||
if not args.distributed and next(iter(sd.items()))[0].startswith('module'):
|
||||
sd = {k[len('module.'):]: v for k, v in sd.items()}
|
||||
model.load_state_dict(sd)
|
||||
if optimizer is not None:
|
||||
optimizer.load_state_dict(checkpoint["optimizer"])
|
||||
logging.info(
|
||||
f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})"
|
||||
)
|
||||
else:
|
||||
logging.info("=> no checkpoint found at '{}'".format(args.resume))
|
||||
|
||||
cudnn.benchmark = True
|
||||
cudnn.deterministic = False
|
||||
|
||||
# determine if this worker should save logs and checkpoints.
|
||||
# only do so if it is the 0th worker.
|
||||
args.save_logs = (args.logs is not None and args.logs != '' and args.logs.lower() != 'none') and (
|
||||
(not args.distributed) or args.gpu == 0
|
||||
)
|
||||
writer = None
|
||||
if args.save_logs and args.tensorboard:
|
||||
writer = SummaryWriter(args.tensorboard_path)
|
||||
|
||||
if args.wandb and is_master(args):
|
||||
logging.debug('Starting wandb.')
|
||||
args.train_sz = data["train"].dataloader.num_samples
|
||||
if args.val_data is not None:
|
||||
args.val_sz = data["val"].dataloader.num_samples
|
||||
# you will have to configure this for your project!
|
||||
wandb.init(
|
||||
project="open-clip",
|
||||
notes=args.wandb_notes,
|
||||
tags=[],
|
||||
config=vars(args),
|
||||
)
|
||||
if args.debug:
|
||||
wandb.watch(model, log='all')
|
||||
wandb.save(params_file)
|
||||
logging.debug('Finished loading wandb.')
|
||||
|
||||
if args.train_data is None:
|
||||
evaluate(model, data, start_epoch, args, writer, 0)
|
||||
return
|
||||
elif start_epoch == 0 and args.val_data is not None:
|
||||
evaluate(model, data, 0, args, writer, 0)
|
||||
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
if args.gpu == 0:
|
||||
logging.info(f'Start epoch {epoch}')
|
||||
train(model, data, epoch, optimizer, scaler, scheduler, args, writer)
|
||||
steps = data["train"].dataloader.num_batches * (epoch + 1)
|
||||
if args.val_data is not None:
|
||||
evaluate(model, data, epoch + 1, args, writer, steps)
|
||||
|
||||
# Saving checkpoints.
|
||||
if args.save_logs and (args.gpu == 0 or (not args.distributed)):
|
||||
if (epoch + 1) == args.epochs or (
|
||||
args.save_frequency > 0 and ((epoch + 1) % args.save_frequency) == 0
|
||||
):
|
||||
torch.save(
|
||||
{
|
||||
"epoch": epoch + 1,
|
||||
"name": args.name,
|
||||
"state_dict": model.state_dict(),
|
||||
"optimizer": optimizer.state_dict(),
|
||||
},
|
||||
os.path.join(args.checkpoint_path, f"epoch_{epoch + 1}.pt"),
|
||||
)
|
||||
|
||||
if args.wandb and (args.gpu == 0 or (not args.distributed)):
|
||||
wandb.finish()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# get the name of the experiments
|
||||
if args.name is None:
|
||||
args.name = strftime(
|
||||
f"lr={args.lr}_"
|
||||
f"wd={args.wd}_"
|
||||
f"agg={args.aggregate}_"
|
||||
f"model={args.model}_"
|
||||
f"batchsize={args.batch_size}_workers={args.workers}_date=%Y-%m-%d-%H-%M-%S",
|
||||
gmtime(),
|
||||
)
|
||||
|
||||
if args.copy_codebase:
|
||||
import sys, subprocess
|
||||
from shutil import copytree, ignore_patterns
|
||||
new_code_path = os.path.join(args.logs, args.name, "code")
|
||||
if os.path.exists(new_code_path):
|
||||
print(
|
||||
f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment."
|
||||
)
|
||||
return -1
|
||||
print(f"Copying codebase to {new_code_path}")
|
||||
current_code_path = os.path.realpath(__file__)
|
||||
for _ in range(3):
|
||||
current_code_path = os.path.dirname(current_code_path)
|
||||
copytree(current_code_path, new_code_path, ignore=ignore_patterns('log', 'logs', 'wandb'))
|
||||
print("Done copying code.")
|
||||
os.environ["PYTHONPATH"] = f"{os.environ['PYTHONPATH']}:{os.path.join(new_code_path, 'src')}"
|
||||
main_file = os.path.join(new_code_path, "src", "training", "main.py")
|
||||
argv = sys.argv
|
||||
argv.remove('--copy-codebase')
|
||||
argv.extend(['--name', args.name])
|
||||
command = [sys.executable] + argv
|
||||
print("Executing command:", " ".join(command))
|
||||
subprocess.check_call(command)
|
||||
return 1
|
||||
|
||||
args.log_path = os.path.join(args.logs, args.name, "out.log")
|
||||
if os.path.exists(args.log_path):
|
||||
print(
|
||||
"Error. Experiment already exists. Use --name {} to specify a new experiment."
|
||||
)
|
||||
return -1
|
||||
|
||||
assert args.precision in ['amp', 'fp16', 'fp32']
|
||||
#assert args.model in ['RN50', 'RN101', 'RN50x4', 'ViT-B/32'] or os.path.exists(args.model)
|
||||
|
||||
args.ngpus_per_node = torch.cuda.device_count()
|
||||
|
||||
args.wandb = 'wandb' in args.report_to or 'all' in args.report_to
|
||||
args.tensorboard = 'tensorboard' in args.report_to or 'all' in args.report_to
|
||||
|
||||
args.tensorboard_path = os.path.join(args.logs, args.name, "tensorboard") if args.tensorboard else ''
|
||||
args.checkpoint_path = os.path.join(args.logs, args.name, "checkpoints")
|
||||
for dirname in [args.tensorboard_path, args.checkpoint_path]:
|
||||
if dirname:
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
|
||||
|
||||
# Set multiprocessing type to spawn.
|
||||
# This is important for logging to work with multiprocessing.
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
|
||||
# Set logger
|
||||
args.log_level = logging.DEBUG if args.debug else logging.INFO
|
||||
log_queue = setup_primary_logging(args.log_path, args.log_level)
|
||||
|
||||
# Distributed training = training on more than one GPU.
|
||||
# Also easily possible to extend to multiple nodes & multiple GPUs.
|
||||
args.distributed = (args.gpu is None) and torch.cuda.is_available() and (not args.dp)
|
||||
if args.distributed:
|
||||
ngpus_per_node = torch.cuda.device_count()
|
||||
args.world_size = ngpus_per_node
|
||||
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, log_queue, args))
|
||||
else:
|
||||
if args.dp:
|
||||
args.gpu = args.multigpu[0]
|
||||
args.world_size = len(args.multigpu)
|
||||
else:
|
||||
args.world_size = 1
|
||||
main_worker(args.gpu, None, log_queue, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
17
src/training/model_configs/RN101.json
Normal file
17
src/training/model_configs/RN101.json
Normal file
@@ -0,0 +1,17 @@
|
||||
{
|
||||
"embed_dim": 512,
|
||||
"image_resolution": 224,
|
||||
"vision_layers": [
|
||||
3,
|
||||
4,
|
||||
23,
|
||||
3
|
||||
],
|
||||
"vision_width": 64,
|
||||
"vision_patch_size": null,
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"transformer_width": 512,
|
||||
"transformer_heads": 8,
|
||||
"transformer_layers": 12
|
||||
}
|
||||
17
src/training/model_configs/RN50.json
Normal file
17
src/training/model_configs/RN50.json
Normal file
@@ -0,0 +1,17 @@
|
||||
{
|
||||
"embed_dim": 1024,
|
||||
"image_resolution": 224,
|
||||
"vision_layers": [
|
||||
3,
|
||||
4,
|
||||
6,
|
||||
3
|
||||
],
|
||||
"vision_width": 64,
|
||||
"vision_patch_size": null,
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"transformer_width": 512,
|
||||
"transformer_heads": 8,
|
||||
"transformer_layers": 12
|
||||
}
|
||||
17
src/training/model_configs/RN50x4.json
Normal file
17
src/training/model_configs/RN50x4.json
Normal file
@@ -0,0 +1,17 @@
|
||||
{
|
||||
"embed_dim": 640,
|
||||
"image_resolution": 288,
|
||||
"vision_layers": [
|
||||
4,
|
||||
6,
|
||||
10,
|
||||
6
|
||||
],
|
||||
"vision_width": 80,
|
||||
"vision_patch_size": null,
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"transformer_width": 640,
|
||||
"transformer_heads": 10,
|
||||
"transformer_layers": 12
|
||||
}
|
||||
12
src/training/model_configs/ViT-B-32.json
Normal file
12
src/training/model_configs/ViT-B-32.json
Normal file
@@ -0,0 +1,12 @@
|
||||
{
|
||||
"embed_dim": 512,
|
||||
"image_resolution": 224,
|
||||
"vision_layers": 12,
|
||||
"vision_width": 768,
|
||||
"vision_patch_size": 32,
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"transformer_width": 512,
|
||||
"transformer_heads": 8,
|
||||
"transformer_layers": 12
|
||||
}
|
||||
207
src/training/params.py
Normal file
207
src/training/params.py
Normal file
@@ -0,0 +1,207 @@
|
||||
import argparse
|
||||
|
||||
|
||||
def get_default_params(model_name):
|
||||
# Params from paper (https://arxiv.org/pdf/2103.00020.pdf)
|
||||
if model_name in ["RN50", "RN101", "RN50x4"]:
|
||||
return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8}
|
||||
elif model_name == "ViT-B/32":
|
||||
return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6}
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--train-data",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to csv filewith training data",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--val-data",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to csv file with validation data",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-type",
|
||||
choices=["webdataset", "csv", "auto"],
|
||||
default="auto",
|
||||
help="Which type of dataset to process."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--csv-separator",
|
||||
type=str,
|
||||
default="\t",
|
||||
help="For csv-like datasets, which separator to use."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--csv-img-key",
|
||||
type=str,
|
||||
default="filepath",
|
||||
help="For csv-like datasets, the name of the key for the image paths."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--csv-caption-key",
|
||||
type=str,
|
||||
default="title",
|
||||
help="For csv-like datasets, the name of the key for the captions."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--imagenet-val",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to imagenet val set for conducting zero shot evaluation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--imagenet-v2",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to imagenet v2 for conducting zero shot evaluation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logs",
|
||||
type=str,
|
||||
default="./logs/",
|
||||
help="Where to store tensorboard logs. Use None to avoid storing logs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Optional identifier for the experiment when storing logs. Otherwise use current time.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--workers", type=int, default=1, help="Number of workers per GPU."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size", type=int, default=64, help="Batch size per GPU."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epochs", type=int, default=32, help="Number of epochs to train for."
|
||||
)
|
||||
parser.add_argument("--lr", type=float, default=None, help="Learning rate.")
|
||||
parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.")
|
||||
parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.")
|
||||
parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.")
|
||||
parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.")
|
||||
parser.add_argument(
|
||||
"--warmup", type=int, default=10000, help="Number of steps to warmup for."
|
||||
)
|
||||
parser.add_argument("--use-bn-sync",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to use batch norm sync.")
|
||||
parser.add_argument(
|
||||
"--gpu",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Specify a single GPU to run the code on for debugging."
|
||||
"Leave at None to use all available GPUs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-scheduler",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use this flag to skip the learning rate decay.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-frequency", type=int, default=1, help="How often to save checkpoints."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--zeroshot-frequency", type=int, default=2, help="How often to run zero shot."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--regression-frequency", type=int, default=2, help="How often to run zero shot."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
default=None,
|
||||
type=str,
|
||||
help="path to latest checkpoint (default: none)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
choices=["amp", "fp16", "fp32"],
|
||||
default="amp",
|
||||
help="Floating point precition."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
choices=["RN50", "RN101", "RN50x4", "ViT-B/32"],
|
||||
default="RN50",
|
||||
help="Name of the vision backbone to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--openai-pretrained",
|
||||
default=False,
|
||||
action='store_true',
|
||||
help="Use the openai pretrained models.",
|
||||
)
|
||||
# arguments for distributed training
|
||||
parser.add_argument(
|
||||
"--dist-url",
|
||||
default="tcp://127.0.0.1:6100",
|
||||
type=str,
|
||||
help="url used to set up distributed training",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dist-backend", default="nccl", type=str, help="distributed backend"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-aggregate",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="whether to aggregate features across gpus before computing the loss"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--report-to",
|
||||
default='',
|
||||
type=str,
|
||||
help="Options are ['wandb', 'tensorboard', 'wandb,tensorboard']"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--wandb-notes",
|
||||
default='',
|
||||
type=str,
|
||||
help="Notes if logging with wandb"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--C", type=float, default=3.16, help="inverse regularizer for logistic reg."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="If true, more information is logged."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--copy-codebase",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="If true, we copy the entire base on the log diretory, and execute from there."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dp",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Use DP instead of DDP."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--multigpu",
|
||||
default=None,
|
||||
type=lambda x: [int(a) for a in x.split(",")],
|
||||
help="In DP, which GPUs to use for multigpu training",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args.aggregate = not args.skip_aggregate
|
||||
|
||||
# If some params are not passed, we use the default values based on model name.
|
||||
default_params = get_default_params(args.model)
|
||||
for name, val in default_params.items():
|
||||
if getattr(args, name) is None:
|
||||
setattr(args, name, val)
|
||||
|
||||
return args
|
||||
20
src/training/scheduler.py
Normal file
20
src/training/scheduler.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import numpy as np
|
||||
|
||||
def assign_learning_rate(optimizer, new_lr):
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["lr"] = new_lr
|
||||
|
||||
def _warmup_lr(base_lr, warmup_length, step):
|
||||
return base_lr * (step + 1) / warmup_length
|
||||
|
||||
def cosine_lr(optimizer, base_lr, warmup_length, steps):
|
||||
def _lr_adjuster(step):
|
||||
if step < warmup_length:
|
||||
lr = _warmup_lr(base_lr, warmup_length, step)
|
||||
else:
|
||||
e = step - warmup_length
|
||||
es = steps - warmup_length
|
||||
lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
|
||||
assign_learning_rate(optimizer, lr)
|
||||
return lr
|
||||
return _lr_adjuster
|
||||
245
src/training/train.py
Normal file
245
src/training/train.py
Normal file
@@ -0,0 +1,245 @@
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from torch.cuda.amp import autocast
|
||||
import torch.distributed as dist
|
||||
|
||||
from .zero_shot import zero_shot_eval
|
||||
|
||||
import sys
|
||||
import pdb
|
||||
import wandb
|
||||
|
||||
import logging
|
||||
|
||||
def is_master(args):
|
||||
return (not args.distributed) or args.gpu == 0
|
||||
|
||||
def get_loss(model, images, texts, loss_img, loss_txt, args):
|
||||
image_features, text_features, logit_scale = model(images, texts)
|
||||
logit_scale = logit_scale.mean()
|
||||
if args.distributed and args.aggregate:
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
|
||||
# We gather tensors from all gpus to get more negatives to contrast with.
|
||||
gathered_image_features = [
|
||||
torch.zeros_like(image_features) for _ in range(world_size)
|
||||
]
|
||||
gathered_text_features = [
|
||||
torch.zeros_like(text_features) for _ in range(world_size)
|
||||
]
|
||||
dist.all_gather(gathered_image_features, image_features)
|
||||
dist.all_gather(gathered_text_features, text_features)
|
||||
|
||||
all_image_features = torch.cat(
|
||||
[image_features]
|
||||
+ gathered_image_features[:rank]
|
||||
+ gathered_image_features[rank + 1 :]
|
||||
)
|
||||
all_text_features = torch.cat(
|
||||
[text_features]
|
||||
+ gathered_text_features[:rank]
|
||||
+ gathered_text_features[rank + 1 :]
|
||||
)
|
||||
|
||||
# this is needed to send gradients back everywhere.
|
||||
logits_per_image = logit_scale * all_image_features @ all_text_features.t()
|
||||
logits_per_text = logits_per_image.t()
|
||||
|
||||
else:
|
||||
logits_per_image = logit_scale * image_features @ text_features.t()
|
||||
logits_per_text = logit_scale * text_features @ image_features.t()
|
||||
|
||||
ground_truth = torch.arange(len(logits_per_image)).long()
|
||||
if args.gpu is not None:
|
||||
ground_truth = ground_truth.cuda(args.gpu, non_blocking=True)
|
||||
|
||||
total_loss = (
|
||||
loss_img(logits_per_image, ground_truth)
|
||||
+ loss_txt(logits_per_text, ground_truth)
|
||||
) / 2
|
||||
return total_loss
|
||||
|
||||
|
||||
def train(model, data, epoch, optimizer, scaler, scheduler, args, tb_writer=None):
|
||||
os.environ["WDS_EPOCH"] = str(epoch)
|
||||
|
||||
model.train()
|
||||
|
||||
dataloader, sampler = data['train'].dataloader, data['train'].sampler
|
||||
|
||||
loss_img = nn.CrossEntropyLoss()
|
||||
loss_txt = nn.CrossEntropyLoss()
|
||||
if args.gpu is not None:
|
||||
loss_img = loss_img.cuda(args.gpu)
|
||||
loss_txt = loss_txt.cuda(args.gpu)
|
||||
|
||||
if args.distributed and sampler is not None:
|
||||
sampler.set_epoch(epoch)
|
||||
|
||||
num_batches_per_epoch = dataloader.num_batches
|
||||
|
||||
end = time.time()
|
||||
for i, batch in enumerate(dataloader):
|
||||
step = num_batches_per_epoch * epoch + i
|
||||
scheduler(step)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
images, texts = batch
|
||||
if args.gpu is not None:
|
||||
images = images.cuda(args.gpu, non_blocking=True)
|
||||
texts = texts.cuda(args.gpu, non_blocking=True)
|
||||
|
||||
data_time = time.time() - end
|
||||
|
||||
m = model.module if args.distributed or args.dp else model
|
||||
|
||||
# with automatic mixed precision.
|
||||
if args.precision == "amp":
|
||||
with autocast():
|
||||
total_loss = get_loss(model, images, texts, loss_img, loss_txt, args)
|
||||
scaler.scale(total_loss).backward()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
||||
else:
|
||||
total_loss = get_loss(model, images, texts, loss_img, loss_txt, args)
|
||||
total_loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Note: we clamp to 4.6052 = ln(100), as in the original paper.
|
||||
m.logit_scale.data = torch.clamp(m.logit_scale.data, 0, 4.6052)
|
||||
|
||||
batch_time = time.time() - end
|
||||
end = time.time()
|
||||
|
||||
if is_master(args) and (i % 100) == 0:
|
||||
num_samples = i * len(images) * args.world_size
|
||||
samples_per_epoch = dataloader.num_samples
|
||||
percent_complete = 100.0 * i / num_batches_per_epoch
|
||||
logging.info(
|
||||
f"Train Epoch: {epoch} [{num_samples}/{samples_per_epoch} ({percent_complete:.0f}%)]\t"
|
||||
f"Loss: {total_loss.item():.6f}\tData (t) {data_time:.3f}\tBatch (t) {batch_time:.3f}"
|
||||
f"\tLR: {optimizer.param_groups[0]['lr']:5f}\tlogit_scale {m.logit_scale.data:.3f}"
|
||||
)
|
||||
# save train loss / etc.
|
||||
|
||||
timestep = epoch * num_batches_per_epoch + i
|
||||
log_data = {
|
||||
"loss": total_loss.item(),
|
||||
"data_time": data_time,
|
||||
"batch_time": batch_time,
|
||||
"scale": m.logit_scale.data.item(),
|
||||
"lr": optimizer.param_groups[0]["lr"]
|
||||
}
|
||||
|
||||
for name, val in log_data.items():
|
||||
name = "train/" + name
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(name, val, timestep)
|
||||
if args.wandb:
|
||||
wandb.log({name: val, 'step': timestep})
|
||||
|
||||
|
||||
def evaluate(model, data, epoch, args, tb_writer=None, steps=None):
|
||||
if not is_master(args):
|
||||
return
|
||||
|
||||
model.eval()
|
||||
|
||||
zero_shot_metrics = zero_shot_eval(model, data, epoch, args)
|
||||
|
||||
dataloader = data['val'].dataloader
|
||||
|
||||
loss_img = nn.CrossEntropyLoss()
|
||||
loss_txt = nn.CrossEntropyLoss()
|
||||
if args.gpu is not None:
|
||||
loss_img = loss_img.cuda(args.gpu)
|
||||
loss_txt = loss_txt.cuda(args.gpu)
|
||||
|
||||
cumulative_loss = 0.0
|
||||
num_elements = 0.0
|
||||
all_image_features, all_text_features = [], []
|
||||
with torch.no_grad():
|
||||
for batch in dataloader:
|
||||
images, texts = batch
|
||||
if args.gpu is not None:
|
||||
images = images.cuda(args.gpu, non_blocking=True)
|
||||
texts = texts.cuda(args.gpu, non_blocking=True)
|
||||
|
||||
image_features, text_features, logit_scale = model(images, texts)
|
||||
all_image_features.append(image_features)
|
||||
all_text_features.append(text_features)
|
||||
logit_scale = logit_scale.mean()
|
||||
logits_per_image = logit_scale * image_features @ text_features.t()
|
||||
logits_per_text = logits_per_image.t()
|
||||
|
||||
ground_truth = torch.arange(len(images)).long()
|
||||
if args.gpu is not None:
|
||||
ground_truth = ground_truth.cuda(args.gpu, non_blocking=True)
|
||||
total_loss = (
|
||||
loss_img(logits_per_image, ground_truth)
|
||||
+ loss_txt(logits_per_text, ground_truth)
|
||||
) / 2
|
||||
|
||||
batch_size = len(images)
|
||||
cumulative_loss += total_loss * batch_size
|
||||
num_elements += batch_size
|
||||
|
||||
metrics = get_metrics(
|
||||
torch.cat(all_image_features), torch.cat(all_text_features)
|
||||
)
|
||||
loss = cumulative_loss / num_elements
|
||||
metrics.update(
|
||||
**{"val_loss": loss.item(), "epoch": epoch, "num_elements": num_elements}
|
||||
)
|
||||
metrics.update(zero_shot_metrics)
|
||||
|
||||
logging.info(
|
||||
f"Eval Epoch: {epoch} "
|
||||
+ "\t".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
|
||||
)
|
||||
|
||||
if args.save_logs:
|
||||
for name, val in metrics.items():
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(f"val/{name}", val, epoch)
|
||||
if args.wandb:
|
||||
for name, val in metrics.items():
|
||||
wandb.log({f"val/{name}": val, 'epoch': epoch})
|
||||
|
||||
if args.save_logs:
|
||||
with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f:
|
||||
f.write(json.dumps(metrics))
|
||||
f.write("\n")
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def get_metrics(image_features, text_features):
|
||||
metrics = {}
|
||||
logits_per_image = image_features @ text_features.t()
|
||||
logits_per_text = logits_per_image.t()
|
||||
|
||||
logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text}
|
||||
ground_truth = (
|
||||
torch.arange(len(text_features)).view(-1, 1).to(logits_per_image.device)
|
||||
)
|
||||
|
||||
for name, logit in logits.items():
|
||||
ranking = torch.argsort(logit, descending=True)
|
||||
preds = torch.where(ranking == ground_truth)[1]
|
||||
preds = preds.detach().cpu().numpy()
|
||||
metrics[f"{name}_mean_rank"] = preds.mean() + 1
|
||||
metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1
|
||||
for k in [1, 5, 10]:
|
||||
metrics[f"{name}_R@{k}"] = np.mean(preds < k)
|
||||
|
||||
return metrics
|
||||
90
src/training/zero_shot.py
Normal file
90
src/training/zero_shot.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import clip.clip as clip
|
||||
from .imagenet_zeroshot_data import imagenet_classnames, openai_imagenet_template
|
||||
|
||||
import logging
|
||||
|
||||
def zero_shot_classifier(model, classnames, templates, args):
|
||||
with torch.no_grad():
|
||||
zeroshot_weights = []
|
||||
for classname in tqdm(classnames):
|
||||
texts = [template(classname) for template in templates] #format with class
|
||||
texts = clip.tokenize(texts).to(args.gpu) #tokenize
|
||||
if args.distributed:
|
||||
class_embeddings = model.module.encode_text(texts)
|
||||
elif args.dp:
|
||||
class_embeddings = model(None, texts)
|
||||
else:
|
||||
class_embeddings = model.encode_text(texts)
|
||||
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
||||
class_embedding = class_embeddings.mean(dim=0)
|
||||
class_embedding /= class_embedding.norm()
|
||||
zeroshot_weights.append(class_embedding)
|
||||
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(args.gpu)
|
||||
return zeroshot_weights
|
||||
|
||||
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
pred = output.topk(max(topk), 1, True, True)[1].t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]
|
||||
|
||||
def run(model, classifier, dataloader, args):
|
||||
with torch.no_grad():
|
||||
top1, top5, n = 0., 0., 0.
|
||||
for images, target in tqdm(dataloader):
|
||||
images = images.to(args.gpu)
|
||||
target = target.to(args.gpu)
|
||||
|
||||
# predict
|
||||
if args.distributed:
|
||||
image_features = model.module.encode_image(images)
|
||||
elif args.dp:
|
||||
image_features = model(images, None)
|
||||
else:
|
||||
image_features = model.encode_image(images)
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
logits = 100. * image_features @ classifier
|
||||
|
||||
# measure accuracy
|
||||
acc1, acc5 = accuracy(logits, target, topk=(1, 5))
|
||||
top1 += acc1
|
||||
top5 += acc5
|
||||
n += images.size(0)
|
||||
|
||||
top1 = (top1 / n)
|
||||
top5 = (top5 / n)
|
||||
return top1, top5
|
||||
|
||||
def zero_shot_eval(model, data, epoch, args):
|
||||
|
||||
if 'imagenet-val' not in data and 'imagenet-v2' not in data:
|
||||
return {}
|
||||
|
||||
if args.zeroshot_frequency == 0:
|
||||
return {}
|
||||
if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs:
|
||||
return {}
|
||||
|
||||
logging.info('Starting zero-shot imagenet.')
|
||||
|
||||
logging.info('Building zero-shot classifier')
|
||||
|
||||
classifier = zero_shot_classifier(model, imagenet_classnames, openai_imagenet_template, args)
|
||||
|
||||
logging.info('Using classifier')
|
||||
results = {}
|
||||
if 'imagenet-val' in data:
|
||||
top1, top5 = run(model, classifier, data['imagenet-val'].dataloader, args)
|
||||
results['imagenet-zeroshot-val-top1'] = top1
|
||||
results['imagenet-zeroshot-val-top5'] = top5
|
||||
if 'imagenet-v2' in data:
|
||||
top1, top5 = run(model, classifier, data['imagenet-v2'].dataloader, args)
|
||||
results['imagenetv2-zeroshot-val-top1'] = top1
|
||||
results['imagenetv2-zeroshot-val-top5'] = top5
|
||||
|
||||
|
||||
logging.info('Finished zero-shot imagenet.')
|
||||
|
||||
return results
|
||||
Reference in New Issue
Block a user