/Image GPT——手把手教你搭建

Image GPT——手把手教你搭建

关注并星标

从此不迷路

计算机视觉研究院

公众号IDComputerVisionGzq

学习群扫码在主页获取加入方式

论文地址:
https://cdn.openai.com/papers/Generative_Pretraining_from_Pixels_V2.pdf

计算机视觉研究院专栏

作者:Edison_G

研究者发现,正如在语言上训练的大型变换器模型可以生成连贯的文本一样,在像素序列上训练的相同精确模型可以生成一致的图像补全和样本。通过建立样本质量和图像分类精度之间的相关性,研究者表明,Open AI的最佳生成模型还包含在无监督设置下与顶级卷积网络竞争的特征。


Install

You can get miniconda from https://docs.conda.io/en/latest/miniconda.html, or install the dependencies shown below manually.

conda create --name image-gpt python=3.7.3conda activate image-gpt
conda install numpy=1.16.3conda install tensorflow-gpu=1.13.1
conda install imageio=2.8.0conda install requests=2.21.0conda install tqdm=4.46.0

Usage

This repository is meant to be a starting point for researchers and engineers to experiment with image GPT (iGPT). Our code forks GPT-2 to highlight that it can be easily applied across domains. The diff from gpt-2/src/model.py to image-gpt/src/model.py includes a new activation function, renaming of several variables, and the introduction of a start-of-sequence token, none of which change the model architecture.

Downloading Pre-trained Models

To download a model checkpoint, run download.py. The --model argument should be one of "s", "m", or "l", and the --ckpt argument should be one of "131000", "262000", "524000", or "1000000".

python download.py --model s --ckpt 1000000

This command downloads the iGPT-S checkpoint at 1M training iterations. The default download directory is set to /root/downloads/, and can be changed using the --download_dir argument.

Downloading Datasets

To download datasets, run download.py with the --dataset argument set to "imagenet" or "cifar10".

python download.py --model s --ckpt 1000000 --dataset imagenet

This command additionally downloads 32x32 ImageNet encoded with the 9-bit color palette described in the paper. The datasets we provide are center-cropped images intended for evaluation; random cropped images are required to faithfully replicate training.

Downloading Color Clusters

To download the color cluster file defining our 9-bit color palette, run download.py with the --clusters flag set.

python download.py --model s --ckpt 1000000 --dataset imagenet --clusters

This command additionally downloads the color cluster file. src/run.py:sample shows how to decode from 9-bit color to RGB and src/utils.py:color_quantize shows how to go the other way around.

Sampling

Once the desired checkpoint and color cluster file are downloaded, we can run the script in sampling mode. The following commands sample from iGPT-S, iGPT-M, and iGPT-L respectively:

python src/run.py --sample --n_embd 512  --n_head 8  --n_layer 24python src/run.py --sample --n_embd 1024 --n_head 8  --n_layer 36python src/run.py --sample --n_embd 1536 --n_head 16 --n_layer 48

If your data is not in /root/downloads/, set --ckpt_path and --color_cluster_path manually. To run on fewer than 8 GPUs, use a command of the following form:

CUDA_VISIBLE_DEVICES=0,1 python src/run.py --sample --n_embd 512  --n_head 8  --n_layer 24 --n_gpu 2

Evaluating

Once the desired checkpoint and evaluation dataset are downloaded, we can run the script in evaluation mode. The following commands evaluate iGPT-S, iGPT-M, and iGPT-L on ImageNet respectively:

python src/run.py --eval --n_embd 512  --n_head 8  --n_layer 24python src/run.py --eval --n_embd 1024 --n_head 8  --n_layer 36python src/run.py --eval --n_embd 1536 --n_head 16 --n_layer 48

If your data is not in /root/downloads/, set --ckpt_path and --data_path manually. You should see that the test generative losses are 2.0895, 2.0614, and 2.0466, matching Figure 3 in the paper.

© The Ending

转载请联系本公众号获得授权


计算机视觉研究院学习群等你加入!

计算机视觉研究院主要涉及深度学习领域,主要致力于人脸检测、人脸识别,多目标检测、目标跟踪、图像分割等研究方向。研究院接下来会不断分享最新的论文算法新框架,我们这次改革不同点就是,我们要着重”研究“。之后我们会针对相应领域分享实践过程,让大家真正体会摆脱理论的真实场景,培养爱动手编程爱动脑思考的习惯!

扫码关注

计算机视觉研究院

公众号IDComputerVisionGzq

学习群扫码在主页获取加入方式

 往期推荐 

🔗

本文来自微信公众号“计算机视觉研究院”(ID:ComputerVisionGzq)。大作社经授权转载,该文观点仅代表作者本人,大作社平台仅提供信息存储空间服务。