Implementation of Memory-Compressed Attention, from the paper "Generating Wikipedia By Summarizing Long Sequences"

Overview

Memory Compressed Attention

Implementation of the Self-Attention layer of the proposed Memory-Compressed Attention, in Pytorch. This repository offers both the causal and non-causal variant, and will take care of the padding if the sequence length is not divisible by the compression ratio.

The code also resolves an edge-case where the very first query have no keys to attend to in the auto-regressive scenario. The solution is to use null key/values, appended to the final compressed set, so that there is always at least 1 key for all queries to attend to.

Install

$ pip install memory_compressed_attention

Usage

import torch
from memory_compressed_attention import MemoryCompressedAttention

attn = MemoryCompressedAttention(
    dim = 512,
    heads = 8,                 # number of heads
    causal = False,            # auto-regressive or not
    compression_factor = 3,    # compression ratio
    dropout = 0.1              # dropout post-attention
)

x = torch.randn(1, 1024, 512)
mask = torch.ones(1, 1024).bool()

attn(x, input_mask = mask) # (1, 1024, 512)

Citations

@misc{liu2018generating,
    title={Generating Wikipedia by Summarizing Long Sequences},
    author={Peter J. Liu and Mohammad Saleh and Etienne Pot and Ben Goodrich and Ryan Sepassi and Lukasz Kaiser and Noam Shazeer},
    year={2018},
    eprint={1801.10198},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}
You might also like...
Memory Efficient Attention (O(sqrt(n)) for Jax and PyTorch

Memory Efficient Attention This is unofficial implementation of Self-attention Does Not Need O(n^2) Memory for Jax and PyTorch. Implementation is almo

 Attention for PyTorch with Linear Memory Footprint
Attention for PyTorch with Linear Memory Footprint

Attention for PyTorch with Linear Memory Footprint Unofficially implements https://arxiv.org/abs/2112.05682 to get Linear Memory Cost on Attention (+

PyTorch code for our paper "Attention in Attention Network for Image Super-Resolution"

Under construction... Attention in Attention Network for Image Super-Resolution (A2N) This repository is an PyTorch implementation of the paper "Atten

Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch
Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch

Transformer in Transformer Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image c

Implementation of STAM (Space Time Attention Model), a pure and simple attention model that reaches SOTA for video classification
Implementation of STAM (Space Time Attention Model), a pure and simple attention model that reaches SOTA for video classification

STAM - Pytorch Implementation of STAM (Space Time Attention Model), yet another pure and simple SOTA attention model that bests all previous models in

Official Pytorch Implementation of Relational Self-Attention: What's Missing in Attention for Video Understanding
Official Pytorch Implementation of Relational Self-Attention: What's Missing in Attention for Video Understanding

Relational Self-Attention: What's Missing in Attention for Video Understanding This repository is the official implementation of "Relational Self-Atte

Official implementation of cosformer-attention in cosFormer: Rethinking Softmax in Attention

cosFormer Official implementation of cosformer-attention in cosFormer: Rethinking Softmax in Attention Update log 2022/2/28 Add core code License This

This is a pytorch implementation of the NeurIPS paper GAN Memory with No Forgetting.

GAN Memory for Lifelong learning This is a pytorch implementation of the NeurIPS paper GAN Memory with No Forgetting. Please consider citing our paper

Official and maintained implementation of the paper
Official and maintained implementation of the paper "OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data" [BMVC 2021].

OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data Christoph Reich, Tim Prangemeier, Özdemir Cetin & Heinz Koeppl | Pr

Comments
  • The order of masking and softmax operation

    The order of masking and softmax operation

    Hi,

    In memory_compressed_attention.py, I'm wondering if we need to do softmax operation after masking? Btw, if the entry in the mask should be float('-inf') instead of -float('-inf')? If I make something wrong, please correct me.

    image

    Thanks!

    opened by cfeng16 3
  • mask error in attention

    mask error in attention

    Very grateful for your pioneering work! I want to use it in Standard Transformer released in http://nlp.seas.harvard.edu/2018/04/03/attention.html. but it mat a mask error in training. more detail information shown as follow, the code i use: image class ConvCompress(nn.Module): def init(self, dim, ratio = 2, groups = 1): super(ConvCompress, self).init() self.conv = nn.Conv1d(dim, dim, ratio, stride = ratio, groups = groups) #self.linear = nn.Linear(dim, dim)

    def forward(self, mem):
        mem = mem.transpose(1, 2)
        compressed_mem = self.conv(mem)
        return compressed_mem.transpose(1, 2)
    

    class MemoryCompressedAttention(nn.Module): def init(self, h, d_model, compression_factor = 2, dropout = 0.1): super(MemoryCompressedAttention, self).init() assert (d_model % h) == 0, 'dimension must be divisible by number of heads' self.h = h self.d_model = d_model self.d_k = d_model // h

        self.compression_factor = compression_factor
        self.compress_fn = ConvCompress(d_model, compression_factor, groups = h)
    
        #self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
        self.wq = nn.Linear(d_model, d_model, bias = False)
        self.wk = nn.Linear(d_model, d_model, bias = False)
        self.wv = nn.Linear(d_model, d_model, bias = False)
    
        self.wo = nn.Linear(d_model, d_model)
    
        self.dropout = nn.Dropout(dropout)
    
        #self.null_k = nn.Parameter(torch.zeros(1, 1, d_model))
        #self.null_v = nn.Parameter(torch.zeros(1, 1, d_model))
    
    def forward(self, query, key, value, mask = None):
        
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)
        t = query.size(1)
        cf = self.compression_factor
    
        query = self.wq(query)
        key = self.wk(key)
        value = self.wv(value)
    
        # make sure keys and values sequence lengths
        # are divisible by the compression factor
        padding = cf - (t % cf)
        if padding != 0:
            key, value = map(lambda t: F.pad(t, (0, 0, padding, 0)), (key, value))
    
    
        # compress keys and values
        key, value = map(self.compress_fn, (key, value))
    
        # attach a null key and value, in the case that the first query has no keys to pay attention to
        null_k = nn.Parameter(torch.zeros(key.size(0), 1, self.d_model)).cuda()
        null_v = nn.Parameter(torch.zeros(value.size(0), 1, self.d_model)).cuda()
    
        key = torch.cat((null_k, key), dim=1)
        value = torch.cat((null_v, value), dim=1)
        
        # merge heads
        #query, key, value = map(lambda t: t.reshape(*t.shape[:2], h, -1).transpose(1, 2), (query, key, value))
        # 1) Do all the linear projections in batch from d_model => h x d_k
        query = query.view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
        key = key.view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
        value = value.view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
    
      
        # 2) Apply attention on all the projected vectors in batch.
        x, self.attn = attention(query, key, value, mask=mask,
                                 dropout=self.dropout)
    
        # 3) "Concat" using a view and apply a final linear.   # split heads and combine
        x = x.contiguous().view(nbatches, -1, self.d_model)
        out = self.wo(x)
    
        return out
    

    The error was show that image

    I want to know how to fix it, and how to do mask for N*M matrix??

    opened by HN123-123 0
Releases(0.0.5)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
ReConsider is a re-ranking model that re-ranks the top-K (passage, answer-span) predictions of an Open-Domain QA Model like DPR (Karpukhin et al., 2020).

ReConsider ReConsider is a re-ranking model that re-ranks the top-K (passage, answer-span) predictions of an Open-Domain QA Model like DPR (Karpukhin

Facebook Research 47 Jul 26, 2022
Code of Adverse Weather Image Translation with Asymmetric and Uncertainty aware GAN

Adverse Weather Image Translation with Asymmetric and Uncertainty-aware GAN (AU-GAN) Official Tensorflow implementation of Adverse Weather Image Trans

Jeong-gi Kwak 36 Dec 26, 2022
Implementation of the master's thesis "Temporal copying and local hallucination for video inpainting".

Temporal copying and local hallucination for video inpainting This repository contains the implementation of my master's thesis "Temporal copying and

David Álvarez de la Torre 1 Dec 02, 2022
Symmetry and Uncertainty-Aware Object SLAM for 6DoF Object Pose Estimation

SUO-SLAM This repository hosts the code for our CVPR 2022 paper "Symmetry and Uncertainty-Aware Object SLAM for 6DoF Object Pose Estimation". ArXiv li

Robot Perception & Navigation Group (RPNG) 97 Jan 03, 2023
Learning Dense Representations of Phrases at Scale (Lee et al., 2020)

DensePhrases DensePhrases provides answers to your natural language questions from the entire Wikipedia in real-time. While it efficiently searches th

Princeton Natural Language Processing 540 Dec 30, 2022
Differentiable Quantum Chemistry (only Differentiable Density Functional Theory and Hartree Fock at the moment)

DQC: Differentiable Quantum Chemistry Differentiable quantum chemistry package. Currently only support differentiable density functional theory (DFT)

75 Dec 02, 2022
This repository provides some of the code implemented and the data used for the work proposed in "A Cluster-Based Trip Prediction Graph Neural Network Model for Bike Sharing Systems".

cluster-link-prediction This repository provides some of the code implemented and the data used for the work proposed in "A Cluster-Based Trip Predict

Bárbara 0 Dec 28, 2022
Implementation of ProteinBERT in Pytorch

ProteinBERT - Pytorch (wip) Implementation of ProteinBERT in Pytorch. Original Repository Install $ pip install protein-bert-pytorch Usage import torc

Phil Wang 92 Dec 25, 2022
Revitalizing CNN Attention via Transformers in Self-Supervised Visual Representation Learning

Revitalizing CNN Attention via Transformers in Self-Supervised Visual Representation Learning This repository is the official implementation of CARE.

ChongjianGE 89 Dec 02, 2022
EMNLP 2021: Single-dataset Experts for Multi-dataset Question-Answering

MADE (Multi-Adapter Dataset Experts) This repository contains the implementation of MADE (Multi-adapter dataset experts), which is described in the pa

Princeton Natural Language Processing 68 Jul 18, 2022
RL-GAN: Transfer Learning for Related Reinforcement Learning Tasks via Image-to-Image Translation

RL-GAN: Transfer Learning for Related Reinforcement Learning Tasks via Image-to-Image Translation RL-GAN is an official implementation of the paper: T

42 Nov 10, 2022
Recreate CenternetV2 based on MMDET.

Introduction This project is trying to Recreate CenternetV2 based on MMDET, which is proposed in paper Probabilistic two-stage detection. This project

25 Dec 09, 2022
Implemented fully documented Particle Swarm Optimization algorithm (basic model with few advanced features) using Python programming language

Implemented fully documented Particle Swarm Optimization (PSO) algorithm in Python which includes a basic model along with few advanced features such as updating inertia weight, cognitive, social lea

9 Nov 29, 2022
The code repository for "PyCIL: A Python Toolbox for Class-Incremental Learning" in PyTorch.

PyCIL: A Python Toolbox for Class-Incremental Learning Introduction • Methods Reproduced • Reproduced Results • How To Use • License • Acknowledgement

Fu-Yun Wang 258 Dec 31, 2022
UnsupervisedR&R: Unsupervised Pointcloud Registration via Differentiable Rendering

UnsupervisedR&R: Unsupervised Pointcloud Registration via Differentiable Rendering This repository holds all the code and data for our recent work on

Mohamed El Banani 118 Dec 06, 2022
Code for "NeuralRecon: Real-Time Coherent 3D Reconstruction from Monocular Video", CVPR 2021 oral

NeuralRecon: Real-Time Coherent 3D Reconstruction from Monocular Video Project Page | Paper NeuralRecon: Real-Time Coherent 3D Reconstruction from Mon

ZJU3DV 1.4k Dec 30, 2022
C3D is a modified version of BVLC caffe to support 3D ConvNets.

C3D C3D is a modified version of BVLC caffe to support 3D convolution and pooling. The main supporting features include: Training or fine-tuning 3D Co

Meta Archive 1.1k Nov 14, 2022
The code is an implementation of Feedback Convolutional Neural Network for Visual Localization and Segmentation.

Feedback Convolutional Neural Network for Visual Localization and Segmentation The code is an implementation of Feedback Convolutional Neural Network

19 Dec 04, 2022
Example scripts for the detection of lanes using the ultra fast lane detection model in ONNX.

Example scripts for the detection of lanes using the ultra fast lane detection model in ONNX.

Ibai Gorordo 35 Sep 07, 2022
Jupyter notebooks for using & learning Keras

deep-learning-with-keras-notebooks 這個github的repository主要是個人在學習Keras的一些記錄及練習。希望在學習過程中發現到一些好的資訊與範例也可以對想要學習使用 Keras來解決問題的同好,或是對深度學習有興趣的在學學生可以有一些方便理解與上手範例

ErhWen Kuo 2.1k Dec 27, 2022