Implementation of Perceiver, General Perception with Iterative Attention, in Pytorch

Overview

Perceiver - Pytorch

Implementation of Perceiver, General Perception with Iterative Attention, in Pytorch

Install

$ pip install perceiver-pytorch

Usage

import torch
from perceiver_pytorch import Perceiver

model = Perceiver(
    num_fourier_features = 6,    # number of fourier features, with original value (2 * K + 1)
    depth = 48,                  # depth of net, in paper, they went deep, making up for lack of attention
    num_latents = 6,             # number of latents, or induced set points, or centroids. different papers giving it different names
    cross_dim = 512,             # cross attention dimension
    latent_dim = 512,            # latent dimension
    cross_heads = 1,             # number of heads for cross attention. paper said 1
    latent_heads = 8,            # number of heads for latent self attention, 8
    cross_dim_head = 64,
    latent_dim_head = 64,
    num_classes = 1000,          # output number of classes
    attn_dropout = 0.,
    ff_dropout = 0.,
    weight_tie_layers = False    # whether to weight tie layers (optional, as indicated in the diagram)
)

img = torch.randn(1, 224 * 224) # 1 imagenet image, pixelized

model(img) # (1, 1000)

Citations

@misc{jaegle2021perceiver,
    title   = {Perceiver: General Perception with Iterative Attention},
    author  = {Andrew Jaegle and Felix Gimeno and Andrew Brock and Andrew Zisserman and Oriol Vinyals and Joao Carreira},
    year    = {2021},
    eprint  = {2103.03206},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
Comments
  • Latent averaging to the logits?

    Latent averaging to the logits?

    I read through the paper last night and came away confused about a few things. I looked through your code hoping for some clarity.

    One issue that doesn't seem to be explained in the paper (or I am missing it) is how the authors go from a set of latents to the logits used at the classification head. You implemented this by taking the mean of the latent set:

    https://github.com/lucidrains/perceiver-pytorch/blob/main/perceiver_pytorch/perceiver_pytorch.py#L203

    Is this actually how the authors convert to logits?

    opened by neonbjb 7
  • PerceiverAR?

    PerceiverAR?

    Hey @lucidrains - love this repo, and still trying to wrap my head around the various difference between Perceiver architectures; how hard would it be to extend PerceiverIO to PerceiverAR; what fundamentally needs to change?

    opened by siddk 5
  • Not using the classification head in Perceiver

    Not using the classification head in Perceiver

    Hi @lucidrains, thank you for your great job!

    I'd like to use the Perceiver (not PerceiverIO) without the classification head (average and projection). Do you think we could add an option to avoid using it? I can do a PR if you want.

    Thanks!

    opened by gegallego 4
  • Decoder Attention Module needs a FF network as well in perceiver_io.py script

    Decoder Attention Module needs a FF network as well in perceiver_io.py script

    Hi,

    According to perceiver io paper's (https://arxiv.org/abs/2107.14795) architectural details, they mention that the decoder attention block contains a cross attention block (4), which is already implemented in the perceiver_io.py script (Line 151), followed by a Feedforward network, given by equation (6) in the paper, which is not present in that script. I am not aware of the repercussions of not having FF in the decoder module but it might be a good idea to have it in the implementation. Something like self.decoder_ff = PreNorm(FeedForward(queries_dim)) would do the job. Experimentally, the authors had found that omitting equation (5) is helpful.

    opened by Hritikbansal 4
  • Positional encoding are already part of the input

    Positional encoding are already part of the input

    Hello! First of all, thank you for this implementation.

    My inputs already have the proper positional encoding as part of the channel axis. Would it be possible to add a feature to deactivate the default implementation of the positional encoding?

    Thank you!

    opened by Atlis 4
  • x = self.latents + self.pos_emb

    x = self.latents + self.pos_emb

    self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))
    self.pos_emb = nn.Parameter(torch.randn(num_latents, latent_dim))
    ...
    x = self.latents + self.pos_emb
    

    I'm not very familiar with pytorch, but does this make sense? I mean, what's intended when 2 trainable weight matrices are simply summed and that's that's the only place where both latents and pos_emb appear. It looks like it can be replaced with only one matrix.

    opened by galchinsky 4
  • Fourier encoding is not similar to the paper

    Fourier encoding is not similar to the paper

    First of all, thanks for sharing the code !

    I have a follow up question to #4.

    In the paper, the authors mentioned about [sin(f_kπx_d), cos(f_kπx_d)], where f_k is a bank of frequencies spaced log-linearly between 1 and µ/2. Can you maybe point out how you came to the 1/2**i scaling in the code ?

    https://github.com/lucidrains/perceiver-pytorch/blob/6ae733773d29cb29383f3ac7b45af8cb6bd2c0dc/perceiver_pytorch/perceiver_pytorch.py#L28-L35

    Thanks!

    opened by cheneeheng 4
  • Fourier encoding should be for position coordinates instead of byte array

    Fourier encoding should be for position coordinates instead of byte array

    The fourier_encode function as implemented takes as input a byte array x and directly encodes it with sin/cos before concating with the input.

    As I understand the NeRF position encodings, they encode the x/y/etc. position coordinates, and not a transformation of the data itself. From the Perceiver paper:

    We parametrize the frequency encoding to take the values [sin(fkπxd), cos(fkπxd)], where the frequencies fk is the kth band of a bank of frequencies spaced log-linearly between 1 and µ/2... For example, by allowing the network to resolve the maximum frequency present in an input array, we can encourage it to learn to compare the values of bytes at any positions in the input array. xd is the value of the input position along the dth dimension (e.g. for images d = 2 and for video d = 3). xd takes values in [−1, 1] for each dimension. We concatenate the raw positional value xd to produce the final representation of position. This results in a positional encoding of size d(2K + 1).

    NeRF position encoding examples:

    • https://github.com/bmild/nerf/blob/20a91e764a28816ee2234fcadb73bd59a613a44c/run_nerf_helpers.py#L22
    • https://github.com/ankurhanda/nerf2D
    opened by eridgd 4
  • Positional encoding frequency bands should be linearly spaced

    Positional encoding frequency bands should be linearly spaced

    A small bug, but as alluded to in this comment by @marcdumon, it seems as though the frequency bands are indeed spaced linearly in the official JAX implementation.

    opened by djl11 2
  • Bug in fourier_encode (?)

    Bug in fourier_encode (?)

    Thank you for this great implementation. I'm learning a lot from it!

    I think I found a problem in the fourier_encode method. In this line: https://github.com/lucidrains/perceiver-pytorch/blob/b33aced4e1b266aeb1383e03ab63f0a9951f9126/perceiver_pytorch/perceiver_pytorch.py#L36

    the scales are always the same whatever value of parameter base. Example:

    max_freq = 10, num_bands=6, base = 2
    => scales = [1.0000, 1.3797, 1.9037, 2.6265, 3.6239, 5.0000]
    
    max_freq = 10, num_bands=6, base = 10
    => scales = [1.0000, 1.3797, 1.9037, 2.6265, 3.6239, 5.0000]
    
    opened by marcdumon 2
  • Attention softmax is applied to incorrect dimension?

    Attention softmax is applied to incorrect dimension?

    I am studying multi-head attention. When I was reading through [1], I found that the attenion softmax is applied over the last dimension of the similarity tensor sim:

            q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))
    
            sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
    
            if exists(mask):
                <removed>
    
            # attention, what we cannot get enough of
            attn = sim.softmax(dim = -1)
    

    If I understand correctly sim has the shape (b*h) n1 n2. The softmax is computed over the last dimension n2. Shouldn't the softmax be applied to matrices with all the similarity values of a single head (i.e. with shape n1, n2)?

    [1] https://github.com/lucidrains/perceiver-pytorch/blob/main/perceiver_pytorch/perceiver_io.py#L97

    opened by breuderink 2
  • Issue defining base in fourier_encode for experimental.py, gated.py, mixed_latents.py

    Issue defining base in fourier_encode for experimental.py, gated.py, mixed_latents.py

    Hey Lucid, love the work, it appears you deprecated base in fourier_encode at https://github.com/lucidrains/perceiver-pytorch/commit/144b0d9716a7212b5fd6d95a2267c4d4a08b56a7

    But experimental.py, gated.py, mixed_latents.py are still trying to define the base within the forward pass. https://github.com/lucidrains/perceiver-pytorch/blob/abbb5d5949d3509c57749bd134f5068f2761aac7/perceiver_pytorch/experimental.py#L122 https://github.com/lucidrains/perceiver-pytorch/blob/2d59df42ebb0b7538af77d584f5ae5b50759618b/perceiver_pytorch/mixed_latents.py#L85 https://github.com/lucidrains/perceiver-pytorch/blob/2d59df42ebb0b7538af77d584f5ae5b50759618b/perceiver_pytorch/gated.py#L103

    Thanks again, keep up the great work.

    opened by TannerLaBorde 0
  • Audio + Text data?

    Audio + Text data?

    Can someone please guide me on how you can process both audio and .txt data through perceiver simultaneously for multimodality learning?

    An example code would be nice.

    Thanks

    opened by Sidz1812 1
  • just a suggestion

    just a suggestion

    Hi I like to start with thanking you for such a great work with a lot of great implementations. I have a small suggestion. I suggest for all your codes/modules try to add if __name__ == "__main__": so that if someone just wants to use one file/module can easily try that without having going through whole implementations. for example I am trying to use the this, in case of having a if __name__ == "__main__": I can easily try to run a random input and see how it will work. This will increase the usability with a huge amount.

    Keep up the great work :)

    opened by seyeeet 4
  • What should I change if I want to use data with input size 720*184

    What should I change if I want to use data with input size 720*184

    thanks for sharing this code, I was wondering what should I change if I want to be able to use data that can be converted into images with an input size of 720*184? thanks in advance

    opened by Oussamab21 0
  • Question regarding queries dimensionality in Perceiver IO

    Question regarding queries dimensionality in Perceiver IO

    Hi @lucidrains,

    I think I may be missing something - why do we define the perceiver IO queries vector to have a batch dimension (i.e. queries = torch.randn(1, 128, 32))? Was this just to make the code work nicely? Shouldnt we be using queries = torch.randn(128, 32) ? I expect to use the same embedding for all of my batch elements, which is IIUC what your code is doing.

    opened by pcicales 3
Releases(0.8.6)
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
CM-NAS: Cross-Modality Neural Architecture Search for Visible-Infrared Person Re-Identification (ICCV2021)

CM-NAS Official Pytorch code of paper CM-NAS: Cross-Modality Neural Architecture Search for Visible-Infrared Person Re-Identification in ICCV2021. Vis

JDAI-CV 40 Nov 25, 2022
code and models for "Laplacian Pyramid Reconstruction and Refinement for Semantic Segmentation"

Laplacian Pyramid Reconstruction and Refinement for Semantic Segmentation This repository contains code and models for the method described in: Golnaz

55 Jun 18, 2022
Robot Servers and Server Manager software for robo-gym

robo-gym-server-modules Robot Servers and Server Manager software for robo-gym. For info on how to use this package please visit the robo-gym website

JR ROBOTICS 4 Aug 16, 2021
NudeNet: Neural Nets for Nudity Classification, Detection and selective censoring

NudeNet: Neural Nets for Nudity Classification, Detection and selective censoring Uncensored version of the following image can be found at https://i.

notAI.tech 1.1k Dec 29, 2022
Research on Tabular Deep Learning (Python package & papers)

Research on Tabular Deep Learning For paper implementations, see the section "Papers and projects". rtdl is a PyTorch-based package providing a user-f

Yura Gorishniy 510 Dec 30, 2022
Решения, подсказки, тесты и утилиты для тренировки по алгоритмам от Яндекса.

Решения и подсказки к тренировке по алгоритмам от Яндекса Что есть внутри Решения с подсказками и комментариями; рекомендую сначала смотреть md файл п

Yankovsky Andrey 50 Dec 26, 2022
WRENCH: Weak supeRvision bENCHmark

🔧 What is it? Wrench is a benchmark platform containing diverse weak supervision tasks. It also provides a common and easy framework for development

Jieyu Zhang 176 Dec 28, 2022
Graph Posterior Network: Bayesian Predictive Uncertainty for Node Classification (NeurIPS 2021)

Graph Posterior Network This is the official code repository to the paper Graph Posterior Network: Bayesian Predictive Uncertainty for Node Classifica

Maximilian Stadler 30 Dec 05, 2022
PyTorch implementation of "MLP-Mixer: An all-MLP Architecture for Vision" Tolstikhin et al. (2021)

mlp-mixer-pytorch PyTorch implementation of "MLP-Mixer: An all-MLP Architecture for Vision" Tolstikhin et al. (2021) Usage import torch from mlp_mixer

isaac 27 Jul 09, 2022
Enhancing Knowledge Tracing via Adversarial Training

Enhancing Knowledge Tracing via Adversarial Training This repository contains source code for the paper "Enhancing Knowledge Tracing via Adversarial T

Xiaopeng Guo 14 Oct 24, 2022
Implementation of the paper Scalable Intervention Target Estimation in Linear Models (NeurIPS 2021), and the code to generate simulation results.

Scalable Intervention Target Estimation in Linear Models Implementation of the paper Scalable Intervention Target Estimation in Linear Models (NeurIPS

0 Oct 25, 2021
USAD - UnSupervised Anomaly Detection on multivariate time series

USAD - UnSupervised Anomaly Detection on multivariate time series Scripts and utility programs for implementing the USAD architecture. Implementation

116 Jan 04, 2023
3.8% and 18.3% on CIFAR-10 and CIFAR-100

Wide Residual Networks This code was used for experiments with Wide Residual Networks (BMVC 2016) http://arxiv.org/abs/1605.07146 by Sergey Zagoruyko

Sergey Zagoruyko 1.2k Dec 29, 2022
Scalable machine learning based time series forecasting

mlforecast Scalable machine learning based time series forecasting. Install PyPI pip install mlforecast Optional dependencies If you want more functio

Nixtla 145 Dec 24, 2022
Geometry-Aware Learning of Maps for Camera Localization (CVPR2018)

Geometry-Aware Learning of Maps for Camera Localization This is the PyTorch implementation of our CVPR 2018 paper "Geometry-Aware Learning of Maps for

NVIDIA Research Projects 321 Nov 26, 2022
[CVPR2021 Oral] End-to-End Video Instance Segmentation with Transformers

VisTR: End-to-End Video Instance Segmentation with Transformers This is the official implementation of the VisTR paper: Installation We provide instru

Yuqing Wang 687 Jan 07, 2023
Radar-to-Lidar: Heterogeneous Place Recognition via Joint Learning

radar-to-lidar-place-recognition This page is the coder of a pre-print, implemented by PyTorch. If you have some questions on this project, please fee

Huan Yin 37 Oct 09, 2022
Pgn2tex - Scripts to convert pgn files to latex document. Useful to build books or pdf from pgn studies

Pgn2Latex (WIP) A simple script to make pdf from pgn files and studies. It's sti

12 Jul 23, 2022
PyTorch implementation of Algorithm 1 of "On the Anatomy of MCMC-Based Maximum Likelihood Learning of Energy-Based Models"

Code for On the Anatomy of MCMC-Based Maximum Likelihood Learning of Energy-Based Models This repository will reproduce the main results from our pape

Mitch Hill 32 Nov 25, 2022
This project aim to create multi-label classification annotation tool to boost annotation speed and make it more easier.

This project aim to create multi-label classification annotation tool to boost annotation speed and make it more easier.

4 Aug 02, 2022