codes for "Scheduled Sampling Based on Decoding Steps for Neural Machine Translation" (long paper of EMNLP-2022)

Overview

Scheduled Sampling Based on Decoding Steps for Neural Machine Translation (EMNLP-2021 main conference)

Contents

Overview

We propose to conduct scheduled sampling based on decoding steps instead of the original training steps. We observe that our proposal can more realistically simulate the distribution of real translation errors, thus better bridging the gap between training and inference. The paper has been accepted to the main conference of EMNLP-2021.

Background

fastText

We conduct scheduled sampling for the Transformer with a two-pass decoder. An example of pseudo-code is as follows:

# first-pass: the same as the standard Transformer decoder
first_decoder_outputs = decoder(first_decoder_inputs)

# sampling tokens between model predicitions and ground-truth tokens
second_decoder_inputs = sampling_function(first_decoder_outputs, first_decoder_inputs)

# second-pass: computing the decoder again with the above sampled tokens
second_decoder_outputs = decoder(second_decoder_inputs)

Quick to Use

Our approaches are suitable for most autoregressive-based tasks. Please try the following pseudo-codes when conducting scheduled sampling:

import torch

def sampling_function(first_decoder_outputs, first_decoder_inputs, max_seq_len, tgt_lengths)
    '''
    conduct scheduled sampling based on the index of decoded tokens 
    param first_decoder_outputs: [batch_size, seq_len, hidden_size], model prediections 
    param first_decoder_inputs: [batch_size, seq_len, hidden_size], ground-truth target tokens
    param max_seq_len: scalar, the max lengh of target sequence
    param tgt_lengths: [batch_size], the lenghs of target sequences in a mini-batch
    '''

    # indexs of decoding steps
    t = torch.range(0, max_seq_len-1)

    # differenct sampling strategy based on decoding steps
    if sampling_strategy == "exponential":
        threshold_table = exp_radix ** t  
    elif sampling_strategy == "sigmoid":
        threshold_table = sigmoid_k / (sigmoid_k + torch.exp(t / sigmoid_k ))
    elif sampling_strategy == "linear":        
        threshold_table = torch.max(epsilon, 1 - t / max_seq_len)
    else:
        ValuraiseeError("Unknown sampling_strategy %s" % sampling_strategy)

    # convert threshold_table to [batch_size, seq_len]
    threshold_table = threshold_table.unsqueeze_(0).repeat(max_seq_len, 1).tril()
    thresholds = threshold_table[tgt_lengths].view(-1, max_seq_len)
    thresholds = current_thresholds[:, :seq_len]

    # conduct sampling based on the above thresholds
    random_select_seed = torch.rand([batch_size, seq_len]) 
    second_decoder_inputs = torch.where(random_select_seed < thresholds, first_decoder_inputs, first_decoder_outputs)

    return second_decoder_inputs
    

Further Usage

Error accumulation is a common phenomenon in NLP tasks. Whenever you want to simulate the accumulation of errors, our method may come in handy. For examples:

# sampling tokens between noisy target tokens and ground-truth tokens
decoder_inputs = sampling_function(noisy_decoder_inputs, golden_decoder_inputs, max_seq_len, tgt_lengths)

# computing the decoder with the above sampled tokens
decoder_outputs = decoder(decoder_inputs)
# sampling utterences from model predictions and ground-truth utterences
contexts = sampling_function(predicted_utterences, golden_utterences, max_turns, current_turns)

model_predictions = dialogue_model(contexts, target_inputs)

Experiments

We provide scripts to reproduce the results in this paper(NMT and text summarization)

Citation

Please cite this paper if you find this repo useful.

@inproceedings{liu_ss_decoding_2021,
    title = "Scheduled Sampling Based on Decoding Steps for Neural Machine Translation",
    author = "Liu, Yijin  and
      Meng, Fandong  and
      Chen, Yufeng  and
      Xu, Jinan  and
      Zhou, Jie",
    booktitle = "Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing (EMNLP)",
    year = "2021",
    address = "Online"
}

Contact

Please feel free to contact us ([email protected]) for any further questions.

Owner
Adaxry
Fast learner, eagle for new knowledge and deeper understanding
Adaxry
Official repository of OFA. Paper: Unifying Architectures, Tasks, and Modalities Through a Simple Sequence-to-Sequence Learning Framework

Paper | Blog OFA is a unified multimodal pretrained model that unifies modalities (i.e., cross-modality, vision, language) and tasks (e.g., image gene

OFA Sys 1.4k Jan 08, 2023
a basic code repository for basic task in CV(classification,detection,segmentation)

basic_cv a basic code repository for basic task in CV(classification,detection,segmentation,tracking) classification generate dataset train predict de

1 Oct 15, 2021
OpenDILab Multi-Agent Environment

Go-Bigger: Multi-Agent Decision Intelligence Environment GoBigger Doc (中文版) Ongoing 2021.11.13 We are holding a competition —— Go-Bigger: Multi-Agent

OpenDILab 441 Jan 05, 2023
Code Release for Learning to Adapt to Evolving Domains

EAML Code release for "Learning to Adapt to Evolving Domains" (NeurIPS 2020) Prerequisites PyTorch = 0.4.0 (with suitable CUDA and CuDNN version) tor

23 Dec 07, 2022
Synthesize photos from PhotoDNA using machine learning 🌱

Ribosome Synthesize photos from PhotoDNA. See the blog post for more information. Installation Dependencies You can install Python dependencies using

Anish Athalye 112 Nov 23, 2022
DiAne is a smart fuzzer for IoT devices

Diane Diane is a fuzzer for IoT devices. Diane works by identifying fuzzing triggers in the IoT companion apps to produce valid yet under-constrained

seclab 28 Jan 04, 2023
Efficient semidefinite bounds for multi-label discrete graphical models.

Low rank solvers #################################### benchmark/ : folder with the random instances used in the paper. ############################

1 Dec 08, 2022
An official source code for "Augmentation-Free Self-Supervised Learning on Graphs"

Augmentation-Free Self-Supervised Learning on Graphs An official source code for Augmentation-Free Self-Supervised Learning on Graphs paper, accepted

Namkyeong Lee 59 Dec 01, 2022
Code for "Adversarial Attack Generation Empowered by Min-Max Optimization", NeurIPS 2021

Min-Max Adversarial Attacks [Paper] [arXiv] [Video] [Slide] Adversarial Attack Generation Empowered by Min-Max Optimization Jingkang Wang, Tianyun Zha

Jingkang Wang 12 Nov 23, 2022
Transport Mode detection - can detect the mode of transport with the help of features such as acceeration,jerk etc

title emoji colorFrom colorTo sdk app_file pinned Transport_Mode_Detector 🚀 purple yellow gradio app.py false Configuration title: string Display tit

Nishant Rajadhyaksha 3 Jan 16, 2022
A user-friendly research and development tool built to standardize RL competency assessment for custom agents and environments.

Built with ❤️ by Sam Showalter Contents Overview Installation Dependencies Usage Scripts Standard Execution Environment Development Environment Benchm

SRI-AIC 1 Nov 18, 2021
Out of Distribution Detection on Natural Adversarial Examples

OOD-on-NAE Research project on out of distribution detection for the Computer Vision course by Prof. Rob Fergus (CSCI-GA 2271) Paper out on arXiv - ht

Anugya 1 Jun 08, 2022
π-GAN: Periodic Implicit Generative Adversarial Networks for 3D-Aware Image Synthesis

π-GAN: Periodic Implicit Generative Adversarial Networks for 3D-Aware Image Synthesis Project Page | Paper | Data Eric Ryan Chan*, Marco Monteiro*, Pe

375 Dec 31, 2022
(SIGIR2020) “Asymmetric Tri-training for Debiasing Missing-Not-At-Random Explicit Feedback’’

Asymmetric Tri-training for Debiasing Missing-Not-At-Random Explicit Feedback About This repository accompanies the real-world experiments conducted i

yuta-saito 19 Dec 01, 2022
A curated list of awesome neural radiance fields papers

Awesome Neural Radiance Fields A curated list of awesome neural radiance fields papers, inspired by awesome-computer-vision. How to submit a pull requ

Yen-Chen Lin 3.9k Dec 27, 2022
Multi-layer convolutional LSTM with Pytorch

Convolution_LSTM_pytorch Thanks for your attention. I haven't got time to maintain this repo for a long time. I recommend this repo which provides an

Zijie Zhuang 734 Jan 03, 2023
This is the official Pytorch-version code of FlatGCN (Flattened Graph Convolutional Networks for Recommendation).

FlatGCN This is the official Pytorch-version code of FlatGCN (Flattened Graph Convolutional Networks for Recommendation, submitted to ICASSP2022). Req

Dreamer 2 Aug 09, 2022
A pure PyTorch batched computation implementation of "CIF: Continuous Integrate-and-Fire for End-to-End Speech Recognition"

A pure PyTorch batched computation implementation of "CIF: Continuous Integrate-and-Fire for End-to-End Speech Recognition"

張致強 14 Dec 02, 2022
Repository for the paper "PoseAug: A Differentiable Pose Augmentation Framework for 3D Human Pose Estimation", CVPR 2021.

PoseAug: A Differentiable Pose Augmentation Framework for 3D Human Pose Estimation Code repository for the paper: PoseAug: A Differentiable Pose Augme

Pyjcsx 328 Dec 17, 2022
A computer vision pipeline to identify the "icons" in Christian paintings

Christian-Iconography A computer vision pipeline to identify the "icons" in Christian paintings. A bit about iconography. Iconography is related to id

Rishab Mudliar 3 Jul 30, 2022