Implementation of Bottleneck Transformer in Pytorch

Overview

Bottleneck Transformer - Pytorch

PyPI version

Implementation of Bottleneck Transformer, SotA visual recognition model with convolution + attention that outperforms EfficientNet and DeiT in terms of performance-computes trade-off, in Pytorch

Install

$ pip install bottleneck-transformer-pytorch

Usage

import torch
from torch import nn
from bottleneck_transformer_pytorch import BottleStack

layer = BottleStack(
    dim = 256,              # channels in
    fmap_size = 64,         # feature map size
    dim_out = 2048,         # channels out
    proj_factor = 4,        # projection factor
    downsample = True,      # downsample on first layer or not
    heads = 4,              # number of heads
    dim_head = 128,         # dimension per head, defaults to 128
    rel_pos_emb = False,    # use relative positional embedding - uses absolute if False
    activation = nn.ReLU()  # activation throughout the network
)

fmap = torch.randn(2, 256, 64, 64) # feature map from previous resnet block(s)

layer(fmap) # (2, 2048, 32, 32)

BotNet

With some simple model surgery off a resnet, you can have the 'BotNet' (what a weird name) for training.

import torch
from torch import nn
from torchvision.models import resnet50

from bottleneck_transformer_pytorch import BottleStack

layer = BottleStack(
    dim = 256,
    fmap_size = 56,        # set specifically for imagenet's 224 x 224
    dim_out = 2048,
    proj_factor = 4,
    downsample = True,
    heads = 4,
    dim_head = 128,
    rel_pos_emb = True,
    activation = nn.ReLU()
)

resnet = resnet50()

# model surgery

backbone = list(resnet.children())

model = nn.Sequential(
    *backbone[:5],
    layer,
    nn.AdaptiveAvgPool2d((1, 1)),
    nn.Flatten(1),
    nn.Linear(2048, 1000)
)

# use the 'BotNet'

img = torch.randn(2, 3, 224, 224)
preds = model(img) # (2, 1000)

Citations

@misc{srinivas2021bottleneck,
    title   = {Bottleneck Transformers for Visual Recognition}, 
    author  = {Aravind Srinivas and Tsung-Yi Lin and Niki Parmar and Jonathon Shlens and Pieter Abbeel and Ashish Vaswani},
    year    = {2021},
    eprint  = {2101.11605},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
Comments
  • How should I modify the code if the input feature map has unequal width and height?

    How should I modify the code if the input feature map has unequal width and height?

    Assume that the width and height of the feature map are 10 and 8, respectively. Could you please help me to check that the modification about the class RelPosEmb is correct?

    class RelPosEmb(nn.Module): def init( self, fmap_size, dim_head ): super().init() scale = dim_head ** -0.5 self.fmap_size = fmap_size self.scale = scale # self.rel_height = nn.Parameter(torch.randn(fmap_size * 2 - 1, dim_head) * scale) self.rel_height = nn.Parameter(torch.randn(8 * 2 - 1, dim_head) * scale) # self.rel_width = nn.Parameter(torch.randn(fmap_size * 2 - 1, dim_head) * scale) self.rel_width = nn.Parameter(torch.randn(10* 2 - 1, dim_head) * scale)

    def forward(self, q):
        q = rearrange(q, 'b h (x y) d -> b h x y d', x = 8)
        rel_logits_w = relative_logits_1d(q, self.rel_width)
        rel_logits_w = rearrange(rel_logits_w, 'b h x i y j-> b h (x y) (i j)')
    
        q = rearrange(q, 'b h x y d -> b h y x d')
        rel_logits_h = relative_logits_1d(q, self.rel_height)
        rel_logits_h = rearrange(rel_logits_h, 'b h x i y j -> b h (y x) (j i)')
        return rel_logits_w + rel_logits_h
    
    opened by ShuweiShao 4
  • Feature map size

    Feature map size

    hi In my case, the input images size are all different, so the feature map size keeps changing. In this case, how should the fmap_size parameter of BottleStack be set? Is it possible to learn with an unfixed feature map size?

    opened by benlee73 3
  • A little bug.

    A little bug.

    https://github.com/lucidrains/bottleneck-transformer-pytorch/blob/b789de6db39f33854862fbc9bcee27c697cf003c/bottleneck_transformer_pytorch/bottleneck_transformer_pytorch.py#L16

    It is necessary to specify the equipment here.

    flat_pad = torch.zeros((b, h, l - 1), device = device, dtype = dtype) 
    
    opened by lartpang 1
  • fix inplace operations

    fix inplace operations

    Latest versions of PyTorch throw runtime errors for inplace operations like *= and += on tensors that require gradients. This pull request fixes the issue by replacing them with binary versions.

    opened by AminRezaei0x443 0
  • could you explain the implements of ralative position embedding?

    could you explain the implements of ralative position embedding?

    reference https://github.com/tensorflow/tensor2tensor/blob/5f9dd2db6d7797162e53adf152310ed13e9fc711/tensor2tensor/layers/common_attention.py

    def _generate_relative_positions_matrix(length_q, length_k,
                                            max_relative_position,
                                            cache=False):
      """Generates matrix of relative positions between inputs."""
      if not cache:
        if length_q == length_k:
          range_vec_q = range_vec_k = tf.range(length_q)
        else:
          range_vec_k = tf.range(length_k)
          range_vec_q = range_vec_k[-length_q:]
        distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
      else:
        distance_mat = tf.expand_dims(tf.range(-length_k+1, 1, 1), 0)
      distance_mat_clipped = tf.clip_by_value(distance_mat, -max_relative_position,
                                              max_relative_position)
      # Shift values to be >= 0. Each integer still uniquely identifies a relative
      # position difference.
      final_mat = distance_mat_clipped + max_relative_position
      return final_mat
    
    
    def _generate_relative_positions_embeddings(length_q, length_k, depth,
                                                max_relative_position, name,
                                                cache=False):
      """Generates tensor of size [1 if cache else length_q, length_k, depth]."""
      with tf.variable_scope(name):
        relative_positions_matrix = _generate_relative_positions_matrix(
            length_q, length_k, max_relative_position, cache=cache)
        vocab_size = max_relative_position * 2 + 1
        # Generates embedding for each relative position of dimension depth.
        embeddings_table = tf.get_variable("embeddings", [vocab_size, depth])
        embeddings = tf.gather(embeddings_table, relative_positions_matrix)
        return embeddings
    
    
    def _relative_attention_inner(x, y, z, transpose):
      """Relative position-aware dot-product attention inner calculation.
      This batches matrix multiply calculations to avoid unnecessary broadcasting.
      Args:
        x: Tensor with shape [batch_size, heads, length or 1, length or depth].
        y: Tensor with shape [batch_size, heads, length or 1, depth].
        z: Tensor with shape [length or 1, length, depth].
        transpose: Whether to transpose inner matrices of y and z. Should be true if
            last dimension of x is depth, not length.
      Returns:
        A Tensor with shape [batch_size, heads, length, length or depth].
      """
      batch_size = tf.shape(x)[0]
      heads = x.get_shape().as_list()[1]
      length = tf.shape(x)[2]
    
      # xy_matmul is [batch_size, heads, length or 1, length or depth]
      xy_matmul = tf.matmul(x, y, transpose_b=transpose)
      # x_t is [length or 1, batch_size, heads, length or depth]
      x_t = tf.transpose(x, [2, 0, 1, 3])
      # x_t_r is [length or 1, batch_size * heads, length or depth]
      x_t_r = tf.reshape(x_t, [length, heads * batch_size, -1])
      # x_tz_matmul is [length or 1, batch_size * heads, length or depth]
      x_tz_matmul = tf.matmul(x_t_r, z, transpose_b=transpose)
      # x_tz_matmul_r is [length or 1, batch_size, heads, length or depth]
      x_tz_matmul_r = tf.reshape(x_tz_matmul, [length, batch_size, heads, -1])
      # x_tz_matmul_r_t is [batch_size, heads, length or 1, length or depth]
      x_tz_matmul_r_t = tf.transpose(x_tz_matmul_r, [1, 2, 0, 3])
      return xy_matmul + x_tz_matmul_r_t
    
    
    def dot_product_attention_relative(q,
                                       k,
                                       v,
                                       bias,
                                       max_relative_position,
                                       dropout_rate=0.0,
                                       image_shapes=None,
                                       save_weights_to=None,
                                       name=None,
                                       make_image_summary=True,
                                       cache=False,
                                       allow_memory=False,
                                       hard_attention_k=0,
                                       gumbel_noise_weight=0.0):
      """Calculate relative position-aware dot-product self-attention.
      The attention calculation is augmented with learned representations for the
      relative position between each element in q and each element in k and v.
      Args:
        q: a Tensor with shape [batch, heads, length, depth].
        k: a Tensor with shape [batch, heads, length, depth].
        v: a Tensor with shape [batch, heads, length, depth].
        bias: bias Tensor.
        max_relative_position: an integer specifying the maximum distance between
            inputs that unique position embeddings should be learned for.
        dropout_rate: a floating point number.
        image_shapes: optional tuple of integer scalars.
        save_weights_to: an optional dictionary to capture attention weights
          for visualization; the weights tensor will be appended there under
          a string key created from the variable scope (including name).
        name: an optional string.
        make_image_summary: Whether to make an attention image summary.
        cache: whether use cache mode
        allow_memory: whether to assume that recurrent memory is in use. If True,
          the length dimension of k/v/bias may be longer than the queries, and it is
          assumed that the extra memory entries precede the non-memory entries.
        hard_attention_k: integer, if > 0 triggers hard attention (picking top-k)
        gumbel_noise_weight: if > 0, apply Gumbel noise with weight
          `gumbel_noise_weight` before picking top-k. This is a no op if
          hard_attention_k <= 0.
      Returns:
        A Tensor.
      Raises:
        ValueError: if max_relative_position is not > 0.
      """
      if not max_relative_position:
        raise ValueError("Max relative position (%s) should be > 0 when using "
                         "relative self attention." % (max_relative_position))
      with tf.variable_scope(
          name, default_name="dot_product_attention_relative",
          values=[q, k, v]) as scope:
    
        # This calculation only works for self attention.
        # q, k and v must therefore have the same shape, unless memory is enabled.
        if not cache and not allow_memory:
          q.get_shape().assert_is_compatible_with(k.get_shape())
          q.get_shape().assert_is_compatible_with(v.get_shape())
    
        # Use separate embeddings suitable for keys and values.
        depth = k.get_shape().as_list()[3]
        length_k = common_layers.shape_list(k)[2]
        length_q = common_layers.shape_list(q)[2] if allow_memory else length_k
        relations_keys = _generate_relative_positions_embeddings(
            length_q, length_k, depth, max_relative_position,
            "relative_positions_keys", cache=cache)
        relations_values = _generate_relative_positions_embeddings(
            length_q, length_k, depth, max_relative_position,
            "relative_positions_values", cache=cache)
    
        # Compute self attention considering the relative position embeddings.
        logits = _relative_attention_inner(q, k, relations_keys, True)
        if bias is not None:
          logits += bias
        weights = tf.nn.softmax(logits, name="attention_weights")
        if hard_attention_k > 0:
          weights = harden_attention_weights(weights, hard_attention_k,
                                             gumbel_noise_weight)
        if save_weights_to is not None:
          save_weights_to[scope.name] = weights
          save_weights_to[scope.name + "/logits"] = logits
        weights = tf.nn.dropout(weights, 1.0 - dropout_rate)
        if (not tf.get_variable_scope().reuse and
            common_layers.should_generate_summaries() and
            make_image_summary):
          attention_image_summary(weights, image_shapes)
        return _relative_attention_inner(weights, v, relations_values, False)
    

    which is coresponding of the formula clip(x; k) = max(-k; min(k; x))

    but in youre code ,there is a randn with grad,i don't understand ,could you make a explanations?

    opened by AncientRemember 0
  • Is it possible to modify these codes to support 3D images as well?

    Is it possible to modify these codes to support 3D images as well?

    Thank you for your great work!

    I was wondering if it is possible to modify these codes to support 3D images as well (i.e. adding z-axis). image

    I can't imagine how to change the dimensions of vectors in the "content-position" part. E.g. Hx1xd and 1xWxd -> Hx1x1xd and 1x1xZxd and 1xWx1xd ?

    Thank you for your answer!

    opened by kyuchoi 0
  • Hello, in the training of the following mistakes, how to solve it

    Hello, in the training of the following mistakes, how to solve it

    einops.EinopsError: Error while processing rearrange-reduction pattern "b h (x y) d -> b h x y d". Input tensor shape: torch.Size([2, 4, 900, 128]). Additional info: {'x': 26, 'y': 26}. Shape mismatch, 900 != 676

    opened by glt999 1
  • the size of tenor a (9) must match the size of tenor b (10) at a non singleton dimension 3

    the size of tenor a (9) must match the size of tenor b (10) at a non singleton dimension 3

    Hello, I want to ask a question, the input feature map is 228 * 304, but here is an error, the size of tenor a (9) must match the size of tenor B (10) at a non singleton dimension 3.

    opened by shezhi 1
  • the 2d relative position embedding is not Inductive,maybe the FLOATER embedding is better

    the 2d relative position embedding is not Inductive,maybe the FLOATER embedding is better

    opened by AncientRemember 2
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
PyTorch implementation of "Representing Shape Collections with Alignment-Aware Linear Models" paper.

deep-linear-shapes PyTorch implementation of "Representing Shape Collections with Alignment-Aware Linear Models" paper. If you find this code useful i

Romain Loiseau 27 Sep 24, 2022
Updated for TTS(CE) = Also Known as TTN V3. The code requires the first server to be 'ttn' protocol.

Updated Updated for TTS(CE) = Also Known as TTN V3. The code requires the first server to be 'ttn' protocol. Introduction This balenaCloud (previously

Remko 1 Oct 17, 2021
Harmonious Textual Layout Generation over Natural Images via Deep Aesthetics Learning

Harmonious Textual Layout Generation over Natural Images via Deep Aesthetics Learning Code for the paper Harmonious Textual Layout Generation over Nat

7 Aug 09, 2022
The code for the NSDI'21 paper "BMC: Accelerating Memcached using Safe In-kernel Caching and Pre-stack Processing".

BMC The code for the NSDI'21 paper "BMC: Accelerating Memcached using Safe In-kernel Caching and Pre-stack Processing". BibTex entry available here. B

Orange 383 Dec 16, 2022
[CVPR2021] UAV-Human: A Large Benchmark for Human Behavior Understanding with Unmanned Aerial Vehicles

UAV-Human Official repository for CVPR2021: UAV-Human: A Large Benchmark for Human Behavior Understanding with Unmanned Aerial Vehicle Paper arXiv Res

129 Jan 04, 2023
Bunch of different tools which helps visualizing and annotating images for semantic/instance segmentation tasks

Data Framework for Semantic/Instance Segmentation Bunch of different tools which helps visualizing, transforming and annotating images for semantic/in

Bruno Fernandes Carvalho 5 Dec 21, 2022
LaneDetectionAndLaneKeeping - Lane Detection And Lane Keeping

LaneDetectionAndLaneKeeping This project is part of my bachelor's thesis. The go

5 Jun 27, 2022
Semi-supervised Implicit Scene Completion from Sparse LiDAR

Semi-supervised Implicit Scene Completion from Sparse LiDAR Paper Created by Pengfei Li, Yongliang Shi, Tianyu Liu, Hao Zhao, Guyue Zhou and YA-QIN ZH

114 Nov 30, 2022
From Fidelity to Perceptual Quality: A Semi-Supervised Approach for Low-Light Image Enhancement (CVPR'2020)

Under-exposure introduces a series of visual degradation, i.e. decreased visibility, intensive noise, and biased color, etc. To address these problems, we propose a novel semi-supervised learning app

Yang Wenhan 117 Jan 03, 2023
PyTorch implementation for SDEdit: Image Synthesis and Editing with Stochastic Differential Equations

SDEdit: Image Synthesis and Editing with Stochastic Differential Equations Project | Paper | Colab PyTorch implementation of SDEdit: Image Synthesis a

536 Jan 05, 2023
Single-stage Keypoint-based Category-level Object Pose Estimation from an RGB Image

CenterPose Overview This repository is the official implementation of the paper "Single-stage Keypoint-based Category-level Object Pose Estimation fro

NVIDIA Research Projects 188 Dec 27, 2022
Implementation of Shape Generation and Completion Through Point-Voxel Diffusion

Shape Generation and Completion Through Point-Voxel Diffusion Project | Paper Implementation of Shape Generation and Completion Through Point-Voxel Di

Linqi Zhou 103 Dec 29, 2022
VIsually-Pivoted Audio and(N) Text

VIP-ANT: VIsually-Pivoted Audio and(N) Text Code for the paper Connecting the Dots between Audio and Text without Parallel Data through Visual Knowled

Yän.PnG 16 Nov 04, 2022
Code for ACM MM2021 paper "Complementary Trilateral Decoder for Fast and Accurate Salient Object Detection"

CTDNet The PyTorch code for ACM MM2021 paper "Complementary Trilateral Decoder for Fast and Accurate Salient Object Detection" Requirements Python 3.6

CVTEAM 28 Oct 20, 2022
Official Implementation for "StyleCLIP: Text-Driven Manipulation of StyleGAN Imagery" (ICCV 2021 Oral)

StyleCLIP: Text-Driven Manipulation of StyleGAN Imagery (ICCV 2021 Oral) Run this model on Replicate Optimization: Global directions: Mapper: Check ou

3.3k Jan 05, 2023
Universal Adversarial Examples in Remote Sensing: Methodology and Benchmark

Universal Adversarial Examples in Remote Sensing: Methodology and Benchmark Yong

19 Dec 17, 2022
Vehicle detection using machine learning and computer vision techniques for Udacity's Self-Driving Car Engineer Nanodegree.

Vehicle Detection Video demo Overview Vehicle detection using these machine learning and computer vision techniques. Linear SVM HOG(Histogram of Orien

hata 1.1k Dec 18, 2022
An elaborate and exhaustive paper list for Named Entity Recognition (NER)

Named-Entity-Recognition-NER-Papers by Pengfei Liu, Jinlan Fu and other contributors. An elaborate and exhaustive paper list for Named Entity Recognit

Pengfei Liu 388 Dec 18, 2022
NU-Wave: A Diffusion Probabilistic Model for Neural Audio Upsampling @ INTERSPEECH 2021 Accepted

NU-Wave — Official PyTorch Implementation NU-Wave: A Diffusion Probabilistic Model for Neural Audio Upsampling Junhyeok Lee, Seungu Han @ MINDsLab Inc

MINDs Lab 242 Dec 23, 2022
PyTorch implementation of "LayoutTransformer: Layout Generation and Completion with Self-attention"

PyTorch implementation of "LayoutTransformer: Layout Generation and Completion with Self-attention" to appear in ICCV 2021

Kamal Gupta 75 Dec 23, 2022