Implementation of H-Transformer-1D, Hierarchical Attention for Sequence Learning

Overview

H-Transformer-1D

Implementation of H-Transformer-1D, Transformer using hierarchical Attention for sequence learning with subquadratic costs.

For now, the H-Transformer will only act as a long-context encoder

Install

$ pip install h-transformer-1d

Usage

import torch
from h_transformer_1d import HTransformer1D

model = HTransformer1D(
    num_tokens = 256,          # number of tokens
    dim = 512,                 # dimension
    depth = 2,                 # depth
    max_seq_len = 8192,        # maximum sequence length
    heads = 8,                 # heads
    dim_head = 64,             # dimension per head
    block_size = 128           # block size
)

x = torch.randint(0, 256, (1, 8000))   # variable sequence length
mask = torch.ones((1, 8000)).bool()    # variable mask length

# network will automatically pad to power of 2, do hierarchical attention, etc

logits = model(x, mask = mask) # (1, 8000, 256)

Citations

@misc{zhu2021htransformer1d,
    title   = {H-Transformer-1D: Fast One-Dimensional Hierarchical Attention for Sequences}, 
    author  = {Zhenhai Zhu and Radu Soricut},
    year    = {2021},
    eprint  = {2107.11906},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
Comments
  • Masking not working in training, thanks

    Masking not working in training, thanks

    Hi, I have tried to train the model on GPU with masking enabled. The line 94 t = torch.flip(t, dims = (2,)) reports an error: RuntimeError: "flip_cuda" not implemented for 'Bool', even though I have tried to move mask to CPU.

    Any ideas to solve the problem? Thanks a lot.

    opened by junyongyou 6
  • Application to sequence classification?

    Application to sequence classification?

    Hi,

    Forgive the naive question, I am trying to make sense of this paper but it's tough going. If I understand correctly, this attention mechanism focuses mainly on nearby tokens and only attends to distant tokens via a hierarchical, low-rank approximation. In that case, can the usual sequence classification approach of having a global [CLS] token that can attend to all other tokens (and vice versa) still work? If not, how can this attention mechanism handle the text classification tasks in the long range arena benchmark?

    Cheers for whatever insights you can share, and thanks for the great work!

    opened by trpstra 4
  • eos token does not work in batch mode generation

    eos token does not work in batch mode generation

    When generating the sequence with current code it seems the eos_token will work when generating one sequence at a time https://github.com/lucidrains/h-transformer-1d/blob/main/h_transformer_1d/autoregressive_wrapper.py#L59

    opened by tmphex 4
  • RuntimeError: Tensor type unknown to einops <class 'torch.return_types.max'>

    RuntimeError: Tensor type unknown to einops

    lib/python3.7/site-packages/einops/_backends.py", line 52, in get_backend raise RuntimeError('Tensor type unknown to einops {}'.format(type(tensor)))

    RuntimeError: Tensor type unknown to einops <class 'torch.return_types.max'>

    I understand why this gets raised. Could it be a pytorch version problem? Mine is 1.6.0

    opened by wajihullahbaig 4
  • Algorithm Mismatch

    Algorithm Mismatch

    Paper Implementation

    In the implementation, we get blocked Q, K, V tensors by level with the code below.

    https://github.com/lucidrains/h-transformer-1d/blob/110cab0038898560d72d460bfef8ca8b7f17f0a5/h_transformer_1d/h_transformer_1d.py#L164-L179

    And return the final result of matrix-matrix product with Equation 29 or 69 with the for loop below.

    https://github.com/lucidrains/h-transformer-1d/blob/110cab0038898560d72d460bfef8ca8b7f17f0a5/h_transformer_1d/h_transformer_1d.py#L234-L247

    What is problem?

    However, according to the current code, it is not possible to include information about the level 0 white blocks in the figure below. (Equation 70 of the paper includes the corresponding attention matrix entries.)

    fig2

    I think you should also add an off-diagonal term of near-interaction (level 0) to match Equation 69!

    opened by jinmang2 3
  • I have some questions about implementation details

    I have some questions about implementation details

    Thanks for making your implementation public. I have three questions about your h-transformer 1d implementation.

    1. The number of levels M

    https://github.com/lucidrains/h-transformer-1d/blob/63063d5bb036b56a7205aadc5c8198da02d698f6/h_transformer_1d/h_transformer_1d.py#L105-L114

    In papers, eq (32) gives a guide on how M is determined.

    img1

    In your code implementations,

    • n is sequence length which is not padded
    • bsz is block size (is same to N_r which is numerical rank of the off-diagonal blocks)
    • Because code line 111 already contains level 0, M is equal to int(log2(n // bsz)) - 1

    However, in the Section 6.1 Constructing Hierarchical Attention, I found that sequence length(L) must be a multiple of 2. In my opinion, eq (31)'s L is equal to 2^{M+1}. In implementation, n is not padded sequence. So one of M is missing.

    Since x is the sequence padded by processing below, https://github.com/lucidrains/h-transformer-1d/blob/63063d5bb036b56a7205aadc5c8198da02d698f6/h_transformer_1d/h_transformer_1d.py#L83-L91

    I think the above implementation should be modified as below

    107    num_levels = int(log2(x.size(1) // bsz)) - 1 
    

    2. Super- and Sub-diagonal blocks of the coarsened matrix \tilde{A} as level-l

    https://github.com/lucidrains/h-transformer-1d/blob/63063d5bb036b56a7205aadc5c8198da02d698f6/h_transformer_1d/h_transformer_1d.py#L190-L198

    Ys conatins y and A computed as calculate_Y_and_A. For examples,

    # Ys
    [
        (batch_size*n_heads, N_b(2), N_r), (batch_size*n_heads, N_b(2)),  # level 2, (Y(2), tilde_A(2))
        (batch_size*n_heads, N_b(1), N_r), (batch_size*n_heads, N_b(1)),  # level 1, (Y(1), tilde_A(1))
        (batch_size*n_heads, N_b(0), N_r), (batch_size*n_heads, N_b(0)),  # level 0, (Y(0), tilde_A(0))
    ]
    

    In eq (29), Y is calculated as Y = AV = Y(0) + P(0)( Y(1) + P(1)Y(2) ) However, in code line 190, Y is calculated using only level-0 and level-1 blocks, no matter how many M there are. Y = AV = Y(0) + P(0)Y(1)

    Does increasing the level cause performance degradation issues in implementation? I'm so curious!


    3. Comparison with Luna: Linear Unified Nested Attention

    h-transformer significantly exceeded the scores of BigBird and Luna in LRA. However, what I regretted while reading the paper was that there was no comparison of computation time with other sub-quadratic and Luna. Is this algorithm much faster than other sub-quadratic? And how about compared to Luna?


    Thanks again for the implementation release! The idea of ​​calculating off-diagonal with flip was amazing and I learned a lot. Thank you!! 😄

    opened by jinmang2 3
  • Add Norm Missing

    Add Norm Missing

    I am using code now, and i wonder is there implemented add norm? I only find layer norm, but no add operation. Here is code in h-transformer-1d.py line 489 ... Is this a bug or something ? Thanks @Lucidrains

    for ind in range(depth): attn = attn_class(dim, dim_head = dim_head, heads = heads, block_size = block_size, pos_emb = self.pos_emb, **attn_kwargs) ff = FeedForward(dim, mult = ff_mult)

            if shift_tokens:
                attn, ff = map(lambda t: PreShiftTokens(shift_token_ranges, t), (attn, ff))
    
            attn, ff = map(lambda t: PreNorm(dim, t), (attn, ff))
            layers.append(nn.ModuleList([attn ,ff]))_
    
    opened by wwx13 2
  • Mask not working

    Mask not working

    def forward(self, x, mask = None):
        b, n, device = *x.shape, x.device
        assert n <= self.max_seq_len, 'sequence length must be less than the maximum sequence length'
        x = self.token_emb(x)
        x = self.layers(x)
        return self.to_logits(x)
    

    I think... Masking does not work ???

    opened by wwx13 2
  • One simple question

    One simple question

    Hi, Phil!

    One simple question, (my math is not good) https://github.com/lucidrains/h-transformer-1d/blob/7c11d036d53926495ec0917a34a1aad7261892b5/train.py#L65

    why not be randint(0, self.data.size(0)-self.seq_len+1)? Since the high part should be excluded

    opened by CiaoHe 2
  • Mini-batching (b > 1) does not work with masking

    Mini-batching (b > 1) does not work with masking

    When using x and mask that have batch size larger than 1 following error is arises:

    import torch
    from h_transformer_1d import HTransformer1D
    
    model = HTransformer1D(
        num_tokens = 256,          # number of tokens
        dim = 512,                 # dimension
        depth = 2,                 # depth
        causal = False,            # autoregressive or not
        max_seq_len = 8192,        # maximum sequence length
        heads = 8,                 # heads
        dim_head = 64,             # dimension per head
        block_size = 128           # block size
    )
    
    batch_size = 2
    x = torch.randint(0, 256, (batch_size, 8000))   # variable sequence length
    mask = torch.ones((batch_size, 8000)).bool()    # variable mask length
    
    # network will automatically pad to power of 2, do hierarchical attention, etc
    
    logits = model(x, mask = mask) # (1, 8000, 256)
    

    Gives following error:

    ~/git/h-transformer-1d/h_transformer_1d/h_transformer_1d.py in masked_aggregate(tensor, mask, dim, average)
         19     diff_len = len(tensor.shape) - len(mask.shape)
         20     mask = mask[(..., *((None,) * diff_len))]
    ---> 21     tensor = tensor.masked_fill(~mask, 0.)
         22 
         23     total_el = mask.sum(dim = dim)
    
    RuntimeError: The size of tensor a (2) must match the size of tensor b (16) at non-singleton dimension 0
    

    It seems the tensor has shape heads * batch in 0 dimension and not batch what mask has.

    opened by jaak-s 2
  • Example in README does not work

    Example in README does not work

    Executing the example:

    import torch
    from h_transformer_1d import HTransformer1D
    
    model = HTransformer1D(
        num_tokens = 256,          # number of tokens
        dim = 512,                 # dimension
        depth = 2,                 # depth
        causal = False,            # autoregressive or not
        max_seq_len = 8192,        # maximum sequence length
        heads = 8,                 # heads
        dim_head = 64,             # dimension per head
        block_size = 128           # block size
    )
    
    x = torch.randint(0, 256, (1, 8000))   # variable sequence length
    mask = torch.ones((1, 8000)).bool()    # variable mask length
    
    # network will automatically pad to power of 2, do hierarchical attention, etc
    
    logits = model(x, mask = mask) # (1, 8000, 256)
    

    Gives the following error:

    ~/miniconda3/lib/python3.7/site-packages/rotary_embedding_torch/rotary_embedding_torch.py in apply_rotary_emb(freqs, t, start_index)
         43     assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
         44     t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
    ---> 45     t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
         46     return torch.cat((t_left, t, t_right), dim = -1)
         47 
    
    RuntimeError: The size of tensor a (8192) must match the size of tensor b (8000) at non-singleton dimension 1
    
    opened by jaak-s 2
  • Fix indexing

    Fix indexing

    I am fixing a few apparent bugs in the code. The upshot is that the attention now supports a block size of the (next largest power of two) of the input length, and for this value of the block size it becomes exact. This allows one to look at the systematic error in the output as a function of decreased block size (and memory usage).

    I've found this module to reduce memory consumption by a factor of two, but the approximation quickly becomes too inaccurate with decreasing block size to use it as a drop-in replacement for an existing (full) attention layer.

    This repository shows how to compute the full attention with linear memory complexity: https://github.com/CHARM-Tx/linear_mem_attention_pytorch

    opened by jglaser 0
  • Approximated values are off

    Approximated values are off

    I wrote a simple test to check the output of the hierarchical transformer self attention against the BERT self attention from huggingface transformers.

    import torch
    import torch.nn as nn
    import math
    
    from h_transformer_1d.h_transformer_1d import HAttention1D
    
    def transpose_for_scores(x, num_attention_heads, attention_head_size):
        new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)
    
    def bert_self_attention(query, key, value_layer, attention_mask=None, num_attention_heads=1):
            dim_head = query.size()[-1] // num_attention_heads
            all_head_size = dim_head*num_attention_heads
    
            query_layer = transpose_for_scores(query, num_attention_heads, dim_head)
            key_layer = transpose_for_scores(key, num_attention_heads, dim_head)
            value_layer = transpose_for_scores(value, num_attention_heads, dim_head)
    
            # Take the dot product between "query" and "key" to get the raw attention scores.
            attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
    
            attention_scores = attention_scores / math.sqrt(dim_head)
    
            if attention_mask is not None:
                # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
                attention_scores = attention_scores + attention_mask
    
            # Normalize the attention scores to probabilities.
            attention_probs = nn.functional.softmax(attention_scores, dim=-1)
    
            # This is actually dropping out entire tokens to attend to, which might
            # seem a bit unusual, but is taken from the original Transformer paper.
            #attention_probs = self.dropout(attention_probs)
    
            context_layer = torch.matmul(attention_probs, value_layer)
    
            context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
            new_context_layer_shape = context_layer.size()[:-2] + (all_head_size,)
            context_layer = context_layer.view(*new_context_layer_shape)
    
            return context_layer, attention_probs
    
    if __name__ == "__main__":
        query = torch.tensor([[[0.1,0.2],[-0.5,0.7],[-0.5,-0.75],[.123,.456]]])
    #    query = torch.tensor([[[0.1,0.2],[-0.5,0.7]]])
        key = value = query
    
        n_heads = 1
        attn, probs = bert_self_attention(query, key, value, num_attention_heads=n_heads)
        print('bert_self_attention out: ', attn)
    
        block_size = 1
        for _ in range(0,2):
            dim_head = query.size()[-1]//n_heads
            h_attn = HAttention1D(
                dim=query.size()[-1],
                heads=n_heads,
                dim_head=dim_head,
                block_size=block_size
            )
    
            h_attn.to_qkv = torch.nn.Identity()
            h_attn.to_out = torch.nn.Identity()
    
            qkv = torch.stack([query, key, value], dim=2)
            qkv = torch.flatten(qkv, start_dim=2)
    
            attn_scores = h_attn(qkv)
            print('hattention_1d: (block_size = {})'.format(block_size), attn_scores)
    
            block_size *= 2
    

    This is the output I get

    bert_self_attention:  tensor([[[-0.1807,  0.1959],
             [-0.2096,  0.2772],
             [-0.2656, -0.0568],
             [-0.1725,  0.2442]]])
    hattention_1d: (block_size = 1) tensor([[[-0.2000,  0.4500],
             [-0.2000,  0.4500],
             [-0.1885, -0.1470],
             [-0.1885, -0.1470]]])
    

    before it errors out with

    assert num_levels >= 0, 'number of levels must be at least greater than 0'
    

    Some of the values are off in absolute magnitude by more than a factor of two.

    Looking at the code, this line seems problematic: https://github.com/lucidrains/h-transformer-1d/blob/8afd75cc6bc41754620bb6ab3737176cb69bdf93/h_transformer_1d/h_transformer_1d.py#L172

    I believe it should read

    num_levels = int(log2(pad_to_len // bsz)) - 1
    

    If I make that change, the approximated attention output is much closer to the exact one:

    bert_self_attention out:  tensor([[[-0.1807,  0.1959],
             [-0.2096,  0.2772],
             [-0.2656, -0.0568],
             [-0.1725,  0.2442]]])
    hattention_1d: (block_size = 1) tensor([[[-0.2590,  0.2020],
             [-0.2590,  0.2020],
             [-0.2590,  0.2020],
             [-0.2590,  0.2020]]])
    hattention_1d: (block_size = 2) tensor([[[-0.1808,  0.1972],
             [-0.1980,  0.2314],
             [-0.2438,  0.0910],
             [-0.1719,  0.2413]]])
    
    opened by jglaser 1
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Read number plates with https://platerecognizer.com/

HASS-plate-recognizer Read vehicle license plates with https://platerecognizer.com/ which offers free processing of 2500 images per month. You will ne

Robin 69 Dec 30, 2022
Code release for The Devil is in the Channels: Mutual-Channel Loss for Fine-Grained Image Classification (TIP 2020)

The Devil is in the Channels: Mutual-Channel Loss for Fine-Grained Image Classification Code release for The Devil is in the Channels: Mutual-Channel

PRIS-CV: Computer Vision Group 230 Dec 31, 2022
Winners of the Facebook Image Similarity Challenge

Winners of the Facebook Image Similarity Challenge

DrivenData 111 Jan 05, 2023
This is implementation of AlexNet(2012) with 3D Convolution on TensorFlow (AlexNet 3D).

AlexNet_3dConv TensorFlow implementation of AlexNet(2012) by Alex Krizhevsky, with 3D convolutiional layers. 3D AlexNet Network with a standart AlexNe

Denis Timonin 41 Jan 16, 2022
Official repository for CVPR21 paper "Deep Stable Learning for Out-Of-Distribution Generalization".

StableNet StableNet is a deep stable learning method for out-of-distribution generalization. This is the official repo for CVPR21 paper "Deep Stable L

120 Dec 28, 2022
Problem-943.-ACMP - Problem 943. ACMP

Problem-943.-ACMP В "main.py" расположен вариант моего решения задачи 943 с серв

Konstantin Dyomshin 2 Aug 19, 2022
Pytorch implementation for the Temporal and Object Quantification Networks (TOQ-Nets).

TOQ-Nets-PyTorch-Release Pytorch implementation for the Temporal and Object Quantification Networks (TOQ-Nets). Temporal and Object Quantification Net

Zhezheng Luo 9 Jun 30, 2022
Minimal implementation of PAWS (https://arxiv.org/abs/2104.13963) in TensorFlow.

PAWS-TF 🐾 Implementation of Semi-Supervised Learning of Visual Features by Non-Parametrically Predicting View Assignments with Support Samples (PAWS)

Sayak Paul 43 Jan 08, 2023
A Runtime method overload decorator which should behave like a compiled language

strongtyping-pyoverload A Runtime method overload decorator which should behave like a compiled language there is a override decorator from typing whi

20 Oct 31, 2022
Facial detection, landmark tracking and expression transfer library for Windows, Linux and Mac

Welcome to the CSIRO Face Analysis SDK. Documentation for the SDK can be found in doc/documentation.html. All code in this SDK is provided according t

Luiz Carlos Vieira 7 Jul 16, 2020
Intro-to-dl - Resources for "Introduction to Deep Learning" course.

Introduction to Deep Learning course resources https://www.coursera.org/learn/intro-to-deep-learning Running on Google Colab (tested for all weeks) Go

Advanced Machine Learning specialisation by HSE 761 Dec 24, 2022
Stochastic Downsampling for Cost-Adjustable Inference and Improved Regularization in Convolutional Networks

Stochastic Downsampling for Cost-Adjustable Inference and Improved Regularization in Convolutional Networks (SDPoint) This repository contains the cod

Jason Kuen 17 Jul 04, 2022
An easy way to build PyTorch datasets. Modularly build datasets and automatically cache processed results

EasyDatas An easy way to build PyTorch datasets. Modularly build datasets and automatically cache processed results Installation pip install git+https

Ximing Yang 4 Dec 14, 2021
Genshin-assets - 👧 Public documentation & static assets for Genshin Impact data.

genshin-assets This repo provides easy access to the Genshin Impact assets, primarily for use on static sites. Sources Genshin Optimizer - An Artifact

Zerite Development 5 Nov 22, 2022
Adversarial Attacks on Probabilistic Autoregressive Forecasting Models.

Attack-Probabilistic-Models This is the source code for Adversarial Attacks on Probabilistic Autoregressive Forecasting Models. This repository contai

SRI Lab, ETH Zurich 25 Sep 14, 2022
Object detection evaluation metrics using Python.

Object detection evaluation metrics using Python.

Louis Facun 2 Sep 06, 2022
3D AffordanceNet is a 3D point cloud benchmark consisting of 23k shapes from 23 semantic object categories, annotated with 56k affordance annotations and covering 18 visual affordance categories.

3D AffordanceNet This repository is the official experiment implementation of 3D AffordanceNet benchmark. 3D AffordanceNet is a 3D point cloud benchma

49 Dec 01, 2022
Clustering is a popular approach to detect patterns in unlabeled data

Visual Clustering Clustering is a popular approach to detect patterns in unlabeled data. Existing clustering methods typically treat samples in a data

Tarek Naous 24 Nov 11, 2022
NeoDTI: Neural integration of neighbor information from a heterogeneous network for discovering new drug-target interactions

NeoDTI NeoDTI: Neural integration of neighbor information from a heterogeneous network for discovering new drug-target interactions (Bioinformatics).

62 Nov 26, 2022
Development of IP code based on VIPs and AADM

Sparse Implicit Processes In this repository we include the two different versions of the SIP code developed for the article Sparse Implicit Processes

1 Aug 22, 2022