ZeroGen: Efficient Zero-shot Learning via Dataset Generation

Overview

ZEROGEN

This repository contains the code for our paper “ZeroGen: Efficient Zero-shot Learning via Dataset Generation”. Our implementation is built on the source code from dino. Thanks for their work.

If you use this code, please cite our paper:

@article{ye2022zerogen,
      title={ZeroGen: Efficient Zero-shot Learning via Dataset Generation}, 
      author={Jiacheng Ye and Jiahui Gao and Qintong Li and Hang Xu and Jiangtao Feng and Zhiyong Wu and Tao Yu and Lingpeng Kong},
      year={2022},
      eprint={2202.07922},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

Setup

All requirements for ZEROGEN can be found in requirements.txt. You can install all required packages in a new environment with pip install -r requirements.txt.

Usage

The scripts/run_cls.sh and scripts/run_qa.sh scripts contain the running commands for the following settings:

  • supervised learning with human annotations (SUPERVISED)
  • prompt-based zero-shot learning (PROMPTING)
  • efficient zero-shot learning via dataset generation (ZEROGEN)

For text classification (TC) tasks (e.g., SST-2 and IMDb) and natural language inference (NLI) tasks (e.g., QNLI and RTE), run with bash scripts/run_cls.sh. For question answering (QA) tasks, run with bash scripts/run_qa.sh

When generating X (i.e., denotes text in TC, hypothesis in NLI and question in QA) in the final stage of the scripts, we also train the small model and evaluate it on human annotations. Specifically, after generating log_every number of examples, we perform training on the synthetic dataset and evaluation on the gold validation set. This gives as a trend graph similar to Figure 2 in the paper, which is shown by wandb, a powerful toolkit to track experiments.

Before running, you need to reset the following parameters to yours:

  • home_dir: path to ZeroGen
  • gpu: gpu id
  • batch_size: the batch size for generating with PLM. For SST-2, it costs ~16G when using a batch size of 32 with gpt2-xl. While for SQuAD, it costs ~60G using the same batch size and PLM because of the longer contexts. So decrease the batch size if needed.
  • WANDB_PROJECT: project name, by default ZeroGen
  • WANDB_ENTITY: your wandb username
  • WANDB_API_KEY: your api-key

By default we use GPT2-XL as pre-trained language model (PLM) and DistilBERT as tiny-task model (TAM), to modify the size of PLM and TAM, you can change model_name and small_model_name in run_xxx.sh scripts.

Run with a synthesized dataset

After dataset generation, we save the synthetic dataset at:

  • For TC and NLI: out-${task_name}-x2/${dataset}/${task_name}-dataset.jsonl (e.g., out-sst-2-x2/gpt2-xl_topk0_topp0.9_sst-2-x2/sst-2-dataset.jsonl). The file is in json line format (e.g., {"C": "The Book of Mormon Musical", "X": "The Book of Mormon Musical brings all the drama and excitement of a real revival of the Broadway production to the big screen.", "Y": 0}).
  • For QA: out-${task_name}-x2/${dataset}. We save the dataset in huggingface Dataset format.

To run DistilBERT given a generated dataset, you can use the scripts/run_distilbert.sh script.

To run a LSTM-based model given a generated dataset, you can use the scripts/run_cls_lstm.sh script. Before that, you have to download the datasets from google drive link, which contain the standard test files.

Diversity and Correctness of a synthesized dataset

Divesity

We use Self-BLEU to measure the diversity of a synthesized dataset. To calculate the Self-BLEU for a given dataset, you can see the example in scripts/run_self_bleu.sh script.

Correctness

To calculate the Correctness, you can take the following steps:

  1. Replace the following parameters in scripts/run_distilbert.sh script with:

    • small_model_name=roberta-large
    • dataset=: empty means using standard training set
    • limit=: empty means using full standard training set

    This will give you a RoBERTa-Large trained with full human annotations, which can be used as an evaluator.

  2. Replace the following parameters in scripts/run_distilbert.sh script with:

    • small_model_ckpt=tmp/checkpoint-xxx: the final RoBERTa-Large checkpoint saved in step 1.
    • limit=10000: the number of samples to use, by default 10000
    • dataset=xxx: the name of synthetic dataset (e.g., gpt2-xl_topk0_topp0.9_sst-2-x2)
    • no_train=true: disable training

    Run the script, and you will get Metric on standard dataset and Metric on synthetic dataset, which represents the Correctness of standard dataset and synthetic dataset, respectively.

Resources

We provide some synthetic datasets and standard datasets for training LSTM in this google drive link. When training DistilBERT, the standard dataset is directly downloaded by huggingface Dataset package. Note we use the same prompt for IMDb/SST-2, and SQuAD/AdversarialQA, therefore the synthetic datasets are also the same.

A PyTorch Implementation of Neural IMage Assessment

NIMA: Neural IMage Assessment This is a PyTorch implementation of the paper NIMA: Neural IMage Assessment (accepted at IEEE Transactions on Image Proc

yunxiaos 418 Dec 29, 2022
code for paper "Not All Unlabeled Data are Equal: Learning to Weight Data in Semi-supervised Learning" by Zhongzheng Ren*, Raymond A. Yeh*, Alexander G. Schwing.

Not All Unlabeled Data are Equal: Learning to Weight Data in Semi-supervised Learning Overview This code is for paper: Not All Unlabeled Data are Equa

Jason Ren 22 Nov 23, 2022
Course content and resources for the AIAIART course.

AIAIART course This repo will house the notebooks used for the AIAIART course. Part 1 (first four lessons) ran via Discord in September/October 2021.

Jonathan Whitaker 492 Jan 06, 2023
transfer attack; adversarial examples; black-box attack; unrestricted Adversarial Attacks on ImageNet; CVPR2021 天池黑盒竞赛

transfer_adv CVPR-2021 AIC-VI: unrestricted Adversarial Attacks on ImageNet CVPR2021 安全AI挑战者计划第六期赛道2:ImageNet无限制对抗攻击 介绍 : 深度神经网络已经在各种视觉识别问题上取得了最先进的性能。

25 Dec 08, 2022
A hybrid framework (neural mass model + ML) for SC-to-FC prediction

The current workflow simulates brain functional connectivity (FC) from structural connectivity (SC) with a neural mass model. Gradient descent is applied to optimize the parameters in the neural mass

Yilin Liu 1 Jan 26, 2022
A python implementation of Physics-informed Spline Learning for nonlinear dynamics discovery

PiSL A python implementation of Physics-informed Spline Learning for nonlinear dynamics discovery. Sun, F., Liu, Y. and Sun, H., 2021. Physics-informe

Fangzheng (Andy) Sun 8 Jul 13, 2022
Official implementation of Protected Attribute Suppression System, ICCV 2021

Official implementation of Protected Attribute Suppression System, ICCV 2021

Prithviraj Dhar 6 Jan 01, 2023
The official implementation of paper "Finding the Task-Optimal Low-Bit Sub-Distribution in Deep Neural Networks" (IJCV under review).

DGMS This is the code of the paper "Finding the Task-Optimal Low-Bit Sub-Distribution in Deep Neural Networks". Installation Our code works with Pytho

Runpei Dong 3 Aug 28, 2022
Using OpenAI's CLIP to upscale and enhance images

CLIP Upscaler and Enhancer Using OpenAI's CLIP to upscale and enhance images Based on nshepperd's JAX CLIP Guided Diffusion v2.4 Sample Results Viewpo

Tripp Lyons 5 Jun 14, 2022
Unified API to facilitate usage of pre-trained "perceptor" models, a la CLIP

mmc installation git clone https://github.com/dmarx/Multi-Modal-Comparators cd 'Multi-Modal-Comparators' pip install poetry poetry build pip install d

David Marx 37 Nov 25, 2022
Official PyTorch Implementation of GAN-Supervised Dense Visual Alignment

GAN-Supervised Dense Visual Alignment — Official PyTorch Implementation Paper | Project Page | Video This repo contains training, evaluation and visua

944 Jan 07, 2023
Python implementation of a live deep learning based age/gender/expression recognizer

TUT live age estimator Python implementation of a live deep learning based age/gender/smile/celebrity twin recognizer. All components use convolutiona

Heikki Huttunen 80 Nov 21, 2022
Bot developed in Python that automates races in pegaxy.

español | português About it: This is a fork from pega-racing-bot. This bot, developed in Python, is to automate races in pegaxy. The game developers

4 Apr 08, 2022
Tensorflow implementation of "Learning Deconvolution Network for Semantic Segmentation"

Tensorflow implementation of Learning Deconvolution Network for Semantic Segmentation. Install Instructions Works with tensorflow 1.11.0 and uses the

Fabian Bormann 224 Apr 15, 2022
Python utility to generate filesystem content for Obsidian.

Security Vault Generator Quickly parse, format, and output common frameworks/content for Obsidian.md. There is a strong focus on MITRE ATT&CK because

Justin Angel 73 Dec 02, 2022
JupyterNotebook - C/C++, Javascript, HTML, LaTex, Shell scripts in Jupyter Notebook Also run them on remote computer

JupyterNotebook Read, write and execute C, C++, Javascript, Shell scripts, HTML, LaTex in jupyter notebook, And also execute them on remote computer R

1 Jan 09, 2022
Gradient-free global optimization algorithm for multidimensional functions based on the low rank tensor train format

ttopt Description Gradient-free global optimization algorithm for multidimensional functions based on the low rank tensor train (TT) format and maximu

5 May 23, 2022
OpenABC-D: A Large-Scale Dataset For Machine Learning Guided Integrated Circuit Synthesis

OpenABC-D: A Large-Scale Dataset For Machine Learning Guided Integrated Circuit Synthesis Overview OpenABC-D is a large-scale labeled dataset generate

NYU Machine-Learning guided Design Automation (MLDA) 31 Nov 22, 2022
Code for the paper "Functional Regularization for Reinforcement Learning via Learned Fourier Features"

Reinforcement Learning with Learned Fourier Features State-space Soft Actor-Critic Experiments Move to the state-SAC-LFF repository. cd state-SAC-LFF

Alex Li 10 Nov 11, 2022
Pointer networks Tensorflow2

Pointer networks Tensorflow2 原文:https://arxiv.org/abs/1506.03134 仅供参考与学习,内含代码备注 环境 tensorflow==2.6.0 tqdm matplotlib numpy 《pointer networks》阅读笔记 应用场景

HUANG HAO 7 Oct 27, 2022